Python torchvision.datasets.SVHN Examples

The following are 27 code examples for showing how to use torchvision.datasets.SVHN(). These examples are extracted from open source projects. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.

You may check out the related API usage on the sidebar.

You may also want to check out all available functions/classes of the module torchvision.datasets , or try the search function .

Example 1
Project: mnist-svhn-transfer   Author: yunjey   File: data_loader.py    License: MIT License 8 votes vote down vote up
def get_loader(config):
    """Builds and returns Dataloader for MNIST and SVHN dataset."""
    
    transform = transforms.Compose([
                    transforms.Scale(config.image_size),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    
    svhn = datasets.SVHN(root=config.svhn_path, download=True, transform=transform)
    mnist = datasets.MNIST(root=config.mnist_path, download=True, transform=transform)

    svhn_loader = torch.utils.data.DataLoader(dataset=svhn,
                                              batch_size=config.batch_size,
                                              shuffle=True,
                                              num_workers=config.num_workers)

    mnist_loader = torch.utils.data.DataLoader(dataset=mnist,
                                               batch_size=config.batch_size,
                                               shuffle=True,
                                               num_workers=config.num_workers)
    return svhn_loader, mnist_loader 
Example 2
Project: pytorch-atda   Author: corenel   File: svhn.py    License: MIT License 6 votes vote down vote up
def get_svhn(train, get_dataset=False, batch_size=cfg.batch_size):
    """Get SVHN dataset loader."""
    # image pre-processing
    pre_process = transforms.Compose([transforms.ToTensor(),
                                      transforms.Normalize(
                                          mean=cfg.dataset_mean,
                                          std=cfg.dataset_std)])

    # dataset and data loader
    svhn_dataset = datasets.SVHN(root=cfg.data_root,
                                 split='train' if train else 'test',
                                 transform=pre_process,
                                 download=True)

    if get_dataset:
        return svhn_dataset
    else:
        svhn_data_loader = torch.utils.data.DataLoader(
            dataset=svhn_dataset,
            batch_size=batch_size,
            shuffle=True)
        return svhn_data_loader 
Example 3
Project: BatchBALD   Author: BlackHC   File: dataset_enum.py    License: GNU General Public License v3.0 6 votes vote down vote up
def get_targets(dataset):
    """Get the targets of a dataset without any target target transforms(!)."""
    if isinstance(dataset, TransformedDataset):
        return get_targets(dataset.dataset)
    if isinstance(dataset, data.Subset):
        targets = get_targets(dataset.dataset)
        return torch.as_tensor(targets)[dataset.indices]
    if isinstance(dataset, data.ConcatDataset):
        return torch.cat([get_targets(sub_dataset) for sub_dataset in dataset.datasets])

    if isinstance(
            dataset, (datasets.MNIST, datasets.ImageFolder,)
    ):
        return torch.as_tensor(dataset.targets)
    if isinstance(dataset, datasets.SVHN):
        return dataset.labels

    raise NotImplementedError(f"Unknown dataset {dataset}!") 
Example 4
Project: OCDVAEContinualLearning   Author: MrtnMndt   File: datasets.py    License: MIT License 6 votes vote down vote up
def get_dataset(self):
        """
        Uses torchvision.datasets.CIFAR100 to load dataset.
        Downloads dataset if doesn't exist already.
        Returns:
             torch.utils.data.TensorDataset: trainset, valset
        """

        trainset = datasets.SVHN('datasets/SVHN/train/', split='train', transform=self.train_transforms,
                                 target_transform=None, download=True)
        valset = datasets.SVHN('datasets/SVHN/test/', split='test', transform=self.val_transforms,
                               target_transform=None, download=True)
        extraset = datasets.SVHN('datasets/SVHN/extra', split='extra', transform=self.train_transforms,
                                 target_transform=None, download=True)

        trainset = torch.utils.data.ConcatDataset([trainset, extraset])

        return trainset, valset 
Example 5
Project: Deep_Openset_Recognition_through_Uncertainty   Author: MrtnMndt   File: datasets.py    License: MIT License 6 votes vote down vote up
def get_dataset(self):
        """
        Uses torchvision.datasets.CIFAR100 to load dataset.
        Downloads dataset if doesn't exist already.
        Returns:
             torch.utils.data.TensorDataset: trainset, valset
        """

        trainset = datasets.SVHN('datasets/SVHN/train/', split='train', transform=self.train_transforms,
                                 target_transform=None, download=True)
        valset = datasets.SVHN('datasets/SVHN/test/', split='test', transform=self.val_transforms,
                               target_transform=None, download=True)
        extraset = datasets.SVHN('datasets/SVHN/extra', split='extra', transform=self.train_transforms,
                                 target_transform=None, download=True)

        trainset = torch.utils.data.ConcatDataset([trainset, extraset])

        return trainset, valset 
Example 6
Project: dfw   Author: oval-group   File: loaders.py    License: MIT License 5 votes vote down vote up
def loaders_svhn(dataset, batch_size, cuda,
                 train_size=63257, augment=False, val_size=10000, test_size=26032,
                 test_batch_size=1000, **kwargs):

    assert dataset == 'svhn'

    root = '{}/{}'.format(os.environ['VISION_DATA'], dataset)

    # Data loading code
    mean = [0.4380, 0.4440, 0.4730]
    std = [0.1751, 0.1771, 0.1744]

    normalize = transforms.Normalize(mean=mean,
                                     std=std)

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        normalize])

    if augment:
        print('Using data augmentation on SVHN data set.')
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize])
    else:
        print('Not using data augmentation on SVHN data set.')
        transform_train = transform_test

    # define two datasets in order to have different transforms
    # on training and validation (no augmentation on validation)
    dataset = datasets.SVHN
    dataset_train = dataset(root=root, split='train',
                            transform=transform_train)
    dataset_val = dataset(root=root, split='train',
                          transform=transform_test)
    dataset_test = dataset(root=root, split='test',
                           transform=transform_test)

    return create_loaders(dataset_train, dataset_val,
                          dataset_test, train_size, val_size, test_size,
                          batch_size, test_batch_size, cuda, num_workers=4) 
