from functools import wraps
from typing import Callable, List, Optional, Tuple

from sanic.request import Request

from sanic_jwt_extended.exceptions import (
    AccessDeniedError,
    ConfigurationConflictError,
    CSRFError,
    FreshTokenRequiredError,
    InvalidHeaderError,
    NoAuthorizationError,
    RevokedTokenError,
    WrongTokenError,
)
from sanic_jwt_extended.jwt_manager import JWT
from sanic_jwt_extended.tokens import Token

try:
    from hmac import compare_digest
except ImportError:  # pragma: no cover

    def compare_digest(a, b):
        if isinstance(a, str):
            a = a.encode("utf-8")
        if isinstance(b, str):
            b = b.encode("utf-8")

        if len(a) != len(b):
            return False

        r = 0
        for x, y in zip(a, b):
            r |= x ^ y

        return not r


jwt_get_function = Callable[[Request, bool], Tuple[str, Optional[str]]]


def _get_request(args) -> Request:
    if isinstance(args[0], Request):
        request = args[0]
    else:
        request = args[1]
    return request


def _get_raw_jwt_from_request(request, is_access=True):
    functions: List[jwt_get_function] = []

    for eligible_location in JWT.config.token_location:
        if eligible_location == "header":
            functions.append(_get_raw_jwt_from_headers)
        if eligible_location == "query":
            functions.append(_get_raw_jwt_from_query_params)
        if eligible_location == "cookies":
            functions.append(_get_raw_jwt_from_cookies)

    raw_jwt = None
    csrf_value = None
    errors = []

    for f in functions:
        try:
            raw_jwt, csrf_value = f(request, is_access)
            break
        except NoAuthorizationError as e:
            errors.append(str(e))

    if not raw_jwt:
        raise NoAuthorizationError(', '.join(errors))

    return raw_jwt, csrf_value


def _get_raw_jwt_from_headers(request, is_access):
    header_key = (
        JWT.config.jwt_header_key if is_access else JWT.config.refresh_jwt_header_key
    )
    header_prefix = JWT.config.jwt_header_prefix

    token_header = request.headers.get(header_key)

    if not token_header:
        raise NoAuthorizationError(f'Missing header "{header_key}"')

    parts: List[str] = token_header.split()

    if parts[0] != header_prefix or len(parts) != 2:
        raise InvalidHeaderError(
            f"Bad {header_key} header. Expected value '{header_prefix} <JWT>'"
        )

    encoded_token: str = parts[1]

    return encoded_token, None


def _get_raw_jwt_from_query_params(request, _):
    encoded_token = request.args.get(JWT.config.jwt_query_param_name)
    if not encoded_token:
        raise NoAuthorizationError(
            f'Missing query parameter "{JWT.config.jwt_query_param_name}"'
        )

    return encoded_token, None


def _get_raw_jwt_from_cookies(request, is_access):
    cookie_key = JWT.config.jwt_cookie if is_access else JWT.config.refresh_jwt_cookie
    csrf_header_key = (
        JWT.config.jwt_csrf_header if is_access else JWT.config.refresh_jwt_csrf_header
    )

    encoded_token = request.cookies.get(cookie_key)
    csrf_value = None

    if not encoded_token:
        raise NoAuthorizationError(f'Missing cookie "{cookie_key}"')

    if JWT.config.csrf_protect and request.method in JWT.config.csrf_request_methods:
        csrf_value = request.headers.get(csrf_header_key)

        if not csrf_value:
            raise CSRFError("Missing CSRF token")

    return encoded_token, csrf_value


def _csrf_check(csrf_from_request, csrf_from_jwt):
    if not csrf_from_jwt or not isinstance(csrf_from_jwt, str):
        raise CSRFError('Can not find valid CSRF data from token')
    if not compare_digest(csrf_from_request, csrf_from_jwt):
        raise CSRFError('CSRF double submit tokens do not match')


def jwt_required(
    function=None, *, allow=None, deny=None, fresh_required=False,
):
    def real(fn):
        @wraps(fn)
        async def wrapper(*args, **kwargs):
            request = _get_request(args)
            raw_jwt, csrf_value = _get_raw_jwt_from_request(request)

            token_obj = Token(raw_jwt)

            if csrf_value:
                _csrf_check(csrf_value, token_obj.csrf)

            if token_obj.type != "access":
                raise WrongTokenError("Only access tokens are allowed")

            if fresh_required and not token_obj.fresh:
                raise FreshTokenRequiredError("Only fresh access tokens are allowed")

            if allow and token_obj.role not in allow:
                raise AccessDeniedError("You are not allowed to access here")

            if deny and token_obj.role in deny:
                raise AccessDeniedError("You are not allowed to access here")

            if JWT.config.use_blacklist and await JWT.blacklist.is_blacklisted(
                token_obj
            ):
                raise RevokedTokenError("Token has been revoked")

            kwargs["token"] = token_obj

            return await fn(*args, **kwargs)

        return wrapper

    if function:
        return real(function)
    else:
        if allow and deny:
            raise ConfigurationConflictError(
                "Can not use 'deny' and 'allow' option together."
            )
        return real


def jwt_optional(function):
    @wraps(function)
    async def wrapper(*args, **kwargs):
        request = _get_request(args)
        token_obj: Optional[Token] = None

        try:
            raw_jwt, csrf_value = _get_raw_jwt_from_request(request)

            token_obj = Token(raw_jwt)

            if csrf_value:
                _csrf_check(csrf_value, token_obj.csrf)

            if token_obj.type != "access":
                raise WrongTokenError("Only access tokens are allowed")
        except (NoAuthorizationError, InvalidHeaderError):
            pass

        kwargs["token"] = token_obj

        return await function(*args, **kwargs)

    return wrapper


def refresh_jwt_required(function=None, *, allow=None, deny=None):
    def real(fn):
        @wraps(fn)
        async def wrapper(*args, **kwargs):
            request = _get_request(args)
            raw_jwt, csrf_value = _get_raw_jwt_from_request(request, is_access=False)

            token_obj = Token(raw_jwt)

            if csrf_value:
                _csrf_check(csrf_value, token_obj.csrf)

            if token_obj.type != "refresh":
                raise WrongTokenError("Only refresh tokens are allowed")

            if allow and token_obj.role not in allow:
                raise AccessDeniedError("You are not allowed to refresh in here")

            if deny and token_obj.role in deny:
                raise AccessDeniedError("You are not allowed to refresh in here")

            if JWT.config.use_blacklist and await JWT.blacklist.is_blacklisted(
                token_obj
            ):
                raise RevokedTokenError("Token has been revoked")

            kwargs["token"] = token_obj

            return await fn(*args, **kwargs)

        return wrapper

    if function:
        return real(function)
    else:
        if allow and deny:
            raise ConfigurationConflictError(
                "Can not use 'deny' and 'allow' option together."
            )
        return real