Python numpy.argpartition() Examples
The following are 30 code examples for showing how to use numpy.argpartition(). These examples are extracted from open source projects. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
You may check out the related API usage on the sidebar.
You may also want to check out all available functions/classes of the module
numpy
, or try the search function
.
Example 1
Project: DOTA_models Author: ringringyi File: replay_buffer.py License: Apache License 2.0 | 6 votes |
def remove_n(self, n): """Get n items for removal.""" assert self.init_length + n <= self.cur_size if self.eviction_strategy == 'rand': # random removal idxs = random.sample(xrange(self.init_length, self.cur_size), n) elif self.eviction_strategy == 'fifo': # overwrite elements in cyclical fashion idxs = [ self.init_length + (self.remove_idx + i) % (self.max_size - self.init_length) for i in xrange(n)] self.remove_idx = idxs[-1] + 1 - self.init_length elif self.eviction_strategy == 'rank': # remove lowest-priority indices idxs = np.argpartition(self.priorities, n)[:n] return idxs
Example 2
Project: pyscf Author: pyscf File: direct_spin1.py License: Apache License 2.0 | 6 votes |
def _get_init_guess(na, nb, nroots, hdiag): '''Initial guess is the single Slater determinant ''' # The "nroots" lowest determinats based on energy expectation value. ci0 = [] try: addrs = numpy.argpartition(hdiag, nroots-1)[:nroots] except AttributeError: addrs = numpy.argsort(hdiag)[:nroots] for addr in addrs: x = numpy.zeros((na*nb)) x[addr] = 1 ci0.append(x.ravel()) # Add noise ci0[0][0 ] += 1e-5 ci0[0][-1] -= 1e-5 return ci0
Example 3
Project: iAI Author: aimuch File: resnet_as_a_service.py License: MIT License | 6 votes |
def analyze(output_data): #Results from the engine are returned as a list of 5D numpy arrays: # (Number of Batches x Batch Size x C x H x W) output = output_data.reshape(len(LABELS)) # Get result top = np.argmax(output) top = LABELS[top] # Get top5 top5 = np.argpartition(output, -5, axis=-1)[-5:] top5 = top5[np.argsort(output[top5])][::-1] top5_classes = [] for i in top5: top5_classes.append((LABELS[i], output[i])) return [top, top5_classes] #Arguments to create lite engine
Example 4
Project: iAI Author: aimuch File: resnet_as_a_service.py License: MIT License | 6 votes |
def analyze(output_data): #Results from the engine are returned as a list of 5D numpy arrays: # (Number of Batches x Batch Size x C x H x W) output = output_data.reshape(len(LABELS)) # Get result top = np.argmax(output) top = LABELS[top] # Get top5 top5 = np.argpartition(output, -5, axis=-1)[-5:] top5 = top5[np.argsort(output[top5])][::-1] top5_classes = [] for i in top5: top5_classes.append((LABELS[i], output[i])) return [top, top5_classes] #Arguments to create lite engine
Example 5
Project: recruit Author: Frank-qlu File: test_shape_base.py License: Apache License 2.0 | 6 votes |
def test_argequivalent(self): """ Test it translates from arg<func> to <func> """ from numpy.random import rand a = rand(3, 4, 5) funcs = [ (np.sort, np.argsort, dict()), (_add_keepdims(np.min), _add_keepdims(np.argmin), dict()), (_add_keepdims(np.max), _add_keepdims(np.argmax), dict()), (np.partition, np.argpartition, dict(kth=2)), ] for func, argfunc, kwargs in funcs: for axis in list(range(a.ndim)) + [None]: a_func = func(a, axis=axis, **kwargs) ai_func = argfunc(a, axis=axis, **kwargs) assert_equal(a_func, take_along_axis(a, ai_func, axis=axis))
Example 6
Project: OpenQA Author: thunlp File: tfidf_doc_ranker.py License: MIT License | 6 votes |
def closest_docs(self, query, k=1): """Closest docs by dot product between query and documents in tfidf weighted word vector space. """ spvec = self.text2spvec(query) res = spvec * self.doc_mat if len(res.data) <= k: o_sort = np.argsort(-res.data) else: o = np.argpartition(-res.data, k)[0:k] o_sort = o[np.argsort(-res.data[o])] doc_scores = res.data[o_sort] doc_ids = [self.get_doc_id(i) for i in res.indices[o_sort]] return doc_ids, doc_scores
Example 7
Project: knmt Author: fabiencro File: beam_search.py License: GNU General Public License v3.0 | 6 votes |
def iterate_eos_scores(new_scores, eos_idx, existing_cases = None, beam_width=None)->Tuple[Sequence, Sequence, Sequence]: """ Return the indices and scores corresponding to the eos word. Meaning of returned values is the same as for iterate_best_score """ nb_cases, v_size = new_scores.shape num_cases = np.arange(nb_cases, dtype=np.int32) scores = -cuda.to_cpu(new_scores[:, eos_idx]) if existing_cases is not None: need_to_return = np.logical_not(np.isin(num_cases, existing_cases)) num_cases = num_cases[need_to_return] scores = scores[need_to_return] idx_in_cases = np.full(num_cases.shape[0], eos_idx, dtype=np.int32) if beam_width is not None: if beam_width < len(scores): idx_to_keep = np.argpartition(scores, beam_width)[:beam_width] scores = scores[idx_to_keep] num_cases = num_cases[idx_to_keep] idx_in_cases = idx_in_cases[idx_to_keep] return num_cases, idx_in_cases, scores
Example 8
Project: justcopy-backend Author: ailabstw File: tfidf_doc_ranker.py License: MIT License | 6 votes |
def closest_docs(self, query, k=1): """Closest docs by dot product between query and documents in tfidf weighted word vector space. """ spvec = self.text2spvec(query) res = spvec * self.doc_mat if len(res.data) <= k: o_sort = np.argsort(-res.data) else: o = np.argpartition(-res.data, k)[0:k] o_sort = o[np.argsort(-res.data[o])] doc_scores = res.data[o_sort] doc_ids = [self.get_doc_id(i) for i in res.indices[o_sort]] return doc_ids, doc_scores
Example 9
Project: DeepHash Author: thulab File: __init__.py License: MIT License | 6 votes |
def get_mAPs(q_output, q_labels, db_output, db_labels, Rs, dist_type): dist = distance(q_output, db_output, dist_type=dist_type, pair=True) unsorted_ids = np.argpartition(dist, Rs - 1)[:, :Rs] APx = [] for i in range(dist.shape[0]): label = q_labels[i, :] label[label == 0] = -1 idx = unsorted_ids[i, :] idx = idx[np.argsort(dist[i, :][idx])] imatch = np.sum(np.equal(db_labels[idx[0: Rs], :], label), 1) > 0 rel = np.sum(imatch) Lx = np.cumsum(imatch) Px = Lx.astype(float) / np.arange(1, Rs + 1, 1) if rel != 0: APx.append(np.sum(Px * imatch) / rel) return np.mean(np.array(APx))
Example 10
Project: scikit-learn-extra Author: scikit-learn-contrib File: _k_medoids.py License: BSD 3-Clause "New" or "Revised" License | 6 votes |
def _initialize_medoids(self, D, n_clusters, random_state_): """Select initial mediods when beginning clustering.""" if self.init == "random": # Random initialization # Pick random k medoids as the initial ones. medoids = random_state_.choice(len(D), n_clusters) elif self.init == "k-medoids++": medoids = self._kpp_init(D, n_clusters, random_state_) elif self.init == "heuristic": # Initialization by heuristic # Pick K first data points that have the smallest sum distance # to every other point. These are the initial medoids. medoids = np.argpartition(np.sum(D, axis=1), n_clusters - 1)[ :n_clusters ] else: raise ValueError(f"init value '{self.init}' not recognized") return medoids # Copied from sklearn.cluster.k_means_._k_init
Example 11
Project: siamese-triplet Author: adambielski File: utils.py License: BSD 3-Clause "New" or "Revised" License | 6 votes |
def get_pairs(self, embeddings, labels): if self.cpu: embeddings = embeddings.cpu() distance_matrix = pdist(embeddings) labels = labels.cpu().data.numpy() all_pairs = np.array(list(combinations(range(len(labels)), 2))) all_pairs = torch.LongTensor(all_pairs) positive_pairs = all_pairs[(labels[all_pairs[:, 0]] == labels[all_pairs[:, 1]]).nonzero()] negative_pairs = all_pairs[(labels[all_pairs[:, 0]] != labels[all_pairs[:, 1]]).nonzero()] negative_distances = distance_matrix[negative_pairs[:, 0], negative_pairs[:, 1]] negative_distances = negative_distances.cpu().data.numpy() top_negatives = np.argpartition(negative_distances, len(positive_pairs))[:len(positive_pairs)] top_negative_pairs = negative_pairs[torch.LongTensor(top_negatives)] return positive_pairs, top_negative_pairs
Example 12
Project: youtube-8m Author: wangheda File: inference-sample-error-analysis.py License: Apache License 2.0 | 6 votes |
def format_lines(video_ids, predictions, labels, top_k): batch_size = len(video_ids) for video_index in range(batch_size): n_recall = max(int(numpy.sum(labels[video_index])), 1) # labels label_indices = numpy.argpartition(labels[video_index], -n_recall)[-n_recall:] label_predictions = [(class_index, predictions[video_index][class_index]) for class_index in label_indices] label_predictions = sorted(label_predictions, key=lambda p: -p[1]) label_str = "\t".join(["%d\t%f"%(x,y) for x,y in label_predictions]) # predictions top_k_indices = numpy.argpartition(predictions[video_index], -top_k)[-top_k:] top_k_predictions = [(class_index, predictions[video_index][class_index]) for class_index in top_k_indices] top_k_predictions = sorted(top_k_predictions, key=lambda p: -p[1]) top_k_str = "\t".join(["%d\t%f"%(x,y) for x,y in top_k_predictions]) # compute PERR top_n_indices = numpy.argpartition(predictions[video_index], -n_recall)[-n_recall:] positives = [labels[video_index][class_index] for class_index in top_n_indices] perr = sum(positives) / float(n_recall) # URL url = "https://www.youtube.com/watch?v=" + video_ids[video_index].decode('utf-8') yield url + "\t" + str(1-perr) + "\t" + top_k_str + "\t" + label_str + "\n"
Example 13
Project: Computable Author: ktraunmueller File: test_multiarray.py License: MIT License | 6 votes |
def test_partition_cdtype(self): d = array([('Galahad', 1.7, 38), ('Arthur', 1.8, 41), ('Lancelot', 1.9, 38)], dtype=[('name', '|S10'), ('height', '<f8'), ('age', '<i4')]) tgt = np.sort(d, order=['age', 'height']) assert_array_equal(np.partition(d, range(d.size), order=['age', 'height']), tgt) assert_array_equal(d[np.argpartition(d, range(d.size), order=['age', 'height'])], tgt) for k in range(d.size): assert_equal(np.partition(d, k, order=['age', 'height'])[k], tgt[k]) assert_equal(d[np.argpartition(d, k, order=['age', 'height'])][k], tgt[k]) d = array(['Galahad', 'Arthur', 'zebra', 'Lancelot']) tgt = np.sort(d) assert_array_equal(np.partition(d, range(d.size)), tgt) for k in range(d.size): assert_equal(np.partition(d, k)[k], tgt[k]) assert_equal(d[np.argpartition(d, k)][k], tgt[k])
Example 14
Project: BrainSpace Author: MICA-MNI File: utils.py License: BSD 3-Clause "New" or "Revised" License | 6 votes |
def _dominant_set_sparse(s, k, is_thresh=False, norm=False): """Compute dominant set for a sparse matrix.""" if is_thresh: mask = s > k idx, data = np.where(mask), s[mask] s = ssp.coo_matrix((data, idx), shape=s.shape) else: # keep top k nr, nc = s.shape idx = np.argpartition(s, nc - k, axis=1) col = idx[:, -k:].ravel() # idx largest row = np.broadcast_to(np.arange(nr)[:, None], (nr, k)).ravel() data = s[row, col].ravel() s = ssp.coo_matrix((data, (row, col)), shape=s.shape) if norm: s.data /= s.sum(axis=1).A1[s.row] return s.tocsr(copy=False)
Example 15
Project: BrainSpace Author: MICA-MNI File: utils.py License: BSD 3-Clause "New" or "Revised" License | 6 votes |
def _dominant_set_dense(s, k, is_thresh=False, norm=False, copy=True): """Compute dominant set for a dense matrix.""" if is_thresh: s = s.copy() if copy else s s[s <= k] = 0 else: # keep top k nr, nc = s.shape idx = np.argpartition(s, nc - k, axis=1) row = np.arange(nr)[:, None] if copy: col = idx[:, -k:] # idx largest data = s[row, col] s = np.zeros_like(s) s[row, col] = data else: col = idx[:, :-k] # idx smallest s[row, col] = 0 if norm: s /= np.nansum(s, axis=1, keepdims=True) return s
Example 16
Project: otalign Author: dmelis File: bilind.py License: GNU General Public License v3.0 | 6 votes |
def csls_sparse(X, Y, idx_x, idx_y, knn = 10): def mean_similarity_sparse(X, Y, seeds, knn, axis = 1, metric = 'cosine'): if axis == 1: dists = sp.spatial.distance.cdist(X[seeds,:], Y, metric=metric) else: dists = sp.spatial.distance.cdist(X, Y[seeds,:], metric=metric).T nghbs = np.argpartition(dists, knn, axis = 1) # for rows #[-k:] # argpartition returns top k not in order but it's efficient (doesnt sort all rows) nghbs = nghbs[:,:knn] nghbs_dists = np.concatenate([row[indices] for row, indices in zip(dists, nghbs)]).reshape(nghbs.shape) nghbs_sims = 1 - nghbs_dists return nghbs_sims.mean(axis = 1) src_ms = mean_similarity_sparse(X, Y, idx_x, knn, axis = 1) trg_ms = mean_similarity_sparse(X, Y, idx_y, knn, axis = 0) sims = 1 - sp.spatial.distance.cdist(X[idx_x,:], Y[idx_y,:]) normalized_sims = ((2*sims - trg_ms).T - src_ms).T print(normalized_sims) nn = normalized_sims.argmax(axis=1).tolist() return nn
Example 17
Project: Mastering-Elasticsearch-7.0 Author: PacktPublishing File: test_shape_base.py License: MIT License | 6 votes |
def test_argequivalent(self): """ Test it translates from arg<func> to <func> """ from numpy.random import rand a = rand(3, 4, 5) funcs = [ (np.sort, np.argsort, dict()), (_add_keepdims(np.min), _add_keepdims(np.argmin), dict()), (_add_keepdims(np.max), _add_keepdims(np.argmax), dict()), (np.partition, np.argpartition, dict(kth=2)), ] for func, argfunc, kwargs in funcs: for axis in list(range(a.ndim)) + [None]: a_func = func(a, axis=axis, **kwargs) ai_func = argfunc(a, axis=axis, **kwargs) assert_equal(a_func, take_along_axis(a, ai_func, axis=axis))
Example 18
Project: cryptotrader Author: naripok File: risk.py License: MIT License | 6 votes |
def polar_returns(ret, k): """ Calculate polar return :param obs: pandas DataFrame :return: return radius, return angles """ ret= np.mat(ret) # Find the radius and the angle decomposition on price relative vectors radius = np.linalg.norm(ret, ord=1, axis=1) angle = np.divide(ret, np.mat(radius).T) # Select the 'window' greater values on the observation index = np.argpartition(radius, -(int(ret.shape[0] * k) + 1))[-(int(ret.shape[0] * k) + 1):] index = index[np.argsort(radius[index])] # Return the radius and the angle for extreme found values return radius[index][::-1], angle[index][::-1] # Pareto Extreme Risk Index
Example 19
Project: AugmentedAutoencoder Author: DLR-RM File: codebook.py License: MIT License | 6 votes |
def nearest_rotation(self, session, x, top_n=1, upright=False, return_idcs=False): #R_model2cam if x.dtype == 'uint8': x = x/255. if x.ndim == 3: x = np.expand_dims(x, 0) cosine_similarity = session.run(self.cos_similarity, {self._encoder.x: x}) if top_n == 1: if upright: idcs = np.argmax(cosine_similarity[:,::int(self._dataset._kw['num_cyclo'])], axis=1)*int(self._dataset._kw['num_cyclo']) else: idcs = np.argmax(cosine_similarity, axis=1) else: unsorted_max_idcs = np.argpartition(-cosine_similarity.squeeze(), top_n)[:top_n] idcs = unsorted_max_idcs[np.argsort(-cosine_similarity.squeeze()[unsorted_max_idcs])] if return_idcs: return idcs else: return self._dataset.viewsphere_for_embedding[idcs].squeeze()
Example 20
Project: hred-latent-piecewise Author: julianser File: search.py License: GNU General Public License v3.0 | 6 votes |
def select_next_words(self, next_costs, next_probs, step_num, how_many): # Pick only on the first line (for the beginning of sampling) # This will avoid duplicate <q> token. if step_num == 0: flat_next_costs = next_costs[:1, :].flatten() else: # Set the next cost to infinite for finished utterances (they will be replaced) # by other utterances in the beam flat_next_costs = next_costs.flatten() voc_size = next_costs.shape[1] args = numpy.argpartition(flat_next_costs, how_many)[:how_many] args = args[numpy.argsort(flat_next_costs[args])] return numpy.unravel_index(args, next_costs.shape), flat_next_costs[args]
Example 21
Project: pyxclib Author: kunaldahiya File: misc.py License: MIT License | 6 votes |
def _update_predicted(start_idx, predicted_batch_labels, predicted_labels, top_k=10): """ Update the predicted answers for the batch Args: predicted_batch_labels predicted_labels """ def _select_topk(vec, k): batch_size = vec.shape[0] top_ind = np.argpartition(vec, -k)[:, -k:] ind = np.zeros((k*batch_size, 2), dtype=np.int) ind[:, 0] = np.repeat(np.arange(0, batch_size, 1), [k]*batch_size) ind[:, 1] = top_ind.flatten('C') return top_ind.flatten('C'), vec[ind[:, 0], ind[:, 1]] batch_size = predicted_batch_labels.shape[0] top_indices, top_vals = _select_topk(predicted_batch_labels, k=top_k) ind = np.zeros((top_k*batch_size, 2), dtype=np.int) ind[:, 0] = np.repeat( np.arange(start_idx, start_idx+batch_size, 1), [top_k]*batch_size) ind[:, 1] = top_indices predicted_labels[ind[:, 0], ind[:, 1]] = top_vals
Example 22
Project: xam Author: MaxHalford File: top_terms.py License: MIT License | 6 votes |
def fit(self, X, y=None, **fit_params): # scikit-learn checks X, y = utils.check_X_y(X, y, accept_sparse='csr', order='C') n_terms = min(self.n_terms, X.shape[1]) # Get a list of unique labels from y labels = np.unique(y) # Determine the n top terms per class self.top_terms_per_class_ = { c: set(np.argpartition(np.sum(X[y == c], axis=0), -n_terms)[-n_terms:]) for c in labels } # Return the classifier return self
Example 23
Project: yolo_v2 Author: rky0930 File: replay_buffer.py License: Apache License 2.0 | 6 votes |
def remove_n(self, n): """Get n items for removal.""" assert self.init_length + n <= self.cur_size if self.eviction_strategy == 'rand': # random removal idxs = random.sample(xrange(self.init_length, self.cur_size), n) elif self.eviction_strategy == 'fifo': # overwrite elements in cyclical fashion idxs = [ self.init_length + (self.remove_idx + i) % (self.max_size - self.init_length) for i in xrange(n)] self.remove_idx = idxs[-1] + 1 - self.init_length elif self.eviction_strategy == 'rank': # remove lowest-priority indices idxs = np.argpartition(self.priorities, n)[:n] return idxs
Example 24
Project: neural-pipeline Author: toodef File: train_config.py License: MIT License | 5 votes |
def exec(self, data_processor: TrainDataProcessor, losses: np.ndarray, indices: []) -> None: num_losses = int(losses.size * self._part) idxs = np.argpartition(losses, -num_losses)[-num_losses:] self._run(self.data_producer.get_loader([indices[i] for i in idxs]), self.name(), data_processor)
Example 25
Project: torch-toolbox Author: PistonY File: metric.py License: BSD 3-Clause "New" or "Revised" License | 5 votes |
def update(self, preds, labels): """Update status. Args: preds (Tensor): Model outputs labels (Tensor): True label """ preds = to_numpy(preds).astype('float32') labels = to_numpy(labels).astype('float32') preds = np.argpartition(preds, -self.topK)[:, -self.topK:] # TODO: Is there any more quick way? for l, p in zip(labels, preds): self.num_metric += 1 if l in p else 0 self.num_inst += 1
Example 26
Project: torch-toolbox Author: PistonY File: test_metric.py License: BSD 3-Clause "New" or "Revised" License | 5 votes |
def get_ture_top3(label, pred): pred = np.argpartition(pred, -3)[:, -3:] num_ture_idx = 0 for l, p in zip(label, pred): if l in p: num_ture_idx += 1 return num_ture_idx / 10
Example 27
Project: pyscf Author: pyscf File: direct_uhf.py License: Apache License 2.0 | 5 votes |
def pspace(h1e, eri, norb, nelec, hdiag=None, np=400): neleca, nelecb = direct_spin1._unpack_nelec(nelec) h1e_a = numpy.ascontiguousarray(h1e[0]) h1e_b = numpy.ascontiguousarray(h1e[1]) g2e_aa = ao2mo.restore(1, eri[0], norb) g2e_ab = ao2mo.restore(1, eri[1], norb) g2e_bb = ao2mo.restore(1, eri[2], norb) if hdiag is None: hdiag = make_hdiag(h1e, eri, norb, nelec) if hdiag.size < np: addr = numpy.arange(hdiag.size) else: try: addr = numpy.argpartition(hdiag, np-1)[:np] except AttributeError: addr = numpy.argsort(hdiag)[:np] nb = cistring.num_strings(norb, nelecb) addra = addr // nb addrb = addr % nb stra = cistring.addrs2str(norb, neleca, addra) strb = cistring.addrs2str(norb, nelecb, addrb) np = len(addr) h0 = numpy.zeros((np,np)) libfci.FCIpspace_h0tril_uhf(h0.ctypes.data_as(ctypes.c_void_p), h1e_a.ctypes.data_as(ctypes.c_void_p), h1e_b.ctypes.data_as(ctypes.c_void_p), g2e_aa.ctypes.data_as(ctypes.c_void_p), g2e_ab.ctypes.data_as(ctypes.c_void_p), g2e_bb.ctypes.data_as(ctypes.c_void_p), stra.ctypes.data_as(ctypes.c_void_p), strb.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(norb), ctypes.c_int(np)) for i in range(np): h0[i,i] = hdiag[addr[i]] h0 = lib.hermi_triu(h0) return addr, h0 # be careful with single determinant initial guess. It may lead to the # eigvalue of first davidson iter being equal to hdiag
Example 28
Project: MnemonicReader Author: HKUST-KnowComp File: model.py License: BSD 3-Clause "New" or "Revised" License | 5 votes |
def decode(score_s, score_e, top_n=1, max_len=None): """Take argmax of constrained score_s * score_e. Args: score_s: independent start predictions score_e: independent end predictions top_n: number of top scored pairs to take max_len: max span length to consider """ pred_s = [] pred_e = [] pred_score = [] max_len = max_len or score_s.size(1) for i in range(score_s.size(0)): # Outer product of scores to get full p_s * p_e matrix scores = torch.ger(score_s[i], score_e[i]) # Zero out negative length and over-length span scores scores.triu_().tril_(max_len - 1) # Take argmax or top n scores = scores.numpy() scores_flat = scores.flatten() if top_n == 1: idx_sort = [np.argmax(scores_flat)] elif len(scores_flat) < top_n: idx_sort = np.argsort(-scores_flat) else: idx = np.argpartition(-scores_flat, top_n)[0:top_n] idx_sort = idx[np.argsort(-scores_flat[idx])] s_idx, e_idx = np.unravel_index(idx_sort, scores.shape) pred_s.append(s_idx) pred_e.append(e_idx) pred_score.append(scores_flat[idx_sort]) del score_s, score_e return pred_s, pred_e, pred_score
Example 29
Project: devicehive-audio-analysis Author: devicehive File: processor.py License: Apache License 2.0 | 5 votes |
def _filter_predictions(self, predictions): count = params.PREDICTIONS_COUNT_LIMIT hit = params.PREDICTIONS_HIT_LIMIT top_indices = np.argpartition(predictions[0], -count)[-count:] line = ((self._class_map[i], float(predictions[0][i])) for i in top_indices if predictions[0][i] > hit) return sorted(line, key=lambda p: -p[1])
Example 30
Project: mabwiser Author: fidelity File: simulator.py License: Apache License 2.0 | 5 votes |
def _predict_contexts(self, contexts: np.ndarray, is_predict: bool, seeds: Optional[np.ndarray] = None, start_index: Optional[int] = None) -> List: # Copy Learning Policy object and set random state lp = deepcopy(self.lp) # Create an empty list of predictions predictions = [None] * len(contexts) # For each row in the given contexts for index, row in enumerate(contexts): # Get random generator lp.rng = create_rng(seed=seeds[index]) # Calculate the distances from the historical contexts # Row is 1D so convert it to 2D array for cdist using newaxis # Finally, reshape to flatten the output distances list row_2d = row[np.newaxis, :] distances_to_row = self.distances[start_index + index] # Find the k nearest neighbor indices indices = np.argpartition(distances_to_row, self.k - 1)[:self.k] prediction, exp, stats = self._get_nhood_predictions(lp, row_2d, indices, is_predict) predictions[index] = [prediction, exp, self.k, stats] # Return the list of predictions return predictions