Example 7
Project: cycada_release   Author: jhoffman   File: svhn_balanced.py    License: BSD 2-Clause "Simplified" License 5 votes vote down vote up
def __init__(self, root, train=True,
            transform=None, target_transform=None, download=False):
        if train:
            split = 'train'
        else:
            split = 'test'
        super(SVHN, self).__init__(root, split=split, transform=transform,
                target_transform=target_transform, download=download)

        # Subsample images to balance the training set
       
        if split == 'train':
            # compute the histogram of original label set
            label_set = np.unique(self.labels)
            num_cls = len(label_set)
            count,_ = np.histogram(self.labels.squeeze(), bins=num_cls)
            min_num = min(count)
            
            # subsample
            ind = np.zeros((num_cls, min_num), dtype=int)
            for i in label_set:
                binary_ind = np.where(self.labels.squeeze() == i)[0]
                np.random.shuffle(binary_ind)
                
                ind[i % num_cls,:] = binary_ind[:min_num]
            
            ind = ind.flatten()
            # shuffle 5 times
            for i in range(100):
                np.random.shuffle(ind)
            self.labels = self.labels[ind]
            self.data = self.data[ind] 
Example 8
Project: cycada_release   Author: jhoffman   File: svhn.py    License: BSD 2-Clause "Simplified" License 5 votes vote down vote up
def __init__(self, root, train=True,
            transform=None, target_transform=None, download=False):
        if train:
            split = 'train'
        else:
            split = 'test'
        super(SVHN, self).__init__(root, split=split, transform=transform,
                target_transform=target_transform, download=download) 
Example 9
Project: imgclsmob   Author: osmr   File: svhn_cls_dataset.py    License: MIT License 5 votes vote down vote up
def __init__(self):
        super(SVHNMetaInfo, self).__init__()
        self.label = "SVHN"
        self.root_dir_name = "svhn"
        self.dataset_class = SVHNFine
        self.num_training_samples = 73257 
Example 10
Project: convex_adversarial   Author: locuslab   File: problems.py    License: MIT License 5 votes vote down vote up
def svhn_loaders(batch_size): 
    train = datasets.SVHN("./data", split='train', download=True, transform=transforms.ToTensor(), target_transform=replace_10_with_0)
    test = datasets.SVHN("./data", split='test', download=True, transform=transforms.ToTensor(), target_transform=replace_10_with_0)
    train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True, pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test, batch_size=batch_size, shuffle=False, pin_memory=True)
    return train_loader, test_loader 
