import base64 from datetime import datetime, timedelta import json from unittest.mock import patch from unittest import TestCase as PythonTestCase from django.core.exceptions import ImproperlyConfigured from django.test import override_settings import jwt from oauth2_provider_jwt import utils class GeneratePayloadTest(PythonTestCase): def _get_payload_args(self): issuer = 'activityapi' expires_in = 36000 return issuer, expires_in @patch('oauth2_provider_jwt.utils.datetime') def test_generate_payload_no_extra_data(self, mock_datetime): now = datetime.utcnow() mock_datetime.utcnow.return_value = now issuer, expires_in = self._get_payload_args() expiration = now + timedelta(seconds=expires_in) self.assertEqual( utils.generate_payload(issuer, expires_in), { 'iss': issuer, 'exp': expiration, 'iat': now, } ) @patch('oauth2_provider_jwt.utils.datetime') def test_generate_payload_with_extra_data(self, mock_datetime): now = datetime.utcnow() mock_datetime.utcnow.return_value = now issuer, expires_in = self._get_payload_args() expiration = now + timedelta(seconds=expires_in) extra_data = { 'usr': 'some_usr', 'org': 'some_org', 'sub': 'subject', } self.assertEqual( utils.generate_payload(issuer, expires_in, **extra_data), { 'iss': issuer, 'exp': expiration, 'iat': now, 'sub': 'subject', 'usr': 'some_usr', 'org': 'some_org', } ) class EncodeJWTTest(PythonTestCase): def _get_payload(self): now = datetime.utcnow() return { 'iss': 'issuer', 'exp': now + timedelta(seconds=10), 'iat': now, 'sub': 'subject', 'usr': 'some_usr', 'org': 'some_org', } @override_settings(JWT_PRIVATE_KEY_ISSUER='') def test_encode_jwt_no_private_key_in_setting(self): payload = self._get_payload() self.assertRaises(ImproperlyConfigured, utils.encode_jwt, payload) def test_encode_jwt_rs256(self): payload_in = self._get_payload() encoded = utils.encode_jwt(payload_in) self.assertIn(type(encoded).__name__, ('unicode', 'str')) headers, payload, verify_signature = encoded.split(".") self.assertDictEqual( json.loads(base64.b64decode(headers)), {"typ": "JWT", "alg": "RS256"}) payload += '=' * (-len(payload) % 4) # add padding self.assertEqual( json.loads(base64.b64decode(payload).decode("utf-8")), payload_in) @override_settings(JWT_PRIVATE_KEY_ISSUER='test') @override_settings(JWT_ENC_ALGORITHM='HS256') def test_encode_jwt_hs256(self): payload_in = self._get_payload() encoded = utils.encode_jwt(payload_in) self.assertIn(type(encoded).__name__, ('unicode', 'str')) headers, payload, verify_signature = encoded.split('.') self.assertDictEqual( json.loads(base64.b64decode(headers)), {'typ': 'JWT', 'alg': 'HS256'}) payload += '=' * (-len(payload) % 4) self.assertEqual( json.loads(base64.b64decode(payload).decode('utf-8')), payload_in) class DecodeJWTTest(PythonTestCase): def _get_payload(self): now = datetime.utcnow() return { 'iss': 'issuer', 'exp': now + timedelta(seconds=10), 'iat': now, 'sub': 'subject', 'usr': 'some_usr', 'org': 'some_org', } def test_decode_jwt_invalid(self): self.assertRaises(jwt.InvalidTokenError, utils.decode_jwt, 'abc') @override_settings(JWT_PUBLIC_KEY_ISSUER='') def test_decode_jwt_public_key_not_found(self): payload = self._get_payload() jwt_value = utils.encode_jwt(payload) self.assertRaises(ImproperlyConfigured, utils.decode_jwt, jwt_value) def test_decode_jwt_expired(self): payload = self._get_payload() now = datetime.utcnow() payload['exp'] = now - timedelta(seconds=1) payload['iat'] = now jwt_value = utils.encode_jwt(payload) self.assertRaises(jwt.ExpiredSignatureError, utils.decode_jwt, jwt_value) def test_decode_jwt_rs256(self): payload = self._get_payload() jwt_value = utils.encode_jwt(payload) payload_out = utils.decode_jwt(jwt_value) self.assertDictEqual(payload, payload_out) @override_settings(JWT_PRIVATE_KEY_ISSUER='test') @override_settings(JWT_PUBLIC_KEY_ISSUER='test') @override_settings(JWT_ENC_ALGORITHM='HS256') def test_decode_jwt_hs256(self): payload = self._get_payload() jwt_value = utils.encode_jwt(payload) payload_out = utils.decode_jwt(jwt_value) self.assertDictEqual(payload, payload_out)