# -*- coding: utf-8 -*- import json import time import logging import threading from boto3.session import Session from botocore.client import Config logger = logging.getLogger(__name__) logging.getLogger("botocore").setLevel(logging.CRITICAL) session = Session() config = Config(connect_timeout=10, read_timeout=310) client = session.client("lambda", config=config) class LambdaLoadTest(object): """ An object to run and collect statistics and results from multiple parallel locust load tests running on AWS Lambda """ def __init__( self, lambda_function_name, threads, ramp_time, time_limit, lambda_payload, lambda_timeout=300000, ): self.lock = threading.Lock() self.start_time = time.time() self.logger = logging.getLogger() self.threads = threads self.ramp_time = ramp_time self.time_limit = ( time_limit # don't start new threads after {time_limit} seconds ) self.lambda_function_name = lambda_function_name self.lambda_payload = lambda_payload self.lambda_invocation_errors = 0 self.lambda_invocation_count = 0 self.lambda_invocation_error_threshold = 20 self.lambda_total_execution_time = 0 self.requests_fail = 0 self.request_fail_ratio_threshold = 0.5 self.requests_total = 0 self.locust_results = [] self.thread_data = {} self.print_stats_delay = 3 self.exit_threads = False self.lambda_timeout = lambda_timeout def update_thread_data(self, thread_id, key, value): """ Receives data from threads and stores in the thread_data dict """ with self.lock: if thread_id not in self.thread_data: self.thread_data[thread_id] = {} self.thread_data[thread_id][key] = value def get_thread_count(self): """ Returns number of load test threads running """ return len( [t for t in threading.enumerate() if t.getName() is not "MainThread"] ) def get_time_elapsed(self): """ Returns elapsed time in seconds since starting the load test """ return round(time.time() - self.start_time) def increase_lambda_invocation_error(self): """ Increases Lambda invocation error count """ with self.lock: self.lambda_invocation_errors += 1 def increase_lambda_invocation_count(self): """ Increases Lambda invocation count """ with self.lock: self.lambda_invocation_count += 1 def get_invocation_error_ratio(self): """ Returns ratio of Lambda invocations to invocation errors """ try: return self.lambda_invocation_errors / float(self.lambda_invocation_count) except ZeroDivisionError: return 0 def increase_requests_total(self, requests): """ Increases total request count """ with self.lock: self.requests_total += requests def increase_requests_fail(self, requests): """ Increases total request fail count """ with self.lock: self.requests_fail += requests def get_request_fail_ratio(self): """ Returns ratio of failed to total requests """ try: return self.requests_fail / float(self.requests_total) except ZeroDivisionError: return 0 def append_locust_results(self, results): """ Logs results from a locust execution. All results needs to be aggregated in order to show meaningful statistics of the whole load test """ with self.lock: self.locust_results.append(results) def get_summary_stats(self): """ Returns summary statistics in a dict """ return { "lambda_invocation_count": self.lambda_invocation_count, "total_lambda_execution_time": self.lambda_total_execution_time, "requests_total": self.requests_total, "request_fail_ratio": self.get_request_fail_ratio(), "invocation_error_ratio": self.get_invocation_error_ratio(), } def get_stats(self): """ Returns current statistics in a dict """ return { "thread_count": self.get_thread_count(), "rpm": self.calculate_rpm(), "time_elapsed": self.get_time_elapsed(), "requests_total": self.requests_total, "request_fail_ratio": self.get_request_fail_ratio(), "invocation_error_ratio": self.get_invocation_error_ratio(), } def get_locust_results(self): """ Returns a list of locust results """ return self.locust_results def increase_lambda_execution_time(self, time): """ Add Lambda execution time to the total """ with self.lock: self.lambda_total_execution_time += time def calculate_rpm(self): """ Returns current total request per minute across all threads """ return round( sum( self.thread_data[thread_id]["rpm"] for thread_id in self.thread_data if "rpm" in self.thread_data[thread_id] ) ) def check_error_threshold(self): """ Checks if the current Lambda and request fail ratios are within thresholds """ if self.lambda_invocation_errors > self.lambda_invocation_error_threshold: self.logger.error( f"Error limit reached. invocation error count/threshold: " f"{self.lambda_invocation_errors}/{self.lambda_invocation_error_threshold}" ) return True elif self.get_request_fail_ratio() > self.request_fail_ratio_threshold: self.logger.error( f"Error limit reached. requests failed ratio/threshold: " f"{self.get_request_fail_ratio()}/{self.request_fail_ratio_threshold}" ) return True else: return False def thread_required(self): """ Returns True if a new thread should be started when ramping up over time """ result = False if self.get_thread_count() < self.threads: next_thread_interval = ( self.ramp_time / self.threads ) * self.get_thread_count() if self.get_time_elapsed() > next_thread_interval: result = True return result def stop_threads(self): """ Sets a boolean to stop threads """ with self.lock: self.exit_threads = True def start_new_thread(self): """ Creates a new load test thread """ t_name = "thread_{0}".format(threading.activeCount()) t = threading.Thread(name=t_name, target=self.thread) t.daemon = True t.start() def thread(self): """ This method is a single thread and performs the actual execution of the Lambda function and logs the statistics/results """ self.logger.info("thread started") thread_start_time = time.time() thread_id = threading.current_thread().getName() self.update_thread_data(thread_id, "start_time", thread_start_time) while True: thread_run_time = time.time() - thread_start_time if self.exit_threads: break if self.ramp_time in [0.0, 0]: sleep_time = 0 else: sleep_time = round(max(0, self.ramp_time - thread_run_time) / 30) function_start_time = time.time() try: self.logger.info("Invoking lambda...") response = client.invoke( FunctionName=self.lambda_function_name, Payload=json.dumps(self.lambda_payload), ) except Exception as e: self.logger.critical("Lambda invocation failed: {0}".format(repr(e))) time.sleep(2) continue function_end_time = time.time() self.increase_lambda_invocation_count() if "FunctionError" in response: logger.error( "error {0}: {1}".format( response["FunctionError"], response["Payload"].read() ) ) self.increase_lambda_invocation_error() time.sleep(2) continue payload = response["Payload"].read() payload_json_str = json.loads(payload.decode("utf-8")) if not payload_json_str: logger.error("No results in payload") self.increase_lambda_invocation_error() time.sleep(2) continue results = json.loads(payload_json_str) function_duration = function_end_time - function_start_time total_rpm = results["num_requests"] / (function_duration / 60) lambda_execution_time = self.lambda_timeout - results["remaining_time"] self.append_locust_results(results) self.increase_requests_fail(results["num_requests_fail"]) self.increase_requests_total(results["num_requests"]) self.update_thread_data(thread_id, "rpm", total_rpm) self.update_thread_data( thread_id, "lambda_execution_time", lambda_execution_time ) self.increase_lambda_execution_time(lambda_execution_time) logger.info( "Lambda invocation complete. Requests (errors): {0} ({1}), execution time: {2}ms, sleeping: {3}s".format( results["num_requests"], results["num_requests_fail"], lambda_execution_time, sleep_time, ) ) time.sleep(sleep_time) self.logger.info("thread finished") def run(self): """ Starts the load test, periodically prints statistics and starts new threads """ self.logger.info( "\nStarting load test..." f"\nFunction name: {self.lambda_function_name}" f"\nRamp time: {self.ramp_time}s" f"\nThreads: {self.threads}" f"\nLambda payload: {self.lambda_payload}" f"\nStart ramping down after: {self.time_limit}s" ) self.start_new_thread() while True: self.logger.info( "threads: {thread_count}, rpm: {rpm}, time elapsed: {time_elapsed}s, total requests from finished threads: {requests_total}, " "request fail ratio: {request_fail_ratio}, invocation error ratio: {invocation_error_ratio}".format( **self.get_stats() ) ) if self.thread_required(): self.start_new_thread() if self.check_error_threshold(): self.stop_threads() self.logger.info("Waiting for threads to exit...") while self.get_thread_count() > 0: time.sleep(1) else: break if self.time_limit and self.get_time_elapsed() > self.time_limit: self.logger.info("Time limit reached. Starting ramp down...") self.stop_threads() self.logger.info( "Waiting for all Lambdas to return. This may take up to {0}.".format( self.lambda_payload["run_time"] ) ) while self.get_thread_count() > 0: time.sleep(1) else: break time.sleep(self.print_stats_delay)