Python torchvision.datasets.CIFAR10 Examples

The following are 30 code examples of torchvision.datasets.CIFAR10(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torchvision.datasets , or try the search function .
Example #1
Source File: train.py    From pytorch-multigpu with MIT License 7 votes vote down vote up
def main():
    best_acc = 0

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    print('==> Preparing data..')
    transforms_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

    dataset_train = CIFAR10(root='../data', train=True, download=True, 
                            transform=transforms_train)

    train_loader = DataLoader(dataset_train, batch_size=args.batch_size, 
                              shuffle=True, num_workers=args.num_worker)

    # there are 10 classes so the dataset name is cifar-10
    classes = ('plane', 'car', 'bird', 'cat', 'deer', 
               'dog', 'frog', 'horse', 'ship', 'truck')

    print('==> Making model..')

    net = pyramidnet()
    net = nn.DataParallel(net)
    net = net.to(device)
    num_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
    print('The number of parameters of model is', num_params)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=args.lr)
    # optimizer = optim.SGD(net.parameters(), lr=args.lr, 
    #                       momentum=0.9, weight_decay=1e-4)
    
    train(net, criterion, optimizer, train_loader, device) 
Example #2
Source File: problems.py    From convex_adversarial with MIT License 6 votes vote down vote up
def cifar_loaders(batch_size, shuffle_test=False): 
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.225, 0.225, 0.225])
    train = datasets.CIFAR10('./data', train=True, download=True, 
        transform=transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, 4),
            transforms.ToTensor(),
            normalize,
        ]))
    test = datasets.CIFAR10('./data', train=False, 
        transform=transforms.Compose([transforms.ToTensor(), normalize]))
    train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size,
        shuffle=True, pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test, batch_size=batch_size,
        shuffle=shuffle_test, pin_memory=True)
    return train_loader, test_loader 
Example #3
Source File: conv_cifar_2.py    From cwcf with MIT License 6 votes vote down vote up
def get_data(train):
	data_raw = datasets.CIFAR10('../data/dl/', train=train, download=True,  transform=transforms.Compose([
							transforms.Grayscale(),
							transforms.Resize((20, 20)),
							transforms.ToTensor(),
							lambda x: x.numpy().flatten()]))

	data_x, data_y = zip(*data_raw)
	
	data_x = np.array(data_x)
	data_y = np.array(data_y, dtype='int32').reshape(-1, 1)

	# binarize
	label_0 = data_y < 5
	label_1 = ~label_0

	data_y[label_0] = 0
	data_y[label_1] = 1

	data = pd.DataFrame(data_x)
	data[COLUMN_LABEL] = data_y

	return data, data_x.mean(), data_x.std()

#--- 
Example #4
Source File: cifar10.py    From Deep-SAD-PyTorch with MIT License 6 votes vote down vote up
def __getitem__(self, index):
        """Override the original method of the CIFAR10 class.
        Args:
            index (int): Index

        Returns:
            tuple: (image, target, semi_target, index)
        """
        img, target, semi_target = self.data[index], self.targets[index], int(self.semi_targets[index])

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target, semi_target, index 
Example #5
Source File: conv_cifar.py    From cwcf with MIT License 6 votes vote down vote up
def get_data(train):
	data_raw = datasets.CIFAR10('../data/dl/', train=train, download=True,  transform=transforms.Compose([
							transforms.Grayscale(),
							transforms.Resize((20, 20)),
							transforms.ToTensor(),
							lambda x: x.numpy().flatten()]))

	data_x, data_y = zip(*data_raw)
	
	data_x = np.array(data_x)
	data_y = np.array(data_y, dtype='int32').reshape(-1, 1)

	data = pd.DataFrame(data_x)
	data[COLUMN_LABEL] = data_y

	return data, data_x.mean(), data_x.std()

#--- 
Example #6
Source File: cifar10_cls_dataset.py    From imgclsmob with MIT License 6 votes vote down vote up
def __init__(self):
        super(CIFAR10MetaInfo, self).__init__()
        self.label = "CIFAR10"
        self.short_label = "cifar"
        self.root_dir_name = "cifar10"
        self.dataset_class = CIFAR10Fine
        self.num_training_samples = 50000
        self.in_channels = 3
        self.num_classes = 10
        self.input_image_size = (32, 32)
        self.train_metric_capts = ["Train.Err"]
        self.train_metric_names = ["Top1Error"]
        self.train_metric_extra_kwargs = [{"name": "err"}]
        self.val_metric_capts = ["Val.Err"]
        self.val_metric_names = ["Top1Error"]
        self.val_metric_extra_kwargs = [{"name": "err"}]
        self.saver_acc_ind = 0
        self.train_transform = cifar10_train_transform
        self.val_transform = cifar10_val_transform
        self.test_transform = cifar10_val_transform
        self.ml_type = "imgcls" 
