from __future__ import annotations

import json
import logging
from typing import Callable

from django.core.exceptions import ImproperlyConfigured
from django.http import HttpResponseForbidden
from django.http.request import HttpRequest
from django.http.response import HttpResponse
from django.template import loader
from jwt.exceptions import InvalidTokenError

from .models import RequestToken
from .settings import FOUR03_TEMPLATE, JWT_QUERYSTRING_ARG
from .utils import decode

logger = logging.getLogger(__name__)


class RequestTokenMiddleware:
    """
    Extract and verify request tokens from incoming GET requests.

    This middleware is used to perform initial JWT verfication of
    link tokens.

    """

    def __init__(self, get_response: Callable):
        self.get_response = get_response

    def __call__(self, request: HttpRequest) -> HttpResponse:  # noqa: C901
        """
        Verify JWT request querystring arg.

        If a token is found (using JWT_QUERYSTRING_ARG), then it is decoded,
        which verifies the signature and expiry dates, and raises a 403 if
        the token is invalid.

        The decoded payload is then added to the request as the `token_payload`
        property - allowing it to be interrogated by the view function
        decorator when it gets there.

        We don't substitute in the user at this point, as we are not making
        any assumptions about the request path at this point - it's not until
        we get to the view function that we know where we are heading - at
        which point we verify that the scope matches, and only then do we
        use the token user.

        """
        if not hasattr(request, "session"):
            raise ImproperlyConfigured(
                "Request has no session attribute, please ensure that Django "
                "session middleware is installed."
            )
        if not hasattr(request, "user"):
            raise ImproperlyConfigured(
                "Request has no user attribute, please ensure that Django "
                "authentication middleware is installed."
            )

        if request.method == "GET" or request.method == "POST":
            token = request.GET.get(JWT_QUERYSTRING_ARG)
            if not token and request.method == "POST":
                if request.META.get("CONTENT_TYPE") == "application/json":
                    token = json.loads(request.body).get(JWT_QUERYSTRING_ARG)
                if not token:
                    token = request.POST.get(JWT_QUERYSTRING_ARG)
        else:
            token = None

        if token is None:
            return self.get_response(request)

        # in the event of an error we log it, but then let the request
        # continue - as the fact that the token cannot be decoded, or
        # no longer exists, may not invalidate the request itself.
        try:
            payload = decode(token)
            request.token = RequestToken.objects.get(id=payload["jti"])
        except RequestToken.DoesNotExist:
            request.token = None
            logger.exception("RequestToken no longer exists: %s", payload["jti"])
        except InvalidTokenError:
            request.token = None
            logger.exception("RequestToken cannot be decoded: %s", token)

        return self.get_response(request)

    def process_exception(
        self, request: HttpRequest, exception: Exception
    ) -> HttpResponse:
        """Handle all InvalidTokenErrors."""
        if isinstance(exception, InvalidTokenError):
            logger.exception("JWT request token error")
            response = _403(request, exception)
            if getattr(request, "token", None):
                request.token.log(request, response, error=exception)
            return response


def _403(request: HttpRequest, exception: Exception) -> HttpResponseForbidden:
    """Render HttpResponseForbidden for exception."""
    if FOUR03_TEMPLATE:
        html = loader.render_to_string(
            template_name=FOUR03_TEMPLATE,
            context={"token_error": str(exception), "exception": exception},
            request=request,
        )
        return HttpResponseForbidden(html, reason=str(exception))
    return HttpResponseForbidden(reason=str(exception))