import inspect
import re
from functools import partial

from pydantic import ValidationError

from .base import BasePlugin
from .page import PAGES


class OpenAPI:
    def __init__(self, spec):
        self.spec = spec

    def on_get(self, req, resp):
        resp.media = self.spec


class DocPage:
    def __init__(self, html, spec_url):
        self.page = html.format(spec_url)

    def on_get(self, req, resp):
        resp.content_type = 'text/html'
        resp.body = self.page


DOC_CLASS = [x.__name__ for x in (DocPage, OpenAPI)]


class FlaconPlugin(BasePlugin):
    def __init__(self, spectree):
        super().__init__(spectree)
        from falcon.routing.compiled import _FIELD_PATTERN

        self.FIELD_PATTERN = _FIELD_PATTERN
        # NOTE from `falcon.routing.compiled.CompiledRouterNode`
        self.ESCAPE = r'[\.\(\)\[\]\?\$\*\+\^\|]'
        self.ESCAPE_TO = r'\\\g<0>'
        self.EXTRACT = r'{\2}'
        # NOTE this regex is copied from werkzeug.routing._converter_args_re and
        # modified to support only int args
        self.INT_ARGS = re.compile(r'''
            ((?P<name>\w+)\s*=\s*)?
            (?P<value>\d+)\s*
        ''', re.VERBOSE)
        self.INT_ARGS_NAMES = ('num_digits', 'min', 'max')

    def register_route(self, app):
        self.app = app
        self.app.add_route(
            self.config.spec_url, OpenAPI(self.spectree.spec)
        )
        for ui in PAGES:
            self.app.add_route(
                f'/{self.config.PATH}/{ui}',
                DocPage(PAGES[ui], self.config.spec_url),
            )

    def find_routes(self):
        routes = []

        def find_node(node):
            if node.resource and node.resource.__class__.__name__ not in DOC_CLASS:
                routes.append(node)

            for child in node.children:
                find_node(child)

        for route in self.app._router._roots:
            find_node(route)

        return routes

    def parse_func(self, route):
        return route.method_map.items()

    def parse_path(self, route):
        subs, parameters = [], []
        for segment in route.uri_template.strip('/').split('/'):
            matches = self.FIELD_PATTERN.finditer(segment)
            if not matches:
                subs.append(segment)
                continue

            escaped = re.sub(self.ESCAPE, self.ESCAPE_TO, segment)
            subs.append(self.FIELD_PATTERN.sub(self.EXTRACT, escaped))

            for field in matches:
                variable, converter, argstr = [field.group(name) for name in
                                               ('fname', 'cname', 'argstr')]

                if converter == 'int':
                    if argstr is None:
                        argstr = ''

                    arg_values = [None, None, None]
                    for index, match in enumerate(self.INT_ARGS.finditer(argstr)):
                        name, value = match.group('name'), match.group('value')
                        if name:
                            index = self.INT_ARGS_NAMES.index(name)
                        arg_values[index] = value

                    num_digits, minumum, maximum = arg_values
                    schema = {
                        'type': 'integer',
                        'format': f'int{num_digits}' if num_digits else 'int32',
                    }
                    if minumum:
                        schema['minimum'] = minumum
                    if maximum:
                        schema['maximum'] = maximum
                elif converter == 'uuid':
                    schema = {
                        'type': 'string',
                        'format': 'uuid'
                    }
                elif converter == 'dt':
                    schema = {
                        'type': 'string',
                        'format': 'date-time',
                    }
                else:
                    # no converter specified or customized converters
                    schema = {'type': 'string'}

                parameters.append({
                    'name': variable,
                    'in': 'path',
                    'required': True,
                    'schema': schema,
                })

        return f'/{"/".join(subs)}', parameters

    def request_validation(self, req, query, json, headers, cookies):
        if query:
            req.context.query = query.parse_obj(req.params)
        if headers:
            req.context.headers = headers.parse_obj(req.headers)
        if cookies:
            req.context.cookies = cookies.parse_obj(req.cookies)
        media = req.media or {}
        if json:
            req.context.json = json.parse_obj(media)

    def validate(self,
                 func,
                 query, json, headers, cookies, resp,
                 before, after,
                 *args, **kwargs):
        # falcon endpoint method arguments: (self, req, resp)
        _self, _req, _resp = args[:3]
        req_validation_error, resp_validation_error = None, None
        try:
            self.request_validation(_req, query, json, headers, cookies)

        except ValidationError as err:
            req_validation_error = err
            _resp.status = '422 Unprocessable Entity'
            _resp.media = err.errors()

        before(_req, _resp, req_validation_error, _self)
        if req_validation_error:
            return

        func(*args, **kwargs)
        if resp and resp.has_model():
            model = resp.find_model(_resp.status[:3])
            if model:
                try:
                    model.validate(_resp.media)
                except ValidationError as err:
                    resp_validation_error = err
                    _resp.status = '500 Internal Service Response Validation Error'
                    _resp.media = err.errors()

        after(_req, _resp, resp_validation_error, _self)

    def bypass(self, func, method):
        if not isinstance(func, partial):
            return False
        if inspect.ismethod(func.func):
            return False
        # others are <cyfunction>
        return True