import aiohttp import asyncio import time import urllib.parse import pytest from aioes.transport import Endpoint, Transport @pytest.fixture def make_transport(loop, es_params): tr = None def maker(endpoints=[{'host': es_params['host']}], sniffer_interval=None): nonlocal tr tr = Transport(endpoints, loop=loop, sniffer_interval=sniffer_interval) return tr yield maker if tr is not None: tr.close() def test_ctor(make_transport, es_params): tr = make_transport() assert 3 == tr.max_retries assert time.monotonic() >= tr.last_sniff assert tr.sniffer_interval is None assert 0.1 == tr.sniffer_timeout assert [Endpoint('http', es_params['host'], 9200)] == tr.endpoints assert 1 == len(tr._pool.connections) @asyncio.coroutine def test_connector_factory(es_params, loop): class TCPConnector(aiohttp.TCPConnector): used = False def __init__(self, *args, **kwargs): TCPConnector.used = True super(TCPConnector, self).__init__(*args, **kwargs) tr = Transport( endpoints=[{'host': es_params['host']}], sniffer_interval=None, loop=loop, connector_factory=lambda: TCPConnector(loop=loop) ) assert 1 == len(tr._pool.connections) assert TCPConnector.used tr.close() @asyncio.coroutine def test_simple(make_transport): tr = make_transport() status, data = yield from tr.perform_request( 'GET', '/_nodes/_all/clear') assert 200 == status assert 'nodes' in data # self.assertEqual( # {'nodes': # {'kagIbHGHS3a0dcyPmp0Jkw': # {'version': '1.3.1', # 'ip': '127.0.1.1', # 'build': '2de6dc5', # 'name': 'Mandrill', # 'transport_address': # 'inet[/192.168.0.183:9300]', # 'http_address': 'inet[/192.168.0.183:9200]', # 'host': 'andrew-levelup'}}, # 'cluster_name': 'elasticsearch'}, data) def test_set_endpoints(make_transport): tr = make_transport([]) assert [] == tr.endpoints tr.endpoints = [{'host': 'localhost'}] assert [Endpoint('http', 'localhost', 9200)] == tr.endpoints assert 1 == len(tr._pool.connections) def test_set_endpoints_Endpoint(make_transport): tr = make_transport([]) assert [] == tr.endpoints tr.endpoints = [Endpoint('http', 'localhost', 9200)] assert [Endpoint('http', 'localhost', 9200)] == tr.endpoints assert 1 == len(tr._pool.connections) def test_dont_recreate_existing_connections(make_transport): tr = make_transport() tr.endpoints = [{'host': 'localhost'}] assert [Endpoint('http', 'localhost', 9200)] == tr.endpoints def test_set_malformed_endpoints(make_transport, es_params): tr = make_transport() with pytest.raises(RuntimeError): tr.endpoints = [123] assert [Endpoint('http', es_params['host'], 9200)] == tr.endpoints assert 1 == len(tr._pool.connections) def test_set_host_only_string(make_transport): tr = make_transport() tr.endpoints = ['host'] assert [Endpoint('http', 'host', 9200)] == tr.endpoints assert 1 == len(tr._pool.connections) def test_set_host_port_string(make_transport): tr = make_transport() tr.endpoints = ['host:123'] assert [Endpoint('http', 'host', 123)] == tr.endpoints assert 1 == len(tr._pool.connections) def test_set_host_port_string_invalid(make_transport, es_params): tr = make_transport() with pytest.raises(RuntimeError): tr.endpoints = ['host:123:abc'] assert [Endpoint('http', es_params['host'], 9200)] == tr.endpoints assert 1 == len(tr._pool.connections) def test_set_host_dict_invalid(make_transport, es_params): tr = make_transport() with pytest.raises(RuntimeError): tr.endpoints = [{'a': 'b'}] assert [Endpoint('http', es_params['host'], 9200)] == tr.endpoints assert 1 == len(tr._pool.connections) def test_username_password_endpoints_with_port(make_transport): tr = make_transport(endpoints=['john:doe@localhost:9200']) assert [Endpoint('http', 'john:doe@localhost', 9200)] == tr.endpoints def test_username_password_endpoints_without_port(make_transport): tr = make_transport(endpoints=['john:doe@localhost']) assert [Endpoint('http', 'john:doe@localhost', 9200)] == tr.endpoints def test_username_password_endpoints_with_port_https(make_transport): tr = make_transport(endpoints=['https://john:doe@localhost:9200']) assert [Endpoint('https', 'john:doe@localhost', 9200)] == tr.endpoints assert ('https', 'john:doe@localhost:9200', '/', '', '', '') == \ tuple(urllib.parse.urlparse(str(tr._pool.connections[0]._base_url))) def test_bad_schema(make_transport): with pytest.raises(RuntimeError): make_transport(endpoints=['s3://john:doe@localhost:9200']) def test_default_port_https(make_transport): tr = make_transport(endpoints=['https://localhost']) assert [Endpoint('https', 'localhost', 443)] == tr.endpoints def test_default_port_http(make_transport): tr = make_transport(endpoints=['http://localhost']) assert [Endpoint('http', 'localhost', 9200)] == tr.endpoints @asyncio.coroutine def test_sniff(make_transport, loop): tr = make_transport(sniffer_interval=0.001) t0 = time.monotonic() yield from asyncio.sleep(0.001, loop=loop) yield from tr.get_connection() assert tr.last_sniff > t0 @asyncio.coroutine def test_get_connection_without_sniffing(make_transport): tr = make_transport(sniffer_interval=1000) t0 = tr.last_sniff yield from tr.get_connection() assert t0 == tr.last_sniff @asyncio.coroutine def test_perform_request_body_bytes(make_transport): tr = make_transport() status, data = yield from tr.perform_request( 'GET', '/_nodes/_all', body=b'') assert status == 200