Python torch.index_select() Examples

The following are 30 code examples of torch.index_select(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch , or try the search function .
Example #1
Source File: model_re.py    From fastNLP with Apache License 2.0 6 votes vote down vote up
def sort_mention(self, mention_start, mention_end, candidate_mention_emb, candidate_mention_score, seq_lens):
        # 排序记录,高分段在前面
        mention_score, mention_ids = torch.sort(candidate_mention_score, descending=True)
        preserve_mention_num = int(self.config.mention_ratio * sum(seq_lens))
        mention_ids = mention_ids[0:preserve_mention_num]
        mention_score = mention_score[0:preserve_mention_num]

        mention_start_tensor = torch.from_numpy(mention_start).to(self.device).index_select(dim=0,
                                                                                            index=mention_ids)  # [lamda*word_num]
        mention_end_tensor = torch.from_numpy(mention_end).to(self.device).index_select(dim=0,
                                                                                        index=mention_ids)  # [lamda*word_num]
        mention_emb = candidate_mention_emb.index_select(index=mention_ids, dim=0)  # [lamda*word_num,emb]
        assert mention_score.shape[0] == preserve_mention_num
        assert mention_start_tensor.shape[0] == preserve_mention_num
        assert mention_end_tensor.shape[0] == preserve_mention_num
        assert mention_emb.shape[0] == preserve_mention_num
        # TODO 不交叉没做处理

        # 对start进行再排序,实际位置在前面
        # TODO 这里只考虑了start没有考虑end
        mention_start_tensor, temp_index = torch.sort(mention_start_tensor)
        mention_end_tensor = mention_end_tensor.index_select(dim=0, index=temp_index)
        mention_emb = mention_emb.index_select(dim=0, index=temp_index)
        mention_score = mention_score.index_select(dim=0, index=temp_index)
        return mention_start_tensor, mention_end_tensor, mention_score, mention_emb 
Example #2
Source File: graphML.py    From graph-neural-networks with GNU General Public License v3.0 6 votes vote down vote up
def forward(self, x):
        # x is of shape: batchSize x dimInFeatures x numberNodesIn
        B = x.shape[0]
        F = x.shape[1]
        Nin = x.shape[2]
        # And now we add the zero padding
        if Nin < self.N:
            x = torch.cat((x,
                           torch.zeros(B, F, self.N-Nin)\
                                   .type(x.dtype).to(x.device)
                          ), dim = 2)
        # Compute the filter output
        u = LSIGF(self.weight, self.S, x, self.bias)
        # So far, u is of shape batchSize x dimOutFeatures x numberNodes
        # And we want to return a tensor of shape
        # batchSize x dimOutFeatures x numberNodesIn
        # since the nodes between numberNodesIn and numberNodes are not required
        if Nin < self.N:
            u = torch.index_select(u, 2, torch.arange(Nin).to(u.device))
        return u 
Example #3
Source File: anchor_generator.py    From mmdetection with Apache License 2.0 6 votes vote down vote up
def gen_base_anchors(self):
        """Generate base anchors.

        Returns:
            list(torch.Tensor): Base anchors of a feature grid in multiple
                feature levels.
        """
        multi_level_base_anchors = []
        for i, base_size in enumerate(self.base_sizes):
            base_anchors = self.gen_single_level_base_anchors(
                base_size,
                scales=self.scales[i],
                ratios=self.ratios[i],
                center=self.centers[i])
            indices = list(range(len(self.ratios[i])))
            indices.insert(1, len(indices))
            base_anchors = torch.index_select(base_anchors, 0,
                                              torch.LongTensor(indices))
            multi_level_base_anchors.append(base_anchors)
        return multi_level_base_anchors 
Example #4
Source File: graphML.py    From graph-neural-networks with GNU General Public License v3.0 6 votes vote down vote up
def forward(self, x):
        # x is of shape: batchSize x dimInFeatures x numberNodesIn
        B = x.shape[0]
        F = x.shape[1]
        Nin = x.shape[2]
        # If we have less filter coefficients than the required ones, we need
        # to use the copying scheme
        if self.M == self.N:
            self.h = self.weight
        else:
            self.h = torch.index_select(self.weight, 4, self.copyNodes)
        # And now we add the zero padding
        if Nin < self.N:
            zeroPad = torch.zeros(B, F, self.N-Nin).type(x.dtype).to(x.device)
            x = torch.cat((x, zeroPad), dim = 2)
        # Compute the filter output
        u = NVGF(self.h, self.S, x, self.bias)
        # So far, u is of shape batchSize x dimOutFeatures x numberNodes
        # And we want to return a tensor of shape
        # batchSize x dimOutFeatures x numberNodesIn
        # since the nodes between numberNodesIn and numberNodes are not required
        if Nin < self.N:
            u = torch.index_select(u, 2, torch.arange(Nin).to(u.device))
        return u 
Example #5
Source File: DDPAE_utils.py    From DDPAE-video-prediction with MIT License 6 votes vote down vote up
def pose_inv_full(pose):
  '''
  param pose: N x 6
  Inverse the 2x3 transformer matrix.
  '''
  N, _ = pose.size()
  b = pose.view(N, 2, 3)[:, :, 2:]
  # A^{-1}
  # Calculate determinant
  determinant = (pose[:, 0] * pose[:, 4] - pose[:, 1] * pose[:, 3] + 1e-8).view(N, 1)
  indices = Variable(torch.LongTensor([4, 1, 3, 0]).cuda())
  scale = Variable(torch.Tensor([1, -1, -1, 1]).cuda())
  A_inv = torch.index_select(pose, 1, indices) * scale / determinant
  A_inv = A_inv.view(N, 2, 2)
  # b' = - A^{-1} b
  b_inv = - A_inv.matmul(b).view(N, 2, 1)
  transformer_inv = torch.cat([A_inv, b_inv], dim=2)
  return transformer_inv 
Example #6
Source File: check_chamfer.py    From DIB-R with MIT License 6 votes vote down vote up
def sample(verts, faces, num=10000, ret_choice = False):
    dist_uni = torch.distributions.Uniform(torch.tensor([0.0]).cuda(), torch.tensor([1.0]).cuda())
    x1,x2,x3 = torch.split(torch.index_select(verts, 0, faces[:,0]) - torch.index_select(verts, 0, faces[:,1]), 1, dim = 1)
    y1,y2,y3 = torch.split(torch.index_select(verts, 0, faces[:,1]) - torch.index_select(verts, 0, faces[:,2]), 1, dim = 1)
    a = (x2*y3 - x3*y2)**2
    b = (x3*y1 - x1*y3)**2
    c = (x1*y2 - x2*y1)**2
    Areas = torch.sqrt(a+b+c)/2
    Areas = Areas / torch.sum(Areas)
    cat_dist = torch.distributions.Categorical(Areas.view(-1))
    choices = cat_dist.sample_n(num)
    select_faces = faces[choices]
    xs = torch.index_select(verts, 0,select_faces[:,0])
    ys = torch.index_select(verts, 0,select_faces[:,1])
    zs = torch.index_select(verts, 0,select_faces[:,2])
    u = torch.sqrt(dist_uni.sample_n(num))
    v = dist_uni.sample_n(num)
    points = (1- u)*xs + (u*(1-v))*ys + u*v*zs
    if ret_choice:
        return points, choices
    else:
        return points 
Example #7
Source File: main.py    From Extremely-Fine-Grained-Entity-Typing with MIT License 6 votes vote down vote up
def visualize(args):
  saved_path = constant.EXP_ROOT
  model = models.Model(args, constant.ANSWER_NUM_DICT[args.goal])
  model.cuda()
  model.eval()
  model.load_state_dict(torch.load(saved_path + '/' + args.model_id + '_best.pt')["state_dict"])

  label2id = constant.ANS2ID_DICT["open"] 
  visualize = SummaryWriter("../visualize/" + args.model_id)
  # label_list = ["person", "leader", "president", "politician", "organization", "company", "athlete","adult",  "male",  "man", "television_program", "event"]
  label_list = list(label2id.keys())
  ids = [label2id[_] for _ in label_list]
  if args.gcn:
    # connection_matrix = model.decoder.label_matrix + model.decoder.weight * model.decoder.affinity
    connection_matrix = model.decoder.label_matrix + model.decoder.weight * model.decoder.affinity
    label_vectors = model.decoder.transform(connection_matrix.mm(model.decoder.linear.weight) / connection_matrix.sum(1, keepdim=True))
  else:
    label_vectors = model.decoder.linear.weight.data

  interested_vectors = torch.index_select(label_vectors, 0, torch.tensor(ids).to(torch.device("cuda")))
  visualize.add_embedding(interested_vectors, metadata=label_list, label_img=None) 
Example #8
Source File: pytorch_util.py    From rlgraph with Apache License 2.0 6 votes vote down vote up
def pytorch_tile(tensor, n_tile, dim=0):
    """
    Tile utility as there is not `torch.tile`.
    Args:
        tensor (torch.Tensor): Tensor to tile.
        n_tile (int): Num tiles.
        dim (int): Dim to tile.

    Returns:
        torch.Tensor: Tiled tensor.
    """
    if isinstance(n_tile, torch.Size):
        n_tile = n_tile[0]
    init_dim = tensor.size(dim)
    repeat_idx = [1] * tensor.dim()
    repeat_idx[dim] = n_tile
    tensor = tensor.repeat(*(repeat_idx))
    order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]))
    return torch.index_select(tensor, dim, order_index)


