Python torch.distributed() Examples

The following are 30 code examples of torch.distributed(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch , or try the search function .
Example #1
Source Project: Doc2EDAG   Author: dolphin-zs   File: base_task.py    License: MIT License 6 votes vote down vote up
def _init_device(self):
        self.logging('='*20 + 'Init Device' + '='*20)

        # set device
        if self.setting.local_rank == -1 or self.setting.no_cuda:
            self.device = torch.device("cuda" if torch.cuda.is_available() and not self.setting.no_cuda else "cpu")
            self.n_gpu = torch.cuda.device_count()
        else:
            self.device = torch.device("cuda", self.setting.local_rank)
            self.n_gpu = 1
            if self.setting.fp16:
                self.logging("16-bits training currently not supported in distributed training")
                self.setting.fp16 = False  # (see https://github.com/pytorch/pytorch/pull/13496)
        self.logging("device {} n_gpu {} distributed training {}".format(
            self.device, self.n_gpu,self.in_distributed_mode()
        )) 
Example #2
Source Project: Doc2EDAG   Author: dolphin-zs   File: base_task.py    License: MIT License 6 votes vote down vote up
def _decorate_model(self, parallel_decorate=True):
        self.logging('='*20 + 'Decorate Model' + '='*20)

        if self.setting.fp16:
            self.model.half()

        self.model.to(self.device)
        self.logging('Set model device to {}'.format(str(self.device)))

        if parallel_decorate:
            if self.in_distributed_mode():
                self.model = para.DistributedDataParallel(self.model,
                                                          device_ids=[self.setting.local_rank],
                                                          output_device=self.setting.local_rank)
                self.logging('Wrap distributed data parallel')
                # self.logging('In Distributed Mode, but do not use DistributedDataParallel Wrapper')
            elif self.n_gpu > 1:
                self.model = para.DataParallel(self.model)
                self.logging('Wrap data parallel')
        else:
            self.logging('Do not wrap parallel layers') 
Example #3
Source Project: Res2Net-maskrcnn   Author: Res2Net   File: distributed.py    License: MIT License 6 votes vote down vote up
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
        self.total_size = self.num_samples * self.num_replicas
        self.shuffle = shuffle 
Example #4
Source Project: Parsing-R-CNN   Author: soeaver   File: repeat_factor.py    License: MIT License 6 votes vote down vote up
def __iter__(self):
        if self.shuffle:
            # deterministically shuffle based on epoch
            g = torch.Generator()
            g.manual_seed(self.epoch)
            indices = self._get_epoch_indices(g)
            randperm = torch.randperm(len(indices), generator=g).tolist()
            indices = indices[randperm]
        else:
            g = torch.Generator()
            g.manual_seed(self.epoch)
            indices = self._get_epoch_indices(g)
            # indices = torch.arange(len(self.dataset)).tolist()

        # when balance len(indices) diff from dataset image_num
        self.total_size = len(indices)
        logging_rank('balance sample total_size: {}'.format(self.total_size), distributed=1, local_rank=self.rank)
        # subsample
        self.num_samples = int(len(indices) / self.num_replicas)
        offset = self.num_samples * self.rank
        indices = indices[offset: offset + self.num_samples]
        assert len(indices) == self.num_samples

        return iter(indices) 
Example #5
Source Project: Parsing-R-CNN   Author: soeaver   File: distributed.py    License: MIT License 6 votes vote down vote up
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
        self.total_size = self.num_samples * self.num_replicas
        self.shuffle = shuffle 
Example #6
Source Project: decaNLP   Author: salesforce   File: iterator.py    License: BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def init_epoch(self):
        """Set up the batch generator for a new epoch."""
        if not self.distributed:
            if self._restored_from_state:
                self.random_shuffler.random_state = self._random_state_this_epoch
            else:
                self._random_state_this_epoch = self.random_shuffler.random_state

        self.create_batches()

        if not self.distributed:
            if self._restored_from_state:
                self._restored_from_state = False
            else:
                self._iterations_this_epoch = 0
        else:
            self._iterations_this_epoch = 0


        if not self.repeat:
            self.iterations = 0
        self.epoch += 1
        if self.distributed:
            self.random_shuffler.set_epoch(self.epoch) 
Example #7
Source Project: R2CNN.pytorch   Author: Xiangyu-CAS   File: distributed.py    License: MIT License 6 votes vote down vote up
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
        self.total_size = self.num_samples * self.num_replicas
        self.shuffle = shuffle 
Example #8
Source Project: CornerNet-Lite-Pytorch   Author: DataXujing   File: trainmyData.py    License: BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def parse_args():
    parser = argparse.ArgumentParser(description="Training Script")
    parser.add_argument("cfg_file", help="config file", type=str)
    parser.add_argument("--iter", dest="start_iter",
                        help="train at iteration i",
                        default=0, type=int)
    parser.add_argument("--workers", default=4, type=int)
    parser.add_argument("--initialize", action="store_true")

    parser.add_argument("--distributed", action="store_true")
    parser.add_argument("--world-size", default=-1, type=int,
                        help="number of nodes of distributed training")
    parser.add_argument("--rank", default=0, type=int,
                        help="node rank for distributed training")
    parser.add_argument("--dist-url", default=None, type=str,
                        help="url used to set up distributed training")
    parser.add_argument("--dist-backend", default="nccl", type=str)

    args = parser.parse_args()
    return args 
Example #9
Source Project: CornerNet-Lite-Pytorch   Author: DataXujing   File: train.py    License: BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def parse_args():
    parser = argparse.ArgumentParser(description="Training Script")
    parser.add_argument("cfg_file", help="config file", type=str) #训练用的配置文件
    parser.add_argument("--iter", dest="start_iter",
                        help="train at iteration i",
                        default=0, type=int)  #指定训练从第i次迭代开始
    parser.add_argument("--workers", default=4, type=int)
    parser.add_argument("--initialize", action="store_true")

    parser.add_argument("--distributed", action="store_true")  # 分布式训练
    parser.add_argument("--world-size", default=-1, type=int,
                        help="number of nodes of distributed training")  # 分布式节点的数量
    parser.add_argument("--rank", default=0, type=int,
                        help="node rank for distributed training")  # 分布式训练节点的等级
    parser.add_argument("--dist-url", default=None, type=str,
                        help="url used to set up distributed training")
    parser.add_argument("--dist-backend", default="nccl", type=str)

    args = parser.parse_args()
    return args 
Example #10
Source Project: LEDNet   Author: AceCoooool   File: sampler.py    License: MIT License 6 votes vote down vote up
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
        self.total_size = self.num_samples * self.num_replicas
        self.shuffle = shuffle 
Example #11
Source Project: crosentgec   Author: nusnlp   File: distributed_utils.py    License: GNU General Public License v3.0 6 votes vote down vote up
def distributed_init(args):
    if args.distributed_world_size == 1:
        raise ValueError('Cannot initialize distributed with distributed_world_size=1')

    print('| distributed init (rank {}): {}'.format(
        args.distributed_rank, args.distributed_init_method), flush=True)
    if args.distributed_init_method.startswith('tcp://'):
        torch.distributed.init_process_group(
            backend=args.distributed_backend, init_method=args.distributed_init_method,
            world_size=args.distributed_world_size, rank=args.distributed_rank)
    else:
        torch.distributed.init_process_group(
            backend=args.distributed_backend, init_method=args.distributed_init_method,
            world_size=args.distributed_world_size)

    args.distributed_rank = torch.distributed.get_rank()
    if not is_master(args):
        suppress_output()

    return args.distributed_rank 
Example #12
Source Project: SegmenTron   Author: LikeLy-Journey   File: distributed.py    License: Apache License 2.0 6 votes vote down vote up
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
        self.total_size = self.num_samples * self.num_replicas
        self.shuffle = shuffle 
Example #13
Source Project: mars   Author: mars-project   File: sampler.py    License: Apache License 2.0 6 votes vote down vote up
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
        import torch.distributed as dist

        super().__init__(dataset)
        if num_replicas is None:  # pragma: no cover
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:  # pragma: no cover
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()

        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
        self.total_size = self.num_samples * self.num_replicas
        self.shuffle = shuffle 
Example #14
Source Project: DenseMatchingBenchmark   Author: DeepMotionAIResearch   File: test.py    License: MIT License 6 votes vote down vote up
def parse_args():
    parser = argparse.ArgumentParser(description='Test dense matching benchmark')
    parser.add_argument('config', help='train config file path')
    parser.add_argument('--checkpoint', help='checkpoint file')
    parser.add_argument('--out_dir', help='output result directory')
    parser.add_argument('--show', type=str, default='False', help='show results in images')
    parser.add_argument('--validate', action='store_true', help='whether to evaluate the result')
    parser.add_argument('--gpus', type=int, default=1,
        help='number of gpus to use (only applicable to non-distributed training)')
    parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='pytorch',
        help='job launcher'
    )
    parser.add_argument('--local_rank', type=int, default=0)

    args = parser.parse_args()
    if 'LOCAL_RANK' not in os.environ:
        os.environ['LOCAL_RANK'] = str(args.local_rank)

    return args 
