Python torch.utils.data.sampler.RandomSampler() Examples

The following are 30 code examples of torch.utils.data.sampler.RandomSampler(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch.utils.data.sampler , or try the search function .
Example #1
Source File: test_data_samplers.py    From training with Apache License 2.0 6 votes vote down vote up
def test_len(self):
        batch_size = 3
        drop_uneven = True
        dataset = [i for i in range(10)]
        group_ids = [random.randint(0, 1) for _ in dataset]
        sampler = RandomSampler(dataset)

        batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size, drop_uneven)
        result = list(batch_sampler)
        self.assertEqual(len(result), len(batch_sampler))
        self.assertEqual(len(result), len(batch_sampler))

        batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size, drop_uneven)
        batch_sampler_len = len(batch_sampler)
        result = list(batch_sampler)
        self.assertEqual(len(result), batch_sampler_len)
        self.assertEqual(len(result), len(batch_sampler)) 
Example #2
Source File: test_data_samplers.py    From sampling-free with MIT License 6 votes vote down vote up
def test_len(self):
        batch_size = 3
        drop_uneven = True
        dataset = [i for i in range(10)]
        group_ids = [random.randint(0, 1) for _ in dataset]
        sampler = RandomSampler(dataset)

        batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size, drop_uneven)
        result = list(batch_sampler)
        self.assertEqual(len(result), len(batch_sampler))
        self.assertEqual(len(result), len(batch_sampler))

        batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size, drop_uneven)
        batch_sampler_len = len(batch_sampler)
        result = list(batch_sampler)
        self.assertEqual(len(result), batch_sampler_len)
        self.assertEqual(len(result), len(batch_sampler)) 
Example #3
Source File: test_data_samplers.py    From HRNet-MaskRCNN-Benchmark with MIT License 6 votes vote down vote up
def test_len(self):
        batch_size = 3
        drop_uneven = True
        dataset = [i for i in range(10)]
        group_ids = [random.randint(0, 1) for _ in dataset]
        sampler = RandomSampler(dataset)

        batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size, drop_uneven)
        result = list(batch_sampler)
        self.assertEqual(len(result), len(batch_sampler))
        self.assertEqual(len(result), len(batch_sampler))

        batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size, drop_uneven)
        batch_sampler_len = len(batch_sampler)
        result = list(batch_sampler)
        self.assertEqual(len(result), batch_sampler_len)
        self.assertEqual(len(result), len(batch_sampler)) 
Example #4
Source File: sampler.py    From deep-person-reid with MIT License 6 votes vote down vote up
def build_train_sampler(
    data_source, train_sampler, batch_size=32, num_instances=4, **kwargs
):
    """Builds a training sampler.

    Args:
        data_source (list): contains tuples of (img_path(s), pid, camid).
        train_sampler (str): sampler name (default: ``RandomSampler``).
        batch_size (int, optional): batch size. Default is 32.
        num_instances (int, optional): number of instances per identity in a
            batch (when using ``RandomIdentitySampler``). Default is 4.
    """
    assert train_sampler in AVAI_SAMPLERS, \
        'train_sampler must be one of {}, but got {}'.format(AVAI_SAMPLERS, train_sampler)

    if train_sampler == 'RandomIdentitySampler':
        sampler = RandomIdentitySampler(data_source, batch_size, num_instances)

    elif train_sampler == 'SequentialSampler':
        sampler = SequentialSampler(data_source)

    elif train_sampler == 'RandomSampler':
        sampler = RandomSampler(data_source)

    return sampler 
Example #5
Source File: test_data_samplers.py    From maskrcnn-benchmark with MIT License 6 votes vote down vote up
def test_len(self):
        batch_size = 3
        drop_uneven = True
        dataset = [i for i in range(10)]
        group_ids = [random.randint(0, 1) for _ in dataset]
        sampler = RandomSampler(dataset)

        batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size, drop_uneven)
        result = list(batch_sampler)
        self.assertEqual(len(result), len(batch_sampler))
        self.assertEqual(len(result), len(batch_sampler))

        batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size, drop_uneven)
        batch_sampler_len = len(batch_sampler)
        result = list(batch_sampler)
        self.assertEqual(len(result), batch_sampler_len)
        self.assertEqual(len(result), len(batch_sampler)) 
