try:
    from sqlalchemy import create_engine, inspect
    from sqlalchemy.orm import scoped_session, sessionmaker
    from sqlalchemy.pool import StaticPool
except:
    def create_engine(*args, **kwargs):
        raise Exception("You need to install sqlalchemy")

import logging
import json

from ..bson import ObjectId
from ..json import JSONDecoder, JSONEncoder
from .base import Database
from .listdb import ListTable

class SQLList():
    def __init__(self, database, name):
        self.database = database
        self.name = name

    def __getitem__(self, item):
        if isinstance(item, int):
            result = self.database.execute(
                """SELECT blob_data FROM %s WHERE rowid = :rowid"""
                % (self.name), {"rowid": item})
            item = result.fetchone()
            if item:
                item = json.loads(item[0], cls=JSONDecoder)
        elif isinstance(item, slice):
            items = [item in islice(self, item.start, item.stop, item.step)]
            return items
        if item:
            return item
        else:
            raise IndexError("list index out of range")

    def __setitem__(self, item, value):
        s = json.dumps(value, cls=JSONEncoder)
        # first see if it exists:
        try:
            old_item = self[item]
        except:
            old_item = None
        if old_item:
            # update it
            try:
                self.database.execute(
                    """UPDATE %s SET blob_data = :s WHERE rowid = :rowid;"""
                    % (self.name), {"s": s, "rowid": item})
                self.database.commit()
            except:
                self.database.rollback()
                raise
        else:
            # insert it
            oid = str(value["_id"])
            try:
                self.database.execute(
                    """INSERT INTO %s (blob_data, oid, rowid) VALUES (:s, :oid, :rowid);"""
                    % (self.name), {"s": s, "rowid": item, "oid": oid})
                self.database.commit()
            except:
                self.database.rollback()
                raise

    def __delitem__(self, key):
        try:
            self.database.execute(
                """DELETE FROM %s WHERE rowid = :rowid;"""
                % (self.name), {"rowid": key})
            self.database.execute(
                """UPDATE %s SET rowid = (rowid - 1) WHERE rowid > :rowid;"""
                % self.name, {"rowid": key})
            self.database.commit()
        except:
            self.database.rollback()
            raise

    def clear(self):
        try:
            self.database.execute("DELETE from %s;" % self.name)
            self.database.commit()
        except:
            self.database.rollback()
            raise

    def append(self, item):
        pos = len(self)
        self[pos] = item

    def __len__(self):
        result = self.database.execute("SELECT count(1) FROM %s" % self.name)
        row = result.fetchone()
        return row[0]

class SQLTable(ListTable):
    def __init__(self, database, name):
        super().__init__(database, name)
        if not self.database.table_exists(name):
            self.database.build_table(name)
        self.data = SQLList(database, name)

    def get_schema(self):
        ins = inspect(self.database.engine)
        return ins.get_columns(self.name)

    def get_columns(self):
        schema = self.get_schema()
        return [d["name"] for d in schema]

    def build_compare(self, lhs, rhs):
        if isinstance(rhs, dict):
            q = []
            for item in rhs:
                if item == "$regex":
                    q.append("SQL regex") ## FIXME
                elif item == "$lt":
                    q.append("(%s < %s)" % (lhs, rhs[item]))
                elif item == "$gt":
                    q.append("(%s > %s)" % (lhs, rhs[item]))
                elif item == "$in":
                    if isinstance(lhs, list):
                        q.append("(%s IN %s)" % (lhs, rhs)) ## FIXME?
                    else:
                        q.append("(%s IN %s)" % (lhs, rhs[item])) ## FIXME?
                else:
                    raise Exception("unknown operator: %s" % item)
            return "(" + (" AND ".join(q)) + ")"
        else:
            if isinstance(lhs, list):
                if isinstance(rhs, list):
                    return "(%s = %s)" % (lhs, repr(rhs)) ## FIXME?
                else:
                    return "(%s IN %s)" % (rhs, lhs) ## FIXME?
            else:
                return "(%s = %s)" % (lhs, repr(rhs))

    def build_query(self, query, limit=None):
        q = []
        for item in query:
            if item == "$or":
                expr = "(" + (" OR ".join([self.build_query(each) for each in query[item]])) + ")"
            elif item == "$and":
                expr = "(" + (" AND ".join([self.build_query(each) for each in query[item]])) + ")"
            else:
                expr = self.build_compare(item, query[item])
            q.append(expr)
        return "(" + (" AND ".join(q)) + ")"

    def find(self, query=None, limit=None, enumerated=False):
        ## if the query contains a SQL table field, then
        ## use that portion
        ## WIP: find portion of query that can be SQL selected
        ## NOTE: limit can only be applied if full query applies
        # logging.info("query: %s" % query)
        # if query is not None or limit is not None:
        #     q = self.build_query(query, limit)
        #     logging.info("built q: %s" % q)
        #     if False: ## TODO: handle query
        #         results = self.database.execute(q)
        #         return ListTable(data=results.fetchall()).find(query, enumerated=enumerated)
        ## else, just go through all of the items
        return super().find(query, limit, enumerated)

    def sort(self, sort_key, sort_order):
        # sort_key = "_id"
        # sort_order = 1 or -1
        ## Always use ListTable here:
        return ListTable(data=sorted(
            self.data,
            key=lambda row: self.get_item_in_dict(row, sort_key),
            reverse=(sort_order == -1)))

class SQLDatabase(Database):
    Table = SQLTable

    def __init__(self, *args, **kwargs):
        super().__init__()
        args = list(args)
        if args[0].endswith(":memory:"):
            args[0] = args[0].replace(":memory:", "")
        if args[0] == "sqlite://": # in-memory
            kwargs.update({
                "connect_args": {'check_same_thread': False},
                "poolclass": StaticPool,
            })
            self.engine = create_engine(*args, **kwargs)
            self.session = sessionmaker(bind=self.engine)()
        else:
            self.engine = create_engine(*args, **kwargs)
            self.session = scoped_session(sessionmaker(bind=self.engine))

    def commit(self):
        self.session.commit()

    def rollback(self):
        self.session.rollback()

    def execute(self, *args, **kwargs):
        return self.session.execute(*args, **kwargs)

    def table_exists(self, table):
        ins = inspect(self.engine)
        return table in ins.get_table_names()

    def build_table(self, name):
        try:
            self.execute(
                """CREATE TABLE %s (
                    rowid INTEGER PRIMARY KEY ASC,
                    oid CHAR(24),
                    blob_data BLOB
                )""" % name)
            self.commit()
        except:
            self.rollback()
            raise