import threading
import warnings
from collections import defaultdict
from enum import Enum
from functools import partial
from typing import Type, Dict, Any, Set, overload
from uuid import uuid4

from asgiref.sync import async_to_sync
from channels.layers import get_channel_layer
from django.db import transaction
from django.db.models import Model
from django.db.models.signals import post_delete, post_save, post_init

from djangochannelsrestframework.consumers import AsyncAPIConsumer
from import BaseObserver

1) not in transaction
2) in a simple transation with one operations
3) in a simple transation with mutliple operations
4) was in a transation but rolled back then outside of a transation saved
5) was in a transatino but rolled back then inside a new one saved
6) in a transations savepoint that is never saved
7) in a transation and then in a save point

On each model instance we add a `__observers = {self(id): {tracking info}}`

class Action(Enum):
    CREATE = "create"
    UPDATE = "update"
    DELETE = "delete"

class UnsupportedWarning(Warning):


class ModelObserverInstanceState:
    # this is set when the instance is created
    current_groups: Set[str] = set()

class ModelObserver(BaseObserver):
    def __init__(self, func, model_cls: Type[Model], **kwargs):
        self._model_cls = None
        self.model_cls = model_cls  # type: Type[Model] = uuid4()

    def model_cls(self) -> Type[Model]:
        return self._model_cls

    def model_cls(self, value: Type[Model]):
        was_none = self._model_cls is None
        self._model_cls = value

        if self._model_cls is not None and was_none:

    def _connect(self):
        Connect the signal listing.

        # this is used to capture the current state for the model
            self.post_init_receiver, sender=self.model_cls, dispatch_uid=id(self)

            self.post_save_receiver, sender=self.model_cls, dispatch_uid=id(self)

            self.post_delete_receiver, sender=self.model_cls, dispatch_uid=id(self)

    def post_init_receiver(self, instance: Model, **kwargs):

        if is None:
            current_groups = set()
            current_groups = set(self.group_names_for_signal(instance=instance))

        self.get_observer_state(instance).current_groups = current_groups

    def get_observer_state(self, instance: Model) -> ModelObserverInstanceState:
        # use a thread local dict to be safe...
        if not hasattr(instance._state, "_thread_local_observers"):
            instance._state._thread_local_observers = defaultdict(

        return instance._state._thread_local_observers[]

    def post_save_receiver(self, instance: Model, created: bool, **kwargs):
        Handle the post save.
        if created:
            self.database_event(instance, Action.CREATE)
            self.database_event(instance, Action.UPDATE)

    def post_delete_receiver(self, instance: Model, **kwargs):
        self.database_event(instance, Action.DELETE)

    def database_event(self, instance: Model, action: Action):

        connection = transaction.get_connection()

        if connection.in_atomic_block:
            if len(connection.savepoint_ids) > 0:
                    "Model observation with save points is unsupported and will"
                    " result in unexpected beauvoir.",

        connection.on_commit(partial(self.post_change_receiver, instance, action))

    def post_change_receiver(self, instance: Model, action: Action, **kwargs):
        Triggers the old_binding to possibly send to its group.

        old_group_names = self.get_observer_state(instance).current_groups

        if action == Action.DELETE:
            new_group_names = set()
            new_group_names = set(self.group_names_for_signal(instance=instance))

        self.get_observer_state(instance).current_groups = new_group_names

        # if post delete, new_group_names should be []

        # Django DDP had used the ordering of DELETE, UPDATE then CREATE for good reasons.
            instance, old_group_names - new_group_names, Action.DELETE, **kwargs
        # the object has been updated so that its groups are not the same.
            instance, old_group_names & new_group_names, Action.UPDATE, **kwargs

            instance, new_group_names - old_group_names, Action.CREATE, **kwargs

    def send_messages(
        self, instance: Model, group_names: Set[str], action: Action, **kwargs
        if not group_names:
        message = self.serialize(instance, action, **kwargs)
        channel_layer = get_channel_layer()
        for group_name in group_names:
            async_to_sync(channel_layer.group_send)(group_name, message)

    def group_names(self, *args, **kwargs):
        # one channel for all updates.
        yield "{}-{}-model-{}".format(
            self._uuid, self.func.__name__.replace("_", "."), self.model_label,

    def serialize(self, instance, action, **kwargs) -> Dict[str, Any]:
        message = {}
        if self._serializer:
            message = self._serializer(self, instance, action, **kwargs)
            message["pk"] =
        message["type"] = self.func.__name__.replace("_", ".")
        message["action"] = action.value
        return message

    def model_label(self):
        model_label = (
            .replace("_", ".")
        return model_label