# TODO remove when we have handled pytorch placeholder inference better. 
Example #9
Source File: transformation.py    From weakalign with MIT License 6 votes vote down vote up
def symmetricImagePad(self, image_batch, padding_factor):
        b, c, h, w = image_batch.size()
        pad_h, pad_w = int(h*padding_factor), int(w*padding_factor)
        idx_pad_left = torch.LongTensor(range(pad_w-1,-1,-1))
        idx_pad_right = torch.LongTensor(range(w-1,w-pad_w-1,-1))
        idx_pad_top = torch.LongTensor(range(pad_h-1,-1,-1))
        idx_pad_bottom = torch.LongTensor(range(h-1,h-pad_h-1,-1))
        if self.use_cuda:
                idx_pad_left = idx_pad_left.cuda()
                idx_pad_right = idx_pad_right.cuda()
                idx_pad_top = idx_pad_top.cuda()
                idx_pad_bottom = idx_pad_bottom.cuda()
        image_batch = torch.cat((image_batch.index_select(3,idx_pad_left),image_batch,
                                 image_batch.index_select(3,idx_pad_right)),3)
        image_batch = torch.cat((image_batch.index_select(2,idx_pad_top),image_batch,
                                 image_batch.index_select(2,idx_pad_bottom)),2)
        return image_batch 
