Python torch.einsum() Examples

The following are 30 code examples of torch.einsum(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch , or try the search function .
Example #1
Source File: matching_tensor.py    From MatchZoo-py with Apache License 2.0 6 votes vote down vote up
def forward(self, x, y):
        """
        The computation logic of MatchingTensor.

        :param inputs: two input tensors.
        """

        if self._normalize:
            x = F.normalize(x, p=2, dim=-1)
            y = F.normalize(y, p=2, dim=-1)

        # output = [b, c, l, r]
        output = torch.einsum(
            'bld,cde,bre->bclr',
            x, self.interaction_matrix, y
        )
        return output 
Example #2
Source File: dimenet.py    From pytorch_geometric with MIT License 6 votes vote down vote up
def forward(self, x, rbf, sbf, idx_kj, idx_ji):
        rbf = self.lin_rbf(rbf)
        sbf = self.lin_sbf(sbf)

        x_ji = self.act(self.lin_ji(x))
        x_kj = self.act(self.lin_kj(x))
        x_kj = x_kj * rbf
        x_kj = torch.einsum('wj,wl,ijl->wi', sbf, x_kj[idx_kj], self.W)
        x_kj = scatter(x_kj, idx_ji, dim=0, dim_size=x.size(0))

        h = x_ji + x_kj
        for layer in self.layers_before_skip:
            h = layer(h)
        h = self.act(self.lin(h)) + x
        for layer in self.layers_after_skip:
            h = layer(h)

        return h 
Example #3
Source File: crossentropyloss.py    From backpack with MIT License 6 votes vote down vote up
def _make_hessian_mat_prod(self, module, g_inp, g_out):
        """Multiplication of the input Hessian with a matrix."""
        self._check_2nd_order_parameters(module)

        probs = self._get_probs(module)

        def hessian_mat_prod(mat):
            Hmat = einsum("bi,cbi->cbi", (probs, mat)) - einsum(
                "bi,bj,cbj->cbi", (probs, probs, mat)
            )

            if module.reduction == "mean":
                N = module.input0.shape[0]
                Hmat /= N

            return Hmat

        return hessian_mat_prod 
Example #4
Source File: crossentropyloss.py    From backpack with MIT License 6 votes vote down vote up
def _sqrt_hessian_sampled(self, module, g_inp, g_out, mc_samples=1):
        self._check_2nd_order_parameters(module)

        M = mc_samples
        C = module.input0.shape[1]

        probs = self._get_probs(module)
        V_dim = 0
        probs_unsqueezed = probs.unsqueeze(V_dim).repeat(M, 1, 1)

        multi = multinomial(probs, M, replacement=True)
        classes = one_hot(multi, num_classes=C)
        classes = einsum("nvc->vnc", classes).float()

        sqrt_mc_h = (probs_unsqueezed - classes) / sqrt(M)

        if module.reduction == "mean":
            N = module.input0.shape[0]
            sqrt_mc_h /= sqrt(N)

        return sqrt_mc_h 
Example #5
Source File: pac.py    From openseg.pytorch with MIT License 6 votes vote down vote up
def pacconv2d(input, kernel, weight, bias=None, stride=1, padding=0, dilation=1, shared_filters=False,
              native_impl=False):
    kernel_size = tuple(weight.shape[-2:])
    stride = _pair(stride)
    padding = _pair(padding)
    dilation = _pair(dilation)

    if native_impl:
        # im2col on input
        im_cols = nd2col(input, kernel_size, stride=stride, padding=padding, dilation=dilation)

        # main computation
        if shared_filters:
            output = torch.einsum('ijklmn,zykl->ijmn', (im_cols * kernel, weight))
        else:
            output = torch.einsum('ijklmn,ojkl->iomn', (im_cols * kernel, weight))

        if bias is not None:
            output += bias.view(1, -1, 1, 1)
    else:
        output = PacConv2dFn.apply(input, kernel, weight, bias, stride, padding, dilation, shared_filters)

    return output 
Example #6
Source File: pac.py    From openseg.pytorch with MIT License 6 votes vote down vote up
def forward(ctx, input, kernel, kernel_size, stride=1, padding=0, dilation=1):
        (bs, ch), in_sz = input.shape[:2], input.shape[2:]
        if kernel.size(1) > 1 and kernel.size(1) != ch:
            raise ValueError('Incompatible input and kernel sizes.')
        ctx.input_size = in_sz
        ctx.kernel_size = _pair(kernel_size)
        ctx.kernel_ch = kernel.size(1)
        ctx.dilation = _pair(dilation)
        ctx.padding = _pair(padding)
        ctx.stride = _pair(stride)
        ctx.save_for_backward(input if ctx.needs_input_grad[1] else None,
                              kernel if ctx.needs_input_grad[0] else None)
        ctx._backend = type2backend[input.type()]

        cols = F.unfold(input, ctx.kernel_size, ctx.dilation, ctx.padding, ctx.stride)

        output = cols.view(bs, ch, *kernel.shape[2:]) * kernel
        output = torch.einsum('ijklmn->ijmn', (output,))

        return output.clone()  # TODO check whether a .clone() is needed here 
Example #7
Source File: atomicconv.py    From dgl with Apache License 2.0 6 votes vote down vote up
def msg_func(edges):
    """Send messages along edges.

    Parameters
    ----------
    edges : EdgeBatch
        A batch of edges.

    Returns
    -------
    dict mapping 'm' to Float32 tensor of shape (E, K * T)
        Messages computed. E for the number of edges, K for the number of
        radial filters and T for the number of features to use
        (types of atomic number in the paper).
    """
    return {'m': th.einsum(
        'ij,ik->ijk', edges.src['hv'], edges.data['he']).view(len(edges), -1)} 
Example #8
Source File: relgraphconv.py    From dgl with Apache License 2.0 6 votes vote down vote up
def bdd_message_func(self, edges):
        """Message function for block-diagonal-decomposition regularizer"""
        if edges.src['h'].dtype == th.int64 and len(edges.src['h'].shape) == 1:
            raise TypeError('Block decomposition does not allow integer ID feature.')

        # calculate msg @ W_r before put msg into edge
        if self.low_mem:
            etypes = th.unique(edges.data['type'])
            msg = th.empty((edges.src['h'].shape[0], self.out_feat),
                           device=edges.src['h'].device)
            for etype in etypes:
                loc = edges.data['type'] == etype
                w = self.weight[etype].view(self.num_bases, self.submat_in, self.submat_out)
                src = edges.src['h'][loc].view(-1, self.num_bases, self.submat_in)
                sub_msg = th.einsum('abc,bcd->abd', src, w)
                sub_msg = sub_msg.reshape(-1, self.out_feat)
                msg[loc] = sub_msg
        else:
            weight = self.weight.index_select(0, edges.data['type']).view(
                -1, self.submat_in, self.submat_out)
            node = edges.src['h'].view(-1, 1, self.submat_in)
            msg = th.bmm(node, weight).view(-1, self.out_feat)
        if 'norm' in edges.data:
            msg = msg * edges.data['norm']
        return {'msg': msg} 
Example #9
Source File: kroneckers.py    From backpack with MIT License 6 votes vote down vote up
def kfac_mat_prod(factors):
    """Return function v ↦ (A ⊗ B ⊗ ...)v for `factors = [A, B, ...]` """
    assert all_tensors_of_order(order=2, tensors=factors)

    shapes = [list(f.size()) for f in factors]
    _, col_dims = zip(*shapes)

    num_factors = len(shapes)
    equation = kfac_mat_prod_einsum_equation(num_factors)

    @kfacmp_unsqueeze_if_missing_dim(mat_dim=2)
    def kfacmp(mat):
        assert is_matrix(mat)
        _, mat_cols = mat.shape
        mat_reshaped = mat.view(*(col_dims), mat_cols)
        return einsum(equation, mat_reshaped, *factors).contiguous().view(-1, mat_cols)

    return kfacmp 
Example #10
Source File: class_balance.py    From metal with Apache License 2.0 6 votes vote down vote up
def _get_overlaps_tensor(self, L):
        """Transforms the input label matrix to a three-way overlaps tensor.

        Args:
            L: (np.array) An n x m array of LF output labels, in {0,...,k} if
                self.abstains, else in {1,...,k}, generated by m conditionally
                independent LFs on n data points

        Outputs:
            O: (torch.Tensor) A (m, m, m, k, k, k) tensor of the label-specific
            empirical overlap rates; that is,

                O[i,j,k,y1,y2,y3] = P(\lf_i = y1, \lf_j = y2, \lf_k = y3)

            where this quantity is computed empirically by this function, based
            on the label matrix L.
        """
        n, m = L.shape

        # Convert from a (n,m) matrix of ints to a (k_lf, n, m) indicator tensor
        LY = np.array([np.where(L == y, 1, 0) for y in range(self.k_0, self.k + 1)])

        # Form the three-way overlaps matrix
        O = np.einsum("abc,dbe,fbg->cegadf", LY, LY, LY) / n
        return torch.from_numpy(O).float() 
Example #11
Source File: linear.py    From backpack with MIT License 5 votes vote down vote up
def weight(self, ext, module, g_inp, g_out, backproped):
        return einsum("ni,nj->ij", (g_out[0] ** 2, module.input0 ** 2)) 
Example #12
Source File: utils_test.py    From backpack with MIT License 5 votes vote down vote up
def make_random_psd_kfacs(self, num_facs=None):
        def make_quadratic_psd(mat):
            """Make matrix positive semi-definite: A -> AAᵀ."""
            mat_squared = einsum("ij,kj->ik", (mat, mat))
            shift = self.PSD_KFAC_MIN_EIGVAL * self.torch_eye_like(mat_squared)
            return mat_squared + shift

        kfacs = self.make_random_kfacs(num_facs=num_facs)
        return [make_quadratic_psd(fac) for fac in kfacs]

    # Torch helpers
    ######################################################################### 
Example #13
Source File: conv2d.py    From backpack with MIT License 5 votes vote down vote up
def bias(self, ext, module, g_inp, g_out, backproped):
        N_axis = 0
        return (einsum("nchw->nc", g_out[0]) ** 2).sum(N_axis) 
Example #14
Source File: elementwise.py    From backpack with MIT License 5 votes vote down vote up
def _jac_t_mat_prod(self, module, g_inp, g_out, mat):
        self._no_inplace(module)

        df_elementwise = self.df(module, g_inp, g_out)
        return einsum("...,v...->v...", (df_elementwise, mat)) 
Example #15
Source File: batchnorm1d.py    From backpack with MIT License 5 votes vote down vote up
def _jac_t_mat_prod(self, module, g_inp, g_out, mat):
        """
        Note:
        -----
        The Jacobian is *not independent* among the batch dimension, i.e.
        D z_i = D z_i(x_1, ..., x_B).

        This structure breaks the computation of the GGN diagonal,
        for curvature-matrix products it should still work.

        References:
        -----------
        https://kevinzakka.github.io/2016/09/14/batch_normalization/
        https://chrisyeh96.github.io/2017/08/28/deriving-batchnorm-backprop.html
        """
        assert module.affine is True

        N = self.get_batch(module)
        x_hat, var = self.get_normalized_input_and_var(module)
        ivar = 1.0 / (var + module.eps).sqrt()

        dx_hat = einsum("vni,i->vni", (mat, module.weight))

        jac_t_mat = N * dx_hat
        jac_t_mat -= dx_hat.sum(1).unsqueeze(1).expand_as(jac_t_mat)
        jac_t_mat -= einsum("ni,vsi,si->vni", (x_hat, dx_hat, x_hat))
        jac_t_mat = einsum("vni,i->vni", (jac_t_mat, ivar / N))

        return jac_t_mat 
Example #16
Source File: batchnorm1d.py    From backpack with MIT License 5 votes vote down vote up
def _weight_jac_mat_prod(self, module, g_inp, g_out, mat):
        x_hat, _ = self.get_normalized_input_and_var(module)
        return einsum("ni,vi->vni", (x_hat, mat)) 
Example #17
Source File: batchnorm1d.py    From backpack with MIT License 5 votes vote down vote up
def _weight_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch):
        if not sum_batch:
            warn(
                "BatchNorm batch summation disabled."
                "This may not compute meaningful quantities"
            )
        x_hat, _ = self.get_normalized_input_and_var(module)
        equation = "vni,ni->v{}i".format("" if sum_batch is True else "n")
        operands = [mat, x_hat]
        return einsum(equation, operands) 
