Python torch.einsum() Examples

The following are 30 code examples of torch.einsum().
Example #1
def forward(self, x, y):
        The computation logic of MatchingTensor.

        :param inputs: two input tensors.

        if self._normalize:
            x = F.normalize(x, p=2, dim=-1)
            y = F.normalize(y, p=2, dim=-1)

        # output = [b, c, l, r]
        output = torch.einsum(
            x, self.interaction_matrix, y
        return output 
Example #2
def forward(self, x, rbf, sbf, idx_kj, idx_ji):
        rbf = self.lin_rbf(rbf)
        sbf = self.lin_sbf(sbf)

        x_ji = self.act(self.lin_ji(x))
        x_kj = self.act(self.lin_kj(x))
        x_kj = x_kj * rbf
        x_kj = torch.einsum('wj,wl,ijl->wi', sbf, x_kj[idx_kj], self.W)
        x_kj = scatter(x_kj, idx_ji, dim=0, dim_size=x.size(0))

        h = x_ji + x_kj
        for layer in self.layers_before_skip:
            h = layer(h)
        h = self.act(self.lin(h)) + x
        for layer in self.layers_after_skip:
            h = layer(h)

        return h 
Example #3
def _make_hessian_mat_prod(self, module, g_inp, g_out):
        """Multiplication of the input Hessian with a matrix."""

        probs = self._get_probs(module)

        def hessian_mat_prod(mat):
            Hmat = einsum("bi,cbi->cbi", (probs, mat)) - einsum(
                "bi,bj,cbj->cbi", (probs, probs, mat)

            if module.reduction == "mean":
                N = module.input0.shape[0]
                Hmat /= N

            return Hmat

        return hessian_mat_prod 
Example #4
def _sqrt_hessian_sampled(self, module, g_inp, g_out, mc_samples=1):

        M = mc_samples
        C = module.input0.shape[1]

        probs = self._get_probs(module)
        V_dim = 0
        probs_unsqueezed = probs.unsqueeze(V_dim).repeat(M, 1, 1)

        multi = multinomial(probs, M, replacement=True)
        classes = one_hot(multi, num_classes=C)
        classes = einsum("nvc->vnc", classes).float()

        sqrt_mc_h = (probs_unsqueezed - classes) / sqrt(M)

        if module.reduction == "mean":
            N = module.input0.shape[0]
            sqrt_mc_h /= sqrt(N)

        return sqrt_mc_h 
Example #5
def pacconv2d(input, kernel, weight, bias=None, stride=1, padding=0, dilation=1, shared_filters=False,
    kernel_size = tuple(weight.shape[-2:])
    stride = _pair(stride)
    padding = _pair(padding)
    dilation = _pair(dilation)

    if native_impl:
        # im2col on input
        im_cols = nd2col(input, kernel_size, stride=stride, padding=padding, dilation=dilation)

        # main computation
        if shared_filters:
            output = torch.einsum('ijklmn,zykl->ijmn', (im_cols * kernel, weight))
            output = torch.einsum('ijklmn,ojkl->iomn', (im_cols * kernel, weight))

        if bias is not None:
            output += bias.view(1, -1, 1, 1)
        output = PacConv2dFn.apply(input, kernel, weight, bias, stride, padding, dilation, shared_filters)

    return output 
Example #6
def forward(ctx, input, kernel, kernel_size, stride=1, padding=0, dilation=1):
        (bs, ch), in_sz = input.shape[:2], input.shape[2:]
        if kernel.size(1) > 1 and kernel.size(1) != ch:
            raise ValueError('Incompatible input and kernel sizes.')
        ctx.input_size = in_sz
        ctx.kernel_size = _pair(kernel_size)
        ctx.kernel_ch = kernel.size(1)
        ctx.dilation = _pair(dilation)
        ctx.padding = _pair(padding)
        ctx.stride = _pair(stride)
        ctx.save_for_backward(input if ctx.needs_input_grad[1] else None,
                              kernel if ctx.needs_input_grad[0] else None)
        ctx._backend = type2backend[input.type()]

        cols = F.unfold(input, ctx.kernel_size, ctx.dilation, ctx.padding, ctx.stride)

        output = cols.view(bs, ch, *kernel.shape[2:]) * kernel
        output = torch.einsum('ijklmn->ijmn', (output,))

        return output.clone()  # TODO check whether a .clone() is needed here 
Example #7
def msg_func(edges):
    """Send messages along edges.

    edges : EdgeBatch
        A batch of edges.

    dict mapping 'm' to Float32 tensor of shape (E, K * T)
        Messages computed. E for the number of edges, K for the number of
        radial filters and T for the number of features to use
        (types of atomic number in the paper).
    return {'m': th.einsum(
        'ij,ik->ijk', edges.src['hv'],['he']).view(len(edges), -1)} 
Example #8
def bdd_message_func(self, edges):
        """Message function for block-diagonal-decomposition regularizer"""
        if edges.src['h'].dtype == th.int64 and len(edges.src['h'].shape) == 1:
            raise TypeError('Block decomposition does not allow integer ID feature.')

        # calculate msg @ W_r before put msg into edge
        if self.low_mem:
            etypes = th.unique(['type'])
            msg = th.empty((edges.src['h'].shape[0], self.out_feat),
            for etype in etypes:
                loc =['type'] == etype
                w = self.weight[etype].view(self.num_bases, self.submat_in, self.submat_out)
                src = edges.src['h'][loc].view(-1, self.num_bases, self.submat_in)
                sub_msg = th.einsum('abc,bcd->abd', src, w)
                sub_msg = sub_msg.reshape(-1, self.out_feat)
                msg[loc] = sub_msg
            weight = self.weight.index_select(0,['type']).view(
                -1, self.submat_in, self.submat_out)
            node = edges.src['h'].view(-1, 1, self.submat_in)
            msg = th.bmm(node, weight).view(-1, self.out_feat)
        if 'norm' in
            msg = msg *['norm']
        return {'msg': msg} 
Example #9
def kfac_mat_prod(factors):
    """Return function v ↦ (A ⊗ B ⊗ ...)v for `factors = [A, B, ...]` """
    assert all_tensors_of_order(order=2, tensors=factors)

    shapes = [list(f.size()) for f in factors]
    _, col_dims = zip(*shapes)

    num_factors = len(shapes)
    equation = kfac_mat_prod_einsum_equation(num_factors)

    def kfacmp(mat):
        assert is_matrix(mat)
        _, mat_cols = mat.shape
        mat_reshaped = mat.view(*(col_dims), mat_cols)
        return einsum(equation, mat_reshaped, *factors).contiguous().view(-1, mat_cols)

    return kfacmp 
Example #10
def _get_overlaps_tensor(self, L):
        """Transforms the input label matrix to a three-way overlaps tensor.

            L: (np.array) An n x m array of LF output labels, in {0,...,k} if
                self.abstains, else in {1,...,k}, generated by m conditionally
                independent LFs on n data points

            O: (torch.Tensor) A (m, m, m, k, k, k) tensor of the label-specific
            empirical overlap rates; that is,

                O[i,j,k,y1,y2,y3] = P(\lf_i = y1, \lf_j = y2, \lf_k = y3)

            where this quantity is computed empirically by this function, based
            on the label matrix L.
        n, m = L.shape

        # Convert from a (n,m) matrix of ints to a (k_lf, n, m) indicator tensor
        LY = np.array([np.where(L == y, 1, 0) for y in range(self.k_0, self.k + 1)])

        # Form the three-way overlaps matrix
        O = np.einsum("abc,dbe,fbg->cegadf", LY, LY, LY) / n
        return torch.from_numpy(O).float() 
Example #11
def weight(self, ext, module, g_inp, g_out, backproped):
        return einsum("ni,nj->ij", (g_out[0] ** 2, module.input0 ** 2)) 
Source File:    From backpack with MIT License 5 votes vote down vote up
def make_random_psd_kfacs(self, num_facs=None):
        def make_quadratic_psd(mat):
            """Make matrix positive semi-definite: A -> AAᵀ."""
            mat_squared = einsum("ij,kj->ik", (mat, mat))
            shift = self.PSD_KFAC_MIN_EIGVAL * self.torch_eye_like(mat_squared)
            return mat_squared + shift

        kfacs = self.make_random_kfacs(num_facs=num_facs)
        return [make_quadratic_psd(fac) for fac in kfacs]

    # Torch helpers
Example #13
def bias(self, ext, module, g_inp, g_out, backproped):
        N_axis = 0
        return (einsum("nchw->nc", g_out[0]) ** 2).sum(N_axis) 
Example #14
def _jac_t_mat_prod(self, module, g_inp, g_out, mat):

        df_elementwise = self.df(module, g_inp, g_out)
        return einsum("...,v...->v...", (df_elementwise, mat)) 
Example #15
def _jac_t_mat_prod(self, module, g_inp, g_out, mat):
        The Jacobian is *not independent* among the batch dimension, i.e.
        D z_i = D z_i(x_1, ..., x_B).

        This structure breaks the computation of the GGN diagonal,
        for curvature-matrix products it should still work.

        assert module.affine is True

        N = self.get_batch(module)
        x_hat, var = self.get_normalized_input_and_var(module)
        ivar = 1.0 / (var + module.eps).sqrt()

        dx_hat = einsum("vni,i->vni", (mat, module.weight))

        jac_t_mat = N * dx_hat
        jac_t_mat -= dx_hat.sum(1).unsqueeze(1).expand_as(jac_t_mat)
        jac_t_mat -= einsum("ni,vsi,si->vni", (x_hat, dx_hat, x_hat))
        jac_t_mat = einsum("vni,i->vni", (jac_t_mat, ivar / N))

        return jac_t_mat 
Example #16
def _weight_jac_mat_prod(self, module, g_inp, g_out, mat):
        x_hat, _ = self.get_normalized_input_and_var(module)
        return einsum("ni,vi->vni", (x_hat, mat)) 
Example #17
def _weight_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch):
        if not sum_batch:
                "BatchNorm batch summation disabled."
                "This may not compute meaningful quantities"
        x_hat, _ = self.get_normalized_input_and_var(module)
        equation = "vni,ni->v{}i".format("" if sum_batch is True else "n")
        operands = [mat, x_hat]
        return einsum(equation, operands) 
Example #18
def _weight_jac_mat_prod(self, module, g_inp, g_out, mat):
        jac_mat = eingroup("v,o,i,h,w->v,o,ihw", mat)
        X = self.get_unfolded_input(module)

        jac_mat = einsum("nij,vki->vnkj", (X, jac_mat))
        return self.reshape_like_output(jac_mat, module) 
Example #19
def _sqrt_hessian(self, module, g_inp, g_out):

        probs = self._get_probs(module)
        tau = torchsqrt(probs)
        V_dim, C_dim = 0, 2
        Id = diag_embed(ones_like(probs), dim1=V_dim, dim2=C_dim)
        Id_tautau = Id - einsum("nv,nc->vnc", tau, tau)
        sqrt_H = einsum("nc,vnc->vnc", tau, Id_tautau)

        if module.reduction == "mean":
            N = module.input0.shape[0]
            sqrt_H /= sqrt(N)

        return sqrt_H 
Example #20
def _factor_from_sqrt(self, module, backproped):
        sqrt_ggn = backproped

        sqrt_ggn = convUtils.separate_channels_and_pixels(module, sqrt_ggn)
        sqrt_ggn = einsum("cbij->cbi", (sqrt_ggn,))
        return einsum("cbi,cbl->il", (sqrt_ggn, sqrt_ggn)) 
Example #21
def bias(self, ext, module, g_inp, g_out, backproped):
        C_axis = 1
        return (einsum("nchw->nc", g_out[0]) ** 2).sum(C_axis) 
Example #22
def weight(self, ext, module, g_inp, g_out, backproped):
        return einsum("ni,nj->n", (g_out[0] ** 2, module.input0 ** 2)) 
Example #23
def two_kfacs_to_mat(A, B):
    """Given A, B, return A ⊗ B."""
    assert is_matrix(A)
    assert is_matrix(B)

    mat_shape = (
        A.shape[0] * B.shape[0],
        A.shape[1] * B.shape[1],
    mat = einsum("ij,kl->ikjl", (A, B)).contiguous().view(mat_shape)
    return mat 
Source File:    From backpack with MIT License 5 votes vote down vote up
def extract_bias_diagonal(module, sqrt):
    `sqrt` must be the backpropagated quantity for DiagH or DiagGGN(MC)
    V_axis, N_axis = 0, 1
    bias_diagonal = (einsum("vnchw->vnc", sqrt) ** 2).sum([V_axis, N_axis])
    return bias_diagonal 
Source File:    From backpack with MIT License 5 votes vote down vote up
def extract_weight_diagonal(module, input, grad_output):
    input must be the unfolded input to the convolution (see unfold_func)
    and grad_output the backpropagated gradient
    grad_output_viewed = separate_channels_and_pixels(module, grad_output)
    AX = einsum("nkl,vnml->vnkm", (input, grad_output_viewed))
    weight_diagonal = (AX ** 2).sum([0, 1]).transpose(0, 1)
    return weight_diagonal.view_as(module.weight) 
Example #26
def extract_bias_diagonal(module, backproped):
    return einsum("vno->o", backproped ** 2) 
Example #27
def extract_weight_diagonal(module, backproped):
    return einsum("vno,ni->oi", (backproped ** 2, module.input0 ** 2)) 
Example #28
def _get_mixed_outputs(self, outputs):
        # outputs: num_layers x batch_size x max_len x hidden_size
        # return: batch_size x max_len x hidden_size
        weights = F.softmax(self.layer_weights + 1 / len(outputs), dim=0).to(outputs)
        outputs = torch.einsum('l,lbij->bij', weights, outputs)
        return * outputs 
Example #29
def forward(self, inputs):

        query, doc = inputs['text_left'], inputs['text_right']

        q_embed = self.embedding(query.long()).transpose(1, 2)
        d_embed = self.embedding(doc.long()).transpose(1, 2)

        q_convs = []
        d_convs = []
        for q_conv, d_conv in zip(self.q_convs, self.d_convs):
            q_convs.append(q_conv(q_embed).transpose(1, 2))
            d_convs.append(d_conv(d_embed).transpose(1, 2))

        KM = []
        for qi in range(self._params['max_ngram']):
            for di in range(self._params['max_ngram']):
                # do not match n-gram with different length if use crossmatch
                if not self._params['use_crossmatch'] and qi != di:
                mm = torch.einsum(
                    F.normalize(q_convs[qi], p=2, dim=-1),
                    F.normalize(d_convs[di], p=2, dim=-1)
                for kernel in self.kernels:
                    K = torch.log1p(kernel(mm).sum(dim=-1)).sum(dim=-1)

        phi = torch.stack(KM, dim=1)

        out = self.out(phi)
        return out 
Example #30
def forward(self, inputs):

        # Scalar dimensions referenced here:
        #   B = batch size (number of sequences)
        #   D = embedding size
        #   L = `input_left` sequence length
        #   R = `input_right` sequence length
        #   H = histogram size
        #   K = size of top-k

        # Left input and right input.
        # query: shape = [B, L]
        # doc: shape = [B, L, H]
        # Note here, the doc is the matching histogram between original query
        # and original document.

        query, match_hist = inputs['text_left'], inputs['match_histogram']

        # shape = [B, L]
        mask_query = (query == self._params['mask_value'])

        # Process left input.
        # shape = [B, L, D]
        embed_query = self.embedding(query.long())

        # shape = [B, L]
        attention_probs = self.attention(embed_query, mask_query)

        # shape = [B, L]
        dense_output = self.mlp(match_hist).squeeze(dim=-1)

        x = torch.einsum('bl,bl->b', dense_output, attention_probs)

        out = self.out(x.unsqueeze(dim=-1))
        return out