import re
from enum import Enum
from functools import wraps
from logging import getLogger, basicConfig, INFO

from pymongo import MongoClient
from pymongo.errors import ServerSelectionTimeoutError, PyMongoError

logger = getLogger(__name__)

basicConfig(format="%(message)s", level=INFO)


def with_retry(tries):
    def outer_wrapper(f):
        @wraps(f)
        def inner_wrapper(*args, **kwargs):
            def _retry(t=tries):
                if t <= 0:
                    logger.error("unable to write hit to database")
                    # raise WriteError(f"unable to write to database")
                    return
                try:
                    f(*args, **kwargs)
                except PyMongoError:
                    t -= 1
                    _retry(t)

            return _retry()

        return inner_wrapper

    return outer_wrapper


class StringEnum(Enum):
    def __str__(self):
        return str(self.value["text"])

    def __repr__(self):
        return str(self.value["key"])


class Access(StringEnum):
    PUBLIC = {"key": "+", "text": "public"}
    PRIVATE = {"key": "-", "text": "private"}


class Hit:
    def __init__(self, url: str, access: Access):
        self.url = url
        self.access = access

    def __iter__(self):
        yield from {"url": self.url, "access": str(self.access)}.items()

    def is_valid(self):
        return (
            re.match(r"^https?://.*\.amazonaws.com/.*$", self.url)
            and self.access in Access
        )


class MongoDB:
    def __init__(
        self,
        host: str = "0.0.0.0",
        port: int = 27017,
        db_name: str = "s3recon",
        col_name: str = "hits",
        unique_indicies: tuple = ("url",),
        indicies: tuple = ("access",),
        timeout: int = 10,
    ):
        self.client = MongoClient(host, port, serverSelectionTimeoutMS=timeout)
        self.db_name = db_name
        self.col_name = col_name
        self.index(unique_indicies, unique=True)
        self.index(indicies)

    def __del__(self):
        self.client.close()

    def index(self, indicies=(), **kwargs):
        for i in indicies:
            self.client[self.db_name][self.col_name].ensure_index(i, **kwargs)

    @staticmethod
    def normalize(item):
        if isinstance(item, (list, set)):
            return list(map(dict, item))
        else:
            return dict(item)

    @with_retry(3)
    def insert_many(self, items):
        self.client[self.db_name][self.col_name].insert_many(self.normalize(items))

    @with_retry(3)
    def insert(self, item):
        self.client[self.db_name][self.col_name].insert(self.normalize(item))

    @with_retry(3)
    def update_many(self, filter, items):
        self.client[self.db_name][self.col_name].update_many(
            filter, self.normalize(items), upsert=True
        )

    @with_retry(3)
    def update(self, filter, item):
        self.client[self.db_name][self.col_name].update(
            filter, self.normalize(item), upsert=True
        )

    def is_connected(self):
        try:
            self.client.server_info()
        except ServerSelectionTimeoutError:
            return False
        return True