Example #18
Source File: conv2d.py    From backpack with MIT License 5 votes vote down vote up
def _weight_jac_mat_prod(self, module, g_inp, g_out, mat):
        jac_mat = eingroup("v,o,i,h,w->v,o,ihw", mat)
        X = self.get_unfolded_input(module)

        jac_mat = einsum("nij,vki->vnkj", (X, jac_mat))
        return self.reshape_like_output(jac_mat, module) 
Example #19
Source File: crossentropyloss.py    From backpack with MIT License 5 votes vote down vote up
def _sqrt_hessian(self, module, g_inp, g_out):
        self._check_2nd_order_parameters(module)

        probs = self._get_probs(module)
        tau = torchsqrt(probs)
        V_dim, C_dim = 0, 2
        Id = diag_embed(ones_like(probs), dim1=V_dim, dim2=C_dim)
        Id_tautau = Id - einsum("nv,nc->vnc", tau, tau)
        sqrt_H = einsum("nc,vnc->vnc", tau, Id_tautau)

        if module.reduction == "mean":
            N = module.input0.shape[0]
            sqrt_H /= sqrt(N)

        return sqrt_H 
Example #20
Source File: conv2d.py    From backpack with MIT License 5 votes vote down vote up
def _factor_from_sqrt(self, module, backproped):
        sqrt_ggn = backproped

        sqrt_ggn = convUtils.separate_channels_and_pixels(module, sqrt_ggn)
        sqrt_ggn = einsum("cbij->cbi", (sqrt_ggn,))
        return einsum("cbi,cbl->il", (sqrt_ggn, sqrt_ggn)) 
