import functools
import collections
import subprocess
import multiprocessing
import multiprocessing.connection
import tempfile
import shutil
import os
import sys
import inspect
import importlib
import concurrent.futures
import contextlib
import time
import random
import string

from tblib import pickling_support
pickling_support.install()

import forge
import anyio
from async_generator import aclosing


@contextlib.contextmanager
def compile_temp_proto(*relative_proto_paths):
    modules = []
    with tempfile.TemporaryDirectory() as temp_dir:
        sys.path.insert(0, temp_dir)
        try:
            for relative_proto_path in relative_proto_paths:
                proto_path = os.path.join(os.path.dirname(
                    inspect.currentframe().f_back.f_back.f_globals['__file__']),
                    relative_proto_path)
                proto_filename = os.path.basename(proto_path)
                proto_temp_path = os.path.join(temp_dir, proto_filename)
                shutil.copyfile(proto_path, proto_temp_path)
            for relative_proto_path in relative_proto_paths:
                proto_filename = os.path.basename(relative_proto_path)
                proto_temp_path = os.path.join(temp_dir, proto_filename)
                cmdline = [sys.executable, '-m', 'grpc_tools.protoc',
                           '--python_out=.', '--purerpc_out=.', '--grpc_python_out=.',
                           '-I' + temp_dir, proto_temp_path]
                subprocess.check_call(cmdline, cwd=temp_dir)

                pb2_module_name = proto_filename.replace(".proto", "_pb2")
                pb2_grpc_module_name = proto_filename.replace(".proto", "_pb2_grpc")
                grpc_module_name = proto_filename.replace(".proto", "_grpc")

                pb2_module = importlib.import_module(pb2_module_name)
                pb2_grpc_module = importlib.import_module(pb2_grpc_module_name)
                grpc_module = importlib.import_module(grpc_module_name)
                modules.extend((pb2_module, pb2_grpc_module, grpc_module))
            yield modules
        finally:
            sys.path.remove(temp_dir)


_WrappedResult = collections.namedtuple("_WrappedResult", ("result", "exc_info"))


def _wrap_gen_in_process(conn: multiprocessing.connection.Connection):
    def decorator(gen):
        @functools.wraps(gen)
        def new_func(*args, **kwargs):
            try:
                for elem in gen(*args, **kwargs):
                    conn.send(_WrappedResult(result=elem, exc_info=None))
            except:
                conn.send(_WrappedResult(result=None, exc_info=sys.exc_info()))
            finally:
                conn.close()
        return new_func
    return decorator


async def async_iterable_to_list(async_iterable):
    result = []
    async with aclosing(async_iterable) as async_iterable:
        async for value in async_iterable:
            result.append(value)
    return result


def random_payload(min_size=1000, max_size=100000):
    return "".join(random.choice(string.ascii_letters)
                   for _ in range(random.randint(min_size, max_size)))


@contextlib.contextmanager
def _run_context_manager_generator_in_process(cm_gen):
    parent_conn, child_conn = multiprocessing.Pipe(duplex=False)
    target_fn = _wrap_gen_in_process(child_conn)(cm_gen)

    process = multiprocessing.Process(target=target_fn)
    process.start()
    try:
        wrapped_result = parent_conn.recv()
        if wrapped_result.exc_info is not None:
            raise wrapped_result.exc_info[0].with_traceback(*wrapped_result.exc_info[1:])
        else:
            yield wrapped_result.result
    finally:
        try:
            if parent_conn.poll():
                exc_info = parent_conn.recv().exc_info
                if exc_info is not None:
                    raise exc_info[0].with_traceback(*exc_info[1:])
        finally:
            process.terminate()
            process.join()
            parent_conn.close()


