import threading
import warnings
from collections import defaultdict
from enum import Enum
from functools import partial
from typing import Type, Dict, Any, Set, overload
from uuid import uuid4

from asgiref.sync import async_to_sync
from channels.layers import get_channel_layer
from django.db import transaction
from django.db.models import Model
from django.db.models.signals import post_delete, post_save, post_init

from djangochannelsrestframework.consumers import AsyncAPIConsumer
from djangochannelsrestframework.observer.base_observer import BaseObserver

"""
1) not in transaction
2) in a simple transation with one operations
3) in a simple transation with mutliple operations
4) was in a transation but rolled back then outside of a transation saved
5) was in a transatino but rolled back then inside a new one saved
6) in a transations savepoint that is never saved
7) in a transation and then in a save point


On each model instance we add a `__observers = {self(id): {tracking info}}`
"""


class Action(Enum):
    CREATE = "create"
    UPDATE = "update"
    DELETE = "delete"


class UnsupportedWarning(Warning):
    """

    """


class ModelObserverInstanceState:
    # this is set when the instance is created
    current_groups: Set[str] = set()


class ModelObserver(BaseObserver):
    def __init__(self, func, model_cls: Type[Model], **kwargs):
        super().__init__(func)
        self._model_cls = None
        self.model_cls = model_cls  # type: Type[Model]
        self.id = uuid4()

    @property
    def model_cls(self) -> Type[Model]:
        return self._model_cls

    @model_cls.setter
    def model_cls(self, value: Type[Model]):
        was_none = self._model_cls is None
        self._model_cls = value

        if self._model_cls is not None and was_none:
            self._connect()

    def _connect(self):
        """
        Connect the signal listing.
        """

        # this is used to capture the current state for the model
        post_init.connect(
            self.post_init_receiver, sender=self.model_cls, dispatch_uid=id(self)
        )

        post_save.connect(
            self.post_save_receiver, sender=self.model_cls, dispatch_uid=id(self)
        )

        post_delete.connect(
            self.post_delete_receiver, sender=self.model_cls, dispatch_uid=id(self)
        )

    def post_init_receiver(self, instance: Model, **kwargs):

        if instance.pk is None:
            current_groups = set()
        else:
            current_groups = set(self.group_names_for_signal(instance=instance))

        self.get_observer_state(instance).current_groups = current_groups

    def get_observer_state(self, instance: Model) -> ModelObserverInstanceState:
        # use a thread local dict to be safe...
        if not hasattr(instance._state, "_thread_local_observers"):
            instance._state._thread_local_observers = defaultdict(
                ModelObserverInstanceState
            )

        return instance._state._thread_local_observers[self.id]

    def post_save_receiver(self, instance: Model, created: bool, **kwargs):
        """
        Handle the post save.
        """
        if created:
            self.database_event(instance, Action.CREATE)
        else:
            self.database_event(instance, Action.UPDATE)

    def post_delete_receiver(self, instance: Model, **kwargs):
        self.database_event(instance, Action.DELETE)

    def database_event(self, instance: Model, action: Action):

        connection = transaction.get_connection()

        if connection.in_atomic_block:
            if len(connection.savepoint_ids) > 0:
                warnings.warn(
                    "Model observation with save points is unsupported and will"
                    " result in unexpected beauvoir.",
                    UnsupportedWarning,
                )

        connection.on_commit(partial(self.post_change_receiver, instance, action))

    def post_change_receiver(self, instance: Model, action: Action, **kwargs):
        """
        Triggers the old_binding to possibly send to its group.
        """

        old_group_names = self.get_observer_state(instance).current_groups

        if action == Action.DELETE:
            new_group_names = set()
        else:
            new_group_names = set(self.group_names_for_signal(instance=instance))

        self.get_observer_state(instance).current_groups = new_group_names

        # if post delete, new_group_names should be []

        # Django DDP had used the ordering of DELETE, UPDATE then CREATE for good reasons.
        self.send_messages(
            instance, old_group_names - new_group_names, Action.DELETE, **kwargs
        )
        # the object has been updated so that its groups are not the same.
        self.send_messages(
            instance, old_group_names & new_group_names, Action.UPDATE, **kwargs
        )

        #
        self.send_messages(
            instance, new_group_names - old_group_names, Action.CREATE, **kwargs
        )

    def send_messages(
        self, instance: Model, group_names: Set[str], action: Action, **kwargs
    ):
        if not group_names:
            return
        message = self.serialize(instance, action, **kwargs)
        channel_layer = get_channel_layer()
        for group_name in group_names:
            async_to_sync(channel_layer.group_send)(group_name, message)

    def group_names(self, *args, **kwargs):
        # one channel for all updates.
        yield "{}-{}-model-{}".format(
            self._uuid, self.func.__name__.replace("_", "."), self.model_label,
        )

    def serialize(self, instance, action, **kwargs) -> Dict[str, Any]:
        message = {}
        if self._serializer:
            message = self._serializer(self, instance, action, **kwargs)
        else:
            message["pk"] = instance.pk
        message["type"] = self.func.__name__.replace("_", ".")
        message["action"] = action.value
        return message

    @property
    def model_label(self):
        model_label = (
            "{}.{}".format(
                self.model_cls._meta.app_label.lower(),
                self.model_cls._meta.object_name.lower(),
            )
            .lower()
            .replace("_", ".")
        )
        return model_label