Example #6
Source File: test_data_samplers.py    From EmbedMask with MIT License 6 votes vote down vote up
def test_len(self):
        batch_size = 3
        drop_uneven = True
        dataset = [i for i in range(10)]
        group_ids = [random.randint(0, 1) for _ in dataset]
        sampler = RandomSampler(dataset)

        batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size, drop_uneven)
        result = list(batch_sampler)
        self.assertEqual(len(result), len(batch_sampler))
        self.assertEqual(len(result), len(batch_sampler))

        batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size, drop_uneven)
        batch_sampler_len = len(batch_sampler)
        result = list(batch_sampler)
        self.assertEqual(len(result), batch_sampler_len)
        self.assertEqual(len(result), len(batch_sampler)) 
Example #7
Source File: test_data_samplers.py    From FreeAnchor with MIT License 6 votes vote down vote up
def test_len(self):
        batch_size = 3
        drop_uneven = True
        dataset = [i for i in range(10)]
        group_ids = [random.randint(0, 1) for _ in dataset]
        sampler = RandomSampler(dataset)

        batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size, drop_uneven)
        result = list(batch_sampler)
        self.assertEqual(len(result), len(batch_sampler))
        self.assertEqual(len(result), len(batch_sampler))

        batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size, drop_uneven)
        batch_sampler_len = len(batch_sampler)
        result = list(batch_sampler)
        self.assertEqual(len(result), batch_sampler_len)
        self.assertEqual(len(result), len(batch_sampler)) 
Example #8
Source File: test_data_samplers.py    From retinamask with MIT License 6 votes vote down vote up
def test_len(self):
        batch_size = 3
        drop_uneven = True
        dataset = [i for i in range(10)]
        group_ids = [random.randint(0, 1) for _ in dataset]
        sampler = RandomSampler(dataset)

        batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size, drop_uneven)
        result = list(batch_sampler)
        self.assertEqual(len(result), len(batch_sampler))
        self.assertEqual(len(result), len(batch_sampler))

        batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size, drop_uneven)
        batch_sampler_len = len(batch_sampler)
        result = list(batch_sampler)
        self.assertEqual(len(result), batch_sampler_len)
        self.assertEqual(len(result), len(batch_sampler)) 
Example #9
Source File: loader.py    From SlowFast with Apache License 2.0 6 votes vote down vote up
def shuffle_dataset(loader, cur_epoch):
    """"
    Shuffles the data.
    Args:
        loader (loader): data loader to perform shuffle.
        cur_epoch (int): number of the current epoch.
    """
    sampler = (
        loader.batch_sampler.sampler
        if isinstance(loader.batch_sampler, ShortCycleBatchSampler)
        else loader.sampler
    )
    assert isinstance(
        sampler, (RandomSampler, DistributedSampler)
    ), "Sampler type '{}' not supported".format(type(sampler))
    # RandomSampler handles shuffling automatically
    if isinstance(sampler, DistributedSampler):
        # DistributedSampler shuffles data based on epoch
        sampler.set_epoch(cur_epoch) 
Example #10
Source File: test_data_samplers.py    From R2CNN.pytorch with MIT License 6 votes vote down vote up
def test_len(self):
        batch_size = 3
        drop_uneven = True
        dataset = [i for i in range(10)]
        group_ids = [random.randint(0, 1) for _ in dataset]
        sampler = RandomSampler(dataset)

        batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size, drop_uneven)
        result = list(batch_sampler)
        self.assertEqual(len(result), len(batch_sampler))
        self.assertEqual(len(result), len(batch_sampler))

        batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size, drop_uneven)
        batch_sampler_len = len(batch_sampler)
        result = list(batch_sampler)
        self.assertEqual(len(result), batch_sampler_len)
        self.assertEqual(len(result), len(batch_sampler)) 
