import inspect import typing from datetime import timedelta, datetime from typing import Callable, Awaitable, Union import jwt from fastapi.security import OAuth2PasswordBearer from passlib.context import CryptContext from starlette.datastructures import Secret from starlette.requests import Request from fastapi_login.exceptions import InvalidCredentialsException class LoginManager: def __init__(self, secret: str, tokenUrl: str, algorithm="HS256"): """ :param str secret: Secret key used to sign and decrypt the JWT :param str algorithm: Should be "HS256" or "RS256" used to decrypt the JWT :param str tokenUrl: the url where the user can login to get the token """ self.secret = Secret(secret) self._user_callback = None self.algorithm = algorithm self.pwd_context = CryptContext(schemes=["bcrypt"]) # this is not mandatory as they user may want to user their own # function to get the token and pass it to the get_current_user method self.tokenUrl = tokenUrl self.oauth_scheme = None self._not_authenticated_exception = None def user_loader(self, callback: Union[Callable, Awaitable]) -> Union[Callable, Awaitable]: """ This sets the callback to retrieve the user. The function should take an unique identifier like an email and return the user object or None. Basic usage:: >>> from fastapi import FastAPI >>> from fastapi_login import LoginManager >>> app = FastAPI() >>> # use import os; print(os.urandom(24).hex()) to get a true secret key >>> SECRET = "super-secret" >>> manager = LoginManager(SECRET, app) >>> manager.user_loader(get_user) >>> # this is the preferred way >>> @manager.user_loader >>> def get_user(): ... # get user logic here :param Callable or Awaitable callback: The callback which returns the user :return: The callback """ self._user_callback = callback return callback @property def not_authenticated_exception(self): return self._not_authenticated_exception @not_authenticated_exception.setter def not_authenticated_exception(self, value: Exception): """ Setter for the Exception which raises when the user is not authenticated :param Exception value: The Exception you want to raise """ assert issubclass(value, Exception) self._not_authenticated_exception = value async def get_current_user(self, token: str): """ This decodes the jwt based on the secret and on the algorithm set on the LoginManager. If the token is correctly formatted and the user is found the user is returned else this raises a `fastapi.HTTPException` :param str token: The encoded jwt token :return: The user object returned by `self._user_callback` :raise: HTTPException if the token is invalid or the user is not found """ try: payload = jwt.decode( token, str(self.secret), algorithms=[self.algorithm] ) # the identifier should be stored under the sub (subject) key user_identifier = payload.get('sub') if user_identifier is None: raise InvalidCredentialsException except jwt.PyJWTError: raise InvalidCredentialsException user = await self._load_user(user_identifier) if user is None: raise InvalidCredentialsException return user async def _load_user(self, identifier: typing.Any): """ This loads the user using the user_callback :param typing.Any identifier: The identifier the user callback takes :return: The user object or None :raises: Exception if the user_back has not been set """ if self._user_callback is None: raise Exception( "Missing user_loader callback" ) if inspect.iscoroutinefunction(self._user_callback): user = await self._user_callback(identifier) else: user = self._user_callback(identifier) return user def create_access_token(self, *, data: dict, expires_delta: timedelta = None) -> str: """ Helper function to create the encoded access token using the provided secret and the algorithm of the LoginManager instance :param dict data: The data which should be stored in the token :param timedelta expires_delta: An optional timedelta in which the token expires. Defaults to 15 minutes :return: The encoded JWT with the data and the expiry. The expiry is available under the 'exp' key """ to_encode = data.copy() if expires_delta: expires_in = datetime.utcnow() + expires_delta else: # default to 15 minutes expiry times expires_in = datetime.utcnow() + timedelta(minutes=15) to_encode.update({'exp': expires_in}) encoded_jwt = jwt.encode(to_encode, str(self.secret), self.algorithm) return encoded_jwt.decode() async def __call__(self, request: Request): """ Provides the functionality to act as a Dependency :param Request request: The incoming request, this is set automatically by FastAPI :return: The user object or None :raises: The not_authenticated_exception if set by the user """ if self.not_authenticated_exception is None: self.oauth_scheme = OAuth2PasswordBearer(tokenUrl=self.tokenUrl) else: # we handle Exception raising self.oauth_scheme = OAuth2PasswordBearer(tokenUrl=self.tokenUrl, auto_error=False) token = await self.oauth_scheme(request) if token is not None: return await self.get_current_user(token) # No token is present in the request and no Exception has been raised yet raise self.not_authenticated_exception