############################################################## # Copyright (c) 2018-present, Facebook, Inc. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. ############################################################## from __future__ import absolute_import from __future__ import division from __future__ import print_function from __future__ import unicode_literals import os import cv2 import numpy as np import logging import sys import torch from torchvision import models os.environ['TORCH_MODEL_ZOO'] = \ '/mnt/vol/gfsai-east/ai-group/users/rgirdhar/StandardModels/PyTorch/ImNet' FORMAT = '%(levelname)s %(filename)s:%(lineno)4d: %(message)s' logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout) logger = logging.getLogger(__name__) default_model = models.resnet18(pretrained=True) def prepare_image(im): im = im[..., (2, 1, 0)] # convert to rgb try: im = cv2.resize(im, (224, 224)) except cv2.error: im = np.zeros((224, 224, 3)) # dummy image logger.warning('Invalid patch, replaced with 0 image.') im = im.transpose(2, 0, 1) mean = np.array([0.485, 0.456, 0.406]).reshape(1, 3, 1, 1) std = np.array([0.229, 0.224, 0.224]).reshape(1, 3, 1, 1) im = (im / 255.0 - mean) / std im = torch.FloatTensor(im).cuda() im = torch.autograd.Variable(im, volatile=True) return im def extract_features(im, test_model=None, layers=('layer3',)): """ Args: im (np.ndarray): Image, read using cv2.imread so is in BGR format. Returns: features (list): List of features from each layer in the list layers. """ model = test_model or default_model model.eval() # Preprocess the image im = prepare_image(im) # Extract the features x = im outputs = [] layers = list(layers) for name, module in model._modules.items(): if len(layers) == 0: break if name == 'fc': # Not sure why I need to do this... x = torch.squeeze(x) x = module.cuda()(x) if name in layers: outputs += [x.data.cpu().clone().numpy()] del layers[layers.index(name)] return outputs