Python torchvision.datasets.STL10 Examples

The following are 9 code examples for showing how to use torchvision.datasets.STL10(). These examples are extracted from open source projects. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.

You may check out the related API usage on the sidebar.

You may also want to check out all available functions/classes of the module torchvision.datasets , or try the search function .

Example 1
Project: bigBatch   Author: eladhoffer   File: data.py    License: MIT License 5 votes vote down vote up
def get_dataset(name, split='train', transform=None,
                target_transform=None, download=True, datasets_path=__DATASETS_DEFAULT_PATH):
    train = (split == 'train')
    root = os.path.join(datasets_path, name)
    if name == 'cifar10':
        return datasets.CIFAR10(root=root,
                                train=train,
                                transform=transform,
                                target_transform=target_transform,
                                download=download)
    elif name == 'cifar100':
        return datasets.CIFAR100(root=root,
                                 train=train,
                                 transform=transform,
                                 target_transform=target_transform,
                                 download=download)
    elif name == 'mnist':
        return datasets.MNIST(root=root,
                              train=train,
                              transform=transform,
                              target_transform=target_transform,
                              download=download)
    elif name == 'stl10':
        return datasets.STL10(root=root,
                              split=split,
                              transform=transform,
                              target_transform=target_transform,
                              download=download)
    elif name == 'imagenet':
        if train:
            root = os.path.join(root, 'train')
        else:
            root = os.path.join(root, 'val')
        return datasets.ImageFolder(root=root,
                                    transform=transform,
                                    target_transform=target_transform) 
Example 2
Project: amdim-public   Author: Philip-Bachman   File: datasets.py    License: MIT License 5 votes vote down vote up
def get_encoder_size(dataset):
    if dataset in [Dataset.C10, Dataset.C100]:
        return 32
    if dataset == Dataset.STL10:
        return 64
    if dataset in [Dataset.IN128, Dataset.PLACES205]:
        return 128
    raise RuntimeError("Couldn't get encoder size, unknown dataset: {}".format(dataset)) 
Example 3
Project: amdim-public   Author: Philip-Bachman   File: datasets.py    License: MIT License 5 votes vote down vote up
def _get_directories(dataset, input_dir):
    if dataset in [Dataset.C10, Dataset.C100, Dataset.STL10]:
        # Pytorch will download those datasets automatically
        return None, None
    if dataset == Dataset.IN128:
        train_dir = os.path.join(input_dir, 'ILSVRC2012_img_train/')
        val_dir = os.path.join(input_dir, 'ILSVRC2012_img_val/')
    elif dataset == Dataset.PLACES205:
        train_dir = os.path.join(input_dir, 'places205_256_train/')
        val_dir = os.path.join(input_dir, 'places205_256_val/')
    else:
        raise 'Data directories for dataset ' + dataset + ' are not defined'
    return train_dir, val_dir 
Example 4
Project: quantized.pytorch   Author: eladhoffer   File: data.py    License: MIT License 5 votes vote down vote up
def get_dataset(name, split='train', transform=None,
                target_transform=None, download=True, datasets_path=__DATASETS_DEFAULT_PATH):
    train = (split == 'train')
    root = os.path.join(datasets_path, name)
    if name == 'cifar10':
        return datasets.CIFAR10(root=root,
                                train=train,
                                transform=transform,
                                target_transform=target_transform,
                                download=download)
    elif name == 'cifar100':
        return datasets.CIFAR100(root=root,
                                 train=train,
                                 transform=transform,
                                 target_transform=target_transform,
                                 download=download)
    elif name == 'mnist':
        return datasets.MNIST(root=root,
                              train=train,
                              transform=transform,
                              target_transform=target_transform,
                              download=download)
    elif name == 'stl10':
        return datasets.STL10(root=root,
                              split=split,
                              transform=transform,
                              target_transform=target_transform,
                              download=download)
    elif name == 'imagenet':
        if train:
            root = os.path.join(root, 'train')
        else:
            root = os.path.join(root, 'val')
        return datasets.ImageFolder(root=root,
                                    transform=transform,
                                    target_transform=target_transform) 
Example 5
Project: pytorch-playground   Author: aaron-xichen   File: dataset.py    License: MIT License 5 votes vote down vote up
def get(batch_size, data_root='/mnt/local0/public_dataset/pytorch/', train=True, val=True, **kwargs):
    data_root = os.path.expanduser(os.path.join(data_root, 'stl10-data'))
    num_workers = kwargs.setdefault('num_workers', 1)
    kwargs.pop('input_size', None)
    print("Building STL10 data loader with {} workers".format(num_workers))
    ds = []
    if train:
        train_loader = torch.utils.data.DataLoader(
            datasets.STL10(
                root=data_root, split='train', download=True,
                transform=transforms.Compose([
                    transforms.Pad(4),
                    transforms.RandomCrop(96),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ])),
            batch_size=batch_size, shuffle=True, **kwargs)
        ds.append(train_loader)

    if val:
        test_loader = torch.utils.data.DataLoader(
            datasets.STL10(
                root=data_root, split='test', download=True,
                transform=transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ])),
            batch_size=batch_size, shuffle=False, **kwargs)
        ds.append(test_loader)

    ds = ds[0] if len(ds) == 1 else ds
    return ds 
