import asyncio
import functools

from ..abc import AbcPool
from ..errors import (
    RedisError,
    PipelineError,
    MultiExecError,
    ConnectionClosedError,
    )
from ..util import (
    wait_ok,
    _set_exception,
    get_event_loop,
    )


class TransactionsCommandsMixin:
    """Transaction commands mixin.

    For commands details see: http://redis.io/commands/#transactions

    Transactions HOWTO:

    >>> tr = redis.multi_exec()
    >>> result_future1 = tr.incr('foo')
    >>> result_future2 = tr.incr('bar')
    >>> try:
    ...     result = await tr.execute()
    ... except MultiExecError:
    ...     pass    # check what happened
    >>> result1 = await result_future1
    >>> result2 = await result_future2
    >>> assert result == [result1, result2]
    """

    def unwatch(self):
        """Forget about all watched keys."""
        fut = self._pool_or_conn.execute(b'UNWATCH')
        return wait_ok(fut)

    def watch(self, key, *keys):
        """Watch the given keys to determine execution of the MULTI/EXEC block.
        """
        # FIXME: we can send watch through one connection and then issue
        #   'multi/exec' command through other.
        # Possible fix:
        #   "Remember" a connection that was used for 'watch' command
        #   and then send 'multi / exec / discard' through it.
        fut = self._pool_or_conn.execute(b'WATCH', key, *keys)
        return wait_ok(fut)

    def multi_exec(self):
        """Returns MULTI/EXEC pipeline wrapper.

        Usage:

        >>> tr = redis.multi_exec()
        >>> fut1 = tr.incr('foo')   # NO `await` as it will block forever!
        >>> fut2 = tr.incr('bar')
        >>> result = await tr.execute()
        >>> result
        [1, 1]
        >>> await asyncio.gather(fut1, fut2)
        [1, 1]
        """
        return MultiExec(self._pool_or_conn, self.__class__)

    def pipeline(self):
        """Returns :class:`Pipeline` object to execute bulk of commands.

        It is provided for convenience.
        Commands can be pipelined without it.

        Example:

        >>> pipe = redis.pipeline()
        >>> fut1 = pipe.incr('foo') # NO `await` as it will block forever!
        >>> fut2 = pipe.incr('bar')
        >>> result = await pipe.execute()
        >>> result
        [1, 1]
        >>> await asyncio.gather(fut1, fut2)
        [1, 1]
        >>> #
        >>> # The same can be done without pipeline:
        >>> #
        >>> fut1 = redis.incr('foo')    # the 'INCRY foo' command already sent
        >>> fut2 = redis.incr('bar')
        >>> await asyncio.gather(fut1, fut2)
        [2, 2]
        """
        return Pipeline(self._pool_or_conn, self.__class__)


class _RedisBuffer:

    def __init__(self, pipeline, *, loop=None):
        # TODO: deprecation note
        # if loop is None:
        #     loop = asyncio.get_event_loop()
        self._pipeline = pipeline

    def execute(self, cmd, *args, **kw):
        fut = get_event_loop().create_future()
        self._pipeline.append((fut, cmd, args, kw))
        return fut

    # TODO: add here or remove in connection methods like `select`, `auth` etc