Example 11
Project: UFDN   Author: Alexander-H-Liu   File: data.py    License: MIT License 5 votes vote down vote up
def LoadSVHN(data_root, batch_size=32, split='train', shuffle=True):
    if not os.path.exists(data_root):
        os.makedirs(data_root)
    svhn_dataset = datasets.SVHN(data_root, split=split, download=True,
                                   transform=transforms.ToTensor())
    return DataLoader(svhn_dataset,batch_size=batch_size, shuffle=shuffle, drop_last=True) 
Example 12
Project: RobustDARTS   Author: automl   File: args.py    License: Apache License 2.0 5 votes vote down vote up
def get_train_val_loaders(self):
        if self.args.dataset == 'cifar10':
            train_transform, valid_transform = utils._data_transforms_cifar10(self.args)
            train_data = dset.CIFAR10(
                root=self.args.data, train=True, download=True, transform=train_transform)
            valid_data = dset.CIFAR10(
                root=self.args.data, train=False, download=True, transform=valid_transform)
        elif self.args.dataset == 'cifar100':
            train_transform, valid_transform = utils._data_transforms_cifar100(self.args)
            train_data = dset.CIFAR100(
                root=self.args.data, train=True, download=True, transform=train_transform)
            valid_data = dset.CIFAR100(
                root=self.args.data, train=False, download=True, transform=valid_transform)
        elif self.args.dataset == 'svhn':
            train_transform, valid_transform = utils._data_transforms_svhn(self.args)
            train_data = dset.SVHN(
                root=self.args.data, split='train', download=True, transform=train_transform)
            valid_data = dset.SVHN(
                root=self.args.data, split='test', download=True, transform=valid_transform)

        train_queue = torch.utils.data.DataLoader(
            train_data, batch_size=self.args.batch_size,
            shuffle=True, pin_memory=True, num_workers=2)

        valid_queue = torch.utils.data.DataLoader(
            valid_data, batch_size=self.args.batch_size,
            shuffle=False, pin_memory=True, num_workers=2)

        return train_queue, valid_queue, train_transform, valid_transform 
Example 13
Project: RobustDARTS   Author: automl   File: args.py    License: Apache License 2.0 5 votes vote down vote up
def get_train_val_loaders(self):
        if self.args.dataset == 'cifar10':
            train_transform, valid_transform = utils._data_transforms_cifar10(self.args)
            train_data = dset.CIFAR10(root=self.args.data, train=True, download=True, transform=train_transform)
        elif self.args.dataset == 'cifar100':
            train_transform, valid_transform = utils._data_transforms_cifar100(self.args)
            train_data = dset.CIFAR100(root=self.args.data, train=True, download=True, transform=train_transform)
        elif self.args.dataset == 'svhn':
            train_transform, valid_transform = utils._data_transforms_svhn(self.args)
            train_data = dset.SVHN(root=self.args.data, split='train', download=True, transform=train_transform)

        num_train = len(train_data)
        indices = list(range(num_train))
        split = int(np.floor(self.args.train_portion * num_train))

        train_queue = torch.utils.data.DataLoader(
            train_data, batch_size=self.args.batch_size,
            sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
            pin_memory=True, num_workers=2)

        valid_queue = torch.utils.data.DataLoader(
            train_data, batch_size=self.args.batch_size,
            sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]),
            pin_memory=True, num_workers=2)

        return train_queue, valid_queue, train_transform, valid_transform 
Example 14
Project: Confident_classifier   Author: alinlab   File: data_loader.py    License: MIT License 5 votes vote down vote up
def getSVHN(batch_size, img_size=32, data_root='/tmp/public_dataset/pytorch', train=True, val=True, **kwargs):
    data_root = os.path.expanduser(os.path.join(data_root, 'svhn-data'))
    num_workers = kwargs.setdefault('num_workers', 1)
    kwargs.pop('input_size', None)
    print("Building SVHN data loader with {} workers".format(num_workers))

    def target_transform(target):
        new_target = target - 1
        if new_target == -1:
            new_target = 9
        return new_target

    ds = []
    if train:
        train_loader = torch.utils.data.DataLoader(
            datasets.SVHN(
                root=data_root, split='train', download=True,
                transform=transforms.Compose([
                    transforms.Scale(img_size),
                    transforms.ToTensor(),
                ]),
                target_transform=target_transform,
            ),
            batch_size=batch_size, shuffle=True, **kwargs)
        ds.append(train_loader)

    if val:
        test_loader = torch.utils.data.DataLoader(
            datasets.SVHN(
                root=data_root, split='test', download=True,
                transform=transforms.Compose([
                    transforms.Scale(img_size),
                    transforms.ToTensor(),
                ]),
                target_transform=target_transform
            ),
            batch_size=batch_size, shuffle=False, **kwargs)
        ds.append(test_loader)
    ds = ds[0] if len(ds) == 1 else ds
    return ds 