Example #11
Source File: Dataloader.py    From SSD-Pytorch with Apache License 2.0 6 votes vote down vote up
def Our_Dataloader(dataset,batch_size,shuffle=True,num_workers=2,drop_last=True,max_iteration=100000000):
    """
    几近无限迭代器,迭代次数为1亿次,每次迭代输出一个批次的数据.
    :param dataset:         数据集
    :param batch_size:      批次数
    :param max_iteration:   迭代的总次数,默认1亿次,具体迭代次数,在取数据时进行判断会更为灵活
    :param shuffle:
    :param num_workers:
    :param drop_last:
    :return:
    """
    if shuffle:
        sampler = RandomSampler(dataset)        # 随机采样器
    else:
        sampler = SequentialSampler(dataset)    # 顺序采样器
    batch_sampler = BatchSampler_Our(sampler=sampler,
                                     batch_size=batch_size,
                                     max_iteration=max_iteration,
                                     drop_last=drop_last)
    loader = DataLoader(dataset=dataset,batch_sampler=batch_sampler,num_workers=num_workers,collate_fn=BatchCollator(is_train=dataset.is_train))
    return loader 
Example #12
Source File: test_data_samplers.py    From NAS-FCOS with BSD 2-Clause "Simplified" License 6 votes vote down vote up
def test_len(self):
        batch_size = 3
        drop_uneven = True
        dataset = [i for i in range(10)]
        group_ids = [random.randint(0, 1) for _ in dataset]
        sampler = RandomSampler(dataset)

        batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size, drop_uneven)
        result = list(batch_sampler)
        self.assertEqual(len(result), len(batch_sampler))
        self.assertEqual(len(result), len(batch_sampler))

        batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size, drop_uneven)
        batch_sampler_len = len(batch_sampler)
        result = list(batch_sampler)
        self.assertEqual(len(result), batch_sampler_len)
        self.assertEqual(len(result), len(batch_sampler)) 
Example #13
Source File: test_data_samplers.py    From RRPN_pytorch with MIT License 6 votes vote down vote up
def test_len(self):
        batch_size = 3
        drop_uneven = True
        dataset = [i for i in range(10)]
        group_ids = [random.randint(0, 1) for _ in dataset]
        sampler = RandomSampler(dataset)

        batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size, drop_uneven)
        result = list(batch_sampler)
        self.assertEqual(len(result), len(batch_sampler))
        self.assertEqual(len(result), len(batch_sampler))

        batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size, drop_uneven)
        batch_sampler_len = len(batch_sampler)
        result = list(batch_sampler)
        self.assertEqual(len(result), batch_sampler_len)
        self.assertEqual(len(result), len(batch_sampler)) 
Example #14
Source File: test_data_samplers.py    From DF-Traffic-Sign-Identification with MIT License 6 votes vote down vote up
def test_len(self):
        batch_size = 3
        drop_uneven = True
        dataset = [i for i in range(10)]
        group_ids = [random.randint(0, 1) for _ in dataset]
        sampler = RandomSampler(dataset)

        batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size, drop_uneven)
        result = list(batch_sampler)
        self.assertEqual(len(result), len(batch_sampler))
        self.assertEqual(len(result), len(batch_sampler))

        batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size, drop_uneven)
        batch_sampler_len = len(batch_sampler)
        result = list(batch_sampler)
        self.assertEqual(len(result), batch_sampler_len)
        self.assertEqual(len(result), len(batch_sampler)) 
Example #15
Source File: test_data_samplers.py    From maskscoring_rcnn with MIT License 6 votes vote down vote up
def test_len(self):
        batch_size = 3
        drop_uneven = True
        dataset = [i for i in range(10)]
        group_ids = [random.randint(0, 1) for _ in dataset]
        sampler = RandomSampler(dataset)

        batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size, drop_uneven)
        result = list(batch_sampler)
        self.assertEqual(len(result), len(batch_sampler))
        self.assertEqual(len(result), len(batch_sampler))

        batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size, drop_uneven)
        batch_sampler_len = len(batch_sampler)
        result = list(batch_sampler)
        self.assertEqual(len(result), batch_sampler_len)
        self.assertEqual(len(result), len(batch_sampler)) 