Example #21
Source File: conv2d.py    From backpack with MIT License 5 votes vote down vote up
def bias(self, ext, module, g_inp, g_out, backproped):
        C_axis = 1
        return (einsum("nchw->nc", g_out[0]) ** 2).sum(C_axis) 
Example #22
Source File: linear.py    From backpack with MIT License 5 votes vote down vote up
def weight(self, ext, module, g_inp, g_out, backproped):
        return einsum("ni,nj->n", (g_out[0] ** 2, module.input0 ** 2)) 
Example #23
Source File: kroneckers.py    From backpack with MIT License 5 votes vote down vote up
def two_kfacs_to_mat(A, B):
    """Given A, B, return A ⊗ B."""
    assert is_matrix(A)
    assert is_matrix(B)

    mat_shape = (
        A.shape[0] * B.shape[0],
        A.shape[1] * B.shape[1],
    )
    mat = einsum("ij,kl->ikjl", (A, B)).contiguous().view(mat_shape)
    return mat 
Example #24
Source File: conv.py    From backpack with MIT License 5 votes vote down vote up
def extract_bias_diagonal(module, sqrt):
    """
    `sqrt` must be the backpropagated quantity for DiagH or DiagGGN(MC)
    """
    V_axis, N_axis = 0, 1
    bias_diagonal = (einsum("vnchw->vnc", sqrt) ** 2).sum([V_axis, N_axis])
    return bias_diagonal 