Example 15
Project: pytorch-playground   Author: aaron-xichen   File: dataset.py    License: MIT License 5 votes vote down vote up
def get(batch_size, data_root='/tmp/public_dataset/pytorch', train=True, val=True, **kwargs):
    data_root = os.path.expanduser(os.path.join(data_root, 'svhn-data'))
    num_workers = kwargs.setdefault('num_workers', 1)
    kwargs.pop('input_size', None)
    print("Building SVHN data loader with {} workers".format(num_workers))

    def target_transform(target):
        return int(target) - 1

    ds = []
    if train:
        train_loader = torch.utils.data.DataLoader(
            datasets.SVHN(
                root=data_root, split='train', download=True,
                transform=transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]),
                target_transform=target_transform,
            ),
            batch_size=batch_size, shuffle=True, **kwargs)
        ds.append(train_loader)

    if val:
        test_loader = torch.utils.data.DataLoader(
            datasets.SVHN(
                root=data_root, split='test', download=True,
                transform=transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]),
                target_transform=target_transform
            ),
            batch_size=batch_size, shuffle=False, **kwargs)
        ds.append(test_loader)
    ds = ds[0] if len(ds) == 1 else ds
    return ds 
Example 16
Project: ssl_bad_gan   Author: kimiyoung   File: data.py    License: MIT License 5 votes vote down vote up
def get_svhn_loaders(config):
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    training_set = SVHN(config.data_root, split='train', download=True, transform=transform)
    dev_set = SVHN(config.data_root, split='test', download=True, transform=transform)

    def preprocess(data_set):
        for i in range(len(data_set.data)):
            if data_set.labels[i][0] == 10:
                data_set.labels[i][0] = 0
    preprocess(training_set)
    preprocess(dev_set)

    indices = np.arange(len(training_set))
    np.random.shuffle(indices)
    mask = np.zeros(indices.shape[0], dtype=np.bool)
    labels = np.array([training_set[i][1] for i in indices], dtype=np.int64)
    for i in range(10):
        mask[np.where(labels == i)[0][: config.size_labeled_data / 10]] = True
    # labeled_indices, unlabeled_indices = indices[mask], indices[~ mask]
    labeled_indices, unlabeled_indices = indices[mask], indices
    print 'labeled size', labeled_indices.shape[0], 'unlabeled size', unlabeled_indices.shape[0], 'dev size', len(dev_set)

    labeled_loader = DataLoader(config, training_set, labeled_indices, config.train_batch_size)
    unlabeled_loader = DataLoader(config, training_set, unlabeled_indices, config.train_batch_size)
    unlabeled_loader2 = DataLoader(config, training_set, unlabeled_indices, config.train_batch_size_2)
    dev_loader = DataLoader(config, dev_set, np.arange(len(dev_set)), config.dev_batch_size)

    special_set = []
    for i in range(10):
        special_set.append(training_set[indices[np.where(labels==i)[0][0]]][0])
    special_set = torch.stack(special_set)

    return labeled_loader, unlabeled_loader, unlabeled_loader2, dev_loader, special_set 
Example 17
Project: homura   Author: moskomule   File: datasets.py    License: Apache License 2.0 5 votes vote down vote up
def __new__(cls,
                root,
                train=True,
                transform=None,
                download=False):
        if train:
            return (datasets.SVHN(root, split='train', transform=transform, download=download) +
                    datasets.SVHN(root, split='extra', transform=transform, download=download))
        else:
            return OriginalSVHN(root, train=False, transform=transform, download=download) 
Example 18
Project: realistic-ssl-evaluation-pytorch   Author: perrying   File: build_dataset.py    License: MIT License 5 votes vote down vote up
def _load_svhn():
    splits = {}
    for split in ["train", "test", "extra"]:
        tv_data = datasets.SVHN(_DATA_DIR, split, download=True)
        data = {}
        data["images"] = tv_data.data
        data["labels"] = tv_data.labels
        splits[split] = data
    return splits.values() 