Example #16
Source File: train_para_encoder.py    From Multi-Step-Reasoning with Apache License 2.0 6 votes vote down vote up
def make_data_loader(args, corpus, train_time=False):

    dataset = data.MultiCorpusDataset(
        args,
        corpus,
        args.word_dict,
        args.feature_dict,
        single_answer=False,
        para_mode=args.para_mode,
        train_time=train_time
    )
    sampler = SequentialSampler(dataset) if not train_time else RandomSampler(dataset)
    loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        sampler=sampler,
        num_workers=args.data_workers,
        collate_fn=vector.batchify(args, args.para_mode, train_time=train_time),
        pin_memory=True
    )

    return loader 
Example #17
Source File: train.py    From Multi-Step-Reasoning with Apache License 2.0 6 votes vote down vote up
def make_data_loader(args, exs, train_time=False):

    dataset = data.ReaderDataset(
        args,
        exs,
        args.word_dict,
        args.feature_dict,
        single_answer=False,
        train_time=train_time
    )
    sampler = SequentialSampler(dataset) if not train_time else RandomSampler(dataset)
    batch_size = args.batch_size if train_time else args.test_batch_size
    loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        sampler=sampler,
        num_workers=0,
        collate_fn=vector.batchify,
        pin_memory=True
    )

    return loader 
Example #18
Source File: test_data_samplers.py    From Res2Net-maskrcnn with MIT License 6 votes vote down vote up
def test_len(self):
        batch_size = 3
        drop_uneven = True
        dataset = [i for i in range(10)]
        group_ids = [random.randint(0, 1) for _ in dataset]
        sampler = RandomSampler(dataset)

        batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size, drop_uneven)
        result = list(batch_sampler)
        self.assertEqual(len(result), len(batch_sampler))
        self.assertEqual(len(result), len(batch_sampler))

        batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size, drop_uneven)
        batch_sampler_len = len(batch_sampler)
        result = list(batch_sampler)
        self.assertEqual(len(result), batch_sampler_len)
        self.assertEqual(len(result), len(batch_sampler)) 
Example #19
Source File: test_data_samplers.py    From remote_sensing_object_detection_2019 with MIT License 6 votes vote down vote up
def test_len(self):
        batch_size = 3
        drop_uneven = True
        dataset = [i for i in range(10)]
        group_ids = [random.randint(0, 1) for _ in dataset]
        sampler = RandomSampler(dataset)

        batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size, drop_uneven)
        result = list(batch_sampler)
        self.assertEqual(len(result), len(batch_sampler))
        self.assertEqual(len(result), len(batch_sampler))

        batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size, drop_uneven)
        batch_sampler_len = len(batch_sampler)
        result = list(batch_sampler)
        self.assertEqual(len(result), batch_sampler_len)
        self.assertEqual(len(result), len(batch_sampler)) 
Example #20
Source File: test_data_samplers.py    From DetNAS with MIT License 6 votes vote down vote up
def test_len(self):
        batch_size = 3
        drop_uneven = True
        dataset = [i for i in range(10)]
        group_ids = [random.randint(0, 1) for _ in dataset]
        sampler = RandomSampler(dataset)

        batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size, drop_uneven)
        result = list(batch_sampler)
        self.assertEqual(len(result), len(batch_sampler))
        self.assertEqual(len(result), len(batch_sampler))

        batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size, drop_uneven)
        batch_sampler_len = len(batch_sampler)
        result = list(batch_sampler)
        self.assertEqual(len(result), batch_sampler_len)
        self.assertEqual(len(result), len(batch_sampler)) 
