#
# Copyright 2017 Chef Software
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import base64
import getpass
import html
import logging
import os
import re
import shutil
import subprocess
import sys
import xml.etree.ElementTree as ET

import requests
import toml
import boto3

from okta_aws import exceptions


__VERSION__ = '0.7.0'


class OktaAWS(object):
    def __init__(self, argv=None):
        """Initialize the program

        argv - command line arguments (or None to use sys.argv)
        """
        self.args = self.parse_args(argv)
        self.profile = self.args.profile

    def parse_args(self, argv):
        """Parses command line arguments using argparse

        argv - command line arguments (or None to use sys.argv)
        """
        parser = argparse.ArgumentParser(
            description='Generates temporary AWS credentials for an AWS'
            ' account you access through okta.')
        parser.add_argument('profile', nargs='?',
                            default=os.getenv("AWS_PROFILE") or "default",
                            help='The AWS profile you want credentials for')
        parser.add_argument('--config', '-c', default='~/.okta_aws.toml',
                            help='Path to the configuration file')
        parser.add_argument('--no-cookies', '-n', action='store_true',
                            help="Don't use or save okta session cookie")
        parser.add_argument('--debug', '-d', action='store_true',
                            help='Show debug output')
        parser.add_argument('--quiet', '-q', action='store_true',
                            help='Only show error messages')
        parser.add_argument('--list', '-l', action='store_true',
                            help="Don't assume a role, list assigned "
                            "applications in okta")
        parser.add_argument('--all', '-a', action='store_true',
                            help='Assume a role in all assigned accounts')
        parser.add_argument('--role_arn', '-r',
                            help='Role name or ARN to assume')
        parser.add_argument('--version', '-v', action='version',
                            version=__VERSION__,
                            help='Show version of okta_aws and exit')
        parser.add_argument('--setup', '-s', action='store_true',
                            help="Set up a config file for okta_aws")
        return parser.parse_args(argv)

    def setup_logging(self):
        """Sets up logging based on whether debugging is enabled or not"""
        if self.args.debug:
            logging.basicConfig(
                format='%(asctime)s %(levelname)s %(message)s',
                level=logging.DEBUG)
        elif self.args.quiet:
            logging.basicConfig(format='%(message)s', level=logging.ERROR)
        else:
            logging.basicConfig(format='%(message)s', level=logging.INFO)

        if not self.args.debug:
            logging.getLogger('boto3').setLevel(logging.ERROR)
            logging.getLogger('botocore').setLevel(logging.ERROR)

    def preflight_checks(self):
        """Performs some checks to ensure that the program can be run
        successfully. If these checks fail, an explanation is given as well as
        a hint for how the user can fix the problem.
        """
        errors = []
        # AWS cli
        if shutil.which("aws") is None:
            errors.append("The AWS CLI (the 'aws' command) cannot be found. "
                          "see http://docs.aws.amazon.com/cli/latest/"
                          "userguide/installing.html for information on "
                          "installing it.")
        if errors:
            print("Preflight check failed")
            print("======================")
            for e in errors:
                print("* %s" % e)
            sys.exit(1)

    def interactive_setup(self, config_file):
        """Performs first-time setup for users who haven't set up a config
        file yet, by asking some simple questions.
        """
        try:
            toml_config = toml.load(os.path.expanduser(config_file))
        except FileNotFoundError:
            toml_config = {}
        toml_config.setdefault('general', {})

        default_username = toml_config['general'].get(
            'username', getpass.getuser())
        default_server = toml_config['general'].get(
            'okta_server', 'example.okta.com')

        print("Okta AWS initial setup")
        print("======================")
        print()
        username = input("Enter your okta username [%s]: " % default_username)
        if username == "":
            username = default_username

        okta_server = input("Enter your okta domain [%s]: " % default_server)
        if okta_server == "":
            okta_server = default_server

        print()
        print("Creating/updating %s" % config_file)
        toml_config['general']['username'] = username
        toml_config['general']['okta_server'] = okta_server
        with open(os.path.expanduser(config_file), "w") as fh:
            toml.dump(toml_config, fh)

        print("Setup complete. You can now log in with okta_aws PROFILENAME.")
        print("Hint: you can use 'okta_aws --list' to see which profiles "
              "you can use.")

    def load_config(self, config_file):
        """Loads the config file and returns a dictionary containing its
        contents.

        config_file - path to the configuration file to load
        """
        try:
            config = toml.load(os.path.expanduser(config_file))
        except FileNotFoundError:
            self.interactive_setup(config_file)
            sys.exit(0)

        config.setdefault('general', {})
        config.setdefault('aliases', {})

        required_config_options = [
            'username',
            'okta_server'
        ]

        missing_options = [k for k in required_config_options
                           if k not in config['general']]
        if missing_options:
            logging.error("Missing required configuration settings: %s",
                          ', '.join(missing_options))
            sys.exit(1)

        # Default configuration values
        config['general'].setdefault('cookie_file', '~/.okta_aws_cookie')
        config['general'].setdefault('short_profile_names', True)
        config['general'].setdefault('session_duration', 3600)

        config['general']['cookie_file'] = os.path.expanduser(
            config['general']['cookie_file'])

        return config

    def get_config(self, key, default=None):
        """Obtain a profile specific configuration value, falling back to the
        general config or a default value"""
        try:
            return self.config[self.profile][key]
        except KeyError:
            try:
                real_profile = self.config['aliases'][self.profile]
                return self.config[real_profile][key]
            except KeyError:
                return self.config['general'].get(key, default)

    def choose_from_menu(self, choices, prompt="Select an option: "):
        """Present an interactive menu of choices for the user to pick from.

        Returns the index into the list of the selected item.

        choices - a list of options to choose from
        prompt  - the prompt to show to the user
        """
        for idx, value in enumerate(choices):
            print("%2d) %s" % (idx + 1, value))
        response = 0
        while response < 1 or response > len(choices):
            try:
                response = int(input(prompt))
            except ValueError:
                # If we enter something invalid, just go through the
                # loop again.
                pass
        return response - 1

    def select_role(self, arns):
        """Returns the role to use from a list of principal/role arn pairs,
        based on either a configuration option, user selecting from a menu,
        or simply returning the only arn pair in the list if there is only
        one.

        arns - a list of arn pairs (each pair should be a princpal/role arn)
        """
        selected = None
        if len(arns) > 1:
            # Get role via config, but allow commandline override
            role_arn = self.get_config('role_arn') or self.args.role_arn
            if role_arn is not None:
                # First check to see if we configured a default role
                logging.debug("Looking for configured role: %s", role_arn)
                for arn in arns:
                    # Use endswith here so we can provide just a role name
                    # instead of the full ARN.
                    if arn[1].endswith(role_arn):
                        selected = arn
            if selected is None:
                # We either didn't configure a default role or the configured
                # default role didn't match any available roles. Ask the user
                # to pick one.
                print("Available roles")
                response = self.choose_from_menu(
                    [arn[1].split('/')[-1] for arn in arns],
                    "Select role to log in with: ")
                selected = arns[response]
        else:
            selected = arns[0]
        return selected

    def get_arns(self, saml_assertion):
        """Extracts the available principal/role ARNS for a user given a
        base64 encoded SAML assertion returned by okta.

        saml_assertion - the saml asssertion given by okta, base64 encoded.
        """
        parsed = ET.fromstring(base64.b64decode(saml_assertion))
        # Horrible xpath expression to dig into the ARNs
        elems = parsed.findall(
            ".//{urn:oasis:names:tc:SAML:2.0:assertion}Attribute["
            "@Name='https://aws.amazon.com/SAML/Attributes/Role']//*")
        # text contains Principal ARN, Role ARN separated by a comma
        arns = [e.text.split(",", 1) for e in elems]
        selected = self.select_role(arns)
        # Returns principal_arn, role_arn
        logging.debug("Principal ARN: %s", selected[0])
        logging.debug("Role ARN: %s", selected[1])
        return selected

    def aws_assume_role(self, principal_arn, role_arn, assertion, duration):
        """Gets temporary credentials from aws. Returns a dictionary
        containing the temporary credentials.

        principal_arn - the principal_arn (obtained from saml assertion)
        role_arn  - the arn of the role to assume (obtained from saml
                    assertion)
        assertion - the saml assertion itself (base64 encoded)
        duration  - how long to request the credentials be valid for in
                    seconds. This can't be longer than AWS allows (3600 by
                    default, may be configured to be as long as 43200)
        """

        # Override AWS_PROFILE so boto3 doesn't complain if we have it set
        # to a new profile that doesn't yet exist. This is needed because
        # boto3 will use environment variables if you don't pass in a profile,
        # but will complain if you do pass in a profile that doesn't exist.
        oldenv = os.environ
        if 'AWS_PROFILE' in os.environ:
            del os.environ['AWS_PROFILE']
        if 'AWS_DEFAULT_PROFILE' in os.environ:
            del os.environ['AWS_DEFAULT_PROFILE']
        if 'govcloud' in self.profile:
            region_name = 'us-gov-east-1'
        else:
            region_name = 'us-east-1'

        client = boto3.client('sts', region_name=region_name)

        # And restore them once more
        if 'AWS_PROFILE' in oldenv:
            os.environ['AWS_PROFILE'] = oldenv['AWS_PROFILE']
        if 'AWS_DEFAULT_PROFILE' in oldenv:
            os.environ['AWS_DEFAULT_PROFILE'] = oldenv['AWS_DEFAULT_PROFILE']

        aws_creds = client.assume_role_with_saml(
            RoleArn=role_arn,
            PrincipalArn=principal_arn,
            SAMLAssertion=assertion,
            DurationSeconds=duration
            )

        if 'Credentials' not in aws_creds:
            logging.debug("aws_creds json is: %s" % aws_creds)
            raise exceptions.AssumeRoleError("Credentials key not in returned"
                                             " json")

        return aws_creds['Credentials']

    def set_aws_config(self, profile, key, value):
        """Sets a single AWS configuration option. Used to store the temporary
        credentials in ~/.aws/credentials.

        profile - the profile to set the configuration option under
        key     - the option to change
        value   - the value to change it to
        """
        # Override AWS_PROFILE so aws sts doesn't complain if we have it set
        # to a new profile that doesn't yet exist
        newenv = os.environ.copy()
        if 'AWS_PROFILE' in newenv:
            del newenv['AWS_PROFILE']
        if 'AWS_DEFAULT_PROFILE' in newenv:
            del newenv['AWS_DEFAULT_PROFILE']

        subprocess.call([shutil.which("aws"), "configure", "set",
                         "--profile", profile, key, value],
                        env=newenv)

    def store_aws_creds_in_profile(self, profile, aws_creds):
        """Stores the temporary AWS credentials in ~/.aws/credentials.

        profile - the profile to store the credentials under
        aws_creds - a dictionary containing the credentials returned from AWS
        """
        self.set_aws_config(profile, "aws_access_key_id",
                            aws_creds['AccessKeyId'])
        self.set_aws_config(profile, "aws_secret_access_key",
                            aws_creds['SecretAccessKey'])
        self.set_aws_config(profile, "aws_session_token",
                            aws_creds['SessionToken'])

    def is_logged_in(self, session_id):
        """Checks to see if a given okta session ID is still valid. Will return
        false if the session has expired and we are no longer logged in to
        okta.

        session_id - the session token that we are verifying
        """
        logging.debug("Verifying if we are already logged in")
        r = requests.get("https://%s/api/v1/sessions/me" %
                         self.get_config('okta_server'),
                         cookies={"sid": session_id})
        logged_in = r.status_code == 200
        logging.debug("Logged in: %s", logged_in)
        return logged_in

    def verify_totp_factor(self, url, statetoken):
        """Verifies the totp factor passcode, returning a single use session
        token that can be exchanged for a long lived session ID.

        url - the totp factor verification url
        statetoken - the state token provided when verifying totp factor
        """
        passcode = input("Enter your passcode: ")
        r = requests.post(url,
                          json={
                              "stateToken": statetoken,
                              "passCode": passcode
                          })
        if r.status_code == 403:
            raise exceptions.LoginError("Incorrect passcode")
        if r.status_code != 200:
            logging.debug(r.text)
            raise exceptions.LoginError(
                "Login request returned HTTP status %s" % r.status_code)
        return r.json()

    def log_in_to_okta(self, password):
        """Logs in to okta using the authn API, returning a single use session
        token that can be exchanged for a long lived session ID.

        password - the user's okta password
        """
        r = requests.post(
            "https://%s/api/v1/authn" % self.get_config('okta_server'),
            json={
                "username": self.get_config('username'),
                "password": password
            })
        if r.status_code == 401:
            raise exceptions.LoginError("Incorrect password")
        if r.status_code != 200:
            logging.debug(r.text)
            raise exceptions.LoginError(
                "Login request returned HTTP status %s" % r.status_code)
        session_data = r.json()
        if 'status' not in session_data:
            logging.error(session_data)
            raise exceptions.LoginError(
                "Unknown error (missing status field in response)")
        if session_data['status'] == 'MFA_REQUIRED':
            logging.debug('MFA Required')
            statetoken = session_data["stateToken"]
            for factor in session_data["_embedded"]["factors"]:
                # TODO - Add other factors
                if factor["factorType"] == "token:software:totp":
                    url = factor["_links"]["verify"]["href"]
                    session_data = self.verify_totp_factor(url, statetoken)
        if session_data['status'] != 'SUCCESS':
            raise exceptions.LoginError(
                session_data['status'].title().replace('_', ' '))
        if 'sessionToken' not in session_data:
            logging.debug(session_data)
            raise exceptions.LoginError("Missing session token")
        return session_data['sessionToken']

    def get_session_id(self, session_token):
        """Returns a (long lived) session ID given a (single use) session
        token.

        session_token - the single use token returned when logging in to okta
        """
        r = requests.post(
            "https://%s/api/v1/sessions" % self.get_config('okta_server'),
            json={"sessionToken": session_token})
        if r.status_code != 200:
            logging.debug(r.text)
            return None
        return r.json()['id']

    def get_assigned_applications(self, session_id):
        """Queries okta to get a list of AWS applications that have been
        assigned to the user. Returns a dictionary mapping the profile names
        to log in URLs for each assigned application.

        session_id - the okta session ID needed to make api calls
        """
        # TODO - proper pagination on this
        logging.debug("Getting assigned application links from okta")
        r = requests.get("https://%s/api/v1/users/me/appLinks?limit=1000" %
                         self.get_config('okta_server'),
                         cookies={"sid": session_id})
        if r.status_code != 200:
            logging.error("Error getting assigned application list")
            logging.debug(r.text)
            return None
        applinks = {i['label']: i['linkUrl'] for i in r.json()
                    if i['appName'] == 'amazon_aws'}
        return applinks

    def shorten_appnames(self, applinks):
        """Converts long application names such as
        'Company Engineering AWS (dev use)' to something suitable for use in
        an aws profile such as 'company-engineering'.

        applinks - a dictionary mapping application names to application links.
        """
        logging.debug("Shortening application names")
        newapplinks = {}
        for k, v in applinks.items():
            newk = re.sub(" *AWS$", "", k)  # Remove AWS suffix
            newk = re.sub(r" *\(.*\)", "", newk)  # Remove anything in parens
            newk = newk.lower()
            newk = re.sub(" +", "-", newk)
            newapplinks[newk] = v
            logging.debug("%s => %s", k, newk)
        return newapplinks

    def get_saml_assertion(self, session_id, app_url):
        """Sends a request to the application link, and extracts a SAML
        assertion from the response.

        session_id - okta session ID needed to make api calls
        app_url    - The URL used to log in to the okta application
        """
        r = requests.post(app_url, cookies={"sid": session_id})

        if r.status_code != 200:
            logging.error("Error getting saml assertion. HTML response %s",
                          r.status_code)
            return None

        match = re.search(r'<input name="SAMLResponse".*value="([^"]*)"',
                          r.text)
        if not match:
            return None
        return html.unescape(match.group(1))

    def friendly_interval(self, seconds):
        """Converts a number of seconds into something a little friendlier,
        such as '10 minutes' or '1 hour'.

        seconds - an integer number of seconds to convert
        """
        if seconds == 3600:
            return "1 hour"
        elif seconds >= 3600:
            return "%.2g hours" % (seconds / 3600.0)
        elif seconds == 60:
            return "1 minute"
        return "%.2g minutes" % (seconds / 60.0)

    def fetch_credentials(self, applinks, session_id):
        """Performs the various steps needed to actually get a set of
        temporary credentials and store them. Doesn't return anything, but
        temporary credentials should be stored in ~/.aws/credentials by the
        time this method has finished.

        applinks   - a mapping of profile names to application links
        session_id - okta session ID needed to make API calls
        """
        # Resolve any profile alias and store it in real_profile
        real_profile = self.config['aliases'].get(self.profile, self.profile)

        if real_profile not in applinks:
            alias_msg = ""
            if real_profile != self.profile:
                alias_msg = " (an alias that resolved to %s)" % real_profile
            print("ERROR: %s%s isn't a valid profile name" % (
                self.profile, alias_msg))
            print("Valid profiles:", ', '.join(list(applinks.keys())))
            sys.exit(1)

        saml_assertion = self.get_saml_assertion(
            session_id, applinks[real_profile])
        if saml_assertion is None:
            logging.error("Problem getting SAML assertion")
            sys.exit(1)

        principal_arn, role_arn = self.get_arns(saml_assertion)

        logging.info("Assuming AWS role %s...", role_arn.split("/")[-1])
        session_duration = self.get_config('session_duration')
        try:
            aws_creds = self.aws_assume_role(principal_arn, role_arn,
                                             saml_assertion, session_duration)
        except exceptions.AssumeRoleError as e:
            logging.error("Unable to get temporary credentials: %s", e)
            sys.exit(1)
        self.store_aws_creds_in_profile(self.profile, aws_creds)
        logging.info("Temporary credentials stored in profile %s",
                     self.profile)
        logging.info("Credentials expire in %s",
                     self.friendly_interval(session_duration))

    def run(self):
        """Main entry point for the application after parsing command line
        arguments."""
        self.setup_logging()

        self.preflight_checks()

        if self.args.setup:
            self.interactive_setup(self.args.config)
            sys.exit(0)

        self.config = self.load_config(self.args.config)

        if not self.args.no_cookies:
            if os.path.exists(self.get_config('cookie_file')):
                logging.debug("Loading session ID from %s",
                              self.get_config('cookie_file'))
                with open(self.get_config('cookie_file')) as fh:
                    session_id = fh.read().rstrip("\n")
                    # Support old cookie file format
                    if session_id.startswith('#LWP-Cookies-2.0'):
                        logging.debug("Converting old cookie file format")
                        m = re.search(r'sid="([^"]*)"', session_id)
                        if m:
                            logging.debug("Found session ID in old cookies")
                            session_id = m.group(1)
                        else:
                            logging.debug("Didn't find session ID in cookies")
                            session_id = None
                if session_id is not None \
                        and not self.is_logged_in(session_id):
                    session_id = None
            else:
                session_id = None
        else:
            session_id = None

        if session_id is None:
            print("Okta Username:", self.get_config('username'))
            password = ""
            while password == "":
                password = getpass.getpass("Okta Password: ")
            sys.stdout.flush()

            try:
                onetimetoken = self.log_in_to_okta(password)
            except exceptions.LoginError as e:
                logging.error("Error logging into okta: %s", e.message)
                sys.exit(1)

            session_id = self.get_session_id(onetimetoken)
            if not self.args.no_cookies:
                logging.debug("Saving session cookie to %s",
                              self.get_config('cookie_file'))
                with open(self.get_config('cookie_file'), 'w') as fh:
                    fh.write(session_id)

        applinks = self.get_assigned_applications(session_id)
        if self.get_config('short_profile_names'):
            applinks = self.shorten_appnames(applinks)

        if self.args.all:
            for profile in applinks.keys():
                print("Fetching credentials for:", profile)
                self.profile = profile
                self.fetch_credentials(applinks, session_id)
            sys.exit(0)

        if self.args.list:
            print("Available profiles:")
            reverse_aliases = {}
            for k, v in self.config['aliases'].items():
                reverse_aliases.setdefault(v, []).append(k)
            for profile in applinks.keys():
                if profile in reverse_aliases:
                    print("%s (Aliases: %s)" % (profile, ', '.join(
                        reverse_aliases[profile])))
                else:
                    print(profile)

            sys.exit(0)

        self.fetch_credentials(applinks, session_id)