from __future__ import annotations

import asyncio
from collections import defaultdict
import functools
import logging
import json
from typing import (
    Any, Union, Final,
    Iterable, AsyncIterator,
    Sequence, Tuple,
    Mapping, MutableMapping,
    Set,
    Protocol,
)
import uuid

from aiohttp import web
import aiohttp_cors
from aiohttp_sse import sse_response
import aioredis
from aiojobs.aiohttp import get_scheduler_from_app
from aiotools import adefer
import attr
import sqlalchemy as sa
import trafaret as t

from ai.backend.common import msgpack, redis
from ai.backend.common import validators as tx
from ai.backend.common.logging import BraceStyleAdapter
from ai.backend.common.types import (
    aobject,
    AgentId,
)
from ai.backend.common.utils import current_loop
from .auth import auth_required
from .defs import REDIS_STREAM_DB
from .exceptions import GenericNotFound, GenericForbidden, GroupNotFound
from .manager import READ_ALLOWED, server_status_required
from .types import CORSOptions, WebMiddleware
from .utils import check_api_params
from ..manager.models import kernels, groups, UserRole
from ..manager.types import BackgroundTaskEventArgs, Sentinel

log = BraceStyleAdapter(logging.getLogger('ai.backend.gateway.events'))

sentinel: Final = Sentinel.token


class EventCallback(Protocol):
    async def __call__(self,
                       context: Any,
                       agent_id: AgentId,
                       event_name: str,
                       *args) -> None:
        ...


@attr.s(auto_attribs=True, slots=True, frozen=True, eq=False, order=False)
class EventHandler:
    context: Any
    callback: EventCallback


