from __future__ import annotations import datetime import logging from typing import Any, Optional from django.conf import settings from django.contrib.auth import login from django.contrib.postgres.fields import JSONField from django.core.exceptions import ValidationError from django.db import models, transaction from django.http import HttpRequest from django.http.response import HttpResponse from django.utils.timezone import now as tz_now from jwt.exceptions import InvalidAudienceError, InvalidTokenError from .exceptions import MaxUseError from .settings import DEFAULT_MAX_USES, JWT_SESSION_TOKEN_EXPIRY, LOG_TOKEN_ERRORS from .utils import encode, to_seconds logger = logging.getLogger(__name__) class RequestTokenQuerySet(models.query.QuerySet): """Custom QuerySet for RquestToken objects.""" def create_token(self, scope: str, **kwargs: Any) -> RequestToken: """Create a new RequestToken.""" return RequestToken(scope=scope, **kwargs).save() class RequestToken(models.Model): """ A link token, targeted for use by a known Django User. A RequestToken contains information that can be encoded as a JWT (JSON Web Token). It is designed to be used in conjunction with the RequestTokenMiddleware (responsible for JWT verification) and the @use_request_token decorator (responsible for validating the token and setting the request.user correctly). Each token must have a 'scope', which is used to tie it to a view function that is decorated with the `use_request_token` decorator. The token can only be used by functions with matching scopes. The token may be set to a specific User, in which case, if the existing request is unauthenticated, it will use that user as the `request.user` property, allowing access to authenticated views. The token may be timebound by the `not_before_time` and `expiration_time` properties, which are registered JWT 'claims'. The token may be restricted by the number of times it can be used, through the `max_use` property, which is incremented each time it's used (NB *not* thread-safe). The token may also store arbitrary serializable data, which can be used by the view function if the request token is valid. JWT spec: https://tools.ietf.org/html/rfc7519 """ # do not login the user on the request LOGIN_MODE_NONE = "None" # login the user, but only for the original request LOGIN_MODE_REQUEST = "Request" # login the user fully, but only for single-use short-duration links LOGIN_MODE_SESSION = "Session" LOGIN_MODE_CHOICES = ( (LOGIN_MODE_NONE, "Do not authenticate"), (LOGIN_MODE_REQUEST, "Authenticate a single request"), (LOGIN_MODE_SESSION, "Authenticate for the entire session"), ) login_mode = models.CharField( max_length=10, default=LOGIN_MODE_NONE, choices=LOGIN_MODE_CHOICES, help_text="How should the request be authenticated?", ) user = models.ForeignKey( settings.AUTH_USER_MODEL, related_name="request_tokens", blank=True, null=True, on_delete=models.CASCADE, help_text="Intended recipient of the JWT (can be used by anyone if not set).", ) scope = models.CharField( max_length=100, help_text="Label used to match request to view function in decorator.", ) expiration_time = models.DateTimeField( blank=True, null=True, help_text="Token will expire at this time (raises ExpiredSignatureError).", ) not_before_time = models.DateTimeField( blank=True, null=True, help_text=( "Token cannot be used before this time (raises ImmatureSignatureError)." ), ) data = JSONField( help_text=( "Custom data add to the token, but not encoded (must be fetched from DB)." ), blank=True, null=True, default=dict, ) issued_at = models.DateTimeField( blank=True, null=True, help_text="Time the token was created (set in the initial save).", ) max_uses = models.IntegerField( default=DEFAULT_MAX_USES, help_text="The maximum number of times the token can be used.", ) used_to_date = models.IntegerField( default=0, help_text=( "Number of times the token has been used to date (raises MaxUseError)." ), ) objects = RequestTokenQuerySet.as_manager() class Meta: verbose_name = "Token" verbose_name_plural = "Tokens" def __str__(self) -> str: return "Request token #%s" % (self.id) def __repr__(self) -> str: return "<RequestToken id=%s scope=%s login_mode='%s'>" % ( self.id, self.scope, self.login_mode, ) @property def aud(self) -> Optional[int]: """Return 'aud' claim, mapped to user.id.""" return self.claims.get("aud") @property def exp(self) -> Optional[datetime.datetime]: """Return 'exp' claim, mapped to expiration_time.""" return self.claims.get("exp") @property def nbf(self) -> Optional[datetime.datetime]: """Return the 'nbf' claim, mapped to not_before_time.""" return self.claims.get("nbf") @property def iat(self) -> Optional[datetime.datetime]: """Return the 'iat' claim, mapped to issued_at.""" return self.claims.get("iat") @property def jti(self) -> Optional[int]: """Return the 'jti' claim, mapped to id.""" return self.claims.get("jti") @property def max(self) -> int: """Return the 'max' claim, mapped to max_uses.""" return self.claims["max"] @property def sub(self) -> str: """Return the 'sub' claim, mapped to scope.""" return self.claims["sub"] @property def claims(self) -> dict: """Return dict containing all of the DEFAULT_CLAIMS (where values exist).""" claims = { "max": self.max_uses, "sub": self.scope, "mod": self.login_mode[:1].lower(), } if self.id is not None: claims["jti"] = self.id if self.user is not None: claims["aud"] = self.user.id if self.expiration_time is not None: claims["exp"] = to_seconds(self.expiration_time) if self.issued_at is not None: claims["iat"] = to_seconds(self.issued_at) if self.not_before_time is not None: claims["nbf"] = to_seconds(self.not_before_time) return claims def clean(self) -> None: """Ensure that login_mode setting is valid.""" if self.login_mode == RequestToken.LOGIN_MODE_NONE: pass if self.login_mode == RequestToken.LOGIN_MODE_SESSION: if self.user is None: raise ValidationError({"user": "Session token must have a user."}) if self.expiration_time is None: raise ValidationError( {"expiration_time": "Session token must have an expiration_time."} ) if self.login_mode == RequestToken.LOGIN_MODE_REQUEST: if self.user is None: raise ValidationError( {"expiration_time": "Request token must have a user."} ) def save(self, *args: Any, **kwargs: Any) -> RequestToken: if "update_fields" not in kwargs: self.issued_at = self.issued_at or tz_now() if self.login_mode == RequestToken.LOGIN_MODE_SESSION: self.expiration_time = self.expiration_time or ( self.issued_at + datetime.timedelta(minutes=JWT_SESSION_TOKEN_EXPIRY) ) self.clean() super(RequestToken, self).save(*args, **kwargs) return self def jwt(self) -> str: """Encode the token claims into a JWT.""" return encode(self.claims).decode() def validate_max_uses(self) -> None: """ Check the token max_uses is still valid. Raises MaxUseError if invalid. """ if self.used_to_date >= self.max_uses: raise MaxUseError("RequestToken [%s] has exceeded max uses" % self.id) def _auth_is_anonymous(self, request: HttpRequest) -> HttpRequest: """Authenticate anonymous requests.""" if request.user.is_authenticated: raise InvalidAudienceError("Token requires anonymous user.") if self.login_mode == RequestToken.LOGIN_MODE_NONE: pass if self.login_mode == RequestToken.LOGIN_MODE_REQUEST: logger.debug( "Setting request.user to %r from token %i.", self.user, self.id ) request.user = self.user if self.login_mode == RequestToken.LOGIN_MODE_SESSION: logger.debug( "Authenticating request.user as %r from token %i.", self.user, self.id ) # I _think_ we can get away with this as we are pulling the # user out of the DB, and we are explicitly authenticating # the user. self.user.backend = "django.contrib.auth.backends.ModelBackend" login(request, self.user) return request def _auth_is_authenticated(self, request: HttpRequest) -> HttpRequest: """Authenticate requests with existing users.""" if request.user.is_anonymous: raise InvalidAudienceError("Token requires authenticated user.") if self.login_mode == RequestToken.LOGIN_MODE_NONE: return request if request.user == self.user: return request raise InvalidAudienceError( "RequestToken [%i] audience mismatch: '%s' != '%s'" % (self.id, request.user, self.user) ) def authenticate(self, request: HttpRequest) -> HttpRequest: """ Authenticate an HttpRequest with the token user. This method encapsulates the request handling - if the token has a user assigned, then this will be added to the request. """ if request.user.is_anonymous: return self._auth_is_anonymous(request) else: return self._auth_is_authenticated(request) @transaction.atomic def log( self, request: HttpRequest, response: HttpResponse, error: Optional[InvalidTokenError] = None, ) -> RequestTokenLog: """ Record the use of a token. This is used by the decorator to log each time someone uses the token, or tries to. Used for reporting, diagnostics. Args: request: the HttpRequest object that used the token, from which the user, ip and user-agenct are extracted. response: the corresponding HttpResponse object, from which the status code is extracted. error: an InvalidTokenError that gets logged as a RequestTokenError. Returns a RequestTokenUse object. """ def rmg(key: str, default: Any = None) -> Any: return request.META.get(key, default) log = RequestTokenLog( token=self, user=None if request.user.is_anonymous else request.user, user_agent=rmg("HTTP_USER_AGENT", "unknown"), client_ip=parse_xff(rmg("HTTP_X_FORWARDED_FOR")) or rmg("REMOTE_ADDR", None), status_code=response.status_code, ).save() if error and LOG_TOKEN_ERRORS: RequestTokenErrorLog.objects.create_error_log(log, error) # NB this will include all error logs - which means that an error log # may prohibit further use of the token. Is there a scenario in which # this would be the wrong outcome? self.used_to_date = self.logs.filter(error__isnull=True).count() self.save() return log def expire(self) -> None: """Mark the token as expired immediately, effectively killing the token.""" self.expiration_time = tz_now() - datetime.timedelta(microseconds=1) self.save() def parse_xff(header_value: str) -> Optional[str]: """ Parse out the X-Forwarded-For request header. This handles the bug that blows up when multiple IP addresses are specified in the header. The docs state that the header contains "The originating IP address", but in reality it contains a list of all the intermediate addresses. The first item is the original client, and then any intermediate proxy IPs. We want the original. Returns the first IP in the list, else None. """ try: return header_value.split(",")[0].strip() except (KeyError, AttributeError): return None class RequestTokenLog(models.Model): """Used to log the use of a RequestToken.""" token = models.ForeignKey( RequestToken, related_name="logs", help_text="The RequestToken that was used.", on_delete=models.CASCADE, db_index=True, ) user = models.ForeignKey( settings.AUTH_USER_MODEL, blank=True, null=True, on_delete=models.CASCADE, help_text="The user who made the request (None if anonymous).", ) user_agent = models.TextField( blank=True, help_text="User-agent of client used to make the request." ) client_ip = models.GenericIPAddressField( blank=True, null=True, unpack_ipv4=True, help_text="Client IP of device used to make the request.", ) status_code = models.IntegerField( blank=True, null=True, help_text="Response status code associated with this use of the token.", ) timestamp = models.DateTimeField( blank=True, help_text="Time the request was logged." ) class Meta: verbose_name = "Log" verbose_name_plural = "Logs" def __str__(self) -> str: if self.user is None: return "%s used %s" % (self.token, self.timestamp) else: return "%s used by %s at %s" % (self.token, self.user, self.timestamp) def __repr__(self) -> str: return "<RequestTokenLog id=%s token=%s timestamp='%s'>" % ( self.id, self.token.id, self.timestamp, ) def save(self, *args: Any, **kwargs: Any) -> RequestToken: if "update_fields" not in kwargs: self.timestamp = self.timestamp or tz_now() super(RequestTokenLog, self).save(*args, **kwargs) return self class RequestTokenErrorLogQuerySet(models.query.QuerySet): def create_error_log( self, log: RequestTokenLog, error: Exception ) -> RequestTokenErrorLog: return RequestTokenErrorLog( token=log.token, log=log, error_type=type(error).__name__, error_message=str(error), ) class RequestTokenErrorLog(models.Model): """Used to log errors that occur with the use of a RequestToken.""" token = models.ForeignKey( RequestToken, related_name="errors", on_delete=models.CASCADE, help_text="The RequestToken that was used.", db_index=True, ) log = models.OneToOneField( RequestTokenLog, related_name="error", on_delete=models.CASCADE, help_text="The token use against which the error occurred.", db_index=True, ) error_type = models.CharField( max_length=50, help_text="The underlying type of error raised." ) error_message = models.CharField( max_length=200, help_text="The error message supplied." ) objects = RequestTokenErrorLogQuerySet().as_manager() class Meta: verbose_name = "Error" verbose_name_plural = "Errors" def __str__(self) -> str: return self.error_message def save(self, *args: Any, **kwargs: Any) -> RequestTokenErrorLog: super(RequestTokenErrorLog, self).save(*args, **kwargs) return self