""" peewee-async tests ================== Create tests.ini file to configure tests. """ import asyncio import contextlib import json import logging import os import sys import unittest import uuid import peewee import peewee_async import peewee_asyncext ########## # Config # ########## # logging.basicConfig(level=logging.DEBUG) DB_DEFAULTS = { 'postgres': { 'database': 'test', 'host': '127.0.0.1', # 'port': 5432, 'user': 'postgres', }, 'postgres-ext': { 'database': 'test', 'host': '127.0.0.1', # 'port': 5432, 'user': 'postgres', }, 'postgres-pool': { 'database': 'test', 'host': '127.0.0.1', # 'port': 5432, 'user': 'postgres', 'max_connections': 4, }, 'postgres-pool-ext': { 'database': 'test', 'host': '127.0.0.1', # 'port': 5432, 'user': 'postgres', 'max_connections': 4, }, 'mysql': { 'database': 'test', 'host': '127.0.0.1', 'port': 3306, 'user': 'root', }, 'mysql-pool': { 'database': 'test', 'host': '127.0.0.1', 'port': 3306, 'user': 'root', } } DB_OVERRIDES = {} DB_CLASSES = { 'postgres': peewee_async.PostgresqlDatabase, 'postgres-ext': peewee_asyncext.PostgresqlExtDatabase, 'postgres-pool': peewee_async.PooledPostgresqlDatabase, 'postgres-pool-ext': peewee_asyncext.PooledPostgresqlExtDatabase, 'mysql': peewee_async.MySQLDatabase, 'mysql-pool': peewee_async.PooledMySQLDatabase } try: import aiopg except ImportError: aiopg = None try: import aiomysql except ImportError: aiomysql = None def setUpModule(): try: with open('tests.json', 'r') as tests_fp: DB_OVERRIDES.update(json.load(tests_fp)) except FileNotFoundError: print("'tests.json' file not found, will use defaults") if not aiopg: print("aiopg is not installed, ignoring PostgreSQL tests") for key in list(DB_CLASSES.keys()): if key.startswith('postgres'): DB_CLASSES.pop(key) if not aiomysql: print("aiomysql is not installed, ignoring MySQL tests") for key in list(DB_CLASSES.keys()): if key.startswith('mysql'): DB_CLASSES.pop(key) loop = asyncio.new_event_loop() all_databases = load_databases(only=None) for key, database in all_databases.items(): connect = database.connect_async(loop=loop) loop.run_until_complete(connect) if database._async_conn is not None: disconnect = database.close_async() loop.run_until_complete(disconnect) else: print("Can't setup connection for %s" % key) DB_CLASSES.pop(key) def load_managers(*, loop, only): managers = {} for key in DB_CLASSES: if only and key not in only: continue params = DB_DEFAULTS.get(key) or {} params.update(DB_OVERRIDES.get(key) or {}) database = DB_CLASSES[key](**params) managers[key] = peewee_async.Manager(database, loop=loop) return managers def load_databases(*, only): databases = {} for key in DB_CLASSES: if only and key not in only: continue params = DB_DEFAULTS.get(key) or {} params.update(DB_OVERRIDES.get(key) or {}) databases[key] = DB_CLASSES[key](**params) return databases ########## # Models # ########## class TestModel(peewee.Model): text = peewee.CharField(max_length=100, unique=True) data = peewee.TextField(default='') def __str__(self): return '<%s id=%s> %s' % (self.__class__.__name__, self.id, self.text) class TestModelAlpha(peewee.Model): text = peewee.CharField() def __str__(self): return '<%s id=%s> %s' % (self.__class__.__name__, self.id, self.text) class TestModelBeta(peewee.Model): alpha = peewee.ForeignKeyField(TestModelAlpha, backref='betas') text = peewee.CharField() def __str__(self): return '<%s id=%s> %s' % (self.__class__.__name__, self.id, self.text) class TestModelGamma(peewee.Model): text = peewee.CharField() beta = peewee.ForeignKeyField(TestModelBeta, backref='gammas') def __str__(self): return '<%s id=%s> %s' % (self.__class__.__name__, self.id, self.text) class UUIDTestModel(peewee.Model): id = peewee.UUIDField(primary_key=True, default=uuid.uuid4) text = peewee.CharField() def __str__(self): return '<%s id=%s> %s' % (self.__class__.__name__, self.id, self.text) class CompositeTestModel(peewee.Model): """A simple "through" table for many-to-many relationship.""" uuid = peewee.ForeignKeyField(UUIDTestModel) alpha = peewee.ForeignKeyField(TestModelAlpha) class Meta: primary_key = peewee.CompositeKey('uuid', 'alpha') #################### # Base tests class # #################### class BaseManagerTestCase(unittest.TestCase): only = None models = [TestModel, UUIDTestModel, TestModelAlpha, TestModelBeta, TestModelGamma, CompositeTestModel] @classmethod @contextlib.contextmanager def manager(cls, objects, allow_sync=False): for model in cls.models: model._meta.database = objects.database if allow_sync: with objects.allow_sync(): yield else: yield def setUp(self): """Setup the new event loop, and database configs, reset counter. """ self.run_count = 0 self.loop = asyncio.new_event_loop() self.managers = load_managers(loop=self.loop, only=self.only) # Clean up before tests for _, objects in self.managers.items(): objects.database.set_allow_sync(False) with self.manager(objects, allow_sync=True): for model in self.models: model.create_table(True) for model in reversed(self.models): model.delete().execute() def tearDown(self): """Check if test was actually passed by counter, clean up. """ self.assertEqual(len(self.managers), self.run_count) for _, objects in self.managers.items(): self.loop.run_until_complete(objects.close()) self.loop.close() for _, objects in self.managers.items(): with self.manager(objects, allow_sync=True): for model in reversed(self.models): model.drop_table(fail_silently=True) self.managers = None def run_with_managers(self, test, exclude=None): """Run test coroutine against available Manager instances. test -- coroutine with single parameter, Manager instance exclude -- exclude list or string for manager key Example: async def test(objects): # ... run_with_managers(test, exclude=['mysql', 'mysql-pool']) """ for key, objects in self.managers.items(): if exclude is None or (key not in exclude): with self.manager(objects, allow_sync=False): self.loop.run_until_complete(test(objects)) with self.manager(objects, allow_sync=True): for model in reversed(self.models): model.delete().execute() self.run_count += 1 ################ # Common tests # ################ class DatabaseTestCase(unittest.TestCase): def test_deferred_init(self): for key in DB_CLASSES: params = DB_DEFAULTS.get(key) or {} params.update(DB_OVERRIDES.get(key) or {}) database = DB_CLASSES[key](None) self.assertTrue(database.deferred) database.init(**params) self.assertTrue(not database.deferred) TestModel._meta.database = database TestModel.create_table(True) TestModel.drop_table(True) def test_proxy_database(self): loop = asyncio.new_event_loop() database = peewee.Proxy() TestModel._meta.database = database objects = peewee_async.Manager(database, loop=loop) async def test(objects): text = "Test %s" % uuid.uuid4() await objects.create(TestModel, text=text) await objects.get(TestModel, text=text) for key in DB_CLASSES: params = DB_DEFAULTS.get(key) or {} params.update(DB_OVERRIDES.get(key) or {}) database.initialize(DB_CLASSES[key](**params)) TestModel.create_table(True) loop.run_until_complete(test(objects)) loop.run_until_complete(objects.close()) TestModel.drop_table(True) loop.close() class OlderTestCase(unittest.TestCase): # only = ['postgres', 'postgres-ext', 'postgres-pool', 'postgres-pool-ext'] only = None models = [TestModel, UUIDTestModel, TestModelAlpha, TestModelBeta, TestModelGamma] @classmethod @contextlib.contextmanager def current_database(cls, database, allow_sync=False): for model in cls.models: model._meta.database = database yield @classmethod def setUpClass(cls, *args, **kwargs): """Configure database managers, create test tables. """ cls.loop = asyncio.new_event_loop() asyncio.set_event_loop(cls.loop) cls.databases = load_databases(only=cls.only) for k, database in cls.databases.items(): database.set_allow_sync(True) with cls.current_database(database): for model in cls.models: model.create_table(True) database.set_allow_sync(False) @classmethod def tearDownClass(cls, *args, **kwargs): """Remove all test tables and close connections. """ for _, database in cls.databases.items(): cls.loop.run_until_complete(database.close_async()) cls.loop.close() for _, database in cls.databases.items(): database.set_allow_sync(True) with cls.current_database(database): for model in reversed(cls.models): model.drop_table(fail_silently=True) database.set_allow_sync(False) cls.databases = None def setUp(self): """Reset all data. """ self.run_count = 0 for k, database in self.databases.items(): with self.current_database(database): database.set_allow_sync(True) for model in reversed(self.models): model.delete().execute() database.set_allow_sync(False) def tearDown(self): """Check if test was actually passed by counter. """ self.assertEqual(len(self.databases), self.run_count) def run_with_databases(self, test, exclude=None): """Run test coroutine against available databases. """ for k, database in self.databases.items(): if exclude is None or (k not in exclude): with self.current_database(database): database.set_allow_sync(False) self.loop.run_until_complete(test(database)) database.set_allow_sync(True) for model in reversed(self.models): model.delete().execute() database.set_allow_sync(False) self.run_count += 1 def test_create_obj(self): async def test(database): text = "Test %s" % uuid.uuid4() obj = await peewee_async.create_object(TestModel, text=text) self.assertTrue(obj is not None) self.assertEqual(obj.text, text) self.run_with_databases(test) def test_get_and_delete_obj(self): async def test(database): text = "Test %s" % uuid.uuid4() obj1 = await peewee_async.create_object( TestModel, text=text) obj2 = await peewee_async.get_object( TestModel, TestModel.id == obj1.id) await peewee_async.delete_object(obj2) try: obj3 = await peewee_async.get_object( TestModel, TestModel.id == obj1.id) except TestModel.DoesNotExist: obj3 = None self.assertTrue(obj3 is None, "Error, object wasn't deleted") self.run_with_databases(test) def test_get_and_update_obj(self): async def test(database): text = "Test %s" % uuid.uuid4() obj1 = await peewee_async.create_object( TestModel, text=text) obj1.text = "Test update object" await peewee_async.update_object(obj1) obj2 = await peewee_async.get_object( TestModel, TestModel.id == obj1.id) self.assertEqual(obj2.text, "Test update object") self.run_with_databases(test) class ManagerTestCase(BaseManagerTestCase): # only = ['postgres', 'postgres-ext', 'postgres-pool', 'postgres-pool-ext'] only = None def test_connect_close(self): async def get_conn(objects): await objects.connect() # await asyncio.sleep(0.05, loop=self.loop) # NOTE: "private" member access return objects.database._async_conn async def test(objects): c1 = await get_conn(objects) c2 = await get_conn(objects) self.assertEqual(c1, c2) self.assertTrue(objects.is_connected) await objects.close() self.assertTrue(not objects.is_connected) done, not_done = await asyncio.wait([ get_conn(objects), get_conn(objects), get_conn(objects), ], loop=self.loop) conn = next(iter(done)).result() self.assertEqual(len(done), 3) self.assertTrue(objects.is_connected) self.assertTrue(all(map(lambda t: t.result() == conn, done))) await objects.close() self.assertTrue(not objects.is_connected) self.run_with_managers(test) def test_many_requests(self): async def test(objects): max_connections = getattr(objects.database, 'max_connections', 1) text = "Test %s" % uuid.uuid4() obj = await objects.create(TestModel, text=text) n = 2 * max_connections # number of requests done, not_done = await asyncio.wait( [objects.get(TestModel, id=obj.id) for _ in range(n)], loop=self.loop) self.assertEqual(len(done), n) self.run_with_managers(test) def test_create_obj(self): async def test(objects): text = "Test %s" % uuid.uuid4() obj = await objects.create(TestModel, text=text) self.assertTrue(obj is not None) self.assertEqual(obj.text, text) self.run_with_managers(test) def test_create_or_get(self): async def test(objects): text = "Test %s" % uuid.uuid4() obj1, created1 = await objects.create_or_get( TestModel, text=text, data="Data 1") obj2, created2 = await objects.create_or_get( TestModel, text=text, data="Data 2") self.assertTrue(created1) self.assertTrue(not created2) self.assertEqual(obj1, obj2) self.assertEqual(obj1.data, "Data 1") self.assertEqual(obj2.data, "Data 1") self.run_with_managers(test) def test_get_or_create(self): async def test(objects): text = "Test %s" % uuid.uuid4() obj1, created1 = await objects.get_or_create( TestModel, text=text, defaults={'data': "Data 1"}) obj2, created2 = await objects.get_or_create( TestModel, text=text, defaults={'data': "Data 2"}) self.assertTrue(created1) self.assertTrue(not created2) self.assertEqual(obj1, obj2) self.assertEqual(obj1.data, "Data 1") self.assertEqual(obj2.data, "Data 1") self.run_with_managers(test) def test_create_uuid_obj(self): async def test(objects): text = "Test %s" % uuid.uuid4() obj = await objects.create(UUIDTestModel, text=text) self.assertEqual(len(str(obj.id)), 36) self.run_with_managers(test, exclude=['mysql', 'mysql-pool']) def test_get_obj_by_id(self): async def test(objects): text = "Test %s" % uuid.uuid4() obj1 = await objects.create(TestModel, text=text) obj2 = await objects.get(TestModel, id=obj1.id) self.assertEqual(obj1, obj2) self.assertEqual(obj1.id, obj2.id) self.run_with_managers(test) def test_get_obj_by_uuid(self): async def test(objects): text = "Test %s" % uuid.uuid4() obj1 = await objects.create(UUIDTestModel, text=text) obj2 = await objects.get(UUIDTestModel, id=obj1.id) self.assertEqual(obj1, obj2) self.assertEqual(len(str(obj1.id)), 36) self.run_with_managers(test) def test_raw_query(self): async def test(objects): text = "Test %s" % uuid.uuid4() await objects.create(TestModel, text=text) result1 = await objects.execute(TestModel.raw( 'select id, text from testmodel')) result1 = list(result1) self.assertEqual(len(result1), 1) self.assertTrue(isinstance(result1[0], TestModel)) result2 = await objects.execute(TestModel.raw( 'select id, text from testmodel').tuples()) result2 = list(result2) self.assertEqual(len(result2), 1) self.assertTrue(isinstance(result2[0], tuple)) result3 = await objects.execute(TestModel.raw( 'select id, text from testmodel').dicts()) result3 = list(result3) self.assertEqual(len(result3), 1) self.assertTrue(isinstance(result3[0], dict)) self.run_with_managers(test) def test_select_many_objects(self): async def test(objects): text = "Test 1" obj1 = await objects.create(TestModel, text=text) text = "Test 2" obj2 = await objects.create(TestModel, text=text) select1 = [obj1, obj2] len1 = len(select1) select2 = await objects.execute( TestModel.select().order_by(TestModel.text)) len2 = len([o for o in select2]) self.assertEqual(len1, len2) for o1, o2 in zip(select1, select2): self.assertEqual(o1, o2) self.run_with_managers(test) def test_indexing_result(self): async def test(objects): await objects.create(TestModel, text="Test 1") obj = await objects.create(TestModel, text="Test 2") result = await objects.execute( TestModel.select().order_by(TestModel.text)) self.assertEqual(obj, result[1]) self.run_with_managers(test) def test_multiple_iterate_over_result(self): async def test(objects): obj1 = await objects.create(TestModel, text="Test 1") obj2 = await objects.create(TestModel, text="Test 2") result = await objects.execute( TestModel.select().order_by(TestModel.text)) self.assertEqual(list(result), [obj1, obj2]) self.assertEqual(list(result), [obj1, obj2]) self.run_with_managers(test) def test_insert_many_rows_query(self): async def test(objects): select1 = await objects.execute(TestModel.select()) self.assertEqual(len(select1), 0) query = TestModel.insert_many([ {'text': "Test %s" % uuid.uuid4()}, {'text': "Test %s" % uuid.uuid4()}, ]) last_id = await objects.execute(query) self.assertTrue(last_id is not None) select2 = await objects.execute(TestModel.select()) self.assertEqual(len(select2), 2) self.run_with_managers(test) def test_insert_one_row_query(self): async def test(objects): query = TestModel.insert(text="Test %s" % uuid.uuid4()) last_id = await objects.execute(query) self.assertTrue(last_id is not None) select1 = await objects.execute(TestModel.select()) self.assertEqual(len(select1), 1) self.run_with_managers(test) def test_insert_one_row_uuid_query(self): async def test(objects): query = UUIDTestModel.insert(text="Test %s" % uuid.uuid4()) last_id = await objects.execute(query) self.assertEqual(len(str(last_id)), 36) self.run_with_managers(test, exclude=['mysql', 'mysql-pool']) def test_update_query(self): async def test(objects): text = "Test %s" % uuid.uuid4() obj1 = await objects.create(TestModel, text=text) query = TestModel.update(text="Test update query") \ .where(TestModel.id == obj1.id) upd1 = await objects.execute(query) self.assertEqual(upd1, 1) obj2 = await objects.get(TestModel, id=obj1.id) self.assertEqual(obj2.text, "Test update query") self.run_with_managers(test) def test_update_obj(self): async def test(objects): text = "Test %s" % uuid.uuid4() obj1 = await objects.create(TestModel, text=text) obj1.text = "Test update object" await objects.update(obj1) obj2 = await objects.get(TestModel, id=obj1.id) self.assertEqual(obj2.text, "Test update object") self.run_with_managers(test) def test_delete_obj(self): async def test(objects): text = "Test %s" % uuid.uuid4() obj1 = await objects.create(TestModel, text=text) obj2 = await objects.get(TestModel, id=obj1.id) await objects.delete(obj2) try: obj3 = await objects.get(TestModel, id=obj1.id) except TestModel.DoesNotExist: obj3 = None self.assertTrue(obj3 is None, "Error, object wasn't deleted") self.run_with_managers(test) def test_scalar_query(self): async def test(objects): text = "Test %s" % uuid.uuid4() await objects.create(TestModel, text=text) text = "Test %s" % uuid.uuid4() await objects.create(TestModel, text=text) fn = peewee.fn.Count(TestModel.id) count = await objects.scalar(TestModel.select(fn)) self.assertEqual(count, 2) self.run_with_managers(test) def test_count_query(self): async def test(objects): text = "Test %s" % uuid.uuid4() await objects.create(TestModel, text=text) text = "Test %s" % uuid.uuid4() await objects.create(TestModel, text=text) text = "Test %s" % uuid.uuid4() await objects.create(TestModel, text=text) count = await objects.count(TestModel.select()) self.assertEqual(count, 3) self.run_with_managers(test) def test_count_query_with_limit(self): async def test(objects): text = "Test %s" % uuid.uuid4() await objects.create(TestModel, text=text) text = "Test %s" % uuid.uuid4() await objects.create(TestModel, text=text) text = "Test %s" % uuid.uuid4() await objects.create(TestModel, text=text) count = await objects.count(TestModel.select().limit(1)) self.assertEqual(count, 1) self.run_with_managers(test) def test_prefetch(self): async def test(objects): alpha_1 = await objects.create( TestModelAlpha, text='Alpha 1') alpha_2 = await objects.create( TestModelAlpha, text='Alpha 2') beta_11 = await objects.create( TestModelBeta, alpha=alpha_1, text='Beta 11') beta_12 = await objects.create( TestModelBeta, alpha=alpha_1, text='Beta 12') _ = await objects.create( TestModelBeta, alpha=alpha_2, text='Beta 21') _ = await objects.create( TestModelBeta, alpha=alpha_2, text='Beta 22') gamma_111 = await objects.create( TestModelGamma, beta=beta_11, text='Gamma 111') gamma_112 = await objects.create( TestModelGamma, beta=beta_11, text='Gamma 112') result = await objects.prefetch( TestModelAlpha.select(), TestModelBeta.select(), TestModelGamma.select()) self.assertEqual(tuple(result), (alpha_1, alpha_2)) self.assertEqual(tuple(result[0].betas), (beta_11, beta_12)) self.assertEqual(tuple(result[0].betas[0].gammas), (gamma_111, gamma_112)) self.run_with_managers(test) def test_composite_key(self): async def test(objects): obj_uuid = await objects.create(UUIDTestModel, text='UUID') obj_alpha = await objects.create(TestModelAlpha, text='Alpha') comp = await objects.create(CompositeTestModel, uuid=obj_uuid, alpha=obj_alpha) self.assertEqual((obj_uuid, obj_alpha), (comp.uuid, comp.alpha)) self.run_with_managers(test) ###################### # Transactions tests # ###################### class FakeUpdateError(Exception): """Fake error while updating database. """ pass class ManagerTransactionsTestCase(BaseManagerTestCase): # only = ['postgres', 'postgres-ext', 'postgres-pool', 'postgres-pool-ext'] only = None def test_atomic_success(self): """Successful update in transaction. """ async def test(objects): obj = await objects.create(TestModel, text='FOO') obj_id = obj.id async with objects.atomic(): obj.text = 'BAR' await objects.update(obj) res = await objects.get(TestModel, id=obj_id) self.assertEqual(res.text, 'BAR') self.run_with_managers(test) def test_atomic_failed(self): """Failed update in transaction. """ async def test(objects): obj = await objects.create(TestModel, text='FOO') obj_id = obj.id try: async with objects.atomic(): obj.text = 'BAR' await objects.update(obj) raise FakeUpdateError() except FakeUpdateError as e: error = True res = await objects.get(TestModel, id=obj_id) self.assertTrue(error) self.assertEqual(res.text, 'FOO') self.run_with_managers(test) def test_several_transactions(self): """Run several transactions in parallel tasks. """ wait = lambda tasks: self.loop.run_until_complete( asyncio.wait([ self.loop.create_task(t) for t in tasks ], loop=self.loop)) async def t1(objects): async with objects.atomic(): self.assertEqual(objects.database.transaction_depth_async(), 1) await asyncio.sleep(0.25, loop=self.loop) async def t2(objects): async with objects.atomic(): self.assertEqual(objects.database.transaction_depth_async(), 1) await asyncio.sleep(0.0625, loop=self.loop) async def t3(objects): async with objects.atomic(): self.assertEqual(objects.database.transaction_depth_async(), 1) await asyncio.sleep(0.125, loop=self.loop) for _, objects in self.managers.items(): wait([ t1(objects), t2(objects), t3(objects), ]) with self.manager(objects, allow_sync=True): for model in reversed(self.models): model.delete().execute() self.run_count += 1 def test_atomic_fail_with_disconnect(self): """Database gone in transaction. """ async def test(objects): error = False try: async with objects.atomic(): await objects.database.close_async() raise FakeUpdateError() except FakeUpdateError: error = True self.assertTrue(error) self.run_with_managers(test)