class Pipeline:
    """Commands pipeline.

    Usage:

    >>> pipe = redis.pipeline()
    >>> fut1 = pipe.incr('foo')
    >>> fut2 = pipe.incr('bar')
    >>> await pipe.execute()
    [1, 1]
    >>> await fut1
    1
    >>> await fut2
    1
    """
    error_class = PipelineError

    def __init__(self, pool_or_connection, commands_factory=lambda conn: conn,
                 *, loop=None):
        # TODO: deprecation note
        # if loop is None:
        #     loop = asyncio.get_event_loop()
        self._pool_or_conn = pool_or_connection
        self._pipeline = []
        self._results = []
        self._buffer = _RedisBuffer(self._pipeline)
        self._redis = commands_factory(self._buffer)
        self._done = False

    def __getattr__(self, name):
        assert not self._done, "Pipeline already executed. Create new one."
        attr = getattr(self._redis, name)
        if callable(attr):

            @functools.wraps(attr)
            def wrapper(*args, **kw):
                try:
                    task = asyncio.ensure_future(attr(*args, **kw))
                except Exception as exc:
                    task = get_event_loop().create_future()
                    task.set_exception(exc)
                self._results.append(task)
                return task
            return wrapper
        return attr

    async def execute(self, *, return_exceptions=False):
        """Execute all buffered commands.

        Any exception that is raised by any command is caught and
        raised later when processing results.

        Exceptions can also be returned in result if
        `return_exceptions` flag is set to True.
        """
        assert not self._done, "Pipeline already executed. Create new one."
        self._done = True

        if self._pipeline:
            if isinstance(self._pool_or_conn, AbcPool):
                async with self._pool_or_conn.get() as conn:
                    return await self._do_execute(
                        conn, return_exceptions=return_exceptions)
            else:
                return await self._do_execute(
                    self._pool_or_conn,
                    return_exceptions=return_exceptions)
        else:
            return await self._gather_result(return_exceptions)

    async def _do_execute(self, conn, *, return_exceptions=False):
        await asyncio.gather(*self._send_pipeline(conn),
                             return_exceptions=True)
        return await self._gather_result(return_exceptions)

    async def _gather_result(self, return_exceptions):
        errors = []
        results = []
        for fut in self._results:
            try:
                res = await fut
                results.append(res)
            except Exception as exc:
                errors.append(exc)
                results.append(exc)
        if errors and not return_exceptions:
            raise self.error_class(errors)
        return results

    def _send_pipeline(self, conn):
        with conn._buffered():
            for fut, cmd, args, kw in self._pipeline:
                try:
                    result_fut = conn.execute(cmd, *args, **kw)
                    result_fut.add_done_callback(
                        functools.partial(self._check_result, waiter=fut))
                except Exception as exc:
                    fut.set_exception(exc)
                else:
                    yield result_fut

    def _check_result(self, fut, waiter):
        if fut.cancelled():
            waiter.cancel()
        elif fut.exception():
            waiter.set_exception(fut.exception())
        else:
            waiter.set_result(fut.result())


class MultiExec(Pipeline):
    """Multi/Exec pipeline wrapper.

    Usage:

    >>> tr = redis.multi_exec()
    >>> f1 = tr.incr('foo')
    >>> f2 = tr.incr('bar')
    >>> # A)
    >>> await tr.execute()
    >>> res1 = await f1
    >>> res2 = await f2
    >>> # or B)
    >>> res1, res2 = await tr.execute()

    and ofcourse try/except:

    >>> tr = redis.multi_exec()
    >>> f1 = tr.incr('1') # won't raise any exception (why?)
    >>> try:
    ...     res = await tr.execute()
    ... except RedisError:
    ...     pass
    >>> assert f1.done()
    >>> assert f1.result() is res

    >>> tr = redis.multi_exec()
    >>> wait_ok_coro = tr.mset('1')
    >>> try:
    ...     ok1 = await tr.execute()
    ... except RedisError:
    ...     pass # handle it
    >>> ok2 = await wait_ok_coro
    >>> # for this to work `wait_ok_coro` must be wrapped in Future
    """
    error_class = MultiExecError

    async def _do_execute(self, conn, *, return_exceptions=False):
        self._waiters = waiters = []
        with conn._buffered():
            multi = conn.execute('MULTI')
            coros = list(self._send_pipeline(conn))
            exec_ = conn.execute('EXEC')
        gather = asyncio.gather(multi, *coros,
                                return_exceptions=True)
        last_error = None
        try:
            await asyncio.shield(gather)
        except asyncio.CancelledError:
            await gather
        except Exception as err:
            last_error = err
            raise
        finally:
            if conn.closed:
                if last_error is None:
                    last_error = ConnectionClosedError()
                for fut in waiters:
                    _set_exception(fut, last_error)
                    # fut.cancel()
                for fut in self._results:
                    if not fut.done():
                        fut.set_exception(last_error)
                        # fut.cancel()
            else:
                try:
                    results = await exec_
                except RedisError as err:
                    for fut in waiters:
                        fut.set_exception(err)
                else:
                    assert len(results) == len(waiters), (
                        "Results does not match waiters", results, waiters)
                    self._resolve_waiters(results, return_exceptions)
            return (await self._gather_result(return_exceptions))

    def _resolve_waiters(self, results, return_exceptions):
        errors = []
        for val, fut in zip(results, self._waiters):
            if isinstance(val, RedisError):
                fut.set_exception(val)
                errors.append(val)
            else:
                fut.set_result(val)
        if errors and not return_exceptions:
            raise MultiExecError(errors)

    def _check_result(self, fut, waiter):
        assert waiter not in self._waiters, (fut, waiter, self._waiters)
        assert not waiter.done(), waiter
        if fut.cancelled():     # await gather was cancelled
            waiter.cancel()
        elif fut.exception():   # server replied with error
            waiter.set_exception(fut.exception())
        elif fut.result() in {b'QUEUED', 'QUEUED'}:
            # got result, it should be QUEUED
            self._waiters.append(waiter)