""" Pytorch implementation of Self-Supervised GAN Reference: "Self-Supervised GANs via Auxiliary Rotation Loss" Authors: Ting Chen, Xiaohua Zhai, Marvin Ritter, Mario Lucic and Neil Houlsby https://arxiv.org/abs/1811.11212 CVPR 2019. Script Author: Vandit Jain. Github:vandit15 """ import torch from torch.utils.data import DataLoader from torchvision import datasets, transforms from torch.utils.data import Dataset import torchvision.transforms.functional as TF import random from typing import Sequence def get_mnist_dataloaders(batch_size=128): """MNIST dataloader with (32, 32) sized images.""" # Resize images so they are a power of 2 all_transforms = transforms.Compose([ transforms.Resize(32), transforms.ToTensor() ]) # Get train and test data train_data = datasets.MNIST('../data', train=True, download=True, transform=all_transforms) test_data = datasets.MNIST('../data', train=False, transform=all_transforms) # Create dataloaders train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True) test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True) return train_loader, test_loader def get_fashion_mnist_dataloaders(batch_size=128): """Fashion MNIST dataloader with (32, 32) sized images.""" # Resize images so they are a power of 2 all_transforms = transforms.Compose([ transforms.Resize(32), transforms.ToTensor() ]) # Get train and test data train_data = datasets.FashionMNIST('../fashion_data', train=True, download=True, transform=all_transforms) test_data = datasets.FashionMNIST('../fashion_data', train=False, transform=all_transforms) # Create dataloaders train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True) test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True) return train_loader, test_loader def get_lsun_dataloader(path_to_data='../lsun', dataset='bedroom_train', batch_size=64): """LSUN dataloader with (128, 128) sized images. path_to_data : str One of 'bedroom_val' or 'bedroom_train' """ # Compose transforms transform = transforms.Compose([ transforms.Resize(128), transforms.CenterCrop(128), transforms.ToTensor() ]) # Get dataset lsun_dset = datasets.LSUN(db_path=path_to_data, classes=[dataset], transform=transform) # Create dataloader return DataLoader(lsun_dset, batch_size=batch_size, shuffle=True)