Example #7
Source File: evaluate.py    From pytorch_deephash with MIT License 6 votes vote down vote up
def load_data():
    transform_train = transforms.Compose(
        [transforms.Resize(227),
         transforms.ToTensor(),
         transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
    transform_test = transforms.Compose(
        [transforms.Resize(227),
         transforms.ToTensor(),
         transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
    trainset = datasets.CIFAR10(root='./data', train=True, download=True,
                                transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=100,
                                              shuffle=False, num_workers=0)

    testset = datasets.CIFAR10(root='./data', train=False, download=True,
                               transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=100,
                                             shuffle=False, num_workers=0)
    return trainloader, testloader 
Example #8
Source File: train.py    From pytorch_deephash with MIT License 6 votes vote down vote up
def init_dataset():
    transform_train = transforms.Compose(
        [transforms.Resize(256),
         transforms.RandomCrop(227),
         transforms.RandomHorizontalFlip(),
         transforms.ToTensor(),
         transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
    transform_test = transforms.Compose(
        [transforms.Resize(227),
         transforms.ToTensor(),
         transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
    trainset = datasets.CIFAR10(root='./data', train=True, download=True,
                                transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                              shuffle=True, num_workers=0)

    testset = datasets.CIFAR10(root='./data', train=False, download=True,
                               transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=100,
                                             shuffle=True, num_workers=0)
    return trainloader, testloader 
Example #9
Source File: vgg_mcdropout_cifar10.py    From baal with Apache License 2.0 6 votes vote down vote up
def get_datasets(initial_pool):
    transform = transforms.Compose(
        [transforms.Resize((224, 224)),
         transforms.RandomHorizontalFlip(),
         transforms.RandomRotation(30),
         transforms.ToTensor(),
         transforms.Normalize(3 * [0.5], 3 * [0.5]), ])
    test_transform = transforms.Compose(
        [
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(3 * [0.5], 3 * [0.5]),
        ]
    )
    # Note: We use the test set here as an example. You should make your own validation set.
    train_ds = datasets.CIFAR10('.', train=True,
                                transform=transform, target_transform=None, download=True)
    test_set = datasets.CIFAR10('.', train=False,
                                transform=test_transform, target_transform=None, download=True)

    active_set = ActiveLearningDataset(train_ds, pool_specifics={'transform': test_transform})

    # We start labeling randomly.
    active_set.label_randomly(initial_pool)
    return active_set, test_set 
Example #10
Source File: cifar10.py    From Deep-SVDD-PyTorch with MIT License 6 votes vote down vote up
def __getitem__(self, index):
        """Override the original method of the CIFAR10 class.
        Args:
            index (int): Index
        Returns:
            triple: (image, target, index) where target is index of the target class.
        """
        if self.train:
            img, target = self.train_data[index], self.train_labels[index]
        else:
            img, target = self.test_data[index], self.test_labels[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target, index  # only line changed 
Example #11
Source File: dataset.py    From jdit with Apache License 2.0 6 votes vote down vote up
def build_datasets(self):
        """ You must to rewrite this method to load your own datasets.

        * :attr:`self.dataset_train` . Assign a training ``dataset`` to this.
        * :attr:`self.dataset_valid` . Assign a valid_epoch ``dataset`` to this.
        * :attr:`self.dataset_test` is optional. Assign a test ``dataset`` to this.
          If not, it will be replaced by ``self.dataset_valid`` .

        Example::

            self.dataset_train = datasets.CIFAR10(root, train=True, download=True,
                                                  transform=transforms.Compose(self.train_transform_list))
            self.dataset_valid = datasets.CIFAR10(root, train=False, download=True,
                                                  transform=transforms.Compose(self.valid_transform_list))
        """
        pass 
Example #12
Source File: acc_under_attack.py    From RobGAN with MIT License 6 votes vote down vote up
def make_dataset():
    if opt.dataset in ("imagenet", "dog_and_cat_64", "dog_and_cat_128"):
        trans = tfs.Compose([
            tfs.Resize(opt.img_width),
            tfs.ToTensor(),
            tfs.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])])
        data = ImageFolder(opt.root, transform=trans)
        loader = DataLoader(data, batch_size=100, shuffle=False, num_workers=opt.workers)
    elif opt.dataset == "cifar10":
        trans = tfs.Compose([
            tfs.Resize(opt.img_width),
            tfs.ToTensor(),
            tfs.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])])
        data = CIFAR10(root=opt.root, train=True, download=False, transform=trans)
        loader = DataLoader(data, batch_size=100, shuffle=True, num_workers=opt.workers)
    else:
        raise ValueError(f"Unknown dataset: {opt.dataset}")
    return loader 
Example #13
Source File: data_loaders.py    From ModelFeast with MIT License 5 votes vote down vote up
def __init__(self, data_dir, batch_size, shuffle, validation_split, num_workers, training=True):
        self.data_dir = data_dir
        self.dataset = datasets.CIFAR10(self.data_dir, train=training, download=True,
         transform=self._tansform_)
        super(CIFAR10DataLoader, self).__init__(self.dataset, batch_size, shuffle, validation_split, num_workers) 
Example #14
Source File: cifar10.py    From novelty-detection with MIT License 5 votes vote down vote up
def __init__(self, path):
        # type: (str) -> None
        """
        Class constructor.

        :param path: The folder in which to download CIFAR10.
        """
        super(CIFAR10, self).__init__()

        self.path = path

        self.normal_class = None

        # Get train and test split
        self.train_split = datasets.CIFAR10(self.path, train=True, download=True, transform=None)
        self.test_split = datasets.CIFAR10(self.path, train=False, download=True, transform=None)

        # Shuffle training indexes to build a validation set (see val())
        train_idx = np.arange(len(self.train_split))
        np.random.shuffle(train_idx)
        self.shuffled_train_idx = train_idx

        # Transform zone
        self.val_transform = transforms.Compose([ToFloatTensor2D()])
        self.test_transform = transforms.Compose([ToFloat32(), OCToFloatTensor2D()])
        self.transform = None

        # Other utilities
        self.mode = None
        self.length = None
        self.val_idxs = None 
Example #15
Source File: cifar10.py    From novelty-detection with MIT License 5 votes vote down vote up
def test(self, normal_class):
        # type: (int) -> None
        """
        Sets CIFAR10 in test mode.

        :param normal_class: the class to be considered normal.
        """
        self.normal_class = int(normal_class)

        # Update mode, length and transform
        self.mode = 'test'
        self.transform = self.test_transform
        self.length = len(self.test_split) 
Example #16
Source File: cifar10.py    From novelty-detection with MIT License 5 votes vote down vote up
def val(self, normal_class):
        # type: (int) -> None
        """
        Sets CIFAR10 in validation mode.

        :param normal_class: the class to be considered normal.
        """
        self.normal_class = int(normal_class)

        # Update mode, indexes, length and transform
        self.mode = 'val'
        self.transform = self.val_transform
        self.val_idxs = self.shuffled_train_idx[int(0.9 * len(self.shuffled_train_idx)):]
        self.val_idxs = [idx for idx in self.val_idxs if self.train_split[idx][1] == self.normal_class]
        self.length = len(self.val_idxs) 
Example #17
Source File: cifar10.py    From novelty-detection with MIT License 5 votes vote down vote up
def __repr__(self):
        return f'ONE-CLASS CIFAR10 (normal class = {self.normal_class})' 
Example #18
Source File: dataset.py    From jdit with Apache License 2.0 5 votes vote down vote up
def build_datasets(self):
        self.dataset_train = datasets.CIFAR10(self.root, train=True, download=True,
                                              transform=transforms.Compose(self.train_transform_list))
        self.dataset_valid = datasets.CIFAR10(self.root, train=False, download=True,
                                              transform=transforms.Compose(self.valid_transform_list)) 
Example #19
Source File: dataset.py    From jdit with Apache License 2.0 5 votes vote down vote up
def build_datasets(self):
        self.dataset_train = datasets.CIFAR10(self.root, train=True, download=True,
                                              transform=transforms.Compose(self.train_transform_list))
        self.dataset_valid = datasets.CIFAR10(self.root, train=False, download=True,
                                              transform=transforms.Compose(self.valid_transform_list)) 
Example #20
Source File: denseprune.py    From rethinking-network-pruning with MIT License 5 votes vote down vote up
def test(model):
    kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
    if args.dataset == 'cifar10':
        test_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])),
            batch_size=args.test_batch_size, shuffle=False, **kwargs)
    elif args.dataset == 'cifar100':
        test_loader = torch.utils.data.DataLoader(
            datasets.CIFAR100('./data.cifar100', train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])),
            batch_size=args.test_batch_size, shuffle=False, **kwargs)
    else:
        raise ValueError("No valid dataset is given.")
    model.eval()
    correct = 0
    for data, target in test_loader:
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data, volatile=True), Variable(target)
        output = model(data)
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()

    print('\nTest set: Accuracy: {}/{} ({:.1f}%)\n'.format(
        correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset)))
    return correct / float(len(test_loader.dataset)) 
