import boto3 import hashlib import logging from botocore.exceptions import ClientError from cfn_resource_provider import ResourceProvider from cryptography.hazmat.backends import default_backend as crypto_default_backend from cryptography.hazmat.primitives import serialization as crypto_serialization from cryptography.hazmat.primitives.asymmetric import rsa import ssm_parameter_name log = logging.getLogger() request_schema = { "type": "object", "required": ["Name"], "properties": { "Name": { "type": "string", "minLength": 1, "pattern": "[a-zA-Z0-9_/]+", "description": "the name of the private key in the parameters store", }, "KeySize": { "type": "integer", "default": 2048, "description": "number of bits in the key", }, "KeyFormat": { "type": "string", "enum": ["PKCS8", "TraditionalOpenSSL"], "default": "PKCS8", "description": "encoding type of the private key", }, "Description": { "type": "string", "default": "", "description": "the description of the key in the parameter store", }, "KeyAlias": { "type": "string", "default": "alias/aws/ssm", "description": "KMS key to use to encrypt the key", }, "RefreshOnUpdate": { "type": "boolean", "default": False, "description": "generate a new secret on update", }, "Version": {"type": "string", "description": "opaque string to force update"}, }, } class RSAKeyProvider(ResourceProvider): def __init__(self): super(RSAKeyProvider, self).__init__() self.request_schema = request_schema self.ssm = boto3.client("ssm") self.iam = boto3.client("iam") self.region = boto3.session.Session().region_name self.account_id = (boto3.client("sts")).get_caller_identity()["Account"] def convert_property_types(self): self.heuristic_convert_property_types(self.properties) @property def allow_overwrite(self): return ssm_parameter_name.equals(self.physical_resource_id, self.arn) @property def arn(self): return ssm_parameter_name.to_arn(self.region, self.account_id, self.get("Name")) def name_from_physical_resource_id(self): return ssm_parameter_name.from_arn(self.physical_resource_id) @property def key_format(self): if self.get("KeyFormat", "") == "TraditionalOpenSSL": return crypto_serialization.PrivateFormat.TraditionalOpenSSL else: return crypto_serialization.PrivateFormat.PKCS8 def get_key(self): response = self.ssm.get_parameter( Name=self.name_from_physical_resource_id(), WithDecryption=True ) private_key = response["Parameter"]["Value"].encode("ascii") key = crypto_serialization.load_pem_private_key( private_key, password=None, backend=crypto_default_backend() ) private_key = key.private_bytes( crypto_serialization.Encoding.PEM, self.key_format, crypto_serialization.NoEncryption(), ) public_key = key.public_key().public_bytes( crypto_serialization.Encoding.OpenSSH, crypto_serialization.PublicFormat.OpenSSH, ) return private_key.decode("ascii"), public_key.decode("ascii") def create_key(self): key = rsa.generate_private_key( backend=crypto_default_backend(), public_exponent=65537, key_size=self.get("KeySize"), ) private_key = key.private_bytes( crypto_serialization.Encoding.PEM, self.key_format, crypto_serialization.NoEncryption(), ) public_key = key.public_key().public_bytes( crypto_serialization.Encoding.OpenSSH, crypto_serialization.PublicFormat.OpenSSH, ) return private_key.decode("ascii"), public_key.decode("ascii") def public_key_to_pem(self, private_key): key = crypto_serialization.load_pem_private_key( private_key.encode("ascii"), password=None, backend=crypto_default_backend() ) public_key = key.public_key().public_bytes( crypto_serialization.Encoding.PEM, crypto_serialization.PublicFormat.SubjectPublicKeyInfo, ) return public_key.decode("ascii") def create_or_update_secret(self, overwrite=False, new_secret=True): try: if new_secret: private_key, public_key = self.create_key() else: private_key, public_key = self.get_key() kwargs = { "Name": self.get("Name"), "KeyId": self.get("KeyAlias"), "Type": "SecureString", "Overwrite": overwrite, "Value": private_key, } if self.get("Description") != "": kwargs["Description"] = self.get("Description") response = self.ssm.put_parameter(**kwargs) version = response["Version"] if "Version" in response else 1 self.set_attribute("Arn", self.arn) self.set_attribute("PublicKey", public_key) self.set_attribute("PublicKeyPEM", self.public_key_to_pem(private_key)) self.set_attribute( "Hash", hashlib.md5(public_key.encode("utf-8")).hexdigest() ) self.set_attribute("Version", version) if not ssm_parameter_name.equals(self.physical_resource_id, self.arn): # prevent CFN deleting a resource with identical Arns in different formats. self.physical_resource_id = self.arn except ClientError as e: self.physical_resource_id = "could-not-create" self.fail(str(e)) def create(self): self.create_or_update_secret(overwrite=False, new_secret=True) def update(self): self.create_or_update_secret( overwrite=self.allow_overwrite, new_secret=self.get("RefreshOnUpdate") ) def delete(self): name = self.physical_resource_id.split("/", 1) if len(name) == 2: try: self.ssm.delete_parameter(Name=name[1]) except ClientError as e: if e.response["Error"]["Code"] != "ParameterNotFound": return self.fail(str(e)) self.success("System Parameter with the name %s is deleted" % name) else: self.success( "System Parameter with the name %s is ignored" % self.physical_resource_id ) provider = RSAKeyProvider() def handler(request, context): return provider.handle(request, context)