Python torch.topk() Examples
The following are 30
code examples of torch.topk().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch
, or try the search function
.
Example #1
Source File: unsup_net.py From SEDST with MIT License | 7 votes |
def pz_selective_sampling(self, pz_proba): """ Selective sampling of pz(do max-sampling but prevent repeated words) """ pz_proba = pz_proba.data z_proba, z_token = torch.topk(pz_proba, pz_proba.size(0), dim=2) z_token = z_token.transpose(0, 1) # [B,Tz,top_Tz] all_sampled_z = [] for b in range(z_token.size(0)): sampled_z = [] for t in range(z_token.size(1)): for i in range(z_token.size(2)): if z_token[b][t][i] not in sampled_z: sampled_z.append(z_token[b][t][i]) break all_sampled_z.append(sampled_z) return all_sampled_z
Example #2
Source File: unsup_net.py From SEDST with MIT License | 7 votes |
def greedy_decode(self, pz_dec_outs, pz_proba, u_enc_out, m_tm1, last_hidden, flag): """ greedy decoding of the response :param pz_dec_outs: :param u_enc_out: :param m_tm1: :param last_hidden: :return: nested-list """ decoded = [] decoder = self.m_decoder if not flag else self.p_decoder for t in range(self.max_ts): proba, last_hidden = decoder(pz_dec_outs, pz_proba, u_enc_out, m_tm1, last_hidden) mt_proba, mt_index = torch.topk(proba, 1) # [B,1] mt_index = mt_index.data.view(-1) decoded.append(mt_index) m_tm1 = cuda_(Variable(mt_index).view(1, -1)) decoded = torch.stack(decoded, dim=0).transpose(0, 1) decoded = list(decoded) return [list(_) for _ in decoded]
Example #3
Source File: recurrent.py From Tagger with BSD 3-Clause "New" or "Revised" License | 6 votes |
def top_k_softmax(logits, k, n): top_logits, top_indices = torch.topk(logits, k=min(k + 1, n)) top_k_logits = top_logits[:, :k] top_k_indices = top_indices[:, :k] probs = torch.softmax(top_k_logits, dim=-1) batch = top_k_logits.shape[0] k = top_k_logits.shape[1] # Flat to 1D indices_flat = torch.reshape(top_k_indices, [-1]) indices_flat = indices_flat + torch.div( torch.arange(batch * k, device=logits.device), k) * n tensor = torch.zeros([batch * n], dtype=logits.dtype, device=logits.device) tensor = tensor.scatter_add(0, indices_flat.long(), torch.reshape(probs, [-1])) return torch.reshape(tensor, [batch, n])
Example #4
Source File: interact.py From dialogue-generation with MIT License | 6 votes |
def select_topk(args, logits, force_no_eos_id=None): """ Applies topk sampling decoding. """ if force_no_eos_id is not None: logits[:, force_no_eos_id] = float('-inf') indices_to_remove = logits < \ torch.topk(logits, args.top_k, axis=-1)[0][ ..., -1, None] logits[indices_to_remove] = float('-inf') return logits # implementation is from Huggingface/transformers repo
Example #5
Source File: centernet_tensorrt_engine.py From centerpose with MIT License | 6 votes |
def _topk(self, scores, K=40): batch, cat, height, width = scores.size() topk_scores, topk_inds = torch.topk(scores.view(batch, cat, -1), K) topk_inds = topk_inds % (height * width) topk_ys = (topk_inds / width).int().float() topk_xs = (topk_inds % width).int().float() topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), K) topk_clses = (topk_ind / K).int() topk_inds = _gather_feat( topk_inds.view(batch, -1, 1), topk_ind).view(batch, K) topk_ys = _gather_feat(topk_ys.view(batch, -1, 1), topk_ind).view(batch, K) topk_xs = _gather_feat(topk_xs.view(batch, -1, 1), topk_ind).view(batch, K) return topk_score, topk_inds, topk_clses, topk_ys, topk_xs
Example #6
Source File: decode.py From centerpose with MIT License | 6 votes |
def _topk(scores, K=40): batch, cat, height, width = scores.size() topk_scores, topk_inds = torch.topk(scores.view(batch, cat, -1), K) topk_inds = topk_inds % (height * width) topk_ys = (topk_inds / width).int().float() topk_xs = (topk_inds % width).int().float() topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), K) topk_clses = (topk_ind / K).int() topk_inds = _gather_feat( topk_inds.view(batch, -1, 1), topk_ind).view(batch, K) topk_ys = _gather_feat(topk_ys.view(batch, -1, 1), topk_ind).view(batch, K) topk_xs = _gather_feat(topk_xs.view(batch, -1, 1), topk_ind).view(batch, K) return topk_score, topk_inds, topk_clses, topk_ys, topk_xs
Example #7
Source File: model.py From ConvLab with MIT License | 6 votes |
def greedy_decode(self, decoder_hidden, encoder_outputs, target_tensor): decoded_sentences = [] batch_size, seq_len = target_tensor.size() decoder_input = torch.LongTensor([[SOS_token] for _ in range(batch_size)], device=self.device) decoded_words = torch.zeros((batch_size, self.max_len)) for t in range(self.max_len): decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_outputs) topv, topi = decoder_output.data.topk(1) # get candidates topi = topi.view(-1) decoded_words[:, t] = topi decoder_input = topi.detach().view(-1, 1) for sentence in decoded_words: sent = [] for ind in sentence: if self.output_index2word(str(int(ind.item()))) == self.output_index2word(str(EOS_token)): break sent.append(self.output_index2word(str(int(ind.item())))) decoded_sentences.append(' '.join(sent)) return decoded_sentences
Example #8
Source File: tsd_net.py From ConvLab with MIT License | 6 votes |
def greedy_decode(self, pz_dec_outs, u_enc_out, m_tm1, u_input_np, last_hidden, degree_input, bspan_index): decoded = [] bspan_index_np = pad_sequences(bspan_index).transpose((1, 0)) for t in range(self.max_ts): proba, last_hidden, _ = self.m_decoder(pz_dec_outs, u_enc_out, u_input_np, m_tm1, degree_input, last_hidden, bspan_index_np) proba = torch.cat((proba[:, :2], proba[:, 3:]), 1) mt_proba, mt_index = torch.topk(proba, 1) # [B,1] mt_index.add_(mt_index.ge(2).long()) mt_index = mt_index.data.view(-1) decoded.append(mt_index.clone()) for i in range(mt_index.size(0)): if mt_index[i] >= cfg.vocab_size: mt_index[i] = 2 # unk m_tm1 = cuda_(Variable(mt_index).view(1, -1)) decoded = torch.stack(decoded, dim=0).transpose(0, 1) decoded = list(decoded) return [list(_) for _ in decoded]
Example #9
Source File: semi_sup_net.py From SEDST with MIT License | 6 votes |
def greedy_decode(self, pz_dec_outs, pz_proba, u_enc_out, m_tm1, last_hidden, degree_input): """ greedy decoding of the response :param pz_dec_outs: :param u_enc_out: :param m_tm1: :param last_hidden: :return: nested-list """ decoded = [] for t in range(self.max_ts): proba, last_hidden, _ = self.m_decoder(pz_dec_outs, pz_proba, u_enc_out, m_tm1, degree_input, last_hidden) mt_proba, mt_index = torch.topk(proba, 1) # [B,1] mt_index = mt_index.data.view(-1) decoded.append(mt_index) m_tm1 = cuda_(Variable(mt_index).view(1, -1)) decoded = torch.stack(decoded, dim=0).transpose(0, 1) decoded = list(decoded) return [list(_) for _ in decoded]
Example #10
Source File: memory.py From LSH_Memory with Apache License 2.0 | 6 votes |
def predict(self, x): batch_size, dims = x.size() query = F.normalize(self.query_proj(x), dim=1) # Find the k-nearest neighbors of the query scores = torch.matmul(query, torch.t(self.keys_var)) cosine_similarity, topk_indices_var = torch.topk(scores, self.top_k, dim=1) # softmax of cosine similarities - embedding softmax_score = F.softmax(self.softmax_temperature * cosine_similarity) # retrive memory values - prediction y_hat_indices = topk_indices_var.data[:, 0] y_hat = self.values[y_hat_indices] return y_hat, softmax_score
Example #11
Source File: test_geometry.py From dgl with Apache License 2.0 | 6 votes |
def test_knn(): x = th.randn(8, 3) kg = dgl.nn.KNNGraph(3) d = th.cdist(x, x) def check_knn(g, x, start, end): for v in range(start, end): src, _ = g.in_edges(v) src = set(src.numpy()) i = v - start src_ans = set(th.topk(d[start:end, start:end][i], 3, largest=False)[1].numpy() + start) assert src == src_ans g = kg(x) check_knn(g, x, 0, 8) g = kg(x.view(2, 4, 3)) check_knn(g, x, 0, 4) check_knn(g, x, 4, 8) kg = dgl.nn.SegmentedKNNGraph(3) g = kg(x, [3, 5]) check_knn(g, x, 0, 3) check_knn(g, x, 3, 8)
Example #12
Source File: competing_completed.py From translate with BSD 3-Clause "New" or "Revised" License | 6 votes |
def select_next_words( self, word_scores, bsz, beam_size, possible_translation_tokens ): cand_scores, cand_indices = torch.topk(word_scores.view(bsz, -1), k=beam_size) possible_tokens_size = self.vocab_size if possible_translation_tokens is not None: possible_tokens_size = possible_translation_tokens.size(0) cand_beams = torch.div(cand_indices, possible_tokens_size) cand_indices.fmod_(possible_tokens_size) # Handle vocab reduction if possible_translation_tokens is not None: possible_translation_tokens = possible_translation_tokens.view( 1, possible_tokens_size ).expand(cand_indices.size(0), possible_tokens_size) cand_indices = torch.gather( possible_translation_tokens, dim=1, index=cand_indices, out=cand_indices ) return cand_scores, cand_indices, cand_beams
Example #13
Source File: word_predictor.py From translate with BSD 3-Clause "New" or "Revised" License | 6 votes |
def get_topk_predicted_tokens(self, net_output, src_tokens, log_probs: bool): """ Get self.topk_labels_per_source_token top predicted words for vocab reduction (per source token). """ assert ( isinstance(self.topk_labels_per_source_token, int) and self.topk_labels_per_source_token > 0 ), "topk_labels_per_source_token must be a positive int, or None" # number of labels to predict for each example in batch k = src_tokens.size(1) * self.topk_labels_per_source_token # [batch_size, vocab_size] probs = self.get_normalized_probs(net_output, log_probs) _, topk_indices = torch.topk(probs, k, dim=1) return topk_indices
Example #14
Source File: beam_decode.py From translate with BSD 3-Clause "New" or "Revised" License | 6 votes |
def diversity_sibling_rank(self, logprobs, gamma): """ See "A Simple, Fast Diverse Decoding Algorithm for Neural Generation" for details """ _, beam_size, vocab_size = logprobs.size() logprobs = logprobs.view(-1, vocab_size) # Keep consistent with beamsearch class in fairseq k = min(2 * beam_size, vocab_size) _, indices = torch.topk(logprobs, k) # Set diverse penalty as k for all words diverse_penalty = torch.ones_like(logprobs) * k diversity_sibling_rank = ( torch.arange(0, k).view(-1, 1).expand(k, logprobs.size(0)).type_as(logprobs) ) # Set diversity penalty accordingly for top-k words diverse_penalty[ torch.arange(0, logprobs.size(0)).long(), indices.transpose(0, 1) ] = diversity_sibling_rank logprobs -= gamma * diverse_penalty return logprobs
Example #15
Source File: rnn_model.py From vmf_vae_nlp with MIT License | 5 votes |
def forward_decode(self, args, input, ntokens): seq_len = input.size()[0] batch_sz = input.size()[1] # emb: seq_len, batchsz, hid_dim # hidden: ([2(nlayers),10(batchsz),200],[]) hidden = None outputs_prob = Variable(torch.FloatTensor(seq_len, batch_sz, ntokens)) if args.cuda: outputs_prob = outputs_prob.cuda() outputs = torch.LongTensor(seq_len, batch_sz) # First time step sos sos = Variable(torch.ones(batch_sz).long()) # id for sos =1 unk = Variable(torch.ones(batch_sz).long()) * 2 # id for unk =2 if args.cuda: sos = sos.cuda() unk = unk.cuda() emb_0 = self.drop(self.encoder(sos)).unsqueeze(0) emb_t = self.drop(self.encoder(unk)).unsqueeze(0) for t in range(seq_len): # input (seq_len, batch, input_size) if t == 0: emb = emb_0 else: emb = emb_t output, hidden = self.rnn(emb, hidden) output_prob = self.decoder(self.drop(output)) output_prob = output_prob.squeeze(0) outputs_prob[t] = output_prob value, ind = torch.topk(output_prob, 1, dim=1) outputs[t] = ind.squeeze(1).data return outputs_prob, outputs
Example #16
Source File: search.py From fairseq with MIT License | 5 votes |
def step(self, step: int, lprobs, scores: Optional[Tensor]): bsz, beam_size, vocab_size = lprobs.size() if step == 0: # at the first step all hypotheses are equally likely, so use # only the first beam lprobs = lprobs[:, ::beam_size, :].contiguous() else: # make probs contain cumulative scores for each hypothesis assert scores is not None lprobs = lprobs + scores[:, :, step - 1].unsqueeze(-1) top_prediction = torch.topk( lprobs.view(bsz, -1), k=min( # Take the best 2 x beam_size predictions. We'll choose the first # beam_size of these which don't predict eos to continue with. beam_size * 2, lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad ), ) scores_buf = top_prediction[0] indices_buf = top_prediction[1] beams_buf = indices_buf // vocab_size indices_buf = indices_buf.fmod(vocab_size) return scores_buf, indices_buf, beams_buf
Example #17
Source File: tasks.py From openseg.pytorch with MIT License | 5 votes |
def _get_multilabel_prediction(dir_logits, no_offset_mask=None, topk=8): h, w, _ = dir_logits.shape dir_logits = torch.from_numpy( dir_logits ).unsqueeze(0).permute(0, 3, 1, 2) offsets = [] if topk == dir_logits.shape[1]: for i in range(topk): offset_i = DTOffsetHelper.label_to_vector( torch.tensor([i]).view(1, 1, 1) ).repeat(1, 1, h, w) offset_i = offset_i.float() * dir_logits[:, i:i+1, :, :] offsets.append(offset_i) else: dir_logits, dir_pred = torch.topk(dir_logits, topk, dim=1) for i in range(topk): dir_pred_i = dir_pred[:, i, :, :] offset_i = DTOffsetHelper.label_to_vector(dir_pred_i) offset_i = offset_i.float() * dir_logits[:, i:i+1, :, :] offsets.append(offset_i) offset = sum(offsets) dir_pred = DTOffsetHelper.vector_to_label( offset.permute(0, 2, 3, 1), num_classes=8, return_tensor=True ) dir_pred = dir_pred.squeeze(0).numpy() if no_offset_mask is not None: dir_pred[no_offset_mask] = 8 return dir_pred
Example #18
Source File: inference.py From Clothing-Detection with GNU General Public License v3.0 | 5 votes |
def select_over_all_levels(self, boxlists): num_images = len(boxlists) # different behavior during training and during testing: # during training, post_nms_top_n is over *all* the proposals combined, while # during testing, it is over the proposals for each image # NOTE: it should be per image, and not per batch. However, to be consistent # with Detectron, the default is per batch (see Issue #672) if self.training and self.fpn_post_nms_per_batch: objectness = torch.cat( [boxlist.get_field("objectness") for boxlist in boxlists], dim=0 ) box_sizes = [len(boxlist) for boxlist in boxlists] post_nms_top_n = min(self.fpn_post_nms_top_n, len(objectness)) _, inds_sorted = torch.topk(objectness, post_nms_top_n, dim=0, sorted=True) inds_mask = torch.zeros_like(objectness, dtype=torch.uint8) inds_mask[inds_sorted] = 1 inds_mask = inds_mask.split(box_sizes) for i in range(num_images): boxlists[i] = boxlists[i][inds_mask[i]] else: for i in range(num_images): objectness = boxlists[i].get_field("objectness") post_nms_top_n = min(self.fpn_post_nms_top_n, len(objectness)) _, inds_sorted = torch.topk( objectness, post_nms_top_n, dim=0, sorted=True ) boxlists[i] = boxlists[i][inds_sorted] return boxlists
Example #19
Source File: helper.py From vmf_vae_nlp with MIT License | 5 votes |
def demo(txt_seq, pred_seq, gold_seq, p_gens, attn_list, meta): tgt_len = len(attn_list) print(tgt_len) tgt_len_ = len(p_gens) print(tgt_len_) pred = [] for x in pred_seq[0]: pred += x tgt_len__ = len(pred) assert tgt_len == tgt_len_ == tgt_len__ txt = [] for x in txt_seq[0]: txt += x gold = [] for x in gold_seq[0][0]: gold += x output_txt = ' '.join(txt) output_gold = ' '.join(gold) output_meta = "%s" % (meta) output_pred = [] for t in range(tgt_len): p_gen = p_gens[t][0] att = attn_list[t] val, idx = torch.topk(att, 1) attn_val = val[0] attn_txt = txt[idx[0]] o = '{:10s} G{:.2f} {:10s} A{:.2f}\n'.format(pred[t], p_gen, attn_txt, attn_val) output_pred.append(o) output_pred = '\t'.join(output_pred) output_string = '\nMeta: %s\nText: %s\nGold: %s\nPred:\n%s\n' % (output_meta, output_txt, output_gold, output_pred) return output_string, pred_seq, gold_seq
Example #20
Source File: tasks.py From openseg.pytorch with MIT License | 5 votes |
def eval(outputs, meta, running_scores): distance_map = meta['ori_distance_map'] seg_label_map = meta['ori_target'] dir_label_map = meta['ori_multi_label_direction_map'] dir_label_map = DTOffsetHelper.encode_multi_labels(dir_label_map) dir_label_map[seg_label_map == -1, :] = -1 gt_mask_label = DTOffsetHelper.distance_to_mask_label( distance_map, seg_label_map, return_tensor=False ) mask_pred = MaskTask.get_mask_pred(outputs['mask']) dir_pred = MLDirectionTask._get_multilabel_prediction( outputs['ml_dir'], no_offset_mask=mask_pred == 0, topk=8 ) running_scores['ML dir (mask)'].update( dir_pred, dir_label_map, (mask_pred == 1) & (seg_label_map != -1) ) running_scores['ML dir (GT)'].update( dir_pred, dir_label_map, gt_mask_label == 1 )
Example #21
Source File: projection.py From pykg2vec with MIT License | 5 votes |
def predict_head_rank(self, e, r, topk=-1): _, rank = torch.topk(-self.forward(e, r, direction="head"), k=topk) return rank
Example #22
Source File: projection.py From pykg2vec with MIT License | 5 votes |
def predict_tail_rank(self, e, r, topk=-1): _, rank = torch.topk(-self.forward(e, r, direction="tail"), k=topk) return rank
Example #23
Source File: tensor.py From dgl with Apache License 2.0 | 5 votes |
def topk(input, k, dim, descending=True): return th.topk(input, k, dim, largest=descending)[0]
Example #24
Source File: mask_point_head.py From mmdetection with Apache License 2.0 | 5 votes |
def get_roi_rel_points_test(self, mask_pred, pred_label, cfg): """Get ``num_points`` most uncertain points during test. Args: mask_pred (Tensor): A tensor of shape (num_rois, num_classes, mask_height, mask_width) for class-specific or class-agnostic prediction. pred_label (list): The predication class for each instance. cfg (dict): Testing config of point head. Returns: point_indices (Tensor): A tensor of shape (num_rois, num_points) that contains indices from [0, mask_height x mask_width) of the most uncertain points. point_coords (Tensor): A tensor of shape (num_rois, num_points, 2) that contains [0, 1] x [0, 1] normalized coordinates of the most uncertain points from the [mask_height, mask_width] grid . """ num_points = cfg.subdivision_num_points uncertainty_map = self._get_uncertainty(mask_pred, pred_label) num_rois, _, mask_height, mask_width = uncertainty_map.shape h_step = 1.0 / mask_height w_step = 1.0 / mask_width uncertainty_map = uncertainty_map.view(num_rois, mask_height * mask_width) num_points = min(mask_height * mask_width, num_points) point_indices = uncertainty_map.topk(num_points, dim=1)[1] point_coords = uncertainty_map.new_zeros(num_rois, num_points, 2) point_coords[:, :, 0] = w_step / 2.0 + (point_indices % mask_width).float() * w_step point_coords[:, :, 1] = h_step / 2.0 + (point_indices // mask_width).float() * h_step return point_indices, point_coords
Example #25
Source File: evaluator.py From pykg2vec with MIT License | 5 votes |
def test_rel_rank(self, h, t, topk=-1): if hasattr(self.model, 'predict_rel_rank'): # TODO: this broke training on ProjE_pointwise # h = h.unsqueeze(0).to(self.config.device) # t = t.unsqueeze(0).to(self.config.device) rank = self.model.predict_rel_rank(h, t, topk=topk) return rank.squeeze(0) h_batch = torch.LongTensor([h]).repeat([self.config.tot_relation]).to(self.config.device) rel_array = torch.LongTensor(list(range(self.config.tot_relation))).to(self.config.device) t_batch = torch.LongTensor([t]).repeat([self.config.tot_relation]).to(self.config.device) preds = self.model.forward(h_batch, rel_array, t_batch) _, rank = torch.topk(preds, k=topk) return rank
Example #26
Source File: evaluator.py From pykg2vec with MIT License | 5 votes |
def test_head_rank(self, r, t, topk=-1): if hasattr(self.model, 'predict_head_rank'): # TODO: this broke training on ProjE_pointwise # t = t.unsqueeze(0).to(self.config.device) # r = r.unsqueeze(0).to(self.config.device) rank = self.model.predict_head_rank(t, r, topk=topk) return rank.squeeze(0) entity_array = torch.LongTensor(list(range(self.config.tot_entity))).to(self.config.device) r_batch = torch.LongTensor([r]).repeat([self.config.tot_entity]).to(self.config.device) t_batch = torch.LongTensor([t]).repeat([self.config.tot_entity]).to(self.config.device) preds = self.model.forward(entity_array, r_batch, t_batch) _, rank = torch.topk(preds, k=topk) return rank
Example #27
Source File: evaluator.py From pykg2vec with MIT License | 5 votes |
def test_tail_rank(self, h, r, topk=-1): if hasattr(self.model, 'predict_tail_rank'): # TODO: this broke training on ProjE_pointwise # h = h.unsqueeze(0).to(self.config.device) # r = r.unsqueeze(0).to(self.config.device) rank = self.model.predict_tail_rank(h, r, topk=topk) return rank.squeeze(0) h_batch = torch.LongTensor([h]).repeat([self.config.tot_entity]).to(self.config.device) r_batch = torch.LongTensor([r]).repeat([self.config.tot_entity]).to(self.config.device) entity_array = torch.LongTensor(list(range(self.config.tot_entity))).to(self.config.device) preds = self.model.forward(h_batch, r_batch, entity_array) _, rank = torch.topk(preds, k=topk) return rank
Example #28
Source File: tensor.py From dgl with Apache License 2.0 | 5 votes |
def argtopk(input, k, dim, descending=True): return th.topk(input, k, dim, largest=descending)[1]
Example #29
Source File: primitives.py From torchbearer with MIT License | 5 votes |
def process(self, *args): state = args[0] y_pred = state[self.pred_key] y_true = state[self.target_key] mask = y_true.eq(self.ignore_index).eq(0) y_pred = y_pred[mask] y_true = y_true[mask] sorted_indices = torch.topk(y_pred, self.k, dim=1)[1] expanded_y = y_true.view(-1, 1).expand(-1, self.k) return torch.sum(torch.eq(sorted_indices, expanded_y), dim=1).float()
Example #30
Source File: memory.py From LSH_Memory with Apache License 2.0 | 5 votes |
def update(self, query, y, y_hat, y_hat_indices): batch_size, dims = query.size() # 1) Untouched: Increment memory by 1 self.age += 1 # Divide batch by correctness result = torch.squeeze(torch.eq(y_hat, torch.unsqueeze(y.data, dim=1))).float() incorrect_examples = torch.squeeze(torch.nonzero(1-result)) correct_examples = torch.squeeze(torch.nonzero(result)) incorrect = len(incorrect_examples.size()) > 0 correct = len(correct_examples.size()) > 0 # 2) Correct: if V[n1] = v # Update Key k[n1] <- normalize(q + K[n1]), Reset Age A[n1] <- 0 if correct: correct_indices = y_hat_indices[correct_examples] correct_keys = self.keys[correct_indices] correct_query = query.data[correct_examples] new_correct_keys = F.normalize(correct_keys + correct_query, dim=1) self.keys[correct_indices] = new_correct_keys self.age[correct_indices] = 0 # 3) Incorrect: if V[n1] != v # Select item with oldest age, Add random offset - n' = argmax_i(A[i]) + r_i # K[n'] <- q, V[n'] <- v, A[n'] <- 0 if incorrect: incorrect_size = incorrect_examples.size()[0] incorrect_query = query.data[incorrect_examples] incorrect_values = y.data[incorrect_examples] age_with_noise = self.age + random_uniform((self.memory_size, 1), -self.age_noise, self.age_noise, cuda=True) topk_values, topk_indices = torch.topk(age_with_noise, incorrect_size, dim=0) oldest_indices = torch.squeeze(topk_indices) self.keys[oldest_indices] = incorrect_query self.values[oldest_indices] = incorrect_values self.age[oldest_indices] = 0