"""This module contains custom field classes.""" import importlib import pickle from django.utils import six from django.utils.functional import cached_property from rest_framework import fields from rest_framework.exceptions import ValidationError, ParseError from rest_framework.serializers import SerializerMethodField from dynamic_rest.bases import ( CacheableFieldMixin, DynamicSerializerBase, resettable_cached_property ) from dynamic_rest.conf import settings from dynamic_rest.fields.common import WithRelationalFieldMixin from dynamic_rest.meta import is_field_remote, get_model_field class DynamicField(CacheableFieldMixin, fields.Field): """ Generic field base to capture additional custom field attributes. """ def __init__( self, requires=None, deferred=None, field_type=None, immutable=False, **kwargs ): """ Arguments: deferred: Whether or not this field is deferred. Deferred fields are not included in the response, unless explicitly requested. field_type: Field data type, if not inferrable from model. requires: List of fields that this field depends on. Processed by the view layer during queryset build time. """ self.requires = requires self.deferred = deferred self.field_type = field_type self.immutable = immutable self.kwargs = kwargs super(DynamicField, self).__init__(**kwargs) def to_representation(self, value): return value def to_internal_value(self, data): return data class DynamicComputedField(DynamicField): pass class DynamicMethodField(SerializerMethodField, DynamicField): def reset(self): super(DynamicMethodField, self).reset() if self.method_name == 'get_' + self.field_name: self.method_name = None class DynamicRelationField(WithRelationalFieldMixin, DynamicField): """Field proxy for a nested serializer. Supports passing in the child serializer as a class or string, and resolves to the class after binding to the parent serializer. Will proxy certain arguments to the child serializer. Attributes: SERIALIZER_KWARGS: list of arguments that are passed to the child serializer. """ SERIALIZER_KWARGS = set(('many', 'source')) def __init__( self, serializer_class, many=False, queryset=None, embed=False, sideloading=None, debug=False, **kwargs ): """ Arguments: serializer_class: Serializer class (or string representation) to proxy. many: Boolean, if relation is to-many. queryset: Default queryset to apply when filtering for related objects. sideloading: if True, force sideloading all the way down. if False, force embedding all the way down. This overrides the "embed" option if set. embed: If True, always embed related object(s). Will not sideload, and will include the full object unless specifically excluded. """ self._serializer_class = serializer_class self.bound = False self.queryset = queryset self.sideloading = sideloading self.debug = debug self.embed = embed if sideloading is None else not sideloading if '.' in kwargs.get('source', ''): raise Exception('Nested relationships are not supported') if 'link' in kwargs: self.link = kwargs.pop('link') super(DynamicRelationField, self).__init__(**kwargs) self.kwargs['many'] = self.many = many def get_model(self): """Get the child serializer's model.""" return getattr(self.serializer_class.Meta, 'model', None) def bind(self, *args, **kwargs): """Bind to the parent serializer.""" if self.bound: # Prevent double-binding return super(DynamicRelationField, self).bind(*args, **kwargs) self.bound = True parent_model = getattr(self.parent.Meta, 'model', None) remote = is_field_remote(parent_model, self.source) try: model_field = get_model_field(parent_model, self.source) except: # model field may not be available for m2o fields with no # related_name model_field = None # Infer `required` and `allow_null` if 'required' not in self.kwargs and ( remote or ( model_field and ( model_field.has_default() or model_field.null ) ) ): self.required = False if 'allow_null' not in self.kwargs and getattr( model_field, 'null', False ): self.allow_null = True self.model_field = model_field @resettable_cached_property def root_serializer(self): """Return the root serializer (serializer for the primary resource).""" if not self.parent: # Don't cache, so that we'd recompute if parent is set. return None node = self seen = set() while True: seen.add(node) if getattr(node, 'parent', None): node = node.parent if node in seen: return None else: return node def _get_cached_serializer(self, args, init_args): enabled = settings.ENABLE_SERIALIZER_CACHE root = self.root_serializer if not root or not self.field_name or not enabled: # Not enough info to use cache. return self.serializer_class(*args, **init_args) if not hasattr(root, '_descendant_serializer_cache'): # Initialize dict to use as cache on root serializer. # Arguably this is a Serializer concern, but we'll do it # here so it's agnostic to the exact type of the root # serializer (i.e. it could be a DRF serializer). root._descendant_serializer_cache = {} key_dict = { 'parent': self.parent.__class__.__name__, 'field': self.field_name, 'args': args, 'init_args': init_args } cache_key = hash(pickle.dumps(key_dict)) if cache_key not in root._descendant_serializer_cache: szr = self.serializer_class( *args, **init_args ) root._descendant_serializer_cache[cache_key] = szr else: root._descendant_serializer_cache[cache_key].reset() return root._descendant_serializer_cache[cache_key] def _inherit_parent_kwargs(self, kwargs): """Extract any necessary attributes from parent serializer to propagate down to child serializer. """ if not self.parent or not self._is_dynamic: return kwargs if 'request_fields' not in kwargs: # If 'request_fields' isn't explicitly set, pull it from the # parent serializer. request_fields = self._get_request_fields_from_parent() if request_fields is None: # Default to 'id_only' for nested serializers. request_fields = True kwargs['request_fields'] = request_fields if self.embed and kwargs.get('request_fields') is True: # If 'embed' then make sure we fetch the full object. kwargs['request_fields'] = {} if hasattr(self.parent, 'sideloading'): kwargs['sideloading'] = self.parent.sideloading if hasattr(self.parent, 'debug'): kwargs['debug'] = self.parent.debug return kwargs def get_serializer(self, *args, **kwargs): """Get an instance of the child serializer.""" init_args = { k: v for k, v in six.iteritems(self.kwargs) if k in self.SERIALIZER_KWARGS } kwargs = self._inherit_parent_kwargs(kwargs) init_args.update(kwargs) if self.embed and self._is_dynamic: init_args['embed'] = True serializer = self._get_cached_serializer(args, init_args) serializer.parent = self return serializer @resettable_cached_property def serializer(self): return self.get_serializer() @cached_property def _is_dynamic(self): """Return True if the child serializer is dynamic.""" return issubclass( self.serializer_class, DynamicSerializerBase ) def get_attribute(self, instance): serializer = self.serializer model = serializer.get_model() # attempt to optimize by reading the related ID directly # from the current instance rather than from the related object if not self.kwargs['many'] and serializer.id_only(): return instance elif model is not None: try: return getattr(instance, self.source) except model.DoesNotExist: return None else: return instance def to_representation(self, instance): """Represent the relationship, either as an ID or object.""" serializer = self.serializer model = serializer.get_model() source = self.source if not self.kwargs['many'] and serializer.id_only(): # attempt to optimize by reading the related ID directly # from the current instance rather than from the related object source_id = '%s_id' % source # try the faster way first: if hasattr(instance, source_id): return getattr(instance, source_id) elif model is not None: # this is probably a one-to-one field, or a reverse related # lookup, so let's look it up the slow way and let the # serializer handle the id dereferencing try: instance = getattr(instance, source) except model.DoesNotExist: instance = None # dereference ephemeral objects if model is None: instance = getattr(instance, source) if instance is None: return None return serializer.to_representation(instance) def to_internal_value_single(self, data, serializer): """Return the underlying object, given the serialized form.""" related_model = serializer.Meta.model if isinstance(data, related_model): return data try: instance = related_model.objects.get(pk=data) except related_model.DoesNotExist: raise ValidationError( "Invalid value for '%s': %s object with ID=%s not found" % (self.field_name, related_model.__name__, data) ) return instance def to_internal_value(self, data): """Return the underlying object(s), given the serialized form.""" if self.kwargs['many']: serializer = self.serializer.child if not isinstance(data, list): raise ParseError("'%s' value must be a list" % self.field_name) return [ self.to_internal_value_single( instance, serializer ) for instance in data ] return self.to_internal_value_single(data, self.serializer) @property def serializer_class(self): """Get the class of the child serializer. Resolves string imports. """ serializer_class = self._serializer_class if not isinstance(serializer_class, six.string_types): return serializer_class parts = serializer_class.split('.') module_path = '.'.join(parts[:-1]) if not module_path: if getattr(self, 'parent', None) is None: raise Exception( "Can not load serializer '%s'" % serializer_class + ' before binding or without specifying full path') # try the module of the parent class module_path = self.parent.__module__ module = importlib.import_module(module_path) serializer_class = getattr(module, parts[-1]) self._serializer_class = serializer_class return serializer_class class CountField(DynamicComputedField): """ Computed field that counts the number of elements in another field. """ def __init__(self, serializer_source, *args, **kwargs): """ Arguments: serializer_source: A serializer field. unique: Whether or not to perform a count of distinct elements. """ self.field_type = int # Use `serializer_source`, which indicates a field at the API level, # instead of `source`, which indicates a field at the model level. self.serializer_source = serializer_source # Set `source` to an empty value rather than the field name to avoid # an attempt to look up this field. kwargs['source'] = '' self.unique = kwargs.pop('unique', True) return super(CountField, self).__init__(*args, **kwargs) def get_attribute(self, obj): source = self.serializer_source if source not in self.parent.fields: return None value = self.parent.fields[source].get_attribute(obj) data = self.parent.fields[source].to_representation(value) # How to count None is undefined... let the consumer decide. if data is None: return None # Check data type. Technically len() works on dicts, strings, but # since this is a "count" field, we'll limit to list, set, tuple. if not isinstance(data, (list, set, tuple)): raise TypeError( "'%s' is %s. Must be list, set or tuple to be countable." % ( source, type(data) ) ) if self.unique: # Try to create unique set. This may fail if `data` contains # non-hashable elements (like dicts). try: data = set(data) except TypeError: pass return len(data)