Python torch.topk() Examples

The following are 30 code examples of torch.topk(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch , or try the search function .
Example #1
Source File: unsup_net.py    From SEDST with MIT License 7 votes vote down vote up
def pz_selective_sampling(self, pz_proba):
        """
        Selective sampling of pz(do max-sampling but prevent repeated words)
        """
        pz_proba = pz_proba.data
        z_proba, z_token = torch.topk(pz_proba, pz_proba.size(0), dim=2)
        z_token = z_token.transpose(0, 1)  # [B,Tz,top_Tz]
        all_sampled_z = []
        for b in range(z_token.size(0)):
            sampled_z = []
            for t in range(z_token.size(1)):
                for i in range(z_token.size(2)):
                    if z_token[b][t][i] not in sampled_z:
                        sampled_z.append(z_token[b][t][i])
                        break
            all_sampled_z.append(sampled_z)
        return all_sampled_z 
Example #2
Source File: unsup_net.py    From SEDST with MIT License 7 votes vote down vote up
def greedy_decode(self, pz_dec_outs, pz_proba, u_enc_out, m_tm1, last_hidden, flag):
        """
        greedy decoding of the response
        :param pz_dec_outs:
        :param u_enc_out:
        :param m_tm1:
        :param last_hidden:
        :return: nested-list
        """
        decoded = []
        decoder = self.m_decoder if not flag else self.p_decoder
        for t in range(self.max_ts):
            proba, last_hidden = decoder(pz_dec_outs, pz_proba, u_enc_out, m_tm1, last_hidden)
            mt_proba, mt_index = torch.topk(proba, 1)  # [B,1]
            mt_index = mt_index.data.view(-1)
            decoded.append(mt_index)
            m_tm1 = cuda_(Variable(mt_index).view(1, -1))
        decoded = torch.stack(decoded, dim=0).transpose(0, 1)
        decoded = list(decoded)
        return [list(_) for _ in decoded] 
Example #3
Source File: recurrent.py    From Tagger with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def top_k_softmax(logits, k, n):
        top_logits, top_indices = torch.topk(logits, k=min(k + 1, n))

        top_k_logits = top_logits[:, :k]
        top_k_indices = top_indices[:, :k]

        probs = torch.softmax(top_k_logits, dim=-1)
        batch = top_k_logits.shape[0]
        k = top_k_logits.shape[1]

        # Flat to 1D
        indices_flat = torch.reshape(top_k_indices, [-1])
        indices_flat = indices_flat + torch.div(
            torch.arange(batch * k, device=logits.device), k) * n

        tensor = torch.zeros([batch * n], dtype=logits.dtype,
                             device=logits.device)
        tensor = tensor.scatter_add(0, indices_flat.long(),
                                    torch.reshape(probs, [-1]))

        return torch.reshape(tensor, [batch, n]) 
Example #4
Source File: interact.py    From dialogue-generation with MIT License 6 votes vote down vote up
def select_topk(args, logits, force_no_eos_id=None):
    """
    Applies topk sampling decoding.
    """        
    if force_no_eos_id is not None:
        logits[:, force_no_eos_id] = float('-inf')

    indices_to_remove = logits < \
        torch.topk(logits, args.top_k, axis=-1)[0][
            ..., -1, None]

    logits[indices_to_remove] = float('-inf')

    return logits


# implementation is from Huggingface/transformers repo 
Example #5
Source File: centernet_tensorrt_engine.py    From centerpose with MIT License 6 votes vote down vote up
def _topk(self, scores, K=40):
        batch, cat, height, width = scores.size()
          
        topk_scores, topk_inds = torch.topk(scores.view(batch, cat, -1), K)

        topk_inds = topk_inds % (height * width)
        topk_ys   = (topk_inds / width).int().float()
        topk_xs   = (topk_inds % width).int().float()
          
        topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), K)
        topk_clses = (topk_ind / K).int()
        topk_inds = _gather_feat(
            topk_inds.view(batch, -1, 1), topk_ind).view(batch, K)
        topk_ys = _gather_feat(topk_ys.view(batch, -1, 1), topk_ind).view(batch, K)
        topk_xs = _gather_feat(topk_xs.view(batch, -1, 1), topk_ind).view(batch, K)

        return topk_score, topk_inds, topk_clses, topk_ys, topk_xs 
