import json
import logging
import uuid

from aiohttp import web
import aiohttp_cors
import sqlalchemy as sa
import trafaret as t
from typing import Any, Tuple
import yaml

from ai.backend.common import validators as tx
from ai.backend.common.logging import BraceStyleAdapter
from ai.backend.common.types import SessionTypes

from .auth import auth_required
from .exceptions import InvalidAPIParameters, TaskTemplateNotFound
from .manager import READ_ALLOWED, server_status_required
from .types import CORSOptions, Iterable, WebMiddleware
from .utils import check_api_params, get_access_key_scopes

from ..manager.models import (
    association_groups_users as agus, domains,
    groups, session_templates, keypairs, users, UserRole,
    query_accessible_session_templates, TemplateType,
    verify_vfolder_name
)

log = BraceStyleAdapter(logging.getLogger('ai.backend.gateway.session_template'))


task_template_v1 = t.Dict({
    tx.AliasedKey(['api_version', 'apiVersion']): t.String,
    t.Key('kind'): t.Enum('taskTemplate', 'task_template'),
    t.Key('metadata'): t.Dict({
        t.Key('name'): t.String,
        t.Key('tag', default=None): t.Null | t.String,
    }),
    t.Key('spec'): t.Dict({
        tx.AliasedKey(['type', 'sessionType'],
                      default='interactive') >> 'session_type': tx.Enum(SessionTypes),
        t.Key('kernel'): t.Dict({
            t.Key('image'): t.String,
            t.Key('environ', default={}): t.Null | t.Mapping(t.String, t.String),
            t.Key('run', default=None): t.Null | t.Dict({
                t.Key('bootstrap', default=None): t.Null | t.String,
                tx.AliasedKey(['startup', 'startup_command', 'startupCommand'],
                              default=None) >> 'startup_command': t.Null | t.String
            }),
            t.Key('git', default=None): t.Null | t.Dict({
                t.Key('repository'): t.String,
                t.Key('commit', default=None): t.Null | t.String,
                t.Key('branch', default=None): t.Null | t.String,
                t.Key('credential', default=None): t.Null | t.Dict({
                    t.Key('username'): t.String,
                    t.Key('password'): t.String
                }),
                tx.AliasedKey(['destination_dir', 'destinationDir'],
                              default=None) >> 'dest_dir': t.Null | t.String
            })
        }),
        t.Key('mounts', default=None): t.Null | t.Mapping(t.String, t.Any),
        t.Key('resources', default=None): t.Null | t.Mapping(t.String, t.Any)
    })
}).allow_extra('*')


