# -*- coding: utf-8 -*- __author__ = 'Masroor Ehsan' import unittest import threading import time import pg_simple TEST_DB_DSN = 'dbname=pg_simple user=masroor' class AbstractPgSimpleTestCase(unittest.TestCase): def __init__(self, *args, **kwargs): super(AbstractPgSimpleTestCase, self).__init__(*args, **kwargs) # Kludge alert: We want this class to carry test cases without being run # by the unit test framework, so the `run' method is overridden to do # nothing. But in order for sub-classes to be able to do something when # run is invoked, the constructor will rebind `run' from TestCase. if self.__class__ != AbstractPgSimpleTestCase: # Rebind `run' from the parent class. self.run = unittest.TestCase.run.__get__(self, self.__class__) else: self.run = lambda self, *args, **kwargs: None def setUp(self): super(AbstractPgSimpleTestCase, self).setUp() self.pool = pg_simple.config_pool(max_conn=25, expiration=5, pool_manager=self._get_pool_manager(), dsn=TEST_DB_DSN) self.tables = (('pg_t1', '''id SERIAL PRIMARY KEY, name TEXT NOT NULL, count INTEGER NOT NULL DEFAULT 0, active BOOLEAN NOT NULL DEFAULT true'''), ('pg_t2', '''id SERIAL PRIMARY KEY, value TEXT NOT NULL, pg_t1_id INTEGER NOT NULL REFERENCES pg_t1(id)''')) def _get_pool_manager(self): raise NotImplementedError() def _drop_tables(self, db): db.drop('pg_t1', True) db.drop('pg_t2') def _truncate_tables(self, db): db.truncate('pg_t2', restart_identity=True) db.truncate('pg_t1', restart_identity=True, cascade=True) def _populate_tables(self, db): for i in range(26): id_ = db.insert('pg_t1', {'name': chr(97 + i) * 5}, returning='id') _ = db.insert('pg_t2', {'value': chr(97 + i) * 4, 'pg_t1_id': id_}) def _create_tables(self, db, fill=False): for (name, schema) in self.tables: db.create(name, schema) if fill: self._populate_tables(db) def test_basic_functions(self): import code import doctest import sys db = pg_simple.PgSimple(self.pool) if sys.argv.count('--interact'): db.log = sys.stdout code.interact(local=locals()) else: try: # Setup tables self._drop_tables(db) self._create_tables(db, fill=True) # Run tests doctest.testmod(optionflags=doctest.ELLIPSIS) finally: # Drop tables self._drop_tables(db) self.assertEqual(True, True) def _check_table(self, db, table_name): record = db.fetchone('pg_tables', fields=['tablename', ], where=('schemaname=%s AND tablename=%s', ['public', table_name])) self.assertEqual(record is not None and record.tablename == table_name, True, 'Table must exist, but was not found. Auto-commit fail.') def test_connection_auto_commit(self): import code import sys with pg_simple.PgSimple(self.pool) as db: if sys.argv.count('--interact'): db.log = sys.stdout code.interact(local=locals()) else: self._drop_tables(db) self._create_tables(db, fill=True) with pg_simple.PgSimple(self.pool) as db: try: self._check_table(db, 'pg_t1') finally: self._drop_tables(db) class PgSimpleTestCase(AbstractPgSimpleTestCase): def _get_pool_manager(self): return pg_simple.SimpleConnectionPool class PgSimpleThread(threading.Thread): def __init__(self, thread_id, name, counter, test_cls): threading.Thread.__init__(self) self.thread_id = thread_id self.name = name self.counter = counter self.test_cls = test_cls def run(self): print('Starting %s' % self.name) self.database_operations() print('Exiting %s' % self.name) def database_operations(self): with pg_simple.PgSimple(self.test_cls.pool) as db: self.test_cls._check_table(db, 'pg_t1') self.test_cls._truncate_tables(db) self.test_cls._populate_tables(db) time.sleep(1) class PgSimpleThreadedTestCase(AbstractPgSimpleTestCase): def _get_pool_manager(self): return pg_simple.ThreadedConnectionPool def test_threaded_connections(self): with pg_simple.PgSimple(self.pool) as db: self._drop_tables(db) self._create_tables(db, fill=True) threads = [] # Create new threads for i in range(20): t = PgSimpleThread(i, 'thread-' + str(i), i, self) threads.append(t) # Start new Threads for t in threads: t.start() # Wait for all threads to complete for t in threads: t.join() # Drop tables with pg_simple.PgSimple(self.pool) as db: self._drop_tables(db) print("Exiting Main Thread \n") if __name__ == '__main__': unittest.main()