def run_purerpc_service_in_process(service):
    def target_fn():
        import purerpc
        server = purerpc.Server(port=0)
        server.add_service(service)
        socket = server._create_socket_and_listen()
        yield socket.getsockname()[1]

        async def sleep_10_seconds_then_die():
            await anyio.sleep(20)
            raise ValueError

        async def main():
            async with anyio.create_task_group() as tg:
                await tg.spawn(server._run_async_server, socket)
                await tg.spawn(sleep_10_seconds_then_die)
        # import cProfile
        anyio.run(server._run_async_server, socket)
        # cProfile.runctx("anyio.run(main)", globals(), locals(), sort="tottime")
    return _run_context_manager_generator_in_process(target_fn)


def run_grpc_service_in_process(add_handler_fn):
    def target_fn():
        import grpc
        server = grpc.server(concurrent.futures.ThreadPoolExecutor(max_workers=1))
        port = server.add_insecure_port('[::]:0')
        add_handler_fn(server)
        server.start()
        yield port
        while True:
            time.sleep(60)
    return _run_context_manager_generator_in_process(target_fn)


def run_tests_in_workers(*, target, num_workers):
    parent_conn, child_conn = multiprocessing.Pipe(duplex=False)

    @_wrap_gen_in_process(child_conn)
    def target_fn():
        target()
        yield

    processes = [multiprocessing.Process(target=target_fn) for _ in range(num_workers)]
    for process in processes:
        process.start()

    try:
        for _ in range(num_workers):
            wrapped_result = parent_conn.recv()
            if wrapped_result.exc_info is not None:
                raise wrapped_result.exc_info[0].with_traceback(*wrapped_result.exc_info[1:])
    finally:
        parent_conn.close()
        for process in processes:
            process.join()


def async_test(corofunc):
    if not inspect.iscoroutinefunction(corofunc):
        raise TypeError("Expected coroutine function")

    @functools.wraps(corofunc)
    def func(**kwargs):
        return anyio.run(functools.partial(corofunc, **kwargs))
    return func


def grpc_client_parallelize(num_workers):
    def decorator(func):
        @functools.wraps(func)
        def new_func(*args, **kwargs):
            def target():
                func(*args, **kwargs)
            run_tests_in_workers(target=target, num_workers=num_workers)

        new_func.__parallelized__ = True
        return new_func
    return decorator


def purerpc_client_parallelize(num_tasks):
    def decorator(corofunc):
        if not inspect.iscoroutinefunction(corofunc):
            raise TypeError("Expected coroutine function")

        @functools.wraps(corofunc)
        async def new_corofunc(**kwargs):
            async with anyio.create_task_group() as tg:
                for _ in range(num_tasks):
                    await tg.spawn(functools.partial(corofunc, **kwargs))
        return new_corofunc
    return decorator


def grpc_channel(port_fixture_name, channel_arg_name="channel"):
    def decorator(func):
        if hasattr(func, "__parallelized__") and func.__parallelized__:
            raise TypeError("Cannot pass gRPC channel to already parallelized test, grpc_client_parallelize should "
                            "be the last decorator in chain")

        @forge.compose(
            forge.copy(func),
            forge.modify(channel_arg_name, name=port_fixture_name, interface_name="port_fixture_value"),
        )
        def new_func(*, port_fixture_value, **kwargs):
            import grpc
            with grpc.insecure_channel('127.0.0.1:{}'.format(port_fixture_value)) as channel:
                func(**kwargs, channel=channel)

        return new_func
    return decorator


def purerpc_channel(port_fixture_name, channel_arg_name="channel"):
    def decorator(corofunc):
        if not inspect.iscoroutinefunction(corofunc):
            raise TypeError("Expected coroutine function")

        @forge.compose(
            forge.copy(corofunc),
            forge.modify(channel_arg_name, name=port_fixture_name, interface_name="port_fixture_value"),
        )
        async def new_corofunc(*, port_fixture_value, **kwargs):
            import purerpc
            async with purerpc.insecure_channel("127.0.0.1", port_fixture_value) as channel:
                await corofunc(**kwargs, channel=channel)

        return new_corofunc
    return decorator