""" Connect to an EC2 instance via SSH, by name or instance ID. Security groups, network ACLs, interfaces, VPC routing tables, VPC Internet Gateways, and internal firewalls for the instance must be configured to allow SSH connections. To facilitate SSH connections, ``aegea ssh`` resolves instance names to public DNS names assigned by AWS, and securely retrieves SSH host public keys from instance metadata before connecting. This avoids both the prompt to save the instance public key and the resulting transient MITM vulnerability. ``aegea ssh`` also supports Bless, via the --bless-config CONFIG_FILE option or the BLESS_CONFIG environment variable. This should point to a YAML file with the format described in https://github.com/chanzuckerberg/blessclient/blob/master/examples/config.yml. """ import os, sys, argparse, string, datetime, json, base64, time, fnmatch import boto3, yaml from . import register_parser, logger from .util.aws import resolve_instance_id, resources, clients, ARN from .util.crypto import (add_ssh_host_key_to_known_hosts, ensure_local_ssh_key, get_public_key_from_pair, add_ssh_key_to_agent, get_ssh_key_path) from .util.printing import BOLD from .util.exceptions import AegeaException from .util.compat import lru_cache from .util.aws.ssm import ensure_session_manager_plugin, run_command opts_by_nargs = { "ssh": {0: "46AaCfGgKkMNnqsTtVvXxYy", 1: "BbcDEeFIiJLlmOopQRSW"}, "scp": {0: "346BCpqrv", 1: "cFiloPS"} } def add_bless_and_passthrough_opts(parser, program): parser.add_argument("--bless-config", default=os.environ.get("BLESS_CONFIG"), help="Path to a Bless configuration file (or pass via the BLESS_CONFIG environment variable)") parser.add_argument("--use-kms-auth", help=argparse.SUPPRESS) for opt in opts_by_nargs[program][0]: parser.add_argument("-" + opt, action="store_true", help=argparse.SUPPRESS) for opt in opts_by_nargs[program][1]: parser.add_argument("-" + opt, action="append", help=argparse.SUPPRESS) def extract_passthrough_opts(args, program): opts = [] for opt in opts_by_nargs[program][0]: if getattr(args, opt): opts.append("-" + opt) for opt in opts_by_nargs[program][1]: for value in getattr(args, opt) or []: opts.extend(["-" + opt, value]) return opts @lru_cache(8) def get_instance(name): return resources.ec2.Instance(resolve_instance_id(name)) def save_instance_public_key(name, use_ssm=False): instance = get_instance(name) tags = {tag["Key"]: tag["Value"] for tag in instance.tags or []} ssh_host_key = tags.get("SSHHostPublicKeyPart1", "") + tags.get("SSHHostPublicKeyPart2", "") if ssh_host_key: # FIXME: this results in duplicates. # Use paramiko to detect if the key is already listed and not insert it then (or only insert if different) hostname = instance.id if use_ssm else instance.public_dns_name add_ssh_host_key_to_known_hosts(hostname + " " + ssh_host_key + "\n") def resolve_instance_public_dns(name): instance = get_instance(name) if not getattr(instance, "public_dns_name", None): msg = "Unable to resolve public DNS name for {} (state: {})" raise AegeaException(msg.format(instance, getattr(instance, "state", {}).get("Name"))) return instance.public_dns_name def get_linux_username(): username = ARN.get_iam_username() assert username != "unknown" username, at, domain = username.partition("@") return username def get_kms_auth_token(session, bless_config, lambda_regional_config): logger.info("Requesting new KMS auth token in %s", lambda_regional_config["aws_region"]) token_not_before = datetime.datetime.utcnow() - datetime.timedelta(minutes=1) token_not_after = token_not_before + datetime.timedelta(hours=1) token = dict(not_before=token_not_before.strftime("%Y%m%dT%H%M%SZ"), not_after=token_not_after.strftime("%Y%m%dT%H%M%SZ")) encryption_context = { "from": session.resource("iam").CurrentUser().user_name, "to": bless_config["lambda_config"]["function_name"], "user_type": "user" } kms = session.client('kms', region_name=lambda_regional_config["aws_region"]) res = kms.encrypt(KeyId=lambda_regional_config["kms_auth_key_id"], Plaintext=json.dumps(token), EncryptionContext=encryption_context) return base64.b64encode(res["CiphertextBlob"]).decode() def ensure_bless_ssh_cert(ssh_key_name, bless_config, use_kms_auth, max_cert_age=1800): ssh_key = ensure_local_ssh_key(ssh_key_name) ssh_key_filename = get_ssh_key_path(ssh_key_name) ssh_cert_filename = ssh_key_filename + "-cert.pub" if os.path.exists(ssh_cert_filename) and time.time() - os.stat(ssh_cert_filename).st_mtime < max_cert_age: logger.info("Using cached Bless SSH certificate %s", ssh_cert_filename) return ssh_cert_filename logger.info("Requesting new Bless SSH certificate") for lambda_regional_config in bless_config["lambda_config"]["regions"]: if lambda_regional_config["aws_region"] == clients.ec2.meta.region_name: break session = boto3.Session(profile_name=bless_config["client_config"]["aws_user_profile"]) iam = session.resource("iam") sts = session.client("sts") assume_role_res = sts.assume_role(RoleArn=bless_config["lambda_config"]["role_arn"], RoleSessionName=__name__) awslambda = boto3.client('lambda', region_name=lambda_regional_config["aws_region"], aws_access_key_id=assume_role_res['Credentials']['AccessKeyId'], aws_secret_access_key=assume_role_res['Credentials']['SecretAccessKey'], aws_session_token=assume_role_res['Credentials']['SessionToken']) bless_input = dict(bastion_user=iam.CurrentUser().user_name, bastion_user_ip="0.0.0.0/0", bastion_ips=",".join(bless_config["client_config"]["bastion_ips"]), remote_usernames=",".join(bless_config["client_config"]["remote_users"]), public_key_to_sign=get_public_key_from_pair(ssh_key), command="*") if use_kms_auth: bless_input["kmsauth_token"] = get_kms_auth_token(session=session, bless_config=bless_config, lambda_regional_config=lambda_regional_config) res = awslambda.invoke(FunctionName=bless_config["lambda_config"]["function_name"], Payload=json.dumps(bless_input)) bless_output = json.loads(res["Payload"].read().decode()) if "certificate" not in bless_output: raise AegeaException("Error while requesting Bless SSH certificate: {}".format(bless_output)) with open(ssh_cert_filename, "w") as fh: fh.write(bless_output["certificate"]) return ssh_cert_filename def match_instance_to_bastion(instance, bastions): for bastion_config in bastions: for ipv4_pattern in bastion_config["hosts"]: if fnmatch.fnmatch(instance.private_ip_address, ipv4_pattern["pattern"]): logger.info("Using %s to connect to %s", bastion_config["pattern"], instance) return bastion_config def prepare_ssh_host_opts(username, hostname, bless_config_filename=None, ssh_key_name=__name__, use_kms_auth=True, use_ssm=True): if bless_config_filename: with open(bless_config_filename) as fh: bless_config = yaml.safe_load(fh) ensure_bless_ssh_cert(ssh_key_name=ssh_key_name, bless_config=bless_config, use_kms_auth=use_kms_auth) add_ssh_key_to_agent(ssh_key_name) instance = get_instance(hostname) if not username: username = bless_config["client_config"]["remote_users"][0] bastion_config = match_instance_to_bastion(instance=instance, bastions=bless_config["ssh_config"]["bastions"]) if use_ssm: return [], username + "@" + instance.id elif bastion_config: jump_host = bastion_config["user"] + "@" + bastion_config["pattern"] return ["-o", "ProxyJump=" + jump_host], username + "@" + instance.private_ip_address elif instance.public_dns_name: logger.warn("No bastion host found for %s, trying direct connection", instance.private_ip_address) return [], username + "@" + instance.public_dns_name else: raise AegeaException("No bastion host or public route found for {}".format(instance)) else: if get_instance(hostname).key_name is not None: add_ssh_key_to_agent(get_instance(hostname).key_name) if not username: username = get_linux_username() save_instance_public_key(hostname, use_ssm=use_ssm) return [], username + "@" + (get_instance(hostname).id if use_ssm else resolve_instance_public_dns(hostname)) def init_ssm(instance_id): ssm_plugin_path = ensure_session_manager_plugin() os.environ["PATH"] = os.environ["PATH"] + ":" + os.path.dirname(ssm_plugin_path) return ["-o", "ProxyCommand=aws ssm start-session --document-name AWS-StartSSHSession --target " + instance_id] def ssh(args): ssh_opts = ["-o", "ServerAliveInterval={}".format(args.server_alive_interval)] ssh_opts += ["-o", "ServerAliveCountMax={}".format(args.server_alive_count_max)] ssh_opts += extract_passthrough_opts(args, "ssh") prefix, at, name = args.name.rpartition("@") if args.use_ssm: ssh_opts += init_ssm(get_instance(name).id) host_opts, hostname = prepare_ssh_host_opts(username=prefix, hostname=name, bless_config_filename=args.bless_config, use_kms_auth=args.use_kms_auth, use_ssm=args.use_ssm) os.execvp("ssh", ["ssh"] + ssh_opts + host_opts + [hostname] + args.ssh_args) ssh_parser = register_parser(ssh, help="Connect to an EC2 instance", description=__doc__) ssh_parser.add_argument("name") ssh_parser.add_argument("ssh_args", nargs=argparse.REMAINDER, help="Arguments to pass to ssh; please see " + BOLD("man ssh") + " for details") ssh_parser.add_argument("--server-alive-interval", help=argparse.SUPPRESS) ssh_parser.add_argument("--server-alive-count-max", help=argparse.SUPPRESS) ssh_parser.add_argument("--no-ssm", action="store_false", dest="use_ssm") add_bless_and_passthrough_opts(ssh_parser, "ssh") def scp(args): """ Transfer files to or from EC2 instance. """ scp_opts, host_opts = extract_passthrough_opts(args, "scp"), [] user_or_hostname_chars = string.ascii_letters + string.digits ssm_init_complete = False for i, arg in enumerate(args.scp_args): if arg[0] in user_or_hostname_chars and ":" in arg: hostname, colon, path = arg.partition(":") username, at, hostname = hostname.rpartition("@") if args.use_ssm and not ssm_init_complete: scp_opts += init_ssm(get_instance(hostname).id) ssm_init_complete = True host_opts, hostname = prepare_ssh_host_opts(username=username, hostname=hostname, bless_config_filename=args.bless_config, use_kms_auth=args.use_kms_auth, use_ssm=args.use_ssm) args.scp_args[i] = hostname + colon + path os.execvp("scp", ["scp"] + scp_opts + host_opts + args.scp_args) scp_parser = register_parser(scp, help="Transfer files to or from EC2 instance", description=scp.__doc__) scp_parser.add_argument("scp_args", nargs=argparse.REMAINDER, help="Arguments to pass to scp; please see " + BOLD("man scp") + " for details") scp_parser.add_argument("--no-ssm", action="store_false", dest="use_ssm") add_bless_and_passthrough_opts(scp_parser, "scp") def run(args): run_command(args.command, instance_ids=[get_instance(args.instance).id]) run_parser = register_parser(run, help="Run a command on an EC2 instance", description=run_command.__doc__) run_parser.add_argument("instance") run_parser.add_argument("command")