import os import torch import torchvision.transforms as transforms import torchvision.datasets as dset import torch.utils.data as data class FolderSubset(data.Dataset): def __init__(self, dataset, classes, indices): self.dataset = dataset self.classes = classes self.indices = indices self.update_classes() def update_classes(self): for i in self.indices: img_path, cls = self.dataset.samples[i] cls = self.classes.index(cls) self.dataset.samples[i] = (img_path, cls) def __getitem__(self, idx): return self.dataset[self.indices[idx]] def __len__(self): return len(self.indices) def check_split(opt): splits = [] for split in ['train', 'val', 'test']: splits.append(torch.load('split/' + opt.datasplit + '-' + split)) return splits def check_dataset(opt): normalize_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) train_large_transform = transforms.Compose([transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip()]) val_large_transform = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224)]) train_small_transform = transforms.Compose([transforms.Pad(4), transforms.RandomCrop(32), transforms.RandomHorizontalFlip()]) splits = check_split(opt) if opt.dataset in ['cub200', 'indoor', 'stanford40', 'dog']: train, val = 'train', 'test' train_transform = transforms.Compose([train_large_transform, normalize_transform]) val_transform = transforms.Compose([val_large_transform, normalize_transform]) sets = [dset.ImageFolder(root=os.path.join(opt.dataroot, train), transform=train_transform), dset.ImageFolder(root=os.path.join(opt.dataroot, train), transform=val_transform), dset.ImageFolder(root=os.path.join(opt.dataroot, val), transform=val_transform)] sets = [FolderSubset(dataset, *split) for dataset, split in zip(sets, splits)] opt.num_classes = len(splits[0][0]) else: raise Exception('Unknown dataset') loaders = [torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize, shuffle=True, num_workers=0) for dataset in sets] return loaders