Example #15
Source Project: sagemaker-pytorch-training-toolkit   Author: aws   File: distributed_operations.py    License: Apache License 2.0 6 votes vote down vote up
def run(backend, rank, rows, columns, num_gpus):
    # https://pytorch.org/docs/master/distributed.html
    if backend == 'gloo':
        print('Run operations supported by \'gloo\' backend.')
        _broadcast(rank, rows, columns)
        _all_reduce(rank, rows, columns)
        _barrier(rank)

        # this operation supported only on cpu
        if num_gpus == 0:
            _send_recv(rank, rows, columns)
    elif backend == 'nccl':
        print('Run operations supported by \'nccl\' backend.')
        # Note: nccl does not support gather or scatter as well:
        # https://github.com/pytorch/pytorch/blob/v0.4.0/torch/lib/THD/base/data_channels/DataChannelNccl.cpp
        _broadcast(rank, rows, columns)
        _all_reduce(rank, rows, columns)
        _reduce(rank, rows, columns)
        _all_gather(rank, rows, columns) 
Example #16
Source Project: Clothing-Detection   Author: simaiden   File: distributed.py    License: GNU General Public License v3.0 6 votes vote down vote up
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
        self.total_size = self.num_samples * self.num_replicas
        self.shuffle = shuffle 
