import asyncio
from concurrent.futures import ThreadPoolExecutor, Executor
from itertools import count
from threading import current_thread
from unittest.mock import patch

import pytest
from async_generator import yield_

from asphalt.core import (
    ResourceConflict, ResourceNotFound, Context, context_teardown, callable_name, executor)
from asphalt.core.context import ResourceContainer, TeardownError


@pytest.fixture
def context(event_loop):
    return Context()


@pytest.fixture
def special_executor(context):
    executor = ThreadPoolExecutor(1)
    context.add_resource(executor, 'special', types=[Executor])
    yield executor
    executor.shutdown()


class TestResourceContainer:
    @pytest.mark.parametrize('thread', [False, True], ids=['eventloop', 'worker'])
    @pytest.mark.parametrize('context_attr', [None, 'attrname'], ids=['no_attr', 'has_attr'])
    @pytest.mark.asyncio
    async def test_generate_value(self, thread, context_attr):
        container = ResourceContainer(lambda ctx: 'foo', (str,), 'default', context_attr, True)
        context = Context()
        if thread:
            value = await context.call_in_executor(container.generate_value, context)
        else:
            value = container.generate_value(context)

        assert value == 'foo'
        assert context.get_resource(str) == 'foo'
        if context_attr:
            assert getattr(context, context_attr) == 'foo'

    def test_repr(self):
        container = ResourceContainer('foo', (str,), 'default', 'attrname', False)
        assert repr(container) == ("ResourceContainer(value='foo', types=[str], name='default', "
                                   "context_attr='attrname')")

    def test_repr_factory(self):
        container = ResourceContainer(lambda ctx: 'foo', (str,), 'default', 'attrname', True)
        assert repr(container) == (
            "ResourceContainer(factory=test_context.TestResourceContainer.test_repr_factory."
            "<locals>.<lambda>, types=[str], name='default', context_attr='attrname')")


