import time
from abc import abstractmethod
from functools import partial

from prometheus_client import Summary

from ..graph import GraphTransformer
from ..engine import pass_context, _do_pass_context
from ..sources.graph import CheckedExpr


_METRIC = None


def _get_default_metric():
    global _METRIC
    if _METRIC is None:
        _METRIC = Summary(
            'graph_field_time',
            'Graph field time (seconds)',
            ['graph', 'node', 'field'],
        )
    return _METRIC


def _func_field_names(func):
    fields_pos = 1 if _do_pass_context(func) else 0

    def wrapper(*args):
        return func([f.name for f in args[fields_pos]], *args)
    return wrapper


def _subquery_field_names(func):
    def wrapper(fields, *args):
        return func([f.name for _, f in fields], fields, *args)
    return wrapper


class GraphMetricsBase(GraphTransformer):
    root_name = 'Root'

    def __init__(self, name, *, metric=None):
        self._name = name
        self._metric = metric or _get_default_metric()
        self._node = None
        self._wrappers = {}

    @abstractmethod
    def field_wrapper(self, observe, func):
        raise NotImplementedError

    @abstractmethod
    def link_wrapper(self, observe, func):
        raise NotImplementedError

    @abstractmethod
    def subquery_wrapper(self, observe, subquery):
        raise NotImplementedError

    def _observe_fields(self, node_name):
        by_field = {}

        def observe(start_time, field_names):
            duration = time.perf_counter() - start_time
            for name in field_names:
                try:
                    field_metric = by_field[name]
                except KeyError:
                    field_metric = by_field[name] = \
                        self._metric.labels(self._name, node_name, name)
                field_metric.observe(duration)
        return observe

    def _wrap_field(self, node_name, func):
        observe = self._observe_fields(node_name)
        wrapper = self.field_wrapper(observe, func)
        if _do_pass_context(func):
            wrapper = pass_context(wrapper)
        wrapper = _func_field_names(wrapper)
        if _do_pass_context(func):
            wrapper = pass_context(wrapper)
        return wrapper

    def _wrap_link(self, node_name, link_name, func):
        observe = self._observe_fields(node_name)
        wrapper = self.link_wrapper(observe, func)
        if _do_pass_context(func):
            wrapper = pass_context(wrapper)
        wrapper = partial(wrapper, link_name)
        if _do_pass_context(func):
            wrapper = pass_context(wrapper)
        return wrapper

    def _wrap_subquery(self, node_name, subquery):
        observe = self._observe_fields(node_name)
        wrapper = self.subquery_wrapper(observe, subquery)
        wrapper = _subquery_field_names(wrapper)
        wrapper.__subquery__ = lambda: wrapper
        return wrapper

    def visit_node(self, obj):
        self._node = obj
        try:
            return super().visit_node(obj)
        finally:
            self._node = None

    def visit_field(self, obj):
        obj = super().visit_field(obj)
        node_name = self.root_name if self._node is None else self._node.name
        if isinstance(obj.func, CheckedExpr):
            func = obj.func.__subquery__
        else:
            func = obj.func

        wrapper = self._wrappers.get(func)
        if wrapper is None:
            if isinstance(obj.func, CheckedExpr):
                wrapper = self._wrappers[func] = self._wrap_subquery(
                    node_name, func,
                )
            else:
                wrapper = self._wrappers[func] = self._wrap_field(
                    node_name, func,
                )

        if isinstance(obj.func, CheckedExpr):
            obj.func = CheckedExpr(
                wrapper,
                obj.func.expr,
                obj.func.reqs,
                obj.func.proc,
            )
        else:
            obj.func = wrapper
        return obj

    def visit_link(self, obj):
        obj = super().visit_link(obj)
        node_name = self.root_name if self._node is None else self._node.name
        obj.func = self._wrap_link(node_name, obj.name, obj.func)
        return obj


class _SubqueryMixin:

    def subquery_wrapper(self, observe, subquery):
        def wrapper(field_names, *args):
            start_time = time.perf_counter()
            result_proc = subquery(*args)

            def proc_wrapper():
                result = result_proc()
                observe(start_time, field_names)
                return result
            return proc_wrapper
        return wrapper


class GraphMetrics(_SubqueryMixin, GraphMetricsBase):

    def field_wrapper(self, observe, func):
        def wrapper(field_names, *args):
            start_time = time.perf_counter()
            result = func(*args)
            observe(start_time, field_names)
            return result
        return wrapper

    def link_wrapper(self, observe, func):
        def wrapper(link_name, *args):
            start_time = time.perf_counter()
            result = func(*args)
            observe(start_time, [link_name])
            return result
        return wrapper


class AsyncGraphMetrics(_SubqueryMixin, GraphMetricsBase):

    def field_wrapper(self, observe, func):
        async def wrapper(field_names, *args):
            start_time = time.perf_counter()
            result = await func(*args)
            observe(start_time, field_names)
            return result
        return wrapper

    def link_wrapper(self, observe, func):
        async def wrapper(link_name, *args):
            start_time = time.perf_counter()
            result = await func(*args)
            observe(start_time, [link_name])
            return result
        return wrapper