Example #21
Source File: loader.py    From Dialog with MIT License 6 votes vote down vote up
def __iter__(self):
        src_list = list()
        tgt_list = list()
        # sampler is RandomSampler
        for i in self.sampler:
            self.count += 1
            src, tgt = self.sampler.data_source[i]
            src_list.append(src)
            tgt_list.append(tgt)
            if self.count % self.batch_size == 0:
                assert len(src_list) == self.batch_size
                src = rnn.pad_sequence(src_list, batch_first=True, padding_value=self.pad_id)
                tgt = rnn.pad_sequence(tgt_list, batch_first=True, padding_value=self.pad_id)
                src_list.clear()
                tgt_list.clear()
                yield src, tgt 
Example #22
Source File: dataloader.py    From keyphrase-gan with MIT License 5 votes vote down vote up
def __init__(self, dataset, max_batch_example=5, max_batch_pair=1, shuffle=False, sampler=None, batch_sampler=None,
                 num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False):
        self.dataset     = dataset
        # used for generating one2many batches
        self.num_trgs           = [len(e['trg']) for e in dataset.examples]
        self.batch_size         = max_batch_pair
        self.max_example_number = max_batch_example
        self.num_workers        = num_workers
        self.collate_fn         = collate_fn
        self.pin_memory         = pin_memory
        self.drop_last          = drop_last

        if batch_sampler is not None:
            if max_batch_pair > 1 or shuffle or sampler is not None or drop_last:
                raise ValueError('batch_sampler is mutually exclusive with '
                                 'batch_size, shuffle, sampler, and drop_last')

        if sampler is not None and shuffle:
            raise ValueError('sampler is mutually exclusive with shuffle')

        if batch_sampler is None:
            if sampler is None:
                if shuffle:
                    sampler = RandomSampler(dataset)
                else:
                    sampler = SequentialSampler(dataset)

        batch_sampler = One2ManyBatchSampler(sampler, self.num_trgs, max_batch_example=max_batch_example, max_batch_pair=max_batch_pair, drop_last=drop_last)

        self.sampler = sampler
        self.batch_sampler = batch_sampler 
Example #23
Source File: data_silo.py    From FARM with Apache License 2.0 5 votes vote down vote up
def __init__(self, origsilo, trainset, devset, testset):
        self.tensor_names = origsilo.tensor_names
        self.data = {"train": trainset, "dev": devset, "test": testset}
        self.processor = origsilo.processor
        self.batch_size = origsilo.batch_size
        # should not be necessary, xval makes no sense with huge data
        # sampler_train = DistributedSampler(self.data["train"])
        sampler_train = RandomSampler(trainset)

        self.data_loader_train = NamedDataLoader(
            dataset=trainset,
            sampler=sampler_train,
            batch_size=self.batch_size,
            tensor_names=self.tensor_names,
        )
        self.data_loader_dev = NamedDataLoader(
            dataset=devset,
            sampler=SequentialSampler(devset),
            batch_size=self.batch_size,
            tensor_names=self.tensor_names,
        )
        self.data_loader_test = NamedDataLoader(
            dataset=testset,
            sampler=SequentialSampler(testset),
            batch_size=self.batch_size,
            tensor_names=self.tensor_names,
        )
        self.loaders = {
            "train": self.data_loader_train,
            "dev": self.data_loader_dev,
            "test": self.data_loader_test,
        } 
Example #24
Source File: dataloader.py    From weakalign with MIT License 5 votes vote down vote up
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
                 num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.collate_fn = collate_fn
        self.pin_memory = pin_memory
        self.drop_last = drop_last

        if batch_sampler is not None:
            if batch_size > 1 or shuffle or sampler is not None or drop_last:
                raise ValueError('batch_sampler is mutually exclusive with '
                                 'batch_size, shuffle, sampler, and drop_last')

        if sampler is not None and shuffle:
            raise ValueError('sampler is mutually exclusive with shuffle')

        if batch_sampler is None:
            if sampler is None:
                if shuffle:
                    sampler = RandomSampler(dataset)
                else:
                    sampler = SequentialSampler(dataset)
            batch_sampler = BatchSampler(sampler, batch_size, drop_last)

        self.sampler = sampler
        self.batch_sampler = batch_sampler 
