import asyncio
import collections
import types
import warnings
import sys

from .connection import create_connection, _PUBSUB_COMMANDS
from .log import logger
from .util import parse_url, CloseEvent
from .errors import PoolClosedError
from .abc import AbcPool
from .locks import Lock


async def create_pool(address, *, db=None, password=None, ssl=None,
                      encoding=None, minsize=1, maxsize=10,
                      parser=None, loop=None, create_connection_timeout=None,
                      pool_cls=None, connection_cls=None):
    # FIXME: rewrite docstring
    """Creates Redis Pool.

    By default it creates pool of Redis instances, but it is
    also possible to create pool of plain connections by passing
    ``lambda conn: conn`` as commands_factory.

    *commands_factory* parameter is deprecated since v0.2.9

    All arguments are the same as for create_connection.

    Returns RedisPool instance or a pool_cls if it is given.
    """
    if pool_cls:
        assert issubclass(pool_cls, AbcPool),\
                "pool_class does not meet the AbcPool contract"
        cls = pool_cls
    else:
        cls = ConnectionsPool
    if isinstance(address, str):
        address, options = parse_url(address)
        db = options.setdefault('db', db)
        password = options.setdefault('password', password)
        encoding = options.setdefault('encoding', encoding)
        create_connection_timeout = options.setdefault(
            'timeout', create_connection_timeout)
        if 'ssl' in options:
            assert options['ssl'] or (not options['ssl'] and not ssl), (
                "Conflicting ssl options are set", options['ssl'], ssl)
            ssl = ssl or options['ssl']
        # TODO: minsize/maxsize

    pool = cls(address, db, password, encoding,
               minsize=minsize, maxsize=maxsize,
               ssl=ssl, parser=parser,
               create_connection_timeout=create_connection_timeout,
               connection_cls=connection_cls,
               loop=loop)
    try:
        await pool._fill_free(override_min=False)
    except Exception:
        pool.close()
        await pool.wait_closed()
        raise
    return pool


