import asyncio import pytest import socket import subprocess import sys import contextlib import os import ssl import time import tempfile import atexit import inspect from collections import namedtuple from urllib.parse import urlencode, urlunparse from async_timeout import timeout as async_timeout import aioredis import aioredis.sentinel TCPAddress = namedtuple('TCPAddress', 'host port') RedisServer = namedtuple('RedisServer', 'name tcp_address unixsocket version password') SentinelServer = namedtuple('SentinelServer', 'name tcp_address unixsocket version masters') # Public fixtures @pytest.yield_fixture def loop(): """Creates new event loop.""" loop = asyncio.new_event_loop() if sys.version_info < (3, 8): asyncio.set_event_loop(loop) try: yield loop finally: if hasattr(loop, 'is_closed'): closed = loop.is_closed() else: closed = loop._closed # XXX if not closed: loop.call_soon(loop.stop) loop.run_forever() loop.close() @pytest.fixture(scope='session') def unused_port(): """Gets random free port.""" def fun(): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(('127.0.0.1', 0)) return s.getsockname()[1] return fun @pytest.fixture def create_connection(_closable): """Wrapper around aioredis.create_connection.""" async def f(*args, **kw): conn = await aioredis.create_connection(*args, **kw) _closable(conn) return conn return f @pytest.fixture(params=[ aioredis.create_redis, aioredis.create_redis_pool], ids=['single', 'pool']) def create_redis(_closable, request): """Wrapper around aioredis.create_redis.""" factory = request.param async def f(*args, **kw): redis = await factory(*args, **kw) _closable(redis) return redis return f @pytest.fixture def create_pool(_closable): """Wrapper around aioredis.create_pool.""" async def f(*args, **kw): redis = await aioredis.create_pool(*args, **kw) _closable(redis) return redis return f @pytest.fixture def create_sentinel(_closable): """Helper instantiating RedisSentinel client.""" async def f(*args, **kw): # make it fail fast on slow CIs (if timeout argument is ommitted) kw.setdefault('timeout', .001) client = await aioredis.sentinel.create_sentinel(*args, **kw) _closable(client) return client return f @pytest.fixture def pool(create_pool, server, loop): """Returns RedisPool instance.""" return loop.run_until_complete(create_pool(server.tcp_address)) @pytest.fixture def redis(create_redis, server, loop): """Returns Redis client instance.""" redis = loop.run_until_complete( create_redis(server.tcp_address)) async def clear(): await redis.flushall() loop.run_until_complete(clear()) return redis @pytest.fixture def redis_sentinel(create_sentinel, sentinel, loop): """Returns Redis Sentinel client instance.""" redis_sentinel = loop.run_until_complete( create_sentinel([sentinel.tcp_address], timeout=2)) async def ping(): return await redis_sentinel.ping() assert loop.run_until_complete(ping()) == b'PONG' return redis_sentinel @pytest.yield_fixture def _closable(loop): conns = [] async def close(): waiters = [] while conns: conn = conns.pop(0) conn.close() waiters.append(conn.wait_closed()) if waiters: await asyncio.gather(*waiters) try: yield conns.append finally: loop.run_until_complete(close()) @pytest.fixture(scope='session') def server(start_server): """Starts redis-server instance.""" return start_server('A') @pytest.fixture(scope='session') def serverB(start_server): """Starts redis-server instance.""" return start_server('B') @pytest.fixture(scope='session') def sentinel(start_sentinel, request, start_server): """Starts redis-sentinel instance with one master -- masterA.""" # Adding master+slave for normal (no failover) tests: master_no_fail = start_server('master-no-fail') start_server('slave-no-fail', slaveof=master_no_fail) # Adding master+slave for failover test; masterA = start_server('masterA') start_server('slaveA', slaveof=masterA) return start_sentinel('main', masterA, master_no_fail) @pytest.fixture(params=['path', 'query']) def server_tcp_url(server, request): def make(**kwargs): netloc = '{0.host}:{0.port}'.format(server.tcp_address) path = '' if request.param == 'path': if 'password' in kwargs: netloc = ':{0}@{1.host}:{1.port}'.format( kwargs.pop('password'), server.tcp_address) if 'db' in kwargs: path = '/{}'.format(kwargs.pop('db')) query = urlencode(kwargs) return urlunparse(('redis', netloc, path, '', query, '')) return make @pytest.fixture def server_unix_url(server): def make(**kwargs): query = urlencode(kwargs) return urlunparse(('unix', '', server.unixsocket, '', query, '')) return make # Internal stuff # def pytest_addoption(parser): parser.addoption('--redis-server', default=[], action="append", help="Path to redis-server executable," " defaults to `%(default)s`") parser.addoption('--ssl-cafile', default='tests/ssl/cafile.crt', help="Path to testing SSL CA file") parser.addoption('--ssl-dhparam', default='tests/ssl/dhparam.pem', help="Path to testing SSL DH params file") parser.addoption('--ssl-cert', default='tests/ssl/cert.pem', help="Path to testing SSL CERT file") parser.addoption('--uvloop', default=False, action='store_true', help="Run tests with uvloop") def _read_server_version(redis_bin): args = [redis_bin, '--version'] with subprocess.Popen(args, stdout=subprocess.PIPE) as proc: version = proc.stdout.readline().decode('utf-8') for part in version.split(): if part.startswith('v='): break else: raise RuntimeError( "No version info can be found in {}".format(version)) return tuple(map(int, part[2:].split('.'))) @contextlib.contextmanager def config_writer(path): with open(path, 'wt') as f: def write(*args): print(*args, file=f) yield write REDIS_SERVERS = [] VERSIONS = {} def format_version(srv): return 'redis_v{}'.format('.'.join(map(str, VERSIONS[srv]))) @pytest.fixture(scope='session') def start_server(_proc, request, unused_port, server_bin): """Starts Redis server instance. Caches instances by name. ``name`` param -- instance alias ``config_lines`` -- optional list of config directives to put in config (if no config_lines passed -- no config will be generated, for backward compatibility). """ version = _read_server_version(server_bin) verbose = request.config.getoption('-v') > 3 servers = {} def timeout(t): end = time.time() + t while time.time() <= end: yield True raise RuntimeError("Redis startup timeout expired") def maker(name, config_lines=None, *, slaveof=None, password=None): assert slaveof is None or isinstance(slaveof, RedisServer), slaveof if name in servers: return servers[name] port = unused_port() tcp_address = TCPAddress('localhost', port) if sys.platform == 'win32': unixsocket = None else: unixsocket = '/tmp/aioredis.{}.sock'.format(port) dumpfile = 'dump-{}.rdb'.format(port) data_dir = tempfile.gettempdir() dumpfile_path = os.path.join(data_dir, dumpfile) stdout_file = os.path.join(data_dir, 'aioredis.{}.stdout'.format(port)) tmp_files = [dumpfile_path, stdout_file] if config_lines: config = os.path.join(data_dir, 'aioredis.{}.conf'.format(port)) with config_writer(config) as write: write('daemonize no') write('save ""') write('dir ', data_dir) write('dbfilename', dumpfile) write('port', port) if unixsocket: write('unixsocket', unixsocket) tmp_files.append(unixsocket) if password: write('requirepass "{}"'.format(password)) write('# extra config') for line in config_lines: write(line) if slaveof is not None: write("slaveof {0.tcp_address.host} {0.tcp_address.port}" .format(slaveof)) if password: write('masterauth "{}"'.format(password)) args = [config] tmp_files.append(config) else: args = ['--daemonize', 'no', '--save', '""', '--dir', data_dir, '--dbfilename', dumpfile, '--port', str(port), ] if unixsocket: args += [ '--unixsocket', unixsocket, ] if password: args += [ '--requirepass "{}"'.format(password) ] if slaveof is not None: args += [ '--slaveof', str(slaveof.tcp_address.host), str(slaveof.tcp_address.port), ] if password: args += [ '--masterauth "{}"'.format(password) ] f = open(stdout_file, 'w') atexit.register(f.close) proc = _proc(server_bin, *args, stdout=f, stderr=subprocess.STDOUT, _clear_tmp_files=tmp_files) with open(stdout_file, 'rt') as f: for _ in timeout(10): assert proc.poll() is None, ( "Process terminated", proc.returncode) log = f.readline() if log and verbose: print(name, ":", log, end='') if 'The server is now ready to accept connections ' in log: break if slaveof is not None: for _ in timeout(10): log = f.readline() if log and verbose: print(name, ":", log, end='') if 'sync: Finished with success' in log: break info = RedisServer(name, tcp_address, unixsocket, version, password) servers.setdefault(name, info) return info return maker @pytest.fixture(scope='session') def start_sentinel(_proc, request, unused_port, server_bin): """Starts Redis Sentinel instances.""" version = _read_server_version(server_bin) verbose = request.config.getoption('-v') > 3 sentinels = {} def timeout(t): end = time.time() + t while time.time() <= end: yield True raise RuntimeError("Redis startup timeout expired") def maker(name, *masters, quorum=1, noslaves=False, down_after_milliseconds=3000, failover_timeout=1000): key = (name,) + masters if key in sentinels: return sentinels[key] port = unused_port() tcp_address = TCPAddress('localhost', port) data_dir = tempfile.gettempdir() config = os.path.join( data_dir, 'aioredis-sentinel.{}.conf'.format(port)) stdout_file = os.path.join( data_dir, 'aioredis-sentinel.{}.stdout'.format(port)) tmp_files = [config, stdout_file] if sys.platform == 'win32': unixsocket = None else: unixsocket = os.path.join( data_dir, 'aioredis-sentinel.{}.sock'.format(port)) tmp_files.append(unixsocket) with config_writer(config) as write: write('daemonize no') write('save ""') write('port', port) if unixsocket: write('unixsocket', unixsocket) write('loglevel debug') for master in masters: write('sentinel monitor', master.name, '127.0.0.1', master.tcp_address.port, quorum) write('sentinel down-after-milliseconds', master.name, down_after_milliseconds) write('sentinel failover-timeout', master.name, failover_timeout) write('sentinel auth-pass', master.name, master.password) f = open(stdout_file, 'w') atexit.register(f.close) proc = _proc(server_bin, config, '--sentinel', stdout=f, stderr=subprocess.STDOUT, _clear_tmp_files=tmp_files) # XXX: wait sentinel see all masters and slaves; all_masters = {m.name for m in masters} if noslaves: all_slaves = {} else: all_slaves = {m.name for m in masters} with open(stdout_file, 'rt') as f: for _ in timeout(30): assert proc.poll() is None, ( "Process terminated", proc.returncode) log = f.readline() if log and verbose: print(name, ":", log, end='') for m in masters: if '# +monitor master {}'.format(m.name) in log: all_masters.discard(m.name) if '* +slave slave' in log and \ '@ {}'.format(m.name) in log: all_slaves.discard(m.name) if not all_masters and not all_slaves: break else: raise RuntimeError("Could not start Sentinel") masters = {m.name: m for m in masters} info = SentinelServer(name, tcp_address, unixsocket, version, masters) sentinels.setdefault(key, info) return info return maker @pytest.fixture(scope='session') def ssl_proxy(_proc, request, unused_port): by_port = {} cafile = os.path.abspath(request.config.getoption('--ssl-cafile')) certfile = os.path.abspath(request.config.getoption('--ssl-cert')) dhfile = os.path.abspath(request.config.getoption('--ssl-dhparam')) assert os.path.exists(cafile), \ "Missing SSL CA file, run `make certificate` to generate new one" assert os.path.exists(certfile), \ "Missing SSL CERT file, run `make certificate` to generate new one" assert os.path.exists(dhfile), \ "Missing SSL DH params, run `make certificate` to generate new one" ssl_ctx = ssl.create_default_context(cafile=cafile) ssl_ctx.check_hostname = False ssl_ctx.verify_mode = ssl.CERT_NONE ssl_ctx.load_dh_params(dhfile) def sockat(unsecure_port): if unsecure_port in by_port: return by_port[unsecure_port] secure_port = unused_port() _proc('/usr/bin/socat', 'openssl-listen:{port},' 'dhparam={param},' 'cert={cert},verify=0,fork' .format(port=secure_port, param=dhfile, cert=certfile), 'tcp-connect:localhost:{}' .format(unsecure_port) ) time.sleep(1) # XXX by_port[unsecure_port] = secure_port, ssl_ctx return secure_port, ssl_ctx return sockat @pytest.yield_fixture(scope='session') def _proc(): processes = [] tmp_files = set() def run(*commandline, _clear_tmp_files=(), **kwargs): proc = subprocess.Popen(commandline, **kwargs) processes.append(proc) tmp_files.update(_clear_tmp_files) return proc try: yield run finally: while processes: proc = processes.pop(0) proc.terminate() proc.wait() for path in tmp_files: try: os.remove(path) except OSError: pass @pytest.mark.tryfirst def pytest_pyfunc_call(pyfuncitem): """ Run asyncio marked test functions in an event loop instead of a normal function call. """ if inspect.iscoroutinefunction(pyfuncitem.obj): marker = pyfuncitem.get_closest_marker('timeout') if marker is not None and marker.args: timeout = marker.args[0] else: timeout = 15 funcargs = pyfuncitem.funcargs loop = funcargs['loop'] testargs = {arg: funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames} loop.run_until_complete( _wait_coro(pyfuncitem.obj, testargs, timeout=timeout)) return True async def _wait_coro(corofunc, kwargs, timeout): with async_timeout(timeout): return (await corofunc(**kwargs)) def pytest_runtest_setup(item): is_coro = inspect.iscoroutinefunction(item.obj) if is_coro and 'loop' not in item.fixturenames: # inject an event loop fixture for all async tests item.fixturenames.append('loop') def pytest_collection_modifyitems(session, config, items): skip_by_version = [] for item in items[:]: marker = item.get_closest_marker('redis_version') if marker is not None: try: version = VERSIONS[item.callspec.getparam('server_bin')] except (KeyError, ValueError, AttributeError): # TODO: throw noisy warning continue if version < marker.kwargs['version']: skip_by_version.append(item) item.add_marker(pytest.mark.skip( reason=marker.kwargs['reason'])) if 'ssl_proxy' in item.fixturenames: item.add_marker(pytest.mark.skipif( "not os.path.exists('/usr/bin/socat')", reason="socat package required (apt-get install socat)")) if len(items) != len(skip_by_version): for i in skip_by_version: items.remove(i) def pytest_configure(config): bins = config.getoption('--redis-server')[:] cmd = 'which redis-server' if not bins: with os.popen(cmd) as pipe: path = pipe.read().rstrip() assert path, ( "There is no redis-server on your computer." " Please install it first") REDIS_SERVERS[:] = [path] else: REDIS_SERVERS[:] = bins VERSIONS.update({srv: _read_server_version(srv) for srv in REDIS_SERVERS}) assert VERSIONS, ("Expected to detect redis versions", REDIS_SERVERS) class DynamicFixturePlugin: @pytest.fixture(scope='session', params=REDIS_SERVERS, ids=format_version) def server_bin(self, request): """Common for start_server and start_sentinel server bin path parameter. """ return request.param config.pluginmanager.register(DynamicFixturePlugin(), 'server-bin-fixture') if config.getoption('--uvloop'): try: import uvloop except ImportError: raise RuntimeError( "Can not import uvloop, make sure it is installed") asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())