import time
import json
import asyncio
from ssl import SSLContext
from functools import partial
from typing import Optional, Callable, NoReturn

import jwt
import OpenSSL
from h2.connection import H2Connection
from h2.events import ResponseReceived, DataReceived, RemoteSettingsChanged,\
    StreamEnded, ConnectionTerminated, WindowUpdated, SettingsAcknowledged
from h2.exceptions import NoAvailableStreamIDError, FlowControlError
from h2.settings import SettingCodes

from aioapns.common import NotificationResult, DynamicBoundedSemaphore,\
    APNS_RESPONSE_CODE
from aioapns.exceptions import ConnectionClosed, ConnectionError
from aioapns.logging import logger


class ChannelPool(DynamicBoundedSemaphore):
    def __init__(self, *args, **kwargs):
        super(ChannelPool, self).__init__(*args, **kwargs)
        self._stream_id = -1

    async def acquire(self):
        await super(ChannelPool, self).acquire()
        self._stream_id += 2
        if self._stream_id > H2Connection.HIGHEST_ALLOWED_STREAM_ID:
            raise NoAvailableStreamIDError()
        return self._stream_id

    @property
    def is_busy(self):
        return self._value <= 0


class AuthorizationHeaderProvider:
    def get_header(self):
        raise NotImplementedError


class JWTAuthorizationHeaderProvider(AuthorizationHeaderProvider):

    TOKEN_TTL = 30 * 60

    def __init__(self, key, key_id, team_id):
        self.key = key
        self.key_id = key_id
        self.team_id = team_id

        self.__issued_at = None
        self.__header = None

    def get_header(self):
        now = time.time()
        if not self.__header or self.__issued_at < now - self.TOKEN_TTL:
            self.__issued_at = int(now)
            token = jwt.encode(
                payload={'iss': self.team_id, 'iat': self.__issued_at},
                key=self.key,
                algorithm='ES256',
                headers={'kid': self.key_id},
            ).decode('ascii')
            self.__header = f"bearer {token}"
        return self.__header


class H2Protocol(asyncio.Protocol):
    def __init__(self):
        self.transport = None
        self.conn = H2Connection()
        self.free_channels = ChannelPool(1000)

    def connection_made(self, transport):
        self.transport = transport
        self.conn.initiate_connection()
        self.flush()

    def data_received(self, data):
        for event in self.conn.receive_data(data):
            if isinstance(event, ResponseReceived):
                headers = dict(event.headers)
                self.on_response_received(headers)
            elif isinstance(event, DataReceived):
                self.on_data_received(event.data, event.stream_id)
            elif isinstance(event, RemoteSettingsChanged):
                self.on_remote_settings_changed(event.changed_settings)
            elif isinstance(event, StreamEnded):
                self.on_stream_ended(event.stream_id)
            elif isinstance(event, ConnectionTerminated):
                self.on_connection_terminated(event)
            elif isinstance(event, WindowUpdated):
                pass
            elif isinstance(event, SettingsAcknowledged):
                pass
            else:
                logger.warning('Unknown event: %s', event)
        self.flush()

    def flush(self):
        self.transport.write(self.conn.data_to_send())

    def on_response_received(self, headers):
        pass

    def on_data_received(self, data, stream_id):
        pass

    def on_remote_settings_changed(self, changed_settings):
        for setting in changed_settings.values():
            logger.debug('Remote setting changed: %s', setting)
            if setting.setting == SettingCodes.MAX_CONCURRENT_STREAMS:
                self.free_channels.bound = setting.new_value

    def on_stream_ended(self, stream_id):
        if stream_id % 2 == 0:
            logger.warning('End stream: %d', stream_id)
        self.free_channels.release()

    def on_connection_terminated(self, event):
        pass


