Python torchvision.datasets.LSUN Examples
The following are 13 code examples for showing how to use torchvision.datasets.LSUN(). These examples are extracted from open source projects. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
You may check out the related API usage on the sidebar.
You may also want to check out all available functions/classes of the module
torchvision.datasets
, or try the search function
.
Example 1
Project: Self-Supervised-Gans-Pytorch Author: vandit15 File: dataloaders.py License: MIT License | 6 votes |
def get_lsun_dataloader(path_to_data='../lsun', dataset='bedroom_train', batch_size=64): """LSUN dataloader with (128, 128) sized images. path_to_data : str One of 'bedroom_val' or 'bedroom_train' """ # Compose transforms transform = transforms.Compose([ transforms.Resize(128), transforms.CenterCrop(128), transforms.ToTensor() ]) # Get dataset lsun_dset = datasets.LSUN(db_path=path_to_data, classes=[dataset], transform=transform) # Create dataloader return DataLoader(lsun_dset, batch_size=batch_size, shuffle=True)
Example 2
Project: jdit Author: dingguanglei File: dataset.py License: Apache License 2.0 | 6 votes |
def get_lsun_dataloader(path_to_data='/data/dgl/LSUN', dataset='bedroom_train', batch_size=64): """LSUN dataloader with (128, 128) sized images. path_to_data : str One of 'bedroom_val' or 'bedroom_train' """ # Compose transforms transform = transforms.Compose([ transforms.Resize(128), transforms.CenterCrop(128), transforms.ToTensor() ]) # Get dataset lsun_dset = datasets.LSUN(root=path_to_data, classes=[dataset], transform=transform) # Create dataloader return DataLoader(lsun_dset, batch_size=batch_size, shuffle=True)
Example 3
Project: BigGAN-pytorch Author: sxhxliang File: data_loader.py License: Apache License 2.0 | 5 votes |
def load_lsun(self, classes=['church_outdoor_train','classroom_train']): transforms = self.transform(True, True, True, False) dataset = dsets.LSUN(self.path, classes=classes, transform=transforms) return dataset
Example 4
Project: self-attention-GAN-pytorch Author: voletiv File: utils.py License: MIT License | 5 votes |
def make_dataloader(batch_size, dataset_type, data_path, shuffle=True, drop_last=True, dataloader_args={}, resize=True, imsize=128, centercrop=False, centercrop_size=128, totensor=True, normalize=False, norm_mean=(0.5, 0.5, 0.5), norm_std=(0.5, 0.5, 0.5)): # Make transform transform = make_transform(resize=resize, imsize=imsize, centercrop=centercrop, centercrop_size=centercrop_size, totensor=totensor, normalize=normalize, norm_mean=norm_mean, norm_std=norm_std) # Make dataset if dataset_type in ['folder', 'imagenet', 'lfw']: # folder dataset assert os.path.exists(data_path), "data_path does not exist! Given: " + data_path dataset = dset.ImageFolder(root=data_path, transform=transform) elif dataset_type == 'lsun': assert os.path.exists(data_path), "data_path does not exist! Given: " + data_path dataset = dset.LSUN(root=data_path, classes=['bedroom_train'], transform=transform) elif dataset_type == 'cifar10': if not os.path.exists(data_path): print("data_path does not exist! Given: {}\nDownloading CIFAR10 dataset...".format(data_path)) dataset = dset.CIFAR10(root=data_path, download=True, transform=transform) elif dataset_type == 'fake': dataset = dset.FakeData(image_size=(3, centercrop_size, centercrop_size), transform=transforms.ToTensor()) assert dataset num_of_classes = len(dataset.classes) print("Data found! # of images =", len(dataset), ", # of classes =", num_of_classes, ", classes:", dataset.classes) # Make dataloader from dataset dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, **dataloader_args) return dataloader, num_of_classes
Example 5
Project: ganzo Author: unicredit File: data.py License: Apache License 2.0 | 5 votes |
def __init__(self, options): transform_list = [] if options.image_size is not None: transform_list.append(transforms.Resize((options.image_size, options.image_size))) # transform_list.append(transforms.CenterCrop(options.image_size)) transform_list.append(transforms.ToTensor()) if options.image_colors == 1: transform_list.append(transforms.Normalize(mean=[0.5], std=[0.5])) elif options.image_colors == 3: transform_list.append(transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])) transform = transforms.Compose(transform_list) if options.dataset == 'mnist': dataset = datasets.MNIST(options.data_dir, train=True, download=True, transform=transform) elif options.dataset == 'emnist': # Updated URL from https://www.westernsydney.edu.au/bens/home/reproducible_research/emnist datasets.EMNIST.url = 'https://cloudstor.aarnet.edu.au/plus/s/ZNmuFiuQTqZlu9W/download' dataset = datasets.EMNIST(options.data_dir, split=options.image_class, train=True, download=True, transform=transform) elif options.dataset == 'fashion-mnist': dataset = datasets.FashionMNIST(options.data_dir, train=True, download=True, transform=transform) elif options.dataset == 'lsun': training_class = options.image_class + '_train' dataset = datasets.LSUN(options.data_dir, classes=[training_class], transform=transform) elif options.dataset == 'cifar10': dataset = datasets.CIFAR10(options.data_dir, train=True, download=True, transform=transform) elif options.dataset == 'cifar100': dataset = datasets.CIFAR100(options.data_dir, train=True, download=True, transform=transform) else: dataset = datasets.ImageFolder(root=options.data_dir, transform=transform) self.dataloader = DataLoader( dataset, batch_size=options.batch_size, num_workers=options.loader_workers, shuffle=True, drop_last=True, pin_memory=options.pin_memory ) self.iterator = iter(self.dataloader)
Example 6
Project: RL-GAN-Net Author: iSarmad File: data_loader.py License: MIT License | 5 votes |
def load_lsun(self, classes='church_outdoor_train'): transforms = self.transform(True, True, True, False) dataset = dsets.LSUN(self.path, classes=[classes], transform=transforms) return dataset
Example 7
Project: vgan Author: akanazawa File: inputs.py License: MIT License | 5 votes |
def get_dataset(name, data_dir, size=64, lsun_categories=None): transform = transforms.Compose([ transforms.Resize(size), transforms.CenterCrop(size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), transforms.Lambda(lambda x: x + 1./128 * torch.rand(x.size())), ]) if name == 'image': dataset = datasets.ImageFolder(data_dir, transform) nlabels = len(dataset.classes) elif name == 'npy': # Only support normalization for now dataset = datasets.DatasetFolder(data_dir, npy_loader, ['npy']) nlabels = len(dataset.classes) elif name == 'cifar10': dataset = datasets.CIFAR10(root=data_dir, train=True, download=True, transform=transform) nlabels = 10 elif name == 'lsun': if lsun_categories is None: lsun_categories = 'train' dataset = datasets.LSUN(data_dir, lsun_categories, transform) nlabels = len(dataset.classes) elif name == 'lsun_class': dataset = datasets.LSUNClass(data_dir, transform, target_transform=(lambda t: 0)) nlabels = 1 else: raise NotImplemented return dataset, nlabels
Example 8
Project: ffjord Author: rtqichen File: viz_multiscale.py License: MIT License | 5 votes |
def get_dataset(args): trans = lambda im_size: tforms.Compose([tforms.Resize(im_size), tforms.ToTensor(), add_noise]) if args.data == "mnist": im_dim = 1 im_size = 28 if args.imagesize is None else args.imagesize train_set = dset.MNIST(root="./data", train=True, transform=trans(im_size), download=True) elif args.data == "cifar10": im_dim = 3 im_size = 32 if args.imagesize is None else args.imagesize train_set = dset.CIFAR10( root="./data", train=True, transform=tforms.Compose([ tforms.Resize(im_size), tforms.RandomHorizontalFlip(), tforms.ToTensor(), add_noise, ]), download=True ) elif args.data == 'lsun_church': im_dim = 3 im_size = 64 if args.imagesize is None else args.imagesize train_set = dset.LSUN( 'data', ['church_outdoor_train'], transform=tforms.Compose([ tforms.Resize(96), tforms.RandomCrop(64), tforms.Resize(im_size), tforms.ToTensor(), add_noise, ]) ) data_shape = (im_dim, im_size, im_size) if not args.conv: data_shape = (im_dim * im_size * im_size,) return train_set, data_shape
Example 9
Project: improved-wgan-pytorch Author: jalola File: training_utils.py License: MIT License | 5 votes |
def load_data(image_data_type, path_to_folder, data_transform, batch_size, classes=None, num_workers=5): # torch issue # https://github.com/pytorch/pytorch/issues/22866 torch.set_num_threads(1) if image_data_type == 'lsun': dataset = datasets.LSUN(path_to_folder, classes=classes, transform=data_transform) elif image_data_type == "image_folder": dataset = datasets.ImageFolder(root=path_to_folder,transform=data_transform) else: raise ValueError("Invalid image data type") dataset_loader = torch.utils.data.DataLoader(dataset,batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True, pin_memory=True) return dataset_loader
Example 10
Project: improved-wgan-pytorch Author: jalola File: congan_train.py License: MIT License | 5 votes |
def load_data(path_to_folder, classes): data_transform = transforms.Compose([ transforms.Scale(64), transforms.CenterCrop(64), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5]) ]) if IMAGE_DATA_SET == 'lsun': dataset = datasets.LSUN(path_to_folder, classes=classes, transform=data_transform) else: dataset = datasets.ImageFolder(root=path_to_folder,transform=data_transform) dataset_loader = torch.utils.data.DataLoader(dataset,batch_size=BATCH_SIZE, shuffle=True, num_workers=5, drop_last=True, pin_memory=True) return dataset_loader
Example 11
Project: ignite Author: pytorch File: dcgan.py License: BSD 3-Clause "New" or "Revised" License | 4 votes |
def check_dataset(dataset, dataroot): """ Args: dataset (str): Name of the dataset to use. See CLI help for details dataroot (str): root directory where the dataset will be stored. Returns: dataset (data.Dataset): torchvision Dataset object """ resize = transforms.Resize(64) crop = transforms.CenterCrop(64) to_tensor = transforms.ToTensor() normalize = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) if dataset in {"imagenet", "folder", "lfw"}: dataset = dset.ImageFolder(root=dataroot, transform=transforms.Compose([resize, crop, to_tensor, normalize])) nc = 3 elif dataset == "lsun": dataset = dset.LSUN( root=dataroot, classes=["bedroom_train"], transform=transforms.Compose([resize, crop, to_tensor, normalize]) ) nc = 3 elif dataset == "cifar10": dataset = dset.CIFAR10( root=dataroot, download=True, transform=transforms.Compose([resize, to_tensor, normalize]) ) nc = 3 elif dataset == "mnist": dataset = dset.MNIST(root=dataroot, download=True, transform=transforms.Compose([resize, to_tensor, normalize])) nc = 1 elif dataset == "fake": dataset = dset.FakeData(size=256, image_size=(3, 64, 64), transform=to_tensor) nc = 3 else: raise RuntimeError("Invalid dataset name: {}".format(dataset)) return dataset, nc
Example 12
Project: CapsGAN Author: raeidsaqur File: main.py License: MIT License | 4 votes |
def __getDataSet(opt): if isDebug: print(f"Getting dataset: {opt.dataset} ... ") dataset = None if opt.dataset in ['imagenet', 'folder', 'lfw']: # folder dataset traindir = os.path.join(opt.dataroot, f"{opt.dataroot}/train") valdir = os.path.join(opt.dataroot, f"{opt.dataroot}/val") normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = dset.ImageFolder( traindir, transforms.Compose([ transforms.RandomResizedCrop(opt.imageSize), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) dataset = dset.ImageFolder(root=opt.dataroot, transform=transforms.Compose([ transforms.Scale(opt.imageSize), transforms.CenterCrop(opt.imageSize), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])) elif opt.dataset == 'lsun': dataset = dset.LSUN(db_path=opt.dataroot, classes=['bedroom_train'], transform=transforms.Compose([ transforms.Scale(opt.imageSize), transforms.CenterCrop(opt.imageSize), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])) elif opt.dataset == 'cifar10': dataset = dset.CIFAR10(root=opt.dataroot, download=True, transform=transforms.Compose([ transforms.Scale(opt.imageSize), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])) # Load pre-trained state dict if opt.load_dict: opt.netD = NETD_CIFAR10 opt.netG = NETG_CIFAR10 elif opt.dataset == 'mnist': opt.nc = 1 opt.imageSize = 32 dataset = dset.MNIST(root=opt.dataroot, download=True, transform=transforms.Compose([ transforms.Scale(opt.imageSize), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])) # Update opt params for mnist if opt.load_dict: opt.netD = NETD_MNIST opt.netG = NETG_MNIST return dataset
Example 13
Project: metropolis-hastings-gans Author: uber-research File: dcgan_loader.py License: Apache License 2.0 | 4 votes |
def get_data_loader(dataset, dataroot, workers, image_size, batch_size): if dataset in ['imagenet', 'folder', 'lfw']: # folder dataset dataset = dset.ImageFolder(root=dataroot, transform=transforms.Compose([ transforms.Resize(image_size), transforms.CenterCrop(image_size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])) elif dataset == 'lsun': dataset = dset.LSUN(root=dataroot, classes=['bedroom_train'], transform=transforms.Compose([ transforms.Resize(image_size), transforms.CenterCrop(image_size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])) elif dataset == 'cifar10': dataset = dset.CIFAR10(root=dataroot, download=True, transform=transforms.Compose([ transforms.Resize(image_size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])) elif dataset == 'mnist': dataset = dset.MNIST(root=dataroot, train=True, download=True, transform=transforms.Compose([ transforms.Resize(image_size), transforms.CenterCrop(image_size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])) elif dataset == 'fake': dataset = dset.FakeData(image_size=(3, image_size, image_size), transform=transforms.ToTensor()) else: assert False assert dataset data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=int(workers)) return data_loader