Example #25
Source File: conv.py    From backpack with MIT License 5 votes vote down vote up
def extract_weight_diagonal(module, input, grad_output):
    """
    input must be the unfolded input to the convolution (see unfold_func)
    and grad_output the backpropagated gradient
    """
    grad_output_viewed = separate_channels_and_pixels(module, grad_output)
    AX = einsum("nkl,vnml->vnkm", (input, grad_output_viewed))
    weight_diagonal = (AX ** 2).sum([0, 1]).transpose(0, 1)
    return weight_diagonal.view_as(module.weight) 
Example #26
Source File: linear.py    From backpack with MIT License 5 votes vote down vote up
def extract_bias_diagonal(module, backproped):
    return einsum("vno->o", backproped ** 2) 
Example #27
Source File: linear.py    From backpack with MIT License 5 votes vote down vote up
def extract_weight_diagonal(module, backproped):
    return einsum("vno,ni->oi", (backproped ** 2, module.input0 ** 2)) 
Example #28
Source File: elmo_embedding.py    From fastNLP with Apache License 2.0 5 votes vote down vote up
def _get_mixed_outputs(self, outputs):
        # outputs: num_layers x batch_size x max_len x hidden_size
        # return: batch_size x max_len x hidden_size
        weights = F.softmax(self.layer_weights + 1 / len(outputs), dim=0).to(outputs)
        outputs = torch.einsum('l,lbij->bij', weights, outputs)
        return self.gamma.to(outputs) * outputs 
