from __future__ import print_function import numbers import torch import torch.utils.data as data_utils import pickle from scipy.io import loadmat import numpy as np import os from torch.utils.data import Dataset import torchvision from torchvision import transforms from torchvision.transforms import functional as vf from torch.utils.data import ConcatDataset from PIL import Image import os import os.path from os.path import join import sys import tarfile class ToTensorNoNorm(): def __call__(self, X_i): return torch.from_numpy(np.array(X_i, copy=False)).permute(2, 0, 1) class PadToMultiple(object): def __init__(self, multiple, fill=0, padding_mode='constant'): assert isinstance(multiple, numbers.Number) assert isinstance(fill, (numbers.Number, str, tuple)) assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'] self.multiple = multiple self.fill = fill self.padding_mode = padding_mode def __call__(self, img): """ Args: img (PIL Image): Image to be padded. Returns: PIL Image: Padded image. """ w, h = img.size m = self.multiple nw = (w // m + int((w % m) != 0)) * m nh = (h // m + int((h % m) != 0)) * m padw = nw - w padh = nh - h out = vf.pad(img, (0, 0, padw, padh), self.fill, self.padding_mode) return out def __repr__(self): return self.__class__.__name__ + '(multiple={0}, fill={1}, padding_mode={2})'.\ format(self.mulitple, self.fill, self.padding_mode) class CustomTensorDataset(Dataset): """Dataset wrapping tensors. Each sample will be retrieved by indexing tensors along the first dimension. Arguments: *tensors (Tensor): tensors that have the same size of the first dimension. """ def __init__(self, *tensors, transform=None): assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors) self.tensors = tensors self.transform = transform def __getitem__(self, index): from PIL import Image X, y = self.tensors X_i, y_i, = X[index], y[index] if self.transform: X_i = self.transform(X_i) X_i = torch.from_numpy(np.array(X_i, copy=False)) X_i = X_i.permute(2, 0, 1) return X_i, y_i def __len__(self): return self.tensors[0].size(0) def load_cifar10(args, **kwargs): # set args args.input_size = [3, 32, 32] args.input_type = 'continuous' args.dynamic_binarization = False from keras.datasets import cifar10 (x_train, y_train), (x_test, y_test) = cifar10.load_data() x_train = x_train.transpose(0, 3, 1, 2) x_test = x_test.transpose(0, 3, 1, 2) import math if args.data_augmentation_level == 2: data_transform = transforms.Compose([ transforms.ToPILImage(), transforms.RandomHorizontalFlip(), transforms.Pad(int(math.ceil(32 * 0.05)), padding_mode='edge'), transforms.RandomAffine(degrees=0, translate=(0.05, 0.05)), transforms.CenterCrop(32) ]) elif args.data_augmentation_level == 1: data_transform = transforms.Compose([ transforms.ToPILImage(), transforms.RandomHorizontalFlip(), ]) else: data_transform = transforms.Compose([ transforms.ToPILImage(), ]) x_val = x_train[-10000:] y_val = y_train[-10000:] x_train = x_train[:-10000] y_train = y_train[:-10000] train = CustomTensorDataset(torch.from_numpy(x_train), torch.from_numpy(y_train), transform=data_transform) train_loader = data_utils.DataLoader(train, batch_size=args.batch_size, shuffle=True, **kwargs) validation = data_utils.TensorDataset(torch.from_numpy(x_val), torch.from_numpy(y_val)) val_loader = data_utils.DataLoader(validation, batch_size=args.batch_size, shuffle=False, **kwargs) test = data_utils.TensorDataset(torch.from_numpy(x_test), torch.from_numpy(y_test)) test_loader = data_utils.DataLoader(test, batch_size=args.batch_size, shuffle=False, **kwargs) return train_loader, val_loader, test_loader, args def extract_tar(tarpath): assert tarpath.endswith('.tar') startdir = tarpath[:-4] + '/' if os.path.exists(startdir): return startdir print('Extracting', tarpath) with tarfile.open(name=tarpath) as tar: t = 0 done = False while not done: path = join(startdir, 'images{}'.format(t)) os.makedirs(path, exist_ok=True) print(path) for i in range(50000): member = tar.next() if member is None: done = True break # Skip directories while member.isdir(): member = tar.next() if member is None: done = True break member.name = member.name.split('/')[-1] tar.extract(member, path=path) t += 1 return startdir def load_imagenet(resolution, args, **kwargs): assert resolution == 32 or resolution == 64 args.input_size = [3, resolution, resolution] trainpath = '../imagenet{res}/train_{res}x{res}.tar'.format(res=resolution) valpath = '../imagenet{res}/valid_{res}x{res}.tar'.format(res=resolution) trainpath = extract_tar(trainpath) valpath = extract_tar(valpath) data_transform = transforms.Compose([ ToTensorNoNorm() ]) print('Starting loading ImageNet') imagenet_data = torchvision.datasets.ImageFolder( trainpath, transform=data_transform) print('Number of data images', len(imagenet_data)) val_idcs = np.random.choice(len(imagenet_data), size=20000, replace=False) train_idcs = np.setdiff1d(np.arange(len(imagenet_data)), val_idcs) train_dataset = torch.utils.data.dataset.Subset( imagenet_data, train_idcs) val_dataset = torch.utils.data.dataset.Subset( imagenet_data, val_idcs) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, **kwargs) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=args.batch_size, shuffle=False, **kwargs) test_dataset = torchvision.datasets.ImageFolder( valpath, transform=data_transform) print('Number of val images:', len(test_dataset)) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=args.batch_size, shuffle=False, **kwargs) return train_loader, val_loader, test_loader, args def load_dataset(args, **kwargs): if args.dataset == 'cifar10': train_loader, val_loader, test_loader, args = load_cifar10(args, **kwargs) elif args.dataset == 'imagenet32': train_loader, val_loader, test_loader, args = load_imagenet(32, args, **kwargs) elif args.dataset == 'imagenet64': train_loader, val_loader, test_loader, args = load_imagenet(64, args, **kwargs) else: raise Exception('Wrong name of the dataset!') return train_loader, val_loader, test_loader, args if __name__ == '__main__': class Args(): def __init__(self): self.batch_size = 128 train_loader, val_loader, test_loader, args = load_imagenet32(Args())