Example #21
Source File: vggprune.py    From rethinking-network-pruning with MIT License 5 votes vote down vote up
def test(model):
    kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
    if args.dataset == 'cifar10':
        test_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])),
            batch_size=args.test_batch_size, shuffle=True, **kwargs)
    elif args.dataset == 'cifar100':
        test_loader = torch.utils.data.DataLoader(
            datasets.CIFAR100('./data.cifar100', train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])),
            batch_size=args.test_batch_size, shuffle=True, **kwargs)
    else:
        raise ValueError("No valid dataset is given.")
    model.eval()
    correct = 0
    for data, target in test_loader:
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data, volatile=True), Variable(target)
        output = model(data)
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()

    print('\nTest set: Accuracy: {}/{} ({:.1f}%)\n'.format(
        correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset)))
    return correct / float(len(test_loader.dataset)) 
Example #22
Source File: datasets.py    From shake-drop_pytorch with MIT License 5 votes vote down vote up
def fetch_bylabel(label):
    if label == 10:
        normalizer = transforms.Normalize(mean=[0.4914, 0.4824, 0.4467],
                                          std=[0.2471, 0.2435, 0.2616])
        data_cls = datasets.CIFAR10
    else:
        normalizer = transforms.Normalize(mean=[0.5071, 0.4867, 0.4408],
                                          std=[0.2675, 0.2565, 0.2761])
        data_cls = datasets.CIFAR100
    return normalizer, data_cls 