Example #25
Source File: custom_data.py    From DCC with MIT License 5 votes vote down vote up
def __init__(self, pairs, shuffle=False, batch_size=1, drop_last=False):
        if shuffle:
            self.sampler = RandomSampler(pairs)
        else:
            self.sampler = SequentialSampler(pairs)
        self.batch_size = batch_size
        self.drop_last = drop_last 
Example #26
Source File: test_auto.py    From ignite with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def _test_auto_dataloader(ws, nproc, batch_size, num_workers=1, sampler_name=None, dl_type=DataLoader):

    data = torch.rand(100, 3, 12, 12)

    if sampler_name is None:
        sampler = None
    elif sampler_name == "WeightedRandomSampler":
        sampler = WeightedRandomSampler(weights=torch.ones(100), num_samples=100)
    else:
        raise RuntimeError("Unknown sampler name: {}".format(sampler_name))

    # Test auto_dataloader
    assert idist.get_world_size() == ws
    dataloader = auto_dataloader(
        data, batch_size=batch_size, num_workers=num_workers, sampler=sampler, shuffle=sampler is None
    )

    assert isinstance(dataloader, dl_type)
    if hasattr(dataloader, "_loader"):
        dataloader = dataloader._loader
    if ws < batch_size:
        assert dataloader.batch_size == batch_size // ws
    else:
        assert dataloader.batch_size == batch_size
    if ws <= num_workers:
        assert dataloader.num_workers == (num_workers + nproc - 1) // nproc
    else:
        assert dataloader.num_workers == num_workers

    if ws < 2:
        sampler_type = RandomSampler if sampler is None else type(sampler)
        assert isinstance(dataloader.sampler, sampler_type)
    else:
        sampler_type = DistributedSampler if sampler is None else DistributedProxySampler
        assert isinstance(dataloader.sampler, sampler_type)
    if isinstance(dataloader, DataLoader):
        assert dataloader.pin_memory == ("cuda" in idist.device().type) 
Example #27
Source File: test_common.py    From ignite with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_assert_setup_common_training_handlers_wrong_train_sampler(distributed_context_single_node_gloo):
    trainer = Engine(lambda e, b: None)

    from torch.utils.data.sampler import RandomSampler

    with pytest.raises(TypeError, match=r"Train sampler should be torch DistributedSampler"):
        train_sampler = RandomSampler([0, 1, 2, 3])
        setup_common_training_handlers(trainer, train_sampler) 
Example #28
Source File: test_common.py    From ignite with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_no_warning_with_train_sampler(recwarn):
    from torch.utils.data import RandomSampler

    trainer = Engine(lambda e, b: None)
    train_sampler = RandomSampler([0, 1, 2])
    setup_common_training_handlers(trainer, train_sampler=train_sampler)
    assert len(recwarn) == 0, recwarn.pop() 
Example #29
Source File: my_data_loader.py    From ps_pytorch with MIT License 5 votes vote down vote up
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
                 num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.collate_fn = collate_fn
        self.pin_memory = pin_memory
        self.drop_last = drop_last
        #self._index_in_epoch = 0
        

        if batch_sampler is not None:
            if batch_size > 1 or shuffle or sampler is not None or drop_last:
                raise ValueError('batch_sampler is mutually exclusive with '
                                 'batch_size, shuffle, sampler, and drop_last')

        if sampler is not None and shuffle:
            raise ValueError('sampler is mutually exclusive with shuffle')

        if batch_sampler is None:
            if sampler is None:
                if shuffle:
                    sampler = RandomSampler(dataset)
                else:
                    sampler = SequentialSampler(dataset)
            batch_sampler = BatchSampler(sampler, batch_size, drop_last)

        self.sampler = sampler
        self.batch_sampler = batch_sampler
        self.data_iterator = DataLoaderIter(self) 
Example #30
Source File: loader.py    From Dialog with MIT License 5 votes vote down vote up
def __init__(self, data: Dataset, pad_id: int):
        super().__init__(RandomSampler(data), Config.batch_size, True)
        self.pad_id = pad_id
        self.count = 0