# -*- coding: utf-8 -*- from distutils.version import LooseVersion from django import get_version from drf_tweaks.serializers import ContextPassing from rest_framework.serializers import ListSerializer, Serializer try: from django.db.models.fields import related_descriptors except ImportError: from django.db.models.fields import related as related_descriptors def check_if_related_object(model_field): if LooseVersion(get_version()) >= LooseVersion("1.9"): return any(isinstance(model_field, x) for x in (related_descriptors.ForwardManyToOneDescriptor, related_descriptors.ReverseOneToOneDescriptor)) else: return any(isinstance(model_field, x) for x in (related_descriptors.SingleRelatedObjectDescriptor, related_descriptors.ReverseSingleRelatedObjectDescriptor)) def check_if_prefetch_object(model_field): if LooseVersion(get_version()) >= LooseVersion("1.9"): return any(isinstance(model_field, x) for x in (related_descriptors.ManyToManyDescriptor, related_descriptors.ReverseManyToOneDescriptor)) else: return any(isinstance(model_field, x) for x in (related_descriptors.ManyRelatedObjectsDescriptor, related_descriptors.ForeignRelatedObjectsDescriptor, related_descriptors.ReverseManyRelatedObjectsDescriptor)) def run_autooptimization_discovery(serializer, prefix, select_related_set, prefetch_related_set, is_prefetch, only_fields, include_fields, force_prefetch=False): if not hasattr(serializer, "Meta") or not hasattr(serializer.Meta, "model"): return model_class = serializer.Meta.model if hasattr(serializer, "get_on_demand_fields"): on_demand_fields = serializer.get_on_demand_fields() else: on_demand_fields = set() def filter_field_name(field_name, fields_to_serialize): if fields_to_serialize is not None: return ContextPassing.filter_fields(field_name, fields_to_serialize) return None for field_name, field in serializer.fields.items(): if hasattr(serializer, "check_if_needs_serialization"): if not serializer.check_if_needs_serialization(field_name, only_fields, include_fields, on_demand_fields): continue if isinstance(field, ListSerializer): if "." not in field.source and hasattr(model_class, field.source): model_field = getattr(model_class, field.source) if check_if_prefetch_object(model_field): prefetch_related_set.add(prefix + field.source) run_autooptimization_discovery(field.child, prefix + field.source + "__", select_related_set, prefetch_related_set, True, filter_field_name(field_name, only_fields), filter_field_name(field_name, include_fields)) elif isinstance(field, Serializer): if "." not in field.source and hasattr(model_class, field.source): model_field = getattr(model_class, field.source) if check_if_related_object(model_field): if is_prefetch or force_prefetch: prefetch_related_set.add(prefix + field.source) else: select_related_set.add(prefix + field.source) run_autooptimization_discovery(field, prefix + field.source + "__", select_related_set, prefetch_related_set, is_prefetch, filter_field_name(field_name, only_fields), filter_field_name(field_name, include_fields)) elif "." in field.source: field_name = field.source.split(".", 1)[0] if hasattr(model_class, field_name): model_field = getattr(model_class, field_name) if check_if_related_object(model_field): select_related_set.add(prefix + field_name) class AutoOptimizeMixin(object): def get_queryset(self): # discover select/prefetch related structure serializer = self.get_serializer_class()(context=self.get_serializer_context()) if hasattr(serializer, "get_only_fields_and_include_fields"): only_fields, include_fields = serializer.get_only_fields_and_include_fields() else: only_fields, include_fields = set(), set() select_related_set = set() prefetch_related_set = set() run_autooptimization_discovery( serializer, "", select_related_set, prefetch_related_set, False, only_fields, include_fields, force_prefetch=getattr(self, "AUTOOPTIMIZE_FORCE_PREFETCH", False) ) # ammending queryset queryset = super(AutoOptimizeMixin, self).get_queryset() if select_related_set: queryset = queryset.select_related(*list(select_related_set)) if prefetch_related_set: queryset = queryset.prefetch_related(*list(prefetch_related_set)) return queryset