from collections import defaultdict
from typing import List, Optional, Dict, Any
from schema import JobSchema, MatchesSchema, AgentSpecSchema, ConfigSchema
from time import time
import json
import random
import string
from redis import StrictRedis
from enum import Enum


class TaskType(Enum):
    SEARCH = "search"
    YARA = "yara"
    RELOAD = "reload"
    COMMAND = "command"


class AgentTask:
    def __init__(self, type: TaskType, data: str):
        self.type = type
        self.data = data


class JobId:
    """ Represents a unique job ID in redis. Looks like this: `job:IU32AD3` """

    def __init__(self, key: str) -> None:
        """ Creates a new JobId object. Can take both key and raw hash. """
        if not key.startswith("job:"):
            key = f"job:{key}"
        self.key = key
        self.hash = key[4:]

    @property
    def meta_key(self) -> str:
        """ Every job has exactly one related meta key"""
        return f"meta:{self.hash}"

    def __repr__(self) -> str:
        return self.key


class MatchInfo:
    """ Represents information about a single match """

    def __init__(
        self, file: str, meta: Dict[str, Any], matches: List[str]
    ) -> None:
        self.file = file
        self.meta = meta
        self.matches = matches

    def to_json(self) -> str:
        """ Converts match info to json """
        return json.dumps(
            {"file": self.file, "meta": self.meta, "matches": self.matches}
        )