Example #10
Source File: projection.py    From Pointnet2.ScanNet with MIT License 6 votes vote down vote up
def project(self, label, lin_indices_3d, lin_indices_2d, num_points):
        """
        forward pass of backprojection for 2d features onto 3d points

        :param label: image features (shape: (num_input_channels, proj_image_dims[0], proj_image_dims[1]))
        :param lin_indices_3d: point indices from projection (shape: (num_input_channels, num_points_sample))
        :param lin_indices_2d: pixel indices from projection (shape: (num_input_channels, num_points_sample))
        :param num_points: number of points in one sample
        :return: array of points in sample with projected features (shape: (num_input_channels, num_points))
        """
        
        num_label_ft = 1 if len(label.shape) == 2 else label.shape[0] # = num_input_channels

        output = label.new(num_label_ft, num_points).fill_(0)
        num_ind = lin_indices_3d[0]
        if num_ind > 0:
            # selects values from image_features at indices given by lin_indices_2d
            vals = torch.index_select(label.view(num_label_ft, -1), 1, lin_indices_2d[1:1+num_ind])
            output.view(num_label_ft, -1)[:, lin_indices_3d[1:1+num_ind]] = vals
        
        return output


# Inherit from Function 
Example #11
Source File: projection.py    From Pointnet2.ScanNet with MIT License 6 votes vote down vote up
def forward(ctx, label, lin_indices_3d, lin_indices_2d, num_points):
        """
        forward pass of backprojection for 2d features onto 3d points

        :param label: image features (shape: (num_input_channels, proj_image_dims[0], proj_image_dims[1]))
        :param lin_indices_3d: point indices from projection (shape: (num_input_channels, num_points_sample))
        :param lin_indices_2d: pixel indices from projection (shape: (num_input_channels, num_points_sample))
        :param num_points: number of points in one sample
        :return: array of points in sample with projected features (shape: (num_input_channels, num_points))
        """
        # ctx.save_for_backward(lin_indices_3d, lin_indices_2d)
        num_label_ft = 1 if len(label.shape) == 2 else label.shape[0] # = num_input_channels

        output = label.new(num_label_ft, num_points).fill_(0)
        num_ind = lin_indices_3d[0]
        if num_ind > 0:
            # selects values from image_features at indices given by lin_indices_2d
            vals = torch.index_select(label.view(num_label_ft, -1), 1, lin_indices_2d[1:1+num_ind])
            output.view(num_label_ft, -1)[:, lin_indices_3d[1:1+num_ind]] = vals
        return output 
