Python torch.bmm() Examples

The following are 30 code examples of torch.bmm(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch , or try the search function .
Example #1
Source File: modules.py    From BAMnet with Apache License 2.0 7 votes vote down vote up
def forward(self, x, x_len, atten_mask):
        CoAtt = torch.bmm(x, x.transpose(1, 2))
        CoAtt = atten_mask.unsqueeze(1) * CoAtt - (1 - atten_mask).unsqueeze(1) * INF
        CoAtt = torch.softmax(CoAtt, dim=-1)
        new_x = torch.cat([torch.bmm(CoAtt, x), x], -1)

        sorted_x_len, indx = torch.sort(x_len, 0, descending=True)
        new_x = pack_padded_sequence(new_x[indx], sorted_x_len.data.tolist(), batch_first=True)

        h0 = to_cuda(torch.zeros(2, x_len.size(0), self.hidden_size // 2), self.use_cuda)
        c0 = to_cuda(torch.zeros(2, x_len.size(0), self.hidden_size // 2), self.use_cuda)
        packed_h, (packed_h_t, _) = self.model(new_x, (h0, c0))

        # restore the sorting
        _, inverse_indx = torch.sort(indx, 0)
        packed_h_t = torch.cat([packed_h_t[i] for i in range(packed_h_t.size(0))], -1)
        restore_packed_h_t = packed_h_t[inverse_indx]
        output = restore_packed_h_t
        return output 
Example #2
Source File: model.py    From VSE-C with MIT License 7 votes vote down vote up
def forward(self, encoding, lengths):
        lengths = Variable(torch.LongTensor(lengths))
        if torch.cuda.is_available():
            lengths = lengths.cuda()
        if self.method == 'mean':
            encoding_pad = nn.utils.rnn.pack_padded_sequence(encoding, lengths.data.tolist(), batch_first=True)
            encoding = nn.utils.rnn.pad_packed_sequence(encoding_pad, batch_first=True, padding_value=0)[0]
            lengths = lengths.float().view(-1, 1)
            return encoding.sum(1) / lengths, None
        elif self.method == 'max':
            return encoding.max(1)  # [bsz, in_dim], [bsz, in_dim] (position)
        elif self.method == 'attn':
            size = encoding.size()  # [bsz, len, in_dim]
            x_flat = encoding.contiguous().view(-1, size[2])  # [bsz*len, in_dim]
            hbar = self.tanh(self.ws1(x_flat))  # [bsz*len, attn_hid]
            alphas = self.ws2(hbar).view(size[0], size[1])  # [bsz, len]
            alphas = nn.utils.rnn.pack_padded_sequence(alphas, lengths.data.tolist(), batch_first=True)
            alphas = nn.utils.rnn.pad_packed_sequence(alphas, batch_first=True, padding_value=-1e8)[0]
            alphas = functional.softmax(alphas, dim=1)  # [bsz, len]
            alphas = alphas.view(size[0], 1, size[1])  # [bsz, 1, len]
            return torch.bmm(alphas, encoding).squeeze(1), alphas  # [bsz, in_dim], [bsz, len]
        elif self.method == 'last':
            return torch.cat([encoding[i][lengths[i] - 1] for i in range(encoding.size(0))], dim=0), None 
Example #3
Source File: Patient2Vec.py    From Patient2Vec with MIT License 6 votes vote down vote up
def get_loss(pred, y, criterion, mtr, a=0.5):
    """
    To calculate loss
    :param pred: predicted value
    :param y: actual value
    :param criterion: nn.CrossEntropyLoss
    :param mtr: beta matrix
    """
    mtr_t = torch.transpose(mtr, 1, 2)
    aa = torch.bmm(mtr, mtr_t)
    loss_fn = 0
    for i in range(aa.size()[0]):
        aai = torch.add(aa[i, ], Variable(torch.neg(torch.eye(mtr.size()[1]))))
        loss_fn += torch.trace(torch.mul(aai, aai).data)
    loss_fn /= aa.size()[0]
    loss = torch.add(criterion(pred, y), Variable(torch.FloatTensor([loss_fn * a])))
    return loss 
Example #4
Source File: modules.py    From BAMnet with Apache License 2.0 6 votes vote down vote up
def forward(self, query_embed, in_memory_embed, atten_mask=None):
        if self.atten_type == 'simple': # simple attention
            attention = torch.bmm(in_memory_embed, query_embed.unsqueeze(2)).squeeze(2)
        elif self.atten_type == 'mul': # multiplicative attention
            attention = torch.bmm(in_memory_embed, torch.mm(query_embed, self.W).unsqueeze(2)).squeeze(2)
        elif self.atten_type == 'add': # additive attention
            attention = torch.tanh(torch.mm(in_memory_embed.view(-1, in_memory_embed.size(-1)), self.W2)\
                .view(in_memory_embed.size(0), -1, self.W2.size(-1)) \
                + torch.mm(query_embed, self.W).unsqueeze(1))
            attention = torch.mm(attention.view(-1, attention.size(-1)), self.W3).view(attention.size(0), -1)
        else:
            raise RuntimeError('Unknown atten_type: {}'.format(self.atten_type))

        if atten_mask is not None:
            # Exclude masked elements from the softmax
            attention = atten_mask * attention - (1 - atten_mask) * INF
        return attention 
Example #5
Source File: operators.py    From Fast_Seg with Apache License 2.0 6 votes vote down vote up
def forward(self, x):
        res = x
        A = self.down(res)
        B = self.gather_down(res)
        b, c, h, w = A.size()
        A = A.view(b, c, -1)  # (b, c, h*w)
        B = B.view(b, c, -1)  # (b, c, h*w)
        B = self.softmax(B)
        B = B.permute(0, 2, 1)  # (b, h*w, c)

        G = torch.bmm(A, B)  # (b,c,c)

        C = self.distribue_down(res)
        C = C.view(b, c, -1)  # (b, c, h*w)
        C = self.softmax(C)
        C = C.permute(0, 2, 1)  # (b, h*w, c)

        atten = torch.bmm(C, G)  # (b, h*w, c)
        atten = atten.permute(0, 2, 1).view(b, c, h, -1)
        atten = self.up(atten)

        out = res + atten
        return out 
Example #6
Source File: attention.py    From nsf with MIT License 6 votes vote down vote up
def forward(self, inputs, y=None):
        # Apply convs
        theta = self.theta(inputs)
        phi = F.max_pool2d(self.phi(inputs), [2, 2])
        g = F.max_pool2d(self.g(inputs), [2, 2])
        # Perform reshapes
        theta = theta.view(-1, self.channels // self.heads, inputs.shape[2] * inputs.shape[3])
        phi = phi.view(-1, self.channels // self.heads, inputs.shape[2] * inputs.shape[3] // 4)
        g = g.view(-1, self.channels // 2, inputs.shape[2] * inputs.shape[3] // 4)
        # Matmul and softmax to get attention maps
        beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), -1)
        # Attention map times g path
        o = self.o(torch.bmm(g, beta.transpose(1, 2)).view(-1, self.channels // 2, inputs.shape[2],
                                                           inputs.shape[3]))
        outputs = self.gamma * o + inputs
        return outputs 
Example #7
Source File: common.py    From decaNLP with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def forward(self, context, question, context_padding, question_padding): 
        context_padding = torch.cat([context.new_zeros((context.size(0), 1), dtype=torch.long)==1, context_padding], 1)
        question_padding = torch.cat([question.new_zeros((question.size(0), 1), dtype=torch.long)==1, question_padding], 1)

        context_sentinel = self.embed_sentinel(context.new_zeros((context.size(0), 1), dtype=torch.long))
        context = torch.cat([context_sentinel, self.dropout(context)], 1) # batch_size x (context_length + 1) x features

        question_sentinel = self.embed_sentinel(question.new_ones((question.size(0), 1), dtype=torch.long))
        question = torch.cat([question_sentinel, question], 1) # batch_size x (question_length + 1) x features
        question = torch.tanh(self.proj(question)) # batch_size x (question_length + 1) x features

        affinity = context.bmm(question.transpose(1,2)) # batch_size x (context_length + 1) x (question_length + 1)
        attn_over_context = self.normalize(affinity, context_padding) # batch_size x (context_length + 1) x 1
        attn_over_question = self.normalize(affinity.transpose(1,2), question_padding) # batch_size x (question_length + 1) x 1
        sum_of_context = self.attn(attn_over_context, context) # batch_size x (question_length + 1) x features
        sum_of_question = self.attn(attn_over_question, question) # batch_size x (context_length + 1) x features
        coattn_context = self.attn(attn_over_question, sum_of_context) # batch_size x (context_length + 1) x features
        coattn_question = self.attn(attn_over_context, sum_of_question) # batch_size x (question_length + 1) x features
        return torch.cat([coattn_context, sum_of_question], 2)[:, 1:], torch.cat([coattn_question, sum_of_context], 2)[:, 1:] 
Example #8
Source File: gat_layers.py    From DeepInf with MIT License 6 votes vote down vote up
def forward(self, h, adj):
        n = h.size(0) # h is of size n x f_in
        h_prime = torch.matmul(h.unsqueeze(0), self.w) #  n_head x n x f_out
        attn_src = torch.bmm(h_prime, self.a_src) # n_head x n x 1
        attn_dst = torch.bmm(h_prime, self.a_dst) # n_head x n x 1
        attn = attn_src.expand(-1, -1, n) + attn_dst.expand(-1, -1, n).permute(0, 2, 1) # n_head x n x n

        attn = self.leaky_relu(attn)
        attn.data.masked_fill_(1 - adj, float("-inf"))
        attn = self.softmax(attn) # n_head x n x n
        attn = self.dropout(attn)
        output = torch.bmm(attn, h_prime) # n_head x n x f_out

        if self.bias is not None:
            return output + self.bias
        else:
            return output 
Example #9
Source File: attention.py    From Character-Level-Language-Modeling-with-Deeper-Self-Attention-pytorch with MIT License 6 votes vote down vote up
def calc_score(self, att_query, att_keys):
        """
        att_query is: b x t_q x n
        att_keys is b x t_k x n
        return b x t_q x t_k scores
        """

        b, t_k, n = list(att_keys.size())
        t_q = att_query.size(1)
        if self.mode == 'bahdanau':
            att_query = att_query.unsqueeze(2).expand(b, t_q, t_k, n)
            att_keys = att_keys.unsqueeze(1).expand(b, t_q, t_k, n)
            sum_qk = att_query + att_keys
            sum_qk = sum_qk.view(b * t_k * t_q, n)
            out = self.linear_att(F.tanh(sum_qk)).view(b, t_q, t_k)
        elif self.mode == 'dot_prod':
            out = torch.bmm(att_query, att_keys.transpose(1, 2))
            if hasattr(self, 'scale'):
                out = out * self.scale
        return out 
Example #10
Source File: pointnet.py    From TreeGAN with MIT License 6 votes vote down vote up
def forward(self, x):
        batchsize = x.size()[0]
        n_pts = x.size()[2]
        trans = self.stn(x)
        x = x.transpose(2,1)
        x = torch.bmm(x, trans)
        x = x.transpose(2,1)
        x = F.relu(self.bn1(self.conv1(x)))
        pointfeat = x
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.bn3(self.conv3(x))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)
        if self.global_feat:
            return x, trans
        else:
            x = x.view(-1, 1024, 1).repeat(1, 1, n_pts)
            return torch.cat([x, pointfeat], 1), trans 
Example #11
Source File: Patient2Vec.py    From Patient2Vec with MIT License 6 votes vote down vote up
def convolutional_layer(self, inputs):
        convolution_all = []
        conv_wts = []
        for i in range(self.seq_len):
            convolution_one_month = []
            for j in range(self.pad_size):
                convolution = self.conv(torch.unsqueeze(inputs[:, i, j], dim=1))
                convolution_one_month.append(convolution)
            convolution_one_month = torch.stack(convolution_one_month)
            convolution_one_month = torch.squeeze(convolution_one_month, dim=3)
            convolution_one_month = torch.transpose(convolution_one_month, 0, 1)
            convolution_one_month = torch.transpose(convolution_one_month, 1, 2)
            convolution_one_month = torch.squeeze(convolution_one_month, dim=1)
            convolution_one_month = self.func_tanh(convolution_one_month)
            convolution_one_month = torch.unsqueeze(convolution_one_month, dim=1)
            vec = torch.bmm(convolution_one_month, inputs[:, i])
            convolution_all.append(vec)
            conv_wts.append(convolution_one_month)
        convolution_all = torch.stack(convolution_all, dim=1)
        convolution_all = torch.squeeze(convolution_all, dim=2)
        conv_wts = torch.squeeze(torch.stack(conv_wts, dim=1), dim=2)
        return convolution_all, conv_wts 
Example #12
Source File: model_utils.py    From TVQAplus with MIT License 6 votes vote down vote up
def find_max_triples(p1, p2, topN=5, prob_thd=None):
    """ Find a list of (k1, k2) where k1 >= k2 with the maximum values of p1[k1] * p2[k2]
    Args:
        p1 (torch.CudaTensor): (N, L) batched start_idx probabilities
        p2 (torch.CudaTensor): (N, L) batched end_idx probabilities
        topN (int): return topN pairs with highest values
        prob_thd (float):
    Returns:
        batched_sorted_triple: N * [(st_idx, ed_idx, confidence), ...]
    """
    product = torch.bmm(p1.unsqueeze(2), p2.unsqueeze(1))  # (N, L, L), end_idx >= start_idx
    upper_product = torch.stack([torch.triu(p) for p in product]
                                ).data.cpu().numpy()  # (N, L, L) the lower part becomes zeros
    batched_sorted_triple = []
    for idx, e in enumerate(upper_product):
        sorted_triple = topN_array_2d(e, topN=topN)
        if prob_thd is not None:
            sorted_triple = [t for t in sorted_triple if t[2] >= prob_thd]
        batched_sorted_triple.append(sorted_triple)
    return batched_sorted_triple 
Example #13
Source File: tutorial.py    From TaskBot with GNU General Public License v3.0 6 votes vote down vote up
def forward(self, input, hidden, encoder_outputs):
        embedded = self.embedding(input).view(1, 1, -1)
        embedded = self.dropout(embedded)

        attn_weights = F.softmax(
            self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1)
        attn_applied = torch.bmm(attn_weights.unsqueeze(0),
                                 encoder_outputs.unsqueeze(0))

        output = torch.cat((embedded[0], attn_applied[0]), 1)
        output = self.attn_combine(output).unsqueeze(0)

        output = F.relu(output)
        output, hidden = self.gru(output, hidden)

        output = F.log_softmax(self.out(output[0]), dim=1)
        return output, hidden, attn_weights 
Example #14
Source File: MessageFunction.py    From nmp_qc with MIT License 6 votes vote down vote up
def m_ggnn(self, h_v, h_w, e_vw, opt={}):

        m = Variable(torch.zeros(h_w.size(0), h_w.size(1), self.args['out']).type_as(h_w.data))

        for w in range(h_w.size(1)):
            if torch.nonzero(e_vw[:, w, :].data).size():
                for i, el in enumerate(self.args['e_label']):
                    ind = (el == e_vw[:,w,:]).type_as(self.learn_args[0][i])

                    parameter_mat = self.learn_args[0][i][None, ...].expand(h_w.size(0), self.learn_args[0][i].size(0),
                                                                            self.learn_args[0][i].size(1))

                    m_w = torch.transpose(torch.bmm(torch.transpose(parameter_mat, 1, 2),
                                                                        torch.transpose(torch.unsqueeze(h_w[:, w, :], 1),
                                                                                        1, 2)), 1, 2)
                    m_w = torch.squeeze(m_w)
                    m[:,w,:] = ind.expand_as(m_w)*m_w
        return m 
Example #15
Source File: tsd_net.py    From ConvLab with MIT License 6 votes vote down vote up
def forward(self, z_enc_out, u_enc_out, u_input_np, m_t_input, degree_input, last_hidden, z_input_np):
        sparse_z_input = Variable(self.get_sparse_selective_input(z_input_np), requires_grad=False)

        m_embed = self.emb(m_t_input)
        z_context = self.attn_z(last_hidden, z_enc_out)
        u_context = self.attn_u(last_hidden, u_enc_out)
        gru_in = torch.cat([m_embed, u_context, z_context, degree_input.unsqueeze(0)], dim=2)
        gru_out, last_hidden = self.gru(gru_in, last_hidden)
        gen_score = self.proj(torch.cat([z_context, u_context, gru_out], 2)).squeeze(0)
        z_copy_score = F.tanh(self.proj_copy2(z_enc_out.transpose(0, 1)))
        z_copy_score = torch.matmul(z_copy_score, gru_out.squeeze(0).unsqueeze(2)).squeeze(2)
        z_copy_score = z_copy_score.cpu()
        z_copy_score_max = torch.max(z_copy_score, dim=1, keepdim=True)[0]
        z_copy_score = torch.exp(z_copy_score - z_copy_score_max)  # [B,T]
        z_copy_score = torch.log(torch.bmm(z_copy_score.unsqueeze(1), sparse_z_input)).squeeze(
            1) + z_copy_score_max  # [B,V]
        z_copy_score = cuda_(z_copy_score)

        scores = F.softmax(torch.cat([gen_score, z_copy_score], dim=1), dim=1)
        gen_score, z_copy_score = scores[:, :cfg.vocab_size], \
                                  scores[:, cfg.vocab_size:]
        proba = gen_score + z_copy_score[:, :cfg.vocab_size]  # [B,V]
        proba = torch.cat([proba, z_copy_score[:, cfg.vocab_size:]], 1)
        return proba, last_hidden, gru_out 
Example #16
Source File: model.py    From c3dpo_nrsfm with MIT License 5 votes vote down vote up
def apply_similarity_t(self, S, R, T, s):
        return torch.bmm(R, s[:, None, None] * S) + T[:, :, None] 
Example #17
Source File: module.py    From Tacotron-pytorch with Apache License 2.0 5 votes vote down vote up
def forward(self, decoder_input, memory, attn_hidden, gru1_hidden, gru2_hidden):

        memory_len = memory.size()[1]
        batch_size = memory.size()[0]

        # Get keys
        keys = self.W1(memory.contiguous().view(-1, self.num_units))
        keys = keys.view(-1, memory_len, self.num_units)

        # Get hidden state (query) passed through GRUcell
        d_t = self.attn_grucell(decoder_input, attn_hidden)

        # Duplicate query with same dimension of keys for matrix operation (Speed up)
        d_t_duplicate = self.W2(d_t).unsqueeze(1).expand_as(memory)

        # Calculate attention score and get attention weights
        attn_weights = self.v(F.tanh(keys + d_t_duplicate).view(-1, self.num_units)).view(-1, memory_len, 1)
        attn_weights = attn_weights.squeeze(2)
        attn_weights = F.softmax(attn_weights)

        # Concatenate with original query
        d_t_prime = torch.bmm(attn_weights.view([batch_size,1,-1]), memory).squeeze(1)

        # Residual GRU
        gru1_input = self.attn_projection(torch.cat([d_t, d_t_prime], 1))
        gru1_hidden = self.gru1(gru1_input, gru1_hidden)
        gru2_input = gru1_input + gru1_hidden

        gru2_hidden = self.gru2(gru2_input, gru2_hidden)
        bf_out = gru2_input + gru2_hidden

        # Output
        output = self.out(bf_out).view(-1, hp.num_mels, hp.outputs_per_step)

        return output, d_t, gru1_hidden, gru2_hidden 
Example #18
Source File: model.py    From c3dpo_nrsfm with MIT License 5 votes vote down vote up
def canonicalization_loss(self, phi_out, class_mask=None):

        shape_canonical = phi_out['shape_canonical']

        dtype = shape_canonical.type()
        ba = shape_canonical.shape[0]

        n_sample = self.canonicalization['n_rand_samples']

        # rotate the canonical point cloud
        # generate random rotation around all axes
        R_rand = rand_rot(ba * n_sample,
                          dtype=dtype,
                          max_rot_angle=self.canonicalization['rot_angle'],
                          axes=(1, 1, 1))

        unrotated = shape_canonical.repeat(n_sample, 1, 1)
        rotated = torch.bmm(R_rand, unrotated)

        psi_out = self.run_psi(rotated)  # psi3( Rrand X )

        a, b = psi_out['shape_canonical'], unrotated
        l_canonicalization = avg_l2_huber(a, b,
                                          scaling=self.huber_scaling,
                                          mask=class_mask.repeat(n_sample, 1)
                                          if class_mask is not None else None)

        # reshape the outputs in the output list
        psi_out = {k: v.view(
            self.canonicalization['n_rand_samples'],
            ba, *v.shape[1:]) for k, v in psi_out.items()}

        return l_canonicalization, psi_out 
Example #19
Source File: so3.py    From c3dpo_nrsfm with MIT License 5 votes vote down vote up
def so3_exponential_map(log_rot: torch.Tensor, eps: float = 0.0001):
    """
    Convert a batch of logarithmic representations of rotation matrices
    `log_rot` to a batch of 3x3 rotation matrices using Rodrigues formula.
    The conversion has a singularity around 0 which is handled by clamping
    controlled with the `eps` argument.

    Args:
        log_rot: batch of vectors of shape `(minibatch , 3)`
        eps: a float constant handling the conversion singularity around 0

    Returns:
        batch of rotation matrices of shape `(minibatch , 3 , 3)`

    Raises:
        ValueError if `log_rot` is of incorrect shape
    """

    _, dim = log_rot.shape
    if dim != 3:
        raise ValueError('Input tensor shape has to be Nx3.')

    nrms = (log_rot * log_rot).sum(1)
    phis = torch.clamp(nrms, 0.).sqrt()
    phisi = 1. / (phis+eps)
    fac1 = phisi * phis.sin()
    fac2 = phisi * phisi * (1. - phis.cos())
    ss = hat(log_rot)

    R = fac1[:, None, None] * ss + \
        fac2[:, None, None] * torch.bmm(ss, ss) + \
        torch.eye(3, dtype=log_rot.dtype, device=log_rot.device)[None]

    return R 
Example #20
Source File: self_attention.py    From BiaffineDependencyParsing with MIT License 5 votes vote down vote up
def forward(self, q, k, v, mask=None):
        attn = torch.bmm(q, k.transpose(1, 2))
        attn = attn / self.temperature

        if mask is not None:
            attn = attn.masked_fill(mask, -np.inf)

        attn = self.softmax(attn)
        attn = self.dropout(attn)
        output = torch.bmm(attn, v)

        return output, attn 
Example #21
Source File: train.py    From MomentumContrast.pytorch with MIT License 5 votes vote down vote up
def train(model_q, model_k, device, train_loader, queue, optimizer, epoch, temp=0.07):
    model_q.train()
    total_loss = 0

    for batch_idx, (data, target) in enumerate(train_loader):
        x_q = data[0]
        x_k = data[1]

        x_q, x_k = x_q.to(device), x_k.to(device)
        q = model_q(x_q)
        k = model_k(x_k)
        k = k.detach()

        N = data[0].shape[0]
        K = queue.shape[0]
        l_pos = torch.bmm(q.view(N,1,-1), k.view(N,-1,1))
        l_neg = torch.mm(q.view(N,-1), queue.T.view(-1,K))

        logits = torch.cat([l_pos.view(N, 1), l_neg], dim=1)

        labels = torch.zeros(N, dtype=torch.long)
        labels = labels.to(device)

        cross_entropy_loss = nn.CrossEntropyLoss()
        loss = cross_entropy_loss(logits/temp, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        momentum_update(model_q, model_k)

        queue = queue_data(queue, k)
        queue = dequeue_data(queue)

    total_loss /= len(train_loader.dataset)

    print('Train Epoch: {} \tLoss: {:.6f}'.format(epoch, total_loss)) 
Example #22
Source File: attention.py    From Character-Level-Language-Modeling-with-Deeper-Self-Attention-pytorch with MIT License 5 votes vote down vote up
def forward(self, q, k, v):
        b_q, t_q, dim_q = list(q.size())
        b_k, t_k, dim_k = list(k.size())
        b_v, t_v, dim_v = list(v.size())
        assert(b_q == b_k and b_k == b_v)  # batch size should be equal
        assert(dim_q == dim_k)  # dims should be equal
        assert(t_k == t_v)  # times should be equal
        b = b_q
        qk = torch.bmm(q, k.transpose(1, 2))  # b x t_q x t_k
        qk.div_(dim_k ** 0.5)
        mask = None
        with torch.no_grad():
            if self.causal and t_q > 1:
                causal_mask = q.data.new(t_q, t_k).byte().fill_(1).triu_(1)
                mask = causal_mask.unsqueeze(0).expand(b, t_q, t_k)
            if self.mask_k is not None:
                mask_k = self.mask_k.unsqueeze(1).expand(b, t_q, t_k)
                mask = mask_k if mask is None else mask | mask_k
            if self.mask_q is not None:
                mask_q = self.mask_q.unsqueeze(2).expand(b, t_q, t_k)
                mask = mask_q if mask is None else mask | mask_q
            if mask is not None:
                qk.masked_fill_(mask, -1e9)

        sm_qk = F.softmax(qk, dim=2)
        sm_qk = self.dropout(sm_qk)
        return torch.bmm(sm_qk, v), sm_qk  # b x t_q x dim_v 
Example #23
Source File: selection_predict.py    From SQL_Database_Optimization with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def forward(self, x_emb_var, x_len, col_inp_var,
            col_name_len, col_len, col_num):
        B = len(x_emb_var)
        max_x_len = max(x_len)

        e_col, _ = col_name_encode(col_inp_var, col_name_len,
                col_len, self.sel_col_name_enc)

        if self.use_ca:
            h_enc, _ = run_lstm(self.sel_lstm, x_emb_var, x_len)
            att_val = torch.bmm(e_col, self.sel_att(h_enc).transpose(1, 2))
            for idx, num in enumerate(x_len):
                if num < max_x_len:
                    att_val[idx, :, num:] = -100
            att = self.softmax(att_val.view((-1, max_x_len))).view(
                    B, -1, max_x_len)
            K_sel_expand = (h_enc.unsqueeze(1) * att.unsqueeze(3)).sum(2)
        else:
            h_enc, _ = run_lstm(self.sel_lstm, x_emb_var, x_len)
            att_val = self.sel_att(h_enc).squeeze()
            for idx, num in enumerate(x_len):
                if num < max_x_len:
                    att_val[idx, num:] = -100
            att = self.softmax(att_val)
            K_sel = (h_enc * att.unsqueeze(2).expand_as(h_enc)).sum(1)
            K_sel_expand=K_sel.unsqueeze(1)

        sel_score = self.sel_out( self.sel_out_K(K_sel_expand) + \
                self.sel_out_col(e_col) ).squeeze()
        max_col_num = max(col_num)
        for idx, num in enumerate(col_num):
            if num < max_col_num:
                sel_score[idx, num:] = -100

        return sel_score 
Example #24
Source File: aggregator_predict.py    From SQL_Database_Optimization with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def forward(self, x_emb_var, x_len, col_inp_var=None, col_name_len=None,
            col_len=None, col_num=None, gt_sel=None):
        B = len(x_emb_var)
        max_x_len = max(x_len)

        h_enc, _ = run_lstm(self.agg_lstm, x_emb_var, x_len)
        if self.use_ca:
            e_col, _ = col_name_encode(col_inp_var, col_name_len, 
                    col_len, self.agg_col_name_enc)
            chosen_sel_idx = torch.LongTensor(gt_sel)
            aux_range = torch.LongTensor(range(len(gt_sel)))
            if x_emb_var.is_cuda:
                chosen_sel_idx = chosen_sel_idx.cuda()
                aux_range = aux_range.cuda()
            chosen_e_col = e_col[aux_range, chosen_sel_idx]
            att_val = torch.bmm(self.agg_att(h_enc), 
                    chosen_e_col.unsqueeze(2)).squeeze()
        else:
            att_val = self.agg_att(h_enc).squeeze()

        for idx, num in enumerate(x_len):
            if num < max_x_len:
                att_val[idx, num:] = -100
        att = self.softmax(att_val)

        K_agg = (h_enc * att.unsqueeze(2).expand_as(h_enc)).sum(1)
        agg_score = self.agg_out(K_agg)
        return agg_score 
Example #25
Source File: decoders.py    From ConvLab with MIT License 5 votes vote down vote up
def forward(self, output, context):
        # output: (batch_size, output_seq_len, dec_cell_size)
        # context: (batch_size, max_ctx_len, ctx_cell_size)
        batch_size = output.size(0)
        max_ctx_len = context.size(1)

        if self.attn_mode == 'dot':
            attn = th.bmm(output, context.transpose(1, 2)) # (batch_size, output_seq_len, max_ctx_len)
        elif self.attn_mode == 'general':
            mapped_output = self.dec_w(output) # (batch_size, output_seq_len, ctx_cell_size)
            attn = th.bmm(mapped_output, context.transpose(1, 2)) # (batch_size, output_seq_len, max_ctx_len)
        elif self.attn_mode == 'cat':
            mapped_output = self.dec_w(output) # (batch_size, output_seq_len, dec_cell_size)
            mapped_attn = self.attn_w(context) # (batch_size, max_ctx_len, dec_cell_size)
            tiled_output = mapped_output.unsqueeze(2).repeat(1, 1, max_ctx_len, 1) # (batch_size, output_seq_len, max_ctx_len, dec_cell_size)
            tiled_attn = mapped_attn.unsqueeze(1) # (batch_size, 1, max_ctx_len, dec_cell_size)
            fc1 = F.tanh(tiled_output+tiled_attn) # (batch_size, output_seq_len, max_ctx_len, dec_cell_size)
            attn = self.query_w(fc1).squeeze(-1) # (batch_size, otuput_seq_len, max_ctx_len)
        else:
            raise ValueError('Unknown attention mode')

        # TODO mask
        # if self.mask is not None:

        attn = F.softmax(attn.view(-1, max_ctx_len), dim=1).view(batch_size, -1, max_ctx_len) # (batch_size, output_seq_len, max_ctx_len)
        mix = th.bmm(attn, context) # (batch_size, output_seq_len, ctx_cell_size)
        combined = th.cat((mix, output), dim=2) # (batch_size, output_seq_len, dec_cell_size+ctx_cell_size)
        if self.linear_out is None:
            return combined, attn
        else:
            output = F.tanh(
                self.linear_out(combined.view(-1, self.dec_cell_size+self.ctx_cell_size))).view(
                batch_size, -1, self.dec_cell_size) # (batch_size, output_seq_len, dec_cell_size)
            return output, attn 
Example #26
Source File: Transformer.py    From ConvLab with MIT License 5 votes vote down vote up
def forward(self, q, k, v, mask=None):

        attn = torch.bmm(q, k.transpose(1, 2))
        attn = attn / self.temperature

        if mask is not None:
            attn = attn.masked_fill(mask, -np.inf)

        attn = self.softmax(attn)
        attn = self.dropout(attn)
        output = torch.bmm(attn, v)

        return output, attn 
Example #27
Source File: model.py    From ConvLab with MIT License 5 votes vote down vote up
def forward(self, input, hidden, encoder_outputs):
        if isinstance(hidden, tuple):
            h_t = hidden[0]
        else:
            h_t = hidden
        encoder_outputs = encoder_outputs.transpose(0, 1)
        embedded = self.embedding(input)  # .view(1, 1, -1)
        # embedded = F.dropout(embedded, self.dropout_p)

        # SCORE 3
        max_len = encoder_outputs.size(1)
        h_t = h_t.transpose(0, 1)  # [1,B,D] -> [B,1,D]
        h_t = h_t.repeat(1, max_len, 1)  # [B,1,D]  -> [B,T,D]
        energy = self.attn(torch.cat((h_t, encoder_outputs), 2))  # [B,T,2D] -> [B,T,D]
        energy = torch.tanh(energy)
        energy = energy.transpose(2, 1)  # [B,H,T]
        v = self.v.repeat(encoder_outputs.size(0), 1).unsqueeze(1)  # [B,1,H]
        energy = torch.bmm(v, energy)  # [B,1,T]
        attn_weights = F.softmax(energy, dim=2)  # [B,1,T]

        # getting context
        context = torch.bmm(attn_weights, encoder_outputs)  # [B,1,H]

        # context = torch.bmm(attn_weights.unsqueeze(0), encoder_outputs.unsqueeze(0)) #[B,1,H]
        # Combine embedded input word and attended context, run through RNN
        rnn_input = torch.cat((embedded, context), 2)
        rnn_input = rnn_input.transpose(0, 1)
        output, hidden = self.rnn(rnn_input, hidden)
        output = output.squeeze(0)  # (1,B,V)->(B,V)

        output = F.log_softmax(self.out(output), dim=1)
        return output, hidden  # , attn_weights 
Example #28
Source File: model.py    From ConvLab with MIT License 5 votes vote down vote up
def score(self, hidden, encoder_outputs):
        cat = torch.cat([hidden, encoder_outputs], 2)
        energy = torch.tanh(self.attn(cat)) # [B*T*2H]->[B*T*H]
        energy = energy.transpose(2,1) # [B*H*T]
        v = self.v.repeat(encoder_outputs.data.shape[0],1).unsqueeze(1) #[B*1*H]
        energy = torch.bmm(v,energy)  # [B*1*T]
        return energy.squeeze(1)  # [B*T] 
Example #29
Source File: enc_Luong.py    From ConvLab with MIT License 5 votes vote down vote up
def forward(self, input_seq, last_hidden, encoder_outputs):
        # Note: we run this one step at a time

        # Get the embedding of the current input word (last output word)
        batch_size = input_seq.size(0)
        max_len = encoder_outputs.size(0)
        encoder_outputs = encoder_outputs.transpose(0,1)
        embedded = self.embedding(input_seq)
        embedded = self.dropout(embedded)
        embedded = embedded.view(1, batch_size, self.hidden_size) # S=1 x B x N

        # Get current hidden state from input word and last hidden state
        rnn_output, hidden = self.lstm(embedded, last_hidden)

        s_t = hidden[0][-1].unsqueeze(0)
        H = s_t.repeat(max_len,1,1).transpose(0,1)

        energy = F.tanh(self.W1(torch.cat([H,encoder_outputs], 2)))
        energy = energy.transpose(2,1)
        v = self.v.repeat(encoder_outputs.data.shape[0],1).unsqueeze(1) #[B*1*H]
        energy = torch.bmm(v,energy) # [B*1*T]
        a = F.softmax(energy)
        context = a.bmm(encoder_outputs)

        # Attentional vector using the RNN hidden state and context vector
        # concatenated together (Luong eq. 5)
        rnn_output = rnn_output.squeeze(0) # S=1 x B x N -> B x N
        context = context.squeeze(1)       # B x S=1 x N -> B x N
        concat_input = torch.cat((rnn_output, context), 1)
        concat_output = F.tanh(self.concat(concat_input))

        # Finally predict next token (Luong eq. 6, without softmax)
        output = self.out(concat_output)
        # Return final output, hidden state, and attention weights (for visualization)
        return output, hidden 
Example #30
Source File: tsd_net.py    From ConvLab with MIT License 5 votes vote down vote up
def score(self, hidden, encoder_outputs):
        max_len = encoder_outputs.size(1)
        H = hidden.repeat(max_len, 1, 1).transpose(0, 1)
        energy = F.tanh(self.attn(torch.cat([H, encoder_outputs], 2)))  # [B,T,2H]->[B,T,H]
        energy = energy.transpose(2, 1)  # [B,H,T]
        v = self.v.repeat(encoder_outputs.size(0), 1).unsqueeze(1)  # [B,1,H]
        energy = torch.bmm(v, energy)  # [B,1,T]
        return energy