# Copyright 2019 Karsten Roth and Biagio Brattoli # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== ################## LIBRARIES ############################## import warnings warnings.filterwarnings("ignore") import numpy as np, os, sys, pandas as pd, csv, random, datetime import torch, torch.nn as nn from torch.utils.data import Dataset from torchvision import transforms from PIL import Image import matplotlib.pyplot as plt from tqdm import tqdm import pickle as pkl from sklearn import metrics from sklearn import cluster import faiss import losses as losses """=============================================================================================================""" ################# ACQUIRE NUMBER OF WEIGHTS ################# def gimme_params(model): """ Provide number of trainable parameters (i.e. those requiring gradient computation) for input network. Args: model: PyTorch Network Returns: int, number of parameters. """ model_parameters = filter(lambda p: p.requires_grad, model.parameters()) params = sum([np.prod(p.size()) for p in model_parameters]) return params ################# SAVE TRAINING PARAMETERS IN NICE STRING ################# def gimme_save_string(opt): """ Taking the set of parameters and convert it to easy-to-read string, which can be stored later. Args: opt: argparse.Namespace, contains all training-specific parameters. Returns: string, returns string summary of parameters. """ varx = vars(opt) base_str = '' for key in varx: base_str += str(key) if isinstance(varx[key],dict): for sub_key, sub_item in varx[key].items(): base_str += '\n\t'+str(sub_key)+': '+str(sub_item) else: base_str += '\n\t'+str(varx[key]) base_str+='\n\n' return base_str def f1_score(model_generated_cluster_labels, target_labels, feature_coll, computed_centroids): """ NOTE: MOSTLY ADAPTED FROM https://github.com/wzzheng/HDML on Hardness-Aware Deep Metric Learning. Args: model_generated_cluster_labels: np.ndarray [n_samples x 1], Cluster labels computed on top of data embeddings. target_labels: np.ndarray [n_samples x 1], ground truth labels for each data sample. feature_coll: np.ndarray [n_samples x embed_dim], total data embedding made by network. computed_centroids: np.ndarray [num_cluster=num_classes x embed_dim], cluster coordinates Returns: float, F1-score """ from scipy.special import comb d = np.zeros(len(feature_coll)) for i in range(len(feature_coll)): d[i] = np.linalg.norm(feature_coll[i,:] - computed_centroids[model_generated_cluster_labels[i],:]) labels_pred = np.zeros(len(feature_coll)) for i in np.unique(model_generated_cluster_labels): index = np.where(model_generated_cluster_labels == i)[0] ind = np.argmin(d[index]) cid = index[ind] labels_pred[index] = cid N = len(target_labels) #Cluster n_labels avail_labels = np.unique(target_labels) n_labels = len(avail_labels) #Count the number of objects in each cluster count_cluster = np.zeros(n_labels) for i in range(n_labels): count_cluster[i] = len(np.where(target_labels == avail_labels[i])[0]) #Build a mapping from item_id to item index keys = np.unique(labels_pred) num_item = len(keys) values = range(num_item) item_map = dict() for i in range(len(keys)): item_map.update([(keys[i], values[i])]) #Count the number of objects of each item count_item = np.zeros(num_item) for i in range(N): index = item_map[labels_pred[i]] count_item[index] = count_item[index] + 1 #Compute True Positive (TP) plus False Positive (FP) count tp_fp = 0 for k in range(n_labels): if count_cluster[k] > 1: tp_fp = tp_fp + comb(count_cluster[k], 2) #Compute True Positive (TP) count tp = 0 for k in range(n_labels): member = np.where(target_labels == avail_labels[k])[0] member_ids = labels_pred[member] count = np.zeros(num_item) for j in range(len(member)): index = item_map[member_ids[j]] count[index] = count[index] + 1 for i in range(num_item): if count[i] > 1: tp = tp + comb(count[i], 2) #Compute False Positive (FP) count fp = tp_fp - tp #Compute False Negative (FN) count count = 0 for j in range(num_item): if count_item[j] > 1: count = count + comb(count_item[j], 2) fn = count - tp # compute F measure beta = 1 P = tp / (tp + fp) R = tp / (tp + fn) F1 = (beta*beta + 1) * P * R / (beta*beta * P + R) return F1 """=============================================================================================================""" def eval_metrics_one_dataset(model, test_dataloader, device, k_vals, opt): """ Compute evaluation metrics on test-dataset, e.g. NMI, F1 and Recall @ k. Args: model: PyTorch network, network to compute evaluation metrics for. test_dataloader: PyTorch Dataloader, dataloader for test dataset, should have no shuffling and correct processing. device: torch.device, Device to run inference on. k_vals: list of int, Recall values to compute opt: argparse.Namespace, contains all training-specific parameters. Returns: F1 score (float), NMI score (float), recall_at_k (list of float), data embedding (np.ndarray) """ torch.cuda.empty_cache() _ = model.eval() n_classes = len(test_dataloader.dataset.avail_classes) with torch.no_grad(): ### For all test images, extract features target_labels, feature_coll = [],[] final_iter = tqdm(test_dataloader, desc='Computing Evaluation Metrics...') image_paths= [x[0] for x in test_dataloader.dataset.image_list] for idx,inp in enumerate(final_iter): input_img,target = inp[-1], inp[0] target_labels.extend(target.numpy().tolist()) out = model(input_img.to(device)) feature_coll.extend(out.cpu().detach().numpy().tolist()) target_labels = np.hstack(target_labels).reshape(-1,1) feature_coll = np.vstack(feature_coll).astype('float32') torch.cuda.empty_cache() ### Set Faiss CPU Cluster index cpu_cluster_index = faiss.IndexFlatL2(feature_coll.shape[-1]) kmeans = faiss.Clustering(feature_coll.shape[-1], n_classes) kmeans.niter = 20 kmeans.min_points_per_centroid = 1 kmeans.max_points_per_centroid = 1000000000 ### Train Kmeans kmeans.train(feature_coll, cpu_cluster_index) computed_centroids = faiss.vector_float_to_array(kmeans.centroids).reshape(n_classes, feature_coll.shape[-1]) ### Assign feature points to clusters faiss_search_index = faiss.IndexFlatL2(computed_centroids.shape[-1]) faiss_search_index.add(computed_centroids) _, model_generated_cluster_labels = faiss_search_index.search(feature_coll, 1) ### Compute NMI NMI = metrics.cluster.normalized_mutual_info_score(model_generated_cluster_labels.reshape(-1), target_labels.reshape(-1)) ### Recover max(k_vals) nearest neighbours to use for recall computation faiss_search_index = faiss.IndexFlatL2(feature_coll.shape[-1]) faiss_search_index.add(feature_coll) _, k_closest_points = faiss_search_index.search(feature_coll, int(np.max(k_vals)+1)) k_closest_classes = target_labels.reshape(-1)[k_closest_points[:,1:]] ### Compute Recall recall_all_k = [] for k in k_vals: recall_at_k = np.sum([1 for target, recalled_predictions in zip(target_labels, k_closest_classes) if target in recalled_predictions[:k]])/len(target_labels) recall_all_k.append(recall_at_k) ### Compute F1 Score F1 = f1_score(model_generated_cluster_labels, target_labels, feature_coll, computed_centroids) return F1, NMI, recall_all_k, feature_coll def eval_metrics_query_and_gallery_dataset(model, query_dataloader, gallery_dataloader, device, k_vals, opt): """ Compute evaluation metrics on test-dataset, e.g. NMI, F1 and Recall @ k. Args: model: PyTorch network, network to compute evaluation metrics for. query_dataloader: PyTorch Dataloader, dataloader for query dataset, for which nearest neighbours in the gallery dataset are retrieved. gallery_dataloader: PyTorch Dataloader, dataloader for gallery dataset, provides target samples which are to be retrieved in correspondance to the query dataset. device: torch.device, Device to run inference on. k_vals: list of int, Recall values to compute opt: argparse.Namespace, contains all training-specific parameters. Returns: F1 score (float), NMI score (float), recall_at_ks (list of float), query data embedding (np.ndarray), gallery data embedding (np.ndarray) """ torch.cuda.empty_cache() _ = model.eval() n_classes = len(query_dataloader.dataset.avail_classes) with torch.no_grad(): ### For all query test images, extract features query_target_labels, query_feature_coll = [],[] query_image_paths = [x[0] for x in query_dataloader.dataset.image_list] query_iter = tqdm(query_dataloader, desc='Extraction Query Features') for idx,inp in enumerate(query_iter): input_img,target = inp[-1], inp[0] query_target_labels.extend(target.numpy().tolist()) out = model(input_img.to(device)) query_feature_coll.extend(out.cpu().detach().numpy().tolist()) ### For all gallery test images, extract features gallery_target_labels, gallery_feature_coll = [],[] gallery_image_paths = [x[0] for x in gallery_dataloader.dataset.image_list] gallery_iter = tqdm(gallery_dataloader, desc='Extraction Gallery Features') for idx,inp in enumerate(gallery_iter): input_img,target = inp[-1], inp[0] gallery_target_labels.extend(target.numpy().tolist()) out = model(input_img.to(device)) gallery_feature_coll.extend(out.cpu().detach().numpy().tolist()) query_target_labels, query_feature_coll = np.hstack(query_target_labels).reshape(-1,1), np.vstack(query_feature_coll).astype('float32') gallery_target_labels, gallery_feature_coll = np.hstack(gallery_target_labels).reshape(-1,1), np.vstack(gallery_feature_coll).astype('float32') torch.cuda.empty_cache() ### Set CPU Cluster index stackset = np.concatenate([query_feature_coll, gallery_feature_coll],axis=0) stacklabels = np.concatenate([query_target_labels, gallery_target_labels],axis=0) cpu_cluster_index = faiss.IndexFlatL2(stackset.shape[-1]) kmeans = faiss.Clustering(stackset.shape[-1], n_classes) kmeans.niter = 20 kmeans.min_points_per_centroid = 1 kmeans.max_points_per_centroid = 1000000000 ### Train Kmeans kmeans.train(stackset, cpu_cluster_index) computed_centroids = faiss.vector_float_to_array(kmeans.centroids).reshape(n_classes, stackset.shape[-1]) ### Assign feature points to clusters faiss_search_index = faiss.IndexFlatL2(computed_centroids.shape[-1]) faiss_search_index.add(computed_centroids) _, model_generated_cluster_labels = faiss_search_index.search(stackset, 1) ### Compute NMI NMI = metrics.cluster.normalized_mutual_info_score(model_generated_cluster_labels.reshape(-1), stacklabels.reshape(-1)) ### Recover max(k_vals) nearest neighbours to use for recall computation faiss_search_index = faiss.IndexFlatL2(gallery_feature_coll.shape[-1]) faiss_search_index.add(gallery_feature_coll) _, k_closest_points = faiss_search_index.search(query_feature_coll, int(np.max(k_vals))) k_closest_classes = gallery_target_labels.reshape(-1)[k_closest_points] ### Compute Recall recall_all_k = [] for k in k_vals: recall_at_k = np.sum([1 for target, recalled_predictions in zip(query_target_labels, k_closest_classes) if target in recalled_predictions[:k]])/len(query_target_labels) recall_all_k.append(recall_at_k) recall_str = ', '.join('@{0}: {1:.4f}'.format(k,rec) for k,rec in zip(k_vals, recall_all_k)) ### Compute F1 score F1 = f1_score(model_generated_cluster_labels, stacklabels, stackset, computed_centroids) return F1, NMI, recall_all_k, query_feature_coll, gallery_feature_coll """=============================================================================================================""" ####### RECOVER CLOSEST EXAMPLE IMAGES ####### def recover_closest_one_dataset(feature_matrix_all, image_paths, save_path, n_image_samples=10, n_closest=3): """ Provide sample recoveries. Args: feature_matrix_all: np.ndarray [n_samples x embed_dim], full data embedding of test samples. image_paths: list [n_samples], list of datapaths corresponding to <feature_matrix_all> save_path: str, where to store sample image. n_image_samples: Number of sample recoveries. n_closest: Number of closest recoveries to show. Returns: Nothing! """ image_paths = np.array([x[0] for x in image_paths]) sample_idxs = np.random.choice(np.arange(len(feature_matrix_all)), n_image_samples) faiss_search_index = faiss.IndexFlatL2(feature_matrix_all.shape[-1]) faiss_search_index.add(feature_matrix_all) _, closest_feature_idxs = faiss_search_index.search(feature_matrix_all, n_closest+1) sample_paths = image_paths[closest_feature_idxs][sample_idxs] f,axes = plt.subplots(n_image_samples, n_closest+1) for i,(ax,plot_path) in enumerate(zip(axes.reshape(-1), sample_paths.reshape(-1))): ax.imshow(np.array(Image.open(plot_path))) ax.set_xticks([]) ax.set_yticks([]) if i%(n_closest+1): ax.axvline(x=0, color='g', linewidth=13) else: ax.axvline(x=0, color='r', linewidth=13) f.set_size_inches(10,20) f.tight_layout() f.savefig(save_path) plt.close() ####### RECOVER CLOSEST EXAMPLE IMAGES ####### def recover_closest_inshop(query_feature_matrix_all, gallery_feature_matrix_all, query_image_paths, gallery_image_paths, save_path, n_image_samples=10, n_closest=3): """ Provide sample recoveries. Args: query_feature_matrix_all: np.ndarray [n_query_samples x embed_dim], full data embedding of query samples. gallery_feature_matrix_all: np.ndarray [n_gallery_samples x embed_dim], full data embedding of gallery samples. query_image_paths: list [n_samples], list of datapaths corresponding to <query_feature_matrix_all> gallery_image_paths: list [n_samples], list of datapaths corresponding to <gallery_feature_matrix_all> save_path: str, where to store sample image. n_image_samples: Number of sample recoveries. n_closest: Number of closest recoveries to show. Returns: Nothing! """ query_image_paths, gallery_image_paths = np.array(query_image_paths), np.array(gallery_image_paths) sample_idxs = np.random.choice(np.arange(len(query_feature_matrix_all)), n_image_samples) faiss_search_index = faiss.IndexFlatL2(gallery_feature_matrix_all.shape[-1]) faiss_search_index.add(gallery_feature_matrix_all) _, closest_feature_idxs = faiss_search_index.search(query_feature_matrix_all, n_closest) image_paths = gallery_image_paths[closest_feature_idxs] image_paths = np.concatenate([query_image_paths.reshape(-1,1), image_paths],axis=-1) sample_paths = image_paths[closest_feature_idxs][sample_idxs] f,axes = plt.subplots(n_image_samples, n_closest+1) for i,(ax,plot_path) in enumerate(zip(axes.reshape(-1), sample_paths.reshape(-1))): ax.imshow(np.array(Image.open(plot_path))) ax.set_xticks([]) ax.set_yticks([]) if i%(n_closest+1): ax.axvline(x=0, color='g', linewidth=13) else: ax.axvline(x=0, color='r', linewidth=13) f.set_size_inches(10,20) f.tight_layout() f.savefig(save_path) plt.close() """=============================================================================================================""" ################## SET NETWORK TRAINING CHECKPOINT ##################### def set_checkpoint(model, opt, progress_saver, savepath): """ Store relevant parameters (model and progress saver, as well as parameter-namespace). Can be easily extend for other stuff. Args: model: PyTorch network, network whose parameters are to be saved. opt: argparse.Namespace, includes all training-specific parameters progress_saver: subclass of LOGGER-class, contains a running memory of all training metrics. savepath: str, where to save checkpoint. Returns: Nothing! """ torch.save({'state_dict':model.state_dict(), 'opt':opt, 'progress':progress_saver}, savepath) """=============================================================================================================""" ################## WRITE TO CSV FILE ##################### class CSV_Writer(): """ Class to append newly compute training metrics to a csv file for data logging. Is used together with the LOGGER class. """ def __init__(self, save_path, columns): """ Args: save_path: str, where to store the csv file columns: list of str, name of csv columns under which the resp. metrics are stored. Returns: Nothing! """ self.save_path = save_path self.columns = columns with open(self.save_path, "a") as csv_file: writer = csv.writer(csv_file, delimiter=",") writer.writerow(self.columns) def log(self, inputs): """ log one set of entries to the csv. Args: inputs: [list of int/str/float], values to append to the csv. Has to be of the same length as self.columns. Returns: Nothing! """ with open(self.save_path, "a") as csv_file: writer = csv.writer(csv_file, delimiter=',') writer.writerow(inputs) ################## PLOT SUMMARY IMAGE ##################### class InfoPlotter(): """ Plotter class to visualize training progression by showing different metrics. """ def __init__(self, save_path, title='Training Log', figsize=(20,15)): """ Args: save_path: str, where to store the create plot. title: placeholder title of plot figsize: base size of saved figure Returns: Nothing! """ self.save_path = save_path self.title = title self.figsize = figsize #Colors for validation lines self.v_colors = ['r','g','b','y','m','k','c'] #Colors for training lines self.t_colors = ['k','b','r','g'] def make_plot(self, t_epochs, v_epochs, t_metrics, v_metrics, t_labels, v_labels, appendix=None): """ Given a list of iterated epochs, visualize the progression of various training/testing metrics. Args: t_epochs: [list of int/float], list of epochs for which training metrics were collected (e.g. Training Loss) v_epochs: [list of int/float], list of epochs for which validation metrics were collected (e.g. Recall @ k) t_metrics: [list of float], list of training metrics per epoch v_metrics: [list of list of int/float], contains all computed validation metrics t_labels, v_labels: [list of str], names for each metric that is plotted. Returns: Nothing! """ plt.style.use('ggplot') f,axes = plt.subplots(1,2) #Visualize Training Loss for i in range(len(t_metrics)): axes[0].plot(t_epochs, t_metrics[i], '-{}'.format(self.t_colors[i]), linewidth=1, label=t_labels[i]) axes[0].set_title('Training Performance', fontsize=19) axes[0].legend(fontsize=16) axes[0].tick_params(axis='both', which='major', labelsize=16) axes[0].tick_params(axis='both', which='minor', labelsize=16) #Visualize Validation metrics for i in range(len(v_metrics)): axes[1].plot(v_epochs, v_metrics[i], '-{}'.format(self.v_colors[i]), linewidth=1, label=v_labels[i]) axes[1].set_title(self.title, fontsize=19) axes[1].legend(fontsize=16) axes[1].tick_params(axis='both', which='major', labelsize=16) axes[1].tick_params(axis='both', which='minor', labelsize=16) f.set_size_inches(2*self.figsize[0], self.figsize[1]) savepath = self.save_path f.savefig(self.save_path, bbox_inches='tight') plt.close() ################## GENERATE LOGGING FOLDER/FILES ####################### def set_logging(opt): """ Generate the folder in which everything is saved. If opt.savename is given, folder will take on said name. If not, a name based on the start time is provided. If the folder already exists, it will by iterated until it can be created without deleting existing data. The current opt.save_path will be extended to account for the new save_folder name. Args: opt: argparse.Namespace, contains all training-specific parameters. Returns: Nothing! """ checkfolder = opt.save_path+'/'+opt.savename #Create start-time-based name if opt.savename is not give. if opt.savename == '': date = datetime.datetime.now() time_string = '{}-{}-{}-{}-{}-{}'.format(date.year, date.month, date.day, date.hour, date.minute, date.second) checkfolder = opt.save_path+'/{}_{}_'.format(opt.dataset.upper(), opt.arch.upper())+time_string #If folder already exists, iterate over it until is doesn't. counter = 1 while os.path.exists(checkfolder): checkfolder = opt.save_path+'/'+opt.savename+'_'+str(counter) counter += 1 #Create Folder os.makedirs(checkfolder) opt.save_path = checkfolder #Store training parameters as text and pickle in said folder. with open(opt.save_path+'/Parameter_Info.txt','w') as f: f.write(gimme_save_string(opt)) pkl.dump(opt,open(opt.save_path+"/hypa.pkl","wb")) class LOGGER(): """ This class provides a collection of logging properties that are useful for training. These include setting the save folder, in which progression of training/testing metrics is visualized, csv log-files are stored, sample recoveries are plotted and an internal data saver. """ def __init__(self, opt, metrics_to_log, name='Basic', start_new=True): """ Args: opt: argparse.Namespace, contains all training-specific parameters. metrics_to_log: dict, dictionary which shows in what structure the data should be saved. is given as the output of aux.metrics_to_examine. Example: {'train': ['Epochs', 'Time', 'Train Loss', 'Time'], 'val': ['Epochs','Time','NMI','F1', 'Recall @ 1','Recall @ 2','Recall @ 4','Recall @ 8']} name: Name of this logger. Will be used to distinguish logged files from other LOGGER instances. start_new: If set to true, a new save folder will be created initially. Returns: Nothing! """ self.prop = opt self.metrics_to_log = metrics_to_log ### Make Logging Directories if start_new: set_logging(opt) ### Set INFO-PLOTS if self.prop.dataset != 'vehicle_id': self.info_plot = InfoPlotter(opt.save_path+'/InfoPlot_{}.svg'.format(name)) else: self.info_plot = {'Set {}'.format(i): InfoPlotter(opt.save_path+'/InfoPlot_{}_Set{}.svg'.format(name,i+1)) for i in range(3)} ### Set Progress Saver Dict self.progress_saver = self.provide_progress_saver(metrics_to_log) ### Set CSV Writters self.csv_loggers= {mode:CSV_Writer(opt.save_path+'/log_'+mode+'_'+name+'.csv', lognames) for mode, lognames in metrics_to_log.items()} def provide_progress_saver(self, metrics_to_log): """ Provide Progress Saver dictionary. Args: metrics_to_log: see __init__(). Describes the structure of Progress_Saver. """ Progress_Saver = {key:{sub_key:[] for sub_key in metrics_to_log[key]} for key in metrics_to_log.keys()} return Progress_Saver def log(self, main_keys, metric_keys, values): """ Actually log new values in csv and Progress Saver dict internally. Args: main_keys: Main key in which data will be stored. Normally is either 'train' for training metrics or 'val' for validation metrics. metric_keys: Needs to follow the list length of self.progress_saver[main_key(s)]. List of metric keys that are extended with new values. values: Needs to be a list of the same structure as metric_keys. Actual values that are appended. """ if not isinstance(main_keys, list): main_keys = [main_keys] if not isinstance(metric_keys, list): metric_keys = [metric_keys] if not isinstance(values, list): values = [values] #Log data to progress saver dict. for main_key in main_keys: for value, metric_key in zip(values, metric_keys): self.progress_saver[main_key][metric_key].append(value) #Append data to csv. self.csv_loggers[main_key].log(values) def update_info_plot(self): """ Create a new updated version of training/metric progression plot. Args: None Returns: Nothing! """ t_epochs = self.progress_saver['val']['Epochs'] t_loss_list = [self.progress_saver['train']['Train Loss']] t_legend_handles = ['Train Loss'] v_epochs = self.progress_saver['val']['Epochs'] #Because Vehicle-ID normally uses three different test sets, a distinction has to be made. if self.prop.dataset != 'vehicle_id': title = ' | '.join(key+': {0:3.3f}'.format(np.max(item)) for key,item in self.progress_saver['val'].items() if key not in ['Time', 'Epochs']) self.info_plot.title = title v_metric_list = [self.progress_saver['val'][key] for key in self.progress_saver['val'].keys() if key not in ['Time', 'Epochs']] v_legend_handles = [key for key in self.progress_saver['val'].keys() if key not in ['Time', 'Epochs']] self.info_plot.make_plot(t_epochs, v_epochs, t_loss_list, v_metric_list, t_legend_handles, v_legend_handles) else: #Iterate over all test sets. for i in range(3): title = ' | '.join(key+': {0:3.3f}'.format(np.max(item)) for key,item in self.progress_saver['val'].items() if key not in ['Time', 'Epochs'] and 'Set {}'.format(i) in key) self.info_plot['Set {}'.format(i)].title = title v_metric_list = [self.progress_saver['val'][key] for key in self.progress_saver['val'].keys() if key not in ['Time', 'Epochs'] and 'Set {}'.format(i) in key] v_legend_handles = [key for key in self.progress_saver['val'].keys() if key not in ['Time', 'Epochs'] and 'Set {}'.format(i) in key] self.info_plot['Set {}'.format(i)].make_plot(t_epochs, v_epochs, t_loss_list, v_metric_list, t_legend_handles, v_legend_handles, appendix='set_{}'.format(i)) def metrics_to_examine(dataset, k_vals): """ Please only use either of the following keys: -> Epochs, Time, Train Loss for training -> Epochs, Time, NMI, F1 & Recall @ k for validation Args: dataset: str, dataset for which a storing structure for LOGGER.progress_saver is to be made. k_vals: list of int, Recall @ k - values. Returns: metric_dict: Dictionary representing the storing structure for LOGGER.progress_saver. See LOGGER.__init__() for an example. """ metric_dict = {'train':['Epochs','Time','Train Loss']} if dataset=='vehicle_id': metric_dict['val'] = ['Epochs','Time'] #Vehicle_ID uses three test sets for i in range(3): metric_dict['val'] += ['Set {} NMI'.format(i), 'Set {} F1'.format(i)] for k in k_vals: metric_dict['val'] += ['Set {} Recall @ {}'.format(i,k)] else: metric_dict['val'] = ['Epochs','Time','NMI', 'F1'] metric_dict['val'] += ['Recall @ {}'.format(k) for k in k_vals] return metric_dict """=================================================================================================""" def run_kmeans(features, n_cluster): """ Run kmeans on a set of features to find <n_cluster> cluster. Args: features: np.ndarrary [n_samples x embed_dim], embedding training/testing samples for which kmeans should be performed. n_cluster: int, number of cluster. Returns: cluster_assignments: np.ndarray [n_samples x 1], per sample provide the respective cluster label it belongs to. """ n_samples, dim = features.shape kmeans = faiss.Kmeans(dim, n_cluster) kmeans.n_iter, kmeans.min_points_per_centroid, kmeans.max_points_per_centroid = 20,5,1000000000 kmeans.train(features) _, cluster_assignments = kmeans.index.search(features,1) return cluster_assignments """=============================================================================================================""" def save_graph(opt, model): """ Generate Network Graph. NOTE: Requires the installation of the graphviz library on you system. Args: opt: argparse.Namespace, contains all training-specific parameters. model: PyTorch Network, network for which the computational graph should be visualized. Returns: Nothing! """ inp = torch.randn((1,3,224,224)).to(opt.device) network_output = model(inp) if isinstance(network_output, dict): network_output = network_output['Class'] from graphviz import Digraph def make_dot(var, savename, params=None): """ Generate a symbolic representation of the network graph. """ if params is not None: assert all(isinstance(p, Variable) for p in params.values()) param_map = {id(v): k for k, v in params.items()} node_attr = dict(style='filled', shape='box', align='left', fontsize='6', ranksep='0.1', height='0.6', width='1') dot = Digraph(node_attr=node_attr, format='svg', graph_attr=dict(size="40,10", rankdir='LR', rank='same')) seen = set() def size_to_str(size): return '('+(', ').join(['%d' % v for v in size])+')' def add_nodes(var): replacements = ['Backward', 'Th', 'Cudnn'] color_assigns = {'Convolution':'orange', 'ConvolutionTranspose': 'lightblue', 'Add': 'red', 'Cat': 'green', 'Softmax': 'yellow', 'Sigmoid': 'yellow', 'Copys': 'yellow'} if var not in seen: op1 = torch.is_tensor(var) op2 = not torch.is_tensor(var) and str(type(var).__name__)!='AccumulateGrad' text = str(type(var).__name__) for rep in replacements: text = text.replace(rep, '') color = color_assigns[text] if text in color_assigns.keys() else 'gray' if 'Pool' in text: color = 'lightblue' if op1 or op2: if hasattr(var, 'next_functions'): count = 0 for i, u in enumerate(var.next_functions): if str(type(u[0]).__name__)=='AccumulateGrad': if count==0: attr_text = '\nParameter Sizes:\n' attr_text += size_to_str(u[0].variable.size()) count += 1 attr_text += ' ' if count>0: text += attr_text if op1: dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange') if op2: dot.node(str(id(var)), text, fillcolor=color) seen.add(var) if op1 or op2: if hasattr(var, 'next_functions'): for u in var.next_functions: if u[0] is not None: if str(type(u[0]).__name__)!='AccumulateGrad': dot.edge(str(id(u[0])), str(id(var))) add_nodes(u[0]) if hasattr(var, 'saved_tensors'): for t in var.saved_tensors: dot.edge(str(id(t)), str(id(var))) add_nodes(t) add_nodes(var.grad_fn) dot.save(savename) return dot if not os.path.exists(opt.save_path): raise Exception('No save folder {} available!'.format(opt.save_path)) viz_graph = make_dot(network_output, opt.save_path+"/Network_Graphs"+"/{}_network_graph".format(opt.arch)) viz_graph.format = 'svg' viz_graph.render() torch.cuda.empty_cache()