Python torchvision.datasets.FakeData() Examples

The following are 10 code examples for showing how to use torchvision.datasets.FakeData(). These examples are extracted from open source projects. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.

You may check out the related API usage on the sidebar.

You may also want to check out all available functions/classes of the module torchvision.datasets , or try the search function .

Example 1
Project: ignite   Author: pytorch   File: neural_style.py    License: BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def check_dataset(args):
    transform = transforms.Compose(
        [
            transforms.Resize(args.image_size),
            transforms.CenterCrop(args.image_size),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.mul(255)),
        ]
    )

    if args.dataset in {"folder", "mscoco"}:
        train_dataset = datasets.ImageFolder(args.dataroot, transform)
    elif args.dataset == "test":
        train_dataset = datasets.FakeData(
            size=args.batch_size, image_size=(3, 32, 32), num_classes=1, transform=transform
        )
    else:
        raise RuntimeError("Invalid dataset name: {}".format(args.dataset))

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0)

    return train_loader 
Example 2
Project: AIX360   Author: IBM   File: test_DIPVAE.py    License: Apache License 2.0 5 votes vote down vote up
def __init__(self, batch_size=256, subset_size=256, test_batch_size=256):
        trans = transforms.Compose([transforms.ToTensor()])

        root = './data_fake_fmnist'
        train_set = dset.FakeData(image_size=(1, 28, 28),transform=transforms.ToTensor())
        test_set = dset.FakeData(image_size=(1, 28, 28),transform=transforms.ToTensor())

        indices = torch.randperm(len(train_set))[:subset_size]
        train_set = torch.utils.data.Subset(train_set, indices)

        self.train_loader = torch.utils.data.DataLoader(
            dataset=train_set,
            batch_size=batch_size,
            shuffle=True)
        self.test_loader = torch.utils.data.DataLoader(
            dataset=test_set,
            batch_size=test_batch_size,
            shuffle=False)

        self.name = "fakemnist"
        self.data_dims = [28, 28, 1]
        self.train_size = len(self.train_loader)
        self.test_size = len(self.test_loader)
        self.range = [0.0, 1.0]
        self.batch_size = batch_size
        self.num_training_instances = len(train_set)
        self.num_test_instances = len(test_set)
        self.likelihood_type = 'gaussian'
        self.output_activation_type = 'sigmoid' 
Example 3
Project: self-attention-GAN-pytorch   Author: voletiv   File: utils.py    License: MIT License 5 votes vote down vote up
def make_dataloader(batch_size, dataset_type, data_path, shuffle=True, drop_last=True, dataloader_args={},
                    resize=True, imsize=128, centercrop=False, centercrop_size=128, totensor=True,
                    normalize=False, norm_mean=(0.5, 0.5, 0.5), norm_std=(0.5, 0.5, 0.5)):
    # Make transform
    transform = make_transform(resize=resize, imsize=imsize,
                               centercrop=centercrop, centercrop_size=centercrop_size,
                               totensor=totensor,
                               normalize=normalize, norm_mean=norm_mean, norm_std=norm_std)
    # Make dataset
    if dataset_type in ['folder', 'imagenet', 'lfw']:
        # folder dataset
        assert os.path.exists(data_path), "data_path does not exist! Given: " + data_path
        dataset = dset.ImageFolder(root=data_path, transform=transform)
    elif dataset_type == 'lsun':
        assert os.path.exists(data_path), "data_path does not exist! Given: " + data_path
        dataset = dset.LSUN(root=data_path, classes=['bedroom_train'], transform=transform)
    elif dataset_type == 'cifar10':
        if not os.path.exists(data_path):
            print("data_path does not exist! Given: {}\nDownloading CIFAR10 dataset...".format(data_path))
        dataset = dset.CIFAR10(root=data_path, download=True, transform=transform)
    elif dataset_type == 'fake':
        dataset = dset.FakeData(image_size=(3, centercrop_size, centercrop_size), transform=transforms.ToTensor())
    assert dataset
    num_of_classes = len(dataset.classes)
    print("Data found!  # of images =", len(dataset), ", # of classes =", num_of_classes, ", classes:", dataset.classes)
    # Make dataloader from dataset
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, **dataloader_args)
    return dataloader, num_of_classes 
Example 4
Project: lale   Author: IBM   File: test_interoperability.py    License: Apache License 2.0 5 votes vote down vote up
def test_init_fit_predict(self):
        import torchvision.datasets as datasets
        import torchvision.transforms as transforms
        from lale.lib.pytorch import ResNet50

        transform = transforms.Compose([transforms.ToTensor()])

        data_train = datasets.FakeData(size = 50, num_classes=2 , transform = transform)#, target_transform = transform)
        clf = ResNet50(num_classes=2,num_epochs = 1)
        clf.fit(data_train)
        predicted = clf.predict(data_train) 