class EventDispatcher(aobject):
    '''
    We have two types of event handlers: consumer and subscriber.

    Consumers use the distribution pattern. Only one consumer among many manager worker processes
    receives the event.

    Consumer example: database updates upon specific events.

    Subscribers use the broadcast pattern. All subscribers in many manager worker processes
    receive the same event.

    Subscriber example: enqueuing events to the queues for event streaming API handlers
    '''

    loop: asyncio.AbstractEventLoop
    root_app: web.Application
    consumers: MutableMapping[str, Set[EventHandler]]
    subscribers: MutableMapping[str, Set[EventHandler]]
    redis_producer: aioredis.Redis
    redis_consumer: aioredis.Redis
    redis_subscriber: aioredis.Redis
    consumer_task: asyncio.Task
    subscriber_task: asyncio.Task
    producer_lock: asyncio.Lock

    def __init__(self, app: web.Application) -> None:
        self.loop = current_loop()
        self.root_app = app
        self.consumers = defaultdict(set)
        self.subscribers = defaultdict(set)

    async def __ainit__(self) -> None:
        self.redis_producer = await self._create_redis()
        self.redis_consumer = await self._create_redis()
        self.redis_subscriber = await self._create_redis()
        self.consumer_task = self.loop.create_task(self._consume())
        self.subscriber_task = self.loop.create_task(self._subscribe())
        self.producer_lock = asyncio.Lock()

    async def _create_redis(self):
        config = self.root_app['config']
        return await redis.connect_with_retries(
            config['redis']['addr'].as_sockaddr(),
            db=REDIS_STREAM_DB,
            password=(config['redis']['password']
                      if config['redis']['password'] else None),
            encoding=None,
        )

    async def close(self) -> None:
        self.consumer_task.cancel()
        await self.consumer_task
        self.subscriber_task.cancel()
        await self.subscriber_task
        self.redis_producer.close()
        self.redis_consumer.close()
        self.redis_subscriber.close()
        await self.redis_producer.wait_closed()
        await self.redis_consumer.wait_closed()
        await self.redis_subscriber.wait_closed()

    def consume(self, event_name: str, context: Any, callback: EventCallback) -> EventHandler:
        handler = EventHandler(context, callback)
        self.consumers[event_name].add(handler)
        return handler

    def unconsume(self, event_name: str, handler: EventHandler) -> None:
        self.consumers[event_name].discard(handler)

    def subscribe(self, event_name: str, context: Any, callback: EventCallback) -> EventHandler:
        handler = EventHandler(context, callback)
        self.subscribers[event_name].add(handler)
        return handler

    def unsubscribe(self, event_name: str, handler: EventHandler) -> None:
        self.subscribers[event_name].discard(handler)

    async def produce_event(self, event_name: str,
                            args: Sequence[Any] = tuple(), *,
                            agent_id: str = 'manager') -> None:
        raw_msg = msgpack.packb({
            'event_name': event_name,
            'agent_id': agent_id,
            'args': args,
        })
        async with self.producer_lock:
            def _pipe_builder():
                pipe = self.redis_producer.pipeline()
                pipe.rpush('events.prodcons', raw_msg)
                pipe.publish('events.pubsub', raw_msg)
                return pipe
            await redis.execute_with_retries(_pipe_builder)

    async def dispatch_consumers(self, event_name: str, agent_id: AgentId,
                                 args: Tuple[Any, ...] = tuple()) -> None:
        log_fmt = 'DISPATCH_CONSUMERS(ev:{}, ag:{})'
        log_args = (event_name, agent_id)
        if self.root_app['config']['debug']['log-events']:
            log.debug(log_fmt, *log_args)
        scheduler = get_scheduler_from_app(self.root_app)
        for consumer in self.consumers[event_name]:
            cb = consumer.callback
            try:
                if asyncio.iscoroutine(cb):
                    await scheduler.spawn(cb)
                elif asyncio.iscoroutinefunction(cb):
                    await scheduler.spawn(cb(consumer.context, agent_id, event_name, *args))
                else:
                    cb = functools.partial(cb, consumer.context, agent_id, event_name, *args)
                    self.loop.call_soon(cb)
            except asyncio.CancelledError:
                raise
            except Exception:
                log.exception(log_fmt + ': unexpected-error', *log_args)

    async def dispatch_subscribers(self, event_name: str, agent_id: AgentId,
                                   args: Tuple[Any, ...] = tuple()) -> None:
        log_fmt = 'DISPATCH_SUBSCRIBERS(ev:{}, ag:{})'
        log_args = (event_name, agent_id)
        if self.root_app['config']['debug']['log-events']:
            log.debug(log_fmt, *log_args)
        scheduler = get_scheduler_from_app(self.root_app)
        for subscriber in self.subscribers[event_name]:
            cb = subscriber.callback
            try:
                if asyncio.iscoroutine(cb):
                    await scheduler.spawn(cb)
                elif asyncio.iscoroutinefunction(cb):
                    await scheduler.spawn(cb(subscriber.context, agent_id, event_name, *args))
                else:
                    cb = functools.partial(cb, subscriber.context, agent_id, event_name, *args)
                    self.loop.call_soon(cb)
            except asyncio.CancelledError:
                raise
            except Exception:
                log.exception(log_fmt + ': unexpected-error', *log_args)

    async def _consume(self) -> None:
        while True:
            try:
                key, raw_msg = await redis.execute_with_retries(
                    lambda: self.redis_consumer.blpop('events.prodcons'))
                msg = msgpack.unpackb(raw_msg)
                await self.dispatch_consumers(msg['event_name'],
                                              msg['agent_id'],
                                              msg['args'])
            except asyncio.CancelledError:
                break
            except Exception:
                log.exception('EventDispatcher.consume(): unexpected-error')

    async def _subscribe(self) -> None:

        async def _subscribe_impl():
            channels = await self.redis_subscriber.subscribe('events.pubsub')
            async for raw_msg in channels[0].iter():
                msg = msgpack.unpackb(raw_msg)
                await self.dispatch_subscribers(msg['event_name'],
                                                msg['agent_id'],
                                                msg['args'])

        while True:
            try:
                await redis.execute_with_retries(lambda: _subscribe_impl())
            except asyncio.CancelledError:
                break
            except Exception:
                log.exception('EventDispatcher.subscribe(): unexpected-error')


@server_status_required(READ_ALLOWED)
@auth_required
@check_api_params(
    t.Dict({
        tx.AliasedKey(['name', 'sessionName'], default='*') >> 'session_name': t.String,
        t.Key('ownerAccessKey', default=None) >> 'owner_access_key': t.Null | t.String,
        tx.AliasedKey(['group', 'groupName'], default='*') >> 'group_name': t.String,
    }))