Example 19
Project: VLAE   Author: yookoon   File: datasets.py    License: MIT License 5 votes vote down vote up
def __init__(self, batch_size, binarize=False, logit_transform=False):
        """ [-1, 3, 32, 32]
        """
        if binarize:
            raise NotImplementedError

        self.logit_transform = logit_transform

        directory='./datasets/SVHN'
        if not os.path.exists(directory):
            os.makedirs(directory)

        kwargs = {'num_workers': num_workers, 'pin_memory': True} if torch.cuda.is_available() else {}
        self.train_loader = DataLoader(
            datasets.SVHN(root=directory,split='train', download=True,
                           transform=transforms.ToTensor()),
            batch_size=batch_size, shuffle=True, **kwargs)
        self.test_loader = DataLoader(
            datasets.SVHN(root=directory, split='test', download=True, transform=transforms.ToTensor()),
            batch_size=batch_size, shuffle=False, **kwargs)

        self.dim = [3, 32, 32]

        train = torch.stack([data for data, _ in
                                list(self.train_loader.dataset)], 0).cuda()
        train = train.view(train.shape[0], -1)
        if self.logit_transform:
            train = train * 255.0
            train = (train + torch.rand_like(train)) / 256.0
            train = lamb + (1 - 2.0 * lamb) * train
            train = torch.log(train) - torch.log(1.0 - train)

        self.mean = train.mean(0)
        self.logvar = torch.log(torch.mean((train - self.mean)**2)).unsqueeze(0) 
Example 20
Project: imgclsmob   Author: osmr   File: cifar1.py    License: MIT License 4 votes vote down vote up
def add_dataset_parser_arguments(parser,
                                 dataset_name):
    if dataset_name == "CIFAR10":
        parser.add_argument(
            '--data-dir',
            type=str,
            default='../imgclsmob_data/cifar10',
            help='path to directory with CIFAR-10 dataset')
        parser.add_argument(
            '--num-classes',
            type=int,
            default=10,
            help='number of classes')
    elif dataset_name == "CIFAR100":
        parser.add_argument(
            '--data-dir',
            type=str,
            default='../imgclsmob_data/cifar100',
            help='path to directory with CIFAR-100 dataset')
        parser.add_argument(
            '--num-classes',
            type=int,
            default=100,
            help='number of classes')
    elif dataset_name == "SVHN":
        parser.add_argument(
            '--data-dir',
            type=str,
            default='../imgclsmob_data/svhn',
            help='path to directory with SVHN dataset')
        parser.add_argument(
            '--num-classes',
            type=int,
            default=10,
            help='number of classes')
    else:
        raise Exception('Unrecognized dataset: {}'.format(dataset_name))
    parser.add_argument(
        '--in-channels',
        type=int,
        default=3,
        help='number of input channels') 
Example 21
Project: imgclsmob   Author: osmr   File: cifar1.py    License: MIT License 4 votes vote down vote up
def get_train_data_loader(dataset_name,
                          dataset_dir,
                          batch_size,
                          num_workers):
    mean_rgb = (0.4914, 0.4822, 0.4465)
    std_rgb = (0.2023, 0.1994, 0.2010)
    jitter_param = 0.4

    transform_train = transforms.Compose([
        transforms.RandomCrop(size=32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(
            brightness=jitter_param,
            contrast=jitter_param,
            saturation=jitter_param),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=mean_rgb,
            std=std_rgb),
    ])

    if dataset_name == "CIFAR10":
        dataset = datasets.CIFAR10(
            root=dataset_dir,
            train=True,
            transform=transform_train,
            download=True)
    elif dataset_name == "CIFAR100":
        dataset = datasets.CIFAR100(
            root=dataset_dir,
            train=True,
            transform=transform_train,
            download=True)
    elif dataset_name == "SVHN":
        dataset = datasets.SVHN(
            root=dataset_dir,
            split="train",
            transform=transform_train,
            download=True)
    else:
        raise Exception('Unrecognized dataset: {}'.format(dataset_name))

    train_loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True)

    return train_loader 
