import threading
from peewee import Database, ExceptionWrapper, basestring
from peewee import sort_models_topologically, merge_dict
from peewee import OperationalError
from peewee import (RESULTS_NAIVE, RESULTS_TUPLES, RESULTS_DICTS,
                    RESULTS_AGGREGATE_MODELS, RESULTS_MODELS)
from peewee import SQL, R, Clause, fn, binary_construct
from peewee import logger

from .context import _aio_atomic, aio_transaction, aio_savepoint
from .result import (AioNaiveQueryResultWrapper, AioModelQueryResultWrapper,
                     AioTuplesQueryResultWrapper, AioDictQueryResultWrapper,
                     AioAggregateQueryResultWrapper)


# remove this one, just use autocommit arg in db.execute_sql
# in case of a transaction, the connection should be bounded
# to the atomic/transaction context manager
class AioConnection(object):

    def __init__(self, acquirer, exception_wrapper,
                 autocommit=None, autorollback=None):
        self.autocommit = autocommit
        self.autorollback = autorollback
        self.acquirer = acquirer
        self.closed = True
        self.conn = None
        self.context_stack = []
        self.transactions = []
        self.exception_wrapper = exception_wrapper  # TODO: remove

    def transaction_depth(self):
        return len(self.transactions)

    def push_transaction(self, transaction):
        self.transactions.append(transaction)

    def pop_transaction(self):
        return self.transactions.pop()

    async def execute_sql(self, sql, params=None, require_commit=True):
        logger.debug((sql, params))
        with self.exception_wrapper:
            cursor = await self.conn.cursor()
            try:
                await cursor.execute(sql, params or ())
            except Exception:
                if self.autorollback and self.autocommit:
                    await self.rollback()
                raise
            else:
                if require_commit and self.autocommit:
                    await self.commit()
            return cursor

    async def __aenter__(self):
        self.conn = await self.acquirer.__aenter__()
        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        await self.acquirer.__aexit__(exc_type, exc_val, exc_tb)

    async def begin(self):
        pass

    def commit(self):
        with self.exception_wrapper:
            return self.conn.commit()

    def rollback(self):
        with self.exception_wrapper:
            return self.conn.rollback()

    # def close(self):
    #     # self.conn_pool.release(conn)
    #     return self.conn.close()

    def transaction(self, transaction_type=None):
        return aio_transaction(self, transaction_type)
    commit_on_success = property(transaction)

    def savepoint(self, sid=None):
        if not self.savepoints:
            raise NotImplementedError
        return aio_savepoint(self, sid)


