from typing import Optional, Tuple, Any, Union

from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm.scoping import scoped_session
from sqlalchemy import Column, String, Integer, BigInteger, LargeBinary, orm, func, select, and_
import sqlalchemy as sql

from .orm import AlchemySession
from .core import AlchemyCoreSession
from .core_mysql import AlchemyMySQLCoreSession
from .core_sqlite import AlchemySQLiteCoreSession
from .core_postgres import AlchemyPostgresCoreSession

LATEST_VERSION = 2


class AlchemySessionContainer:
    def __init__(self, engine: Union[sql.engine.Engine, str] = None,
                 session: Optional[Union[orm.Session, scoped_session, bool]] = None,
                 table_prefix: str = "", table_base: Optional[declarative_base] = None,
                 manage_tables: bool = True) -> None:
        if isinstance(engine, str):
            engine = sql.create_engine(engine)

        self.db_engine = engine
        if session is None:
            db_factory = orm.sessionmaker(bind=self.db_engine)
            self.db = orm.scoping.scoped_session(db_factory)
        elif not session:
            self.db = None
        else:
            self.db = session

        table_base = table_base or declarative_base()
        (self.Version, self.Session, self.Entity,
         self.SentFile, self.UpdateState) = self.create_table_classes(self.db, table_prefix,
                                                                      table_base)
        self.alchemy_session_class = AlchemySession
        if not self.db:
            # Implicit core mode if there's no ORM session.
            self.core_mode = True

        if manage_tables:
            if not self.db:
                raise ValueError("Can't manage tables without an ORM session.")
            table_base.metadata.bind = self.db_engine
            if not self.db_engine.dialect.has_table(self.db_engine,
                                                    self.Version.__tablename__):
                table_base.metadata.create_all()
                self.db.add(self.Version(version=LATEST_VERSION))
                self.db.commit()
            else:
                self.check_and_upgrade_database()

    @property
    def core_mode(self) -> bool:
        return self.alchemy_session_class != AlchemySession

    @core_mode.setter
    def core_mode(self, val: bool) -> None:
        if val:
            if self.db_engine.dialect.name == "mysql":
                self.alchemy_session_class = AlchemyMySQLCoreSession
            elif self.db_engine.dialect.name == "postgresql":
                self.alchemy_session_class = AlchemyPostgresCoreSession
            elif self.db_engine.dialect.name == "sqlite":
                self.alchemy_session_class = AlchemySQLiteCoreSession
            else:
                self.alchemy_session_class = AlchemyCoreSession
        else:
            if not self.db:
                raise ValueError("Can't use ORM mode without an ORM session.")
            self.alchemy_session_class = AlchemySession

    @staticmethod
    def create_table_classes(db: scoped_session, prefix: str, base: declarative_base
                             ) -> Tuple[Any, Any, Any, Any, Any]:
        qp = db.query_property() if db else None

        class Version(base):
            query = qp
            __tablename__ = "{prefix}version".format(prefix=prefix)
            version = Column(Integer, primary_key=True)

            def __str__(self):
                return "Version('{}')".format(self.version)

        class Session(base):
            query = qp
            __tablename__ = '{prefix}sessions'.format(prefix=prefix)

            session_id = Column(String(255), primary_key=True)
            dc_id = Column(Integer, primary_key=True)
            server_address = Column(String(255))
            port = Column(Integer)
            auth_key = Column(LargeBinary)

            def __str__(self):
                return "Session('{}', {}, '{}', {}, {})".format(self.session_id, self.dc_id,
                                                                self.server_address, self.port,
                                                                self.auth_key)

        class Entity(base):
            query = qp
            __tablename__ = '{prefix}entities'.format(prefix=prefix)

            session_id = Column(String(255), primary_key=True)
            id = Column(BigInteger, primary_key=True)
            hash = Column(BigInteger, nullable=False)
            username = Column(String(32))
            phone = Column(BigInteger)
            name = Column(String(255))

            def __str__(self):
                return "Entity('{}', {}, {}, '{}', '{}', '{}')".format(self.session_id, self.id,
                                                                       self.hash, self.username,
                                                                       self.phone, self.name)

        class SentFile(base):
            query = qp
            __tablename__ = '{prefix}sent_files'.format(prefix=prefix)

            session_id = Column(String(255), primary_key=True)
            md5_digest = Column(LargeBinary, primary_key=True)
            file_size = Column(Integer, primary_key=True)
            type = Column(Integer, primary_key=True)
            id = Column(BigInteger)
            hash = Column(BigInteger)

            def __str__(self):
                return "SentFile('{}', {}, {}, {}, {}, {})".format(self.session_id,
                                                                   self.md5_digest, self.file_size,
                                                                   self.type, self.id, self.hash)

        class UpdateState(base):
            query = qp
            __tablename__ = "{prefix}update_state".format(prefix=prefix)

            session_id = Column(String(255), primary_key=True)
            entity_id = Column(BigInteger, primary_key=True)
            pts = Column(BigInteger)
            qts = Column(BigInteger)
            date = Column(BigInteger)
            seq = Column(BigInteger)
            unread_count = Column(Integer)

        return Version, Session, Entity, SentFile, UpdateState

    def _add_column(self, table: Any, column: Column) -> None:
        column_name = column.compile(dialect=self.db_engine.dialect)
        column_type = column.type.compile(self.db_engine.dialect)
        self.db_engine.execute("ALTER TABLE {} ADD COLUMN {} {}".format(
            table.__tablename__, column_name, column_type))

    def check_and_upgrade_database(self) -> None:
        row = self.Version.query.all()
        version = row[0].version if row else 1
        if version == LATEST_VERSION:
            return

        self.Version.query.delete()

        if version == 1:
            self.UpdateState.__table__.create(self.db_engine)
            version = 3
        elif version == 2:
            self._add_column(self.UpdateState, Column(type=Integer, name="unread_count"))

        self.db.add(self.Version(version=version))
        self.db.commit()

    def new_session(self, session_id: str) -> 'AlchemySession':
        return self.alchemy_session_class(self, session_id)

    def has_session(self, session_id: str) -> bool:
        if self.core_mode:
            t = self.Session.__table__
            rows = self.db_engine.execute(select([func.count(t.c.auth_key)])
                                          .where(and_(t.c.session_id == session_id,
                                                      t.c.auth_key != b'')))
            try:
                count, = next(rows)
                return count > 0
            except StopIteration:
                return False
        else:
            return self.Session.query.filter(self.Session.session_id == session_id).count() > 0

    def list_sessions(self):
        return

    def save(self) -> None:
        if self.db:
            self.db.commit()