Example 22
Project: imgclsmob   Author: osmr   File: cifar1.py    License: MIT License 4 votes vote down vote up
def get_val_data_loader(dataset_name,
                        dataset_dir,
                        batch_size,
                        num_workers):
    mean_rgb = (0.4914, 0.4822, 0.4465)
    std_rgb = (0.2023, 0.1994, 0.2010)

    transform_val = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(
            mean=mean_rgb,
            std=std_rgb),
    ])

    if dataset_name == "CIFAR10":
        dataset = datasets.CIFAR10(
            root=dataset_dir,
            train=False,
            transform=transform_val,
            download=True)
    elif dataset_name == "CIFAR100":
        dataset = datasets.CIFAR100(
            root=dataset_dir,
            train=False,
            transform=transform_val,
            download=True)
    elif dataset_name == "SVHN":
        dataset = datasets.SVHN(
            root=dataset_dir,
            split="test",
            transform=transform_val,
            download=True)
    else:
        raise Exception('Unrecognized dataset: {}'.format(dataset_name))

    val_loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True)

    return val_loader 
Example 23
Project: pytorch_DANN   Author: CuthbertCai   File: utils.py    License: MIT License 4 votes vote down vote up
def get_train_loader(dataset):
    """
    Get train dataloader of source domain or target domain
    :return: dataloader
    """
    if dataset == 'MNIST':
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean= params.dataset_mean, std= params.dataset_std)
        ])

        data = datasets.MNIST(root= params.mnist_path, train= True, transform= transform,
                              download= True)

        dataloader = DataLoader(dataset= data, batch_size= params.batch_size, shuffle= True)


    elif dataset == 'MNIST_M':
        transform = transforms.Compose([
            transforms.RandomCrop((28)),
            transforms.ToTensor(),
            transforms.Normalize(mean= params.dataset_mean, std= params.dataset_std)
        ])

        data = datasets.ImageFolder(root=params.mnistm_path + '/train', transform= transform)

        dataloader = DataLoader(dataset = data, batch_size= params.batch_size, shuffle= True)

    elif dataset == 'SVHN':
        transform = transforms.Compose([
            transforms.RandomCrop((28)),
            transforms.ToTensor(),
            transforms.Normalize(mean=params.dataset_mean, std=params.dataset_std)
        ])

        data1 = datasets.SVHN(root=params.svhn_path, split='train', transform=transform, download=True)
        data2 = datasets.SVHN(root= params.svhn_path, split= 'extra', transform = transform, download= True)

        data = torch.utils.data.ConcatDataset((data1, data2))

        dataloader = DataLoader(dataset=data, batch_size=params.batch_size, shuffle=True)
    elif dataset == 'SynDig':
        transform = transforms.Compose([
            transforms.RandomCrop((28)),
            transforms.ToTensor(),
            transforms.Normalize(mean= params.dataset_mean, std= params.dataset_std)
        ])

        data = SynDig.SynDig(root= params.syndig_path, split= 'train', transform= transform, download= False)

        dataloader = DataLoader(dataset = data, batch_size= params.batch_size, shuffle= True)


    else:
        raise Exception('There is no dataset named {}'.format(str(dataset)))

    return dataloader 
Example 24
Project: pytorch_DANN   Author: CuthbertCai   File: utils.py    License: MIT License 4 votes vote down vote up
def get_test_loader(dataset):
    """
    Get test dataloader of source domain or target domain
    :return: dataloader
    """
    if dataset == 'MNIST':
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean= params.dataset_mean, std= params.dataset_std)
        ])

        data = datasets.MNIST(root= params.mnist_path, train= False, transform= transform,
                              download= True)

        dataloader = DataLoader(dataset= data, batch_size= 1, shuffle= False)
    elif dataset == 'MNIST_M':
        transform = transforms.Compose([
            # transforms.RandomCrop((28)),
            transforms.CenterCrop((28)),
            transforms.ToTensor(),
            transforms.Normalize(mean= params.dataset_mean, std= params.dataset_std)
        ])

        data = datasets.ImageFolder(root=params.mnistm_path + '/test', transform= transform)

        dataloader = DataLoader(dataset = data, batch_size= 1, shuffle= False)
    elif dataset == 'SVHN':
        transform = transforms.Compose([
            transforms.CenterCrop((28)),
            transforms.ToTensor(),
            transforms.Normalize(mean= params.dataset_mean, std = params.dataset_std)
        ])

        data = datasets.SVHN(root= params.svhn_path, split= 'test', transform = transform, download= True)

        dataloader = DataLoader(dataset = data, batch_size= 1, shuffle= False)
    elif dataset == 'SynDig':
        transform = transforms.Compose([
            transforms.CenterCrop((28)),
            transforms.ToTensor(),
            transforms.Normalize(mean=params.dataset_mean, std=params.dataset_std)
        ])

        data = SynDig.SynDig(root= params.syndig_path, split= 'test', transform= transform, download= False)

        dataloader = DataLoader(dataset= data, batch_size= 1, shuffle= False)
    else:
        raise Exception('There is no dataset named {}'.format(str(dataset)))

    return dataloader 
