diff --git a/awslambdaric/bootstrap.py b/awslambdaric/bootstrap.py index a3da58c..f87ee1b 100644 --- a/awslambdaric/bootstrap.py +++ b/awslambdaric/bootstrap.py @@ -462,8 +462,14 @@ def run(app_root, handler, lambda_runtime_api_addr): sys.stdout = Unbuffered(sys.stdout) sys.stderr = Unbuffered(sys.stderr) + use_thread_for_polling_next = ( + os.environ.get("AWS_EXECUTION_ENV") == "AWS_Lambda_python3.12" + ) + with create_log_sink() as log_sink: - lambda_runtime_client = LambdaRuntimeClient(lambda_runtime_api_addr) + lambda_runtime_client = LambdaRuntimeClient( + lambda_runtime_api_addr, use_thread_for_polling_next + ) try: _setup_logging(_AWS_LAMBDA_LOG_FORMAT, _AWS_LAMBDA_LOG_LEVEL, log_sink) diff --git a/awslambdaric/lambda_runtime_client.py b/awslambdaric/lambda_runtime_client.py index b05918b..91ebd4c 100644 --- a/awslambdaric/lambda_runtime_client.py +++ b/awslambdaric/lambda_runtime_client.py @@ -3,8 +3,8 @@ """ import sys -from concurrent.futures import ThreadPoolExecutor from awslambdaric import __version__ +from .lambda_runtime_exception import FaultException def _user_agent(): @@ -49,8 +49,9 @@ class LambdaRuntimeClient(object): and response. It allows for function authors to override the the default implementation, LambdaMarshaller which unmarshals and marshals JSON, to an instance of a class that implements the same interface.""" - def __init__(self, lambda_runtime_address): + def __init__(self, lambda_runtime_address, use_thread_for_polling_next=False): self.lambda_runtime_address = lambda_runtime_address + self.use_thread_for_polling_next = use_thread_for_polling_next def post_init_error(self, error_response_data): # These imports are heavy-weight. They implicitly trigger `import ssl, hashlib`. @@ -69,9 +70,23 @@ def post_init_error(self, error_response_data): raise LambdaRuntimeClientError(endpoint, response.code, response_body) def wait_next_invocation(self): - with ThreadPoolExecutor() as e: - fut = e.submit(runtime_client.next) - response_body, headers = fut.result() + # Calling runtime_client.next() from a separate thread unblocks the main thread, + # which can then process signals. + if self.use_thread_for_polling_next: + try: + from concurrent.futures import ThreadPoolExecutor + + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(runtime_client.next) + response_body, headers = future.result() + except Exception as e: + raise FaultException( + FaultException.LAMBDA_RUNTIME_CLIENT_ERROR, + "LAMBDA_RUNTIME Failed to get next invocation: {}".format(str(e)), + None, + ) + else: + response_body, headers = runtime_client.next() return InvocationRequest( invoke_id=headers.get("Lambda-Runtime-Aws-Request-Id"), x_amzn_trace_id=headers.get("Lambda-Runtime-Trace-Id"), diff --git a/awslambdaric/lambda_runtime_exception.py b/awslambdaric/lambda_runtime_exception.py index 416327e..e09af70 100644 --- a/awslambdaric/lambda_runtime_exception.py +++ b/awslambdaric/lambda_runtime_exception.py @@ -12,6 +12,7 @@ class FaultException(Exception): BUILT_IN_MODULE_CONFLICT = "Runtime.BuiltInModuleConflict" MALFORMED_HANDLER_NAME = "Runtime.MalformedHandlerName" LAMBDA_CONTEXT_UNMARSHAL_ERROR = "Runtime.LambdaContextUnmarshalError" + LAMBDA_RUNTIME_CLIENT_ERROR = "Runtime.LambdaRuntimeClientError" def __init__(self, exception_type, msg, trace=None): self.msg = msg diff --git a/tests/test_lambda_runtime_client.py b/tests/test_lambda_runtime_client.py index 47d95cf..b0eae4a 100644 --- a/tests/test_lambda_runtime_client.py +++ b/tests/test_lambda_runtime_client.py @@ -84,6 +84,21 @@ def test_wait_next_invocation(self, mock_runtime_client): self.assertEqual(event_request.content_type, "application/json") self.assertEqual(event_request.event_body, response_body) + # Using ThreadPoolExecutor to polling next() + runtime_client = LambdaRuntimeClient("localhost:1234", True) + + event_request = runtime_client.wait_next_invocation() + + self.assertIsNotNone(event_request) + self.assertEqual(event_request.invoke_id, "RID1234") + self.assertEqual(event_request.x_amzn_trace_id, "TID1234") + self.assertEqual(event_request.invoked_function_arn, "FARN1234") + self.assertEqual(event_request.deadline_time_in_ms, 12) + self.assertEqual(event_request.client_context, "client_context") + self.assertEqual(event_request.cognito_identity, "cognito_identity") + self.assertEqual(event_request.content_type, "application/json") + self.assertEqual(event_request.event_body, response_body) + @patch("http.client.HTTPConnection", autospec=http.client.HTTPConnection) def test_post_init_error(self, MockHTTPConnection): mock_conn = MockHTTPConnection.return_value