Python torch.utils.data.distributed.DistributedSampler() Examples

The following are 30 code examples of torch.utils.data.distributed.DistributedSampler(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch.utils.data.distributed , or try the search function .
Example #1
Source File: train.py    From tn2-wg with BSD 3-Clause "New" or "Revised" License 8 votes vote down vote up
def prepare_dataloaders(hparams):
    # Get data, data loaders and collate function ready
    trainset = TextMelLoader(hparams.training_files, hparams)
    valset = TextMelLoader(hparams.validation_files, hparams)
    collate_fn = TextMelCollate(hparams.n_frames_per_step)

    if hparams.distributed_run:
        train_sampler = DistributedSampler(trainset)
        shuffle = False
    else:
        train_sampler = None
        shuffle = True

    train_loader = DataLoader(trainset, num_workers=1, shuffle=shuffle,
                              sampler=train_sampler,
                              batch_size=hparams.batch_size, pin_memory=False,
                              drop_last=True, collate_fn=collate_fn)
    return train_loader, valset, collate_fn 
Example #2
Source File: dataloader.py    From imagenet18_old with The Unlicense 7 votes vote down vote up
def get_loaders(traindir, valdir, sz, bs, fp16=True, val_bs=None, workers=8, rect_val=False, min_scale=0.08, distributed=False):
    val_bs = val_bs or bs
    train_tfms = [
            transforms.RandomResizedCrop(sz, scale=(min_scale, 1.0)),
            transforms.RandomHorizontalFlip()
    ]
    train_dataset = datasets.ImageFolder(traindir, transforms.Compose(train_tfms))
    train_sampler = (DistributedSampler(train_dataset, num_replicas=env_world_size(), rank=env_rank()) if distributed else None)

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=bs, shuffle=(train_sampler is None),
        num_workers=workers, pin_memory=True, collate_fn=fast_collate, 
        sampler=train_sampler)

    val_dataset, val_sampler = create_validation_set(valdir, val_bs, sz, rect_val=rect_val, distributed=distributed)
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        num_workers=workers, pin_memory=True, collate_fn=fast_collate, 
        batch_sampler=val_sampler)

    train_loader = BatchTransformDataLoader(train_loader, fp16=fp16)
    val_loader = BatchTransformDataLoader(val_loader, fp16=fp16)

    return train_loader, val_loader, train_sampler, val_sampler 
Example #3
Source File: data_loading.py    From pytorch-lightning with Apache License 2.0 6 votes vote down vote up
def _get_distributed_sampler(self, dataloader):
        if self.use_tpu:
            kwargs = dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
        elif self.use_horovod:
            kwargs = dict(num_replicas=hvd.size(), rank=hvd.rank())
        else:
            world_size = {
                'ddp': self.num_nodes * self.num_processes,
                'ddp_spawn': self.num_nodes * self.num_processes,
                'ddp2': self.num_nodes,
                'ddp_cpu': self.num_processes * self.num_nodes
            }
            assert self.distributed_backend is not None
            kwargs = dict(num_replicas=world_size[self.distributed_backend], rank=self.global_rank)
        sampler = DistributedSampler(dataloader.dataset, **kwargs)
        return sampler 
Example #4
Source File: run_joint_span.py    From SpanABSA with Apache License 2.0 6 votes vote down vote up
def read_eval_data(args, tokenizer, logger):
    eval_path = os.path.join(args.data_dir, args.predict_file)
    eval_set = read_absa_data(eval_path)
    eval_examples = convert_absa_data(dataset=eval_set, verbose_logging=args.verbose_logging)

    eval_features = convert_examples_to_features(eval_examples, tokenizer, args.max_seq_length,
                                                 args.verbose_logging, logger)

    logger.info("Num orig examples = %d", len(eval_examples))
    logger.info("Num split features = %d", len(eval_features))
    logger.info("Batch size = %d", args.predict_batch_size)
    all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
    all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
    eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index)
    if args.local_rank == -1:
        eval_sampler = SequentialSampler(eval_data)
    else:
        eval_sampler = DistributedSampler(eval_data)
    eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.predict_batch_size)
    return eval_examples, eval_features, eval_dataloader 
