import inspect
from collections import namedtuple
from functools import partial
from json import JSONDecodeError
from json import loads as json_loads

from pydantic import ValidationError

from .base import BasePlugin, Context
from .page import PAGES

METHODS = {'get', 'post', 'put', 'patch', 'delete'}
Route = namedtuple('Route', ['path', 'methods', 'func'])


class StarlettePlugin(BasePlugin):
    def __init__(self, spectree):
        super().__init__(spectree)
        from starlette.convertors import CONVERTOR_TYPES
        self.conv2type = {
            conv: typ for typ, conv in CONVERTOR_TYPES.items()
        }

    def register_route(self, app):
        self.app = app
        from starlette.responses import JSONResponse, HTMLResponse

        self.app.add_route(
            self.config.spec_url,
            lambda request: JSONResponse(self.spectree.spec),
        )

        for ui in PAGES:
            self.app.add_route(
                f'/{self.config.PATH}/{ui}',
                lambda request, ui=ui: HTMLResponse(
                    PAGES[ui].format(self.config.spec_url)
                ),
            )

    async def request_validation(self, request, query, json, headers, cookies):
        request.context = Context(
            query.parse_obj(request.query_params) if query else None,
            json.parse_obj(json_loads(await request.body() or '{}')) if json else None,
            headers.parse_obj(request.headers) if headers else None,
            cookies.parse_obj(request.cookies) if cookies else None,
        )

    async def validate(self,
                       func,
                       query, json, headers, cookies, resp,
                       before, after,
                       *args, **kwargs):
        from starlette.responses import JSONResponse

        # NOTE: If func is a `HTTPEndpoint`, it should have '.' in its ``__qualname__``
        # This is not elegant. But it seems `inspect` doesn't work here.
        instance = args[0] if '.' in func.__qualname__ else None
        request = args[1] if '.' in func.__qualname__ else args[0]
        response = None
        req_validation_error, resp_validation_error, json_decode_error = None, None, None

        try:
            await self.request_validation(request, query, json, headers, cookies)
        except ValidationError as err:
            req_validation_error = err
            response = JSONResponse(err.errors(), 422)
        except JSONDecodeError as err:
            json_decode_error = err
            self.logger.info(
                '422 Validation Error',
                extra={'spectree_json_decode_error': str(err)}
            )
            response = JSONResponse({'error_msg': str(err)}, 422)

        before(request, response, req_validation_error, instance)
        if req_validation_error or json_decode_error:
            return response

        if inspect.iscoroutinefunction(func):
            response = await func(*args, **kwargs)
        else:
            response = func(*args, **kwargs)

        if resp:
            model = resp.find_model(response.status_code)
            if model:
                try:
                    model.validate(json_loads(response.body))
                except ValidationError as err:
                    resp_validation_error = err
                    response = JSONResponse(err.errors(), 500)

        after(request, response, resp_validation_error, instance)

        return response

    def find_routes(self):
        routes = []

        def parse_route(app, prefix=''):
            for route in app.routes:
                if route.path.startswith(f'/{self.config.PATH}'):
                    continue

                func = route.app
                if isinstance(func, partial):
                    try:
                        func = func.__wrapped__
                    except AttributeError:
                        pass

                if inspect.isclass(func):
                    for method in METHODS:
                        if getattr(func, method, None):
                            routes.append(Route(
                                f'{prefix}{route.path}',
                                {method.upper()},
                                getattr(func, method)
                            ))
                elif inspect.isfunction(func):
                    routes.append(Route(
                        f'{prefix}{route.path}',
                        route.methods,
                        route.endpoint))
                else:
                    parse_route(route, prefix=f'{prefix}{route.path}')

        parse_route(self.app)
        return routes

    def bypass(self, func, method):
        if method in ['HEAD', 'OPTIONS']:
            return True
        return False

    def parse_func(self, route):
        for method in route.methods or ['GET']:
            yield method, route.func

    def parse_path(self, route):
        from starlette.routing import compile_path
        _, path, variables = compile_path(route.path)
        parameters = []

        for name, conv in variables.items():
            schema = None
            typ = self.conv2type[conv]
            if typ == 'int':
                schema = {
                    'type': 'integer',
                    'format': 'int32'
                }
            elif typ == 'float':
                schema = {
                    'type': 'number',
                    'format': 'float',
                }
            elif typ == 'path':
                schema = {
                    'type': 'string',
                    'format': 'path',
                }
            elif typ == 'str':
                schema = {'type': 'string'}

            parameters.append({
                'name': name,
                'in': 'path',
                'required': True,
                'schema': schema,
            })

        return path, parameters