# # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import * import asyncio import atexit import collections import contextlib import decimal import functools import inspect import json import math import os import pprint import re import subprocess import sys import unittest import uuid from datetime import timedelta import click.testing import edgedb from edb import cli from edb.server import cluster as edgedb_cluster from edb.server import defines as edgedb_defines from edb.common import taskgroup from edb.testbase import serutils def get_test_cases(tests): result = collections.OrderedDict() for test in tests: if isinstance(test, unittest.TestSuite): result.update(get_test_cases(test._tests)) else: cls = type(test) try: methods = result[cls] except KeyError: methods = result[cls] = [] methods.append(test) return result class TestCaseMeta(type(unittest.TestCase)): _database_names = set() @staticmethod def _iter_methods(bases, ns): for base in bases: for methname in dir(base): if not methname.startswith('test_'): continue meth = getattr(base, methname) if not inspect.iscoroutinefunction(meth): continue yield methname, meth for methname, meth in ns.items(): if not methname.startswith('test_'): continue if not inspect.iscoroutinefunction(meth): continue yield methname, meth @classmethod def wrap(mcls, meth): @functools.wraps(meth) def wrapper(self, *args, __meth__=meth, **kwargs): try_no = 1 while True: try: # There might be unobvious serializability # anomalies across the test suite, so, rather # than hunting them down every time, simply # retry the test. self.loop.run_until_complete( __meth__(self, *args, **kwargs)) except edgedb.TransactionSerializationError: if try_no == 3: raise else: self.loop.run_until_complete(self.con.execute( 'ROLLBACK;' )) try_no += 1 else: break return wrapper @classmethod def add_method(mcls, methname, ns, meth): ns[methname] = mcls.wrap(meth) def __new__(mcls, name, bases, ns): for methname, meth in mcls._iter_methods(bases, ns.copy()): if methname in ns: del ns[methname] mcls.add_method(methname, ns, meth) cls = super().__new__(mcls, name, bases, ns) if not ns.get('BASE_TEST_CLASS') and hasattr(cls, 'get_database_name'): dbname = cls.get_database_name() if name in mcls._database_names: raise TypeError( f'{name} wants duplicate database name: {dbname}') mcls._database_names.add(name) return cls class TestCase(unittest.TestCase, metaclass=TestCaseMeta): @classmethod def setUpClass(cls): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) cls.loop = loop @classmethod def tearDownClass(cls): cls.loop.close() asyncio.set_event_loop(None) def add_fail_notes(self, **kwargs): if not hasattr(self, 'fail_notes'): self.fail_notes = {} self.fail_notes.update(kwargs) @contextlib.contextmanager def annotate(self, **kwargs): # Annotate the test in case the nested block of code fails. try: yield except Exception: self.add_fail_notes(**kwargs) raise @contextlib.contextmanager def assertRaisesRegex(self, exception, regex, msg=None, **kwargs): with super().assertRaisesRegex(exception, regex, msg=msg): try: yield except BaseException as e: if isinstance(e, exception): for attr_name, expected_val in kwargs.items(): val = getattr(e, attr_name) if val != expected_val: raise self.failureException( f'{exception.__name__} context attribute ' f'{attr_name!r} is {val} (expected ' f'{expected_val!r})') from e raise _default_cluster = None def _init_cluster(data_dir=None, *, cleanup_atexit=True, init_settings=None): if init_settings is None: init_settings = {} if (not os.environ.get('EDGEDB_DEBUG_SERVER') and not os.environ.get('EDGEDB_LOG_LEVEL')): _env = {'EDGEDB_LOG_LEVEL': 'silent'} else: _env = {} if data_dir is None: cluster = edgedb_cluster.TempCluster(env=_env, testmode=True) destroy = True else: cluster = edgedb_cluster.Cluster(data_dir=data_dir, env=_env) destroy = False if cluster.get_status() == 'not-initialized': cluster.init(server_settings=init_settings) cluster.start(port='dynamic') cluster.set_superuser_password('test') if cleanup_atexit: atexit.register(_shutdown_cluster, cluster, destroy=destroy) return cluster def _start_cluster(*, cleanup_atexit=True): global _default_cluster if _default_cluster is None: cluster_addr = os.environ.get('EDGEDB_TEST_CLUSTER_ADDR') if cluster_addr: conn_spec = json.loads(cluster_addr) _default_cluster = edgedb_cluster.RunningCluster(**conn_spec) else: data_dir = os.environ.get('EDGEDB_TEST_DATA_DIR') _default_cluster = _init_cluster( data_dir=data_dir, cleanup_atexit=cleanup_atexit) return _default_cluster def _shutdown_cluster(cluster, *, destroy=True): global _default_cluster _default_cluster = None if cluster is not None: cluster.stop() if destroy: cluster.destroy() class ClusterTestCase(TestCase): BASE_TEST_CLASS = True @classmethod def setUpClass(cls): super().setUpClass() cls.cluster = _start_cluster(cleanup_atexit=True) @classmethod def get_connect_args(cls, *, cluster=None, database=edgedb_defines.EDGEDB_SUPERUSER_DB, user=edgedb_defines.EDGEDB_SUPERUSER, password='test'): if cluster is None: cluster = cls.cluster conargs = cluster.get_connect_args().copy() conargs.update(dict(user=user, password=password, database=database)) return conargs class RollbackChanges: def __init__(self, test): self._conn = test.con async def __aenter__(self): self._tx = self._conn.transaction() await self._tx.start() async def __aexit__(self, exc_type, exc, tb): await self._tx.rollback() class ConnectedTestCaseMixin: @classmethod async def connect(cls, *, cluster=None, database=edgedb_defines.EDGEDB_SUPERUSER_DB, user=edgedb_defines.EDGEDB_SUPERUSER, password='test'): conargs = cls.get_connect_args( cluster=cluster, database=database, user=user, password=password) return await edgedb.async_connect(**conargs) def repl(self): """Open interactive EdgeQL REPL right in the test. This is obviously only for debugging purposes. Just add `self.repl()` at any point in your test. """ conargs = self.get_connect_args() cmd = [ # TODO: switch to 'edgedb' when it understands EDGEDB_PASSWORD 'python', '-m', 'edb.cli', '--host', conargs['host'], '--port', str(conargs['port']), '--database', self.con.dbname, '--user', conargs['user'], ] env = os.environ.copy() if password := conargs.get('password'): env['EDGEDB_PASSWORD'] = password proc = subprocess.Popen( cmd, stdin=sys.stdin, stdout=sys.stdout, env=env) while proc.returncode is None: try: proc.wait() except KeyboardInterrupt: pass def _run_and_rollback(self): return RollbackChanges(self) async def assert_query_result(self, query, exp_result_json, exp_result_binary=..., *, msg=None, sort=None, variables=None): fetch_args = variables if isinstance(variables, tuple) else () fetch_kw = variables if isinstance(variables, dict) else {} try: tx = self.con.transaction() await tx.start() try: res = await self.con.fetchall_json(query, *fetch_args, **fetch_kw) finally: await tx.rollback() res = json.loads(res) if sort is not None: self._sort_results(res, sort) self._assert_data_shape(res, exp_result_json, message=msg) except Exception: self.add_fail_notes(serialization='json') raise if exp_result_binary is ...: # The expected result is the same exp_result_binary = exp_result_json try: res = await self.con.fetchall(query, *fetch_args, **fetch_kw) res = serutils.serialize(res) if sort is not None: self._sort_results(res, sort) self._assert_data_shape(res, exp_result_binary, message=msg) except Exception: self.add_fail_notes(serialization='binary') raise def _sort_results(self, results, sort): if sort is True: sort = lambda x: x # don't bother sorting empty things if results: # sort can be either a key function or a dict if isinstance(sort, dict): # the keys in the dict indicate the fields that # actually must be sorted for key, val in sort.items(): # '.' is a special key referring to the base object if key == '.': self._sort_results(results, val) else: if isinstance(results, list): for r in results: self._sort_results(r[key], val) else: self._sort_results(results[key], val) else: results.sort(key=sort) def _assert_data_shape(self, data, shape, message=None): _void = object() def _format_path(path): if path: return 'PATH: ' + ''.join(str(p) for p in path) else: return 'PATH: <top-level>' def _assert_type_shape(path, data, shape): if shape in (int, float): if not isinstance(data, shape): self.fail( f'{message}: expected {shape}, got {data!r} ' f'{_format_path(path)}') else: try: shape(data) except (ValueError, TypeError): self.fail( f'{message}: expected {shape}, got {data!r} ' f'{_format_path(path)}') def _assert_dict_shape(path, data, shape): for sk, sv in shape.items(): if not data or sk not in data: self.fail( f'{message}: key {sk!r} ' f'is missing\n{pprint.pformat(data)} ' f'{_format_path(path)}') _assert_generic_shape(path + (f'["{sk}"]',), data[sk], sv) def _list_shape_iter(shape): last_shape = _void for item in shape: if item is Ellipsis: if last_shape is _void: raise ValueError( 'invalid shape spec: Ellipsis cannot be the' 'first element') while True: yield last_shape last_shape = item yield item def _assert_list_shape(path, data, shape): if not isinstance(data, list): self.fail( f'{message}: expected list ' f'{_format_path(path)}') if not data and shape: self.fail( f'{message}: expected non-empty list ' f'{_format_path(path)}') shape_iter = _list_shape_iter(shape) _data_count = 0 for _data_count, el in enumerate(data): try: el_shape = next(shape_iter) except StopIteration: self.fail( f'{message}: unexpected trailing elements in list ' f'{_format_path(path)}') _assert_generic_shape( path + (f'[{_data_count}]',), el, el_shape) if len(shape) > _data_count + 1: if shape[_data_count + 1] is not Ellipsis: self.fail( f'{message}: expecting more elements in list ' f'{_format_path(path)}') def _assert_set_shape(path, data, shape): if not isinstance(data, (list, set)): self.fail( f'{message}: expected list or set ' f'{_format_path(path)}') if not data and shape: self.fail( f'{message}: expected non-empty set ' f'{_format_path(path)}') shape_iter = _list_shape_iter(sorted(shape)) _data_count = 0 for _data_count, el in enumerate(sorted(data)): try: el_shape = next(shape_iter) except StopIteration: self.fail( f'{message}: unexpected trailing elements in set ' f'[path {_format_path(path)}]') _assert_generic_shape( path + (f'{{{_data_count}}}',), el, el_shape) if len(shape) > _data_count + 1: if Ellipsis not in shape: self.fail( f'{message}: expecting more elements in set ' f'{_format_path(path)}') def _assert_generic_shape(path, data, shape): if isinstance(shape, nullable): if data is None: return else: shape = shape.value if isinstance(shape, list): return _assert_list_shape(path, data, shape) elif isinstance(shape, set): return _assert_set_shape(path, data, shape) elif isinstance(shape, dict): return _assert_dict_shape(path, data, shape) elif isinstance(shape, type): return _assert_type_shape(path, data, shape) elif isinstance(shape, float): if not math.isclose(data, shape, rel_tol=1e-04): self.fail( f'{message}: not isclose({data}, {shape}) ' f'{_format_path(path)}') elif isinstance(shape, uuid.UUID): # since the data comes from JSON, it will only have a str if data != str(shape): self.fail( f'{message}: {data!r} != {shape!r} ' f'{_format_path(path)}') elif isinstance(shape, (str, int, timedelta, decimal.Decimal)): if data != shape: self.fail( f'{message}: {data!r} != {shape!r} ' f'{_format_path(path)}') elif shape is None: if data is not None: self.fail( f'{message}: {data!r} is expected to be None ' f'{_format_path(path)}') else: raise ValueError(f'unsupported shape type {shape}') message = message or 'data shape differs' return _assert_generic_shape((), data, shape) class OldCLITestCaseMixin: def run_cli(self, *args, input: Optional[str]=None): conn_args = self.get_connect_args() cmd_args = ( '--host', conn_args['host'], '--port', conn_args['port'], '--user', conn_args['user'], ) + args if conn_args['password']: cmd_args = ('--password-from-stdin',) + cmd_args if input is not None: input = f"{conn_args['password']}\n{input}" else: input = f"{conn_args['password']}\n" runner = click.testing.CliRunner() return runner.invoke( cli.cli, args=cmd_args, input=input, catch_exceptions=False) class CLITestCaseMixin: def run_cli(self, *args, input: Optional[str]=None): conn_args = self.get_connect_args() cmd_args = ( '--host', conn_args['host'], '--port', str(conn_args['port']), '--user', conn_args['user'], ) + args if conn_args['password']: cmd_args = ('--password-from-stdin',) + cmd_args if input is not None: input = f"{conn_args['password']}\n{input}" else: input = f"{conn_args['password']}\n" subprocess.run( ('edgedb',) + cmd_args, input=input.encode() if input else None, check=True, capture_output=True, ) class ConnectedTestCase(ClusterTestCase, ConnectedTestCaseMixin): BASE_TEST_CLASS = True @classmethod def setUpClass(cls): super().setUpClass() cls.con = cls.loop.run_until_complete(cls.connect()) @classmethod def tearDownClass(cls): try: cls.loop.run_until_complete(cls.con.aclose()) # Give event loop another iteration so that connection # transport has a chance to properly close. cls.loop.run_until_complete(asyncio.sleep(0)) cls.con = None finally: super().tearDownClass() class DatabaseTestCase(ClusterTestCase, ConnectedTestCaseMixin): SETUP = None TEARDOWN = None SCHEMA = None SETUP_METHOD = None TEARDOWN_METHOD = None # Some tests may want to manage transactions manually, # in which case ISOLATED_METHODS will be False. ISOLATED_METHODS = True # Turns on "EdgeDB developer" mode which allows using restricted # syntax like USING SQL and similar. It allows modifying standard # library (e.g. declaring casts). INTERNAL_TESTMODE = True BASE_TEST_CLASS = True def setUp(self): if self.INTERNAL_TESTMODE: self.loop.run_until_complete( self.con.execute( 'CONFIGURE SESSION SET __internal_testmode := true;')) if self.ISOLATED_METHODS: self.xact = self.con.transaction() self.loop.run_until_complete(self.xact.start()) if self.SETUP_METHOD: self.loop.run_until_complete( self.con.execute(self.SETUP_METHOD)) super().setUp() def tearDown(self): try: if self.TEARDOWN_METHOD: self.loop.run_until_complete( self.con.execute(self.TEARDOWN_METHOD)) finally: try: if self.ISOLATED_METHODS: self.loop.run_until_complete(self.xact.rollback()) del self.xact if self.con.is_in_transaction(): self.loop.run_until_complete( self.con.execute('ROLLBACK')) raise AssertionError( 'test connection is still in transaction ' '*after* the test') if not self.ISOLATED_METHODS: self.loop.run_until_complete( self.con.execute('RESET ALIAS *;')) finally: super().tearDown() @classmethod def setUpClass(cls): super().setUpClass() dbname = cls.get_database_name() cls.admin_conn = None cls.con = None class_set_up = os.environ.get('EDGEDB_TEST_CASES_SET_UP') # Only open an extra admin connection if necessary. if not class_set_up: script = f'CREATE DATABASE {dbname};' cls.admin_conn = cls.loop.run_until_complete(cls.connect()) cls.loop.run_until_complete(cls.admin_conn.execute(script)) cls.con = cls.loop.run_until_complete(cls.connect(database=dbname)) if not class_set_up: script = cls.get_setup_script() if script: cls.loop.run_until_complete(cls.con.execute(script)) @classmethod def get_database_name(cls): if cls.__name__.startswith('TestEdgeQL'): dbname = cls.__name__[len('TestEdgeQL'):] elif cls.__name__.startswith('Test'): dbname = cls.__name__[len('Test'):] else: dbname = cls.__name__ return dbname.lower() @classmethod def get_setup_script(cls): script = '' # allow the setup script to also run in test mode if cls.INTERNAL_TESTMODE: script += '\nCONFIGURE SESSION SET __internal_testmode := true;' # Look at all SCHEMA entries and potentially create multiple # modules, but always create the 'test' module. schema = ['\nmodule test {}'] for name in dir(cls): m = re.match(r'^SCHEMA(?:_(\w+))?', name) if m: module_name = (m.group(1) or 'test').lower().replace( '__', '.') schema_fn = getattr(cls, name) if schema_fn is not None: with open(schema_fn, 'r') as sf: module = sf.read() schema.append(f'\nmodule {module_name} {{ {module} }}') script += f'\nSTART MIGRATION' script += f' TO {{ {"".join(schema)} }};' script += f'\nPOPULATE MIGRATION;' script += f'\nCOMMIT MIGRATION;' if cls.SETUP: if not isinstance(cls.SETUP, (list, tuple)): scripts = [cls.SETUP] else: scripts = cls.SETUP for scr in scripts: if '\n' not in scr and os.path.exists(scr): with open(scr, 'rt') as f: setup = f.read() else: setup = scr script += '\n' + setup # allow the setup script to also run in test mode if cls.INTERNAL_TESTMODE: script += '\nCONFIGURE SESSION SET __internal_testmode := false;' return script.strip(' \n') @classmethod def tearDownClass(cls): script = '' class_set_up = os.environ.get('EDGEDB_TEST_CASES_SET_UP') if cls.TEARDOWN and not class_set_up: script = cls.TEARDOWN.strip() try: if script: cls.loop.run_until_complete( cls.con.execute(script)) finally: try: cls.loop.run_until_complete(cls.con.aclose()) if not class_set_up: dbname = cls.get_database_name() script = f'DROP DATABASE {dbname};' cls.loop.run_until_complete( cls.admin_conn.execute(script)) finally: try: if cls.admin_conn is not None: cls.loop.run_until_complete( cls.admin_conn.aclose()) finally: super().tearDownClass() @contextlib.asynccontextmanager async def assertRaisesRegexTx(self, exception, regex, msg=None, **kwargs): """A version of assertRaisesRegex with automatic transaction recovery """ with super().assertRaisesRegex(exception, regex, msg=msg): try: tx = self.con.transaction() await tx.start() yield except BaseException as e: if isinstance(e, exception): for attr_name, expected_val in kwargs.items(): val = getattr(e, attr_name) if val != expected_val: raise self.failureException( f'{exception.__name__} context attribute ' f'{attr_name!r} is {val} (expected ' f'{expected_val!r})') from e raise finally: await tx.rollback() async def migrate(self, migration, *, module: str = 'test'): async with self.con.transaction(): await self.con.execute(f""" START MIGRATION TO {{ module {module} {{ {migration} }} }}; POPULATE MIGRATION; COMMIT MIGRATION; """) class nullable: def __init__(self, value): self.value = value class Error: def __init__(self, cls, message, shape): self._message = message self._class = cls self._shape = shape @property def message(self): return self._message @property def cls(self): return self._class @property def shape(self): return self._shape class BaseQueryTestCase(DatabaseTestCase): BASE_TEST_CLASS = True class DDLTestCase(BaseQueryTestCase): # DDL test cases generally need to be serialized # to avoid deadlocks in parallel execution. SERIALIZED = True class NonIsolatedDDLTestCase(DDLTestCase): ISOLATED_METHODS = False BASE_TEST_CLASS = True class QueryTestCase(BaseQueryTestCase): BASE_TEST_CLASS = True def get_test_cases_setup(cases): result = [] for case in cases: if not hasattr(case, 'get_setup_script'): continue setup_script = case.get_setup_script() if not setup_script: continue dbname = case.get_database_name() result.append((case, dbname, setup_script)) return result def setup_test_cases(cases, conn, num_jobs, verbose=False): setup = get_test_cases_setup(cases) async def _run(): if num_jobs == 1: # Special case for --jobs=1 for _case, dbname, setup_script in setup: await _setup_database(dbname, setup_script, conn) if verbose: print(f' -> {dbname}: OK', flush=True) else: async with taskgroup.TaskGroup(name='setup test cases') as g: # Use a semaphore to limit the concurrency of bootstrap # tasks to the number of jobs (bootstrap is heavy, having # more tasks than `--jobs` won't necessarily make # things faster.) sem = asyncio.BoundedSemaphore(num_jobs) async def controller(coro, dbname, *args): async with sem: await coro(dbname, *args) if verbose: print(f' -> {dbname}: OK', flush=True) for _case, dbname, setup_script in setup: g.create_task(controller( _setup_database, dbname, setup_script, conn)) return asyncio.run(_run()) async def _setup_database(dbname, setup_script, conn_args): default_args = { 'user': edgedb_defines.EDGEDB_SUPERUSER, 'password': 'test', } default_args.update(conn_args) admin_conn = await edgedb.async_connect( database=edgedb_defines.EDGEDB_SUPERUSER_DB, **default_args) try: await admin_conn.execute(f'CREATE DATABASE {dbname};') finally: await admin_conn.aclose() dbconn = await edgedb.async_connect(database=dbname, **default_args) try: async with dbconn.transaction(): await dbconn.execute(setup_script) except Exception as ex: raise RuntimeError( f'exception during initialization of {dbname!r} test DB: {ex}' ) from ex finally: await dbconn.aclose() return dbname _lock_cnt = 0 def gen_lock_key(): global _lock_cnt _lock_cnt += 1 return os.getpid() * 1000 + _lock_cnt