Example #5
Source File: runners.py    From bert_on_stilts with Apache License 2.0 6 votes vote down vote up
def get_train_dataloader(self, train_examples, verbose=True):
        train_features = convert_examples_to_features(
            train_examples, self.label_map, self.rparams.max_seq_length, self.tokenizer,
            verbose=verbose,
        )
        train_data, train_tokens = convert_to_dataset(
            train_features, label_mode=get_label_mode(self.label_map),
        )
        if self.rparams.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(
            train_data, sampler=train_sampler, batch_size=self.rparams.train_batch_size,
        )
        return HybridLoader(train_dataloader, train_tokens) 
Example #6
Source File: setup_utils.py    From tape with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def setup_loader(dataset: Dataset,
                 batch_size: int,
                 local_rank: int,
                 n_gpu: int,
                 gradient_accumulation_steps: int,
                 num_workers: int) -> DataLoader:
    sampler = DistributedSampler(dataset) if local_rank != -1 else RandomSampler(dataset)
    batch_size = get_effective_batch_size(
        batch_size, local_rank, n_gpu, gradient_accumulation_steps) * n_gpu
    # WARNING: this will fail if the primary sequence is not the first thing the dataset returns
    batch_sampler = BucketBatchSampler(
        sampler, batch_size, False, lambda x: len(x[0]), dataset)

    loader = DataLoader(
        dataset,
        num_workers=num_workers,
        collate_fn=dataset.collate_fn,  # type: ignore
        batch_sampler=batch_sampler)

    return loader 
Example #7
Source File: loader.py    From SlowFast with Apache License 2.0 6 votes vote down vote up
def shuffle_dataset(loader, cur_epoch):
    """"
    Shuffles the data.
    Args:
        loader (loader): data loader to perform shuffle.
        cur_epoch (int): number of the current epoch.
    """
    sampler = (
        loader.batch_sampler.sampler
        if isinstance(loader.batch_sampler, ShortCycleBatchSampler)
        else loader.sampler
    )
    assert isinstance(
        sampler, (RandomSampler, DistributedSampler)
    ), "Sampler type '{}' not supported".format(type(sampler))
    # RandomSampler handles shuffling automatically
    if isinstance(sampler, DistributedSampler):
        # DistributedSampler shuffles data based on epoch
        sampler.set_epoch(cur_epoch) 
Example #8
Source File: loader.py    From pycls with MIT License 6 votes vote down vote up
def _construct_loader(dataset_name, split, batch_size, shuffle, drop_last):
    """Constructs the data loader for the given dataset."""
    err_str = "Dataset '{}' not supported".format(dataset_name)
    assert dataset_name in _DATASETS and dataset_name in _PATHS, err_str
    # Retrieve the data path for the dataset
    data_path = os.path.join(_DATA_DIR, _PATHS[dataset_name])
    # Construct the dataset
    dataset = _DATASETS[dataset_name](data_path, split)
    # Create a sampler for multi-process training
    sampler = DistributedSampler(dataset) if cfg.NUM_GPUS > 1 else None
    # Create a loader
    loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=(False if sampler else shuffle),
        sampler=sampler,
        num_workers=cfg.DATA_LOADER.NUM_WORKERS,
        pin_memory=cfg.DATA_LOADER.PIN_MEMORY,
        drop_last=drop_last,
    )
    return loader 
Example #9
Source File: dataset.py    From jdit with Apache License 2.0 6 votes vote down vote up
def convert_to_distributed(self, which_dataset=None, num_replicas=None, rank=None):
        samplers = {}
        if which_dataset is None:
            samplers["train"] = DistributedSampler(self.dataset_train, num_replicas=None, rank=None)
            self.loader_train = DataLoader(self.dataset_train, self.batch_size, False, sampler=samplers["train"])

        else:
            if which_dataset == "train":
                samplers["train"] = DistributedSampler(self.dataset_train, num_replicas=num_replicas, rank=rank)
                self.loader_train = DataLoader(self.dataset_train, self.batch_size, False,
                                               sampler=samplers["train"])
            elif which_dataset == "valid":
                samplers["valid"] = DistributedSampler(self.dataset_valid, num_replicas=num_replicas, rank=rank)
                self.loader_valid = DataLoader(self.dataset_valid, self.batch_size, False,
                                               sampler=samplers["valid"])
            elif which_dataset == "test":
                self.loader_test.sampler = samplers["test"]
                self.loader_test = DataLoader(self.dataset_test, self.batch_size, False,
                                              sampler=samplers["test"])
            else:
                ValueError(
                    "param `which_dataset` can only be set 'train, valid and test'. Got %s instead" % which_dataset)
        return samplers 
