# coding: utf-8 from __future__ import absolute_import, unicode_literals import jwt from datetime import datetime from django.contrib.auth import get_user_model from django.core.serializers.json import DjangoJSONEncoder from django.utils import timezone from django.utils.translation import gettext_lazy as _ from jwt.exceptions import MissingRequiredClaimError, InvalidIssuerError, InvalidTokenError from rest_framework import exceptions from rest_framework_sso import claims from rest_framework_sso.keys import get_private_key_and_key_id, get_public_key_and_key_id from rest_framework_sso.settings import api_settings import logging logger = logging.getLogger(__name__) def create_session_payload(session_token, user, **kwargs): return { claims.TOKEN: claims.TOKEN_SESSION, claims.SESSION_ID: session_token.pk, claims.USER_ID: user.pk, claims.EMAIL: user.email, } def create_authorization_payload(session_token, user, **kwargs): return { claims.TOKEN: claims.TOKEN_AUTHORIZATION, claims.SESSION_ID: session_token.pk, claims.USER_ID: user.pk, claims.EMAIL: user.email, claims.SCOPES: [], } def encode_jwt_token(payload): if payload.get(claims.TOKEN) not in (claims.TOKEN_SESSION, claims.TOKEN_AUTHORIZATION): raise RuntimeError("Unknown token type") if not payload.get(claims.ISSUER): if api_settings.IDENTITY is not None: payload[claims.ISSUER] = api_settings.IDENTITY else: raise RuntimeError("IDENTITY must be specified in settings") if not payload.get(claims.AUDIENCE): if payload.get(claims.TOKEN) == claims.TOKEN_SESSION and api_settings.SESSION_AUDIENCE is not None: payload[claims.AUDIENCE] = api_settings.SESSION_AUDIENCE elif ( payload.get(claims.TOKEN) == claims.TOKEN_AUTHORIZATION and api_settings.AUTHORIZATION_AUDIENCE is not None ): payload[claims.AUDIENCE] = api_settings.AUTHORIZATION_AUDIENCE elif api_settings.IDENTITY is not None: payload[claims.AUDIENCE] = [api_settings.IDENTITY] else: raise RuntimeError("SESSION_AUDIENCE must be specified in settings") if not payload.get(claims.EXPIRATION_TIME): if payload.get(claims.TOKEN) == claims.TOKEN_SESSION and api_settings.SESSION_EXPIRATION is not None: payload[claims.EXPIRATION_TIME] = datetime.utcnow() + api_settings.SESSION_EXPIRATION elif ( payload.get(claims.TOKEN) == claims.TOKEN_AUTHORIZATION and api_settings.AUTHORIZATION_EXPIRATION is not None ): payload[claims.EXPIRATION_TIME] = datetime.utcnow() + api_settings.AUTHORIZATION_EXPIRATION if not payload.get(claims.ISSUED_AT): payload[claims.ISSUED_AT] = datetime.utcnow() if payload[claims.ISSUER] not in api_settings.PRIVATE_KEYS: raise RuntimeError("Private key for specified issuer was not found in settings") private_key, key_id = get_private_key_and_key_id(issuer=payload[claims.ISSUER]) headers = {claims.KEY_ID: key_id} return jwt.encode( payload=payload, key=private_key, algorithm=api_settings.ENCODE_ALGORITHM, headers=headers, json_encoder=DjangoJSONEncoder, ).decode("utf-8") def decode_jwt_token(token): unverified_header = jwt.get_unverified_header(token) unverified_claims = jwt.decode(token, verify=False) if unverified_header.get(claims.KEY_ID): unverified_key_id = str(unverified_header.get(claims.KEY_ID)) else: unverified_key_id = None if claims.ISSUER not in unverified_claims: raise MissingRequiredClaimError(claims.ISSUER) unverified_issuer = str(unverified_claims[claims.ISSUER]) if api_settings.ACCEPTED_ISSUERS is not None and unverified_issuer not in api_settings.ACCEPTED_ISSUERS: raise InvalidIssuerError("Invalid issuer") public_key, key_id = get_public_key_and_key_id(issuer=unverified_issuer, key_id=unverified_key_id) options = { "verify_exp": api_settings.VERIFY_EXPIRATION, "verify_iss": api_settings.VERIFY_ISSUER, "verify_aud": api_settings.VERIFY_AUDIENCE, } payload = jwt.decode( jwt=token, key=public_key, verify=api_settings.VERIFY_SIGNATURE, algorithms=api_settings.DECODE_ALGORITHMS or [api_settings.ENCODE_ALGORITHM], options=options, leeway=api_settings.EXPIRATION_LEEWAY, audience=api_settings.IDENTITY, issuer=unverified_issuer, ) if payload.get(claims.TOKEN) not in (claims.TOKEN_SESSION, claims.TOKEN_AUTHORIZATION): raise InvalidTokenError("Unknown token type") if not payload.get(claims.SESSION_ID): raise MissingRequiredClaimError("Session ID is missing.") if not payload.get(claims.USER_ID): raise MissingRequiredClaimError("User ID is missing.") return payload def authenticate_payload(payload, request=None): from rest_framework_sso.models import SessionToken user_model = get_user_model() if api_settings.VERIFY_SESSION_TOKEN: try: session_token = ( SessionToken.objects.active() .select_related("user") .get(pk=payload.get(claims.SESSION_ID), user_id=payload.get(claims.USER_ID)) ) if request is not None: session_token.update_attributes(request=request) session_token.last_used_at = timezone.now() session_token.save() user = session_token.user except SessionToken.DoesNotExist: raise exceptions.AuthenticationFailed(_("Invalid token.")) else: try: user = user_model.objects.get(pk=payload.get(claims.USER_ID)) except user_model.DoesNotExist: raise exceptions.AuthenticationFailed(_("Invalid token.")) if not user.is_active: raise exceptions.AuthenticationFailed(_("User inactive or deleted.")) return user