""" asyncio compatibility shims. """ import asyncio import ssl import sys from asyncio.sslproto import SSLProtocol # type: ignore from typing import Optional, Union __all__ = ( "PY36_OR_LATER", "PY37_OR_LATER", "all_tasks", "get_running_loop", "start_tls", ) PY36_OR_LATER = sys.version_info[:2] >= (3, 6) PY37_OR_LATER = sys.version_info[:2] >= (3, 7) def get_running_loop() -> asyncio.AbstractEventLoop: if PY37_OR_LATER: return asyncio.get_running_loop() loop = asyncio.get_event_loop() if not loop.is_running(): raise RuntimeError("no running event loop") return loop def all_tasks(loop: asyncio.AbstractEventLoop = None): if PY37_OR_LATER: return asyncio.all_tasks(loop=loop) return asyncio.Task.all_tasks(loop=loop) async def start_tls( loop: asyncio.AbstractEventLoop, transport: asyncio.Transport, protocol: asyncio.Protocol, sslcontext: ssl.SSLContext, server_side: bool = False, server_hostname: Optional[str] = None, ssl_handshake_timeout: Optional[Union[float, int]] = None, ) -> asyncio.Transport: # We use hasattr here, as uvloop also supports start_tls. if hasattr(loop, "start_tls"): return await loop.start_tls( # type: ignore transport, protocol, sslcontext, server_side=server_side, server_hostname=server_hostname, ssl_handshake_timeout=ssl_handshake_timeout, ) waiter = loop.create_future() ssl_protocol = SSLProtocol( loop, protocol, sslcontext, waiter, server_side, server_hostname ) # Pause early so that "ssl_protocol.data_received()" doesn't # have a chance to get called before "ssl_protocol.connection_made()". transport.pause_reading() # Use set_protocol if we can if hasattr(transport, "set_protocol"): transport.set_protocol(ssl_protocol) else: transport._protocol = ssl_protocol # type: ignore conmade_cb = loop.call_soon(ssl_protocol.connection_made, transport) resume_cb = loop.call_soon(transport.resume_reading) try: await asyncio.wait_for(waiter, timeout=ssl_handshake_timeout) except Exception: transport.close() conmade_cb.cancel() resume_cb.cancel() raise return ssl_protocol._app_transport def create_connection(loop: asyncio.AbstractEventLoop, *args, **kwargs): if not PY37_OR_LATER: kwargs.pop("ssl_handshake_timeout") return loop.create_connection(*args, **kwargs) def create_unix_connection(loop: asyncio.AbstractEventLoop, *args, **kwargs): if not PY37_OR_LATER: kwargs.pop("ssl_handshake_timeout") return loop.create_unix_connection(*args, **kwargs)