Example #6
Source File: decode.py    From centerpose with MIT License 6 votes vote down vote up
def _topk(scores, K=40):
    batch, cat, height, width = scores.size()
      
    topk_scores, topk_inds = torch.topk(scores.view(batch, cat, -1), K)

    topk_inds = topk_inds % (height * width)
    topk_ys   = (topk_inds / width).int().float()
    topk_xs   = (topk_inds % width).int().float()
      
    topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), K)
    topk_clses = (topk_ind / K).int()
    topk_inds = _gather_feat(
        topk_inds.view(batch, -1, 1), topk_ind).view(batch, K)
    topk_ys = _gather_feat(topk_ys.view(batch, -1, 1), topk_ind).view(batch, K)
    topk_xs = _gather_feat(topk_xs.view(batch, -1, 1), topk_ind).view(batch, K)

    return topk_score, topk_inds, topk_clses, topk_ys, topk_xs 
Example #7
Source File: model.py    From ConvLab with MIT License 6 votes vote down vote up
def greedy_decode(self, decoder_hidden, encoder_outputs, target_tensor):
        decoded_sentences = []
        batch_size, seq_len = target_tensor.size()
        decoder_input = torch.LongTensor([[SOS_token] for _ in range(batch_size)], device=self.device)

        decoded_words = torch.zeros((batch_size, self.max_len))
        for t in range(self.max_len):
            decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_outputs)

            topv, topi = decoder_output.data.topk(1)  # get candidates
            topi = topi.view(-1)

            decoded_words[:, t] = topi
            decoder_input = topi.detach().view(-1, 1)

        for sentence in decoded_words:
            sent = []
            for ind in sentence:
                if self.output_index2word(str(int(ind.item()))) == self.output_index2word(str(EOS_token)):
                    break
                sent.append(self.output_index2word(str(int(ind.item()))))
            decoded_sentences.append(' '.join(sent))

        return decoded_sentences 
Example #8
Source File: tsd_net.py    From ConvLab with MIT License 6 votes vote down vote up
def greedy_decode(self, pz_dec_outs, u_enc_out, m_tm1, u_input_np, last_hidden, degree_input, bspan_index):
        decoded = []
        bspan_index_np = pad_sequences(bspan_index).transpose((1, 0))
        for t in range(self.max_ts):
            proba, last_hidden, _ = self.m_decoder(pz_dec_outs, u_enc_out, u_input_np, m_tm1,
                                                   degree_input, last_hidden, bspan_index_np)
            proba = torch.cat((proba[:, :2], proba[:, 3:]), 1)
            mt_proba, mt_index = torch.topk(proba, 1)  # [B,1]
            mt_index.add_(mt_index.ge(2).long())
            mt_index = mt_index.data.view(-1)
            decoded.append(mt_index.clone())
            for i in range(mt_index.size(0)):
                if mt_index[i] >= cfg.vocab_size:
                    mt_index[i] = 2  # unk
            m_tm1 = cuda_(Variable(mt_index).view(1, -1))
        decoded = torch.stack(decoded, dim=0).transpose(0, 1)
        decoded = list(decoded)
        return [list(_) for _ in decoded] 
Example #9
Source File: semi_sup_net.py    From SEDST with MIT License 6 votes vote down vote up
def greedy_decode(self, pz_dec_outs, pz_proba, u_enc_out, m_tm1, last_hidden, degree_input):
        """
        greedy decoding of the response
        :param pz_dec_outs:
        :param u_enc_out:
        :param m_tm1:
        :param last_hidden:
        :return: nested-list
        """
        decoded = []
        for t in range(self.max_ts):
            proba, last_hidden, _ = self.m_decoder(pz_dec_outs, pz_proba, u_enc_out, m_tm1, degree_input, last_hidden)
            mt_proba, mt_index = torch.topk(proba, 1)  # [B,1]
            mt_index = mt_index.data.view(-1)
            decoded.append(mt_index)
            m_tm1 = cuda_(Variable(mt_index).view(1, -1))
        decoded = torch.stack(decoded, dim=0).transpose(0, 1)
        decoded = list(decoded)
        return [list(_) for _ in decoded] 