Example #10
Source File: data.py    From convNet.pytorch with MIT License 6 votes vote down vote up
def get_loader(self, force_update=False, override_settings=None, subset_indices=None):
        if force_update or self.regime.update(self.epoch, self.steps):
            setting = self.get_setting()
            if override_settings is not None:
                setting.update(override_settings)
            self._transform = get_transform(**setting['transform'])
            setting['data'].setdefault('transform', self._transform)
            self._data = get_dataset(**setting['data'])
            if subset_indices is not None:
                self._data = Subset(self._data, subset_indices)
            if setting['other'].get('distributed', False):
                setting['loader']['sampler'] = DistributedSampler(self._data)
                setting['loader']['shuffle'] = None
                # pin-memory currently broken for distributed
                setting['loader']['pin_memory'] = False
            self._sampler = setting['loader'].get('sampler', None)
            self._loader = torch.utils.data.DataLoader(
                self._data, **setting['loader'])
        return self._loader 
Example #11
Source File: dataloader.py    From pykaldi2 with MIT License 6 votes vote down vote up
def __init__(self, dataset, batch_size, distributed=False, num_workers=0, timeout=1000):
 
        if not distributed: 
            super(ChunkDataloader, self).__init__(dataset,
                                              batch_size=batch_size,
                                              shuffle=True,
                                              num_workers=num_workers,
                                              collate_fn=self.collate_fn)
        else:
            import horovod.torch as hvd
            sampler = DistributedSampler(dataset, num_replicas=hvd.size(), rank=hvd.rank())
            super(ChunkDataloader, self).__init__(dataset,
                                           batch_size=batch_size,
                                           sampler=sampler,
                                           num_workers=num_workers,
                                           collate_fn=self.collate_fn,
                                           drop_last=False,
                                           timeout=timeout) 
Example #12
Source File: dataloader.py    From pykaldi2 with MIT License 6 votes vote down vote up
def __init__(self, dataset, batch_size, num_workers=0, distributed=False, test_only=False, timeout=1000):
        
        self.test_only = test_only
 
        # now decide on a sampler
        #base_sampler = torch.utils.data.SequentialSampler(self.dataset)
        base_sampler = torch.utils.data.RandomSampler(dataset)
        
        if not distributed:
            sampler = torch.utils.data.BatchSampler(base_sampler, batch_size, False)
            super(SeqDataloader, self).__init__(dataset,
                                           batch_sampler=sampler,
                                           num_workers=num_workers,
                                           collate_fn=self.collate_fn)
        else:
            import horovod.torch as hvd
            sampler = DistributedSampler(dataset, num_replicas=hvd.size(), rank=hvd.rank())
            super(SeqDataloader, self).__init__(dataset,
                                           batch_size=batch_size, 
                                           sampler=sampler, 
                                           num_workers=num_workers, 
                                           collate_fn=self.collate_fn, 
                                           drop_last=False,
                                           timeout=timeout) 
Example #13
Source File: classy_dataset.py    From ClassyVision with MIT License 6 votes vote down vote up
def _get_sampler(self, epoch: int):
        """
        Return a :class:`torch.utils.data.sampler.Sampler` to sample the data.

        This is used to distribute the data across the replicas. If shuffling
        is enabled, every epoch will have a different shuffle.

        Args:
            epoch: The epoch being fetched.

        Returns:
            A sampler which tells the data loader which sample to load next.
        """
        world_size = get_world_size()
        rank = get_rank()
        sampler = DistributedSampler(
            self, num_replicas=world_size, rank=rank, shuffle=self.shuffle
        )
        sampler.set_epoch(epoch)
        return sampler 
