import logging
import uuid

from pymongo import MongoClient
from pymongo.errors import ConnectionFailure, ServerSelectionTimeoutError

from ..common import (
    SessionChecker, create_resource, datetime_to_float, datetime_to_string,
    datetime_to_string_stix, determine_spec_version, determine_version,
    float_to_datetime, generate_status, generate_status_details,
    get_custom_headers, get_timestamp, parse_request_parameters,
    string_to_datetime
)
from ..exceptions import MongoBackendError, ProcessingError
from ..filters.mongodb_filter import MongoDBFilter
from .base import Backend

# Module-level logger
log = logging.getLogger(__name__)


def catch_mongodb_error(func):
    """Catch mongodb availability error"""

    def api_wrapper(*args, **kwargs):
        try:
            return func(*args, **kwargs)
        except (ConnectionFailure, ServerSelectionTimeoutError) as e:
            raise MongoBackendError("Unable to connect to MongoDB", 500, e)

    return api_wrapper


class MongoBackend(Backend):

    # access control is handled at the views level

    def __init__(self, **kwargs):
        try:
            self.client = MongoClient(kwargs.get("uri"))
            self.pages = {}
            self.timeout = kwargs.get("session_timeout", 30)
        except ConnectionFailure:
            log.error("Unable to establish a connection to MongoDB server {}".format(kwargs.get("uri")))

        checker = SessionChecker(kwargs.get("check_interval", 10), self._pop_expired_sessions)
        checker.start()

    def _process_params(self, filter_args, limit):
        next_id = filter_args.get("next")
        if limit and next_id is None:
            client_params = parse_request_parameters(filter_args)
            record = {"skip": 0, "limit": limit, "args": client_params, "request_time": datetime_to_float(get_timestamp())}
            next_id = str(uuid.uuid4())
            self.pages[next_id] = record
        elif limit and next_id:
            if next_id not in self.pages:
                raise ProcessingError("The server did not understand the request or filter parameters: 'next' not valid", 400)
            client_params = parse_request_parameters(filter_args)
            if self.pages[next_id]["args"] != client_params:
                raise ProcessingError("The server did not understand the request or filter parameters: params changed over subsequent transaction", 400)
            self.pages[next_id]["limit"] = limit
            self.pages[next_id]["request_time"] = datetime_to_float(get_timestamp())
            record = self.pages[next_id]
        else:
            record = {}
        return next_id, record

    def _update_record(self, next_id, count, internal=False):
        more = False
        if next_id:
            if internal is False:
                self.pages[next_id]["skip"] += self.pages[next_id]["limit"]
            if self.pages[next_id]["skip"] >= count:
                self.pages.pop(next_id, None)
                next_id = None
            else:
                more = True
        return next_id, more

    def _validate_object_id(self, manifest_info, collection_id, object_id):
        result = list(manifest_info.find({"_collection_id": collection_id, "id": object_id}).limit(1))
        if len(result) == 0:
            raise ProcessingError("Object '{}' not found".format(object_id), 404)

    def _pop_expired_sessions(self):
        expired_ids = []
        boundary = datetime_to_float(get_timestamp())
        for next_id, record in self.pages.items():
            if boundary - record["request_time"] > self.timeout:
                expired_ids.append(next_id)

        for item in expired_ids:
            self.pages.pop(item)

    def _get_object_manifest(self, api_root, collection_id, filter_args, allowed_filters, limit, internal=False):
        api_root_db = self.client[api_root]
        objects_info = api_root_db["objects"]
        next_id, record = self._process_params(filter_args, limit)

        full_filter = MongoDBFilter(
            filter_args,
            {"_collection_id": {"$eq": collection_id}},
            allowed_filters,
            record
        )
        count, objects_found = full_filter.process_filter(
            objects_info,
            allowed_filters,
            "manifests",
        )

        for obj in objects_found:
            obj["date_added"] = datetime_to_string(float_to_datetime(obj["date_added"]))
            obj["version"] = datetime_to_string_stix(float_to_datetime(obj["version"]))

        next_id, more = self._update_record(next_id, count, internal)
        manifest_resource = create_resource("objects", objects_found, more, next_id)
        if internal:
            return manifest_resource
        else:
            headers = get_custom_headers(manifest_resource)
            return manifest_resource, headers

    @catch_mongodb_error
    def _update_manifest(self, api_root, collection_id, media_type):
        api_root_db = self.client[api_root]
        collection_info = api_root_db["collections"]

        # update media_types in collection if a new one is present.
        info = collection_info.find_one({"id": collection_id})
        if media_type not in info["media_types"]:
            info["media_types"].append(media_type)
            collection_info.update_one(
                {"id": collection_id},
                {"$set": {"media_types": info["media_types"]}}
            )

    @catch_mongodb_error
    def server_discovery(self):
        discovery_db = self.client["discovery_database"]
        discovery_info = discovery_db["discovery_information"]
        pipeline = [
            {
                "$lookup": {
                    "from": "api_root_info",
                    "localField": "api_roots",
                    "foreignField": "_name",
                    "as": "_roots",
                },
            },
            {
                "$addFields": {
                    "api_roots": "$_roots._url",
                },
            },
            {
                "$project": {
                    "_roots": 0,
                    "_id": 0,
                }
            }
        ]
        info = discovery_info.aggregate(pipeline).next()
        return info

    @catch_mongodb_error
    def get_collections(self, api_root):
        if api_root not in self.client.list_database_names():
            return None  # must return None, so 404 is raised

        api_root_db = self.client[api_root]
        collection_info = api_root_db["collections"]
        collections = list(collection_info.find({}, {"_id": 0}))
        return create_resource("collections", collections)

    @catch_mongodb_error
    def get_collection(self, api_root, collection_id):
        if api_root not in self.client.list_database_names():
            return None  # must return None, so 404 is raised

        api_root_db = self.client[api_root]
        collection_info = api_root_db["collections"]
        info = collection_info.find_one({"id": collection_id}, {"_id": 0})
        return info

    @catch_mongodb_error
    def get_object_manifest(self, api_root, collection_id, filter_args, allowed_filters, limit):
        return self._get_object_manifest(api_root, collection_id, filter_args, allowed_filters, limit, False)

    @catch_mongodb_error
    def get_api_root_information(self, api_root_name):
        db = self.client["discovery_database"]
        api_root_info = db["api_root_info"]
        info = api_root_info.find_one(
            {"_name": api_root_name},
            {"_id": 0, "_url": 0, "_name": 0}
        )
        return info

    @catch_mongodb_error
    def get_status(self, api_root, status_id):
        api_root_db = self.client[api_root]
        status_info = api_root_db["status"]
        result = status_info.find_one(
            {"id": status_id},
            {"_id": 0}
        )
        return result

    @catch_mongodb_error
    def get_objects(self, api_root, collection_id, filter_args, allowed_filters, limit):
        api_root_db = self.client[api_root]
        objects_info = api_root_db["objects"]
        next_id, record = self._process_params(filter_args, limit)

        full_filter = MongoDBFilter(
            filter_args,
            {"_collection_id": {"$eq": collection_id}},
            allowed_filters,
            record
        )
        # Note: error handling was not added to following call as mongo will
        # handle (user supplied) filters gracefully if they don't exist
        count, objects_found = full_filter.process_filter(
            objects_info,
            allowed_filters,
            "objects"
        )

        for obj in objects_found:
            if "modified" in obj:
                obj["modified"] = datetime_to_string_stix(float_to_datetime(obj["modified"]))
            if "created" in obj:
                obj["created"] = datetime_to_string_stix(float_to_datetime(obj["created"]))

        manifest_resource = self._get_object_manifest(api_root, collection_id, filter_args, allowed_filters, limit, True)
        headers = get_custom_headers(manifest_resource)

        next_id, more = self._update_record(next_id, count)
        return create_resource("objects", objects_found, more, next_id), headers

    @catch_mongodb_error
    def add_objects(self, api_root, collection_id, objs, request_time):
        api_root_db = self.client[api_root]
        objects_info = api_root_db["objects"]
        failed = 0
        succeeded = 0
        pending = 0
        successes = []
        failures = []
        media_fmt = "application/stix+json;version={}"

        try:
            for new_obj in objs["objects"]:
                media_type = media_fmt.format(determine_spec_version(new_obj))
                mongo_query = {"_collection_id": collection_id, "id": new_obj["id"], "_manifest.media_type": media_type}
                if "modified" in new_obj:
                    mongo_query["_manifest.version"] = datetime_to_float(string_to_datetime(new_obj["modified"]))
                existing_entry = objects_info.find_one(mongo_query)
                obj_version = determine_version(new_obj, request_time)
                if existing_entry:
                    status_detail = generate_status_details(
                        new_obj["id"], obj_version,
                        message="Unable to process object because an identical entry already exists in collection '{}'.".format(collection_id),
                    )
                    failures.append(status_detail)
                    failed += 1
                else:
                    new_obj.update({"_collection_id": collection_id})
                    if "modified" in new_obj:
                        new_obj["modified"] = datetime_to_float(string_to_datetime(new_obj["modified"]))
                    if "created" in new_obj:
                        new_obj["created"] = datetime_to_float(string_to_datetime(new_obj["created"]))
                    _manifest = {
                        "id": new_obj["id"],
                        "date_added": datetime_to_float(request_time),
                        "version": datetime_to_float(string_to_datetime(obj_version)),
                        "media_type": media_type,
                    }
                    new_obj.update({"_manifest": _manifest})
                    objects_info.insert_one(new_obj)
                    self._update_manifest(api_root, collection_id, media_type)
                    status_detail = generate_status_details(
                        new_obj["id"], obj_version,
                        message="Successfully added object to collection '{}'.".format(collection_id)
                    )
                    successes.append(status_detail)
                    succeeded += 1
        except Exception as e:
            # log.exception(e)
            raise ProcessingError("While processing supplied content, an error occurred", 422, e)

        status = generate_status(
            datetime_to_string(request_time), "complete", succeeded, failed,
            pending, successes=successes, failures=failures,
        )
        api_root_db["status"].insert_one(status)
        status.pop("_id", None)
        return status

    @catch_mongodb_error
    def get_object(self, api_root, collection_id, object_id, filter_args, allowed_filters, limit):
        api_root_db = self.client[api_root]
        objects_info = api_root_db["objects"]
        # set manually to properly retrieve manifests, and early to not break the pagination checks
        filter_args["match[id]"] = object_id
        next_id, record = self._process_params(filter_args, limit)

        self._validate_object_id(objects_info, collection_id, object_id)

        full_filter = MongoDBFilter(
            filter_args,
            {"_collection_id": {"$eq": collection_id}, "id": {"$eq": object_id}},
            allowed_filters,
            record
        )
        count, objects_found = full_filter.process_filter(
            objects_info,
            allowed_filters,
            "objects"
        )

        for obj in objects_found:
            if "modified" in obj:
                obj["modified"] = datetime_to_string_stix(float_to_datetime(obj["modified"]))
            if "created" in obj:
                obj["created"] = datetime_to_string_stix(float_to_datetime(obj["created"]))

        manifest_resource = self._get_object_manifest(api_root, collection_id, filter_args, ("id", "type", "version", "spec_version"), limit, True)
        headers = get_custom_headers(manifest_resource)

        next_id, more = self._update_record(next_id, count)
        return create_resource("objects", objects_found, more, next_id), headers

    @catch_mongodb_error
    def delete_object(self, api_root, collection_id, object_id, filter_args, allowed_filters):
        api_root_db = self.client[api_root]
        objects_info = api_root_db["objects"]

        self._validate_object_id(objects_info, collection_id, object_id)

        # Currently it will delete the object and the matching manifest from the backend
        full_filter = MongoDBFilter(
            filter_args,
            {"_collection_id": {"$eq": collection_id}, "id": {"$eq": object_id}},
            allowed_filters,
        )
        count, objects_found = full_filter.process_filter(
            objects_info,
            allowed_filters,
            "raw"
        )
        if objects_found:
            for obj in objects_found:
                obj_version = obj["_manifest"]["version"]
                objects_info.delete_one(
                    {"_collection_id": collection_id, "id": object_id, "_manifest.version": obj_version}
                )
        else:
            raise ProcessingError("Object '{}' not found".format(object_id), 404)

    @catch_mongodb_error
    def get_object_versions(self, api_root, collection_id, object_id, filter_args, allowed_filters, limit):
        api_root_db = self.client[api_root]
        objects_info = api_root_db["objects"]
        # set manually to properly retrieve manifests, and early to not break the pagination checks
        filter_args["match[id]"] = object_id
        filter_args["match[version]"] = "all"
        next_id, record = self._process_params(filter_args, limit)

        self._validate_object_id(objects_info, collection_id, object_id)

        full_filter = MongoDBFilter(
            filter_args,
            {"_collection_id": {"$eq": collection_id}, "id": {"$eq": object_id}},
            allowed_filters,
            record
        )
        count, manifests_found = full_filter.process_filter(
            objects_info,
            allowed_filters,
            "manifests",
        )

        manifest_resource = self._get_object_manifest(api_root, collection_id, filter_args, ("id", "type", "version", "spec_version"), limit, True)
        headers = get_custom_headers(manifest_resource)

        manifests_found = list(map(lambda x: datetime_to_string_stix(float_to_datetime(x["version"])), manifests_found))
        next_id, more = self._update_record(next_id, count)
        return create_resource("versions", manifests_found, more, next_id), headers