import random import torch import os from glob import glob from torch.utils.data import Dataset from torchvision import transforms from PIL import Image class SeasonTransferDataset(Dataset): def __init__(self, opt): self.image_path = opt.dataroot self.is_train = opt.is_train self.d_num = opt.n_attribute print ('Start preprocessing dataset..!') random.seed(1234) self.preprocess() print ('Finished preprocessing dataset..!') if self.is_train: trs = [transforms.Resize(opt.load_size, interpolation=Image.ANTIALIAS), transforms.RandomCrop(opt.fine_size)] else: trs = [transforms.Resize(opt.load_size, interpolation=Image.ANTIALIAS), transforms.CenterCrop(opt.fine_size)] if opt.is_flip: trs.append(transforms.RandomHorizontalFlip()) self.transform = transforms.Compose(trs) self.norm = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) self.num_data = max(self.num) def preprocess(self): dirs = os.listdir(self.image_path) trainDirs = [dir for dir in dirs if 'train' in dir] testDirs = [dir for dir in dirs if 'test' in dir] assert len(trainDirs) == self.d_num trainDirs.sort() testDirs.sort() self.filenames = [] self.num = [] if self.is_train: for dir in trainDirs: filenames = glob("{}/{}/*.jpg".format(self.image_path,dir)) + glob("{}/{}/*.png".format(self.image_path,dir)) filenames.sort() random.shuffle(filenames) self.filenames.append(filenames) self.num.append(len(filenames)) else: for dir in testDirs: filenames = glob("{}/{}/*.jpg".format(self.image_path,dir)) + glob("{}/{}/*.png".format(self.image_path,dir)) filenames.sort() self.filenames.append(filenames) self.num.append(len(filenames)) self.labels=[[ int(j==i) for j in range(self.d_num)] for i in range(self.d_num)] def __getitem__(self, index): imgs = [] labels = [] for d in range(self.d_num): index_d = index if index < self.num[d] else random.randint(0,self.num[d]-1) img = Image.open(self.filenames[d][index_d]).convert('RGB') img = self.transform(img) img = self.norm(img) imgs.append(img) labels.append(torch.FloatTensor(self.labels[d])) return imgs, labels def __len__(self): return self.num_data