Example #14
Source File: train_ppg2mel.py    From fac-via-ppg with Apache License 2.0 6 votes vote down vote up
def prepare_dataloaders(hparams):
    # Get data, data loaders and collate function ready
    trainset = PPGMelLoader(hparams.training_files, hparams)
    hparams.load_feats_from_disk = False
    hparams.is_cache_feats = False
    hparams.feats_cache_path = ''
    valset = PPGMelLoader(hparams.validation_files, hparams)

    collate_fn = ppg_acoustics_collate

    train_sampler = DistributedSampler(trainset) \
        if hparams.distributed_run else None

    train_loader = DataLoader(trainset, num_workers=1, shuffle=True,
                              sampler=train_sampler,
                              batch_size=hparams.batch_size, pin_memory=False,
                              drop_last=True, collate_fn=collate_fn)
    return train_loader, valset, collate_fn 
Example #15
Source File: data_loading.py    From pytorch-lightning with Apache License 2.0 6 votes vote down vote up
def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:

        # don't do anything if it's not a dataloader
        is_dataloader = isinstance(dataloader, DataLoader)
        # don't manipulate iterable datasets
        is_iterable_ds = _has_iterable_dataset(dataloader)

        if not is_dataloader or is_iterable_ds:
            return dataloader
        need_dist_sampler = (self.use_ddp or self.use_ddp2 or self.use_horovod or self.use_tpu)

        if self.replace_sampler_ddp and need_dist_sampler:
            if not isinstance(dataloader.sampler, (SequentialSampler, RandomSampler)):
                raise MisconfigurationException(
                    'You seem to have configured a sampler in your DataLoader. This will be replaced '
                    ' by `DistributedSampler` since `replace_sampler_ddp` is True and you are using'
                    ' distributed training. Either remove the sampler from your DataLoader or set'
                    ' `replace_sampler_ddp`=False if you want to use your custom sampler.')

            # replace with distributed sampler
            sampler = self._get_distributed_sampler(dataloader)
            dataloader = self.replace_sampler(dataloader, sampler)

        return dataloader 
Example #16
Source File: __init__.py    From ignite with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def setup_sampler(sampler_type, num_iters, batch_size):
    if sampler_type is None:
        return None, batch_size

    if sampler_type == "weighted":
        from torch.utils.data.sampler import WeightedRandomSampler

        w = torch.ones(num_iters * batch_size, dtype=torch.float)
        for i in range(num_iters):
            w[batch_size * i : batch_size * (i + 1)] += i * 1.0
        return WeightedRandomSampler(w, num_samples=num_iters * batch_size, replacement=True), batch_size

    if sampler_type == "distributed":
        from torch.utils.data.distributed import DistributedSampler
        import torch.distributed as dist

        num_replicas = 1
        rank = 0
        if dist.is_available() and dist.is_initialized():
            num_replicas = dist.get_world_size()
            rank = dist.get_rank()

        dataset = torch.zeros(num_iters * batch_size)
        return DistributedSampler(dataset, num_replicas=num_replicas, rank=rank), batch_size // num_replicas 
Example #17
Source File: test_common.py    From ignite with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def test_asserts_setup_common_training_handlers():
    trainer = Engine(lambda e, b: None)

    with pytest.raises(
        ValueError,
        match=r"If to_save argument is provided then output_path or save_handler arguments should be also defined",
    ):
        setup_common_training_handlers(trainer, to_save={})

    with pytest.raises(ValueError, match=r"Arguments output_path and save_handler are mutually exclusive"):
        setup_common_training_handlers(trainer, to_save={}, output_path="abc", save_handler=lambda c, f, m: None)

    with pytest.warns(UserWarning, match=r"Argument train_sampler is a distributed sampler"):
        train_sampler = MagicMock(spec=DistributedSampler)
        setup_common_training_handlers(trainer, train_sampler=train_sampler)

    with pytest.warns(UserWarning, match=r"Argument device is unused and deprecated"):
        setup_common_training_handlers(trainer, device="cpu") 
Example #18
Source File: run_extract_span.py    From SpanABSA with Apache License 2.0 6 votes vote down vote up
def read_eval_data(args, tokenizer, logger):
    eval_path = os.path.join(args.data_dir, args.predict_file)
    eval_set = read_absa_data(eval_path)
    eval_examples = convert_absa_data(dataset=eval_set, verbose_logging=args.verbose_logging)

    eval_features = convert_examples_to_features(eval_examples, tokenizer, args.max_seq_length,
                                                 args.verbose_logging, logger)

    logger.info("Num orig examples = %d", len(eval_examples))
    logger.info("Num split features = %d", len(eval_features))
    logger.info("Batch size = %d", args.predict_batch_size)
    all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
    all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
    eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index)
    if args.local_rank == -1:
        eval_sampler = SequentialSampler(eval_data)
    else:
        eval_sampler = DistributedSampler(eval_data)
    eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.predict_batch_size)
    return eval_examples, eval_features, eval_dataloader 
