import torch import numpy as np import sys from urllib import request from torch.utils.data import Dataset sys.path.append("../semi-supervised") n_labels = 10 cuda = torch.cuda.is_available() class SpriteDataset(Dataset): """ A PyTorch wrapper for the dSprites dataset by Matthey et al. 2017. The dataset provides a 2D scene with a sprite under different transformations: * color * shape * scale * orientation * x-position * y-position """ def __init__(self, transform=None): self.transform = transform url = "https://github.com/deepmind/dsprites-dataset/raw/master/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz" try: self.dset = np.load("./dsprites.npz", encoding="bytes")["imgs"] except FileNotFoundError: request.urlretrieve(url, "./dsprites.npz") self.dset = np.load("./dsprites.npz", encoding="bytes")["imgs"] def __len__(self): return len(self.dset) def __getitem__(self, idx): sample = self.dset[idx] if self.transform: sample = self.transform(sample) return sample def get_mnist(location="./", batch_size=64, labels_per_class=100): from functools import reduce from operator import __or__ from torch.utils.data.sampler import SubsetRandomSampler from torchvision.datasets import MNIST import torchvision.transforms as transforms from utils import onehot flatten_bernoulli = lambda x: transforms.ToTensor()(x).view(-1).bernoulli() mnist_train = MNIST(location, train=True, download=True, transform=flatten_bernoulli, target_transform=onehot(n_labels)) mnist_valid = MNIST(location, train=False, download=True, transform=flatten_bernoulli, target_transform=onehot(n_labels)) def get_sampler(labels, n=None): # Only choose digits in n_labels (indices,) = np.where(reduce(__or__, [labels == i for i in np.arange(n_labels)])) # Ensure uniform distribution of labels np.random.shuffle(indices) indices = np.hstack([list(filter(lambda idx: labels[idx] == i, indices))[:n] for i in range(n_labels)]) indices = torch.from_numpy(indices) sampler = SubsetRandomSampler(indices) return sampler # Dataloaders for MNIST labelled = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, num_workers=2, pin_memory=cuda, sampler=get_sampler(mnist_train.train_labels.numpy(), labels_per_class)) unlabelled = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, num_workers=2, pin_memory=cuda, sampler=get_sampler(mnist_train.train_labels.numpy())) validation = torch.utils.data.DataLoader(mnist_valid, batch_size=batch_size, num_workers=2, pin_memory=cuda, sampler=get_sampler(mnist_valid.test_labels.numpy())) return labelled, unlabelled, validation