from .ann import NearestNeighbor, HNSW, HNSWM import logging from sklearn.preprocessing import normalize import _pickle as pickle from .clustering import Cluster import numpy as np from collections import OrderedDict import numba as nb import time from ..utils.dense import compute_centroid import os import math @nb.njit(cache=True) def bin_index(array, item): # Binary search first, last = 0, len(array) - 1 while first <= last: mid = (first + last) // 2 if array[mid] == item: return mid if item < array[mid]: last = mid - 1 else: first = mid + 1 return -1 @nb.njit(cache=True) def safe_normalize(array): _max = np.max(array) if _max != 0: return array/_max else: return array @nb.njit(cache=True) def safe_normalize(array): _max = np.max(array) if _max != 0: return array/_max else: return array @nb.njit(nb.types.Tuple( (nb.int64[:], nb.float32[:]))(nb.int64[:, :], nb.float32[:], nb.int64)) def map_one(indices_labels, similarity, pad_ind): unique_point_labels = np.unique(indices_labels) unique_point_labels = unique_point_labels[unique_point_labels != pad_ind] point_label_similarity = np.zeros( (len(unique_point_labels), ), dtype=np.float32) for j in range(len(indices_labels)): for lbl in indices_labels[j]: if(lbl != pad_ind): _ind = bin_index(unique_point_labels, lbl) point_label_similarity[_ind] += similarity[j] point_label_similarity = safe_normalize(point_label_similarity) return unique_point_labels, point_label_similarity @nb.njit(nb.types.Tuple( (nb.int64[:, :], nb.float32[:, :]))(nb.int64[:, :], nb.float32[:, :], nb.int64[:, :], nb.int64, nb.int64, nb.float32), parallel=True) def map_neighbors(indices, similarity, labels, top_k, pad_ind, pad_val): m = indices.shape[0] point_labels = np.full( (m, top_k), pad_ind, dtype=np.int64) point_label_sims = np.full( (m, top_k), pad_val, dtype=np.float32) for i in nb.prange(m): unique_point_labels, point_label_sim = map_one( labels[indices[i]], similarity[i], pad_ind) if top_k < len(unique_point_labels): top_indices = np.argsort( point_label_sim)[-1 * top_k:][::-1] point_labels[i] = unique_point_labels[top_indices] point_label_sims[i] = point_label_sim[top_indices] else: point_labels[i, :len(unique_point_labels)] = unique_point_labels point_label_sims[i, :len(unique_point_labels)] = point_label_sim return point_labels, point_label_sims @nb.njit(cache=True) def _remap_centroid_one(indices, sims, mapping): mapped_indices = mapping[indices] unique_mapped_indices = np.unique(mapped_indices) unique_mapped_sims = np.zeros( (len(unique_mapped_indices), ), dtype=np.float32) for i in range(len(unique_mapped_indices)): ind = unique_mapped_indices[i] unique_mapped_sims[i] = np.max(sims[mapped_indices == ind]) return unique_mapped_indices, unique_mapped_sims @nb.njit() def map_centroids(indices, sims, mapping, pad_ind, pad_val): mapped_indices = np.full( indices.shape, fill_value=pad_ind, dtype=np.int64) mapped_sims = np.full( indices.shape, fill_value=pad_val, dtype=np.float32) for i in nb.prange(indices.shape[0]): _ind, _sim = _remap_centroid_one(indices[i], sims[i], mapping) mapped_indices[i, :len(_ind)] = _ind mapped_sims[i, :len(_sim)] = _sim return mapped_indices, mapped_sims def construct_shortlist(method, num_neighbours, M, efC, efS, order='centroids', space='cosine', num_threads=-1, num_clusters=1, threshold_freq=10000, verbose=True): if order == 'centroids': return ShortlistCentroids( method, num_neighbours, M, efC, efS, num_threads, space, verbose, num_clusters, threshold_freq) elif order == 'instances': return ShortlistInstances( method, num_neighbours, M, efC, efS, num_threads, space, verbose) else: raise NotImplementedError("Unknown order") class Shortlist(object): """Get nearest neighbors using brute or HNSW algorithm Parameters ---------- method: str brute or hnsw num_neighbours: int number of neighbors M: int HNSW M (Usually 100) efC: int construction parameter (Usually 300) efS: int search parameter (Usually 300) num_threads: int, optional, default=-1 use multiple threads to cluster """ def __init__(self, method, num_neighbours, M, efC, efS, num_threads=24): self.method = method self.num_neighbours = num_neighbours self.M = M self.efC = efC self.efS = efS self.num_threads = num_threads self.index = None self._construct() def _construct(self): if self.method == 'brute': self.index = NearestNeighbor( num_neighbours=self.num_neighbours, method='brute', num_threads=self.num_threads ) elif self.method == 'hnsw': self.index = HNSW( M=self.M, efC=self.efC, efS=self.efS, num_neighbours=self.num_neighbours, num_threads=self.num_threads ) else: print("Unknown NN method!") def fit(self, data): self.index.fit(data) def query(self, data, *args, **kwargs): indices, distances = self.index.predict(data, *args, **kwargs) return indices, 1-distances def save(self, fname): self.index.save(fname) def load(self, fname): self.index.load(fname) def reset(self): # TODO Do we need to delete it! del self.index self._construct() @property def model_size(self): # size on disk; see if there is a better solution import tempfile with tempfile.NamedTemporaryFile() as tmp: self.index.save(tmp.name) _size = os.path.getsize(tmp.name)/math.pow(2, 20) return _size def __repr__(self): return "efC: {}, efS: {}, M: {}, num_nbrs: {}, num_threads: {}".format( self.efS, self.efC, self.M, self.num_neighbours, self.num_threads) class ShortlistCentroids(Shortlist): """Get nearest labels using KCentroids * centroid(l) = mean_{i=1}^{N}{x_i*y_il} * brute or HNSW algorithm for search Parameters ---------- method: str, optional, default='hnsw' brute or hnsw num_neighbours: int number of neighbors (same as efS) * may be useful if the NN search retrieve less number of labels * typically doesn't happen with HNSW etc. M: int, optional, default=100 HNSW M (Usually 100) efC: int, optional, default=300 construction parameter (Usually 300) efS: int, optional, default=300 search parameter (Usually 300) num_threads: int, optional, default=18 use multiple threads to cluster space: str, optional, default='cosine' metric to use while quering verbose: boolean, optional, default=True print progress num_clusters: int, optional, default=1 cluster instances => multiple representatives for chosen labels threshold: int, optional, default=5000 cluster instances if a label appear in more than 'threshold' training points """ def __init__(self, method='hnsw', num_neighbours=300, M=100, efC=300, efS=300, num_threads=24, space='cosine', verbose=True, num_clusters=1, threshold=7500, pad_val=-10000): super().__init__(method, num_neighbours, M, efC, efS, num_threads) self.num_clusters = num_clusters self.space = space self.pad_ind = -1 self.mapping = None self.ext_head = None self.threshold = threshold self.pad_val = pad_val def _cluster_multiple_rep(self, features, labels, label_centroids, multi_centroid_indices): embedding_dims = features.shape[1] _cluster_obj = Cluster( indices=multi_centroid_indices, embedding_dims=embedding_dims, num_clusters=self.num_clusters, max_iter=50, n_init=2, num_threads=-1) _cluster_obj.fit(features, labels) label_centroids = np.vstack( [label_centroids, _cluster_obj.predict()]) return label_centroids def process_multiple_rep(self, features, labels, label_centroids): freq = np.array(labels.sum(axis=0)).ravel() if np.max(freq) > self.threshold and self.num_clusters > 1: self.ext_head = np.where(freq >= self.threshold)[0] print("Found {} super-head labels".format(len(self.ext_head))) self.mapping = np.arange(label_centroids.shape[0]) for idx in self.ext_head: self.mapping = np.append( self.mapping, [idx]*self.num_clusters) return self._cluster_multiple_rep( features, labels, label_centroids, self.ext_head) else: return label_centroids def fit(self, features, labels, *args, **kwargs): self.pad_ind = labels.shape[1] label_centroids = compute_centroid(features, labels, reduction='mean') label_centroids = self.process_multiple_rep( features, labels, label_centroids) norms = np.sum(np.square(label_centroids), axis=1) super().fit(label_centroids) def query(self, data, *args, **kwargs): indices, sim = super().query(data, *args, **kwargs) return self._remap(indices, sim) def _remap(self, indices, sims): if self.mapping is None: return indices, sims return map_centroids( indices, sims, self.mapping, self.pad_ind, self.pad_val) def load(self, fname): temp = pickle.load(open(fname+".metadata", 'rb')) self.pad_ind = temp['pad_ind'] self.pad_val = temp['pad_val'] self.mapping = temp['mapping'] self.ext_head = temp['ext_head'] super().load(fname+".index") def save(self, fname): metadata = { 'pad_ind': self.pad_ind, 'pad_val': self.pad_val, 'mapping': self.mapping, 'ext_head': self.ext_head, } pickle.dump(metadata, open(fname+".metadata", 'wb')) super().save(fname+".index") def purge(self, fname): # purge files from disk if os.path.isfile(fname+".index"): os.remove(fname+".index") if os.path.isfile(fname+".metadata"): os.remove(fname+".metadata") def __repr__(self): s = "efC: {efC}, efS: {efS}, M: {M}, num_nbrs: {num_neighbours}" \ ", pad_ind: {pad_ind}, num_threads: {num_threads}" \ ", pad_val: {pad_val}, threshold: {threshold}" \ ", num_clusters: {num_clusters}" return s.format(**self.__dict__) class ShortlistInstances(Shortlist): """Get nearest labels using KNN * brute or HNSW algorithm for search Parameters ---------- method: str, optional, default='hnsw' brute or hnsw num_neighbours: int number of labels to keep per data point * labels may be shared across fetched instances * union of labels can be large when dataset is densly tagged M: int, optional, default=100 HNSW M (Usually 100) efC: int, optional, default=300 construction parameter (Usually 300) efS: int, optional, default=300 search parameter (Usually 300) num_threads: int, optional, default=18 use multiple threads to cluster space: str, optional, default='cosine' metric to use while quering verbose: boolean, optional, default=True print progress pad_val: int, optional, default=-10000 value for padding indices - Useful as all documents may have different number of nearest labels after collasping them """ def __init__(self, method='hnsw', num_neighbours=300, M=100, efC=300, efS=300, num_threads=24, space='cosine', verbose=False, pad_val=-10000): super().__init__(method, num_neighbours, M, efC, efS, num_threads) self.labels = None self.space = space self.pad_ind = None self.pad_val = pad_val def _remove_invalid(self, features, labels): # Keep data points with nnz features and atleast one label ind_ft = np.where(np.sum(np.square(features), axis=1) > 0)[0] ind_lb = np.where(np.sum(labels, axis=1) > 0)[0] ind = np.intersect1d(ind_ft, ind_lb) return features[ind], labels[ind] def _as_array(self, labels): n_pos_labels = list(map(len, labels)) _labels = np.full( (len(labels), max(n_pos_labels)), self.pad_ind, np.int64) for ind, _lab in enumerate(labels): _labels[ind, :n_pos_labels[ind]] = labels[ind] return _labels def _remap(self, indices, distances): return map_neighbors( indices, 1-distances, self.labels, self.num_neighbours, self.pad_ind, self.pad_val) def fit(self, features, labels): features, labels = self._remove_invalid(features, labels) self.index.fit(features) self.pad_ind = labels.shape[1] self.labels = self._as_array(labels.tolil().rows) def query(self, data, *args, **kwargs): indices, distances = self.index.predict(data) indices, similarities = self._remap(indices, distances) return indices, similarities def save(self, fname): self.index.save(fname+".index") pickle.dump( {'labels': self.labels, 'pad_ind': self.pad_ind, 'pad_val': self.pad_val, 'num_neighbours': self.num_neighbours, 'space': self.space}, open(fname+".metadata", 'wb')) def load(self, fname): self.index.load(fname+".index") obj = pickle.load( open(fname+".metadata", 'rb')) self.num_neighbours = obj['num_neighbours'] self.space = obj['space'] self.labels = obj['labels'] self.pad_ind = obj['pad_ind'] self.pad_val = obj['pad_val'] def purge(self, fname): # purge files from disk if os.path.isfile(fname+".index"): os.remove(fname+".index") if os.path.isfile(fname+".metadata"): os.remove(fname+".metadata") def __repr__(self): s = "efC: {efC}, efS: {efS}, M: {M}, num_nbrs: {num_neighbours}" \ ", pad_ind: {pad_ind}, num_threads: {num_threads}" \ ", pad_val: {pad_val}" return s.format(**self.__dict__)