Example #19
Source File: train.py    From nonparaSeq2seqVC_code with MIT License 6 votes vote down vote up
def prepare_dataloaders(hparams):
    # Get data, data loaders and collate function ready
    #pids = [id2sp[hparams.speaker_A], id2sp[hparams.speaker_B]]
    trainset = TextMelIDLoader(hparams.training_list, hparams.mel_mean_std, 
        hparams.speaker_A, hparams.speaker_B, pids=None)
    valset = TextMelIDLoader(hparams.validation_list, hparams.mel_mean_std,
        hparams.speaker_A, hparams.speaker_B, pids=None)
    collate_fn = TextMelIDCollate(lcm(hparams.n_frames_per_step_encoder,
                                      hparams.n_frames_per_step_decoder))

    train_sampler = DistributedSampler(trainset) \
        if hparams.distributed_run else None

    train_loader = DataLoader(trainset, num_workers=1, shuffle=True,
                              sampler=train_sampler,
                              batch_size=hparams.batch_size, pin_memory=False,
                              drop_last=True, collate_fn=collate_fn)
    return train_loader, valset, collate_fn 
Example #20
Source File: runners.py    From bert_on_stilts with Apache License 2.0 6 votes vote down vote up
def get_train_dataloader(self, train_examples, verbose=True):
        train_features = convert_examples_to_features(
            train_examples, self.label_map, self.rparams.max_seq_length, self.tokenizer,
            verbose=verbose,
        )
        train_data, train_tokens = convert_to_dataset(
            train_features, label_mode=get_label_mode(self.label_map),
        )
        if self.rparams.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(
            train_data, sampler=train_sampler, batch_size=self.rparams.train_batch_size,
        )
        return HybridLoader(train_dataloader, train_tokens) 
Example #21
Source File: runners.py    From bert_on_stilts with Apache License 2.0 6 votes vote down vote up
def get_train_dataloader(self, train_examples, verbose=True):
        train_features = convert_examples_to_features(
            examples=train_examples,
            max_seq_length=self.rparams.max_seq_length,
            tokenizer=self.tokenizer,
            select_prob=self.rparams.select_prob,
            verbose=verbose,
        )
        train_data, train_tokens = convert_to_dataset(train_features)
        if self.rparams.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(
            train_data, sampler=train_sampler, batch_size=self.rparams.train_batch_size,
        )
        return HybridLoader(train_dataloader, train_tokens) 
Example #22
Source File: test_auto.py    From ignite with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def _test_auto_dataloader(ws, nproc, batch_size, num_workers=1, sampler_name=None, dl_type=DataLoader):

    data = torch.rand(100, 3, 12, 12)

    if sampler_name is None:
        sampler = None
    elif sampler_name == "WeightedRandomSampler":
        sampler = WeightedRandomSampler(weights=torch.ones(100), num_samples=100)
    else:
        raise RuntimeError("Unknown sampler name: {}".format(sampler_name))

    # Test auto_dataloader
    assert idist.get_world_size() == ws
    dataloader = auto_dataloader(
        data, batch_size=batch_size, num_workers=num_workers, sampler=sampler, shuffle=sampler is None
    )

    assert isinstance(dataloader, dl_type)
    if hasattr(dataloader, "_loader"):
        dataloader = dataloader._loader
    if ws < batch_size:
        assert dataloader.batch_size == batch_size // ws
    else:
        assert dataloader.batch_size == batch_size
    if ws <= num_workers:
        assert dataloader.num_workers == (num_workers + nproc - 1) // nproc
    else:
        assert dataloader.num_workers == num_workers

    if ws < 2:
        sampler_type = RandomSampler if sampler is None else type(sampler)
        assert isinstance(dataloader.sampler, sampler_type)
    else:
        sampler_type = DistributedSampler if sampler is None else DistributedProxySampler
        assert isinstance(dataloader.sampler, sampler_type)
    if isinstance(dataloader, DataLoader):
        assert dataloader.pin_memory == ("cuda" in idist.device().type) 