Example #23
Source File: cifar10_module.py    From PyTorch_CIFAR10 with MIT License 5 votes vote down vote up
def val_dataloader(self):
        transform_val = transforms.Compose([transforms.ToTensor(),
                                            transforms.Normalize(self.mean, self.std)])
        dataset = CIFAR10(root=self.hparams.data_dir, train=False, transform=transform_val)
        dataloader = DataLoader(dataset, batch_size=self.hparams.batch_size, num_workers=4, pin_memory=True)
        return dataloader 
Example #24
Source File: cifar10_module.py    From PyTorch_CIFAR10 with MIT License 5 votes vote down vote up
def train_dataloader(self):
        transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),
                                              transforms.RandomHorizontalFlip(),
                                              transforms.ToTensor(),
                                              transforms.Normalize(self.mean, self.std)])
        dataset = CIFAR10(root=self.hparams.data_dir, train=True, transform=transform_train)
        dataloader = DataLoader(dataset, batch_size=self.hparams.batch_size, num_workers=4, shuffle=True, drop_last=True, pin_memory=True)
        return dataloader 
Example #25
Source File: load_dataset.py    From Generative_Continual_Learning with MIT License 5 votes vote down vote up
def load_dataset_test(data_dir, dataset, batch_size):
    list_classes_test = []

    fas=False

    path = os.path.join(data_dir, 'Datasets', dataset)
    
    if dataset == 'mnist':
        dataset_test = datasets.MNIST(path, train=False, download=True, transform=transforms.Compose([transforms.ToTensor()]))
    elif dataset == 'fashion':
        if fas:
            dataset_test = DataLoader(
                datasets.FashionMNIST(path, train=False, download=True, transform=transforms.Compose(
                    [transforms.ToTensor()])),
                batch_size=batch_size)
        else:
            dataset_test = fashion(path, train=False, download=True, transform=transforms.ToTensor())

    elif dataset == 'cifar10':
        transform = transforms.Compose(
                [transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

        dataset_test = datasets.CIFAR10(root=path, train=False,
                   download=True, transform=transform)

    elif dataset == 'celebA':
        dataset_test = utils.load_celebA(path + 'celebA', transform=transforms.Compose(
            [transforms.CenterCrop(160), transforms.Scale(64), transforms.ToTensor()]), batch_size=batch_size)
    elif dataset == 'timagenet':
        dataset_test, labels = get_test_image_folders(path)
        list_classes_test = np.asarray([labels[i] for i in range(len(dataset_test))])
        dataset_test = Subset(dataset_test, np.where(list_classes_test < 10)[0])
        list_classes_test = np.where(list_classes_test < 10)[0]

    list_classes_test = np.asarray([dataset_test[i][1] for i in range(len(dataset_test))])

    return dataset_test, list_classes_test 
Example #26
Source File: disjoint.py    From Generative_Continual_Learning with MIT License 5 votes vote down vote up
def load_cifar10(self):
        transform_train = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        dataset_train = datasets.CIFAR10(root='./Datasets', train=True, download=True, transform=transform_train)
        tensor_data = torch.Tensor(len(dataset_train),3,32,32)
        tensor_label = torch.LongTensor(len(dataset_train))

        for i in range(len(dataset_train)):
            tensor_data[i] = dataset_train[i][0]
            tensor_label[i] = dataset_train[i][1]

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        dataset_test = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

        tensor_test = torch.Tensor(len(dataset_test),3,32,32)
        tensor_label_test = torch.LongTensor(len(dataset_test))

        for i in range(len(dataset_test)):
            tensor_test[i] = dataset_test[i][0]
            tensor_label_test[i] = dataset_test[i][1]

        #testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

        return tensor_data, tensor_label, tensor_test, tensor_label_test 
Example #27
Source File: lottery_res110prune.py    From rethinking-network-pruning with MIT License 5 votes vote down vote up
def test(model):
    kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
    if args.dataset == 'cifar10':
        test_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])),
            batch_size=args.test_batch_size, shuffle=False, **kwargs)
    elif args.dataset == 'cifar100':
        test_loader = torch.utils.data.DataLoader(
            datasets.CIFAR100('./data.cifar100', train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])),
            batch_size=args.test_batch_size, shuffle=False, **kwargs)
    else:
        raise ValueError("No valid dataset is given.")
    model.eval()
    correct = 0
    for data, target in test_loader:
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data, volatile=True), Variable(target)
        output = model(data)
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()

    print('\nTest set: Accuracy: {}/{} ({:.1f}%)\n'.format(
        correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset)))
    return correct / float(len(test_loader.dataset)) 
