from peewee import Database, ImproperlyConfigured, OP, QueryCompiler, CompoundSelect, SQL, Clause, CommaClause
from playhouse.db_url import register_database

__version__ = "0.1.4"

try:
    import pymssql
except ImportError:
    pymssql = None

try:
    from playhouse.pool import PooledDatabase
except ImportError:
    PooledDatabase = None

class MssqlQueryCompiler(QueryCompiler):
    # TODO: implement limit and offset properly, we can use:
    # SELECT *
    # FROM (SELECT ROW_NUMBER() OVER(ORDER BY id) RowNr, id FROM tbl) t
    # WHERE RowNr BETWEEN 10 AND 20

    def generate_select(self, query, alias_map=None):
        model = query.model_class
        db = model._meta.database

        alias_map = self.calculate_alias_map(query, alias_map)

        if isinstance(query, CompoundSelect):
            clauses = [_StripParens(query)]
        else:
            if not query._distinct:
                clauses = [SQL('SELECT')]
            else:
                clauses = [SQL('SELECT DISTINCT')]
                if query._distinct not in (True, False):
                    clauses += [SQL('ON'), EnclosedClause(*query._distinct)]

            # basic support for query limit
            if query._limit is not None or (query._offset and db.limit_max):
                limit = query._limit if query._limit is not None else db.limit_max
                clauses.append(SQL('TOP %s' % limit))

            select_clause = Clause(*query._select)
            select_clause.glue = ', '

            clauses.extend((select_clause, SQL('FROM')))
            if query._from is None:
                clauses.append(model.as_entity().alias(alias_map[model]))
            else:
                clauses.append(CommaClause(*query._from))

        if query._windows is not None:
            clauses.append(SQL('WINDOW'))
            clauses.append(CommaClause(*[
                Clause(
                    SQL(window._alias),
                    SQL('AS'),
                    window.__sql__())
                for window in query._windows]))

        join_clauses = self.generate_joins(query._joins, model, alias_map)
        if join_clauses:
            clauses.extend(join_clauses)

        if query._where is not None:
            clauses.extend([SQL('WHERE'), query._where])

        if query._group_by:
            clauses.extend([SQL('GROUP BY'), CommaClause(*query._group_by)])

        if query._having:
            clauses.extend([SQL('HAVING'), query._having])

        if query._order_by:
            clauses.extend([SQL('ORDER BY'), CommaClause(*query._order_by)])

        # NO OFFSET SUPPORT

        if query._for_update:
            for_update, no_wait = query._for_update
            if for_update:
                stmt = 'FOR UPDATE NOWAIT' if no_wait else 'FOR UPDATE'
                clauses.append(SQL(stmt))

        return self.build_query(clauses, alias_map)

class MssqlDatabase(Database):
    compiler_class = MssqlQueryCompiler
    commit_select = False
    interpolation = '%s'
    quote_char = '"'

    field_overrides = {
        'bool': 'tinyint',
        'double': 'float(53)',
        'float': 'float',
        'int': 'int',
        'string': 'nvarchar',
        'fixed_char': 'nchar',
        'text': 'nvarchar(max)',
        'blob': 'varbinary',
        'uuid': 'nchar(40)',
        'primary_key': 'int identity',
        'datetime': 'datetime2',
        'date': 'date',
        'time': 'time',
    }

    op_overrides = {
        OP.LIKE: 'LIKE BINARY',
        OP.ILIKE: 'LIKE',
    }

    def _connect(self, database, **kwargs):
        if not pymssql:
            raise ImproperlyConfigured('pymssql must be installed')

        if kwargs.pop('use_legacy_datetime', False):
            self.field_overrides['datetime'] = 'datetime'
            self.field_overrides['date'] = 'nvarchar(15)'
            self.field_overrides['time'] = 'nvarchar(10)'

        return pymssql.connect(database=database, **kwargs)

    def get_tables(self, schema=None):
        # should I not be using sys.tables?
        if schema:
            query = ('SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE '
                     'TABLE_SCHEMA = %s AND TABLE_TYPE = %s ORDER BY TABLE_NAME')
            cursor = self.execute_sql(query, (schema, 'BASE TABLE',),
                                      require_commit=False)
        else:
            query = ('SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE '
                     'TABLE_TYPE = %s ORDER BY TABLE_NAME')
            cursor = self.execute_sql(query, ('BASE TABLE',),
                                      require_commit=False)

        return [row[0] for row in cursor.fetchall()]

    def execute_sql(self, sql, params, *args, **kwargs):
        # convert params to tuple
        params = tuple(params)

        return super(MssqlDatabase, self).execute_sql(sql, params, *args, **kwargs)

register_database(MssqlDatabase, 'mssql')

if PooledDatabase:
    class PooledMssqlDatabase(PooledDatabase, MssqlDatabase):
        pass # TODO: implement _is_closed()

    register_database(PooledMssqlDatabase, 'mssql+pool')