from __future__ import print_function import os import os.path as osp import numpy as np import pickle import sys import torch import torch.utils.data as data import torchvision.transforms as transforms from PIL import Image # Set the appropriate paths of the datasets here. THIS_PATH = osp.dirname(__file__) ROOT_PATH1 = osp.abspath(osp.join(THIS_PATH, '..', '..', '..')) ROOT_PATH2 = osp.abspath(osp.join(THIS_PATH, '..', '..')) IMAGE_PATH = osp.join(ROOT_PATH1, 'data/tieredimagenet/') SPLIT_PATH = osp.join(ROOT_PATH2, 'data/miniimagenet/split') def buildLabelIndex(labels): label2inds = {} for idx, label in enumerate(labels): if label not in label2inds: label2inds[label] = [] label2inds[label].append(idx) return label2inds def load_data(file): try: with open(file, 'rb') as fo: data = pickle.load(fo) return data except: with open(file, 'rb') as f: u = pickle._Unpickler(f) u.encoding = 'latin1' data = u.load() return data file_path = {'train':[os.path.join(IMAGE_PATH, 'train_images.npz'), os.path.join(IMAGE_PATH, 'train_labels.pkl')], 'val':[os.path.join(IMAGE_PATH, 'val_images.npz'), os.path.join(IMAGE_PATH,'val_labels.pkl')], 'test':[os.path.join(IMAGE_PATH, 'test_images.npz'), os.path.join(IMAGE_PATH, 'test_labels.pkl')]} class tieredImageNet(data.Dataset): def __init__(self, setname, args, augment=False): assert(setname=='train' or setname=='val' or setname=='test') image_path = file_path[setname][0] label_path = file_path[setname][1] data_train = load_data(label_path) labels = data_train['labels'] self.data = np.load(image_path)['images'] label = [] lb = -1 self.wnids = [] for wnid in labels: if wnid not in self.wnids: self.wnids.append(wnid) lb += 1 label.append(lb) self.label = label self.num_class = len(set(label)) if augment and setname == 'train': transforms_list = [ transforms.RandomCrop(84, padding=8), transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), ] else: transforms_list = [ transforms.ToTensor(), ] # Transformation if args.backbone_class == 'ConvNet': self.transform = transforms.Compose( transforms_list + [ transforms.Normalize(np.array([0.485, 0.456, 0.406]), np.array([0.229, 0.224, 0.225])) ]) elif args.backbone_class == 'ResNet': self.transform = transforms.Compose( transforms_list + [ transforms.Normalize(np.array([x / 255.0 for x in [125.3, 123.0, 113.9]]), np.array([x / 255.0 for x in [63.0, 62.1, 66.7]])) ]) elif args.backbone_class == 'Res12': self.transform = transforms.Compose( transforms_list + [ transforms.Normalize(np.array([x / 255.0 for x in [120.39586422, 115.59361427, 104.54012653]]), np.array([x / 255.0 for x in [70.68188272, 68.27635443, 72.54505529]])) ]) elif args.backbone_class == 'Res18': self.transform = transforms.Compose( transforms_list + [ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) elif args.backbone_class == 'WRN': self.transform = transforms.Compose( transforms_list + [ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) else: raise ValueError('Non-supported Network Types. Please Revise Data Pre-Processing Scripts.') def __getitem__(self, index): img, label = self.data[index], self.label[index] img = self.transform(Image.fromarray(img)) return img, label def __len__(self): return len(self.data)