import sys import contextlib import functools import inspect from mpi4py import MPI import numpy as np __all__ = ['wait_for_turn', 'mprint', 'meq_', 'assert_eq_across_ranks', 'mpitest'] # # Printing routines # @contextlib.contextmanager def wait_for_turn(comm): assert isinstance(comm, MPI.Comm) rank = comm.Get_rank() exc_info = None for i in range(comm.Get_size()): if i == rank: try: yield except: exc_info = sys.exc_info() comm.Barrier() # Raise any exception if exc_info is not None and comm.Get_rank() == 0: raise exc_info[0](exc_info[1]).with_traceback(exc_info[2]) def mprint(comm, *args): assert isinstance(comm, MPI.Comm) with wait_for_turn(comm): print('%d:' % comm.Get_rank(), *args) # TODO: mpprint # # Assertions # def meq_(comm, expected, got): assert type(expected) is list and len(expected) == comm.Get_size() rank = comm.Get_rank() if got != expected[rank]: raise AssertionError("Rank %d: Expected '%r' but got '%r'" % ( rank, expected[rank], got)) def assert_eq_across_ranks(comm, x): lst = comm.gather(x, root=0) if comm.Get_rank() == 0: for i in range(1, len(lst)): y = lst[i] if isinstance(x, np.ndarray) and isinstance(y, np.ndarray): is_equal = np.all(x == y) else: is_equal = (x == y) if not is_equal: raise AssertionError("Rank 0's and %d's result differ: '%r' vs '%r'" % (i, x, y)) # # @mpitest decorator # def format_exc_info(): import traceback type_, value, tb = sys.exc_info() msg = traceback.format_exception(type_, value, tb) return ''.join(msg) def first_nonzero(arr): """ Find index of first nonzero element in the 1D array `arr`, or raise IndexError if no such element exists. """ hits = np.nonzero(arr) assert len(hits) == 1 if len(hits[0]) == 0: raise IndexError("No non-zero elements") else: return hits[0][0] class MpiWorkers: def __init__(self, max_nprocs): import subprocess import sys import os import zmq # Since the output terminals are used for lots of debug output etc., we use # ZeroMQ to communicate with the workers. zctx = zmq.Context() socket = zctx.socket(zmq.REQ) port = socket.bind_to_random_port("tcp://*") cmd = 'import %s as mod; mod._mpi_worker("tcp://127.0.0.1:%d")' % (__name__, port) env = dict(os.environ) env['PYTHONPATH'] = ':'.join(sys.path) self.child = subprocess.Popen(['mpiexec', '-np', str(max_nprocs), sys.executable, '-c', cmd], env=env) self.socket = socket def stop(self): if self.socket is None: raise AssertionError('stopped multiple times') self.socket.send_pyobj('stop') self.socket.recv_pyobj() # TODO: If nose is capturing, gather output from child and forward to nose self.child.wait() self.socket = self.child = None def run_and_raise_result(self, func): # Call on the root worker; it will use MPI to scatter func and gather result self.socket.send_pyobj((func.__module__, func.__name__)) result = self.socket.recv_pyobj() _raise_condition(*result) def _mpi_worker(addr): import importlib import zmq from pickle import loads, dumps rank = MPI.COMM_WORLD.Get_rank() if rank == 0: zctx = zmq.Context() socket = zctx.socket(zmq.REP) socket.connect(addr) while True: if rank == 0: pickled_msg = socket.recv() else: pickled_msg = None pickled_msg = MPI.COMM_WORLD.bcast(pickled_msg, root=0) msg = loads(pickled_msg) if msg == 'stop': if rank == 0: socket.send_pyobj('') break else: module_name, func_name = msg mod = importlib.import_module(module_name) func = getattr(mod, func_name) status = func(_return_status=True) if rank == 0: socket.send_pyobj(status) # All processes wait until they can terminate MPI.COMM_WORLD.barrier() def _raise_condition(first_non_success_status, failing_rank, msg): fmt = '%s in MPI rank %d:\n\n"""\n%s"""\n' if first_non_success_status == 'ERROR': msg = fmt % ('ERROR', failing_rank, msg) raise RuntimeError(msg) elif first_non_success_status == 'FAILED': msg = fmt % ('FAILURE', failing_rank, msg) raise AssertionError(msg) elif first_non_success_status == 'SUCCESS': pass else: assert False def mpitest(nprocs): """ Runs a testcase using a `nprocs`-sized subset of COMM_WORLD. Also synchronizes results, so that a failure or error in one process causes all ranks to fail or error. The algorithm is: - If a process fails (AssertionError) or errors (any other exception), it propagates that - If a process succeeds, it reports the error of the lowest-ranking process that err-ed (by raising an error containing the stack trace as a string). If not other processes errored, the same is repeated with failures. Finally, the process succeeds. """ def dec(func): mod = inspect.getmodule(func) max_mpi_comm_size = max(nprocs, getattr(mod, 'max_mpi_comm_size', 0)) mod.max_mpi_comm_size = max_mpi_comm_size mod.mpi_test_count = getattr(mod, 'mpi_test_count', 0) + 1 @functools.wraps(func) def replacement_func(_return_status=False): from pickle import dumps n = MPI.COMM_WORLD.Get_size() rank = MPI.COMM_WORLD.Get_rank() import os if n == 1: # spawn workers for module if not done already mpi_workers = getattr(mod, 'mpi_workers', None) if mpi_workers is None: mod.mpi_workers = mpi_workers = MpiWorkers(mod.max_mpi_comm_size) try: mpi_workers.run_and_raise_result(func) finally: mod.mpi_test_count -= 1 if mod.mpi_test_count == 0: del mod.mpi_workers mpi_workers.stop() return if n < nprocs: raise RuntimeError('Number of available MPI processes (%d) ' 'too small' % n) sub_comm = MPI.COMM_WORLD.Split(0 if rank < nprocs else 1, 0) status = 'SUCCESS' exc_msg = '' try: if rank < nprocs: try: func(sub_comm) except AssertionError: status = 'FAILED' exc_msg = format_exc_info() if not _return_status: raise except: status = 'ERROR' exc_msg = format_exc_info() if not _return_status: raise finally: # Do communication of error results in a final block, so # that also erring/failing processes participate # First, figure out status of other nodes statuses = MPI.COMM_WORLD.allgather(status) try: first_failing_rank = first_nonzero(statuses) except IndexError: first_failing_rank = -1 first_non_success_status = 'SUCCESS' else: # First non-success gets to broadcast it's error first_non_success_status, msg = MPI.COMM_WORLD.bcast( (status, exc_msg), root=first_failing_rank) if _return_status: return (first_non_success_status, first_failing_rank, msg) else: _raise_condition(first_non_success_status, first_failing_rank, msg) return replacement_func return dec