import logging
import os
import sys

import boto3


class Retry:
    def __init__(self, logging):
        # parameters
        self.logging = logging

        self._client_sts = None
        self._client_sqs = None

    @property
    def client_sts(self):
        if not self._client_sts:
            self._client_sts = boto3.client("sts")
        return self._client_sts

    @property
    def region(self):
        if self.client_sts.meta.region_name != "aws-global":
            return self.client_sts.meta.region_name
        else:
            return "us-east-1"

    @property
    def client_sqs(self):
        if not self._client_sqs:
            self._client_sqs = boto3.client("sqs", self.region)
        return self._client_sqs

    def retry_security_events(self):
        """Retrieves messages from the DLQ and sends them back into the 
        compliance SQS Queue for reprocessing.
        """
        queue_url = os.environ.get("DEADLETTERQUEUE")
        response = self.receive_message(queue_url)

        while "Messages" in response:
            for message in response.get("Messages"):
                receipt_handle = message.get("ReceiptHandle")
                body = message.get("Body")
                try_count = (
                    message.get("MessageAttributes", {})
                    .get("try_count", {})
                    .get("StringValue", "1")
                )

                if self.send_to_compliance_queue(body, try_count):
                    self.delete_from_queue(queue_url, receipt_handle)

                response = self.receive_message(queue_url)

    def delete_from_queue(self, queue_url, receipt_handle):
        """Delete a Message from an SQS Queue.
        
        Arguments:
            queue_url {string} -- URL of an SQS Queue
            receipt_handle {string} -- The receipt handle associated with the message to delete
        """
        try:
            self.client_sqs.delete_message(
                QueueUrl=queue_url, ReceiptHandle=receipt_handle
            )

            self.logging.info(
                f"Deleted Message '{receipt_handle}' from SQS Queue URL '{queue_url}'."
            )
            return True
        except:
            self.logging.error(
                f"Could not delete Message '{receipt_handle}' from SQS Queue URL '{queue_url}'."
            )
            self.logging.error(sys.exc_info()[1])
            return False

    def receive_message(self, queue_url):
        """Retrieves 10 messeges from an SQS Queue
        
        Arguments:
            queue_url {string} -- SQS Queue URL
        
        Returns:
            dictionary -- Dictionary of SQS messeges
        """
        try:
            return self.client_sqs.receive_message(
                QueueUrl=queue_url,
                MessageAttributeNames=["try_count"],
                MaxNumberOfMessages=10,
            )
        except:
            self.logging.error(
                f"Could not retrieve Messages from SQS Queue URL '{queue_url}'."
            )
            self.logging.error(sys.exc_info()[1])
            return {}

    def send_to_compliance_queue(self, config_payload, try_count):
        """Sends a message to the Config Compliance SQS Queue.
        
        Arguments:
            config_payload {string} -- AWS Config payload
            try_count {string} -- Number of attempted remediations for a given AWS Config Rule
        
        Returns:
            boolean -- True if sending message to SQS was successful
        """
        queue_url = os.environ.get("COMPLIANCEQUEUE")

        try:
            self.client_sqs.send_message(
                QueueUrl=queue_url,
                MessageBody=config_payload,
                MessageAttributes={
                    "try_count": {"StringValue": try_count, "DataType": "Number"}
                },
            )

            self.logging.debug(f"Message payload sent to SQS Queue '{queue_url}'.")
            return True
        except:
            self.logging.error(f"Could not send payload to SQS Queue '{queue_url}'.")
            self.logging.error(sys.exc_info()[1])
            return False

    @staticmethod
    def get_config_rule_compliance(record):
        """Retrieves the AWS Config rule compliance variable
        
        Arguments:
            config_payload {JSON} -- AWS Config payload
        
        Returns:
            string -- COMPLIANT | NON_COMPLIANT
        """
        return record.get("detail").get("newEvaluationResult").get("complianceType")

    @staticmethod
    def get_config_rule_name(record):
        """Retrieves the AWS Config rule name variable. For Security Hub rules, the random
        suffixed alphanumeric characters will be removed.
        
        Arguments:
            config_payload {JSON} -- AWS Config payload
        
        Returns:
            string -- AWS Config rule name
        """
        return record.get("detail").get("configRuleName")


def lambda_handler(event, context):
    logger = logging.getLogger()

    if logger.handlers:
        for handler in logger.handlers:
            logger.removeHandler(handler)

    # change logging levels for boto and others
    logging.getLogger("boto3").setLevel(logging.ERROR)
    logging.getLogger("botocore").setLevel(logging.ERROR)
    logging.getLogger("urllib3").setLevel(logging.ERROR)

    # set logging format
    logging.basicConfig(
        format="[%(levelname)s] %(message)s (%(filename)s, %(funcName)s(), line %(lineno)d)",
        level=os.environ.get("LOGLEVEL", "WARNING").upper(),
    )

    # instantiate class
    retry = Retry(logging)

    # run functions
    retry.retry_security_events()