from __future__ import absolute_import
from __future__ import print_function

import json
import time
from numbers import Number
from typing import Any
from typing import Callable
from typing import cast
from typing import List
from typing import NamedTuple
from typing import Optional
from typing import Union

from google.protobuf.timestamp_pb2 import Timestamp
from graphql.pyutils import Path
from tornado.httputil import HTTPServerRequest

from graphene_tornado.ext.apollo_engine_reporting.engine_agent import (
    EngineReportingOptions,
)
from graphene_tornado.ext.apollo_engine_reporting.reports_pb2 import Trace
from graphene_tornado.graphql_extension import GraphQLExtension

CLIENT_NAME_HEADER = "apollographql-client-name"
CLIENT_REFERENCE_HEADER_KEY = "apollographql-client-reference-id"
CLIENT_VERSION_HEADER_KEY = "apollographql-client-version"


ClientInfo = NamedTuple(
    "EngineReportingOptions",
    [("client_name", str), ("client_reference_id", str), ("client_version", str)],
)


def generate_client_info(request: HTTPServerRequest) -> ClientInfo:
    return ClientInfo(
        request.headers.get(CLIENT_NAME_HEADER, ""),
        request.headers.get(CLIENT_REFERENCE_HEADER_KEY, ""),
        request.headers.get(CLIENT_VERSION_HEADER_KEY, ""),
    )


def response_path_as_string(path: Optional[List[Union[str, int]]]) -> str:
    if not path or len(path) == 0:
        return ""
    return ".".join([str(p) for p in path])


def now_ns() -> int:
    return time.time_ns()


class EngineReportingExtension(GraphQLExtension):
    def __init__(self, options: EngineReportingOptions, add_trace: Callable) -> None:
        if add_trace is None:
            raise ValueError("add_trace must be defined")

        self.add_trace = add_trace
        self.operation_name = None

        self.options = options  # maskErrorDetails = False
        self.start_time = now_ns()

        root = Trace.Node()
        root.start_time = self.start_time
        self.trace = Trace(root=root)
        self.nodes = {response_path_as_string(None): root}
        self.generate_client_info = options.generate_client_info or generate_client_info
        self.resolver_stats: List[Any] = list()

    async def request_started(
        self,
        request,
        query_string,
        parsed_query,
        operation_name,
        variables,
        context,
        request_context,
    ):
        self.trace.start_time.GetCurrentTime()
        self.query_string = query_string
        self.document = parsed_query
        self.trace.http.method = self._get_http_method(request)

        client_info = generate_client_info(request)
        if client_info:
            self.trace.client_version = client_info.client_version or ""
            self.trace.client_reference_id = client_info.client_reference_id or ""
            self.trace.client_name = client_info.client_name or ""

        async def on_request_ended(errors):
            start_nanos = self.trace.start_time.ToNanoseconds()
            now = Timestamp()
            now.GetCurrentTime()
            self.trace.duration_ns = now.ToNanoseconds() - start_nanos
            self.trace.end_time.GetCurrentTime()

            op_name = self.operation_name or ""
            self.trace.root.MergeFrom(self.nodes.get(""))
            await self.add_trace(
                op_name,
                request_context.get("document", None),
                self.query_string,
                self.trace,
            )

        return on_request_ended

    async def parsing_started(self, query_string):
        return None

    async def validation_started(self):
        return None

    async def execution_started(
        self,
        schema,
        document,
        root,
        context,
        variables,
        operation_name,
        request_context,
    ):
        if operation_name:
            self.operation_name = operation_name
        request_context["document"] = document

    async def will_resolve_field(self, root, info, **args):
        if not self.operation_name:
            self.operation_name = (
                "" if not info.operation.name else info.operation.name.value
            )

        node = self._new_node(info.path)
        node.start_time = now_ns() - self.start_time
        node.type = str(info.return_type)
        node.parent_type = str(info.parent_type)

        async def on_end(errors=None, result=None):
            node.end_time = now_ns() - self.start_time

        return on_end

    async def will_send_response(self, response, context):
        root = self.nodes.get("", None)
        root.end_time = now_ns()

        if hasattr(response, "errors"):
            errors = response.errors
            for error in errors:
                node = root
                if hasattr(error, "path"):
                    specific_node = self.nodes.get(error.path.join("."))
                    if specific_node:
                        node = specific_node

                if (
                    hasattr(self.options, "mask_error_details")
                    and self.options.mask_errors_details
                ):
                    error_info = {"message": "<masked>"}
                else:
                    error_info = {"message": str(error), "json": json.dumps(error)}
                node.error.add(error=Trace.Error(**error_info))

    def _get_http_method(self, request):
        try:
            return getattr(Trace.HTTP, request.method.upper())
        except:
            return Trace.HTTP.UNKNOWN

    def _new_node(self, path: Path):
        node = Trace.Node()

        path_list = path.as_list()

        id = path_list[-1]
        if isinstance(id, int):
            node.index = id
        else:
            node.response_name = cast(str, id)

        self.nodes[response_path_as_string(path_list)] = node
        parent_node = self._ensure_parent_node(path)
        n = parent_node.child.add()
        n.MergeFrom(node)
        self.nodes[response_path_as_string(path_list)] = n
        return n

    def _ensure_parent_node(self, path: Path):
        parent_path = response_path_as_string(path.prev)
        parent_node = self.nodes.get(parent_path, None)
        if parent_node:
            return parent_node
        return self._new_node(path.prev)