Python torch.transpose() Examples

The following are 30 code examples of torch.transpose(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch , or try the search function .
Example #1
Source File: torch.py    From yolo2-pytorch with GNU Lesser General Public License v3.0 6 votes vote down vote up
def intersection_area(yx_min1, yx_max1, yx_min2, yx_max2):
    """
    Calculates the intersection area of two lists of bounding boxes.
    :author 申瑞珉 (Ruimin Shen)
    :param yx_min1: The top left coordinates (y, x) of the first list (size [N1, 2]) of bounding boxes.
    :param yx_max1: The bottom right coordinates (y, x) of the first list (size [N1, 2]) of bounding boxes.
    :param yx_min2: The top left coordinates (y, x) of the second list (size [N2, 2]) of bounding boxes.
    :param yx_max2: The bottom right coordinates (y, x) of the second list (size [N2, 2]) of bounding boxes.
    :return: The matrix (size [N1, N2]) of the intersection area.
    """
    ymin1, xmin1 = torch.split(yx_min1, 1, -1)
    ymax1, xmax1 = torch.split(yx_max1, 1, -1)
    ymin2, xmin2 = torch.split(yx_min2, 1, -1)
    ymax2, xmax2 = torch.split(yx_max2, 1, -1)
    max_ymin = torch.max(ymin1.repeat(1, ymin2.size(0)), torch.transpose(ymin2, 0, 1).repeat(ymin1.size(0), 1)) # PyTorch's bug
    min_ymax = torch.min(ymax1.repeat(1, ymax2.size(0)), torch.transpose(ymax2, 0, 1).repeat(ymax1.size(0), 1)) # PyTorch's bug
    height = torch.clamp(min_ymax - max_ymin, min=0)
    max_xmin = torch.max(xmin1.repeat(1, xmin2.size(0)), torch.transpose(xmin2, 0, 1).repeat(xmin1.size(0), 1)) # PyTorch's bug
    min_xmax = torch.min(xmax1.repeat(1, xmax2.size(0)), torch.transpose(xmax2, 0, 1).repeat(xmax1.size(0), 1)) # PyTorch's bug
    width = torch.clamp(min_xmax - max_xmin, min=0)
    return height * width 
Example #2
Source File: SINet.py    From ext_portrait_segmentation with MIT License 6 votes vote down vote up
def channel_shuffle(x, groups):
    batchsize, num_channels, height, width = x.data.size()

    channels_per_group = num_channels // groups

    # reshape
    x = x.view(batchsize, groups,
               channels_per_group, height, width)

    # transpose
    # - contiguous() required if transpose() is used before view().
    #   See https://github.com/pytorch/pytorch/issues/764
    x = torch.transpose(x, 1, 2).contiguous()

    # flatten
    x = x.view(batchsize, -1, height, width)

    return x 
Example #3
Source File: model.py    From atec-nlp with MIT License 6 votes vote down vote up
def forward(self, inp, hidden):
        outp = self.bilstm.forward(inp, hidden)[0]
        size = outp.size()  # [bsz, len, nhid]
        compressed_embeddings = outp.view(-1, size[2])  # [bsz*len, nhid*2]
        transformed_inp = torch.transpose(inp, 0, 1).contiguous()  # [bsz, len]
        transformed_inp = transformed_inp.view(size[0], 1, size[1])  # [bsz, 1, len]
        concatenated_inp = [transformed_inp for i in range(self.attention_hops)]
        concatenated_inp = torch.cat(concatenated_inp, 1)  # [bsz, hop, len]

        hbar = self.tanh(self.ws1(self.drop(compressed_embeddings)))  # [bsz*len, attention-unit]
        alphas = self.ws2(hbar).view(size[0], size[1], -1)  # [bsz, len, hop]
        alphas = torch.transpose(alphas, 1, 2).contiguous()  # [bsz, hop, len]
        penalized_alphas = alphas + (
            -10000 * (concatenated_inp == self.dictionary.word2idx['<pad>']).float())
            # [bsz, hop, len] + [bsz, hop, len]
        alphas = self.softmax(penalized_alphas.view(-1, size[1]))  # [bsz*hop, len]
        alphas = alphas.view(size[0], self.attention_hops, size[1])  # [bsz, hop, len]
        # Performs a batch matrix-matrix product of matrices
        return torch.bmm(alphas, outp), alphas 
Example #4
Source File: mmd.py    From transferlearning with MIT License 6 votes vote down vote up
def cmmd(source, target, s_label, t_label, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
    s_label = s_label.cpu()
    s_label = s_label.view(32,1)
    s_label = torch.zeros(32, 31).scatter_(1, s_label.data, 1)
    s_label = Variable(s_label).cuda()

    t_label = t_label.cpu()
    t_label = t_label.view(32, 1)
    t_label = torch.zeros(32, 31).scatter_(1, t_label.data, 1)
    t_label = Variable(t_label).cuda()

    batch_size = int(source.size()[0])
    kernels = guassian_kernel(source, target,
                              kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)
    loss = 0
    XX = kernels[:batch_size, :batch_size]
    YY = kernels[batch_size:, batch_size:]
    XY = kernels[:batch_size, batch_size:]
    loss += torch.mean(torch.mm(s_label, torch.transpose(s_label, 0, 1)) * XX +
                      torch.mm(t_label, torch.transpose(t_label, 0, 1)) * YY -
                      2 * torch.mm(s_label, torch.transpose(t_label, 0, 1)) * XY)
    return loss 
Example #5
Source File: test_manifold_basic.py    From geoopt with Apache License 2.0 6 votes vote down vote up
def sphere_compliment_case():
    torch.manual_seed(42)
    shape = manifold_shapes[geoopt.manifolds.Sphere]
    complement = torch.rand(shape[-1], 1, dtype=torch.float64)

    Q, _ = geoopt.linalg.batch_linalg.qr(complement)
    P = -Q @ Q.transpose(-1, -2)
    P[..., torch.arange(P.shape[-2]), torch.arange(P.shape[-2])] += 1

    ex = torch.randn(*shape, dtype=torch.float64)
    ev = torch.randn(*shape, dtype=torch.float64)
    x = (ex @ P.t()) / torch.norm(ex @ P.t())
    v = (ev - (x @ ev) * x) @ P.t()

    manifold = geoopt.Sphere(complement=complement)
    x = geoopt.ManifoldTensor(x, manifold=manifold)
    case = UnaryCase(shape, x, ex, v, ev, manifold)
    yield case
    manifold = geoopt.SphereExact(complement=complement)
    x = geoopt.ManifoldTensor(x, manifold=manifold)
    case = UnaryCase(shape, x, ex, v, ev, manifold)
    yield case 
Example #6
Source File: Patient2Vec.py    From Patient2Vec with MIT License 6 votes vote down vote up
def convolutional_layer(self, inputs):
        convolution_all = []
        conv_wts = []
        for i in range(self.seq_len):
            convolution_one_month = []
            for j in range(self.pad_size):
                convolution = self.conv(torch.unsqueeze(inputs[:, i, j], dim=1))
                convolution_one_month.append(convolution)
            convolution_one_month = torch.stack(convolution_one_month)
            convolution_one_month = torch.squeeze(convolution_one_month, dim=3)
            convolution_one_month = torch.transpose(convolution_one_month, 0, 1)
            convolution_one_month = torch.transpose(convolution_one_month, 1, 2)
            convolution_one_month = torch.squeeze(convolution_one_month, dim=1)
            convolution_one_month = self.func_tanh(convolution_one_month)
            convolution_one_month = torch.unsqueeze(convolution_one_month, dim=1)
            vec = torch.bmm(convolution_one_month, inputs[:, i])
            convolution_all.append(vec)
            conv_wts.append(convolution_one_month)
        convolution_all = torch.stack(convolution_all, dim=1)
        convolution_all = torch.squeeze(convolution_all, dim=2)
        conv_wts = torch.squeeze(torch.stack(conv_wts, dim=1), dim=2)
        return convolution_all, conv_wts 
Example #7
Source File: cls_fe_dct_bases.py    From signaltrain with GNU General Public License v3.0 6 votes vote down vote up
def tied_transform(analysis, x_ft, hop):
    """
        A method to compute an orthogonal transform for audio signals.
        It will use the analysis weights to perform the reconstruction, via
        transposed convolution.

        Arguments              :
            analysis           :   (object)         Pytorch module
            x_ft               :   (Torch Tensor)   Tensor containing the transformed signal
            hop                :   (int)            Hop-size
        Returns                :
            wave_from          :   (Torch Tensor)   Reconstructed waveform
    """

    sz = analysis.conv_analysis.weight.size()[0]
    wave_form = nn.functional.conv_transpose2d(torch.transpose(x_ft, 2, 1).unsqueeze(3),
                                               analysis.conv_analysis.weight.unsqueeze(3),
                                               padding=(sz, 0), stride=(hop, 1))
    return wave_form.squeeze(3) 
Example #8
Source File: Patient2Vec.py    From Patient2Vec with MIT License 6 votes vote down vote up
def get_loss(pred, y, criterion, mtr, a=0.5):
    """
    To calculate loss
    :param pred: predicted value
    :param y: actual value
    :param criterion: nn.CrossEntropyLoss
    :param mtr: beta matrix
    """
    mtr_t = torch.transpose(mtr, 1, 2)
    aa = torch.bmm(mtr, mtr_t)
    loss_fn = 0
    for i in range(aa.size()[0]):
        aai = torch.add(aa[i, ], Variable(torch.neg(torch.eye(mtr.size()[1]))))
        loss_fn += torch.trace(torch.mul(aai, aai).data)
    loss_fn /= aa.size()[0]
    loss = torch.add(criterion(pred, y), Variable(torch.FloatTensor([loss_fn * a])))
    return loss 
Example #9
Source File: module.py    From Tacotron-pytorch with Apache License 2.0 6 votes vote down vote up
def forward(self, input_):
        """

        :param input_: sequences
        :return: outputs
        """
        batch_size = input_.size()[0]
        if self.time_dim == 2:
            input_ = input_.transpose(1, 2).contiguous()
        input_ = input_.view(-1, self.input_size)

        out = self.linear(input_).view(batch_size, -1, self.output_size)

        if self.time_dim == 2:
            out = out.contiguous().transpose(1, 2)

        return out 
Example #10
Source File: Modules.py    From GST-Tacotron with MIT License 6 votes vote down vote up
def max_pool1d(inputs, kernel_size, stride=1, padding='same'):
    '''
    inputs: [N, T, C]
    outputs: [N, T // stride, C]
    '''
    inputs = inputs.transpose(1, 2)  # [N, C, T]
    if padding == 'same':
        left = (kernel_size - 1) // 2
        right = (kernel_size - 1) - left
        pad = (left, right)
    else:
        pad = (0, 0)
    inputs = F.pad(inputs, pad)
    outputs = F.max_pool1d(inputs, kernel_size, stride)  # [N, C, T]
    outputs = outputs.transpose(1, 2)  # [N, T, C]

    return outputs 
Example #11
Source File: test_manifold_basic.py    From geoopt with Apache License 2.0 6 votes vote down vote up
def birkhoff_case():
    torch.manual_seed(42)
    shape = manifold_shapes[geoopt.manifolds.BirkhoffPolytope]
    ex = torch.randn(*shape, dtype=torch.float64).abs()
    ev = torch.randn(*shape, dtype=torch.float64)
    max_iter = 100
    eps = 1e-12
    tol = 1e-5
    iter = 0
    c = 1.0 / (torch.sum(ex, dim=-2, keepdim=True) + eps)
    r = 1.0 / (torch.matmul(ex, c.transpose(-1, -2)) + eps)
    while iter < max_iter:
        iter += 1
        cinv = torch.matmul(r.transpose(-1, -2), ex)
        if torch.max(torch.abs(cinv * c - 1)) <= tol:
            break
        c = 1.0 / (cinv + eps)
        r = 1.0 / ((ex @ c.transpose(-1, -2)) + eps)
    x = ex * (r @ c)

    v = proju_original(x, ev)
    manifold = geoopt.manifolds.BirkhoffPolytope()
    x = geoopt.ManifoldTensor(x, manifold=manifold)
    case = UnaryCase(shape, x, ex, v, ev, manifold)
    yield case 
Example #12
Source File: ReadoutFunction.py    From nmp_qc with MIT License 6 votes vote down vote up
def r_duvenaud(self, h):
        # layers
        aux = []
        for l in range(len(h)):
            param_sz = self.learn_args[l].size()
            parameter_mat = torch.t(self.learn_args[l])[None, ...].expand(h[l].size(0), param_sz[1],
                                                                                      param_sz[0])

            aux.append(torch.transpose(torch.bmm(parameter_mat, torch.transpose(h[l], 1, 2)), 1, 2))

            for j in range(0, aux[l].size(1)):
                # Mask whole 0 vectors
                aux[l][:, j, :] = nn.Softmax()(aux[l][:, j, :].clone())*(torch.sum(aux[l][:, j, :] != 0, 1) > 0).expand_as(aux[l][:, j, :]).type_as(aux[l])

        aux = torch.sum(torch.sum(torch.stack(aux, 3), 3), 1)
        return self.learn_modules[0](torch.squeeze(aux)) 
Example #13
Source File: score_fun.py    From dgl with Apache License 2.0 6 votes vote down vote up
def create_neg(self, neg_head):
        if neg_head:
            def fn(heads, relations, tails, num_chunks, chunk_size, neg_sample_size):
                hidden_dim = heads.shape[1]
                heads = heads.reshape(num_chunks, neg_sample_size, hidden_dim)
                heads = th.transpose(heads, 1, 2)
                tmp = (tails * relations).reshape(num_chunks, chunk_size, hidden_dim)
                return th.bmm(tmp, heads)
            return fn
        else:
            def fn(heads, relations, tails, num_chunks, chunk_size, neg_sample_size):
                hidden_dim = tails.shape[1]
                tails = tails.reshape(num_chunks, neg_sample_size, hidden_dim)
                tails = th.transpose(tails, 1, 2)
                tmp = (heads * relations).reshape(num_chunks, chunk_size, hidden_dim)
                return th.bmm(tmp, tails)
            return fn 
Example #14
Source File: esim.py    From video_captioning_rl with MIT License 6 votes vote down vote up
def similarity(self, s1, l1, s2, l2):
        """
        :param s1: [B, t1, D]
        :param l1: [B]
        :param s2: [B, t2, D]
        :param l2: [B]
        :return:
        """
        batch_size = s1.size(0)
        t1 = s1.size(1)
        t2 = s2.size(1)
        S = torch.bmm(s1, s2.transpose(1,
                                       2))  # [B, t1, D] * [B, D, t2] -> [B, t1, t2] S is the similarity matrix from biDAF paper. [B, T1, T2]

        s_mask = S.data.new(*S.size()).fill_(1).byte()  # [B, T1, T2]
        # Init similarity mask using lengths
        for i, (l_1, l_2) in enumerate(zip(l1, l2)):
            s_mask[i][:l_1, :l_2] = 0

        s_mask = Variable(s_mask)
        S.data.masked_fill_(s_mask.data.byte(), -math.inf)
        return S 
Example #15
Source File: torch.py    From yolo2-pytorch with GNU Lesser General Public License v3.0 6 votes vote down vote up
def batch_intersection_area(yx_min1, yx_max1, yx_min2, yx_max2):
    """
    Calculates the intersection area of two lists of bounding boxes for N independent batches.
    :author 申瑞珉 (Ruimin Shen)
    :param yx_min1: The top left coordinates (y, x) of the first lists (size [N, N1, 2]) of bounding boxes.
    :param yx_max1: The bottom right coordinates (y, x) of the first lists (size [N, N1, 2]) of bounding boxes.
    :param yx_min2: The top left coordinates (y, x) of the second lists (size [N, N2, 2]) of bounding boxes.
    :param yx_max2: The bottom right coordinates (y, x) of the second lists (size [N, N2, 2]) of bounding boxes.
    :return: The matrics (size [N, N1, N2]) of the intersection area.
    """
    ymin1, xmin1 = torch.split(yx_min1, 1, -1)
    ymax1, xmax1 = torch.split(yx_max1, 1, -1)
    ymin2, xmin2 = torch.split(yx_min2, 1, -1)
    ymax2, xmax2 = torch.split(yx_max2, 1, -1)
    max_ymin = torch.max(ymin1.repeat(1, 1, ymin2.size(1)), torch.transpose(ymin2, 1, 2).repeat(1, ymin1.size(1), 1)) # PyTorch's bug
    min_ymax = torch.min(ymax1.repeat(1, 1, ymax2.size(1)), torch.transpose(ymax2, 1, 2).repeat(1, ymax1.size(1), 1)) # PyTorch's bug
    height = torch.clamp(min_ymax - max_ymin, min=0)
    max_xmin = torch.max(xmin1.repeat(1, 1, xmin2.size(1)), torch.transpose(xmin2, 1, 2).repeat(1, xmin1.size(1), 1)) # PyTorch's bug
    min_xmax = torch.min(xmax1.repeat(1, 1, xmax2.size(1)), torch.transpose(xmax2, 1, 2).repeat(1, xmax1.size(1), 1)) # PyTorch's bug
    width = torch.clamp(min_xmax - max_xmin, min=0)
    return height * width 
Example #16
Source File: transpose.py    From rlgraph with Apache License 2.0 6 votes vote down vote up
def _graph_fn_call(self, key, inputs):
        """
        Transposes the input by flipping batch and time ranks.
        """
        if get_backend() == "tf":
            # Flip around ranks 0 and 1.
            transposed = tf.transpose(
                inputs,
                perm=(1, 0) + tuple(i for i in range(2, len(inputs.shape.as_list()))), name="transpose"
            )
            if self.output_is_time_major is None:
                transposed._time_rank = 0 if self.output_time_majors[key] is True else 1
                transposed._batch_rank = 0 if self.output_time_majors[key] is False else 1
            else:
                transposed._time_rank = 0 if self.output_is_time_major is True else 1
                transposed._batch_rank = 0 if self.output_is_time_major is False else 1

            return transposed

        elif get_backend() == "pytorch":
            perm = (1, 0) + tuple(i for i in range(2, len(list(inputs.shape))))
            return torch.transpose(inputs, perm) 
Example #17
Source File: layers.py    From robustness with Apache License 2.0 6 votes vote down vote up
def _dropping(self, delta):
        weight = self.conv.weight * self.mask
        ### Sum up all kernels
        ### Assume only apply to 1x1 conv to speed up
        assert weight.size()[-1] == 1
        weight = weight.abs().squeeze()
        assert weight.size()[0] == self.out_channels
        assert weight.size()[1] == self.in_channels
        d_out = self.out_channels // self.groups
        ### Shuffle weight
        weight = weight.view(d_out, self.groups, self.in_channels)
        weight = weight.transpose(0, 1).contiguous()
        weight = weight.view(self.out_channels, self.in_channels)
        ### Sort and drop
        for i in range(self.groups):
            wi = weight[i * d_out:(i + 1) * d_out, :]
            ### Take corresponding delta index
            di = wi.sum(0).sort()[1][self.count:self.count + delta]
            for d in di.data:
                self._mask[i::self.groups, d, :, :].fill_(0)
        self.count = self.count + delta 
Example #18
Source File: MessageFunction.py    From nmp_qc with MIT License 6 votes vote down vote up
def m_ggnn(self, h_v, h_w, e_vw, opt={}):

        m = Variable(torch.zeros(h_w.size(0), h_w.size(1), self.args['out']).type_as(h_w.data))

        for w in range(h_w.size(1)):
            if torch.nonzero(e_vw[:, w, :].data).size():
                for i, el in enumerate(self.args['e_label']):
                    ind = (el == e_vw[:,w,:]).type_as(self.learn_args[0][i])

                    parameter_mat = self.learn_args[0][i][None, ...].expand(h_w.size(0), self.learn_args[0][i].size(0),
                                                                            self.learn_args[0][i].size(1))

                    m_w = torch.transpose(torch.bmm(torch.transpose(parameter_mat, 1, 2),
                                                                        torch.transpose(torch.unsqueeze(h_w[:, w, :], 1),
                                                                                        1, 2)), 1, 2)
                    m_w = torch.squeeze(m_w)
                    m[:,w,:] = ind.expand_as(m_w)*m_w
        return m 
Example #19
Source File: deeplabv2.py    From SPNet with MIT License 6 votes vote down vote up
def __init__(self, n_classes, n_blocks, pyramids, class_emb):
        super(DeepLabV2, self).__init__()
        if class_emb is not None:
            self.emb_size = class_emb.shape[1]
            self.class_emb = torch.transpose(class_emb, 1, 0).float().cuda()
        self.add_module(
            "layer1",
            nn.Sequential(
                OrderedDict(
                    [
                        ("conv1", _ConvBatchNormReLU(3, 64, 7, 2, 3, 1)),
                        ("pool", nn.MaxPool2d(3, 2, 1, ceil_mode=True)),
                    ]
                )
            ),
        )
        self.add_module("layer2", _ResBlock(n_blocks[0], 64, 64, 256, 1, 1))
        self.add_module("layer3", _ResBlock(n_blocks[1], 256, 128, 512, 2, 1))
        self.add_module("layer4", _ResBlock(n_blocks[2], 512, 256, 1024, 1, 2))
        self.add_module("layer5", _ResBlock(n_blocks[3], 1024, 512, 2048, 1, 4))
        self.add_module("aspp", _ASPPModule(2048, n_classes, pyramids))

        print("DeepLab Outputs: {}".format(n_classes)) 
Example #20
Source File: tensor.py    From dgl with Apache License 2.0 5 votes vote down vote up
def swapaxes(input, axis1, axis2):
    return th.transpose(input, axis1, axis2) 
Example #21
Source File: model.py    From atec-nlp with MIT License 5 votes vote down vote up
def forward(self, inp, hidden):
        emb = self.drop(self.encoder(inp))
        outp = self.bilstm(emb, hidden)[0]
        if self.pooling == 'mean':
            outp = torch.mean(outp, 0).squeeze()
        elif self.pooling == 'max':
            outp = torch.max(outp, 0)[0].squeeze()
        elif self.pooling == 'all' or self.pooling == 'all-word':
            outp = torch.transpose(outp, 0, 1).contiguous()
        return outp, emb 
Example #22
Source File: attention_augmented_conv.py    From Attention-Augmented-Conv2d with MIT License 5 votes vote down vote up
def relative_logits_1d(self, q, rel_k, H, W, Nh, case):
        rel_logits = torch.einsum('bhxyd,md->bhxym', q, rel_k)
        rel_logits = torch.reshape(rel_logits, (-1, Nh * H, W, 2 * W - 1))
        rel_logits = self.rel_to_abs(rel_logits)

        rel_logits = torch.reshape(rel_logits, (-1, Nh, H, W, W))
        rel_logits = torch.unsqueeze(rel_logits, dim=3)
        rel_logits = rel_logits.repeat((1, 1, 1, H, 1, 1))

        if case == "w":
            rel_logits = torch.transpose(rel_logits, 3, 4)
        elif case == "h":
            rel_logits = torch.transpose(rel_logits, 2, 4).transpose(4, 5).transpose(3, 5)
        rel_logits = torch.reshape(rel_logits, (-1, Nh, H * W, H * W))
        return rel_logits 
Example #23
Source File: attention_augmented_conv.py    From Attention-Augmented-Conv2d with MIT License 5 votes vote down vote up
def relative_logits(self, q):
        B, Nh, dk, H, W = q.size()
        q = torch.transpose(q, 2, 4).transpose(2, 3)

        key_rel_w = nn.Parameter(torch.randn((2 * W - 1, dk), requires_grad=True)).to(device)
        rel_logits_w = self.relative_logits_1d(q, key_rel_w, H, W, Nh, "w")

        key_rel_h = nn.Parameter(torch.randn((2 * H - 1, dk), requires_grad=True)).to(device)
        rel_logits_h = self.relative_logits_1d(torch.transpose(q, 2, 3), key_rel_h, W, H, Nh, "h")

        return rel_logits_h, rel_logits_w 
Example #24
Source File: attention_augmented_conv.py    From Attention-Augmented-Conv2d with MIT License 5 votes vote down vote up
def forward(self, x):
        # Input x
        # (batch_size, channels, height, width)
        batch, _, height, width = x.size()

        # conv_out
        # (batch_size, out_channels, height, width)
        conv_out = self.conv_out(x)

        # flat_q, flat_k, flat_v
        # (batch_size, Nh, height * width, dvh or dkh)
        # dvh = dv / Nh, dkh = dk / Nh
        # q, k, v
        # (batch_size, Nh, height, width, dv or dk)
        flat_q, flat_k, flat_v, q, k, v = self.compute_flat_qkv(x, self.dk, self.dv, self.Nh)
        logits = torch.matmul(flat_q.transpose(2, 3), flat_k)

        if self.relative:
            h_rel_logits, w_rel_logits = self.relative_logits(q)
            logits += h_rel_logits
            logits += w_rel_logits
        weights = F.softmax(logits, dim=-1)

        # attn_out
        # (batch, Nh, height * width, dvh)
        attn_out = torch.matmul(weights, flat_v.transpose(2, 3))
        attn_out = torch.reshape(attn_out, (batch, self.Nh, self.dv // self.Nh, height, width))
        # combine_heads_2d
        # (batch, out_channels, height, width)
        attn_out = self.combine_heads_2d(attn_out)
        attn_out = self.attn_out(attn_out)
        return torch.cat((conv_out, attn_out), dim=1) 
Example #25
Source File: attention_augmented_conv.py    From Attention-Augmented-Conv2d with MIT License 5 votes vote down vote up
def relative_logits(self, q):
        B, Nh, dk, H, W = q.size()
        q = torch.transpose(q, 2, 4).transpose(2, 3)

        rel_logits_w = self.relative_logits_1d(q, self.key_rel_w, H, W, Nh, "w")
        rel_logits_h = self.relative_logits_1d(torch.transpose(q, 2, 3), self.key_rel_h, W, H, Nh, "h")

        return rel_logits_h, rel_logits_w 
Example #26
Source File: attention_augmented_conv.py    From Attention-Augmented-Conv2d with MIT License 5 votes vote down vote up
def forward(self, x):
        # Input x
        # (batch_size, channels, height, width)
        # batch, _, height, width = x.size()

        # conv_out
        # (batch_size, out_channels, height, width)
        conv_out = self.conv_out(x)
        batch, _, height, width = conv_out.size()

        # flat_q, flat_k, flat_v
        # (batch_size, Nh, height * width, dvh or dkh)
        # dvh = dv / Nh, dkh = dk / Nh
        # q, k, v
        # (batch_size, Nh, height, width, dv or dk)
        flat_q, flat_k, flat_v, q, k, v = self.compute_flat_qkv(x, self.dk, self.dv, self.Nh)
        logits = torch.matmul(flat_q.transpose(2, 3), flat_k)
        if self.relative:
            h_rel_logits, w_rel_logits = self.relative_logits(q)
            logits += h_rel_logits
            logits += w_rel_logits
        weights = F.softmax(logits, dim=-1)

        # attn_out
        # (batch, Nh, height * width, dvh)
        attn_out = torch.matmul(weights, flat_v.transpose(2, 3))
        attn_out = torch.reshape(attn_out, (batch, self.Nh, self.dv // self.Nh, height, width))
        # combine_heads_2d
        # (batch, out_channels, height, width)
        attn_out = self.combine_heads_2d(attn_out)
        attn_out = self.attn_out(attn_out)
        return torch.cat((conv_out, attn_out), dim=1) 
Example #27
Source File: attention_augmented_conv.py    From Attention-Augmented-Conv2d with MIT License 5 votes vote down vote up
def relative_logits_1d(self, q, rel_k, H, W, Nh, case):
        rel_logits = torch.einsum('bhxyd,md->bhxym', q, rel_k)
        rel_logits = torch.reshape(rel_logits, (-1, Nh * H, W, 2 * W - 1))
        rel_logits = self.rel_to_abs(rel_logits)

        rel_logits = torch.reshape(rel_logits, (-1, Nh, H, W, W))
        rel_logits = torch.unsqueeze(rel_logits, dim=3)
        rel_logits = rel_logits.repeat((1, 1, 1, H, 1, 1))

        if case == "w":
            rel_logits = torch.transpose(rel_logits, 3, 4)
        elif case == "h":
            rel_logits = torch.transpose(rel_logits, 2, 4).transpose(4, 5).transpose(3, 5)
        rel_logits = torch.reshape(rel_logits, (-1, Nh, H * W, H * W))
        return rel_logits 
Example #28
Source File: attention_augmented_conv.py    From Attention-Augmented-Conv2d with MIT License 5 votes vote down vote up
def relative_logits(self, q):
        B, Nh, dk, H, W = q.size()
        q = torch.transpose(q, 2, 4).transpose(2, 3)

        rel_logits_w = self.relative_logits_1d(q, self.key_rel_w, H, W, Nh, "w")
        rel_logits_h = self.relative_logits_1d(torch.transpose(q, 2, 3), self.key_rel_h, W, H, Nh, "h")

        return rel_logits_h, rel_logits_w 
Example #29
Source File: attention_augmented_conv.py    From Attention-Augmented-Conv2d with MIT License 5 votes vote down vote up
def forward(self, x):
        # Input x
        # (batch_size, channels, height, width)
        # batch, _, height, width = x.size()

        # conv_out
        # (batch_size, out_channels, height, width)
        conv_out = self.conv_out(x)
        batch, _, height, width = conv_out.size()

        # flat_q, flat_k, flat_v
        # (batch_size, Nh, height * width, dvh or dkh)
        # dvh = dv / Nh, dkh = dk / Nh
        # q, k, v
        # (batch_size, Nh, height, width, dv or dk)
        flat_q, flat_k, flat_v, q, k, v = self.compute_flat_qkv(x, self.dk, self.dv, self.Nh)
        logits = torch.matmul(flat_q.transpose(2, 3), flat_k)
        if self.relative:
            h_rel_logits, w_rel_logits = self.relative_logits(q)
            logits += h_rel_logits
            logits += w_rel_logits
        weights = F.softmax(logits, dim=-1)

        # attn_out
        # (batch, Nh, height * width, dvh)
        attn_out = torch.matmul(weights, flat_v.transpose(2, 3))
        attn_out = torch.reshape(attn_out, (batch, self.Nh, self.dv // self.Nh, height, width))
        # combine_heads_2d
        # (batch, out_channels, height, width)
        attn_out = self.combine_heads_2d(attn_out)
        attn_out = self.attn_out(attn_out)
        return torch.cat((conv_out, attn_out), dim=1) 
Example #30
Source File: lednet.py    From SegmenTron with Apache License 2.0 5 votes vote down vote up
def channel_shuffle(x, groups):
        n, c, h, w = x.size()

        channels_per_group = c // groups
        x = x.view(n, groups, channels_per_group, h, w)
        x = torch.transpose(x, 1, 2).contiguous()
        x = x.view(n, -1, h, w)

        return x