Example #12
Source File: eval_base.py    From person-reid-lib with MIT License 6 votes vote down vote up
def _feature_distance(self, feaMat):
        probe_feature = torch.index_select(feaMat, dim=0, index=torch.from_numpy(self.probe_index).long().cuda())
        gallery_feature = torch.index_select(feaMat, dim=0, index=torch.from_numpy(self.gallery_index).long().cuda())

        idx = 0
        while idx + self.probe_dst_max < self.probe_num:
            tmp_probe_fea = probe_feature[idx:idx+self.probe_dst_max]
            dst_pg = self._feature_distance_mini(tmp_probe_fea, gallery_feature)
            self.distMat[idx:idx+self.probe_dst_max] += dst_pg
            idx += self.probe_dst_max
        tmp_probe_fea = probe_feature[idx:self.probe_num]
        dst_pg = self._feature_distance_mini(tmp_probe_fea, gallery_feature)
        self.distMat[idx:self.probe_num] += dst_pg

        for i_p, p in enumerate(self.probe_index):
            for i_g, g in enumerate(self.gallery_index):
                if self.test_info[p, 0] != self.test_info[g, 0]:
                    self.avgDiff = self.avgDiff + self.distMat[i_p, i_g]
                    self.avgDiffCount = self.avgDiffCount + 1
                elif p != g:
                    self.avgSame = self.avgSame + self.distMat[i_p, i_g]
                    self.avgSameCount = self.avgSameCount + 1 
Example #13
Source File: Sets2Sets.py    From Sets2Sets with Apache License 2.0 6 votes vote down vote up
def forward(self, pred, target, weights):
        mseloss = torch.sum(weights * torch.pow((pred - target), 2))
        pred = pred.data
        target = target.data
        #
        ones_idx_set = (target == 1).nonzero()
        zeros_idx_set = (target == 0).nonzero()
        # zeros_idx_set = (target == -1).nonzero()
        
        ones_set = torch.index_select(pred, 1, ones_idx_set[:, 1])
        zeros_set = torch.index_select(pred, 1, zeros_idx_set[:, 1])
        
        repeat_ones = ones_set.repeat(1, zeros_set.shape[1])
        repeat_zeros_set = torch.transpose(zeros_set.repeat(ones_set.shape[1], 1), 0, 1).clone()
        repeat_zeros = repeat_zeros_set.view(1, -1)
        difference_val = -(repeat_ones - repeat_zeros)
        exp_val = torch.exp(difference_val)
        exp_loss = torch.sum(exp_val)
        normalized_loss = exp_loss / (zeros_set.shape[1] * ones_set.shape[1])
        set_loss = Variable(torch.FloatTensor([labmda * normalized_loss]), requires_grad=True)
        if use_cuda:
            set_loss = set_loss.cuda()
        loss = mseloss + set_loss
        #loss = mseloss
        return loss 
Example #14
Source File: model_re.py    From fastNLP with Apache License 2.0 5 votes vote down vote up
def get_mention_emb(self, flatten_lstm, mention_start, mention_end):
        mention_start_tensor = torch.from_numpy(mention_start).to(self.device)
        mention_end_tensor = torch.from_numpy(mention_end).to(self.device)
        emb_start = flatten_lstm.index_select(dim=0, index=mention_start_tensor)  # [mention_num,embed]
        emb_end = flatten_lstm.index_select(dim=0, index=mention_end_tensor)  # [mention_num,embed]
        return emb_start, emb_end 
Example #15
Source File: model_re.py    From fastNLP with Apache License 2.0 5 votes vote down vote up
def flat_lstm(self, lstm_out, seq_lens):
        batch = lstm_out.shape[0]
        seq = lstm_out.shape[1]
        dim = lstm_out.shape[2]
        l = [j + i * seq for i, seq_len in enumerate(seq_lens) for j in range(seq_len)]
        flatted = torch.index_select(lstm_out.view(batch * seq, dim), 0, torch.LongTensor(l).to(self.device))
        return flatted 
Example #16
Source File: model_re.py    From fastNLP with Apache License 2.0 5 votes vote down vote up
def reorder_sequence(self, sequence_emb, order, batch_first=True):
        """
        sequence_emb: [T, B, D] if not batch_first
        order: list of sequence length
        """
        batch_dim = 0 if batch_first else 1
        assert len(order) == sequence_emb.size()[batch_dim]

        order = torch.LongTensor(order)
        order = order.to(sequence_emb).long()

        sorted_ = sequence_emb.index_select(index=order, dim=batch_dim)

        del order
        return sorted_ 
