Python torch.utils.data.IterableDataset() Examples

The following are 7 code examples of torch.utils.data.IterableDataset(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch.utils.data , or try the search function .
Example #1
Source File: data_silo.py    From FARM with Apache License 2.0 8 votes vote down vote up
def __iter__(self):
        #  With IterableDataset, the same __iter__ is copied over to the multiple workers of
        #  a Dataloader. Hence, we need to configure the __iter__ to not yield duplicated data
        #  when more than 1 workers are used.
        #
        #  To avoid duplicates, we need to split the input dicts between the workers.
        #  The grouper() converts a dict generator given as input and yields only the
        #  dicts that are to be processed by the given worker_id.
        #
        #  For instance, consider input as [dictA, dictB, dictC, ...], then the grouper
        #  (with n=2) will return, [[dictA, dictB], [dictE, dictF] ...] for worker 1 and
        #  [[dictC, dictD], [dictG, dictH] ...] for worker 2.

        worker_info = torch.utils.data.get_worker_info()
        if self.distributed:
            worker_id = self.rank * worker_info.num_workers + worker_info.id
            total_workers = self.world_size * worker_info.num_workers
        else:
            worker_id = worker_info.id
            total_workers = self.dataloader_workers

        dicts = grouper(self.file_to_dicts_generator, n=10, worker_id=worker_id, total_workers=total_workers)
        results = map(self._dataset_from_chunk, dicts)

        batch = []
        for datasets, tensor_names in results:
            if not datasets:
                continue
            self.tensor_names = tensor_names
            for ds in datasets:
                batch.append(ds)
                if len(batch) == self.batch_size:
                    yield batch
                    batch = []
        if batch:
            yield batch 
Example #2
Source File: data_loading.py    From pytorch-lightning with Apache License 2.0 6 votes vote down vote up
def _has_len(dataloader: DataLoader) -> bool:
    """ Checks if a given Dataloader has __len__ method implemented i.e. if
    it is a finite dataloader or infinite dataloader. """

    try:
        # try getting the length
        if len(dataloader) == 0:
            raise ValueError('`Dataloader` returned 0 length.'
                             ' Please make sure that your Dataloader at least returns 1 batch')
        has_len = True
    except TypeError:
        has_len = False
    except NotImplementedError:  # e.g. raised by torchtext if a batch_size_fn is used
        has_len = False

    if has_len and _has_iterable_dataset(dataloader) and LooseVersion(torch.__version__) >= LooseVersion("1.4.0"):
        rank_zero_warn(
            'Your `IterableDataset` has `__len__` defined.'
            ' In combination with multi-processing data loading (e.g. batch size > 1),'
            ' this can lead to unintended side effects since the samples will be duplicated.'
        )
    return has_len 
Example #3
Source File: custom_dataset.py    From BiaffineDependencyParsing with MIT License 5 votes vote down vote up
def __init__(self, datasets: List[TensorDataset], probs: List[float] = None, exp: float = None, mode: str = 'exp'):
        """

        :param datasets: 各个源本身的Data Set
        :param probs: 按照概率采样,对应每个源的概率,长度等于datasets的数量
        :param exp: 按照指数平滑采样,0<exp<1
        :param mode:指示是采用概率采样还是采用指数平滑采样
        """
        super().__init__()
        assert len(datasets) > 0, 'datasets should not be an empty iterable'
        assert mode in ['prob', 'exp'], 'ConcatTensorRandomDataset mode只能为prob或者exp'
        if mode == 'prob':
            assert probs and len(probs) == len(datasets) and sum(probs) == 1
        else:
            assert exp and 0 < exp < 1
        self.datasets = list(datasets)
        self.dataset_idxs = list(range(len(self.datasets)))
        self.dataset_lens = [len(x) for x in self.datasets]
        self.original_lengths = []  # 记录每个源的原始数据长度
        for d in self.datasets:
            assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset"
            self.original_lengths.append(len(d))
        if mode == 'exp':
            original_probs = self.original_lengths / np.sum(self.original_lengths)
            # 指数加权
            probs_exp = original_probs ** exp
            # softmax
            pes = np.exp(probs_exp)
            self.probs = pes / np.sum(pes)
        else:
            assert isinstance(probs, list) and probs
            self.probs = np.array(probs)
        self.sample_total_length = np.sum(self.original_lengths * self.probs) 
Example #4
Source File: distributed_torch_runner.py    From ray with Apache License 2.0 5 votes vote down vote up
def _wrap_dataloaders(self):
        def with_sampler(loader):
            # Automatically set the DistributedSampler
            data_loader_args = {
                "dataset": loader.dataset,
                "batch_size": loader.batch_size,
                "shuffle": False,
                "num_workers": loader.num_workers,
                "collate_fn": loader.collate_fn,
                "pin_memory": loader.pin_memory,
                "drop_last": loader.drop_last,
                "timeout": loader.timeout,
                "worker_init_fn": loader.worker_init_fn,
                "sampler": DistributedSampler(loader.dataset)
            }
            return DataLoader(**data_loader_args)

        def should_wrap_dataloader(loader):
            return (isinstance(loader, DataLoader)
                    and not isinstance(loader.dataset, IterableDataset))

        if should_wrap_dataloader(self.train_loader):
            if self.add_dist_sampler:
                self.train_loader = with_sampler(self.train_loader)

        if self.validation_loader and should_wrap_dataloader(
                self.validation_loader):
            if self.add_dist_sampler:
                self.validation_loader = with_sampler(self.validation_loader) 
Example #5
Source File: data_loading.py    From pytorch-lightning with Apache License 2.0 5 votes vote down vote up
def _has_iterable_dataset(dataloader: DataLoader):
    return ITERABLE_DATASET_EXISTS and hasattr(dataloader, 'dataset') \
        and isinstance(dataloader.dataset, IterableDataset) 
Example #6
Source File: iterators.py    From ReAgent with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def __init__(self, dataloader: IterableDataset, dataloader_size: int):
        """ Wraps around an Iterable Dataloader to report progress bars and
        increase global step of SummaryWriter. At last iteration, will call
        dataloader.__exit__ if needed (e.g. Petastorm DataLoader).

        Args:
            dataloader: the iteratable dataloader to wrap around
            dataloader_size: size of the dataset we're iterating over
        """

        self.dataloader = dataloader
        self.dataloader_iter = iter(dataloader)
        self.dataloader_size = dataloader_size 
Example #7
Source File: data_loading.py    From pytorch-lightning with Apache License 2.0 4 votes vote down vote up
def reset_train_dataloader(self, model: LightningModule) -> None:
        """Resets the train dataloader and initialises required variables
        (number of batches, when to validate, etc.).

        Args:
            model: The current `LightningModule`
        """
        self.train_dataloader = self.request_dataloader(model.train_dataloader)

        self.num_training_batches = 0

        # automatically add samplers
        self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True)

        self._worker_check(self.train_dataloader, 'train dataloader')
        self._check_batch_limits('limit_train_batches')

        if not _has_len(self.train_dataloader):
            self.num_training_batches = float('inf')
        else:
            # try getting the length
            if isinstance(self.limit_train_batches, float):
                self.num_training_batches = len(self.train_dataloader)
                self.num_training_batches = int(self.num_training_batches * self.limit_train_batches)
            else:
                self.num_training_batches = self.limit_train_batches

        # determine when to check validation
        # if int passed in, val checks that often
        # otherwise, it checks in [0, 1.0] % range of a training epoch
        if isinstance(self.val_check_interval, int):
            self.val_check_batch = self.val_check_interval
            if self.val_check_batch > self.num_training_batches:
                raise ValueError(
                    f'`val_check_interval` ({self.val_check_interval}) must be less than or equal '
                    f'to the number of the training batches ({self.num_training_batches}). '
                    'If you want to disable validation set `limit_val_batches` to 0.0 instead.')
        else:
            if not _has_len(self.train_dataloader):
                if self.val_check_interval == 1.0:
                    self.val_check_batch = float('inf')
                else:
                    raise MisconfigurationException(
                        'When using an infinite DataLoader (e.g. with an IterableDataset'
                        ' or when DataLoader does not implement `__len__`) for `train_dataloader`,'
                        ' `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies'
                        ' checking validation every k training batches.')
            else:
                self._check_batch_limits('val_check_interval')

                self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
                self.val_check_batch = max(1, self.val_check_batch)