#!/usr/bin/env python
# -*- coding: utf8 -*-
#
# Amazon Web Services CLI - LastPass SAML integration
#
# This script uses LastPass Enterprise SAML-based login to authenticate
# with AWS and retrieve a session token that can then be used with the
# AWS cli tool.
#
# Copyright (c) 2016 LastPass
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
import sys
import re
import requests
import hmac
import hashlib
import binascii
import logging
import xml.etree.ElementTree as ET
from base64 import b64decode, b64encode
from struct import pack
import os
import argparse

import boto3
from six.moves import input
from six.moves import html_parser
from six.moves import configparser

from getpass import getpass

LASTPASS_SERVER = 'https://lastpass.com'

# for debugging with proxy
PROXY_SERVER = 'https://127.0.0.1:8443'
# LASTPASS_SERVER = PROXY_SERVER

logging.basicConfig(level=logging.CRITICAL)
logger = logging.getLogger('lp-aws-saml')

class MfaRequiredException(Exception):
    pass

def should_verify():
    """ Disable SSL validation only when debugging via proxy """
    return LASTPASS_SERVER != PROXY_SERVER


def extract_form(html):
    """
    Retrieve the (first) form elements from an html page.
    """
    fields = {}
    matches = re.findall(r'name="([^"]*)" (id="([^"]*)" )?value="([^"]*)"',
                         html)
    for match in matches:
        if len(match) > 2:
            fields[match[0]] = match[3]

    action = ''
    match = re.search(r'action="([^"]*)"', html)
    if match:
        action = match.group(1)

    form = {
        'action': action,
        'fields': fields
    }
    return form


def xorbytes(a, b):
    """ xor all bytes in a string """
    return ''.join(map(lambda x, y: chr(ord(x) ^ ord(y)), a, b))


def prf(h, data):
    """ internal hash update for pbkdf2/hmac-sha256 """
    hm = h.copy()
    hm.update(data)
    return hm.digest()


def pbkdf2(password, salt, rounds, length):
    """
    PBKDF2-SHA256 password derivation.
    """
    key = ''
    h = hmac.new(password, None, hashlib.sha256)
    for block in range(0, (length + 31) / 32):
        ib = hval = prf(h, salt + pack(">I", block + 1))

        for i in range(1, rounds):
            hval = prf(h, hval)
            ib = xorbytes(ib, hval)

        key = key + ib
    return binascii.hexlify(key[0:length])


def lastpass_login_hash(username, password, iterations):
    """
    Determine the number of PBKDF2 iterations needed for a user.
    """
    key = binascii.unhexlify(pbkdf2(password, username, iterations, 32))
    result = pbkdf2(key, password, 1, 32)
    return result


def lastpass_iterations(session, username):
    """
    Determine the number of PBKDF2 iterations needed for a user.
    """
    iterations = 5000
    lp_iterations_page = '%s/iterations.php' % LASTPASS_SERVER
    params = {
        'email': username
    }
    r = session.post(lp_iterations_page, data=params, verify=should_verify())
    if r.status_code == 200:
        iterations = int(r.text)

    return iterations


def lastpass_login(session, username, password, otp = None):
    """
    Log into LastPass with a given username and password.
    """
    logger.debug("logging into lastpass as %s" % username)
    iterations = lastpass_iterations(session, username)

    lp_login_page = '%s/login.php' % LASTPASS_SERVER

    params = {
        'method': 'web',
        'xml': '1',
        'username': username,
        'hash': lastpass_login_hash(username, password, iterations),
        'iterations': iterations
    }
    if otp is not None:
        params['otp'] = otp

    r = session.post(lp_login_page, data=params, verify=should_verify())
    r.raise_for_status()

    doc = ET.fromstring(r.text)
    error = doc.find("error")
    if error is not None:
        cause = error.get('cause')
        if cause == 'googleauthrequired':
            raise MfaRequiredException('Need MFA for this login')
        else:
            reason = error.get('message')
            raise ValueError("Could not login to lastpass: %s" % reason)


