#!/usr/bin/env python # coding: utf-8 import numpy as np from tensorflow.keras.datasets.mnist import load_data from .databases import RetrievalDb class MnistRet(RetrievalDb): """ Mnist wrapper. Refs: Arguments: Returns: Instance of MnistRet to get images and labels from train/val/test sets for DML tasks. """ def __init__(self, **kwargs): super(MnistRet, self).__init__(name=None, queries_in_collection=True) self.name = "RET_MNIST" (x_train, y_train), (x_test, y_test) = load_data() idx_train = np.where(y_train < 5)[0] idx_test = np.where(y_test < 5)[0] self.train_images = np.concatenate([x_train[idx_train], x_test[idx_test]], axis=0) self.train_labels = np.concatenate([y_train[idx_train], y_test[idx_test]], axis=0) idx_train = np.where(y_train >= 5)[0] idx_test = np.where(y_test >= 5)[0] self.test_images = np.concatenate([x_train[idx_train], x_test[idx_test]], axis=0) self.test_labels = np.concatenate([y_train[idx_train], y_test[idx_test]], axis=0) self.train_images = self.train_images[..., None] self.test_images = self.test_images[..., None] def get_training_set(self, **kwargs): return self.train_images, self.train_labels def get_validation_set(self, **kwargs): super(MnistRet).get_validation_set(**kwargs) def get_testing_set(self, **kwargs): return self.test_images, self.test_labels @staticmethod def get_usual_retrieval_rank(): return [1, 2, 10] def get_queries_idx(self, db_set): """ Get the set of query images from which metrics are evaluated. :param db_set: string containing either 'train', 'training', 'validation', 'val', 'testing' or 'test'. :return: a nd-array of query indexes. """ if db_set.lower() == 'train' or db_set.lower() == 'training': return np.arange(len(self.train_images), dtype=np.int32) elif db_set.lower() == 'validation' or db_set.lower() == 'val': raise ValueError('There is no validation set for {}.'.format(self.name)) elif db_set.lower() == 'testing' or db_set.lower() == 'test': return np.arange(len(self.test_images), dtype=np.int32) else: raise ValueError("'db_set' unrecognized." "Expected 'train', 'training', 'validation', 'val', 'testing', 'test'" "Got {}".format(db_set)) def get_collection_idx(self, db_set): """ Get the set of collection images for retrieval tasks. :param db_set: string containing either 'train', 'training', 'validation', 'val', 'testing' or 'test'. :return: a nd-array of the collection indexes. """ if db_set.lower() == 'train' or db_set.lower() == 'training': return np.arange(len(self.train_images), dtype=np.int32) elif db_set.lower() == 'validation' or db_set.lower() == 'val': raise ValueError('There is no validation set for {}.'.format(self.name)) elif db_set.lower() == 'testing' or db_set.lower() == 'test': return np.arange(len(self.test_images), dtype=np.int32) else: raise ValueError("'db_set' unrecognized." "Expected 'train', 'training', 'validation', 'val', 'testing', 'test'" "Got {}".format(db_set))