import asyncio import json import os import re import websockets from .common import get_base_url, get_data_url, get_credentials, URL from .entity import Account, Entity, trade_mapping, agg_mapping, quote_mapping from . import polygon from .entity import Trade, Quote, Agg import logging from typing import List, Callable class _StreamConn(object): def __init__(self, key_id: str, secret_key: str, base_url: URL): self._key_id = key_id self._secret_key = secret_key self._base_url = re.sub(r'^http', 'ws', base_url) self._endpoint = self._base_url + '/stream' self._handlers = {} self._handler_symbols = {} self._streams = set([]) self._ws = None self._retry = int(os.environ.get('APCA_RETRY_MAX', 3)) self._retry_wait = int(os.environ.get('APCA_RETRY_WAIT', 3)) self._retries = 0 self._consume_task = None async def _connect(self): ws = await websockets.connect(self._endpoint) await ws.send(json.dumps({ 'action': 'authenticate', 'data': { 'key_id': self._key_id, 'secret_key': self._secret_key, } })) r = await ws.recv() if isinstance(r, bytes): r = r.decode('utf-8') msg = json.loads(r) if msg.get('data', {}).get('status'): status = msg.get('data').get('status') if status != 'authorized': raise ValueError( (f"Invalid Alpaca API credentials, Failed to " f"authenticate: {msg}") ) else: self._retries = 0 elif msg.get('data', {}).get('error'): raise Exception(f"Error while connecting to {self._endpoint}:" f"{msg.get('data').get('error')}") else: self._retries = 0 self._ws = ws await self._dispatch('authorized', msg) self._consume_task = asyncio.ensure_future(self._consume_msg()) async def consume(self): if self._consume_task: await self._consume_task async def _consume_msg(self): ws = self._ws try: while True: r = await ws.recv() if isinstance(r, bytes): r = r.decode('utf-8') msg = json.loads(r) stream = msg.get('stream') if stream is not None: await self._dispatch(stream, msg) except websockets.WebSocketException as wse: logging.warn(wse) await self.close() asyncio.ensure_future(self._ensure_ws()) async def _ensure_ws(self): if self._ws is not None: return while self._retries <= self._retry: try: await self._connect() if self._streams: await self.subscribe(self._streams) break except websockets.WebSocketException as wse: logging.warn(wse) self._ws = None self._retries += 1 await asyncio.sleep(self._retry_wait * self._retry) else: raise ConnectionError("Max Retries Exceeded") async def subscribe(self, channels): if isinstance(channels, str): channels = [channels] if len(channels) > 0: await self._ensure_ws() self._streams |= set(channels) await self._ws.send(json.dumps({ 'action': 'listen', 'data': { 'streams': channels, } })) async def unsubscribe(self, channels): if isinstance(channels, str): channels = [channels] if len(channels) > 0: await self._ws.send(json.dumps({ 'action': 'unlisten', 'data': { 'streams': channels, } })) async def close(self): if self._consume_task: self._consume_task.cancel() if self._ws: await self._ws.close() self._ws = None def _cast(self, channel, msg): if channel == 'account_updates': return Account(msg) if channel.startswith('T.'): return Trade({trade_mapping[k]: v for k, v in msg.items() if k in trade_mapping}) if channel.startswith('Q.'): return Quote({quote_mapping[k]: v for k, v in msg.items() if k in quote_mapping}) if channel.startswith('A.') or channel.startswith('AM.'): # to be compatible with REST Agg msg['t'] = msg['s'] return Agg({agg_mapping[k]: v for k, v in msg.items() if k in agg_mapping}) return Entity(msg) async def _dispatch(self, channel, msg): for pat, handler in self._handlers.items(): if pat.match(channel): ent = self._cast(channel, msg['data']) await handler(self, channel, ent) def on(self, channel_pat, symbols=None): def decorator(func): self.register(channel_pat, func, symbols) return func return decorator def register(self, channel_pat, func: Callable, symbols=None): if not asyncio.iscoroutinefunction(func): raise ValueError('handler must be a coroutine function') if isinstance(channel_pat, str): channel_pat = re.compile(channel_pat) self._handlers[channel_pat] = func self._handler_symbols[func] = symbols def deregister(self, channel_pat): if isinstance(channel_pat, str): channel_pat = re.compile(channel_pat) self._handler_symbols.pop(self._handlers[channel_pat], None) del self._handlers[channel_pat] class StreamConn(object): def __init__( self, key_id: str = None, secret_key: str = None, base_url: URL = None, data_url: URL = None, data_stream: str = None): self._key_id, self._secret_key, _ = get_credentials(key_id, secret_key) self._base_url = base_url or get_base_url() self._data_url = data_url or get_data_url() if data_stream is not None: if data_stream in ('alpacadatav1', 'polygon'): _data_stream = data_stream else: raise ValueError('invalid data_stream name {}'.format( data_stream)) else: _data_stream = 'alpacadatav1' self._data_stream = _data_stream self.trading_ws = _StreamConn(self._key_id, self._secret_key, self._base_url) if self._data_stream == 'polygon': self.data_ws = polygon.StreamConn( self._key_id + '-staging' if 'staging' in self._base_url else self._key_id) self._data_prefixes = (('Q.', 'T.', 'A.', 'AM.')) else: self.data_ws = _StreamConn(self._key_id, self._secret_key, self._data_url) self._data_prefixes = ( ('Q.', 'T.', 'AM.', 'alpacadatav1/')) self._handlers = {} self._handler_symbols = {} try: self.loop = asyncio.get_event_loop() except websockets.WebSocketException as wse: logging.warn(wse) self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) async def _ensure_ws(self, conn): if conn._handlers: return conn._handlers = self._handlers.copy() conn._handler_symbols = self._handler_symbols.copy() if isinstance(conn, _StreamConn): await conn._connect() else: await conn.connect() async def subscribe(self, channels: List[str]): '''Start subscribing to channels. If the necessary connection isn't open yet, it opens now. This may raise ValueError if a channel is not recognized. ''' trading_channels, data_channels = [], [] for c in channels: if c in ('trade_updates', 'account_updates'): trading_channels.append(c) elif c.startswith(self._data_prefixes): data_channels.append(c) else: raise ValueError( ('unknown channel {} (you may need to specify ' + 'the right data_stream)').format(c)) if trading_channels: await self._ensure_ws(self.trading_ws) await self.trading_ws.subscribe(trading_channels) if data_channels: await self._ensure_ws(self.data_ws) await self.data_ws.subscribe(data_channels) async def unsubscribe(self, channels: List[str]): '''Handle unsubscribing from channels.''' data_channels = [ c for c in channels if c.startswith(self._data_prefixes) ] if data_channels: await self.data_ws.unsubscribe(data_channels) async def consume(self): await asyncio.gather( self.trading_ws.consume(), self.data_ws.consume(), ) def run(self, initial_channels: List[str] = []): '''Run forever and block until exception is raised. initial_channels is the channels to start with. ''' loop = self.loop should_renew = True # should renew connection if it disconnects while should_renew: try: if loop.is_closed(): self.loop = asyncio.new_event_loop() loop = self.loop loop.run_until_complete(self.subscribe(initial_channels)) loop.run_until_complete(self.consume()) except KeyboardInterrupt: logging.info("Exiting on Interrupt") should_renew = False except Exception as e: logging.error(f"error while consuming ws messages: {e}") loop.run_until_complete(self.close(should_renew)) if loop.is_running(): loop.close() async def close(self, renew): """ Close any of open connections :param renew: should re-open connection? """ if self.trading_ws is not None: await self.trading_ws.close() self.trading_ws = None if self.data_ws is not None: await self.data_ws.close() self.data_ws = None if renew: self.trading_ws = _StreamConn(self._key_id, self._secret_key, self._base_url) if self._data_stream == 'polygon': self.data_ws = polygon.StreamConn( self._key_id + '-staging' if 'staging' in self._base_url else self._key_id) else: self.data_ws = _StreamConn(self._key_id, self._secret_key, self._data_url) def on(self, channel_pat, symbols=None): def decorator(func): self.register(channel_pat, func, symbols) return func return decorator def register(self, channel_pat, func: Callable, symbols=None): if not asyncio.iscoroutinefunction(func): raise ValueError('handler must be a coroutine function') if isinstance(channel_pat, str): channel_pat = re.compile(channel_pat) self._handlers[channel_pat] = func self._handler_symbols[func] = symbols if self.trading_ws: self.trading_ws.register(channel_pat, func, symbols) if self.data_ws: self.data_ws.register(channel_pat, func, symbols) def deregister(self, channel_pat): if isinstance(channel_pat, str): channel_pat = re.compile(channel_pat) self._handler_symbols.pop(self._handlers[channel_pat], None) del self._handlers[channel_pat] if self.trading_ws: self.trading_ws.deregister(channel_pat) if self.data_ws: self.data_ws.deregister(channel_pat)