Example #17
Source Project: kaggle-kuzushiji-2019   Author: lopuhin   File: utils.py    License: MIT License 6 votes vote down vote up
def init_distributed_mode(args):
    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        args.rank = int(os.environ["RANK"])
        args.world_size = int(os.environ['WORLD_SIZE'])
        args.gpu = int(os.environ['LOCAL_RANK'])
    elif 'SLURM_PROCID' in os.environ:
        args.rank = int(os.environ['SLURM_PROCID'])
        args.gpu = args.rank % torch.cuda.device_count()
    else:
        print('Not using distributed mode')
        args.distributed = False
        return

    args.distributed = True

    torch.cuda.set_device(args.gpu)
    args.dist_backend = 'nccl'
    print('| distributed init (rank {}): {}'.format(
        args.rank, args.dist_url), flush=True)
    torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                         world_size=args.world_size, rank=args.rank)
    torch.distributed.barrier()
    setup_for_distributed(args.rank == 0) 
Example #18
Source Project: awesome-semantic-segmentation-pytorch   Author: Tramac   File: distributed.py    License: Apache License 2.0 6 votes vote down vote up
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
        self.total_size = self.num_samples * self.num_replicas
        self.shuffle = shuffle 
Example #19
Source Project: sagemaker-python-sdk   Author: aws   File: mnist.py    License: Apache License 2.0 6 votes vote down vote up
def _get_train_data_loader(training_dir, is_distributed, batch_size, **kwargs):
    logger.info("Get train data loader")
    dataset = datasets.MNIST(
        training_dir,
        train=True,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        ),
        download=False,  # True sets a dependency on an external site for our canaries.
    )
    train_sampler = (
        torch.utils.data.distributed.DistributedSampler(dataset) if is_distributed else None
    )
    train_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=train_sampler is None,
        sampler=train_sampler,
        **kwargs
    )
    return train_sampler, train_loader 
