import mongomock
import pymongo
import pymongo.errors
from mongomock.store import DatabaseStore


class FailingMongoClient(mongomock.MongoClient):
    def __init__(
        self,
        max_calls_before_failure=2,
        exception_to_raise=pymongo.errors.AutoReconnect,
        **kwargs
    ):
        super().__init__(**kwargs)
        self._max_calls_before_failure = max_calls_before_failure
        self.exception_to_raise = exception_to_raise
        self._exception_to_raise = exception_to_raise

    def get_database(
        self, name=None, codec_options=None, read_preference=None, write_concern=None
    ):
        if name is None:
            return self.get_default_database()

        db = self._database_accesses.get(name)
        if db is None:
            db_store = self._store[name]
            db = self._database_accesses[name] = FailingDatabase(
                max_calls_before_failure=self._max_calls_before_failure,
                exception_to_raise=self._exception_to_raise,
                client=self,
                name=name,
                read_preference=read_preference or self.read_preference,
                codec_options=self._codec_options,
                _store=db_store,
            )
        return db


class FailingDatabase(mongomock.Database):
    def __init__(self, max_calls_before_failure, exception_to_raise=None, **kwargs):
        super().__init__(**kwargs)
        self._max_calls_before_failure = max_calls_before_failure
        self._exception_to_raise = exception_to_raise

    def get_collection(
        self,
        name,
        codec_options=None,
        read_preference=None,
        write_concern=None,
        read_concern=None,
    ):
        try:
            return self._collection_accesses[name].with_options(
                codec_options=codec_options or self._codec_options,
                read_preference=read_preference or self.read_preference,
                read_concern=read_concern,
                write_concern=write_concern,
            )
        except KeyError:
            self._ensure_valid_collection_name(name)
            collection = self._collection_accesses[name] = FailingCollection(
                max_calls_before_failure=self._max_calls_before_failure,
                exception_to_raise=self._exception_to_raise,
                database=self,
                name=name,
                write_concern=write_concern,
                read_preference=read_preference or self.read_preference,
                codec_options=codec_options or self._codec_options,
                _db_store=self._store,
            )
            return collection


class FailingCollection(mongomock.Collection):
    def __init__(self, max_calls_before_failure, exception_to_raise, **kwargs):
        super().__init__(**kwargs)
        self._max_calls_before_failure = max_calls_before_failure
        self._exception_to_raise = exception_to_raise
        self._calls = 0

    def insert_one(self, document, session=None):
        self._calls += 1
        if self._calls > self._max_calls_before_failure:
            raise pymongo.errors.ConnectionFailure
        else:
            return super().insert_one(document)

    def update_one(self, filter, update, upsert=False, session=None):
        self._calls += 1
        if self._calls > self._max_calls_before_failure:
            raise pymongo.errors.ConnectionFailure
        else:
            return super().update_one(filter, update, upsert)


class ReconnectingMongoClient(FailingMongoClient):
    def __init__(self, max_calls_before_reconnect, **kwargs):
        super().__init__(**kwargs)
        self._max_calls_before_reconnect = max_calls_before_reconnect

    def get_database(
        self, name=None, codec_options=None, read_preference=None, write_concern=None
    ):
        if name is None:
            return self.get_default_database()

        db = self._database_accesses.get(name)
        if db is None:
            db_store = self._store[name]
            db = self._database_accesses[name] = ReconnectingDatabase(
                max_calls_before_reconnect=self._max_calls_before_reconnect,
                max_calls_before_failure=self._max_calls_before_failure,
                exception_to_raise=self._exception_to_raise,
                client=self,
                name=name,
                read_preference=read_preference or self.read_preference,
                codec_options=self._codec_options,
                _store=db_store,
            )
        return db


class ReconnectingDatabase(FailingDatabase):
    def __init__(self, max_calls_before_reconnect, **kwargs):
        super().__init__(**kwargs)
        self._max_calls_before_reconnect = max_calls_before_reconnect

    def get_collection(
        self,
        name,
        codec_options=None,
        read_preference=None,
        write_concern=None,
        read_concern=None,
    ):
        try:
            return self._collection_accesses[name].with_options(
                codec_options=codec_options or self._codec_options,
                read_preference=read_preference or self.read_preference,
                read_concern=read_concern,
                write_concern=write_concern,
            )
        except KeyError:
            self._ensure_valid_collection_name(name)
            collection = self._collection_accesses[name] = ReconnectingCollection(
                max_calls_before_reconnect=self._max_calls_before_reconnect,
                max_calls_before_failure=self._max_calls_before_failure,
                exception_to_raise=self._exception_to_raise,
                database=self,
                name=name,
                write_concern=write_concern,
                read_preference=read_preference or self.read_preference,
                codec_options=codec_options or self._codec_options,
                _db_store=self._store,
            )
            return collection


class ReconnectingCollection(FailingCollection):
    def __init__(self, max_calls_before_reconnect, **kwargs):
        super().__init__(**kwargs)
        self._max_calls_before_reconnect = max_calls_before_reconnect

    def insert_one(self, document, session=None):
        self._calls += 1
        if self._is_in_failure_range():
            print(self.name, "insert no connection")
            raise self._exception_to_raise
        else:
            print(self.name, "insert connection reestablished")
            return mongomock.Collection.insert_one(self, document)

    def update_one(self, filter, update, upsert=False, session=None):
        self._calls += 1
        if self._is_in_failure_range():
            print(self.name, "update no connection")

            raise self._exception_to_raise
        else:
            print(self.name, "update connection reestablished")

            return mongomock.Collection.update_one(self, filter, update, upsert)

    def _is_in_failure_range(self):
        return (
            self._max_calls_before_failure
            < self._calls
            <= self._max_calls_before_reconnect
        )