import builtins
import contextlib
import copy
import inspect
import functools
import sys
import traceback
import types
from typing import *

from crosshair.condition_parser import Conditions, get_fn_conditions, ClassConditions, get_class_conditions, fn_globals
from crosshair.util import IdentityWrapper, AttributeHolder


class PreconditionFailed(BaseException):
    pass


class PostconditionFailed(BaseException):
    pass


def is_singledispatcher(fn: Callable) -> bool:
    return hasattr(fn, 'registry') and isinstance(fn.registry, Mapping)  # type: ignore


def EnforcementWrapper(fn: Callable, conditions: Conditions, enforced: 'EnforcedConditions') -> Callable:
    signature = conditions.sig

    def wrapper(*a, **kw):
        fns_enforcing = enforced.fns_enforcing
        if fns_enforcing is None or fn in fns_enforcing:
            return fn(*a, **kw)
        #print('Calling enforcement wrapper ', fn)
        bound_args = signature.bind(*a, **kw)
        bound_args.apply_defaults()
        old = {}
        mutable_args = conditions.mutable_args
        mutable_args_remaining = set(mutable_args) if mutable_args is not None else set()
        for argname, argval in bound_args.arguments.items():
            old[argname] = copy.copy(argval)
            if argname in mutable_args_remaining:
                mutable_args_remaining.remove(argname)
        if mutable_args_remaining:
            raise PostconditionFailed('Unrecognized mutable argument(s) in postcondition: "{}"'.format(
                ','.join(mutable_args_remaining)))
        with enforced.currently_enforcing(fn):
            for precondition in conditions.pre:
                #print(' precondition eval ', precondition.expr_source)
                args = {**fn_globals(fn), **bound_args.arguments}
                if not eval(precondition.expr, args):
                    raise PreconditionFailed(
                        f'Precondition "{precondition.expr_source}" was not satisfied '
                        f'before calling "{fn.__name__}"')
        ret = fn(*a, **kw)
        with enforced.currently_enforcing(fn):
            lcls = {**bound_args.arguments, '__return__': ret,
                    '_': ret, '__old__': AttributeHolder(old)}
            args = {**fn_globals(fn), **lcls}
            for postcondition in conditions.post:
                #print(' postcondition eval ', postcondition.expr_source, fn, lcls['_'])
                if postcondition.expr and not eval(postcondition.expr, args):
                    raise PostconditionFailed('Postcondition failed at {}:{}'.format(
                        postcondition.filename, postcondition.line))
        #print('Completed enforcement wrapper ', fn)
        return ret
    return wrapper


class EnforcedConditions:
    def __init__(self, *envs, interceptor=lambda x: x):
        self.envs = envs
        self.interceptor = interceptor
        self.fns_enforcing: Optional[Set[Callable]] = set()
        self.wrapper_map: Dict[Callable, Callable] = {}
        self.original_map: Dict[IdentityWrapper[Callable], Callable] = {}

    def _wrap_class(self, cls: type, class_conditions: ClassConditions) -> None:
        #print('wrapping class ', cls)
        method_conditions = dict(class_conditions.methods)
        for method_name, method in list(inspect.getmembers(cls, inspect.isfunction)):
            conditions = method_conditions.get(method_name)
            if conditions is None:
                continue
            wrapper = self._wrap_fn(method, conditions)
            setattr(cls, method_name, wrapper)

    def _transform_singledispatch(self, fn, transformer):
        overloads = list(fn.registry.items())
        wrapped = functools.singledispatch(transformer(overloads[0][1]))
        for overload_typ, overload_fn in overloads[1:]:
            wrapped.register(overload_typ)(transformer(overload_fn))
        return wrapped

    def is_enforcement_wrapper(self, value):
        return IdentityWrapper(value) in self.original_map

    @contextlib.contextmanager
    def currently_enforcing(self, fn: Callable):
        if self.fns_enforcing is None:
            yield None
        else:
            self.fns_enforcing.add(fn)
            try:
                yield None
            finally:
                self.fns_enforcing.remove(fn)

    @contextlib.contextmanager
    def disabled_enforcement(self):
        prev = self.fns_enforcing
        assert prev is not None
        self.fns_enforcing = None
        try:
            yield None
        finally:
            self.fns_enforcing = prev

    @contextlib.contextmanager
    def enabled_enforcement(self):
        prev = self.fns_enforcing
        assert prev is None
        self.fns_enforcing = set()
        try:
            yield None
        finally:
            self.fns_enforcing = prev

    def __enter__(self):
        next_envs = [env.copy() for env in self.envs]
        for env, next_env in zip(self.envs, next_envs):
            for (k, v) in env.items():
                if isinstance(v, (types.FunctionType, types.BuiltinFunctionType)):
                    if is_singledispatcher(v):
                        wrapper = self._transform_singledispatch(
                            v, self._wrap_fn)
                    else:
                        wrapper = self._wrap_fn(v)
                        if wrapper is v:
                            continue
                    next_env[k] = wrapper
                elif isinstance(v, type):
                    conditions = get_class_conditions(v)
                    if conditions.has_any():
                        self._wrap_class(v, conditions)
        for env, next_env in zip(self.envs, next_envs):
            env.update(next_env)
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        next_envs = [env.copy() for env in self.envs]
        for env, next_env in zip(self.envs, next_envs):
            for (k, v) in list(env.items()):
                next_env[k] = self._unwrap(v)
        for env, next_env in zip(self.envs, next_envs):
            env.update(next_env)
        return False

    def _unwrap(self, value):
        if self.is_enforcement_wrapper(value):
            return self.original_map[IdentityWrapper(value)]
        elif is_singledispatcher(value):
            return self._transform_singledispatch(value, self._unwrap)
        elif isinstance(value, type):
            self._unwrap_class(value)
        return value

    def _unwrap_class(self, cls: type):
        for method_name, method in list(inspect.getmembers(cls, inspect.isfunction)):
            if self.is_enforcement_wrapper(method):
                setattr(cls, method_name,
                        self.original_map[IdentityWrapper(method)])

    def _wrap_fn(self, fn: Callable, conditions: Optional[Conditions] = None) -> Callable:
        wrapper = self.wrapper_map.get(fn)
        if wrapper is not None:
            return wrapper
        if self.is_enforcement_wrapper(fn):
            return fn

        conditions = conditions or get_fn_conditions(fn)
        if conditions and conditions.has_any():
            wrapper = EnforcementWrapper(
                self.interceptor(fn), conditions, self)
            functools.update_wrapper(wrapper, fn)
        else:
            wrapper = fn
        self.wrapper_map[fn] = wrapper
        self.original_map[IdentityWrapper(wrapper)] = fn
        return wrapper