import json from unittest import mock from django.contrib.auth import get_user_model from django.contrib.auth.models import AnonymousUser from django.core.exceptions import ImproperlyConfigured from django.http import HttpResponse from django.test import RequestFactory, TestCase from jwt import exceptions from request_token.middleware import RequestTokenMiddleware from request_token.models import RequestToken from request_token.settings import JWT_QUERYSTRING_ARG class MockSession(object): """Fake Session model used to support `session_key` property.""" @property def session_key(self): return "foobar" class MiddlewareTests(TestCase): """RequestTokenMiddleware tests.""" def setUp(self): self.user = get_user_model().objects.create_user("zoidberg") self.factory = RequestFactory() self.middleware = RequestTokenMiddleware(get_response=lambda r: HttpResponse()) self.token = RequestToken.objects.create_token(scope="foo") def get_request(self): request = self.factory.get("/?%s=%s" % (JWT_QUERYSTRING_ARG, self.token.jwt())) request.user = self.user request.session = MockSession() return request def post_request(self): request = self.factory.post("/", {JWT_QUERYSTRING_ARG: self.token.jwt()}) request.user = self.user request.session = MockSession() return request def post_request_with_JSON(self): data = json.dumps({JWT_QUERYSTRING_ARG: self.token.jwt()}) request = self.factory.post("/", data, "application/json") request.user = self.user request.session = MockSession() return request def test_process_request_assertions(self): request = self.factory.get("/") self.assertRaises(ImproperlyConfigured, self.middleware, request) request.user = AnonymousUser() self.assertRaises(ImproperlyConfigured, self.middleware, request) request.session = MockSession() self.middleware(request) self.assertFalse(hasattr(request, "token")) def test_process_request_without_token(self): request = self.factory.get("/") request.user = AnonymousUser() request.session = MockSession() self.middleware(request) self.assertFalse(hasattr(request, "token")) def test_process_GET_request_with_valid_token(self): request = self.get_request() self.middleware(request) self.assertEqual(request.token, self.token) def test_process_POST_request_with_valid_token(self): request = self.post_request() self.middleware(request) self.assertEqual(request.token, self.token) def test_process_POST_request_with_valid_token_with_json(self): request = self.post_request_with_JSON() self.middleware(request) self.assertEqual(request.token, self.token) def test_process_request_not_allowed(self): # PUT requests won't decode the token request = self.factory.put("/?rt=foo") request.user = self.user request.session = MockSession() response = self.middleware(request) self.assertFalse(hasattr(request, "token")) self.assertEqual(response.status_code, 200) @mock.patch("request_token.middleware.logger") def test_process_request_token_error(self, mock_logger): # token decode error - request passes through _without_ a token request = self.factory.get("/?rt=foo") request.user = self.user request.session = MockSession() self.middleware(request) self.assertIsNone(request.token) self.assertEqual(mock_logger.exception.call_count, 1) @mock.patch("request_token.middleware.logger") def test_process_request_token_does_not_exist(self, mock_logger): request = self.get_request() self.token.delete() self.middleware(request) self.assertIsNone(request.token) self.assertEqual(mock_logger.exception.call_count, 1) @mock.patch.object(RequestToken, "log") def test_process_exception(self, mock_log): request = self.get_request() request.token = self.token exception = exceptions.InvalidTokenError("bar") response = self.middleware.process_exception(request, exception) mock_log.assert_called_once_with(request, response, error=exception) self.assertEqual(response.status_code, 403) self.assertEqual(response.reason_phrase, str(exception)) # no request token = no error log del request.token mock_log.reset_mock() response = self.middleware.process_exception(request, exception) self.assertEqual(mock_log.call_count, 0) self.assertEqual(response.status_code, 403) self.assertEqual(response.reason_phrase, str(exception)) # round it out with a non-token error response = self.middleware.process_exception(request, Exception("foo")) self.assertIsNone(response)