import logging import os import random import subprocess import numpy as np import torch import torch.distributed as dist import torch.multiprocessing as mp from mmcv.runner import get_dist_info def init_dist(launcher, backend='nccl', **kwargs): if mp.get_start_method(allow_none=True) is None: mp.set_start_method('spawn') if launcher == 'pytorch': _init_dist_pytorch(backend, **kwargs) elif launcher == 'mpi': _init_dist_mpi(backend, **kwargs) elif launcher == 'slurm': _init_dist_slurm(backend, **kwargs) else: raise ValueError('Invalid launcher type: {}'.format(launcher)) def _init_dist_pytorch(backend, **kwargs): # TODO: use local_rank instead of rank % num_gpus rank = int(os.environ['RANK']) num_gpus = torch.cuda.device_count() torch.cuda.set_device(rank % num_gpus) dist.init_process_group(backend=backend, **kwargs) def _init_dist_mpi(backend, **kwargs): raise NotImplementedError def _init_dist_slurm(backend, port=29500, **kwargs): proc_id = int(os.environ['SLURM_PROCID']) ntasks = int(os.environ['SLURM_NTASKS']) node_list = os.environ['SLURM_NODELIST'] num_gpus = torch.cuda.device_count() torch.cuda.set_device(proc_id % num_gpus) addr = subprocess.getoutput( 'scontrol show hostname {} | head -n1'.format(node_list)) os.environ['MASTER_PORT'] = str(port) os.environ['MASTER_ADDR'] = addr os.environ['WORLD_SIZE'] = str(ntasks) os.environ['RANK'] = str(proc_id) dist.init_process_group(backend=backend) def set_random_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) def get_root_logger(log_level=logging.INFO): logger = logging.getLogger() if not logger.hasHandlers(): logging.basicConfig( format='%(asctime)s - %(levelname)s - %(message)s', level=log_level) rank, _ = get_dist_info() if rank != 0: logger.setLevel('ERROR') return logger