import io
import ast
import sys
import traceback
import linecache
from importlib.abc import InspectLoader
from typing import Callable, Optional, Any
import IPython.core.ultratb as ultratb
from .user_namespace import BuiltInManager

_assign_nodes = (ast.AugAssign, ast.AnnAssign, ast.Assign)
_single_targets_nodes = (ast.AugAssign, ast.AnnAssign)


class IncompleteExecutionResultException(Exception):
    """The result was not completely filled with the appropriate data"""
    pass


class ExecutionResult:
    def __init__(self):
        self.stdout = None
        self.stderr = None
        self.has_output = False
        self.output = None
        self.target_id = None
        self.has_exception = False
        self.exception = None

    def is_complete(self):
        stdout_fulfilled = self.stdout is not None
        stderr_fulfilled = self.stderr is not None
        output_fulfilled = (
            self.output is not None and self.target_id is not None) if self.has_output else True
        exception_fulfilled = (self.exception is not None and len(
            self.exception) == 3) if self.has_exception else True
        return stdout_fulfilled and stderr_fulfilled and output_fulfilled and exception_fulfilled

    def capture_io(self, stdout, stderr):
        self.stdout = stdout
        self.stderr = stderr

    def displayhook(self, result=None):
        self.output = result
        if self.output is not None:
            self.has_output = True


class CapturedIOCtx(object):
    def __init__(self, container_func: Callable[[str, str], Any], capture_stdout=True,
                 capture_stderr=True):
        self.capture_stdout = capture_stdout
        self.capture_stderr = capture_stderr
        self.container_func: Callable[[str, str], Any] = container_func

    def __enter__(self):
        self.sys_stdout = sys.stdout
        self.sys_stderr = sys.stderr

        stdout = stderr = None
        if self.capture_stdout:
            stdout = sys.stdout = io.StringIO()
        if self.capture_stderr:
            stderr = sys.stderr = io.StringIO()

        return self.container_func(stdout, stderr)

    def __exit__(self, exc_type, exc_value, traceback):
        sys.stdout = self.sys_stdout
        sys.stderr = self.sys_stderr


class CapturedDisplayCtx(object):
    def __init__(self, capture_func: Callable[[Any], Any]):
        self.capture_func: Callable[[Any], Any] = capture_func
        self.sys_displayhook: Optional[Callable[[Any], Any]] = None

    def __enter__(self):
        self.sys_displayhook = sys.displayhook

        displayhook = sys.displayhook = self.capture_func

        return displayhook

    def __exit__(self, exc_type, exc_value, traceback):
        sys.displayhook = self.sys_displayhook


class Executor:
    """The executor handles executing snippets of code and managing the user namespace accordingly.

    It also can run asynchronous functions (coroutines)
    """

    def __init__(self, loader: InspectLoader, ns_manager: BuiltInManager):
        self.loader = loader
        self.excepthook = sys.excepthook
        self.InteractiveTB = ultratb.AutoFormattedTB(mode='Plain',
                                                     color_scheme='LightBG',
                                                     tb_offset=1,
                                                     debugger_cls=None)
        self.SyntaxTB = ultratb.SyntaxTB(color_scheme='NoColor')
        self.ns_manager = ns_manager

    async def run_coroutine(self, coroutine, variable_name, nohandle_exceptions=()):
        exec_result = ExecutionResult()
        exec_result.target_id = variable_name

        with CapturedIOCtx(exec_result.capture_io), CapturedDisplayCtx(exec_result.displayhook):
            exec_result.has_exception = True
            try:
                exec_result.output = await coroutine
            except nohandle_exceptions as e:
                raise e
            except BaseException as e:
                try:
                    etype, value, tb = sys.exc_info()
                    stb = self.InteractiveTB.structured_traceback(
                        etype, value, tb
                    )

                    if issubclass(etype, SyntaxError):
                        # If the error occurred when executing compiled code, we
                        # should provide full stacktrace
                        elist = traceback.extract_tb(tb)
                        stb = self.SyntaxTB.structured_traceback(
                            etype, value, elist)
                        print(
                            self.InteractiveTB.stb2text(stb),
                            file=sys.stderr)
                    else:
                        # Actually show the traceback
                        print(
                            self.InteractiveTB.stb2text(stb),
                            file=sys.stderr)
                except BaseException as e:
                    print(e)
            else:
                exec_result.has_exception = False
                self.update_ns({exec_result.target_id: exec_result.output})

        return exec_result

    def update_ns(self, *args, **kwargs):
        self.ns_manager.update(*args, **kwargs)

    def run_cell(self, code, name):
        linecache.lazycache(
            name, {
                '__name__': name, '__loader__': self.loader})

        exec_result = ExecutionResult()

        with CapturedIOCtx(exec_result.capture_io), CapturedDisplayCtx(exec_result.displayhook):
            code_ast = ast.parse(code, filename=name, mode='exec')
            run_failed, output_name = self._run_ast_nodes(code_ast.body, name)

            exec_result.has_exception = run_failed
            exec_result.target_id = output_name

        return exec_result

    def _run_ast_nodes(self, nodelist, name):
        output_name = None
        if not nodelist:
            return True, output_name

        if isinstance(nodelist[-1], _assign_nodes):
            asg = nodelist[-1]
            if isinstance(asg, ast.Assign) and len(asg.targets) == 1:
                target = asg.targets[0]
            elif isinstance(asg, _single_targets_nodes):
                target = asg.target
            else:
                target = None
            if isinstance(target, ast.Name):
                output_name = target.id
                nnode = ast.Expr(ast.Name(target.id, ast.Load()))
                ast.fix_missing_locations(nnode)
                nodelist.append(nnode)

        if isinstance(nodelist[-1], ast.Expr):
            to_run_exec, to_run_interactive = nodelist[:-1], nodelist[-1:]
        else:
            to_run_exec, to_run_interactive = nodelist, []

        try:
            mod = ast.Module(to_run_exec)
            code = compile(mod, name, 'exec')
            if self._run_code(code):
                return True, output_name

            for node in to_run_interactive:
                mod = ast.Interactive([node])
                code = compile(mod, name, 'single')
                if self._run_code(code):
                    return True, output_name
        except BaseException:
            return True, output_name

        return False, output_name

    def _run_code(self, code_obj):
        old_excepthook, sys.excepthook = sys.excepthook, self.excepthook
        outflag = True  # happens in more places, so it's easier as default
        try:
            try:
                exec(code_obj, self.ns_manager.global_ns,
                     self.ns_manager.global_ns)
            finally:
                # Reset our crash handler in place
                sys.excepthook = old_excepthook
        except BaseException:
            try:
                etype, value, tb = sys.exc_info()
                stb = self.InteractiveTB.structured_traceback(
                    etype, value, tb
                )

                if issubclass(etype, SyntaxError):
                    # If the error occurred when executing compiled code, we
                    # should provide full stacktrace
                    elist = traceback.extract_tb(tb)
                    stb = self.SyntaxTB.structured_traceback(
                        etype, value, elist)
                    print(self.InteractiveTB.stb2text(stb), file=sys.stderr)
                else:
                    # Actually show the traceback
                    print(self.InteractiveTB.stb2text(stb), file=sys.stderr)

            except BaseException as e:
                print(e)
        else:
            outflag = False

        return outflag