Example #17
Source File: DDPAE_utils.py    From DDPAE-video-prediction with MIT License 5 votes vote down vote up
def expand_pose(pose):
  '''
  param pose: N x 3
  Takes 3-dimensional vectors, and massages them into 2x3 affine transformation matrices:
  [s,x,y] -> [[s,0,x],
              [0,s,y]]
  '''
  n = pose.size(0)
  expansion_indices = Variable(torch.LongTensor([1, 0, 2, 0, 1, 3]).cuda(), requires_grad=False)
  zeros = Variable(torch.zeros(n, 1).cuda(), requires_grad=False)
  out = torch.cat([zeros, pose], dim=1)
  return torch.index_select(out, 1, expansion_indices).view(n, 2, 3) 
Example #18
Source File: jobhopping.py    From prediction_api with MIT License 5 votes vote down vote up
def _id2PackedSequence(self, affi_id):
        # 输入的形状可以是 (T×B×*)。T 是最长序列长度,B 是 batch size,* 代表任意维度 (可以是 0)。如果 batch_first=True 的话,那么相应的 input size 就是 (B×T×*)。
        ret = torch.zeros(1, len(affi_id), self._INPUT_DIM)
        indices = torch.tensor(affi_id, device='cpu', dtype=torch.long)
        ret[0] = torch.index_select(self._affi, 0, indices)
        return torch.nn.utils.rnn.pack_padded_sequence(ret, [len(affi_id)],batch_first=True) 
Example #19
Source File: losses.py    From centerpose with MIT License 5 votes vote down vote up
def compute_rot_loss(output, target_bin, target_res, mask):
    # output: (B, 128, 8) [bin1_cls[0], bin1_cls[1], bin1_sin, bin1_cos, 
    #                 bin2_cls[0], bin2_cls[1], bin2_sin, bin2_cos]
    # target_bin: (B, 128, 2) [bin1_cls, bin2_cls]
    # target_res: (B, 128, 2) [bin1_res, bin2_res]
    # mask: (B, 128, 1)
    # import pdb; pdb.set_trace()
    output = output.view(-1, 8)
    target_bin = target_bin.view(-1, 2)
    target_res = target_res.view(-1, 2)
    mask = mask.view(-1, 1)
    loss_bin1 = compute_bin_loss(output[:, 0:2], target_bin[:, 0], mask)
    loss_bin2 = compute_bin_loss(output[:, 4:6], target_bin[:, 1], mask)
    loss_res = torch.zeros_like(loss_bin1)
    if target_bin[:, 0].nonzero().shape[0] > 0:
        idx1 = target_bin[:, 0].nonzero()[:, 0]
        valid_output1 = torch.index_select(output, 0, idx1.long())
        valid_target_res1 = torch.index_select(target_res, 0, idx1.long())
        loss_sin1 = compute_res_loss(
          valid_output1[:, 2], torch.sin(valid_target_res1[:, 0]))
        loss_cos1 = compute_res_loss(
          valid_output1[:, 3], torch.cos(valid_target_res1[:, 0]))
        loss_res += loss_sin1 + loss_cos1
    if target_bin[:, 1].nonzero().shape[0] > 0:
        idx2 = target_bin[:, 1].nonzero()[:, 0]
        valid_output2 = torch.index_select(output, 0, idx2.long())
        valid_target_res2 = torch.index_select(target_res, 0, idx2.long())
        loss_sin2 = compute_res_loss(
          valid_output2[:, 6], torch.sin(valid_target_res2[:, 1]))
        loss_cos2 = compute_res_loss(
          valid_output2[:, 7], torch.cos(valid_target_res2[:, 1]))
        loss_res += loss_sin2 + loss_cos2
    return loss_bin1 + loss_bin2 + loss_res 
