import asyncio import time import logging from itertools import chain from elasticsearch import Transport, TransportError, ConnectionTimeout, ConnectionError, SerializationError from elasticsearch.connection_pool import DummyConnectionPool from .connection import AIOHttpConnection from .connection_pool import AsyncConnectionPool, AsyncDummyConnectionPool from .helpers import ensure_future logger = logging.getLogger('elasticsearch') class AsyncTransport(Transport): def __init__(self, hosts, connection_class=AIOHttpConnection, loop=None, connection_pool_class=AsyncConnectionPool, sniff_on_start=False, raise_on_sniff_error=True, **kwargs): self.raise_on_sniff_error = raise_on_sniff_error self.loop = asyncio.get_event_loop() if loop is None else loop kwargs['loop'] = self.loop super().__init__(hosts, connection_class=connection_class, sniff_on_start=False, connection_pool_class=connection_pool_class, **kwargs) self.sniffing_task = None if sniff_on_start: # schedule sniff on start self.initiate_sniff(True) def initiate_sniff(self, initial=False): """ Initiate a sniffing task. Make sure we only have one sniff request running at any given time. If a finished sniffing request is around, collect its result (which can raise its exception). """ if self.sniffing_task and self.sniffing_task.done(): try: if self.sniffing_task is not None: self.sniffing_task.result() except: if self.raise_on_sniff_error: raise finally: self.sniffing_task = None if self.sniffing_task is None: self.sniffing_task = ensure_future(self.sniff_hosts(initial), loop=self.loop) @asyncio.coroutine def close(self): if self.sniffing_task: self.sniffing_task.cancel() yield from self.connection_pool.close() def set_connections(self, hosts): super().set_connections(hosts) if isinstance(self.connection_pool, DummyConnectionPool): self.connection_pool = AsyncDummyConnectionPool(self.connection_pool.connection_opts) def get_connection(self): if self.sniffer_timeout: if time.time() >= self.last_sniff + self.sniffer_timeout: self.initiate_sniff() return self.connection_pool.get_connection() def mark_dead(self, connection): self.connection_pool.mark_dead(connection) if self.sniff_on_connection_fail: self.initiate_sniff() @asyncio.coroutine def _get_sniff_data(self, initial=False): previous_sniff = self.last_sniff # reset last_sniff timestamp self.last_sniff = time.time() # use small timeout for the sniffing request, should be a fast api call timeout = self.sniff_timeout if not initial else None tasks = [ c.perform_request('GET', '/_nodes/_all/http', timeout=timeout) # go through all current connections as well as the # seed_connections for good measure for c in chain(self.connection_pool.connections, (c for c in self.seed_connections if c not in self.connection_pool.connections)) ] done = () try: while tasks: # execute sniff requests in parallel, wait for first to return done, tasks = yield from asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED, loop=self.loop) # go through all the finished tasks for t in done: try: _, headers, node_info = t.result() node_info = self.deserializer.loads(node_info, headers.get('content-type')) except (ConnectionError, SerializationError) as e: logger.warn('Sniffing request failed with %r', e) continue node_info = list(node_info['nodes'].values()) return node_info else: # no task has finished completely raise TransportError("N/A", "Unable to sniff hosts.") except: # keep the previous value on error self.last_sniff = previous_sniff raise finally: # clean up pending futures for t in chain(done, tasks): t.cancel() @asyncio.coroutine def sniff_hosts(self, initial=False): """ Obtain a list of nodes from the cluster and create a new connection pool using the information retrieved. To extract the node connection parameters use the ``nodes_to_host_callback``. :arg initial: flag indicating if this is during startup (``sniff_on_start``), ignore the ``sniff_timeout`` if ``True`` """ node_info = yield from self._get_sniff_data(initial) hosts = list(filter(None, (self._get_host_info(n) for n in node_info))) # we weren't able to get any nodes, maybe using an incompatible # transport_schema or host_info_callback blocked all - raise error. if not hosts: raise TransportError("N/A", "Unable to sniff hosts - no viable hosts found.") # remember current live connections orig_connections = self.connection_pool.connections[:] self.set_connections(hosts) # close those connections that are not in use any more for c in orig_connections: if c not in self.connection_pool.connections: yield from c.close() @asyncio.coroutine def main_loop(self, method, url, params, body, headers=None, ignore=(), timeout=None): for attempt in range(self.max_retries + 1): connection = self.get_connection() try: status, headers, data = yield from connection.perform_request( method, url, params, body, headers=headers, ignore=ignore, timeout=timeout) except TransportError as e: if method == 'HEAD' and e.status_code == 404: return False retry = False if isinstance(e, ConnectionTimeout): retry = self.retry_on_timeout elif isinstance(e, ConnectionError): retry = True elif e.status_code in self.retry_on_status: retry = True if retry: # only mark as dead if we are retrying self.mark_dead(connection) # raise exception on last retry if attempt == self.max_retries: raise else: raise else: if method == 'HEAD': return 200 <= status < 300 # connection didn't fail, confirm it's live status self.connection_pool.mark_live(connection) if data: data = self.deserializer.loads(data, headers.get('content-type')) return data def perform_request(self, method, url, headers=None, params=None, body=None): if body is not None: body = self.serializer.dumps(body) # some clients or environments don't support sending GET with body if method in ('HEAD', 'GET') and self.send_get_body_as != 'GET': # send it as post instead if self.send_get_body_as == 'POST': method = 'POST' # or as source parameter elif self.send_get_body_as == 'source': if params is None: params = {} params['source'] = body body = None if body is not None: try: body = body.encode('utf-8') except (UnicodeDecodeError, AttributeError): # bytes/str - no need to re-encode pass ignore = () timeout = None if params: timeout = params.pop('request_timeout', None) ignore = params.pop('ignore', ()) if isinstance(ignore, int): ignore = (ignore, ) return ensure_future(self.main_loop(method, url, params, body, headers=headers, ignore=ignore, timeout=timeout), loop=self.loop)