import os,pickle import matplotlib.pyplot as plt import numpy as np import cv2, random import Image from params import get_params from eval import Evaluator CLASSES = ('__background__', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor') class Visualization(): def __init__(self,params): self.dataset= params['dataset'] self.image_path = params['database_images'] self.class_scores = params['use_class_scores'] self.queries = params['query_names'] self.rankings_dir = params['rankings_dir'] self.size_box = params['size_box'] self.stage = params['stage'] self.N_display = params['N_display'] self.figsize = params['figsize'] self.figures_path = params['figures_path'] self.reranking_path = params['reranking_path'] with open(params['query_list'],'r') as f: self.query_names = f.read().splitlines() self.ground_truth = params['ground_truth_file'] def read_ranking(self,query): with open(os.path.join(self.rankings_dir,os.path.basename(query.split('_query')[0]) +'.txt'),'r') as f: ranking = f.read().splitlines() return ranking def query_info(self,filename): ''' For oxford and paris, get query frame and box ''' data = np.loadtxt(filename, dtype="str") if self.dataset is 'paris': query = data[0] elif self.dataset is 'oxford': query = data[0].split('oxc1_')[1] bbx = data[1:].astype(float).astype(int) if self.dataset is 'paris': query = os.path.join(self.image_path,query.split('_')[1],query + '.jpg') elif self.dataset is 'oxford': query = os.path.join(self.image_path,query + '.jpg') return query, bbx def get_query_im(self,query): query,bbx = self.query_info(query) im = cv2.imread(query) im = cv2.cvtColor(im,cv2.COLOR_BGR2RGB) cv2.rectangle(im, (int(bbx[0]), int(bbx[1])), (int(bbx[2]), int(bbx[3])), (255,0,0),self.size_box) return im def create_thumb(self,im): x = 800 y = 800 size = (y,x) image = Image.fromarray(im) image.thumbnail(size, Image.ANTIALIAS) background = Image.new('RGBA', size, "black") background.paste(image, ((size[0] - image.size[0]) / 2, (size[1] - image.size[1]) / 2)) return np.array(background)[:,:,0:3] def vis_one_query(self,query,ranking): grid_size_x = self.N_display + 1 grid_size_y = 1 pos_in_fig = 1 fig = plt.figure(figsize=self.figsize) ax = fig.add_subplot(grid_size_y, grid_size_x, pos_in_fig) query_im = self.get_query_im(query) query_im = self.create_thumb(query_im) query_im = cv2.copyMakeBorder(query_im,30,30,30,30,cv2.BORDER_CONSTANT,value= [0,0,255]) ax.imshow(query_im) ax.axes.get_xaxis().set_visible(False) ax.axes.get_yaxis().set_visible(False) ranking = self.read_ranking(query) junk = np.loadtxt(os.path.join(self.ground_truth,os.path.basename(query).split('_query.txt')[0] + '_junk.txt'),dtype="str") ok = np.loadtxt(os.path.join(self.ground_truth,os.path.basename(query).split('_query.txt')[0] + '_ok.txt'),dtype = "str") good = np.loadtxt(os.path.join(self.ground_truth,os.path.basename(query).split('_query.txt')[0] + '_good.txt'),dtype = "str") if self.stage is 'rerank': with open(os.path.join(self.reranking_path,os.path.basename(query.split('_query')[0]) + '.pkl') ,'rb') as f: distances = pickle.load(f) locations = pickle.load(f) frames = pickle.load(f) class_ids = pickle.load(f) if self.class_scores: frames_sorted = np.array(frames)[np.argsort(distances)[::-1]] locations_sorted = np.array(locations)[np.argsort(distances)[::-1]] else: frames_sorted = np.array(frames)[np.argsort(distances)] locations_sorted = np.array(locations)[np.argsort(distances)] for i in range(self.N_display): frame = ranking[i] if self.dataset is 'paris': frame_to_read = os.path.join(self.image_path,frame.split('_')[1],frame + '.jpg') elif self.dataset is 'oxford': frame_to_read = os.path.join(self.image_path,frame + '.jpg') im = cv2.imread(frame_to_read) im = cv2.cvtColor(im,cv2.COLOR_BGR2RGB) if self.stage is 'rerank': # paint box too bbx = locations_sorted[i,:] print bbx cv2.rectangle(im, (int(bbx[0]), int(bbx[1])), (int(bbx[2]), int(bbx[3])), (255,0,0),self.size_box) im = self.create_thumb(im) if os.path.basename(ranking[i]).split('.jpg')[0] in good: # GREEN im = cv2.copyMakeBorder(im,30,30,30,30,cv2.BORDER_CONSTANT,value= [0,255,0]) elif os.path.basename(ranking[i]).split('.jpg')[0] in ok: # Yellow im = cv2.copyMakeBorder(im,30,30,30,30,cv2.BORDER_CONSTANT,value= [0,255,0]) elif os.path.basename(ranking[i]).split('.jpg')[0] in junk: # ORANGE im = cv2.copyMakeBorder(im,30,30,30,30,cv2.BORDER_CONSTANT,value= [0,255,0]) else: # RED im = cv2.copyMakeBorder(im,30,30,30,30,cv2.BORDER_CONSTANT,value= [255,0,0]) ax2 = fig.add_subplot(grid_size_y, grid_size_x, pos_in_fig + i+1) ax2.imshow(im) ax2.axes.get_xaxis().set_visible(False) ax2.axes.get_yaxis().set_visible(False) ''' if self.stage is 'rerank' and not self.ft_network: ax2.set_title(CLASSES[class_ids[i]], fontsize=50) ''' fig.tight_layout() fig.savefig(os.path.join(self.figures_path,os.path.basename(query).split('_query')[0] + '.png')) plt.close() def vis(self): iter_ = self.query_names for query in iter_: print query ranking = self.read_ranking(query) self.vis_one_query(query,ranking) if __name__ == '__main__': params = get_params() V = Visualization(params) V.vis()