class ConnectionsPool(AbcPool):
    """Redis connections pool."""

    def __init__(self, address, db=None, password=None, encoding=None,
                 *, minsize, maxsize, ssl=None, parser=None,
                 create_connection_timeout=None,
                 connection_cls=None,
                 loop=None):
        assert isinstance(minsize, int) and minsize >= 0, (
            "minsize must be int >= 0", minsize, type(minsize))
        assert maxsize is not None, "Arbitrary pool size is disallowed."
        assert isinstance(maxsize, int) and maxsize > 0, (
            "maxsize must be int > 0", maxsize, type(maxsize))
        assert minsize <= maxsize, (
            "Invalid pool min/max sizes", minsize, maxsize)
        if loop is not None and sys.version_info >= (3, 8):
            warnings.warn("The loop argument is deprecated",
                          DeprecationWarning)
        self._address = address
        self._db = db
        self._password = password
        self._ssl = ssl
        self._encoding = encoding
        self._parser_class = parser
        self._minsize = minsize
        self._create_connection_timeout = create_connection_timeout
        self._pool = collections.deque(maxlen=maxsize)
        self._used = set()
        self._acquiring = 0
        self._cond = asyncio.Condition(lock=Lock())
        self._close_state = CloseEvent(self._do_close)
        self._pubsub_conn = None
        self._connection_cls = connection_cls

    def __repr__(self):
        return '<{} [db:{}, size:[{}:{}], free:{}]>'.format(
            self.__class__.__name__, self.db,
            self.minsize, self.maxsize, self.freesize)

    @property
    def minsize(self):
        """Minimum pool size."""
        return self._minsize

    @property
    def maxsize(self):
        """Maximum pool size."""
        return self._pool.maxlen

    @property
    def size(self):
        """Current pool size."""
        return self.freesize + len(self._used) + self._acquiring

    @property
    def freesize(self):
        """Current number of free connections."""
        return len(self._pool)

    @property
    def address(self):
        return self._address

    async def clear(self):
        """Clear pool connections.

        Close and remove all free connections.
        """
        async with self._cond:
            await self._do_clear()

    async def _do_clear(self):
        waiters = []
        while self._pool:
            conn = self._pool.popleft()
            conn.close()
            waiters.append(conn.wait_closed())
        await asyncio.gather(*waiters)

    async def _do_close(self):
        async with self._cond:
            assert not self._acquiring, self._acquiring
            waiters = []
            while self._pool:
                conn = self._pool.popleft()
                conn.close()
                waiters.append(conn.wait_closed())
            for conn in self._used:
                conn.close()
                waiters.append(conn.wait_closed())
            await asyncio.gather(*waiters)
            # TODO: close _pubsub_conn connection
            logger.debug("Closed %d connection(s)", len(waiters))

    def close(self):
        """Close all free and in-progress connections and mark pool as closed.
        """
        if not self._close_state.is_set():
            self._close_state.set()

    @property
    def closed(self):
        """True if pool is closed."""
        return self._close_state.is_set()

    async def wait_closed(self):
        """Wait until pool gets closed."""
        await self._close_state.wait()

    @property
    def db(self):
        """Currently selected db index."""
        return self._db or 0

    @property
    def encoding(self):
        """Current set codec or None."""
        return self._encoding

    def execute(self, command, *args, **kw):
        """Executes redis command in a free connection and returns
        future waiting for result.

        Picks connection from free pool and send command through
        that connection.
        If no connection is found, returns coroutine waiting for
        free connection to execute command.
        """
        conn, address = self.get_connection(command, args)
        if conn is not None:
            fut = conn.execute(command, *args, **kw)
            return self._check_result(fut, command, args, kw)
        else:
            coro = self._wait_execute(address, command, args, kw)
            return self._check_result(coro, command, args, kw)

    def execute_pubsub(self, command, *channels):
        """Executes Redis (p)subscribe/(p)unsubscribe commands.

        ConnectionsPool picks separate connection for pub/sub
        and uses it until explicitly closed or disconnected
        (unsubscribing from all channels/patterns will leave connection
         locked for pub/sub use).

        There is no auto-reconnect for this PUB/SUB connection.

        Returns asyncio.gather coroutine waiting for all channels/patterns
        to receive answers.
        """
        conn, address = self.get_connection(command)
        if conn is not None:
            return conn.execute_pubsub(command, *channels)
        else:
            return self._wait_execute_pubsub(address, command, channels, {})

    def get_connection(self, command, args=()):
        """Get free connection from pool.

        Returns connection.
        """
        # TODO: find a better way to determine if connection is free
        #       and not havily used.
        command = command.upper().strip()
        is_pubsub = command in _PUBSUB_COMMANDS
        if is_pubsub and self._pubsub_conn:
            if not self._pubsub_conn.closed:
                return self._pubsub_conn, self._pubsub_conn.address
            self._pubsub_conn = None
        for i in range(self.freesize):
            conn = self._pool[0]
            self._pool.rotate(1)
            if conn.closed:  # or conn._waiters: (eg: busy connection)
                continue
            if conn.in_pubsub:
                continue
            if is_pubsub:
                self._pubsub_conn = conn
                self._pool.remove(conn)
                self._used.add(conn)
            return conn, conn.address
        return None, self._address  # figure out

    def _check_result(self, fut, *data):
        """Hook to check result or catch exception (like MovedError).

        This method can be coroutine.
        """
        return fut

    async def _wait_execute(self, address, command, args, kw):
        """Acquire connection and execute command."""
        conn = await self.acquire(command, args)
        try:
            return (await conn.execute(command, *args, **kw))
        finally:
            self.release(conn)

    async def _wait_execute_pubsub(self, address, command, args, kw):
        if self.closed:
            raise PoolClosedError("Pool is closed")
        assert self._pubsub_conn is None or self._pubsub_conn.closed, (
            "Expected no or closed connection", self._pubsub_conn)
        async with self._cond:
            if self.closed:
                raise PoolClosedError("Pool is closed")
            if self._pubsub_conn is None or self._pubsub_conn.closed:
                conn = await self._create_new_connection(address)
                self._pubsub_conn = conn
            conn = self._pubsub_conn
            return (await conn.execute_pubsub(command, *args, **kw))

    async def select(self, db):
        """Changes db index for all free connections.

        All previously acquired connections will be closed when released.
        """
        res = True
        async with self._cond:
            for i in range(self.freesize):
                res = res and (await self._pool[i].select(db))
            self._db = db
        return res

    async def auth(self, password):
        self._password = password
        async with self._cond:
            for i in range(self.freesize):
                await self._pool[i].auth(password)

    @property
    def in_pubsub(self):
        if self._pubsub_conn and not self._pubsub_conn.closed:
            return self._pubsub_conn.in_pubsub
        return 0

    @property
    def pubsub_channels(self):
        if self._pubsub_conn and not self._pubsub_conn.closed:
            return self._pubsub_conn.pubsub_channels
        return types.MappingProxyType({})

    @property
    def pubsub_patterns(self):
        if self._pubsub_conn and not self._pubsub_conn.closed:
            return self._pubsub_conn.pubsub_patterns
        return types.MappingProxyType({})

    async def acquire(self, command=None, args=()):
        """Acquires a connection from free pool.

        Creates new connection if needed.
        """
        if self.closed:
            raise PoolClosedError("Pool is closed")
        async with self._cond:
            if self.closed:
                raise PoolClosedError("Pool is closed")
            while True:
                await self._fill_free(override_min=True)
                if self.freesize:
                    conn = self._pool.popleft()
                    assert not conn.closed, conn
                    assert conn not in self._used, (conn, self._used)
                    self._used.add(conn)
                    return conn
                else:
                    await self._cond.wait()

    def release(self, conn):
        """Returns used connection back into pool.

        When returned connection has db index that differs from one in pool
        the connection will be closed and dropped.
        When queue of free connections is full the connection will be dropped.
        """
        assert conn in self._used, (
            "Invalid connection, maybe from other pool", conn)
        self._used.remove(conn)
        if not conn.closed:
            if conn.in_transaction:
                logger.warning(
                    "Connection %r is in transaction, closing it.", conn)
                conn.close()
            elif conn.in_pubsub:
                logger.warning(
                    "Connection %r is in subscribe mode, closing it.", conn)
                conn.close()
            elif conn._waiters:
                logger.warning(
                    "Connection %r has pending commands, closing it.", conn)
                conn.close()
            elif conn.db == self.db:
                if self.maxsize and self.freesize < self.maxsize:
                    self._pool.append(conn)
                else:
                    # consider this connection as old and close it.
                    conn.close()
            else:
                conn.close()
        # FIXME: check event loop is not closed
        asyncio.ensure_future(self._wakeup())

    def _drop_closed(self):
        for i in range(self.freesize):
            conn = self._pool[0]
            if conn.closed:
                self._pool.popleft()
            else:
                self._pool.rotate(-1)

    async def _fill_free(self, *, override_min):
        # drop closed connections first
        self._drop_closed()
        # address = self._address
        while self.size < self.minsize:
            self._acquiring += 1
            try:
                conn = await self._create_new_connection(self._address)
                # check the healthy of that connection, if
                # something went wrong just trigger the Exception
                await conn.execute('ping')
                self._pool.append(conn)
            finally:
                self._acquiring -= 1
                # connection may be closed at yield point
                self._drop_closed()
        if self.freesize:
            return
        if override_min:
            while not self._pool and self.size < self.maxsize:
                self._acquiring += 1
                try:
                    conn = await self._create_new_connection(self._address)
                    self._pool.append(conn)
                finally:
                    self._acquiring -= 1
                    # connection may be closed at yield point
                    self._drop_closed()

    def _create_new_connection(self, address):
        return create_connection(address,
                                 db=self._db,
                                 password=self._password,
                                 ssl=self._ssl,
                                 encoding=self._encoding,
                                 parser=self._parser_class,
                                 timeout=self._create_connection_timeout,
                                 connection_cls=self._connection_cls,
                                 )

    async def _wakeup(self, closing_conn=None):
        async with self._cond:
            self._cond.notify()
        if closing_conn is not None:
            await closing_conn.wait_closed()

    def __enter__(self):
        raise RuntimeError(
            "'await' should be used as a context manager expression")

    def __exit__(self, *args):
        pass    # pragma: nocover

    def __await__(self):
        # To make `with await pool` work
        conn = yield from self.acquire().__await__()
        return _ConnectionContextManager(self, conn)

    def get(self):
        '''Return async context manager for working with connection.

        async with pool.get() as conn:
            await conn.execute('get', 'my-key')
        '''
        return _AsyncConnectionContextManager(self)


class _ConnectionContextManager:

    __slots__ = ('_pool', '_conn')

    def __init__(self, pool, conn):
        self._pool = pool
        self._conn = conn

    def __enter__(self):
        return self._conn

    def __exit__(self, exc_type, exc_value, tb):
        try:
            self._pool.release(self._conn)
        finally:
            self._pool = None
            self._conn = None


class _AsyncConnectionContextManager:

    __slots__ = ('_pool', '_conn')

    def __init__(self, pool):
        self._pool = pool
        self._conn = None

    async def __aenter__(self):
        conn = await self._pool.acquire()
        self._conn = conn
        return self._conn

    async def __aexit__(self, exc_type, exc_value, tb):
        try:
            self._pool.release(self._conn)
        finally:
            self._pool = None
            self._conn = None