# coding=utf-8 """ Copyright (c) 2018-present, Ant Financial Service Group Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ------------------------------------------------------ File Name : aio_client Author : jiaqi.hjq """ # Needs python >= 3.4 import asyncio import functools import logging import threading import traceback from contextlib import suppress try: import uvloop asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) except ImportError: pass from concurrent.futures import TimeoutError, CancelledError from anthunder.command.heartbeat import HeartbeatRequest from anthunder.exceptions import PyboltError, ServerError from anthunder.helpers.singleton import Singleton from anthunder.protocol import SofaHeader, BoltRequest, BoltResponse from anthunder.protocol.constants import PTYPE, CMDCODE, RESPSTATUS from .base import _BaseClient logger = logging.getLogger(__name__) class AioClient(_BaseClient): __metaclass__ = Singleton """bolt client implemented with asyncio""" def __init__(self, app_name, **kwargs): super(AioClient, self).__init__(app_name, **kwargs) self._loop = asyncio.new_event_loop() self.request_mapping = dict() # request_id: event self.response_mapping = dict() # request_id: response_pkg self.connection_mapping = dict() # address: (reader_coro, writer) self.loop_thread = self._init() self._pending_dial = dict() # address: (asyncio.Lock) def _init(self): def _t(): asyncio.set_event_loop(self._loop) self._loop.run_forever() t = threading.Thread(target=_t, daemon=True) t.start() if self._mesh_client: # ensure mesh client init success, and we will connect to mesh_service_address logger.debug("has mesh client, start a heartbeat thread") asyncio.run_coroutine_threadsafe(self._heartbeat_timer(self._get_address(None)), self._loop) logger.debug("client coro thread started") return t def invoke_oneway(self, interface, method, content, *, spanctx, **headers): header = SofaHeader.build_header(spanctx, interface, method, **headers) pkg = BoltRequest.new_request(header, content, timeout_ms=-1) asyncio.run_coroutine_threadsafe(self.invoke(pkg), loop=self._loop) def invoke_sync(self, interface, method, content, *, spanctx, timeout_ms, **headers): """blocking call to interface, returns responsepkg.content(as bytes)""" assert isinstance(timeout_ms, (int, float)) header = SofaHeader.build_header(spanctx, interface, method, **headers) pkg = BoltRequest.new_request(header, content, timeout_ms=timeout_ms) fut = asyncio.run_coroutine_threadsafe(self.invoke(pkg), loop=self._loop) try: ret = fut.result(timeout=timeout_ms / 1000) except (TimeoutError, CancelledError) as e: logger.error("call to [{}:{}] timeout/cancelled. {}".format(interface, method, e)) raise return ret.content def invoke_async(self, interface, method, content, *, spanctx, callback=None, timeout_ms=None, **headers): """ call callback if callback is a callable, otherwise return a future Callback should recv a bytes object as the only argument, which is the response pkg's content """ header = SofaHeader.build_header(spanctx, interface, method, **headers) pkg = BoltRequest.new_request(header, content, timeout_ms=timeout_ms or -1) fut = asyncio.run_coroutine_threadsafe(self.invoke(pkg), loop=self._loop) if callable(callback): fut.add_done_callback(self.callback_wrapper(callback, timeout_ms / 1000 if timeout_ms else None)) return fut return fut @staticmethod def callback_wrapper(callback, timeout=None): """get future's result, then feed to callback""" @functools.wraps(callback) def _inner(fut): try: ret = fut.result(timeout) except (CancelledError, TimeoutError): logger.error("Failed to get result") return return callback(ret.content) return _inner async def _heartbeat_timer(self, address, interval=30): """Invoke heartbeat periodly""" while True: await asyncio.sleep(interval) await self.invoke_heartbeat(address) async def invoke_heartbeat(self, address): """ Send heartbeat to server :return bool, if the server response properly. TODO: to break the connection if server response wrongly """ pkg = HeartbeatRequest.new_request() resp = await self.invoke(pkg, address=address) if resp.request_id != pkg.request_id: logger.error("heartbeat response request_id({}) mismatch with request({}).".format(resp.request_id, pkg.request_id)) return False if resp.respstatus != RESPSTATUS.SUCCESS: logger.error("heartbeat response status ({}) on request({}).".format(resp.respstatus, resp.request_id)) return False return True async def _get_connection(self, address): try: # fast path return existed connection if address in self.connection_mapping: return self.connection_mapping[address] async with self._pending_dial.setdefault(address, asyncio.Lock()): if address in self.connection_mapping: return self.connection_mapping[address] reader, writer = await asyncio.open_connection(*address) task = asyncio.ensure_future(self._recv_response(reader, writer)) return self.connection_mapping.setdefault(address, (task, writer)) except Exception as e: logger.error("Get connection of {} failed: {}".format(address, e)) raise async def invoke(self, request: BoltRequest, *, address=None): """ A request response wrapper :param address: a inet address, currently only for heartbeat request """ address = address or self._get_address(request.header['service']) logger.debug("invoke to address: {}".format(address)) event = await self._send_request(request, address=address) if event is None: logger.debug("no related event, should be a async/oneway call, return now") return await event.wait() return self.response_mapping.pop(request.request_id) async def _send_request(self, request: BoltRequest, *, address): """ send request and put request_id in request_mapping for response match :param request: :param address: a inet address, currently only for heartbeat request :return: """ assert isinstance(request, BoltRequest) async def _send(retry=3): if retry <= 0: raise PyboltError("send request failed.") readtask, writer = await self._get_connection(address) try: await writer.drain() # avoid back pressure writer.write(request.to_stream()) await writer.drain() except Exception as e: logger.error("Request sent to {} failed: {}, may try again.".format(address, e)) readtask.cancel() self.connection_mapping.pop(address) self._pending_dial.pop(address) await _send(retry - 1) # generate event object first, ensure every successfully sent request has a event self.request_mapping[request.request_id] = asyncio.Event() try: await _send() except PyboltError: logger.error("failed to send request {}".format(request.request_id)) self.request_mapping.pop(request.request_id) return except Exception: logger.error(traceback.format_exc()) self.request_mapping.pop(request.request_id) return if request.ptype == PTYPE.ONEWAY: self.request_mapping.pop(request.request_id) return return self.request_mapping[request.request_id] async def _recv_response(self, reader, writer): """ wait response and put it in response_mapping, than notify the invoke coro :param reader: :return: """ while True: pkg = None try: fixed_header_bs = await reader.readexactly(BoltResponse.bolt_header_size()) header = BoltResponse.bolt_header_from_stream(fixed_header_bs) bs = await reader.readexactly(header['class_len'] + header['header_len'] + header['content_len']) pkg = BoltResponse.bolt_content_from_stream(bs, header) if pkg.class_name != BoltResponse.class_name: raise ServerError("wrong class_name:[{}]".format(pkg.class_name)) if pkg.cmdcode == CMDCODE.HEARTBEAT: continue elif pkg.cmdcode == CMDCODE.REQUEST: # raise error, the connection will be dropped raise ServerError("wrong cmdcode:[{}]".format(pkg.cmdcode)) if pkg.respstatus != RESPSTATUS.SUCCESS: raise ServerError.from_statuscode(pkg.respstatus) if pkg.request_id not in self.request_mapping: continue self.response_mapping[pkg.request_id] = pkg except PyboltError as e: logger.error(e) except (asyncio.CancelledError, EOFError, ConnectionResetError) as e: logger.error(e) writer.close() break except Exception: logger.error(traceback.format_exc()) writer.close() break finally: with suppress(AttributeError, KeyError): # wake up the coro event = self.request_mapping.pop(pkg.request_id) event.set()