class APNsBaseClientProtocol(H2Protocol):
    APNS_SERVER = 'api.push.apple.com'
    INACTIVITY_TIME = 10

    def __init__(self,
                 apns_topic: str,
                 loop: Optional[asyncio.AbstractEventLoop] = None,
                 on_connection_lost: Optional[
                     Callable[['APNsBaseClientProtocol'], NoReturn]] = None,
                 auth_provider: Optional[AuthorizationHeaderProvider] = None):
        super(APNsBaseClientProtocol, self).__init__()
        self.apns_topic = apns_topic
        self.loop = loop or asyncio.get_event_loop()
        self.on_connection_lost = on_connection_lost
        self.auth_provider = auth_provider

        self.requests = {}
        self.request_streams = {}
        self.request_statuses = {}
        self.inactivity_timer = None

    def connection_made(self, transport):
        super(APNsBaseClientProtocol, self).connection_made(transport)
        self.refresh_inactivity_timer()

    async def send_notification(self, request):
        stream_id = await self.free_channels.acquire()

        headers = [
            (':method', 'POST'),
            (':scheme', 'https'),
            (':path', '/3/device/%s' % request.device_token),
            ('host', self.APNS_SERVER),
            ('apns-id', request.notification_id),
            ('apns-topic', self.apns_topic)
        ]
        if request.time_to_live is not None:
            expiration = int(time.time()) + request.time_to_live
            headers.append(('apns-expiration', str(expiration)))
        if request.priority is not None:
            headers.append(('apns-priority', str(request.priority)))
        if request.collapse_key is not None:
            headers.append(('apns-collapse-id', request.collapse_key))
        if request.push_type is not None:
            headers.append(('apns-push-type', request.push_type.value))
        if self.auth_provider:
            headers.append(('authorization', self.auth_provider.get_header()))

        self.conn.send_headers(
            stream_id=stream_id,
            headers=headers
        )
        try:
            data = json.dumps(request.message, ensure_ascii=False).encode()
            self.conn.send_data(
                stream_id=stream_id,
                data=data,
                end_stream=True,
            )
        except FlowControlError:
            raise

        self.flush()

        future_response = asyncio.Future()
        self.requests[request.notification_id] = future_response
        self.request_streams[stream_id] = request.notification_id

        response = await future_response
        return response

    def flush(self):
        self.refresh_inactivity_timer()
        self.transport.write(self.conn.data_to_send())

    def refresh_inactivity_timer(self):
        if self.inactivity_timer:
            self.inactivity_timer.cancel()
        self.inactivity_timer = self.loop.call_later(
            self.INACTIVITY_TIME, self.close)

    @property
    def is_busy(self):
        return self.free_channels.is_busy

    def close(self):
        raise NotImplementedError

    def connection_lost(self, exc):
        logger.debug('Connection %s lost!', self)

        if self.inactivity_timer:
            self.inactivity_timer.cancel()

        if self.on_connection_lost:
            self.on_connection_lost(self)

        closed_connection = ConnectionClosed()
        for request in self.requests.values():
            request.set_exception(closed_connection)
        self.free_channels.destroy(closed_connection)

    def on_response_received(self, headers):
        notification_id = headers.get(b'apns-id').decode('utf8')
        status = headers.get(b':status').decode('utf8')
        if status == APNS_RESPONSE_CODE.SUCCESS:
            request = self.requests.pop(notification_id, None)
            if request:
                result = NotificationResult(notification_id, status)
                request.set_result(result)
            else:
                logger.warning(
                    'Got response for unknown notification request %s',
                    notification_id)
        else:
            self.request_statuses[notification_id] = status

    def on_data_received(self, data, stream_id):
        data = json.loads(data.decode())
        reason = data.get('reason', '')
        if not reason:
            return

        notification_id = self.request_streams.pop(stream_id, None)
        if notification_id:
            request = self.requests.pop(notification_id, None)
            if request:
                # TODO: Теоретически здесь может быть ошибка, если нет ключа
                status = self.request_statuses.pop(notification_id)
                result = NotificationResult(notification_id, status,
                                            description=reason)
                request.set_result(result)
            else:
                logger.warning('Could not find request %s', notification_id)
        else:
            logger.warning('Could not find notification by stream %s',
                           stream_id)

    def on_connection_terminated(self, event):
        logger.warning(
            'Connection %s terminated: code=%s, additional_data=%s, '
            'last_stream_id=%s', self, event.error_code,
            event.additional_data, event.last_stream_id)
        self.close()


class APNsTLSClientProtocol(APNsBaseClientProtocol):
    APNS_PORT = 443

    def close(self):
        if self.inactivity_timer:
            self.inactivity_timer.cancel()
        logger.debug('Closing connection %s', self)
        self.transport.close()


class APNsProductionClientProtocol(APNsTLSClientProtocol):
    APNS_SERVER = 'api.push.apple.com'


class APNsDevelopmentClientProtocol(APNsTLSClientProtocol):
    APNS_SERVER = 'api.development.push.apple.com'