@server_status_required(READ_ALLOWED)
@auth_required
@check_api_params(t.Dict(
    {
        tx.AliasedKey(['group', 'groupName', 'group_name'], default='default'): t.String,
        tx.AliasedKey(['domain', 'domainName', 'domain_name'], default='default'): t.String,
        t.Key('owner_access_key', default=None): t.Null | t.String,
        t.Key('payload'): t.String,
        t.Key('type') >> 'template_type': tx.Enum(TemplateType)
    }
))
async def create(request: web.Request, params: Any) -> web.Response:
    if params['domain'] is None:
        params['domain'] = request['user']['domain_name']
    requester_access_key, owner_access_key = await get_access_key_scopes(request, params)
    requester_uuid = request['user']['uuid']
    log.info('CREATE (ak:{0}/{1})',
             requester_access_key, owner_access_key if owner_access_key != requester_access_key else '*')
    user_uuid = request['user']['uuid']

    dbpool = request.app['dbpool']

    async with dbpool.acquire() as conn, conn.begin():
        if requester_access_key != owner_access_key:
            # Admin or superadmin is creating sessions for another user.
            # The check for admin privileges is already done in get_access_key_scope().
            query = (
                sa.select([keypairs.c.user, users.c.role, users.c.domain_name])
                .select_from(sa.join(keypairs, users, keypairs.c.user == users.c.uuid))
                .where(keypairs.c.access_key == owner_access_key)
            )
            result = await conn.execute(query)
            row = await result.fetchone()
            owner_domain = row['domain_name']
            owner_uuid = row['user']
            owner_role = row['role']
        else:
            # Normal case when the user is creating her/his own session.
            owner_domain = request['user']['domain_name']
            owner_uuid = requester_uuid
            owner_role = UserRole.USER

        query = (
            sa.select([domains.c.name])
            .select_from(domains)
            .where(
                (domains.c.name == owner_domain) &
                (domains.c.is_active)
            )
        )
        qresult = await conn.execute(query)
        domain_name = await qresult.scalar()
        if domain_name is None:
            raise InvalidAPIParameters('Invalid domain')

        if owner_role == UserRole.SUPERADMIN:
            # superadmin can spawn container in any designated domain/group.
            query = (
                sa.select([groups.c.id])
                .select_from(groups)
                .where(
                    (groups.c.domain_name == params['domain']) &
                    (groups.c.name == params['group']) &
                    (groups.c.is_active)
                ))
            qresult = await conn.execute(query)
            group_id = await qresult.scalar()
        elif owner_role == UserRole.ADMIN:
            # domain-admin can spawn container in any group in the same domain.
            if params['domain'] != owner_domain:
                raise InvalidAPIParameters("You can only set the domain to the owner's domain.")
            query = (
                sa.select([groups.c.id])
                .select_from(groups)
                .where(
                    (groups.c.domain_name == owner_domain) &
                    (groups.c.name == params['group']) &
                    (groups.c.is_active)
                ))
            qresult = await conn.execute(query)
            group_id = await qresult.scalar()
        else:
            # normal users can spawn containers in their group and domain.
            if params['domain'] != owner_domain:
                raise InvalidAPIParameters("You can only set the domain to your domain.")
            query = (
                sa.select([agus.c.group_id])
                .select_from(agus.join(groups, agus.c.group_id == groups.c.id))
                .where(
                    (agus.c.user_id == owner_uuid) &
                    (groups.c.domain_name == owner_domain) &
                    (groups.c.name == params['group']) &
                    (groups.c.is_active)
                ))
            qresult = await conn.execute(query)
            group_id = await qresult.scalar()
        if group_id is None:
            raise InvalidAPIParameters('Invalid group')

        log.debug('Params: {0}', params)
        try:
            body = json.loads(params['payload'])
        except json.JSONDecodeError:
            try:
                body = yaml.load(params['payload'], Loader=yaml.BaseLoader)
            except (yaml.YAMLError, yaml.MarkedYAMLError):
                raise InvalidAPIParameters('Malformed payload')
        body = task_template_v1.check(body)
        if mounts := body['spec'].get('mounts'):
            for p in mounts.values():
                if p is None:
                    continue
                if not p.startswith('/home/work/'):
                    raise InvalidAPIParameters(f'Path {p} should start with /home/work/')
                if not verify_vfolder_name(p.replace('/home/work/', '')):
                    raise InvalidAPIParameters(f'Path {p} is reserved for internal operations.')

        template_id = uuid.uuid4().hex
        resp = {
            'id': template_id,
            'user': user_uuid.hex,
        }
        query = session_templates.insert().values({
            'id': template_id,
            'domain_name': params['domain'],
            'group_id': group_id,
            'user_uuid': user_uuid,
            'name': body['metadata']['name'],
            'template': body
        })
        result = await conn.execute(query)
        assert result.rowcount == 1
    return web.json_response(resp)


@auth_required
@server_status_required(READ_ALLOWED)
@check_api_params(
    t.Dict({
        t.Key('all', default=False): t.ToBool,
        tx.AliasedKey(['group_id', 'groupId'], default=None): tx.UUID | t.String | t.Null,
    }),
)
async def list_template(request: web.Request, params: Any) -> web.Response:
    resp = []
    dbpool = request.app['dbpool']
    access_key = request['keypair']['access_key']
    domain_name = request['user']['domain_name']
    user_role = request['user']['role']
    user_uuid = request['user']['uuid']

    log.info('LIST (ak:{})', access_key)
    async with dbpool.acquire() as conn:
        if request['is_superadmin'] and params['all']:
            j = (session_templates
                    .join(users, session_templates.c.user_uuid == users.c.uuid, isouter=True)
                    .join(groups, session_templates.c.group_id == groups.c.id, isouter=True))
            query = (sa.select([session_templates, users.c.email, groups.c.name], use_labels=True)
                       .select_from(j)
                       .where(session_templates.c.is_active))
            result = await conn.execute(query)
            entries = []
            async for row in result:
                is_owner = True if row.session_templates_user == user_uuid else False
                entries.append({
                    'name': row.session_templates_name,
                    'id': row.session_templates_id,
                    'created_at': row.session_templates_created_at,
                    'is_owner': is_owner,
                    'user': (str(row.session_templates_user_uuid)
                             if row.session_templates_user_uuid else None),
                    'group': (str(row.session_templates_group_id)
                              if row.session_templates_group_id else None),
                    'user_email': row.users_email,
                    'group_name': row.groups_name,
                })
        else:
            extra_conds = None
            if params['group_id'] is not None:
                extra_conds = ((session_templates.c.group_id == params['group_id']))
            entries = await query_accessible_session_templates(
                        conn, user_uuid, TemplateType.TASK,
                        user_role=user_role, domain_name=domain_name,
                        allowed_types=['user', 'group'], extra_conds=extra_conds)

        for entry in entries:
            resp.append({
                'name': entry['name'],
                'id': entry['id'].hex,
                'created_at': str(entry['created_at']),
                'is_owner': entry['is_owner'],
                'user': str(entry['user']),
                'group': str(entry['group']),
                'user_email': entry['user_email'],
                'group_name': entry['group_name'],
                'type': 'user' if entry['user'] is not None else 'group',
            })
        return web.json_response(resp)