class AioDatabase(Database):

    def begin(self):
        raise NotImplementedError

    def commit(self):
        raise NotImplementedError

    def rollback(self):
        raise NotImplementedError

    def get_cursor(self):
        raise NotImplementedError

    def get_tables(self, schema=None):
        raise NotImplementedError

    def get_indexes(self, table, schema=None):
        raise NotImplementedError

    def get_columns(self, table, schema=None):
        raise NotImplementedError

    def get_primary_keys(self, table, schema=None):
        raise NotImplementedError

    def get_foreign_keys(self, table, schema=None):
        raise NotImplementedError

    def sequence_exists(self, seq):
        raise NotImplementedError

    def transaction_depth(self):
        raise NotImplementedError

    def __init__(self, database, threadlocals=True, autocommit=True,
                 fields=None, ops=None, autorollback=False,
                 **connect_kwargs):
        self.connect_kwargs = {}
        self.closed = True
        self.init(database, **connect_kwargs)

        self.pool = None

        self.autocommit = autocommit
        self.autorollback = autorollback
        self.use_speedups = False

        self.field_overrides = merge_dict(self.field_overrides, fields or {})
        self.op_overrides = merge_dict(self.op_overrides, ops or {})
        self.exception_wrapper = ExceptionWrapper(self.exceptions)

    def is_closed(self):
        return self.closed

    def get_conn(self):
        if self.closed:
            raise OperationalError('Database pool has not been initialized')

        return AioConnection(self.pool.acquire(),
                             autocommit=self.autocommit,
                             autorollback=self.autorollback,
                             exception_wrapper=self.exception_wrapper)

    async def close(self):
        if self.deferred:
            raise Exception('Error, database not properly initialized '
                            'before closing connection')
        with self.exception_wrapper:
            if not self.closed and self.pool:
                self.pool.close()
                self.closed = True
                await self.pool.wait_closed()

    async def connect(self, safe=True):
        if self.deferred:
            raise OperationalError('Database has not been initialized')
        if not self.closed:
            if safe:
                return
            raise OperationalError('Connection already open')

        with self.exception_wrapper:
            self.pool = await self._connect(self.database,
                                            **self.connect_kwargs)
            self.closed = False

    def get_result_wrapper(self, wrapper_type):
        if wrapper_type == RESULTS_NAIVE:
            return AioNaiveQueryResultWrapper
        elif wrapper_type == RESULTS_MODELS:
            return AioModelQueryResultWrapper
        elif wrapper_type == RESULTS_TUPLES:
            return AioTuplesQueryResultWrapper
        elif wrapper_type == RESULTS_DICTS:
            return AioDictQueryResultWrapper
        elif wrapper_type == RESULTS_AGGREGATE_MODELS:
            return AioAggregateQueryResultWrapper
        else:
            return AioNaiveQueryResultWrapper

    def atomic(self, transaction_type=None):
        return _aio_atomic(self.get_conn(), transaction_type)

    def transaction(self, transaction_type=None):
        return aio_transaction(self, transaction_type)
    commit_on_success = property(transaction)

    # def savepoint(self, sid=None):
    #     if not self.savepoints:
    #         raise NotImplementedError
    #     return aio_savepoint(self, sid)

    async def create_table(self, model_class, safe=False):
        qc = self.compiler()
        async with self.get_conn() as conn:
            args = qc.create_table(model_class, safe)
            return await conn.execute_sql(*args)

    async def create_tables(self, models, safe=False):
        await create_model_tables(models, fail_silently=safe)

    async def create_index(self, model_class, fields, unique=False):
        qc = self.compiler()
        if not isinstance(fields, (list, tuple)):
            raise ValueError('Fields passed to "create_index" must be a list '
                             'or tuple: "%s"' % fields)
        fobjs = [model_class._meta.fields[f]
                 if isinstance(f, basestring) else f
                 for f in fields]
        async with self.get_conn() as conn:
            args = qc.create_index(model_class, fobjs, unique)
            return await conn.execute_sql(*args)

    async def drop_index(self, model_class, fields, safe=False):
        qc = self.compiler()
        if not isinstance(fields, (list, tuple)):
            raise ValueError('Fields passed to "drop_index" must be a list '
                             'or tuple: "%s"' % fields)
        fobjs = [model_class._meta.fields[f]
                 if isinstance(f, basestring) else f
                 for f in fields]
        async with self.get_conn() as conn:
            args = qc.drop_index(model_class, fobjs, safe)
            return await conn.execute_sql(*args)

    async def create_foreign_key(self, model_class, field, constraint=None):
        qc = self.compiler()
        async with self.get_conn() as conn:
            args = qc.create_foreign_key(model_class, field, constraint)
            return await conn.execute_sql(*args)

    async def create_sequence(self, seq):
        if self.sequences:
            qc = self.compiler()
            async with self.get_conn() as conn:
                return await conn.execute_sql(*qc.create_sequence(seq))

    async def drop_table(self, model_class, fail_silently=False, cascade=False):
        qc = self.compiler()
        if cascade and not self.drop_cascade:
            raise ValueError('Database does not support DROP TABLE..CASCADE.')
        async with self.get_conn() as conn:
            args = qc.drop_table(model_class, fail_silently, cascade)
            return await conn.execute_sql(*args)

    async def drop_tables(self, models, safe=False, cascade=False):
        await drop_model_tables(models, fail_silently=safe, cascade=cascade)

    async def truncate_table(self, model_class, restart_identity=False,
                             cascade=False):
        qc = self.compiler()
        async with self.get_conn() as conn:
            args = qc.truncate_table(model_class, restart_identity, cascade)
            return await conn.execute_sql(*args)

    async def truncate_tables(self, models, restart_identity=False,
                              cascade=False):
        for model in reversed(sort_models_topologically(models)):
            await model.truncate_table(restart_identity, cascade)

    async def drop_sequence(self, seq):
        if self.sequences:
            qc = self.compiler()
            async with self.get_conn() as conn:
                return await conn.execute_sql(*qc.drop_sequence(seq))

    async def execute_sql(self, sql, params=None, require_commit=True):
        async with self.get_conn() as conn:
            return await conn.execute_sql(sql, params,
                                          require_commit=require_commit)

    def extract_date(self, date_part, date_field):
        return fn.EXTRACT(Clause(date_part, R('FROM'), date_field))

    def truncate_date(self, date_part, date_field):
        return fn.DATE_TRUNC(date_part, date_field)

    def default_insert_clause(self, model_class):
        return SQL('DEFAULT VALUES')

    def get_noop_sql(self):
        return 'SELECT 0 WHERE 0'

    def get_binary_type(self):
        return binary_construct


async def create_model_tables(models, **create_table_kwargs):
    """Create tables for all given models (in the right order)."""
    for m in sort_models_topologically(models):
        await m.create_table(**create_table_kwargs)


async def drop_model_tables(models, **drop_table_kwargs):
    """Drop tables for all given models (in the right order)."""
    for m in reversed(sort_models_topologically(models)):
        await m.drop_table(**drop_table_kwargs)