Example #10
Source File: memory.py    From LSH_Memory with Apache License 2.0 6 votes vote down vote up
def predict(self, x):
        batch_size, dims = x.size()
        query = F.normalize(self.query_proj(x), dim=1)

        # Find the k-nearest neighbors of the query
        scores = torch.matmul(query, torch.t(self.keys_var))
        cosine_similarity, topk_indices_var = torch.topk(scores, self.top_k, dim=1)

        # softmax of cosine similarities - embedding
        softmax_score = F.softmax(self.softmax_temperature * cosine_similarity)

        # retrive memory values - prediction
        y_hat_indices = topk_indices_var.data[:, 0]
        y_hat = self.values[y_hat_indices]

        return y_hat, softmax_score 
Example #11
Source File: test_geometry.py    From dgl with Apache License 2.0 6 votes vote down vote up
def test_knn():
    x = th.randn(8, 3)
    kg = dgl.nn.KNNGraph(3)
    d = th.cdist(x, x)

    def check_knn(g, x, start, end):
        for v in range(start, end):
            src, _ = g.in_edges(v)
            src = set(src.numpy())
            i = v - start
            src_ans = set(th.topk(d[start:end, start:end][i], 3, largest=False)[1].numpy() + start)
            assert src == src_ans

    g = kg(x)
    check_knn(g, x, 0, 8)

    g = kg(x.view(2, 4, 3))
    check_knn(g, x, 0, 4)
    check_knn(g, x, 4, 8)

    kg = dgl.nn.SegmentedKNNGraph(3)
    g = kg(x, [3, 5])
    check_knn(g, x, 0, 3)
    check_knn(g, x, 3, 8) 
Example #12
Source File: competing_completed.py    From translate with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def select_next_words(
        self, word_scores, bsz, beam_size, possible_translation_tokens
    ):
        cand_scores, cand_indices = torch.topk(word_scores.view(bsz, -1), k=beam_size)
        possible_tokens_size = self.vocab_size
        if possible_translation_tokens is not None:
            possible_tokens_size = possible_translation_tokens.size(0)
        cand_beams = torch.div(cand_indices, possible_tokens_size)
        cand_indices.fmod_(possible_tokens_size)
        # Handle vocab reduction
        if possible_translation_tokens is not None:
            possible_translation_tokens = possible_translation_tokens.view(
                1, possible_tokens_size
            ).expand(cand_indices.size(0), possible_tokens_size)
            cand_indices = torch.gather(
                possible_translation_tokens, dim=1, index=cand_indices, out=cand_indices
            )
        return cand_scores, cand_indices, cand_beams 
Example #13
Source File: word_predictor.py    From translate with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def get_topk_predicted_tokens(self, net_output, src_tokens, log_probs: bool):
        """
        Get self.topk_labels_per_source_token top predicted words for vocab
        reduction (per source token).
        """
        assert (
            isinstance(self.topk_labels_per_source_token, int)
            and self.topk_labels_per_source_token > 0
        ), "topk_labels_per_source_token must be a positive int, or None"

        # number of labels to predict for each example in batch
        k = src_tokens.size(1) * self.topk_labels_per_source_token
        # [batch_size, vocab_size]
        probs = self.get_normalized_probs(net_output, log_probs)
        _, topk_indices = torch.topk(probs, k, dim=1)

        return topk_indices 
Example #14
Source File: beam_decode.py    From translate with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def diversity_sibling_rank(self, logprobs, gamma):
        """
        See "A Simple, Fast Diverse Decoding Algorithm for Neural Generation"
        for details
        """
        _, beam_size, vocab_size = logprobs.size()
        logprobs = logprobs.view(-1, vocab_size)
        # Keep consistent with beamsearch class in fairseq
        k = min(2 * beam_size, vocab_size)
        _, indices = torch.topk(logprobs, k)
        # Set diverse penalty as k for all words
        diverse_penalty = torch.ones_like(logprobs) * k
        diversity_sibling_rank = (
            torch.arange(0, k).view(-1, 1).expand(k, logprobs.size(0)).type_as(logprobs)
        )
        # Set diversity penalty accordingly for top-k words
        diverse_penalty[
            torch.arange(0, logprobs.size(0)).long(), indices.transpose(0, 1)
        ] = diversity_sibling_rank
        logprobs -= gamma * diverse_penalty
        return logprobs 