Example 25
Project: CROWN-IBP   Author: huanzhang12   File: datasets.py    License: BSD 2-Clause "Simplified" License 4 votes vote down vote up
def svhn_loaders(batch_size, shuffle_train = True, shuffle_test = False, train_random_transform = False, normalize_input = False, num_examples = None, test_batch_size=None): 
    if normalize_input:
        mean = [0.43768206, 0.44376972, 0.47280434] 
        std = [0.19803014, 0.20101564, 0.19703615]
        normalize = transforms.Normalize(mean = mean, std = std)
    else:
        std = [1.0, 1.0, 1.0]
        mean = [0, 0, 0]
        normalize = transforms.Normalize(mean = mean, std = std)
    if train_random_transform:
        if normalize_input:
            train = datasets.SVHN('./data', split='train', download=True, 
                transform=transforms.Compose([
                    transforms.RandomCrop(32, 4),
                    transforms.ToTensor(),
                    normalize,
                ]))
        else:
            train = datasets.SVHN('./data', split='train', download=True, 
                transform=transforms.Compose([
                    transforms.RandomCrop(32, 4),
                    transforms.ToTensor(),
                ]))
    else:
        train = datasets.SVHN('./data', split='train', download=True, 
            transform=transforms.Compose([transforms.ToTensor(),normalize]))
    test = datasets.SVHN('./data', split='test', download=True,
        transform=transforms.Compose([transforms.ToTensor(), normalize]))
    
    if num_examples:
        indices = list(range(num_examples))
        train = data.Subset(train, indices)
        test = data.Subset(test, indices)

    train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size,
        shuffle=shuffle_train, pin_memory=True, num_workers=min(multiprocessing.cpu_count(),6))
    if test_batch_size:
        batch_size = test_batch_size
    test_loader = torch.utils.data.DataLoader(test, batch_size=max(batch_size, 1),
        shuffle=shuffle_test, pin_memory=True, num_workers=min(multiprocessing.cpu_count(),6))
    train_loader.std = std
    test_loader.std = std
    train_loader.mean = mean
    test_loader.mean = mean
    mean, std = get_stats(train_loader)
    print('dataset mean = ', mean.numpy(), 'std = ', std.numpy())
    return train_loader, test_loader

# when new loaders is added, they must be registered here 
Example 26
Project: ffjord   Author: rtqichen   File: viz_cnf.py    License: MIT License 4 votes vote down vote up
def get_dataset(args):
    trans = lambda im_size: tforms.Compose([tforms.Resize(im_size), tforms.ToTensor(), add_noise])

    if args.data == "mnist":
        im_dim = 1
        im_size = 28 if args.imagesize is None else args.imagesize
        train_set = dset.MNIST(root="./data", train=True, transform=trans(im_size), download=True)
        test_set = dset.MNIST(root="./data", train=False, transform=trans(im_size), download=True)
    elif args.data == "svhn":
        im_dim = 3
        im_size = 32 if args.imagesize is None else args.imagesize
        train_set = dset.SVHN(root="./data", split="train", transform=trans(im_size), download=True)
        test_set = dset.SVHN(root="./data", split="test", transform=trans(im_size), download=True)
    elif args.data == "cifar10":
        im_dim = 3
        im_size = 32 if args.imagesize is None else args.imagesize
        train_set = dset.CIFAR10(root="./data", train=True, transform=trans(im_size), download=True)
        test_set = dset.CIFAR10(root="./data", train=False, transform=trans(im_size), download=True)
    elif args.dataset == 'celeba':
        im_dim = 3
        im_size = 64 if args.imagesize is None else args.imagesize
        train_set = dset.CelebA(
            train=True, transform=tforms.Compose([
                tforms.ToPILImage(),
                tforms.Resize(im_size),
                tforms.RandomHorizontalFlip(),
                tforms.ToTensor(),
                add_noise,
            ])
        )
        test_set = dset.CelebA(
            train=False, transform=tforms.Compose([
                tforms.ToPILImage(),
                tforms.Resize(args.imagesize),
                tforms.ToTensor(),
                add_noise,
            ])
        )
    data_shape = (im_dim, im_size, im_size)

    train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=args.batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=args.batch_size, shuffle=False)
    return train_loader, test_loader, data_shape 