Example #20
Source File: network_utils.py    From 3d-vehicle-tracking with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def compute_rot_loss(output, target_bin, target_res):
    # output: (B, 8) [bin1_cls[0], bin1_cls[1], bin1_sin, bin1_cos, 
    #                 bin2_cls[0], bin2_cls[1], bin2_sin, bin2_cos]
    # target_bin: (B, 2) [bin1_cls, bin2_cls]
    # target_res: (B, 2) [bin1_res, bin2_res]

    loss_bin1 = F.cross_entropy(output[:, 0:2], target_bin[:, 0])
    loss_bin2 = F.cross_entropy(output[:, 4:6], target_bin[:, 1])
    loss_res = torch.zeros_like(loss_bin1)
    if target_bin[:, 0].nonzero().shape[0] > 0:
        idx1 = target_bin[:, 0].nonzero()[:, 0]
        valid_output1 = torch.index_select(output, 0, idx1.long())
        valid_target_res1 = torch.index_select(target_res, 0, idx1.long())
        loss_sin1 = F.smooth_l1_loss(valid_output1[:, 2],
                                     torch.sin(valid_target_res1[:, 0]))
        loss_cos1 = F.smooth_l1_loss(valid_output1[:, 3],
                                     torch.cos(valid_target_res1[:, 0]))
        loss_res += loss_sin1 + loss_cos1
    if target_bin[:, 1].nonzero().shape[0] > 0:
        idx2 = target_bin[:, 1].nonzero()[:, 0]
        valid_output2 = torch.index_select(output, 0, idx2.long())
        valid_target_res2 = torch.index_select(target_res, 0, idx2.long())
        loss_sin2 = F.smooth_l1_loss(valid_output2[:, 6],
                                     torch.sin(valid_target_res2[:, 1]))
        loss_cos2 = F.smooth_l1_loss(valid_output2[:, 7],
                                     torch.cos(valid_target_res2[:, 1]))
        loss_res += loss_sin2 + loss_cos2
    return loss_bin1 + loss_bin2 + loss_res 
Example #21
Source File: losses.py    From centerpose with MIT License 5 votes vote down vote up
def tile(a, dim, n_tile):
    init_dim = a.size(dim)
    repeat_idx = [1] * a.dim()
    repeat_idx[dim] = n_tile
    a = a.repeat(*(repeat_idx))
    order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]))
    return torch.index_select(a, dim, order_index) 
Example #22
Source File: pointwise.py    From pykg2vec with MIT License 5 votes vote down vote up
def _concat_selected_embeddings(e1, t1, e2, t2):
        return torch.cat([torch.index_select(e1.weight, 0, t1), torch.index_select(e2.weight, 0, t2)], 1) 
Example #23
Source File: projection.py    From Pointnet2.ScanNet with MIT License 5 votes vote down vote up
def backward(ctx, grad_output):
        grad_label = grad_output.clone()
        num_ft = grad_output.shape[0]
        grad_label.resize_(num_ft, 32, 41)
        lin_indices_3d, lin_indices_2d = ctx.saved_variables
        num_ind = lin_indices_3d.data[0]
        vals = torch.index_select(grad_output.data.contiguous().view(num_ft, -1), 1, lin_indices_3d.data[1:1+num_ind])
        grad_label.data.view(num_ft, -1)[:, lin_indices_2d.data[1:1+num_ind]] = vals
        
        return grad_label, None, None, None 
Example #24
Source File: models.py    From KernelGAT with MIT License 5 votes vote down vote up
def self_attention(self, inputs, inputs_hiddens, mask, mask_evidence, index):
        idx = torch.LongTensor([index]).cuda()
        mask = mask.view([-1, self.evi_num, self.max_len])
        mask_evidence = mask_evidence.view([-1, self.evi_num, self.max_len])
        own_hidden = torch.index_select(inputs_hiddens, 1, idx)
        own_mask = torch.index_select(mask, 1, idx)
        own_input = torch.index_select(inputs, 1, idx)
        own_hidden = own_hidden.repeat(1, self.evi_num, 1, 1)
        own_mask = own_mask.repeat(1, self.evi_num, 1)
        own_input = own_input.repeat(1, self.evi_num, 1)

        hiddens_norm = F.normalize(inputs_hiddens, p=2, dim=-1)
        own_norm = F.normalize(own_hidden, p=2, dim=-1)

        att_score = self.get_intersect_matrix_att(hiddens_norm.view(-1, self.max_len, self.bert_hidden_dim), own_norm.view(-1, self.max_len, self.bert_hidden_dim),
                                                  mask_evidence.view(-1, self.max_len), own_mask.view(-1, self.max_len))
        att_score = att_score.view(-1, self.evi_num, self.max_len, 1)
        #if index == 1:
        #    for i in range(self.evi_num):
        #print (att_score.view(-1, self.evi_num, self.max_len)[0, 1, :])
        denoise_inputs = torch.sum(att_score * inputs_hiddens, 2)
        weight_inp = torch.cat([own_input, inputs], -1)
        weight_inp = self.proj_gat(weight_inp)
        weight_inp = F.softmax(weight_inp, dim=1)
        outputs = (inputs * weight_inp).sum(dim=1)
        weight_de = torch.cat([own_input, denoise_inputs], -1)
        weight_de = self.proj_gat(weight_de)
        weight_de = F.softmax(weight_de, dim=1)
        outputs_de = (denoise_inputs * weight_de).sum(dim=1)
        return outputs, outputs_de 
