from sqlalchemy.orm import joinedload
from sqlalchemy.sql import func

from FlaskRTBCTF.utils.models import db, TimeMixin, ReprMixin
from FlaskRTBCTF.utils.cache import cache


# Machine Table
class Machine(TimeMixin, ReprMixin, db.Model):
    __tablename__ = "machine"
    __repr_fields__ = (
        "name",
        "os",
    )
    id = db.Column(db.Integer, primary_key=True, index=True)
    name = db.Column(db.String(64), nullable=False, unique=True)
    user_hash = db.Column(db.String(32), nullable=False)
    root_hash = db.Column(db.String(32), nullable=False)
    user_points = db.Column(db.Integer, default=0)
    root_points = db.Column(db.Integer, default=0)
    os = db.Column(db.String, nullable=False, default="linux")
    ip = db.Column(db.String(64), nullable=False)
    difficulty = db.Column(db.String, nullable=False, default="Easy")

    @staticmethod
    @cache.memoize(timeout=3600)
    def avg_rating(id):
        avg_rating = (
            UserMachine.query.with_entities(func.avg(UserMachine.rating))
            .filter(UserMachine.machine_id == id, UserMachine.rating != 0)
            .scalar()
        )
        return round(avg_rating, 1) if avg_rating else 0

    @staticmethod
    @cache.cached(timeout=3600 * 3, key_prefix="machines")
    def get_all():
        return Machine.query.all()


# UserMachine: N to N relationship
class UserMachine(TimeMixin, db.Model):
    __tablename__ = "user_machine"
    user_id = db.Column(
        db.Integer,
        db.ForeignKey("user.id"),
        nullable=False,
        primary_key=True,
        index=True,
    )
    machine_id = db.Column(
        db.Integer,
        db.ForeignKey("machine.id"),
        nullable=False,
        primary_key=True,
        index=True,
    )
    owned_user = db.Column(db.Boolean, nullable=False, default=False)
    owned_root = db.Column(db.Boolean, nullable=False, default=False)
    rating = db.Column(db.Integer, nullable=False, default=0)

    @classmethod
    @cache.memoize(timeout=3600 * 3)
    def completed_machines(cls, user_id):
        completed = dict()
        _ids1 = (
            cls.query.with_entities(cls.machine_id)
            .filter_by(user_id=user_id, owned_user=True)
            .all()
        )
        _ids2 = (
            cls.query.with_entities(cls.machine_id)
            .filter_by(user_id=user_id, owned_root=True)
            .all()
        )
        completed["user"] = [int(id[0]) for id in _ids1]
        completed["root"] = [int(id[0]) for id in _ids2]
        return completed

    @classmethod
    @cache.memoize(timeout=3600 * 3)
    def rated_machines(cls, user_id):
        _ids = (
            cls.query.with_entities(cls.machine_id)
            .filter(cls.user_id == user_id, cls.rating != 0)
            .all()
        )
        _ids = [int(id[0]) for id in _ids]
        return _ids


# Tag Model
class Tag(ReprMixin, db.Model):
    __tablename__ = "tag"
    __repr_fields__ = ("label",)
    id = db.Column(db.Integer, primary_key=True)
    label = db.Column(db.String(32), nullable=False)
    color = db.Column(db.String(16), nullable=False)


# Tags table
tags = db.Table(
    "tags",
    db.Column("tag_id", db.Integer, db.ForeignKey("tag.id"), primary_key=True),
    db.Column(
        "challenge_id", db.Integer, db.ForeignKey("challenge.id"), primary_key=True
    ),
)


# Challenges Model
class Challenge(TimeMixin, ReprMixin, db.Model):
    __tablename__ = "challenge"
    __repr_fields__ = ("title", "category")
    id = db.Column(db.Integer, primary_key=True, index=True)
    title = db.Column(db.String(64), nullable=False, unique=True)
    description = db.Column(db.TEXT, nullable=True)
    flag = db.Column(db.TEXT, nullable=False)
    points = db.Column(db.Integer, nullable=False, default=0)
    url = db.Column(db.TEXT, nullable=True)
    difficulty = db.Column(db.String, nullable=True)

    category_id = db.Column(db.Integer, db.ForeignKey("category.id"), nullable=False)
    category = db.relationship("Category", backref=db.backref("challenges", lazy=True))

    tags = db.relationship(
        "Tag",
        secondary=tags,
        lazy="subquery",
        backref=db.backref("challenges", lazy="noload"),
    )

    @staticmethod
    @cache.memoize(timeout=3600)
    def avg_rating(id):
        avg_rating = (
            UserChallenge.query.with_entities(func.avg(UserChallenge.rating))
            .filter(UserChallenge.challenge_id == id, UserChallenge.rating != 0)
            .scalar()
        )
        return round(avg_rating, 1) if avg_rating else 0


# UserChallenge: N to N relationship
class UserChallenge(TimeMixin, db.Model):
    __tablename__ = "user_challenge"
    user_id = db.Column(
        db.Integer,
        db.ForeignKey("user.id"),
        nullable=False,
        primary_key=True,
        index=True,
    )
    challenge_id = db.Column(
        db.Integer,
        db.ForeignKey("challenge.id"),
        nullable=False,
        primary_key=True,
        index=True,
    )
    completed = db.Column(db.Boolean, nullable=False, default=False)
    rating = db.Column(db.Integer, nullable=False, default=0)

    @classmethod
    @cache.memoize(timeout=3600 * 3)
    def completed_challenges(cls, user_id):
        _ids = (
            cls.query.with_entities(cls.challenge_id)
            .filter_by(user_id=user_id, completed=True)
            .all()
        )
        _ids = [int(id[0]) for id in _ids]
        return _ids

    @classmethod
    @cache.memoize(timeout=3600 * 3)
    def rated_challenges(cls, user_id):
        _ids = (
            cls.query.with_entities(cls.challenge_id)
            .filter(cls.user_id == user_id, cls.rating != 0)
            .all()
        )
        _ids = [int(id[0]) for id in _ids]
        return _ids


# Category Model
class Category(ReprMixin, db.Model):
    __tablename__ = "category"
    __repr_fields__ = ("name",)
    id = db.Column(db.Integer, primary_key=True)
    name = db.Column(db.String(32), nullable=False)

    @staticmethod
    @cache.cached(timeout=3600 * 3, key_prefix="challenges")
    def get_challenges():
        categories = (
            Category.query.options(joinedload("challenges"))
            .filter(Category.challenges)
            .all()
        )
        return categories