import argparse import os import json from model import ShuffleNet from torchvision import transforms from torch.autograd import Variable import torch from PIL import Image import numpy as np def get_transformer(): normalize = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) transformer = transforms.Compose([ transforms.Resize(128), transforms.ToTensor(), normalize ]) return transformer def preprocess(image, transformer): x = transformer(image) return Variable(x.unsqueeze(0)) def infer(args): # make ShuffleNet model print('Creating ShuffleNet model') net = ShuffleNet(num_classes=args.num_classes, in_channels=3) # load trained checkpoint print('Loading checkpoint') checkpoint = torch.load(args.checkpoint, map_location=lambda storage, loc: storage) net.load_state_dict(checkpoint['state_dict']) print('Loading index-class map') with open(args.idx_to_class, 'r') as f: mapping = json.load(f) # image transformer transformer = get_transformer() # make input tensor print('Loading image') image = Image.open(args.image) print('Preprocessing') x = preprocess(image, transformer) # predict output print('Inferring on image {}'.format(args.image)) net.eval() y = net(x) top_idxs = np.argsort(y.data.cpu().numpy().ravel()).tolist()[-10:][::-1] print('==========================================') for i, idx in enumerate(top_idxs): key = str(idx) class_name = mapping[key][1] print('{}.\t{}'.format(i+1, class_name)) print('==========================================') if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('image', type=str, help='Path to image that we want to classify') parser.add_argument('checkpoint', type=str, help='Path to ShuffleNet checkpoint with trained weights') parser.add_argument('idx_to_class', type=str, help='Path to JSON file mapping indexes to class names') parser.add_argument('--num_classes', type=int, help='Number of classes to predict', default=1000) args = parser.parse_args() infer(args)