import json
from typing import Any, Callable, List, Optional, Type, Union, cast

from django.conf import settings
from django.http import HttpRequest, HttpResponseBadRequest, JsonResponse
from django.shortcuts import render
from django.utils.decorators import method_decorator
from django.views.decorators.csrf import csrf_exempt
from django.views.generic import TemplateView
from graphql import GraphQLSchema
from graphql.execution import MiddlewareManager

from ...constants import DATA_TYPE_JSON, DATA_TYPE_MULTIPART
from ...exceptions import HttpBadRequestError
from ...file_uploads import combine_multipart_data
from ...format_error import format_error
from ...graphql import graphql_sync
from ...types import (
    ContextValue,
    ErrorFormatter,
    Extension,
    GraphQLResult,
    RootValue,
    ValidationRules,
)

ExtensionList = Optional[List[Type[Extension]]]
Extensions = Union[
    Callable[[Any, Optional[ContextValue]], ExtensionList], ExtensionList
]

DEFAULT_PLAYGROUND_OPTIONS = {"request.credentials": "same-origin"}


@method_decorator(csrf_exempt, name="dispatch")
class GraphQLView(TemplateView):
    http_method_names = ["get", "post", "options"]
    template_name = "ariadne/graphql_playground.html"
    playground_options: Optional[dict] = None
    introspection: bool = True
    schema: Optional[GraphQLSchema] = None
    context_value: Optional[ContextValue] = None
    root_value: Optional[RootValue] = None
    logger = None
    validation_rules: Optional[ValidationRules] = None
    error_formatter: Optional[ErrorFormatter] = None
    extensions: Optional[Extensions] = None
    middleware: Optional[MiddlewareManager] = None

    def get(
        self, request: HttpRequest, *args, **kwargs
    ):  # pylint: disable=unused-argument
        options = DEFAULT_PLAYGROUND_OPTIONS.copy()
        if self.playground_options:
            options.update(self.playground_options)

        return render(
            request,
            self.get_template_names(),
            {"playground_options": json.dumps(options)},
        )

    def post(
        self, request: HttpRequest, *args, **kwargs
    ):  # pylint: disable=unused-argument
        if not self.schema:
            raise ValueError("GraphQLView was initialized without schema.")

        try:
            data = self.extract_data_from_request(request)
        except HttpBadRequestError as error:
            return HttpResponseBadRequest(error.message)

        success, result = self.execute_query(request, data)
        status_code = 200 if success else 400
        return JsonResponse(result, status=status_code)

    def extract_data_from_request(self, request: HttpRequest):
        content_type = request.content_type or ""
        content_type = content_type.split(";")[0]

        if content_type == DATA_TYPE_JSON:
            return self.extract_data_from_json_request(request)
        if content_type == DATA_TYPE_MULTIPART:
            return self.extract_data_from_multipart_request(request)

        raise HttpBadRequestError(
            "Posted content must be of type {} or {}".format(
                DATA_TYPE_JSON, DATA_TYPE_MULTIPART
            )
        )

    def extract_data_from_json_request(self, request: HttpRequest):
        try:
            return json.loads(request.body)
        except (TypeError, ValueError):
            raise HttpBadRequestError("Request body is not a valid JSON")

    def extract_data_from_multipart_request(self, request: HttpRequest):
        try:
            operations = json.loads(request.POST.get("operations"))
        except (TypeError, ValueError):
            raise HttpBadRequestError(
                "Request 'operations' multipart field is not a valid JSON"
            )
        try:
            files_map = json.loads(request.POST.get("map"))
        except (TypeError, ValueError):
            raise HttpBadRequestError(
                "Request 'map' multipart field is not a valid JSON"
            )

        return combine_multipart_data(operations, files_map, request.FILES)

    def execute_query(self, request: HttpRequest, data: dict) -> GraphQLResult:
        context_value = self.get_context_for_request(request)
        extensions = self.get_extensions_for_request(request, context_value)

        return graphql_sync(
            cast(GraphQLSchema, self.schema),
            data,
            context_value=context_value,
            root_value=self.root_value,
            validation_rules=self.validation_rules,
            debug=settings.DEBUG,
            introspection=self.introspection,
            logger=self.logger,
            error_formatter=self.error_formatter or format_error,
            extensions=extensions,
            middleware=self.middleware,
        )

    def get_context_for_request(self, request: HttpRequest) -> Optional[ContextValue]:
        if callable(self.context_value):
            return self.context_value(request)  # pylint: disable=not-callable
        return self.context_value or {"request": request}

    def get_extensions_for_request(
        self, request: HttpRequest, context: Optional[ContextValue]
    ) -> ExtensionList:
        if callable(self.extensions):
            return self.extensions(request, context)  # pylint: disable=not-callable
        return self.extensions