# Copyright (C) 2019 Pierre Letessier
# This source code is licensed under the BSD 3 license found in the
# LICENSE file in the root directory of this source tree.
"""
MongoDB Repository
"""
import numpy as np
import pymongo
from pymongo import MongoClient
from pymongo.errors import DuplicateKeyError
from pymongo.results import DeleteResult

from errors.taranis_error import TaranisAlreadyExistsError
from repositories.db_repository import AbstractDatabaseRepository


class MongoDBDatabaseRepository(AbstractDatabaseRepository):

    def __init__(self, host='localhost', port=27017, username='root', password='password'):
        self.mongo_client = MongoClient(host, port,
                                        username=username,
                                        password=password)
        self.db_mongo = self.mongo_client['taranis']

        self.databases_collection = self.db_mongo["databases"]
        self.databases_collection.create_index(
            [("name", pymongo.DESCENDING)],
            unique=True
        )

        self.vector_collection = self.db_mongo["vector"]
        self.vector_collection.create_index(
            [("db_name", pymongo.ASCENDING), ("id", pymongo.ASCENDING)],
            unique=True
        )

        self.indices_collection = self.db_mongo["indices"]
        self.indices_collection.create_index(
            [("index_name", pymongo.DESCENDING), ("db_name", pymongo.DESCENDING)],
            unique=True
        )

    def get_all_databases(self):
        cursor = self.databases_collection.find()
        return [d for d in cursor]

    def create_one_database(self, database):
        try:
            res = self.databases_collection.insert_one(database)
            return res.inserted_id
        except DuplicateKeyError:
            raise TaranisAlreadyExistsError("Database {} already exists".format(database['name']))

    def find_one_database_by_name(self, name):
        return self.databases_collection.find_one(dict(name=name))

    def delete_one_database_by_name(self, name):
        res: DeleteResult = self.databases_collection.delete_one(dict(name=name))
        return res.deleted_count == 1

    def delete_vectors_by_database_name(self, name: str) -> bool:
        res = self.vector_collection.delete_many(dict(db_name=name))
        # TODO DO better
        return True

    def create_vectors(self, vectors: []) -> bool:
        res = self.vector_collection.insert_many(vectors)
        return len(res.inserted_ids) == len(vectors)

    def create_one_index(self, index) -> str:
        res = self.indices_collection.insert_one(index)
        return res.inserted_id

    def delete_one_index(self, index) -> bool:
        res: DeleteResult = self.indices_collection.delete_one(dict(index_name=index.index_name, db_name=index.db_name))
        return res.deleted_count == 1

    def find_one_index_by_index_name_and_db_name(self, index_name: str, db_name: str) -> object:
        return self.indices_collection.find_one(dict(index_name=index_name, db_name=db_name))

    def find_vectors_by_database_name(self, name: str, limit=100000, skip=0) -> (np.ndarray, int):
        cursor = self.vector_collection.find(dict(db_name=name)).skip(skip).limit(limit)
        dimension = 128
        count = cursor.count(with_limit_and_skip=True)

        vectors = np.empty((count, dimension), dtype=np.float32)
        ids = np.empty(count, dtype=np.int64)
        i = 0
        for v in cursor:
            vectors[i, :] = np.frombuffer(v["data"], dtype=np.float32)
            ids[i] = v["id"]
            i += 1
        return vectors, count, ids

    def get_vectors(self, db_name: str, ids: [], limit=100000, skip=0) -> list:
        vectors = [None] * len(ids)

        reverse_list = dict()
        for idx, idv in enumerate(ids):
            reverse_list[idv] = idx

        cursor = self.vector_collection.find(dict(db_name=db_name, id={"$in": ids})).skip(skip).limit(limit)
        for v in cursor:
            vectors[reverse_list[v['id']]] = v
        return vectors