""" Base classes for Custom Authenticator to use OAuth with JupyterHub Most of the code c/o Kyle Kelley (@rgbkrk) """ import base64 import json import os from urllib.parse import quote, urlparse import uuid from tornado import web from tornado.auth import OAuth2Mixin from tornado.log import app_log from jupyterhub.handlers import BaseHandler from jupyterhub.auth import Authenticator from jupyterhub.utils import url_path_join from traitlets import Unicode, Bool, List, Dict, default def guess_callback_uri(protocol, host, hub_server_url): return '{proto}://{host}{path}'.format( proto=protocol, host=host, path=url_path_join(hub_server_url, 'oauth_callback') ) STATE_COOKIE_NAME = 'oauthenticator-state' def _serialize_state(state): """Serialize OAuth state to a base64 string after passing through JSON""" json_state = json.dumps(state) return base64.urlsafe_b64encode(json_state.encode('utf8')).decode('ascii') def _deserialize_state(b64_state): """Deserialize OAuth state as serialized in _serialize_state""" if isinstance(b64_state, str): b64_state = b64_state.encode('ascii') try: json_state = base64.urlsafe_b64decode(b64_state).decode('utf8') except ValueError: app_log.error("Failed to b64-decode state: %r", b64_state) return {} try: return json.loads(json_state) except ValueError: app_log.error("Failed to json-decode state: %r", json_state) return {} class OAuthLoginHandler(OAuth2Mixin, BaseHandler): """Base class for OAuth login handler Typically subclasses will need """ # these URLs are part of the OAuth2Mixin API # get them from the Authenticator object @property def _OAUTH_AUTHORIZE_URL(self): return self.authenticator.authorize_url @property def _OAUTH_ACCESS_TOKEN_URL(self): return self.authenticator.token_url @property def _OAUTH_USERINFO_URL(self): return self.authenticator.userdata_url def set_state_cookie(self, state): self.set_secure_cookie(STATE_COOKIE_NAME, state, expires_days=1, httponly=True) _state = None def get_state(self): next_url = original_next_url = self.get_argument('next', None) if next_url: # avoid browsers treating \ as / next_url = next_url.replace('\\', quote('\\')) # disallow hostname-having urls, # force absolute path redirect urlinfo = urlparse(next_url) next_url = urlinfo._replace( scheme='', netloc='', path='/' + urlinfo.path.lstrip('/') ).geturl() if next_url != original_next_url: self.log.warning( "Ignoring next_url %r, using %r", original_next_url, next_url ) if self._state is None: self._state = _serialize_state( {'state_id': uuid.uuid4().hex, 'next_url': next_url} ) return self._state def get(self): redirect_uri = self.authenticator.get_callback_url(self) extra_params = self.authenticator.extra_authorize_params.copy() self.log.info('OAuth redirect: %r', redirect_uri) state = self.get_state() self.set_state_cookie(state) extra_params['state'] = state self.authorize_redirect( redirect_uri=redirect_uri, client_id=self.authenticator.client_id, scope=self.authenticator.scope, extra_params=extra_params, response_type='code', ) class OAuthCallbackHandler(BaseHandler): """Basic handler for OAuth callback. Calls authenticator to verify username.""" _state_cookie = None def get_state_cookie(self): """Get OAuth state from cookies To be compared with the value in redirect URL """ if self._state_cookie is None: self._state_cookie = ( self.get_secure_cookie(STATE_COOKIE_NAME) or b'' ).decode('utf8', 'replace') self.clear_cookie(STATE_COOKIE_NAME) return self._state_cookie def get_state_url(self): """Get OAuth state from URL parameters to be compared with the value in cookies """ return self.get_argument("state") def check_state(self): """Verify OAuth state compare value in cookie with redirect url param """ cookie_state = self.get_state_cookie() url_state = self.get_state_url() if not cookie_state: raise web.HTTPError(400, "OAuth state missing from cookies") if not url_state: raise web.HTTPError(400, "OAuth state missing from URL") if cookie_state != url_state: self.log.warning("OAuth state mismatch: %s != %s", cookie_state, url_state) raise web.HTTPError(400, "OAuth state mismatch") def check_error(self): """Check the OAuth code""" error = self.get_argument("error", False) if error: message = self.get_argument("error_description", error) raise web.HTTPError(400, "OAuth error: %s" % message) def check_code(self): """Check the OAuth code""" if not self.get_argument("code", False): raise web.HTTPError(400, "OAuth callback made without a code") def check_arguments(self): """Validate the arguments of the redirect Default: - check for oauth-standard error, error_description arguments - check that there's a code - check that state matches """ self.check_error() self.check_code() self.check_state() def get_next_url(self, user=None): """Get the redirect target from the state field""" state = self.get_state_url() if state: next_url = _deserialize_state(state).get('next_url') if next_url: return next_url # JupyterHub 0.8 adds default .get_next_url for a fallback if hasattr(BaseHandler, 'get_next_url'): return super().get_next_url(user) return url_path_join(self.hub.server.base_url, 'home') async def _login_user_pre_08(self): """login_user simplifies the login+cookie+auth_state process in JupyterHub 0.8 _login_user_07 is for backward-compatibility with JupyterHub 0.7 """ user_info = await self.authenticator.get_authenticated_user(self, None) if user_info is None: return if isinstance(user_info, dict): username = user_info['name'] else: username = user_info user = self.user_from_username(username) self.set_login_cookie(user) return user if not hasattr(BaseHandler, 'login_user'): # JupyterHub 0.7 doesn't have .login_user login_user = _login_user_pre_08 async def get(self): self.check_arguments() user = await self.login_user() if user is None: # todo: custom error page? raise web.HTTPError(403) self.redirect(self.get_next_url(user)) class OAuthenticator(Authenticator): """Base class for OAuthenticators Subclasses must override: login_service (string identifying the service provider) authenticate (method takes one arg - the request handler handling the oauth callback) """ login_handler = OAuthLoginHandler callback_handler = OAuthCallbackHandler authorize_url = Unicode( config=True, help="""The authenticate url for initiating oauth""" ) @default("authorize_url") def _authorize_url_default(self): return os.environ.get("OAUTH2_AUTHORIZE_URL", "") token_url = Unicode( config=True, help="""The url retrieving an access token at the completion of oauth""", ) @default("token_url") def _token_url_default(self): return os.environ.get("OAUTH2_TOKEN_URL", "") userdata_url = Unicode( config=True, help="""The url for retrieving user data with a completed access token""", ) @default("userdata_url") def _userdata_url_default(self): return os.environ.get("OAUTH2_USERDATA_URL", "") scope = List( Unicode(), config=True, help="""The OAuth scopes to request. See the OAuth documentation of your OAuth provider for options. For GitHub in particular, you can see github_scopes.md in this repo. """, ) extra_authorize_params = Dict( config=True, help="""Extra GET params to send along with the initial OAuth request to the OAuth provider.""", ) login_service = 'override in subclass' oauth_callback_url = Unicode( os.getenv('OAUTH_CALLBACK_URL', ''), config=True, help="""Callback URL to use. Typically `https://{host}/hub/oauth_callback`""", ) client_id_env = '' client_id = Unicode(config=True) def _client_id_default(self): if self.client_id_env: client_id = os.getenv(self.client_id_env, '') if client_id: return client_id return os.getenv('OAUTH_CLIENT_ID', '') client_secret_env = '' client_secret = Unicode(config=True) def _client_secret_default(self): if self.client_secret_env: client_secret = os.getenv(self.client_secret_env, '') if client_secret: return client_secret return os.getenv('OAUTH_CLIENT_SECRET', '') validate_server_cert_env = 'OAUTH_TLS_VERIFY' validate_server_cert = Bool(config=True) def _validate_server_cert_default(self): env_value = os.getenv(self.validate_server_cert_env, '') if env_value == '0': return False else: return True def login_url(self, base_url): return url_path_join(base_url, 'oauth_login') def get_callback_url(self, handler=None): """Get my OAuth redirect URL Either from config or guess based on the current request. """ if self.oauth_callback_url: return self.oauth_callback_url elif handler: return guess_callback_uri( handler.request.protocol, handler.request.host, handler.hub.server.base_url, ) else: raise ValueError( "Specify callback oauth_callback_url or give me a handler to guess with" ) def get_handlers(self, app): return [ (r'/oauth_login', self.login_handler), (r'/oauth_callback', self.callback_handler), ] async def authenticate(self, handler, data=None): raise NotImplementedError()