Example 6
Project: Deep_Openset_Recognition_through_Uncertainty   Author: MrtnMndt   File: datasets.py    License: MIT License 5 votes vote down vote up
def get_dataset(self):
        """
        Uses torchvision.datasets.STL to load dataset.
        Downloads dataset if doesn't exist already.
        Returns:
             torch.utils.data.TensorDataset: trainset, valset
        """

        trainset = datasets.STL10('datasets/STL10/train/', split='train', transform=self.train_transforms,
                                  target_transform=None, download=True)
        valset = datasets.STL10('datasets/STL10/test/', split='test', transform=self.val_transforms,
                                target_transform=None, download=True)

        return trainset, valset 
Example 7
Project: AutoGAN   Author: TAMU-VITA   File: datasets.py    License: MIT License 4 votes vote down vote up
def __init__(self, args, cur_img_size=None):
        img_size = cur_img_size if cur_img_size else args.img_size
        if args.dataset.lower() == 'cifar10':
            Dt = datasets.CIFAR10
            transform = transforms.Compose([
                transforms.Resize(img_size),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ])
            args.n_classes = 10
        elif args.dataset.lower() == 'stl10':
            Dt = datasets.STL10
            transform = transforms.Compose([
                transforms.Resize(img_size),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ])
        else:
            raise NotImplementedError('Unknown dataset: {}'.format(args.dataset))

        if args.dataset.lower() == 'stl10':
            self.train = torch.utils.data.DataLoader(
                Dt(root=args.data_path, split='train+unlabeled', transform=transform, download=True),
                batch_size=args.dis_batch_size, shuffle=True,
                num_workers=args.num_workers, pin_memory=True)

            self.valid = torch.utils.data.DataLoader(
                Dt(root=args.data_path, split='test', transform=transform),
                batch_size=args.dis_batch_size, shuffle=False,
                num_workers=args.num_workers, pin_memory=True)

            self.test = self.valid
        else:
            self.train = torch.utils.data.DataLoader(
                Dt(root=args.data_path, train=True, transform=transform, download=True),
                batch_size=args.dis_batch_size, shuffle=True,
                num_workers=args.num_workers, pin_memory=True)

            self.valid = torch.utils.data.DataLoader(
                Dt(root=args.data_path, train=False, transform=transform),
                batch_size=args.dis_batch_size, shuffle=False,
                num_workers=args.num_workers, pin_memory=True)

            self.test = self.valid 
Example 8
Project: convNet.pytorch   Author: eladhoffer   File: data.py    License: MIT License 4 votes vote down vote up
def get_dataset(name, split='train', transform=None,
                target_transform=None, download=True, datasets_path='~/Datasets'):
    train = (split == 'train')
    root = os.path.join(os.path.expanduser(datasets_path), name)
    if name == 'cifar10':
        return datasets.CIFAR10(root=root,
                                train=train,
                                transform=transform,
                                target_transform=target_transform,
                                download=download)
    elif name == 'cifar100':
        return datasets.CIFAR100(root=root,
                                 train=train,
                                 transform=transform,
                                 target_transform=target_transform,
                                 download=download)
    elif name == 'mnist':
        return datasets.MNIST(root=root,
                              train=train,
                              transform=transform,
                              target_transform=target_transform,
                              download=download)
    elif name == 'stl10':
        return datasets.STL10(root=root,
                              split=split,
                              transform=transform,
                              target_transform=target_transform,
                              download=download)
    elif name == 'imagenet':
        if train:
            root = os.path.join(root, 'train')
        else:
            root = os.path.join(root, 'val')
        return datasets.ImageFolder(root=root,
                                    transform=transform,
                                    target_transform=target_transform)
    elif name == 'imagenet_tar':
        if train:
            root = os.path.join(root, 'imagenet_train.tar')
        else:
            root = os.path.join(root, 'imagenet_validation.tar')
        return IndexedFileDataset(root, extract_target_fn=(
            lambda fname: fname.split('/')[0]),
            transform=transform,
            target_transform=target_transform) 
Example 9
Project: NeMo   Author: NVIDIA   File: stl10_datalayer.py    License: Apache License 2.0 4 votes vote down vote up
def __init__(
        self,
        height: int = 96,
        width: int = 96,
        data_folder: str = "~/data/st10",
        split: str = "train",
        name: Optional[str] = None,
        batch_size: int = 64,
        shuffle: bool = True,
    ):
        """
        Initializes the STL10 datalayer.

        Args:
            height: image height (DEFAULT: 96)
            width: image width (DEFAULT: 96)
            data_folder: path to the folder with data, can be relative to user (DEFAULT: "~/data/stl10")
            split: One of 4 splits {‘train’, ‘test’, ‘unlabeled’, ‘train+unlabeled’} (DEFAULT: "train")
            name: Name of the module (DEFAULT: None)
            batch_size: size of batch (DEFAULT: 64) [PARAMETER OF DATALOADER]
            shuffle: shuffle data (DEFAULT: True) [PARAMETER OF DATALOADER]
        """
        # Call the base class constructor of DataLayer.
        DataLayerNM.__init__(self, name=name)

        # Store height and width.
        self._height = height
        self._width = width

        # Create transformations: up-scale and transform to tensors.
        STL10_transforms = Compose([Resize((self._height, self._width)), ToTensor()])

        # Get absolute path.
        abs_data_folder = expanduser(data_folder)

        # Create the STL10 dataset object.
        self._dataset = STL10(root=abs_data_folder, split=split, download=True, transform=STL10_transforms)

        # Remember the params passed to DataLoader. :]
        self._batch_size = batch_size
        self._shuffle = shuffle

        # Class names.
        labels = 'airplane bird car cat deer dog horse monkey ship truck'.split(' ')
        word_to_ix = {labels[i]: i for i in range(10)}

        # Reverse mapping.
        self._ix_to_word = {value: key for (key, value) in word_to_ix.items()}