Example #15
Source File: rnn_model.py    From vmf_vae_nlp with MIT License 5 votes vote down vote up
def forward_decode(self, args, input, ntokens):

        seq_len = input.size()[0]
        batch_sz = input.size()[1]
        # emb: seq_len, batchsz, hid_dim
        # hidden: ([2(nlayers),10(batchsz),200],[])
        hidden = None
        outputs_prob = Variable(torch.FloatTensor(seq_len, batch_sz, ntokens))
        if args.cuda:
            outputs_prob = outputs_prob.cuda()
        outputs = torch.LongTensor(seq_len, batch_sz)

        # First time step sos
        sos = Variable(torch.ones(batch_sz).long())  # id for sos =1
        unk = Variable(torch.ones(batch_sz).long()) * 2  # id for unk =2
        if args.cuda:
            sos = sos.cuda()
            unk = unk.cuda()

        emb_0 = self.drop(self.encoder(sos)).unsqueeze(0)
        emb_t = self.drop(self.encoder(unk)).unsqueeze(0)

        for t in range(seq_len):
            # input (seq_len, batch, input_size)
            if t == 0:
                emb = emb_0
            else:
                emb = emb_t

            output, hidden = self.rnn(emb, hidden)
            output_prob = self.decoder(self.drop(output))
            output_prob = output_prob.squeeze(0)
            outputs_prob[t] = output_prob
            value, ind = torch.topk(output_prob, 1, dim=1)
            outputs[t] = ind.squeeze(1).data
        return outputs_prob, outputs 
Example #16
Source File: search.py    From fairseq with MIT License 5 votes vote down vote up
def step(self, step: int, lprobs, scores: Optional[Tensor]):
        bsz, beam_size, vocab_size = lprobs.size()

        if step == 0:
            # at the first step all hypotheses are equally likely, so use
            # only the first beam
            lprobs = lprobs[:, ::beam_size, :].contiguous()
        else:
            # make probs contain cumulative scores for each hypothesis
            assert scores is not None
            lprobs = lprobs + scores[:, :, step - 1].unsqueeze(-1)

        top_prediction = torch.topk(
            lprobs.view(bsz, -1),
            k=min(
                # Take the best 2 x beam_size predictions. We'll choose the first
                # beam_size of these which don't predict eos to continue with.
                beam_size * 2,
                lprobs.view(bsz, -1).size(1) - 1,  # -1 so we never select pad
            ),
        )
        scores_buf = top_prediction[0]
        indices_buf = top_prediction[1]
        beams_buf = indices_buf // vocab_size
        indices_buf = indices_buf.fmod(vocab_size)
        return scores_buf, indices_buf, beams_buf 
Example #17
Source File: tasks.py    From openseg.pytorch with MIT License 5 votes vote down vote up
def _get_multilabel_prediction(dir_logits, no_offset_mask=None, topk=8):
        h, w, _ = dir_logits.shape
        dir_logits = torch.from_numpy(
            dir_logits
        ).unsqueeze(0).permute(0, 3, 1, 2)
        offsets = []
        if topk == dir_logits.shape[1]:
            for i in range(topk):
                offset_i = DTOffsetHelper.label_to_vector(
                    torch.tensor([i]).view(1, 1, 1)
                ).repeat(1, 1, h, w)
                offset_i = offset_i.float() * dir_logits[:, i:i+1, :, :]
                offsets.append(offset_i)
        else:
            dir_logits, dir_pred = torch.topk(dir_logits, topk, dim=1)
            for i in range(topk):
                dir_pred_i = dir_pred[:, i, :, :]
                offset_i = DTOffsetHelper.label_to_vector(dir_pred_i)
                offset_i = offset_i.float() * dir_logits[:, i:i+1, :, :]
                offsets.append(offset_i)

        offset = sum(offsets)
        dir_pred = DTOffsetHelper.vector_to_label(
            offset.permute(0, 2, 3, 1),
            num_classes=8,
            return_tensor=True
        )

        dir_pred = dir_pred.squeeze(0).numpy()

        if no_offset_mask is not None:
            dir_pred[no_offset_mask] = 8

        return dir_pred 