Example #25
Source File: transformation.py    From weakalign with MIT License 5 votes vote down vote up
def __call__(self, batch):
        image_batch, theta_batch = batch['image'], batch['theta'] 
#        theta_aff=torch.index_select(theta_batch[:,:6],1,self.aff_reorder_idx)
        theta_aff=theta_batch[:,:6].contiguous()
        theta_tps=theta_batch[:,6:]

        if self.use_cuda:
            image_batch = image_batch.cuda()
            theta_aff = theta_aff.cuda()
            theta_tps = theta_tps.cuda()
            
        b, c, h, w = image_batch.size()
              
        # generate symmetrically padded image for bigger sampling region
        image_batch = self.symmetricImagePad(image_batch,self.padding_factor)
        
        # convert to variables
        image_batch = Variable(image_batch,requires_grad=False)
        theta_aff =  Variable(theta_aff,requires_grad=False)        
        theta_tps =  Variable(theta_tps,requires_grad=False) 

        # get cropped image
        cropped_image_batch = self.rescalingTnf(image_batch=image_batch,
                                                theta_batch=None,
                                                padding_factor=self.padding_factor,
                                                crop_factor=self.crop_factor) # Identity is used as no theta given
        # get transformed image
        warped_image_aff = self.affTnf(image_batch=image_batch,
                                         theta_batch=theta_aff,
                                         padding_factor=self.padding_factor,
                                         crop_factor=self.crop_factor) 
        
        warped_image_tps = self.tpsTnf(image_batch=image_batch,
                                       theta_batch=theta_tps,
                                       padding_factor=self.padding_factor,
                                       crop_factor=self.crop_factor) 
            
        return {'source_image': cropped_image_batch, 'target_image_aff': warped_image_aff, 'target_image_tps': warped_image_tps, 'theta_GT_aff': theta_aff,  'theta_GT_tps': theta_tps} 
Example #26
Source File: transformation.py    From weakalign with MIT License 5 votes vote down vote up
def __call__(self, batch):
        image_batch, theta_batch = batch['image'], batch['theta'] 
#        theta_aff=torch.index_select(theta_batch[:,:6],1,self.aff_reorder_idx)
        theta_aff=theta_batch[:,:6].contiguous()
        theta_tps=theta_batch[:,6:]

        if self.use_cuda:
            image_batch = image_batch.cuda()
            theta_aff = theta_aff.cuda()
            theta_tps = theta_tps.cuda()
            
        b, c, h, w = image_batch.size()
              
        # generate symmetrically padded image for bigger sampling region
        image_batch = self.symmetricImagePad(image_batch,self.padding_factor)
        
        # convert to variables
        image_batch = Variable(image_batch,requires_grad=False)
        theta_aff =  Variable(theta_aff,requires_grad=False)        
        theta_tps =  Variable(theta_tps,requires_grad=False)        

        # get cropped image
        cropped_image_batch = self.rescalingTnf(image_batch=image_batch,
                                                theta_batch=None,
                                                padding_factor=self.padding_factor,
                                                crop_factor=self.crop_factor) # Identity is used as no theta given
        # get transformed image
        warped_image_batch = self.geometricTnf(image_batch=image_batch,
                                               theta_aff=theta_aff,
                                               theta_aff_tps=theta_tps)
        
        return {'source_image': cropped_image_batch, 'target_image': warped_image_batch,  'theta_GT_aff': theta_aff,  'theta_GT_tps': theta_tps} 
