from __future__ import unicode_literals import json import datetime from django.core.exceptions import FieldDoesNotExist from django.db import models from django.db.models.fields.related import ForeignObjectRel from django.utils.encoding import is_protected_type from django.core.serializers.json import DjangoJSONEncoder from django.conf import settings from django.utils import timezone from modelcluster.fields import ParentalKey, ParentalManyToManyField def get_field_value(field, model): if field.remote_field is None: value = field.pre_save(model, add=model.pk is None) # Make datetimes timezone aware # https://github.com/django/django/blob/master/django/db/models/fields/__init__.py#L1394-L1403 if isinstance(value, datetime.datetime) and settings.USE_TZ: if timezone.is_naive(value): default_timezone = timezone.get_default_timezone() value = timezone.make_aware(value, default_timezone).astimezone(timezone.utc) # convert to UTC value = timezone.localtime(value, timezone.utc) if is_protected_type(value): return value else: return field.value_to_string(model) else: return getattr(model, field.get_attname()) def get_serializable_data_for_fields(model): """ Return a serialised version of the model's fields which exist as local database columns (i.e. excluding m2m and incoming foreign key relations) """ pk_field = model._meta.pk # If model is a child via multitable inheritance, use parent's pk while pk_field.remote_field and pk_field.remote_field.parent_link: pk_field = pk_field.remote_field.model._meta.pk obj = {'pk': get_field_value(pk_field, model)} for field in model._meta.fields: if field.serialize: obj[field.name] = get_field_value(field, model) return obj def model_from_serializable_data(model, data, check_fks=True, strict_fks=False): pk_field = model._meta.pk # If model is a child via multitable inheritance, use parent's pk while pk_field.remote_field and pk_field.remote_field.parent_link: pk_field = pk_field.remote_field.model._meta.pk kwargs = {pk_field.attname: data['pk']} for field_name, field_value in data.items(): try: field = model._meta.get_field(field_name) except FieldDoesNotExist: continue # Filter out reverse relations if isinstance(field, ForeignObjectRel): continue if field.remote_field and isinstance(field.remote_field, models.ManyToManyRel): related_objects = field.remote_field.model._default_manager.filter(pk__in=field_value) kwargs[field.attname] = list(related_objects) elif field.remote_field and isinstance(field.remote_field, models.ManyToOneRel): if field_value is None: kwargs[field.attname] = None else: clean_value = field.remote_field.model._meta.get_field(field.remote_field.field_name).to_python(field_value) kwargs[field.attname] = clean_value if check_fks: try: field.remote_field.model._default_manager.get(**{field.remote_field.field_name: clean_value}) except field.remote_field.model.DoesNotExist: if field.remote_field.on_delete == models.DO_NOTHING: pass elif field.remote_field.on_delete == models.CASCADE: if strict_fks: return None else: kwargs[field.attname] = None elif field.remote_field.on_delete == models.SET_NULL: kwargs[field.attname] = None else: raise Exception("can't currently handle on_delete types other than CASCADE, SET_NULL and DO_NOTHING") else: value = field.to_python(field_value) # Make sure datetimes are converted to localtime if isinstance(field, models.DateTimeField) and settings.USE_TZ and value is not None: default_timezone = timezone.get_default_timezone() if timezone.is_aware(value): value = timezone.localtime(value, default_timezone) else: value = timezone.make_aware(value, default_timezone) kwargs[field.name] = value obj = model(**kwargs) if data['pk'] is not None: # Set state to indicate that this object has come from the database, so that # ModelForm validation doesn't try to enforce a uniqueness check on the primary key obj._state.adding = False return obj def get_all_child_relations(model): """ Return a list of RelatedObject records for child relations of the given model, including ones attached to ancestors of the model """ return [ field for field in model._meta.get_fields() if isinstance(field.remote_field, ParentalKey) ] def get_all_child_m2m_relations(model): """ Return a list of ParentalManyToManyFields on the given model, including ones attached to ancestors of the model """ return [ field for field in model._meta.get_fields() if isinstance(field, ParentalManyToManyField) ] class ClusterableModel(models.Model): def __init__(self, *args, **kwargs): """ Extend the standard model constructor to allow child object lists to be passed in via kwargs """ child_relation_names = ( [rel.get_accessor_name() for rel in get_all_child_relations(self)] + [field.name for field in get_all_child_m2m_relations(self)] ) if any(name in kwargs for name in child_relation_names): # One or more child relation values is being passed in the constructor; need to # separate these from the standard field kwargs to be passed to 'super' kwargs_for_super = kwargs.copy() relation_assignments = {} for rel_name in child_relation_names: if rel_name in kwargs: relation_assignments[rel_name] = kwargs_for_super.pop(rel_name) super(ClusterableModel, self).__init__(*args, **kwargs_for_super) for (field_name, related_instances) in relation_assignments.items(): setattr(self, field_name, related_instances) else: super(ClusterableModel, self).__init__(*args, **kwargs) def save(self, **kwargs): """ Save the model and commit all child relations. """ child_relation_names = [rel.get_accessor_name() for rel in get_all_child_relations(self)] child_m2m_field_names = [field.name for field in get_all_child_m2m_relations(self)] update_fields = kwargs.pop('update_fields', None) if update_fields is None: real_update_fields = None relations_to_commit = child_relation_names m2m_fields_to_commit = child_m2m_field_names else: real_update_fields = [] relations_to_commit = [] m2m_fields_to_commit = [] for field in update_fields: if field in child_relation_names: relations_to_commit.append(field) elif field in child_m2m_field_names: m2m_fields_to_commit.append(field) else: real_update_fields.append(field) super(ClusterableModel, self).save(update_fields=real_update_fields, **kwargs) for relation in relations_to_commit: getattr(self, relation).commit() for field in m2m_fields_to_commit: getattr(self, field).commit() def serializable_data(self): obj = get_serializable_data_for_fields(self) for rel in get_all_child_relations(self): rel_name = rel.get_accessor_name() children = getattr(self, rel_name).all() if hasattr(rel.related_model, 'serializable_data'): obj[rel_name] = [child.serializable_data() for child in children] else: obj[rel_name] = [get_serializable_data_for_fields(child) for child in children] for field in get_all_child_m2m_relations(self): if field.serialize: children = getattr(self, field.name).all() obj[field.name] = [child.pk for child in children] return obj def to_json(self): return json.dumps(self.serializable_data(), cls=DjangoJSONEncoder) @classmethod def from_serializable_data(cls, data, check_fks=True, strict_fks=False): """ Build an instance of this model from the JSON-like structure passed in, recursing into related objects as required. If check_fks is true, it will check whether referenced foreign keys still exist in the database. - dangling foreign keys on related objects are dealt with by either nullifying the key or dropping the related object, according to the 'on_delete' setting. - dangling foreign keys on the base object will be nullified, unless strict_fks is true, in which case any dangling foreign keys with on_delete=CASCADE will cause None to be returned for the entire object. """ obj = model_from_serializable_data(cls, data, check_fks=check_fks, strict_fks=strict_fks) if obj is None: return None child_relations = get_all_child_relations(cls) for rel in child_relations: rel_name = rel.get_accessor_name() try: child_data_list = data[rel_name] except KeyError: continue related_model = rel.related_model if hasattr(related_model, 'from_serializable_data'): children = [ related_model.from_serializable_data(child_data, check_fks=check_fks, strict_fks=True) for child_data in child_data_list ] else: children = [ model_from_serializable_data(related_model, child_data, check_fks=check_fks, strict_fks=True) for child_data in child_data_list ] children = filter(lambda child: child is not None, children) setattr(obj, rel_name, children) return obj @classmethod def from_json(cls, json_data, check_fks=True, strict_fks=False): return cls.from_serializable_data(json.loads(json_data), check_fks=check_fks, strict_fks=strict_fks) class Meta: abstract = True