# encoding: utf-8 import os import time import threading import socket import os from tornado import gen from tormysql.pool import ConnectionNotFoundError from pymysql import OperationalError from tornado.testing import gen_test from tormysql import Connection, ConnectionPool import sevent from tests import BaseTestCase class Request(object): def __init__(self, conn, host, port): self.conn = conn self.pconn = sevent.tcp.Socket() self.buffer = None self.connected = False self.conn.on("data", self.on_data) self.conn.on("close", self.on_close) self.pconn.on("connect", self.on_pconnect) self.pconn.on("data", self.on_pdata) self.pconn.on("close", self.on_pclose) self.pconn.connect((host, int(port))) def on_data(self, conn, data): if self.connected: self.pconn.write(data) else: self.buffer = data def on_pdata(self, conn, data): self.conn.write(data) def on_close(self, conn): self.pconn.end() try: TestThroughProxy.proxys.remove(self) except: pass def on_pclose(self, conn): self.conn.end() try: TestThroughProxy.proxys.remove(self) except: pass def on_pconnect(self, conn): self.connected = True if self.buffer: self.pconn.write(self.buffer) self.buffer = None class TestThroughProxy(BaseTestCase): proxys = [] def setUp(self): super(BaseTestCase, self).setUp() self.PARAMS = dict(self.PARAMS) self.host, self.port = self.PARAMS['host'], self.PARAMS['port'] def init_proxy(self): s = socket.socket() s.bind(('127.0.0.1', 0)) _, self.pport = s.getsockname() s.close() def on_connect(server, conn): TestThroughProxy.proxys.append(Request(conn, self.host, self.port)) self.proxy_server = sevent.tcp.Server() self.proxy_server.on("connection", on_connect) self.proxy_server.listen(("0.0.0.0", self.pport)) self.PARAMS['port'] = self.pport self.PARAMS['host'] = '127.0.0.1' def _close_proxy_sessions(self): for request in TestThroughProxy.proxys: request.conn.end() def tearDown(self): try: for request in TestThroughProxy.proxys: request.conn.end() self.proxy_server.close() except: pass super(BaseTestCase, self).tearDown() @gen.coroutine def _execute_test_connection_closing(self): self.init_proxy() connection = yield Connection(**self.PARAMS) cursor = connection.cursor() self._close_proxy_sessions() try: yield cursor.execute('SELECT 1') yield cursor.close() except OperationalError: pass else: raise AssertionError("Unexpected normal situation") self.proxy_server.close() @gen.coroutine def _execute_test_connection_closed(self): self.init_proxy() conn = yield Connection(**self.PARAMS) yield conn.close() self.proxy_server.close() try: yield Connection(**self.PARAMS) except OperationalError: pass else: raise AssertionError("Unexpected normal situation") @gen.coroutine def _execute_test_remote_closing(self): self.init_proxy() pool = ConnectionPool( max_connections=int(os.getenv("MYSQL_POOL", "5")), idle_seconds=7200, **self.PARAMS ) try: conn = yield pool.Connection() yield conn.do_close() self.proxy_server.close() yield pool.Connection() except OperationalError: pass else: raise AssertionError("Unexpected normal situation") finally: yield pool.close() @gen.coroutine def _execute_test_pool_closing(self): self.init_proxy() pool = ConnectionPool( max_connections=int(os.getenv("MYSQL_POOL", "5")), idle_seconds=7200, **self.PARAMS ) try: with (yield pool.Connection()) as connect: with connect.cursor() as cursor: self._close_proxy_sessions() yield cursor.execute("SELECT 1 as test") except (OperationalError, ConnectionNotFoundError) as e: pass else: raise AssertionError("Unexpected normal situation") finally: yield pool.close() self.proxy_server.close() @gen_test def test(self): loop = sevent.instance() def run(): loop.start() self.proxy_thread = threading.Thread(target=run) self.proxy_thread.setDaemon(True) self.proxy_thread.start() yield self._execute_test_connection_closing() yield self._execute_test_connection_closed() yield self._execute_test_remote_closing() yield self._execute_test_pool_closing()