Example #27
Source File: graphML.py    From graph-neural-networks with GNU General Public License v3.0 5 votes vote down vote up
def forward(self, x):
        # x is of shape: batchSize x dimInFeatures x numberNodesIn
        B = x.shape[0]
        F = x.shape[1]
        Nin = x.shape[2]

        # Check if we have enough spectral filter coefficients as needed, or if
        # we need to fill out the rest using the spline kernel.
        if self.M == self.N:
            self.h = self.weight # F x E x G x N (because N = M)
        else:
            # Adjust dimensions for proper algebraic matrix multiplication
            splineKernel = self.splineKernel.reshape([1,self.E,self.N,self.M])
            # We will multiply a 1 x E x N x M matrix with a F x E x M x G
            # matrix to get the proper F x E x N x G coefficients
            self.h = torch.matmul(splineKernel, self.weight.permute(0,1,3,2))
            # And now we rearrange it to the same shape that the function takes
            self.h = self.h.permute(0,1,3,2) # F x E x G x N
        # And now we add the zero padding (if this comes from a pooling
        # operation)
        if Nin < self.N:
            zeroPad = torch.zeros(B, F, self.N-Nin).type(x.dtype).to(x.device)
            x = torch.cat((x, zeroPad), dim = 2)
        # Compute the filter output
        u = spectralGF(self.h, self.V, self.VH, x, self.bias)
        # So far, u is of shape batchSize x dimOutFeatures x numberNodes
        # And we want to return a tensor of shape
        # batchSize x dimOutFeatures x numberNodesIn
        # since the nodes between numberNodesIn and numberNodes are not required
        if Nin < self.N:
            u = torch.index_select(u, 2, torch.arange(Nin).to(u.device))
        return u 
Example #28
Source File: transformation.py    From weakalign with MIT License 5 votes vote down vote up
def __call__(self, batch):
        image_batch, theta_batch = batch['image'], batch['theta'] 
        if self.use_cuda:
            image_batch = image_batch.cuda()
            theta_batch = theta_batch.cuda()
            
        b, c, h, w = image_batch.size()
              
        # generate symmetrically padded image for bigger sampling region
        image_batch = self.symmetricImagePad(image_batch,self.padding_factor)
        
        # convert to variables
        image_batch = Variable(image_batch,requires_grad=False)
        theta_batch =  Variable(theta_batch,requires_grad=False)        

        # get cropped image
        cropped_image_batch = self.rescalingTnf(image_batch=image_batch,
                                                theta_batch=None,
                                                padding_factor=self.padding_factor,
                                                crop_factor=self.crop_factor) # Identity is used as no theta given
        # get transformed image
        warped_image_batch = self.geometricTnf(image_batch=image_batch,
                                               theta_batch=theta_batch,
                                               padding_factor=self.padding_factor,
                                               crop_factor=self.crop_factor) # Identity is used as no theta given
        
        if self.supervision=='strong':
            return {'source_image': cropped_image_batch, 'target_image': warped_image_batch, 'theta_GT': theta_batch}
        
        elif self.supervision=='weak':
            pos_batch_idx = torch.LongTensor(range(int(b/2)))
            neg_batch_idx = torch.LongTensor(range(int(b/2),b))
            if self.use_cuda:
                pos_batch_idx = pos_batch_idx.cuda()
                neg_batch_idx = neg_batch_idx.cuda()
            source_image = torch.cat((torch.index_select(cropped_image_batch,0,pos_batch_idx),
                                      torch.index_select(cropped_image_batch,0,pos_batch_idx)),0)
            target_image = torch.cat((torch.index_select(warped_image_batch,0,pos_batch_idx),
                                      torch.index_select(cropped_image_batch,0,neg_batch_idx)),0)
            return {'source_image': source_image, 'target_image': target_image, 'theta_GT': theta_batch} 
Example #29
Source File: tensor.py    From dgl with Apache License 2.0 5 votes vote down vote up
def take(data, indices, dim):
    new_shape = data.shape[:dim] + indices.shape + data.shape[dim+1:]
    return th.index_select(data, dim, indices.view(-1)).view(new_shape) 
Example #30
Source File: utils.py    From OpenKiwi with GNU Affero General Public License v3.0 5 votes vote down vote up
def align_source(
    source,
    trg2src_alignments,
    max_aligned,
    unaligned_idx,
    padding_idx,
    pad_size,
):
    assert len(source.shape) == 2
    window_size = source.shape[1]

    assert len(trg2src_alignments) <= pad_size
    aligned_source = source.new_full(
        (pad_size, max_aligned, window_size), padding_idx
    )
    unaligned = source.new_full((window_size,), unaligned_idx)
    nb_alignments = source.new_ones(pad_size, dtype=torch.float)

    for i, source_positions in enumerate(trg2src_alignments):
        if not source_positions:
            aligned_source[i, 0] = unaligned
        else:
            selected = torch.index_select(
                source,
                0,
                torch.tensor(
                    source_positions[:max_aligned], device=source.device
                ),
            )
            aligned_source[i, : len(selected)] = selected
            # counts how many tokens is a target token aligned to
            nb_alignments[i] = len(selected)
    return aligned_source, nb_alignments