class TestContext:
    def test_parent(self):
        """Test that the parent property points to the parent context instance, if any."""
        parent = Context()
        child = Context(parent)
        assert parent.parent is None
        assert child.parent is parent

    @pytest.mark.parametrize('exception', [None, Exception('foo')],
                             ids=['noexception', 'exception'])
    @pytest.mark.asyncio
    async def test_close(self, context, exception):
        """
        Test that teardown callbacks are called in reverse order when a context is closed.

        """
        def callback(exception=None):
            called_functions.append((callback, exception))

        async def async_callback(exception=None):
            called_functions.append((async_callback, exception))

        called_functions = []
        context.add_teardown_callback(callback, pass_exception=True)
        context.add_teardown_callback(async_callback, pass_exception=True)
        await context.close(exception)

        assert called_functions == [(async_callback, exception), (callback, exception)]

    @pytest.mark.asyncio
    async def test_teardown_callback_exception(self, context):
        """
        Test that all callbacks are called even when some teardown callbacks raise exceptions,
        and that a TeardownError is raised in such a case, containing the exception objects.

        """
        def callback1():
            items.append(1)

        def callback2():
            raise Exception('foo')

        context.add_teardown_callback(callback1)
        context.add_teardown_callback(callback2)
        context.add_teardown_callback(callback1)
        context.add_teardown_callback(callback2)
        items = []
        with pytest.raises(TeardownError) as exc:
            await context.close()

        assert 'foo' in str(exc.value)
        assert items == [1, 1]
        assert len(exc.value.exceptions) == 2

    @pytest.mark.asyncio
    async def test_close_closed(self, context):
        """Test that closing an already closed context raises a RuntimeError."""
        assert not context.closed
        await context.close()
        assert context.closed

        with pytest.raises(RuntimeError) as exc:
            await context.close()

        exc.match('this context has already been closed')

    def test_contextmanager_exception(self, context, event_loop):
        close_future = event_loop.create_future()
        close_future.set_result(None)
        exception = Exception('foo')
        with patch.object(context, 'close', return_value=close_future) as close:
            with pytest.raises(Exception) as exc:
                with context:
                    raise exception

        close.assert_called_once_with(exception)
        assert exc.value is exception

    @pytest.mark.asyncio
    async def test_async_contextmanager_exception(self, event_loop, context):
        """Test that "async with context:" calls close() with the exception raised in the block."""
        close_future = event_loop.create_future()
        close_future.set_result(None)
        exception = Exception('foo')
        with patch.object(context, 'close', return_value=close_future) as close:
            with pytest.raises(Exception) as exc:
                async with context:
                    raise exception

        close.assert_called_once_with(exception)
        assert exc.value is exception

    @pytest.mark.parametrize('types', [int, (int,), ()], ids=['type', 'tuple', 'empty'])
    @pytest.mark.asyncio
    async def test_add_resource(self, context, event_loop, types):
        """Test that a resource is properly added in the context and listeners are notified."""
        event_loop.call_soon(context.add_resource, 6, 'foo', 'foo.bar', types)
        event = await context.resource_added.wait_event()

        assert event.resource_types == (int,)
        assert event.resource_name == 'foo'
        assert not event.is_factory
        assert context.get_resource(int, 'foo') == 6

    @pytest.mark.asyncio
    async def test_add_resource_name_conflict(self, context):
        """Test that adding a resource won't replace any existing resources."""
        context.add_resource(5, 'foo')
        with pytest.raises(ResourceConflict) as exc:
            context.add_resource(4, 'foo')

        exc.match("this context already contains a resource of type int using the name 'foo'")

    @pytest.mark.asyncio
    async def test_add_resource_none_value(self, context):
        """Test that None is not accepted as a resource value."""
        exc = pytest.raises(ValueError, context.add_resource, None)
        exc.match('"value" must not be None')

    @pytest.mark.asyncio
    async def test_add_resource_context_attr(self, context):
        """Test that when resources are added, they are also set as properties of the context."""
        context.add_resource(1, context_attr='foo')
        assert context.foo == 1

    def test_add_resource_context_attr_conflict(self, context):
        """
        Test that the context won't allow adding a resource with an attribute name that conflicts
        with an existing attribute.

        """
        context.a = 2
        with pytest.raises(ResourceConflict) as exc:
            context.add_resource(2, context_attr='a')

        exc.match("this context already has an attribute 'a'")
        assert context.get_resource(int) is None

    @pytest.mark.asyncio
    async def test_add_resource_type_conflict(self, context):
        context.add_resource(5)
        with pytest.raises(ResourceConflict) as exc:
            await context.add_resource(6)

        exc.match("this context already contains a resource of type int using the name 'default'")

    @pytest.mark.parametrize('name', ['a.b', 'a:b', 'a b'], ids=['dot', 'colon', 'space'])
    @pytest.mark.asyncio
    async def test_add_resource_bad_name(self, context, name):
        with pytest.raises(ValueError) as exc:
            context.add_resource(1, name)

        exc.match('"name" must be a nonempty string consisting only of alphanumeric characters '
                  'and underscores')

    @pytest.mark.asyncio
    async def test_add_resource_factory(self, context):
        """Test that resources factory callbacks are only called once for each context."""
        def factory(ctx):
            assert ctx is context
            return next(counter)

        counter = count(1)
        context.add_resource_factory(factory, int, context_attr='foo')
        assert context.foo == 1
        assert context.foo == 1
        assert context.__dict__['foo'] == 1

    @pytest.mark.parametrize('name', ['a.b', 'a:b', 'a b'], ids=['dot', 'colon', 'space'])
    @pytest.mark.asyncio
    async def test_add_resource_factory_bad_name(self, context, name):
        with pytest.raises(ValueError) as exc:
            context.add_resource_factory(lambda ctx: 1, int, name)

        exc.match('"name" must be a nonempty string consisting only of alphanumeric characters '
                  'and underscores')

    @pytest.mark.asyncio
    async def test_add_resource_factory_coroutine_callback(self, context):
        async def factory(ctx):
            return 1

        with pytest.raises(TypeError) as exc:
            context.add_resource_factory(factory, int)

        exc.match('"factory_callback" must not be a coroutine function')

    @pytest.mark.asyncio
    async def test_add_resource_factory_empty_types(self, context):
        with pytest.raises(ValueError) as exc:
            context.add_resource_factory(lambda ctx: 1, ())

        exc.match('"types" must not be empty')

    @pytest.mark.asyncio
    async def test_add_resource_factory_context_attr_conflict(self, context):
        context.add_resource_factory(lambda ctx: None, str, context_attr='foo')
        with pytest.raises(ResourceConflict) as exc:
            await context.add_resource_factory(lambda ctx: None, str, context_attr='foo')

        exc.match(
            "this context already contains a resource factory for the context attribute 'foo'")

    @pytest.mark.asyncio
    async def test_add_resource_factory_type_conflict(self, context):
        context.add_resource_factory(lambda ctx: None, (str, int))
        with pytest.raises(ResourceConflict) as exc:
            await context.add_resource_factory(lambda ctx: None, int)

        exc.match('this context already contains a resource factory for the type int')

    @pytest.mark.asyncio
    async def test_add_resource_factory_no_inherit(self, context):
        """
        Test that a subcontext gets its own version of a factory-generated resource even if a
        parent context has one already.

        """
        context.add_resource_factory(id, int, context_attr='foo')
        subcontext = Context(context)
        assert context.foo == id(context)
        assert subcontext.foo == id(subcontext)

    def test_getattr_attribute_error(self, context):
        child_context = Context(context)
        pytest.raises(AttributeError, getattr, child_context, 'foo').\
            match('no such context variable: foo')

    def test_getattr_parent(self, context):
        """
        Test that accessing a nonexistent attribute on a context retrieves the value from parent.

        """
        child_context = Context(context)
        context.a = 2
        assert child_context.a == 2

    def test_get_resources(self, context):
        context.add_resource(9, 'foo')
        context.add_resource_factory(lambda ctx: len(ctx.context_chain), int, 'bar')
        context.require_resource(int, 'bar')
        subctx = Context(context)
        subctx.add_resource(4, 'foo')
        assert subctx.get_resources(int) == {1, 4}

    def test_require_resource(self, context):
        context.add_resource(1)
        assert context.require_resource(int) == 1

    def test_require_resource_not_found(self, context):
        """Test that ResourceNotFound is raised when a required resource is not found."""
        exc = pytest.raises(ResourceNotFound, context.require_resource, int, 'foo')
        exc.match("no matching resource was found for type=int name='foo'")
        assert exc.value.type == int
        assert exc.value.name == 'foo'

    @pytest.mark.asyncio
    async def test_request_resource_parent_add(self, context, event_loop):
        """
        Test that adding a resource to the parent context will satisfy a resource request in a
        child context.

        """
        child_context = Context(context)
        task = event_loop.create_task(child_context.request_resource(int))
        event_loop.call_soon(context.add_resource, 6)
        resource = await task
        assert resource == 6

    @pytest.mark.asyncio
    async def test_request_resource_factory_context_attr(self, context):
        """Test that requesting a factory-generated resource also sets the context variable."""
        context.add_resource_factory(lambda ctx: 6, int, context_attr='foo')
        await context.request_resource(int)
        assert context.__dict__['foo'] == 6

    @pytest.mark.asyncio
    async def test_call_async_plain(self, context):
        def runs_in_event_loop(worker_thread, x, y):
            assert current_thread() is not worker_thread
            return x + y

        def runs_in_worker_thread():
            worker_thread = current_thread()
            return context.call_async(runs_in_event_loop, worker_thread, 1, y=2)

        assert await context.call_in_executor(runs_in_worker_thread) == 3

    @pytest.mark.asyncio
    async def test_call_async_coroutine(self, context):
        async def runs_in_event_loop(worker_thread, x, y):
            assert current_thread() is not worker_thread
            await asyncio.sleep(0.1)
            return x + y

        def runs_in_worker_thread():
            worker_thread = current_thread()
            return context.call_async(runs_in_event_loop, worker_thread, 1, y=2)

        assert await context.call_in_executor(runs_in_worker_thread) == 3

    @pytest.mark.asyncio
    async def test_call_async_exception(self, context):
        def runs_in_event_loop():
            raise ValueError('foo')

        with pytest.raises(ValueError) as exc:
            await context.call_in_executor(context.call_async, runs_in_event_loop)

        assert exc.match('foo')

    @pytest.mark.asyncio
    async def test_call_in_executor(self, context):
        """Test that call_in_executor actually runs the target in a worker thread."""
        worker_thread = await context.call_in_executor(current_thread)
        assert worker_thread is not current_thread()

    @pytest.mark.parametrize('use_resource_name', [True, False], ids=['direct', 'resource'])
    @pytest.mark.asyncio
    async def test_call_in_executor_explicit(self, context, use_resource_name):
        executor = ThreadPoolExecutor(1)
        context.add_resource(executor, types=[Executor])
        context.add_teardown_callback(executor.shutdown)
        executor_arg = 'default' if use_resource_name else executor
        worker_thread = await context.call_in_executor(current_thread, executor=executor_arg)
        assert worker_thread is not current_thread()

    @pytest.mark.asyncio
    async def test_threadpool(self, context):
        event_loop_thread = current_thread()
        async with context.threadpool():
            assert current_thread() is not event_loop_thread

    @pytest.mark.asyncio
    async def test_threadpool_named_executor(self, context, special_executor):
        special_executor_thread = special_executor.submit(current_thread).result()
        async with context.threadpool('special'):
            assert current_thread() is special_executor_thread


