Python torchvision.datasets.MNIST Examples
The following are 30
code examples of torchvision.datasets.MNIST().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torchvision.datasets
, or try the search function
.
Example #1
Source File: data_loader.py From mnist-svhn-transfer with MIT License | 9 votes |
def get_loader(config): """Builds and returns Dataloader for MNIST and SVHN dataset.""" transform = transforms.Compose([ transforms.Scale(config.image_size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) svhn = datasets.SVHN(root=config.svhn_path, download=True, transform=transform) mnist = datasets.MNIST(root=config.mnist_path, download=True, transform=transform) svhn_loader = torch.utils.data.DataLoader(dataset=svhn, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers) mnist_loader = torch.utils.data.DataLoader(dataset=mnist, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers) return svhn_loader, mnist_loader
Example #2
Source File: model.py From iAI with MIT License | 8 votes |
def __init__(self): self.batch_size = 64 self.test_batch_size = 100 self.learning_rate = 0.01 self.sgd_momentum = 0.9 self.log_interval = 100 # Fetch MNIST data set. self.train_loader = torch.utils.data.DataLoader( datasets.MNIST('/tmp/mnist/data', train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=self.batch_size, shuffle=True) self.test_loader = torch.utils.data.DataLoader( datasets.MNIST('/tmp/mnist/data', train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=self.test_batch_size, shuffle=True) self.network = Net() # Train the network for several epochs, validating after each epoch.
Example #3
Source File: mnist_m.py From pytorch-atda with MIT License | 6 votes |
def get_mnist_m(train, get_dataset=False, batch_size=cfg.batch_size): """Get MNIST-M dataset loader.""" # image pre-processing pre_process = transforms.Compose([transforms.ToTensor(), transforms.Normalize( mean=cfg.dataset_mean, std=cfg.dataset_std)]) # dataset and data loader mnist_m_dataset = MNIST_M(root=cfg.data_root, train=train, transform=pre_process, download=True) if get_dataset: return mnist_m_dataset else: mnist_m_data_loader = torch.utils.data.DataLoader( dataset=mnist_m_dataset, batch_size=batch_size, shuffle=True) return mnist_m_data_loader
Example #4
Source File: mnist.py From pytorch-arda with MIT License | 6 votes |
def get_mnist(train): """Get MNIST dataset loader.""" # image pre-processing pre_process = transforms.Compose([transforms.ToTensor(), transforms.Normalize( mean=params.dataset_mean, std=params.dataset_std)]) # dataset and data loader mnist_dataset = datasets.MNIST(root=params.data_root, train=train, transform=pre_process, download=True) mnist_data_loader = torch.utils.data.DataLoader( dataset=mnist_dataset, batch_size=params.batch_size, shuffle=True) return mnist_data_loader
Example #5
Source File: sampling.py From federated-learning with MIT License | 6 votes |
def mnist_noniid(dataset, num_users): """ Sample non-I.I.D client data from MNIST dataset :param dataset: :param num_users: :return: """ num_shards, num_imgs = 200, 300 idx_shard = [i for i in range(num_shards)] dict_users = {i: np.array([], dtype='int64') for i in range(num_users)} idxs = np.arange(num_shards*num_imgs) labels = dataset.train_labels.numpy() # sort labels idxs_labels = np.vstack((idxs, labels)) idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()] idxs = idxs_labels[0,:] # divide and assign for i in range(num_users): rand_set = set(np.random.choice(idx_shard, 2, replace=False)) idx_shard = list(set(idx_shard) - rand_set) for rand in rand_set: dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) return dict_users
Example #6
Source File: dataset.py From jdit with Apache License 2.0 | 6 votes |
def get_fashion_mnist_dataloaders(root=r'.\dataset\fashion_data', batch_size=128, resize=32, transform_list=None, num_workers=-1): """Fashion MNIST dataloader with (32, 32) sized images.""" # Resize images so they are a power of 2 if num_workers == -1: print("use %d thread!" % psutil.cpu_count()) num_workers = psutil.cpu_count() if transform_list is None: transform_list = [ transforms.Resize(resize), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]) ] all_transforms = transforms.Compose(transform_list) # Get train and test data train_data = datasets.FashionMNIST(root, train=True, download=True, transform=all_transforms) test_data = datasets.FashionMNIST(root, train=False, transform=all_transforms) # Create dataloaders train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers) test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers) return train_loader, test_loader
Example #7
Source File: dataloaders.py From Self-Supervised-Gans-Pytorch with MIT License | 6 votes |
def get_fashion_mnist_dataloaders(batch_size=128): """Fashion MNIST dataloader with (32, 32) sized images.""" # Resize images so they are a power of 2 all_transforms = transforms.Compose([ transforms.Resize(32), transforms.ToTensor() ]) # Get train and test data train_data = datasets.FashionMNIST('../fashion_data', train=True, download=True, transform=all_transforms) test_data = datasets.FashionMNIST('../fashion_data', train=False, transform=all_transforms) # Create dataloaders train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True) test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True) return train_loader, test_loader
Example #8
Source File: dataloaders.py From Self-Supervised-Gans-Pytorch with MIT License | 6 votes |
def get_mnist_dataloaders(batch_size=128): """MNIST dataloader with (32, 32) sized images.""" # Resize images so they are a power of 2 all_transforms = transforms.Compose([ transforms.Resize(32), transforms.ToTensor() ]) # Get train and test data train_data = datasets.MNIST('../data', train=True, download=True, transform=all_transforms) test_data = datasets.MNIST('../data', train=False, transform=all_transforms) # Create dataloaders train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True) test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True) return train_loader, test_loader
Example #9
Source File: mnist_fid.py From misgan with MIT License | 6 votes |
def __init__(self): model = mnist_model.Net().to(device) model.eval() map_location = None if use_cuda else 'cpu' model.load_state_dict( torch.load('mnist.pth', map_location=map_location)) stats_file = f'mnist_act_{feature_layer}.npz' try: f = np.load(stats_file) m_mnist, s_mnist = f['mu'][:], f['sigma'][:] f.close() except FileNotFoundError: data = datasets.MNIST('data', train=True, download=True, transform=transforms.ToTensor()) images = len(data) batch_size = 64 data_loader = DataLoader([image for image, _ in data], batch_size=batch_size) m_mnist, s_mnist = calculate_activation_statistics( data_loader, images, model, verbose=True) np.savez(stats_file, mu=m_mnist, sigma=s_mnist) self.model = model self.mnist_stats = m_mnist, s_mnist
Example #10
Source File: mnist.py From sagemaker-python-sdk with Apache License 2.0 | 6 votes |
def _get_train_data_loader(training_dir, is_distributed, batch_size, **kwargs): logger.info("Get train data loader") dataset = datasets.MNIST( training_dir, train=True, transform=transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] ), download=False, # True sets a dependency on an external site for our canaries. ) train_sampler = ( torch.utils.data.distributed.DistributedSampler(dataset) if is_distributed else None ) train_loader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, shuffle=train_sampler is None, sampler=train_sampler, **kwargs ) return train_sampler, train_loader
Example #11
Source File: dataset.py From jdit with Apache License 2.0 | 6 votes |
def get_mnist_dataloaders(root=r'..\data', batch_size=128): """MNIST dataloader with (32, 32) sized images.""" # Resize images so they are a power of 2 all_transforms = transforms.Compose([ transforms.Resize(28), transforms.ToTensor(), # transforms.Normalize([0.5],[0.5]) ]) # Get train and test data train_data = datasets.MNIST(root, train=True, download=True, transform=all_transforms) test_data = datasets.MNIST(root, train=False, transform=all_transforms) # Create dataloaders train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True) test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True) return train_loader, test_loader
Example #12
Source File: mnist.py From pytorch-atda with MIT License | 6 votes |
def get_mnist(train, get_dataset=False, batch_size=cfg.batch_size): """Get MNIST dataset loader.""" # image pre-processing convert_to_3_channels = transforms.Lambda( lambda x: torch.cat([x, x, x], 0)) pre_process = transforms.Compose([transforms.ToTensor(), transforms.Normalize( mean=cfg.dataset_mean, std=cfg.dataset_std), convert_to_3_channels]) # dataset and data loader mnist_dataset = datasets.MNIST(root=cfg.data_root, train=train, transform=pre_process, download=True) if get_dataset: return mnist_dataset else: mnist_data_loader = torch.utils.data.DataLoader( dataset=mnist_dataset, batch_size=batch_size, shuffle=True) return mnist_data_loader
Example #13
Source File: mnist.py From Deep-SAD-PyTorch with MIT License | 6 votes |
def __getitem__(self, index): """Override the original method of the MNIST class. Args: index (int): Index Returns: tuple: (image, target, semi_target, index) """ img, target, semi_target = self.data[index], int(self.targets[index]), int(self.semi_targets[index]) # doing this so that it is consistent with all other datasets # to return a PIL Image img = Image.fromarray(img.numpy(), mode='L') if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) return img, target, semi_target, index
Example #14
Source File: loaders.py From dfw with MIT License | 6 votes |
def loaders_mnist(dataset, batch_size=64, cuda=0, train_size=50000, val_size=10000, test_size=10000, test_batch_size=1000, **kwargs): assert dataset == 'mnist' root = '{}/{}'.format(os.environ['VISION_DATA'], dataset) # Data loading code normalize = transforms.Normalize(mean=(0.1307,), std=(0.3081,)) transform = transforms.Compose([transforms.ToTensor(), normalize]) # define two datasets in order to have different transforms # on training and validation dataset_train = datasets.MNIST(root=root, train=True, transform=transform) dataset_val = datasets.MNIST(root=root, train=True, transform=transform) dataset_test = datasets.MNIST(root=root, train=False, transform=transform) return create_loaders(dataset_train, dataset_val, dataset_test, train_size, val_size, test_size, batch_size=batch_size, test_batch_size=test_batch_size, cuda=cuda, num_workers=0)
Example #15
Source File: model.py From iAI with MIT License | 6 votes |
def __init__(self): self.batch_size = 64 self.test_batch_size = 100 self.learning_rate = 0.0025 self.sgd_momentum = 0.9 self.log_interval = 100 # Fetch MNIST data set. self.train_loader = torch.utils.data.DataLoader( datasets.MNIST('/tmp/mnist/data', train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=self.batch_size, shuffle=True) self.test_loader = torch.utils.data.DataLoader( datasets.MNIST('/tmp/mnist/data', train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=self.test_batch_size, shuffle=True) self.network = Net() # Train the network for one or more epochs, validating after each epoch.
Example #16
Source File: model.py From iAI with MIT License | 6 votes |
def __init__(self): self.batch_size = 64 self.test_batch_size = 100 self.learning_rate = 0.0025 self.sgd_momentum = 0.9 self.log_interval = 100 # Fetch MNIST data set. self.train_loader = torch.utils.data.DataLoader( datasets.MNIST('/tmp/mnist/data', train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=self.batch_size, shuffle=True) self.test_loader = torch.utils.data.DataLoader( datasets.MNIST('/tmp/mnist/data', train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=self.test_batch_size, shuffle=True) self.network = Net() # Train the network for one or more epochs, validating after each epoch.
Example #17
Source File: datasets.py From pytorch-deep-sets with MIT License | 6 votes |
def __init__(self, min_len: int, max_len: int, dataset_len: int, train: bool = True, transform: Compose = None): self.min_len = min_len self.max_len = max_len self.dataset_len = dataset_len self.train = train self.transform = transform self.mnist = MNIST(DATA_ROOT, train=self.train, transform=self.transform, download=True) mnist_len = self.mnist.__len__() mnist_items_range = np.arange(0, mnist_len) items_len_range = np.arange(self.min_len, self.max_len + 1) items_len = np.random.choice(items_len_range, size=self.dataset_len, replace=True) self.mnist_items = [] for i in range(self.dataset_len): self.mnist_items.append(np.random.choice(mnist_items_range, size=items_len[i], replace=True))
Example #18
Source File: mnistm.py From PyTorch-GAN with MIT License | 6 votes |
def __init__(self, root, mnist_root="data", train=True, transform=None, target_transform=None, download=False): """Init MNIST-M dataset.""" super(MNISTM, self).__init__() self.root = os.path.expanduser(root) self.mnist_root = os.path.expanduser(mnist_root) self.transform = transform self.target_transform = target_transform self.train = train # training set or test set if download: self.download() if not self._check_exists(): raise RuntimeError("Dataset not found." + " You can use download=True to download it") if self.train: self.train_data, self.train_labels = torch.load( os.path.join(self.root, self.processed_folder, self.training_file) ) else: self.test_data, self.test_labels = torch.load( os.path.join(self.root, self.processed_folder, self.test_file) )
Example #19
Source File: mnistm.py From PyTorch-GAN with MIT License | 5 votes |
def __init__(self, root, mnist_root="data", train=True, transform=None, target_transform=None, download=False): """Init MNIST-M dataset.""" super(MNISTM, self).__init__() self.root = os.path.expanduser(root) self.mnist_root = os.path.expanduser(mnist_root) self.transform = transform self.target_transform = target_transform self.train = train # training set or test set if download: self.download() if not self._check_exists(): raise RuntimeError('Dataset not found.' + ' You can use download=True to download it') if self.train: self.train_data, self.train_labels = \ torch.load(os.path.join(self.root, self.processed_folder, self.training_file)) else: self.test_data, self.test_labels = \ torch.load(os.path.join(self.root, self.processed_folder, self.test_file))
Example #20
Source File: dataloaders.py From pixel-constrained-cnn-pytorch with Apache License 2.0 | 5 votes |
def mnist(batch_size=128, num_colors=256, size=28, path_to_data='../mnist_data'): """MNIST dataloader with (28, 28) images. Parameters ---------- batch_size : int num_colors : int Number of colors to quantize images into. Typically 256, but can be lower for e.g. binary images. size : int Size (height and width) of each image. Default is 28 for no resizing. path_to_data : string Path to MNIST data files. """ quantize = get_quantize_func(num_colors) all_transforms = transforms.Compose([ transforms.Resize(size), transforms.ToTensor(), transforms.Lambda(lambda x: quantize(x)) ]) train_data = datasets.MNIST(path_to_data, train=True, download=True, transform=all_transforms) test_data = datasets.MNIST(path_to_data, train=False, transform=all_transforms) train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True) test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True) return train_loader, test_loader
Example #21
Source File: problems.py From convex_adversarial with MIT License | 5 votes |
def mnist_loaders(batch_size, shuffle_test=False): mnist_train = datasets.MNIST("./data", train=True, download=True, transform=transforms.ToTensor()) mnist_test = datasets.MNIST("./data", train=False, download=True, transform=transforms.ToTensor()) train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, pin_memory=True) test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=shuffle_test, pin_memory=True) return train_loader, test_loader
Example #22
Source File: problems.py From convex_adversarial with MIT License | 5 votes |
def fashion_mnist_loaders(batch_size): mnist_train = datasets.MNIST("./fashion_mnist", train=True, download=True, transform=transforms.ToTensor()) mnist_test = datasets.MNIST("./fashion_mnist", train=False, download=True, transform=transforms.ToTensor()) train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, pin_memory=True) test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, pin_memory=True) return train_loader, test_loader
Example #23
Source File: model_pixelcnn_bmnist.py From torchkit with MIT License | 5 votes |
def __init__(self): self.mdl = iaf_modules.PixelCNN(1,16,4,5,num_outlayers=1) if cuda: self.mdl = self.mdl.cuda() self.optim = optim.Adam(self.mdl.parameters(), lr=0.001, betas=(0.9, 0.999)) trs = transforms.Compose([transforms.ToTensor()]) self.data_loader = DataLoader(datasets.MNIST('data/mnist', train=True, download=True, transform=trs), batch_size = 32, shuffle = True)
Example #24
Source File: mnist_utils.py From CrypTen with MIT License | 5 votes |
def _get_norm_mnist(dir, reduced=None, binary=False): """Downloads and normalizes mnist""" mnist_train = datasets.MNIST(dir, download=True, train=True) mnist_test = datasets.MNIST(dir, download=True, train=False) # compute normalization factors data_all = torch.cat([mnist_train.data, mnist_test.data]).float() data_mean, data_std = data_all.mean(), data_all.std() tensor_mean, tensor_std = data_mean.unsqueeze(0), data_std.unsqueeze(0) # normalize mnist_train_norm = transforms.functional.normalize( mnist_train.data.float(), tensor_mean, tensor_std ) mnist_test_norm = transforms.functional.normalize( mnist_test.data.float(), tensor_mean, tensor_std ) # change all nonzero labels to 1 if binary classification required if binary: mnist_train.targets[mnist_train.targets != 0] = 1 mnist_test.targets[mnist_test.targets != 0] = 1 # create a reduced dataset if required if reduced is not None: mnist_norm = (mnist_train_norm[:reduced], mnist_test_norm[:reduced]) mnist_labels = (mnist_train.targets[:reduced], mnist_test.targets[:reduced]) else: mnist_norm = (mnist_train_norm, mnist_test_norm) mnist_labels = (mnist_train.targets, mnist_test.targets) return mnist_norm, mnist_labels
Example #25
Source File: data.py From ganzo with Apache License 2.0 | 5 votes |
def __init__(self, options): transform_list = [] if options.image_size is not None: transform_list.append(transforms.Resize((options.image_size, options.image_size))) # transform_list.append(transforms.CenterCrop(options.image_size)) transform_list.append(transforms.ToTensor()) if options.image_colors == 1: transform_list.append(transforms.Normalize(mean=[0.5], std=[0.5])) elif options.image_colors == 3: transform_list.append(transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])) transform = transforms.Compose(transform_list) if options.dataset == 'mnist': dataset = datasets.MNIST(options.data_dir, train=True, download=True, transform=transform) elif options.dataset == 'emnist': # Updated URL from https://www.westernsydney.edu.au/bens/home/reproducible_research/emnist datasets.EMNIST.url = 'https://cloudstor.aarnet.edu.au/plus/s/ZNmuFiuQTqZlu9W/download' dataset = datasets.EMNIST(options.data_dir, split=options.image_class, train=True, download=True, transform=transform) elif options.dataset == 'fashion-mnist': dataset = datasets.FashionMNIST(options.data_dir, train=True, download=True, transform=transform) elif options.dataset == 'lsun': training_class = options.image_class + '_train' dataset = datasets.LSUN(options.data_dir, classes=[training_class], transform=transform) elif options.dataset == 'cifar10': dataset = datasets.CIFAR10(options.data_dir, train=True, download=True, transform=transform) elif options.dataset == 'cifar100': dataset = datasets.CIFAR100(options.data_dir, train=True, download=True, transform=transform) else: dataset = datasets.ImageFolder(root=options.data_dir, transform=transform) self.dataloader = DataLoader( dataset, batch_size=options.batch_size, num_workers=options.loader_workers, shuffle=True, drop_last=True, pin_memory=options.pin_memory ) self.iterator = iter(self.dataloader)
Example #26
Source File: odenet_mnist.py From torchdiffeq with MIT License | 5 votes |
def get_mnist_loaders(data_aug=False, batch_size=128, test_batch_size=1000, perc=1.0): if data_aug: transform_train = transforms.Compose([ transforms.RandomCrop(28, padding=4), transforms.ToTensor(), ]) else: transform_train = transforms.Compose([ transforms.ToTensor(), ]) transform_test = transforms.Compose([ transforms.ToTensor(), ]) train_loader = DataLoader( datasets.MNIST(root='.data/mnist', train=True, download=True, transform=transform_train), batch_size=batch_size, shuffle=True, num_workers=2, drop_last=True ) train_eval_loader = DataLoader( datasets.MNIST(root='.data/mnist', train=True, download=True, transform=transform_test), batch_size=test_batch_size, shuffle=False, num_workers=2, drop_last=True ) test_loader = DataLoader( datasets.MNIST(root='.data/mnist', train=False, download=True, transform=transform_test), batch_size=test_batch_size, shuffle=False, num_workers=2, drop_last=True ) return train_loader, test_loader, train_eval_loader
Example #27
Source File: mnist.py From sagemaker-python-sdk with Apache License 2.0 | 5 votes |
def _get_test_data_loader(training_dir, **kwargs): logger.info("Get test data loader") return torch.utils.data.DataLoader( datasets.MNIST( training_dir, train=False, transform=transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] ), download=False, # True sets a dependency on an external site for our canaries. ), batch_size=1000, shuffle=True, **kwargs )
Example #28
Source File: mpc_autograd_cnn.py From CrypTen with MIT License | 5 votes |
def preprocess_mnist(context_manager): if context_manager is None: context_manager = NoopContextManager() with context_manager: # each party gets a unique temp directory with tempfile.TemporaryDirectory() as data_dir: mnist_train = datasets.MNIST(data_dir, download=True, train=True) mnist_test = datasets.MNIST(data_dir, download=True, train=False) # modify labels so all non-zero digits have class label 1 mnist_train.targets[mnist_train.targets != 0] = 1 mnist_test.targets[mnist_test.targets != 0] = 1 mnist_train.targets[mnist_train.targets == 0] = 0 mnist_test.targets[mnist_test.targets == 0] = 0 # compute normalization factors data_all = torch.cat([mnist_train.data, mnist_test.data]).float() data_mean, data_std = data_all.mean(), data_all.std() tensor_mean, tensor_std = data_mean.unsqueeze(0), data_std.unsqueeze(0) # normalize data data_train_norm = transforms.functional.normalize( mnist_train.data.float(), tensor_mean, tensor_std ) # partition features between Alice and Bob data_alice = data_train_norm[:, :, :20] data_bob = data_train_norm[:, :, 20:] train_labels = mnist_train.targets return data_alice, data_bob, train_labels
Example #29
Source File: demo_hogwild.py From tensorboardX with MIT License | 5 votes |
def train(rank, args, model, device, dataloader_kwargs): torch.manual_seed(args.seed + rank) train_loader = torch.utils.data.DataLoader( datasets.MNIST('../data', train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=args.batch_size, shuffle=True, num_workers=1, **dataloader_kwargs) optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) for epoch in range(1, args.epochs + 1): train_epoch(epoch, args, model, device, train_loader, optimizer)
Example #30
Source File: img_classification.py From neural-pipeline with MIT License | 5 votes |
def __init__(self, data_dir: str, is_train: bool): self.dataset = datasets.MNIST(data_dir, train=is_train, download=True)