import copy from django.db.models import Prefetch from django.db.models.fields.related import ManyToManyRel, ManyToOneRel from django.http import QueryDict from django.utils.functional import cached_property from rest_framework.fields import empty from rest_framework.serializers import ( ListSerializer, Serializer, ValidationError ) from .exceptions import FieldNotFound, QueryFormatError from .fields import ( BaseReplaceableNestedField, BaseRESTQLNestedField, BaseWritableNestedField, DynamicSerializerMethodField ) from .operations import ADD, CREATE, REMOVE, UPDATE from .parser import Parser from .settings import restql_settings class RequestQueryParserMixin(object): """ Mixin for parsing restql query from request. NOTE: We are using `request.GET` instead of `request.query_params` because this might be called before DRF request is created(i.e from dispatch). This means `request.query_params` might not available when this mixin is used. """ @staticmethod def get_restql_query_param_name(): DEFAULT_QUERY_PARAM_NAME = 'query' query_param_name = getattr( restql_settings, "QUERY_PARAM_NAME", DEFAULT_QUERY_PARAM_NAME ) return query_param_name @classmethod def has_restql_query_param(cls, request): query_param_name = cls.get_restql_query_param_name() return query_param_name in request.GET @classmethod def get_raw_restql_query(cls, request): query_param_name = cls.get_restql_query_param_name() return request.GET[query_param_name] @classmethod def get_parsed_restql_query_from_req(cls, request): if hasattr(request, 'parsed_restql_query'): # Use cached parsed restql query return request.parsed_restql_query raw_query = cls.get_raw_restql_query(request) parser = Parser(raw_query) parsed_restql_query = parser.get_parsed() # Save parsed restql query to the request so that # we won't need to parse it again if needed later request.parsed_restql_query = parsed_restql_query return parsed_restql_query class QueryArgumentsMixin(RequestQueryParserMixin): """Mixin for converting query arguments into query parameters""" def get_parsed_restql_query(self, request): if self.has_restql_query_param(request): try: return self.get_parsed_restql_query_from_req(request) except (SyntaxError, QueryFormatError): # Let `DynamicFieldsMixin` handle this for a user # to get a helpful error message pass query = { "include": ["*"], "exclude": [], "arguments": {} } return query def build_query_params(self, parsed_query, parent=None): query_params = {} prefix = '' if parent is None: query_params.update(parsed_query['arguments']) else: prefix = parent + '__' for argument, value in parsed_query['arguments'].items(): name = prefix + argument query_params.update({ name: value }) for field in parsed_query['include']: if isinstance(field, dict): for sub_field, sub_parsed_query in field.items(): nested_query_params = self.build_query_params( sub_parsed_query, parent=prefix + sub_field ) query_params.update(nested_query_params) return query_params def get_query_params(self, request): parsed = self.get_parsed_restql_query(request) query_params = self.build_query_params(parsed) return query_params def dispatch(self, request, *args, **kwargs): query_params = self.get_query_params(request) # We are using `request.GET` instead of `request.query_params` # because at this point DRF request is not yet created so # `request.query_params` is not yet available params = request.GET.copy() params.update(query_params) # Make QueryDict immutable after updating request.GET = QueryDict(params.urlencode(), mutable=False) return super().dispatch(request, *args, **kwargs) class DynamicFieldsMixin(RequestQueryParserMixin): def __init__(self, *args, **kwargs): # Don't pass 'query', 'fields', 'exclude', 'return_pk' # and 'disable_dynamic_fields' kwargs to the superclass self.parsed_restql_query = kwargs.pop('query', None) self.allowed_fields = kwargs.pop('fields', None) self.excluded_fields = kwargs.pop('exclude', None) self.return_pk = kwargs.pop('return_pk', False) self.disable_dynamic_fields = kwargs.pop('disable_dynamic_fields', False) is_field_kwarg_set = self.allowed_fields is not None is_exclude_kwarg_set = self.excluded_fields is not None msg = "May not set both `fields` and `exclude`" assert not(is_field_kwarg_set and is_exclude_kwarg_set), msg # flag to toggle using restql fields self._use_restql_fields = False # Instantiate the superclass normally super().__init__(*args, **kwargs) def to_representation(self, instance): # Activate to use restql fields self._use_restql_fields = True if self.return_pk: return instance.pk return super().to_representation(instance) def get_allowed_fields(self): fields = self._all_fields if self.allowed_fields is not None: # Drop all fields which are not specified on the `fields` kwarg. allowed = set(self.allowed_fields) existing = set(fields) not_allowed = existing.symmetric_difference(allowed) for field_name in not_allowed: try: fields.pop(field_name) except KeyError: msg = "Field `%s` is not found" % field_name raise FieldNotFound(msg) from None if self.excluded_fields is not None: # Drop all fields specified on the `exclude` kwarg. not_allowed = set(self.excluded_fields) for field_name in not_allowed: try: fields.pop(field_name) except KeyError: msg = "Field `%s` is not found" % field_name raise FieldNotFound(msg) from None return fields @staticmethod def is_field_found(field_name, all_field_names, raise_exception=False): if field_name in all_field_names: return True else: if raise_exception: msg = "'%s' field is not found" % field_name raise ValidationError(msg) return False @staticmethod def is_nested_field(field_name, field, raise_exception=False): nested_classes = ( Serializer, ListSerializer, DynamicSerializerMethodField ) if isinstance(field, nested_classes): return True else: if raise_exception: msg = "'%s' is not a nested field" % field_name raise ValidationError(msg) return False def include_fields(self): all_fields = self.get_allowed_fields() all_field_names = list(all_fields.keys()) allowed_flat_fields = [] # The format is {nested_field: [sub_fields ...] ...} allowed_nested_fields = {} # The self.parsed_restql_query["include"] # contains a list of allowed fields, # The format is [field, {nested_field: [sub_fields ...]} ...] included_fields = self.parsed_restql_query["include"] include_all_fields = False for field in included_fields: if field == "*": # Include all fields include_all_fields = True continue if isinstance(field, dict): # Nested field for nested_field in field: self.is_field_found( nested_field, all_field_names, raise_exception=True ) self.is_nested_field( nested_field, all_fields[nested_field], raise_exception=True ) allowed_nested_fields.update(field) else: # Flat field self.is_field_found(field, all_field_names, raise_exception=True) allowed_flat_fields.append(field) self.nested_fields = allowed_nested_fields if include_all_fields: # Return all fields return all_fields all_allowed_fields = ( allowed_flat_fields + list(allowed_nested_fields.keys()) ) for field in all_field_names: if field not in all_allowed_fields: all_fields.pop(field) return all_fields def exclude_fields(self): all_fields = self.get_allowed_fields() all_field_names = list(all_fields.keys()) # The format is {nested_field: [sub_fields ...] ...} allowed_nested_fields = {} # The self.parsed_restql_query["include"] # contains a list of expanded nested fields # The format is [{nested_field: [sub_field]} ...] nested_fields = self.parsed_restql_query["include"] for field in nested_fields: if field == "*": # Ignore this since it's not an actual field(it's just a flag) continue for nested_field in field: self.is_field_found( nested_field, all_field_names, raise_exception=True ) self.is_nested_field( nested_field, all_fields[nested_field], raise_exception=True ) allowed_nested_fields.update(field) # self.parsed_restql_query["exclude"] # is a list of names of excluded fields excluded_fields = self.parsed_restql_query["exclude"] for field in excluded_fields: self.is_field_found(field, all_field_names, raise_exception=True) all_fields.pop(field) self.nested_fields = allowed_nested_fields return all_fields @cached_property def restql_fields(self): request = self.context.get('request') is_not_a_request_to_process = ( request is None or self.disable_dynamic_fields or not self.has_restql_query_param(request) ) if is_not_a_request_to_process: return self.get_allowed_fields() is_top_retrieve_request = ( self.field_name is None and self.parent is None ) is_top_list_request = ( isinstance(self.parent, ListSerializer) and self.parent.parent is None and self.parent.field_name is None ) if is_top_retrieve_request or is_top_list_request: if self.parsed_restql_query is None: # Use a parsed query from the request try: self.parsed_restql_query = \ self.get_parsed_restql_query_from_req(request) except SyntaxError as e: msg = "QuerySyntaxError: " + e.msg + " on " + e.text raise ValidationError(msg) from None except QueryFormatError as e: msg = "QueryFormatError: " + str(e) raise ValidationError(msg) from None elif isinstance(self.parent, ListSerializer): field_name = self.parent.field_name parent = self.parent.parent if hasattr(parent, "nested_fields"): parent_nested_fields = parent.nested_fields self.parsed_restql_query = \ parent_nested_fields.get(field_name, None) elif isinstance(self.parent, Serializer): field_name = self.field_name parent = self.parent if hasattr(parent, "nested_fields"): parent_nested_fields = parent.nested_fields self.parsed_restql_query = \ parent_nested_fields.get(field_name, None) if self.parsed_restql_query is None: # No filtering on nested fields # Retrieve all nested fields return self.get_allowed_fields() # NOTE: self.parsed_restql_query["include"] not being empty # is not a guarantee that the exclude operator(-) has not been # used because the same self.parsed_restql_query["include"] # is used to store nested fields when the exclude operator(-) is used if self.parsed_restql_query["exclude"]: # Exclude fields from a query return self.exclude_fields() elif self.parsed_restql_query["include"]: # Here we are sure that self.parsed_restql_query["exclude"] # is empty which means the exclude operator(-) is not used, # so self.parsed_restql_query["include"] contains only fields # to include return self.include_fields() else: # The query is empty i.e query={} # return nothing return {} @cached_property def _all_fields(self): return super().fields @property def fields(self): if self._use_restql_fields: # Use restql fields return self.restql_fields return self._all_fields class EagerLoadingMixin(RequestQueryParserMixin): @property def parsed_restql_query(self): """ Gets parsed query for use in eager loading. Defaults to the serializer parsed query assuming using django-restql DynamicsFieldMixin. """ if self.has_restql_query_param(self.request): try: return self.get_parsed_restql_query_from_req(self.request) except (SyntaxError, QueryFormatError): # Let `DynamicFieldsMixin` handle this for a user # to get a helpful error message pass # Else include all fields query = { "include": ["*"], "exclude": [], "arguments": {} } return query @property def should_auto_apply_eager_loading(self): if hasattr(self, 'auto_apply_eager_loading'): return self.auto_apply_eager_loading return getattr( restql_settings, "AUTO_APPLY_EAGER_LOADING", True ) def get_select_related_mapping(self): if hasattr(self, "select_related"): return self.select_related # Else select nothing return {} def get_prefetch_related_mapping(self): if hasattr(self, "prefetch_related"): return self.prefetch_related # Else prefetch nothing return {} @classmethod def get_dict_parsed_restql_query(cls, parsed_restql_query): """ Returns the parsed query as a dict. """ keys = {} include = parsed_restql_query.get("include", []) exclude = parsed_restql_query.get("exclude", []) for item in include: if isinstance(item, str): keys[item] = True elif isinstance(item, dict): for key, nested_items in item.items(): key_base = key nested_keys = cls.get_dict_parsed_restql_query(nested_items) keys[key_base] = nested_keys for item in exclude: if isinstance(item, str): keys[item] = False elif isinstance(item, dict): for key, nested_items in item.items(): key_base = key nested_keys = cls.get_dict_parsed_restql_query(nested_items) keys[key_base] = nested_keys return keys @staticmethod def get_related_fields(related_fields_mapping, dict_parsed_restql_query): """ Returns only whitelisted related fields from a query to be used on `select_related` and `prefetch_related` """ related_fields = [] for key, related_field in related_fields_mapping.items(): fields = key.split(".") if isinstance(related_field, (str, Prefetch)): related_field = [related_field] query_node = dict_parsed_restql_query for field in fields: if isinstance(query_node, dict): if field in query_node: # Get a more specific query node query_node = query_node[field] elif "*" in query_node: # All fields are included continue else: # The field is not included in a query so # don't include this field in `related_fields` break else: # If the loop completed without breaking if isinstance(query_node, dict) or query_node: related_fields.extend(related_field) return related_fields def apply_eager_loading(self, queryset): """ Applies appropriate select_related and prefetch_related calls on a queryset """ query = self.get_dict_parsed_restql_query(self.parsed_restql_query) select_mapping = self.get_select_related_mapping() prefetch_mapping = self.get_prefetch_related_mapping() to_select = self.get_related_fields(select_mapping, query) to_prefetch = self.get_related_fields(prefetch_mapping, query) if to_select: queryset = queryset.select_related(*to_select) if to_prefetch: queryset = queryset.prefetch_related(*to_prefetch) return queryset def get_eager_queryset(self, queryset): return self.apply_eager_loading(queryset) def get_queryset(self): """ Override for DRF's get_queryset on the view. If get_queryset is not present, we don't try to run this. Instead, this can still be used by manually calling self.get_eager_queryset and passing in the queryset desired. """ if hasattr(super(), "get_queryset"): queryset = super().get_queryset() if self.should_auto_apply_eager_loading: queryset = self.get_eager_queryset(queryset) return queryset class BaseNestedMixin(object): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # The order in which these two methods are called is important self.build_restql_nested_fields() self.build_restql_source_field_map() def build_restql_nested_fields(self): # Make field_name -> field_value map for restql nested fields self.restql_nested_fields = {} for name, field in self.fields.items(): if isinstance(field, BaseRESTQLNestedField): self.restql_nested_fields.update({name: field}) def build_restql_source_field_map(self): # Make field_source -> field_value map for restql nested fields # You shoul run this after `build_restql_nested_fields` self.restql_source_field_map = {} for field in self.restql_nested_fields.values(): # Get the actual source of the field self.restql_source_field_map.update({field.source: field}) def to_internal_value(self, data): validated_data = super().to_internal_value(data) if self.partial: empty_fields = [] restql_nested_fields = self.restql_source_field_map.keys() for field in restql_nested_fields: if field in validated_data and validated_data[field] == empty: empty_fields.append(field) for field in empty_fields: # Ignore empty fields for partial update validated_data.pop(field) return validated_data class NestedCreateMixin(BaseNestedMixin): """ Create Mixin """ def create_writable_foreignkey_related(self, data): # data format {field: {sub_field: value}} objs = {} nested_fields = self.restql_source_field_map for field, value in data.items(): # Get nested field serializer serializer = nested_fields[field] serializer_class = type(serializer) kwargs = serializer.validation_kwargs serializer = serializer_class( **kwargs, data=value, context=self.context, partial=serializer.is_partial ) serializer.is_valid() if value is None: objs.update({field: None}) else: obj = serializer.save() objs.update({field: obj}) return objs def bulk_create_objs(self, field, data): nested_fields = self.restql_source_field_map # Get nested field serializer serializer = nested_fields[field].child serializer_class = type(serializer) kwargs = serializer.validation_kwargs pks = [] for values in data: serializer = serializer_class( **kwargs, data=values, context=self.context, partial=serializer.is_partial ) serializer.is_valid() obj = serializer.save() pks.append(obj.pk) return pks def create_many_to_one_related(self, instance, data): # data format {field: { # foreignkey_name: name, # data: { # ADD: [pks], # CREATE: [{sub_field: value}] # }} field_pks = {} for field, values in data.items(): model = self.Meta.model foreignkey = getattr(model, field).field.name nested_fields = self.restql_source_field_map for operation in values: if operation == ADD: pks = values[operation] model = nested_fields[field].child.Meta.model qs = model.objects.filter(pk__in=pks) qs.update(**{foreignkey: instance.pk}) field_pks.update({field: pks}) elif operation == CREATE: for v in values[operation]: v.update({foreignkey: instance.pk}) pks = self.bulk_create_objs(field, values[operation]) field_pks.update({field: pks}) return field_pks def create_many_to_many_related(self, instance, data): # data format {field: { # ADD: [pks], # CREATE: [{sub_field: value}] # }} field_pks = {} for field, values in data.items(): obj = getattr(instance, field) for operation in values: if operation == ADD: pks = values[operation] obj.add(*pks) field_pks.update({field: pks}) elif operation == CREATE: pks = self.bulk_create_objs(field, values[operation]) obj.add(*pks) field_pks.update({field: pks}) return field_pks def create(self, validated_data): fields = { "foreignkey_related": { "replaceable": {}, "writable": {} }, "many_to": { "many_related": {}, "one_related": {} } } # Make a partal copy of validated_data so that we can # iterate and alter it data = copy.copy(validated_data) nested_fields = self.restql_source_field_map for field in data: if field not in nested_fields: # Not a nested field continue else: field_serializer = nested_fields[field] if isinstance(field_serializer, Serializer): if isinstance(field_serializer, BaseReplaceableNestedField): value = validated_data.pop(field) fields["foreignkey_related"]["replaceable"] \ .update({field: value}) elif isinstance(field_serializer, BaseWritableNestedField): value = validated_data.pop(field) fields["foreignkey_related"]["writable"]\ .update({field: value}) elif isinstance(field_serializer, ListSerializer) and \ isinstance(field_serializer, BaseWritableNestedField): model = self.Meta.model rel = getattr(model, field).rel if isinstance(rel, ManyToOneRel): value = validated_data.pop(field) fields["many_to"]["one_related"].update({field: value}) elif isinstance(rel, ManyToManyRel): value = validated_data.pop(field) fields["many_to"]["many_related"].update({field: value}) else: pass foreignkey_related = { **fields["foreignkey_related"]["replaceable"], **self.create_writable_foreignkey_related( fields["foreignkey_related"]["writable"] ) } instance = super().create({**validated_data, **foreignkey_related}) self.create_many_to_many_related( instance, fields["many_to"]["many_related"] ) self.create_many_to_one_related( instance, fields["many_to"]["one_related"] ) return instance class NestedUpdateMixin(BaseNestedMixin): """ Update Mixin """ @staticmethod def constrain_error_prefix(field): return "Error on %s field: " % (field,) @staticmethod def update_replaceable_foreignkey_related(instance, data): # data format {field: obj} objs = {} for field, nested_obj in data.items(): setattr(instance, field, nested_obj) instance.save() objs.update({field: instance}) return objs def update_writable_foreignkey_related(self, instance, data): # data format {field: {sub_field: value}} objs = {} nested_fields = self.restql_source_field_map for field, values in data.items(): # Get nested field serializer serializer = nested_fields[field] serializer_class = type(serializer) kwargs = serializer.validation_kwargs nested_obj = getattr(instance, field) serializer = serializer_class( nested_obj, **kwargs, data=values, context=self.context, partial=serializer.is_partial ) serializer.is_valid() if values is None: setattr(instance, field, None) objs.update({field: None}) else: obj = serializer.save() if nested_obj is None: # Patch back newly created object to instance setattr(instance, field, obj) objs.update({field: obj}) else: objs.update({field: nested_obj}) return objs def bulk_create_many_to_many_related(self, field, nested_obj, data): # Get nested field serializer serializer = self.restql_source_field_map[field].child serializer_class = type(serializer) kwargs = serializer.validation_kwargs pks = [] for values in data: serializer = serializer_class( **kwargs, data=values, context=self.context, partial=serializer.is_partial ) serializer.is_valid() obj = serializer.save() pks.append(obj.pk) nested_obj.add(*pks) return pks def bulk_create_many_to_one_related(self, field, nested_obj, data): # Get nested field serializer serializer = self.restql_source_field_map[field].child serializer_class = type(serializer) kwargs = serializer.validation_kwargs pks = [] for values in data: serializer = serializer_class( **kwargs, data=values, context=self.context, partial=serializer.is_partial ) serializer.is_valid() obj = serializer.save() pks.append(obj.pk) return pks def bulk_update_many_to_many_related(self, field, nested_obj, data): # {pk: {sub_field: values}} objs = [] # Get nested field serializer serializer = self.restql_source_field_map[field].child serializer_class = type(serializer) kwargs = serializer.validation_kwargs for pk, values in data.items(): obj = nested_obj.get(pk=pk) serializer = serializer_class( obj, **kwargs, data=values, context=self.context, partial=serializer.is_partial ) serializer.is_valid() obj = serializer.save() objs.append(obj) return objs def bulk_update_many_to_one_related(self, field, instance, data): # {pk: {sub_field: values}} objs = [] # Get nested field serializer serializer = self.restql_source_field_map[field].child serializer_class = type(serializer) kwargs = serializer.validation_kwargs model = self.Meta.model foreignkey = getattr(model, field).field.name nested_obj = getattr(instance, field) for pk, values in data.items(): obj = nested_obj.get(pk=pk) values.update({foreignkey: instance.pk}) serializer = serializer_class( obj, **kwargs, data=values, context=self.context, partial=serializer.is_partial ) serializer.is_valid() obj = serializer.save() objs.append(obj) return objs def update_many_to_one_related(self, instance, data): # data format {field: { # foreignkey_name: name: # data: { # ADD: [{sub_field: value}], # CREATE: [{sub_field: value}], # REMOVE: [pk], # UPDATE: {pk: {sub_field: value}} # }}} for field, values in data.items(): nested_obj = getattr(instance, field) model = self.Meta.model foreignkey = getattr(model, field).field.name nested_fields = self.restql_source_field_map for operation in values: if operation == ADD: pks = values[operation] model = nested_fields[field].child.Meta.model qs = model.objects.filter(pk__in=pks) qs.update(**{foreignkey: instance.pk}) elif operation == CREATE: for v in values[operation]: v.update({foreignkey: instance.pk}) self.bulk_create_many_to_one_related( field, nested_obj, values[operation] ) elif operation == REMOVE: qs = nested_obj.all() qs.filter(pk__in=values[operation]).delete() elif operation == UPDATE: self.bulk_update_many_to_one_related( field, instance, values[operation] ) else: message = ( "%s is an invalid operation, " % (operation,) ) raise ValidationError(message) return instance def update_many_to_many_related(self, instance, data): # data format {field: { # ADD: [{sub_field: value}], # CREATE: [{sub_field: value}], # REMOVE: [pk], # UPDATE: {pk: {sub_field: value}} # }} for field, values in data.items(): nested_obj = getattr(instance, field) for operation in values: if operation == ADD: pks = values[operation] try: nested_obj.add(*pks) except Exception as e: msg = self.constrain_error_prefix(field) + str(e) raise ValidationError(msg) from None elif operation == CREATE: self.bulk_create_many_to_many_related( field, nested_obj, values[operation] ) elif operation == REMOVE: pks = values[operation] try: nested_obj.remove(*pks) except Exception as e: msg = self.constrain_error_prefix(field) + str(e) raise ValidationError(msg) from None elif operation == UPDATE: self.bulk_update_many_to_many_related( field, nested_obj, values[operation] ) else: message = ( "%s is an invalid operation, " % (operation,) ) raise ValidationError(message) return instance def update(self, instance, validated_data): fields = { "foreignkey_related": { "replaceable": {}, "writable": {} }, "many_to": { "many_related": {}, "one_related": {} } } # Make a shallow copy of validated_data so that we can # iterate and alter it data = copy.copy(validated_data) nested_fields = self.restql_source_field_map for field in data: # Not a nested field if field not in nested_fields: continue else: field_serializer = nested_fields[field] if isinstance(field_serializer, Serializer): if isinstance(field_serializer, BaseReplaceableNestedField): value = validated_data.pop(field) fields["foreignkey_related"]["replaceable"] \ .update({field: value}) elif isinstance(field_serializer, BaseWritableNestedField): value = validated_data.pop(field) fields["foreignkey_related"]["writable"] \ .update({field: value}) elif (isinstance(field_serializer, ListSerializer) and (isinstance(field_serializer, BaseWritableNestedField) or isinstance(field_serializer, BaseReplaceableNestedField))): model = self.Meta.model rel = getattr(model, field).rel if isinstance(rel, ManyToOneRel): value = validated_data.pop(field) fields["many_to"]["one_related"].update({field: value}) elif isinstance(rel, ManyToManyRel): value = validated_data.pop(field) fields["many_to"]["many_related"].update({field: value}) else: pass self.update_replaceable_foreignkey_related( instance, fields["foreignkey_related"]["replaceable"] ) self.update_writable_foreignkey_related( instance, fields["foreignkey_related"]["writable"] ) self.update_many_to_many_related( instance, fields["many_to"]["many_related"] ) self.update_many_to_one_related( instance, fields["many_to"]["one_related"] ) return super().update(instance, validated_data)