class TestExecutor:
    @pytest.mark.asyncio
    async def test_no_arguments(self, context):
        @executor
        def runs_in_default_worker():
            assert current_thread() is not event_loop_thread

        event_loop_thread = current_thread()
        await runs_in_default_worker()

    @pytest.mark.asyncio
    async def test_named_executor(self, context, special_executor):
        @executor('special')
        def runs_in_default_worker(ctx):
            assert current_thread() is special_executor_thread

        special_executor_thread = special_executor.submit(current_thread).result()
        await runs_in_default_worker(context)

    @pytest.mark.asyncio
    async def test_executor_missing_context(self, event_loop, context):
        @executor('special')
        def runs_in_default_worker():
            pass

        with pytest.raises(RuntimeError) as exc:
            await runs_in_default_worker()

        exc.match(r'the first positional argument to %s\(\) has to be a Context instance' %
                  callable_name(runs_in_default_worker))


class TestContextTeardown:
    @pytest.mark.parametrize('expected_exc', [
        None, Exception('foo')
    ], ids=['no_exception', 'exception'])
    @pytest.mark.asyncio
    async def test_function(self, expected_exc):
        @context_teardown
        async def start(ctx: Context):
            nonlocal phase, received_exception
            phase = 'started'
            exc = await yield_()
            phase = 'finished'
            received_exception = exc

        phase = received_exception = None
        context = Context()
        await start(context)
        assert phase == 'started'

        await context.close(expected_exc)
        assert phase == 'finished'
        assert received_exception == expected_exc

    @pytest.mark.parametrize('expected_exc', [
        None, Exception('foo')
    ], ids=['no_exception', 'exception'])
    @pytest.mark.asyncio
    async def test_method(self, expected_exc):
        class SomeComponent:
            @context_teardown
            async def start(self, ctx: Context):
                nonlocal phase, received_exception
                phase = 'started'
                exc = await yield_()
                phase = 'finished'
                received_exception = exc

        phase = received_exception = None
        context = Context()
        await SomeComponent().start(context)
        assert phase == 'started'

        await context.close(expected_exc)
        assert phase == 'finished'
        assert received_exception == expected_exc

    def test_plain_function(self):
        def start(ctx):
            pass

        pytest.raises(TypeError, context_teardown, start).\
            match(' must be an async generator function')

    @pytest.mark.asyncio
    async def test_bad_args(self):
        @context_teardown
        async def start(ctx):
            pass

        with pytest.raises(RuntimeError) as exc:
            await start(None)

        exc.match(r'the first positional argument to %s\(\) has to be a Context instance' %
                  callable_name(start))

    @pytest.mark.asyncio
    async def test_exception(self):
        @context_teardown
        async def start(ctx):
            raise Exception('dummy error')

        context = Context()
        with pytest.raises(Exception) as exc_info:
            await start(context)

        exc_info.match('dummy error')

    @pytest.mark.asyncio
    async def test_missing_yield(self):
        @context_teardown
        async def start(ctx: Context):
            pass

        await start(Context())