class APNsBaseConnectionPool:
    def __init__(self,
                 topic: Optional[str] = None,
                 max_connections: int = 10,
                 max_connection_attempts: Optional[int] = None,
                 loop: Optional[asyncio.AbstractEventLoop] = None,
                 use_sandbox: bool = False):

        self.apns_topic = topic
        self.max_connections = max_connections
        if use_sandbox:
            self.protocol_class = APNsDevelopmentClientProtocol
        else:
            self.protocol_class = APNsProductionClientProtocol

        self.loop = loop or asyncio.get_event_loop()
        self.connections = []
        self._lock = asyncio.Lock(loop=self.loop)
        self.max_connection_attempts = max_connection_attempts

    async def create_connection(self):
        raise NotImplementedError

    def close(self):
        for connection in self.connections:
            connection.close()

    def discard_connection(self, connection):
        logger.debug('Connection %s discarded', connection)
        self.connections.remove(connection)
        logger.info('Connection released (total: %d)',
                    len(self.connections))

    async def acquire(self):
        for connection in self.connections:
            if not connection.is_busy:
                return connection
        else:
            await self._lock.acquire()
            for connection in self.connections:
                if not connection.is_busy:
                    self._lock.release()
                    return connection
            if len(self.connections) < self.max_connections:
                try:
                    connection = await self.create_connection()
                except Exception as e:
                    logger.error('Could not connect to server: %s', str(e))
                    self._lock.release()
                    raise ConnectionError()
                self.connections.append(connection)
                logger.info('Connection established (total: %d)',
                            len(self.connections))
                self._lock.release()
                return connection
            else:
                self._lock.release()
                logger.warning('Pool is busy, wait...')
                while True:
                    await asyncio.sleep(0.01)
                    for connection in self.connections:
                        if not connection.is_busy:
                            return connection

    async def send_notification(self, request):
        failed_attempts = 0
        while True:
            logger.debug('Notification %s: waiting for connection',
                         request.notification_id)
            try:
                connection = await self.acquire()
            except ConnectionError:
                failed_attempts += 1
                logger.warning('Could not send notification %s: '
                               'ConnectionError', request.notification_id)

                if self.max_connection_attempts \
                        and failed_attempts > self.max_connection_attempts:
                    logger.error('Failed to connect after %d attempts.',
                                 failed_attempts)
                    raise

                await asyncio.sleep(1)
                continue
            logger.debug('Notification %s: connection %s acquired',
                         request.notification_id, connection)
            try:
                response = await connection.send_notification(request)
                return response
            except NoAvailableStreamIDError:
                connection.close()
            except ConnectionClosed:
                logger.warning('Could not send notification %s: '
                               'ConnectionClosed', request.notification_id)
            except FlowControlError:
                logger.debug('Got FlowControlError for notification %s',
                             request.notification_id)
                await asyncio.sleep(1)


class APNsCertConnectionPool(APNsBaseConnectionPool):
    def __init__(self,
                 cert_file: str,
                 topic: Optional[str] = None,
                 max_connections: int = 10,
                 max_connection_attempts: Optional[int] = None,
                 loop: Optional[asyncio.AbstractEventLoop] = None,
                 use_sandbox: bool = False):

        super(APNsCertConnectionPool, self).__init__(
            topic=topic,
            max_connections=max_connections,
            max_connection_attempts=max_connection_attempts,
            loop=loop,
            use_sandbox=use_sandbox,
        )

        self.cert_file = cert_file
        self.ssl_context = SSLContext()
        self.ssl_context.load_cert_chain(cert_file)

        if not self.apns_topic:
            with open(self.cert_file, 'rb') as f:
                body = f.read()
                cert = OpenSSL.crypto.load_certificate(
                    OpenSSL.crypto.FILETYPE_PEM, body
                )
                self.apns_topic = cert.get_subject().UID

    async def create_connection(self):
        _, protocol = await self.loop.create_connection(
            protocol_factory=partial(
                self.protocol_class,
                self.apns_topic,
                self.loop,
                self.discard_connection
            ),
            host=self.protocol_class.APNS_SERVER,
            port=self.protocol_class.APNS_PORT,
            ssl=self.ssl_context
        )
        return protocol


class APNsKeyConnectionPool(APNsBaseConnectionPool):
    def __init__(self,
                 key_file: str,
                 key_id: str,
                 team_id: str,
                 topic: str,
                 max_connections: int = 10,
                 max_connection_attempts: Optional[int] = None,
                 loop: Optional[asyncio.AbstractEventLoop] = None,
                 use_sandbox: bool = False):

        super(APNsKeyConnectionPool, self).__init__(
            topic=topic,
            max_connections=max_connections,
            max_connection_attempts=max_connection_attempts,
            loop=loop,
            use_sandbox=use_sandbox,
        )

        self.key_id = key_id
        self.team_id = team_id

        with open(key_file) as f:
            self.key = f.read()

    async def create_connection(self):
        auth_provider = JWTAuthorizationHeaderProvider(
            key=self.key,
            key_id=self.key_id,
            team_id=self.team_id
        )
        _, protocol = await self.loop.create_connection(
            protocol_factory=partial(
                self.protocol_class,
                self.apns_topic,
                self.loop,
                self.discard_connection,
                auth_provider,
            ),
            host=self.protocol_class.APNS_SERVER,
            port=self.protocol_class.APNS_PORT,
            ssl=True,
        )
        return protocol