import base64 as _base64 import hashlib as _hashlib import keyring as _keyring import os as _os import re as _re import requests as _requests import webbrowser as _webbrowser from multiprocessing import Process as _Process, Queue as _Queue try: # Python 3.5+ from http import HTTPStatus as _StatusCodes except ImportError: try: # Python 3 from http import client as _StatusCodes except ImportError: # Python 2 import httplib as _StatusCodes try: # Python 3 import http.server as _BaseHTTPServer except ImportError: # Python 2 import BaseHTTPServer as _BaseHTTPServer try: # Python 3 import urllib.parse as _urlparse from urllib.parse import urlencode as _urlencode except ImportError: # Python 2 import urlparse as _urlparse from urllib import urlencode as _urlencode _code_verifier_length = 64 _random_seed_length = 40 _utf_8 = 'utf-8' # Identifies the service used for storing passwords in keyring _keyring_service_name = "flyteauth" # Identifies the key used for storing and fetching from keyring. In our case, instead of a username as the keyring docs # suggest, we are storing a user's oidc. _keyring_access_token_storage_key = "access_token" _keyring_refresh_token_storage_key = "refresh_token" def _generate_code_verifier(): """ Generates a 'code_verifier' as described in https://tools.ietf.org/html/rfc7636#section-4.1 Adapted from https://github.com/openstack/deb-python-oauth2client/blob/master/oauth2client/_pkce.py. :return str: """ code_verifier = _base64.urlsafe_b64encode(_os.urandom(_code_verifier_length)).decode(_utf_8) # Eliminate invalid characters. code_verifier = _re.sub(r'[^a-zA-Z0-9_\-.~]+', '', code_verifier) if len(code_verifier) < 43: raise ValueError("Verifier too short. number of bytes must be > 30.") elif len(code_verifier) > 128: raise ValueError("Verifier too long. number of bytes must be < 97.") return code_verifier def _generate_state_parameter(): state = _base64.urlsafe_b64encode(_os.urandom(_random_seed_length)).decode(_utf_8) # Eliminate invalid characters. code_verifier = _re.sub('[^a-zA-Z0-9-_.,]+', '', state) return code_verifier def _create_code_challenge(code_verifier): """ Adapted from https://github.com/openstack/deb-python-oauth2client/blob/master/oauth2client/_pkce.py. :param str code_verifier: represents a code verifier generated by generate_code_verifier() :return str: urlsafe base64-encoded sha256 hash digest """ code_challenge = _hashlib.sha256(code_verifier.encode(_utf_8)).digest() code_challenge = _base64.urlsafe_b64encode(code_challenge).decode(_utf_8) # Eliminate invalid characters code_challenge = code_challenge.replace('=', '') return code_challenge class AuthorizationCode(object): def __init__(self, code, state): self._code = code self._state = state @property def code(self): return self._code @property def state(self): return self._state class OAuthCallbackHandler(_BaseHTTPServer.BaseHTTPRequestHandler): """ A simple wrapper around BaseHTTPServer.BaseHTTPRequestHandler that handles a callback URL that accepts an authorization token. """ def do_GET(self): url = _urlparse.urlparse(self.path) if url.path == self.server.redirect_path: self.send_response(_StatusCodes.OK) self.end_headers() self.handle_login(dict(_urlparse.parse_qsl(url.query))) else: self.send_response(_StatusCodes.NOT_FOUND) def handle_login(self, data): self.server.handle_authorization_code(AuthorizationCode(data['code'], data['state'])) class OAuthHTTPServer(_BaseHTTPServer.HTTPServer): """ A simple wrapper around the BaseHTTPServer.HTTPServer implementation that binds an authorization_client for handling authorization code callbacks. """ def __init__(self, server_address, RequestHandlerClass, bind_and_activate=True, redirect_path=None, queue=None): _BaseHTTPServer.HTTPServer.__init__(self, server_address, RequestHandlerClass, bind_and_activate) self._redirect_path = redirect_path self._auth_code = None self._queue = queue @property def redirect_path(self): return self._redirect_path def handle_authorization_code(self, auth_code): self._queue.put(auth_code) class Credentials(object): def __init__(self, access_token=None): self._access_token = access_token @property def access_token(self): return self._access_token class AuthorizationClient(object): def __init__(self, auth_endpoint=None, token_endpoint=None, client_id=None, redirect_uri=None): self._auth_endpoint = auth_endpoint self._token_endpoint = token_endpoint self._client_id = client_id self._redirect_uri = redirect_uri self._code_verifier = _generate_code_verifier() code_challenge = _create_code_challenge(self._code_verifier) self._code_challenge = code_challenge state = _generate_state_parameter() self._state = state self._credentials = None self._refresh_token = None self._headers = {'content-type': "application/x-www-form-urlencoded"} self._expired = False self._params = { "client_id": client_id, # This must match the Client ID of the OAuth application. "response_type": "code", # Indicates the authorization code grant "scope": "openid offline_access", # ensures that the /token endpoint returns an ID and refresh token # callback location where the user-agent will be directed to. "redirect_uri": self._redirect_uri, "state": state, "code_challenge": code_challenge, "code_challenge_method": "S256", } # Prefer to use already-fetched token values when they've been set globally. self._refresh_token = _keyring.get_password(_keyring_service_name, _keyring_refresh_token_storage_key) access_token = _keyring.get_password(_keyring_service_name, _keyring_access_token_storage_key) if access_token: self._credentials = Credentials(access_token=access_token) return # In the absence of globally-set token values, initiate the token request flow q = _Queue() # First prepare the callback server in the background server = self._create_callback_server(q) server_process = _Process(target=server.handle_request) server_process.start() # Send the call to request the authorization code self._request_authorization_code() # Request the access token once the auth code has been received. auth_code = q.get() server_process.terminate() self.request_access_token(auth_code) def _create_callback_server(self, q): server_url = _urlparse.urlparse(self._redirect_uri) server_address = (server_url.hostname, server_url.port) return OAuthHTTPServer(server_address, OAuthCallbackHandler, redirect_path=server_url.path, queue=q) def _request_authorization_code(self): scheme, netloc, path, _, _, _ = _urlparse.urlparse(self._auth_endpoint) query = _urlencode(self._params) endpoint = _urlparse.urlunparse((scheme, netloc, path, None, query, None)) _webbrowser.open_new_tab(endpoint) def _initialize_credentials(self, auth_token_resp): """ The auth_token_resp body is of the form: { "access_token": "foo", "refresh_token": "bar", "id_token": "baz", "token_type": "Bearer" } """ response_body = auth_token_resp.json() if "access_token" not in response_body: raise ValueError('Expected "access_token" in response from oauth server') if "refresh_token" in response_body: self._refresh_token = response_body["refresh_token"] access_token = response_body["access_token"] refresh_token = response_body["refresh_token"] _keyring.set_password(_keyring_service_name, _keyring_access_token_storage_key, access_token) _keyring.set_password(_keyring_service_name, _keyring_refresh_token_storage_key, refresh_token) self._credentials = Credentials(access_token=access_token) def request_access_token(self, auth_code): if self._state != auth_code.state: raise ValueError("Unexpected state parameter [{}] passed".format(auth_code.state)) self._params.update({ "code": auth_code.code, "code_verifier": self._code_verifier, "grant_type": "authorization_code", }) resp = _requests.post( url=self._token_endpoint, data=self._params, headers=self._headers, allow_redirects=False ) if resp.status_code != _StatusCodes.OK: # TODO: handle expected (?) error cases: # https://auth0.com/docs/flows/guides/device-auth/call-api-device-auth#token-responses raise Exception('Failed to request access token with response: [{}] {}'.format( resp.status_code, resp.content)) self._initialize_credentials(resp) def refresh_access_token(self): if self._refresh_token is None: raise ValueError("no refresh token available with which to refresh authorization credentials") resp = _requests.post( url=self._token_endpoint, data={'grant_type': 'refresh_token', 'client_id': self._client_id, 'refresh_token': self._refresh_token}, headers=self._headers, allow_redirects=False ) if resp.status_code != _StatusCodes.OK: self._expired = True # In the absence of a successful response, assume the refresh token is expired. This should indicate # to the caller that the AuthorizationClient is defunct and a new one needs to be re-initialized. _keyring.delete_password(_keyring_service_name, _keyring_access_token_storage_key) _keyring.delete_password(_keyring_service_name, _keyring_refresh_token_storage_key) return self._initialize_credentials(resp) @property def credentials(self): """ :return flytekit.clis.auth.auth.Credentials: """ return self._credentials @property def expired(self): """ :return bool: """ return self._expired