import asyncio
import json
import time
import re
import os
import websockets
from .entity import (
    Quote, Trade, Agg, Entity,
    trade_mapping, quote_mapping, agg_mapping
)
from alpaca_trade_api.common import get_polygon_credentials, URL
import logging


class StreamConn(object):
    def __init__(self, key_id: str = None):
        self._key_id = get_polygon_credentials(key_id)
        self._endpoint: URL = URL(os.environ.get(
            'POLYGON_WS_URL',
            'wss://alpaca.socket.polygon.io/stocks'
        ).rstrip('/'))
        self._handlers = {}
        self._handler_symbols = {}
        self._streams = set([])
        self._ws = None
        self._retry = int(os.environ.get('APCA_RETRY_MAX', 3))
        self._retry_wait = int(os.environ.get('APCA_RETRY_WAIT', 3))
        self._retries = 0
        self.loop = asyncio.get_event_loop()
        self._consume_task = None

    async def connect(self):
        await self._dispatch({'ev': 'status',
                              'status': 'connecting',
                              'message': 'Connecting to Polygon'})
        self._ws = await websockets.connect(self._endpoint)
        self._stream = self._recv()

        msg = await self._next()
        if msg.get('status') != 'connected':
            raise ValueError(
                ("Invalid response on Polygon websocket connection: {}"
                    .format(msg))
            )
        await self._dispatch(msg)
        if await self.authenticate():
            self._consume_task = asyncio.ensure_future(self._consume_msg())
        else:
            await self.close()

    async def authenticate(self):
        ws = self._ws
        if not ws:
            return False

        await ws.send(json.dumps({
            'action': 'auth',
            'params': self._key_id
        }))
        data = await self._next()
        stream = data.get('ev')
        msg = data.get('message')
        status = data.get('status')
        if (stream == 'status'
                and msg == 'authenticated'
                and status == 'auth_success'):
            # reset retries only after we successfully authenticated
            self._retries = 0
            await self._dispatch(data)
            return True
        else:
            raise ValueError('Invalid Polygon credentials, '
                             f'Failed to authenticate: {data}')

    async def _next(self):
        '''Returns the next message available
        '''
        return await self._stream.__anext__()

    async def _recv(self):
        '''Function used to recieve and parse all messages from websocket stream.

        This generator yields one message per each call.
        '''
        try:
            while True:
                r = await self._ws.recv()
                if isinstance(r, bytes):
                    r = r.decode('utf-8')
                msg = json.loads(r)
                for update in msg:
                    yield update
        except Exception as e:
            await self._dispatch({'ev': 'status',
                                  'status': 'disconnected',
                                  'message':
                                  f'Polygon Disconnected Unexpectedly ({e})'})
            await self.close()
            asyncio.ensure_future(self._ensure_ws())

    async def consume(self):
        if self._consume_task:
            await self._consume_task

    async def _consume_msg(self):
        async for data in self._stream:
            stream = data.get('ev')
            if stream:
                await self._dispatch(data)
            elif data.get('status') == 'disconnected':
                # Polygon returns this on an empty 'ev' id..
                data['ev'] = 'status'
                await self._dispatch(data)
                raise ConnectionResetError(
                    'Polygon terminated connection: '
                    f'({data.get("message")})')

    async def _ensure_ws(self):
        if self._ws is not None:
            return

        while self._retries <= self._retry:
            try:
                await self.connect()
                if self._streams:
                    await self.subscribe(self._streams)
                break
            except Exception as e:
                await self._dispatch({'ev': 'status',
                                      'status': 'connect failed',
                                      'message':
                                      f'Polygon Connection Failed ({e})'})
                self._ws = None
                self._retries += 1
                time.sleep(self._retry_wait * self._retry)
        else:
            raise ConnectionError("Max Retries Exceeded")

    async def subscribe(self, channels):
        '''Subscribe to channels.
        Note: This is cumulative, meaning you can add channels at runtime,
        and you do not need to specify all the channels.

        To remove channels see unsubscribe().

        If the necessary connection isn't open yet, it opens now.
        '''
        if len(channels) > 0:
            await self._ensure_ws()
            # Join channel list to string
            streams = ','.join(channels)
            self._streams |= set(channels)
            await self._ws.send(json.dumps({
                'action': 'subscribe',
                'params': streams
            }))

    async def unsubscribe(self, channels):
        '''Unsubscribe from channels
        '''
        if not self._ws:
            return
        if len(channels) > 0:
            # Join channel list to string
            streams = ','.join(channels)
            self._streams -= set(channels)
            await self._ws.send(json.dumps({
                'action': 'unsubscribe',
                'params': streams
            }))

    def run(self, initial_channels=[]):
        '''Run forever and block until exception is raised.
        initial_channels is the channels to start with.
        '''
        loop = self.loop
        try:
            loop.run_until_complete(self.subscribe(initial_channels))
            loop.run_forever()
        except KeyboardInterrupt:
            logging.info("Exiting on Interrupt")
        finally:
            loop.run_until_complete(self.close())
            loop.close()

    async def close(self):
        '''Close any open connections'''
        if self._consume_task:
            self._consume_task.cancel()
        if self._ws is not None:
            await self._ws.close()
        self._ws = None

    def _cast(self, subject, data):
        if subject == 'T':
            return Trade({trade_mapping[k]: v for k,
                          v in data.items() if k in trade_mapping})
        if subject == 'Q':
            return Quote({quote_mapping[k]: v for k,
                          v in data.items() if k in quote_mapping})
        if subject == 'AM' or subject == 'A':
            return Agg({agg_mapping[k]: v for k,
                        v in data.items() if k in agg_mapping})
        return Entity(data)

    async def _dispatch(self, msg):
        channel = msg.get('ev')
        for pat, handler in self._handlers.items():
            if pat.match(channel):
                handled_symbols = self._handler_symbols.get(handler)
                if handled_symbols is None or msg['sym'] in handled_symbols:
                    ent = self._cast(channel, msg)
                    await handler(self, channel, ent)

    def register(self, channel_pat, func, symbols=None):
        if not asyncio.iscoroutinefunction(func):
            raise ValueError('handler must be a coroutine function')
        if isinstance(channel_pat, str):
            channel_pat = re.compile(channel_pat)
        self._handlers[channel_pat] = func
        self._handler_symbols[func] = symbols

    def deregister(self, channel_pat):
        if isinstance(channel_pat, str):
            channel_pat = re.compile(channel_pat)
        self._handler_symbols.pop(self._handlers[channel_pat], None)
        del self._handlers[channel_pat]