Example #18
Source File: inference.py    From Clothing-Detection with GNU General Public License v3.0 5 votes vote down vote up
def select_over_all_levels(self, boxlists):
        num_images = len(boxlists)
        # different behavior during training and during testing:
        # during training, post_nms_top_n is over *all* the proposals combined, while
        # during testing, it is over the proposals for each image
        # NOTE: it should be per image, and not per batch. However, to be consistent 
        # with Detectron, the default is per batch (see Issue #672)
        if self.training and self.fpn_post_nms_per_batch:
            objectness = torch.cat(
                [boxlist.get_field("objectness") for boxlist in boxlists], dim=0
            )
            box_sizes = [len(boxlist) for boxlist in boxlists]
            post_nms_top_n = min(self.fpn_post_nms_top_n, len(objectness))
            _, inds_sorted = torch.topk(objectness, post_nms_top_n, dim=0, sorted=True)
            inds_mask = torch.zeros_like(objectness, dtype=torch.uint8)
            inds_mask[inds_sorted] = 1
            inds_mask = inds_mask.split(box_sizes)
            for i in range(num_images):
                boxlists[i] = boxlists[i][inds_mask[i]]
        else:
            for i in range(num_images):
                objectness = boxlists[i].get_field("objectness")
                post_nms_top_n = min(self.fpn_post_nms_top_n, len(objectness))
                _, inds_sorted = torch.topk(
                    objectness, post_nms_top_n, dim=0, sorted=True
                )
                boxlists[i] = boxlists[i][inds_sorted]
        return boxlists 
Example #19
Source File: helper.py    From vmf_vae_nlp with MIT License 5 votes vote down vote up
def demo(txt_seq, pred_seq, gold_seq, p_gens, attn_list, meta):
    tgt_len = len(attn_list)
    print(tgt_len)
    tgt_len_ = len(p_gens)
    print(tgt_len_)
    pred = []
    for x in pred_seq[0]:
        pred += x
    tgt_len__ = len(pred)
    assert tgt_len == tgt_len_ == tgt_len__

    txt = []
    for x in txt_seq[0]:
        txt += x

    gold = []
    for x in gold_seq[0][0]:
        gold += x
    output_txt = ' '.join(txt)
    output_gold = ' '.join(gold)
    output_meta = "%s" % (meta)
    output_pred = []
    for t in range(tgt_len):
        p_gen = p_gens[t][0]
        att = attn_list[t]
        val, idx = torch.topk(att, 1)
        attn_val = val[0]
        attn_txt = txt[idx[0]]
        o = '{:10s} G{:.2f} {:10s} A{:.2f}\n'.format(pred[t], p_gen, attn_txt, attn_val)
        output_pred.append(o)
    output_pred = '\t'.join(output_pred)

    output_string = '\nMeta: %s\nText: %s\nGold: %s\nPred:\n%s\n' % (output_meta, output_txt, output_gold, output_pred)
    return output_string, pred_seq, gold_seq 
Example #20
Source File: tasks.py    From openseg.pytorch with MIT License 5 votes vote down vote up
def eval(outputs, meta, running_scores):
        distance_map = meta['ori_distance_map']
        seg_label_map = meta['ori_target']
        dir_label_map = meta['ori_multi_label_direction_map']
        dir_label_map = DTOffsetHelper.encode_multi_labels(dir_label_map)
        dir_label_map[seg_label_map == -1, :] = -1
        gt_mask_label = DTOffsetHelper.distance_to_mask_label(
            distance_map,
            seg_label_map,
            return_tensor=False
        )

        mask_pred = MaskTask.get_mask_pred(outputs['mask'])
        dir_pred = MLDirectionTask._get_multilabel_prediction(
            outputs['ml_dir'],
            no_offset_mask=mask_pred == 0,
            topk=8
        )

        running_scores['ML dir (mask)'].update(
            dir_pred, dir_label_map,
            (mask_pred == 1) & (seg_label_map != -1)
        )
        running_scores['ML dir (GT)'].update(
            dir_pred, dir_label_map,
            gt_mask_label == 1
        ) 