Example #23
Source File: run_triviaqa_wiki_full_e2e.py    From RE3QA with Apache License 2.0 5 votes vote down vote up
def build_eval_data(args, eval_examples, eval_features, filtered_eval_features, filtered_rank_logits, logger):
    predict_batch_size_for_rank = 2 * args.predict_batch_size

    logger.info("Num orig examples = %d", len(eval_examples))
    logger.info("Num split features = %d", len(eval_features))
    logger.info("Num split filtered features = %d", len(filtered_eval_features))
    logger.info("Batch size for ranker = %d", predict_batch_size_for_rank)
    logger.info("Batch size for reader = %d", args.predict_batch_size)

    all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
    all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
    eval_rank_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index)
    if args.local_rank == -1:
        eval_rank_sampler = SequentialSampler(eval_rank_data)
    else:
        eval_rank_sampler = DistributedSampler(eval_rank_data)
    eval_rank_dataloader = DataLoader(eval_rank_data, sampler=eval_rank_sampler, batch_size=predict_batch_size_for_rank)

    all_input_ids = torch.tensor([f.input_ids for f in filtered_eval_features], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in filtered_eval_features], dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in filtered_eval_features], dtype=torch.long)
    all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
    eval_read_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index)
    if args.local_rank == -1:
        eval_read_sampler = SequentialSampler(eval_read_data)
    else:
        eval_read_sampler = DistributedSampler(eval_read_data)
    eval_read_dataloader = DataLoader(eval_read_data, sampler=eval_read_sampler, batch_size=args.predict_batch_size)
    return eval_examples, eval_features, filtered_eval_features, filtered_rank_logits, eval_rank_dataloader, \
           eval_read_dataloader 
Example #24
Source File: run_swag.py    From KagNet with MIT License 5 votes vote down vote up
def get_train_dataloader(train_features, args):
    all_input_ids = torch.tensor(select_field(train_features, 'input_ids'), dtype=torch.long)
    all_input_mask = torch.tensor(select_field(train_features, 'input_mask'), dtype=torch.long)
    all_segment_ids = torch.tensor(select_field(train_features, 'segment_ids'), dtype=torch.long)
    all_label = torch.tensor([f.label for f in train_features], dtype=torch.long)
    train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label)
    if args.local_rank == -1:
        train_sampler = RandomSampler(train_data)
    else:
        train_sampler = DistributedSampler(train_data)
    train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
    return train_dataloader 
Example #25
Source File: run_joint_span.py    From SpanABSA with Apache License 2.0 5 votes vote down vote up
def read_train_data(args, tokenizer, logger):
    train_path = os.path.join(args.data_dir, args.train_file)
    train_set = read_absa_data(train_path)
    train_examples = convert_absa_data(dataset=train_set, verbose_logging=args.verbose_logging)
    train_features = convert_examples_to_features(train_examples, tokenizer, args.max_seq_length,
                                                  args.verbose_logging, logger)

    num_train_steps = int(
        len(train_features) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)
    logger.info("Num orig examples = %d", len(train_examples))
    logger.info("Num split features = %d", len(train_features))
    logger.info("Batch size = %d", args.train_batch_size)
    logger.info("Num steps = %d", num_train_steps)
    all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
    all_start_positions = torch.tensor([f.start_positions for f in train_features], dtype=torch.long)
    all_end_positions = torch.tensor([f.end_positions for f in train_features], dtype=torch.long)
    all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)

    train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_start_positions, all_end_positions, all_example_index)
    if args.local_rank == -1:
        train_sampler = RandomSampler(train_data)
    else:
        train_sampler = DistributedSampler(train_data)
    train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
    return train_examples, train_features, train_dataloader, num_train_steps 
