import pytest import asyncio import functools import logging.config import multiprocessing as mp import os import signal import sys import threading import time import aiotools @pytest.fixture def restore_signal(): old_alrm = signal.getsignal(signal.SIGALRM) old_intr = signal.getsignal(signal.SIGINT) old_term = signal.getsignal(signal.SIGTERM) old_intr = signal.getsignal(signal.SIGUSR1) yield signal.signal(signal.SIGALRM, old_alrm) signal.signal(signal.SIGINT, old_intr) signal.signal(signal.SIGTERM, old_term) signal.signal(signal.SIGUSR1, old_term) @pytest.fixture def set_timeout(): def make_timeout(sec, callback): def _callback(signum, frame): signal.alarm(0) callback() signal.signal(signal.SIGALRM, _callback) signal.setitimer(signal.ITIMER_REAL, sec) yield make_timeout def interrupt(): os.kill(0, signal.SIGINT) def interrupt_usr1(): os.kill(os.getpid(), signal.SIGUSR1) @aiotools.server # type: ignore async def myserver_singleproc(loop, proc_idx, args): started, terminated = args assert proc_idx == 0 await asyncio.sleep(0) with started.get_lock(): started.value += 1 yield await asyncio.sleep(0) with terminated.get_lock(): terminated.value += 1 @pytest.mark.parametrize('start_method', ['fork', 'spawn']) def test_server_singleproc(mocker, set_timeout, restore_signal, start_method): mpctx = mp.get_context(start_method) mocker.patch('aiotools.server.mp', mpctx) started = mpctx.Value('i', 0) terminated = mpctx.Value('i', 0) set_timeout(0.2, interrupt) aiotools.start_server(myserver_singleproc, args=(started, terminated)) assert started.value == 1 assert terminated.value == 1 def test_server_singleproc_threading(restore_signal): started = 0 terminated = 0 value_lock = threading.Lock() @aiotools.server async def myserver(loop, proc_idx, args): nonlocal started, terminated assert proc_idx == 0 assert len(args) == 0 await asyncio.sleep(0) with value_lock: started += 1 loop.call_later(0.2, interrupt) yield await asyncio.sleep(0) with value_lock: terminated += 1 aiotools.start_server(myserver, use_threading=True) assert started == 1 assert terminated == 1 @aiotools.server # type: ignore async def myserver_multiproc(loop, proc_idx, args): started, terminated, proc_idxs = args await asyncio.sleep(0) with started.get_lock(): started.value += 1 proc_idxs[proc_idx] = proc_idx yield await asyncio.sleep(0) with terminated.get_lock(): terminated.value += 1 @pytest.mark.parametrize('start_method', ['fork', 'spawn']) def test_server_multiproc(mocker, set_timeout, restore_signal, start_method): mpctx = mp.get_context(start_method) mocker.patch('aiotools.server.mp', mpctx) started = mpctx.Value('i', 0) terminated = mpctx.Value('i', 0) proc_idxs = mpctx.Array('i', 3) set_timeout(0.2, interrupt) aiotools.start_server(myserver_multiproc, num_workers=3, args=(started, terminated, proc_idxs)) assert started.value == 3 assert terminated.value == 3 assert list(proc_idxs) == [0, 1, 2] assert len(mp.active_children()) == 0 @aiotools.server # type: ignore async def myserver_multiproc_custom_stop_signals(loop, proc_idx, args): started, terminated, received_signals, proc_idxs = args await asyncio.sleep(0) with started.get_lock(): started.value += 1 proc_idxs[proc_idx] = proc_idx received_signals[proc_idx] = yield await asyncio.sleep(0) with terminated.get_lock(): terminated.value += 1 @pytest.mark.skipif(os.environ.get('TRAVIS', '') == 'true', reason='on Travis CI') @pytest.mark.parametrize('start_method', ['fork', 'spawn']) def test_server_multiproc_custom_stop_signals( mocker, set_timeout, restore_signal, start_method): mpctx = mp.get_context(start_method) mocker.patch('aiotools.server.mp', mpctx) started = mpctx.Value('i', 0) terminated = mpctx.Value('i', 0) received_signals = mpctx.Array('i', 2) proc_idxs = mpctx.Array('i', 2) set_timeout(0.2, interrupt_usr1) aiotools.start_server(myserver_multiproc_custom_stop_signals, num_workers=2, stop_signals={signal.SIGUSR1}, args=(started, terminated, received_signals, proc_idxs)) assert started.value == 2 assert terminated.value == 2 assert list(received_signals) == [signal.SIGUSR1, signal.SIGUSR1] assert list(proc_idxs) == [0, 1] assert len(mpctx.active_children()) == 0 @aiotools.server # type: ignore async def myserver_worker_init_error(loop, proc_idx, args): started, terminated, log_queue = args logging.config.dictConfig({ 'version': 1, 'handlers': { 'q': { 'class': 'logging.handlers.QueueHandler', 'queue': log_queue, 'level': 'DEBUG', }, 'console': { 'class': 'logging.StreamHandler', 'stream': 'ext://sys.stderr', 'level': 'DEBUG', }, }, 'loggers': { 'aiotools': { 'handlers': ['q', 'console'], 'level': 'DEBUG', }, }, }) with started.get_lock(): started.value += 1 if proc_idx == 0: # delay until other workers start normally. await asyncio.sleep(0.2) raise ZeroDivisionError('oops') yield # should not be reached if errored. await asyncio.sleep(0) with terminated.get_lock(): terminated.value += 1 @pytest.mark.parametrize('use_threading,start_method', [ (True, 'fork'), (False, 'fork'), (False, 'spawn'), ]) def test_server_worker_init_error( mocker, restore_signal, use_threading, start_method): mpctx = mp.get_context(start_method) mocker.patch('aiotools.server.mp', mpctx) started = mpctx.Value('i', 0) terminated = mpctx.Value('i', 0) log_queue = mpctx.Queue() aiotools.start_server(myserver_worker_init_error, num_workers=3, use_threading=use_threading, args=(started, terminated, log_queue)) # it should automatically shut down! # reset logging logging.shutdown() assert started.value == 3 # workers who did not raise errors have already started, # and they should have terminated normally # when the errorneous worker interrupted the main loop. assert terminated.value == 2 assert len(mp.active_children()) == 0 assert not log_queue.empty() has_error_log = False while not log_queue.empty(): rec = log_queue.get() if rec.levelname == 'ERROR': has_error_log = True assert 'initialization' in rec.message # exception info is logged to the console, # but we cannot access it here because exceptions # are not picklable. assert rec.exc_info is None assert has_error_log @aiotools.server # type: ignore async def myserver_worker_init_error_multi(loop, proc_idx, args): started, terminated, log_queue = args logging.config.dictConfig({ 'version': 1, 'handlers': { 'q': { 'class': 'logging.handlers.QueueHandler', 'queue': log_queue, 'level': 'DEBUG', }, 'console': { 'class': 'logging.StreamHandler', 'stream': 'ext://sys.stderr', 'level': 'DEBUG', }, }, 'loggers': { 'aiotools': { 'handlers': ['q', 'console'], 'level': 'DEBUG', }, }, }) # make the error timing to spread over some time await asyncio.sleep(0.2 * proc_idx) if proc_idx == 1: raise ZeroDivisionError('oops') with started.get_lock(): started.value += 1 yield # should not be reached if errored. await asyncio.sleep(0) with terminated.get_lock(): terminated.value += 1 @pytest.mark.parametrize('use_threading,start_method', [ (True, 'fork'), (False, 'fork'), (False, 'spawn'), ]) def test_server_worker_init_error_multi( mocker, restore_signal, use_threading, start_method): mpctx = mp.get_context(start_method) mocker.patch('aiotools.server.mp', mpctx) started = mpctx.Value('i', 0) terminated = mpctx.Value('i', 0) log_queue = mpctx.Queue() aiotools.start_server(myserver_worker_init_error_multi, num_workers=3, use_threading=use_threading, args=(started, terminated, log_queue)) # it should automatically shut down! # reset logging logging.shutdown() assert started.value >= 1 # non-errored workers should have been terminated normally. assert terminated.value >= 1 # there is one worker remaining -- which is "cancelled"! # just ensure that all workers have terminated now. assert len(mpctx.active_children()) == 0 assert not log_queue.empty() has_error_log = False while not log_queue.empty(): rec = log_queue.get() if rec.levelname == 'ERROR': has_error_log = True assert 'initialization' in rec.message # exception info is logged to the console, # but we cannot access it here because exceptions # are not picklable. assert rec.exc_info is None assert has_error_log def test_server_multiproc_threading(set_timeout, restore_signal): started = 0 terminated = 0 proc_idxs = [0, 0, 0] value_lock = threading.Lock() @aiotools.server async def myserver(loop, proc_idx, args): nonlocal started, terminated, proc_idxs await asyncio.sleep(0) with value_lock: started += 1 proc_idxs[proc_idx] = proc_idx yield await asyncio.sleep(0) with value_lock: terminated += 1 def interrupt(): os.kill(os.getpid(), signal.SIGINT) set_timeout(0.2, interrupt) aiotools.start_server(myserver, num_workers=3, use_threading=True) assert started == 3 assert terminated == 3 assert list(proc_idxs) == [0, 1, 2] @pytest.mark.parametrize('start_method', ['fork']) def test_server_user_main(mocker, set_timeout, restore_signal, start_method): mpctx = mp.get_context(start_method) mocker.patch('aiotools.server.mp', mpctx) main_enter = False main_exit = False # FIXME: This should work with start_method = "spawn", but to test with it # we need to allow passing arguments to user-provided main functions. @aiotools.main def mymain_user_main(): nonlocal main_enter, main_exit main_enter = True yield 987 main_exit = True @aiotools.server # type: ignore async def myworker_user_main(loop, proc_idx, args): assert args[0] == 987 # first arg from user main assert args[1] == 123 # second arg from start_server args yield set_timeout(0.2, interrupt) aiotools.start_server(myworker_user_main, mymain_user_main, num_workers=3, args=(123, )) assert main_enter assert main_exit @pytest.mark.skipif(os.environ.get('TRAVIS', '') == 'true', reason='on Travis CI') def test_server_user_main_custom_stop_signals(set_timeout, restore_signal): main_enter = False main_exit = False main_signal = None worker_signals = mp.Array('i', 3) @aiotools.main def mymain(): nonlocal main_enter, main_exit, main_signal main_enter = True main_signal = yield main_exit = True @aiotools.server async def myworker(loop, proc_idx, args): worker_signals = args[0] worker_signals[proc_idx] = yield def interrupt(): os.kill(os.getpid(), signal.SIGUSR1) def noop(signum, frame): pass set_timeout(0.2, interrupt) aiotools.start_server(myworker, mymain, num_workers=3, stop_signals={signal.SIGUSR1}, args=(worker_signals, )) assert main_enter assert main_exit assert main_signal == signal.SIGUSR1 assert list(worker_signals) == [signal.SIGUSR1] * 3 def test_server_user_main_tuple(set_timeout, restore_signal): main_enter = False main_exit = False @aiotools.main def mymain(): nonlocal main_enter, main_exit main_enter = True yield 987, 654 main_exit = True @aiotools.server async def myworker(loop, proc_idx, args): assert args[0] == 987 # first arg from user main assert args[1] == 654 # second arg from user main assert args[2] == 123 # third arg from start_server args yield def interrupt(): os.kill(os.getpid(), signal.SIGINT) set_timeout(0.2, interrupt) aiotools.start_server(myworker, mymain, num_workers=3, args=(123, )) assert main_enter assert main_exit def test_server_user_main_threading(set_timeout, restore_signal): main_enter = False main_exit = False @aiotools.main def mymain(): nonlocal main_enter, main_exit main_enter = True yield 987 main_exit = True @aiotools.server async def myworker(loop, proc_idx, args): assert args[0] == 987 # first arg from user main assert args[1] == 123 # second arg from start_server args yield def interrupt(): os.kill(os.getpid(), signal.SIGINT) set_timeout(0.2, interrupt) aiotools.start_server(myworker, mymain, num_workers=3, use_threading=True, args=(123, )) assert main_enter assert main_exit def test_server_extra_proc(set_timeout, restore_signal): extras = mp.Array('i', [0, 0]) def extra_proc(key, _, pidx, args): assert _ is None extras[key] = 980 + key try: while True: time.sleep(0.1) except KeyboardInterrupt: print(f'extra[{key}] interrupted', file=sys.stderr) except Exception as e: print(f'extra[{key}] exception', e, file=sys.stderr) finally: print(f'extra[{key}] finish', file=sys.stderr) extras[key] = 990 + key @aiotools.server async def myworker(loop, pidx, args): yield def interrupt(): os.kill(os.getpid(), signal.SIGINT) set_timeout(0.2, interrupt) aiotools.start_server(myworker, extra_procs=[ functools.partial(extra_proc, 0), functools.partial(extra_proc, 1)], num_workers=3, args=(123, )) assert extras[0] == 990 assert extras[1] == 991 @pytest.mark.skipif(os.environ.get('TRAVIS', '') == 'true', reason='on Travis CI') def test_server_extra_proc_custom_stop_signal(set_timeout, restore_signal): received_signals = mp.Array('i', [0, 0]) def extra_proc(key, _, pidx, args): received_signals = args[0] try: while True: time.sleep(0.1) except aiotools.InterruptedBySignal as e: received_signals[key] = e.args[0] @aiotools.server async def myworker(loop, pidx, args): yield def interrupt(): os.kill(os.getpid(), signal.SIGUSR1) set_timeout(0.3, interrupt) aiotools.start_server(myworker, extra_procs=[ functools.partial(extra_proc, 0), functools.partial(extra_proc, 1)], stop_signals={signal.SIGUSR1}, args=(received_signals, ), num_workers=3) assert received_signals[0] == signal.SIGUSR1 assert received_signals[1] == signal.SIGUSR1 def test_server_extra_proc_threading(set_timeout, restore_signal): # When using extra_procs with threading, you need to provide a way to # explicitly interrupt your synchronous loop. # Here, we use a threading.Event object to signal interruption. extras = [0, 0] value_lock = threading.Lock() def extra_proc(key, intr_event, pidx, args): assert isinstance(intr_event, threading.Event) with value_lock: extras[key] = 980 + key try: while not intr_event.is_set(): time.sleep(0.1) except Exception as e: print(f'extra[{key}] exception', e) finally: with value_lock: extras[key] = 990 + key @aiotools.server async def myworker(loop, pidx, args): yield def interrupt(): os.kill(os.getpid(), signal.SIGINT) set_timeout(0.2, interrupt) aiotools.start_server(myworker, extra_procs=[ functools.partial(extra_proc, 0), functools.partial(extra_proc, 1)], use_threading=True, num_workers=3, args=(123, )) assert extras[0] == 990 assert extras[1] == 991