Example #21
Source File: projection.py    From pykg2vec with MIT License 5 votes vote down vote up
def predict_head_rank(self, e, r, topk=-1):
        _, rank = torch.topk(-self.forward(e, r, direction="head"), k=topk)
        return rank 
Example #22
Source File: projection.py    From pykg2vec with MIT License 5 votes vote down vote up
def predict_tail_rank(self, e, r, topk=-1):
        _, rank = torch.topk(-self.forward(e, r, direction="tail"), k=topk)
        return rank 
Example #23
Source File: tensor.py    From dgl with Apache License 2.0 5 votes vote down vote up
def topk(input, k, dim, descending=True):
    return th.topk(input, k, dim, largest=descending)[0] 
Example #24
Source File: mask_point_head.py    From mmdetection with Apache License 2.0 5 votes vote down vote up
def get_roi_rel_points_test(self, mask_pred, pred_label, cfg):
        """Get ``num_points`` most uncertain points during test.

        Args:
            mask_pred (Tensor): A tensor of shape (num_rois, num_classes,
                mask_height, mask_width) for class-specific or class-agnostic
                prediction.
            pred_label (list): The predication class for each instance.
            cfg (dict): Testing config of point head.

        Returns:
            point_indices (Tensor): A tensor of shape (num_rois, num_points)
                that contains indices from [0, mask_height x mask_width) of the
                most uncertain points.
            point_coords (Tensor): A tensor of shape (num_rois, num_points, 2)
                that contains [0, 1] x [0, 1] normalized coordinates of the
                most uncertain points from the [mask_height, mask_width] grid .
        """
        num_points = cfg.subdivision_num_points
        uncertainty_map = self._get_uncertainty(mask_pred, pred_label)
        num_rois, _, mask_height, mask_width = uncertainty_map.shape
        h_step = 1.0 / mask_height
        w_step = 1.0 / mask_width

        uncertainty_map = uncertainty_map.view(num_rois,
                                               mask_height * mask_width)
        num_points = min(mask_height * mask_width, num_points)
        point_indices = uncertainty_map.topk(num_points, dim=1)[1]
        point_coords = uncertainty_map.new_zeros(num_rois, num_points, 2)
        point_coords[:, :, 0] = w_step / 2.0 + (point_indices %
                                                mask_width).float() * w_step
        point_coords[:, :, 1] = h_step / 2.0 + (point_indices //
                                                mask_width).float() * h_step
        return point_indices, point_coords 
Example #25
Source File: evaluator.py    From pykg2vec with MIT License 5 votes vote down vote up
def test_rel_rank(self, h, t, topk=-1):
        if hasattr(self.model, 'predict_rel_rank'):
            # TODO: this broke training on ProjE_pointwise
            # h = h.unsqueeze(0).to(self.config.device)
            # t = t.unsqueeze(0).to(self.config.device)
            rank = self.model.predict_rel_rank(h, t, topk=topk)
            return rank.squeeze(0)

        h_batch = torch.LongTensor([h]).repeat([self.config.tot_relation]).to(self.config.device)
        rel_array = torch.LongTensor(list(range(self.config.tot_relation))).to(self.config.device)
        t_batch = torch.LongTensor([t]).repeat([self.config.tot_relation]).to(self.config.device)

        preds = self.model.forward(h_batch, rel_array, t_batch)
        _, rank = torch.topk(preds, k=topk)
        return rank 
Example #26
Source File: evaluator.py    From pykg2vec with MIT License 5 votes vote down vote up
def test_head_rank(self, r, t, topk=-1):
        if hasattr(self.model, 'predict_head_rank'):
            # TODO: this broke training on ProjE_pointwise
            # t = t.unsqueeze(0).to(self.config.device)
            # r = r.unsqueeze(0).to(self.config.device)
            rank = self.model.predict_head_rank(t, r, topk=topk)
            return rank.squeeze(0)

        entity_array = torch.LongTensor(list(range(self.config.tot_entity))).to(self.config.device)
        r_batch = torch.LongTensor([r]).repeat([self.config.tot_entity]).to(self.config.device)
        t_batch = torch.LongTensor([t]).repeat([self.config.tot_entity]).to(self.config.device)

        preds = self.model.forward(entity_array, r_batch, t_batch)
        _, rank = torch.topk(preds, k=topk)
        return rank 