Example #26
Source File: run_sequence_level_classification.py    From ZEN with Apache License 2.0 5 votes vote down vote up
def evaluate(args, model, tokenizer, ngram_dict, processor, label_list):
    eval_dataset = load_examples(args, tokenizer, ngram_dict, processor, label_list, mode="test")
    # Run prediction for full data
    if args.local_rank == -1:
        eval_sampler = SequentialSampler(eval_dataset)
    else:
        eval_sampler = DistributedSampler(eval_dataset)  # Note that this sampler samples randomly
    eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)

    # Eval!
    logger.info("***** Running evaluation *****")
    logger.info("  Num examples = %d", len(eval_dataset))
    logger.info("  Batch size = %d", args.eval_batch_size)

    model.eval()
    preds = []
    out_label_ids = None

    for batch in tqdm(eval_dataloader, desc="Evaluating"):
        batch = tuple(t.to(args.device) for t in batch)
        input_ids, input_mask, segment_ids, label_ids, input_ngram_ids, ngram_position_matrix, \
        ngram_lengths, ngram_seg_ids, ngram_masks = batch

        with torch.no_grad():
            logits = model(input_ids=input_ids,
                           input_ngram_ids=input_ngram_ids,
                           ngram_position_matrix=ngram_position_matrix,
                           labels=None, head_mask=None)

        if len(preds) == 0:
            preds.append(logits.detach().cpu().numpy())
            out_label_ids = label_ids.detach().cpu().numpy()
        else:
            preds[0] = np.append(preds[0], logits.detach().cpu().numpy(), axis=0)
            out_label_ids = np.append(out_label_ids, label_ids.detach().cpu().numpy(), axis=0)

    preds = np.argmax(preds[0], axis=1)
    return compute_metrics(args.task_name, preds, out_label_ids) 
Example #27
Source File: train.py    From pysot with Apache License 2.0 5 votes vote down vote up
def build_data_loader():
    logger.info("build train dataset")
    # train_dataset
    train_dataset = TrkDataset()
    logger.info("build dataset done")

    train_sampler = None
    if get_world_size() > 1:
        train_sampler = DistributedSampler(train_dataset)
    train_loader = DataLoader(train_dataset,
                              batch_size=cfg.TRAIN.BATCH_SIZE,
                              num_workers=cfg.TRAIN.NUM_WORKERS,
                              pin_memory=True,
                              sampler=train_sampler)
    return train_loader 
Example #28
Source File: run_csqa_bert.py    From KagNet with MIT License 5 votes vote down vote up
def get_train_dataloader(train_features, args):
    all_input_ids = torch.tensor(select_field(train_features, 'input_ids'), dtype=torch.long)
    all_input_mask = torch.tensor(select_field(train_features, 'input_mask'), dtype=torch.long)
    all_segment_ids = torch.tensor(select_field(train_features, 'segment_ids'), dtype=torch.long)
    all_label = torch.tensor([f.label for f in train_features], dtype=torch.long)
    train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label)
    if args.local_rank == -1:
        train_sampler = RandomSampler(train_data)
    else:
        train_sampler = DistributedSampler(train_data)
    train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
    return train_dataloader 
Example #29
Source File: extract_csqa_bert.py    From KagNet with MIT License 5 votes vote down vote up
def get_train_dataloader(train_features, args):
    all_input_ids = torch.tensor(select_field(train_features, 'input_ids'), dtype=torch.long)
    all_input_mask = torch.tensor(select_field(train_features, 'input_mask'), dtype=torch.long)
    all_segment_ids = torch.tensor(select_field(train_features, 'segment_ids'), dtype=torch.long)
    all_label = torch.tensor([f.label for f in train_features], dtype=torch.long)
    train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label)
    if args.local_rank == -1:
        train_sampler = RandomSampler(train_data)
    else:
        train_sampler = DistributedSampler(train_data)
    train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
    return train_dataloader 
Example #30
Source File: base_task.py    From Doc2EDAG with MIT License 5 votes vote down vote up
def prepare_dist_data_loader(self, dataset, batch_size, epoch=0):
        # prepare distributed data loader
        data_sampler = DistributedSampler(dataset)
        data_sampler.set_epoch(epoch)

        if self.custom_collate_fn is None:
            dataloader = DataLoader(dataset,
                                    batch_size=batch_size,
                                    sampler=data_sampler)
        else:
            dataloader = DataLoader(dataset,
                                    batch_size=batch_size,
                                    sampler=data_sampler,
                                    collate_fn=self.custom_collate_fn)
        return dataloader