from copy import deepcopy
from functools import partial
from inspect import isawaitable
from typing import Any, Callable, Dict, Optional

from graphql import GraphQLResolveInfo
from opentracing import Scope, Tracer, global_tracer
from opentracing.ext import tags

from ...types import ContextValue, Extension, Resolver
from .utils import format_path, should_trace

ArgFilter = Callable[[Dict[str, Any], GraphQLResolveInfo], Dict[str, Any]]


class OpenTracingExtension(Extension):
    _arg_filter: Optional[ArgFilter]
    _root_scope: Scope
    _tracer: Tracer

    def __init__(self, *, arg_filter: Optional[ArgFilter] = None):
        self._arg_filter = arg_filter
        self._tracer = global_tracer()
        self._root_scope = None

    def request_started(self, context: ContextValue):
        self._root_scope = self._tracer.start_active_span("GraphQL Query")
        self._root_scope.span.set_tag(tags.COMPONENT, "graphql")

    def request_finished(self, context: ContextValue):
        self._root_scope.close()

    async def resolve(
        self, next_: Resolver, parent: Any, info: GraphQLResolveInfo, **kwargs
    ):
        if not should_trace(info):
            result = next_(parent, info, **kwargs)
            if isawaitable(result):
                result = await result
            return result

        with self._tracer.start_active_span(info.field_name) as scope:
            span = scope.span
            span.set_tag(tags.COMPONENT, "graphql")
            span.set_tag("graphql.parentType", info.parent_type.name)

            graphql_path = ".".join(
                map(str, format_path(info.path))  # pylint: disable=bad-builtin
            )
            span.set_tag("graphql.path", graphql_path)

            if kwargs:
                filtered_kwargs = self.filter_resolver_args(kwargs, info)
                for kwarg, value in filtered_kwargs.items():
                    span.set_tag(f"graphql.param.{kwarg}", value)

            result = next_(parent, info, **kwargs)
            if isawaitable(result):
                result = await result
            return result

    def filter_resolver_args(
        self, args: Dict[str, Any], info: GraphQLResolveInfo
    ) -> Dict[str, Any]:
        if not self._arg_filter:
            return args

        return self._arg_filter(deepcopy(args), info)


class OpenTracingExtensionSync(OpenTracingExtension):
    def resolve(
        self, next_: Resolver, parent: Any, info: GraphQLResolveInfo, **kwargs
    ):  # pylint: disable=invalid-overridden-method
        if not should_trace(info):
            result = next_(parent, info, **kwargs)
            return result

        with self._tracer.start_active_span(info.field_name) as scope:
            span = scope.span
            span.set_tag(tags.COMPONENT, "graphql")
            span.set_tag("graphql.parentType", info.parent_type.name)

            graphql_path = ".".join(
                map(str, format_path(info.path))  # pylint: disable=bad-builtin
            )
            span.set_tag("graphql.path", graphql_path)

            if kwargs:
                filtered_kwargs = self.filter_resolver_args(kwargs, info)
                for kwarg, value in filtered_kwargs.items():
                    span.set_tag(f"graphql.param.{kwarg}", value)

            result = next_(parent, info, **kwargs)
            return result


def opentracing_extension(*, arg_filter: Optional[ArgFilter] = None):
    return partial(OpenTracingExtension, arg_filter=arg_filter)


def opentracing_extension_sync(*, arg_filter: Optional[ArgFilter] = None):
    return partial(OpenTracingExtensionSync, arg_filter=arg_filter)