@auth_required
@server_status_required(READ_ALLOWED)
@check_api_params(
    t.Dict({
        t.Key('format', default='yaml'): t.Null | t.Enum('yaml', 'json'),
        t.Key('owner_access_key', default=None): t.Null | t.String,
    })
)
async def get(request: web.Request, params: Any) -> web.Response:
    if params['format'] not in ['yaml', 'json']:
        raise InvalidAPIParameters('format should be "yaml" or "json"')
    requester_access_key, owner_access_key = await get_access_key_scopes(request, params)
    log.info('GET (ak:{0}/{1})',
             requester_access_key, owner_access_key if owner_access_key != requester_access_key else '*')

    template_id = request.match_info['template_id']
    dbpool = request.app['dbpool']

    async with dbpool.acquire() as conn, conn.begin():
        query = (sa.select([session_templates.c.template])
                   .select_from(session_templates)
                   .where((session_templates.c.id == template_id) &
                          (session_templates.c.is_active)
                          ))
        template = await conn.scalar(query)
        if not template:
            raise TaskTemplateNotFound
    template = json.loads(template)
    if params['format'] == 'yaml':
        body = yaml.dump(template)
        return web.Response(text=body, content_type='text/yaml')
    else:
        return web.json_response(template)


@auth_required
@server_status_required(READ_ALLOWED)
@check_api_params(
    t.Dict({
        t.Key('payload'): t.String,
        t.Key('owner_access_key', default=None): t.Null | t.String,
    })
)
async def put(request: web.Request, params: Any) -> web.Response:
    dbpool = request.app['dbpool']
    template_id = request.match_info['template_id']

    requester_access_key, owner_access_key = await get_access_key_scopes(request, params)
    log.info('PUT (ak:{0}/{1})',
             requester_access_key, owner_access_key if owner_access_key != requester_access_key else '*')

    async with dbpool.acquire() as conn, conn.begin():
        query = (sa.select([session_templates.c.id])
                   .select_from(session_templates)
                   .where((session_templates.c.id == template_id) &
                          (session_templates.c.is_active)
                          ))
        result = await conn.scalar(query)
        if not result:
            raise TaskTemplateNotFound
        try:
            body = json.loads(params['payload'])
        except json.JSONDecodeError:
            body = yaml.load(params['payload'], Loader=yaml.BaseLoader)
        except (yaml.YAMLError, yaml.MarkedYAMLError):
            raise InvalidAPIParameters('Malformed payload')
        body = task_template_v1.check(body)
        query = (sa.update(session_templates)
                   .values(template=body, name=body['metadata']['name'])
                   .where((session_templates.c.id == template_id)))
        result = await conn.execute(query)
        assert result.rowcount == 1

        return web.json_response({'success': True})


@auth_required
@server_status_required(READ_ALLOWED)
@check_api_params(
    t.Dict({
        t.Key('owner_access_key', default=None): t.Null | t.String,
    })
)
async def delete(request: web.Request, params: Any) -> web.Response:
    dbpool = request.app['dbpool']
    template_id = request.match_info['template_id']

    requester_access_key, owner_access_key = await get_access_key_scopes(request, params)
    log.info('DELETE (ak:{0}/{1})',
             requester_access_key, owner_access_key if owner_access_key != requester_access_key else '*')

    async with dbpool.acquire() as conn, conn.begin():
        query = (sa.select([session_templates.c.id])
                   .select_from(session_templates)
                   .where((session_templates.c.id == template_id) &
                          (session_templates.c.is_active)
                          ))
        result = await conn.scalar(query)
        if not result:
            raise TaskTemplateNotFound

        query = (sa.update(session_templates)
                   .values(is_active=False)
                   .where((session_templates.c.id == template_id)))
        result = await conn.execute(query)
        assert result.rowcount == 1

        return web.json_response({'success': True})


async def init(app: web.Application) -> None:
    pass


async def shutdown(app: web.Application) -> None:
    pass


def create_app(default_cors_options: CORSOptions) -> Tuple[web.Application, Iterable[WebMiddleware]]:
    app = web.Application()
    app.on_startup.append(init)
    app.on_shutdown.append(shutdown)
    app['api_versions'] = (4, 5)
    app['prefix'] = 'template/session'
    cors = aiohttp_cors.setup(app, defaults=default_cors_options)
    cors.add(app.router.add_route('POST', '', create))
    cors.add(app.router.add_route('GET', '', list_template))
    template_resource = cors.add(app.router.add_resource(r'/{template_id}'))
    cors.add(template_resource.add_route('GET', get))
    cors.add(template_resource.add_route('PUT', put))
    cors.add(template_resource.add_route('DELETE', delete))

    return app, []