@adefer
async def push_session_events(
    defer,
    request: web.Request,
    params: Mapping[str, Any],
) -> web.StreamResponse:
    app = request.app
    session_name = params['session_name']
    user_role = request['user']['role']
    user_uuid = request['user']['uuid']
    access_key = params['owner_access_key']
    if access_key is None:
        access_key = request['keypair']['access_key']
    if user_role == UserRole.USER:
        if access_key != request['keypair']['access_key']:
            raise GenericForbidden
    group_name = params['group_name']
    session_event_queues = app['session_event_queues']  # type: Set[asyncio.Queue]
    my_queue = asyncio.Queue()  # type: asyncio.Queue[Union[Sentinel, Tuple[str, dict, str]]]
    log.info('PUSH_SESSION_EVENTS (ak:{}, s:{}, g:{})', access_key, session_name, group_name)
    if group_name == '*':
        group_id = '*'
    else:
        async with app['dbpool'].acquire() as conn, conn.begin():
            query = (
                sa.select([groups.c.id])
                .select_from(groups)
                .where(groups.c.name == group_name)
            )
            row = await conn.first(query)
            if row is None:
                raise GroupNotFound
            group_id = row['id']
    session_event_queues.add(my_queue)
    defer(lambda: session_event_queues.remove(my_queue))
    try:
        async with sse_response(request) as resp:
            while True:
                evdata = await my_queue.get()
                try:
                    if evdata is sentinel:
                        break
                    event_name, row, reason = evdata
                    if user_role in (UserRole.USER, UserRole.ADMIN):
                        if row['domain_name'] != request['user']['domain_name']:
                            continue
                    if user_role == UserRole.USER:
                        if row['user_uuid'] != user_uuid:
                            continue
                    if group_id != '*' and row['group_id'] != group_id:
                        continue
                    if session_name != '*' and not (
                        (row['sess_id'] == session_name) and
                        (row['access_key'] == access_key)):
                        continue
                    await resp.send(json.dumps({
                        'sessionName': str(row['sess_id']),
                        'ownerAccessKey': row['access_key'],
                        'reason': reason,
                    }), event=event_name)
                finally:
                    my_queue.task_done()
    finally:
        return resp


@server_status_required(READ_ALLOWED)
@auth_required
@check_api_params(t.Dict({
    tx.AliasedKey(['task_id', 'taskId']): tx.UUID,
}))
@adefer
async def push_background_task_events(
    defer,
    request: web.Request,
    params: Mapping[str, Any],
) -> web.StreamResponse:
    app = request.app
    task_id = params['task_id']
    access_key = request['keypair']['access_key']
    log.info('PUSH_BACKGROUND_TASK_EVENTS (ak:{}, t:{})', access_key, task_id)

    tracker_key = f'bgtask.{task_id}'
    task_info = await app['redis_stream'].hgetall(tracker_key)
    if task_info is None:
        # The task ID is invalid or represents a task completed more than 24 hours ago.
        raise GenericNotFound('No such background task.')

    if task_info['status'] != 'started':
        # It is an already finished task!
        async with sse_response(request) as resp:
            try:
                body = {
                    'task_id': str(task_id),
                    'status': task_info['status'],
                    'current_progress': task_info['current'],
                    'total_progress': task_info['total'],
                    'message': task_info['msg'],
                }
                await resp.send(json.dumps(body), event=f"task_{task_info['status']}")
            finally:
                await resp.send(b'{}', event="server_close")
        return resp

    # It is an ongoing task.
    task_update_queues = app['task_update_queues']  # type: Set[asyncio.Queue]
    my_queue = asyncio.Queue()  # type: asyncio.Queue[Union[Sentinel, Tuple[str, Any]]]
    task_update_queues.add(my_queue)
    defer(lambda: task_update_queues.remove(my_queue))
    try:
        async with sse_response(request) as resp:
            while True:
                event_args = await my_queue.get()
                try:
                    if event_args is sentinel:
                        break
                    event_name = event_args[0]
                    event_data = BackgroundTaskEventArgs(**event_args[1])
                    if task_id != uuid.UUID(event_data.task_id):
                        continue
                    body = {
                        'task_id': str(task_id),
                        'message': event_data.message,
                    }
                    if event_data.current_progress is not None:
                        body['current_progress'] = event_data.current_progress
                    if event_data.total_progress is not None:
                        body['total_progress'] = event_data.total_progress
                    await resp.send(json.dumps(body), event=event_name, retry=5)
                    if event_name in ('task_done', 'task_failed', 'task_cancelled'):
                        await resp.send('{}', event="server_close")
                        break
                finally:
                    my_queue.task_done()
    finally:
        return resp


