import multiprocessing
from concurrent.futures import Executor, ThreadPoolExecutor
from contextlib import contextmanager
from typing import Callable, Iterator, Optional

from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base, DeclarativeMeta
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm.session import Session
from tornado.concurrent import Future, chain_future
from tornado.ioloop import IOLoop
from tornado.web import Application

__all__ = ('as_future', 'SessionMixin', 'set_max_workers', 'SQLAlchemy')


class MissingFactoryError(Exception):
    pass


class MissingDatabaseSettingError(Exception):
    pass


class _AsyncExecution:
    """Tiny wrapper around ThreadPoolExecutor. This class is not meant to be
    instantiated externally, but internally we just use it as a wrapper around
    ThreadPoolExecutor so we can control the pool size and make the
    `as_future` function public.
    """

    def __init__(self, max_workers: Optional[int] = None):
        self._max_workers = (
            max_workers or multiprocessing.cpu_count()
        )  # type: int
        self._pool = None  # type: Optional[Executor]

    def set_max_workers(self, count: int):
        if self._pool:
            self._pool.shutdown(wait=True)

        self._max_workers = count
        self._pool = ThreadPoolExecutor(max_workers=self._max_workers)

    def as_future(self, query: Callable) -> Future:
        # concurrent.futures.Future is not compatible with the "new style"
        # asyncio Future, and awaiting on such "old-style" futures does not
        # work.
        #
        # tornado includes a `run_in_executor` function to help with this
        # problem, but it's only included in version 5+. Hence, we copy a
        # little bit of code here to handle this incompatibility.

        if not self._pool:
            self._pool = ThreadPoolExecutor(max_workers=self._max_workers)

        old_future = self._pool.submit(query)
        new_future = Future()  # type: Future

        IOLoop.current().add_future(
            old_future, lambda f: chain_future(f, new_future)
        )

        return new_future


class SessionMixin:
    _session = None  # type: Optional[Session]
    application = None  # type: Optional[Application]

    @contextmanager
    def make_session(self) -> Iterator[Session]:
        session = None

        try:
            session = self._make_session()

            yield session
        except Exception:
            if session:
                session.rollback()
            raise
        else:
            session.commit()
        finally:
            if session:
                session.close()

    def on_finish(self):
        next_on_finish = None

        try:
            next_on_finish = super(SessionMixin, self).on_finish
        except AttributeError:
            pass

        if self._session:
            self._session.commit()
            self._session.close()

        if next_on_finish:
            next_on_finish()

    @property
    def session(self) -> Session:
        if not self._session:
            self._session = self._make_session()
        return self._session

    def _make_session(self) -> Session:
        if not self.application:
            raise MissingFactoryError()

        db = self.application.settings.get('db')
        if not db:
            raise MissingDatabaseSettingError()
        return db.sessionmaker()


_async_exec = _AsyncExecution()

as_future = _async_exec.as_future

set_max_workers = _async_exec.set_max_workers


class SessionEx(Session):
    """The SessionEx extends the default session system with bind selection.
    """

    def __init__(self, db, autocommit=False, autoflush=True, **options):
        self.db = db
        bind = options.pop('bind', None) or db.engine
        binds = options.pop('binds', db.get_binds())

        super().__init__(
            autocommit=autocommit,
            autoflush=autoflush,
            bind=bind,
            binds=binds,
            **options
        )

    def get_bind(self, mapper=None, clause=None):
        """Return the engine or connection for a given model or
        table, using the `__bind_key__` if it is set.
        """

        if mapper is not None:
            try:
                # SA >= 1.3
                persist_selectable = mapper.persist_selectable
            except AttributeError:
                # SA < 1.3
                persist_selectable = mapper.mapped_table

            info = getattr(persist_selectable, 'info', {})
            bind_key = info.get('bind_key')

            if bind_key is not None:
                return self.db.get_engine(bind=bind_key)

        return super().get_bind(mapper, clause)


class BindMeta(DeclarativeMeta):
    def __init__(cls, name, bases, d):
        bind_key = d.pop('__bind_key__', None) or getattr(
            cls, '__bind_key__', None
        )

        super(BindMeta, cls).__init__(name, bases, d)

        if (
            bind_key is not None
            and getattr(cls, '__table__', None) is not None
        ):
            cls.__table__.info['bind_key'] = bind_key


class SQLAlchemy:
    def __init__(
        self, url=None, binds=None, session_options=None, engine_options=None
    ):
        self.Model = self.make_declarative_base()
        self._engines = {}

        self.configure(
            url=url,
            binds=binds,
            session_options=session_options,
            engine_options=engine_options,
        )

    def configure(
        self, url=None, binds=None, session_options=None, engine_options=None
    ):
        self.url = url
        self.binds = binds or {}
        self._engine_options = engine_options or {}

        self.sessionmaker = sessionmaker(
            class_=SessionEx, db=self, **(session_options or {})
        )

    @property
    def engine(self):
        return self.get_engine()

    @property
    def metadata(self):
        return self.Model.metadata

    def create_engine(self, bind=None):
        if not self.url and not self.binds:
            raise MissingDatabaseSettingError()

        if bind is None:
            url = self.url
        else:
            if bind not in self.binds:
                raise RuntimeError('bind {} undefined.'.format(bind))
            url = self.binds[bind]

        return create_engine(url, **self._engine_options)

    def get_engine(self, bind=None):
        """Returns a specific engine. cached in self._engines """
        engine = self._engines.get(bind)

        if engine is None:
            engine = self.create_engine(bind)
            self._engines[bind] = engine

        return engine

    def get_tables_for_bind(self, bind=None):
        """Returns a list of all tables relevant for a bind."""
        return [
            table
            for table in self.Model.metadata.tables.values()
            if table.info.get('bind_key') == bind
        ]

    def get_binds(self):
        """Returns a dictionary with a table->engine mapping.

        This is suitable for use of sessionmaker(binds=db.get_binds()).
        """
        binds = [None] + list(self.binds)

        result = {}

        for bind in binds:
            engine = self.get_engine(bind)
            tables = self.get_tables_for_bind(bind)

            result.update(dict((table, engine) for table in tables))

        return result

    def _execute_for_all_tables(self, bind, operation, skip_tables=False):
        if bind == '__all__':
            binds = [None] + list(self.binds)
        elif isinstance(bind, str) or bind is None:
            binds = [bind]
        else:
            binds = bind

        for bind in binds:
            extra = {}

            if not skip_tables:
                tables = self.get_tables_for_bind(bind)
                extra['tables'] = tables

            op = getattr(self.Model.metadata, operation)
            op(bind=self.get_engine(bind), **extra)

    def create_all(self, bind='__all__'):
        """Creates all tables.
        """
        self._execute_for_all_tables(bind, 'create_all')

    def drop_all(self, bind='__all__'):
        """Drops all tables.
        """
        self._execute_for_all_tables(bind, 'drop_all')

    def make_declarative_base(self):
        return declarative_base(metaclass=BindMeta)