Example 5
Project: pytorch-dp   Author: facebookresearch   File: utils_test.py    License: Apache License 2.0 5 votes vote down vote up
def genFakeData(
        self, imgSize: Tuple[int, int, int], batch_size: int = 1, num_batches: int = 1
    ) -> DataLoader:
        self.ds = FakeData(
            size=num_batches,
            image_size=imgSize,
            num_classes=2,
            transform=transforms.Compose([transforms.ToTensor()]),
        )
        return DataLoader(self.ds, batch_size=batch_size) 
Example 6
Project: pytorch-dp   Author: facebookresearch   File: per_sample_gradient_clip_test.py    License: Apache License 2.0 5 votes vote down vote up
def setUp_data(self):
        self.ds = FakeData(
            size=self.DATA_SIZE,
            image_size=(1, 35, 35),
            num_classes=10,
            transform=transforms.Compose(
                [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
            ),
        )
        self.dl = DataLoader(self.ds, batch_size=self.DATA_SIZE) 
Example 7
Project: pytorch-dp   Author: facebookresearch   File: virtual_step_test.py    License: Apache License 2.0 5 votes vote down vote up
def setUp_data(self):
        self.ds = FakeData(
            size=self.DATA_SIZE,
            image_size=(1, 35, 35),
            num_classes=10,
            transform=transforms.Compose(
                [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
            ),
        )
        self.dl = DataLoader(self.ds, batch_size=self.BATCH_SIZE) 
Example 8
Project: pytorch-dp   Author: facebookresearch   File: privacy_engine_test.py    License: Apache License 2.0 5 votes vote down vote up
def setUp_data(self):
        self.ds = FakeData(
            size=self.DATA_SIZE,
            image_size=(1, 35, 35),
            num_classes=10,
            transform=transforms.Compose(
                [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
            ),
        )
        self.dl = DataLoader(self.ds, batch_size=self.BATCH_SIZE) 
Example 9
Project: ignite   Author: pytorch   File: dcgan.py    License: BSD 3-Clause "New" or "Revised" License 4 votes vote down vote up
def check_dataset(dataset, dataroot):
    """

    Args:
        dataset (str): Name of the dataset to use. See CLI help for details
        dataroot (str): root directory where the dataset will be stored.

    Returns:
        dataset (data.Dataset): torchvision Dataset object

    """
    resize = transforms.Resize(64)
    crop = transforms.CenterCrop(64)
    to_tensor = transforms.ToTensor()
    normalize = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

    if dataset in {"imagenet", "folder", "lfw"}:
        dataset = dset.ImageFolder(root=dataroot, transform=transforms.Compose([resize, crop, to_tensor, normalize]))
        nc = 3

    elif dataset == "lsun":
        dataset = dset.LSUN(
            root=dataroot, classes=["bedroom_train"], transform=transforms.Compose([resize, crop, to_tensor, normalize])
        )
        nc = 3

    elif dataset == "cifar10":
        dataset = dset.CIFAR10(
            root=dataroot, download=True, transform=transforms.Compose([resize, to_tensor, normalize])
        )
        nc = 3

    elif dataset == "mnist":
        dataset = dset.MNIST(root=dataroot, download=True, transform=transforms.Compose([resize, to_tensor, normalize]))
        nc = 1

    elif dataset == "fake":
        dataset = dset.FakeData(size=256, image_size=(3, 64, 64), transform=to_tensor)
        nc = 3

    else:
        raise RuntimeError("Invalid dataset name: {}".format(dataset))

    return dataset, nc 
Example 10
Project: metropolis-hastings-gans   Author: uber-research   File: dcgan_loader.py    License: Apache License 2.0 4 votes vote down vote up
def get_data_loader(dataset, dataroot, workers, image_size, batch_size):
    if dataset in ['imagenet', 'folder', 'lfw']:
        # folder dataset
        dataset = dset.ImageFolder(root=dataroot,
                                   transform=transforms.Compose([
                                       transforms.Resize(image_size),
                                       transforms.CenterCrop(image_size),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5, 0.5, 0.5),
                                                            (0.5, 0.5, 0.5)),
                                   ]))
    elif dataset == 'lsun':
        dataset = dset.LSUN(root=dataroot, classes=['bedroom_train'],
                            transform=transforms.Compose([
                                transforms.Resize(image_size),
                                transforms.CenterCrop(image_size),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5),
                                                     (0.5, 0.5, 0.5)),
                            ]))
    elif dataset == 'cifar10':
        dataset = dset.CIFAR10(root=dataroot, download=True,
                               transform=transforms.Compose([
                                   transforms.Resize(image_size),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5),
                                                        (0.5, 0.5, 0.5)),
                               ]))
    elif dataset == 'mnist':
        dataset = dset.MNIST(root=dataroot, train=True, download=True,
                             transform=transforms.Compose([
                                 transforms.Resize(image_size),
                                 transforms.CenterCrop(image_size),
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.5, 0.5, 0.5),
                                                      (0.5, 0.5, 0.5)),
                             ]))
    elif dataset == 'fake':
        dataset = dset.FakeData(image_size=(3, image_size, image_size),
                                transform=transforms.ToTensor())
    else:
        assert False
    assert dataset

    data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                              shuffle=True,
                                              num_workers=int(workers))
    return data_loader