Example #28
Source File: load_data.py    From Deep-Expander-Networks with GNU General Public License v3.0 5 votes vote down vote up
def __init__(self, opt):
		kwargs = {
		  'num_workers': opt.workers,
		  'batch_size' : opt.batch_size,
		  'shuffle' : True,
		  'pin_memory': True}

		self.train_loader = torch.utils.data.DataLoader(
			datasets.CIFAR10(opt.data_dir, train=True, download=True,
					transform=transforms.Compose([
						transforms.RandomCrop(32, padding=4),
						transforms.RandomHorizontalFlip(),
						transforms.ToTensor(),
						transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]],
std=[x/255.0 for x in [63.0, 62.1, 66.7]])
					   ])),
			 **kwargs)

		self.val_loader = torch.utils.data.DataLoader(
			datasets.CIFAR10(opt.data_dir, train=False,
			  transform=transforms.Compose([
						   transforms.ToTensor(),
						   transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]],
std=[x/255.0 for x in [63.0, 62.1, 66.7]])
					   ])),
		  **kwargs) 
Example #29
Source File: lottery_resprune.py    From rethinking-network-pruning with MIT License 5 votes vote down vote up
def test(model):
    kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
    if args.dataset == 'cifar10':
        test_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])),
            batch_size=args.test_batch_size, shuffle=False, **kwargs)
    elif args.dataset == 'cifar100':
        test_loader = torch.utils.data.DataLoader(
            datasets.CIFAR100('./data.cifar100', train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])),
            batch_size=args.test_batch_size, shuffle=False, **kwargs)
    else:
        raise ValueError("No valid dataset is given.")
    model.eval()
    correct = 0
    for data, target in test_loader:
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data, volatile=True), Variable(target)
        output = model(data)
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()

    print('\nTest set: Accuracy: {}/{} ({:.1f}%)\n'.format(
        correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset)))
    return correct / float(len(test_loader.dataset)) 
Example #30
Source File: resprune.py    From rethinking-network-pruning with MIT License 5 votes vote down vote up
def test(model):
    kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
    if args.dataset == 'cifar10':
        test_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])),
            batch_size=args.test_batch_size, shuffle=False, **kwargs)
    elif args.dataset == 'cifar100':
        test_loader = torch.utils.data.DataLoader(
            datasets.CIFAR100('./data.cifar100', train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])),
            batch_size=args.test_batch_size, shuffle=False, **kwargs)
    else:
        raise ValueError("No valid dataset is given.")
    model.eval()
    correct = 0
    for data, target in test_loader:
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data, volatile=True), Variable(target)
        output = model(data)
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()

    print('\nTest set: Accuracy: {}/{} ({:.1f}%)\n'.format(
        correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset)))
    return correct / float(len(test_loader.dataset))