Example 27
Project: Evolutionary-Autoencoders   Author: sg-nm   File: cnn_train.py    License: MIT License 4 votes vote down vote up
def __init__(self, dataset_name, validation=True, verbose=True, imgSize=64, batchsize=16):
        # dataset_name: name of data set ('celebA' or 'cars' or 'svhn')
        # validation  : [True]  model train/validation mode
        #               [False] model test mode for final evaluation of the evolved model
        # verbose     : flag of display
        self.verbose = verbose
        self.imgSize = imgSize
        self.validation = validation
        self.batchsize = batchsize
        self.channel = 3
        num_work = 2

        # load dataset
        if dataset_name == 'svhn' or dataset_name == 'celebA' or dataset_name == 'cars':
            if dataset_name == 'svhn':
                if self.validation:
                    dataset = dset.SVHN(root='./svhn', split='train', download=True,
                            transform=transforms.Compose([transforms.RandomHorizontalFlip(),transforms.Scale(self.imgSize),transforms.ToTensor(),]))
                    test_dataset = dset.SVHN(root='./svhn', split='extra', download=True, 
                            transform=transforms.Compose([transforms.Scale(self.imgSize), transforms.ToTensor(),]))
                else:
                    dataset = dset.SVHN(root='./svhn', split='train', download=True,
                            transform=transforms.Compose([transforms.Scale(self.imgSize),transforms.ToTensor(),]))
                    test_dataset = dset.SVHN(root='./svhn', split='test', download=True, 
                            transform=transforms.Compose([transforms.Scale(self.imgSize),transforms.ToTensor(),]))
                self.dataloader = torch.utils.data.DataLoader(dataset, batch_size=self.batchsize, shuffle=True, num_workers=int(num_work), drop_last=True)
                self.test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=int(num_work))
            elif dataset_name == 'celebA':
                if self.validation:
                    data_transform = transforms.Compose([transforms.RandomHorizontalFlip(),transforms.ToTensor()])
                    test_data_transform = transforms.Compose([transforms.ToTensor()])
                    dataset = dset.ImageFolder(root='/dataset/celebA/train', transform=data_transform)
                    test_dataset = dset.ImageFolder(root='/dataset/celebA/val', transform=test_data_transform)
                else:
                    data_transform = transforms.Compose([transforms.RandomHorizontalFlip(),transforms.ToTensor()])
                    test_data_transform = transforms.Compose([transforms.ToTensor()])
                    dataset = dset.ImageFolder(root='/dataset/celebA/train', transform=data_transform)
                    test_dataset = dset.ImageFolder(root='/dataset/celebA/test', transform=test_data_transform)
                self.dataloader = torch.utils.data.DataLoader(dataset, batch_size=self.batchsize, shuffle=True, num_workers=int(num_work), drop_last=True)
                self.test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=int(num_work))
            elif dataset_name == 'cars':
                if self.validation:
                    data_transform = transforms.Compose([transforms.RandomHorizontalFlip(),transforms.ToTensor()])
                    test_data_transform = transforms.Compose([transforms.ToTensor()])
                    dataset = dset.ImageFolder(root='/dataset/cars/train', transform=data_transform)
                    test_dataset = dset.ImageFolder(root='/dataset/cars/val', transform=test_data_transform)
                else:
                    data_transform = transforms.Compose([transforms.RandomHorizontalFlip(),transforms.ToTensor()])
                    test_data_transform = transforms.Compose([transforms.ToTensor()])
                    dataset = dset.ImageFolder(root='/dataset/cars/retrain', transform=data_transform)
                    test_dataset = dset.ImageFolder(root='/dataset/cars/test', transform=test_data_transform)
                self.dataloader = torch.utils.data.DataLoader(dataset, batch_size=self.batchsize, shuffle=True, num_workers=int(num_work), drop_last=True)
                self.test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=int(num_work))
            print('train num', len(self.dataloader.dataset))
            print('test num ', len(self.test_dataloader.dataset))
        else:
            print('\tInvalid input dataset name at CNN_train()')
            exit(1)