Example #20
Source Project: sparktorch   Author: dmmiller612   File: distributed.py    License: MIT License 6 votes vote down vote up
def process_generic_model(params: List, iters: int, has_early_stop: bool = False):
    """
    Runs a mock training with zero grads. This is due to a bug where the connection gets reset with custom new groups.
    :param params: The params of the model
    :param iters: Iterations.
    """
    # Hopefully this function can go away in newer versions.
    for i in range(iters):
        for p in params:
            z = torch.zeros(p)
            dist.all_reduce(z, op=torch.distributed.ReduceOp.SUM)

        if has_early_stop:
            dist.all_reduce(torch.tensor(0.0), op=torch.distributed.ReduceOp.SUM)
            zeros = torch.zeros(1)
            dist.all_reduce(zeros, op=torch.distributed.ReduceOp.SUM)
            if zeros.item() > 0:
                break 
Example #21
Source Project: DetNAS   Author: megvii-model   File: distributed.py    License: MIT License 6 votes vote down vote up
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
        self.total_size = self.num_samples * self.num_replicas
        self.shuffle = shuffle 
Example #22
Source Project: mmdetection   Author: open-mmlab   File: base.py    License: Apache License 2.0 5 votes vote down vote up
def _parse_losses(self, losses):
        """Parse the raw outputs (losses) of the network.

        Args:
            losses (dict): Raw output of the network, which usually contain
                losses and other necessary infomation.

        Returns:
            tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor
                which may be a weighted sum of all losses, log_vars contains
                all the variables to be sent to the logger.
        """
        log_vars = OrderedDict()
        for loss_name, loss_value in losses.items():
            if isinstance(loss_value, torch.Tensor):
                log_vars[loss_name] = loss_value.mean()
            elif isinstance(loss_value, list):
                log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
            else:
                raise TypeError(
                    f'{loss_name} is not a tensor or list of tensors')

        loss = sum(_value for _key, _value in log_vars.items()
                   if 'loss' in _key)

        log_vars['loss'] = loss
        for loss_name, loss_value in log_vars.items():
            # reduce loss when distributed training
            if dist.is_available() and dist.is_initialized():
                loss_value = loss_value.data.clone()
                dist.all_reduce(loss_value.div_(dist.get_world_size()))
            log_vars[loss_name] = loss_value.item()

        return loss, log_vars 
Example #23
Source Project: Doc2EDAG   Author: dolphin-zs   File: base_task.py    License: MIT License 5 votes vote down vote up
def prepare_dist_data_loader(self, dataset, batch_size, epoch=0):
        # prepare distributed data loader
        data_sampler = DistributedSampler(dataset)
        data_sampler.set_epoch(epoch)

        if self.custom_collate_fn is None:
            dataloader = DataLoader(dataset,
                                    batch_size=batch_size,
                                    sampler=data_sampler)
        else:
            dataloader = DataLoader(dataset,
                                    batch_size=batch_size,
                                    sampler=data_sampler,
                                    collate_fn=self.custom_collate_fn)
        return dataloader 
Example #24
Source Project: Res2Net-maskrcnn   Author: Res2Net   File: comm.py    License: MIT License 5 votes vote down vote up
def synchronize():
    """
    Helper function to synchronize (barrier) among all processes when
    using distributed training
    """
    if not dist.is_available():
        return
    if not dist.is_initialized():
        return
    world_size = dist.get_world_size()
    if world_size == 1:
        return
    dist.barrier() 