Example #29
Source File: conv_knrm.py    From MatchZoo-py with Apache License 2.0 5 votes vote down vote up
def forward(self, inputs):
        """Forward."""

        query, doc = inputs['text_left'], inputs['text_right']

        q_embed = self.embedding(query.long()).transpose(1, 2)
        d_embed = self.embedding(doc.long()).transpose(1, 2)

        q_convs = []
        d_convs = []
        for q_conv, d_conv in zip(self.q_convs, self.d_convs):
            q_convs.append(q_conv(q_embed).transpose(1, 2))
            d_convs.append(d_conv(d_embed).transpose(1, 2))

        KM = []
        for qi in range(self._params['max_ngram']):
            for di in range(self._params['max_ngram']):
                # do not match n-gram with different length if use crossmatch
                if not self._params['use_crossmatch'] and qi != di:
                    continue
                mm = torch.einsum(
                    'bld,brd->blr',
                    F.normalize(q_convs[qi], p=2, dim=-1),
                    F.normalize(d_convs[di], p=2, dim=-1)
                )
                for kernel in self.kernels:
                    K = torch.log1p(kernel(mm).sum(dim=-1)).sum(dim=-1)
                    KM.append(K)

        phi = torch.stack(KM, dim=1)

        out = self.out(phi)
        return out 
Example #30
Source File: drmm.py    From MatchZoo-py with Apache License 2.0 5 votes vote down vote up
def forward(self, inputs):
        """Forward."""

        # Scalar dimensions referenced here:
        #   B = batch size (number of sequences)
        #   D = embedding size
        #   L = `input_left` sequence length
        #   R = `input_right` sequence length
        #   H = histogram size
        #   K = size of top-k

        # Left input and right input.
        # query: shape = [B, L]
        # doc: shape = [B, L, H]
        # Note here, the doc is the matching histogram between original query
        # and original document.

        query, match_hist = inputs['text_left'], inputs['match_histogram']

        # shape = [B, L]
        mask_query = (query == self._params['mask_value'])

        # Process left input.
        # shape = [B, L, D]
        embed_query = self.embedding(query.long())

        # shape = [B, L]
        attention_probs = self.attention(embed_query, mask_query)

        # shape = [B, L]
        dense_output = self.mlp(match_hist).squeeze(dim=-1)

        x = torch.einsum('bl,bl->b', dense_output, attention_probs)

        out = self.out(x.unsqueeze(dim=-1))
        return out