from __future__ import annotations import logging import re from typing import TYPE_CHECKING, Any from django.conf import settings as django_settings from django.contrib.auth.models import AnonymousUser from django.core.cache import cache from django.core.exceptions import ValidationError from django.db import connection, models from django.db.models.query import QuerySet from django.http import HttpRequest, HttpResponse from django.utils import timezone from . import settings if TYPE_CHECKING: from django_settings import AUTH_USER_MODEL logger = logging.getLogger(__name__) class RuleSetQuerySet(models.query.QuerySet): """Custom QuerySet for RuleSet instances.""" def live_rules(self) -> QuerySet: """Return enabled rules.""" rulesets = cache.get(settings.RULESET_CACHE_KEY) if rulesets is None: rulesets = self.filter(enabled=True) cache.set( settings.RULESET_CACHE_KEY, rulesets, settings.RULESET_CACHE_TIMEOUT ) return rulesets class RuleSet(models.Model): """Set of rules to match a URI and/or User.""" # property used to determine how to filter users USER_FILTER_ALL = 0 USER_FILTER_AUTH = 1 USER_FILTER_GROUP = 2 USER_FILTER_CHOICES = ( (USER_FILTER_ALL, "All users (inc. None)"), (USER_FILTER_AUTH, "Authenticated users only"), (USER_FILTER_GROUP, "Users in a named group"), ) enabled = models.BooleanField(default=True, db_index=True) uri_regex = models.CharField( blank=True, default="", max_length=100, help_text="Regex used to filter by request URI.", verbose_name="Request path regex", ) user_filter_type = models.IntegerField( default=0, choices=USER_FILTER_CHOICES, help_text="Filter requests by type of user.", verbose_name="User type filter", ) user_group_filter = models.CharField( blank=True, default="", max_length=100, help_text="Group used to filter users.", verbose_name="User group filter", ) # use the custom model manager objects = RuleSetQuerySet.as_manager() def __str__(self) -> str: return "Profiling rule #{}".format(self.pk) @property def has_group_filter(self) -> bool: return len(self.user_group_filter.strip()) > 0 def clean(self) -> None: """Ensure that user_filter_group and user_filter_type values are appropriate.""" if self.has_group_filter and self.user_filter_type != RuleSet.USER_FILTER_GROUP: raise ValidationError( "User filter type must be 'group' if you specify a group." ) if ( self.user_filter_type == RuleSet.USER_FILTER_GROUP and not self.has_group_filter ): raise ValidationError( "You must specify a group if the filter type is 'group'." ) def match_uri(self, request_uri: str) -> bool: """ Return True if there is a uri_regex and it matches. Args: request_uri: the HttpRequest.build_absolute_uri(), used to match against all the uri_regex. Returns True if there is a uri_regex and it matches, or if there there is no uri_regex, in which the match is implicit. """ regex = self.uri_regex.strip() if regex == "": return True else: return re.search(regex, request_uri) is not None def match_user(self, user: AUTH_USER_MODEL) -> bool: """Return True if the user passes the various user filters.""" # treat no user (i.e. has not been added) as AnonymousUser() user = user or AnonymousUser() if self.user_filter_type == RuleSet.USER_FILTER_ALL: return True if self.user_filter_type == RuleSet.USER_FILTER_AUTH: return user.is_authenticated if self.user_filter_type == RuleSet.USER_FILTER_GROUP: group = self.user_group_filter.strip() return user.groups.filter(name__iexact=group).exists() # if we're still going, then it's a no. it's also an invalid # user_filter_type, so we may want to think about a warning return False class ProfilingRecord(models.Model): """Record of a request and its response.""" user = models.ForeignKey( django_settings.AUTH_USER_MODEL, on_delete=models.CASCADE, null=True, blank=True ) session_key = models.CharField(blank=True, max_length=40) start_ts = models.DateTimeField(verbose_name="Request started at") end_ts = models.DateTimeField(verbose_name="Request ended at") duration = models.FloatField(verbose_name="Request duration (sec)") http_method = models.CharField(max_length=10) request_uri = models.URLField(verbose_name="Request path") query_string = models.TextField(null=False, blank=True, verbose_name="Query string") remote_addr = models.CharField(max_length=100) http_user_agent = models.CharField(max_length=400) http_referer = models.CharField(max_length=400, default="") view_func_name = models.CharField(max_length=100, verbose_name="View function") response_status_code = models.IntegerField() response_content_length = models.IntegerField() query_count = models.IntegerField( help_text="Number of database queries logged during request.", blank=True, null=True, ) def __str__(self) -> str: return "Profiling record #{}".format(self.pk) def save(self, *args: Any, **kwargs: Any) -> ProfilingRecord: super().save(*args, **kwargs) return self def start(self) -> ProfilingRecord: """Set start_ts from current datetime.""" self.start_ts = timezone.now() self.end_ts = None self.duration = None self.query_count = 0 self._query_count = len(connection.queries) self._force_debug_cursor = connection.force_debug_cursor connection.force_debug_cursor = settings.FORCE_DEBUG_CURSOR return self @property def elapsed(self) -> float: """Time (in seconds) elapsed so far.""" if self.start_ts is None: raise ValueError("You must 'start' before you can get elapsed time.") return (timezone.now() - self.start_ts).total_seconds() def set_request(self, request: HttpRequest) -> ProfilingRecord: """Extract values from HttpRequest and store locally.""" self.request = request self.http_method = request.method self.request_uri = request.path self.query_string = request.META.get("QUERY_STRING", "") self.http_user_agent = request.META.get("HTTP_USER_AGENT", "")[:400] # we care about the domain more than the URL itself, so truncating # doesn't lose much useful information self.http_referer = request.META.get("HTTP_REFERER", "")[:400] # X-Forwarded-For is used by convention when passing through # load balancers etc., as the REMOTE_ADDR is rewritten in transit self.remote_addr = ( request.META.get("HTTP_X_FORWARDED_FOR") if "HTTP_X_FORWARDED_FOR" in request.META else request.META.get("REMOTE_ADDR") ) # these two require middleware, so may not exist if hasattr(request, "session"): self.session_key = request.session.session_key or "" # NB you can't store AnonymouseUsers, so don't bother trying if hasattr(request, "user") and request.user.is_authenticated: self.user = request.user return self def set_response(self, response: HttpResponse) -> ProfilingRecord: """Extract values from HttpResponse and store locally.""" self.response = response self.response_status_code = response.status_code self.response_content_length = len(response.content) return self def stop(self) -> ProfilingRecord: """Set end_ts and duration from current datetime.""" if self.start_ts is None: raise ValueError("You must 'start' before you can 'stop'") self.end_ts = timezone.now() self.duration = (self.end_ts - self.start_ts).total_seconds() self.query_count = len(connection.queries) - self._query_count connection.force_debug_cursor = self._force_debug_cursor if hasattr(self, "response"): self.response["X-Profiler-Duration"] = self.duration return self def cancel(self) -> ProfilingRecord: """Cancel the profile by setting is_cancelled to True.""" self.start_ts = None self.end_ts = None self.duration = None self.is_cancelled = True return self def capture(self) -> ProfilingRecord: """Call stop() and save() on the profile if is_cancelled is False.""" if getattr(self, "is_cancelled", False) is True: logger.debug("%r has been cancelled.", self) return self else: self.stop().save() return self