def get_saml_token(session, username, password, saml_cfg_id):
    """
    Log into LastPass and retrieve a SAML token for a given
    SAML configuration.
    """
    logger.debug("Getting SAML token")

    # now logged in, grab the SAML token from the IdP-initiated login
    idp_login = '%s/saml/launch/cfg/%d' % (LASTPASS_SERVER, saml_cfg_id)

    r = session.get(idp_login, verify=should_verify())

    form = extract_form(r.text)
    if not form['action']:
        # try to scrape the error message just to make it more user friendly
        error = ""
        for l in r.text.splitlines():
            match = re.search(r'<h2>(.*)</h2>', l)
            if match:
                msg = html_parser.HTMLParser().unescape(match.group(1))
                msg = msg.replace("<br/>", "\n")
                msg = msg.replace("<b>", "")
                msg = msg.replace("</b>", "")
                error = "\n" + msg

        raise ValueError("Unable to find SAML ACS" + error)

    return b64decode(form['fields']['SAMLResponse'])


def get_saml_aws_roles(assertion):
    """
    Get the AWS roles contained in the assertion.  This returns a list of
    RoleARN, PrincipalARN (IdP) pairs.
    """
    doc = ET.fromstring(assertion)

    role_attrib = 'https://aws.amazon.com/SAML/Attributes/Role'
    xpath = ".//saml:Attribute[@Name='%s']/saml:AttributeValue" % role_attrib
    ns = {'saml': 'urn:oasis:names:tc:SAML:2.0:assertion'}

    attribs = doc.findall(xpath, ns)
    return [a.text.split(",", 2) for a in attribs]


def get_saml_nameid(assertion):
    """
    Get the AWS roles contained in the assertion.  This returns a list of
    RoleARN, PrincipalARN (IdP) pairs.
    """
    doc = ET.fromstring(assertion)

    ns = {'saml': 'urn:oasis:names:tc:SAML:2.0:assertion'}
    return doc.find(".//saml:NameID", ns).text


def prompt_for_role(roles):
    """
    Ask user which role to assume.
    """
    if len(roles) == 1:
        return roles[0]

    print 'Please select a role:'
    count = 1
    for r in roles:
        print '  %d) %s' % (count, r[0])
        count = count + 1

    choice = 0
    while choice < 1 or choice > len(roles) + 1:
        try:
            choice = int(input("Choice: "))
        except ValueError:
            choice = 0

    return roles[choice - 1]


def aws_assume_role(session, assertion, role_arn, principal_arn):
    client = boto3.client('sts')
    return client.assume_role_with_saml(
                RoleArn=role_arn,
                PrincipalArn=principal_arn,
                SAMLAssertion=b64encode(assertion))


def aws_set_profile(profile_name, response):
    """
    Save AWS credentials returned from Assume Role operation in
    ~/.aws/credentials INI file.  The credentials are saved in
    a profile with [profile_name].
    """
    config_fn = os.path.expanduser("~/.aws/credentials")

    config = configparser.ConfigParser()
    config.read(config_fn)

    section = profile_name
    try:
        config.add_section(section)
    except configparser.DuplicateSectionError:
        pass

    try:
        os.makedirs(os.path.dirname(config_fn))
    except OSError:
        pass

    config.set(section, 'aws_access_key_id',
               response['Credentials']['AccessKeyId'])
    config.set(section, 'aws_secret_access_key',
               response['Credentials']['SecretAccessKey'])
    config.set(section, 'aws_session_token',
               response['Credentials']['SessionToken'])
    with open(config_fn, 'w') as out:
        config.write(out)


def main():
    parser = argparse.ArgumentParser(description='Get temporary AWS access credentials using LastPass SAML Login')
    parser.add_argument('username', type=str,
                    help='the lastpass username')
    parser.add_argument('saml_config_id', type=int,
                    help='the lastpass SAML config id')
    parser.add_argument('--profile-name', dest='profile_name',
                    help='the name of AWS profile to save the data in (default username)')

    args = parser.parse_args()
    
    username = args.username
    saml_cfg_id = args.saml_config_id

    if args.profile_name is not None:
        profile_name = args.profile_name
    else:
        profile_name = username
    
    password = getpass()

    session = requests.Session()
    try:
      lastpass_login(session, username, password)
    except MfaRequiredException:
      otp = input("OTP: ")
      lastpass_login(session, username, password, otp)

    assertion = get_saml_token(session, username, password, saml_cfg_id)
    roles = get_saml_aws_roles(assertion)
    user = get_saml_nameid(assertion)

    role = prompt_for_role(roles)

    response = aws_assume_role(session, assertion, role[0], role[1])
    aws_set_profile(profile_name, response)

    print "A new AWS CLI profile '%s' has been added." % profile_name
    print "You may now invoke the aws CLI tool as follows:"
    print
    print "    aws --profile %s [...] " % profile_name
    print
    print "This token expires in one hour."


if __name__ == "__main__":
    main()