Example #25
Source Project: Parsing-R-CNN   Author: soeaver   File: repeat_factor.py    License: MIT License 5 votes vote down vote up
def _get_repeat_factors(self, dataset_dicts):
        """
        Compute (fractional) per-image repeat factors.

        Args:
            dataset_dicts (list) : per-image annotations

        Returns:
            torch.Tensor: the i-th element is the repeat factor for the dataset_dicts image
                at index i.
        """
        # 1. For each category c, compute the fraction of images that contain it: f(c)
        category_freq = defaultdict(int)
        for dataset_dict in dataset_dicts:  # For each image (without repeats)
            cat_ids = {ann["category_id"] for ann in dataset_dict["annotations"]}
            for cat_id in cat_ids:
                category_freq[cat_id] += 1
        num_images = len(dataset_dicts)
        for k, v in category_freq.items():
            category_freq[k] = v / num_images
        # 2. For each category c, compute the category-level repeat factor:
        #    lvis paper: r(c) = max(1, sqrt(t / f(c)))
        #    common: r(c) = max(i, min(a,pow(t / f(c),alpha)))
        category_rep = {
            cat_id: max(self.config.MIN_REPEAT_TIMES, min(self.config.MAX_REPEAT_TIMES, math.pow(
                (self.config.REPEAT_THRESHOLD / cat_freq), self.config.POW)))
            for cat_id, cat_freq in category_freq.items()
        }
        # 3. For each image I, compute the image-level repeat factor:
        #    r(I) = max_{c in I} r(c)
        rep_factors = []
        for dataset_dict in dataset_dicts:
            cat_ids = {ann["category_id"] for ann in dataset_dict["annotations"]}
            rep_factor = max({category_rep[cat_id] for cat_id in cat_ids})
            rep_factors.append(rep_factor)
        logging_rank('max(rep_factors): {} , min(rep_factors): {} , len(rep_factors): {}'.
                     format(max(rep_factors), min(rep_factors), len(rep_factors)),
                     distributed=1, local_rank=self.rank)

        return torch.tensor(rep_factors, dtype=torch.float32) 
Example #26
Source Project: Parsing-R-CNN   Author: soeaver   File: misc.py    License: MIT License 5 votes vote down vote up
def logging_rank(sstr, distributed=True, local_rank=0):
    if distributed and local_rank == 0:
        logger.info(sstr)
    elif not distributed:
        logger.info(sstr)
    return 0 
Example #27
Source Project: Parsing-R-CNN   Author: soeaver   File: misc.py    License: MIT License 5 votes vote down vote up
def reduce_sum(tensor):
    if get_num_gpus() <= 1:
        return tensor
    import torch.distributed as dist
    tensor = tensor.clone()
    dist.all_reduce(tensor, op=dist.reduce_op.SUM)
    return tensor 
Example #28
Source Project: OpenChem   Author: Mariewelt   File: openchem_model.py    License: MIT License 5 votes vote down vote up
def print_logs(world_size):
    if world_size == 1:
        return True
    elif torch.distributed.get_rank() == 0:
        return True
    else:
        return False 
Example #29
Source Project: decaNLP   Author: salesforce   File: iterator.py    License: BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def __init__(self, dataset, batch_size, sort_key=None, device=None,
                 batch_size_fn=None, train=True,
                 repeat=None, shuffle=None, sort=None, reverse=False,
                 sort_within_batch=None, distributed=False, num_replicas=None, rank=None):
        self.batch_size, self.train, self.dataset = batch_size, train, dataset
        self.batch_size_fn = batch_size_fn
        self.iterations = 0
        self.epoch = 0
        self.reverse = reverse
        self.repeat = train if repeat is None else repeat
        self.shuffle = train if shuffle is None else shuffle
        self.sort = not train if sort is None else sort
        if sort_within_batch is None:
            self.sort_within_batch = self.sort
        else:
            self.sort_within_batch = sort_within_batch
        if sort_key is None:
            self.sort_key = dataset.sort_key
        else:
            self.sort_key = sort_key
        self.device = device

        self.distributed = distributed
        if distributed:
            self.random_shuffler = DistributedShuffler(num_replicas=num_replicas, rank=rank)
        else:
            self.random_shuffler = RandomShuffler() 

        # For state loading/saving only
        self._iterations_this_epoch = 0
        self._random_state_this_epoch = None
        self._restored_from_state = False 
Example #30
Source Project: decaNLP   Author: salesforce   File: iterator.py    License: BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def data(self):
        """Return the examples in the dataset in order, sorted, or shuffled."""
        if self.sort:
            xs = sorted(self.dataset, key=self.sort_key, reverse=self.reverse)
            if self.distributed:
                xs = [xs[i] for i in self.random_shuffler.subsample(list(range(len(xs))))]
        elif self.shuffle:
            xs = self.random_shuffler(list(self.dataset))
        else:
            xs = self.dataset
        self.extra = self.random_shuffler.extra
        return xs