from inspect import isawaitable from asyncio import ensure_future, wait, shield from websockets import ConnectionClosed from graphql.execution.executors.asyncio import AsyncioExecutor from .base import ( ConnectionClosedException, BaseConnectionContext, BaseSubscriptionServer) from .observable_aiter import setup_observable_extension from .constants import ( GQL_CONNECTION_ACK, GQL_CONNECTION_ERROR, GQL_COMPLETE) setup_observable_extension() class WsLibConnectionContext(BaseConnectionContext): async def receive(self): try: msg = await self.ws.recv() return msg except ConnectionClosed: raise ConnectionClosedException() async def send(self, data): if self.closed: return await self.ws.send(data) @property def closed(self): return self.ws.open is False async def close(self, code): await self.ws.close(code) class WsLibSubscriptionServer(BaseSubscriptionServer): def __init__(self, schema, keep_alive=True, loop=None): self.loop = loop super().__init__(schema, keep_alive) def get_graphql_params(self, *args, **kwargs): params = super(WsLibSubscriptionServer, self).get_graphql_params(*args, **kwargs) return dict(params, return_promise=True, executor=AsyncioExecutor(loop=self.loop)) async def _handle(self, ws, request_context): connection_context = WsLibConnectionContext(ws, request_context) await self.on_open(connection_context) pending = set() while True: try: if connection_context.closed: raise ConnectionClosedException() message = await connection_context.receive() except ConnectionClosedException: break finally: if pending: (_, pending) = await wait(pending, timeout=0, loop=self.loop) task = ensure_future( self.on_message(connection_context, message), loop=self.loop) pending.add(task) self.on_close(connection_context) for task in pending: task.cancel() async def handle(self, ws, request_context=None): await shield(self._handle(ws, request_context), loop=self.loop) async def on_open(self, connection_context): pass def on_close(self, connection_context): remove_operations = list(connection_context.operations.keys()) for op_id in remove_operations: self.unsubscribe(connection_context, op_id) async def on_connect(self, connection_context, payload): pass async def on_connection_init(self, connection_context, op_id, payload): try: await self.on_connect(connection_context, payload) await self.send_message( connection_context, op_type=GQL_CONNECTION_ACK) except Exception as e: await self.send_error( connection_context, op_id, e, GQL_CONNECTION_ERROR) await connection_context.close(1011) async def on_start(self, connection_context, op_id, params): execution_result = self.execute( connection_context.request_context, params) if isawaitable(execution_result): execution_result = await execution_result if not hasattr(execution_result, '__aiter__'): await self.send_execution_result( connection_context, op_id, execution_result) else: iterator = await execution_result.__aiter__() connection_context.register_operation(op_id, iterator) async for single_result in iterator: if not connection_context.has_operation(op_id): break await self.send_execution_result( connection_context, op_id, single_result) await self.send_message(connection_context, op_id, GQL_COMPLETE) async def on_stop(self, connection_context, op_id): self.unsubscribe(connection_context, op_id)