# test.py: eval query image import torch from math import floor import torch.nn.functional as F from torchvision.transforms import ToTensor, Normalize, Resize, RandomHorizontalFlip, RandomRotation, Compose import utils.transforms as trf from PIL import Image def extract_query(net, dataset, q_idx, scale=800, crop=True, flip=None, rotate=None): # Load query image img = Image.open(dataset.get_query_filename(q_idx)) # Crop the query ROI if crop: img = img.crop(tuple(dataset.get_query_roi(q_idx))) # Apply transformations img = trf.resize_image(img, scale) # Flip if flip: img = trf.flip_image(img, flip) # Rotation if rotate: img = trf.rotate(img, rotate) # Convert to Pytorch's tensor and normalize I = trf.to_tensor(img) I = trf.normalize(I, dict(rgb_means=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) I = I.unsqueeze(0)#.to(device) # Forward pass to extract the features with torch.no_grad(): print ('Extracting features using input={} pixels...'.format(scale)) q_feat = net(I).cpu().numpy() return q_feat, img def q_eval(net, dataset, q_idx, flip=False, rotate=False, scale=1): # load query image q_im = dataset.get_image(q_idx) q_size = q_im.size # list of transformation lists trfs_chains = [[]] if rotate: eps = 1e-6 trfs_chains[0] += [RandomRotation((rotate-eps,rotate+eps))] if flip: trfs_chains[0] += [RandomHorizontalFlip(1)] if scale == 0: # AlexNet asks for resized images of 224x224 edge_list = [224] resize_list = [Resize((edge,edge)) for edge in edge_list] elif scale == 1: edge_list = [800] resize_list = [lambda im: imresize(im, edge) for edge in edge_list] elif scale == 1.5: edge_list = [1200] resize_list = [lambda im: imresize(im, edge) for edge in edge_list] elif scale == 2: # multiscale edge_list = [600,800,1000,1200] resize_list = [lambda im: imresize(im, edge) for edge in edge_list] else: raise ValueError() if len(resize_list) == 1: trfs_chains[0] += resize_list else: add_trf(trfs_chains, resize_list ) # default transformations mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] for chain in trfs_chains: chain += [ToTensor(), Normalize(mean, std)] net = net.eval() q_feat = torch.zeros( (len(trfs_chains), net.out_features) ) print ('Computing the forward pass and extracting the image representation...') for i in range(len(trfs_chains)): q_tensor = Compose(trfs_chains[i])(q_im) import pdb; pdb.set_trace() # XXX BREAKPOINT q_feat[i] = net.forward(q_tensor.view(1,q_tensor.shape[0],q_tensor.shape[1],q_tensor.shape[2])) return F.normalize(q_feat.mean(dim=0), dim=0).detach().numpy() def add_trf(trfs_chains, new_trfs): n_chains = len(trfs_chains) for trf in new_trfs: for i in range(n_chains): trfs_chains.append(trfs_chains[i] + [trf]) return def imresize(im, maxedge): ''' creates image in a different size, where the aspect ratio as maintainted max max(height,width)=maxedge ''' h,w = im.size if h<w: return im.resize((int(h/w*maxedge), maxedge)) else: return im.resize((maxedge, int(h/w*maxedge)))