#!/usr/bin/env python # coding: utf-8 import numpy as np import tensorflow as tf import tensorflow.keras.backend as K from abc import abstractmethod, ABC _MAX_LIM = 250000000 # 1Go matrix maximum allowed (assuming float32). class GlobalMetric(ABC): """ Global metric abstract class (all implemented metrics must inherit from it). """ def __init__(self, name): self.name = name @abstractmethod def compute_metric(self, predictions, labels): raise NotImplementedError("This function is currently not implemented.") @staticmethod def check_disrepancies(predictions: np.ndarray, labels: np.ndarray): # Number of samples: assert len(predictions) == len(labels), "'predictions' and 'labels' got a different number of samples. Got"\ "(predictions){}=\={}(labels).".format(len(predictions), len(labels)) # Labels and predictions formats: if labels.ndim == 1: unique_labels = list(set(labels)) elif labels.ndim == 2: # Expect a 0/1 sparse matrix of labels: labels = np.argmax(labels, axis=1) unique_labels = list(set(labels)) else: raise ValueError("Could not understand 'labels' format." "Expect a 1D or 2D ndarray, but got an array with shape {}.".format(np.shape(labels))) if predictions.ndim != 2: raise ValueError("Could not understand 'predictions' format." "Expect a 2D ndarray, but got an array with shape {}.".format(np.shape(predictions))) class KerasRecallAtK(GlobalMetric): """ Recall@K class as a global metrics. Arguments: ind_queries: indexes of all queries ind_collection: indexes of all images from the collection k_list: list of values of K similarity_measure: similarity measure. Should be 'cosine' or 'l2" queries_in_collection: set True if the queries are also in the collection for the search Returns: Instance of GlobalMetric to use in a GlobalMetricCallback. """ def __init__(self, ind_queries, ind_collection=None, k_list: list = [1], similarity_measure='cosine', queries_in_collection=True): super().__init__('Recall@') self.k_list = k_list self.k_max = max(k_list) if len(ind_queries) == 0: raise ValueError('No query indexes given.') self.ind_queries = ind_queries self.ind_collection = ind_collection self.similarity_measure = similarity_measure offset = 1 if queries_in_collection else 0 # Ranking computation graph: if similarity_measure == 'cosine': self.all_representations, self.input_labels, self.batch_representations,\ self.batch_labels, self.ranking = _build_tf_cosine_similarity(max_rank=0, offset=offset) elif similarity_measure == 'l2': self.all_representations, self.input_labels, self.batch_representations,\ self.batch_labels, self.ranking = _build_tf_l2_similarity(max_rank=0, offset=offset) else: raise NotImplementedError self.bin_ranking = K.cast(K.equal(self.ranking, self.batch_labels), K.floatx()) def compute_metric(self, predictions: np.ndarray, labels: np.ndarray): self.check_disrepancies(predictions, labels) if self.ind_collection is None: collection = predictions.transpose() self.ind_collection = np.arange(len(predictions), dtype=np.int32) else: collection = predictions[self.ind_collection, :].transpose() # We can compute recall@K batch = int(np.ceil(_MAX_LIM / float(predictions.shape[0]))) print('Computing {} steps.'.format(np.ceil(len(self.ind_queries) / batch))) sess = tf.get_default_session() if sess is None: sess = tf.Session() b = 0 retrieved = np.zeros((len(self.k_list), 2), dtype=np.float32) retrieved[:, 0] = self.k_list while b < len(self.ind_queries): N = min(batch, len(self.ind_queries) - b) rnk = sess.run(self.bin_ranking, feed_dict={self.all_representations: collection, self.input_labels: labels[self.ind_collection], self.batch_representations: predictions[self.ind_queries[b:b + N], :], self.batch_labels: labels[self.ind_queries[b:b + N], None]}) b += N for i, k in enumerate(self.k_list): retrieved[i, 1] += np.sum(np.float32(np.max(rnk[:, 0:k], axis=1))) retrieved[:, 1] = (retrieved[:, 1] * 100) / float(len(self.ind_queries)) return retrieved def _build_tf_cosine_similarity(max_rank=0, offset=1, eps=1e-12): # We build the graph (See utils.generic_utils.tf_recall_at_k for original implementation): tf_db = K.placeholder(ndim=2, dtype=K.floatx()) # Where to find tf_labels = K.placeholder(ndim=1, dtype=K.floatx()) # and their labels tf_batch_query = K.placeholder(ndim=2, dtype=K.floatx()) # Used in case of memory issues batch_labels = K.placeholder(ndim=2, dtype=K.floatx()) # and their labels all_representations_T = K.expand_dims(tf_db, axis=0) # 1 x D x N batch_representations = K.expand_dims(tf_batch_query, axis=0) # 1 x n x D sim = K.batch_dot(batch_representations, all_representations_T) # 1 x n x N sim = K.squeeze(sim, axis=0) # n x N sim /= tf.linalg.norm(tf_batch_query, axis=1, keepdims=True) + eps sim /= tf.linalg.norm(tf_db, axis=0, keepdims=True) + eps if max_rank > 0: # computing r@K or mAP@K index_ranking = tf.nn.top_k(sim, k=max_rank + offset).indices else: index_ranking = tf.contrib.framework.argsort(sim, axis=-1, direction='DESCENDING', stable=True) top_k = index_ranking[:, offset:] tf_ranking = tf.gather(tf_labels, top_k) return tf_db, tf_labels, tf_batch_query, batch_labels, tf_ranking def _build_tf_l2_similarity(max_rank=0, offset=1): # We build the graph (See utils.generic_utils.tf_recall_at_k for original implementation): tf_db = K.placeholder(ndim=2, dtype=K.floatx()) # Where to find tf_labels = K.placeholder(ndim=1, dtype=K.floatx()) # and their labels tf_batch_query = K.placeholder(ndim=2, dtype=K.floatx()) # Used in case of memory issues batch_labels = K.placeholder(ndim=2, dtype=K.floatx()) # and their labels all_representations_T = K.expand_dims(tf_db, axis=0) # 1 x D x N batch_representations = K.expand_dims(tf_batch_query, axis=0) # 1 x n x D dist = -2. * K.batch_dot(batch_representations, all_representations_T) # 1 x n x N dist = K.squeeze(dist, axis=0) # n x N dist += K.sum(tf_batch_query * tf_batch_query, axis=1, keepdims=True) dist += K.sum(tf_db * tf_db, axis=0, keepdims=True) if max_rank > 0: # computing r@K or mAP@K # top_k finds the k greatest entries and we want the lowest. Note that distance with itself will be last ranked dist = -dist index_ranking = tf.nn.top_k(dist, k=max_rank + offset).indices else: index_ranking = tf.contrib.framework.argsort(dist, axis=-1, direction='ASCENDING', stable=True) index_ranking = index_ranking[:, offset:] tf_ranking = tf.gather(tf_labels, index_ranking) return tf_db, tf_labels, tf_batch_query, batch_labels, tf_ranking