from __future__ import annotations import logging import time import warnings from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union from django.conf import settings from django.contrib.postgres.fields import JSONField from django.core.cache import cache from django.core.serializers.json import DjangoJSONEncoder from django.db import models from django.db.models.expressions import RawSQL from django.db.models.fields import CharField from django.db.models.query import QuerySet from django.utils.timezone import now as tz_now from elasticsearch_dsl import Search from .settings import ( get_client, get_model_index_properties, get_model_indexes, get_setting, ) if TYPE_CHECKING: from django.contrib.auth.models import AbstractBaseUser logger = logging.getLogger(__name__) UPDATE_STRATEGY_FULL = "full" UPDATE_STRATEGY_PARTIAL = "partial" UPDATE_STRATEGY = get_setting("update_strategy", UPDATE_STRATEGY_FULL) class SearchDocumentManagerMixin(models.Manager): """ Model manager mixin that adds search document methods. There is one method in this class that must implemented - `get_search_queryset`. This must return a queryset that is the set of objects to be indexed. This queryset is then converted into a generator that emits the objects as JSON documents. """ def get_search_queryset(self, index: str = "_all") -> QuerySet: """ Return the dataset used to populate the search index. Kwargs: index: string, the name of the index we are interested in - this allows us to have different sets of objects in different indexes. Defaults to '_all', in which case all indexes index the same set of objects. This must return a queryset object. """ raise NotImplementedError( "{} does not implement 'get_search_queryset'.".format( self.__class__.__name__ ) ) def in_search_queryset(self, instance_id: int, index: str = "_all") -> bool: """ Return True if an object is part of the search index queryset. Sometimes it's useful to know if an object _should_ be indexed. If an object is saved, how do you know if you should push that change to the search index? The simplest (albeit not most efficient) way is to check if it appears in the underlying search queryset. NB this method doesn't evaluate the entire dataset, it chains an additional queryset filter expression on the end. That's why it's important that the `get_search_queryset` method returns a queryset. Args: instance_id: the id of model object that we are looking for. Kwargs: index: string, the name of the index in which to check. Defaults to '_all'. """ return self.get_search_queryset(index=index).filter(pk=instance_id).exists() def from_search_query(self, search_query: SearchQuery) -> QuerySet: """ Return queryset of objects from SearchQuery.results, **in order**. EXPERIMENTAL: this will only work with results from a single index, with a single doc_type - as we are returning a single QuerySet. This method takes the hits JSON and converts that into a queryset of all the relevant objects. The key part of this is the ordering - the order in which search results are returned is based on relevance, something that only ES can calculate, and that cannot be replicated in the database. It does this by adding custom SQL which annotates each record with the score from the search 'hit'. This is brittle, caveat emptor. The RawSQL clause is in the form: SELECT CASE {{model}}.id WHEN {{id}} THEN {{score}} END The "WHEN x THEN y" is repeated for every hit. The resulting SQL, in full is like this: SELECT "freelancer_freelancerprofile"."id", (SELECT CASE freelancer_freelancerprofile.id WHEN 25 THEN 1.0 WHEN 26 THEN 1.0 [...] ELSE 0 END) AS "search_score" FROM "freelancer_freelancerprofile" WHERE "freelancer_freelancerprofile"."id" IN (25, 26, [...]) ORDER BY "search_score" DESC It should be very fast, as there is no table lookup, but there is an assumption at the heart of this, which is that the search query doesn't contain the entire database - i.e. that it has been paged. (ES itself caps the results at 10,000.) """ hits = search_query.hits score_sql = self._raw_sql([(h["id"], h["score"] or 0) for h in hits]) rank_sql = self._raw_sql([(hits[i]["id"], i) for i in range(len(hits))]) return ( self.get_queryset() .filter(pk__in=[h["id"] for h in hits]) # add the query relevance score .annotate(search_score=RawSQL(score_sql, ())) # noqa: S611 # add the ordering number (0-based) .annotate(search_rank=RawSQL(rank_sql, ())) # noqa: S611 .order_by("search_rank") ) def _when(self, x: Union[str, int], y: Union[str, int]) -> str: return "WHEN {} THEN {}".format(x, y) def _raw_sql(self, values: List[Tuple[Union[str, int], Union[str, int]]]) -> str: """Prepare SQL statement consisting of a sequence of WHEN .. THEN statements.""" if isinstance(self.model._meta.pk, CharField): when_clauses = " ".join( [self._when("'{}'".format(x), y) for (x, y) in values] ) else: when_clauses = " ".join([self._when(x, y) for (x, y) in values]) table_name = self.model._meta.db_table primary_key = self.model._meta.pk.column return 'SELECT CASE {}."{}" {} ELSE 0 END'.format( table_name, primary_key, when_clauses ) class SearchDocumentMixin(object): """ Mixin used by models that are indexed for ES. This mixin defines the interface exposed by models that are indexed ready for ES. The only method that needs implementing is `as_search_document`. """ # Django model field types that can be serialized directly into # a known format. All other types will need custom serialization. # Used by as_search_document_update method SIMPLE_UPDATE_FIELD_TYPES = [ "AutoField", "BooleanField", "CharField", "DateField", "DateTimeField", "DecimalField", "EmailField", "FloatField", "IntegerField", "TextField", "URLField", "UUIDField", ] @property def search_indexes(self) -> List[str]: """Return the list of indexes for which this model is configured.""" return get_model_indexes(self.__class__) @property def search_document_cache_key(self) -> str: """Key used for storing search docs in local cache.""" return "elasticsearch_django:{}.{}.{}".format( self._meta.app_label, self._meta.model_name, self.pk # type: ignore ) @property def search_doc_type(self) -> str: """Return the doc_type used for the model.""" return self._meta.model_name # type: ignore def as_search_document(self, *, index: str) -> dict: """ Return the object as represented in a named index. This is named to avoid confusion - if it was `get_search_document`, which would be the logical name, it would not be clear whether it referred to getting the local representation of the search document, or actually fetching it from the index. Kwargs: index: string, the name of the index in which the object is to appear - this allows different representations in different indexes. Defaults to '_all', in which case all indexes use the same search document structure. Returns a dictionary. """ raise NotImplementedError( "{} does not implement 'as_search_document'.".format( self.__class__.__name__ ) ) def _is_field_serializable(self, field_name: str) -> bool: """Return True if the field can be serialized into a JSON doc.""" return ( self._meta.get_field(field_name).get_internal_type() # type: ignore in self.SIMPLE_UPDATE_FIELD_TYPES ) def clean_update_fields(self, index: str, update_fields: List[str]) -> List[str]: """ Clean the list of update_fields based on the index being updated. If any field in the update_fields list is not in the set of properties defined by the index mapping for this model, then we ignore it. If a field _is_ in the mapping, but the underlying model field is a related object, and thereby not directly serializable, then this method will raise a ValueError. """ search_fields = get_model_index_properties(self, index) clean_fields = [f for f in update_fields if f in search_fields] ignore = [f for f in update_fields if f not in search_fields] if ignore: logger.debug( "Ignoring fields from partial update: %s", [f for f in update_fields if f not in search_fields], ) for f in clean_fields: if not self._is_field_serializable(f): raise ValueError( "'%s' cannot be automatically serialized into a search " "document property. Please override as_search_document_update.", f, ) return clean_fields def as_search_document_update( self, *, index: str, update_fields: List[str] ) -> dict: """ Return a partial update document based on which fields have been updated. If an object is saved with the `update_fields` argument passed through, then it is assumed that this is a 'partial update'. In this scenario we need a {property: value} dictionary containing just the fields we want to update. This method handles two possible update strategies - 'full' or 'partial'. The default 'full' strategy simply returns the value of `as_search_document` - thereby replacing the entire document each time. The 'partial' strategy is more intelligent - it will determine whether the fields passed are in the search document mapping, and return a partial update document that contains only those that are. In addition, if any field that _is_ included cannot be automatically serialized (e.g. a RelatedField object), then this method will raise a ValueError. In this scenario, you should override this method in your subclass. >>> def as_search_document_update(self, index, update_fields): ... if 'user' in update_fields: ... update_fields.remove('user') ... doc = super().as_search_document_update(index, update_fields) ... doc['user'] = self.user.get_full_name() ... return doc ... return super().as_search_document_update(index, update_fields) You may also wish to subclass this method to perform field-specific logic - in this example if only the timestamp is being saved, then ignore the update if the timestamp is later than a certain time. >>> def as_search_document_update(self, index, update_fields): ... if update_fields == ['timestamp']: ... if self.timestamp > today(): ... return {} ... return super().as_search_document_update(index, update_fields) """ if UPDATE_STRATEGY == UPDATE_STRATEGY_FULL: return self.as_search_document(index=index) if UPDATE_STRATEGY == UPDATE_STRATEGY_PARTIAL: # in partial mode we update the intersection of update_fields and # properties found in the mapping file. return { k: getattr(self, k) for k in self.clean_update_fields( index=index, update_fields=update_fields ) } raise ValueError("Invalid update strategy.") def as_search_action(self, *, index: str, action: str) -> dict: """ Return an object as represented in a bulk api operation. Bulk API operations have a very specific format. This function will call the standard `as_search_document` method on the object and then wrap that up in the correct format for the action specified. https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html Args: index: string, the name of the index in which the action is to be taken. Bulk operations are only every carried out on a single index at a time. action: string ['index' | 'update' | 'delete'] - this decides how the final document is formatted. Returns a dictionary. """ if action not in ("index", "update", "delete"): raise ValueError("Action must be 'index', 'update' or 'delete'.") document = { "_index": index, "_type": self.search_doc_type, "_op_type": action, "_id": self.pk, # type: ignore } if action == "index": document["_source"] = self.as_search_document(index=index) elif action == "update": document["doc"] = self.as_search_document(index=index) return document def fetch_search_document(self, *, index: str) -> dict: """Fetch the object's document from a search index by id.""" if not self.pk: # type: ignore raise ValueError("Object must have a primary key before being indexed.") client = get_client() return client.get( index=index, doc_type=self.search_doc_type, id=self.pk # type: ignore ) def index_search_document(self, *, index: str) -> None: """ Create or replace search document in named index. Checks the local cache to see if the document has changed, and if not aborts the update, else pushes to ES, and then resets the local cache. Cache timeout is set as "cache_expiry" in the settings, and defaults to 60s. """ cache_key = self.search_document_cache_key new_doc = self.as_search_document(index=index) cached_doc = cache.get(cache_key) if new_doc == cached_doc: logger.debug("Search document for %r is unchanged, ignoring update.", self) return cache.set(cache_key, new_doc, timeout=get_setting("cache_expiry", 60)) get_client().index( index=index, doc_type=self.search_doc_type, body=new_doc, id=self.pk, # type: ignore ) def update_search_document(self, *, index: str, update_fields: List[str]) -> None: """ Partial update of a document in named index. Partial updates are invoked via a call to save the document with 'update_fields'. These fields are passed to the as_search_document method so that it can build a partial document. NB we don't just call as_search_document and then strip the fields _not_ in update_fields as we are trying to avoid possibly expensive operations in building the source document. The canonical example for this method is updating a single timestamp on a model - we don't want to have to walk the model relations and build a document in this case - we just want to push the timestamp. When POSTing a partial update the `as_search_document` doc must be passed to the `client.update` wrapped in a "doc" node, # noqa: E501, see: https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-update.html """ doc = self.as_search_document_update(index=index, update_fields=update_fields) if not doc: logger.debug("Ignoring object update as document is empty.") return get_client().update( index=index, doc_type=self.search_doc_type, body={"doc": doc}, id=self.pk, # type: ignore ) def delete_search_document(self, *, index: str) -> None: """Delete document from named index.""" cache.delete(self.search_document_cache_key) get_client().delete( index=index, doc_type=self.search_doc_type, id=self.pk # type: ignore ) class SearchQuery(models.Model): """ Model used to capture ES queries and responses. For low-traffic sites it's useful to be able to replay searches, and to track how a user filtered and searched. This model can be used to store a search query and meta information about the results (document type, id and score). >>> from elasticsearch_dsl import Search >>> search = Search(using=client) >>> sq = SearchQuery.execute(search).save() """ # whether this is a search query (returns results), or a count API # query (returns the number of results, but no detail), QUERY_TYPE_SEARCH = "SEARCH" QUERY_TYPE_COUNT = "COUNT" QUERY_TYPE_CHOICES = ( (QUERY_TYPE_SEARCH, "Search results"), (QUERY_TYPE_COUNT, "Count only"), ) user = models.ForeignKey( settings.AUTH_USER_MODEL, related_name="search_queries", blank=True, null=True, help_text="The user who made the search query (nullable).", on_delete=models.SET_NULL, ) index = models.CharField( max_length=100, default="_all", help_text="The name of the ElasticSearch index(es) being queried.", ) # The query property contains the raw DSL query, which can be arbitrarily complex - # there is no one way of mapping input text to the query itself. However, it's # often helpful to have the terms that the user themselves typed easily accessible # without having to parse JSON. search_terms = models.CharField( max_length=400, default="", blank=True, help_text=( "Free text search terms used in the query, stored for easy reference." ), ) query = JSONField( help_text="The raw ElasticSearch DSL query.", encoder=DjangoJSONEncoder ) query_type = CharField( help_text="Does this query return results, or just the hit count?", choices=QUERY_TYPE_CHOICES, default=QUERY_TYPE_SEARCH, max_length=10, ) hits = JSONField( help_text="The list of meta info for each of the query matches returned.", encoder=DjangoJSONEncoder, ) total_hits = models.IntegerField( default=0, help_text="Total number of matches found for the query (!= the hits returned).", ) reference = models.CharField( max_length=100, default="", blank=True, help_text="Custom reference used to identify and group related searches.", ) executed_at = models.DateTimeField( help_text="When the search was executed - set via execute() method." ) duration = models.FloatField( help_text="Time taken to execute the search itself, in seconds." ) class Meta: app_label = "elasticsearch_django" verbose_name = "Search query" verbose_name_plural = "Search queries" def __str__(self) -> str: return f"Query (id={self.pk}) run against index '{self.index}'" def __repr__(self) -> str: return ( f"<SearchQuery id={self.pk} user={self.user} " f"index='{self.index}' total_hits={self.total_hits} >" ) def save(self, *args: Any, **kwargs: Any) -> SearchQuery: """Save and return the object (for chaining).""" if self.search_terms is None: self.search_terms = "" super().save(**kwargs) return self def _extract_set(self, _property: str) -> List[Union[str, int]]: return ( [] if self.hits is None else (list(set([h[_property] for h in self.hits]))) ) @property def doc_types(self) -> List[str]: """List of doc_types extracted from hits.""" return [str(x) for x in self._extract_set("doc_type")] @property def max_score(self) -> int: """Max relevance score in the returned page.""" return int(max(self._extract_set("score") or [0])) @property def min_score(self) -> int: """Min relevance score in the returned page.""" return int(min(self._extract_set("score") or [0])) @property def object_ids(self) -> List[int]: """List of model ids extracted from hits.""" return [int(x) for x in self._extract_set("id")] @property def page_slice(self) -> Optional[Tuple[int, int]]: """Return the query from:size tuple (0-based).""" return ( None if self.query is None else (self.query.get("from", 0), self.query.get("size", 10)) ) @property def page_from(self) -> int: """1-based index of the first hit in the returned page.""" if self.page_size == 0: return 0 if not self.page_slice: return 0 return self.page_slice[0] + 1 @property def page_to(self) -> int: """1-based index of the last hit in the returned page.""" return 0 if self.page_size == 0 else self.page_from + self.page_size - 1 @property def page_size(self) -> int: """Return number of hits returned in this specific page.""" return 0 if self.hits is None else len(self.hits) @classmethod def execute( cls, search: Search, search_terms: str = "", user: Optional[AbstractBaseUser] = None, reference: Optional[str] = "", save: bool = True, ) -> SearchQuery: """Create a new SearchQuery instance and execute a search against ES.""" warnings.warn( "Deprecated - please use `execute_search` function instead.", DeprecationWarning, ) return execute_search( search, search_terms=search_terms, user=user, reference=reference, save=save ) def execute_search( search: Search, search_terms: str = "", user: Optional[AbstractBaseUser] = None, reference: Optional[str] = "", save: bool = True, query_type: str = SearchQuery.QUERY_TYPE_SEARCH, ) -> SearchQuery: """ Create a new SearchQuery instance and execute a search against ES. Args: search: elasticsearch.search.Search object, that internally contains the connection and query; this is the query that is executed. All we are doing is logging the input and parsing the output. search_terms: raw end user search terms input - what they typed into the search box. user: Django User object, the person making the query - used for logging purposes. Can be null. reference: string, can be anything you like, used for identification, grouping purposes. save: bool, if True then save the new object immediately, can be overridden to False to prevent logging absolutely everything. Defaults to True query_type: string, used to determine whether to run a search query or a count query (returns hit count, but no results). """ start = time.time() if query_type == SearchQuery.QUERY_TYPE_SEARCH: response = search.execute() hits = [h.meta.to_dict() for h in response.hits] total_hits = response.hits.total elif query_type == SearchQuery.QUERY_TYPE_COUNT: response = total_hits = search.count() hits = [] else: raise ValueError(f"Invalid SearchQuery.query_type value: '{query_type}'") duration = time.time() - start search_query = SearchQuery( user=user, search_terms=search_terms, index=", ".join(search._index or ["_all"])[:100], # field length restriction query=search.to_dict(), query_type=query_type, hits=hits, total_hits=total_hits, reference=reference or "", executed_at=tz_now(), duration=duration, ) search_query.response = response return search_query.save() if save else search_query