Example #27
Source File: evaluator.py    From pykg2vec with MIT License 5 votes vote down vote up
def test_tail_rank(self, h, r, topk=-1):
        if hasattr(self.model, 'predict_tail_rank'):
            # TODO: this broke training on ProjE_pointwise
            # h = h.unsqueeze(0).to(self.config.device)
            # r = r.unsqueeze(0).to(self.config.device)
            rank = self.model.predict_tail_rank(h, r, topk=topk)
            return rank.squeeze(0)

        h_batch = torch.LongTensor([h]).repeat([self.config.tot_entity]).to(self.config.device)
        r_batch = torch.LongTensor([r]).repeat([self.config.tot_entity]).to(self.config.device)
        entity_array = torch.LongTensor(list(range(self.config.tot_entity))).to(self.config.device)

        preds = self.model.forward(h_batch, r_batch, entity_array)
        _, rank = torch.topk(preds, k=topk)
        return rank 
Example #28
Source File: tensor.py    From dgl with Apache License 2.0 5 votes vote down vote up
def argtopk(input, k, dim, descending=True):
    return th.topk(input, k, dim, largest=descending)[1] 
Example #29
Source File: primitives.py    From torchbearer with MIT License 5 votes vote down vote up
def process(self, *args):
        state = args[0]
        y_pred = state[self.pred_key]
        y_true = state[self.target_key]
        mask = y_true.eq(self.ignore_index).eq(0)
        y_pred = y_pred[mask]
        y_true = y_true[mask]

        sorted_indices = torch.topk(y_pred, self.k, dim=1)[1]
        expanded_y = y_true.view(-1, 1).expand(-1, self.k)
        return torch.sum(torch.eq(sorted_indices, expanded_y), dim=1).float() 
Example #30
Source File: memory.py    From LSH_Memory with Apache License 2.0 5 votes vote down vote up
def update(self, query, y, y_hat, y_hat_indices):
        batch_size, dims = query.size()

        # 1) Untouched: Increment memory by 1
        self.age += 1

        # Divide batch by correctness
        result = torch.squeeze(torch.eq(y_hat, torch.unsqueeze(y.data, dim=1))).float()
        incorrect_examples = torch.squeeze(torch.nonzero(1-result))
        correct_examples = torch.squeeze(torch.nonzero(result))

        incorrect = len(incorrect_examples.size()) > 0
        correct = len(correct_examples.size()) > 0

        # 2) Correct: if V[n1] = v
        # Update Key k[n1] <- normalize(q + K[n1]), Reset Age A[n1] <- 0
        if correct:
            correct_indices = y_hat_indices[correct_examples]
            correct_keys = self.keys[correct_indices]
            correct_query = query.data[correct_examples]

            new_correct_keys = F.normalize(correct_keys + correct_query, dim=1)
            self.keys[correct_indices] = new_correct_keys
            self.age[correct_indices] = 0

        # 3) Incorrect: if V[n1] != v
        # Select item with oldest age, Add random offset - n' = argmax_i(A[i]) + r_i 
        # K[n'] <- q, V[n'] <- v, A[n'] <- 0
        if incorrect:
            incorrect_size = incorrect_examples.size()[0]
            incorrect_query = query.data[incorrect_examples]
            incorrect_values = y.data[incorrect_examples]

            age_with_noise = self.age + random_uniform((self.memory_size, 1), -self.age_noise, self.age_noise, cuda=True)
            topk_values, topk_indices = torch.topk(age_with_noise, incorrect_size, dim=0)
            oldest_indices = torch.squeeze(topk_indices)

            self.keys[oldest_indices] = incorrect_query
            self.values[oldest_indices] = incorrect_values
            self.age[oldest_indices] = 0