class Database:
    def __init__(self, redis_host: str, redis_port: int) -> None:
        self.redis = StrictRedis(
            host=redis_host, port=redis_port, decode_responses=True
        )

    def get_yara_by_job(self, job: JobId) -> str:
        """ Gets yara rule associated with job """
        return self.redis.hget(job.key, "raw_yara")

    def get_job_status(self, job: JobId) -> str:
        """ Gets status of the specified job """
        return self.redis.hget(job.key, "status")

    def get_job_ids(self) -> List[JobId]:
        """ Gets IDs of all jobs in the database """
        return [JobId(key) for key in self.redis.keys("job:*")]

    def cancel_job(self, job: JobId) -> None:
        """ Sets the job status to cancelled """
        self.redis.hmset(
            job.key, {"status": "cancelled", "finished": int(time())}
        )

    def fail_job(self, job: JobId, message: str) -> None:
        """ Sets the job status to failed. """
        self.redis.hmset(
            job.key,
            {"status": "failed", "error": message, "finished": int(time())},
        )

    def get_job(self, job: JobId) -> JobSchema:
        data = self.redis.hgetall(job.key)
        return JobSchema(
            id=job.hash,
            status=data.get("status", "ERROR"),
            error=data.get("error", None),
            rule_name=data.get("rule_name", "ERROR"),
            rule_author=data.get("rule_author", None),
            raw_yara=data.get("raw_yara", "ERROR"),
            submitted=data.get("submitted", 0),
            finished=data.get("finished", None),
            priority=data.get("priority", "medium"),
            files_limit=data.get("files_limit", 0),
            files_processed=int(data.get("files_processed", 0)),
            files_matched=int(data.get("files_matched", 0)),
            files_in_progress=int(data.get("files_in_progress", 0)),
            total_files=int(data.get("total_files", 0)),
            files_errored=int(data.get("files_errored", 0)),
            reference=data.get("reference", ""),
            iterator=data.get("iterator", None),
            taints=json.loads(data.get("taints", "[]")),
            total_datasets=data.get("total_datasets", 0),
            datasets_left=data.get("datasets_left", 0),
        )

    def remove_query(self, job: JobId) -> None:
        """ Sets the job status to removed """
        self.redis.hmset(job.key, {"status": "removed"})

    def add_match(self, job: JobId, match: MatchInfo) -> None:
        self.redis.rpush(job.meta_key, match.to_json())

    def job_contains(self, job: JobId, ordinal: int, file_path: str) -> bool:
        file_list = self.redis.lrange(job.meta_key, ordinal, ordinal)
        return file_list and file_path == json.loads(file_list[0])["file"]

    def job_start_work(self, job: JobId, files_in_progress: int) -> None:
        """ Updates the number of files being processed right now.
        :param job: ID of the job being updated.
        :type job: JobId
        :param files_in_progress: Number of files in the current work unit.
        :type files_in_progress: int
        """
        self.redis.hincrby(job.key, "files_in_progress", files_in_progress)

    def job_update_work(
        self, job: JobId, files_processed: int, files_matched: int
    ) -> None:
        """ Update progress for the job. This will increment number of files processed
        and matched, and if as a result all files are processed, will change the job
        status to `done`
        """
        self.redis.hincrby(job.key, "files_processed", files_processed)
        self.redis.hincrby(job.key, "files_in_progress", -files_processed)
        self.redis.hincrby(job.key, "files_matched", files_matched)

    def job_update_error(self, job: JobId, files_errored: int) -> None:
        """ Update error for the job if it appears during agents' work.
        This will increment number of files errored and write them to the variable.
        """
        self.redis.hincrby(job.key, "files_errored", files_errored)

    def create_search_task(
        self,
        rule_name: str,
        rule_author: str,
        raw_yara: str,
        priority: Optional[str],
        files_limit: int,
        reference: str,
        taints: List[str],
        agents: List[str],
    ) -> JobId:
        job = JobId(
            "".join(
                random.SystemRandom().choice(
                    string.ascii_uppercase + string.digits
                )
                for _ in range(12)
            )
        )
        job_obj = {
            "status": "new",
            "rule_name": rule_name,
            "rule_author": rule_author,
            "raw_yara": raw_yara,
            "submitted": int(time()),
            "priority": priority or "medium",
            "files_limit": files_limit,
            "reference": reference,
            "files_in_progress": 0,
            "files_processed": 0,
            "files_matched": 0,
            "total_files": 0,
            "files_errored": 0,
            "agents_left": len(agents),
            "datasets_left": 0,
            "total_datasets": 0,
        }

        job_obj["taints"] = json.dumps(taints)

        self.redis.hmset(job.key, job_obj)
        for agent in agents:
            self.redis.rpush(f"agent:{agent}:queue-search", job.hash)
        return job

    def broadcast_command(self, command: str) -> None:
        for agent in self.get_active_agents().keys():
            self.redis.rpush(f"agent:{agent}:queue-command", command)

    def init_job_datasets(
        self, agent_id: str, job: JobId, datasets: List[str]
    ) -> None:
        if datasets:
            self.redis.lpush(f"job-ds:{agent_id}:{job.hash}", *datasets)
            self.redis.hincrby(job.key, "total_datasets", len(datasets))
            self.redis.hincrby(job.key, "datasets_left", len(datasets))
        self.redis.hset(job.key, "status", "processing")

    def get_next_search_dataset(
        self, agent_id: str, job: JobId
    ) -> Optional[str]:
        return self.redis.lpop(f"job-ds:{agent_id}:{job.hash}")

    def dataset_query_done(self, job: JobId):
        self.redis.hincrby(job.key, "datasets_left", -1)

    def job_datasets_left(self, agent_id: str, job: JobId) -> int:
        return self.redis.llen(f"job-ds:{agent_id}:{job.hash}")

    def agent_continue_search(self, agent_id: str, job: JobId) -> None:
        self.redis.rpush(f"agent:{agent_id}:queue-search", job.hash)

    def get_job_matches(
        self, job: JobId, offset: int = 0, limit: Optional[int] = None
    ) -> MatchesSchema:
        if limit is None:
            end = -1
        else:
            end = offset + limit - 1
        meta = self.redis.lrange("meta:" + job.hash, offset, end)
        matches = [json.loads(m) for m in meta]
        for match in matches:
            # Compatibility fix for old jobs, without sha256 metadata key.
            if "sha256" not in match["meta"]:
                match["meta"]["sha256"] = {
                    "display_text": "0" * 64,
                    "hidden": True,
                }
        return MatchesSchema(job=self.get_job(job), matches=matches)

    def reload_configuration(self, config_version: int):
        # Send request to any of agents that configuration must be reloaded
        self.redis.lpush(f"config-reload:{config_version}", "reload")
        # After 300 seconds of inactivity: reload request is deleted
        self.redis.expire(f"config-reload:{config_version}", 300)

    def agent_get_task(self, agent_id: str, config_version: int) -> AgentTask:
        agent_prefix = f"agent:{agent_id}"
        # config-reload is a notification queue that is set by web to notify
        # agents that configuration has been changed
        task_queues = [
            f"config-reload:{config_version}",
            f"{agent_prefix}:queue-command",
            f"{agent_prefix}:queue-search",
            f"{agent_prefix}:queue-yara",
        ]
        queue_task: Any = self.redis.blpop(task_queues)
        queue, task = queue_task

        if queue == f"config-reload:{config_version}":
            return AgentTask(TaskType.RELOAD, task)

        if queue.endswith(":queue-command"):
            return AgentTask(TaskType.COMMAND, task)

        if queue.endswith(":queue-search"):
            return AgentTask(TaskType.SEARCH, task)

        if queue.endswith(":queue-yara"):
            return AgentTask(TaskType.YARA, task)

        raise RuntimeError("Unexpected queue")

    def update_job_files(self, job: JobId, total_files: int) -> int:
        return self.redis.hincrby(job.key, "total_files", total_files)

    def agent_start_job(
        self, agent_id: str, job: JobId, iterator: str
    ) -> None:
        job_data = json.dumps({"job": job.key, "iterator": iterator})
        self.redis.rpush(f"agent:{agent_id}:queue-yara", job_data)

    def agent_finish_job(self, job: JobId) -> None:
        new_agents = self.redis.hincrby(job.key, "agents_left", -1)
        if new_agents <= 0:
            self.redis.hmset(
                job.key, {"status": "done", "finished": int(time())}
            )

    def has_pending_search_tasks(self, agent_id: str, job: JobId) -> bool:
        return self.redis.llen(f"job-ds:{agent_id}:{job.hash}") == 0

    def register_active_agent(
        self,
        agent_id: str,
        ursadb_url: str,
        plugins_spec: Dict[str, Dict[str, str]],
        active_plugins: List[str],
    ) -> None:
        self.redis.hset(
            "agents",
            agent_id,
            AgentSpecSchema(
                ursadb_url=ursadb_url,
                plugins_spec=plugins_spec,
                active_plugins=active_plugins,
            ).json(),
        )

    def get_active_agents(self) -> Dict[str, AgentSpecSchema]:
        return {
            name: AgentSpecSchema.parse_raw(spec)
            for name, spec in self.redis.hgetall("agents").items()
        }

    def get_plugins_config(self) -> List[ConfigSchema]:
        # { plugin_name: { field: description } }
        config_fields: Dict[str, Dict[str, str]] = defaultdict(dict)
        # Merge all config fields
        for agent_spec in self.get_active_agents().values():
            for name, fields in agent_spec.plugins_spec.items():
                config_fields[name].update(fields)
        # Transform fields into ConfigSchema
        # { plugin_name: { field: ConfigSchema } }
        plugin_configs = {
            plugin: {
                key: ConfigSchema(
                    plugin=plugin, key=key, value="", description=description
                )
                for key, description in spec.items()
            }
            for plugin, spec in config_fields.items()
        }
        # Get configuration values for each plugin
        for plugin, spec in plugin_configs.items():
            config = self.get_plugin_configuration(plugin)
            for key, value in config.items():
                if key in plugin_configs[plugin]:
                    plugin_configs[plugin][key].value = value
        # Flatten to the target form
        return [
            plugin_configs[plugin][key]
            for plugin in sorted(plugin_configs.keys())
            for key in sorted(plugin_configs[plugin].keys())
        ]

    def get_plugin_config_version(self) -> int:
        return int(self.redis.get("plugin-version") or 0)

    def get_plugin_configuration(self, plugin_name: str) -> Dict[str, str]:
        return self.redis.hgetall(f"plugin:{plugin_name}")

    def set_plugin_configuration_key(
        self, plugin_name: str, key: str, value: str
    ) -> None:
        self.redis.hset(f"plugin:{plugin_name}", key, value)
        prev_version = self.redis.incrby("plugin-version", 1) - 1
        self.reload_configuration(prev_version)

    def cache_get(self, key: str, expire: int) -> Optional[str]:
        value = self.redis.get(f"cached:{key}")
        if value is not None:
            self.redis.expire(f"cached:{key}", expire)
        return value

    def cache_store(self, key: str, value: str, expire: int) -> None:
        self.redis.setex(f"cached:{key}", expire, value)