async def enqueue_session_status_update(
    app: web.Application,
    agent_id: AgentId,
    event_name: str,
    raw_kernel_id: str,
    reason: str = None,
    exit_code: int = None,
) -> None:
    if raw_kernel_id is None:
        return
    kernel_id = uuid.UUID(raw_kernel_id)
    # TODO: when event_name == 'kernel_started', read the service port data.
    async with app['dbpool'].acquire() as conn, conn.begin():
        query = (
            sa.select([
                kernels.c.role,
                kernels.c.sess_id,
                kernels.c.access_key,
                kernels.c.domain_name,
                kernels.c.group_id,
                kernels.c.user_uuid,
            ])
            .select_from(kernels)
            .where(
                (kernels.c.id == kernel_id)
            )
        )
        result = await conn.execute(query)
        row = await result.first()
        if row is None:
            return
        if row['role'] != 'master':
            return
    for q in app['session_event_queues']:
        q.put_nowait((event_name, row, reason))


async def enqueue_batch_session_result_update(
    app: web.Application,
    agent_id: AgentId,
    event_name: str,
    raw_kernel_id: str,
    exit_code: int = None,
) -> None:
    kernel_id = uuid.UUID(raw_kernel_id)
    # TODO: when event_name == 'kernel_started', read the service port data.
    async with app['dbpool'].acquire() as conn, conn.begin():
        query = (
            sa.select([
                kernels.c.role,
                kernels.c.sess_id,
                kernels.c.access_key,
                kernels.c.domain_name,
                kernels.c.group_id,
                kernels.c.user_uuid,
            ])
            .select_from(kernels)
            .where(
                (kernels.c.id == kernel_id)
            )
        )
        result = await conn.execute(query)
        row = await result.first()
        if row is None:
            return
        if row['role'] != 'master':
            return
    if event_name == 'kernel_success':
        reason = 'task-success'
    else:
        reason = 'task-failure'
    for q in app['session_event_queues']:
        q.put_nowait((event_name, row, reason))


async def enqueue_task_status_update(
    app: web.Application,
    agent_id: AgentId,
    event_name: str,
    raw_task_id: str,
    current_progress: Union[int, float] = None,
    total_progress: Union[int, float] = None,
    message: str = None,
) -> None:
    for q in app['task_update_queues']:
        q.put_nowait((event_name, raw_task_id, current_progress, total_progress, message, ))


async def events_app_ctx(app: web.Application) -> AsyncIterator[None]:
    app['session_event_queues'] = set()
    app['task_update_queues'] = set()
    event_dispatcher = app['event_dispatcher']
    event_dispatcher.subscribe('kernel_enqueued', app, enqueue_session_status_update)
    event_dispatcher.subscribe('kernel_preparing', app, enqueue_session_status_update)
    event_dispatcher.subscribe('kernel_pulling', app, enqueue_session_status_update)
    event_dispatcher.subscribe('kernel_creating', app, enqueue_session_status_update)
    event_dispatcher.subscribe('kernel_started', app, enqueue_session_status_update)
    event_dispatcher.subscribe('kernel_terminating', app, enqueue_session_status_update)
    event_dispatcher.subscribe('kernel_terminated', app, enqueue_session_status_update)
    event_dispatcher.subscribe('kernel_cancelled', app, enqueue_session_status_update)
    event_dispatcher.subscribe('kernel_success', app, enqueue_batch_session_result_update)
    event_dispatcher.subscribe('kernel_failure', app, enqueue_batch_session_result_update)
    event_dispatcher.subscribe('task_updated', app, enqueue_task_status_update)
    event_dispatcher.subscribe('task_done', app, enqueue_task_status_update)
    event_dispatcher.subscribe('task_cancelled', app, enqueue_task_status_update)
    event_dispatcher.subscribe('task_failed', app, enqueue_task_status_update)

    yield


async def events_shutdown(app: web.Application) -> None:
    # shutdown handler is called before waiting for closing active connections.
    # We need to put sentinels here to ensure delivery of them to active SSE connections.
    for q in app['session_event_queues']:
        q.put_nowait(sentinel)
    for q in app['task_update_queues']:
        q.put_nowait(sentinel)


def create_app(default_cors_options: CORSOptions) -> Tuple[web.Application, Iterable[WebMiddleware]]:
    app = web.Application()
    app['prefix'] = 'events'
    app['api_versions'] = (3, 4)
    app.on_shutdown.append(events_shutdown)
    cors = aiohttp_cors.setup(app, defaults=default_cors_options)
    add_route = app.router.add_route
    app.cleanup_ctx.append(events_app_ctx)
    cors.add(add_route('GET', r'/background-task', push_background_task_events))
    cors.add(add_route('GET', r'/session', push_session_events))
    return app, []