From 6257398da6d0699f1646dc02aa211bfa5c879177 Mon Sep 17 00:00:00 2001 From: vicheey <181402101+vicheey@users.noreply.github.com> Date: Sun, 30 Nov 2025 22:26:44 -0800 Subject: [PATCH] add support for Lambda Managed Instance (LMI) functions --- THIRD-PARTY-LICENSES.md | 617 +++++++++-- cmd/aws-lambda-rie/main.go | 8 + .../lambda-managed-instances/agents/agent.go | 50 + .../agents/agent_test.go | 225 ++++ .../lambda-managed-instances/appctx/appctx.go | 79 ++ .../appctx/appctxutil.go | 150 +++ .../appctx/appctxutil_test.go | 186 ++++ .../aws-lambda-rie/internal/app.go | 97 ++ .../aws-lambda-rie/internal/app_test.go | 143 +++ .../aws-lambda-rie/internal/init.go | 147 +++ .../aws-lambda-rie/internal/init_test.go | 306 ++++++ .../internal/invoke/responder.go | 85 ++ .../internal/invoke/responder_test.go | 216 ++++ .../internal/invoke/rie_invoke_request.go | 148 +++ .../invoke/rie_invoke_request_test.go | 80 ++ .../mock_init_request_message_factory.go | 57 ++ .../internal/mock_raptor_app.go | 79 ++ .../aws-lambda-rie/internal/run.go | 96 ++ .../internal/telemetry/events_api.go | 101 ++ .../internal/telemetry/events_api_test.go | 157 +++ .../internal/telemetry/internal/batch.go | 41 + .../internal/telemetry/internal/batch_test.go | 71 ++ .../internal/telemetry/internal/client.go | 126 +++ .../telemetry/internal/client_test.go | 263 +++++ .../telemetry/internal/mock_client.go | 43 + .../internal/mock_logs_dropped_event_api.go | 39 + .../internal/telemetry/internal/subscriber.go | 174 ++++ .../telemetry/internal/subscriber_test.go | 61 ++ .../internal/telemetry/internal/types.go | 49 + .../internal/telemetry/logs_egress.go | 63 ++ .../internal/telemetry/logs_egress_test.go | 69 ++ .../internal/telemetry/mock_relay.go | 34 + .../internal/telemetry/mock_sub.go | 52 + .../telemetry/mock_subscription_store.go | 43 + .../mock_telemetry_subscription_event_api.go | 39 + .../internal/telemetry/relay.go | 113 +++ .../internal/telemetry/relay_test.go | 45 + .../schema/telemetry-subscription-schema.json | 81 ++ .../internal/telemetry/subscription_api.go | 188 ++++ .../telemetry/subscription_api_test.go | 380 +++++++ .../aws-lambda-rie/internal/utils.go | 67 ++ .../aws-lambda-rie/internal/utils_test.go | 211 ++++ .../aws-lambda-rie/run/run.go | 32 + .../aws-lambda-rie/test/rie_test.go | 488 +++++++++ .../core/agent_state_names.go | 16 + .../core/agentsmap.go | 72 ++ .../core/agentsmap_test.go | 75 ++ .../core/agentutil.go | 28 + .../core/bandwidthlimiter/bandwidthlimiter.go | 62 ++ .../bandwidthlimiter/bandwidthlimiter_test.go | 106 ++ .../core/bandwidthlimiter/throttler.go | 156 +++ .../core/bandwidthlimiter/throttler_test.go | 215 ++++ .../core/bandwidthlimiter/util.go | 32 + .../core/bandwidthlimiter/util_test.go | 45 + .../core/directinvoke/customerheaders.go | 41 + .../core/directinvoke/customerheaders_test.go | 26 + .../core/directinvoke/util.go | 79 ++ internal/lambda-managed-instances/core/doc.go | 4 + .../core/externalagent.go | 192 ++++ .../core/externalagent_states.go | 189 ++++ .../core/externalagent_states_test.go | 189 ++++ .../lambda-managed-instances/core/flow.go | 88 ++ .../lambda-managed-instances/core/gates.go | 144 +++ .../core/gates_test.go | 136 +++ .../core/internalagent.go | 133 +++ .../core/internalagent_states.go | 137 +++ .../core/internalagent_states_test.go | 119 +++ .../core/registrations.go | 350 +++++++ .../core/registrations_test.go | 200 ++++ .../core/runtime_state_names.go | 11 + .../core/statejson/description.go | 65 ++ .../lambda-managed-instances/core/states.go | 215 ++++ .../core/states_test.go | 148 +++ .../interop/cancellable_request.go | 27 + .../interop/error_utils.go | 23 + .../interop/error_utils_test.go | 48 + .../interop/events_api.go | 205 ++++ .../interop/events_api_test.go | 884 ++++++++++++++++ .../interop/mock_duration_metric_timer.go | 26 + .../interop/mock_events_api.go | 149 +++ .../interop/mock_health_check_response.go | 26 + .../interop/mock_init_metrics.go | 93 ++ .../interop/mock_init_response.go | 26 + .../interop/mock_init_static_data_provider.go | 248 +++++ .../interop/mock_internal_state_getter.go | 42 + .../interop/mock_invoke_metrics.go | 149 +++ .../interop/mock_invoke_request.go | 304 ++++++ .../interop/mock_invoke_response.go | 26 + .../interop/mock_invoke_response_sender.go | 80 ++ .../interop/mock_message.go | 22 + .../interop/mock_rapid_context.go | 131 +++ .../interop/mock_server.go | 51 + .../interop/mock_shutdown_metrics.go | 57 ++ .../interop/mock_shutdown_response.go | 26 + .../lambda-managed-instances/interop/model.go | 291 ++++++ .../interop/model_test.go | 85 ++ .../interop/response_status.go | 13 + .../interop/sandbox_model.go | 344 +++++++ .../interop/service_log_values.go | 37 + .../lambda-managed-instances/invoke/consts.go | 21 + .../invoke/invoke_router.go | 201 ++++ .../invoke/invoke_router_test.go | 345 +++++++ .../invoke/metrics.go | 382 +++++++ .../invoke/metrics_test.go | 654 ++++++++++++ .../invoke/mock_counter.go | 26 + .../invoke/mock_error_for_invoker.go | 93 ++ .../invoke/mock_invoke_response_sender.go | 79 ++ .../invoke/mock_responder_factory_func.go | 46 + .../invoke/mock_running_invoke.go | 109 ++ .../invoke/mock_runtime_error_request.go | 184 ++++ .../invoke/mock_runtime_response_request.go | 135 +++ .../invoke/mock_timeout_cache.go | 43 + .../invoke/running_invoke.go | 329 ++++++ .../invoke/running_invoke_test.go | 478 +++++++++ .../invoke/runtime_error_request.go | 110 ++ .../invoke/runtime_error_request_test.go | 154 +++ .../invoke/runtime_response_request.go | 122 +++ .../invoke/runtime_response_request_test.go | 81 ++ .../invoke/runtime_response_sender.go | 89 ++ .../invoke/runtime_response_sender_test.go | 136 +++ .../invoke/timeout/timeout_cache.go | 65 ++ .../invoke/timeout/timeout_cache_test.go | 79 ++ .../lambda-managed-instances/invoke/utils.go | 30 + .../logging/contextual_logger.go | 82 ++ .../logging/contextual_logger_test.go | 91 ++ .../lambda-managed-instances/model/init.go | 141 +++ .../model/init_test.go | 171 ++++ .../lambda-managed-instances/model/model.go | 18 + internal/lambda-managed-instances/ptr/ptr.go | 8 + .../rapi/extensions_fuzz_test.go | 311 ++++++ .../rapi/handler/agentexiterror.go | 72 ++ .../rapi/handler/agentiniterror.go | 69 ++ .../rapi/handler/agentiniterror_test.go | 120 +++ .../rapi/handler/agentnext.go | 66 ++ .../rapi/handler/agentnext_test.go | 289 ++++++ .../rapi/handler/agentregister.go | 224 ++++ .../rapi/handler/agentregister_test.go | 314 ++++++ .../rapi/handler/constants.go | 22 + .../rapi/handler/initerror.go | 54 + .../rapi/handler/initerror_test.go | 48 + .../rapi/handler/invocationerror.go | 68 ++ .../rapi/handler/invocationnext.go | 58 ++ .../rapi/handler/invocationresponse.go | 68 ++ .../rapi/handler/ping.go | 23 + .../rapi/handler/runtimelogs.go | 138 +++ .../rapi/handler/runtimelogs_stub.go | 49 + .../rapi/handler/runtimelogs_stub_test.go | 25 + .../rapi/handler/runtimelogs_test.go | 356 +++++++ .../rapi/middleware/middleware.go | 66 ++ .../rapi/middleware/middleware_test.go | 94 ++ .../rapi/model/agentevent.go | 28 + .../rapi/model/agentregisterresponse.go | 11 + .../rapi/model/cognitoidentity.go | 9 + .../rapi/model/constants.go | 16 + .../rapi/model/error_cause.go | 98 ++ .../rapi/model/error_cause_compactor.go | 69 ++ .../rapi/model/error_cause_compactor_test.go | 83 ++ .../rapi/model/error_cause_test.go | 142 +++ .../rapi/model/errorresponse.go | 10 + .../rapi/model/statusresponse.go | 8 + .../rapi/model/tracing.go | 35 + .../rapi/rapi_fuzz_test.go | 98 ++ .../rapi/rendering/doc.go | 4 + .../rapi/rendering/render_error.go | 93 ++ .../rapi/rendering/render_json.go | 28 + .../rapi/rendering/rendering.go | 259 +++++ .../lambda-managed-instances/rapi/router.go | 79 ++ .../lambda-managed-instances/rapi/server.go | 122 +++ .../rapi/telemetry_logs_fuzz_test.go | 168 +++ .../rapid/handlers.go | 379 +++++++ .../rapid/handlers_test.go | 958 ++++++++++++++++++ .../rapid/init_metrics.go | 158 +++ .../rapid/init_metrics_test.go | 213 ++++ .../rapid/model/client_error.go | 36 + .../rapid/model/client_error_test.go | 33 + .../rapid/model/customer_error.go | 94 ++ .../rapid/model/customer_error_test.go | 36 + .../rapid/model/error_types.go | 144 +++ .../rapid/model/error_types_test.go | 62 ++ .../rapid/model/exec.go | 55 + .../rapid/model/function_metadata.go | 24 + .../rapid/model/interfaces.go | 10 + .../rapid/model/platform_error.go | 51 + .../rapid/model/platform_error_test.go | 31 + .../lambda-managed-instances/rapid/sandbox.go | 137 +++ .../rapid/shutdown.go | 391 +++++++ .../rapid/shutdown_metrics.go | 149 +++ .../rapid/shutdown_metrics_test.go | 222 ++++ .../rapid/shutdown_test.go | 393 +++++++ .../rapidcore/env/environment.go | 179 ++++ .../rapidcore/env/environment_test.go | 238 +++++ .../rapidcore/env/util.go | 43 + .../rapidcore/env/util_test.go | 36 + .../rapidcore/errors.go | 27 + .../rapidcore/runtime_release.go | 97 ++ .../rapidcore/runtime_release_test.go | 151 +++ .../lambda-managed-instances/raptor/app.go | 209 ++++ .../raptor/app_test.go | 263 +++++ .../raptor/internal/raptor_state.go | 93 ++ .../raptor/internal/raptor_state_test.go | 239 +++++ .../raptor/mock_address.go | 64 ++ .../raptor/mock_raptor_logger.go | 54 + .../raptor/mock_shutdown_handler.go | 29 + .../raptor/raptor_utils.go | 74 ++ .../lambda-managed-instances/raptor/server.go | 136 +++ .../raptor/server_test.go | 86 ++ .../servicelogs/logger.go | 54 + .../servicelogs/mock_logger.go | 47 + .../supervisor/local/process.go | 324 ++++++ .../supervisor/local/process_test.go | 228 +++++ .../supervisor/model/mock_lock_hard_error.go | 90 ++ .../model/mock_process_supervisor.go | 106 ++ .../model/mock_process_supervisor_client.go | 106 ++ .../supervisor/model/process.go | 230 +++++ .../supervisor/model/process_test.go | 186 ++++ .../telemetry/constants.go | 17 + .../telemetry/events.go | 142 +++ .../telemetry/events_api.go | 90 ++ .../telemetry/events_api_test.go | 32 + .../telemetry/logs_egress_api.go | 28 + .../telemetry/logs_subscription_api.go | 48 + .../telemetry/xray/tracer.go | 109 ++ .../telemetry/xray/tracer_test.go | 126 +++ .../testdata/agents/bash_true.sh | 1 + .../testdata/async_assertion_utils.go | 32 + .../testdata/bash_function.sh | 7 + .../testdata/bash_runtime.sh | 22 + .../testdata/bash_script_with_child_proc.sh | 14 + .../testdata/env_setup_helpers.go | 108 ++ .../testdata/flowtesting.go | 102 ++ .../testdata/mockcommand.go | 35 + .../testdata/mockthread/mockthread.go | 14 + .../testdata/mocktracer/mocktracer.go | 93 ++ .../testdata/parametrization.go | 12 + .../testutils/functional/chunked.go | 45 + .../testutils/functional/doc.go | 4 + .../testutils/functional/extension_actions.go | 397 ++++++++ .../testutils/functional/extensions_client.go | 254 +++++ .../testutils/functional/fluxpump_server.go | 153 +++ .../testutils/functional/httputils.go | 56 + .../functional/in_memory_events_api.go | 350 +++++++ .../functional/process_supervisor.go | 285 ++++++ .../testutils/functional/runtime_actions.go | 238 +++++ .../testutils/functional/runtime_client.go | 149 +++ .../testutils/functional/supv.go | 88 ++ .../testutils/mocks/http_handler_mock.go | 25 + .../testutils/mocks/http_mock.go | 56 + .../testutils/socket_utils.go | 49 + .../testutils/socket_utils_test.go | 28 + .../testutils/test_data.go | 133 +++ .../utils/buffer_pool.go | 22 + .../utils/buffer_pool_test.go | 44 + .../utils/file_utils.go | 49 + .../utils/file_utils_test.go | 39 + .../utils/invariant/invariant.go | 46 + .../utils/invariant/invariant_test.go | 76 ++ .../invariant/mock_violation_executor.go | 26 + .../utils/invariant/model.go | 16 + .../utils/invariant/panic.go | 16 + .../utils/invariant/panic_test.go | 20 + internal/lambda-managed-instances/utils/io.go | 58 ++ .../lambda-managed-instances/utils/io_test.go | 95 ++ .../utils/mock_file_util.go | 101 ++ .../lambda-managed-instances/utils/mocks.go | 63 ++ internal/lambda/appctx/appctxutil_test.go | 2 +- .../lambda/core/directinvoke/directinvoke.go | 2 +- .../core/directinvoke/directinvoke_test.go | 6 +- .../lambda/core/externalagent_states_test.go | 2 +- .../lambda/core/internalagent_states_test.go | 2 +- internal/lambda/core/states_test.go | 4 +- internal/lambda/interop/events_api_test.go | 2 +- internal/lambda/rapi/extensions_fuzz_test.go | 4 +- .../lambda/rapi/handler/agentexiterror.go | 2 +- .../lambda/rapi/handler/agentiniterror.go | 2 +- .../rapi/handler/agentiniterror_test.go | 4 +- internal/lambda/rapi/handler/agentnext.go | 4 +- .../lambda/rapi/handler/agentnext_test.go | 4 +- internal/lambda/rapi/handler/agentregister.go | 2 +- .../lambda/rapi/handler/agentregister_test.go | 2 +- .../lambda/rapi/handler/credentials_test.go | 2 +- .../lambda/rapi/handler/initerror_test.go | 2 +- .../rapi/handler/invocationresponse_test.go | 2 +- internal/lambda/rapi/handler/restoreerror.go | 2 +- .../lambda/rapi/handler/restoreerror_test.go | 2 +- internal/lambda/rapi/handler/restorenext.go | 2 +- .../lambda/rapi/handler/restorenext_test.go | 2 +- .../lambda/rapi/handler/runtimelogs_stub.go | 2 +- internal/lambda/rapi/middleware/middleware.go | 4 +- .../lambda/rapi/middleware/middleware_test.go | 6 +- internal/lambda/rapi/rapi_fuzz_test.go | 2 +- .../lambda/rapi/rendering/render_error.go | 2 +- internal/lambda/rapi/server.go | 2 +- .../lambda/rapi/telemetry_logs_fuzz_test.go | 4 +- internal/lambda/rapidcore/server_test.go | 2 +- .../standalone/directInvokeHandler.go | 2 +- .../rapidcore/standalone/executeHandler.go | 2 +- .../rapidcore/standalone/reserveHandler.go | 2 +- .../rapidcore/standalone/restoreHandler.go | 2 +- internal/lambda/rapidcore/standalone/util.go | 2 +- internal/lambda/rie/http.go | 2 +- internal/lambda/rie/run.go | 2 +- .../lambda/supervisor/local_supervisor.go | 2 +- .../supervisor/local_supervisor_test.go | 2 +- internal/lambda/telemetry/events_api_test.go | 2 +- licenses.tpl | 11 + 305 files changed, 31823 insertions(+), 149 deletions(-) create mode 100644 internal/lambda-managed-instances/agents/agent.go create mode 100644 internal/lambda-managed-instances/agents/agent_test.go create mode 100644 internal/lambda-managed-instances/appctx/appctx.go create mode 100644 internal/lambda-managed-instances/appctx/appctxutil.go create mode 100644 internal/lambda-managed-instances/appctx/appctxutil_test.go create mode 100644 internal/lambda-managed-instances/aws-lambda-rie/internal/app.go create mode 100644 internal/lambda-managed-instances/aws-lambda-rie/internal/app_test.go create mode 100644 internal/lambda-managed-instances/aws-lambda-rie/internal/init.go create mode 100644 internal/lambda-managed-instances/aws-lambda-rie/internal/init_test.go create mode 100644 internal/lambda-managed-instances/aws-lambda-rie/internal/invoke/responder.go create mode 100644 internal/lambda-managed-instances/aws-lambda-rie/internal/invoke/responder_test.go create mode 100644 internal/lambda-managed-instances/aws-lambda-rie/internal/invoke/rie_invoke_request.go create mode 100644 internal/lambda-managed-instances/aws-lambda-rie/internal/invoke/rie_invoke_request_test.go create mode 100644 internal/lambda-managed-instances/aws-lambda-rie/internal/mock_init_request_message_factory.go create mode 100644 internal/lambda-managed-instances/aws-lambda-rie/internal/mock_raptor_app.go create mode 100644 internal/lambda-managed-instances/aws-lambda-rie/internal/run.go create mode 100644 internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/events_api.go create mode 100644 internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/events_api_test.go create mode 100644 internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/internal/batch.go create mode 100644 internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/internal/batch_test.go create mode 100644 internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/internal/client.go create mode 100644 internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/internal/client_test.go create mode 100644 internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/internal/mock_client.go create mode 100644 internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/internal/mock_logs_dropped_event_api.go create mode 100644 internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/internal/subscriber.go create mode 100644 internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/internal/subscriber_test.go create mode 100644 internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/internal/types.go create mode 100644 internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/logs_egress.go create mode 100644 internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/logs_egress_test.go create mode 100644 internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/mock_relay.go create mode 100644 internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/mock_sub.go create mode 100644 internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/mock_subscription_store.go create mode 100644 internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/mock_telemetry_subscription_event_api.go create mode 100644 internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/relay.go create mode 100644 internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/relay_test.go create mode 100644 internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/schema/telemetry-subscription-schema.json create mode 100644 internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/subscription_api.go create mode 100644 internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/subscription_api_test.go create mode 100644 internal/lambda-managed-instances/aws-lambda-rie/internal/utils.go create mode 100644 internal/lambda-managed-instances/aws-lambda-rie/internal/utils_test.go create mode 100644 internal/lambda-managed-instances/aws-lambda-rie/run/run.go create mode 100644 internal/lambda-managed-instances/aws-lambda-rie/test/rie_test.go create mode 100644 internal/lambda-managed-instances/core/agent_state_names.go create mode 100644 internal/lambda-managed-instances/core/agentsmap.go create mode 100644 internal/lambda-managed-instances/core/agentsmap_test.go create mode 100644 internal/lambda-managed-instances/core/agentutil.go create mode 100644 internal/lambda-managed-instances/core/bandwidthlimiter/bandwidthlimiter.go create mode 100644 internal/lambda-managed-instances/core/bandwidthlimiter/bandwidthlimiter_test.go create mode 100644 internal/lambda-managed-instances/core/bandwidthlimiter/throttler.go create mode 100644 internal/lambda-managed-instances/core/bandwidthlimiter/throttler_test.go create mode 100644 internal/lambda-managed-instances/core/bandwidthlimiter/util.go create mode 100644 internal/lambda-managed-instances/core/bandwidthlimiter/util_test.go create mode 100644 internal/lambda-managed-instances/core/directinvoke/customerheaders.go create mode 100644 internal/lambda-managed-instances/core/directinvoke/customerheaders_test.go create mode 100644 internal/lambda-managed-instances/core/directinvoke/util.go create mode 100644 internal/lambda-managed-instances/core/doc.go create mode 100644 internal/lambda-managed-instances/core/externalagent.go create mode 100644 internal/lambda-managed-instances/core/externalagent_states.go create mode 100644 internal/lambda-managed-instances/core/externalagent_states_test.go create mode 100644 internal/lambda-managed-instances/core/flow.go create mode 100644 internal/lambda-managed-instances/core/gates.go create mode 100644 internal/lambda-managed-instances/core/gates_test.go create mode 100644 internal/lambda-managed-instances/core/internalagent.go create mode 100644 internal/lambda-managed-instances/core/internalagent_states.go create mode 100644 internal/lambda-managed-instances/core/internalagent_states_test.go create mode 100644 internal/lambda-managed-instances/core/registrations.go create mode 100644 internal/lambda-managed-instances/core/registrations_test.go create mode 100644 internal/lambda-managed-instances/core/runtime_state_names.go create mode 100644 internal/lambda-managed-instances/core/statejson/description.go create mode 100644 internal/lambda-managed-instances/core/states.go create mode 100644 internal/lambda-managed-instances/core/states_test.go create mode 100644 internal/lambda-managed-instances/interop/cancellable_request.go create mode 100644 internal/lambda-managed-instances/interop/error_utils.go create mode 100644 internal/lambda-managed-instances/interop/error_utils_test.go create mode 100644 internal/lambda-managed-instances/interop/events_api.go create mode 100644 internal/lambda-managed-instances/interop/events_api_test.go create mode 100644 internal/lambda-managed-instances/interop/mock_duration_metric_timer.go create mode 100644 internal/lambda-managed-instances/interop/mock_events_api.go create mode 100644 internal/lambda-managed-instances/interop/mock_health_check_response.go create mode 100644 internal/lambda-managed-instances/interop/mock_init_metrics.go create mode 100644 internal/lambda-managed-instances/interop/mock_init_response.go create mode 100644 internal/lambda-managed-instances/interop/mock_init_static_data_provider.go create mode 100644 internal/lambda-managed-instances/interop/mock_internal_state_getter.go create mode 100644 internal/lambda-managed-instances/interop/mock_invoke_metrics.go create mode 100644 internal/lambda-managed-instances/interop/mock_invoke_request.go create mode 100644 internal/lambda-managed-instances/interop/mock_invoke_response.go create mode 100644 internal/lambda-managed-instances/interop/mock_invoke_response_sender.go create mode 100644 internal/lambda-managed-instances/interop/mock_message.go create mode 100644 internal/lambda-managed-instances/interop/mock_rapid_context.go create mode 100644 internal/lambda-managed-instances/interop/mock_server.go create mode 100644 internal/lambda-managed-instances/interop/mock_shutdown_metrics.go create mode 100644 internal/lambda-managed-instances/interop/mock_shutdown_response.go create mode 100644 internal/lambda-managed-instances/interop/model.go create mode 100644 internal/lambda-managed-instances/interop/model_test.go create mode 100644 internal/lambda-managed-instances/interop/response_status.go create mode 100644 internal/lambda-managed-instances/interop/sandbox_model.go create mode 100644 internal/lambda-managed-instances/interop/service_log_values.go create mode 100644 internal/lambda-managed-instances/invoke/consts.go create mode 100644 internal/lambda-managed-instances/invoke/invoke_router.go create mode 100644 internal/lambda-managed-instances/invoke/invoke_router_test.go create mode 100644 internal/lambda-managed-instances/invoke/metrics.go create mode 100644 internal/lambda-managed-instances/invoke/metrics_test.go create mode 100644 internal/lambda-managed-instances/invoke/mock_counter.go create mode 100644 internal/lambda-managed-instances/invoke/mock_error_for_invoker.go create mode 100644 internal/lambda-managed-instances/invoke/mock_invoke_response_sender.go create mode 100644 internal/lambda-managed-instances/invoke/mock_responder_factory_func.go create mode 100644 internal/lambda-managed-instances/invoke/mock_running_invoke.go create mode 100644 internal/lambda-managed-instances/invoke/mock_runtime_error_request.go create mode 100644 internal/lambda-managed-instances/invoke/mock_runtime_response_request.go create mode 100644 internal/lambda-managed-instances/invoke/mock_timeout_cache.go create mode 100644 internal/lambda-managed-instances/invoke/running_invoke.go create mode 100644 internal/lambda-managed-instances/invoke/running_invoke_test.go create mode 100644 internal/lambda-managed-instances/invoke/runtime_error_request.go create mode 100644 internal/lambda-managed-instances/invoke/runtime_error_request_test.go create mode 100644 internal/lambda-managed-instances/invoke/runtime_response_request.go create mode 100644 internal/lambda-managed-instances/invoke/runtime_response_request_test.go create mode 100644 internal/lambda-managed-instances/invoke/runtime_response_sender.go create mode 100644 internal/lambda-managed-instances/invoke/runtime_response_sender_test.go create mode 100644 internal/lambda-managed-instances/invoke/timeout/timeout_cache.go create mode 100644 internal/lambda-managed-instances/invoke/timeout/timeout_cache_test.go create mode 100644 internal/lambda-managed-instances/invoke/utils.go create mode 100644 internal/lambda-managed-instances/logging/contextual_logger.go create mode 100644 internal/lambda-managed-instances/logging/contextual_logger_test.go create mode 100644 internal/lambda-managed-instances/model/init.go create mode 100644 internal/lambda-managed-instances/model/init_test.go create mode 100644 internal/lambda-managed-instances/model/model.go create mode 100644 internal/lambda-managed-instances/ptr/ptr.go create mode 100644 internal/lambda-managed-instances/rapi/extensions_fuzz_test.go create mode 100644 internal/lambda-managed-instances/rapi/handler/agentexiterror.go create mode 100644 internal/lambda-managed-instances/rapi/handler/agentiniterror.go create mode 100644 internal/lambda-managed-instances/rapi/handler/agentiniterror_test.go create mode 100644 internal/lambda-managed-instances/rapi/handler/agentnext.go create mode 100644 internal/lambda-managed-instances/rapi/handler/agentnext_test.go create mode 100644 internal/lambda-managed-instances/rapi/handler/agentregister.go create mode 100644 internal/lambda-managed-instances/rapi/handler/agentregister_test.go create mode 100644 internal/lambda-managed-instances/rapi/handler/constants.go create mode 100644 internal/lambda-managed-instances/rapi/handler/initerror.go create mode 100644 internal/lambda-managed-instances/rapi/handler/initerror_test.go create mode 100644 internal/lambda-managed-instances/rapi/handler/invocationerror.go create mode 100644 internal/lambda-managed-instances/rapi/handler/invocationnext.go create mode 100644 internal/lambda-managed-instances/rapi/handler/invocationresponse.go create mode 100644 internal/lambda-managed-instances/rapi/handler/ping.go create mode 100644 internal/lambda-managed-instances/rapi/handler/runtimelogs.go create mode 100644 internal/lambda-managed-instances/rapi/handler/runtimelogs_stub.go create mode 100644 internal/lambda-managed-instances/rapi/handler/runtimelogs_stub_test.go create mode 100644 internal/lambda-managed-instances/rapi/handler/runtimelogs_test.go create mode 100644 internal/lambda-managed-instances/rapi/middleware/middleware.go create mode 100644 internal/lambda-managed-instances/rapi/middleware/middleware_test.go create mode 100644 internal/lambda-managed-instances/rapi/model/agentevent.go create mode 100644 internal/lambda-managed-instances/rapi/model/agentregisterresponse.go create mode 100644 internal/lambda-managed-instances/rapi/model/cognitoidentity.go create mode 100644 internal/lambda-managed-instances/rapi/model/constants.go create mode 100644 internal/lambda-managed-instances/rapi/model/error_cause.go create mode 100644 internal/lambda-managed-instances/rapi/model/error_cause_compactor.go create mode 100644 internal/lambda-managed-instances/rapi/model/error_cause_compactor_test.go create mode 100644 internal/lambda-managed-instances/rapi/model/error_cause_test.go create mode 100644 internal/lambda-managed-instances/rapi/model/errorresponse.go create mode 100644 internal/lambda-managed-instances/rapi/model/statusresponse.go create mode 100644 internal/lambda-managed-instances/rapi/model/tracing.go create mode 100644 internal/lambda-managed-instances/rapi/rapi_fuzz_test.go create mode 100644 internal/lambda-managed-instances/rapi/rendering/doc.go create mode 100644 internal/lambda-managed-instances/rapi/rendering/render_error.go create mode 100644 internal/lambda-managed-instances/rapi/rendering/render_json.go create mode 100644 internal/lambda-managed-instances/rapi/rendering/rendering.go create mode 100644 internal/lambda-managed-instances/rapi/router.go create mode 100644 internal/lambda-managed-instances/rapi/server.go create mode 100644 internal/lambda-managed-instances/rapi/telemetry_logs_fuzz_test.go create mode 100644 internal/lambda-managed-instances/rapid/handlers.go create mode 100644 internal/lambda-managed-instances/rapid/handlers_test.go create mode 100644 internal/lambda-managed-instances/rapid/init_metrics.go create mode 100644 internal/lambda-managed-instances/rapid/init_metrics_test.go create mode 100644 internal/lambda-managed-instances/rapid/model/client_error.go create mode 100644 internal/lambda-managed-instances/rapid/model/client_error_test.go create mode 100644 internal/lambda-managed-instances/rapid/model/customer_error.go create mode 100644 internal/lambda-managed-instances/rapid/model/customer_error_test.go create mode 100644 internal/lambda-managed-instances/rapid/model/error_types.go create mode 100644 internal/lambda-managed-instances/rapid/model/error_types_test.go create mode 100644 internal/lambda-managed-instances/rapid/model/exec.go create mode 100644 internal/lambda-managed-instances/rapid/model/function_metadata.go create mode 100644 internal/lambda-managed-instances/rapid/model/interfaces.go create mode 100644 internal/lambda-managed-instances/rapid/model/platform_error.go create mode 100644 internal/lambda-managed-instances/rapid/model/platform_error_test.go create mode 100644 internal/lambda-managed-instances/rapid/sandbox.go create mode 100644 internal/lambda-managed-instances/rapid/shutdown.go create mode 100644 internal/lambda-managed-instances/rapid/shutdown_metrics.go create mode 100644 internal/lambda-managed-instances/rapid/shutdown_metrics_test.go create mode 100644 internal/lambda-managed-instances/rapid/shutdown_test.go create mode 100644 internal/lambda-managed-instances/rapidcore/env/environment.go create mode 100644 internal/lambda-managed-instances/rapidcore/env/environment_test.go create mode 100644 internal/lambda-managed-instances/rapidcore/env/util.go create mode 100644 internal/lambda-managed-instances/rapidcore/env/util_test.go create mode 100644 internal/lambda-managed-instances/rapidcore/errors.go create mode 100644 internal/lambda-managed-instances/rapidcore/runtime_release.go create mode 100644 internal/lambda-managed-instances/rapidcore/runtime_release_test.go create mode 100644 internal/lambda-managed-instances/raptor/app.go create mode 100644 internal/lambda-managed-instances/raptor/app_test.go create mode 100644 internal/lambda-managed-instances/raptor/internal/raptor_state.go create mode 100644 internal/lambda-managed-instances/raptor/internal/raptor_state_test.go create mode 100644 internal/lambda-managed-instances/raptor/mock_address.go create mode 100644 internal/lambda-managed-instances/raptor/mock_raptor_logger.go create mode 100644 internal/lambda-managed-instances/raptor/mock_shutdown_handler.go create mode 100644 internal/lambda-managed-instances/raptor/raptor_utils.go create mode 100644 internal/lambda-managed-instances/raptor/server.go create mode 100644 internal/lambda-managed-instances/raptor/server_test.go create mode 100644 internal/lambda-managed-instances/servicelogs/logger.go create mode 100644 internal/lambda-managed-instances/servicelogs/mock_logger.go create mode 100644 internal/lambda-managed-instances/supervisor/local/process.go create mode 100644 internal/lambda-managed-instances/supervisor/local/process_test.go create mode 100644 internal/lambda-managed-instances/supervisor/model/mock_lock_hard_error.go create mode 100644 internal/lambda-managed-instances/supervisor/model/mock_process_supervisor.go create mode 100644 internal/lambda-managed-instances/supervisor/model/mock_process_supervisor_client.go create mode 100644 internal/lambda-managed-instances/supervisor/model/process.go create mode 100644 internal/lambda-managed-instances/supervisor/model/process_test.go create mode 100644 internal/lambda-managed-instances/telemetry/constants.go create mode 100644 internal/lambda-managed-instances/telemetry/events.go create mode 100644 internal/lambda-managed-instances/telemetry/events_api.go create mode 100644 internal/lambda-managed-instances/telemetry/events_api_test.go create mode 100644 internal/lambda-managed-instances/telemetry/logs_egress_api.go create mode 100644 internal/lambda-managed-instances/telemetry/logs_subscription_api.go create mode 100644 internal/lambda-managed-instances/telemetry/xray/tracer.go create mode 100644 internal/lambda-managed-instances/telemetry/xray/tracer_test.go create mode 100755 internal/lambda-managed-instances/testdata/agents/bash_true.sh create mode 100644 internal/lambda-managed-instances/testdata/async_assertion_utils.go create mode 100755 internal/lambda-managed-instances/testdata/bash_function.sh create mode 100755 internal/lambda-managed-instances/testdata/bash_runtime.sh create mode 100755 internal/lambda-managed-instances/testdata/bash_script_with_child_proc.sh create mode 100644 internal/lambda-managed-instances/testdata/env_setup_helpers.go create mode 100644 internal/lambda-managed-instances/testdata/flowtesting.go create mode 100644 internal/lambda-managed-instances/testdata/mockcommand.go create mode 100644 internal/lambda-managed-instances/testdata/mockthread/mockthread.go create mode 100644 internal/lambda-managed-instances/testdata/mocktracer/mocktracer.go create mode 100644 internal/lambda-managed-instances/testdata/parametrization.go create mode 100644 internal/lambda-managed-instances/testutils/functional/chunked.go create mode 100644 internal/lambda-managed-instances/testutils/functional/doc.go create mode 100644 internal/lambda-managed-instances/testutils/functional/extension_actions.go create mode 100644 internal/lambda-managed-instances/testutils/functional/extensions_client.go create mode 100644 internal/lambda-managed-instances/testutils/functional/fluxpump_server.go create mode 100644 internal/lambda-managed-instances/testutils/functional/httputils.go create mode 100644 internal/lambda-managed-instances/testutils/functional/in_memory_events_api.go create mode 100644 internal/lambda-managed-instances/testutils/functional/process_supervisor.go create mode 100644 internal/lambda-managed-instances/testutils/functional/runtime_actions.go create mode 100644 internal/lambda-managed-instances/testutils/functional/runtime_client.go create mode 100644 internal/lambda-managed-instances/testutils/functional/supv.go create mode 100644 internal/lambda-managed-instances/testutils/mocks/http_handler_mock.go create mode 100644 internal/lambda-managed-instances/testutils/mocks/http_mock.go create mode 100644 internal/lambda-managed-instances/testutils/socket_utils.go create mode 100644 internal/lambda-managed-instances/testutils/socket_utils_test.go create mode 100644 internal/lambda-managed-instances/testutils/test_data.go create mode 100644 internal/lambda-managed-instances/utils/buffer_pool.go create mode 100644 internal/lambda-managed-instances/utils/buffer_pool_test.go create mode 100644 internal/lambda-managed-instances/utils/file_utils.go create mode 100644 internal/lambda-managed-instances/utils/file_utils_test.go create mode 100644 internal/lambda-managed-instances/utils/invariant/invariant.go create mode 100644 internal/lambda-managed-instances/utils/invariant/invariant_test.go create mode 100644 internal/lambda-managed-instances/utils/invariant/mock_violation_executor.go create mode 100644 internal/lambda-managed-instances/utils/invariant/model.go create mode 100644 internal/lambda-managed-instances/utils/invariant/panic.go create mode 100644 internal/lambda-managed-instances/utils/invariant/panic_test.go create mode 100644 internal/lambda-managed-instances/utils/io.go create mode 100644 internal/lambda-managed-instances/utils/io_test.go create mode 100644 internal/lambda-managed-instances/utils/mock_file_util.go create mode 100644 internal/lambda-managed-instances/utils/mocks.go create mode 100644 licenses.tpl diff --git a/THIRD-PARTY-LICENSES.md b/THIRD-PARTY-LICENSES.md index 406e06d..6852840 100644 --- a/THIRD-PARTY-LICENSES.md +++ b/THIRD-PARTY-LICENSES.md @@ -1,49 +1,134 @@ -** Go3p-Github-Pkg-Errors; version 1.x -- -https://godoc.org/github.com/pkg/errors -Copyright (c) 2015, Dave Cheney -All rights reserved. + +## github.com/davecgh/go-spew/spew + +* Name: github.com/davecgh/go-spew/spew +* Version: v1.1.1 +* License: [ISC](https://github.com/davecgh/go-spew/blob/v1.1.1/LICENSE) + +``` +ISC License + +Copyright (c) 2012-2016 Dave Collins + +Permission to use, copy, modify, and/or distribute this software for any +purpose with or without fee is hereby granted, provided that the above +copyright notice and this permission notice appear in all copies. + +THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES +WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR +ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF +OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +``` + +## github.com/go-chi/chi + +* Name: github.com/go-chi/chi +* Version: v1.5.5 +* License: [MIT](https://github.com/go-chi/chi/blob/v1.5.5/LICENSE) + +``` +Copyright (c) 2015-present Peter Kieltyka (https://github.com/pkieltyka), Google Inc. + +MIT License + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +``` + +## github.com/go-chi/chi/v5 + +* Name: github.com/go-chi/chi/v5 +* Version: v5.2.2 +* License: [MIT](https://github.com/go-chi/chi/blob/v5.2.2/LICENSE) + +``` +Copyright (c) 2015-present Peter Kieltyka (https://github.com/pkieltyka), Google Inc. + +MIT License + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +``` + +## github.com/google/uuid + +* Name: github.com/google/uuid +* Version: v1.6.0 +* License: [BSD-3-Clause](https://github.com/google/uuid/blob/v1.6.0/LICENSE) + +``` +Copyright (c) 2009,2014 Google Inc. All rights reserved. Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -* Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. - -* Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ------- +``` -** Go3p-Golang-X-Sys; version 1.x -- https://github.com/golang/sys -Copyright (c) 2009 The Go Authors. All rights reserved. -** go; version 1.13.8 -- https://github.com/golang/go/ -Copyright (c) 2009 The Go Authors. All rights reserved. -** Go3p-Github-Jessevdk-GoFlags; version 0.1.0 -- -https://github.com/jessevdk/go-flags -Copyright (c) 2012 Jesse van den Kieboom. All rights reserved. -** Go3p-Golang-X-Crypto; version 20180728-614d502 -- -https://tip.golang.org/pkg/crypto/ -Copyright (c) 2009 The Go Authors. All rights reserved. -** Go3p-Golang-X-Net; version 20180521-5706520 -- -https://tip.golang.org/pkg/net/ -Copyright (c) 2009 The Go Authors. All rights reserved. -** Go3p-Golang-X-Text; version 1.x -- https://tip.golang.org/pkg/text -Copyright (c) 2009 The Go Authors. All rights reserved. -** google-uuid; version 1.0 -- https://github.com/google/uuid -Copyright (c) 2009,2014 Google Inc. All rights reserved. +## github.com/jessevdk/go-flags + +* Name: github.com/jessevdk/go-flags +* Version: v1.5.0 +* License: [BSD-3-Clause](https://github.com/jessevdk/go-flags/blob/v1.5.0/LICENSE) +``` +Copyright (c) 2012 Jesse van den Kieboom. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -70,38 +155,362 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ------- +``` + +## github.com/orcaman/concurrent-map + +* Name: github.com/orcaman/concurrent-map +* Version: v1.0.0 +* License: [MIT](https://github.com/orcaman/concurrent-map/blob/v1.0.0/LICENSE) + +``` +The MIT License (MIT) + +Copyright (c) 2014 streamrail + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + + +``` + +## github.com/pmezard/go-difflib/difflib + +* Name: github.com/pmezard/go-difflib/difflib +* Version: v1.0.0 +* License: [BSD-3-Clause](https://github.com/pmezard/go-difflib/blob/v1.0.0/LICENSE) + +``` +Copyright (c) 2013, Patrick Mezard +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in the +documentation and/or other materials provided with the distribution. + The names of its contributors may not be used to endorse or promote +products derived from this software without specific prior written +permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS +IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED +TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED +TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +``` + +## github.com/santhosh-tekuri/jsonschema/v5 + +* Name: github.com/santhosh-tekuri/jsonschema/v5 +* Version: v5.3.1 +* License: [Apache-2.0](https://github.com/santhosh-tekuri/jsonschema/blob/v5.3.1/LICENSE) + +``` + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ -** Go3p-Github-GoChi-Chi; version 3.3.2 -- https://github.com/go-chi/chi -Copyright (c) 2015-present Peter Kieltyka (https://github.com/pkieltyka), -Google Inc. -** Go3p-Github-GoChi-Render; version 1.0.0 -- https://github.com/go-chi/render -Copyright (c) 2016-Present https://github.com/go-chi authors + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. +``` + +## github.com/sirupsen/logrus + +* Name: github.com/sirupsen/logrus +* Version: v1.9.3 +* License: [MIT](https://github.com/sirupsen/logrus/blob/v1.9.3/LICENSE) + +``` +The MIT License (MIT) + +Copyright (c) 2014 Simon Eskildsen + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. + +``` + +## github.com/stretchr/objx + +* Name: github.com/stretchr/objx +* Version: v0.5.2 +* License: [MIT](https://github.com/stretchr/objx/blob/v0.5.2/LICENSE) + +``` +The MIT License + +Copyright (c) 2014 Stretchr, Inc. +Copyright (c) 2017-2018 objx contributors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +``` + +## github.com/stretchr/testify + +* Name: github.com/stretchr/testify +* Version: v1.9.0 +* License: [MIT](https://github.com/stretchr/testify/blob/v1.9.0/LICENSE) + +``` MIT License -Permission is hereby granted, free of charge, to any person obtaining a copy of -this software and associated documentation files (the "Software"), to deal in -the Software without restriction, including without limitation the rights to -use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies -of -the Software, and to permit persons to whom the Software is furnished to do so, -subject to the following conditions: +Copyright (c) 2012-2020 Mat Ryer, Tyler Bunnell and contributors. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS -FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR -COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER -IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN -CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +``` ------- +## golang.org/x/sys/unix -** Go3p-Golang-X-Sync; version 1.x -- https://github.com/golang/sync/ +* Name: golang.org/x/sys/unix +* Version: v0.14.0 +* License: [BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.14.0:LICENSE) + +``` Copyright (c) 2009 The Go Authors. All rights reserved. Redistribution and use in source and binary forms, with or without @@ -130,52 +539,64 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ------- +``` -** sirupsen-logrus; version 1.0.6 -- https://github.com/sirupsen/logrus -Copyright (c) 2014 Simon Eskildsen +## gopkg.in/yaml.v3 -The MIT License (MIT) +* Name: gopkg.in/yaml.v3 +* Version: v3.0.1 +* License: [MIT](https://github.com/go-yaml/yaml/blob/v3.0.1/LICENSE) -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: +``` -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. +This project is covered by two different licenses: MIT and Apache. + +#### MIT License #### + +The following files were ported to Go from C files of libyaml, and thus +are still covered by their original MIT license, with the additional +copyright staring in 2011 when the project was ported over: + + apic.go emitterc.go parserc.go readerc.go scannerc.go + writerc.go yamlh.go yamlprivateh.go + +Copyright (c) 2006-2010 Kirill Simonov +Copyright (c) 2006-2011 Kirill Simonov + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -THE SOFTWARE. +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +### Apache License ### + +All the remaining project files are covered by the Apache license: + +Copyright (c) 2011-2019 Canonical Ltd + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. ------- - -** Go3p-Github-Satori-GoUUID; version 1.1.0 -- -https://github.com/satori/go.uuid -Copyright (C) 2013-2016 by Maxim Bublis - -Permission is hereby granted, free of charge, to any person obtaining -a copy of this software and associated documentation files (the -"Software"), to deal in the Software without restriction, including -without limitation the rights to use, copy, modify, merge, publish, -distribute, sublicense, and/or sell copies of the Software, and to -permit persons to whom the Software is furnished to do so, subject to -the following conditions: - -The above copyright notice and this permission notice shall be -included in all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +``` diff --git a/cmd/aws-lambda-rie/main.go b/cmd/aws-lambda-rie/main.go index d564334..f18b734 100644 --- a/cmd/aws-lambda-rie/main.go +++ b/cmd/aws-lambda-rie/main.go @@ -4,9 +4,17 @@ package main import ( + "os" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/aws-lambda-rie/run" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapidcore/env" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/rie" ) func main() { + if _, ok := os.LookupEnv(env.AWS_LAMBDA_MAX_CONCURRENCY); ok { + run.Run() + return + } rie.Run() } diff --git a/internal/lambda-managed-instances/agents/agent.go b/internal/lambda-managed-instances/agents/agent.go new file mode 100644 index 0000000..b2eaa32 --- /dev/null +++ b/internal/lambda-managed-instances/agents/agent.go @@ -0,0 +1,50 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package agents + +import ( + "log/slog" + "path" + "path/filepath" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/utils" +) + +const ( + ExtensionsDir = "/opt/extensions" +) + +func ListExternalAgentPaths(fileutils utils.FileUtil, dir string, root string) []string { + var agentPaths []string + if !isCanonical(dir) || !isCanonical(root) { + slog.Warn("Agents base paths are not absolute and in canonical form", "dir", dir, "root", root) + return agentPaths + } + fullDir := path.Join(root, dir) + files, err := fileutils.ReadDirectory(fullDir) + if err != nil { + if fileutils.IsNotExist(err) { + slog.Info("The extension's directory does not exist, assuming no extensions to be loaded", "fullDir", fullDir) + } else { + + slog.Error("Cannot list external agents", "err", err) + } + + return agentPaths + } + + for _, file := range files { + if !file.IsDir() { + + p := path.Join("/", dir, file.Name()) + agentPaths = append(agentPaths, p) + } + } + return agentPaths +} + +func isCanonical(path string) bool { + absPath, err := filepath.Abs(path) + return err == nil && absPath == path +} diff --git a/internal/lambda-managed-instances/agents/agent_test.go b/internal/lambda-managed-instances/agents/agent_test.go new file mode 100644 index 0000000..d31ddfc --- /dev/null +++ b/internal/lambda-managed-instances/agents/agent_test.go @@ -0,0 +1,225 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package agents + +import ( + "os" + "path" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/utils" +) + +type fileInfo struct { + name string + mode os.FileMode + size int64 + target string +} + +func mkFile(name string, size int64, perm os.FileMode) fileInfo { + return fileInfo{ + name: name, + mode: perm, + size: size, + target: "", + } +} + +func mkDir(name string, perm os.FileMode) fileInfo { + return fileInfo{ + name: name, + mode: perm | os.ModeDir, + size: 0, + target: "", + } +} + +func mkLink(name, target string) fileInfo { + return fileInfo{ + name: name, + mode: os.ModeSymlink, + size: 0, + target: target, + } +} + +func createFileTree(root string, fs []fileInfo) error { + for _, info := range fs { + filename := info.name + dir := path.Join(root, path.Dir(filename)) + name := path.Base(filename) + err := os.MkdirAll(dir, 0o775) + if err != nil && !os.IsExist(err) { + return err + } + switch { + case os.ModeDir == info.mode&os.ModeDir: + err := os.Mkdir(path.Join(dir, name), info.mode&os.ModePerm) + if err != nil { + return err + } + case os.ModeSymlink == info.mode&os.ModeSymlink: + target := path.Join(root, info.target) + _, err = os.Stat(target) + if err != nil { + return err + } + err := os.Symlink(target, path.Join(dir, name)) + if err != nil { + return err + } + default: + file, err := os.OpenFile(path.Join(dir, name), os.O_RDWR|os.O_CREATE, info.mode&os.ModePerm) + if err != nil { + return err + } + if err := file.Truncate(info.size); err != nil { + return err + } + if err := file.Close(); err != nil { + return err + } + } + } + + return nil +} + +func TestBaseEmpty(t *testing.T) { + assert := assert.New(t) + + fs := []fileInfo{ + mkDir("/opt/extensions", 0o777), + } + + tmpDir, err := os.MkdirTemp("", "ext-") + require.NoError(t, err) + + require.NoError(t, createFileTree(tmpDir, fs)) + defer func() { require.NoError(t, os.RemoveAll(tmpDir)) }() + + fileUtils := utils.NewFileUtil() + + agents := ListExternalAgentPaths(fileUtils, path.Join(tmpDir, "/opt/extensions"), "/") + assert.Equal(0, len(agents)) +} + +func TestBaseNotExist(t *testing.T) { + assert := assert.New(t) + + fileUtils := utils.NewFileUtil() + + agents := ListExternalAgentPaths(fileUtils, "/path/which/does/not/exist", "/") + assert.Equal(0, len(agents)) +} + +func TestChrootNotExist(t *testing.T) { + assert := assert.New(t) + + fileUtils := utils.NewFileUtil() + + agents := ListExternalAgentPaths(fileUtils, "/bin", "/does/not/exist") + assert.Equal(0, len(agents)) +} + +func TestBaseNotDir(t *testing.T) { + assert := assert.New(t) + + fs := []fileInfo{ + mkFile("/opt/extensions", 1, 0o777), + } + tmpDir, err := os.MkdirTemp("", "ext-") + require.NoError(t, err) + + require.NoError(t, createFileTree(tmpDir, fs)) + defer func() { require.NoError(t, os.RemoveAll(tmpDir)) }() + + path := path.Join(tmpDir, "/opt/extensions") + + fileUtils := utils.NewFileUtil() + agents := ListExternalAgentPaths(fileUtils, path, "/") + assert.Equal(0, len(agents)) +} + +func TestFindAgentMixed(t *testing.T) { + assert := assert.New(t) + + listed := []fileInfo{ + mkFile("/opt/extensions/ok2", 1, 0o777), + mkFile("/opt/extensions/ok1", 1, 0o777), + mkFile("/opt/extensions/not_exec", 1, 0o666), + mkFile("/opt/extensions/not_read", 1, 0o333), + mkFile("/opt/extensions/empty_file", 0, 0o777), + mkLink("/opt/extensions/link", "/opt/extensions/ok1"), + } + + unlisted := []fileInfo{ + mkDir("/opt/extensions/empty_dir", 0o777), + mkDir("/opt/extensions/nonempty_dir", 0o777), + mkFile("/opt/extensions/nonempty_dir/notok", 1, 0o777), + } + + fs := append([]fileInfo{}, listed...) + fs = append(fs, unlisted...) + + tmpDir, err := os.MkdirTemp("", "ext-") + require.NoError(t, err) + + require.NoError(t, createFileTree(tmpDir, fs)) + defer func() { require.NoError(t, os.RemoveAll(tmpDir)) }() + + path := path.Join(tmpDir, "/opt/extensions") + fileUtils := utils.NewFileUtil() + agentPaths := ListExternalAgentPaths(fileUtils, path, "/") + assert.Equal(len(listed), len(agentPaths)) + last := "" + for index := range listed { + if len(last) > 0 { + assert.GreaterOrEqual(agentPaths[index], last) + } + last = agentPaths[index] + } +} + +func TestFindAgentMixedInChroot(t *testing.T) { + assert := assert.New(t) + + listed := []fileInfo{ + mkFile("/opt/extensions/ok2", 1, 0o777), + mkFile("/opt/extensions/ok1", 1, 0o777), + mkFile("/opt/extensions/not_exec", 1, 0o666), + mkFile("/opt/extensions/not_read", 1, 0o333), + mkFile("/opt/extensions/empty_file", 0, 0o777), + mkLink("/opt/extensions/link", "/opt/extensions/ok1"), + } + + unlisted := []fileInfo{ + mkDir("/opt/extensions/empty_dir", 0o777), + mkDir("/opt/extensions/nonempty_dir", 0o777), + mkFile("/opt/extensions/nonempty_dir/notok", 1, 0o777), + } + + fs := append([]fileInfo{}, listed...) + fs = append(fs, unlisted...) + + rootDir, err := os.MkdirTemp("", "rootfs") + require.NoError(t, err) + + require.NoError(t, createFileTree(rootDir, fs)) + defer func() { require.NoError(t, os.RemoveAll(rootDir)) }() + fileUtils := utils.NewFileUtil() + agentPaths := ListExternalAgentPaths(fileUtils, "/opt/extensions", rootDir) + assert.Equal(len(listed), len(agentPaths)) + last := "" + for index := range listed { + if len(last) > 0 { + assert.GreaterOrEqual(agentPaths[index], last) + } + last = agentPaths[index] + } +} diff --git a/internal/lambda-managed-instances/appctx/appctx.go b/internal/lambda-managed-instances/appctx/appctx.go new file mode 100644 index 0000000..cadabfb --- /dev/null +++ b/internal/lambda-managed-instances/appctx/appctx.go @@ -0,0 +1,79 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package appctx + +import ( + "sync" +) + +type Key int + +const ( + AppCtxInvokeErrorTraceDataKey Key = iota + + AppCtxRuntimeReleaseKey + + AppCtxInteropServerKey + + AppCtxResponseSenderKey + + AppCtxFirstFatalErrorKey +) + +type ApplicationContext interface { + Store(key Key, value interface{}) + Load(key Key) (value interface{}, ok bool) + Delete(key Key) + GetOrDefault(key Key, defaultValue interface{}) interface{} + StoreIfNotExists(key Key, value interface{}) interface{} +} + +type applicationContext struct { + mux *sync.Mutex + m map[Key]interface{} +} + +func (appCtx *applicationContext) Store(key Key, value interface{}) { + appCtx.mux.Lock() + defer appCtx.mux.Unlock() + appCtx.m[key] = value +} + +func (appCtx *applicationContext) StoreIfNotExists(key Key, value interface{}) interface{} { + appCtx.mux.Lock() + defer appCtx.mux.Unlock() + existing, found := appCtx.m[key] + if found { + return existing + } + appCtx.m[key] = value + return nil +} + +func (appCtx *applicationContext) Load(key Key) (value interface{}, ok bool) { + appCtx.mux.Lock() + defer appCtx.mux.Unlock() + value, ok = appCtx.m[key] + return value, ok +} + +func (appCtx *applicationContext) Delete(key Key) { + appCtx.mux.Lock() + defer appCtx.mux.Unlock() + delete(appCtx.m, key) +} + +func (appCtx *applicationContext) GetOrDefault(key Key, defaultValue interface{}) interface{} { + if value, ok := appCtx.Load(key); ok { + return value + } + return defaultValue +} + +func NewApplicationContext() ApplicationContext { + return &applicationContext{ + mux: &sync.Mutex{}, + m: make(map[Key]interface{}), + } +} diff --git a/internal/lambda-managed-instances/appctx/appctxutil.go b/internal/lambda-managed-instances/appctx/appctxutil.go new file mode 100644 index 0000000..d211f1e --- /dev/null +++ b/internal/lambda-managed-instances/appctx/appctxutil.go @@ -0,0 +1,150 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package appctx + +import ( + "context" + "log/slog" + "net/http" + "strings" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" +) + +type ReqCtxKey int + +const ReqCtxApplicationContextKey ReqCtxKey = iota + +const MaxRuntimeReleaseLength = 128 + +func FromRequest(request *http.Request) ApplicationContext { + return request.Context().Value(ReqCtxApplicationContextKey).(ApplicationContext) +} + +func RequestWithAppCtx(request *http.Request, appCtx ApplicationContext) *http.Request { + return request.WithContext(context.WithValue(request.Context(), ReqCtxApplicationContextKey, appCtx)) +} + +func GetRuntimeRelease(appCtx ApplicationContext) string { + return appCtx.GetOrDefault(AppCtxRuntimeReleaseKey, "").(string) +} + +func GetUserAgentFromRequest(request *http.Request) string { + runtimeRelease := "" + userAgent := request.Header.Get("User-Agent") + + if fields := strings.Fields(userAgent); len(fields) > 0 && len(fields[0]) > 0 { + runtimeRelease = fields[0] + } + return runtimeRelease +} + +func CreateRuntimeReleaseFromRequest(request *http.Request, runtimeRelease string) string { + lambdaRuntimeFeaturesHeader := request.Header.Get("Lambda-Runtime-Features") + + lambdaRuntimeFeaturesHeader = strings.ReplaceAll(lambdaRuntimeFeaturesHeader, "(", "") + lambdaRuntimeFeaturesHeader = strings.ReplaceAll(lambdaRuntimeFeaturesHeader, ")", "") + + numberOfAppendedFeatures := 0 + + runtimeReleaseLength := len(runtimeRelease) + if runtimeReleaseLength == 0 { + runtimeReleaseLength = len("Unknown") + } + availableLength := MaxRuntimeReleaseLength - runtimeReleaseLength - 3 + var lambdaRuntimeFeatures []string + + for _, feature := range strings.Fields(lambdaRuntimeFeaturesHeader) { + featureLength := len(feature) + + if featureLength <= availableLength-numberOfAppendedFeatures { + availableLength -= featureLength + lambdaRuntimeFeatures = append(lambdaRuntimeFeatures, feature) + numberOfAppendedFeatures++ + } + } + + if len(lambdaRuntimeFeatures) > 0 { + if runtimeRelease == "" { + runtimeRelease = "Unknown" + } + runtimeRelease += " (" + strings.Join(lambdaRuntimeFeatures, " ") + ")" + } + + return runtimeRelease +} + +func UpdateAppCtxWithRuntimeRelease(request *http.Request, appCtx ApplicationContext) bool { + + if appCtxRuntimeRelease := GetRuntimeRelease(appCtx); len(appCtxRuntimeRelease) > 0 { + + if runtimeReleaseWithFeatures := CreateRuntimeReleaseFromRequest(request, appCtxRuntimeRelease); len(runtimeReleaseWithFeatures) > len(appCtxRuntimeRelease) && + appCtxRuntimeRelease[len(appCtxRuntimeRelease)-1] != ')' { + appCtx.Store(AppCtxRuntimeReleaseKey, runtimeReleaseWithFeatures) + return true + } + return false + } + + if runtimeReleaseWithFeatures := CreateRuntimeReleaseFromRequest(request, + GetUserAgentFromRequest(request)); runtimeReleaseWithFeatures != "" { + appCtx.Store(AppCtxRuntimeReleaseKey, runtimeReleaseWithFeatures) + return true + } + return false +} + +func StoreInvokeErrorTraceData(appCtx ApplicationContext, invokeError *interop.InvokeErrorTraceData) { + appCtx.Store(AppCtxInvokeErrorTraceDataKey, invokeError) +} + +func LoadInvokeErrorTraceData(appCtx ApplicationContext) *interop.InvokeErrorTraceData { + v, ok := appCtx.Load(AppCtxInvokeErrorTraceDataKey) + if ok { + return v.(*interop.InvokeErrorTraceData) + } + return nil +} + +func StoreInteropServer(appCtx ApplicationContext, server interop.Server) { + appCtx.Store(AppCtxInteropServerKey, server) +} + +func LoadInteropServer(appCtx ApplicationContext) interop.Server { + v, ok := appCtx.Load(AppCtxInteropServerKey) + if ok { + return v.(interop.Server) + } + return nil +} + +func StoreResponseSender(appCtx ApplicationContext, server interop.InvokeResponseSender) { + appCtx.Store(AppCtxResponseSenderKey, server) +} + +func LoadResponseSender(appCtx ApplicationContext) interop.InvokeResponseSender { + v, ok := appCtx.Load(AppCtxResponseSenderKey) + if ok { + return v.(interop.InvokeResponseSender) + } + return nil +} + +func StoreFirstFatalError(appCtx ApplicationContext, err model.CustomerError) { + if existing := appCtx.StoreIfNotExists(AppCtxFirstFatalErrorKey, err); existing != nil { + slog.Warn("Omitting fatal error: already stored", "err", err, "existing", existing.(model.CustomerError)) + return + } + + slog.Warn("First fatal error stored in appctx", "errorType", err.ErrorType()) +} + +func LoadFirstFatalError(appCtx ApplicationContext) (customerError model.CustomerError, found bool) { + v, found := appCtx.Load(AppCtxFirstFatalErrorKey) + if !found { + return model.CustomerError{}, false + } + return v.(model.CustomerError), true +} diff --git a/internal/lambda-managed-instances/appctx/appctxutil_test.go b/internal/lambda-managed-instances/appctx/appctxutil_test.go new file mode 100644 index 0000000..fab0a03 --- /dev/null +++ b/internal/lambda-managed-instances/appctx/appctxutil_test.go @@ -0,0 +1,186 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package appctx + +import ( + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" +) + +func runTestRequestWithUserAgent(t *testing.T, userAgent string, expectedRuntimeRelease string) { + + req := httptest.NewRequest("", "/", nil) + req.Header.Set("User-Agent", userAgent) + request := RequestWithAppCtx(req, NewApplicationContext()) + appCtx := request.Context().Value(ReqCtxApplicationContextKey).(ApplicationContext) + + ok := UpdateAppCtxWithRuntimeRelease(request, appCtx) + + assert.True(t, ok) + ctxRuntimeRelease, ok := appCtx.Load(AppCtxRuntimeReleaseKey) + assert.True(t, ok) + assert.Equal(t, expectedRuntimeRelease, ctxRuntimeRelease, "failed to extract runtime_release token") +} + +func TestCreateRuntimeReleaseFromRequest(t *testing.T) { + tests := map[string]struct { + userAgentHeader string + lambdaRuntimeFeaturesHeader string + expectedRuntimeRelease string + }{ + "No User-Agent header": { + userAgentHeader: "", + lambdaRuntimeFeaturesHeader: "httpcl/2.0 execwr", + expectedRuntimeRelease: "Unknown (httpcl/2.0 execwr)", + }, + "No Lambda-Runtime-Features header": { + userAgentHeader: "Node.js/14.16.0", + lambdaRuntimeFeaturesHeader: "", + expectedRuntimeRelease: "Node.js/14.16.0", + }, + "Lambda-Runtime-Features header with additional spaces": { + userAgentHeader: "Node.js/14.16.0", + lambdaRuntimeFeaturesHeader: "httpcl/2.0 execwr", + expectedRuntimeRelease: "Node.js/14.16.0 (httpcl/2.0 execwr)", + }, + "Lambda-Runtime-Features header with special characters": { + userAgentHeader: "Node.js/14.16.0", + lambdaRuntimeFeaturesHeader: "httpcl/2.0@execwr-1 abcd?efg nodewr/(4.33)) nodewr/4.3", + expectedRuntimeRelease: "Node.js/14.16.0 (httpcl/2.0@execwr-1 abcd?efg nodewr/4.33 nodewr/4.3)", + }, + "Lambda-Runtime-Features header with long Lambda-Runtime-Features header": { + userAgentHeader: "Node.js/14.16.0", + lambdaRuntimeFeaturesHeader: strings.Repeat("abcdef ", MaxRuntimeReleaseLength/7), + expectedRuntimeRelease: "Node.js/14.16.0 (" + strings.Repeat("abcdef ", (MaxRuntimeReleaseLength-18-6)/7) + "abcdef)", + }, + "Lambda-Runtime-Features header with long Lambda-Runtime-Features header with UTF-8 characters": { + userAgentHeader: "Node.js/14.16.0", + lambdaRuntimeFeaturesHeader: strings.Repeat("我爱亚马逊 ", MaxRuntimeReleaseLength/16), + expectedRuntimeRelease: "Node.js/14.16.0 (" + strings.Repeat("我爱亚马逊 ", (MaxRuntimeReleaseLength-18-15)/16) + "我爱亚马逊)", + }, + } + + for _, tc := range tests { + req := httptest.NewRequest("", "/", nil) + if tc.userAgentHeader != "" { + req.Header.Set("User-Agent", tc.userAgentHeader) + } + if tc.lambdaRuntimeFeaturesHeader != "" { + req.Header.Set("Lambda-Runtime-Features", tc.lambdaRuntimeFeaturesHeader) + } + appCtx := NewApplicationContext() + request := RequestWithAppCtx(req, appCtx) + + UpdateAppCtxWithRuntimeRelease(request, appCtx) + runtimeRelease := GetRuntimeRelease(appCtx) + + assert.LessOrEqual(t, len(runtimeRelease), MaxRuntimeReleaseLength) + assert.Equal(t, tc.expectedRuntimeRelease, runtimeRelease) + } +} + +func TestUpdateAppCtxWithRuntimeRelease(t *testing.T) { + type pair struct { + in, wanted string + } + pairs := []pair{ + {"Mozilla/5.0", "Mozilla/5.0"}, + {"Mozilla/6.0 (Windows NT 6.1; Win64; x64; rv:47.0) Gecko/20100101 Firefox/47.0", "Mozilla/6.0"}, + } + for _, p := range pairs { + runTestRequestWithUserAgent(t, p.in, p.wanted) + } +} + +func TestUpdateAppCtxWithRuntimeReleaseWithoutUserAgent(t *testing.T) { + + request := RequestWithAppCtx(httptest.NewRequest("", "/", nil), NewApplicationContext()) + appCtx := request.Context().Value(ReqCtxApplicationContextKey).(ApplicationContext) + + ok := UpdateAppCtxWithRuntimeRelease(request, appCtx) + + assert.False(t, ok) + _, ok = appCtx.Load(AppCtxRuntimeReleaseKey) + assert.False(t, ok) +} + +func TestUpdateAppCtxWithRuntimeReleaseWithBlankUserAgent(t *testing.T) { + + req := httptest.NewRequest("", "/", nil) + req.Header.Set("User-Agent", " ") + request := RequestWithAppCtx(req, NewApplicationContext()) + appCtx := request.Context().Value(ReqCtxApplicationContextKey).(ApplicationContext) + + ok := UpdateAppCtxWithRuntimeRelease(request, appCtx) + + assert.False(t, ok) + _, ok = appCtx.Load(AppCtxRuntimeReleaseKey) + assert.False(t, ok) +} + +func TestUpdateAppCtxWithRuntimeReleaseWithLambdaRuntimeFeatures(t *testing.T) { + + req := httptest.NewRequest("", "/", nil) + req.Header.Set("User-Agent", "Node.js/14.16.0") + req.Header.Set("Lambda-Runtime-Features", "httpcl/2.0 execwr nodewr/4.3") + request := RequestWithAppCtx(req, NewApplicationContext()) + appCtx := request.Context().Value(ReqCtxApplicationContextKey).(ApplicationContext) + + ok := UpdateAppCtxWithRuntimeRelease(request, appCtx) + + assert.True(t, ok, "runtime_release updated based only on User-Agent and valid features") + ctxRuntimeRelease, ok := appCtx.Load(AppCtxRuntimeReleaseKey) + assert.True(t, ok) + assert.Equal(t, "Node.js/14.16.0 (httpcl/2.0 execwr nodewr/4.3)", ctxRuntimeRelease) +} + +func TestUpdateAppCtxWithRuntimeReleaseMultipleTimes(t *testing.T) { + + firstValue := "Value1" + secondValue := "Value2" + + req := httptest.NewRequest("", "/", nil) + req.Header.Set("User-Agent", firstValue) + request := RequestWithAppCtx(req, NewApplicationContext()) + appCtx := request.Context().Value(ReqCtxApplicationContextKey).(ApplicationContext) + + ok := UpdateAppCtxWithRuntimeRelease(request, appCtx) + + assert.True(t, ok) + ctxRuntimeRelease, ok := appCtx.Load(AppCtxRuntimeReleaseKey) + assert.True(t, ok) + assert.Equal(t, firstValue, ctxRuntimeRelease) + + req.Header.Set("User-Agent", secondValue) + + ok = UpdateAppCtxWithRuntimeRelease(request, appCtx) + + assert.False(t, ok, "failed to prevent second update of runtime_release") + ctxRuntimeRelease, ok = appCtx.Load(AppCtxRuntimeReleaseKey) + assert.True(t, ok) + assert.Equal(t, firstValue, ctxRuntimeRelease, "failed to prevent second update of runtime_release") +} + +func TestFirstFatalError(t *testing.T) { + appCtx := NewApplicationContext() + + _, found := LoadFirstFatalError(appCtx) + require.False(t, found) + + StoreFirstFatalError(appCtx, model.WrapErrorIntoCustomerFatalError(nil, model.ErrorAgentCrash)) + v, found := LoadFirstFatalError(appCtx) + require.True(t, found) + require.Equal(t, model.ErrorAgentCrash, v.ErrorType()) + + StoreFirstFatalError(appCtx, model.WrapErrorIntoCustomerFatalError(nil, model.ErrorAgentExit)) + v, found = LoadFirstFatalError(appCtx) + require.True(t, found) + require.Equal(t, model.ErrorAgentCrash, v.ErrorType()) +} diff --git a/internal/lambda-managed-instances/aws-lambda-rie/internal/app.go b/internal/lambda-managed-instances/aws-lambda-rie/internal/app.go new file mode 100644 index 0000000..a25091e --- /dev/null +++ b/internal/lambda-managed-instances/aws-lambda-rie/internal/app.go @@ -0,0 +1,97 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package internal + +import ( + "context" + "log/slog" + "net/http" + "sync" + "time" + + "github.com/go-chi/chi" + + rieinvoke "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/aws-lambda-rie/internal/invoke" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/invoke" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/logging" + intmodel "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid" + rapidmodel "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" +) + +func NewHTTPHandler(raptorApp raptorApp, initMsg intmodel.InitRequestMessage) *HTTPHandler { + h := &HTTPHandler{ + app: raptorApp, + initMsg: initMsg, + } + + h.initOnceValue = sync.OnceValue(func() rapidmodel.AppError { + initCtx, cancel := context.WithTimeout(context.Background(), time.Duration(h.initMsg.InitTimeout)) + defer cancel() + + dummyInitMetrics := rapid.NewInitMetrics(nil) + res := h.app.Init(initCtx, &h.initMsg, dummyInitMetrics) + return res + }) + + router := chi.NewRouter() + router.Post("/2015-03-31/functions/function/invocations", h.invoke) + h.router = router + + return h +} + +type HTTPHandler struct { + router *chi.Mux + app raptorApp + initOnceValue func() rapidmodel.AppError + initMsg intmodel.InitRequestMessage +} + +func (h *HTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + h.router.ServeHTTP(w, r) +} + +func (h *HTTPHandler) invoke(w http.ResponseWriter, r *http.Request) { + if err := h.initOnceValue(); err != nil { + h.respondWithError(w, err) + return + } + + invokeReq := rieinvoke.NewRieInvokeRequest(r, w) + ctx := logging.WithInvokeID(r.Context(), invokeReq.InvokeID()) + + metrics := invoke.NewInvokeMetrics(nil, &noOpCounter{}) + metrics.AttachInvokeRequest(invokeReq) + if err, responseSent := h.app.Invoke(ctx, invokeReq, metrics); err != nil { + logging.Err(ctx, "invoke failed", err) + if !responseSent { + h.respondWithError(w, err) + } + } +} + +func (h *HTTPHandler) Init() rapidmodel.AppError { + return h.initOnceValue() +} + +func (h *HTTPHandler) respondWithError(w http.ResponseWriter, err rapidmodel.AppError) { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Error-Type", string(err.ErrorType())) + w.WriteHeader(err.ReturnCode()) + + if _, encodeErr := w.Write([]byte(err.ErrorDetails())); encodeErr != nil { + slog.Error("could not encode error response", "err", encodeErr) + } +} + +type raptorApp interface { + Init(ctx context.Context, req *intmodel.InitRequestMessage, metrics interop.InitMetrics) rapidmodel.AppError + Invoke(ctx context.Context, msg interop.InvokeRequest, metrics interop.InvokeMetrics) (err rapidmodel.AppError, responseSent bool) +} + +type noOpCounter struct{} + +func (c *noOpCounter) AddInvoke(_ uint64) {} diff --git a/internal/lambda-managed-instances/aws-lambda-rie/internal/app_test.go b/internal/lambda-managed-instances/aws-lambda-rie/internal/app_test.go new file mode 100644 index 0000000..31144c8 --- /dev/null +++ b/internal/lambda-managed-instances/aws-lambda-rie/internal/app_test.go @@ -0,0 +1,143 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package internal + +import ( + "errors" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + intmodel "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/model" + rapidmodel "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" +) + +func TestApp_ServeHTTP(t *testing.T) { + tests := []struct { + name string + initResponse rapidmodel.AppError + invokeErr rapidmodel.AppError + responseSent bool + expectedStatus int + expectedError bool + expectedErrorType string + expectedJSON string + }{ + { + name: "successful invocation", + initResponse: nil, + invokeErr: nil, + responseSent: false, + expectedStatus: 200, + expectedError: false, + }, + { + name: "init failure", + initResponse: rapidmodel.NewCustomerError(rapidmodel.ErrorRuntimeUnknown), + expectedStatus: 200, + expectedError: true, + expectedErrorType: "Runtime.Unknown", + expectedJSON: `{"errorType":"Runtime.Unknown"}`, + }, + { + name: "platform error during init", + initResponse: rapidmodel.NewPlatformError(errors.New("platform error"), rapidmodel.ErrorReasonUnknownError), + expectedStatus: 500, + expectedError: true, + expectedErrorType: "UnknownError", + expectedJSON: `{"errorType":"UnknownError"}`, + }, + { + name: "invoke error with response not sent", + initResponse: nil, + invokeErr: rapidmodel.NewCustomerError(rapidmodel.ErrorRuntimeUnknown), + responseSent: false, + expectedStatus: 200, + expectedError: true, + expectedErrorType: "Runtime.Unknown", + expectedJSON: `{"errorType":"Runtime.Unknown"}`, + }, + { + name: "invoke error with response already sent", + initResponse: nil, + invokeErr: rapidmodel.NewCustomerError(rapidmodel.ErrorRuntimeUnknown), + responseSent: true, + expectedStatus: 200, + expectedError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockApp := newMockRaptorApp(t) + defer mockApp.AssertExpectations(t) + mockApp.On("Init", mock.Anything, mock.Anything, mock.Anything).Return(tt.initResponse) + if tt.initResponse == nil { + mockApp.On("Invoke", mock.Anything, mock.Anything, mock.Anything).Return(tt.invokeErr, tt.responseSent) + } + + initMsg := intmodel.InitRequestMessage{ + InitTimeout: intmodel.DurationMS(10 * time.Second), + Handler: "test.handler", + } + app := NewHTTPHandler(mockApp, initMsg) + + req := httptest.NewRequest("POST", "/2015-03-31/functions/function/invocations", strings.NewReader("{}")) + w := httptest.NewRecorder() + + app.ServeHTTP(w, req) + + assert.Equal(t, tt.expectedStatus, w.Code) + + if tt.expectedError { + assert.Equal(t, "application/json", w.Header().Get("Content-Type")) + assert.Equal(t, tt.expectedErrorType, w.Header().Get("Error-Type")) + assert.JSONEq(t, tt.expectedJSON, w.Body.String()) + } else if tt.responseSent { + assert.Empty(t, w.Header().Get("Content-Type")) + assert.Empty(t, w.Header().Get("Error-Type")) + assert.Empty(t, w.Body.String()) + } + }) + } +} + +func TestApp_ServeHTTP_Concurrent(t *testing.T) { + mockApp := &mockRaptorApp{} + defer mockApp.AssertExpectations(t) + mockApp.On("Init", mock.Anything, mock.Anything, mock.Anything). + Run(func(args mock.Arguments) { + time.Sleep(100 * time.Millisecond) + }). + Return(rapidmodel.NewCustomerError(rapidmodel.ErrorRuntimeUnknown)). + Once() + + initMsg := intmodel.InitRequestMessage{ + InitTimeout: intmodel.DurationMS(10 * time.Second), + Handler: "test.handler", + } + app := NewHTTPHandler(mockApp, initMsg) + + var wg sync.WaitGroup + const invokes = 10 + wg.Add(invokes) + for i := 0; i < invokes; i++ { + go func() { + req := httptest.NewRequest("POST", "/2015-03-31/functions/function/invocations", strings.NewReader("{}")) + w := httptest.NewRecorder() + app.ServeHTTP(w, req) + assert.Equal(t, 200, w.Code) + assert.Equal(t, "application/json", w.Header().Get("Content-Type")) + assert.Equal(t, "Runtime.Unknown", w.Header().Get("Error-Type")) + assert.JSONEq(t, `{"errorType":"Runtime.Unknown"}`, w.Body.String()) + wg.Done() + }() + } + wg.Wait() +} diff --git a/internal/lambda-managed-instances/aws-lambda-rie/internal/init.go b/internal/lambda-managed-instances/aws-lambda-rie/internal/init.go new file mode 100644 index 0000000..667554a --- /dev/null +++ b/internal/lambda-managed-instances/aws-lambda-rie/internal/init.go @@ -0,0 +1,147 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package internal + +import ( + "fmt" + "log/slog" + "net/netip" + "os" + "strconv" + "strings" + "time" + + intmodel "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapidcore/env" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/utils" +) + +type InitRequestMessageFactory func(fileUtil utils.FileUtil, args []string) (intmodel.InitRequestMessage, model.AppError) + +func GetInitRequestMessage(fileUtil utils.FileUtil, args []string) (intmodel.InitRequestMessage, model.AppError) { + accountID := getEnvOrDefault("AWS_ACCOUNT_ID", "123456789012") + invokeTimeout := intmodel.DurationMS(getEnvOrDefaultInt("AWS_LAMBDA_FUNCTION_TIMEOUT", 300) * int(time.Second)) + functionName := getEnvOrDefault(env.AWS_LAMBDA_FUNCTION_NAME, "test_function") + region := getEnvOrDefault(env.AWS_REGION, "us-east-1") + cwd := getCwd() + + cmd, err := getBootstrap(fileUtil, args, cwd) + if err != nil { + return intmodel.InitRequestMessage{}, err + } + + return intmodel.InitRequestMessage{ + AccountID: accountID, + AwsKey: os.Getenv(env.AWS_ACCESS_KEY_ID), + AwsSecret: os.Getenv(env.AWS_SECRET_ACCESS_KEY), + AwsSession: os.Getenv(env.AWS_SESSION_TOKEN), + AwsRegion: region, + EnvVars: env.KVPairStringsToMap(os.Environ()), + MemorySizeBytes: getEnvOrDefaultInt(env.AWS_LAMBDA_FUNCTION_MEMORY_SIZE, 3008) * 1024 * 1024, + FunctionARN: fmt.Sprintf("arn:aws:lambda:%s:%s:function:%s", region, accountID, functionName), + FunctionVersion: getEnvOrDefault(env.AWS_LAMBDA_FUNCTION_VERSION, "$LATEST"), + FunctionVersionID: "", + ArtefactType: intmodel.ArtefactTypeZIP, + TaskName: functionName, + Handler: getHandler(args), + InvokeTimeout: invokeTimeout, + InitTimeout: invokeTimeout, + RuntimeVersion: "", + RuntimeArn: "", + RuntimeWorkerCount: getEnvOrDefaultInt(env.AWS_LAMBDA_MAX_CONCURRENCY, 1), + LogFormat: getEnvOrDefault(env.AWS_LAMBDA_LOG_FORMAT, "json"), + LogLevel: os.Getenv(env.AWS_LAMBDA_LOG_LEVEL), + LogGroupName: getEnvOrDefault(env.AWS_LAMBDA_LOG_GROUP_NAME, "/aws/lambda/Functions"), + LogStreamName: getEnvOrDefault(env.AWS_LAMBDA_LOG_STREAM_NAME, "$LATEST"), + TelemetryAPIAddress: intmodel.TelemetryAddr(netip.MustParseAddrPort("127.0.0.1:0")), + TelemetryPassphrase: "", + XRayDaemonAddress: "", + XrayTracingMode: intmodel.XRayTracingModePassThrough, + CurrentWorkingDir: cwd, + RuntimeBinaryCommand: cmd, + AvailabilityZoneId: "", + AmiId: "", + }, nil +} + +func getHandler(args []string) string { + handler := getEnvOrDefault("AWS_LAMBDA_FUNCTION_HANDLER", os.Getenv(env.HANDLER)) + if handler != "" { + return handler + } + + if len(args) > 2 { + return args[len(args)-1] + } + return "" +} + +func getEnvOrDefault(key, defaultVal string) string { + if val, ok := os.LookupEnv(key); ok { + return val + } + return defaultVal +} + +func getEnvOrDefaultInt(key string, defaultVal int) int { + val, ok := os.LookupEnv(key) + if !ok { + return defaultVal + } + + valInt, err := strconv.Atoi(val) + if err != nil { + slog.Warn("Failed to convert environment variable to integer", + "key", key, + "value", val, + "err", err) + return defaultVal + } + return valInt +} + +func getBootstrap(fileUtil utils.FileUtil, args []string, cwd string) (cmd []string, err model.AppError) { + + if len(args) > 1 { + slog.Info("executing bootstrap", "command", args[1]) + return args[1:], nil + } + + candidates := []string{ + cwd + "/bootstrap", + "/var/runtime/bootstrap", + "/var/task/bootstrap", + "/opt/bootstrap", + } + for _, c := range candidates { + stat, err := fileUtil.Stat(c) + if fileUtil.IsNotExist(err) { + slog.Warn("could not find bootstrap file", "path", c) + continue + } + if stat.IsDir() { + slog.Warn("bootstrap file is a directory", "path", c) + continue + } + + return []string{c}, nil + } + + err = model.WrapErrorIntoCustomerInvalidError( + fmt.Errorf("could not find runtime entrypoint in CLI args and in predefined locations: %s", strings.Join(candidates, ", ")), + model.ErrorRuntimeInvalidEntryPoint, + ) + return nil, err +} + +func getCwd() string { + cwd, err := os.Getwd() + if err != nil { + slog.Warn("could not find current working directory. Using default /var/task instead", "err", err) + return "/var/task" + } + + return cwd +} diff --git a/internal/lambda-managed-instances/aws-lambda-rie/internal/init_test.go b/internal/lambda-managed-instances/aws-lambda-rie/internal/init_test.go new file mode 100644 index 0000000..f19912d --- /dev/null +++ b/internal/lambda-managed-instances/aws-lambda-rie/internal/init_test.go @@ -0,0 +1,306 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package internal + +import ( + "net/netip" + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + intmodel "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapidcore/env" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/utils" +) + +func Test_getInitRequestMessage(t *testing.T) { + + for k := range env.Defined { + require.NoError(t, os.Unsetenv(k)) + } + + tests := []struct { + name string + args []string + env map[string]string + want intmodel.InitRequestMessage + }{ + { + name: "default_values", + args: []string{"aws-lambda-rie", "/path/to/bootstrap"}, + env: map[string]string{}, + want: intmodel.InitRequestMessage{ + AccountID: "123456789012", + AwsKey: "", + AwsSecret: "", + AwsSession: "", + AwsRegion: "us-east-1", + EnvVars: map[string]string{}, + MemorySizeBytes: 3008 * 1024 * 1024, + FunctionARN: "arn:aws:lambda:us-east-1:123456789012:function:test_function", + FunctionVersion: "$LATEST", + FunctionVersionID: "", + ArtefactType: intmodel.ArtefactTypeZIP, + TaskName: "test_function", + Handler: "", + InvokeTimeout: intmodel.DurationMS(300 * time.Second), + InitTimeout: intmodel.DurationMS(300 * time.Second), + RuntimeVersion: "", + RuntimeArn: "", + RuntimeWorkerCount: 1, + LogFormat: "json", + LogLevel: "", + LogGroupName: "/aws/lambda/Functions", + LogStreamName: "$LATEST", + TelemetryAPIAddress: intmodel.TelemetryAddr(netip.MustParseAddrPort("127.0.0.1:0")), + TelemetryPassphrase: "", + XRayDaemonAddress: "", + XrayTracingMode: intmodel.XRayTracingModePassThrough, + CurrentWorkingDir: "REPLACE", + RuntimeBinaryCommand: []string{"/path/to/bootstrap"}, + AvailabilityZoneId: "", + AmiId: "", + }, + }, + { + name: "all_env_vars_and_args", + args: []string{"app", "/custom/bootstrap", "custom_handler"}, + env: map[string]string{ + "AWS_ACCOUNT_ID": "987654321098", + "AWS_ACCESS_KEY_ID": "test_key", + "AWS_SECRET_ACCESS_KEY": "test_secret", + "AWS_SESSION_TOKEN": "test_session", + "AWS_REGION": "eu-west-1", + "AWS_LAMBDA_FUNCTION_MEMORY_SIZE": "1024", + "AWS_LAMBDA_FUNCTION_NAME": "custom_function", + "AWS_LAMBDA_FUNCTION_VERSION": "2", + "AWS_LAMBDA_FUNCTION_TIMEOUT": "60", + "AWS_LAMBDA_MAX_CONCURRENCY": "5", + "AWS_LAMBDA_LOG_FORMAT": "JSON", + "AWS_LAMBDA_LOG_LEVEL": "DEBUG", + "AWS_LAMBDA_LOG_GROUP_NAME": "/aws/lambda/custom", + "AWS_LAMBDA_LOG_STREAM_NAME": "custom-stream", + "AWS_LAMBDA_FUNCTION_HANDLER": "custom_handler_from_env", + "_HANDLER": "lower_priority_custom_handler_from_env", + }, + want: intmodel.InitRequestMessage{ + AccountID: "987654321098", + AwsKey: "test_key", + AwsSecret: "test_secret", + AwsSession: "test_session", + AwsRegion: "eu-west-1", + EnvVars: map[string]string{}, + MemorySizeBytes: 1024 * 1024 * 1024, + FunctionARN: "arn:aws:lambda:eu-west-1:987654321098:function:custom_function", + FunctionVersion: "2", + FunctionVersionID: "", + ArtefactType: intmodel.ArtefactTypeZIP, + TaskName: "custom_function", + Handler: "custom_handler_from_env", + InvokeTimeout: intmodel.DurationMS(60 * time.Second), + InitTimeout: intmodel.DurationMS(60 * time.Second), + RuntimeVersion: "", + RuntimeArn: "", + RuntimeWorkerCount: 5, + LogFormat: "JSON", + LogLevel: "DEBUG", + LogGroupName: "/aws/lambda/custom", + LogStreamName: "custom-stream", + TelemetryAPIAddress: intmodel.TelemetryAddr(netip.MustParseAddrPort("127.0.0.1:0")), + TelemetryPassphrase: "", + XRayDaemonAddress: "", + XrayTracingMode: intmodel.XRayTracingModePassThrough, + CurrentWorkingDir: "/var/task", + RuntimeBinaryCommand: []string{"/custom/bootstrap", "custom_handler"}, + AvailabilityZoneId: "", + AmiId: "", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + for k, v := range tt.env { + t.Setenv(k, v) + } + tt.want.EnvVars = env.KVPairStringsToMap(os.Environ()) + cwd, err := os.Getwd() + require.NoError(t, err) + tt.want.CurrentWorkingDir = cwd + + got, err := GetInitRequestMessage(&utils.MockFileUtil{}, tt.args) + + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func Test_getBootstrap(t *testing.T) { + type args struct { + fileUtil utils.FileUtil + args []string + cwd string + } + tests := []struct { + name string + args args + wantCmd []string + wantErr bool + }{ + { + name: "with_bootstrap_arg", + args: args{ + fileUtil: &utils.MockFileUtil{}, + args: []string{"aws-lambda-rie", "/path/to/bootstrap", "handler"}, + cwd: "/test/cwd", + }, + wantCmd: []string{"/path/to/bootstrap", "handler"}, + wantErr: false, + }, + { + name: "find_bootstrap_in_cwd", + args: args{ + fileUtil: func() *utils.MockFileUtil { + m := &utils.MockFileUtil{} + mockInfo := utils.NewMockFileInfo() + mockInfo.On("IsDir").Return(false) + m.On("Stat", "/test/cwd/bootstrap").Return(mockInfo, nil) + m.On("IsNotExist", nil).Return(false) + return m + }(), + args: []string{"aws-lambda-rie"}, + cwd: "/test/cwd", + }, + wantCmd: []string{"/test/cwd/bootstrap"}, + wantErr: false, + }, + { + name: "find_bootstrap_in_var_runtime", + args: args{ + fileUtil: func() *utils.MockFileUtil { + m := &utils.MockFileUtil{} + + m.On("Stat", "/test/cwd/bootstrap").Return(nil, os.ErrNotExist) + m.On("IsNotExist", os.ErrNotExist).Return(true) + + mockInfo := utils.NewMockFileInfo() + mockInfo.On("IsDir").Return(false) + m.On("Stat", "/var/runtime/bootstrap").Return(mockInfo, nil) + m.On("IsNotExist", nil).Return(false) + return m + }(), + args: []string{"aws-lambda-rie"}, + cwd: "/test/cwd", + }, + wantCmd: []string{"/var/runtime/bootstrap"}, + wantErr: false, + }, + { + name: "bootstrap_is_directory", + args: args{ + fileUtil: func() *utils.MockFileUtil { + m := &utils.MockFileUtil{} + mockInfo := utils.NewMockFileInfo() + mockInfo.On("IsDir").Return(true) + m.On("Stat", "/test/cwd/bootstrap").Return(mockInfo, nil) + m.On("IsNotExist", nil).Return(false) + + m.On("Stat", "/var/runtime/bootstrap").Return(nil, os.ErrNotExist) + m.On("Stat", "/var/task/bootstrap").Return(nil, os.ErrNotExist) + m.On("Stat", "/opt/bootstrap").Return(nil, os.ErrNotExist) + m.On("IsNotExist", os.ErrNotExist).Return(true) + return m + }(), + args: []string{"aws-lambda-rie"}, + cwd: "/test/cwd", + }, + wantCmd: nil, + wantErr: true, + }, + { + name: "no_bootstrap_found", + args: args{ + fileUtil: func() *utils.MockFileUtil { + m := &utils.MockFileUtil{} + m.On("Stat", "/test/cwd/bootstrap").Return(nil, os.ErrNotExist) + m.On("Stat", "/var/runtime/bootstrap").Return(nil, os.ErrNotExist) + m.On("Stat", "/var/task/bootstrap").Return(nil, os.ErrNotExist) + m.On("Stat", "/opt/bootstrap").Return(nil, os.ErrNotExist) + m.On("IsNotExist", os.ErrNotExist).Return(true) + return m + }(), + args: []string{"aws-lambda-rie"}, + cwd: "/test/cwd", + }, + wantCmd: nil, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotCmd, gotErr := getBootstrap(tt.args.fileUtil, tt.args.args, tt.args.cwd) + if tt.wantErr { + assert.Error(t, gotErr) + } else { + assert.NoError(t, gotErr) + } + assert.Equal(t, tt.wantCmd, gotCmd) + }) + } +} + +func Test_getHandler(t *testing.T) { + tests := []struct { + name string + args []string + envVars map[string]string + want string + }{ + { + name: "AWS_LAMBDA_FUNCTION_HANDLER_takes_precedence", + args: []string{"aws-lambda-rie", "/path/to/bootstrap", "handler_from_args"}, + envVars: map[string]string{ + "AWS_LAMBDA_FUNCTION_HANDLER": "handler_from_aws_lambda_function_handler", + "_HANDLER": "handler_from__handler", + }, + want: "handler_from_aws_lambda_function_handler", + }, + { + name: "_HANDLER_takes_precedence", + args: []string{"aws-lambda-rie", "/path/to/bootstrap", "handler_from_args"}, + envVars: map[string]string{ + "_HANDLER": "handler_from__handler", + }, + want: "handler_from__handler", + }, + { + name: "handler_from_args", + args: []string{"aws-lambda-rie", "/path/to/bootstrap", "handler_from_args"}, + envVars: map[string]string{}, + want: "handler_from_args", + }, + + { + name: "no_handler_specified", + args: []string{"aws-lambda-rie", "/path/to/bootstrap"}, + envVars: map[string]string{}, + want: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + for k, v := range tt.envVars { + t.Setenv(k, v) + } + + assert.Equalf(t, tt.want, getHandler(tt.args), "getHandler(%v)", tt.args) + }) + } +} diff --git a/internal/lambda-managed-instances/aws-lambda-rie/internal/invoke/responder.go b/internal/lambda-managed-instances/aws-lambda-rie/internal/invoke/responder.go new file mode 100644 index 0000000..2192ee5 --- /dev/null +++ b/internal/lambda-managed-instances/aws-lambda-rie/internal/invoke/responder.go @@ -0,0 +1,85 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package invoke + +import ( + "context" + "io" + "log/slog" + "net/http" + "time" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/invoke" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" +) + +type Responder struct { + invokeReq interop.InvokeRequest + body []byte + rw http.ResponseWriter +} + +func NewResponder(invokeReq interop.InvokeRequest) *Responder { + return &Responder{ + invokeReq: invokeReq, + rw: invokeReq.ResponseWriter(), + } +} + +func (s *Responder) SendRuntimeResponseHeaders(_ interop.InitStaticDataProvider, _, _ string) { + +} + +func (s *Responder) SendRuntimeResponseBody(_ context.Context, runtimeResp invoke.RuntimeResponseRequest, _ time.Duration) invoke.SendResponseBodyResult { + runtimeBodyReader := io.LimitReader(runtimeResp.BodyReader(), s.invokeReq.MaxPayloadSize()+1) + b, err := io.ReadAll(runtimeBodyReader) + if err != nil { + return invoke.SendResponseBodyResult{ + Err: model.NewCustomerError(model.ErrorRuntimeTruncatedResponse, model.WithCause(err)), + } + } + if len(b) > int(s.invokeReq.MaxPayloadSize()) { + errorResponseTooLarge := interop.ErrorResponseTooLarge{ + ResponseSize: len(b), + MaxResponseSize: int(s.invokeReq.MaxPayloadSize()), + } + return invoke.SendResponseBodyResult{ + Err: model.NewCustomerError(model.ErrorFunctionOversizedResponse, model.WithCause(&errorResponseTooLarge)), + } + } + s.body = b + + return invoke.SendResponseBodyResult{} +} + +func (s *Responder) SendRuntimeResponseTrailers(request invoke.RuntimeResponseRequest) { + trailerError := request.TrailerError() + if trailerError != nil { + s.SendErrorTrailers(trailerError, "") + return + } + s.rw.Header().Set(invoke.СontentTypeHeader, request.ContentType()) + s.rw.Header().Set(invoke.RuntimeResponseModeHeader, request.ResponseMode()) + if _, err := s.rw.Write(s.body); err != nil { + slog.Error("could not write invoke response", "err", err) + } +} + +func (s *Responder) SendError(err invoke.ErrorForInvoker, _ interop.InitStaticDataProvider) { + s.SendErrorTrailers(err, "") +} + +func (s *Responder) SendErrorTrailers(err invoke.ErrorForInvoker, _ invoke.InvokeBodyResponseStatus) { + s.rw.Header().Set("Error-Type", err.ErrorType().String()) + + s.rw.WriteHeader(err.ReturnCode()) + if _, err := s.rw.Write([]byte(err.ErrorDetails())); err != nil { + slog.Error("could not write invoke error response", "err", err) + } +} + +func (s *Responder) ErrorPayloadSizeBytes() int { + return 0 +} diff --git a/internal/lambda-managed-instances/aws-lambda-rie/internal/invoke/responder_test.go b/internal/lambda-managed-instances/aws-lambda-rie/internal/invoke/responder_test.go new file mode 100644 index 0000000..1e3e147 --- /dev/null +++ b/internal/lambda-managed-instances/aws-lambda-rie/internal/invoke/responder_test.go @@ -0,0 +1,216 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package invoke + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/invoke" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" +) + +func TestResponder_CompleteSuccessFlow(t *testing.T) { + + recorder := httptest.NewRecorder() + mockInvokeReq := interop.NewMockInvokeRequest(t) + mockRuntimeReq := invoke.NewMockRuntimeResponseRequest(t) + mockInitData := interop.NewMockInitStaticDataProvider(t) + + mockInvokeReq.On("ResponseWriter").Return(recorder) + mockInvokeReq.On("MaxPayloadSize").Return(int64(1024)) + + responseBody := "test response body" + mockRuntimeReq.On("BodyReader").Return(strings.NewReader(responseBody)) + mockRuntimeReq.On("ContentType").Return("application/json") + mockRuntimeReq.On("ResponseMode").Return("buffered") + mockRuntimeReq.On("TrailerError").Return(nil) + + responder := NewResponder(mockInvokeReq) + responder.SendRuntimeResponseHeaders(mockInitData, "", "") + result := responder.SendRuntimeResponseBody(context.Background(), mockRuntimeReq, 0) + assert.NoError(t, result.Err) + responder.SendRuntimeResponseTrailers(mockRuntimeReq) + + assert.Equal(t, "application/json", recorder.Header().Get(invoke.СontentTypeHeader)) + assert.Equal(t, "buffered", recorder.Header().Get(invoke.RuntimeResponseModeHeader)) + assert.Equal(t, responseBody, recorder.Body.String()) + assert.Equal(t, http.StatusOK, recorder.Code) + + mockInvokeReq.AssertExpectations(t) + mockRuntimeReq.AssertExpectations(t) +} + +func TestResponder_SendErrorFlow(t *testing.T) { + + recorder := httptest.NewRecorder() + mockInvokeReq := interop.NewMockInvokeRequest(t) + mockInitData := interop.NewMockInitStaticDataProvider(t) + + mockInvokeReq.On("ResponseWriter").Return(recorder) + + baseErr := io.ErrUnexpectedEOF + appError := model.NewCustomerError("Function.TestError", model.WithCause(baseErr), model.WithErrorMessage("test error")) + + responder := NewResponder(mockInvokeReq) + responder.SendError(appError, mockInitData) + + assert.Equal(t, "Function.TestError", recorder.Header().Get("Error-Type")) + assert.Equal(t, `{"errorType":"Function.TestError","errorMessage":"test error"}`, recorder.Body.String()) + assert.Equal(t, http.StatusOK, recorder.Code) + + mockInvokeReq.AssertExpectations(t) +} + +func TestResponder_RuntimeInvocationErrorFlow(t *testing.T) { + + recorder := httptest.NewRecorder() + mockInvokeReq := interop.NewMockInvokeRequest(t) + mockInitData := interop.NewMockInitStaticDataProvider(t) + + mockInvokeReq.On("ResponseWriter").Return(recorder) + + responder := NewResponder(mockInvokeReq) + responder.SendRuntimeResponseHeaders(mockInitData, "", "") + responder.SendErrorTrailers(model.NewCustomerError("Runtime.TestError", model.WithErrorMessage("trailer error")), "") + + assert.Equal(t, "Runtime.TestError", recorder.Header().Get("Error-Type")) + assert.Equal(t, `{"errorType":"Runtime.TestError","errorMessage":"trailer error"}`, recorder.Body.String()) + assert.Equal(t, http.StatusOK, recorder.Code) + + mockInvokeReq.AssertExpectations(t) +} + +func TestResponder_ErrorInTheMiddleOfResponse(t *testing.T) { + + recorder := httptest.NewRecorder() + mockInvokeReq := interop.NewMockInvokeRequest(t) + mockRuntimeReq := invoke.NewMockRuntimeResponseRequest(t) + mockInitData := interop.NewMockInitStaticDataProvider(t) + + mockInvokeReq.On("ResponseWriter").Return(recorder) + mockInvokeReq.On("MaxPayloadSize").Return(int64(1024)) + + responseBody := "test response body" + mockRuntimeReq.On("BodyReader").Return(strings.NewReader(responseBody)) + + responder := NewResponder(mockInvokeReq) + responder.SendRuntimeResponseHeaders(mockInitData, "", "") + result := responder.SendRuntimeResponseBody(context.Background(), mockRuntimeReq, 0) + assert.NoError(t, result.Err) + responder.SendErrorTrailers(model.NewCustomerError("Sandbox.TestError", model.WithSeverity(model.ErrorSeverityFatal), model.WithErrorMessage("error after body")), "") + + assert.Equal(t, "Sandbox.TestError", recorder.Header().Get("Error-Type")) + assert.Equal(t, `{"errorType":"Sandbox.TestError","errorMessage":"error after body"}`, recorder.Body.String()) + assert.Equal(t, http.StatusOK, recorder.Code) + + mockInvokeReq.AssertExpectations(t) + mockRuntimeReq.AssertExpectations(t) +} + +func TestResponder_RuntimeResponseTrailerError(t *testing.T) { + + recorder := httptest.NewRecorder() + mockInvokeReq := interop.NewMockInvokeRequest(t) + mockRuntimeReq := invoke.NewMockRuntimeResponseRequest(t) + mockInitData := interop.NewMockInitStaticDataProvider(t) + + mockInvokeReq.On("ResponseWriter").Return(recorder) + mockInvokeReq.On("MaxPayloadSize").Return(int64(1024)) + + errorType := model.ErrorType("Function.TrailerError") + errorBody := []byte(`trailer error`) + + trailerError := invoke.NewMockErrorForInvoker(t) + trailerError.On("ErrorType").Return(errorType) + trailerError.On("ReturnCode").Return(http.StatusOK) + trailerError.On("ErrorDetails").Return("trailer error") + + responseBody := "test response body" + mockRuntimeReq.On("BodyReader").Return(strings.NewReader(responseBody)) + mockRuntimeReq.On("TrailerError").Return(trailerError) + + responder := NewResponder(mockInvokeReq) + responder.SendRuntimeResponseHeaders(mockInitData, "", "") + result := responder.SendRuntimeResponseBody(context.Background(), mockRuntimeReq, 0) + assert.NoError(t, result.Err) + responder.SendRuntimeResponseTrailers(mockRuntimeReq) + + assert.Equal(t, "Function.TrailerError", recorder.Header().Get("Error-Type")) + assert.Equal(t, string(errorBody), recorder.Body.String()) + assert.Equal(t, http.StatusOK, recorder.Code) + + mockInvokeReq.AssertExpectations(t) + mockRuntimeReq.AssertExpectations(t) +} + +func TestResponder_SendRuntimeResponseBody(t *testing.T) { + tests := []struct { + name string + setupMocks func(*interop.MockInvokeRequest, *invoke.MockRuntimeResponseRequest) + expectedError model.ErrorType + expectError bool + }{ + { + name: "OversizedResponse", + setupMocks: func(mockInvokeReq *interop.MockInvokeRequest, mockRuntimeReq *invoke.MockRuntimeResponseRequest) { + mockInvokeReq.On("MaxPayloadSize").Return(int64(10)) + largeResponse := strings.Repeat("x", 20) + mockRuntimeReq.On("BodyReader").Return(strings.NewReader(largeResponse)) + }, + expectedError: model.ErrorFunctionOversizedResponse, + expectError: true, + }, + { + name: "ReadError", + setupMocks: func(mockInvokeReq *interop.MockInvokeRequest, mockRuntimeReq *invoke.MockRuntimeResponseRequest) { + mockInvokeReq.On("MaxPayloadSize").Return(int64(1024)) + errorReader := &errorReader{err: io.ErrUnexpectedEOF} + mockRuntimeReq.On("BodyReader").Return(errorReader) + }, + expectedError: model.ErrorRuntimeTruncatedResponse, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + recorder := httptest.NewRecorder() + mockInvokeReq := interop.NewMockInvokeRequest(t) + mockRuntimeReq := invoke.NewMockRuntimeResponseRequest(t) + + mockInvokeReq.On("ResponseWriter").Return(recorder) + tt.setupMocks(mockInvokeReq, mockRuntimeReq) + + responder := NewResponder(mockInvokeReq) + result := responder.SendRuntimeResponseBody(context.Background(), mockRuntimeReq, 0) + + if tt.expectError { + assert.Error(t, result.Err) + assert.Equal(t, tt.expectedError, result.Err.ErrorType()) + } else { + assert.NoError(t, result.Err) + } + + mockInvokeReq.AssertExpectations(t) + mockRuntimeReq.AssertExpectations(t) + }) + } +} + +type errorReader struct { + err error +} + +func (e *errorReader) Read(p []byte) (n int, err error) { + return 0, e.err +} diff --git a/internal/lambda-managed-instances/aws-lambda-rie/internal/invoke/rie_invoke_request.go b/internal/lambda-managed-instances/aws-lambda-rie/internal/invoke/rie_invoke_request.go new file mode 100644 index 0000000..26a1cae --- /dev/null +++ b/internal/lambda-managed-instances/aws-lambda-rie/internal/invoke/rie_invoke_request.go @@ -0,0 +1,148 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package invoke + +import ( + "errors" + "io" + "net/http" + "time" + + "github.com/google/uuid" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/invoke" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" +) + +type rieInvokeRequest struct { + request *http.Request + writer http.ResponseWriter + + contentType string + maxPayloadSize int64 + responseBandwidthRate int64 + responseBandwidthBurstSize int64 + invokeID interop.InvokeID + deadline time.Time + traceId string + cognitoIdentityId string + cognitoIdentityPoolId string + clientContext string + responseMode string + + functionVersionID string +} + +func NewRieInvokeRequest(request *http.Request, writer http.ResponseWriter) *rieInvokeRequest { + + contentType := request.Header.Get(invoke.СontentTypeHeader) + if contentType == "" { + contentType = "application/json" + } + + invokeID := request.Header.Get("X-Amzn-RequestId") + if invokeID == "" { + invokeID = uuid.New().String() + } + + req := &rieInvokeRequest{ + request: request, + writer: writer, + invokeID: invokeID, + contentType: contentType, + maxPayloadSize: 6*1024*1024 + 100, + responseBandwidthRate: 2 * 1024 * 1024, + responseBandwidthBurstSize: 6 * 1024 * 1024, + traceId: request.Header.Get(invoke.TraceIdHeader), + cognitoIdentityId: "", + cognitoIdentityPoolId: "", + clientContext: request.Header.Get("X-Amz-Client-Context"), + responseMode: request.Header.Get(invoke.ResponseModeHeader), + } + + return req +} + +func (r *rieInvokeRequest) ContentType() string { + return r.contentType +} + +func (r *rieInvokeRequest) InvokeID() interop.InvokeID { + return r.invokeID +} + +func (r *rieInvokeRequest) Deadline() time.Time { + return r.deadline +} + +func (r *rieInvokeRequest) TraceId() string { + return r.traceId +} + +func (r *rieInvokeRequest) ClientContext() string { + return r.clientContext +} + +func (r *rieInvokeRequest) CognitoId() string { + return r.cognitoIdentityId +} + +func (r *rieInvokeRequest) CognitoPoolId() string { + return r.cognitoIdentityPoolId +} + +func (r *rieInvokeRequest) ResponseBandwidthRate() int64 { + return r.responseBandwidthRate +} + +func (r *rieInvokeRequest) ResponseBandwidthBurstRate() int64 { + return r.responseBandwidthBurstSize +} + +func (r *rieInvokeRequest) MaxPayloadSize() int64 { + return r.maxPayloadSize +} + +func (r *rieInvokeRequest) BodyReader() io.Reader { + return r.request.Body +} + +func (r *rieInvokeRequest) ResponseWriter() http.ResponseWriter { + return r.writer +} + +func (r *rieInvokeRequest) SetResponseHeader(key string, val string) { + r.writer.Header().Set(key, val) +} + +func (r *rieInvokeRequest) AddResponseHeader(key string, val string) { + r.writer.Header().Add(key, val) +} + +func (r *rieInvokeRequest) WriteResponseHeaders(status int) { + r.writer.WriteHeader(status) +} + +func (r *rieInvokeRequest) ResponseMode() string { + return r.responseMode +} + +func (r *rieInvokeRequest) UpdateFromInitData(initData interop.InitStaticDataProvider) model.AppError { + if initData == nil { + return model.NewClientError(errors.New("sandbox is not initialized"), model.ErrorSeverityError, model.ErrorInitIncomplete) + } + + r.deadline = time.Now().Add(time.Duration(initData.FunctionTimeout()) * time.Millisecond) + + if r.functionVersionID != initData.FunctionVersionID() { + return model.NewClientError(nil, model.ErrorSeverityInvalid, model.ErrorInvalidFunctionVersion) + } + + return nil +} + +func (r *rieInvokeRequest) FunctionVersionID() string { + return r.functionVersionID +} diff --git a/internal/lambda-managed-instances/aws-lambda-rie/internal/invoke/rie_invoke_request_test.go b/internal/lambda-managed-instances/aws-lambda-rie/internal/invoke/rie_invoke_request_test.go new file mode 100644 index 0000000..f817b38 --- /dev/null +++ b/internal/lambda-managed-instances/aws-lambda-rie/internal/invoke/rie_invoke_request_test.go @@ -0,0 +1,80 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package invoke + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewRieInvokeRequest(t *testing.T) { + tests := []struct { + name string + request func() *http.Request + writer http.ResponseWriter + want *rieInvokeRequest + }{ + { + name: "no_headers_in_request", + request: func() *http.Request { + r, err := http.NewRequest("GET", "http://localhost/", nil) + require.NoError(t, err) + return r + }, + writer: httptest.NewRecorder(), + want: &rieInvokeRequest{ + contentType: "application/json", + maxPayloadSize: 6*1024*1024 + 100, + responseBandwidthRate: 2 * 1024 * 1024, + responseBandwidthBurstSize: 6 * 1024 * 1024, + traceId: "", + cognitoIdentityId: "", + cognitoIdentityPoolId: "", + clientContext: "", + }, + }, + { + name: "all_headers_present_in_request", + request: func() *http.Request { + r, err := http.NewRequest("GET", "http://localhost/", nil) + r.Header.Set("Content-Type", "text/plain") + r.Header.Set("X-Amzn-Trace-Id", "Root=1-5e1b4151-5ac6c58f3375aa3c7c6b73c9") + r.Header.Set("X-Amz-Client-Context", "eyJjdXN0b20iOnsidGVzdCI6InZhbHVlIn19") + r.Header.Set("X-Amzn-RequestId", "test-invoke-id") + require.NoError(t, err) + return r + }, + writer: httptest.NewRecorder(), + want: &rieInvokeRequest{ + invokeID: "test-invoke-id", + contentType: "text/plain", + maxPayloadSize: 6*1024*1024 + 100, + responseBandwidthRate: 2 * 1024 * 1024, + responseBandwidthBurstSize: 6 * 1024 * 1024, + traceId: "Root=1-5e1b4151-5ac6c58f3375aa3c7c6b73c9", + cognitoIdentityId: "", + cognitoIdentityPoolId: "", + clientContext: "eyJjdXN0b20iOnsidGVzdCI6InZhbHVlIn19", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := tt.request() + got := NewRieInvokeRequest(r, tt.writer) + + tt.want.request = r + tt.want.writer = tt.writer + if tt.want.invokeID == "" { + tt.want.invokeID = got.invokeID + } + + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/internal/lambda-managed-instances/aws-lambda-rie/internal/mock_init_request_message_factory.go b/internal/lambda-managed-instances/aws-lambda-rie/internal/mock_init_request_message_factory.go new file mode 100644 index 0000000..46c294d --- /dev/null +++ b/internal/lambda-managed-instances/aws-lambda-rie/internal/mock_init_request_message_factory.go @@ -0,0 +1,57 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package internal + +import ( + mock "github.com/stretchr/testify/mock" + model "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/model" + rapidmodel "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" + + utils "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/utils" +) + +type MockInitRequestMessageFactory struct { + mock.Mock +} + +func (_m *MockInitRequestMessageFactory) Execute(fileUtil utils.FileUtil, args []string) (model.InitRequestMessage, rapidmodel.AppError) { + ret := _m.Called(fileUtil, args) + + if len(ret) == 0 { + panic("no return value specified for Execute") + } + + var r0 model.InitRequestMessage + var r1 rapidmodel.AppError + if rf, ok := ret.Get(0).(func(utils.FileUtil, []string) (model.InitRequestMessage, rapidmodel.AppError)); ok { + return rf(fileUtil, args) + } + if rf, ok := ret.Get(0).(func(utils.FileUtil, []string) model.InitRequestMessage); ok { + r0 = rf(fileUtil, args) + } else { + r0 = ret.Get(0).(model.InitRequestMessage) + } + + if rf, ok := ret.Get(1).(func(utils.FileUtil, []string) rapidmodel.AppError); ok { + r1 = rf(fileUtil, args) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(rapidmodel.AppError) + } + } + + return r0, r1 +} + +func NewMockInitRequestMessageFactory(t interface { + mock.TestingT + Cleanup(func()) +}) *MockInitRequestMessageFactory { + mock := &MockInitRequestMessageFactory{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/lambda-managed-instances/aws-lambda-rie/internal/mock_raptor_app.go b/internal/lambda-managed-instances/aws-lambda-rie/internal/mock_raptor_app.go new file mode 100644 index 0000000..1132f8d --- /dev/null +++ b/internal/lambda-managed-instances/aws-lambda-rie/internal/mock_raptor_app.go @@ -0,0 +1,79 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package internal + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + interop "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + + model "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/model" + + rapidmodel "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" +) + +type mockRaptorApp struct { + mock.Mock +} + +func (_m *mockRaptorApp) Init(ctx context.Context, req *model.InitRequestMessage, metrics interop.InitMetrics) rapidmodel.AppError { + ret := _m.Called(ctx, req, metrics) + + if len(ret) == 0 { + panic("no return value specified for Init") + } + + var r0 rapidmodel.AppError + if rf, ok := ret.Get(0).(func(context.Context, *model.InitRequestMessage, interop.InitMetrics) rapidmodel.AppError); ok { + r0 = rf(ctx, req, metrics) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(rapidmodel.AppError) + } + } + + return r0 +} + +func (_m *mockRaptorApp) Invoke(ctx context.Context, msg interop.InvokeRequest, metrics interop.InvokeMetrics) (rapidmodel.AppError, bool) { + ret := _m.Called(ctx, msg, metrics) + + if len(ret) == 0 { + panic("no return value specified for Invoke") + } + + var r0 rapidmodel.AppError + var r1 bool + if rf, ok := ret.Get(0).(func(context.Context, interop.InvokeRequest, interop.InvokeMetrics) (rapidmodel.AppError, bool)); ok { + return rf(ctx, msg, metrics) + } + if rf, ok := ret.Get(0).(func(context.Context, interop.InvokeRequest, interop.InvokeMetrics) rapidmodel.AppError); ok { + r0 = rf(ctx, msg, metrics) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(rapidmodel.AppError) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, interop.InvokeRequest, interop.InvokeMetrics) bool); ok { + r1 = rf(ctx, msg, metrics) + } else { + r1 = ret.Get(1).(bool) + } + + return r0, r1 +} + +func newMockRaptorApp(t interface { + mock.TestingT + Cleanup(func()) +}) *mockRaptorApp { + mock := &mockRaptorApp{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/lambda-managed-instances/aws-lambda-rie/internal/run.go b/internal/lambda-managed-instances/aws-lambda-rie/internal/run.go new file mode 100644 index 0000000..712204e --- /dev/null +++ b/internal/lambda-managed-instances/aws-lambda-rie/internal/run.go @@ -0,0 +1,96 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package internal + +import ( + "context" + "fmt" + "log/slog" + "os" + "time" + + rieinvoke "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/aws-lambda-rie/internal/invoke" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/invoke" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/invoke/timeout" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/raptor" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/servicelogs" + supvmodel "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/supervisor/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/utils" +) + +func Run(supv supvmodel.ProcessSupervisor, args []string, fileUtil utils.FileUtil, sigCh chan os.Signal) (*raptor.Server, *HTTPHandler, *raptor.App, error) { + + opts, args, err := ParseCLIArgs(args) + if err != nil { + return nil, nil, nil, fmt.Errorf("failed to parse command line arguments: %w", err) + } + + ConfigureLogging(opts.LogLevel) + + runtimeAPIAddr, err := ParseAddr(opts.RuntimeAddress, "127.0.0.1:9001") + if err != nil { + return nil, nil, nil, fmt.Errorf("invalid runtime API address: %w", err) + } + + rieAddr, err := ParseAddr(opts.RIEAddress, "0.0.0.0:8080") + if err != nil { + return nil, nil, nil, fmt.Errorf("invalid RIE address: %w", err) + } + + telemetryAPIRelay := telemetry.NewRelay() + eventsAPI := telemetry.NewEventsAPI(telemetryAPIRelay) + + responderFactoryFunc := func(_ context.Context, invokeReq interop.InvokeRequest) invoke.InvokeResponseSender { + return rieinvoke.NewResponder(invokeReq) + } + invokeRouter := invoke.NewInvokeRouter(rapid.MaxIdleRuntimesQueueSize, eventsAPI, responderFactoryFunc, timeout.NewRecentCache()) + + deps := rapid.Dependencies{ + EventsAPI: eventsAPI, + LogsEgressAPI: telemetry.NewLogsEgress(telemetryAPIRelay, os.Stdout), + TelemetrySubscriptionAPI: telemetry.NewSubscriptionAPI(telemetryAPIRelay, eventsAPI, eventsAPI), + Supervisor: supv, + RuntimeAPIAddrPort: runtimeAPIAddr, + FileUtils: fileUtil, + InvokeRouter: invokeRouter, + } + + raptorApp, err := raptor.StartApp(deps, "", noOpLogger{}) + if err != nil { + return nil, nil, nil, fmt.Errorf("could not start runtime api server: %w", err) + } + + initMsg, err := GetInitRequestMessage(fileUtil, args) + if err != nil { + return nil, nil, nil, fmt.Errorf("could not build initialization parameters: %w", err) + } + + rieApp := NewHTTPHandler(raptorApp, initMsg) + s, err := raptor.StartServer(raptorApp, rieApp, &raptor.TCPAddress{AddrPort: rieAddr}) + if err != nil { + return nil, nil, nil, fmt.Errorf("could not start RIE server: %w", err) + } + slog.Debug("RIE server started") + + go func() { + <-raptorApp.Done() + s.Shutdown(raptorApp.Err()) + }() + + s.AttachShutdownSignalHandler(sigCh) + + return s, rieApp, raptorApp, nil +} + +type noOpLogger struct{} + +func (n noOpLogger) Log(_ servicelogs.Operation, _ time.Time, _ []servicelogs.Property, _ []servicelogs.Dimension, _ []servicelogs.Metric) { +} + +func (n noOpLogger) SetInitData(_ interop.InitStaticDataProvider) {} + +func (n noOpLogger) Close() error { return nil } diff --git a/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/events_api.go b/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/events_api.go new file mode 100644 index 0000000..893b371 --- /dev/null +++ b/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/events_api.go @@ -0,0 +1,101 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package telemetry + +import ( + "context" + "fmt" + "io" + "log/slog" + "os" + "time" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/internal" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/telemetry" +) + +type EventsAPI struct { + eventRelay relay + stdout io.Writer +} + +func NewEventsAPI(eventRelay relay) *EventsAPI { + return &EventsAPI{ + eventRelay: eventRelay, + stdout: os.Stdout, + } +} + +func (api *EventsAPI) SendInitStart(data interop.InitStartData) error { + api.eventRelay.broadcast(data, internal.CategoryPlatform, internal.TypePlatformInitStart) + return nil +} + +func (api *EventsAPI) SendInitRuntimeDone(data interop.InitRuntimeDoneData) error { + api.eventRelay.broadcast(data, internal.CategoryPlatform, internal.TypePlatformInitRuntimeDone) + return nil +} + +func (api *EventsAPI) SendInitReport(data interop.InitReportData) error { + api.eventRelay.broadcast(data, internal.CategoryPlatform, internal.TypePlatformInitReport) + return nil +} + +func (api *EventsAPI) SendExtensionInit(data interop.ExtensionInitData) error { + api.eventRelay.broadcast(data, internal.CategoryPlatform, internal.TypePlatformExtension) + return nil +} + +func (api *EventsAPI) SendImageError(errLog interop.ImageErrorLogData) { + slog.Error(telemetry.FormatImageError(errLog)) +} + +func (api *EventsAPI) SendInvokeStart(data interop.InvokeStartData) error { + _, _ = fmt.Fprintf(api.stdout, "START RequestId: %s\tVersion: %s\n", data.InvokeID, data.Version) + + api.eventRelay.broadcast(data, internal.CategoryPlatform, internal.TypePlatformStart) + return nil +} + +func (api *EventsAPI) SendInternalXRayErrorCause(data interop.InternalXRayErrorCauseData) error { + return nil +} + +func (api *EventsAPI) SendReport(data interop.ReportData) error { + _, _ = fmt.Fprintf(api.stdout, "END RequestId: %s\n", data.InvokeID) + _, _ = fmt.Fprintf(api.stdout, "REPORT RequestId: %s\tDuration: %.2f ms\n", data.InvokeID, float64(data.Metrics.DurationMs)) + + api.eventRelay.broadcast(data, internal.CategoryPlatform, internal.TypePlatformReport) + return nil +} + +func (api *EventsAPI) SendPlatformLogsDropped(droppedBytes, droppedRecords int, reason string) error { + record := map[string]any{ + "droppedBytes": droppedBytes, + "droppedRecords": droppedRecords, + "reason": reason, + } + + api.eventRelay.broadcast(record, internal.CategoryPlatform, internal.TypePlatformLogsDropped) + return nil +} + +func (api *EventsAPI) sendTelemetrySubscription(agentName, state string, types []internal.EventCategory) error { + record := map[string]any{ + "name": agentName, + "state": state, + "types": types, + } + + api.eventRelay.broadcast(record, internal.CategoryPlatform, internal.TypePlatformTelemetrySubscription) + return nil +} + +func (api *EventsAPI) Flush() { + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + api.eventRelay.flush(ctx) +} diff --git a/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/events_api_test.go b/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/events_api_test.go new file mode 100644 index 0000000..bfd5233 --- /dev/null +++ b/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/events_api_test.go @@ -0,0 +1,157 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package telemetry + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/internal" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + rapidmodel "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" +) + +func TestEventsAPI_SendInitStart(t *testing.T) { + m := newMockRelay(t) + defer m.AssertExpectations(t) + api := &EventsAPI{ + eventRelay: m, + stdout: &bytes.Buffer{}, + } + + data := interop.InitStartData{} + m.On("broadcast", data, internal.CategoryPlatform, internal.TypePlatformInitStart).Once() + assert.NoError(t, api.SendInitStart(data)) +} + +func TestEventsAPI_SendInitRuntimeDone(t *testing.T) { + m := newMockRelay(t) + defer m.AssertExpectations(t) + api := &EventsAPI{ + eventRelay: m, + stdout: &bytes.Buffer{}, + } + + data := interop.InitRuntimeDoneData{} + m.On("broadcast", data, internal.CategoryPlatform, internal.TypePlatformInitRuntimeDone).Once() + assert.NoError(t, api.SendInitRuntimeDone(data)) +} + +func TestEventsAPI_SendInitReport(t *testing.T) { + m := newMockRelay(t) + defer m.AssertExpectations(t) + api := &EventsAPI{ + eventRelay: m, + stdout: &bytes.Buffer{}, + } + + data := interop.InitReportData{} + m.On("broadcast", data, internal.CategoryPlatform, internal.TypePlatformInitReport).Once() + assert.NoError(t, api.SendInitReport(data)) +} + +func TestEventsAPI_SendExtensionInit(t *testing.T) { + m := newMockRelay(t) + defer m.AssertExpectations(t) + api := &EventsAPI{ + eventRelay: m, + stdout: &bytes.Buffer{}, + } + + data := interop.ExtensionInitData{} + m.On("broadcast", data, internal.CategoryPlatform, internal.TypePlatformExtension).Once() + assert.NoError(t, api.SendExtensionInit(data)) +} + +func TestEventsAPI_SendInvokeStart(t *testing.T) { + m := newMockRelay(t) + defer m.AssertExpectations(t) + stdout := &bytes.Buffer{} + api := &EventsAPI{ + eventRelay: m, + stdout: stdout, + } + + data := interop.InvokeStartData{ + InvokeID: "test-invoke-id-123", + Version: "$LATEST", + FunctionARN: "arn:aws:lambda:us-east-1:123456789012:function:test-function", + } + m.On("broadcast", data, internal.CategoryPlatform, internal.TypePlatformStart).Once() + assert.NoError(t, api.SendInvokeStart(data)) + + expectedOutput := "START RequestId: test-invoke-id-123\tVersion: $LATEST\n" + assert.Equal(t, expectedOutput, stdout.String()) +} + +func TestEventsAPI_SendReport(t *testing.T) { + m := newMockRelay(t) + defer m.AssertExpectations(t) + stdout := &bytes.Buffer{} + api := &EventsAPI{ + eventRelay: m, + stdout: stdout, + } + + errorType := rapidmodel.ErrorRuntimeExit + data := interop.ReportData{ + InvokeID: "test-invoke-id-123", + Status: "success", + Metrics: interop.ReportMetrics{ + DurationMs: 125.456, + }, + ErrorType: &errorType, + } + m.On("broadcast", data, internal.CategoryPlatform, internal.TypePlatformReport).Once() + assert.NoError(t, api.SendReport(data)) + + expectedOutput := "END RequestId: test-invoke-id-123\nREPORT RequestId: test-invoke-id-123\tDuration: 125.46 ms\n" + assert.Equal(t, expectedOutput, stdout.String()) +} + +func TestEventsAPI_SendPlatformLogsDropped(t *testing.T) { + m := newMockRelay(t) + defer m.AssertExpectations(t) + api := &EventsAPI{ + eventRelay: m, + stdout: &bytes.Buffer{}, + } + + droppedBytes := 1024 + droppedRecords := 5 + reason := "buffer overflow" + + expectedRecord := map[string]any{ + "droppedBytes": droppedBytes, + "droppedRecords": droppedRecords, + "reason": reason, + } + + m.On("broadcast", expectedRecord, internal.CategoryPlatform, internal.TypePlatformLogsDropped).Once() + assert.NoError(t, api.SendPlatformLogsDropped(droppedBytes, droppedRecords, reason)) +} + +func TestEventsAPI_sendTelemetrySubscription(t *testing.T) { + m := newMockRelay(t) + defer m.AssertExpectations(t) + api := &EventsAPI{ + eventRelay: m, + stdout: &bytes.Buffer{}, + } + + agentName := "test-agent" + state := "subscribed" + types := []internal.EventCategory{internal.CategoryPlatform, internal.CategoryFunction} + + expectedRecord := map[string]any{ + "name": agentName, + "state": state, + "types": types, + } + + m.On("broadcast", expectedRecord, internal.CategoryPlatform, internal.TypePlatformTelemetrySubscription).Once() + assert.NoError(t, api.sendTelemetrySubscription(agentName, state, types)) +} diff --git a/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/internal/batch.go b/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/internal/batch.go new file mode 100644 index 0000000..06f4261 --- /dev/null +++ b/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/internal/batch.go @@ -0,0 +1,41 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package internal + +import ( + "encoding/json" + "time" +) + +type batch struct { + events []json.RawMessage + sizeBytes int + flushAt <-chan time.Time + bufCfg BufferingConfig + doneCh chan struct{} +} + +func newBatch(bufCfg BufferingConfig) *batch { + return &batch{ + bufCfg: bufCfg, + flushAt: time.After(time.Duration(bufCfg.Timeout)), + doneCh: make(chan struct{}), + } +} + +func (b *batch) addEvent(event json.RawMessage) (full bool) { + b.events = append(b.events, event) + b.sizeBytes += len(event) + return b.isFull() +} + +func (b *batch) isFull() bool { + if len(b.events) >= b.bufCfg.MaxItems { + return true + } + if b.sizeBytes >= b.bufCfg.MaxBytes { + return true + } + return false +} diff --git a/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/internal/batch_test.go b/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/internal/batch_test.go new file mode 100644 index 0000000..d82b540 --- /dev/null +++ b/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/internal/batch_test.go @@ -0,0 +1,71 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package internal + +import ( + "encoding/json" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/model" +) + +func TestBatch(t *testing.T) { + tests := []struct { + name string + bufCfg BufferingConfig + events []string + expectedFull []bool + description string + }{ + { + name: "sequential_event_addition_until_full", + bufCfg: BufferingConfig{ + MaxItems: 3, + MaxBytes: 1000, + Timeout: model.DurationMS(1 * time.Second), + }, + events: []string{ + `{"event":1}`, + `{"event":2}`, + `{"event":3}`, + `{"event":4}`, + }, + expectedFull: []bool{false, false, true, true}, + description: "Sequential addition should properly track fullness", + }, + { + name: "mixed_size_events_reach_byte_limit", + bufCfg: BufferingConfig{ + MaxItems: 10, + MaxBytes: 50, + Timeout: model.DurationMS(1 * time.Second), + }, + events: []string{ + `{"a":1}`, + `{"data":"test"}`, + `{"largeData":"this is a long string"}`, + }, + expectedFull: []bool{false, false, true}, + description: "Mixed size events should trigger byte limit", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + batch := newBatch(tt.bufCfg) + var totalSize int + for i, eventJSON := range tt.events { + totalSize += len(eventJSON) + expectedFull := tt.expectedFull[i] + assert.Equal(t, expectedFull, batch.addEvent(json.RawMessage(eventJSON)), "Event %d: addEvent return value should match expected fullness", i+1) + assert.Equal(t, expectedFull, batch.isFull(), "Event %d: addEvent return value should match isFull()", i+1) + } + assert.Equal(t, len(tt.events), len(batch.events), "Batch should contain events") + assert.Equal(t, totalSize, batch.sizeBytes, "Batch sizeBytes should match sum of event lengths") + }) + } +} diff --git a/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/internal/client.go b/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/internal/client.go new file mode 100644 index 0000000..76f502e --- /dev/null +++ b/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/internal/client.go @@ -0,0 +1,126 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package internal + +import ( + "bufio" + "bytes" + "context" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "log/slog" + "net" + "net/http" + "net/url" +) + +type Client interface { + send(ctx context.Context, b batch) error +} + +const sandboxLocalDomain = "sandbox.localdomain" + +func NewClient(dst SubscriptionDestination) (Client, error) { + switch dst.Protocol { + case protocolTCP: + return newTCPClient(dst.Port) + case protocolHTTP: + return newHTTPClient(dst.URI) + default: + return nil, fmt.Errorf("unknown protocol: %s. Only tcp and http are supported", dst.Protocol) + } +} + +type tcpClient struct { + conn *bufio.Writer +} + +func newTCPClient(port uint16) (*tcpClient, error) { + address := fmt.Sprintf("127.0.0.1:%d", port) + conn, err := net.Dial("tcp", address) + if err != nil { + return nil, fmt.Errorf("could not TCP dial provided address %s: %w", address, err) + } + return &tcpClient{conn: bufio.NewWriter(conn)}, nil +} + +func (c *tcpClient) send(ctx context.Context, b batch) error { + for _, ev := range b.events { + select { + case <-ctx.Done(): + return fmt.Errorf("sending event to TCP subscriber was interrupted: %w", ctx.Err()) + default: + } + if _, err := c.conn.Write(ev); err != nil { + return fmt.Errorf("could not write event: %w", err) + } + if _, err := c.conn.Write([]byte("\n")); err != nil { + return fmt.Errorf("could not write event: %w", err) + } + } + if err := c.conn.Flush(); err != nil { + return fmt.Errorf("could not write events: %w", err) + } + return nil +} + +type httpClient struct { + addr string +} + +func newHTTPClient(uri string) (*httpClient, error) { + u, err := url.Parse(uri) + if err != nil { + return nil, fmt.Errorf("could not parse destination.URI: %w", err) + } + if u.Hostname() != sandboxLocalDomain && u.Hostname() != "sandbox" { + return nil, fmt.Errorf("destination.URI host must be %s", sandboxLocalDomain) + } + return &httpClient{addr: "http://127.0.0.1:" + u.Port()}, nil +} + +func (c *httpClient) send(ctx context.Context, batch batch) error { + b, err := json.Marshal(batch.events) + if err != nil { + return err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.addr, bytes.NewReader(b)) + if err != nil { + return fmt.Errorf("could not create HTTP request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Sequence-Id", sequenceId(b)) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + defer func() { + if err := resp.Body.Close(); err != nil { + slog.Error("could not close response body", "err", err) + } + }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("could not read response body: %w", err) + } + + if resp.StatusCode >= 400 { + return fmt.Errorf("http request failed with status %s: %s", resp.Status, string(body)) + } + + slog.Debug("telemetry HTTP request completed", "extension_response", string(body)) + return nil +} + +func sequenceId(data []byte) string { + hash := sha256.Sum256(data) + return base64.StdEncoding.EncodeToString(hash[:]) +} diff --git a/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/internal/client_test.go b/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/internal/client_test.go new file mode 100644 index 0000000..9286112 --- /dev/null +++ b/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/internal/client_test.go @@ -0,0 +1,263 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package internal + +import ( + "context" + "encoding/json" + "net" + "net/http" + "net/http/httptest" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewClient(t *testing.T) { + tests := []struct { + name string + destination SubscriptionDestination + setupServer func() net.Listener + expectError bool + errorMsg string + }{ + + { + name: "creates_tcp_client_successfully", + destination: SubscriptionDestination{ + Protocol: protocolTCP, + Port: 0, + }, + setupServer: func() net.Listener { + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + return listener + }, + expectError: false, + }, + { + name: "tcp_client_fails_no_server", + destination: SubscriptionDestination{ + Protocol: protocolTCP, + Port: 65432, + }, + expectError: true, + errorMsg: "could not TCP dial", + }, + { + name: "tcp_client_with_zero_port", + destination: SubscriptionDestination{ + Protocol: protocolTCP, + Port: 0, + }, + expectError: true, + errorMsg: "could not TCP dial", + }, + + { + name: "creates_http_client_valid_sandbox_domain", + destination: SubscriptionDestination{ + Protocol: protocolHTTP, + URI: "http://sandbox.localdomain:8080/telemetry", + }, + expectError: false, + }, + + { + name: "http_client_with_https", + destination: SubscriptionDestination{ + Protocol: protocolHTTP, + URI: "https://sandbox.localdomain:9090/events", + }, + expectError: false, + }, + { + name: "http_client_missing_port", + destination: SubscriptionDestination{ + Protocol: protocolHTTP, + URI: "http://sandbox.localdomain/path", + }, + expectError: false, + }, + { + name: "http_client_with_path_and_query", + destination: SubscriptionDestination{ + Protocol: protocolHTTP, + URI: "http://sandbox.localdomain:8080/telemetry?param=value", + }, + expectError: false, + }, + { + name: "http_client_invalid_uri_format", + destination: SubscriptionDestination{ + Protocol: protocolHTTP, + URI: "not-a-valid-uri", + }, + expectError: true, + errorMsg: "destination.URI host must be sandbox.localdomain", + }, + { + name: "http_client_invalid_hostname", + destination: SubscriptionDestination{ + Protocol: protocolHTTP, + URI: "http://invalid.domain:8080/telemetry", + }, + expectError: true, + errorMsg: "destination.URI host must be sandbox.localdomain", + }, + + { + name: "invalid_protocol", + destination: SubscriptionDestination{ + Protocol: "INVALID", + URI: "http://localhost:8080", + }, + expectError: true, + errorMsg: "unknown protocol: INVALID", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.setupServer != nil { + listener := tt.setupServer() + defer func() { require.NoError(t, listener.Close()) }() + + if tt.destination.Protocol == protocolTCP && tt.destination.Port == 0 { + _, portStr, err := net.SplitHostPort(listener.Addr().String()) + require.NoError(t, err) + port, err := strconv.Atoi(portStr) + require.NoError(t, err) + tt.destination.Port = uint16(port) + } + } + + _, err := NewClient(tt.destination) + if tt.expectError { + assert.Error(t, err) + if tt.errorMsg != "" { + assert.Contains(t, err.Error(), tt.errorMsg) + } + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestHTTPClient_send(t *testing.T) { + tests := []struct { + name string + setupServer func() *httptest.Server + batch batch + setupContext func() context.Context + expectError bool + errorMsg string + validateReq func(t *testing.T, body []byte) + }{ + { + name: "sends_batch_successfully", + setupServer: func() *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodPost, r.Method) + w.WriteHeader(http.StatusOK) + _, err := w.Write([]byte("OK")) + assert.NoError(t, err) + })) + }, + batch: batch{ + events: []json.RawMessage{ + json.RawMessage(`{"type":"platform.start","timestamp":"2023-01-01T00:00:00Z"}`), + json.RawMessage(`{"type":"platform.report","timestamp":"2023-01-01T00:01:00Z"}`), + }, + }, + setupContext: context.Background, + expectError: false, + validateReq: func(t *testing.T, body []byte) { + var events []json.RawMessage + err := json.Unmarshal(body, &events) + assert.NoError(t, err) + assert.Len(t, events, 2) + }, + }, + { + name: "handles_http_error_response", + setupServer: func() *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, err := w.Write([]byte("Bad Request")) + assert.NoError(t, err) + })) + }, + batch: batch{ + events: []json.RawMessage{ + json.RawMessage(`{"type":"test"}`), + }, + }, + setupContext: context.Background, + expectError: true, + errorMsg: "http request failed with status 400 Bad Request", + }, + { + name: "handles_context_cancellation", + setupServer: func() *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(200 * time.Millisecond) + w.WriteHeader(http.StatusOK) + })) + }, + batch: batch{ + events: []json.RawMessage{ + json.RawMessage(`{"type":"test"}`), + }, + }, + setupContext: func() context.Context { + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + return ctx + }, + expectError: true, + errorMsg: "context canceled", + }, + { + name: "sends_empty_batch", + setupServer: func() *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, err := w.Write([]byte("OK")) + assert.NoError(t, err) + })) + }, + batch: batch{events: []json.RawMessage{}}, + setupContext: context.Background, + expectError: false, + validateReq: func(t *testing.T, body []byte) { + var events []json.RawMessage + err := json.Unmarshal(body, &events) + assert.NoError(t, err) + assert.Len(t, events, 0) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := tt.setupServer() + defer server.Close() + + client := &httpClient{addr: server.URL} + err := client.send(tt.setupContext(), tt.batch) + + if tt.expectError { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.errorMsg) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/internal/mock_client.go b/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/internal/mock_client.go new file mode 100644 index 0000000..1eab3ce --- /dev/null +++ b/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/internal/mock_client.go @@ -0,0 +1,43 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package internal + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" +) + +type MockClient struct { + mock.Mock +} + +func (_m *MockClient) send(ctx context.Context, b batch) error { + ret := _m.Called(ctx, b) + + if len(ret) == 0 { + panic("no return value specified for send") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, batch) error); ok { + r0 = rf(ctx, b) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +func NewMockClient(t interface { + mock.TestingT + Cleanup(func()) +}) *MockClient { + mock := &MockClient{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/internal/mock_logs_dropped_event_api.go b/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/internal/mock_logs_dropped_event_api.go new file mode 100644 index 0000000..2973889 --- /dev/null +++ b/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/internal/mock_logs_dropped_event_api.go @@ -0,0 +1,39 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package internal + +import mock "github.com/stretchr/testify/mock" + +type MockLogsDroppedEventAPI struct { + mock.Mock +} + +func (_m *MockLogsDroppedEventAPI) SendPlatformLogsDropped(droppedBytes int, droppedRecords int, reason string) error { + ret := _m.Called(droppedBytes, droppedRecords, reason) + + if len(ret) == 0 { + panic("no return value specified for SendPlatformLogsDropped") + } + + var r0 error + if rf, ok := ret.Get(0).(func(int, int, string) error); ok { + r0 = rf(droppedBytes, droppedRecords, reason) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +func NewMockLogsDroppedEventAPI(t interface { + mock.TestingT + Cleanup(func()) +}) *MockLogsDroppedEventAPI { + mock := &MockLogsDroppedEventAPI{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/internal/subscriber.go b/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/internal/subscriber.go new file mode 100644 index 0000000..aeca875 --- /dev/null +++ b/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/internal/subscriber.go @@ -0,0 +1,174 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package internal + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "sync" + "time" +) + +type Subscriber struct { + agentName string + categories map[EventCategory]struct{} + client Client + curBatch *batch + curBatchMu sync.Mutex + eventCh chan json.RawMessage + batchSenderCh chan batch + bufCfg BufferingConfig + logsDroppedEventAPI LogsDroppedEventAPI + flushCh chan struct{} +} + +func NewSubscriber( + agentName string, + categories map[EventCategory]struct{}, + bufCfg BufferingConfig, + client Client, + logsDroppedEventAPI LogsDroppedEventAPI, +) *Subscriber { + s := &Subscriber{ + agentName: agentName, + categories: categories, + client: client, + eventCh: make(chan json.RawMessage), + batchSenderCh: make(chan batch), + bufCfg: bufCfg, + logsDroppedEventAPI: logsDroppedEventAPI, + flushCh: make(chan struct{}), + } + + go s.batchSenderLoop() + go s.eventConsumerLoop() + return s +} + +func (s *Subscriber) AgentName() string { + return s.agentName +} + +func (s *Subscriber) SendAsync(event json.RawMessage, category EventCategory) { + if _, ok := s.categories[category]; !ok { + return + } + s.eventCh <- event +} + +func (s *Subscriber) Flush(ctx context.Context) { + s.curBatchMu.Lock() + b := s.curBatch + s.curBatchMu.Unlock() + if b == nil { + + return + } + + s.flushCh <- struct{}{} + select { + case <-b.doneCh: + + case <-ctx.Done(): + slog.Warn("could not flush telemetry api subscriber", "agent_name", s.agentName) + return + } +} + +func (s *Subscriber) eventConsumerLoop() { + var curBatchTimer <-chan time.Time + + for { + select { + + case <-s.flushCh: + outBatch := s.takeCurrentBatch() + if outBatch == nil { + + continue + } + s.sendCurrentBatchAsync(*outBatch) + curBatchTimer = nil + case <-curBatchTimer: + slog.Debug("sending batch after batch timer expired", "agent_name", s.agentName) + outBatch := s.takeCurrentBatch() + + s.sendCurrentBatchAsync(*outBatch) + curBatchTimer = nil + + case event := <-s.eventCh: + slog.Debug("processing event in eventConsumerLoop", "agent_name", s.agentName) + s.curBatchMu.Lock() + + if s.curBatch == nil { + s.curBatch = newBatch(s.bufCfg) + curBatchTimer = s.curBatch.flushAt + } + + if isFull := s.curBatch.addEvent(event); isFull { + + outBatch := s.curBatch + s.curBatch = nil + + s.sendCurrentBatchAsync(*outBatch) + curBatchTimer = nil + + } + s.curBatchMu.Unlock() + } + } +} + +func (s *Subscriber) takeCurrentBatch() *batch { + s.curBatchMu.Lock() + defer s.curBatchMu.Unlock() + b := s.curBatch + s.curBatch = nil + return b +} + +func (s *Subscriber) sendCurrentBatchAsync(batch batch) { + select { + case s.batchSenderCh <- batch: + slog.Debug("sent batch to batchSenderLoop", "agent_name", s.agentName) + default: + slog.Warn("could not send batch to telemetry Subscriber as previous batch hasn't processed yet", + "Subscriber", s.agentName) + + close(batch.doneCh) + if err := s.logsDroppedEventAPI.SendPlatformLogsDropped( + batch.sizeBytes, + len(batch.events), + "Some logs were dropped because the downstream consumer is slower than the logs production rate", + ); err != nil { + slog.Error("could not send platform.logsDropped event", "err", err) + } + } +} + +func (s *Subscriber) batchSenderLoop() { + for outBatch := range s.batchSenderCh { + slog.Debug("sending batch to client", "agent_name", s.agentName) + err := s.client.send(context.Background(), outBatch) + close(outBatch.doneCh) + if err != nil { + slog.Warn("could not send batch to telemetry api Subscriber", + "Subscriber", s.agentName, + "err", err) + if err := s.logsDroppedEventAPI.SendPlatformLogsDropped( + outBatch.sizeBytes, + len(outBatch.events), + fmt.Sprintf("could not send events: %s", err), + ); err != nil { + slog.Error("could not send platform.logsDropped event", "err", err) + } + } + } +} + +type LogsDroppedEventAPI interface { + SendPlatformLogsDropped(droppedBytes, droppedRecords int, reason string) error +} diff --git a/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/internal/subscriber_test.go b/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/internal/subscriber_test.go new file mode 100644 index 0000000..1950070 --- /dev/null +++ b/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/internal/subscriber_test.go @@ -0,0 +1,61 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package internal + +import ( + "context" + "encoding/json" + "fmt" + "math" + "math/rand" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func TestSubscriber(t *testing.T) { + t.Parallel() + + client := NewMockClient(t) + logsDroppedEventAPI := NewMockLogsDroppedEventAPI(t) + + agentName := fmt.Sprintf("test-name-%d", rand.Uint32()) + sub := NewSubscriber(agentName, map[EventCategory]struct{}{CategoryPlatform: {}}, BufferingConfig{MaxItems: 2, MaxBytes: math.MaxInt, Timeout: math.MaxInt}, client, logsDroppedEventAPI) + time.Sleep(100 * time.Millisecond) + assert.Equal(t, agentName, sub.AgentName()) + + sub.Flush(context.Background()) + client.AssertExpectations(t) + + event := json.RawMessage("data") + sub.SendAsync(event, CategoryFunction) + sub.SendAsync(event, CategoryExtension) + sub.Flush(context.Background()) + client.AssertExpectations(t) + + sub.SendAsync(event, CategoryPlatform) + client.AssertExpectations(t) + + client.On("send", mock.Anything, mock.Anything).Return(nil) + sub.SendAsync(event, CategoryPlatform) + + require.Eventually(t, func() bool { + return client.AssertNumberOfCalls(t, "send", 1) + }, time.Second, 10*time.Millisecond) + + sub.SendAsync(event, CategoryPlatform) + assert.Eventually( + t, + func() bool { + + sub.Flush(context.Background()) + return client.AssertNumberOfCalls(t, "send", 2) + }, + time.Second, + 10*time.Millisecond, + ) +} diff --git a/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/internal/types.go b/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/internal/types.go new file mode 100644 index 0000000..f727d1e --- /dev/null +++ b/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/internal/types.go @@ -0,0 +1,49 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package internal + +import "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/model" + +type Protocol = string + +const ( + protocolTCP Protocol = "TCP" + protocolHTTP Protocol = "HTTP" +) + +type SubscriptionDestination struct { + Protocol Protocol `json:"protocol"` + URI string `json:"URI,omitempty"` + Port uint16 `json:"port,omitempty"` +} + +type BufferingConfig struct { + MaxItems int `json:"maxItems"` + MaxBytes int `json:"maxBytes"` + Timeout model.DurationMS `json:"timeoutMs"` +} + +type EventCategory = string + +const ( + CategoryPlatform EventCategory = "platform" + CategoryFunction EventCategory = "function" + CategoryExtension EventCategory = "extension" +) + +type EventType = string + +const ( + TypePlatformInitStart EventType = "platform.initStart" + TypePlatformInitRuntimeDone EventType = "platform.initRuntimeDone" + TypePlatformExtension EventType = "platform.extension" + TypePlatformInitReport EventType = "platform.initReport" + TypePlatformStart EventType = "platform.start" + TypePlatformRuntimeDone EventType = "platform.runtimeDone" + TypePlatformReport EventType = "platform.report" + TypePlatformTelemetrySubscription EventType = "platform.telemetrySubscription" + TypePlatformLogsDropped EventType = "platform.logsDropped" + TypeFunction EventType = CategoryFunction + TypeExtension EventType = CategoryExtension +) diff --git a/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/logs_egress.go b/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/logs_egress.go new file mode 100644 index 0000000..b9523d8 --- /dev/null +++ b/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/logs_egress.go @@ -0,0 +1,63 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package telemetry + +import ( + "bufio" + "context" + "fmt" + "io" + "log/slog" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/internal" +) + +type LogsEgress struct { + eventRelay relay + w io.Writer +} + +func NewLogsEgress(eventRelay relay, w io.Writer) *LogsEgress { + return &LogsEgress{ + eventRelay: eventRelay, + w: w, + } +} + +func (e *LogsEgress) startWriter(category internal.EventCategory, typ internal.EventType) io.Writer { + pipeReader, pipeWriter := io.Pipe() + scanner := bufio.NewScanner(pipeReader) + + go func() { + for scanner.Scan() { + line := scanner.Text() + _, _ = fmt.Fprintln(e.w, line) + e.eventRelay.broadcast(line, category, typ) + } + if err := scanner.Err(); err != nil { + slog.Error("scanner failed", "err", err) + } else { + slog.Debug("log scanner reached EOF", "category", category, "type", typ) + } + }() + + return pipeWriter +} + +func (e *LogsEgress) GetExtensionSockets() (io.Writer, io.Writer, error) { + stdout := e.startWriter(internal.CategoryExtension, internal.TypeExtension) + stderr := e.startWriter(internal.CategoryExtension, internal.TypeExtension) + return stdout, stderr, nil +} + +func (e *LogsEgress) GetRuntimeSockets() (io.Writer, io.Writer, error) { + stdout := e.startWriter(internal.CategoryFunction, internal.TypeFunction) + stderr := e.startWriter(internal.CategoryFunction, internal.TypeFunction) + return stdout, stderr, nil +} + +type relay interface { + broadcast(record any, category internal.EventCategory, typ internal.EventType) + flush(ctx context.Context) +} diff --git a/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/logs_egress_test.go b/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/logs_egress_test.go new file mode 100644 index 0000000..bc78ff4 --- /dev/null +++ b/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/logs_egress_test.go @@ -0,0 +1,69 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package telemetry + +import ( + "io" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/internal" +) + +func TestLogsEgress(t *testing.T) { + tests := []struct { + name string + socketType string + expectedCategory internal.EventCategory + }{ + { + name: "extension_sockets_use_extension_category", + socketType: "extension", + expectedCategory: internal.CategoryExtension, + }, + { + name: "runtime_sockets_use_function_category", + socketType: "runtime", + expectedCategory: internal.CategoryFunction, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + relay := newMockRelay(t) + defer relay.AssertExpectations(t) + + egress := NewLogsEgress(relay, io.Discard) + + var stdout, stderr io.Writer + var err error + + switch tt.socketType { + case "extension": + stdout, stderr, err = egress.GetExtensionSockets() + case "runtime": + stdout, stderr, err = egress.GetRuntimeSockets() + } + assert.NoError(t, err) + require.NotNil(t, stdout) + require.NotNil(t, stderr) + + line := []byte("test\n") + relay.On("broadcast", "test", tt.expectedCategory, tt.expectedCategory).Twice() + n, err := stdout.Write(line) + assert.NoError(t, err) + assert.Len(t, line, n) + n, err = stderr.Write(line) + assert.NoError(t, err) + assert.Len(t, line, n) + + assert.Eventually(t, func() bool { + return relay.AssertNumberOfCalls(t, "broadcast", 2) + }, 1*time.Second, 10*time.Millisecond) + }) + } +} diff --git a/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/mock_relay.go b/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/mock_relay.go new file mode 100644 index 0000000..eabb894 --- /dev/null +++ b/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/mock_relay.go @@ -0,0 +1,34 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package telemetry + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" +) + +type mockRelay struct { + mock.Mock +} + +func (_m *mockRelay) broadcast(record interface{}, category string, typ string) { + _m.Called(record, category, typ) +} + +func (_m *mockRelay) flush(ctx context.Context) { + _m.Called(ctx) +} + +func newMockRelay(t interface { + mock.TestingT + Cleanup(func()) +}) *mockRelay { + mock := &mockRelay{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/mock_sub.go b/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/mock_sub.go new file mode 100644 index 0000000..91e9124 --- /dev/null +++ b/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/mock_sub.go @@ -0,0 +1,52 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package telemetry + +import ( + context "context" + json "encoding/json" + + mock "github.com/stretchr/testify/mock" +) + +type mockSub struct { + mock.Mock +} + +func (_m *mockSub) AgentName() string { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for AgentName") + } + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +func (_m *mockSub) Flush(ctx context.Context) { + _m.Called(ctx) +} + +func (_m *mockSub) SendAsync(event json.RawMessage, cat string) { + _m.Called(event, cat) +} + +func newMockSub(t interface { + mock.TestingT + Cleanup(func()) +}) *mockSub { + mock := &mockSub{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/mock_subscription_store.go b/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/mock_subscription_store.go new file mode 100644 index 0000000..61448d3 --- /dev/null +++ b/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/mock_subscription_store.go @@ -0,0 +1,43 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package telemetry + +import mock "github.com/stretchr/testify/mock" + +type mockSubscriptionStore struct { + mock.Mock +} + +func (_m *mockSubscriptionStore) addSubscriber(subscriber sub) error { + ret := _m.Called(subscriber) + + if len(ret) == 0 { + panic("no return value specified for addSubscriber") + } + + var r0 error + if rf, ok := ret.Get(0).(func(sub) error); ok { + r0 = rf(subscriber) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +func (_m *mockSubscriptionStore) disableAddSubscriber() { + _m.Called() +} + +func newMockSubscriptionStore(t interface { + mock.TestingT + Cleanup(func()) +}) *mockSubscriptionStore { + mock := &mockSubscriptionStore{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/mock_telemetry_subscription_event_api.go b/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/mock_telemetry_subscription_event_api.go new file mode 100644 index 0000000..a54fd90 --- /dev/null +++ b/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/mock_telemetry_subscription_event_api.go @@ -0,0 +1,39 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package telemetry + +import mock "github.com/stretchr/testify/mock" + +type mockTelemetrySubscriptionEventAPI struct { + mock.Mock +} + +func (_m *mockTelemetrySubscriptionEventAPI) sendTelemetrySubscription(agentName string, state string, categories []string) error { + ret := _m.Called(agentName, state, categories) + + if len(ret) == 0 { + panic("no return value specified for sendTelemetrySubscription") + } + + var r0 error + if rf, ok := ret.Get(0).(func(string, string, []string) error); ok { + r0 = rf(agentName, state, categories) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +func newMockTelemetrySubscriptionEventAPI(t interface { + mock.TestingT + Cleanup(func()) +}) *mockTelemetrySubscriptionEventAPI { + mock := &mockTelemetrySubscriptionEventAPI{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/relay.go b/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/relay.go new file mode 100644 index 0000000..b324a0f --- /dev/null +++ b/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/relay.go @@ -0,0 +1,113 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package telemetry + +import ( + "context" + "encoding/json" + "errors" + "sync" + "time" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/internal" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/telemetry" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/utils/invariant" +) + +type Relay struct { + mu sync.Mutex + subs map[string]sub + + initEventsBuffer []bufferedEvent + disabled bool +} + +type bufferedEvent struct { + event json.RawMessage + cat internal.EventCategory +} + +func NewRelay() *Relay { + return &Relay{ + subs: make(map[string]sub), + } +} + +var errSubscriberAlreadyExist = errors.New("subscriber already exists") + +func (r *Relay) addSubscriber(subscriber sub) error { + r.mu.Lock() + defer r.mu.Unlock() + + if r.disabled { + return telemetry.ErrTelemetryServiceOff + } + + if _, ok := r.subs[subscriber.AgentName()]; ok { + return errSubscriberAlreadyExist + } + r.subs[subscriber.AgentName()] = subscriber + + for _, be := range r.initEventsBuffer { + subscriber.SendAsync(be.event, be.cat) + } + + return nil +} + +func (r *Relay) disableAddSubscriber() { + r.mu.Lock() + defer r.mu.Unlock() + r.disabled = true + r.initEventsBuffer = nil +} + +func (r *Relay) broadcast(record any, category internal.EventCategory, typ internal.EventType) { + recordJSON, err := json.Marshal(record) + invariant.Checkf(err == nil, "could not marshal record to json: %s", err) + + event := telemetry.Event{ + Time: time.Now().UTC().Format(telemetry.TimeFormat), + Type: typ, + Record: recordJSON, + } + + b, err := json.Marshal(event) + invariant.Checkf(err == nil, "could not marshal json telemetry event: %s", err) + + r.mu.Lock() + defer r.mu.Unlock() + + for _, sub := range r.subs { + sub.SendAsync(b, category) + } + + if !r.disabled { + r.initEventsBuffer = append(r.initEventsBuffer, bufferedEvent{ + event: b, + cat: category, + }) + } +} + +func (r *Relay) flush(ctx context.Context) { + r.mu.Lock() + defer r.mu.Unlock() + + var wg sync.WaitGroup + wg.Add(len(r.subs)) + for _, s := range r.subs { + go func(s sub) { + s.Flush(ctx) + wg.Done() + }(s) + } + wg.Wait() +} + +type sub interface { + AgentName() string + SendAsync(event json.RawMessage, cat internal.EventCategory) + Flush(ctx context.Context) +} diff --git a/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/relay_test.go b/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/relay_test.go new file mode 100644 index 0000000..bb124cc --- /dev/null +++ b/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/relay_test.go @@ -0,0 +1,45 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package telemetry + +import ( + "context" + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/internal" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/telemetry" +) + +func TestRelay(t *testing.T) { + r := NewRelay() + r.broadcast("buffer_before_first", internal.CategoryFunction, internal.TypeFunction) + r.flush(context.Background()) + + sub1 := newMockSub(t) + defer sub1.AssertExpectations(t) + sub1.On("AgentName").Return("sub1") + sub1.On("SendAsync", mock.Anything, mock.Anything).Times(4) + sub1.On("Flush", mock.Anything).Return().Once() + require.NoError(t, r.addSubscriber(sub1)) + require.Error(t, errSubscriberAlreadyExist, r.addSubscriber(sub1)) + + r.broadcast("buffer_before_second", internal.CategoryFunction, internal.TypeFunction) + sub2 := newMockSub(t) + defer sub2.AssertExpectations(t) + sub2.On("AgentName").Return("sub2") + sub2.On("SendAsync", mock.Anything, mock.Anything).Times(4) + sub2.On("Flush", mock.Anything).Return().Once() + require.NoError(t, r.addSubscriber(sub2)) + + r.disableAddSubscriber() + require.Error(t, telemetry.ErrTelemetryServiceOff, r.addSubscriber(newMockSub(t))) + + r.broadcast("test_function", internal.CategoryFunction, internal.TypeFunction) + r.broadcast("test_extension", internal.CategoryExtension, internal.TypeExtension) + + r.flush(context.Background()) +} diff --git a/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/schema/telemetry-subscription-schema.json b/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/schema/telemetry-subscription-schema.json new file mode 100644 index 0000000..24617ba --- /dev/null +++ b/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/schema/telemetry-subscription-schema.json @@ -0,0 +1,81 @@ +{ + "type": "object", + "required": [ + "schemaVersion", + "types", + "destination" + ], + "properties": { + "schemaVersion": { + "type": "string", + "enum": [ + "2025-01-29" + ] + }, + "types": { + "type": "array", + "minItems": 1, + "items": { + "type": "string", + "enum": [ + "platform", + "function", + "extension" + ] + } + }, + "destination": { + "oneOf": [ + { + "type": "object", + "required": ["protocol", "URI"], + "properties": { + "protocol": { + "type": "string", + "enum": ["HTTP"] + }, + "URI": { + "type": "string", + "minLength": 1 + } + } + }, + { + "type": "object", + "required": ["protocol", "port"], + "properties": { + "protocol": { + "type": "string", + "enum": ["TCP"] + }, + "port": { + "type": "integer", + "minimum": 1, + "maximum": 65535 + } + } + } + ] + }, + "buffering": { + "type": "object", + "properties": { + "maxItems": { + "type": "integer", + "minimum": 1000, + "maximum": 10000 + }, + "maxBytes": { + "type": "integer", + "minimum": 262144, + "maximum": 1048576 + }, + "timeoutMs": { + "type": "integer", + "minimum": 25, + "maximum": 30000 + } + } + } + } +} diff --git a/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/subscription_api.go b/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/subscription_api.go new file mode 100644 index 0000000..b16a8bb --- /dev/null +++ b/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/subscription_api.go @@ -0,0 +1,188 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package telemetry + +import ( + "bytes" + _ "embed" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "net/netip" + "time" + + jsonschema "github.com/santhosh-tekuri/jsonschema/v5" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/internal" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/model" +) + +//go:embed schema/telemetry-subscription-schema.json +var subscriptionSchema []byte +var telemetrySchema *jsonschema.Schema + +func init() { + compiler := jsonschema.NewCompiler() + if err := compiler.AddResource("telemetry-subscription-schema.json", bytes.NewReader(subscriptionSchema)); err != nil { + slog.Error("error adding telemetry schema resource", "err", err) + panic(err) + } + + var err error + telemetrySchema, err = compiler.Compile("telemetry-subscription-schema.json") + if err != nil { + slog.Error("error compiling telemetry schema", "err", err) + panic(err) + } +} + +type SubscriptionAPI struct { + store subscriptionStore + logsDroppedEventAPI internal.LogsDroppedEventAPI + telemetrySubscriptionEventAPI telemetrySubscriptionEventAPI +} + +type SubscriptionRequest struct { + SchemaVersion string `json:"schemaVersion"` + Categories []string `json:"types"` + Destination internal.SubscriptionDestination `json:"destination"` + Buffering *internal.BufferingConfig `json:"buffering,omitempty"` +} + +func NewSubscriptionAPI(store subscriptionStore, logsDroppedEventAPI internal.LogsDroppedEventAPI, telemetrySubscriptionEventAPI telemetrySubscriptionEventAPI) *SubscriptionAPI { + return &SubscriptionAPI{ + store: store, + logsDroppedEventAPI: logsDroppedEventAPI, + telemetrySubscriptionEventAPI: telemetrySubscriptionEventAPI, + } +} + +func (api *SubscriptionAPI) Subscribe(agentName string, body io.Reader, _ map[string][]string, _ string) (resp []byte, status int, respHeaders map[string][]string, err error) { + + bodyData, err := io.ReadAll(body) + if err != nil { + return []byte(fmt.Sprintf(`{"errorType": "ValidationError", "errorMessage": "Failed to read request body: %s"}`, err.Error())), + http.StatusBadRequest, + map[string][]string{}, + nil + } + + if err := api.validateSubscriptionJSON(bodyData); err != nil { + return []byte(fmt.Sprintf(`{"errorType": "ValidationError", "errorMessage": "%s"}`, err.Error())), + http.StatusBadRequest, + map[string][]string{}, + nil + } + + var req SubscriptionRequest + if err := json.Unmarshal(bodyData, &req); err != nil { + return []byte(fmt.Sprintf(`{"errorType": "ValidationError", "errorMessage": "Invalid JSON: %s"}`, err.Error())), + http.StatusBadRequest, + map[string][]string{}, + nil + } + + buffering := internal.BufferingConfig{ + MaxItems: 10000, + MaxBytes: 256 * 1024, + Timeout: model.DurationMS(1 * time.Second), + } + + if req.Buffering != nil { + if req.Buffering.MaxItems > 0 { + buffering.MaxItems = req.Buffering.MaxItems + } + if req.Buffering.MaxBytes > 0 { + buffering.MaxBytes = req.Buffering.MaxBytes + } + if req.Buffering.Timeout > 0 { + buffering.Timeout = req.Buffering.Timeout + } + } + + c, err := internal.NewClient(req.Destination) + if err != nil { + return []byte(fmt.Sprintf(`{"errorType": "ValidationError", "errorMessage": "Invalid destination: %s"}`, err.Error())), + http.StatusBadRequest, + map[string][]string{}, + nil + } + + categories := make(map[internal.EventCategory]struct{}, len(req.Categories)) + for _, category := range req.Categories { + categories[category] = struct{}{} + } + + subscriber := internal.NewSubscriber(agentName, categories, buffering, c, api.logsDroppedEventAPI) + + if err := api.store.addSubscriber(subscriber); err != nil { + return nil, 0, nil, err + } + + slog.Info("Telemetry subscription created", + "agentName", agentName, + "destinationURI", req.Destination.URI, + "destinationPort", req.Destination.Port, + "protocol", req.Destination.Protocol, + "categories", req.Categories) + + if err := api.telemetrySubscriptionEventAPI.sendTelemetrySubscription(agentName, "Subscribed", req.Categories); err != nil { + slog.Error("Failed to send platform.telemetrySubscription event", "err", err) + } + + return []byte(`"OK"`), http.StatusOK, map[string][]string{}, nil +} + +func (api *SubscriptionAPI) validateSubscriptionJSON(jsonData []byte) error { + var rawData map[string]any + if err := json.Unmarshal(jsonData, &rawData); err != nil { + return fmt.Errorf("failed to parse JSON: %w", err) + } + + if err := telemetrySchema.Validate(rawData); err != nil { + return fmt.Errorf("schema validation error: %w", err) + } + + return nil +} + +func (api *SubscriptionAPI) RecordCounterMetric(metricName string, count int) {} + +func (api *SubscriptionAPI) FlushMetrics() interop.TelemetrySubscriptionMetrics { + return interop.TelemetrySubscriptionMetrics{} +} + +func (api *SubscriptionAPI) Clear() { + panic("not implemented") +} + +func (api *SubscriptionAPI) TurnOff() { + api.store.disableAddSubscriber() +} + +func (api *SubscriptionAPI) GetEndpointURL() string { + panic("not implemented") +} + +func (api *SubscriptionAPI) GetServiceClosedErrorMessage() string { + return "Telemetry API subscription is closed" +} + +func (api *SubscriptionAPI) GetServiceClosedErrorType() string { + return "Telemetry.SubscriptionClosed" +} + +func (api *SubscriptionAPI) Configure(passphrase string, addr netip.AddrPort) {} + +type telemetrySubscriptionEventAPI interface { + sendTelemetrySubscription(agentName, state string, categories []internal.EventCategory) error +} + +type subscriptionStore interface { + addSubscriber(subscriber sub) error + disableAddSubscriber() +} diff --git a/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/subscription_api_test.go b/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/subscription_api_test.go new file mode 100644 index 0000000..a1b472b --- /dev/null +++ b/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/subscription_api_test.go @@ -0,0 +1,380 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package telemetry + +import ( + "fmt" + "io" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/aws-lambda-rie/internal/telemetry/internal" +) + +func TestSubscriptionAPI_Subscribe(t *testing.T) { + tests := []struct { + name string + agentName string + body io.Reader + expectedStatus int + expectedError bool + setupMocks func(*mockSubscriptionStore, internal.LogsDroppedEventAPI, *mockTelemetrySubscriptionEventAPI) + validateResp func(t *testing.T, resp []byte) + }{ + { + name: "valid_http_subscription_request", + agentName: "test-agent", + body: strings.NewReader(`{ + "schemaVersion": "2025-01-29", + "types": ["platform", "function"], + "destination": { + "protocol": "HTTP", + "URI": "http://sandbox.localdomain:8080/telemetry" + }, + "buffering": { + "maxItems": 5000, + "maxBytes": 262144, + "timeoutMs": 500 + } + }`), + expectedStatus: 200, + expectedError: false, + setupMocks: func(store *mockSubscriptionStore, logsAPI internal.LogsDroppedEventAPI, telemetryAPI *mockTelemetrySubscriptionEventAPI) { + store.On("addSubscriber", mock.MatchedBy(func(s *internal.Subscriber) bool { + return s.AgentName() == "test-agent" + })).Return(nil).Once() + telemetryAPI.On("sendTelemetrySubscription", "test-agent", "Subscribed", []string{"platform", "function"}).Return(nil).Once() + }, + validateResp: func(t *testing.T, resp []byte) { + assert.Equal(t, `"OK"`, string(resp)) + }, + }, + { + name: "tcp_connection_refused", + agentName: "tcp-agent", + body: strings.NewReader(`{ + "schemaVersion": "2025-01-29", + "types": ["extension"], + "destination": { + "protocol": "TCP", + "port": 8081 + } + }`), + expectedStatus: 400, + expectedError: false, + setupMocks: func(store *mockSubscriptionStore, logsAPI internal.LogsDroppedEventAPI, telemetryAPI *mockTelemetrySubscriptionEventAPI) { + + }, + validateResp: func(t *testing.T, resp []byte) { + assert.Contains(t, string(resp), "ValidationError") + assert.Contains(t, string(resp), "Invalid destination") + }, + }, + { + name: "subscription_with_default_buffering", + agentName: "default-agent", + body: strings.NewReader(`{ + "schemaVersion": "2025-01-29", + "types": ["platform"], + "destination": { + "protocol": "HTTP", + "URI": "http://sandbox.localdomain:8080/events" + } + }`), + expectedStatus: 200, + expectedError: false, + setupMocks: func(store *mockSubscriptionStore, logsAPI internal.LogsDroppedEventAPI, telemetryAPI *mockTelemetrySubscriptionEventAPI) { + store.On("addSubscriber", mock.MatchedBy(func(s *internal.Subscriber) bool { + return s.AgentName() == "default-agent" + + })).Return(nil).Once() + telemetryAPI.On("sendTelemetrySubscription", "default-agent", "Subscribed", []string{"platform"}).Return(nil).Once() + }, + validateResp: func(t *testing.T, resp []byte) { + assert.Equal(t, `"OK"`, string(resp)) + }, + }, + { + name: "invalid_json_request", + agentName: "invalid-agent", + body: strings.NewReader(`{"invalid": json}`), + expectedStatus: 400, + expectedError: false, + setupMocks: func(store *mockSubscriptionStore, logsAPI internal.LogsDroppedEventAPI, telemetryAPI *mockTelemetrySubscriptionEventAPI) { + + }, + validateResp: func(t *testing.T, resp []byte) { + assert.Contains(t, string(resp), "ValidationError") + assert.Contains(t, string(resp), "failed to parse JSON") + }, + }, + { + name: "schema_validation_error_missing_required_fields", + agentName: "schema-error-agent", + body: strings.NewReader(`{ + "schemaVersion": "2025-01-29" + }`), + expectedStatus: 400, + expectedError: false, + setupMocks: func(store *mockSubscriptionStore, logsAPI internal.LogsDroppedEventAPI, telemetryAPI *mockTelemetrySubscriptionEventAPI) { + + }, + validateResp: func(t *testing.T, resp []byte) { + assert.Contains(t, string(resp), "ValidationError") + assert.Contains(t, string(resp), "schema validation error") + }, + }, + { + name: "invalid_destination_protocol", + agentName: "invalid-dest-agent", + body: strings.NewReader(`{ + "schemaVersion": "2025-01-29", + "types": ["platform"], + "destination": { + "protocol": "INVALID", + "URI": "http://localhost:8080" + } + }`), + expectedStatus: 400, + expectedError: false, + setupMocks: func(store *mockSubscriptionStore, logsAPI internal.LogsDroppedEventAPI, telemetryAPI *mockTelemetrySubscriptionEventAPI) { + + }, + validateResp: func(t *testing.T, resp []byte) { + assert.Contains(t, string(resp), "ValidationError") + assert.Contains(t, string(resp), "schema validation error") + }, + }, + { + name: "telemetry_subscription_event_api_error", + agentName: "event-error-agent", + body: strings.NewReader(`{ + "schemaVersion": "2025-01-29", + "types": ["platform"], + "destination": { + "protocol": "HTTP", + "URI": "http://sandbox.localdomain:8080/events" + } + }`), + expectedStatus: 200, + expectedError: false, + setupMocks: func(store *mockSubscriptionStore, logsAPI internal.LogsDroppedEventAPI, telemetryAPI *mockTelemetrySubscriptionEventAPI) { + store.On("addSubscriber", mock.MatchedBy(func(s *internal.Subscriber) bool { + return s.AgentName() == "event-error-agent" + })).Return(nil).Once() + telemetryAPI.On("sendTelemetrySubscription", "event-error-agent", "Subscribed", []string{"platform"}).Return(fmt.Errorf("event API error")).Once() + }, + validateResp: func(t *testing.T, resp []byte) { + + assert.Equal(t, `"OK"`, string(resp)) + }, + }, + { + name: "partial_buffering_config", + agentName: "partial-buffer-agent", + body: strings.NewReader(`{ + "schemaVersion": "2025-01-29", + "types": ["function"], + "destination": { + "protocol": "HTTP", + "URI": "http://sandbox.localdomain:8080/events" + }, + "buffering": { + "maxItems": 2000 + } + }`), + expectedStatus: 200, + expectedError: false, + setupMocks: func(store *mockSubscriptionStore, logsAPI internal.LogsDroppedEventAPI, telemetryAPI *mockTelemetrySubscriptionEventAPI) { + store.On("addSubscriber", mock.MatchedBy(func(s *internal.Subscriber) bool { + return s.AgentName() == "partial-buffer-agent" + + })).Return(nil).Once() + telemetryAPI.On("sendTelemetrySubscription", "partial-buffer-agent", "Subscribed", []string{"function"}).Return(nil).Once() + }, + validateResp: func(t *testing.T, resp []byte) { + assert.Equal(t, `"OK"`, string(resp)) + }, + }, + { + name: "empty_types_array", + agentName: "empty-types-agent", + body: strings.NewReader(`{ + "schemaVersion": "2025-01-29", + "types": [], + "destination": { + "protocol": "HTTP", + "URI": "http://sandbox.localdomain:8080/events" + } + }`), + expectedStatus: 400, + expectedError: false, + setupMocks: func(store *mockSubscriptionStore, logsAPI internal.LogsDroppedEventAPI, telemetryAPI *mockTelemetrySubscriptionEventAPI) { + + }, + validateResp: func(t *testing.T, resp []byte) { + assert.Contains(t, string(resp), "ValidationError") + assert.Contains(t, string(resp), "schema validation error") + }, + }, + { + name: "missing_required_field_destination", + agentName: "missing-dest-agent", + body: strings.NewReader(`{ + "schemaVersion": "2025-01-29", + "types": ["platform"] + }`), + expectedStatus: 400, + expectedError: false, + setupMocks: func(store *mockSubscriptionStore, logsAPI internal.LogsDroppedEventAPI, telemetryAPI *mockTelemetrySubscriptionEventAPI) { + + }, + validateResp: func(t *testing.T, resp []byte) { + assert.Contains(t, string(resp), "ValidationError") + assert.Contains(t, string(resp), "schema validation error") + }, + }, + { + name: "invalid_schema_version", + agentName: "invalid-version-agent", + body: strings.NewReader(`{ + "schemaVersion": "invalid-version", + "types": ["platform"], + "destination": { + "protocol": "HTTP", + "URI": "http://sandbox.localdomain:8080/events" + } + }`), + expectedStatus: 400, + expectedError: false, + setupMocks: func(store *mockSubscriptionStore, logsAPI internal.LogsDroppedEventAPI, telemetryAPI *mockTelemetrySubscriptionEventAPI) { + + }, + validateResp: func(t *testing.T, resp []byte) { + assert.Contains(t, string(resp), "ValidationError") + assert.Contains(t, string(resp), "schema validation error") + }, + }, + { + name: "tcp_destination_missing_port", + agentName: "tcp-missing-port-agent", + body: strings.NewReader(`{ + "schemaVersion": "2025-01-29", + "types": ["platform"], + "destination": { + "protocol": "TCP" + } + }`), + expectedStatus: 400, + expectedError: false, + setupMocks: func(store *mockSubscriptionStore, logsAPI internal.LogsDroppedEventAPI, telemetryAPI *mockTelemetrySubscriptionEventAPI) { + + }, + validateResp: func(t *testing.T, resp []byte) { + assert.Contains(t, string(resp), "ValidationError") + assert.Contains(t, string(resp), "schema validation error") + }, + }, + { + name: "http_destination_missing_uri", + agentName: "http-missing-uri-agent", + body: strings.NewReader(`{ + "schemaVersion": "2025-01-29", + "types": ["platform"], + "destination": { + "protocol": "HTTP" + } + }`), + expectedStatus: 400, + expectedError: false, + setupMocks: func(store *mockSubscriptionStore, logsAPI internal.LogsDroppedEventAPI, telemetryAPI *mockTelemetrySubscriptionEventAPI) { + + }, + validateResp: func(t *testing.T, resp []byte) { + assert.Contains(t, string(resp), "ValidationError") + assert.Contains(t, string(resp), "schema validation error") + }, + }, + { + name: "relay_add_subscriber_error", + agentName: "store-error-agent", + body: strings.NewReader(`{ + "schemaVersion": "2025-01-29", + "types": ["platform"], + "destination": { + "protocol": "HTTP", + "URI": "http://sandbox.localdomain:8080/events" + } + }`), + expectedStatus: 0, + expectedError: true, + setupMocks: func(store *mockSubscriptionStore, logsAPI internal.LogsDroppedEventAPI, telemetryAPI *mockTelemetrySubscriptionEventAPI) { + store.On("addSubscriber", mock.MatchedBy(func(s *internal.Subscriber) bool { + return s.AgentName() == "store-error-agent" + })).Return(fmt.Errorf("subscription store disabled")).Once() + }, + validateResp: func(t *testing.T, resp []byte) { + + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := newMockSubscriptionStore(t) + defer store.AssertExpectations(t) + logsDroppedAPI := internal.NewMockLogsDroppedEventAPI(t) + telemetryAPI := newMockTelemetrySubscriptionEventAPI(t) + + tt.setupMocks(store, logsDroppedAPI, telemetryAPI) + + api := NewSubscriptionAPI(store, logsDroppedAPI, telemetryAPI) + + resp, status, headers, err := api.Subscribe(tt.agentName, tt.body, nil, "") + + if tt.expectedError { + assert.Error(t, err) + assert.Nil(t, headers) + } else { + assert.NoError(t, err) + assert.NotNil(t, headers) + } + + assert.Equal(t, tt.expectedStatus, status) + + if tt.validateResp != nil { + tt.validateResp(t, resp) + } + }) + } +} + +func TestSubscriptionAPI_Subscribe_ReadError(t *testing.T) { + relay := NewRelay() + logsDroppedAPI := internal.NewMockLogsDroppedEventAPI(t) + telemetryAPI := newMockTelemetrySubscriptionEventAPI(t) + + api := NewSubscriptionAPI(relay, logsDroppedAPI, telemetryAPI) + + failingReader := &failingReader{err: fmt.Errorf("read failed")} + + resp, status, headers, err := api.Subscribe("test-agent", failingReader, nil, "") + + assert.NoError(t, err) + assert.Equal(t, 400, status) + assert.NotNil(t, headers) + assert.Contains(t, string(resp), "ValidationError") + assert.Contains(t, string(resp), "Failed to read request body") +} + +type failingReader struct { + err error +} + +func (r *failingReader) Read(p []byte) (n int, err error) { + return 0, r.err +} diff --git a/internal/lambda-managed-instances/aws-lambda-rie/internal/utils.go b/internal/lambda-managed-instances/aws-lambda-rie/internal/utils.go new file mode 100644 index 0000000..4604836 --- /dev/null +++ b/internal/lambda-managed-instances/aws-lambda-rie/internal/utils.go @@ -0,0 +1,67 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package internal + +import ( + "fmt" + "log/slog" + "net" + "net/netip" + "os" + + "github.com/jessevdk/go-flags" +) + +type Options struct { + LogLevel string `long:"log-level" description:"Log level (default: info). Can also be set via LOG_LEVEL env."` + RuntimeAddress string `long:"runtime-api-address" description:"Address of the Lambda Runtime API."` + RIEAddress string `long:"runtime-interface-emulator-address" default:"0.0.0.0:8080" description:"Address for RIE to accept HTTP requests."` +} + +func ParseCLIArgs(args []string) (Options, []string, error) { + var opts Options + parser := flags.NewParser(&opts, flags.IgnoreUnknown) + args, err := parser.ParseArgs(args) + + if opts.LogLevel == "" { + opts.LogLevel = os.Getenv("LOG_LEVEL") + if opts.LogLevel == "" { + opts.LogLevel = "info" + } + } + + return opts, args, err +} + +func ConfigureLogging(levelStr string) { + var lvl slog.Level + if err := lvl.UnmarshalText([]byte(levelStr)); err != nil { + lvl = slog.LevelInfo + } + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: lvl})) + slog.SetDefault(logger) +} + +func ParseAddr(addrStr, defaultAddr string) (netip.AddrPort, error) { + if addrStr == "" { + addrStr = defaultAddr + } + + host, portStr, err := net.SplitHostPort(addrStr) + if err != nil { + return netip.AddrPort{}, fmt.Errorf("invalid address: %w", err) + } + + port, err := net.LookupPort("tcp", portStr) + if err != nil { + return netip.AddrPort{}, fmt.Errorf("invalid port: %w", err) + } + + ip, err := netip.ParseAddr(host) + if err != nil { + return netip.AddrPort{}, fmt.Errorf("invalid IP: %w", err) + } + + return netip.AddrPortFrom(ip, uint16(port)), nil +} diff --git a/internal/lambda-managed-instances/aws-lambda-rie/internal/utils_test.go b/internal/lambda-managed-instances/aws-lambda-rie/internal/utils_test.go new file mode 100644 index 0000000..edbaf0b --- /dev/null +++ b/internal/lambda-managed-instances/aws-lambda-rie/internal/utils_test.go @@ -0,0 +1,211 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package internal + +import ( + "context" + "log/slog" + "net/netip" + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseCLIArgs(t *testing.T) { + + originalArgs := os.Args + originalLogLevel := os.Getenv("LOG_LEVEL") + defer func() { + os.Args = originalArgs + require.NoError(t, os.Setenv("LOG_LEVEL", originalLogLevel)) + }() + + tests := []struct { + name string + args []string + envLogLevel string + expectedOpts Options + expectError bool + }{ + { + name: "Default values", + args: []string{"cmd"}, + envLogLevel: "", + expectedOpts: Options{ + LogLevel: "info", + RuntimeAddress: "", + RIEAddress: "0.0.0.0:8080", + }, + expectError: false, + }, + { + name: "Command line arguments", + args: []string{"cmd", "--log-level=debug", "--runtime-api-address=127.0.0.1:8000", "--runtime-interface-emulator-address=0.0.0.0:9000"}, + envLogLevel: "", + expectedOpts: Options{ + LogLevel: "debug", + RuntimeAddress: "127.0.0.1:8000", + RIEAddress: "0.0.0.0:9000", + }, + expectError: false, + }, + { + name: "Environment variable override", + args: []string{"cmd"}, + envLogLevel: "warn", + expectedOpts: Options{ + LogLevel: "warn", + RuntimeAddress: "", + RIEAddress: "0.0.0.0:8080", + }, + expectError: false, + }, + { + name: "Command line takes precedence over env", + args: []string{"cmd", "--log-level=debug"}, + envLogLevel: "warn", + expectedOpts: Options{ + LogLevel: "debug", + RuntimeAddress: "", + RIEAddress: "0.0.0.0:8080", + }, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + os.Args = tt.args + if tt.envLogLevel != "" { + require.NoError(t, os.Setenv("LOG_LEVEL", tt.envLogLevel)) + } else { + require.NoError(t, os.Unsetenv("LOG_LEVEL")) + } + + opts, _, err := ParseCLIArgs(os.Args) + + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + + assert.Equal(t, tt.expectedOpts.LogLevel, opts.LogLevel) + assert.Equal(t, tt.expectedOpts.RuntimeAddress, opts.RuntimeAddress) + assert.Equal(t, tt.expectedOpts.RIEAddress, opts.RIEAddress) + }) + } +} + +func TestConfigureLogging(t *testing.T) { + + originalLogger := slog.Default() + defer slog.SetDefault(originalLogger) + + tests := []struct { + name string + logLevel string + wantLevel slog.Level + }{ + { + name: "Debug level", + logLevel: "debug", + wantLevel: slog.LevelDebug, + }, + { + name: "Info level", + logLevel: "info", + wantLevel: slog.LevelInfo, + }, + { + name: "Warn level", + logLevel: "warn", + wantLevel: slog.LevelWarn, + }, + { + name: "Error level", + logLevel: "error", + wantLevel: slog.LevelError, + }, + { + name: "Invalid level defaults to Info", + logLevel: "invalid", + wantLevel: slog.LevelInfo, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + ConfigureLogging(tt.logLevel) + + logger := slog.Default() + ctx := context.TODO() + assert.True(t, logger.Enabled(ctx, tt.wantLevel), + "Logger should have %v level enabled", tt.wantLevel) + }) + } +} + +func TestParseAddr(t *testing.T) { + tests := []struct { + name string + addrStr string + defaultAddr string + want netip.AddrPort + wantErr bool + }{ + { + name: "Valid IPv6 address", + addrStr: "[::1]:8080", + defaultAddr: "0.0.0.0:9000", + want: netip.AddrPortFrom(netip.MustParseAddr("::1"), 8080), + wantErr: false, + }, + { + name: "Use default address", + addrStr: "", + defaultAddr: "0.0.0.0:9000", + want: netip.AddrPortFrom(netip.MustParseAddr("0.0.0.0"), 9000), + wantErr: false, + }, + { + name: "Invalid address format", + addrStr: "127.0.0.1", + defaultAddr: "0.0.0.0:9000", + want: netip.AddrPort{}, + wantErr: true, + }, + { + name: "Invalid port", + addrStr: "127.0.0.1:invalid", + defaultAddr: "0.0.0.0:9000", + want: netip.AddrPort{}, + wantErr: true, + }, + { + name: "Invalid IP", + addrStr: "invalid:8080", + defaultAddr: "0.0.0.0:9000", + want: netip.AddrPort{}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ParseAddr(tt.addrStr, tt.defaultAddr) + + if tt.wantErr { + assert.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tt.want, got) + } + }) + } +} diff --git a/internal/lambda-managed-instances/aws-lambda-rie/run/run.go b/internal/lambda-managed-instances/aws-lambda-rie/run/run.go new file mode 100644 index 0000000..bf61100 --- /dev/null +++ b/internal/lambda-managed-instances/aws-lambda-rie/run/run.go @@ -0,0 +1,32 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package run + +import ( + "log/slog" + "os" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/aws-lambda-rie/internal" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/supervisor/local" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/utils" +) + +func Run() { + server, rieHandler, _, err := internal.Run(local.NewProcessSupervisor(local.WithLowerPriorities(false)), os.Args, utils.NewFileUtil(), make(chan os.Signal, 1)) + if err != nil { + slog.Error("rie failed", "err", err) + os.Exit(1) + } + + if err := rieHandler.Init(); err != nil { + slog.Warn("INIT failed", "err", err) + + } + + <-server.Done() + if err := server.Err(); err != nil { + slog.Warn("rie server stopped", "err", err) + os.Exit(1) + } +} diff --git a/internal/lambda-managed-instances/aws-lambda-rie/test/rie_test.go b/internal/lambda-managed-instances/aws-lambda-rie/test/rie_test.go new file mode 100644 index 0000000..4992ec9 --- /dev/null +++ b/internal/lambda-managed-instances/aws-lambda-rie/test/rie_test.go @@ -0,0 +1,488 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +//go:build test + +package test + +import ( + "io" + "net/http" + "os" + "strings" + "sync" + "syscall" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/aws-lambda-rie/internal" + rmodel "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/testutils" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/testutils/functional" +) + +func TestRie_SingleInvoke(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + expectedStatus int + expectedPayload string + expectedHeaders map[string]string + runtimeEnv functional.RuntimeEnv + }{ + { + name: "simple_invoke", + expectedStatus: http.StatusOK, + expectedPayload: "test response", + expectedHeaders: map[string]string{ + "Content-Type": "application/json", + }, + runtimeEnv: functional.RuntimeEnv{ + Workers: []functional.RuntimeExecutionEnvironment{ + { + Actions: []functional.ExecutionEnvironmentAction{ + functional.NextAction{}, + functional.InvocationResponseAction{ + Payload: strings.NewReader("test response"), + ContentType: "application/json", + }, + }, + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + tt.runtimeEnv.T = t + supv := functional.NewMockSupervisor(t, &tt.runtimeEnv, nil, nil) + + sigCh := make(chan os.Signal, 1) + + mockFileUtil := functional.MakeMockFileUtil(nil) + + args := []string{"--runtime-api-address", "127.0.0.1:0", "--runtime-interface-emulator-address", "127.0.0.1:0", "echo", "hello"} + server, rieHandler, _, err := internal.Run(supv, args, mockFileUtil, sigCh) + require.NoError(t, err) + + require.NoError(t, rieHandler.Init()) + + resp, err := http.Post("http://"+server.Addr.String()+"/2015-03-31/functions/function/invocations", "application/json", strings.NewReader("{}")) + require.NoError(t, err) + defer func() { require.NoError(t, resp.Body.Close()) }() + + assert.Equal(t, tt.expectedStatus, resp.StatusCode) + + for key, expectedValue := range tt.expectedHeaders { + assert.Equal(t, expectedValue, resp.Header.Get(key)) + } + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, tt.expectedPayload, string(body)) + + tt.runtimeEnv.Done.Wait() + + sigCh <- syscall.SIGTERM + <-server.Done() + assert.NoError(t, server.Err()) + }) + } +} + +func TestRie_SigtermDuringInvoke(t *testing.T) { + t.Parallel() + + runtime := &functional.RuntimeEnv{ + Workers: []functional.RuntimeExecutionEnvironment{ + { + Actions: []functional.ExecutionEnvironmentAction{ + functional.NextAction{}, + }, + }, + }, + ForcedError: nil, + T: t, + } + supv := functional.NewMockSupervisor(t, runtime, nil, nil) + + sigCh := make(chan os.Signal, 1) + + mockFileUtil := functional.MakeMockFileUtil(nil) + + args := []string{"--runtime-api-address", "127.0.0.1:0", "--runtime-interface-emulator-address", "127.0.0.1:0", "echo", "hello"} + server, rieHandler, _, err := internal.Run(supv, args, mockFileUtil, sigCh) + require.NoError(t, err) + + require.NoError(t, rieHandler.Init()) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + resp, err := http.Post("http://"+server.Addr.String()+"/2015-03-31/functions/function/invocations", "application/json", strings.NewReader("{}")) + require.NoError(t, err) + defer func() { require.NoError(t, resp.Body.Close()) }() + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + + assert.Equal(t, "Client.ExecutionEnvironmentShutDown", resp.Header.Get("Error-Type")) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.JSONEq(t, `{"errorType":"Client.ExecutionEnvironmentShutDown"}`, string(body)) + + wg.Done() + }() + + time.Sleep(200 * time.Millisecond) + + sigCh <- syscall.SIGTERM + <-server.Done() + assert.NoError(t, server.Err()) + + wg.Wait() +} + +func TestRie_InitError(t *testing.T) { + t.Parallel() + + runtime := &functional.RuntimeEnv{ + Workers: []functional.RuntimeExecutionEnvironment{ + { + Actions: []functional.ExecutionEnvironmentAction{ + functional.InitErrorAction{ + ErrorType: "Function.TestError", + ExpectedStatus: http.StatusAccepted, + }, + functional.ExitAction{}, + }, + }, + }, + ForcedError: nil, + T: t, + } + supv := functional.NewMockSupervisor(t, runtime, nil, nil) + + sigCh := make(chan os.Signal, 1) + + args := []string{"--runtime-api-address", "127.0.0.1:0", "--runtime-interface-emulator-address", "127.0.0.1:0", "echo", "hello"} + server, rieHandler, _, err := internal.Run(supv, args, functional.MakeMockFileUtil(nil), sigCh) + require.NoError(t, err) + + initErr := rieHandler.Init() + assert.Error(t, initErr) + assert.Contains(t, []rmodel.ErrorType{"Function.TestError", "Runtime.ExitError"}, initErr.ErrorType()) + + <-server.Done() + serverErr := server.Err() + assert.Error(t, serverErr) + assert.Contains(t, []rmodel.ErrorType{"Function.TestError", "Runtime.ExitError"}, initErr.(rmodel.CustomerError).ErrorType()) +} + +func TestRie_InvokeWaitingForInitError(t *testing.T) { + t.Parallel() + + runtime := &functional.RuntimeEnv{ + Workers: []functional.RuntimeExecutionEnvironment{ + { + Actions: []functional.ExecutionEnvironmentAction{ + functional.SleepAction{Duration: 100 * time.Millisecond}, + functional.ExitAction{}, + }, + }, + }, + ForcedError: nil, + T: t, + } + supv := functional.NewMockSupervisor(t, runtime, nil, nil) + + sigCh := make(chan os.Signal, 1) + + mockFileUtil := functional.MakeMockFileUtil(nil) + + args := []string{"--runtime-api-address", "127.0.0.1:0", "--runtime-interface-emulator-address", "127.0.0.1:0", "echo", "hello"} + server, rieHandler, _, err := internal.Run(supv, args, mockFileUtil, sigCh) + require.NoError(t, err) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + resp, err := http.Post("http://"+server.Addr.String()+"/2015-03-31/functions/function/invocations", "application/json", strings.NewReader("{}")) + require.NoError(t, err) + defer func() { require.NoError(t, resp.Body.Close()) }() + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + expectedHeaders := map[string]string{ + "Content-Type": "application/json", + "Error-Type": "Runtime.ExitError", + } + for key, expectedValue := range expectedHeaders { + assert.Equal(t, expectedValue, resp.Header.Get(key)) + } + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.JSONEq(t, `{"errorType":"Runtime.ExitError"}`, string(body)) + + wg.Done() + }() + + time.Sleep(200 * time.Millisecond) + initErr := rieHandler.Init() + require.Error(t, initErr) + + <-server.Done() + serverErr := server.Err() + assert.Error(t, serverErr) + assert.Equal(t, initErr, serverErr) + + wg.Wait() +} + +func TestRie_InvokeFatalError(t *testing.T) { + t.Parallel() + + runtime := &functional.RuntimeEnv{ + Workers: []functional.RuntimeExecutionEnvironment{ + { + Actions: []functional.ExecutionEnvironmentAction{ + functional.NextAction{}, + functional.ExitAction{}, + }, + }, + }, + ForcedError: nil, + T: t, + } + supv := functional.NewMockSupervisor(t, runtime, nil, nil) + + sigCh := make(chan os.Signal, 1) + + mockFileUtil := functional.MakeMockFileUtil(nil) + + args := []string{"--runtime-api-address", "127.0.0.1:0", "--runtime-interface-emulator-address", "127.0.0.1:0", "echo", "hello"} + server, rieHandler, _, err := internal.Run(supv, args, mockFileUtil, sigCh) + require.NoError(t, err) + + require.NoError(t, rieHandler.Init()) + + resp, err := http.Post("http://"+server.Addr.String()+"/2015-03-31/functions/function/invocations", "application/json", strings.NewReader("{}")) + require.NoError(t, err) + defer func() { require.NoError(t, resp.Body.Close()) }() + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + assert.Equal(t, "Runtime.ExitError", resp.Header.Get("Error-Type")) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.JSONEq(t, `{"errorType":"Runtime.ExitError"}`, string(body)) + + <-server.Done() + assert.Error(t, server.Err()) +} + +func TestRIE_TelemetryAPI(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + expectedStatus int + expectedPayload string + expectedHeaders map[string]string + runtimeEnv functional.RuntimeEnv + extensionEnv functional.ExtensionsEnv + }{ + { + name: "simple_invoke", + expectedStatus: http.StatusOK, + expectedPayload: "test response", + expectedHeaders: map[string]string{ + "Content-Type": "application/json", + }, + runtimeEnv: functional.RuntimeEnv{ + Workers: []functional.RuntimeExecutionEnvironment{ + { + Actions: []functional.ExecutionEnvironmentAction{ + functional.NextAction{}, + functional.StdoutAction{Payload: "runtime: test stdout log\n"}, + functional.StderrAction{Payload: "runtime: test stderr log\n"}, + functional.InvocationResponseAction{ + Payload: strings.NewReader("test response"), + ContentType: "application/json", + }, + }, + }, + }, + }, + extensionEnv: functional.ExtensionsEnv{ + "http": &functional.ExtensionsExecutionEnvironment{ + Actions: []functional.ExecutionEnvironmentAction{ + functional.ExtensionsRegisterAction{ + AgentUniqueName: "http", + ExpectedStatus: http.StatusOK, + }, + functional.ExtensionsTelemetryAPIHTTPSubscriberAction{ + InMemoryEventsApi: functional.NewInMemoryEventsApi(t), + Subscription: functional.ExtensionTelemetrySubscribeAction{ + AgentName: "http", + ExpectedStatus: http.StatusOK, + }, + }, + functional.StdoutAction{Payload: "extension http: test stdout log\n"}, + functional.StderrAction{Payload: "extension http: test stderr log\n"}, + functional.ExtensionsNextAction{ + ExpectedStatus: http.StatusOK, + }, + }, + }, + "tcp": &functional.ExtensionsExecutionEnvironment{ + Actions: []functional.ExecutionEnvironmentAction{ + functional.ExtensionsRegisterAction{ + AgentUniqueName: "tcp", + ExpectedStatus: http.StatusOK, + }, + functional.ExtensionsTelemetryAPITCPSubscriberAction{ + InMemoryEventsApi: functional.NewInMemoryEventsApi(t), + Subscription: functional.ExtensionTelemetrySubscribeAction{ + AgentName: "tcp", + ExpectedStatus: http.StatusOK, + }, + }, + functional.StdoutAction{Payload: "extension tcp: test stdout log\n"}, + functional.StderrAction{Payload: "extension tcp: test stderr log\n"}, + functional.ExtensionsNextAction{ + ExpectedStatus: http.StatusOK, + }, + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + tt.runtimeEnv.T = t + for _, ext := range tt.extensionEnv { + ext.T = t + } + + var httpEventsApi, tcpEventsApi *functional.InMemoryEventsApi + + if httpExt, ok := tt.extensionEnv["http"]; ok { + for _, action := range httpExt.Actions { + if httpAction, ok := action.(functional.ExtensionsTelemetryAPIHTTPSubscriberAction); ok { + httpEventsApi = httpAction.InMemoryEventsApi + break + } + } + } + if tcpExt, ok := tt.extensionEnv["tcp"]; ok { + for _, action := range tcpExt.Actions { + if tcpAction, ok := action.(functional.ExtensionsTelemetryAPITCPSubscriberAction); ok { + tcpEventsApi = tcpAction.InMemoryEventsApi + break + } + } + } + + supv := functional.NewMockSupervisor(t, &tt.runtimeEnv, tt.extensionEnv, nil) + + sigCh := make(chan os.Signal, 1) + + mockFileUtil := functional.MakeMockFileUtil(tt.extensionEnv) + + args := []string{"--runtime-api-address", "127.0.0.1:0", "--runtime-interface-emulator-address", "127.0.0.1:0", "echo", "hello"} + server, rieHandler, _, err := internal.Run(supv, args, mockFileUtil, sigCh) + require.NoError(t, err) + + initStartTime := time.Now() + require.NoError(t, rieHandler.Init()) + initFinishTime := time.Now() + + req, err := http.NewRequest(http.MethodPost, "http://"+server.Addr.String()+"/2015-03-31/functions/function/invocations", strings.NewReader("{}")) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + invokeID := uuid.NewString() + req.Header.Set("x-amzn-RequestId", invokeID) + + invokeStartTime := time.Now() + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer func() { require.NoError(t, resp.Body.Close()) }() + + assert.Equal(t, tt.expectedStatus, resp.StatusCode) + + for key, expectedValue := range tt.expectedHeaders { + assert.Equal(t, expectedValue, resp.Header.Get(key)) + } + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, tt.expectedPayload, string(body)) + + invokeFinishTime := time.Now() + + sigCh <- syscall.SIGTERM + <-server.Done() + assert.NoError(t, server.Err()) + + initPayload := testutils.MakeInitPayload() + + expectedInitEvents := []functional.ExpectedInitEvent{ + { + EventType: functional.PlatformInitStart, + Status: "success", + }, + { + EventType: functional.PlatformInitRuntimeDone, + Status: "success", + }, + { + EventType: functional.PlatformInitReport, + Status: "success", + }, + } + + expectedExtensionEvents := []functional.ExpectedExtensionEvents{ + { + ExtensionName: "http", + State: "Ready", + }, + { + ExtensionName: "tcp", + State: "Ready", + }, + } + + expectedInvokeEvents := []functional.ExpectedInvokeEvents{ + { + EventType: functional.PlatformRuntimeStart, + }, + { + EventType: functional.PlatformReport, + Status: "success", + Spans: []string{"responseLatency", "responseDuration"}, + }, + } + + for _, mock := range []*functional.InMemoryEventsApi{httpEventsApi, tcpEventsApi} { + mock.CheckSimpleInitExpectations(initStartTime, initFinishTime, expectedInitEvents, initPayload) + mock.CheckSimpleExtensionExpectations(expectedExtensionEvents) + mock.CheckSimpleInvokeExpectations(invokeStartTime, invokeFinishTime, invokeID, expectedInvokeEvents, initPayload) + assert.Len(t, mock.LogLines(), 6) + } + }) + } +} diff --git a/internal/lambda-managed-instances/core/agent_state_names.go b/internal/lambda-managed-instances/core/agent_state_names.go new file mode 100644 index 0000000..9a4766b --- /dev/null +++ b/internal/lambda-managed-instances/core/agent_state_names.go @@ -0,0 +1,16 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package core + +const ( + AgentStartedStateName = "Started" + AgentRegisteredStateName = "Registered" + AgentReadyStateName = "Ready" + AgentRunningStateName = "Running" + AgentInitErrorStateName = "InitError" + AgentExitErrorStateName = "ExitError" + AgentShutdownFailedStateName = "ShutdownFailed" + AgentExitedStateName = "Exited" + AgentLaunchErrorName = "LaunchError" +) diff --git a/internal/lambda-managed-instances/core/agentsmap.go b/internal/lambda-managed-instances/core/agentsmap.go new file mode 100644 index 0000000..7919059 --- /dev/null +++ b/internal/lambda-managed-instances/core/agentsmap.go @@ -0,0 +1,72 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package core + +import ( + "errors" + + "github.com/google/uuid" +) + +type Agent interface { + *ExternalAgent | *InternalAgent + Name() string + ID() uuid.UUID +} + +var ErrAgentNameCollision = errors.New("ErrAgentNameCollision") + +var ErrAgentIDCollision = errors.New("ErrAgentIDCollision") + +type AgentsMap[T Agent] struct { + byName map[string]T + byID map[string]T +} + +func NewAgentsMap[T Agent]() AgentsMap[T] { + return AgentsMap[T]{ + byName: make(map[string]T), + byID: make(map[string]T), + } +} + +func (m *AgentsMap[T]) Insert(a T) error { + if _, nameCollision := m.FindByName(a.Name()); nameCollision { + return ErrAgentNameCollision + } + + if _, idCollision := m.FindByID(a.ID()); idCollision { + return ErrAgentIDCollision + } + + m.byName[a.Name()] = a + m.byID[a.ID().String()] = a + + return nil +} + +func (m *AgentsMap[T]) FindByName(name string) (agent T, found bool) { + agent, found = m.byName[name] + return agent, found +} + +func (m *AgentsMap[T]) FindByID(id uuid.UUID) (agent T, found bool) { + agent, found = m.byID[id.String()] + return agent, found +} + +func (m *AgentsMap[T]) Visit(cb func(T)) { + for _, a := range m.byName { + cb(a) + } +} + +func (m *AgentsMap[T]) Size() int { + return len(m.byName) +} + +func (m *AgentsMap[T]) Clear() { + m.byName = make(map[string]T) + m.byID = make(map[string]T) +} diff --git a/internal/lambda-managed-instances/core/agentsmap_test.go b/internal/lambda-managed-instances/core/agentsmap_test.go new file mode 100644 index 0000000..f95a19a --- /dev/null +++ b/internal/lambda-managed-instances/core/agentsmap_test.go @@ -0,0 +1,75 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package core + +import ( + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestExternalAgentsMapLookupByName(t *testing.T) { + m := NewAgentsMap[*ExternalAgent]() + + err := m.Insert(&ExternalAgent{name: "a", id: uuid.New()}) + require.NoError(t, err) + agentIn := &ExternalAgent{name: "b", id: uuid.New()} + err = m.Insert(agentIn) + require.NoError(t, err) + err = m.Insert(&ExternalAgent{name: "c", id: uuid.New()}) + require.NoError(t, err) + + agentOut, found := m.FindByName(agentIn.Name()) + require.True(t, found) + require.Equal(t, agentIn, agentOut) + + assert.Equal(t, m.Size(), 3) +} + +func TestExternalAgentsMapLookupByID(t *testing.T) { + m := NewAgentsMap[*ExternalAgent]() + + err := m.Insert(&ExternalAgent{name: "a", id: uuid.New()}) + require.NoError(t, err) + agentIn := &ExternalAgent{name: "b", id: uuid.New()} + err = m.Insert(agentIn) + require.NoError(t, err) + err = m.Insert(&ExternalAgent{name: "c", id: uuid.New()}) + require.NoError(t, err) + + agentOut, found := m.FindByID(agentIn.ID()) + require.True(t, found) + require.Equal(t, agentIn, agentOut) + + m.Clear() + assert.Equal(t, m.Size(), 0) +} + +func TestExternalAgentsMapInsertNameCollision(t *testing.T) { + m := NewAgentsMap[*ExternalAgent]() + + err := m.Insert(&ExternalAgent{name: "a", id: uuid.New()}) + require.NoError(t, err) + + err = m.Insert(&ExternalAgent{name: "a", id: uuid.New()}) + require.Equal(t, err, ErrAgentNameCollision) + + assert.Equal(t, m.Size(), 1) +} + +func TestExternalAgentsMapInsertIDCollision(t *testing.T) { + m := NewAgentsMap[*ExternalAgent]() + + id := uuid.New() + + err := m.Insert(&ExternalAgent{name: "a", id: id}) + require.NoError(t, err) + + err = m.Insert(&ExternalAgent{name: "b", id: id}) + require.Equal(t, err, ErrAgentIDCollision) + + assert.Equal(t, m.Size(), 1) +} diff --git a/internal/lambda-managed-instances/core/agentutil.go b/internal/lambda-managed-instances/core/agentutil.go new file mode 100644 index 0000000..39dcc14 --- /dev/null +++ b/internal/lambda-managed-instances/core/agentutil.go @@ -0,0 +1,28 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package core + +import ( + "errors" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" +) + +var errInvalidEventType = errors.New("ErrorInvalidEventType") + +type disallowEverything struct{} + +func (s *disallowEverything) Register(events []Event) error { return ErrNotAllowed } + +func (s *disallowEverything) Ready() error { return ErrNotAllowed } + +func (s *disallowEverything) InitError(errorType model.ErrorType) error { return ErrNotAllowed } + +func (s *disallowEverything) ExitError(errorType model.ErrorType) error { return ErrNotAllowed } + +func (s *disallowEverything) ShutdownFailed() error { return ErrNotAllowed } + +func (s *disallowEverything) Exited() error { return ErrNotAllowed } + +func (s *disallowEverything) LaunchError(model.ErrorType) error { return ErrNotAllowed } diff --git a/internal/lambda-managed-instances/core/bandwidthlimiter/bandwidthlimiter.go b/internal/lambda-managed-instances/core/bandwidthlimiter/bandwidthlimiter.go new file mode 100644 index 0000000..f2be6ed --- /dev/null +++ b/internal/lambda-managed-instances/core/bandwidthlimiter/bandwidthlimiter.go @@ -0,0 +1,62 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package bandwidthlimiter + +import ( + "io" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/utils" +) + +func BandwidthLimitingCopy(dst *BandwidthLimitingWriter, src io.Reader) (written int64, err error) { + written, err = utils.CopyWithPool(dst, src) + _ = dst.Close() + return written, err +} + +func NewBandwidthLimitingWriter(w io.Writer, bucket *Bucket) (*BandwidthLimitingWriter, error) { + throttler, err := NewThrottler(bucket) + if err != nil { + return nil, err + } + return &BandwidthLimitingWriter{w: w, th: throttler}, nil +} + +type BandwidthLimitingWriter struct { + w io.Writer + th *Throttler +} + +func (w *BandwidthLimitingWriter) ChunkedWrite(p []byte) (n int, err error) { + i := NewChunkIterator(p, int(w.th.b.capacity)) + for { + buf := i.Next() + if buf == nil { + return n, err + } + written, writeErr := w.th.bandwidthLimitingWrite(w.w, buf) + n += written + if writeErr != nil { + return n, writeErr + } + } +} + +func (w *BandwidthLimitingWriter) Write(p []byte) (n int, err error) { + w.th.start() + if int64(len(p)) > w.th.b.capacity { + return w.ChunkedWrite(p) + } + return w.th.bandwidthLimitingWrite(w.w, p) +} + +func (w *BandwidthLimitingWriter) Close() (err error) { + w.th.stop() + return err +} + +func (w *BandwidthLimitingWriter) GetMetrics() (metrics *interop.InvokeResponseMetrics) { + return w.th.metrics +} diff --git a/internal/lambda-managed-instances/core/bandwidthlimiter/bandwidthlimiter_test.go b/internal/lambda-managed-instances/core/bandwidthlimiter/bandwidthlimiter_test.go new file mode 100644 index 0000000..0b933b4 --- /dev/null +++ b/internal/lambda-managed-instances/core/bandwidthlimiter/bandwidthlimiter_test.go @@ -0,0 +1,106 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package bandwidthlimiter + +import ( + "bytes" + "io" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestBandwidthLimitingCopy(t *testing.T) { + var size10mb int64 = 10 * 1024 * 1024 + + inputBuffer := []byte(strings.Repeat("a", int(size10mb))) + reader := bytes.NewReader(inputBuffer) + + bucket, err := NewBucket(size10mb/2, size10mb/4, size10mb/2, time.Millisecond/2) + assert.NoError(t, err) + + internalWriter := bytes.NewBuffer(make([]byte, 0, size10mb)) + writer, err := NewBandwidthLimitingWriter(internalWriter, bucket) + assert.NoError(t, err) + + n, err := BandwidthLimitingCopy(writer, reader) + assert.Equal(t, size10mb, n) + assert.Equal(t, nil, err) + assert.Equal(t, inputBuffer, internalWriter.Bytes()) +} + +type ErrorBufferWriter struct { + w ByteBufferWriter + failAfter int +} + +func (w *ErrorBufferWriter) Write(p []byte) (n int, err error) { + if w.failAfter >= 1 { + w.failAfter-- + } + n, err = w.w.Write(p) + if w.failAfter == 0 { + return n, io.ErrUnexpectedEOF + } + return n, err +} + +func (w *ErrorBufferWriter) Bytes() []byte { + return w.w.Bytes() +} + +func TestNewBandwidthLimitingWriter(t *testing.T) { + type testCase struct { + refillNumber int64 + internalWriter ByteBufferWriter + inputBuffer []byte + expectedN int + expectedError error + } + testCases := []testCase{ + { + refillNumber: 2, + internalWriter: bytes.NewBuffer(make([]byte, 0, 36)), + inputBuffer: []byte(strings.Repeat("a", 36)), + expectedN: 36, + expectedError: nil, + }, + { + refillNumber: 2, + internalWriter: bytes.NewBuffer(make([]byte, 0, 12)), + inputBuffer: []byte(strings.Repeat("a", 12)), + expectedN: 12, + expectedError: nil, + }, + { + + refillNumber: 2, + internalWriter: &ErrorBufferWriter{w: bytes.NewBuffer(make([]byte, 0, 36)), failAfter: 2}, + inputBuffer: []byte(strings.Repeat("a", 36)), + expectedN: 32, + expectedError: io.ErrUnexpectedEOF, + }, + } + + for _, test := range testCases { + bucket, err := NewBucket(16, 8, test.refillNumber, 100*time.Millisecond) + assert.NoError(t, err) + + writer, err := NewBandwidthLimitingWriter(test.internalWriter, bucket) + assert.NoError(t, err) + assert.False(t, writer.th.running) + + n, err := writer.Write(test.inputBuffer) + assert.True(t, writer.th.running) + assert.Equal(t, test.expectedN, n) + assert.Equal(t, test.expectedError, err) + assert.Equal(t, test.inputBuffer[:n], test.internalWriter.Bytes()) + + err = writer.Close() + assert.Nil(t, err) + assert.False(t, writer.th.running) + } +} diff --git a/internal/lambda-managed-instances/core/bandwidthlimiter/throttler.go b/internal/lambda-managed-instances/core/bandwidthlimiter/throttler.go new file mode 100644 index 0000000..6b5c681 --- /dev/null +++ b/internal/lambda-managed-instances/core/bandwidthlimiter/throttler.go @@ -0,0 +1,156 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package bandwidthlimiter + +import ( + "errors" + "fmt" + "io" + "log/slog" + "sync" + "time" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" +) + +var ErrBufferSizeTooLarge = errors.New("buffer size cannot be greater than bucket size") + +func NewBucket(capacity int64, initialTokenCount int64, refillNumber int64, refillInterval time.Duration) (*Bucket, error) { + if capacity <= 0 || initialTokenCount < 0 || refillNumber <= 0 || refillInterval <= 0 || + capacity < initialTokenCount { + errorMsg := fmt.Sprintf("invalid bucket parameters (capacity: %d, initialTokenCount: %d, refillNumber: %d,"+ + "refillInterval: %d)", capacity, initialTokenCount, refillInterval, refillInterval) + slog.Error(errorMsg) + return nil, errors.New(errorMsg) + } + return &Bucket{ + capacity: capacity, + tokenCount: initialTokenCount, + refillNumber: refillNumber, + refillInterval: refillInterval, + mutex: sync.Mutex{}, + }, nil +} + +type Bucket struct { + capacity int64 + tokenCount int64 + refillNumber int64 + refillInterval time.Duration + mutex sync.Mutex +} + +func (b *Bucket) produceTokens() { + b.mutex.Lock() + defer b.mutex.Unlock() + if b.tokenCount < b.capacity { + b.tokenCount = min(b.tokenCount+b.refillNumber, b.capacity) + } +} + +func (b *Bucket) consumeTokens(n int64) bool { + b.mutex.Lock() + defer b.mutex.Unlock() + if n <= b.tokenCount { + b.tokenCount -= n + return true + } + return false +} + +func (b *Bucket) getTokenCount() int64 { + b.mutex.Lock() + defer b.mutex.Unlock() + return b.tokenCount +} + +func NewThrottler(bucket *Bucket) (*Throttler, error) { + if bucket == nil { + errorMsg := "cannot create a throttler with nil bucket" + slog.Error(errorMsg) + return nil, errors.New(errorMsg) + } + + now := time.Now() + + return &Throttler{ + b: bucket, + running: false, + produced: make(chan time.Time), + done: make(chan struct{}), + + metrics: &interop.InvokeResponseMetrics{ + StartReadingResponseTime: now, + FinishReadingResponseTime: now, + OutboundThroughputBps: -1, + FunctionResponseMode: interop.FunctionResponseModeStreaming, + }, + }, nil +} + +type Throttler struct { + b *Bucket + running bool + produced chan time.Time + done chan struct{} + metrics *interop.InvokeResponseMetrics +} + +func (th *Throttler) start() { + if th.running { + return + } + th.running = true + th.metrics.StartReadingResponseTime = time.Now() + go func() { + ticker := time.NewTicker(th.b.refillInterval) + for { + select { + case <-ticker.C: + th.b.produceTokens() + select { + case th.produced <- time.Now(): + default: + } + case <-th.done: + ticker.Stop() + return + } + } + }() +} + +func (th *Throttler) stop() { + if !th.running { + return + } + th.running = false + th.metrics.FinishReadingResponseTime = time.Now() + duration := th.metrics.StartReadingResponseTime.Sub(th.metrics.FinishReadingResponseTime) + if duration > 0 { + th.metrics.OutboundThroughputBps = (th.metrics.ProducedBytes / duration.Milliseconds()) * int64(time.Second/time.Millisecond) + } else { + th.metrics.OutboundThroughputBps = -1 + } + th.done <- struct{}{} +} + +func (th *Throttler) bandwidthLimitingWrite(w io.Writer, p []byte) (written int, err error) { + n := int64(len(p)) + if n > th.b.capacity { + return 0, ErrBufferSizeTooLarge + } + for { + if th.b.consumeTokens(n) { + written, err = w.Write(p) + th.metrics.ProducedBytes += int64(written) + return written, err + } + waitStart := time.Now() + elapsed := (<-th.produced).Sub(waitStart) + if elapsed > 0 { + th.metrics.TimeShaped += elapsed + } + } +} diff --git a/internal/lambda-managed-instances/core/bandwidthlimiter/throttler_test.go b/internal/lambda-managed-instances/core/bandwidthlimiter/throttler_test.go new file mode 100644 index 0000000..0210334 --- /dev/null +++ b/internal/lambda-managed-instances/core/bandwidthlimiter/throttler_test.go @@ -0,0 +1,215 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package bandwidthlimiter + +import ( + "bytes" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestNewBucket(t *testing.T) { + type testCase struct { + capacity int64 + initialTokenCount int64 + refillNumber int64 + refillInterval time.Duration + bucketCreated bool + } + testCases := []testCase{ + {capacity: 8, initialTokenCount: 6, refillNumber: 2, refillInterval: 100 * time.Millisecond, bucketCreated: true}, + {capacity: 8, initialTokenCount: 6, refillNumber: 2, refillInterval: -100 * time.Millisecond, bucketCreated: false}, + {capacity: 8, initialTokenCount: 6, refillNumber: -5, refillInterval: 100 * time.Millisecond, bucketCreated: false}, + {capacity: 8, initialTokenCount: -2, refillNumber: 2, refillInterval: 100 * time.Millisecond, bucketCreated: false}, + {capacity: -2, initialTokenCount: 6, refillNumber: 2, refillInterval: 100 * time.Millisecond, bucketCreated: false}, + {capacity: 8, initialTokenCount: 10, refillNumber: 2, refillInterval: 100 * time.Millisecond, bucketCreated: false}, + } + + for _, test := range testCases { + bucket, err := NewBucket(test.capacity, test.initialTokenCount, test.refillNumber, test.refillInterval) + if test.bucketCreated { + assert.NoError(t, err) + assert.NotNil(t, bucket) + } else { + assert.Error(t, err) + assert.Nil(t, bucket) + } + } +} + +func TestBucket_produceTokens_consumeTokens(t *testing.T) { + var consumed bool + bucket, err := NewBucket(16, 8, 6, 100*time.Millisecond) + assert.NoError(t, err) + assert.Equal(t, int64(8), bucket.getTokenCount()) + + consumed = bucket.consumeTokens(5) + assert.Equal(t, int64(3), bucket.getTokenCount()) + assert.True(t, consumed) + + bucket.produceTokens() + assert.Equal(t, int64(9), bucket.getTokenCount()) + + bucket.produceTokens() + assert.Equal(t, int64(15), bucket.getTokenCount()) + + bucket.produceTokens() + assert.Equal(t, int64(16), bucket.getTokenCount()) + + bucket.produceTokens() + assert.Equal(t, int64(16), bucket.getTokenCount()) + + consumed = bucket.consumeTokens(18) + assert.Equal(t, int64(16), bucket.getTokenCount()) + assert.False(t, consumed) + + consumed = bucket.consumeTokens(16) + assert.Equal(t, int64(0), bucket.getTokenCount()) + assert.True(t, consumed) +} + +func TestNewThrottler(t *testing.T) { + bucket, err := NewBucket(16, 8, 6, 100*time.Millisecond) + assert.NoError(t, err) + + throttler, err := NewThrottler(bucket) + assert.NoError(t, err) + assert.NotNil(t, throttler) + + throttler, err = NewThrottler(nil) + assert.Error(t, err) + assert.Nil(t, throttler) +} + +func TestNewThrottler_start_stop(t *testing.T) { + bucket, err := NewBucket(16, 8, 6, 100*time.Millisecond) + assert.NoError(t, err) + + throttler, err := NewThrottler(bucket) + assert.NoError(t, err) + + assert.False(t, throttler.running) + + throttler.start() + assert.True(t, throttler.running) + + <-time.Tick(2 * throttler.b.refillInterval) + assert.LessOrEqual(t, int64(14), throttler.b.getTokenCount()) + assert.True(t, throttler.running) + + throttler.start() + assert.True(t, throttler.running) + <-time.Tick(2 * throttler.b.refillInterval) + assert.Equal(t, int64(16), throttler.b.getTokenCount()) + assert.True(t, throttler.running) + + throttler.stop() + assert.False(t, throttler.running) + + throttler.stop() + assert.False(t, throttler.running) + + throttler.start() + assert.True(t, throttler.running) + + throttler.stop() + assert.False(t, throttler.running) +} + +type ByteBufferWriter interface { + Write(p []byte) (n int, err error) + Bytes() []byte +} + +type FixedSizeBufferWriter struct { + buf []byte +} + +func (w *FixedSizeBufferWriter) Write(p []byte) (n int, err error) { + n = copy(w.buf, p) + return n, err +} + +func (w *FixedSizeBufferWriter) Bytes() []byte { + return w.buf +} + +func TestNewThrottler_bandwidthLimitingWrite(t *testing.T) { + var size10mb int64 = 10 * 1024 * 1024 + + type testCase struct { + capacity int64 + initialTokenCount int64 + writer ByteBufferWriter + inputBuffer []byte + expectedN int + expectedError error + } + testCases := []testCase{ + { + capacity: 16, + initialTokenCount: 8, + writer: bytes.NewBuffer(make([]byte, 0, 14)), + inputBuffer: []byte(strings.Repeat("a", 12)), + expectedN: 12, + expectedError: nil, + }, + { + capacity: 16, + initialTokenCount: 8, + writer: bytes.NewBuffer(make([]byte, 0, 12)), + inputBuffer: []byte(strings.Repeat("a", 14)), + expectedN: 14, + expectedError: nil, + }, + { + capacity: size10mb, + initialTokenCount: size10mb, + writer: bytes.NewBuffer(make([]byte, 0, size10mb)), + inputBuffer: []byte(strings.Repeat("a", int(size10mb))), + expectedN: int(size10mb), + expectedError: nil, + }, + { + capacity: 16, + initialTokenCount: 8, + writer: bytes.NewBuffer(make([]byte, 0, 18)), + inputBuffer: []byte(strings.Repeat("a", 18)), + expectedN: 0, + expectedError: ErrBufferSizeTooLarge, + }, + { + capacity: 16, + initialTokenCount: 8, + writer: &FixedSizeBufferWriter{buf: make([]byte, 12)}, + inputBuffer: []byte(strings.Repeat("a", 14)), + expectedN: 12, + expectedError: nil, + }, + } + + for _, test := range testCases { + bucket, err := NewBucket(test.capacity, test.initialTokenCount, 2, 100*time.Millisecond) + assert.NoError(t, err) + + throttler, err := NewThrottler(bucket) + assert.NoError(t, err) + + writer := test.writer + throttler.start() + n, err := throttler.bandwidthLimitingWrite(writer, test.inputBuffer) + assert.Equal(t, test.expectedN, n) + assert.Equal(t, test.expectedError, err) + + if test.expectedError == nil { + assert.Equal(t, test.inputBuffer[:n], test.writer.Bytes()) + } else { + assert.Equal(t, []byte{}, test.writer.Bytes()) + } + throttler.stop() + } +} diff --git a/internal/lambda-managed-instances/core/bandwidthlimiter/util.go b/internal/lambda-managed-instances/core/bandwidthlimiter/util.go new file mode 100644 index 0000000..059f0e4 --- /dev/null +++ b/internal/lambda-managed-instances/core/bandwidthlimiter/util.go @@ -0,0 +1,32 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package bandwidthlimiter + +func NewChunkIterator(buf []byte, chunkSize int) *ChunkIterator { + if buf == nil { + return nil + } + return &ChunkIterator{ + buf: buf, + chunkSize: chunkSize, + offset: 0, + } +} + +type ChunkIterator struct { + buf []byte + chunkSize int + offset int +} + +func (i *ChunkIterator) Next() []byte { + begin := i.offset + end := min(i.offset+i.chunkSize, len(i.buf)) + i.offset = end + + if begin == end { + return nil + } + return i.buf[begin:end] +} diff --git a/internal/lambda-managed-instances/core/bandwidthlimiter/util_test.go b/internal/lambda-managed-instances/core/bandwidthlimiter/util_test.go new file mode 100644 index 0000000..ed93c77 --- /dev/null +++ b/internal/lambda-managed-instances/core/bandwidthlimiter/util_test.go @@ -0,0 +1,45 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package bandwidthlimiter + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewChunkIterator(t *testing.T) { + buf := []byte("abcdefghijk") + + type testCase struct { + buf []byte + chunkSize int + expectedResult [][]byte + } + testCases := []testCase{ + {buf: nil, chunkSize: 0, expectedResult: [][]byte{}}, + {buf: nil, chunkSize: 1, expectedResult: [][]byte{}}, + {buf: buf, chunkSize: 0, expectedResult: [][]byte{}}, + {buf: buf, chunkSize: 1, expectedResult: [][]byte{ + []byte("a"), []byte("b"), []byte("c"), []byte("d"), []byte("e"), []byte("f"), []byte("g"), []byte("h"), + []byte("i"), []byte("j"), []byte("k"), + }}, + {buf: buf, chunkSize: 4, expectedResult: [][]byte{[]byte("abcd"), []byte("efgh"), []byte("ijk")}}, + {buf: buf, chunkSize: 5, expectedResult: [][]byte{[]byte("abcde"), []byte("fghij"), []byte("k")}}, + {buf: buf, chunkSize: 11, expectedResult: [][]byte{[]byte("abcdefghijk")}}, + {buf: buf, chunkSize: 12, expectedResult: [][]byte{[]byte("abcdefghijk")}}, + } + + for _, test := range testCases { + iterator := NewChunkIterator(test.buf, test.chunkSize) + if test.buf == nil { + assert.Nil(t, iterator) + } else { + for _, expectedChunk := range test.expectedResult { + assert.Equal(t, expectedChunk, iterator.Next()) + } + assert.Nil(t, iterator.Next()) + } + } +} diff --git a/internal/lambda-managed-instances/core/directinvoke/customerheaders.go b/internal/lambda-managed-instances/core/directinvoke/customerheaders.go new file mode 100644 index 0000000..fd0e4ad --- /dev/null +++ b/internal/lambda-managed-instances/core/directinvoke/customerheaders.go @@ -0,0 +1,41 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package directinvoke + +import ( + "bytes" + "encoding/base64" + "encoding/json" +) + +type CustomerHeaders struct { + CognitoIdentityID string `json:"Cognito-Identity-Id"` + CognitoIdentityPoolID string `json:"Cognito-Identity-Pool-Id"` + ClientContext string `json:"Client-Context"` +} + +func (s CustomerHeaders) Dump() string { + if (s == CustomerHeaders{}) { + return "" + } + + custHeadersJSON, err := json.Marshal(&s) + if err != nil { + panic(err) + } + + return base64.StdEncoding.EncodeToString(custHeadersJSON) +} + +func (s *CustomerHeaders) Load(in string) error { + *s = CustomerHeaders{} + + if in == "" { + return nil + } + + base64Decoder := base64.NewDecoder(base64.StdEncoding, bytes.NewReader([]byte(in))) + + return json.NewDecoder(base64Decoder).Decode(s) +} diff --git a/internal/lambda-managed-instances/core/directinvoke/customerheaders_test.go b/internal/lambda-managed-instances/core/directinvoke/customerheaders_test.go new file mode 100644 index 0000000..924df8d --- /dev/null +++ b/internal/lambda-managed-instances/core/directinvoke/customerheaders_test.go @@ -0,0 +1,26 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package directinvoke + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestCustomerHeadersEmpty(t *testing.T) { + in := CustomerHeaders{} + out := CustomerHeaders{} + + require.NoError(t, out.Load(in.Dump())) + require.Equal(t, in, out) +} + +func TestCustomerHeaders(t *testing.T) { + in := CustomerHeaders{CognitoIdentityID: "asd"} + out := CustomerHeaders{} + + require.NoError(t, out.Load(in.Dump())) + require.Equal(t, in, out) +} diff --git a/internal/lambda-managed-instances/core/directinvoke/util.go b/internal/lambda-managed-instances/core/directinvoke/util.go new file mode 100644 index 0000000..6eba642 --- /dev/null +++ b/internal/lambda-managed-instances/core/directinvoke/util.go @@ -0,0 +1,79 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package directinvoke + +import ( + "context" + "io" + "net/http" + "time" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/core/bandwidthlimiter" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/utils/invariant" +) + +const DefaultRefillIntervalMs = 125 + +func NewStreamedResponseWriter(w io.Writer, responseBandwidthRate int64, responseBandwidthBurstSize int64) (*bandwidthlimiter.BandwidthLimitingWriter, context.CancelFunc, error) { + ctx, cancel := context.WithCancel(context.Background()) + cancellableWriter := NewCancellableWriter(ctx, w) + + refillNumber := responseBandwidthRate * DefaultRefillIntervalMs / 1000 + refillInterval := DefaultRefillIntervalMs * time.Millisecond + + bucket, err := bandwidthlimiter.NewBucket(responseBandwidthBurstSize, responseBandwidthBurstSize, refillNumber, refillInterval) + if err != nil { + cancel() + return nil, nil, err + } + + bandwidthLimitingWriter, err := bandwidthlimiter.NewBandwidthLimitingWriter(cancellableWriter, bucket) + if err != nil { + cancel() + return nil, nil, err + } + + return bandwidthLimitingWriter, cancel, nil +} + +func NewFlushingWriter(w io.Writer) *FlushingWriter { + flusher, ok := w.(http.Flusher) + invariant.Checkf(ok, "writer must implement http.Flusher interface") + + return &FlushingWriter{ + w: w, + flusher: flusher, + } +} + +type FlushingWriter struct { + w io.Writer + flusher http.Flusher +} + +func (w *FlushingWriter) Write(p []byte) (n int, err error) { + n, err = w.w.Write(p) + w.Flush() + return n, err +} + +func (w *FlushingWriter) Flush() { + w.flusher.Flush() +} + +func NewCancellableWriter(ctx context.Context, w io.Writer) io.Writer { + return &CancellableWriter{w: w, ctx: ctx} +} + +type CancellableWriter struct { + w io.Writer + ctx context.Context +} + +func (w *CancellableWriter) Write(p []byte) (int, error) { + if err := w.ctx.Err(); err != nil { + return 0, err + } + return w.w.Write(p) +} diff --git a/internal/lambda-managed-instances/core/doc.go b/internal/lambda-managed-instances/core/doc.go new file mode 100644 index 0000000..c3498dd --- /dev/null +++ b/internal/lambda-managed-instances/core/doc.go @@ -0,0 +1,4 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package core diff --git a/internal/lambda-managed-instances/core/externalagent.go b/internal/lambda-managed-instances/core/externalagent.go new file mode 100644 index 0000000..7abbdcd --- /dev/null +++ b/internal/lambda-managed-instances/core/externalagent.go @@ -0,0 +1,192 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package core + +import ( + "fmt" + "time" + + "github.com/google/uuid" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/core/statejson" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" +) + +type ExternalAgent struct { + name string + id uuid.UUID + events map[Event]struct{} + + ManagedThread Suspendable + + currentState ExternalAgentState + stateLastModified time.Time + + StartedState ExternalAgentState + RegisteredState ExternalAgentState + ReadyState ExternalAgentState + RunningState ExternalAgentState + InitErrorState ExternalAgentState + ExitErrorState ExternalAgentState + ShutdownFailedState ExternalAgentState + ExitedState ExternalAgentState + LaunchErrorState ExternalAgentState + + errorType model.ErrorType +} + +func NewExternalAgent(name string, initFlow InitFlowSynchronization) *ExternalAgent { + agent := &ExternalAgent{ + name: name, + id: uuid.New(), + ManagedThread: NewManagedThread(), + events: make(map[Event]struct{}), + } + + agent.StartedState = &ExternalAgentStartedState{agent: agent, initFlow: initFlow} + agent.RegisteredState = &ExternalAgentRegisteredState{agent: agent, initFlow: initFlow} + agent.ReadyState = &ExternalAgentReadyState{agent: agent} + agent.RunningState = &ExternalAgentRunningState{agent: agent} + agent.InitErrorState = &ExternalAgentInitErrorState{} + agent.ExitErrorState = &ExternalAgentExitErrorState{} + agent.ShutdownFailedState = &ExternalAgentShutdownFailedState{} + agent.ExitedState = &ExternalAgentExitedState{} + agent.LaunchErrorState = &ExternalAgentLaunchErrorState{} + + agent.setStateUnsafe(agent.StartedState) + + return agent +} + +func (s *ExternalAgent) Name() string { + return s.name +} + +func (s *ExternalAgent) ID() uuid.UUID { + return s.id +} + +func (s *ExternalAgent) String() string { + return fmt.Sprintf("%s (%s)", s.name, s.id) +} + +func (s *ExternalAgent) SuspendUnsafe() { + s.ManagedThread.SuspendUnsafe() +} + +func (s *ExternalAgent) Release() { + s.ManagedThread.Release() +} + +func (s *ExternalAgent) SetState(state ExternalAgentState) { + s.ManagedThread.Lock() + defer s.ManagedThread.Unlock() + s.setStateUnsafe(state) +} + +func (s *ExternalAgent) setStateUnsafe(state ExternalAgentState) { + s.currentState = state + s.stateLastModified = time.Now() +} + +func ValidateExternalAgentEvent(e Event) error { + if e == ShutdownEvent { + return nil + } + return errInvalidEventType +} + +func (s *ExternalAgent) subscribeUnsafe(e Event) error { + if err := ValidateExternalAgentEvent(e); err != nil { + return err + } + s.events[e] = struct{}{} + return nil +} + +func (s *ExternalAgent) IsSubscribed(e Event) bool { + s.ManagedThread.Lock() + defer s.ManagedThread.Unlock() + _, found := s.events[e] + return found +} + +func (s *ExternalAgent) SubscribedEvents() []string { + s.ManagedThread.Lock() + defer s.ManagedThread.Unlock() + + events := []string{} + for event := range s.events { + events = append(events, string(event)) + } + return events +} + +func (s *ExternalAgent) GetState() ExternalAgentState { + s.ManagedThread.Lock() + defer s.ManagedThread.Unlock() + return s.currentState +} + +func (s *ExternalAgent) Register(events []Event) error { + s.ManagedThread.Lock() + defer s.ManagedThread.Unlock() + return s.currentState.Register(events) +} + +func (s *ExternalAgent) Ready() error { + s.ManagedThread.Lock() + defer s.ManagedThread.Unlock() + return s.currentState.Ready() +} + +func (s *ExternalAgent) InitError(errorType model.ErrorType) error { + s.ManagedThread.Lock() + defer s.ManagedThread.Unlock() + return s.currentState.InitError(errorType) +} + +func (s *ExternalAgent) ExitError(errorType model.ErrorType) error { + s.ManagedThread.Lock() + defer s.ManagedThread.Unlock() + return s.currentState.ExitError(errorType) +} + +func (s *ExternalAgent) ShutdownFailed() error { + s.ManagedThread.Lock() + defer s.ManagedThread.Unlock() + return s.currentState.ShutdownFailed() +} + +func (s *ExternalAgent) Exited() error { + s.ManagedThread.Lock() + defer s.ManagedThread.Unlock() + return s.currentState.Exited() +} + +func (s *ExternalAgent) ErrorType() model.ErrorType { + s.ManagedThread.Lock() + defer s.ManagedThread.Unlock() + return s.errorType +} + +func (s *ExternalAgent) LaunchError(err model.ErrorType) error { + s.ManagedThread.Lock() + defer s.ManagedThread.Unlock() + return s.currentState.LaunchError(err) +} + +func (s *ExternalAgent) GetAgentDescription() statejson.ExtensionDescription { + s.ManagedThread.Lock() + defer s.ManagedThread.Unlock() + return statejson.ExtensionDescription{ + Name: s.name, + ID: s.id.String(), + State: statejson.StateDescription{ + Name: s.currentState.Name(), + LastModified: s.stateLastModified, + }, + ErrorType: string(s.errorType), + } +} diff --git a/internal/lambda-managed-instances/core/externalagent_states.go b/internal/lambda-managed-instances/core/externalagent_states.go new file mode 100644 index 0000000..ae80c6a --- /dev/null +++ b/internal/lambda-managed-instances/core/externalagent_states.go @@ -0,0 +1,189 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package core + +import ( + "log/slog" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" +) + +type ExternalAgentState interface { + Register([]Event) error + Ready() error + InitError(errorType model.ErrorType) error + ExitError(errorType model.ErrorType) error + ShutdownFailed() error + Exited() error + LaunchError(model.ErrorType) error + Name() string +} + +type ExternalAgentStartedState struct { + disallowEverything + agent *ExternalAgent + initFlow InitFlowSynchronization +} + +func (s *ExternalAgentStartedState) Register(events []Event) error { + for _, e := range events { + if err := s.agent.subscribeUnsafe(e); err != nil { + return err + } + } + s.agent.setStateUnsafe(s.agent.RegisteredState) + if err := s.initFlow.ExternalAgentRegistered(); err != nil { + slog.Error("External agent registration failed", "err", err) + } + return nil +} + +func (s *ExternalAgentStartedState) LaunchError(err model.ErrorType) error { + s.agent.setStateUnsafe(s.agent.LaunchErrorState) + s.agent.errorType = err + return nil +} + +func (s *ExternalAgentStartedState) Name() string { + return AgentStartedStateName +} + +type ExternalAgentRegisteredState struct { + disallowEverything + agent *ExternalAgent + initFlow InitFlowSynchronization +} + +func (s *ExternalAgentRegisteredState) Ready() error { + s.agent.setStateUnsafe(s.agent.ReadyState) + if err := s.initFlow.AgentReady(); err != nil { + slog.Error("Agent ready failed", "err", err) + } + s.agent.ManagedThread.SuspendUnsafe() + + if s.agent.currentState != s.agent.ReadyState { + return ErrConcurrentStateModification + } + s.agent.setStateUnsafe(s.agent.RunningState) + + return nil +} + +func (s *ExternalAgentRegisteredState) InitError(errorType model.ErrorType) error { + s.agent.setStateUnsafe(s.agent.InitErrorState) + s.agent.errorType = errorType + return nil +} + +func (s *ExternalAgentRegisteredState) ExitError(errorType model.ErrorType) error { + s.agent.setStateUnsafe(s.agent.ExitErrorState) + s.agent.errorType = errorType + return nil +} + +func (s *ExternalAgentRegisteredState) Name() string { + return AgentRegisteredStateName +} + +type ExternalAgentReadyState struct { + disallowEverything + agent *ExternalAgent +} + +func (s *ExternalAgentReadyState) ExitError(errorType model.ErrorType) error { + s.agent.setStateUnsafe(s.agent.ExitErrorState) + s.agent.errorType = errorType + return nil +} + +func (s *ExternalAgentReadyState) Name() string { + return AgentReadyStateName +} + +type ExternalAgentRunningState struct { + disallowEverything + agent *ExternalAgent +} + +func (s *ExternalAgentRunningState) Ready() error { + s.agent.setStateUnsafe(s.agent.ReadyState) + s.agent.ManagedThread.SuspendUnsafe() + + if s.agent.currentState != s.agent.ReadyState { + return ErrConcurrentStateModification + } + s.agent.setStateUnsafe(s.agent.RunningState) + + return nil +} + +func (s *ExternalAgentRunningState) ExitError(errorType model.ErrorType) error { + s.agent.setStateUnsafe(s.agent.ExitErrorState) + s.agent.errorType = errorType + return nil +} + +func (s *ExternalAgentRunningState) ShutdownFailed() error { + s.agent.setStateUnsafe(s.agent.ShutdownFailedState) + return nil +} + +func (s *ExternalAgentRunningState) Exited() error { + s.agent.setStateUnsafe(s.agent.ExitedState) + return nil +} + +func (s *ExternalAgentRunningState) Name() string { + return AgentRunningStateName +} + +type ExternalAgentInitErrorState struct { + disallowEverything +} + +func (s *ExternalAgentInitErrorState) Name() string { + return AgentInitErrorStateName +} + +func (s *ExternalAgentInitErrorState) InitError(errorType model.ErrorType) error { + + return nil +} + +type ExternalAgentExitErrorState struct { + disallowEverything +} + +func (s *ExternalAgentExitErrorState) Name() string { + return AgentExitErrorStateName +} + +func (s *ExternalAgentExitErrorState) ExitError(errorType model.ErrorType) error { + + return nil +} + +type ExternalAgentShutdownFailedState struct { + disallowEverything +} + +func (s *ExternalAgentShutdownFailedState) Name() string { + return AgentShutdownFailedStateName +} + +type ExternalAgentExitedState struct { + disallowEverything +} + +func (s *ExternalAgentExitedState) Name() string { + return AgentExitedStateName +} + +type ExternalAgentLaunchErrorState struct { + disallowEverything +} + +func (s *ExternalAgentLaunchErrorState) Name() string { + return AgentLaunchErrorName +} diff --git a/internal/lambda-managed-instances/core/externalagent_states_test.go b/internal/lambda-managed-instances/core/externalagent_states_test.go new file mode 100644 index 0000000..429f0f0 --- /dev/null +++ b/internal/lambda-managed-instances/core/externalagent_states_test.go @@ -0,0 +1,189 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package core + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/testdata/mockthread" +) + +func TestExternalAgentStateUnknownEventType(t *testing.T) { + agent := NewExternalAgent("name", &mockInitFlowSynchronization{}) + agent.ManagedThread = &mockthread.MockManagedThread{} + require.Equal(t, agent.StartedState, agent.GetState()) + require.Equal(t, errInvalidEventType, agent.Register([]Event{"foo"})) + require.Equal(t, agent.StartedState, agent.GetState()) +} + +func TestExternalAgentStateTransitionsFromStartedState(t *testing.T) { + agent := NewExternalAgent("name", &mockInitFlowSynchronization{}) + agent.ManagedThread = &mockthread.MockManagedThread{} + + require.Equal(t, agent.StartedState, agent.GetState()) + + require.NoError(t, agent.Register([]Event{})) + require.Equal(t, agent.RegisteredState, agent.GetState()) + agent.SetState(agent.StartedState) + + require.NoError(t, agent.LaunchError(model.ErrorAgentPermissionDenied)) + require.Equal(t, agent.LaunchErrorState, agent.GetState()) + agent.SetState(agent.StartedState) + + require.Equal(t, ErrNotAllowed, agent.Ready()) + require.Equal(t, agent.StartedState, agent.GetState()) + + require.Equal(t, ErrNotAllowed, agent.InitError("Extension.TestError")) + require.Equal(t, agent.StartedState, agent.GetState()) + + require.Equal(t, ErrNotAllowed, agent.ExitError("Extension.TestError")) + require.Equal(t, agent.StartedState, agent.GetState()) + + require.Equal(t, ErrNotAllowed, agent.ShutdownFailed()) + require.Equal(t, agent.StartedState, agent.GetState()) + + require.Equal(t, ErrNotAllowed, agent.Exited()) + require.Equal(t, agent.StartedState, agent.GetState()) +} + +func TestExternalAgentStateTransitionsFromRegisteredState(t *testing.T) { + agent := NewExternalAgent("name", &mockInitFlowSynchronization{}) + agent.ManagedThread = &mockthread.MockManagedThread{} + agent.SetState(agent.RegisteredState) + + require.Equal(t, ErrNotAllowed, agent.Register([]Event{})) + require.Equal(t, agent.RegisteredState, agent.GetState()) + + require.NoError(t, agent.Ready()) + require.Equal(t, agent.RunningState, agent.GetState()) + + agent.SetState(agent.RegisteredState) + require.NoError(t, agent.InitError("Extension.TestError")) + require.Equal(t, agent.InitErrorState, agent.GetState()) + require.Equal(t, model.ErrorType("Extension.TestError"), agent.errorType) + + agent.SetState(agent.RegisteredState) + require.NoError(t, agent.ExitError("Extension.TestError")) + require.Equal(t, agent.ExitErrorState, agent.GetState()) + require.Equal(t, model.ErrorType("Extension.TestError"), agent.errorType) +} + +func TestExternalAgentStateTransitionsFromReadyState(t *testing.T) { + agent := NewExternalAgent("name", &mockInitFlowSynchronization{}) + agent.ManagedThread = &mockthread.MockManagedThread{} + agent.SetState(agent.ReadyState) + + require.Equal(t, ErrNotAllowed, agent.Register([]Event{})) + require.Equal(t, agent.ReadyState, agent.GetState()) + + require.Equal(t, ErrNotAllowed, agent.Ready()) + require.Equal(t, agent.ReadyState, agent.GetState()) + + require.Equal(t, ErrNotAllowed, agent.InitError("Extension.TestError")) + require.Equal(t, agent.ReadyState, agent.GetState()) + + agent.SetState(agent.ReadyState) + require.NoError(t, agent.ExitError("Extension.TestError")) + require.Equal(t, agent.ExitErrorState, agent.GetState()) + require.Equal(t, model.ErrorType("Extension.TestError"), agent.errorType) + + agent.SetState(agent.ReadyState) + require.Equal(t, ErrNotAllowed, agent.Exited()) + require.Equal(t, agent.ReadyState, agent.GetState()) + + require.Equal(t, ErrNotAllowed, agent.ShutdownFailed()) + require.Equal(t, agent.ReadyState, agent.GetState()) +} + +func assertAgentIsInFinalState(t *testing.T, agent *ExternalAgent) { + initialState := agent.GetState() + require.Equal(t, ErrNotAllowed, agent.Register([]Event{})) + require.Equal(t, initialState, agent.GetState()) + require.Equal(t, ErrNotAllowed, agent.Ready()) + require.Equal(t, initialState, agent.GetState()) + require.Equal(t, ErrNotAllowed, agent.ShutdownFailed()) + require.Equal(t, initialState, agent.GetState()) + require.Equal(t, ErrNotAllowed, agent.Exited()) + require.Equal(t, initialState, agent.GetState()) + require.Equal(t, ErrNotAllowed, agent.LaunchError(model.ErrorAgentExtensionLaunch)) + require.Equal(t, initialState, agent.GetState()) + + if agent.InitErrorState == initialState { + require.Equal(t, nil, agent.InitError("Extension.TestError")) + } else { + require.Equal(t, ErrNotAllowed, agent.InitError("Extension.TestError")) + } + + require.Equal(t, initialState, agent.GetState()) + + if agent.ExitErrorState == initialState { + require.Equal(t, nil, agent.ExitError("Extension.TestError")) + } else { + require.Equal(t, ErrNotAllowed, agent.ExitError("Extension.TestError")) + } + + require.Equal(t, initialState, agent.GetState()) +} + +func TestExternalAgentStateTransitionsFromInitErrorState(t *testing.T) { + agent := NewExternalAgent("name", &mockInitFlowSynchronization{}) + agent.ManagedThread = &mockthread.MockManagedThread{} + agent.SetState(agent.InitErrorState) + assertAgentIsInFinalState(t, agent) +} + +func TestExternalAgentStateTransitionsFromExitErrorState(t *testing.T) { + agent := NewExternalAgent("name", &mockInitFlowSynchronization{}) + agent.ManagedThread = &mockthread.MockManagedThread{} + agent.SetState(agent.ExitErrorState) + assertAgentIsInFinalState(t, agent) +} + +func TestExternalAgentStateTransitionsFromShutdownFailedState(t *testing.T) { + agent := NewExternalAgent("name", &mockInitFlowSynchronization{}) + agent.ManagedThread = &mockthread.MockManagedThread{} + agent.SetState(agent.ShutdownFailedState) + assertAgentIsInFinalState(t, agent) +} + +func TestExternalAgentStateTransitionsFromExitedState(t *testing.T) { + agent := NewExternalAgent("name", &mockInitFlowSynchronization{}) + agent.ManagedThread = &mockthread.MockManagedThread{} + agent.SetState(agent.ExitedState) + assertAgentIsInFinalState(t, agent) +} + +func TestExternalAgentStateTransitionsFromRunningState(t *testing.T) { + agent := NewExternalAgent("name", &mockInitFlowSynchronization{}) + agent.ManagedThread = &mockthread.MockManagedThread{} + agent.SetState(agent.RunningState) + require.Equal(t, agent.RunningState, agent.GetState()) + + require.Equal(t, ErrNotAllowed, agent.Register([]Event{})) + require.Equal(t, agent.RunningState, agent.GetState()) + + require.Equal(t, ErrNotAllowed, agent.InitError("Extension.TestError")) + require.Equal(t, agent.RunningState, agent.GetState()) + + require.NoError(t, agent.ShutdownFailed()) + require.Equal(t, agent.ShutdownFailedState, agent.GetState()) + + agent.SetState(agent.RunningState) + require.NoError(t, agent.Exited()) + require.Equal(t, agent.ExitedState, agent.GetState()) + + agent.SetState(agent.RunningState) + require.NoError(t, agent.Ready()) + require.Equal(t, agent.RunningState, agent.GetState()) +} + +func TestExternalAgentStateTransitionsFromLaunchErrorState(t *testing.T) { + agent := NewExternalAgent("name", &mockInitFlowSynchronization{}) + agent.ManagedThread = &mockthread.MockManagedThread{} + agent.SetState(agent.LaunchErrorState) + assertAgentIsInFinalState(t, agent) +} diff --git a/internal/lambda-managed-instances/core/flow.go b/internal/lambda-managed-instances/core/flow.go new file mode 100644 index 0000000..c332bfb --- /dev/null +++ b/internal/lambda-managed-instances/core/flow.go @@ -0,0 +1,88 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package core + +import ( + "context" +) + +type InitFlowSynchronization interface { + SetExternalAgentsRegisterCount(uint16) error + SetAgentsReadyCount(uint16) error + + ExternalAgentRegistered() error + AwaitExternalAgentsRegistered(context.Context) error + + RuntimeReady() error + AwaitRuntimeReady(context.Context) error + + AgentReady() error + AwaitAgentsReady(context.Context) error + + CancelWithError(error) + + Clear() +} + +type initFlowSynchronizationImpl struct { + externalAgentsRegisteredGate Gate + runtimeReadyGate Gate + agentReadyGate Gate +} + +func (s *initFlowSynchronizationImpl) SetExternalAgentsRegisterCount(externalAgentsNumber uint16) error { + return s.externalAgentsRegisteredGate.SetCount(externalAgentsNumber) +} + +func (s *initFlowSynchronizationImpl) SetAgentsReadyCount(agentCount uint16) error { + return s.agentReadyGate.SetCount(agentCount) +} + +func (s *initFlowSynchronizationImpl) AwaitRuntimeReady(ctx context.Context) error { + _, err := s.runtimeReadyGate.AwaitGateCondition(ctx) + return err +} + +func (s *initFlowSynchronizationImpl) AwaitExternalAgentsRegistered(ctx context.Context) error { + _, err := s.externalAgentsRegisteredGate.AwaitGateCondition(ctx) + return err +} + +func (s *initFlowSynchronizationImpl) AwaitAgentsReady(ctx context.Context) error { + _, err := s.agentReadyGate.AwaitGateCondition(ctx) + return err +} + +func (s *initFlowSynchronizationImpl) RuntimeReady() error { + return s.runtimeReadyGate.WalkThrough(nil) +} + +func (s *initFlowSynchronizationImpl) AgentReady() error { + return s.agentReadyGate.WalkThrough(nil) +} + +func (s *initFlowSynchronizationImpl) ExternalAgentRegistered() error { + return s.externalAgentsRegisteredGate.WalkThrough(nil) +} + +func (s *initFlowSynchronizationImpl) CancelWithError(err error) { + s.externalAgentsRegisteredGate.CancelWithError(err) + s.runtimeReadyGate.CancelWithError(err) + s.agentReadyGate.CancelWithError(err) +} + +func (s *initFlowSynchronizationImpl) Clear() { + s.externalAgentsRegisteredGate.Clear() + s.runtimeReadyGate.Clear() + s.agentReadyGate.Clear() +} + +func NewInitFlowSynchronization() InitFlowSynchronization { + initFlow := &initFlowSynchronizationImpl{ + runtimeReadyGate: NewGate(1), + externalAgentsRegisteredGate: NewGate(0), + agentReadyGate: NewGate(maxAgentsLimit), + } + return initFlow +} diff --git a/internal/lambda-managed-instances/core/gates.go b/internal/lambda-managed-instances/core/gates.go new file mode 100644 index 0000000..a00bdb6 --- /dev/null +++ b/internal/lambda-managed-instances/core/gates.go @@ -0,0 +1,144 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package core + +import ( + "context" + "errors" + "math" + "sync" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" +) + +const maxAgentsLimit uint16 = math.MaxUint16 + +type Gate interface { + Register(count uint16) + Reset() + SetCount(uint16) error + WalkThrough(interface{}) error + AwaitGateCondition(ctx context.Context) (interface{}, error) + CancelWithError(error) + Clear() +} + +type gateImpl struct { + gateCondition *sync.Cond + count uint16 + arrived uint16 + value interface{} + canceled bool + err error +} + +func (g *gateImpl) Register(count uint16) { + g.gateCondition.L.Lock() + defer g.gateCondition.L.Unlock() + g.count += count +} + +func (g *gateImpl) SetCount(count uint16) error { + g.gateCondition.L.Lock() + defer g.gateCondition.L.Unlock() + + if count > maxAgentsLimit || count < g.arrived { + return ErrGateIntegrity + } + g.count = count + return nil +} + +func (g *gateImpl) Reset() { + g.gateCondition.L.Lock() + defer g.gateCondition.L.Unlock() + if !g.canceled { + g.arrived = 0 + } +} + +var ErrGateIntegrity = errors.New("ErrGateIntegrity") + +var ErrGateCanceled = errors.New("ErrGateCanceled") + +func (g *gateImpl) WalkThrough(value interface{}) error { + g.gateCondition.L.Lock() + defer g.gateCondition.L.Unlock() + + if g.arrived == g.count { + return ErrGateIntegrity + } + + g.arrived++ + g.value = value + + if g.arrived == g.count { + g.gateCondition.Broadcast() + } + + return nil +} + +func (g *gateImpl) awaitGateCondition() error { + g.gateCondition.L.Lock() + defer g.gateCondition.L.Unlock() + + for g.arrived != g.count && !g.canceled { + g.gateCondition.Wait() + } + + if g.canceled { + if g.err != nil { + return g.err + } + return ErrGateCanceled + } + + return nil +} + +func (g *gateImpl) AwaitGateCondition(ctx context.Context) (interface{}, error) { + var err error + errorChan := make(chan error) + + go func() { + errorChan <- g.awaitGateCondition() + }() + + select { + case err = <-errorChan: + break + case <-ctx.Done(): + err = interop.ErrTimeout + g.CancelWithError(err) + break + } + + return g.value, err +} + +func (g *gateImpl) CancelWithError(err error) { + g.gateCondition.L.Lock() + defer g.gateCondition.L.Unlock() + g.canceled = true + g.err = err + g.gateCondition.Broadcast() +} + +func (g *gateImpl) Clear() { + g.gateCondition.L.Lock() + defer g.gateCondition.L.Unlock() + + g.canceled = false + g.arrived = 0 + g.value = nil + g.err = nil +} + +func NewGate(count uint16) Gate { + return &gateImpl{ + count: count, + gateCondition: sync.NewCond(&sync.Mutex{}), + } +} diff --git a/internal/lambda-managed-instances/core/gates_test.go b/internal/lambda-managed-instances/core/gates_test.go new file mode 100644 index 0000000..95c25a6 --- /dev/null +++ b/internal/lambda-managed-instances/core/gates_test.go @@ -0,0 +1,136 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package core + +import ( + "context" + "errors" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" +) + +func TestWalkThrough(t *testing.T) { + g := NewGate(1) + assert.NoError(t, g.WalkThrough(nil)) +} + +func TestWalkThroughTwice(t *testing.T) { + g := NewGate(1) + assert.NoError(t, g.WalkThrough(nil)) + assert.Equal(t, ErrGateIntegrity, g.WalkThrough(nil)) +} + +func TestSetCount(t *testing.T) { + g := NewGate(2) + assert.NoError(t, g.WalkThrough(nil)) + assert.NoError(t, g.WalkThrough(nil)) + assert.Equal(t, ErrGateIntegrity, g.SetCount(1)) + assert.NoError(t, g.SetCount(2)) + assert.Equal(t, ErrGateIntegrity, g.WalkThrough(nil)) + assert.NoError(t, g.SetCount(3)) + assert.NoError(t, g.WalkThrough(nil)) +} + +func TestReset(t *testing.T) { + g := NewGate(1) + assert.NoError(t, g.WalkThrough(nil)) + g.Reset() + assert.NoError(t, g.WalkThrough(nil)) +} + +func TestCancel(t *testing.T) { + g := NewGate(1) + ch := make(chan error) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + wg.Done() + _, err := g.AwaitGateCondition(context.Background()) + ch <- err + }() + + wg.Wait() + g.CancelWithError(nil) + err := <-ch + + assert.Equal(t, ErrGateCanceled, err) +} + +func TestCancelWithError(t *testing.T) { + g := NewGate(1) + ch := make(chan error) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + wg.Done() + _, err := g.AwaitGateCondition(context.Background()) + ch <- err + }() + + wg.Wait() + err := errors.New("MyErr") + g.CancelWithError(err) + receivedErr := <-ch + + assert.Equal(t, err, receivedErr) +} + +func TestUseAfterCancel(t *testing.T) { + g := NewGate(1) + err := errors.New("MyErr") + g.CancelWithError(err) + _, awaitErr := g.AwaitGateCondition(context.Background()) + assert.Equal(t, err, awaitErr) + g.Reset() + _, awaitErr = g.AwaitGateCondition(context.Background()) + assert.Equal(t, err, awaitErr) +} + +func TestAwaitGateConditionWithDeadlineWalkthrough(t *testing.T) { + g := NewGate(1) + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + require.NoError(t, g.WalkThrough(nil)) + + _, err := g.AwaitGateCondition(ctx) + + assert.Equal(t, nil, err) +} + +func TestAwaitGateConditionWithDeadlineTimeout(t *testing.T) { + g := NewGate(1) + + ctx, cancel := context.WithCancel(context.Background()) + + cancel() + + require.NoError(t, g.WalkThrough(nil)) + + _, err := g.AwaitGateCondition(ctx) + + assert.Equal(t, interop.ErrTimeout, err) +} + +func BenchmarkAwaitGateCondition(b *testing.B) { + g := NewGate(1) + + for n := 0; n < b.N; n++ { + go func() { require.NoError(b, g.WalkThrough(nil)) }() + _, err := g.AwaitGateCondition(context.Background()) + if err != nil { + panic(err) + } + g.Reset() + } +} diff --git a/internal/lambda-managed-instances/core/internalagent.go b/internal/lambda-managed-instances/core/internalagent.go new file mode 100644 index 0000000..50c3666 --- /dev/null +++ b/internal/lambda-managed-instances/core/internalagent.go @@ -0,0 +1,133 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package core + +import ( + "fmt" + "time" + + "github.com/google/uuid" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/core/statejson" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" +) + +type InternalAgent struct { + name string + id uuid.UUID + + ManagedThread Suspendable + + currentState InternalAgentState + stateLastModified time.Time + + StartedState InternalAgentState + RegisteredState InternalAgentState + RunningState InternalAgentState + ReadyState InternalAgentState + InitErrorState InternalAgentState + ExitErrorState InternalAgentState + + errorType model.ErrorType +} + +func NewInternalAgent(name string, initFlow InitFlowSynchronization) *InternalAgent { + agent := &InternalAgent{ + name: name, + id: uuid.New(), + ManagedThread: NewManagedThread(), + } + + agent.StartedState = &InternalAgentStartedState{agent: agent} + agent.RegisteredState = &InternalAgentRegisteredState{agent: agent, initFlow: initFlow} + agent.RunningState = &InternalAgentRunningState{agent: agent} + agent.ReadyState = &InternalAgentReadyState{agent: agent} + agent.InitErrorState = &InternalAgentInitErrorState{} + agent.ExitErrorState = &InternalAgentExitErrorState{} + + agent.setStateUnsafe(agent.StartedState) + + return agent +} + +func (s *InternalAgent) Name() string { + return s.name +} + +func (s *InternalAgent) ID() uuid.UUID { + return s.id +} + +func (s InternalAgent) String() string { + return fmt.Sprintf("%s (%s)", s.name, s.id) +} + +func (s *InternalAgent) SuspendUnsafe() { + s.ManagedThread.SuspendUnsafe() +} + +func (s *InternalAgent) Release() { + s.ManagedThread.Release() +} + +func (s *InternalAgent) SetState(state InternalAgentState) { + s.ManagedThread.Lock() + defer s.ManagedThread.Unlock() + s.setStateUnsafe(state) +} + +func (s *InternalAgent) setStateUnsafe(state InternalAgentState) { + s.currentState = state + s.stateLastModified = time.Now() +} + +func (s *InternalAgent) GetState() InternalAgentState { + s.ManagedThread.Lock() + defer s.ManagedThread.Unlock() + return s.currentState +} + +func (s *InternalAgent) Register(events []Event) error { + s.ManagedThread.Lock() + defer s.ManagedThread.Unlock() + return s.currentState.Register(events) +} + +func (s *InternalAgent) Ready() error { + s.ManagedThread.Lock() + defer s.ManagedThread.Unlock() + return s.currentState.Ready() +} + +func (s *InternalAgent) InitError(errorType model.ErrorType) error { + s.ManagedThread.Lock() + defer s.ManagedThread.Unlock() + return s.currentState.InitError(errorType) +} + +func (s *InternalAgent) ExitError(errorType model.ErrorType) error { + s.ManagedThread.Lock() + defer s.ManagedThread.Unlock() + return s.currentState.ExitError(errorType) +} + +func (s *InternalAgent) ErrorType() model.ErrorType { + s.ManagedThread.Lock() + defer s.ManagedThread.Unlock() + return s.errorType +} + +func (s *InternalAgent) GetAgentDescription() statejson.ExtensionDescription { + s.ManagedThread.Lock() + defer s.ManagedThread.Unlock() + return statejson.ExtensionDescription{ + Name: s.name, + ID: s.id.String(), + State: statejson.StateDescription{ + Name: s.currentState.Name(), + LastModified: s.stateLastModified, + }, + ErrorType: string(s.errorType), + } +} diff --git a/internal/lambda-managed-instances/core/internalagent_states.go b/internal/lambda-managed-instances/core/internalagent_states.go new file mode 100644 index 0000000..86e7b71 --- /dev/null +++ b/internal/lambda-managed-instances/core/internalagent_states.go @@ -0,0 +1,137 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package core + +import ( + "log/slog" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" +) + +type InternalAgentState interface { + Register([]Event) error + Ready() error + InitError(errorType model.ErrorType) error + ExitError(errorType model.ErrorType) error + Name() string +} + +type InternalAgentStartedState struct { + disallowEverything + agent *InternalAgent +} + +func (s *InternalAgentStartedState) Register(events []Event) error { + s.agent.setStateUnsafe(s.agent.RegisteredState) + return nil +} + +func (s *InternalAgentStartedState) Name() string { + return AgentStartedStateName +} + +type InternalAgentRegisteredState struct { + disallowEverything + agent *InternalAgent + initFlow InitFlowSynchronization +} + +func (s *InternalAgentRegisteredState) Ready() error { + s.agent.setStateUnsafe(s.agent.ReadyState) + if err := s.initFlow.AgentReady(); err != nil { + slog.Error("Agent ready failed", "err", err) + } + s.agent.ManagedThread.SuspendUnsafe() + + if s.agent.currentState != s.agent.ReadyState { + return ErrConcurrentStateModification + } + s.agent.setStateUnsafe(s.agent.RunningState) + + return nil +} + +func (s *InternalAgentRegisteredState) InitError(errorType model.ErrorType) error { + s.agent.setStateUnsafe(s.agent.InitErrorState) + s.agent.errorType = errorType + return nil +} + +func (s *InternalAgentRegisteredState) ExitError(errorType model.ErrorType) error { + s.agent.setStateUnsafe(s.agent.ExitErrorState) + s.agent.errorType = errorType + return nil +} + +func (s *InternalAgentRegisteredState) Name() string { + return AgentRegisteredStateName +} + +type InternalAgentReadyState struct { + disallowEverything + agent *InternalAgent +} + +func (s *InternalAgentReadyState) ExitError(errorType model.ErrorType) error { + s.agent.setStateUnsafe(s.agent.ExitErrorState) + s.agent.errorType = errorType + return nil +} + +func (s *InternalAgentReadyState) Name() string { + return AgentReadyStateName +} + +type InternalAgentRunningState struct { + disallowEverything + agent *InternalAgent +} + +func (s *InternalAgentRunningState) Ready() error { + s.agent.setStateUnsafe(s.agent.ReadyState) + s.agent.ManagedThread.SuspendUnsafe() + + if s.agent.currentState != s.agent.ReadyState { + return ErrConcurrentStateModification + } + s.agent.setStateUnsafe(s.agent.RunningState) + + return nil +} + +func (s *InternalAgentRunningState) ExitError(errorType model.ErrorType) error { + s.agent.setStateUnsafe(s.agent.ExitErrorState) + s.agent.errorType = errorType + return nil +} + +func (s *InternalAgentRunningState) Name() string { + return AgentRunningStateName +} + +type InternalAgentInitErrorState struct { + disallowEverything +} + +func (s *InternalAgentInitErrorState) Name() string { + return AgentInitErrorStateName +} + +func (s *InternalAgentInitErrorState) InitError(errorType model.ErrorType) error { + + return nil +} + +type InternalAgentExitErrorState struct { + disallowEverything +} + +func (s *InternalAgentExitErrorState) Name() string { + return AgentExitErrorStateName +} + +func (s *InternalAgentExitErrorState) ExitError(errorType model.ErrorType) error { + + return nil +} diff --git a/internal/lambda-managed-instances/core/internalagent_states_test.go b/internal/lambda-managed-instances/core/internalagent_states_test.go new file mode 100644 index 0000000..89060d1 --- /dev/null +++ b/internal/lambda-managed-instances/core/internalagent_states_test.go @@ -0,0 +1,119 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package core + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/testdata/mockthread" +) + +func TestInternalAgentStateTransitionsFromStartedState(t *testing.T) { + agent := NewInternalAgent("name", &mockInitFlowSynchronization{}) + agent.ManagedThread = &mockthread.MockManagedThread{} + + require.Equal(t, agent.StartedState, agent.GetState()) + require.NoError(t, agent.Register([]Event{})) + require.Equal(t, agent.RegisteredState, agent.GetState()) + + agent.SetState(agent.StartedState) + require.Equal(t, ErrNotAllowed, agent.Ready()) + require.Equal(t, agent.StartedState, agent.GetState()) + + require.Equal(t, ErrNotAllowed, agent.InitError("Extension.TestError")) + require.Equal(t, agent.StartedState, agent.GetState()) + + require.Equal(t, ErrNotAllowed, agent.ExitError("Extension.TestError")) + require.Equal(t, agent.StartedState, agent.GetState()) +} + +func TestInternalAgentStateTransitionsFromRegisteredState(t *testing.T) { + agent := NewInternalAgent("name", &mockInitFlowSynchronization{}) + agent.ManagedThread = &mockthread.MockManagedThread{} + agent.SetState(agent.RegisteredState) + + require.Equal(t, ErrNotAllowed, agent.Register([]Event{})) + require.Equal(t, agent.RegisteredState, agent.GetState()) + + require.NoError(t, agent.Ready()) + require.Equal(t, agent.RunningState, agent.GetState()) + + agent.SetState(agent.RegisteredState) + require.NoError(t, agent.InitError("Extension.TestError")) + require.Equal(t, agent.InitErrorState, agent.GetState()) + require.Equal(t, model.ErrorType("Extension.TestError"), agent.errorType) + + agent.SetState(agent.RegisteredState) + require.NoError(t, agent.ExitError("Extension.TestError")) + require.Equal(t, agent.ExitErrorState, agent.GetState()) + require.Equal(t, model.ErrorType("Extension.TestError"), agent.errorType) +} + +func TestInternalAgentStateTransitionsFromReadyState(t *testing.T) { + agent := NewInternalAgent("name", &mockInitFlowSynchronization{}) + agent.ManagedThread = &mockthread.MockManagedThread{} + agent.SetState(agent.ReadyState) + + require.Equal(t, ErrNotAllowed, agent.Register([]Event{})) + require.Equal(t, agent.ReadyState, agent.GetState()) + + require.Equal(t, ErrNotAllowed, agent.InitError("Extension.TestError")) + require.Equal(t, agent.ReadyState, agent.GetState()) + + agent.SetState(agent.ReadyState) + require.NoError(t, agent.ExitError("Extension.TestError")) + require.Equal(t, agent.ExitErrorState, agent.GetState()) + require.Equal(t, model.ErrorType("Extension.TestError"), agent.errorType) + + agent.SetState(agent.ReadyState) + require.Equal(t, ErrNotAllowed, agent.Ready()) + require.Equal(t, agent.ReadyState, agent.GetState()) +} + +func TestInternalAgentStateTransitionsFromInitErrorState(t *testing.T) { + agent := NewInternalAgent("name", &mockInitFlowSynchronization{}) + agent.ManagedThread = &mockthread.MockManagedThread{} + agent.SetState(agent.InitErrorState) + + require.Equal(t, ErrNotAllowed, agent.Register([]Event{})) + require.Equal(t, agent.InitErrorState, agent.GetState()) + require.Equal(t, nil, agent.InitError("Extension.TestError")) + require.Equal(t, agent.InitErrorState, agent.GetState()) + require.Equal(t, ErrNotAllowed, agent.ExitError("Extension.TestError")) + require.Equal(t, agent.InitErrorState, agent.GetState()) + require.Equal(t, ErrNotAllowed, agent.Ready()) + require.Equal(t, agent.InitErrorState, agent.GetState()) +} + +func TestInternalAgentStateTransitionsFromExitErrorState(t *testing.T) { + agent := NewInternalAgent("name", &mockInitFlowSynchronization{}) + agent.ManagedThread = &mockthread.MockManagedThread{} + agent.SetState(agent.ExitErrorState) + + require.Equal(t, ErrNotAllowed, agent.Register([]Event{})) + require.Equal(t, agent.ExitErrorState, agent.GetState()) + require.Equal(t, nil, agent.ExitError("Extension.TestError")) + require.Equal(t, agent.ExitErrorState, agent.GetState()) + require.Equal(t, ErrNotAllowed, agent.InitError("Extension.TestError")) + require.Equal(t, agent.ExitErrorState, agent.GetState()) + require.Equal(t, ErrNotAllowed, agent.Ready()) + require.Equal(t, agent.ExitErrorState, agent.GetState()) +} + +func TestInternalAgentStateTransitionsFromRunningState(t *testing.T) { + agent := NewInternalAgent("name", &mockInitFlowSynchronization{}) + agent.ManagedThread = &mockthread.MockManagedThread{} + agent.SetState(agent.RunningState) + require.Equal(t, agent.RunningState, agent.GetState()) + + require.Equal(t, ErrNotAllowed, agent.Register([]Event{})) + require.Equal(t, agent.RunningState, agent.GetState()) + + agent.SetState(agent.RunningState) + require.NoError(t, agent.Ready()) + require.Equal(t, agent.RunningState, agent.GetState()) +} diff --git a/internal/lambda-managed-instances/core/registrations.go b/internal/lambda-managed-instances/core/registrations.go new file mode 100644 index 0000000..5506a71 --- /dev/null +++ b/internal/lambda-managed-instances/core/registrations.go @@ -0,0 +1,350 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package core + +import ( + "errors" + "log/slog" + "os" + "sync" + + "github.com/google/uuid" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/appctx" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/core/statejson" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" +) + +type registrationServiceState int + +const ( + registrationServiceOn registrationServiceState = iota + registrationServiceOff +) + +const MaxAgentsAllowed = 10 + +type Event string + +const ( + ShutdownEvent Event = "SHUTDOWN" +) + +var ErrRegistrationServiceOff = errors.New("ErrRegistrationServiceOff") + +var ErrTooManyExtensions = errors.New("ErrTooManyExtensions") + +func MapErrorToAgentInfoErrorType(err error) model.ErrorType { + if errors.Is(err, os.ErrPermission) { + return model.ErrorAgentPermissionDenied + } + if err == ErrTooManyExtensions { + return model.ErrorAgentTooManyExtensions + } + return model.ErrorAgentExtensionLaunch +} + +type AgentInfo struct { + Name string + State string + Subscriptions []string + ErrorType model.ErrorType +} + +type RegistrationService interface { + CreateExternalAgent(agentName string) (*ExternalAgent, error) + CreateInternalAgent(agentName string) (*InternalAgent, error) + PreregisterRuntime(r *Runtime) error + SetFunctionMetadata(metadata model.FunctionMetadata) + GetFunctionMetadata() model.FunctionMetadata + GetRuntime() *Runtime + GetRegisteredAgentsSize() uint16 + FindExternalAgentByName(agentName string) (*ExternalAgent, bool) + FindInternalAgentByName(agentName string) (*InternalAgent, bool) + FindExternalAgentByID(agentID uuid.UUID) (*ExternalAgent, bool) + FindInternalAgentByID(agentID uuid.UUID) (*InternalAgent, bool) + TurnOff() + InitFlow() InitFlowSynchronization + GetInternalStateDescriptor(appCtx appctx.ApplicationContext) func() statejson.InternalStateDescription + GetInternalAgents() []*InternalAgent + GetExternalAgents() []*ExternalAgent + GetSubscribedExternalAgents(eventType Event) []*ExternalAgent + CountAgents() int + Clear() + AgentsInfo() []AgentInfo + CancelFlows(err error) +} + +type registrationServiceImpl struct { + runtime *Runtime + internalAgents AgentsMap[*InternalAgent] + externalAgents AgentsMap[*ExternalAgent] + state registrationServiceState + mutex *sync.Mutex + initFlow InitFlowSynchronization + functionMetadata model.FunctionMetadata + cancelOnce sync.Once +} + +func (s *registrationServiceImpl) Clear() { + s.mutex.Lock() + defer s.mutex.Unlock() + + s.runtime = nil + s.internalAgents.Clear() + s.externalAgents.Clear() + s.state = registrationServiceOn + s.cancelOnce = sync.Once{} +} + +func (s *registrationServiceImpl) InitFlow() InitFlowSynchronization { + return s.initFlow +} + +func (s *registrationServiceImpl) GetInternalStateDescriptor(appCtx appctx.ApplicationContext) func() statejson.InternalStateDescription { + return func() statejson.InternalStateDescription { + return s.getInternalStateDescription(appCtx) + } +} + +func (s *registrationServiceImpl) getInternalStateDescription(appCtx appctx.ApplicationContext) statejson.InternalStateDescription { + isd := statejson.InternalStateDescription{ + Extensions: []statejson.ExtensionDescription{}, + } + + if s.runtime != nil { + + rtdesc := s.runtime.GetRuntimeDescription() + isd.Runtime = &rtdesc + } + + s.mutex.Lock() + defer s.mutex.Unlock() + + s.internalAgents.Visit(func(agent *InternalAgent) { + isd.Extensions = append(isd.Extensions, agent.GetAgentDescription()) + }) + + s.externalAgents.Visit(func(agent *ExternalAgent) { + isd.Extensions = append(isd.Extensions, agent.GetAgentDescription()) + }) + + if fatalerror, found := appctx.LoadFirstFatalError(appCtx); found { + isd.FirstFatalError = string(fatalerror.ErrorType()) + } + + return isd +} + +func (s *registrationServiceImpl) CountAgents() int { + s.mutex.Lock() + defer s.mutex.Unlock() + + return s.countAgentsUnsafe() +} + +func (s *registrationServiceImpl) countAgentsUnsafe() int { + res := 0 + s.externalAgents.Visit(func(a *ExternalAgent) { + res++ + }) + s.internalAgents.Visit(func(a *InternalAgent) { + res++ + }) + return res +} + +func (s *registrationServiceImpl) GetExternalAgents() []*ExternalAgent { + agents := []*ExternalAgent{} + s.externalAgents.Visit(func(a *ExternalAgent) { + agents = append(agents, a) + }) + return agents +} + +func (s *registrationServiceImpl) GetInternalAgents() []*InternalAgent { + agents := []*InternalAgent{} + s.internalAgents.Visit(func(a *InternalAgent) { + agents = append(agents, a) + }) + return agents +} + +func (s *registrationServiceImpl) GetSubscribedExternalAgents(eventType Event) []*ExternalAgent { + agents := []*ExternalAgent{} + s.externalAgents.Visit(func(a *ExternalAgent) { + if a.IsSubscribed(eventType) { + agents = append(agents, a) + } + }) + return agents +} + +func (s *registrationServiceImpl) CreateExternalAgent(agentName string) (*ExternalAgent, error) { + agent := NewExternalAgent(agentName, s.initFlow) + + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.state != registrationServiceOn { + return nil, ErrRegistrationServiceOff + } + + if _, found := s.internalAgents.FindByName(agentName); found { + return nil, ErrAgentNameCollision + } + + if err := s.externalAgents.Insert(agent); err != nil { + return nil, err + } + + return agent, nil +} + +func (s *registrationServiceImpl) CreateInternalAgent(agentName string) (*InternalAgent, error) { + agent := NewInternalAgent(agentName, s.initFlow) + + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.state != registrationServiceOn { + return nil, ErrRegistrationServiceOff + } + + if s.countAgentsUnsafe() >= MaxAgentsAllowed { + return nil, ErrTooManyExtensions + } + + if _, found := s.externalAgents.FindByName(agentName); found { + return nil, ErrAgentNameCollision + } + + if err := s.internalAgents.Insert(agent); err != nil { + return nil, err + } + + return agent, nil +} + +func (s *registrationServiceImpl) PreregisterRuntime(r *Runtime) error { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.state != registrationServiceOn { + return ErrRegistrationServiceOff + } + + s.runtime = r + + return nil +} + +func (s *registrationServiceImpl) SetFunctionMetadata(metadata model.FunctionMetadata) { + s.functionMetadata = metadata +} + +func (s *registrationServiceImpl) GetFunctionMetadata() model.FunctionMetadata { + return s.functionMetadata +} + +func (s *registrationServiceImpl) GetRuntime() *Runtime { + s.mutex.Lock() + defer s.mutex.Unlock() + return s.runtime +} + +func (s *registrationServiceImpl) GetRegisteredAgentsSize() uint16 { + s.mutex.Lock() + defer s.mutex.Unlock() + return uint16(s.externalAgents.Size()) + uint16(s.internalAgents.Size()) +} + +func (s *registrationServiceImpl) FindExternalAgentByName(name string) (agent *ExternalAgent, found bool) { + s.mutex.Lock() + defer s.mutex.Unlock() + if agent, found = s.externalAgents.FindByName(name); found { + return agent, found + } + return agent, found +} + +func (s *registrationServiceImpl) FindInternalAgentByName(name string) (agent *InternalAgent, found bool) { + s.mutex.Lock() + defer s.mutex.Unlock() + if agent, found = s.internalAgents.FindByName(name); found { + return agent, found + } + return agent, found +} + +func (s *registrationServiceImpl) FindExternalAgentByID(agentID uuid.UUID) (agent *ExternalAgent, found bool) { + s.mutex.Lock() + defer s.mutex.Unlock() + if agent, found = s.externalAgents.FindByID(agentID); found { + return agent, found + } + return agent, found +} + +func (s *registrationServiceImpl) FindInternalAgentByID(agentID uuid.UUID) (agent *InternalAgent, found bool) { + s.mutex.Lock() + defer s.mutex.Unlock() + if agent, found = s.internalAgents.FindByID(agentID); found { + return agent, found + } + return agent, found +} + +func (s *registrationServiceImpl) AgentsInfo() []AgentInfo { + s.mutex.Lock() + defer s.mutex.Unlock() + + agentsInfo := []AgentInfo{} + for _, agent := range s.GetExternalAgents() { + agentsInfo = append(agentsInfo, AgentInfo{ + agent.Name(), + agent.GetState().Name(), + agent.SubscribedEvents(), + agent.ErrorType(), + }) + } + + for _, agent := range s.GetInternalAgents() { + agentsInfo = append(agentsInfo, AgentInfo{ + agent.Name(), + agent.GetState().Name(), + []string{}, + agent.ErrorType(), + }) + } + + return agentsInfo +} + +func (s *registrationServiceImpl) TurnOff() { + s.mutex.Lock() + defer s.mutex.Unlock() + s.state = registrationServiceOff +} + +func (s *registrationServiceImpl) CancelFlows(err error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + s.cancelOnce.Do(func() { + slog.Debug("Canceling flows", "err", err) + s.initFlow.CancelWithError(err) + }) +} + +func NewRegistrationService(initFlow InitFlowSynchronization) RegistrationService { + return ®istrationServiceImpl{ + mutex: &sync.Mutex{}, + state: registrationServiceOn, + internalAgents: NewAgentsMap[*InternalAgent](), + externalAgents: NewAgentsMap[*ExternalAgent](), + initFlow: initFlow, + cancelOnce: sync.Once{}, + } +} diff --git a/internal/lambda-managed-instances/core/registrations_test.go b/internal/lambda-managed-instances/core/registrations_test.go new file mode 100644 index 0000000..ffde84e --- /dev/null +++ b/internal/lambda-managed-instances/core/registrations_test.go @@ -0,0 +1,200 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package core + +import ( + "context" + "errors" + "fmt" + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" +) + +func TestRegistrationServiceHappyPathDuringInit(t *testing.T) { + + initFlow := NewInitFlowSynchronization() + registrationService := NewRegistrationService(initFlow) + + registrationService.SetFunctionMetadata(model.FunctionMetadata{ + FunctionName: "AWS_LAMBDA_FUNCTION_NAME", + FunctionVersion: "AWS_LAMBDA_FUNCTION_VERSION", + Handler: "_HANDLER", + }) + + extAgentNames := []string{"agentName1", "agentName2"} + assert.NoError(t, initFlow.SetExternalAgentsRegisterCount(uint16(len(extAgentNames)))) + + extAgent1, err := registrationService.CreateExternalAgent(extAgentNames[0]) + assert.NoError(t, err) + assert.Equal(t, extAgent1.String(), fmt.Sprintf("agentName1 (%s)", extAgent1.id)) + + extAgent2, err := registrationService.CreateExternalAgent(extAgentNames[1]) + assert.NoError(t, err) + + go func() { + for _, agentName := range extAgentNames { + agent, found := registrationService.FindExternalAgentByName(agentName) + assert.True(t, found) + + assert.NoError(t, agent.Register([]Event{ShutdownEvent})) + } + }() + + assert.NoError(t, initFlow.AwaitExternalAgentsRegistered(context.Background())) + + runtime := NewRuntime(initFlow) + assert.NoError(t, registrationService.PreregisterRuntime(runtime)) + + intAgentNames := []string{"intAgentName1", "intAgentName2"} + + intAgent1, err := registrationService.CreateInternalAgent(intAgentNames[0]) + assert.NoError(t, err) + assert.Equal(t, intAgent1.String(), fmt.Sprintf("intAgentName1 (%s)", intAgent1.id)) + + intAgent2, err := registrationService.CreateInternalAgent(intAgentNames[1]) + assert.NoError(t, err) + + go func() { + for _, agentName := range intAgentNames { + agent, found := registrationService.FindInternalAgentByName(agentName) + assert.True(t, found) + + assert.NoError(t, agent.Register([]Event{})) + } + assert.NoError(t, runtime.Ready()) + }() + + assert.NoError(t, initFlow.AwaitRuntimeReady(context.Background())) + registrationService.TurnOff() + + extAgent1Description := extAgent1.GetAgentDescription() + assert.Equal(t, extAgent1Description.Name, "agentName1") + assert.Equal(t, extAgent1Description.State.Name, "Registered") + assert.Equal(t, extAgent1Description.ErrorType, "") + + intAgent1Description := intAgent1.GetAgentDescription() + assert.Equal(t, intAgent1Description.Name, "intAgentName1") + assert.Equal(t, intAgent1Description.State.Name, "Registered") + assert.Equal(t, intAgent1Description.ErrorType, "") + + assert.NoError(t, initFlow.SetAgentsReadyCount(registrationService.GetRegisteredAgentsSize())) + go func() { + for _, agentName := range intAgentNames { + agent, found := registrationService.FindInternalAgentByName(agentName) + assert.True(t, found) + go func() { assert.NoError(t, agent.Ready()) }() + } + + for _, agentName := range extAgentNames { + agent, found := registrationService.FindExternalAgentByName(agentName) + assert.True(t, found) + go func() { assert.NoError(t, agent.Ready()) }() + } + }() + + assert.NoError(t, initFlow.AwaitAgentsReady(context.Background())) + + expectedAgents := []AgentInfo{ + {extAgent1.Name(), "Ready", []string{"SHUTDOWN"}, ""}, + {extAgent2.Name(), "Ready", []string{"SHUTDOWN"}, ""}, + {intAgent1.Name(), "Ready", []string{}, ""}, + {intAgent2.Name(), "Ready", []string{}, ""}, + } + + assert.Len(t, registrationService.AgentsInfo(), len(expectedAgents)) + + actualAgents := map[string]AgentInfo{} + for _, agentInfo := range registrationService.AgentsInfo() { + actualAgents[agentInfo.Name] = agentInfo + } + + for _, agentInfo := range expectedAgents { + assert.Contains(t, actualAgents, agentInfo.Name) + assert.Equal(t, actualAgents[agentInfo.Name].Name, agentInfo.Name) + assert.Equal(t, actualAgents[agentInfo.Name].State, agentInfo.State) + for _, event := range agentInfo.Subscriptions { + assert.Contains(t, actualAgents[agentInfo.Name].Subscriptions, event) + } + } +} + +func TestGetAgents(t *testing.T) { + initFlow := NewInitFlowSynchronization() + registrationService := NewRegistrationService(initFlow) + + externalAgentNames := map[string]bool{ + "external/1": true, + "external/2": true, + } + + for agentName := range externalAgentNames { + _, err := registrationService.CreateExternalAgent(agentName) + require.NoError(t, err) + } + + internalAgentNames := map[string]bool{ + "internal/1": true, + "internal/2": true, + "internal/3": true, + } + + for agentName := range internalAgentNames { + _, err := registrationService.CreateInternalAgent(agentName) + require.NoError(t, err) + } + + actualExternalAgents := registrationService.GetExternalAgents() + actualInternalAgents := registrationService.GetInternalAgents() + + assert.Equal(t, len(actualExternalAgents), 2) + assert.Equal(t, len(actualInternalAgents), 3) + + for _, agent := range actualExternalAgents { + delete(externalAgentNames, agent.Name()) + } + + assert.Equal(t, 0, len(externalAgentNames), "external agents: %v are not retrieved back from registration service", externalAgentNames) + + for _, agent := range actualInternalAgents { + delete(internalAgentNames, agent.Name()) + } + + assert.Equal(t, 0, len(internalAgentNames), "internal agents: %v are not retrieved back from registration service", internalAgentNames) +} + +func TestMapErrorToAgentInfoErrorType(t *testing.T) { + tests := []struct { + name string + err error + expectedErr model.ErrorType + }{ + { + name: "ErrTooManyExtensions returns too many extensions error", + err: ErrTooManyExtensions, + expectedErr: model.ErrorAgentTooManyExtensions, + }, + { + name: "wrapped permission error with fmt.Errorf returns permission denied error", + err: fmt.Errorf("failed to start command: %w", os.ErrPermission), + expectedErr: model.ErrorAgentPermissionDenied, + }, + { + name: "random error returns extension launch error", + err: errors.New("some random error"), + expectedErr: model.ErrorAgentExtensionLaunch, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := MapErrorToAgentInfoErrorType(tt.err) + assert.Equal(t, tt.expectedErr, result, "Expected %s but got %s for error: %v", tt.expectedErr, result, tt.err) + }) + } +} diff --git a/internal/lambda-managed-instances/core/runtime_state_names.go b/internal/lambda-managed-instances/core/runtime_state_names.go new file mode 100644 index 0000000..23f7cfe --- /dev/null +++ b/internal/lambda-managed-instances/core/runtime_state_names.go @@ -0,0 +1,11 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package core + +const ( + RuntimeStartedStateName = "Started" + RuntimeInitErrorStateName = "InitError" + RuntimeReadyStateName = "Ready" + RuntimeRunningStateName = "Running" +) diff --git a/internal/lambda-managed-instances/core/statejson/description.go b/internal/lambda-managed-instances/core/statejson/description.go new file mode 100644 index 0000000..f265d05 --- /dev/null +++ b/internal/lambda-managed-instances/core/statejson/description.go @@ -0,0 +1,65 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package statejson + +import ( + "encoding/json" + "log/slog" + "time" +) + +type ResponseMode string + +const ( + ResponseModeBuffered = "Buffered" + ResponseModeStreaming = "Streaming" +) + +type InvokeResponseMode string + +const ( + InvokeResponseModeBuffered InvokeResponseMode = ResponseModeBuffered + InvokeResponseModeStreaming InvokeResponseMode = ResponseModeStreaming +) + +type StateDescription struct { + Name string `json:"name"` + LastModified time.Time `json:"lastModified"` + ResponseTime time.Time `json:"responseTime"` +} + +type RuntimeDescription struct { + State StateDescription `json:"state"` +} + +type ExtensionDescription struct { + Name string `json:"name"` + ID string + State StateDescription `json:"state"` + ErrorType string `json:"errorType"` +} + +type InternalStateDescription struct { + Runtime *RuntimeDescription `json:"runtime"` + Extensions []ExtensionDescription `json:"extensions"` + FirstFatalError string `json:"firstFatalError"` +} + +type ResponseMetricsDimensions struct { + InvokeResponseMode InvokeResponseMode `json:"invokeResponseMode"` +} + +type ResponseMetrics struct { + RuntimeResponseLatencyMs float64 `json:"runtimeResponseLatencyMs"` + Dimensions ResponseMetricsDimensions `json:"dimensions"` +} + +func (s *InternalStateDescription) AsJSON() []byte { + bytes, err := json.Marshal(s) + if err != nil { + slog.Error("Failed to marshall internal states", "err", err) + panic(err) + } + return bytes +} diff --git a/internal/lambda-managed-instances/core/states.go b/internal/lambda-managed-instances/core/states.go new file mode 100644 index 0000000..0bb93a8 --- /dev/null +++ b/internal/lambda-managed-instances/core/states.go @@ -0,0 +1,215 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package core + +import ( + "errors" + "sync" + "time" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/core/statejson" +) + +type Suspendable interface { + SuspendUnsafe() + Release() + Lock() + Unlock() +} + +type ManagedThread struct { + operatorCondition *sync.Cond + operatorConditionValue bool +} + +func (s *ManagedThread) SuspendUnsafe() { + for !s.operatorConditionValue { + s.operatorCondition.Wait() + } + s.operatorConditionValue = false +} + +func (s *ManagedThread) Release() { + s.operatorCondition.L.Lock() + defer s.operatorCondition.L.Unlock() + s.operatorConditionValue = true + s.operatorCondition.Signal() +} + +func (s *ManagedThread) Lock() { + s.operatorCondition.L.Lock() +} + +func (s *ManagedThread) Unlock() { + s.operatorCondition.L.Unlock() +} + +func NewManagedThread() *ManagedThread { + return &ManagedThread{ + operatorCondition: sync.NewCond(&sync.Mutex{}), + operatorConditionValue: false, + } +} + +var ErrNotAllowed = errors.New("state transition is not allowed") + +var ErrConcurrentStateModification = errors.New("concurrent state modification") + +type RuntimeState interface { + InitError() error + Ready() error + Name() string +} + +type disallowEveryTransitionByDefault struct{} + +func (s *disallowEveryTransitionByDefault) InitError() error { return ErrNotAllowed } +func (s *disallowEveryTransitionByDefault) Ready() error { return ErrNotAllowed } + +type Runtime struct { + ManagedThread Suspendable + + currentState RuntimeState + stateLastModified time.Time + responseTime time.Time + + RuntimeStartedState RuntimeState + RuntimeInitErrorState RuntimeState + RuntimeReadyState RuntimeState + RuntimeRunningState RuntimeState +} + +func (s *Runtime) Release() { + s.ManagedThread.Release() +} + +func (s *Runtime) SetState(state RuntimeState) { + s.ManagedThread.Lock() + defer s.ManagedThread.Unlock() + s.setStateUnsafe(state) +} + +func (s *Runtime) setStateUnsafe(state RuntimeState) { + s.currentState = state + s.stateLastModified = time.Now() +} + +func (s *Runtime) GetState() RuntimeState { + s.ManagedThread.Lock() + defer s.ManagedThread.Unlock() + return s.currentState +} + +func (s *Runtime) Ready() error { + s.ManagedThread.Lock() + defer s.ManagedThread.Unlock() + return s.currentState.Ready() +} + +func (s *Runtime) InitError() error { + s.ManagedThread.Lock() + defer s.ManagedThread.Unlock() + return s.currentState.InitError() +} + +func (s *Runtime) GetRuntimeDescription() statejson.RuntimeDescription { + s.ManagedThread.Lock() + defer s.ManagedThread.Unlock() + res := statejson.RuntimeDescription{ + State: statejson.StateDescription{ + Name: s.currentState.Name(), + LastModified: s.stateLastModified, + }, + } + if !s.responseTime.IsZero() { + res.State.ResponseTime = s.responseTime + } + return res +} + +func NewRuntime(initFlow InitFlowSynchronization) *Runtime { + runtime := &Runtime{ + ManagedThread: NewManagedThread(), + } + + runtime.RuntimeStartedState = &RuntimeStartedState{runtime: runtime, initFlow: initFlow} + runtime.RuntimeInitErrorState = &RuntimeInitErrorState{runtime: runtime, initFlow: initFlow} + runtime.RuntimeReadyState = &RuntimeReadyState{runtime: runtime} + runtime.RuntimeRunningState = &RuntimeRunningState{runtime: runtime} + runtime.setStateUnsafe(runtime.RuntimeStartedState) + return runtime +} + +type RuntimeStartedState struct { + runtime *Runtime + initFlow InitFlowSynchronization +} + +func (s *RuntimeStartedState) Ready() error { + s.runtime.setStateUnsafe(s.runtime.RuntimeReadyState) + + err := s.initFlow.RuntimeReady() + if err != nil { + return err + } + + s.runtime.ManagedThread.SuspendUnsafe() + if s.runtime.currentState != s.runtime.RuntimeReadyState && s.runtime.currentState != s.runtime.RuntimeRunningState { + return ErrConcurrentStateModification + } + + s.runtime.setStateUnsafe(s.runtime.RuntimeRunningState) + return nil +} + +func (s *RuntimeStartedState) InitError() error { + s.runtime.setStateUnsafe(s.runtime.RuntimeInitErrorState) + return nil +} + +func (s *RuntimeStartedState) Name() string { + return RuntimeStartedStateName +} + +type RuntimeInitErrorState struct { + disallowEveryTransitionByDefault + runtime *Runtime + initFlow InitFlowSynchronization +} + +func (s *RuntimeInitErrorState) Name() string { + return RuntimeInitErrorStateName +} + +type RuntimeReadyState struct { + disallowEveryTransitionByDefault + runtime *Runtime +} + +func (s *RuntimeReadyState) Ready() error { + s.runtime.ManagedThread.SuspendUnsafe() + if s.runtime.currentState != s.runtime.RuntimeReadyState && s.runtime.currentState != s.runtime.RuntimeRunningState { + return ErrConcurrentStateModification + } + + s.runtime.setStateUnsafe(s.runtime.RuntimeRunningState) + return nil +} + +func (s *RuntimeReadyState) Name() string { + return RuntimeReadyStateName +} + +type RuntimeRunningState struct { + disallowEveryTransitionByDefault + runtime *Runtime +} + +func (s *RuntimeRunningState) Ready() error { + return nil +} + +func (s *RuntimeRunningState) Name() string { + return RuntimeRunningStateName +} diff --git a/internal/lambda-managed-instances/core/states_test.go b/internal/lambda-managed-instances/core/states_test.go new file mode 100644 index 0000000..2066034 --- /dev/null +++ b/internal/lambda-managed-instances/core/states_test.go @@ -0,0 +1,148 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package core + +import ( + "context" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/testdata/mockthread" +) + +func TestRuntimeInitErrorAfterReady(t *testing.T) { + initFlow := &mockInitFlowSynchronization{} + initFlow.ReadyCond = sync.NewCond(&sync.Mutex{}) + runtime := NewRuntime(initFlow) + + readyChan := make(chan struct{}) + runtime.SetState(runtime.RuntimeStartedState) + go func() { + assert.NoError(t, runtime.Ready()) + readyChan <- struct{}{} + }() + + initFlow.ReadyCond.L.Lock() + for !initFlow.ReadyCalled { + initFlow.ReadyCond.Wait() + } + initFlow.ReadyCond.L.Unlock() + assert.Equal(t, runtime.RuntimeReadyState, runtime.GetState()) + + assert.Equal(t, ErrNotAllowed, runtime.InitError()) + runtime.Release() + <-readyChan + assert.Equal(t, ErrNotAllowed, runtime.InitError()) + assert.Equal(t, runtime.RuntimeRunningState, runtime.GetState()) +} + +func TestRuntimeStateTransitionsFromStartedState(t *testing.T) { + runtime := newRuntime() + + assert.Equal(t, runtime.RuntimeStartedState, runtime.GetState()) + + runtime.SetState(runtime.RuntimeStartedState) + assert.NoError(t, runtime.InitError()) + assert.Equal(t, runtime.RuntimeInitErrorState, runtime.GetState()) + + runtime.SetState(runtime.RuntimeStartedState) + assert.NoError(t, runtime.Ready()) + assert.Equal(t, runtime.RuntimeRunningState, runtime.GetState()) +} + +func TestRuntimeStateTransitionsFromInitErrorState(t *testing.T) { + runtime := newRuntime() + + runtime.SetState(runtime.RuntimeInitErrorState) + assert.Equal(t, ErrNotAllowed, runtime.InitError()) + assert.Equal(t, runtime.RuntimeInitErrorState, runtime.GetState()) + + runtime.SetState(runtime.RuntimeInitErrorState) + assert.Equal(t, ErrNotAllowed, runtime.Ready()) + assert.Equal(t, runtime.RuntimeInitErrorState, runtime.GetState()) +} + +func TestRuntimeStateTransitionsFromReadyState(t *testing.T) { + runtime := newRuntime() + + runtime.SetState(runtime.RuntimeReadyState) + assert.Equal(t, ErrNotAllowed, runtime.InitError()) + assert.Equal(t, runtime.RuntimeReadyState, runtime.GetState()) + + runtime.SetState(runtime.RuntimeReadyState) + assert.NoError(t, runtime.Ready()) + assert.Equal(t, runtime.RuntimeRunningState, runtime.GetState()) +} + +func TestRuntimeStateTransitionsFromRunningState(t *testing.T) { + runtime := newRuntime() + + runtime.SetState(runtime.RuntimeRunningState) + assert.Equal(t, ErrNotAllowed, runtime.InitError()) + assert.Equal(t, runtime.RuntimeRunningState, runtime.GetState()) + + runtime.SetState(runtime.RuntimeRunningState) + assert.NoError(t, runtime.Ready()) + assert.Equal(t, runtime.RuntimeRunningState, runtime.GetState()) +} + +func newRuntime() *Runtime { + initFlow := &mockInitFlowSynchronization{} + runtime := NewRuntime(initFlow) + runtime.ManagedThread = &mockthread.MockManagedThread{} + + return runtime +} + +type mockInitFlowSynchronization struct { + mock.Mock + ReadyCond *sync.Cond + ReadyCalled bool +} + +func (s *mockInitFlowSynchronization) SetExternalAgentsRegisterCount(agentCount uint16) error { + return nil +} + +func (s *mockInitFlowSynchronization) SetAgentsReadyCount(agentCount uint16) error { + return nil +} + +func (s *mockInitFlowSynchronization) AwaitExternalAgentsRegistered(ctx context.Context) error { + return nil +} + +func (s *mockInitFlowSynchronization) ExternalAgentRegistered() error { + return nil +} + +func (s *mockInitFlowSynchronization) AwaitRuntimeReady(ctx context.Context) error { + return nil +} + +func (s *mockInitFlowSynchronization) AwaitAgentsReady(ctx context.Context) error { + return nil +} + +func (s *mockInitFlowSynchronization) RuntimeReady() error { + if s.ReadyCond != nil { + s.ReadyCond.L.Lock() + defer s.ReadyCond.L.Unlock() + s.ReadyCalled = true + s.ReadyCond.Signal() + } + return nil +} + +func (s *mockInitFlowSynchronization) AgentReady() error { + return nil +} + +func (s *mockInitFlowSynchronization) CancelWithError(err error) { + s.Called(err) +} +func (s *mockInitFlowSynchronization) Clear() {} diff --git a/internal/lambda-managed-instances/interop/cancellable_request.go b/internal/lambda-managed-instances/interop/cancellable_request.go new file mode 100644 index 0000000..7e8fca5 --- /dev/null +++ b/internal/lambda-managed-instances/interop/cancellable_request.go @@ -0,0 +1,27 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package interop + +import ( + "net" + "net/http" +) + +type key int + +const ( + HTTPConnKey key = iota +) + +func GetConn(r *http.Request) net.Conn { + return r.Context().Value(HTTPConnKey).(net.Conn) +} + +type CancellableRequest struct { + Request *http.Request +} + +func (c *CancellableRequest) Cancel() error { + return GetConn(c.Request).Close() +} diff --git a/internal/lambda-managed-instances/interop/error_utils.go b/internal/lambda-managed-instances/interop/error_utils.go new file mode 100644 index 0000000..ac977e0 --- /dev/null +++ b/internal/lambda-managed-instances/interop/error_utils.go @@ -0,0 +1,23 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package interop + +import "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" + +func BuildStatusFromError(err model.AppError) ResponseStatus { + if err == nil { + return Success + } + + if err.ErrorType() == model.ErrorSandboxTimedout { + return Timeout + } + + switch err.(type) { + case model.CustomerError: + return Error + default: + return Failure + } +} diff --git a/internal/lambda-managed-instances/interop/error_utils_test.go b/internal/lambda-managed-instances/interop/error_utils_test.go new file mode 100644 index 0000000..e6153a0 --- /dev/null +++ b/internal/lambda-managed-instances/interop/error_utils_test.go @@ -0,0 +1,48 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package interop + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" +) + +func TestBuildStatusFromError(t *testing.T) { + testCases := []struct { + name string + err model.AppError + expected ResponseStatus + }{ + { + name: "nil error", + err: nil, + expected: Success, + }, + { + name: "sandbox timeout error", + err: model.NewCustomerError(model.ErrorSandboxTimedout), + expected: Timeout, + }, + { + name: "customer error", + err: model.NewCustomerError(model.ErrorFunctionUnknown), + expected: Error, + }, + { + name: "runtime error", + err: model.NewPlatformError(nil, model.ErrorReasonUnknownError), + expected: Failure, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + actual := BuildStatusFromError(tc.err) + assert.Equal(t, tc.expected, actual) + }) + } +} diff --git a/internal/lambda-managed-instances/interop/events_api.go b/internal/lambda-managed-instances/interop/events_api.go new file mode 100644 index 0000000..4ec9d25 --- /dev/null +++ b/internal/lambda-managed-instances/interop/events_api.go @@ -0,0 +1,205 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package interop + +import ( + "fmt" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi/model" + rapidmodel "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" +) + +type InitPhase string + +type InitStartData struct { + InitializationType string `json:"initializationType"` + RuntimeVersion string `json:"runtimeVersion"` + RuntimeVersionArn string `json:"runtimeVersionArn"` + FunctionName string `json:"functionName"` + FunctionVersion string `json:"functionVersion"` + InstanceID string `json:"instanceId"` + InstanceMaxMemory uint64 `json:"instanceMaxMemory"` + Phase InitPhase `json:"phase"` + Tracing *TracingCtx `json:"tracing,omitempty"` +} + +func (d *InitStartData) String() string { + return fmt.Sprintf("INIT START(initType: %s, phase: %s)", d.InitializationType, d.Phase) +} + +type InitRuntimeDoneData struct { + InitializationType string `json:"initializationType"` + Status ResponseStatus `json:"status"` + Phase InitPhase `json:"phase"` + ErrorType *string `json:"errorType,omitempty"` + Tracing *TracingCtx `json:"tracing,omitempty"` +} + +func (d *InitRuntimeDoneData) String() string { + errorType := "nil" + if d.ErrorType != nil { + errorType = *d.ErrorType + } + return fmt.Sprintf("INIT RTDONE(initType: %s, status: %s, phase: %s, errorType: %s)", d.InitializationType, d.Status, d.Phase, errorType) +} + +type InitReportMetrics struct { + DurationMs float64 `json:"durationMs"` +} + +type InitReportData struct { + InitializationType string `json:"initializationType"` + Metrics InitReportMetrics `json:"metrics"` + Phase InitPhase `json:"phase"` + Tracing *TracingCtx `json:"tracing,omitempty"` + Status ResponseStatus `json:"status"` + ErrorType *string `json:"errorType,omitempty"` +} + +func (d *InitReportData) String() string { + errorType := "nil" + if d.ErrorType != nil { + errorType = *d.ErrorType + } + + return fmt.Sprintf("INIT REPORT(initType: %s, durationMs: %.2f, status: %s, phase: %s, errorType: %s)", d.InitializationType, d.Metrics.DurationMs, d.Status, d.Phase, errorType) +} + +type TracingCtx struct { + SpanID string `json:"spanId,omitempty"` + Type model.TracingType `json:"type"` + Value string `json:"value"` +} + +type InvokeStartData struct { + InvokeID InvokeID `json:"requestId"` + Version string `json:"version,omitempty"` + FunctionARN string `json:"functionArn,omitempty"` + Tracing *TracingCtx `json:"tracing,omitempty"` +} + +func (d *InvokeStartData) String() string { + return fmt.Sprintf("INVOKE START(requestId: %s)", d.InvokeID) +} + +type RuntimeDoneInvokeMetrics struct { + ProducedBytes int64 `json:"producedBytes"` + DurationMs float64 `json:"durationMs"` +} + +type Span struct { + Name string `json:"name"` + Start string `json:"start"` + DurationMs float64 `json:"durationMs"` +} + +func (s *Span) String() string { + return fmt.Sprintf("SPAN(name: %s)", s.Name) +} + +type InvokeRuntimeDoneData struct { + InvokeID InvokeID `json:"requestId"` + Status ResponseStatus `json:"status"` + Metrics *RuntimeDoneInvokeMetrics `json:"metrics,omitempty"` + Tracing *TracingCtx `json:"tracing,omitempty"` + Spans []Span `json:"spans,omitempty"` + ErrorType *string `json:"errorType,omitempty"` + InternalMetrics *InvokeResponseMetrics `json:"-"` +} + +func (d *InvokeRuntimeDoneData) String() string { + errorType := "nil" + if d.ErrorType != nil { + errorType = *d.ErrorType + } + return fmt.Sprintf("INVOKE RTDONE(status: %s, producedBytes: %d, durationMs: %.2f, spans: %d, errorType: %s)", d.Status, d.Metrics.ProducedBytes, d.Metrics.DurationMs, len(d.Spans), errorType) +} + +type ExtensionInitData struct { + AgentName string `json:"name"` + State string `json:"state"` + Subscriptions []string `json:"events"` + ErrorType string `json:"errorType,omitempty"` +} + +func (d *ExtensionInitData) String() string { + return fmt.Sprintf("EXTENSION INIT(agentName: %s, state: %s, errorType: %s)", d.AgentName, d.State, d.ErrorType) +} + +type ReportDurationMs float64 + +func (d ReportDurationMs) MarshalJSON() ([]byte, error) { + return []byte(fmt.Sprintf("%.3f", d)), nil +} + +type ReportMetrics struct { + DurationMs ReportDurationMs `json:"durationMs"` +} + +type ReportData struct { + InvokeID InvokeID `json:"requestId"` + Status ResponseStatus `json:"status"` + Metrics ReportMetrics `json:"metrics"` + Tracing *TracingCtx `json:"tracing,omitempty"` + Spans []Span `json:"spans,omitempty"` + ErrorType *rapidmodel.ErrorType `json:"errorType,omitempty"` +} + +func (d *ReportData) String() string { + errorType := "nil" + if d.ErrorType != nil { + errorType = string(*d.ErrorType) + } + return fmt.Sprintf("REPORT(status: %s, durationMs: %.2f, errorType: %s)", d.Status, d.Metrics.DurationMs, errorType) +} + +type EndData struct { + InvokeID InvokeID `json:"requestId"` +} + +func (d *EndData) String() string { + return "END" +} + +type InternalXRayErrorCauseData struct { + InvokeID InvokeID `json:"requestId"` + Cause string `json:"cause"` +} + +func (d *InternalXRayErrorCauseData) String() string { + return fmt.Sprintf("XRAY_ERROR_CAUSE(len: %d)", len(d.Cause)) +} + +type InvokeID = string + +type FaultData struct { + InvokeID InvokeID + Status ResponseStatus + ErrorType *rapidmodel.ErrorType +} + +func (d *FaultData) RenderFluxpumpMsg() string { + var errtype string + if d.ErrorType != nil { + errtype = fmt.Sprintf("\tErrorType: %s", *d.ErrorType) + } + return fmt.Sprintf("RequestId: %s\tStatus: %s%s\n", d.InvokeID, d.Status, errtype) +} + +type ImageErrorLogData struct { + ExecError rapidmodel.RuntimeExecError + ExecConfig rapidmodel.RuntimeExec +} + +type EventsAPI interface { + SendInitStart(InitStartData) error + SendInitRuntimeDone(InitRuntimeDoneData) error + SendInitReport(InitReportData) error + SendExtensionInit(ExtensionInitData) error + SendImageError(ImageErrorLogData) + SendInternalXRayErrorCause(InternalXRayErrorCauseData) error + SendInvokeStart(InvokeStartData) error + SendReport(ReportData) error + Flush() +} diff --git a/internal/lambda-managed-instances/interop/events_api_test.go b/internal/lambda-managed-instances/interop/events_api_test.go new file mode 100644 index 0000000..0a9ea00 --- /dev/null +++ b/internal/lambda-managed-instances/interop/events_api_test.go @@ -0,0 +1,884 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package interop + +import ( + "encoding/json" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi/model" + rapidmodel "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" +) + +const ( + initializationType = "lambda-managed-instances" + invokeID InvokeID = "REQUEST_ID" +) + +func TestJsonMarshalInvokeRuntimeDone(t *testing.T) { + data := InvokeRuntimeDoneData{ + InvokeID: invokeID, + Status: "success", + Metrics: &RuntimeDoneInvokeMetrics{ + ProducedBytes: int64(100), + DurationMs: float64(52.56), + }, + Spans: []Span{ + { + Name: "responseLatency", + Start: "2022-04-11T15:01:28.543Z", + DurationMs: float64(23.02), + }, + { + Name: "responseDuration", + Start: "2022-04-11T15:00:00.000Z", + DurationMs: float64(20), + }, + }, + Tracing: &TracingCtx{ + SpanID: "spanid", + Type: model.XRayTracingType, + Value: "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1", + }, + } + + expected := ` + { + "requestId": "REQUEST_ID", + "status": "success", + "tracing": { + "spanId": "spanid", + "type": "X-Amzn-Trace-Id", + "value": "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1" + }, + "spans": [ + { + "name": "responseLatency", + "start": "2022-04-11T15:01:28.543Z", + "durationMs": 23.02 + }, + { + "name": "responseDuration", + "start": "2022-04-11T15:00:00.000Z", + "durationMs": 20 + } + ], + "metrics": { + "producedBytes": 100, + "durationMs": 52.56 + } + } + ` + + actual, err := json.Marshal(data) + assert.NoError(t, err) + assert.JSONEq(t, expected, string(actual)) +} + +func TestJsonMarshalInvokeRuntimeDoneNoTracing(t *testing.T) { + data := InvokeRuntimeDoneData{ + InvokeID: invokeID, + Status: "success", + Metrics: &RuntimeDoneInvokeMetrics{ + ProducedBytes: int64(100), + DurationMs: float64(52.56), + }, + Spans: []Span{ + { + Name: "responseLatency", + Start: "2022-04-11T15:01:28.543Z", + DurationMs: float64(23.02), + }, + { + Name: "responseDuration", + Start: "2022-04-11T15:00:00.000Z", + DurationMs: float64(20), + }, + }, + } + + expected := ` + { + "requestId": "REQUEST_ID", + "status": "success", + "spans": [ + { + "name": "responseLatency", + "start": "2022-04-11T15:01:28.543Z", + "durationMs": 23.02 + }, + { + "name": "responseDuration", + "start": "2022-04-11T15:00:00.000Z", + "durationMs": 20 + } + ], + "metrics": { + "producedBytes": 100, + "durationMs": 52.56 + } + } + ` + + actual, err := json.Marshal(data) + assert.NoError(t, err) + assert.JSONEq(t, expected, string(actual)) +} + +func TestJsonMarshalInvokeRuntimeDoneNoMetrics(t *testing.T) { + data := InvokeRuntimeDoneData{ + InvokeID: invokeID, + Status: "success", + Spans: []Span{ + { + Name: "responseLatency", + Start: "2022-04-11T15:01:28.543Z", + DurationMs: float64(23.02), + }, + { + Name: "responseDuration", + Start: "2022-04-11T15:00:00.000Z", + DurationMs: float64(20), + }, + }, + Tracing: &TracingCtx{ + SpanID: "spanid", + Type: model.XRayTracingType, + Value: "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1", + }, + } + + expected := ` + { + "requestId": "REQUEST_ID", + "status": "success", + "tracing": { + "spanId": "spanid", + "type": "X-Amzn-Trace-Id", + "value": "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1" + }, + "spans": [ + { + "name": "responseLatency", + "start": "2022-04-11T15:01:28.543Z", + "durationMs": 23.02 + }, + { + "name": "responseDuration", + "start": "2022-04-11T15:00:00.000Z", + "durationMs": 20 + } + ] + } + ` + + actual, err := json.Marshal(data) + assert.NoError(t, err) + assert.JSONEq(t, expected, string(actual)) +} + +func TestJsonMarshalInvokeRuntimeDoneWithProducedBytesEqualToZero(t *testing.T) { + data := InvokeRuntimeDoneData{ + InvokeID: invokeID, + Status: "success", + Metrics: &RuntimeDoneInvokeMetrics{ + ProducedBytes: int64(0), + DurationMs: float64(52.56), + }, + Spans: []Span{ + { + Name: "responseLatency", + Start: "2022-04-11T15:01:28.543Z", + DurationMs: float64(23.02), + }, + { + Name: "responseDuration", + Start: "2022-04-11T15:00:00.000Z", + DurationMs: float64(20), + }, + }, + Tracing: &TracingCtx{ + SpanID: "spanid", + Type: model.XRayTracingType, + Value: "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1", + }, + } + + expected := ` + { + "requestId": "REQUEST_ID", + "status": "success", + "tracing": { + "spanId": "spanid", + "type": "X-Amzn-Trace-Id", + "value": "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1" + }, + "spans": [ + { + "name": "responseLatency", + "start": "2022-04-11T15:01:28.543Z", + "durationMs": 23.02 + }, + { + "name": "responseDuration", + "start": "2022-04-11T15:00:00.000Z", + "durationMs": 20 + } + ], + "metrics": { + "producedBytes": 0, + "durationMs": 52.56 + } + } + ` + + actual, err := json.Marshal(data) + assert.NoError(t, err) + assert.JSONEq(t, expected, string(actual)) +} + +func TestJsonMarshalInvokeRuntimeDoneWithNoSpans(t *testing.T) { + data := InvokeRuntimeDoneData{ + InvokeID: invokeID, + Status: "success", + Metrics: &RuntimeDoneInvokeMetrics{ + ProducedBytes: int64(100), + DurationMs: float64(52.56), + }, + Spans: []Span{}, + Tracing: &TracingCtx{ + SpanID: "spanid", + Type: model.XRayTracingType, + Value: "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1", + }, + } + + expected := ` + { + "requestId": "REQUEST_ID", + "status": "success", + "tracing": { + "spanId": "spanid", + "type": "X-Amzn-Trace-Id", + "value": "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1" + }, + "metrics": { + "producedBytes": 100, + "durationMs": 52.56 + } + } + ` + + actual, err := json.Marshal(data) + assert.NoError(t, err) + assert.JSONEq(t, expected, string(actual)) +} + +func TestJsonMarshalInvokeRuntimeDoneTimeout(t *testing.T) { + data := InvokeRuntimeDoneData{ + InvokeID: invokeID, + Status: "timeout", + Metrics: &RuntimeDoneInvokeMetrics{ + DurationMs: float64(52.56), + }, + Spans: []Span{}, + Tracing: &TracingCtx{ + SpanID: "spanid", + Type: model.XRayTracingType, + Value: "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1", + }, + } + + expected := ` + { + "requestId": "REQUEST_ID", + "status": "timeout", + "tracing": { + "spanId": "spanid", + "type": "X-Amzn-Trace-Id", + "value": "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1" + }, + "metrics": { + "producedBytes": 0, + "durationMs": 52.56 + } + } + ` + + actual, err := json.Marshal(data) + assert.NoError(t, err) + assert.JSONEq(t, expected, string(actual)) +} + +func TestJsonMarshalInvokeRuntimeDoneFailure(t *testing.T) { + errorType := "Runtime.ExitError" + data := InvokeRuntimeDoneData{ + InvokeID: invokeID, + Status: "failure", + ErrorType: &errorType, + Metrics: &RuntimeDoneInvokeMetrics{ + DurationMs: float64(52.56), + }, + Spans: []Span{}, + Tracing: &TracingCtx{ + SpanID: "spanid", + Type: model.XRayTracingType, + Value: "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1", + }, + } + + expected := ` + { + "requestId": "REQUEST_ID", + "status": "failure", + "tracing": { + "spanId": "spanid", + "type": "X-Amzn-Trace-Id", + "value": "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1" + }, + "metrics": { + "producedBytes": 0, + "durationMs": 52.56 + }, + "errorType": "Runtime.ExitError" + } + ` + + actual, err := json.Marshal(data) + assert.NoError(t, err) + assert.JSONEq(t, expected, string(actual)) +} + +func TestJsonMarshalInvokeRuntimeDoneWithEmptyErrorType(t *testing.T) { + errorType := "" + data := InvokeRuntimeDoneData{ + InvokeID: invokeID, + Status: "failure", + ErrorType: &errorType, + Metrics: &RuntimeDoneInvokeMetrics{ + DurationMs: float64(52.56), + }, + Spans: []Span{}, + Tracing: &TracingCtx{ + SpanID: "spanid", + Type: model.XRayTracingType, + Value: "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1", + }, + } + + expected := ` + { + "requestId": "REQUEST_ID", + "status": "failure", + "tracing": { + "spanId": "spanid", + "type": "X-Amzn-Trace-Id", + "value": "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1" + }, + "metrics": { + "producedBytes": 0, + "durationMs": 52.56 + }, + "errorType": "" + } + ` + + actual, err := json.Marshal(data) + assert.NoError(t, err) + assert.JSONEq(t, expected, string(actual)) +} + +func TestJsonMarshalInitRuntimeDoneSuccess(t *testing.T) { + var errorType *string + data := InitRuntimeDoneData{ + InitializationType: initializationType, + Phase: "init", + Status: "success", + ErrorType: errorType, + } + + expected := ` + { + "initializationType": "lambda-managed-instances", + "phase": "init", + "status": "success" + } + ` + + actual, err := json.Marshal(data) + assert.NoError(t, err) + assert.JSONEq(t, expected, string(actual)) +} + +func TestJsonMarshalInitRuntimeDoneError(t *testing.T) { + errorType := "Runtime.ExitError" + data := InitRuntimeDoneData{ + InitializationType: initializationType, + Phase: "init", + Status: "error", + ErrorType: &errorType, + } + + expected := ` + { + "initializationType": "lambda-managed-instances", + "phase": "init", + "status": "error", + "errorType": "Runtime.ExitError" + } + ` + + actual, err := json.Marshal(data) + assert.NoError(t, err) + assert.JSONEq(t, expected, string(actual)) +} + +func TestJsonMarshalInitRuntimeDoneFailureWithEmptyErrorType(t *testing.T) { + errorType := "" + data := InitRuntimeDoneData{ + InitializationType: initializationType, + Phase: "init", + Status: "error", + ErrorType: &errorType, + } + + expected := ` + { + "initializationType": "lambda-managed-instances", + "phase": "init", + "status": "error", + "errorType": "" + } + ` + + actual, err := json.Marshal(data) + assert.NoError(t, err) + assert.JSONEq(t, expected, string(actual)) +} + +func TestJsonMarshalInitReportSuccess(t *testing.T) { + var errorType *string + data := InitReportData{ + InitializationType: initializationType, + Phase: "init", + Status: "success", + ErrorType: errorType, + Metrics: InitReportMetrics{DurationMs: 5}, + } + + expected := ` + { + "initializationType": "lambda-managed-instances", + "metrics": {"durationMs": 5}, + "phase": "init", + "status": "success" + } + ` + + actual, err := json.Marshal(data) + assert.NoError(t, err) + + assert.JSONEq(t, expected, string(actual)) +} + +func TestJsonMarshalInitReportError(t *testing.T) { + errorType := "Runtime.ExitError" + data := InitReportData{ + InitializationType: initializationType, + Phase: "init", + Status: "error", + ErrorType: &errorType, + Metrics: InitReportMetrics{DurationMs: 18}, + } + + expected := ` + { + "initializationType": "lambda-managed-instances", + "metrics": {"durationMs": 18}, + "phase": "init", + "status": "error", + "errorType": "Runtime.ExitError" + } + ` + + actual, err := json.Marshal(data) + assert.NoError(t, err) + assert.JSONEq(t, expected, string(actual)) +} + +func TestJsonMarshalInitReportTimeout(t *testing.T) { + var errorType *string + + data := InitReportData{ + InitializationType: initializationType, + Phase: "init", + Status: "timeout", + ErrorType: errorType, + Metrics: InitReportMetrics{DurationMs: 17}, + } + + expected := ` + { + "initializationType": "lambda-managed-instances", + "metrics": {"durationMs": 17}, + "phase": "init", + "status": "timeout" + } + ` + + actual, err := json.Marshal(data) + + t.Log(string(actual)) + + assert.NoError(t, err) + assert.JSONEq(t, expected, string(actual)) +} + +func TestJsonMarshalInitReportErrorWithEmptyErrorType(t *testing.T) { + errorType := "" + data := InitReportData{ + InitializationType: initializationType, + Phase: "init", + Status: "error", + ErrorType: &errorType, + Metrics: InitReportMetrics{DurationMs: 23}, + } + + expected := ` + { + "initializationType": "lambda-managed-instances", + "metrics": {"durationMs": 23}, + "phase": "init", + "status": "error", + "errorType": "" + } + ` + + actual, err := json.Marshal(data) + assert.NoError(t, err) + assert.JSONEq(t, expected, string(actual)) +} + +func TestJsonMarshalExtensionInit(t *testing.T) { + data := ExtensionInitData{ + AgentName: "agentName", + State: "Registered", + ErrorType: "", + Subscriptions: []string{"INVOKE", "SHUTDOWN"}, + } + + actual, err := json.Marshal(data) + assert.NoError(t, err) + assert.JSONEq(t, `{"name":"agentName","state":"Registered","events":["INVOKE","SHUTDOWN"]}`, string(actual)) +} + +func TestJsonMarshalExtensionInitWithError(t *testing.T) { + data := ExtensionInitData{ + AgentName: "agentName", + State: "Registered", + ErrorType: "Extension.FooBar", + Subscriptions: []string{"INVOKE", "SHUTDOWN"}, + } + + actual, err := json.Marshal(data) + assert.NoError(t, err) + assert.JSONEq(t, `{"name":"agentName","state":"Registered","events":["INVOKE","SHUTDOWN"],"errorType":"Extension.FooBar"}`, string(actual)) +} + +func TestJsonMarshalExtensionInitEmptyEvents(t *testing.T) { + data := ExtensionInitData{ + AgentName: "agentName", + State: "Registered", + ErrorType: "Extension.FooBar", + Subscriptions: []string{}, + } + + actual, err := json.Marshal(data) + require.NoError(t, err) + require.JSONEq(t, `{"name":"agentName","state":"Registered","events":[],"errorType":"Extension.FooBar"}`, string(actual)) +} + +func TestJsonMarshalReportWithTracing(t *testing.T) { + errorType := rapidmodel.ErrorRuntimeExit + data := ReportData{ + InvokeID: invokeID, + Status: "error", + ErrorType: &errorType, + Metrics: ReportMetrics{ + DurationMs: 52.56, + }, + Tracing: &TracingCtx{ + SpanID: "spanid", + Type: model.XRayTracingType, + Value: "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1", + }, + } + + expected := ` + { + "requestId": "REQUEST_ID", + "status": "error", + "errorType": "Runtime.ExitError", + "metrics": { + "durationMs": 52.56 + }, + "tracing": { + "spanId": "spanid", + "type": "X-Amzn-Trace-Id", + "value": "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1" + } + } + ` + + actual, err := json.Marshal(data) + require.NoError(t, err) + require.JSONEq(t, expected, string(actual)) +} + +func TestJsonMarshalReportWithoutErrorSpansAndTracing(t *testing.T) { + data := ReportData{ + InvokeID: invokeID, + Status: "timeout", + Metrics: ReportMetrics{ + DurationMs: 52.56, + }, + } + + expected := ` + { + "requestId": "REQUEST_ID", + "status": "timeout", + "metrics": { + "durationMs": 52.56 + } + } + ` + + actual, err := json.Marshal(data) + require.NoError(t, err) + require.JSONEq(t, expected, string(actual)) +} + +func TestJsonMarshalReportWithInit(t *testing.T) { + data := ReportData{ + InvokeID: invokeID, + Status: "success", + Metrics: ReportMetrics{ + DurationMs: 52.56, + }, + } + + expected := ` + { + "requestId": "REQUEST_ID", + "status": "success", + "metrics": { + "durationMs": 52.56 + } + } + ` + + actual, err := json.Marshal(data) + require.NoError(t, err) + require.JSONEq(t, expected, string(actual)) +} + +func TestJsonMarshalReportMetrics(t *testing.T) { + testCases := []struct { + name string + actual ReportData + expected string + }{ + { + "Report metrics with lower precision than reqd.", + ReportData{ + InvokeID: invokeID, + Status: "success", + Metrics: ReportMetrics{ + DurationMs: 12.3, + }, + }, + `{"requestId":"REQUEST_ID","status":"success","metrics":{"durationMs":12.300}}`, + }, + { + "Report metrics with enough or higher precision than reqd.", + ReportData{ + InvokeID: invokeID, + Status: "success", + Metrics: ReportMetrics{ + DurationMs: 1.234567, + }, + }, + `{"requestId":"REQUEST_ID","status":"success","metrics":{"durationMs":1.235}}`, + }, + { + "`DurationMs` of integer type, `InitDuration` absent", + ReportData{ + InvokeID: invokeID, + Status: "success", + Metrics: ReportMetrics{ + DurationMs: 10, + }, + }, + `{"requestId":"REQUEST_ID","status":"success","metrics":{"durationMs":10.000}}`, + }, + { + "Report metrics with zero value", + ReportData{ + InvokeID: invokeID, + Status: "success", + Metrics: ReportMetrics{ + DurationMs: 0, + }, + }, + `{"requestId":"REQUEST_ID","status":"success","metrics":{"durationMs":0.000}}`, + }, + { + "Report metrics not explicitly provided", + ReportData{ + InvokeID: invokeID, + Status: "success", + Metrics: ReportMetrics{}, + }, + `{"requestId":"REQUEST_ID","status":"success","metrics":{"durationMs":0.000}}`, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + actual, err := json.Marshal(tc.actual) + require.NoError(t, err) + require.Equal(t, tc.expected, string(actual)) + }) + } +} + +func TestFaultDataRenderFluxpumpMsg(t *testing.T) { + getPointerFromErrorType := func(e rapidmodel.ErrorType) *rapidmodel.ErrorType { + return &e + } + + testCases := []struct { + name string + expectLog string + actualLog *FaultData + }{ + { + name: "TimeoutDataString", + expectLog: "RequestId: dbe05a36-624e-4924-84d9-1c196aa21733\tStatus: timeout\n", + actualLog: &FaultData{ + "dbe05a36-624e-4924-84d9-1c196aa21733", + Timeout, + nil, + }, + }, + { + name: "ErrorDataString", + expectLog: "RequestId: 34359edf-cda1-4088-a74d-74f37ef686b6\tStatus: error\tErrorType: Runtime.InvalidEntrypoint\n", + actualLog: &FaultData{ + "34359edf-cda1-4088-a74d-74f37ef686b6", + Error, + getPointerFromErrorType(rapidmodel.ErrorRuntimeInvalidEntryPoint), + }, + }, + { + name: "FailureDataString", + expectLog: "RequestId: ff612654-10be-4311-8576-ab5af830d402\tStatus: failure\tErrorType: Sandbox.Failure\n", + actualLog: &FaultData{ + "ff612654-10be-4311-8576-ab5af830d402", + Failure, + getPointerFromErrorType(rapidmodel.ErrorSandboxFailure), + }, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.expectLog, tt.actualLog.RenderFluxpumpMsg()) + }) + } +} + +func TestTelemetryEventDataString(t *testing.T) { + runtimeExit := rapidmodel.ErrorRuntimeExit + runtimeUnknown := string(rapidmodel.ErrorRuntimeUnknown) + + testCases := []struct { + name string + expectLog string + actualLog fmt.Stringer + }{ + { + name: "InitStartDataString", + expectLog: "INIT START(initType: lambda-managed-instances, phase: invoke)", + actualLog: &InitStartData{ + InitializationType: initializationType, + Phase: InitPhase("invoke"), + }, + }, + { + name: "InitRuntimeDoneDataString", + expectLog: "INIT RTDONE(initType: lambda-managed-instances, status: error, phase: init, errorType: Runtime.Unknown)", + actualLog: &InitRuntimeDoneData{ + InitializationType: initializationType, + Status: "error", + Phase: InitPhase("init"), + ErrorType: &runtimeUnknown, + }, + }, + { + name: "InitReportDataString", + expectLog: "INIT REPORT(initType: lambda-managed-instances, durationMs: 40.00, status: error, phase: init, errorType: Runtime.Unknown)", + actualLog: &InitReportData{ + InitializationType: initializationType, + Metrics: InitReportMetrics{DurationMs: 40}, + Phase: InitPhase("init"), + Status: "error", + ErrorType: &runtimeUnknown, + }, + }, + { + name: "InvokeRuntimeDoneDataString", + expectLog: "INVOKE RTDONE(status: success, producedBytes: 100, durationMs: 52.56, spans: 0, errorType: nil)", + actualLog: &InvokeRuntimeDoneData{ + Status: "success", + Metrics: &RuntimeDoneInvokeMetrics{ + ProducedBytes: int64(100), + DurationMs: float64(52.56), + }, + Spans: []Span{}, + }, + }, + { + name: "ExtensionInitDataString", + expectLog: "EXTENSION INIT(agentName: Amazon Cloudfront, state: Registered, errorType: )", + actualLog: &ExtensionInitData{ + AgentName: "Amazon Cloudfront", + State: "Registered", + }, + }, + { + name: "ReportDataString", + expectLog: "REPORT(status: error, durationMs: 27.80, errorType: Runtime.ExitError)", + actualLog: &ReportData{ + InvokeID: "75c6a56e-385d-4686-8114-ae6fe457e397", + Status: "error", + Metrics: ReportMetrics{ + DurationMs: 27.799, + }, + ErrorType: &runtimeExit, + }, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.expectLog, tt.actualLog.String()) + }) + } +} diff --git a/internal/lambda-managed-instances/interop/mock_duration_metric_timer.go b/internal/lambda-managed-instances/interop/mock_duration_metric_timer.go new file mode 100644 index 0000000..8418b50 --- /dev/null +++ b/internal/lambda-managed-instances/interop/mock_duration_metric_timer.go @@ -0,0 +1,26 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package interop + +import mock "github.com/stretchr/testify/mock" + +type MockDurationMetricTimer struct { + mock.Mock +} + +func (_m *MockDurationMetricTimer) Done() { + _m.Called() +} + +func NewMockDurationMetricTimer(t interface { + mock.TestingT + Cleanup(func()) +}) *MockDurationMetricTimer { + mock := &MockDurationMetricTimer{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/lambda-managed-instances/interop/mock_events_api.go b/internal/lambda-managed-instances/interop/mock_events_api.go new file mode 100644 index 0000000..5a2223e --- /dev/null +++ b/internal/lambda-managed-instances/interop/mock_events_api.go @@ -0,0 +1,149 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package interop + +import mock "github.com/stretchr/testify/mock" + +type MockEventsAPI struct { + mock.Mock +} + +func (_m *MockEventsAPI) Flush() { + _m.Called() +} + +func (_m *MockEventsAPI) SendExtensionInit(_a0 ExtensionInitData) error { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for SendExtensionInit") + } + + var r0 error + if rf, ok := ret.Get(0).(func(ExtensionInitData) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +func (_m *MockEventsAPI) SendImageError(_a0 ImageErrorLogData) { + _m.Called(_a0) +} + +func (_m *MockEventsAPI) SendInitReport(_a0 InitReportData) error { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for SendInitReport") + } + + var r0 error + if rf, ok := ret.Get(0).(func(InitReportData) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +func (_m *MockEventsAPI) SendInitRuntimeDone(_a0 InitRuntimeDoneData) error { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for SendInitRuntimeDone") + } + + var r0 error + if rf, ok := ret.Get(0).(func(InitRuntimeDoneData) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +func (_m *MockEventsAPI) SendInitStart(_a0 InitStartData) error { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for SendInitStart") + } + + var r0 error + if rf, ok := ret.Get(0).(func(InitStartData) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +func (_m *MockEventsAPI) SendInternalXRayErrorCause(_a0 InternalXRayErrorCauseData) error { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for SendInternalXRayErrorCause") + } + + var r0 error + if rf, ok := ret.Get(0).(func(InternalXRayErrorCauseData) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +func (_m *MockEventsAPI) SendInvokeStart(_a0 InvokeStartData) error { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for SendInvokeStart") + } + + var r0 error + if rf, ok := ret.Get(0).(func(InvokeStartData) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +func (_m *MockEventsAPI) SendReport(_a0 ReportData) error { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for SendReport") + } + + var r0 error + if rf, ok := ret.Get(0).(func(ReportData) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +func NewMockEventsAPI(t interface { + mock.TestingT + Cleanup(func()) +}) *MockEventsAPI { + mock := &MockEventsAPI{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/lambda-managed-instances/interop/mock_health_check_response.go b/internal/lambda-managed-instances/interop/mock_health_check_response.go new file mode 100644 index 0000000..8591604 --- /dev/null +++ b/internal/lambda-managed-instances/interop/mock_health_check_response.go @@ -0,0 +1,26 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package interop + +import mock "github.com/stretchr/testify/mock" + +type MockHealthCheckResponse struct { + mock.Mock +} + +func (_m *MockHealthCheckResponse) healthCheckResponse() { + _m.Called() +} + +func NewMockHealthCheckResponse(t interface { + mock.TestingT + Cleanup(func()) +}) *MockHealthCheckResponse { + mock := &MockHealthCheckResponse{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/lambda-managed-instances/interop/mock_init_metrics.go b/internal/lambda-managed-instances/interop/mock_init_metrics.go new file mode 100644 index 0000000..39f582f --- /dev/null +++ b/internal/lambda-managed-instances/interop/mock_init_metrics.go @@ -0,0 +1,93 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package interop + +import ( + time "time" + + mock "github.com/stretchr/testify/mock" + model "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" +) + +type MockInitMetrics struct { + mock.Mock +} + +func (_m *MockInitMetrics) RunDuration() time.Duration { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for RunDuration") + } + + var r0 time.Duration + if rf, ok := ret.Get(0).(func() time.Duration); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(time.Duration) + } + + return r0 +} + +func (_m *MockInitMetrics) SendMetrics() error { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for SendMetrics") + } + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +func (_m *MockInitMetrics) SetExtensionsNumber(internal int, external int) { + _m.Called(internal, external) +} + +func (_m *MockInitMetrics) SetLogsAPIMetrics(_a0 TelemetrySubscriptionMetrics) { + _m.Called(_a0) +} + +func (_m *MockInitMetrics) TriggerGetRequest() { + _m.Called() +} + +func (_m *MockInitMetrics) TriggerInitCustomerPhaseDone() { + _m.Called() +} + +func (_m *MockInitMetrics) TriggerInitDone(_a0 model.AppError) { + _m.Called(_a0) +} + +func (_m *MockInitMetrics) TriggerRuntimeDone() { + _m.Called() +} + +func (_m *MockInitMetrics) TriggerStartRequest() { + _m.Called() +} + +func (_m *MockInitMetrics) TriggerStartingRuntime() { + _m.Called() +} + +func NewMockInitMetrics(t interface { + mock.TestingT + Cleanup(func()) +}) *MockInitMetrics { + mock := &MockInitMetrics{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/lambda-managed-instances/interop/mock_init_response.go b/internal/lambda-managed-instances/interop/mock_init_response.go new file mode 100644 index 0000000..d1a4989 --- /dev/null +++ b/internal/lambda-managed-instances/interop/mock_init_response.go @@ -0,0 +1,26 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package interop + +import mock "github.com/stretchr/testify/mock" + +type MockInitResponse struct { + mock.Mock +} + +func (_m *MockInitResponse) initResponse() { + _m.Called() +} + +func NewMockInitResponse(t interface { + mock.TestingT + Cleanup(func()) +}) *MockInitResponse { + mock := &MockInitResponse{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/lambda-managed-instances/interop/mock_init_static_data_provider.go b/internal/lambda-managed-instances/interop/mock_init_static_data_provider.go new file mode 100644 index 0000000..bebc4b8 --- /dev/null +++ b/internal/lambda-managed-instances/interop/mock_init_static_data_provider.go @@ -0,0 +1,248 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package interop + +import ( + time "time" + + mock "github.com/stretchr/testify/mock" + model "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/model" +) + +type MockInitStaticDataProvider struct { + mock.Mock +} + +func (_m *MockInitStaticDataProvider) AmiId() string { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for AmiId") + } + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +func (_m *MockInitStaticDataProvider) ArtefactType() model.ArtefactType { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for ArtefactType") + } + + var r0 model.ArtefactType + if rf, ok := ret.Get(0).(func() model.ArtefactType); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(model.ArtefactType) + } + + return r0 +} + +func (_m *MockInitStaticDataProvider) AvailabilityZoneId() string { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for AvailabilityZoneId") + } + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +func (_m *MockInitStaticDataProvider) FunctionARN() string { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for FunctionARN") + } + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +func (_m *MockInitStaticDataProvider) FunctionTimeout() time.Duration { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for FunctionTimeout") + } + + var r0 time.Duration + if rf, ok := ret.Get(0).(func() time.Duration); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(time.Duration) + } + + return r0 +} + +func (_m *MockInitStaticDataProvider) FunctionVersion() string { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for FunctionVersion") + } + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +func (_m *MockInitStaticDataProvider) FunctionVersionID() string { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for FunctionVersionID") + } + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +func (_m *MockInitStaticDataProvider) InitTimeout() time.Duration { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for InitTimeout") + } + + var r0 time.Duration + if rf, ok := ret.Get(0).(func() time.Duration); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(time.Duration) + } + + return r0 +} + +func (_m *MockInitStaticDataProvider) LogGroup() string { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for LogGroup") + } + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +func (_m *MockInitStaticDataProvider) LogStream() string { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for LogStream") + } + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +func (_m *MockInitStaticDataProvider) MemorySizeMB() uint64 { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for MemorySizeMB") + } + + var r0 uint64 + if rf, ok := ret.Get(0).(func() uint64); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(uint64) + } + + return r0 +} + +func (_m *MockInitStaticDataProvider) RuntimeVersion() string { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for RuntimeVersion") + } + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +func (_m *MockInitStaticDataProvider) XRayTracingMode() model.XrayTracingMode { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for XRayTracingMode") + } + + var r0 model.XrayTracingMode + if rf, ok := ret.Get(0).(func() model.XrayTracingMode); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(model.XrayTracingMode) + } + + return r0 +} + +func NewMockInitStaticDataProvider(t interface { + mock.TestingT + Cleanup(func()) +}) *MockInitStaticDataProvider { + mock := &MockInitStaticDataProvider{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/lambda-managed-instances/interop/mock_internal_state_getter.go b/internal/lambda-managed-instances/interop/mock_internal_state_getter.go new file mode 100644 index 0000000..86716fa --- /dev/null +++ b/internal/lambda-managed-instances/interop/mock_internal_state_getter.go @@ -0,0 +1,42 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package interop + +import ( + mock "github.com/stretchr/testify/mock" + statejson "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/core/statejson" +) + +type MockInternalStateGetter struct { + mock.Mock +} + +func (_m *MockInternalStateGetter) Execute() statejson.InternalStateDescription { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Execute") + } + + var r0 statejson.InternalStateDescription + if rf, ok := ret.Get(0).(func() statejson.InternalStateDescription); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(statejson.InternalStateDescription) + } + + return r0 +} + +func NewMockInternalStateGetter(t interface { + mock.TestingT + Cleanup(func()) +}) *MockInternalStateGetter { + mock := &MockInternalStateGetter{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/lambda-managed-instances/interop/mock_invoke_metrics.go b/internal/lambda-managed-instances/interop/mock_invoke_metrics.go new file mode 100644 index 0000000..68cbede --- /dev/null +++ b/internal/lambda-managed-instances/interop/mock_invoke_metrics.go @@ -0,0 +1,149 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package interop + +import ( + json "encoding/json" + time "time" + + mock "github.com/stretchr/testify/mock" + model "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" +) + +type MockInvokeMetrics struct { + mock.Mock +} + +func (_m *MockInvokeMetrics) AttachDependencies(_a0 InitStaticDataProvider, _a1 EventsAPI) { + _m.Called(_a0, _a1) +} + +func (_m *MockInvokeMetrics) AttachInvokeRequest(_a0 InvokeRequest) { + _m.Called(_a0) +} + +func (_m *MockInvokeMetrics) SendInvokeFinishedEvent(tracingCtx *TracingCtx, xrayErrorCause json.RawMessage) error { + ret := _m.Called(tracingCtx, xrayErrorCause) + + if len(ret) == 0 { + panic("no return value specified for SendInvokeFinishedEvent") + } + + var r0 error + if rf, ok := ret.Get(0).(func(*TracingCtx, json.RawMessage) error); ok { + r0 = rf(tracingCtx, xrayErrorCause) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +func (_m *MockInvokeMetrics) SendInvokeStartEvent(_a0 *TracingCtx) error { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for SendInvokeStartEvent") + } + + var r0 error + if rf, ok := ret.Get(0).(func(*TracingCtx) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +func (_m *MockInvokeMetrics) SendMetrics(_a0 model.AppError) error { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for SendMetrics") + } + + var r0 error + if rf, ok := ret.Get(0).(func(model.AppError) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +func (_m *MockInvokeMetrics) TriggerGetRequest() { + _m.Called() +} + +func (_m *MockInvokeMetrics) TriggerGetResponse() { + _m.Called() +} + +func (_m *MockInvokeMetrics) TriggerInvokeDone() (time.Duration, *time.Duration, InitStaticDataProvider) { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for TriggerInvokeDone") + } + + var r0 time.Duration + var r1 *time.Duration + var r2 InitStaticDataProvider + if rf, ok := ret.Get(0).(func() (time.Duration, *time.Duration, InitStaticDataProvider)); ok { + return rf() + } + if rf, ok := ret.Get(0).(func() time.Duration); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(time.Duration) + } + + if rf, ok := ret.Get(1).(func() *time.Duration); ok { + r1 = rf() + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*time.Duration) + } + } + + if rf, ok := ret.Get(2).(func() InitStaticDataProvider); ok { + r2 = rf() + } else { + if ret.Get(2) != nil { + r2 = ret.Get(2).(InitStaticDataProvider) + } + } + + return r0, r1, r2 +} + +func (_m *MockInvokeMetrics) TriggerSentRequest(bytes int64, requestPayloadReadDuration time.Duration, requestPayloadWriteDuration time.Duration) { + _m.Called(bytes, requestPayloadReadDuration, requestPayloadWriteDuration) +} + +func (_m *MockInvokeMetrics) TriggerSentResponse(runtimeResponseSent bool, responseErr model.AppError, streamingMetrics *InvokeResponseMetrics, errorPayloadSizeBytes int) { + _m.Called(runtimeResponseSent, responseErr, streamingMetrics, errorPayloadSizeBytes) +} + +func (_m *MockInvokeMetrics) TriggerStartRequest() { + _m.Called() +} + +func (_m *MockInvokeMetrics) UpdateConcurrencyMetrics(inflightInvokes int, idleRuntimesCount int) { + _m.Called(inflightInvokes, idleRuntimesCount) +} + +func NewMockInvokeMetrics(t interface { + mock.TestingT + Cleanup(func()) +}) *MockInvokeMetrics { + mock := &MockInvokeMetrics{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/lambda-managed-instances/interop/mock_invoke_request.go b/internal/lambda-managed-instances/interop/mock_invoke_request.go new file mode 100644 index 0000000..84e0a1a --- /dev/null +++ b/internal/lambda-managed-instances/interop/mock_invoke_request.go @@ -0,0 +1,304 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package interop + +import ( + io "io" + http "net/http" + + mock "github.com/stretchr/testify/mock" + + time "time" + + model "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" +) + +type MockInvokeRequest struct { + mock.Mock +} + +func (_m *MockInvokeRequest) AddResponseHeader(_a0 string, _a1 string) { + _m.Called(_a0, _a1) +} + +func (_m *MockInvokeRequest) BodyReader() io.Reader { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for BodyReader") + } + + var r0 io.Reader + if rf, ok := ret.Get(0).(func() io.Reader); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(io.Reader) + } + } + + return r0 +} + +func (_m *MockInvokeRequest) ClientContext() string { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for ClientContext") + } + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +func (_m *MockInvokeRequest) CognitoId() string { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for CognitoId") + } + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +func (_m *MockInvokeRequest) CognitoPoolId() string { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for CognitoPoolId") + } + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +func (_m *MockInvokeRequest) ContentType() string { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for ContentType") + } + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +func (_m *MockInvokeRequest) Deadline() time.Time { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Deadline") + } + + var r0 time.Time + if rf, ok := ret.Get(0).(func() time.Time); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(time.Time) + } + + return r0 +} + +func (_m *MockInvokeRequest) FunctionVersionID() string { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for FunctionVersionID") + } + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +func (_m *MockInvokeRequest) InvokeID() string { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for InvokeID") + } + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +func (_m *MockInvokeRequest) MaxPayloadSize() int64 { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for MaxPayloadSize") + } + + var r0 int64 + if rf, ok := ret.Get(0).(func() int64); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(int64) + } + + return r0 +} + +func (_m *MockInvokeRequest) ResponseBandwidthBurstRate() int64 { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for ResponseBandwidthBurstRate") + } + + var r0 int64 + if rf, ok := ret.Get(0).(func() int64); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(int64) + } + + return r0 +} + +func (_m *MockInvokeRequest) ResponseBandwidthRate() int64 { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for ResponseBandwidthRate") + } + + var r0 int64 + if rf, ok := ret.Get(0).(func() int64); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(int64) + } + + return r0 +} + +func (_m *MockInvokeRequest) ResponseMode() string { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for ResponseMode") + } + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +func (_m *MockInvokeRequest) ResponseWriter() http.ResponseWriter { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for ResponseWriter") + } + + var r0 http.ResponseWriter + if rf, ok := ret.Get(0).(func() http.ResponseWriter); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(http.ResponseWriter) + } + } + + return r0 +} + +func (_m *MockInvokeRequest) SetResponseHeader(_a0 string, _a1 string) { + _m.Called(_a0, _a1) +} + +func (_m *MockInvokeRequest) TraceId() string { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for TraceId") + } + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +func (_m *MockInvokeRequest) UpdateFromInitData(_a0 InitStaticDataProvider) model.AppError { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for UpdateFromInitData") + } + + var r0 model.AppError + if rf, ok := ret.Get(0).(func(InitStaticDataProvider) model.AppError); ok { + r0 = rf(_a0) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(model.AppError) + } + } + + return r0 +} + +func (_m *MockInvokeRequest) WriteResponseHeaders(_a0 int) { + _m.Called(_a0) +} + +func NewMockInvokeRequest(t interface { + mock.TestingT + Cleanup(func()) +}) *MockInvokeRequest { + mock := &MockInvokeRequest{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/lambda-managed-instances/interop/mock_invoke_response.go b/internal/lambda-managed-instances/interop/mock_invoke_response.go new file mode 100644 index 0000000..589075b --- /dev/null +++ b/internal/lambda-managed-instances/interop/mock_invoke_response.go @@ -0,0 +1,26 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package interop + +import mock "github.com/stretchr/testify/mock" + +type MockInvokeResponse struct { + mock.Mock +} + +func (_m *MockInvokeResponse) invokeResponse() { + _m.Called() +} + +func NewMockInvokeResponse(t interface { + mock.TestingT + Cleanup(func()) +}) *MockInvokeResponse { + mock := &MockInvokeResponse{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/lambda-managed-instances/interop/mock_invoke_response_sender.go b/internal/lambda-managed-instances/interop/mock_invoke_response_sender.go new file mode 100644 index 0000000..abea58e --- /dev/null +++ b/internal/lambda-managed-instances/interop/mock_invoke_response_sender.go @@ -0,0 +1,80 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package interop + +import mock "github.com/stretchr/testify/mock" + +type MockInvokeResponseSender struct { + mock.Mock +} + +func (_m *MockInvokeResponseSender) SendErrorResponse(invokeID string, response *ErrorInvokeResponse) (*InvokeResponseMetrics, error) { + ret := _m.Called(invokeID, response) + + if len(ret) == 0 { + panic("no return value specified for SendErrorResponse") + } + + var r0 *InvokeResponseMetrics + var r1 error + if rf, ok := ret.Get(0).(func(string, *ErrorInvokeResponse) (*InvokeResponseMetrics, error)); ok { + return rf(invokeID, response) + } + if rf, ok := ret.Get(0).(func(string, *ErrorInvokeResponse) *InvokeResponseMetrics); ok { + r0 = rf(invokeID, response) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*InvokeResponseMetrics) + } + } + + if rf, ok := ret.Get(1).(func(string, *ErrorInvokeResponse) error); ok { + r1 = rf(invokeID, response) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +func (_m *MockInvokeResponseSender) SendResponse(invokeID string, response *StreamableInvokeResponse) (*InvokeResponseMetrics, error) { + ret := _m.Called(invokeID, response) + + if len(ret) == 0 { + panic("no return value specified for SendResponse") + } + + var r0 *InvokeResponseMetrics + var r1 error + if rf, ok := ret.Get(0).(func(string, *StreamableInvokeResponse) (*InvokeResponseMetrics, error)); ok { + return rf(invokeID, response) + } + if rf, ok := ret.Get(0).(func(string, *StreamableInvokeResponse) *InvokeResponseMetrics); ok { + r0 = rf(invokeID, response) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*InvokeResponseMetrics) + } + } + + if rf, ok := ret.Get(1).(func(string, *StreamableInvokeResponse) error); ok { + r1 = rf(invokeID, response) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +func NewMockInvokeResponseSender(t interface { + mock.TestingT + Cleanup(func()) +}) *MockInvokeResponseSender { + mock := &MockInvokeResponseSender{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/lambda-managed-instances/interop/mock_message.go b/internal/lambda-managed-instances/interop/mock_message.go new file mode 100644 index 0000000..4c720e1 --- /dev/null +++ b/internal/lambda-managed-instances/interop/mock_message.go @@ -0,0 +1,22 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package interop + +import mock "github.com/stretchr/testify/mock" + +type MockMessage struct { + mock.Mock +} + +func NewMockMessage(t interface { + mock.TestingT + Cleanup(func()) +}) *MockMessage { + mock := &MockMessage{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/lambda-managed-instances/interop/mock_rapid_context.go b/internal/lambda-managed-instances/interop/mock_rapid_context.go new file mode 100644 index 0000000..49a616b --- /dev/null +++ b/internal/lambda-managed-instances/interop/mock_rapid_context.go @@ -0,0 +1,131 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package interop + +import ( + context "context" + netip "net/netip" + + mock "github.com/stretchr/testify/mock" + model "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" +) + +type MockRapidContext struct { + mock.Mock +} + +func (_m *MockRapidContext) HandleInit(ctx context.Context, initData InitExecutionData, initMetrics InitMetrics) model.AppError { + ret := _m.Called(ctx, initData, initMetrics) + + if len(ret) == 0 { + panic("no return value specified for HandleInit") + } + + var r0 model.AppError + if rf, ok := ret.Get(0).(func(context.Context, InitExecutionData, InitMetrics) model.AppError); ok { + r0 = rf(ctx, initData, initMetrics) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(model.AppError) + } + } + + return r0 +} + +func (_m *MockRapidContext) HandleInvoke(ctx context.Context, invokeRequest InvokeRequest, invokeMetrics InvokeMetrics) (model.AppError, bool) { + ret := _m.Called(ctx, invokeRequest, invokeMetrics) + + if len(ret) == 0 { + panic("no return value specified for HandleInvoke") + } + + var r0 model.AppError + var r1 bool + if rf, ok := ret.Get(0).(func(context.Context, InvokeRequest, InvokeMetrics) (model.AppError, bool)); ok { + return rf(ctx, invokeRequest, invokeMetrics) + } + if rf, ok := ret.Get(0).(func(context.Context, InvokeRequest, InvokeMetrics) model.AppError); ok { + r0 = rf(ctx, invokeRequest, invokeMetrics) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(model.AppError) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, InvokeRequest, InvokeMetrics) bool); ok { + r1 = rf(ctx, invokeRequest, invokeMetrics) + } else { + r1 = ret.Get(1).(bool) + } + + return r0, r1 +} + +func (_m *MockRapidContext) HandleShutdown(shutdownCause model.AppError, metrics ShutdownMetrics) model.AppError { + ret := _m.Called(shutdownCause, metrics) + + if len(ret) == 0 { + panic("no return value specified for HandleShutdown") + } + + var r0 model.AppError + if rf, ok := ret.Get(0).(func(model.AppError, ShutdownMetrics) model.AppError); ok { + r0 = rf(shutdownCause, metrics) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(model.AppError) + } + } + + return r0 +} + +func (_m *MockRapidContext) ProcessTerminationNotifier() <-chan model.AppError { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for ProcessTerminationNotifier") + } + + var r0 <-chan model.AppError + if rf, ok := ret.Get(0).(func() <-chan model.AppError); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(<-chan model.AppError) + } + } + + return r0 +} + +func (_m *MockRapidContext) RuntimeAPIAddrPort() netip.AddrPort { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for RuntimeAPIAddrPort") + } + + var r0 netip.AddrPort + if rf, ok := ret.Get(0).(func() netip.AddrPort); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(netip.AddrPort) + } + + return r0 +} + +func NewMockRapidContext(t interface { + mock.TestingT + Cleanup(func()) +}) *MockRapidContext { + mock := &MockRapidContext{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/lambda-managed-instances/interop/mock_server.go b/internal/lambda-managed-instances/interop/mock_server.go new file mode 100644 index 0000000..5535817 --- /dev/null +++ b/internal/lambda-managed-instances/interop/mock_server.go @@ -0,0 +1,51 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package interop + +import mock "github.com/stretchr/testify/mock" + +type MockServer struct { + mock.Mock +} + +func (_m *MockServer) SendInitErrorResponse(response *ErrorInvokeResponse) (*InvokeResponseMetrics, error) { + ret := _m.Called(response) + + if len(ret) == 0 { + panic("no return value specified for SendInitErrorResponse") + } + + var r0 *InvokeResponseMetrics + var r1 error + if rf, ok := ret.Get(0).(func(*ErrorInvokeResponse) (*InvokeResponseMetrics, error)); ok { + return rf(response) + } + if rf, ok := ret.Get(0).(func(*ErrorInvokeResponse) *InvokeResponseMetrics); ok { + r0 = rf(response) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*InvokeResponseMetrics) + } + } + + if rf, ok := ret.Get(1).(func(*ErrorInvokeResponse) error); ok { + r1 = rf(response) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +func NewMockServer(t interface { + mock.TestingT + Cleanup(func()) +}) *MockServer { + mock := &MockServer{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/lambda-managed-instances/interop/mock_shutdown_metrics.go b/internal/lambda-managed-instances/interop/mock_shutdown_metrics.go new file mode 100644 index 0000000..fbedf53 --- /dev/null +++ b/internal/lambda-managed-instances/interop/mock_shutdown_metrics.go @@ -0,0 +1,57 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package interop + +import ( + mock "github.com/stretchr/testify/mock" + model "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" + servicelogs "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/servicelogs" +) + +type MockShutdownMetrics struct { + mock.Mock +} + +func (_m *MockShutdownMetrics) AddMetric(metric servicelogs.Metric) { + _m.Called(metric) +} + +func (_m *MockShutdownMetrics) CreateDurationMetric(name string) DurationMetricTimer { + ret := _m.Called(name) + + if len(ret) == 0 { + panic("no return value specified for CreateDurationMetric") + } + + var r0 DurationMetricTimer + if rf, ok := ret.Get(0).(func(string) DurationMetricTimer); ok { + r0 = rf(name) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(DurationMetricTimer) + } + } + + return r0 +} + +func (_m *MockShutdownMetrics) SendMetrics(error model.AppError) { + _m.Called(error) +} + +func (_m *MockShutdownMetrics) SetAgentCount(internal int, external int) { + _m.Called(internal, external) +} + +func NewMockShutdownMetrics(t interface { + mock.TestingT + Cleanup(func()) +}) *MockShutdownMetrics { + mock := &MockShutdownMetrics{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/lambda-managed-instances/interop/mock_shutdown_response.go b/internal/lambda-managed-instances/interop/mock_shutdown_response.go new file mode 100644 index 0000000..867f23c --- /dev/null +++ b/internal/lambda-managed-instances/interop/mock_shutdown_response.go @@ -0,0 +1,26 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package interop + +import mock "github.com/stretchr/testify/mock" + +type MockShutdownResponse struct { + mock.Mock +} + +func (_m *MockShutdownResponse) shutdownResponse() { + _m.Called() +} + +func NewMockShutdownResponse(t interface { + mock.TestingT + Cleanup(func()) +}) *MockShutdownResponse { + mock := &MockShutdownResponse{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/lambda-managed-instances/interop/model.go b/internal/lambda-managed-instances/interop/model.go new file mode 100644 index 0000000..6c43dcd --- /dev/null +++ b/internal/lambda-managed-instances/interop/model.go @@ -0,0 +1,291 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package interop + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "log/slog" + "strings" + "time" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/core/statejson" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/utils/invariant" +) + +const ( + MaxPayloadSize = 6*1024*1024 + 100 + + ResponseBandwidthRate = 2 * 1024 * 1024 + ResponseBandwidthBurstSize = 6 * 1024 * 1024 + + MinResponseBandwidthRate = 32 * 1024 + MaxResponseBandwidthRate = 64 * 1024 * 1024 + + MinResponseBandwidthBurstSize = 32 * 1024 + MaxResponseBandwidthBurstSize = 64 * 1024 * 1024 + + InitializationType = "lambda-managed-instances" +) + +type ResponseMode string + +const ( + ResponseModeBuffered = "Buffered" + ResponseModeStreaming = "Streaming" +) + +type InvokeResponseMode string + +const ( + InvokeResponseModeBuffered InvokeResponseMode = ResponseModeBuffered + InvokeResponseModeStreaming InvokeResponseMode = ResponseModeStreaming +) + +var AllInvokeResponseModes = []string{ + string(InvokeResponseModeBuffered), string(InvokeResponseModeStreaming), +} + +func ConvertInvokeResponseModeToString(invokeResponseMode InvokeResponseMode) string { + if invokeResponseMode == "" { + return "" + } + return "invoke_response_mode=" + strings.ToLower(string(invokeResponseMode)) +} + +type FunctionResponseMode string + +const ( + FunctionResponseModeBuffered FunctionResponseMode = ResponseModeBuffered + FunctionResponseModeStreaming FunctionResponseMode = ResponseModeStreaming +) + +var AllFunctionResponseModes = []string{ + string(FunctionResponseModeBuffered), string(FunctionResponseModeStreaming), +} + +func ConvertToFunctionResponseMode(value string) (FunctionResponseMode, error) { + + if strings.EqualFold(value, string(FunctionResponseModeBuffered)) { + return FunctionResponseModeBuffered, nil + } + + if strings.EqualFold(value, string(FunctionResponseModeStreaming)) { + return FunctionResponseModeStreaming, nil + } + + allowedValues := strings.Join(AllFunctionResponseModes, ", ") + slog.Error("Unable to map value to allowed values", "value", value, "allowedValues", allowedValues) + return "", ErrInvalidFunctionResponseMode +} + +type Message interface{} + +type InvokeMetadataHeader string + +type Invoke struct { + TraceID string + LambdaSegmentID string + ID string + InvokedFunctionArn string + CognitoIdentityID string + CognitoIdentityPoolID string + Deadline time.Time + FunctionTimeout time.Duration + ClientContext string + ContentType string + Payload io.Reader + VersionID string + InvokeReceivedTime time.Time + InvokeResponseMetrics *InvokeResponseMetrics + InvokeResponseMode InvokeResponseMode +} + +func (i Invoke) GetDeadlineMs(ctx context.Context) int64 { + if deadline, hasDeadline := ctx.Deadline(); hasDeadline { + + return deadline.UnixMilli() + } + + if !i.Deadline.IsZero() { + + return i.Deadline.UnixMilli() + } + + return 0 +} + +type InvokeErrorTraceData struct { + InvokeID InvokeID `json:"requestId,omitempty"` + + ErrorCause json.RawMessage `json:"ErrorCause,omitempty"` +} + +func GetErrorResponseWithFormattedErrorMessage(errorType model.ErrorType, err error, invokeID InvokeID) *ErrorInvokeResponse { + var errorMessage string + if invokeID != "" { + errorMessage = fmt.Sprintf("RequestId: %s Error: %v", invokeID, err) + } else { + errorMessage = fmt.Sprintf("Error: %v", err) + } + + jsonPayload, err := json.Marshal(model.FunctionError{ + Type: errorType, + Message: errorMessage, + }) + if err != nil { + return &ErrorInvokeResponse{ + Headers: InvokeResponseHeaders{}, + FunctionError: model.FunctionError{ + Type: model.ErrorSandboxFailure, + Message: errorMessage, + }, + Payload: []byte{}, + } + } + + headers := InvokeResponseHeaders{} + functionError := model.FunctionError{ + Type: errorType, + Message: errorMessage, + } + + return &ErrorInvokeResponse{Headers: headers, FunctionError: functionError, Payload: jsonPayload} +} + +type Shutdown struct { + DeadlineNs int64 +} + +type TelemetrySubscriptionMetrics map[string]int + +type InvokeResponseMetrics struct { + StartReadingResponseTime time.Time + + FinishReadingResponseTime time.Time + TimeShaped time.Duration + ProducedBytes int64 + OutboundThroughputBps int64 + FunctionResponseMode FunctionResponseMode + RuntimeCalledResponse bool + + TransferError error + + Interrupted bool + + ResponsePayloadReadDuration time.Duration + + ResponsePayloadWriteDuration time.Duration +} + +func IsResponseStreamingMetrics(metrics *InvokeResponseMetrics) bool { + if metrics == nil { + return false + } + return metrics.FunctionResponseMode == FunctionResponseModeStreaming +} + +type DoneMetadataMetricsDimensions struct { + InvokeResponseMode InvokeResponseMode +} + +func (dimensions DoneMetadataMetricsDimensions) String() string { + return ConvertInvokeResponseModeToString(dimensions.InvokeResponseMode) +} + +type DoneMetadata struct { + NumActiveExtensions int + ExtensionNames string + RuntimeRelease string + + InvokeRequestReadTimeNs int64 + InvokeRequestSizeBytes int64 + InvokeCompletionTimeNs int64 + InvokeReceivedTime int64 + RuntimeReadyTime int64 + RuntimeResponseLatencyMs float64 + RuntimeTimeThrottledMs int64 + RuntimeProducedBytes int64 + RuntimeOutboundThroughputBps int64 + MetricsDimensions DoneMetadataMetricsDimensions +} + +type Done struct { + WaitForExit bool + ErrorType model.ErrorType + Meta DoneMetadata +} + +type DoneFail struct { + ErrorType model.ErrorType + Meta DoneMetadata +} + +var ErrInvalidFunctionResponseMode = fmt.Errorf("ErrInvalidFunctionResponseMode") + +var ErrInvalidInvokeResponseMode = fmt.Errorf("ErrInvalidInvokeResponseMode") + +var ErrInvalidMaxPayloadSize = fmt.Errorf("ErrInvalidMaxPayloadSize") + +var ErrInvalidResponseBandwidthRate = fmt.Errorf("ErrInvalidResponseBandwidthRate") + +var ErrInvalidResponseBandwidthBurstSize = fmt.Errorf("ErrInvalidResponseBandwidthBurstSize") + +var ErrMalformedCustomerHeaders = fmt.Errorf("ErrMalformedCustomerHeaders") + +var ErrResponseSent = fmt.Errorf("ErrResponseSent") + +var ErrReservationExpired = fmt.Errorf("ErrReservationExpired") + +type ErrInternalPlatformError struct{} + +func (s *ErrInternalPlatformError) Error() string { + return "ErrInternalPlatformError" +} + +type ErrTruncatedResponse struct{} + +func (s *ErrTruncatedResponse) Error() string { + return "ErrTruncatedResponse" +} + +type ErrorResponseTooLarge struct { + MaxResponseSize int + ResponseSize int +} + +type ErrorResponseTooLargeDI struct { + ErrorResponseTooLarge +} + +func (s *ErrorResponseTooLarge) Error() string { + return fmt.Sprintf("Response payload size (%d bytes) exceeded maximum allowed payload size (%d bytes).", s.ResponseSize, s.MaxResponseSize) +} + +func (s *ErrorResponseTooLarge) AsErrorResponse() *ErrorInvokeResponse { + functionError := model.FunctionError{ + Type: model.ErrorFunctionOversizedResponse, + Message: s.Error(), + } + jsonPayload, err := json.Marshal(functionError) + invariant.Check(err == nil, "Failed to marshal interop.FunctionError") + + headers := InvokeResponseHeaders{ContentType: "application/json"} + return &ErrorInvokeResponse{Headers: headers, FunctionError: functionError, Payload: jsonPayload} +} + +type Server interface { + SendInitErrorResponse(response *ErrorInvokeResponse) (*InvokeResponseMetrics, error) +} + +type InternalStateGetter func() statejson.InternalStateDescription + +var ( + ErrTimeout = errors.New("errTimeout") + ErrPlatformError = errors.New("ErrPlatformError") +) diff --git a/internal/lambda-managed-instances/interop/model_test.go b/internal/lambda-managed-instances/interop/model_test.go new file mode 100644 index 0000000..37334a8 --- /dev/null +++ b/internal/lambda-managed-instances/interop/model_test.go @@ -0,0 +1,85 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package interop + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" +) + +func TestGetErrorResponseWithFormattedErrorMessageWithoutInvokeRequestId(t *testing.T) { + errorType := model.ErrorRuntimeExit + errorMessage := fmt.Errorf("Divided by 0") + expectedMsg := fmt.Sprintf(`Error: %s`, errorMessage) + expectedJSON := fmt.Sprintf(`{"errorType": "%s", "errorMessage": "%s"}`, string(errorType), expectedMsg) + + actual := GetErrorResponseWithFormattedErrorMessage(errorType, errorMessage, "") + assert.Equal(t, errorType, actual.FunctionError.Type) + assert.Equal(t, expectedMsg, actual.FunctionError.Message) + assert.JSONEq(t, expectedJSON, string(actual.Payload)) +} + +func TestGetErrorResponseWithFormattedErrorMessageWithInvokeRequestId(t *testing.T) { + errorType := model.ErrorRuntimeExit + errorMessage := fmt.Errorf("Divided by 0") + invokeID := "invoke-id" + expectedMsg := fmt.Sprintf(`RequestId: %s Error: %s`, invokeID, errorMessage) + expectedJSON := fmt.Sprintf(`{"errorType": "%s", "errorMessage": "%s"}`, string(errorType), expectedMsg) + + actual := GetErrorResponseWithFormattedErrorMessage(errorType, errorMessage, invokeID) + assert.Equal(t, errorType, actual.FunctionError.Type) + assert.Equal(t, expectedMsg, actual.FunctionError.Message) + assert.JSONEq(t, expectedJSON, string(actual.Payload)) +} + +func TestDoneMetadataMetricsDimensionsStringWhenInvokeResponseModeIsPresent(t *testing.T) { + testcase := []struct { + name string + expectedDim string + dim DoneMetadataMetricsDimensions + }{ + { + name: "invoke response mode is streaming", + expectedDim: "invoke_response_mode=streaming", + dim: DoneMetadataMetricsDimensions{ + InvokeResponseMode: InvokeResponseModeStreaming, + }, + }, + { + name: "invoke response mode is buffered", + expectedDim: "invoke_response_mode=buffered", + dim: DoneMetadataMetricsDimensions{ + InvokeResponseMode: InvokeResponseModeBuffered, + }, + }, + } + for _, tt := range testcase { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expectedDim, tt.dim.String()) + }) + } +} + +func TestDoneMetadataMetricsDimensionsStringWhenEmpty(t *testing.T) { + dimensions := DoneMetadataMetricsDimensions{} + assert.Equal(t, "", dimensions.String()) +} + +func TestGetDeadlineMs(t *testing.T) { + invoke := Invoke{Deadline: time.Now()} + + nowCtx, cancel := context.WithDeadline(context.Background(), time.Now()) + defer cancel() + assert.Equal(t, invoke.GetDeadlineMs(context.Background()), invoke.GetDeadlineMs(nowCtx)) + + notNowCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(1*time.Second)) + defer cancel() + assert.Less(t, invoke.GetDeadlineMs(context.Background()), invoke.GetDeadlineMs(notNowCtx)) +} diff --git a/internal/lambda-managed-instances/interop/response_status.go b/internal/lambda-managed-instances/interop/response_status.go new file mode 100644 index 0000000..d65cdd1 --- /dev/null +++ b/internal/lambda-managed-instances/interop/response_status.go @@ -0,0 +1,13 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package interop + +type ResponseStatus = string + +const ( + Success ResponseStatus = "success" + Timeout ResponseStatus = "timeout" + Error ResponseStatus = "error" + Failure ResponseStatus = "failure" +) diff --git a/internal/lambda-managed-instances/interop/sandbox_model.go b/internal/lambda-managed-instances/interop/sandbox_model.go new file mode 100644 index 0000000..23fb84b --- /dev/null +++ b/internal/lambda-managed-instances/interop/sandbox_model.go @@ -0,0 +1,344 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package interop + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/netip" + "time" + + intmodel "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/servicelogs" + supvmodel "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/supervisor/model" +) + +type InitSuccess struct { + NumActiveExtensions int + ExtensionNames string + RuntimeRelease string + Ack chan struct{} + NumInternalExtensions int + NumExternalExtensions int + RuntimeInitDuration time.Duration +} + +type InitFailure struct { + Error model.AppError + ErrorType model.ErrorType + ErrorCategory model.ErrorCategory + ErrorMessage error + + NumActiveExtensions int + RuntimeRelease string + Ack chan struct{} + NumInternalExtensions int + NumExternalExtensions int + RuntimeInitDuration time.Duration +} + +type InitResponse interface { + initResponse() +} + +func (s InitSuccess) initResponse() {} +func (f InitFailure) initResponse() {} + +type PlatformError struct { + model.PlatformError + RuntimeRelease string +} + +type ClientError struct { + model.ClientError +} + +func (PlatformError) initResponse() {} +func (ClientError) initResponse() {} + +var ( + _ InitResponse = PlatformError{} + _ InitResponse = ClientError{} +) + +type ErrorInvokeResponse struct { + Headers InvokeResponseHeaders + Payload []byte + FunctionError model.FunctionError +} + +type StreamableInvokeResponse struct { + Headers map[string]string + Payload io.Reader + Trailers http.Header + Request *CancellableRequest +} + +type InvokeResponseHeaders struct { + ContentType string + FunctionResponseMode string +} + +type InvokeResponseSender interface { + SendResponse(invokeID InvokeID, response *StreamableInvokeResponse) (*InvokeResponseMetrics, error) + + SendErrorResponse(invokeID InvokeID, response *ErrorInvokeResponse) (*InvokeResponseMetrics, error) +} + +type ResponseMetrics struct { + RuntimeOutboundThroughputBps int64 + RuntimeProducedBytes int64 + RuntimeResponseLatency time.Duration + RuntimeTimeThrottled time.Duration +} + +type RuntimeInitMetrics struct { + Duration time.Duration + Error error +} + +type InvokeSuccess struct { + RuntimeRelease string + NumActiveExtensions int + ExtensionNames string + ExtensionsOverhead time.Duration + InvokeCompletionTime time.Duration + InvokeReceivedTime time.Time + ResponseMetrics ResponseMetrics + InvokeResponseMode InvokeResponseMode +} + +type InvokeFailure struct { + ErrorType model.ErrorType + ErrorMessage error + RuntimeRelease string + NumActiveExtensions int + ExtensionsOverhead time.Duration + InvokeReceivedTime time.Time + ResponseMetrics ResponseMetrics + ExtensionNames string + DefaultErrorResponse *ErrorInvokeResponse + InvokeResponseMode InvokeResponseMode +} + +type InvokeResponse interface{ invokeResponse() } + +func (InvokeSuccess) invokeResponse() {} +func (InvokeFailure) invokeResponse() {} +func (PlatformError) invokeResponse() {} + +type ShutdownSuccess struct{} + +type ShutdownResponse interface{ shutdownResponse() } + +func (ShutdownSuccess) shutdownResponse() {} +func (PlatformError) shutdownResponse() {} + +type HealthyContainerResponse struct{} + +type UnhealthyContainerResponse struct { + ErrorType model.ErrorType +} + +type HealthCheckResponse interface { + healthCheckResponse() +} + +func (HealthyContainerResponse) healthCheckResponse() {} +func (UnhealthyContainerResponse) healthCheckResponse() {} + +type InitExecutionData struct { + ExtensionEnv intmodel.KVMap + Runtime model.Runtime + Credentials model.Credentials + LogGroupName string + LogStreamName string + FunctionMetadata model.FunctionMetadata + RuntimeManagedLoggingFormats []supvmodel.ManagedLoggingFormat + StaticData EEStaticData + TelemetrySubscriptionConfig TelemetrySubscriptionConfig +} + +func (i *InitExecutionData) FunctionARN() string { + return i.StaticData.FunctionARN +} + +func (i *InitExecutionData) FunctionVersion() string { + return i.FunctionMetadata.FunctionVersion +} + +func (i *InitExecutionData) MemorySizeMB() uint64 { + return i.FunctionMetadata.MemorySizeBytes / (1024 * 1024) +} + +func (i *InitExecutionData) FunctionVersionID() string { + return i.StaticData.FunctionVersionID +} + +func (i *InitExecutionData) FunctionTimeout() time.Duration { + return i.StaticData.FunctionTimeout +} + +func (i *InitExecutionData) InitTimeout() time.Duration { + return i.StaticData.InitTimeout +} + +func (i *InitExecutionData) LogGroup() string { + return i.StaticData.LogGroupName +} + +func (i *InitExecutionData) LogStream() string { + return i.StaticData.LogStreamName +} + +func (i *InitExecutionData) XRayTracingMode() intmodel.XrayTracingMode { + return i.StaticData.XRayTracingMode +} + +func (i *InitExecutionData) TelemetryPassphrase() string { + return i.TelemetrySubscriptionConfig.Passphrase +} + +func (i *InitExecutionData) TelemetryAPIAddr() netip.AddrPort { + return i.TelemetrySubscriptionConfig.APIAddr +} + +func (i *InitExecutionData) ArtefactType() intmodel.ArtefactType { + return i.StaticData.ArtefactType +} + +func (i *InitExecutionData) AmiId() string { + return i.StaticData.AmiId +} + +func (i *InitExecutionData) RuntimeVersion() string { + return i.StaticData.RuntimeVersion +} + +func (i *InitExecutionData) AvailabilityZoneId() string { + return i.StaticData.AvailabilityZoneId +} + +type EEStaticData struct { + FunctionARN string + FunctionVersionID string + InitTimeout time.Duration + FunctionTimeout time.Duration + LogGroupName string + LogStreamName string + XRayTracingMode intmodel.XrayTracingMode + ArtefactType intmodel.ArtefactType + RuntimeVersion string + AmiId string + AvailabilityZoneId string +} + +type TelemetrySubscriptionConfig struct { + Passphrase string + APIAddr netip.AddrPort +} + +type RapidContext interface { + HandleInit(ctx context.Context, initData InitExecutionData, initMetrics InitMetrics) (err model.AppError) + + HandleShutdown(shutdownCause model.AppError, metrics ShutdownMetrics) model.AppError + HandleInvoke(ctx context.Context, invokeRequest InvokeRequest, invokeMetrics InvokeMetrics) (err model.AppError, wasResponseSent bool) + RuntimeAPIAddrPort() netip.AddrPort + + ProcessTerminationNotifier() <-chan model.AppError +} + +type LifecyclePhase int + +const ( + LifecyclePhaseInit LifecyclePhase = iota + 1 + LifecyclePhaseInvoke +) + +type InvokeRequest interface { + ContentType() string + InvokeID() InvokeID + Deadline() time.Time + TraceId() string + ClientContext() string + CognitoId() string + CognitoPoolId() string + ResponseBandwidthRate() int64 + ResponseBandwidthBurstRate() int64 + MaxPayloadSize() int64 + ResponseMode() string + + BodyReader() io.Reader + ResponseWriter() http.ResponseWriter + + SetResponseHeader(string, string) + AddResponseHeader(string, string) + WriteResponseHeaders(int) + + UpdateFromInitData(InitStaticDataProvider) model.AppError + FunctionVersionID() string +} + +type InitStaticDataProvider interface { + FunctionARN() string + FunctionVersion() string + FunctionVersionID() string + InitTimeout() time.Duration + FunctionTimeout() time.Duration + LogGroup() string + LogStream() string + XRayTracingMode() intmodel.XrayTracingMode + MemorySizeMB() uint64 + ArtefactType() intmodel.ArtefactType + AmiId() string + RuntimeVersion() string + AvailabilityZoneId() string +} + +type InvokeMetrics interface { + TriggerGetRequest() + AttachInvokeRequest(InvokeRequest) + AttachDependencies(InitStaticDataProvider, EventsAPI) + UpdateConcurrencyMetrics(inflightInvokes, idleRuntimesCount int) + TriggerStartRequest() + TriggerSentRequest(bytes int64, requestPayloadReadDuration, requestPayloadWriteDuration time.Duration) + TriggerGetResponse() + TriggerSentResponse(runtimeResponseSent bool, responseErr model.AppError, streamingMetrics *InvokeResponseMetrics, errorPayloadSizeBytes int) + + TriggerInvokeDone() (totalMs time.Duration, runMs *time.Duration, initData InitStaticDataProvider) + + SendInvokeStartEvent(*TracingCtx) error + SendInvokeFinishedEvent(tracingCtx *TracingCtx, xrayErrorCause json.RawMessage) error + SendMetrics(model.AppError) error +} + +type InitMetrics interface { + TriggerGetRequest() + SetLogsAPIMetrics(TelemetrySubscriptionMetrics) + SetExtensionsNumber(internal, external int) + TriggerStartRequest() + TriggerStartingRuntime() + TriggerRuntimeDone() + TriggerInitCustomerPhaseDone() + TriggerInitDone(model.AppError) + + RunDuration() time.Duration + SendMetrics() error +} + +type ShutdownMetrics interface { + CreateDurationMetric(name string) DurationMetricTimer + AddMetric(metric servicelogs.Metric) + + SetAgentCount(internal, external int) + + SendMetrics(error model.AppError) +} + +type DurationMetricTimer interface { + Done() +} diff --git a/internal/lambda-managed-instances/interop/service_log_values.go b/internal/lambda-managed-instances/interop/service_log_values.go new file mode 100644 index 0000000..cb6d1a5 --- /dev/null +++ b/internal/lambda-managed-instances/interop/service_log_values.go @@ -0,0 +1,37 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package interop + +const ( + RequestIdProperty = "RequestId" + MemorySizeMbProperty = "MemorySizeMB" + FunctionArnProperty = "FunctionArn" + FunctionVersionIdProperty = "FunctionVersionId" + RuntimeVersionProperty = "RuntimeVersion" + + ArtefactTypeDimension = "ArtefactType" + AvailabilityZoneDimension = "AvailabilityZoneId" + WorkerAmiIdDimension = "WorkerAmiId" + + TotalDurationMetric = "TotalDuration" + PlatformOverheadDurationMetric = "PlatformOverheadDuration" + TotalExtensionsCountMetric = "TotalExtensionsCount" + InternalExtensionsCountMetric = "InternalExtensionsCount" + ExternalExtensionsCountMetric = "ExternalExtensionsCount" + + ShutdownAbortInvokesDurationMetric = "AbortInvokeDuration" + ShutdownKillProcessDurationMetricTemplate = "Kill%sDuration" + ShutdownRuntimeDuration = "StopRuntimeDuration" + ShutdownExtensionsDuration = "StopExtensionsDuration" + ShutdownWaitAllProcessesDuration = "WaitCustomerProcessesExitDuration" + ShutdownRuntimeServerDuration = "StopRuntimeServerDuration" + + ClientErrorMetric = "ClientError" + ClientErrorReasonTemplate = "ClientErrorReason-%s" + CustomerErrorMetric = "CustomerError" + CustomerErrorReasonTemplate = "CustomerErrorReason-%s" + PlatformErrorMetric = "PlatformError" + PlatformErrorReasonTemplate = "PlatformErrorReason-%s" + NonCustomerErrorMetric = "NonCustomerError" +) diff --git a/internal/lambda-managed-instances/invoke/consts.go b/internal/lambda-managed-instances/invoke/consts.go new file mode 100644 index 0000000..1632a4f --- /dev/null +++ b/internal/lambda-managed-instances/invoke/consts.go @@ -0,0 +1,21 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package invoke + +const ( + СontentTypeHeader = "content-type" + FunctionErrorTypeTrailer = "lambda-runtime-function-error-type" + FunctionErrorBodyTrailer = "lambda-runtime-function-error-body" + ResponseModeHeader = "invoke-response-mode" + TraceIdHeader = "x-amzn-trace-id" +) + +type InvokeBodyResponseStatus string + +const ( + InvokeBodyResponseComplete InvokeBodyResponseStatus = "Complete" + InvokeBodyResponseTruncated InvokeBodyResponseStatus = "Truncated" + invokeBodyResponseOversized InvokeBodyResponseStatus = "Oversized" + invokeBodyResponseTimeout InvokeBodyResponseStatus = "Timeout" +) diff --git a/internal/lambda-managed-instances/invoke/invoke_router.go b/internal/lambda-managed-instances/invoke/invoke_router.go new file mode 100644 index 0000000..f67bf95 --- /dev/null +++ b/internal/lambda-managed-instances/invoke/invoke_router.go @@ -0,0 +1,201 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package invoke + +import ( + "context" + "encoding/json" + "errors" + "io" + "log/slog" + "net/http" + "sync" + + cmap "github.com/orcaman/concurrent-map" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/logging" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" +) + +var ( + ErrInvokeIdAlreadyExists = errors.New("invoke ID already exists") + ErrInvokeNoReadyRuntime = errors.New("no idle runtimes") +) + +type RuntimeResponseRequest interface { + ParsingError() model.AppError + + InvokeID() interop.InvokeID + ContentType() string + ResponseMode() string + + BodyReader() io.Reader + + TrailerError() ErrorForInvoker +} + +type RuntimeErrorRequest interface { + InvokeID() interop.InvokeID + ContentType() string + ErrorType() model.ErrorType + ErrorCategory() model.ErrorCategory + GetError() model.AppError + IsRuntimeError(model.AppError) bool + ReturnCode() int + ErrorDetails() string + GetXrayErrorCause() json.RawMessage +} + +type runningInvoke interface { + RunInvokeAndSendResult(context.Context, interop.InitStaticDataProvider, interop.InvokeRequest, interop.InvokeMetrics) model.AppError + RuntimeNextWait(context.Context) model.AppError + RuntimeResponse(context.Context, RuntimeResponseRequest) model.AppError + RuntimeError(context.Context, RuntimeErrorRequest) model.AppError + CancelAsync(model.AppError) +} + +type timeoutCache interface { + Register(invokeID interop.InvokeID) + Consume(invokeID interop.InvokeID) (consumed bool) +} + +type InvokeRouter struct { + eventsApi interop.EventsAPI + + idleRuntimes chan runningInvoke + + runningInvokes cmap.ConcurrentMap + + wg sync.WaitGroup + + createRunningInvoke func(http.ResponseWriter) runningInvoke + + timeoutCache timeoutCache +} + +func NewInvokeRouter( + maxIdleRuntimesQueueSize int, + telemetryEventsApi interop.EventsAPI, + responderFactoryFunc ResponderFactoryFunc, + timeoutCache timeoutCache, +) *InvokeRouter { + return &InvokeRouter{ + + idleRuntimes: make(chan runningInvoke, maxIdleRuntimesQueueSize), + runningInvokes: cmap.New(), + eventsApi: telemetryEventsApi, + timeoutCache: timeoutCache, + createRunningInvoke: func(runtimeNext http.ResponseWriter) runningInvoke { + r := newRunningInvoke(runtimeNext, responderFactoryFunc, timeoutCache) + return &r + }, + } +} + +func (ir *InvokeRouter) Invoke(ctx context.Context, initData interop.InitStaticDataProvider, invokeReq interop.InvokeRequest, metrics interop.InvokeMetrics) (err model.AppError, wasResponseSent bool) { + logging.Debug(ctx, "InvokeRouter: received Invoke") + ir.wg.Add(1) + defer ir.wg.Done() + + var idleRuntime runningInvoke + + metrics.UpdateConcurrencyMetrics(ir.runningInvokes.Count(), len(ir.idleRuntimes)) + + if !ir.runningInvokes.SetIfAbsent(invokeReq.InvokeID(), idleRuntime) { + logging.Warn(ctx, "InvokeRouter error: duplicated invokeId") + return model.NewClientError(ErrInvokeIdAlreadyExists, model.ErrorSeverityError, model.ErrorDublicatedInvokeId), false + } + + defer ir.runningInvokes.Remove(invokeReq.InvokeID()) + + select { + case idleRuntime = <-ir.idleRuntimes: + + ir.runningInvokes.Set(invokeReq.InvokeID(), idleRuntime) + default: + logging.Warn(ctx, "InvokeRouter: no ready runtimes") + return model.NewClientError(ErrInvokeNoReadyRuntime, model.ErrorSeverityError, model.ErrorRuntimeUnavailable), false + } + + return idleRuntime.RunInvokeAndSendResult(ctx, initData, invokeReq, metrics), true +} + +func (ir *InvokeRouter) RuntimeNext(ctx context.Context, runtimeReq http.ResponseWriter) (model.RuntimeNextWaiter, model.AppError) { + logging.Debug(ctx, "InvokeRouter: received runtime /next") + + newRunningInvoke := ir.createRunningInvoke(runtimeReq) + + if err := ir.addIdleRuntimeToQueue(newRunningInvoke); err != nil { + logging.Error(ctx, "InvokeRouter: failed to add idle runtime to the queue", "err", err) + return nil, err + } + + return newRunningInvoke, nil +} + +func (ir *InvokeRouter) RuntimeResponse(ctx context.Context, runtimeRespReq RuntimeResponseRequest) model.AppError { + logging.Debug(ctx, "InvokeRouter: received runtime response") + + invoke, ok := ir.runningInvokes.Get(runtimeRespReq.InvokeID()) + if !ok { + if ir.timeoutCache.Consume(runtimeRespReq.InvokeID()) { + logging.Warn(ctx, "InvokeRouter: response is too late for timed out invoke") + return model.NewCustomerError(model.ErrorRuntimeInvokeTimeout) + } + logging.Warn(ctx, "InvokeRouter: invoke id not found") + return model.NewCustomerError(model.ErrorRuntimeInvalidInvokeId) + } + + return invoke.(runningInvoke).RuntimeResponse(ctx, runtimeRespReq) +} + +func (ir *InvokeRouter) RuntimeError(ctx context.Context, runtimeErrReq RuntimeErrorRequest) model.AppError { + invoke, ok := ir.runningInvokes.Get(runtimeErrReq.InvokeID()) + if !ok { + if ir.timeoutCache.Consume(runtimeErrReq.InvokeID()) { + logging.Warn(ctx, "InvokeRouter: error is too late for timed out invoke") + return model.NewCustomerError(model.ErrorRuntimeInvokeTimeout) + } + logging.Warn(ctx, "InvokeRouter: invoke id not found") + return model.NewCustomerError(model.ErrorRuntimeInvalidInvokeId) + } + + logging.Warn(ctx, "InvokeRouter: received Runtime error", "err", runtimeErrReq.GetError()) + return invoke.(runningInvoke).RuntimeError(ctx, runtimeErrReq) +} + +func (ir *InvokeRouter) AbortRunningInvokes(metrics interop.ShutdownMetrics, err model.AppError) { + duration := metrics.CreateDurationMetric(interop.ShutdownAbortInvokesDurationMetric) + defer duration.Done() + + slog.Info("InvokeRouter: Aborting running invokes", "reason", err) + + ir.runningInvokes.IterCb(func(key string, v interface{}) { + if runningInvoke, ok := v.(runningInvoke); ok { + runningInvoke.CancelAsync(err) + } + }) + + slog.Debug("InvokeRouter: Waiting for invokes to be aborted") + ir.wg.Wait() + +} + +func (ir *InvokeRouter) addIdleRuntimeToQueue(invoke runningInvoke) model.AppError { + select { + case ir.idleRuntimes <- invoke: + return nil + default: + return model.NewCustomerError(model.ErrorRuntimeTooManyIdleRuntimes) + } +} + +func (ir *InvokeRouter) GetRunningInvokesCount() int { + return ir.runningInvokes.Count() +} + +func (ir *InvokeRouter) GetIdleRuntimesCount() int { + return len(ir.idleRuntimes) +} diff --git a/internal/lambda-managed-instances/invoke/invoke_router_test.go b/internal/lambda-managed-instances/invoke/invoke_router_test.go new file mode 100644 index 0000000..03e8bf6 --- /dev/null +++ b/internal/lambda-managed-instances/invoke/invoke_router_test.go @@ -0,0 +1,345 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package invoke + +import ( + "context" + "net/http" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/telemetry" +) + +const ( + testInvokeRouterMaxIdleRuntime = 5 +) + +type invokeRouterMocks struct { + ctx context.Context + staticData interop.MockInitStaticDataProvider + eaInvokeRequest interop.MockInvokeRequest + runnningInvoke mockRunningInvoke + runtimeRespReq MockRuntimeResponseRequest + runtimeErrorReq MockRuntimeErrorRequest + invokeMetrics interop.MockInvokeMetrics + shutdownMetrics interop.MockShutdownMetrics + durationMetricTimer *interop.MockDurationMetricTimer + timeoutCache *mockTimeoutCache + + runtimeNextRequest http.ResponseWriter +} + +func newInvokeRouterMocks() invokeRouterMocks { + return invokeRouterMocks{ + ctx: context.TODO(), + staticData: interop.MockInitStaticDataProvider{}, + eaInvokeRequest: interop.MockInvokeRequest{}, + runnningInvoke: mockRunningInvoke{}, + runtimeRespReq: MockRuntimeResponseRequest{}, + runtimeErrorReq: MockRuntimeErrorRequest{}, + invokeMetrics: interop.MockInvokeMetrics{}, + shutdownMetrics: interop.MockShutdownMetrics{}, + durationMetricTimer: &interop.MockDurationMetricTimer{}, + timeoutCache: &mockTimeoutCache{}, + } +} + +func hijackInvokeRouterDeps(router *InvokeRouter, mocks *invokeRouterMocks) { + router.createRunningInvoke = func(runtimeNext http.ResponseWriter) runningInvoke { + return &mocks.runnningInvoke + } +} + +func createMocksAndInitRouter() (*invokeRouterMocks, *InvokeRouter) { + mocks := newInvokeRouterMocks() + router := NewInvokeRouter(testInvokeRouterMaxIdleRuntime, &telemetry.NoOpEventsAPI{}, nil, mocks.timeoutCache) + hijackInvokeRouterDeps(router, &mocks) + + return &mocks, router +} + +func checkMockExpectations(t *testing.T, mocks *invokeRouterMocks) { + mocks.staticData.AssertExpectations(t) + mocks.eaInvokeRequest.AssertExpectations(t) + mocks.runnningInvoke.AssertExpectations(t) + mocks.runtimeRespReq.AssertExpectations(t) + mocks.runtimeErrorReq.AssertExpectations(t) + mocks.shutdownMetrics.AssertExpectations(t) + mocks.durationMetricTimer.AssertExpectations(t) + mocks.timeoutCache.AssertExpectations(t) +} + +func TestInvokeSuccess(t *testing.T) { + t.Parallel() + + mocks, router := createMocksAndInitRouter() + defer checkMockExpectations(t, mocks) + + respChannel := make(chan time.Time) + syncChan := make(chan struct{}) + + mocks.runnningInvoke.On("RuntimeNextWait", mock.Anything).Return(nil).Once() + runningInvoke, err := router.RuntimeNext(mocks.ctx, mocks.runtimeNextRequest) + require.NoError(t, err) + require.NoError(t, runningInvoke.RuntimeNextWait(mocks.ctx)) + + mocks.invokeMetrics.On("UpdateConcurrencyMetrics", 0, 1) + + mocks.eaInvokeRequest.On("InvokeID").Return("123456") + mocks.runnningInvoke.On("RunInvokeAndSendResult", mock.Anything, &mocks.staticData, &mocks.eaInvokeRequest, mock.Anything).Run(func(args mock.Arguments) { + + close(syncChan) + + <-respChannel + }).Return(nil) + + mocks.runtimeRespReq.On("InvokeID").Return("123456") + mocks.runnningInvoke.On("RuntimeResponse", mock.Anything, &mocks.runtimeRespReq).Return(nil) + + wg := new(sync.WaitGroup) + wg.Add(1) + go func() { + defer wg.Done() + err, wasResponseSent := router.Invoke(mocks.ctx, &mocks.staticData, &mocks.eaInvokeRequest, &mocks.invokeMetrics) + assert.NoError(t, err) + assert.True(t, wasResponseSent) + }() + + <-syncChan + + err = router.RuntimeResponse(mocks.ctx, &mocks.runtimeRespReq) + assert.NoError(t, err) + + close(respChannel) + wg.Wait() + + checkMockExpectations(t, mocks) +} + +func TestInvokeFailure_NoIdleRuntime(t *testing.T) { + t.Parallel() + + mocks, router := createMocksAndInitRouter() + defer checkMockExpectations(t, mocks) + + mocks.invokeMetrics.On("UpdateConcurrencyMetrics", 0, 0) + + mocks.eaInvokeRequest.On("InvokeID").Return("123456") + + err, wasResponseSent := router.Invoke(mocks.ctx, &mocks.staticData, &mocks.eaInvokeRequest, &mocks.invokeMetrics) + assert.Error(t, err) + assert.False(t, wasResponseSent) + assert.Equal(t, model.ErrorRuntimeUnavailable, err.ErrorType()) +} + +func TestInvokeFailure_DublicatedInvokeId(t *testing.T) { + t.Parallel() + + mocks, router := createMocksAndInitRouter() + defer checkMockExpectations(t, mocks) + + mocks.runnningInvoke.On("RuntimeNextWait", mock.Anything).Return(nil).Twice() + runningInvoke, err := router.RuntimeNext(mocks.ctx, mocks.runtimeNextRequest) + require.NoError(t, err) + require.NoError(t, runningInvoke.RuntimeNextWait(mocks.ctx)) + + runningInvoke, err = router.RuntimeNext(mocks.ctx, mocks.runtimeNextRequest) + require.NoError(t, err) + require.NoError(t, runningInvoke.RuntimeNextWait(mocks.ctx)) + + respChannel := make(chan time.Time) + + mocks.invokeMetrics.On("UpdateConcurrencyMetrics", mock.AnythingOfType("int"), mock.AnythingOfType("int")).Twice() + + mocks.eaInvokeRequest.On("InvokeID").Return("123456") + mocks.runnningInvoke.On("RunInvokeAndSendResult", mock.Anything, &mocks.staticData, &mocks.eaInvokeRequest, mock.Anything).Return(nil).WaitUntil(respChannel).Once() + + wg := new(sync.WaitGroup) + ch := make(chan model.AppError, 2) + var wasResponseSentCnt atomic.Uint32 + + wg.Add(1) + go func() { + defer wg.Done() + err, wasResponseSent := router.Invoke(mocks.ctx, &mocks.staticData, &mocks.eaInvokeRequest, &mocks.invokeMetrics) + if wasResponseSent { + wasResponseSentCnt.Add(1) + } + ch <- err + }() + + wg.Add(1) + go func() { + defer wg.Done() + err, wasResponseSent := router.Invoke(mocks.ctx, &mocks.staticData, &mocks.eaInvokeRequest, &mocks.invokeMetrics) + if wasResponseSent { + wasResponseSentCnt.Add(1) + } + ch <- err + }() + + err = <-ch + assert.Error(t, err) + assert.Equal(t, model.ErrorDublicatedInvokeId, err.ErrorType()) + + close(respChannel) + err = <-ch + assert.NoError(t, err) + + wg.Wait() + + assert.Equal(t, uint32(1), wasResponseSentCnt.Load()) + checkMockExpectations(t, mocks) +} + +func TestRuntimeNextFailure_TooManyIdleInvokes(t *testing.T) { + t.Parallel() + + mocks, router := createMocksAndInitRouter() + defer checkMockExpectations(t, mocks) + + mocks.runnningInvoke.On("RuntimeNextWait", mock.Anything).Return(nil).Times(testInvokeRouterMaxIdleRuntime) + for range testInvokeRouterMaxIdleRuntime { + runningInvoke, err := router.RuntimeNext(mocks.ctx, mocks.runtimeNextRequest) + require.NoError(t, err) + assert.NoError(t, runningInvoke.RuntimeNextWait(mocks.ctx)) + } + + runningInvoke, err := router.RuntimeNext(mocks.ctx, mocks.runtimeNextRequest) + require.Error(t, err) + require.Nil(t, runningInvoke) + assert.Equal(t, model.ErrorRuntimeTooManyIdleRuntimes, err.ErrorType()) + + checkMockExpectations(t, mocks) +} + +func TestRuntimeResponseFailure(t *testing.T) { + tests := []struct { + name string + inTimeoutCache bool + wantErrorType model.ErrorType + }{ + { + name: "InvokeIdNotFound", + wantErrorType: model.ErrorRuntimeInvalidInvokeId, + }, + { + name: "InvokeTimeout", + inTimeoutCache: true, + wantErrorType: model.ErrorRuntimeInvokeTimeout, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + mocks, router := createMocksAndInitRouter() + defer checkMockExpectations(t, mocks) + mocks.runtimeRespReq.On("InvokeID").Return("123456").Twice() + mocks.timeoutCache.On("Consume", "123456").Return(tt.inTimeoutCache).Once() + + err := router.RuntimeResponse(context.Background(), &mocks.runtimeRespReq) + + assert.Equal(t, tt.wantErrorType, err.ErrorType()) + }) + } +} + +func TestRuntimeErrorFailure(t *testing.T) { + tests := []struct { + name string + inTimeoutCache bool + wantErrorType model.ErrorType + }{ + { + name: "InvokeIdNotFound", + wantErrorType: model.ErrorRuntimeInvalidInvokeId, + }, + { + name: "InvokeTimeout", + inTimeoutCache: true, + wantErrorType: model.ErrorRuntimeInvokeTimeout, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + mocks, router := createMocksAndInitRouter() + defer checkMockExpectations(t, mocks) + mocks.runtimeErrorReq.On("InvokeID").Return("123456").Twice() + mocks.timeoutCache.On("Consume", "123456").Return(tt.inTimeoutCache).Once() + + err := router.RuntimeError(mocks.ctx, &mocks.runtimeErrorReq) + + assert.Equal(t, tt.wantErrorType, err.ErrorType()) + }) + } +} + +func TestAbortRunningInvokes(t *testing.T) { + t.Parallel() + + mocks, router := createMocksAndInitRouter() + defer checkMockExpectations(t, mocks) + + eaGracefulShutdownErr := model.NewClientError(nil, model.ErrorSeverityFatal, model.ErrorExecutionEnvironmentShutdown) + mocks.shutdownMetrics.On("CreateDurationMetric", interop.ShutdownAbortInvokesDurationMetric).Return(mocks.durationMetricTimer) + mocks.durationMetricTimer.On("Done").Return() + + idleRuntime1 := newMockRunningInvoke(t) + idleRuntime2 := newMockRunningInvoke(t) + + mockRunningInvoke1 := newMockRunningInvoke(t) + mockRunningInvoke2 := newMockRunningInvoke(t) + + router.idleRuntimes <- idleRuntime1 + router.idleRuntimes <- idleRuntime2 + + mockRunningInvoke1.On("CancelAsync", eaGracefulShutdownErr).Return() + mockRunningInvoke2.On("CancelAsync", eaGracefulShutdownErr).Return() + + router.runningInvokes.Set("1", mockRunningInvoke1) + router.runningInvokes.Set("2", mockRunningInvoke2) + + router.AbortRunningInvokes(&mocks.shutdownMetrics, eaGracefulShutdownErr) + + mockRunningInvoke1.AssertExpectations(t) + mockRunningInvoke2.AssertExpectations(t) + + idleRuntime1.AssertNumberOfCalls(t, "CancelAsync", 0) + idleRuntime2.AssertNumberOfCalls(t, "CancelAsync", 0) + + mockRunningInvoke1.AssertNumberOfCalls(t, "CancelAsync", 1) + mockRunningInvoke2.AssertNumberOfCalls(t, "CancelAsync", 1) +} + +func TestInvokeRouter_Counters(t *testing.T) { + t.Parallel() + + mocks, router := createMocksAndInitRouter() + defer checkMockExpectations(t, mocks) + + idleRuntime1 := newMockRunningInvoke(t) + idleRuntime2 := newMockRunningInvoke(t) + + mockRunningInvoke1 := newMockRunningInvoke(t) + + router.idleRuntimes <- idleRuntime1 + router.idleRuntimes <- idleRuntime2 + + router.runningInvokes.Set("1", mockRunningInvoke1) + + assert.Equal(t, 2, router.GetIdleRuntimesCount()) + assert.Equal(t, 1, router.GetRunningInvokesCount()) +} diff --git a/internal/lambda-managed-instances/invoke/metrics.go b/internal/lambda-managed-instances/invoke/metrics.go new file mode 100644 index 0000000..e7ecbc1 --- /dev/null +++ b/internal/lambda-managed-instances/invoke/metrics.go @@ -0,0 +1,382 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package invoke + +import ( + "encoding/json" + "fmt" + "log/slog" + "time" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/ptr" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/servicelogs" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/telemetry" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/utils/invariant" +) + +const ( + ResponseLatencySpanName = "responseLatency" + ResponseDurationSpanName = "responseDuration" +) + +const ( + InvokeTimeoutProperty = "InvokeTimeoutSeconds" + + RequestResponseModeDimension = "RequestMode" + ResponseModeDimension = "ResponseMode" + + RequestSendDurationMetric = "RequestSendDuration" + + RequestPayloadReadDurationMetric = "RequestPayloadReadDuration" + + RequestPayloadWriteDurationMetric = "RequestPayloadWriteDuration" + ResponseLatencyMetric = "ResponseLatency" + ResponseDurationMetric = "ResponseDuration" + FunctionDurationMetric = "FunctionDuration" + RequestPayloadSizeBytesMetric = "RequestPayloadSizeBytes" + ResponsePayloadSizeBytesMetric = "ResponsePayloadSizeBytes" + + ResponsePayloadReadDurationMetric = "ResponsePayloadReadDuration" + + ResponsePayloadWriteDurationMetric = "ResponsePayloadWriteDuration" + ErrorPayloadSizeBytesMetric = "ErrorPayloadSizeBytes" + ResponseThrottledDurationMetric = "ResponseThrottledDuration" + ResponseThroughputMetric = "ResponseThroughput" + InflightRequestCountMetric = "InflightRequestCount" + IdleRuntimesCountMetric = "IdleRuntimesCount" +) + +var invokeMetricsMissDepError = "Invoke metrics miss dependencies" + +type Counter interface { + AddInvoke(proxiedBytes uint64) +} + +type invokeMetrics struct { + telemetryEventsAPI interop.EventsAPI + logger servicelogs.Logger + + initData interop.InitStaticDataProvider + invokeReq interop.InvokeRequest + + runtimeResponseSent bool + error model.AppError + + timeGetRequest time.Time + timeStartRequest time.Time + timeSentRequest time.Time + timeGetResponse time.Time + timeSentResponse time.Time + timeInvokeDone time.Time + + responseMetrics *interop.InvokeResponseMetrics + + requestPayloadBytes int64 + requestPayloadReadDuration time.Duration + requestPayloadWriteDuration time.Duration + errorPayloadSizeBytes int + + inflightInvokes int + idleRuntimesCount int + + counter Counter + + getCurrentTime func() time.Time +} + +func NewInvokeMetrics(logger servicelogs.Logger, counter Counter) *invokeMetrics { + return &invokeMetrics{ + getCurrentTime: time.Now, + logger: logger, + counter: counter, + } +} + +func (e *invokeMetrics) AttachInvokeRequest(req interop.InvokeRequest) { + e.invokeReq = req +} + +func (e *invokeMetrics) AttachDependencies(initData interop.InitStaticDataProvider, telemetryEventsAPI interop.EventsAPI) { + invariant.Check(e.invokeReq != nil, invokeMetricsMissDepError) + e.initData = initData + e.telemetryEventsAPI = telemetryEventsAPI +} + +func (e *invokeMetrics) TriggerGetRequest() { + e.timeGetRequest = e.getCurrentTime() +} + +func (e *invokeMetrics) UpdateConcurrencyMetrics(inflightInvokes, idleRuntimesCount int) { + e.inflightInvokes = inflightInvokes + e.idleRuntimesCount = idleRuntimesCount +} + +func (e *invokeMetrics) TriggerStartRequest() { + e.timeStartRequest = e.getCurrentTime() +} + +func (e *invokeMetrics) TriggerSentRequest(requestBytes int64, requestPayloadReadDuration, requestPayloadWriteDuration time.Duration) { + e.timeSentRequest = e.getCurrentTime() + e.requestPayloadBytes = requestBytes + e.requestPayloadReadDuration = requestPayloadReadDuration + e.requestPayloadWriteDuration = requestPayloadWriteDuration +} + +func (e *invokeMetrics) TriggerGetResponse() { + e.timeGetResponse = e.getCurrentTime() +} + +func (e *invokeMetrics) TriggerSentResponse(runtimeResponseSent bool, responseErr model.AppError, streamingMetrics *interop.InvokeResponseMetrics, errorPayloadSizeBytes int) { + e.timeSentResponse = e.getCurrentTime() + e.runtimeResponseSent = runtimeResponseSent + e.error = responseErr + e.responseMetrics = streamingMetrics + e.errorPayloadSizeBytes = errorPayloadSizeBytes +} + +func (e *invokeMetrics) TriggerInvokeDone() (totalMs time.Duration, runMs *time.Duration, initData interop.InitStaticDataProvider) { + e.timeInvokeDone = e.getCurrentTime() + + totalMs = e.timeInvokeDone.Sub(e.timeGetRequest) + + if !e.timeStartRequest.IsZero() { + runMs = ptr.To(e.timeInvokeDone.Sub(e.timeStartRequest)) + } + return totalMs, runMs, e.initData +} + +func (e *invokeMetrics) SendInvokeStartEvent(tracing *interop.TracingCtx) error { + invariant.Check(e.telemetryEventsAPI != nil, invokeMetricsMissDepError) + + return e.telemetryEventsAPI.SendInvokeStart(interop.InvokeStartData{ + InvokeID: e.invokeReq.InvokeID(), + Version: e.initData.FunctionVersion(), + FunctionARN: e.initData.FunctionARN(), + Tracing: tracing, + }) +} + +func (e *invokeMetrics) SendInvokeFinishedEvent(tracing *interop.TracingCtx, xrayErrorCause json.RawMessage) error { + invariant.Check(e.telemetryEventsAPI != nil, invokeMetricsMissDepError) + + if xrayErrorCause != nil { + + err := e.telemetryEventsAPI.SendInternalXRayErrorCause(interop.InternalXRayErrorCauseData{InvokeID: e.invokeReq.InvokeID(), Cause: string(xrayErrorCause)}) + if err != nil { + slog.Error("Failed to send xray error cause", "err", err, "invokeId", e.invokeReq.InvokeID()) + } + } + + spans := []interop.Span{} + spans = e.addLatencySpan(spans) + spans = e.addDurationSpan(spans) + + return e.telemetryEventsAPI.SendReport(interop.ReportData{ + InvokeID: e.invokeReq.InvokeID(), + Status: interop.BuildStatusFromError(e.error), + Metrics: interop.ReportMetrics{ + DurationMs: interop.ReportDurationMs(buildDuration(e.timeStartRequest, e.timeSentResponse)), + }, + Tracing: tracing, + Spans: spans, + ErrorType: buildErrorTypePointer(e.error), + }) +} + +func (e *invokeMetrics) addLatencySpan(spans []interop.Span) []interop.Span { + if e.timeSentRequest.IsZero() || e.timeGetResponse.IsZero() { + return spans + } + + latencySpan := interop.Span{ + Name: ResponseLatencySpanName, + Start: e.timeSentRequest.UTC().Format(telemetry.TimeFormat), + DurationMs: buildDuration(e.timeSentRequest, e.timeGetResponse), + } + + return append(spans, latencySpan) +} + +func (e *invokeMetrics) addDurationSpan(spans []interop.Span) []interop.Span { + if !e.runtimeResponseSent || e.timeGetResponse.IsZero() || e.timeSentResponse.IsZero() { + return spans + } + + durationSpan := interop.Span{ + Name: ResponseDurationSpanName, + Start: e.timeGetResponse.UTC().Format(telemetry.TimeFormat), + DurationMs: buildDuration(e.timeGetResponse, e.timeSentResponse), + } + + return append(spans, durationSpan) +} + +func buildDuration(start time.Time, finish time.Time) float64 { + return float64(finish.Sub(start).Microseconds()) / 1000.0 +} + +func buildErrorTypePointer(err model.AppError) *model.ErrorType { + if err == nil { + return nil + } + + errorType := err.ErrorType() + return &errorType +} + +func (e *invokeMetrics) SendMetrics(invokeErr model.AppError) error { + invariant.Check(e.logger != nil, invokeMetricsMissDepError) + + if e.error != nil && invokeErr == nil { + invariant.Violate("Empty error in SendMetrics after previous error isn't nil") + } + + e.error = invokeErr + + props := e.buildProperties() + dims := e.buildDimensions() + metrics := e.buildMetrics() + + e.logger.Log(servicelogs.InvokeOp, e.timeGetRequest, props, dims, metrics) + + e.updateCounter() + + return nil +} + +func (e *invokeMetrics) updateCounter() { + var proxiedBytes uint64 + proxiedBytes += uint64(e.requestPayloadBytes) + proxiedBytes += uint64(e.errorPayloadSizeBytes) + if e.responseMetrics != nil { + proxiedBytes += uint64(e.responseMetrics.ProducedBytes) + } + + e.counter.AddInvoke(proxiedBytes) +} + +func (e *invokeMetrics) buildProperties() []servicelogs.Property { + var props []servicelogs.Property + + if e.invokeReq != nil { + props = append(props, + servicelogs.Property{ + Name: interop.RequestIdProperty, + Value: e.invokeReq.InvokeID(), + }, + ) + } + + return props +} + +func (e *invokeMetrics) buildDimensions() []servicelogs.Dimension { + var dim []servicelogs.Dimension + + if e.invokeReq != nil { + dim = append(dim, + servicelogs.Dimension{ + Name: RequestResponseModeDimension, + Value: e.invokeReq.ResponseMode(), + }, + ) + } + + if e.responseMetrics != nil { + dim = append(dim, servicelogs.Dimension{ + Name: ResponseModeDimension, + Value: string(e.responseMetrics.FunctionResponseMode), + }) + } + + return dim +} + +func (e *invokeMetrics) buildMetrics() []servicelogs.Metric { + totalDuration := e.timeInvokeDone.Sub(e.timeGetRequest) + runDuration := time.Duration(0) + if !e.timeSentResponse.IsZero() { + runDuration = e.timeSentResponse.Sub(e.timeStartRequest) + } + platformOverhead := totalDuration - runDuration + + metrics := []servicelogs.Metric{ + servicelogs.Timer(interop.TotalDurationMetric, totalDuration), + servicelogs.Timer(interop.PlatformOverheadDurationMetric, platformOverhead), + servicelogs.Counter(InflightRequestCountMetric, float64(e.inflightInvokes)), + servicelogs.Counter(IdleRuntimesCountMetric, float64(e.idleRuntimesCount)), + } + + if e.responseMetrics != nil { + metrics = append(metrics, + servicelogs.Counter(ResponsePayloadSizeBytesMetric, float64(e.responseMetrics.ProducedBytes)), + servicelogs.Timer(ResponseThrottledDurationMetric, e.responseMetrics.TimeShaped), + servicelogs.Counter(ResponseThroughputMetric, float64(e.responseMetrics.OutboundThroughputBps)), + servicelogs.Timer(ResponsePayloadReadDurationMetric, e.responseMetrics.ResponsePayloadReadDuration), + servicelogs.Timer(ResponsePayloadWriteDurationMetric, e.responseMetrics.ResponsePayloadWriteDuration), + ) + } + + if e.errorPayloadSizeBytes != 0 { + metrics = append(metrics, + servicelogs.Counter(ErrorPayloadSizeBytesMetric, float64(e.errorPayloadSizeBytes)), + ) + } + + if runDuration > 0 { + metrics = append(metrics, servicelogs.Timer(FunctionDurationMetric, runDuration)) + } + + if !e.timeSentRequest.IsZero() { + metrics = append(metrics, + servicelogs.Timer(RequestSendDurationMetric, e.timeSentRequest.Sub(e.timeStartRequest)), + servicelogs.Counter(RequestPayloadSizeBytesMetric, float64(e.requestPayloadBytes)), + servicelogs.Timer(RequestPayloadReadDurationMetric, e.requestPayloadReadDuration), + servicelogs.Timer(RequestPayloadWriteDurationMetric, e.requestPayloadWriteDuration), + ) + } + + if !e.timeGetResponse.IsZero() { + metrics = append(metrics, servicelogs.Timer(ResponseLatencyMetric, e.timeGetResponse.Sub(e.timeSentRequest))) + } + + if e.runtimeResponseSent { + metrics = append(metrics, servicelogs.Timer(ResponseDurationMetric, e.timeSentResponse.Sub(e.timeGetResponse))) + } + + var clientErrCnt, customerErrCnt, platformErrCnt, nonCustomerErrCnt float64 + + switch e.error.(type) { + case model.ClientError: + clientErrCnt = 1 + if e.error.ErrorType() != model.ErrorRuntimeUnavailable { + + nonCustomerErrCnt = 1 + } + metrics = append(metrics, + servicelogs.Counter(fmt.Sprintf(interop.ClientErrorReasonTemplate, e.error.ErrorType()), 1.0), + ) + case model.CustomerError: + customerErrCnt = 1 + metrics = append(metrics, + servicelogs.Counter(fmt.Sprintf(interop.CustomerErrorReasonTemplate, e.error.ErrorType()), 1.0), + ) + case model.PlatformError: + platformErrCnt = 1 + nonCustomerErrCnt = 1 + metrics = append(metrics, + servicelogs.Counter(fmt.Sprintf(interop.PlatformErrorReasonTemplate, e.error.ErrorType()), 1.0), + ) + } + + metrics = append(metrics, + servicelogs.Counter(interop.ClientErrorMetric, clientErrCnt), + servicelogs.Counter(interop.CustomerErrorMetric, customerErrCnt), + servicelogs.Counter(interop.PlatformErrorMetric, platformErrCnt), + servicelogs.Counter(interop.NonCustomerErrorMetric, nonCustomerErrCnt), + ) + return metrics +} diff --git a/internal/lambda-managed-instances/invoke/metrics_test.go b/internal/lambda-managed-instances/invoke/metrics_test.go new file mode 100644 index 0000000..d146845 --- /dev/null +++ b/internal/lambda-managed-instances/invoke/metrics_test.go @@ -0,0 +1,654 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package invoke + +import ( + "cmp" + "encoding/json" + "fmt" + "slices" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + intmodel "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/ptr" + rapimodel "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/servicelogs" +) + +var ( + eventInvokeId = "invoke-id" + eventFunctionVersion = "function-version" + eventFunctionArn = "function-arn" + eventTraceId = "Root=12345;Parent=67890;Sampled=1;Lineage=22222" + eventRuntimeErr = model.NewCustomerError(model.ErrorRuntimeUnknown) + eventTimeoutErr = model.NewCustomerError(model.ErrorSandboxTimedout) + eventInternalErr = model.NewPlatformError(nil, model.ErrorReasonRuntimeExecFailed) + + invokeRequestSize int64 = 100 + + dummyTracingCtx = &interop.TracingCtx{ + SpanID: "", + Type: rapimodel.XRayTracingType, + Value: eventTraceId, + } + + dummyExpectedReportData = interop.ReportData{ + InvokeID: eventInvokeId, + Metrics: interop.ReportMetrics{ + DurationMs: interop.ReportDurationMs(3000), + }, + Tracing: &interop.TracingCtx{ + SpanID: "", + Type: rapimodel.XRayTracingType, + Value: eventTraceId, + }, + Spans: []interop.Span{ + { + Name: "responseLatency", + Start: "0001-01-01T00:00:01.000Z", + DurationMs: 1000.0, + }, + { + Name: "responseDuration", + Start: "0001-01-01T00:00:02.000Z", + DurationMs: 1000.0, + }, + }, + ErrorType: nil, + } +) + +type invokeMetricsMocks struct { + eventsApi interop.MockEventsAPI + initData interop.MockInitStaticDataProvider + invokeReq interop.MockInvokeRequest + logger servicelogs.MockLogger + counter MockCounter + responseMetrics interop.InvokeResponseMetrics + timeStamp time.Time + error model.AppError +} + +func createInvokeEventsMocks(t *testing.T) *invokeMetricsMocks { + mocks := invokeMetricsMocks{ + eventsApi: interop.MockEventsAPI{}, + initData: interop.MockInitStaticDataProvider{}, + invokeReq: interop.MockInvokeRequest{}, + logger: servicelogs.MockLogger{}, + counter: MockCounter{}, + } + + mocks.invokeReq.On("InvokeID").Return(eventInvokeId) + mocks.invokeReq.On("ResponseMode").Return("Streaming").Maybe() + + return &mocks +} + +func checkMocksExpectations(t *testing.T, mocks *invokeMetricsMocks) { + mocks.eventsApi.AssertExpectations(t) + mocks.initData.AssertExpectations(t) + mocks.invokeReq.AssertExpectations(t) + mocks.logger.AssertExpectations(t) +} + +func createInvokeEventsAndHijackGetTime(mocks *invokeMetricsMocks) *invokeMetrics { + ev := NewInvokeMetrics(&mocks.logger, &mocks.counter) + ev.AttachInvokeRequest(&mocks.invokeReq) + ev.AttachDependencies(&mocks.initData, &mocks.eventsApi) + ev.getCurrentTime = func() time.Time { + return mocks.timeStamp + } + + return ev +} + +func Test_invokeMetrics_SendInvokeStart(t *testing.T) { + mocks := createInvokeEventsMocks(t) + ev := createInvokeEventsAndHijackGetTime(mocks) + + ev.TriggerStartRequest() + + mocks.initData.On("FunctionVersion").Return(eventFunctionVersion) + mocks.initData.On("FunctionARN").Return(eventFunctionArn) + mocks.eventsApi.On("SendInvokeStart", mock.MatchedBy(func(arg interop.InvokeStartData) bool { + expected := interop.InvokeStartData{ + InvokeID: eventInvokeId, + Version: eventFunctionVersion, + FunctionARN: eventFunctionArn, + Tracing: &interop.TracingCtx{ + SpanID: "", + Type: rapimodel.XRayTracingType, + Value: eventTraceId, + }, + } + + return assert.Equal(t, expected, arg) + })).Return(nil) + + err := ev.SendInvokeStartEvent(dummyTracingCtx) + assert.NoError(t, err) + checkMocksExpectations(t, mocks) +} + +func Test_invokeMetrics_SendReport_FullCycle(t *testing.T) { + tests := []struct { + name string + runtimeError model.AppError + status interop.ResponseStatus + }{ + { + name: "InvokeResponse", + runtimeError: nil, + status: interop.Success, + }, + { + name: "InvokeError", + runtimeError: eventRuntimeErr, + status: interop.Error, + }, + { + name: "InvokeFailure", + runtimeError: eventInternalErr, + status: interop.Failure, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mocks := createInvokeEventsMocks(t) + ev := createInvokeEventsAndHijackGetTime(mocks) + + ev.TriggerStartRequest() + mocks.timeStamp = mocks.timeStamp.Add(time.Second) + ev.TriggerSentRequest(invokeRequestSize, 11*time.Microsecond, 12*time.Microsecond) + mocks.timeStamp = mocks.timeStamp.Add(time.Second) + ev.TriggerGetResponse() + mocks.timeStamp = mocks.timeStamp.Add(time.Second) + ev.TriggerSentResponse(true, tt.runtimeError, nil, 0) + + mocks.eventsApi.On("SendReport", mock.MatchedBy(func(arg interop.ReportData) bool { + expected := dummyExpectedReportData + expected.Status = tt.status + if tt.runtimeError != nil { + errType := tt.runtimeError.ErrorType() + expected.ErrorType = &errType + } + + return assert.Equal(t, expected, arg) + })).Return(nil) + + err := ev.SendInvokeFinishedEvent(dummyTracingCtx, nil) + assert.NoError(t, err) + checkMocksExpectations(t, mocks) + }) + } +} + +func Test_invokeMetrics_SendReport_NoResponse(t *testing.T) { + mocks := createInvokeEventsMocks(t) + ev := createInvokeEventsAndHijackGetTime(mocks) + + ev.TriggerStartRequest() + mocks.timeStamp = mocks.timeStamp.Add(time.Second) + ev.TriggerSentRequest(invokeRequestSize, 11*time.Microsecond, 12*time.Microsecond) + mocks.timeStamp = mocks.timeStamp.Add(time.Second) + ev.TriggerSentResponse(false, eventTimeoutErr, nil, 0) + + mocks.eventsApi.On("SendReport", mock.MatchedBy(func(arg interop.ReportData) bool { + expected := dummyExpectedReportData + expected.Status = interop.Timeout + errType := eventTimeoutErr.ErrorType() + expected.ErrorType = &errType + + expected.Metrics.DurationMs = interop.ReportDurationMs(2000) + expected.Spans = []interop.Span{} + + return assert.Equal(t, expected, arg) + })).Return(nil) + + err := ev.SendInvokeFinishedEvent(dummyTracingCtx, nil) + assert.NoError(t, err) + checkMocksExpectations(t, mocks) +} + +func Test_invokeMetrics_SendReport_ResponseWithUnfinishedBody(t *testing.T) { + mocks := createInvokeEventsMocks(t) + ev := createInvokeEventsAndHijackGetTime(mocks) + + ev.TriggerStartRequest() + mocks.timeStamp = mocks.timeStamp.Add(time.Second) + ev.TriggerSentRequest(invokeRequestSize, 11*time.Microsecond, 12*time.Microsecond) + mocks.timeStamp = mocks.timeStamp.Add(time.Second) + ev.TriggerGetResponse() + mocks.timeStamp = mocks.timeStamp.Add(time.Second) + ev.TriggerSentResponse(false, eventTimeoutErr, nil, 0) + + mocks.eventsApi.On("SendReport", mock.MatchedBy(func(arg interop.ReportData) bool { + expected := dummyExpectedReportData + expected.Status = interop.Timeout + errType := eventTimeoutErr.ErrorType() + expected.ErrorType = &errType + + expected.Metrics.DurationMs = interop.ReportDurationMs(3000) + expected.Spans = []interop.Span{ + { + Name: "responseLatency", + Start: "0001-01-01T00:00:01.000Z", + DurationMs: 1000.0, + }, + } + + return assert.Equal(t, expected, arg) + })).Return(nil) + + err := ev.SendInvokeFinishedEvent(dummyTracingCtx, nil) + assert.NoError(t, err) + checkMocksExpectations(t, mocks) +} + +func Test_invokeMetrics_SendReport_WithXrayErrorCause(t *testing.T) { + mocks := createInvokeEventsMocks(t) + ev := createInvokeEventsAndHijackGetTime(mocks) + + ev.TriggerStartRequest() + mocks.timeStamp = mocks.timeStamp.Add(time.Second) + ev.TriggerSentRequest(invokeRequestSize, 11*time.Microsecond, 12*time.Microsecond) + mocks.timeStamp = mocks.timeStamp.Add(time.Second) + ev.TriggerGetResponse() + mocks.timeStamp = mocks.timeStamp.Add(time.Second) + ev.TriggerSentResponse(true, eventRuntimeErr, nil, 0) + + xrayErrorCause := json.RawMessage(`{"exceptions":[{"message":"Null pointer exception","type":"RuntimeError"}],"working_directory":"","paths":[]}`) + mocks.eventsApi.On("SendInternalXRayErrorCause", mock.MatchedBy(func(arg interop.InternalXRayErrorCauseData) bool { + expected := interop.InternalXRayErrorCauseData{ + InvokeID: eventInvokeId, + Cause: string(xrayErrorCause), + } + return assert.Equal(t, expected, arg) + })).Return(nil) + + mocks.eventsApi.On("SendReport", mock.MatchedBy(func(arg interop.ReportData) bool { + expected := dummyExpectedReportData + expected.Status = interop.Error + errType := eventRuntimeErr.ErrorType() + expected.ErrorType = &errType + + return assert.Equal(t, expected, arg) + })).Return(nil) + + err := ev.SendInvokeFinishedEvent(dummyTracingCtx, xrayErrorCause) + assert.NoError(t, err) + checkMocksExpectations(t, mocks) +} + +func Test_invokeMetrics_TriggerInvokeDone(t *testing.T) { + tests := []struct { + name string + setTimeStartRequest bool + expectedRunMs *time.Duration + }{ + { + name: "timeStartRequest_empty", + setTimeStartRequest: false, + expectedRunMs: nil, + }, + { + name: "timeStartRequest_non_empty", + setTimeStartRequest: true, + expectedRunMs: ptr.To(3 * time.Second), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + im := NewInvokeMetrics(nil, NewMockCounter(t)) + now := time.Now() + im.getCurrentTime = func() time.Time { + return now + } + im.TriggerGetRequest() + + now = now.Add(2 * time.Second) + if tt.setTimeStartRequest { + im.TriggerStartRequest() + } + + now = now.Add(3 * time.Second) + totalMs, runMs, initData := im.TriggerInvokeDone() + + assert.Equal(t, 5*time.Second, totalMs) + assert.Equal(t, tt.expectedRunMs, runMs) + assert.Nil(t, initData) + }) + } +} + +func Test_invokeMetrics_ServiceLogs(t *testing.T) { + tests := []struct { + name string + expectedBytes uint64 + metricFlow func(ev *invokeMetrics, mocks *invokeMetricsMocks) + expectedProps []servicelogs.Property + expectedDims []servicelogs.Dimension + expectedMetrics []servicelogs.Metric + }{ + { + name: "minimal_invoke_flow", + expectedBytes: 0, + metricFlow: func(ev *invokeMetrics, mocks *invokeMetricsMocks) { + mocks.timeStamp = mocks.timeStamp.Add(time.Second) + mocks.error = model.NewClientError(nil, model.ErrorSeverityInvalid, model.ErrorInvalidFunctionVersion) + }, + expectedProps: []servicelogs.Property{}, + expectedDims: []servicelogs.Dimension{}, + expectedMetrics: []servicelogs.Metric{ + {Type: servicelogs.TimerType, Key: "TotalDuration", Value: 1000000}, + {Type: servicelogs.CounterType, Key: "InflightRequestCount", Value: 0}, + {Type: servicelogs.CounterType, Key: "IdleRuntimesCount", Value: 0}, + {Type: servicelogs.TimerType, Key: "PlatformOverheadDuration", Value: 1000000}, + {Type: servicelogs.CounterType, Key: "ClientError", Value: 1}, + {Type: servicelogs.CounterType, Key: "CustomerError", Value: 0}, + {Type: servicelogs.CounterType, Key: "PlatformError", Value: 0}, + {Type: servicelogs.CounterType, Key: "ClientErrorReason-ErrInvalidFunctionVersion", Value: 1}, + {Type: servicelogs.CounterType, Key: "NonCustomerError", Value: 1}, + }, + }, + { + name: "bad_request_invoke_flow", + expectedBytes: 0, + metricFlow: func(ev *invokeMetrics, mocks *invokeMetricsMocks) { + mocks.timeStamp = mocks.timeStamp.Add(time.Second) + ev.AttachInvokeRequest(&mocks.invokeReq) + mocks.error = model.NewClientError(nil, model.ErrorSeverityError, model.ErrorInitIncomplete) + }, + expectedProps: []servicelogs.Property{ + {Name: "RequestId", Value: "invoke-id"}, + }, + expectedDims: []servicelogs.Dimension{ + {Name: "RequestMode", Value: "Streaming"}, + }, + expectedMetrics: []servicelogs.Metric{ + {Type: servicelogs.TimerType, Key: "TotalDuration", Value: 1000000}, + {Type: servicelogs.CounterType, Key: "InflightRequestCount", Value: 0}, + {Type: servicelogs.CounterType, Key: "IdleRuntimesCount", Value: 0}, + {Type: servicelogs.TimerType, Key: "PlatformOverheadDuration", Value: 1000000}, + {Type: servicelogs.CounterType, Key: "ClientError", Value: 1}, + {Type: servicelogs.CounterType, Key: "CustomerError", Value: 0}, + {Type: servicelogs.CounterType, Key: "PlatformError", Value: 0}, + {Type: servicelogs.CounterType, Key: "ClientErrorReason-Client.InitIncomplete", Value: 1}, + {Type: servicelogs.CounterType, Key: "NonCustomerError", Value: 1}, + }, + }, + { + name: "runtime_unavailable_error", + expectedBytes: 0, + metricFlow: func(ev *invokeMetrics, mocks *invokeMetricsMocks) { + mocks.timeStamp = mocks.timeStamp.Add(time.Second) + ev.AttachInvokeRequest(&mocks.invokeReq) + ev.AttachDependencies(&mocks.initData, &mocks.eventsApi) + ev.UpdateConcurrencyMetrics(5, 3) + mocks.error = model.NewClientError(nil, model.ErrorSeverityError, model.ErrorRuntimeUnavailable) + }, + expectedProps: []servicelogs.Property{ + {Name: "RequestId", Value: "invoke-id"}, + }, + expectedDims: []servicelogs.Dimension{ + {Name: "RequestMode", Value: "Streaming"}, + }, + expectedMetrics: []servicelogs.Metric{ + {Type: servicelogs.TimerType, Key: "TotalDuration", Value: 1000000}, + {Type: servicelogs.CounterType, Key: "InflightRequestCount", Value: 5}, + {Type: servicelogs.CounterType, Key: "IdleRuntimesCount", Value: 3}, + {Type: servicelogs.TimerType, Key: "PlatformOverheadDuration", Value: 1000000}, + {Type: servicelogs.CounterType, Key: "ClientError", Value: 1}, + {Type: servicelogs.CounterType, Key: "CustomerError", Value: 0}, + {Type: servicelogs.CounterType, Key: "PlatformError", Value: 0}, + {Type: servicelogs.CounterType, Key: "ClientErrorReason-Runtime.Unavailable", Value: 1}, + {Type: servicelogs.CounterType, Key: "NonCustomerError", Value: 0}, + }, + }, + { + name: "runtime_timeout_flow", + expectedBytes: 100, + metricFlow: func(ev *invokeMetrics, mocks *invokeMetricsMocks) { + mocks.timeStamp = mocks.timeStamp.Add(time.Second) + ev.AttachInvokeRequest(&mocks.invokeReq) + ev.AttachDependencies(&mocks.initData, &mocks.eventsApi) + ev.UpdateConcurrencyMetrics(5, 3) + mocks.timeStamp = mocks.timeStamp.Add(time.Second) + ev.TriggerStartRequest() + mocks.timeStamp = mocks.timeStamp.Add(time.Second) + ev.TriggerSentRequest(100, 11*time.Microsecond, 12*time.Microsecond) + mocks.timeStamp = mocks.timeStamp.Add(time.Second) + mocks.error = model.NewCustomerError(model.ErrorSandboxTimedout) + ev.TriggerSentResponse(false, mocks.error, nil, 0) + }, + expectedProps: []servicelogs.Property{ + {Name: "RequestId", Value: "invoke-id"}, + }, + expectedDims: []servicelogs.Dimension{ + {Name: "RequestMode", Value: "Streaming"}, + }, + expectedMetrics: []servicelogs.Metric{ + {Type: servicelogs.TimerType, Key: "TotalDuration", Value: 4000000}, + {Type: servicelogs.CounterType, Key: "InflightRequestCount", Value: 5}, + {Type: servicelogs.CounterType, Key: "IdleRuntimesCount", Value: 3}, + {Type: servicelogs.TimerType, Key: "PlatformOverheadDuration", Value: 2000000}, + {Type: servicelogs.TimerType, Key: "FunctionDuration", Value: 2000000}, + {Type: servicelogs.TimerType, Key: "RequestSendDuration", Value: 1000000}, + {Type: servicelogs.CounterType, Key: "RequestPayloadSizeBytes", Value: 100}, + {Type: servicelogs.TimerType, Key: "RequestPayloadReadDuration", Value: 11}, + {Type: servicelogs.TimerType, Key: "RequestPayloadWriteDuration", Value: 12}, + {Type: servicelogs.CounterType, Key: "ClientError", Value: 0}, + {Type: servicelogs.CounterType, Key: "CustomerError", Value: 1}, + {Type: servicelogs.CounterType, Key: "PlatformError", Value: 0}, + {Type: servicelogs.CounterType, Key: "CustomerErrorReason-Sandbox.Timedout", Value: 1}, + {Type: servicelogs.CounterType, Key: "NonCustomerError", Value: 0}, + }, + }, + { + name: "full_invoke_flow", + expectedBytes: 200, + metricFlow: func(ev *invokeMetrics, mocks *invokeMetricsMocks) { + ev.AttachInvokeRequest(&mocks.invokeReq) + ev.AttachDependencies(&mocks.initData, &mocks.eventsApi) + mocks.timeStamp = mocks.timeStamp.Add(time.Second) + ev.UpdateConcurrencyMetrics(5, 3) + ev.TriggerStartRequest() + mocks.timeStamp = mocks.timeStamp.Add(time.Second) + ev.TriggerSentRequest(100, 11*time.Microsecond, 12*time.Microsecond) + mocks.timeStamp = mocks.timeStamp.Add(time.Second) + ev.TriggerGetResponse() + mocks.timeStamp = mocks.timeStamp.Add(time.Second) + ev.TriggerSentResponse(true, nil, &mocks.responseMetrics, 0) + mocks.timeStamp = mocks.timeStamp.Add(time.Second) + }, + expectedProps: []servicelogs.Property{ + {Name: "RequestId", Value: "invoke-id"}, + }, + expectedDims: []servicelogs.Dimension{ + {Name: "RequestMode", Value: "Streaming"}, + {Name: "ResponseMode", Value: "Streaming"}, + }, + expectedMetrics: []servicelogs.Metric{ + {Type: servicelogs.TimerType, Key: "TotalDuration", Value: 5000000}, + {Type: servicelogs.CounterType, Key: "InflightRequestCount", Value: 5}, + {Type: servicelogs.CounterType, Key: "IdleRuntimesCount", Value: 3}, + {Type: servicelogs.CounterType, Key: "ResponsePayloadSizeBytes", Value: 100}, + {Type: servicelogs.TimerType, Key: "ResponseThrottledDuration", Value: 1000000}, + {Type: servicelogs.CounterType, Key: "ResponseThroughput", Value: 100}, + {Type: servicelogs.TimerType, Key: "ResponsePayloadReadDuration", Value: 13}, + {Type: servicelogs.TimerType, Key: "ResponsePayloadWriteDuration", Value: 14}, + {Type: servicelogs.TimerType, Key: "PlatformOverheadDuration", Value: 2000000}, + {Type: servicelogs.TimerType, Key: "FunctionDuration", Value: 3000000}, + {Type: servicelogs.TimerType, Key: "RequestSendDuration", Value: 1000000}, + {Type: servicelogs.CounterType, Key: "RequestPayloadSizeBytes", Value: 100}, + {Type: servicelogs.TimerType, Key: "RequestPayloadReadDuration", Value: 11}, + {Type: servicelogs.TimerType, Key: "RequestPayloadWriteDuration", Value: 12}, + {Type: servicelogs.TimerType, Key: "ResponseLatency", Value: 1000000}, + {Type: servicelogs.TimerType, Key: "ResponseDuration", Value: 1000000}, + {Type: servicelogs.CounterType, Key: "ClientError", Value: 0}, + {Type: servicelogs.CounterType, Key: "CustomerError", Value: 0}, + {Type: servicelogs.CounterType, Key: "PlatformError", Value: 0}, + {Type: servicelogs.CounterType, Key: "NonCustomerError", Value: 0}, + }, + }, + { + name: "full_invoke_error_flow", + expectedBytes: 300, + metricFlow: func(ev *invokeMetrics, mocks *invokeMetricsMocks) { + ev.AttachInvokeRequest(&mocks.invokeReq) + ev.AttachDependencies(&mocks.initData, &mocks.eventsApi) + mocks.timeStamp = mocks.timeStamp.Add(time.Second) + ev.UpdateConcurrencyMetrics(2, 1) + ev.TriggerStartRequest() + mocks.timeStamp = mocks.timeStamp.Add(time.Second) + ev.TriggerSentRequest(100, 11*time.Microsecond, 12*time.Microsecond) + mocks.timeStamp = mocks.timeStamp.Add(time.Second) + ev.TriggerGetResponse() + mocks.timeStamp = mocks.timeStamp.Add(time.Second) + mocks.error = model.NewCustomerError(model.ErrorRuntimeUnknown) + ev.TriggerSentResponse(true, mocks.error, &interop.InvokeResponseMetrics{ProducedBytes: 100, FunctionResponseMode: runtimeResponseModeStreaming}, 100) + mocks.timeStamp = mocks.timeStamp.Add(time.Second) + }, + expectedProps: []servicelogs.Property{ + {Name: "RequestId", Value: "invoke-id"}, + }, + expectedDims: []servicelogs.Dimension{ + {Name: "RequestMode", Value: "Streaming"}, + {Name: "ResponseMode", Value: "streaming"}, + }, + expectedMetrics: []servicelogs.Metric{ + {Type: servicelogs.TimerType, Key: "TotalDuration", Value: 5000000}, + {Type: servicelogs.CounterType, Key: "InflightRequestCount", Value: 2}, + {Type: servicelogs.CounterType, Key: "IdleRuntimesCount", Value: 1}, + {Type: servicelogs.CounterType, Key: "ResponsePayloadSizeBytes", Value: 100}, + {Type: servicelogs.TimerType, Key: "ResponseThrottledDuration", Value: 0}, + {Type: servicelogs.CounterType, Key: "ResponseThroughput", Value: 0}, + {Type: servicelogs.TimerType, Key: "ResponsePayloadReadDuration", Value: 0}, + {Type: servicelogs.TimerType, Key: "ResponsePayloadWriteDuration", Value: 0}, + {Type: servicelogs.CounterType, Key: "ErrorPayloadSizeBytes", Value: 100}, + {Type: servicelogs.TimerType, Key: "PlatformOverheadDuration", Value: 2000000}, + {Type: servicelogs.TimerType, Key: "FunctionDuration", Value: 3000000}, + {Type: servicelogs.TimerType, Key: "RequestSendDuration", Value: 1000000}, + {Type: servicelogs.CounterType, Key: "RequestPayloadSizeBytes", Value: 100}, + {Type: servicelogs.TimerType, Key: "RequestPayloadReadDuration", Value: 11}, + {Type: servicelogs.TimerType, Key: "RequestPayloadWriteDuration", Value: 12}, + {Type: servicelogs.TimerType, Key: "ResponseLatency", Value: 1000000}, + {Type: servicelogs.TimerType, Key: "ResponseDuration", Value: 1000000}, + {Type: servicelogs.CounterType, Key: "ClientError", Value: 0}, + {Type: servicelogs.CounterType, Key: "CustomerError", Value: 1}, + {Type: servicelogs.CounterType, Key: "PlatformError", Value: 0}, + {Type: servicelogs.CounterType, Key: "CustomerErrorReason-Runtime.Unknown", Value: 1}, + {Type: servicelogs.CounterType, Key: "NonCustomerError", Value: 0}, + }, + }, + } + + tupleSortFunc := func(a, b servicelogs.Tuple) int { + return cmp.Compare(a.Name, b.Name) + } + + metricsSortFunc := func(a, b servicelogs.Metric) int { + return cmp.Compare(a.Key, b.Key) + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mocks := &invokeMetricsMocks{ + eventsApi: interop.MockEventsAPI{}, + initData: interop.MockInitStaticDataProvider{}, + invokeReq: interop.MockInvokeRequest{}, + logger: servicelogs.MockLogger{}, + counter: MockCounter{}, + } + mocks.counter.On("AddInvoke", tt.expectedBytes) + + ev := NewInvokeMetrics(&mocks.logger, &mocks.counter) + ev.getCurrentTime = func() time.Time { + return mocks.timeStamp + } + + mocks.responseMetrics = interop.InvokeResponseMetrics{ + TimeShaped: time.Second, + ProducedBytes: 100, + OutboundThroughputBps: 100, + FunctionResponseMode: "Streaming", + ResponsePayloadReadDuration: 13 * time.Microsecond, + ResponsePayloadWriteDuration: 14 * time.Microsecond, + } + + mocks.invokeReq.On("InvokeID").Return("invoke-id").Maybe() + mocks.invokeReq.On("ResponseMode").Return("Streaming").Maybe().Maybe() + mocks.invokeReq.On("TraceId").Return("Root=12345;Parent=67890;Sampled=1;Lineage=22222").Maybe() + + mocks.initData.On("XRayTracingMode").Return(intmodel.XRayTracingModePassThrough).Maybe() + mocks.initData.On("MemorySizeMB").Return(uint64(128)).Maybe() + mocks.initData.On("FunctionARN").Return("function-arn").Maybe() + mocks.initData.On("FunctionVersionID").Return("function-version-id").Maybe() + mocks.initData.On("FunctionTimeout").Return(time.Second).Maybe() + mocks.initData.On("RuntimeVersion").Return("python3.9").Maybe() + mocks.initData.On("ArtefactType").Return(intmodel.ArtefactTypeOCI).Maybe() + mocks.initData.On("AmiId").Return("ami-1234567").Maybe() + mocks.initData.On("AvailabilityZoneId").Return("us-west-2").Maybe() + + mocks.logger.On("Log", + mock.MatchedBy(func(op servicelogs.Operation) bool { + return assert.Equal(t, servicelogs.InvokeOp, op) + }), + mock.AnythingOfType("time.Time"), + mock.MatchedBy(func(props []servicelogs.Property) bool { + slices.SortFunc(props, tupleSortFunc) + slices.SortFunc(tt.expectedProps, tupleSortFunc) + assert.Equal(t, len(tt.expectedProps), len(props)) + for i := range len(tt.expectedProps) { + assert.Equal(t, tt.expectedProps[i].Name, props[i].Name) + assert.Equal(t, tt.expectedProps[i].Value, props[i].Value) + } + + return true + }), + mock.MatchedBy(func(dims []servicelogs.Dimension) bool { + slices.SortFunc(dims, tupleSortFunc) + slices.SortFunc(tt.expectedDims, tupleSortFunc) + assert.Equal(t, len(tt.expectedDims), len(dims)) + for i := range len(tt.expectedDims) { + assert.Equal(t, tt.expectedDims[i].Name, dims[i].Name) + assert.Equal(t, tt.expectedDims[i].Value, dims[i].Value) + } + + return true + }), + mock.MatchedBy(func(metrics []servicelogs.Metric) bool { + slices.SortFunc(metrics, metricsSortFunc) + slices.SortFunc(tt.expectedMetrics, metricsSortFunc) + assert.Equal(t, len(tt.expectedMetrics), len(metrics)) + for i := range len(tt.expectedMetrics) { + require.Equal(t, tt.expectedMetrics[i].Key, metrics[i].Key) + require.Equal(t, tt.expectedMetrics[i].Type, metrics[i].Type, fmt.Sprintf("wrong format for %s", metrics[i].Key)) + require.Equal(t, tt.expectedMetrics[i].Value, metrics[i].Value, fmt.Sprintf("wrong value for %s", metrics[i].Key)) + } + + return true + }), + ).Once() + + mocks.timeStamp = time.Now() + ev.TriggerGetRequest() + + tt.metricFlow(ev, mocks) + + ev.TriggerInvokeDone() + require.NoError(t, ev.SendMetrics(mocks.error)) + checkMocksExpectations(t, mocks) + }) + } +} diff --git a/internal/lambda-managed-instances/invoke/mock_counter.go b/internal/lambda-managed-instances/invoke/mock_counter.go new file mode 100644 index 0000000..5500d4e --- /dev/null +++ b/internal/lambda-managed-instances/invoke/mock_counter.go @@ -0,0 +1,26 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package invoke + +import mock "github.com/stretchr/testify/mock" + +type MockCounter struct { + mock.Mock +} + +func (_m *MockCounter) AddInvoke(proxiedBytes uint64) { + _m.Called(proxiedBytes) +} + +func NewMockCounter(t interface { + mock.TestingT + Cleanup(func()) +}) *MockCounter { + mock := &MockCounter{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/lambda-managed-instances/invoke/mock_error_for_invoker.go b/internal/lambda-managed-instances/invoke/mock_error_for_invoker.go new file mode 100644 index 0000000..1b5b3f7 --- /dev/null +++ b/internal/lambda-managed-instances/invoke/mock_error_for_invoker.go @@ -0,0 +1,93 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package invoke + +import ( + mock "github.com/stretchr/testify/mock" + model "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" +) + +type MockErrorForInvoker struct { + mock.Mock +} + +func (_m *MockErrorForInvoker) ErrorCategory() model.ErrorCategory { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for ErrorCategory") + } + + var r0 model.ErrorCategory + if rf, ok := ret.Get(0).(func() model.ErrorCategory); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(model.ErrorCategory) + } + + return r0 +} + +func (_m *MockErrorForInvoker) ErrorDetails() string { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for ErrorDetails") + } + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +func (_m *MockErrorForInvoker) ErrorType() model.ErrorType { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for ErrorType") + } + + var r0 model.ErrorType + if rf, ok := ret.Get(0).(func() model.ErrorType); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(model.ErrorType) + } + + return r0 +} + +func (_m *MockErrorForInvoker) ReturnCode() int { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for ReturnCode") + } + + var r0 int + if rf, ok := ret.Get(0).(func() int); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(int) + } + + return r0 +} + +func NewMockErrorForInvoker(t interface { + mock.TestingT + Cleanup(func()) +}) *MockErrorForInvoker { + mock := &MockErrorForInvoker{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/lambda-managed-instances/invoke/mock_invoke_response_sender.go b/internal/lambda-managed-instances/invoke/mock_invoke_response_sender.go new file mode 100644 index 0000000..c2974ae --- /dev/null +++ b/internal/lambda-managed-instances/invoke/mock_invoke_response_sender.go @@ -0,0 +1,79 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package invoke + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + interop "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + + time "time" +) + +type MockInvokeResponseSender struct { + mock.Mock +} + +func (_m *MockInvokeResponseSender) ErrorPayloadSizeBytes() int { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for ErrorPayloadSizeBytes") + } + + var r0 int + if rf, ok := ret.Get(0).(func() int); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(int) + } + + return r0 +} + +func (_m *MockInvokeResponseSender) SendError(_a0 ErrorForInvoker, _a1 interop.InitStaticDataProvider) { + _m.Called(_a0, _a1) +} + +func (_m *MockInvokeResponseSender) SendErrorTrailers(_a0 ErrorForInvoker, _a1 InvokeBodyResponseStatus) { + _m.Called(_a0, _a1) +} + +func (_m *MockInvokeResponseSender) SendRuntimeResponseBody(ctx context.Context, runtimeResp RuntimeResponseRequest, functionTimeout time.Duration) SendResponseBodyResult { + ret := _m.Called(ctx, runtimeResp, functionTimeout) + + if len(ret) == 0 { + panic("no return value specified for SendRuntimeResponseBody") + } + + var r0 SendResponseBodyResult + if rf, ok := ret.Get(0).(func(context.Context, RuntimeResponseRequest, time.Duration) SendResponseBodyResult); ok { + r0 = rf(ctx, runtimeResp, functionTimeout) + } else { + r0 = ret.Get(0).(SendResponseBodyResult) + } + + return r0 +} + +func (_m *MockInvokeResponseSender) SendRuntimeResponseHeaders(initData interop.InitStaticDataProvider, contentType string, responseMode string) { + _m.Called(initData, contentType, responseMode) +} + +func (_m *MockInvokeResponseSender) SendRuntimeResponseTrailers(_a0 RuntimeResponseRequest) { + _m.Called(_a0) +} + +func NewMockInvokeResponseSender(t interface { + mock.TestingT + Cleanup(func()) +}) *MockInvokeResponseSender { + mock := &MockInvokeResponseSender{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/lambda-managed-instances/invoke/mock_responder_factory_func.go b/internal/lambda-managed-instances/invoke/mock_responder_factory_func.go new file mode 100644 index 0000000..a1d020d --- /dev/null +++ b/internal/lambda-managed-instances/invoke/mock_responder_factory_func.go @@ -0,0 +1,46 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package invoke + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + interop "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" +) + +type MockResponderFactoryFunc struct { + mock.Mock +} + +func (_m *MockResponderFactoryFunc) Execute(_a0 context.Context, _a1 interop.InvokeRequest) InvokeResponseSender { + ret := _m.Called(_a0, _a1) + + if len(ret) == 0 { + panic("no return value specified for Execute") + } + + var r0 InvokeResponseSender + if rf, ok := ret.Get(0).(func(context.Context, interop.InvokeRequest) InvokeResponseSender); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(InvokeResponseSender) + } + } + + return r0 +} + +func NewMockResponderFactoryFunc(t interface { + mock.TestingT + Cleanup(func()) +}) *MockResponderFactoryFunc { + mock := &MockResponderFactoryFunc{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/lambda-managed-instances/invoke/mock_running_invoke.go b/internal/lambda-managed-instances/invoke/mock_running_invoke.go new file mode 100644 index 0000000..8b40260 --- /dev/null +++ b/internal/lambda-managed-instances/invoke/mock_running_invoke.go @@ -0,0 +1,109 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package invoke + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + interop "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + + model "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" +) + +type mockRunningInvoke struct { + mock.Mock +} + +func (_m *mockRunningInvoke) CancelAsync(_a0 model.AppError) { + _m.Called(_a0) +} + +func (_m *mockRunningInvoke) RunInvokeAndSendResult(_a0 context.Context, _a1 interop.InitStaticDataProvider, _a2 interop.InvokeRequest, _a3 interop.InvokeMetrics) model.AppError { + ret := _m.Called(_a0, _a1, _a2, _a3) + + if len(ret) == 0 { + panic("no return value specified for RunInvokeAndSendResult") + } + + var r0 model.AppError + if rf, ok := ret.Get(0).(func(context.Context, interop.InitStaticDataProvider, interop.InvokeRequest, interop.InvokeMetrics) model.AppError); ok { + r0 = rf(_a0, _a1, _a2, _a3) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(model.AppError) + } + } + + return r0 +} + +func (_m *mockRunningInvoke) RuntimeError(_a0 context.Context, _a1 RuntimeErrorRequest) model.AppError { + ret := _m.Called(_a0, _a1) + + if len(ret) == 0 { + panic("no return value specified for RuntimeError") + } + + var r0 model.AppError + if rf, ok := ret.Get(0).(func(context.Context, RuntimeErrorRequest) model.AppError); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(model.AppError) + } + } + + return r0 +} + +func (_m *mockRunningInvoke) RuntimeNextWait(_a0 context.Context) model.AppError { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for RuntimeNextWait") + } + + var r0 model.AppError + if rf, ok := ret.Get(0).(func(context.Context) model.AppError); ok { + r0 = rf(_a0) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(model.AppError) + } + } + + return r0 +} + +func (_m *mockRunningInvoke) RuntimeResponse(_a0 context.Context, _a1 RuntimeResponseRequest) model.AppError { + ret := _m.Called(_a0, _a1) + + if len(ret) == 0 { + panic("no return value specified for RuntimeResponse") + } + + var r0 model.AppError + if rf, ok := ret.Get(0).(func(context.Context, RuntimeResponseRequest) model.AppError); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(model.AppError) + } + } + + return r0 +} + +func newMockRunningInvoke(t interface { + mock.TestingT + Cleanup(func()) +}) *mockRunningInvoke { + mock := &mockRunningInvoke{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/lambda-managed-instances/invoke/mock_runtime_error_request.go b/internal/lambda-managed-instances/invoke/mock_runtime_error_request.go new file mode 100644 index 0000000..09d09cf --- /dev/null +++ b/internal/lambda-managed-instances/invoke/mock_runtime_error_request.go @@ -0,0 +1,184 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package invoke + +import ( + json "encoding/json" + + mock "github.com/stretchr/testify/mock" + model "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" +) + +type MockRuntimeErrorRequest struct { + mock.Mock +} + +func (_m *MockRuntimeErrorRequest) ContentType() string { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for ContentType") + } + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +func (_m *MockRuntimeErrorRequest) ErrorCategory() model.ErrorCategory { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for ErrorCategory") + } + + var r0 model.ErrorCategory + if rf, ok := ret.Get(0).(func() model.ErrorCategory); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(model.ErrorCategory) + } + + return r0 +} + +func (_m *MockRuntimeErrorRequest) ErrorDetails() string { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for ErrorDetails") + } + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +func (_m *MockRuntimeErrorRequest) ErrorType() model.ErrorType { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for ErrorType") + } + + var r0 model.ErrorType + if rf, ok := ret.Get(0).(func() model.ErrorType); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(model.ErrorType) + } + + return r0 +} + +func (_m *MockRuntimeErrorRequest) GetError() model.AppError { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for GetError") + } + + var r0 model.AppError + if rf, ok := ret.Get(0).(func() model.AppError); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(model.AppError) + } + } + + return r0 +} + +func (_m *MockRuntimeErrorRequest) GetXrayErrorCause() json.RawMessage { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for GetXrayErrorCause") + } + + var r0 json.RawMessage + if rf, ok := ret.Get(0).(func() json.RawMessage); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(json.RawMessage) + } + } + + return r0 +} + +func (_m *MockRuntimeErrorRequest) InvokeID() string { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for InvokeID") + } + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +func (_m *MockRuntimeErrorRequest) IsRuntimeError(_a0 model.AppError) bool { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for IsRuntimeError") + } + + var r0 bool + if rf, ok := ret.Get(0).(func(model.AppError) bool); ok { + r0 = rf(_a0) + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +func (_m *MockRuntimeErrorRequest) ReturnCode() int { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for ReturnCode") + } + + var r0 int + if rf, ok := ret.Get(0).(func() int); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(int) + } + + return r0 +} + +func NewMockRuntimeErrorRequest(t interface { + mock.TestingT + Cleanup(func()) +}) *MockRuntimeErrorRequest { + mock := &MockRuntimeErrorRequest{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/lambda-managed-instances/invoke/mock_runtime_response_request.go b/internal/lambda-managed-instances/invoke/mock_runtime_response_request.go new file mode 100644 index 0000000..5da7b50 --- /dev/null +++ b/internal/lambda-managed-instances/invoke/mock_runtime_response_request.go @@ -0,0 +1,135 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package invoke + +import ( + io "io" + + mock "github.com/stretchr/testify/mock" + model "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" +) + +type MockRuntimeResponseRequest struct { + mock.Mock +} + +func (_m *MockRuntimeResponseRequest) BodyReader() io.Reader { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for BodyReader") + } + + var r0 io.Reader + if rf, ok := ret.Get(0).(func() io.Reader); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(io.Reader) + } + } + + return r0 +} + +func (_m *MockRuntimeResponseRequest) ContentType() string { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for ContentType") + } + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +func (_m *MockRuntimeResponseRequest) InvokeID() string { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for InvokeID") + } + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +func (_m *MockRuntimeResponseRequest) ParsingError() model.AppError { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for ParsingError") + } + + var r0 model.AppError + if rf, ok := ret.Get(0).(func() model.AppError); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(model.AppError) + } + } + + return r0 +} + +func (_m *MockRuntimeResponseRequest) ResponseMode() string { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for ResponseMode") + } + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +func (_m *MockRuntimeResponseRequest) TrailerError() ErrorForInvoker { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for TrailerError") + } + + var r0 ErrorForInvoker + if rf, ok := ret.Get(0).(func() ErrorForInvoker); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(ErrorForInvoker) + } + } + + return r0 +} + +func NewMockRuntimeResponseRequest(t interface { + mock.TestingT + Cleanup(func()) +}) *MockRuntimeResponseRequest { + mock := &MockRuntimeResponseRequest{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/lambda-managed-instances/invoke/mock_timeout_cache.go b/internal/lambda-managed-instances/invoke/mock_timeout_cache.go new file mode 100644 index 0000000..bdae680 --- /dev/null +++ b/internal/lambda-managed-instances/invoke/mock_timeout_cache.go @@ -0,0 +1,43 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package invoke + +import mock "github.com/stretchr/testify/mock" + +type mockTimeoutCache struct { + mock.Mock +} + +func (_m *mockTimeoutCache) Consume(invokeID string) bool { + ret := _m.Called(invokeID) + + if len(ret) == 0 { + panic("no return value specified for Consume") + } + + var r0 bool + if rf, ok := ret.Get(0).(func(string) bool); ok { + r0 = rf(invokeID) + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +func (_m *mockTimeoutCache) Register(invokeID string) { + _m.Called(invokeID) +} + +func newMockTimeoutCache(t interface { + mock.TestingT + Cleanup(func()) +}) *mockTimeoutCache { + mock := &mockTimeoutCache{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/lambda-managed-instances/invoke/running_invoke.go b/internal/lambda-managed-instances/invoke/running_invoke.go new file mode 100644 index 0000000..feec349 --- /dev/null +++ b/internal/lambda-managed-instances/invoke/running_invoke.go @@ -0,0 +1,329 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package invoke + +import ( + "context" + "encoding/json" + "log/slog" + "net/http" + "sync/atomic" + "time" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/logging" + intmodel "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/telemetry/xray" +) + +const ( + stateNoResponse uint32 = 0 + stateGotResponse uint32 = 1 + stateGotError uint32 = 2 +) + +type ErrorForInvoker interface { + ReturnCode() int + ErrorCategory() model.ErrorCategory + ErrorType() model.ErrorType + ErrorDetails() string +} + +type InvokeResponseSender interface { + SendRuntimeResponseHeaders(initData interop.InitStaticDataProvider, contentType, responseMode string) + + SendRuntimeResponseBody(ctx context.Context, runtimeResp RuntimeResponseRequest, functionTimeout time.Duration) SendResponseBodyResult + + SendRuntimeResponseTrailers(RuntimeResponseRequest) + + SendError(ErrorForInvoker, interop.InitStaticDataProvider) + + SendErrorTrailers(ErrorForInvoker, InvokeBodyResponseStatus) + + ErrorPayloadSizeBytes() int +} + +type ResponderFactoryFunc func(context.Context, interop.InvokeRequest) InvokeResponseSender + +type SendResponseBodyResult struct { + Metrics interop.InvokeResponseMetrics + Err model.AppError +} + +type runningInvokeImpl struct { + timeoutCache timeoutCache + + cancelAsyncCtx context.Context + cancelAsyncCtxCancel context.CancelCauseFunc + + responseState atomic.Uint32 + invokeSentChan chan struct{} + responseSentChan chan model.AppError + errorSentChan chan model.AppError + runtimeResponseChan chan RuntimeResponseRequest + runtimeErrorChan chan RuntimeErrorRequest + + invokeRespSender InvokeResponseSender + runtimeNext http.ResponseWriter + + responderFactoryFunc ResponderFactoryFunc + sendInvokeToRuntime func(context.Context, interop.InitStaticDataProvider, interop.InvokeRequest, http.ResponseWriter, string) (int64, time.Duration, time.Duration, model.AppError) + createTracingData func(traceId string, tracingMode intmodel.XrayTracingMode, segmentIDGenerator func() string) (downstreamTraceId string, tracingCtx *interop.TracingCtx) +} + +func newRunningInvoke( + runtimeNext http.ResponseWriter, + responderFactoryFunc ResponderFactoryFunc, + timeoutCache timeoutCache, +) runningInvokeImpl { + ctx, cancel := context.WithCancelCause(context.Background()) + return runningInvokeImpl{ + timeoutCache: timeoutCache, + cancelAsyncCtx: ctx, + cancelAsyncCtxCancel: cancel, + invokeSentChan: make(chan struct{}), + responseSentChan: make(chan model.AppError, 1), + errorSentChan: make(chan model.AppError, 1), + runtimeResponseChan: make(chan RuntimeResponseRequest, 1), + runtimeErrorChan: make(chan RuntimeErrorRequest, 1), + runtimeNext: runtimeNext, + + responderFactoryFunc: responderFactoryFunc, + sendInvokeToRuntime: sendInvokeToRuntime, + createTracingData: xray.CreateTracingData, + } +} + +func (r *runningInvokeImpl) RunInvokeAndSendResult(ctx context.Context, initData interop.InitStaticDataProvider, invokeReq interop.InvokeRequest, metrics interop.InvokeMetrics) model.AppError { + downstreamTraceId, tracingCtx := r.createTracingData(invokeReq.TraceId(), initData.XRayTracingMode(), xray.GenerateSegmentID) + + metrics.TriggerStartRequest() + if err := metrics.SendInvokeStartEvent(tracingCtx); err != nil { + logging.Error(ctx, "Failed to send InvokeStartEvent", "err", err) + } + + r.invokeRespSender = r.responderFactoryFunc(ctx, invokeReq) + + ctx, cancel := r.getInvokeCtx(ctx, initData.FunctionTimeout()) + defer cancel() + + logging.Debug(ctx, "Sending Invoke to Runtime") + written, requestPayloadReadDuration, requestPayloadWriteDuration, err := r.sendInvokeToRuntime(ctx, initData, invokeReq, r.runtimeNext, downstreamTraceId) + if err != nil { + logging.Err(ctx, "Failed to send Invoke to Runtime", err) + + r.cancelAsyncCtxCancel(err) + if err.ErrorType() == model.ErrorSandboxTimedout { + + r.timeoutCache.Register(invokeReq.InvokeID()) + } + r.invokeRespSender.SendError(err, initData) + metrics.TriggerSentResponse(false, err, nil, r.invokeRespSender.ErrorPayloadSizeBytes()) + if err := metrics.SendInvokeFinishedEvent(tracingCtx, nil); err != nil { + logging.Error(ctx, "Failed to send InvokeFinishedEvent", "err", err) + } + return err + } + + logging.Debug(ctx, "Waiting for Runtime response") + metrics.TriggerSentRequest(written, requestPayloadReadDuration, requestPayloadWriteDuration) + close(r.invokeSentChan) + + runtimeAnswerSent := false + var invokeResponseMetrics *interop.InvokeResponseMetrics + var xrayErrorCause json.RawMessage + select { + case runtimeResp := <-r.runtimeResponseChan: + metrics.TriggerGetResponse() + logging.Debug(ctx, "Received Runtime response") + if err = runtimeResp.ParsingError(); err != nil { + r.responseState.Store(stateGotError) + r.invokeRespSender.SendError(err, initData) + break + } + runtimeAnswerSent, invokeResponseMetrics, err, xrayErrorCause = r.sendRuntimeResponse(ctx, initData, runtimeResp) + case runtimeErr := <-r.runtimeErrorChan: + metrics.TriggerGetResponse() + err = runtimeErr.GetError() + + logging.Warn(ctx, "Received Runtime error", "err", err) + runtimeAnswerSent = r.sendRuntimeError(initData, runtimeErr) + xrayErrorCause = runtimeErr.GetXrayErrorCause() + case <-ctx.Done(): + + r.responseState.Store(stateGotError) + err = BuildInvokeAppError(context.Cause(ctx), initData.FunctionTimeout()) + logging.Info(ctx, "Received ctx cancellation", "err", err) + + if err.ErrorType() == model.ErrorSandboxTimedout { + + r.timeoutCache.Register(invokeReq.InvokeID()) + } + + r.invokeRespSender.SendError(err, initData) + } + + metrics.TriggerSentResponse(runtimeAnswerSent, err, invokeResponseMetrics, r.invokeRespSender.ErrorPayloadSizeBytes()) + if err := metrics.SendInvokeFinishedEvent(tracingCtx, xrayErrorCause); err != nil { + logging.Error(ctx, "Failed to send InvokeFinishedEvent", "err", err) + } + + logging.Debug(ctx, "Notify response and error channels we sent", "sent_err", err) + r.responseSentChan <- err + r.errorSentChan <- err + return err +} + +func (r *runningInvokeImpl) getInvokeCtx(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) { + ctx, invokeCtxCancel1 := context.WithCancelCause(ctx) + go func() { + + <-r.cancelAsyncCtx.Done() + invokeCtxCancel1(context.Cause(r.cancelAsyncCtx)) + }() + + ctx, invokeCtxCancel2 := context.WithTimeout(ctx, timeout) + + return ctx, func() { + + invokeCtxCancel2() + invokeCtxCancel1(nil) + r.cancelAsyncCtxCancel(nil) + } +} + +func (r *runningInvokeImpl) sendRuntimeResponse(ctx context.Context, initData interop.InitStaticDataProvider, runtimeResp RuntimeResponseRequest) (bool, *interop.InvokeResponseMetrics, model.AppError, json.RawMessage) { + logging.Debug(ctx, "Sending Runtime response headers") + r.invokeRespSender.SendRuntimeResponseHeaders(initData, runtimeResp.ContentType(), runtimeResp.ResponseMode()) + + childCtx, cancelFunc := context.WithCancel(ctx) + defer cancelFunc() + + sendBodyRes := make(chan SendResponseBodyResult) + logging.Debug(ctx, "Sending Runtime response body") + go func() { + sendBodyRes <- r.invokeRespSender.SendRuntimeResponseBody(childCtx, runtimeResp, initData.FunctionTimeout()) + }() + + select { + + case res := <-sendBodyRes: + if res.Err != nil { + logging.Err(ctx, "Failed sending body", res.Err) + r.invokeRespSender.SendErrorTrailers(res.Err, buildEndOfResponse(res.Err)) + return false, &res.Metrics, res.Err, nil + } + + trailerErr := runtimeResp.TrailerError() + if trailerErr != nil { + logging.Warn(ctx, "Runtime sent error trailers after response body", "err", trailerErr) + r.invokeRespSender.SendErrorTrailers(trailerErr, InvokeBodyResponseComplete) + + return true, &res.Metrics, model.NewCustomerError(trailerErr.ErrorType()), nil + } + + logging.Debug(ctx, "Sending response trailers") + r.invokeRespSender.SendRuntimeResponseTrailers(runtimeResp) + return true, &res.Metrics, nil, nil + case runtimeErr := <-r.runtimeErrorChan: + logging.Debug(ctx, "Received Runtime Error while sending response body") + cancelFunc() + res := <-sendBodyRes + + logging.Debug(ctx, "Sending error trailers") + r.invokeRespSender.SendErrorTrailers(runtimeErr, InvokeBodyResponseTruncated) + return true, &res.Metrics, runtimeErr.GetError(), runtimeErr.GetXrayErrorCause() + } +} + +func buildEndOfResponse(err model.AppError) InvokeBodyResponseStatus { + if err == nil { + return InvokeBodyResponseComplete + } + + switch err.ErrorType() { + case model.ErrorSandboxTimedout: + return invokeBodyResponseTimeout + case model.ErrorRuntimeTruncatedResponse: + return InvokeBodyResponseTruncated + case model.ErrorFunctionOversizedResponse: + return invokeBodyResponseOversized + default: + slog.Error("Unrecognized error", "err", err) + + return InvokeBodyResponseTruncated + } +} + +func (r *runningInvokeImpl) sendRuntimeError(initData interop.InitStaticDataProvider, runtimeErr RuntimeErrorRequest) bool { + slog.Debug("Sending Runtime error headers and trailers", "err", runtimeErr.GetError()) + r.invokeRespSender.SendRuntimeResponseHeaders(initData, "", "") + r.invokeRespSender.SendErrorTrailers(runtimeErr, InvokeBodyResponseComplete) + + return true +} + +func (r *runningInvokeImpl) RuntimeNextWait(ctx context.Context) model.AppError { + select { + case <-r.invokeSentChan: + return nil + case <-ctx.Done(): + err := context.Cause(ctx) + logging.Warn(ctx, "Runtime Context expired", "err", err) + + return model.NewCustomerError(model.ErrorRuntimeUnknown, model.WithCause(err)) + case <-r.cancelAsyncCtx.Done(): + err := context.Cause(r.cancelAsyncCtx) + logging.Warn(r.cancelAsyncCtx, "Aborting blocking /next's", "err", err) + + return model.NewCustomerError(model.ErrorRuntimeUnknown, model.WithCause(err)) + } +} + +func (r *runningInvokeImpl) RuntimeResponse(ctx context.Context, runtimeRespReq RuntimeResponseRequest) model.AppError { + logging.Debug(ctx, "Recevied runtime response") + if !r.responseState.CompareAndSwap(stateNoResponse, stateGotResponse) { + logging.Warn(ctx, "Invalid Invoke state : Response in progress") + + return model.NewCustomerError(model.ErrorRuntimeInvokeResponseInProgress) + } + + r.runtimeResponseChan <- runtimeRespReq + return <-r.responseSentChan +} + +func (r *runningInvokeImpl) RuntimeError(ctx context.Context, runtimeErrReq RuntimeErrorRequest) model.AppError { + oldState := r.responseState.Swap(stateGotError) + if oldState == stateGotError { + logging.Warn(ctx, "Invalid invoke state : error in progress") + return model.NewCustomerError(model.ErrorRuntimeInvokeErrorInProgress) + } + + r.runtimeErrorChan <- runtimeErrReq + + err := <-r.errorSentChan + + ctx = logging.WithFields(ctx, "err", err) + + logging.Debug(ctx, "Received Runtime error") + + if err == nil { + logging.Warn(ctx, "Sent Runtime response instead of error") + + return model.NewCustomerError(model.ErrorRuntimeInvokeResponseWasSent) + } else if runtimeErrReq.IsRuntimeError(err) { + + logging.Debug(ctx, "Error was successfully sent") + return nil + } + + logging.Warn(ctx, "Sent another error", "err", err) + return err +} + +func (r *runningInvokeImpl) CancelAsync(err model.AppError) { + slog.Info("cancel invoke", "reason", err) + r.cancelAsyncCtxCancel(err) +} diff --git a/internal/lambda-managed-instances/invoke/running_invoke_test.go b/internal/lambda-managed-instances/invoke/running_invoke_test.go new file mode 100644 index 0000000..951f650 --- /dev/null +++ b/internal/lambda-managed-instances/invoke/running_invoke_test.go @@ -0,0 +1,478 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package invoke + +import ( + "context" + "encoding/json" + "net/http" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + intmodel "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/model" + rapimodel "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" +) + +type runningInvokeMocks struct { + ctx context.Context + staticData interop.MockInitStaticDataProvider + eaInvokeRequest interop.MockInvokeRequest + eaInvokeResponder MockInvokeResponseSender + runtimeRespReq MockRuntimeResponseRequest + runtimeErrorReq MockRuntimeErrorRequest + timeoutCache *mockTimeoutCache + metrics interop.MockInvokeMetrics + errorPayloadSizeBytes int + + runtimeNextRequest http.ResponseWriter +} + +func newRunningInvokeMocks(t *testing.T) runningInvokeMocks { + return runningInvokeMocks{ + ctx: context.TODO(), + staticData: interop.MockInitStaticDataProvider{}, + eaInvokeRequest: interop.MockInvokeRequest{}, + runtimeRespReq: MockRuntimeResponseRequest{}, + runtimeErrorReq: MockRuntimeErrorRequest{}, + timeoutCache: newMockTimeoutCache(t), + metrics: interop.MockInvokeMetrics{}, + errorPayloadSizeBytes: 0, + } +} + +func hijackRunningInvokeDeps(ri *runningInvokeImpl, mocks *runningInvokeMocks) { + ri.sendInvokeToRuntime = func(context.Context, interop.InitStaticDataProvider, interop.InvokeRequest, http.ResponseWriter, string) (int64, time.Duration, time.Duration, model.AppError) { + return 0, 0, 0, nil + } + + ri.createTracingData = func(string, intmodel.XrayTracingMode, func() string) (string, *interop.TracingCtx) { + traceId := "Root=root1;Parent=parent1;Sampled=1;Lineage=foo:1|bar:65535" + tracingCtx := &interop.TracingCtx{ + SpanID: "", + Type: rapimodel.XRayTracingType, + Value: traceId, + } + return traceId, tracingCtx + } +} + +func createMocksAndInitRunningInvoke(t *testing.T) (*runningInvokeMocks, *runningInvokeImpl) { + mocks := newRunningInvokeMocks(t) + + ri := newRunningInvoke( + mocks.runtimeNextRequest, + func(ctx context.Context, ir interop.InvokeRequest) InvokeResponseSender { + return &mocks.eaInvokeResponder + }, + mocks.timeoutCache, + ) + hijackRunningInvokeDeps(&ri, &mocks) + + return &mocks, &ri +} + +func mockMetrics(mocks *runningInvokeMocks, invokeRes interface{}) { + mockMetricsBeforeResponse(mocks) + mocks.metrics.On("TriggerGetResponse").Return() + mocks.metrics.On("TriggerSentResponse", true, invokeRes, mock.Anything, mocks.errorPayloadSizeBytes).Return() +} + +func mockMetricsAnswerNotFinished(mocks *runningInvokeMocks, invokeRes interface{}) { + mockMetricsBeforeResponse(mocks) + mocks.metrics.On("TriggerGetResponse").Return() + mocks.metrics.On("TriggerSentResponse", false, invokeRes, mock.Anything, mocks.errorPayloadSizeBytes).Return() +} + +func mockMetricsUnfinished(mocks *runningInvokeMocks, invokeRes interface{}) { + mockMetricsBeforeResponse(mocks) + mocks.metrics.On("TriggerSentResponse", false, invokeRes, mock.Anything, mocks.errorPayloadSizeBytes).Return() +} + +func mockMetricsBeforeResponse(mocks *runningInvokeMocks) { + mocks.metrics.On("TriggerStartRequest") + mocks.metrics.On("SendInvokeStartEvent", mock.AnythingOfType("*interop.TracingCtx")).Return(nil) + mocks.metrics.On("TriggerSentRequest", mock.AnythingOfType("int64"), mock.AnythingOfType("time.Duration"), mock.AnythingOfType("time.Duration")).Return() + mocks.metrics.On("SendInvokeFinishedEvent", mock.AnythingOfType("*interop.TracingCtx"), mock.AnythingOfType("json.RawMessage")).Return(nil) +} + +func checkRunningInvokeMockExpectations(t *testing.T, mocks *runningInvokeMocks) { + mocks.staticData.AssertExpectations(t) + mocks.eaInvokeRequest.AssertExpectations(t) + mocks.eaInvokeResponder.AssertExpectations(t) + mocks.runtimeRespReq.AssertExpectations(t) + mocks.runtimeErrorReq.AssertExpectations(t) + mocks.metrics.AssertExpectations(t) +} + +func TestRunInvokeAndSendResultSuccess_RuntimeResponse(t *testing.T) { + t.Parallel() + + mocks, runInvoke := createMocksAndInitRunningInvoke(t) + + mocks.staticData.On("FunctionTimeout").Return(time.Second) + mocks.staticData.On("XRayTracingMode").Return(intmodel.XRayTracingModePassThrough) + + mocks.eaInvokeRequest.On("TraceId").Return("Root=root1;Parent=parent1;Sampled=1;Lineage=foo:1|bar:65535") + + mocks.runtimeRespReq.On("ParsingError").Return(nil) + mocks.runtimeRespReq.On("ContentType").Return("") + mocks.runtimeRespReq.On("ResponseMode").Return("") + mocks.runtimeRespReq.On("TrailerError").Return(nil) + + mocks.eaInvokeResponder.On("SendRuntimeResponseHeaders", &mocks.staticData, mock.Anything, mock.Anything).Return() + mocks.eaInvokeResponder.On("SendRuntimeResponseBody", mock.Anything, &mocks.runtimeRespReq, mock.Anything).Return(SendResponseBodyResult{}) + mocks.eaInvokeResponder.On("SendRuntimeResponseTrailers", &mocks.runtimeRespReq).Return() + mocks.eaInvokeResponder.On("ErrorPayloadSizeBytes").Return(mocks.errorPayloadSizeBytes) + mockMetrics(mocks, nil) + + wg := new(sync.WaitGroup) + wg.Add(1) + + go func() { + defer wg.Done() + err := runInvoke.RuntimeResponse(mocks.ctx, &mocks.runtimeRespReq) + assert.NoError(t, err) + }() + + err := runInvoke.RunInvokeAndSendResult(mocks.ctx, &mocks.staticData, &mocks.eaInvokeRequest, &mocks.metrics) + assert.NoError(t, err) + + wg.Wait() + checkRunningInvokeMockExpectations(t, mocks) +} + +func TestRunInvokeAndSendResultSuccess_RuntimeError(t *testing.T) { + t.Parallel() + + mocks, runInvoke := createMocksAndInitRunningInvoke(t) + err := model.NewCustomerError(model.ErrorFunctionUnknown) + + mocks.staticData.On("FunctionTimeout").Return(time.Second) + mocks.staticData.On("XRayTracingMode").Return(intmodel.XRayTracingModePassThrough) + + mocks.eaInvokeRequest.On("TraceId").Return("Root=root1;Parent=parent1;Sampled=1;Lineage=foo:1|bar:65535") + mocks.errorPayloadSizeBytes = 100 + + mocks.runtimeErrorReq.On("GetError").Return(err) + mocks.runtimeErrorReq.On("GetXrayErrorCause").Return(json.RawMessage(nil)) + mocks.eaInvokeResponder.On("SendRuntimeResponseHeaders", &mocks.staticData, mock.Anything, mock.Anything).Return().Once() + mocks.eaInvokeResponder.On("SendErrorTrailers", mock.Anything, InvokeBodyResponseComplete).Return().Once() + mocks.eaInvokeResponder.On("ErrorPayloadSizeBytes").Return(mocks.errorPayloadSizeBytes) + + mocks.runtimeErrorReq.On("IsRuntimeError", err).Return(true) + mockMetrics(mocks, err) + + wg := new(sync.WaitGroup) + wg.Add(1) + + go func() { + defer wg.Done() + err := runInvoke.RuntimeError(mocks.ctx, &mocks.runtimeErrorReq) + assert.NoError(t, err) + }() + + invokeErr := runInvoke.RunInvokeAndSendResult(mocks.ctx, &mocks.staticData, &mocks.eaInvokeRequest, &mocks.metrics) + assert.Error(t, invokeErr) + + wg.Wait() + checkRunningInvokeMockExpectations(t, mocks) +} + +func TestRunInvokeAndSendResultSuccess_RuntimeTrailerError(t *testing.T) { + t.Parallel() + + mocks, runInvoke := createMocksAndInitRunningInvoke(t) + + trailerErrorType := model.ErrorType("Function.Unknown") + expectedTrailerErr := model.NewCustomerError(trailerErrorType) + + mocks.staticData.On("FunctionTimeout").Return(time.Second) + mocks.staticData.On("XRayTracingMode").Return(intmodel.XRayTracingModePassThrough) + + mocks.eaInvokeRequest.On("TraceId").Return("Root=root1;Parent=parent1;Sampled=1;Lineage=foo:1|bar:65535") + + mocks.runtimeRespReq.On("ParsingError").Return(nil) + mocks.runtimeRespReq.On("ContentType").Return("") + mocks.runtimeRespReq.On("ResponseMode").Return("") + mocks.runtimeRespReq.On("TrailerError").Return(expectedTrailerErr, NewMockErrorForInvoker(t)) + + mocks.eaInvokeResponder.On("SendRuntimeResponseHeaders", &mocks.staticData, mock.Anything, mock.Anything).Return() + mocks.eaInvokeResponder.On("SendRuntimeResponseBody", mock.Anything, &mocks.runtimeRespReq, mock.Anything).Return(SendResponseBodyResult{}) + mocks.eaInvokeResponder.On("SendErrorTrailers", expectedTrailerErr, InvokeBodyResponseComplete).Return().Once() + mocks.eaInvokeResponder.On("ErrorPayloadSizeBytes").Return(mocks.errorPayloadSizeBytes) + mockMetrics(mocks, mock.MatchedBy(func(err model.AppError) bool { + return err != nil && err.ErrorType() == trailerErrorType + })) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + err := runInvoke.RuntimeResponse(mocks.ctx, &mocks.runtimeRespReq) + assert.Error(t, err) + assert.Equal(t, trailerErrorType, err.ErrorType()) + }() + + err := runInvoke.RunInvokeAndSendResult(mocks.ctx, &mocks.staticData, &mocks.eaInvokeRequest, &mocks.metrics) + assert.Error(t, err) + assert.Equal(t, trailerErrorType, err.ErrorType()) + + wg.Wait() + checkRunningInvokeMockExpectations(t, mocks) +} + +func TestRuntimeErrorFailure_SendInvokeToRuntime_Error(t *testing.T) { + t.Parallel() + + mocks, runInvoke := createMocksAndInitRunningInvoke(t) + + err := model.NewCustomerError(model.ErrorRuntimeUnknown) + runInvoke.sendInvokeToRuntime = func(context.Context, interop.InitStaticDataProvider, interop.InvokeRequest, http.ResponseWriter, string) (int64, time.Duration, time.Duration, model.AppError) { + return 0, 0, 0, err + } + + mocks.staticData.On("FunctionTimeout").Return(time.Second) + mocks.staticData.On("XRayTracingMode").Return(intmodel.XRayTracingModePassThrough) + + mocks.eaInvokeRequest.On("TraceId").Return("Root=root1;Parent=parent1;Sampled=1;Lineage=foo:1|bar:65535") + mocks.eaInvokeResponder.On("SendError", err, &mocks.staticData, mock.Anything).Return() + mocks.eaInvokeResponder.On("ErrorPayloadSizeBytes").Return(mocks.errorPayloadSizeBytes) + + mocks.metrics.On("TriggerStartRequest") + mocks.metrics.On("SendInvokeStartEvent", mock.AnythingOfType("*interop.TracingCtx")).Return(nil) + mocks.metrics.On("TriggerSentResponse", false, err, mock.Anything, 0).Return() + mocks.metrics.On("SendInvokeFinishedEvent", mock.AnythingOfType("*interop.TracingCtx"), mock.AnythingOfType("json.RawMessage")).Return(nil) + + invokeErr := runInvoke.RunInvokeAndSendResult(mocks.ctx, &mocks.staticData, &mocks.eaInvokeRequest, &mocks.metrics) + assert.Error(t, invokeErr) + + checkRunningInvokeMockExpectations(t, mocks) +} + +func TestRuntimeErrorFailure_SendInvokeToRuntime_Timeout(t *testing.T) { + t.Parallel() + + mocks, runInvoke := createMocksAndInitRunningInvoke(t) + + invokeID := "invoke-0" + mocks.timeoutCache.On("Register", invokeID) + mocks.eaInvokeRequest.On("InvokeID").Return(invokeID) + + err := BuildInvokeAppError(context.DeadlineExceeded, time.Second) + runInvoke.sendInvokeToRuntime = func(context.Context, interop.InitStaticDataProvider, interop.InvokeRequest, http.ResponseWriter, string) (int64, time.Duration, time.Duration, model.AppError) { + return 0, 0, 0, err + } + + mocks.staticData.On("FunctionTimeout").Return(time.Second) + mocks.staticData.On("XRayTracingMode").Return(intmodel.XRayTracingModePassThrough) + + mocks.eaInvokeRequest.On("TraceId").Return("Root=root1;Parent=parent1;Sampled=1;Lineage=foo:1|bar:65535") + mocks.eaInvokeResponder.On("SendError", err, &mocks.staticData, mock.Anything).Return() + mocks.eaInvokeResponder.On("ErrorPayloadSizeBytes").Return(mocks.errorPayloadSizeBytes) + + mocks.metrics.On("TriggerStartRequest") + mocks.metrics.On("SendInvokeStartEvent", mock.AnythingOfType("*interop.TracingCtx")).Return(nil) + mocks.metrics.On("TriggerSentResponse", false, err, mock.Anything, 0).Return() + mocks.metrics.On("SendInvokeFinishedEvent", mock.AnythingOfType("*interop.TracingCtx"), mock.AnythingOfType("json.RawMessage")).Return(nil) + + invokeErr := runInvoke.RunInvokeAndSendResult(mocks.ctx, &mocks.staticData, &mocks.eaInvokeRequest, &mocks.metrics) + assert.Error(t, invokeErr) + + checkRunningInvokeMockExpectations(t, mocks) +} + +func TestRunInvokeAndSendResultFailure_Timeout(t *testing.T) { + t.Parallel() + + mocks, runInvoke := createMocksAndInitRunningInvoke(t) + + invokeID := "invoke-0" + mocks.timeoutCache.On("Register", invokeID) + mocks.eaInvokeRequest.On("InvokeID").Return(invokeID) + + mocks.staticData.On("FunctionTimeout").Return(time.Millisecond) + mocks.staticData.On("XRayTracingMode").Return(intmodel.XRayTracingModePassThrough) + + mocks.eaInvokeRequest.On("TraceId").Return("Root=root1;Parent=parent1;Sampled=1;Lineage=foo:1|bar:65535") + mocks.eaInvokeResponder.On("SendError", mock.Anything, &mocks.staticData, mock.Anything).Return() + mocks.eaInvokeResponder.On("ErrorPayloadSizeBytes").Return(mocks.errorPayloadSizeBytes) + mockMetricsUnfinished(mocks, mock.Anything) + + invokeErr := runInvoke.RunInvokeAndSendResult(mocks.ctx, &mocks.staticData, &mocks.eaInvokeRequest, &mocks.metrics) + assert.Error(t, invokeErr) + + checkRunningInvokeMockExpectations(t, mocks) +} + +func TestRunInvokeAndSendResultFailure_TimeoutWhileResponse(t *testing.T) { + t.Parallel() + + mocks, runInvoke := createMocksAndInitRunningInvoke(t) + timeoutErr := model.NewCustomerError(model.ErrorSandboxTimedout) + + mocks.staticData.On("FunctionTimeout").Return(time.Second) + mocks.staticData.On("XRayTracingMode").Return(intmodel.XRayTracingModePassThrough) + + mocks.eaInvokeRequest.On("TraceId").Return("Root=root1;Parent=parent1;Sampled=1;Lineage=foo:1|bar:65535") + + mocks.runtimeRespReq.On("ParsingError").Return(nil) + mocks.runtimeRespReq.On("ContentType").Return("") + mocks.runtimeRespReq.On("ResponseMode").Return("") + + mocks.eaInvokeResponder.On("SendRuntimeResponseHeaders", &mocks.staticData, mock.Anything, mock.Anything).Return() + mocks.eaInvokeResponder.On("SendRuntimeResponseBody", mock.Anything, &mocks.runtimeRespReq, mock.Anything).Return(SendResponseBodyResult{Err: timeoutErr}) + mocks.eaInvokeResponder.On("SendErrorTrailers", mock.Anything, invokeBodyResponseTimeout).Return() + mocks.eaInvokeResponder.On("ErrorPayloadSizeBytes").Return(mocks.errorPayloadSizeBytes) + mockMetricsAnswerNotFinished(mocks, mock.Anything) + + wg := new(sync.WaitGroup) + wg.Add(1) + + go func() { + defer wg.Done() + err := runInvoke.RuntimeResponse(mocks.ctx, &mocks.runtimeRespReq) + assert.Error(t, err) + }() + + err := runInvoke.RunInvokeAndSendResult(mocks.ctx, &mocks.staticData, &mocks.eaInvokeRequest, &mocks.metrics) + assert.Error(t, err) + + wg.Wait() + + checkRunningInvokeMockExpectations(t, mocks) +} + +func TestRunInvokeAndSendResultFailure_ContextCancelled(t *testing.T) { + t.Parallel() + + mocks, runInvoke := createMocksAndInitRunningInvoke(t) + err := model.NewPlatformError(nil, "test fatal error") + + mocks.staticData.On("FunctionTimeout").Return(time.Second) + mocks.staticData.On("XRayTracingMode").Return(intmodel.XRayTracingModePassThrough) + + mocks.eaInvokeRequest.On("TraceId").Return("Root=root1;Parent=parent1;Sampled=1;Lineage=foo:1|bar:65535") + mocks.eaInvokeResponder.On("SendError", err, &mocks.staticData, mock.Anything).Return(nil) + mocks.eaInvokeResponder.On("ErrorPayloadSizeBytes").Return(mocks.errorPayloadSizeBytes) + mockMetricsUnfinished(mocks, mock.Anything) + + ch := make(chan struct{}) + + go func() { + <-ch + runInvoke.CancelAsync(err) + }() + + close(ch) + invokeErr := runInvoke.RunInvokeAndSendResult(mocks.ctx, &mocks.staticData, &mocks.eaInvokeRequest, &mocks.metrics) + assert.Error(t, invokeErr) + + checkRunningInvokeMockExpectations(t, mocks) +} + +func TestRuntimeResponseFailure_ResponseWhileResponse(t *testing.T) { + t.Parallel() + + mocks, runInvoke := createMocksAndInitRunningInvoke(t) + syncChan := make(chan time.Time) + + mocks.staticData.On("FunctionTimeout").Return(5 * time.Second) + mocks.staticData.On("XRayTracingMode").Return(intmodel.XRayTracingModePassThrough) + + mocks.eaInvokeRequest.On("TraceId").Return("Root=root1;Parent=parent1;Sampled=1;Lineage=foo:1|bar:65535") + + mocks.runtimeRespReq.On("ParsingError").Return(nil) + mocks.runtimeRespReq.On("ContentType").Return("") + mocks.runtimeRespReq.On("ResponseMode").Return("") + mocks.runtimeRespReq.On("TrailerError").Return(nil) + + mocks.eaInvokeResponder.On("SendRuntimeResponseHeaders", &mocks.staticData, mock.Anything, mock.Anything).Return().Once() + mocks.eaInvokeResponder.On("SendRuntimeResponseBody", mock.Anything, &mocks.runtimeRespReq, mock.Anything).Return(SendResponseBodyResult{}).WaitUntil(syncChan).Once() + mocks.eaInvokeResponder.On("SendRuntimeResponseTrailers", &mocks.runtimeRespReq).Return().Once() + mocks.eaInvokeResponder.On("ErrorPayloadSizeBytes").Return(mocks.errorPayloadSizeBytes) + mockMetrics(mocks, nil) + + wg := new(sync.WaitGroup) + ch := make(chan model.AppError, 2) + + wg.Add(1) + go func() { + defer wg.Done() + ch <- runInvoke.RuntimeResponse(mocks.ctx, &mocks.runtimeRespReq) + }() + + wg.Add(1) + go func() { + defer wg.Done() + ch <- runInvoke.RuntimeResponse(mocks.ctx, &mocks.runtimeRespReq) + }() + + wg.Add(1) + go func() { + defer wg.Done() + err := runInvoke.RunInvokeAndSendResult(mocks.ctx, &mocks.staticData, &mocks.eaInvokeRequest, &mocks.metrics) + assert.NoError(t, err) + }() + + err := <-ch + assert.Equal(t, model.ErrorRuntimeInvokeResponseInProgress, err.ErrorType()) + + close(syncChan) + err = <-ch + assert.NoError(t, err) + + wg.Wait() + checkRunningInvokeMockExpectations(t, mocks) +} + +func TestRuntimeErrorFailure_ErrorWhileError(t *testing.T) { + t.Parallel() + + mocks, runInvoke := createMocksAndInitRunningInvoke(t) + defer checkRunningInvokeMockExpectations(t, mocks) + + err := model.NewCustomerError(model.ErrorFunctionUnknown) + + mocks.staticData.On("FunctionTimeout").Return(time.Second) + mocks.staticData.On("XRayTracingMode").Return(intmodel.XRayTracingModePassThrough) + + mocks.eaInvokeRequest.On("TraceId").Return("Root=root1;Parent=parent1;Sampled=1;Lineage=foo:1|bar:65535") + + mocks.errorPayloadSizeBytes = 100 + + mocks.runtimeErrorReq.On("GetError").Return(err) + mocks.runtimeErrorReq.On("GetXrayErrorCause").Return(json.RawMessage(nil)) + + mocks.eaInvokeResponder.On("SendRuntimeResponseHeaders", &mocks.staticData, mock.Anything, mock.Anything).Return().Once() + mocks.eaInvokeResponder.On("SendErrorTrailers", mock.Anything, InvokeBodyResponseComplete).Return().Once() + mocks.eaInvokeResponder.On("ErrorPayloadSizeBytes").Return(mocks.errorPayloadSizeBytes) + + mocks.runtimeErrorReq.On("IsRuntimeError", err).Return(true) + mockMetrics(mocks, err) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + err := runInvoke.RunInvokeAndSendResult(mocks.ctx, &mocks.staticData, &mocks.eaInvokeRequest, &mocks.metrics) + assert.Error(t, err) + assert.Equal(t, model.ErrorFunctionUnknown, err.ErrorType()) + }() + + assert.NoError(t, runInvoke.RuntimeError(mocks.ctx, &mocks.runtimeErrorReq)) + + runtimeErrorErr := runInvoke.RuntimeError(mocks.ctx, &mocks.runtimeErrorReq) + assert.Error(t, runtimeErrorErr) + assert.Equal(t, model.ErrorRuntimeInvokeErrorInProgress, runtimeErrorErr.ErrorType()) + + customerErr := runInvoke.RuntimeError(mocks.ctx, &mocks.runtimeErrorReq) + assert.Error(t, customerErr) + assert.Equal(t, model.ErrorRuntimeInvokeErrorInProgress, customerErr.ErrorType()) +} diff --git a/internal/lambda-managed-instances/invoke/runtime_error_request.go b/internal/lambda-managed-instances/invoke/runtime_error_request.go new file mode 100644 index 0000000..afe5edc --- /dev/null +++ b/internal/lambda-managed-instances/invoke/runtime_error_request.go @@ -0,0 +1,110 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package invoke + +import ( + "context" + "encoding/json" + "net/http" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/logging" + rapidmodel "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" +) + +const ( + RuntimeErrorTypeHeader = "Lambda-Runtime-Function-Error-Type" + RuntimeErrorCategory = "Error.Runtime" + LambdaXRayErrorCauseHeader = "Lambda-Runtime-Function-XRay-Error-Cause" +) + +var runtimeDefaultSeverity = model.ErrorSeverityError + +type runtimeError struct { + request *http.Request + + contentType string + invokeID interop.InvokeID + errorType model.ErrorType + errorCategory model.ErrorCategory + + errorDetails string + xrayErrorCause json.RawMessage +} + +func NewRuntimeError(ctx context.Context, request *http.Request, invokeID interop.InvokeID, errorDetails string) runtimeError { + + return runtimeError{ + request: request, + contentType: request.Header.Get(RuntimeContentTypeHeader), + invokeID: invokeID, + errorType: model.GetValidRuntimeOrFunctionErrorType(request.Header.Get(RuntimeErrorTypeHeader)), + errorCategory: RuntimeErrorCategory, + errorDetails: errorDetails, + xrayErrorCause: getValidatedErrorCause(ctx, request.Header.Get(LambdaXRayErrorCauseHeader)), + } +} + +func (r *runtimeError) InvokeID() interop.InvokeID { + return r.invokeID +} + +func (r *runtimeError) ContentType() string { + return r.contentType +} + +func (r *runtimeError) ErrorType() model.ErrorType { + return r.errorType +} + +func (r *runtimeError) ErrorCategory() model.ErrorCategory { + return r.errorCategory +} + +func (r *runtimeError) GetError() model.AppError { + + return model.NewCustomerError(r.errorType) +} + +func (r *runtimeError) IsRuntimeError(err model.AppError) bool { + + if r.errorType != err.ErrorType() { + return false + } + + if err.Severity() != runtimeDefaultSeverity { + return false + } + + return true +} + +func (r *runtimeError) ErrorDetails() string { + return r.errorDetails +} + +func (r *runtimeError) GetXrayErrorCause() json.RawMessage { + return r.xrayErrorCause +} + +func (r *runtimeError) ReturnCode() int { + return http.StatusOK +} + +func getValidatedErrorCause(ctx context.Context, errorCauseHeader string) json.RawMessage { + if len(errorCauseHeader) == 0 { + logging.Debug(ctx, "errorCause has not been set") + return nil + } + errorCauseJSON := json.RawMessage(errorCauseHeader) + + validErrorCauseJSON, err := rapidmodel.ValidatedErrorCauseJSON(errorCauseJSON) + if err != nil { + logging.Warn(ctx, "errorCause JSON validation failed", "err", err) + return nil + } + + return validErrorCauseJSON +} diff --git a/internal/lambda-managed-instances/invoke/runtime_error_request_test.go b/internal/lambda-managed-instances/invoke/runtime_error_request_test.go new file mode 100644 index 0000000..aca8307 --- /dev/null +++ b/internal/lambda-managed-instances/invoke/runtime_error_request_test.go @@ -0,0 +1,154 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package invoke + +import ( + "context" + "net/http" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" +) + +func TestNewRuntimeError(t *testing.T) { + tests := []struct { + name string + errorTypeHeader string + xrayErrorCauseHeader string + contentTypeHeader string + expectedErrorType model.ErrorType + expectedContentType string + expectedXrayJSON string + }{ + { + name: "validRuntimeErrorWithValidXrayCause", + errorTypeHeader: "Runtime.CustomError", + xrayErrorCauseHeader: `{"message":"test error"}`, + contentTypeHeader: "application/json", + expectedErrorType: "Runtime.CustomError", + expectedContentType: "application/json", + expectedXrayJSON: `{"exceptions":null,"message":"test error","paths":null,"working_directory":""}`, + }, + { + name: "invalidErrorTypeDefaultsToRuntimeUnknown", + errorTypeHeader: "InvalidType", + xrayErrorCauseHeader: `{"paths":["index.js"]}`, + contentTypeHeader: "text/plain", + expectedErrorType: "Runtime.Unknown", + expectedContentType: "text/plain", + expectedXrayJSON: `{"exceptions":null,"paths":["index.js"],"working_directory":""}`, + }, + { + name: "emptyErrorTypeHeader", + errorTypeHeader: "", + xrayErrorCauseHeader: `{"working_directory":"/app"}`, + contentTypeHeader: "", + expectedErrorType: "Runtime.Unknown", + expectedContentType: "", + expectedXrayJSON: `{"exceptions":null,"paths":null,"working_directory":"/app"}`, + }, + { + name: "invalidXrayErrorCauseReturnsNil", + errorTypeHeader: "Function.CustomError", + xrayErrorCauseHeader: `{invalid json}`, + contentTypeHeader: "application/json", + expectedErrorType: "Function.CustomError", + expectedContentType: "application/json", + expectedXrayJSON: "", + }, + { + name: "emptyXrayErrorCauseHeader", + errorTypeHeader: "Runtime.ImportModuleError", + xrayErrorCauseHeader: "", + contentTypeHeader: "application/octet-stream", + expectedErrorType: "Runtime.ImportModuleError", + expectedContentType: "application/octet-stream", + expectedXrayJSON: "", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + + req, err := http.NewRequest("POST", "/runtime/invocation/test-id/error", strings.NewReader("error details")) + require.NoError(t, err) + + req.Header.Set(RuntimeErrorTypeHeader, tc.errorTypeHeader) + req.Header.Set(LambdaXRayErrorCauseHeader, tc.xrayErrorCauseHeader) + req.Header.Set(RuntimeContentTypeHeader, tc.contentTypeHeader) + + ctx := context.Background() + invokeID := interop.InvokeID("test-invoke-id") + + runtimeErr := NewRuntimeError(ctx, req, invokeID, "error details") + + assert.Equal(t, tc.expectedErrorType, runtimeErr.ErrorType()) + assert.Equal(t, tc.expectedContentType, runtimeErr.ContentType()) + assert.Equal(t, model.ErrorCategory(RuntimeErrorCategory), runtimeErr.ErrorCategory()) + assert.Equal(t, "error details", runtimeErr.ErrorDetails()) + assert.Equal(t, invokeID, runtimeErr.InvokeID()) + + xrayErrorCause := runtimeErr.GetXrayErrorCause() + if tc.expectedXrayJSON == "" { + assert.Nil(t, xrayErrorCause) + } else { + assert.NotNil(t, xrayErrorCause) + assert.JSONEq(t, tc.expectedXrayJSON, string(xrayErrorCause)) + } + }) + } +} + +func TestGetValidatedErrorCause(t *testing.T) { + tests := []struct { + name string + errorCauseHeader string + expectedJSON string + }{ + { + name: "emptyHeaderReturnsNil", + errorCauseHeader: "", + expectedJSON: "", + }, + { + name: "validErrorCauseWithMessage", + errorCauseHeader: `{"message":"test error"}`, + expectedJSON: `{"exceptions":null,"message":"test error","paths":null,"working_directory":""}`, + }, + { + name: "invalidJSONReturnsNil", + errorCauseHeader: `{invalid json}`, + expectedJSON: "", + }, + { + name: "invalidErrorCauseFormatReturnsNil", + errorCauseHeader: `{}`, + expectedJSON: "", + }, + { + name: "validErrorCauseWithPaths", + errorCauseHeader: `{"paths":["test.js"]}`, + expectedJSON: `{"exceptions":null,"paths":["test.js"],"working_directory":""}`, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + result := getValidatedErrorCause(ctx, tc.errorCauseHeader) + + if tc.expectedJSON == "" { + assert.Nil(t, result) + } else { + assert.NotNil(t, result) + assert.JSONEq(t, tc.expectedJSON, string(result)) + } + }) + } +} diff --git a/internal/lambda-managed-instances/invoke/runtime_response_request.go b/internal/lambda-managed-instances/invoke/runtime_response_request.go new file mode 100644 index 0000000..b764af6 --- /dev/null +++ b/internal/lambda-managed-instances/invoke/runtime_response_request.go @@ -0,0 +1,122 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package invoke + +import ( + "context" + "encoding/base64" + "io" + "log/slog" + "net/http" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/logging" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" +) + +const ( + RuntimeResponseModeHeader = "Lambda-Runtime-Function-Response-Mode" + runtimeResponseModeStreaming = "streaming" + runtimeResponseModeBuffered = "buffered" +) + +type runtimeResponse struct { + request *http.Request + + parsingErr model.AppError + + contentType string + invokeID interop.InvokeID + responseMode string +} + +func NewRuntimeResponse(ctx context.Context, request *http.Request, invokeID interop.InvokeID) runtimeResponse { + contentType := request.Header.Get(RuntimeContentTypeHeader) + if contentType == "" { + + contentType = "application/octet-stream" + } + resp := runtimeResponse{ + request: request, + contentType: contentType, + invokeID: invokeID, + } + + switch mode := request.Header.Get(RuntimeResponseModeHeader); mode { + case runtimeResponseModeStreaming: + resp.responseMode = runtimeResponseModeStreaming + case "": + resp.responseMode = runtimeResponseModeBuffered + default: + logging.Error(ctx, "invalid response mode from runtime", "mode", mode) + resp.responseMode = "" + + resp.parsingErr = model.NewCustomerError(model.ErrorRuntimeInvalidResponseModeHeader) + } + + return resp +} + +func (r *runtimeResponse) ParsingError() model.AppError { + return r.parsingErr +} + +func (r *runtimeResponse) InvokeID() interop.InvokeID { + return r.invokeID +} + +func (r *runtimeResponse) ContentType() string { + return r.contentType +} + +func (r *runtimeResponse) BodyReader() io.Reader { + return r.request.Body +} + +func (r *runtimeResponse) ResponseMode() string { + return r.responseMode +} + +func (r *runtimeResponse) TrailerError() ErrorForInvoker { + typ := r.request.Trailer.Get(FunctionErrorTypeTrailer) + if typ == "" { + return nil + } + + te := trailerError{ + typ: model.GetValidRuntimeOrFunctionErrorType(typ), + details: "", + } + + base64EncodedBody := r.request.Trailer.Get(FunctionErrorBodyTrailer) + decoded, err := base64.StdEncoding.DecodeString(base64EncodedBody) + if err != nil { + slog.Warn("could not base64 decode lambda-runtime-function-error-body trailer", "err", err) + return te + } + + te.details = string(decoded) + return te +} + +type trailerError struct { + typ model.ErrorType + details string +} + +func (t trailerError) ReturnCode() int { + return http.StatusOK +} + +func (t trailerError) ErrorCategory() model.ErrorCategory { + return RuntimeErrorCategory +} + +func (t trailerError) ErrorType() model.ErrorType { + return t.typ +} + +func (t trailerError) ErrorDetails() string { + return t.details +} diff --git a/internal/lambda-managed-instances/invoke/runtime_response_request_test.go b/internal/lambda-managed-instances/invoke/runtime_response_request_test.go new file mode 100644 index 0000000..cb11bf6 --- /dev/null +++ b/internal/lambda-managed-instances/invoke/runtime_response_request_test.go @@ -0,0 +1,81 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package invoke + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" +) + +func TestRuntimeResponse_TrailerError(t *testing.T) { + tests := []struct { + name string + errorTypeTrailer string + errorBodyTrailer string + expectedErrorType model.ErrorType + expectedErrorBody string + }{ + { + name: "empty error type returns empty values", + errorTypeTrailer: "", + errorBodyTrailer: "dGVzdA==", + expectedErrorType: "", + expectedErrorBody: "", + }, + { + name: "valid runtime error with valid base64 body", + errorTypeTrailer: "Runtime.CustomError", + errorBodyTrailer: "ZXhpdCBjb2RlIDE=", + expectedErrorType: "Runtime.CustomError", + expectedErrorBody: "exit code 1", + }, + { + name: "valid function error with empty base64 body", + errorTypeTrailer: "Function.CustomError", + errorBodyTrailer: "", + expectedErrorType: "Function.CustomError", + expectedErrorBody: "", + }, + { + name: "invalid error type with valid base64 body", + errorTypeTrailer: "InvalidType", + errorBodyTrailer: "c29tZSBlcnJvcg==", + expectedErrorType: "Runtime.Unknown", + expectedErrorBody: "some error", + }, + { + name: "valid error type with invalid base64 body", + errorTypeTrailer: "Function.CustomError", + errorBodyTrailer: "invalid-base64!@#", + expectedErrorType: "Function.CustomError", + expectedErrorBody: "", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + + req, err := http.NewRequest("POST", "/test", nil) + require.NoError(t, err) + req.Trailer = make(http.Header) + req.Trailer.Set(FunctionErrorTypeTrailer, tc.errorTypeTrailer) + req.Trailer.Set(FunctionErrorBodyTrailer, tc.errorBodyTrailer) + resp := NewRuntimeResponse(req.Context(), req, "test-invoke-id") + + actualTrailerError := resp.TrailerError() + + if tc.expectedErrorType == "" { + assert.Nil(t, actualTrailerError) + } else { + assert.Equal(t, tc.expectedErrorType, actualTrailerError.ErrorType()) + assert.Equal(t, tc.expectedErrorBody, actualTrailerError.ErrorDetails()) + } + }) + } +} diff --git a/internal/lambda-managed-instances/invoke/runtime_response_sender.go b/internal/lambda-managed-instances/invoke/runtime_response_sender.go new file mode 100644 index 0000000..9604a8d --- /dev/null +++ b/internal/lambda-managed-instances/invoke/runtime_response_sender.go @@ -0,0 +1,89 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package invoke + +import ( + "context" + "encoding/json" + "log/slog" + "net/http" + "strconv" + "time" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/core/directinvoke" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + rapiModel "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/utils" +) + +const ( + RuntimeContentTypeHeader = "Content-Type" + RuntimeRequestIdHeader = "Lambda-Runtime-Aws-Request-Id" + RuntimeDeadlineHeader = "Lambda-Runtime-Deadline-Ms" + RuntimeFunctionArnHeader = "Lambda-Runtime-Invoked-Function-Arn" + RuntimeTraceIdHeader = "Lambda-Runtime-Trace-Id" + RuntimeClientContextHeader = "Lambda-Runtime-Client-Context" + RuntimeCognitoIdentifyHeader = "Lambda-Runtime-Cognito-Identity" +) + +func sendInvokeToRuntime(ctx context.Context, initData interop.InitStaticDataProvider, invokeReq interop.InvokeRequest, runtimeReq http.ResponseWriter, traceId string) (int64, time.Duration, time.Duration, model.AppError) { + runtimeReq.Header().Set(RuntimeContentTypeHeader, invokeReq.ContentType()) + runtimeReq.Header().Set(RuntimeRequestIdHeader, invokeReq.InvokeID()) + runtimeReq.Header().Set(RuntimeDeadlineHeader, strconv.FormatInt(invokeReq.Deadline().UnixMilli(), 10)) + runtimeReq.Header().Set(RuntimeFunctionArnHeader, initData.FunctionARN()) + runtimeReq.Header().Set(RuntimeTraceIdHeader, traceId) + runtimeReq.Header().Set(RuntimeClientContextHeader, invokeReq.ClientContext()) + runtimeReq.Header().Set(RuntimeCognitoIdentifyHeader, buildCognitoIdentifyHeader(invokeReq)) + runtimeReq.WriteHeader(http.StatusOK) + + timedReader := &utils.TimedReader{ + Reader: invokeReq.BodyReader(), + Name: "request", + Ctx: ctx, + } + timedWriter := &utils.TimedWriter{ + Writer: directinvoke.NewCancellableWriter(ctx, runtimeReq), + Name: "request", + Ctx: ctx, + } + + resChan := make(chan error) + var written int64 + go func() { + var err error + written, err = utils.CopyWithPool(timedWriter, timedReader) + resChan <- err + }() + + select { + case <-ctx.Done(): + + return 0, 0, 0, BuildInvokeAppError(context.Cause(ctx), initData.FunctionTimeout()) + case err := <-resChan: + if err != nil { + return 0, timedWriter.TotalDuration, timedWriter.TotalDuration, model.NewCustomerError(model.ErrorRuntimeUnknown, model.WithCause(err)) + } + return written, timedReader.TotalDuration, timedWriter.TotalDuration, nil + } +} + +func buildCognitoIdentifyHeader(invokeReq interop.InvokeRequest) string { + cognitoIdentityJSON := "" + if len(invokeReq.CognitoId()) != 0 || len(invokeReq.CognitoPoolId()) != 0 { + cognitoJSON, err := json.Marshal(rapiModel.CognitoIdentity{ + CognitoIdentityID: invokeReq.CognitoId(), + CognitoIdentityPoolID: invokeReq.CognitoPoolId(), + }) + + if err != nil { + slog.Error("Marshal cognitoIdentity returns error", "err", err) + return "" + } + + cognitoIdentityJSON = string(cognitoJSON) + } + + return cognitoIdentityJSON +} diff --git a/internal/lambda-managed-instances/invoke/runtime_response_sender_test.go b/internal/lambda-managed-instances/invoke/runtime_response_sender_test.go new file mode 100644 index 0000000..9727e5b --- /dev/null +++ b/internal/lambda-managed-instances/invoke/runtime_response_sender_test.go @@ -0,0 +1,136 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package invoke + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" + internalMocks "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/testutils/mocks" +) + +var responseSenderTestPayloadSize = 100 + +const traceIdTestHeader string = "Root=12345;Parent=67890;Sampeld=11111;Lineage=22222" + +type runtimeResponseSenderMocks struct { + ctx context.Context + initData interop.MockInitStaticDataProvider + invokeReq interop.MockInvokeRequest + + runtimeReq http.ResponseWriter + reader io.Reader + writer io.Writer +} + +func createMocksAndRuntimeResponder() *runtimeResponseSenderMocks { + mocks := runtimeResponseSenderMocks{ + ctx: context.TODO(), + initData: interop.MockInitStaticDataProvider{}, + invokeReq: interop.MockInvokeRequest{}, + runtimeReq: httptest.NewRecorder(), + reader: &internalMocks.ReaderMock{ + WaitBeforeRead: time.Nanosecond, + PayloadSize: responseSenderTestPayloadSize, + }, + writer: io.Discard, + } + + mocks.initData.On("FunctionTimeout").Return(time.Duration(0)).Maybe() + + return &mocks +} + +func checkResponseSenderExpectations(t *testing.T, mocks *runtimeResponseSenderMocks) { + mocks.initData.AssertExpectations(t) + mocks.invokeReq.AssertExpectations(t) +} + +func buildInvokeReqMocks(invokeReq *interop.MockInvokeRequest) { + invokeReq.On("ContentType").Return("application/json") + invokeReq.On("InvokeID").Return("123456") + invokeReq.On("Deadline").Return(time.Now().Add(time.Second)) + invokeReq.On("ClientContext").Return("client-context-example") + invokeReq.On("CognitoId").Return("cognito_id_12345") + invokeReq.On("CognitoPoolId").Return("cognito_pool_id_6789") +} + +func buildInitDataMocks(initData *interop.MockInitStaticDataProvider) { + initData.On("FunctionARN").Return("function-arn") +} + +func TestSendResponseSuccess(t *testing.T) { + t.Parallel() + + mocks := createMocksAndRuntimeResponder() + buildInvokeReqMocks(&mocks.invokeReq) + buildInitDataMocks(&mocks.initData) + + mocks.invokeReq.On("BodyReader").Return(mocks.reader) + + written, readerDuration, writerDuration, err := sendInvokeToRuntime(mocks.ctx, &mocks.initData, &mocks.invokeReq, mocks.runtimeReq, traceIdTestHeader) + assert.NoError(t, err) + assert.Equal(t, int64(responseSenderTestPayloadSize), written) + assert.Greater(t, readerDuration, time.Duration(0)) + assert.Greater(t, writerDuration, time.Duration(0)) + + checkResponseSenderExpectations(t, mocks) +} + +func TestSendResponseFailure_Timeout(t *testing.T) { + t.Parallel() + + mocks := createMocksAndRuntimeResponder() + buildInvokeReqMocks(&mocks.invokeReq) + buildInitDataMocks(&mocks.initData) + + mocks.invokeReq.On("BodyReader").Maybe().Return(&internalMocks.ReaderMock{ + PayloadSize: 100, + WaitBeforeRead: time.Second, + }) + + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel() + + _, readerDuration, writerDuration, err := sendInvokeToRuntime(ctx, &mocks.initData, &mocks.invokeReq, mocks.runtimeReq, traceIdTestHeader) + assert.Error(t, err) + assert.Equal(t, model.ErrorSandboxTimedout, err.ErrorType()) + assert.Zero(t, readerDuration) + assert.Zero(t, writerDuration) + + checkResponseSenderExpectations(t, mocks) +} + +func TestSendResponseFailure_CtxCancelled(t *testing.T) { + t.Parallel() + + mocks := createMocksAndRuntimeResponder() + buildInvokeReqMocks(&mocks.invokeReq) + buildInitDataMocks(&mocks.initData) + + mocks.invokeReq.On("BodyReader").Maybe().Return(&internalMocks.ReaderMock{ + PayloadSize: 100, + WaitBeforeRead: time.Second, + }) + + ctx, cancel := context.WithCancelCause(context.Background()) + cancel(model.NewCustomerError(model.ErrorReasonExtensionExecFailed)) + + _, readerDuration, writerDuration, err := sendInvokeToRuntime(ctx, &mocks.initData, &mocks.invokeReq, mocks.runtimeReq, traceIdTestHeader) + + assert.Error(t, err) + assert.Equal(t, model.ErrorReasonExtensionExecFailed, err.ErrorType()) + assert.Zero(t, readerDuration) + assert.Zero(t, writerDuration) + + checkResponseSenderExpectations(t, mocks) +} diff --git a/internal/lambda-managed-instances/invoke/timeout/timeout_cache.go b/internal/lambda-managed-instances/invoke/timeout/timeout_cache.go new file mode 100644 index 0000000..0d59eae --- /dev/null +++ b/internal/lambda-managed-instances/invoke/timeout/timeout_cache.go @@ -0,0 +1,65 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package timeout + +import ( + "container/list" + "log/slog" + "sync" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" +) + +const timeoutCacheSize = 1000 + +type RecentCache struct { + entries map[interop.InvokeID]*list.Element + order *list.List + mu sync.Mutex +} + +func NewRecentCache() *RecentCache { + return &RecentCache{ + entries: make(map[interop.InvokeID]*list.Element, timeoutCacheSize), + order: list.New(), + } +} + +func (tc *RecentCache) Register(invokeID interop.InvokeID) { + tc.mu.Lock() + defer tc.mu.Unlock() + + if elem, ok := tc.entries[invokeID]; ok { + + tc.order.MoveToBack(elem) + return + } + + if len(tc.entries) == timeoutCacheSize { + + oldestInvokeID := tc.order.Front().Value.(interop.InvokeID) + _ = tc.tryDelete(oldestInvokeID) + slog.Warn("evicted invokeID from full timeout cache", "invokeID", oldestInvokeID) + } + + elem := tc.order.PushBack(invokeID) + tc.entries[invokeID] = elem +} + +func (tc *RecentCache) Consume(invokeID interop.InvokeID) (consumed bool) { + tc.mu.Lock() + defer tc.mu.Unlock() + + return tc.tryDelete(invokeID) +} + +func (tc *RecentCache) tryDelete(invokeID interop.InvokeID) (deleted bool) { + elem, ok := tc.entries[invokeID] + if !ok { + return false + } + tc.order.Remove(elem) + delete(tc.entries, invokeID) + return true +} diff --git a/internal/lambda-managed-instances/invoke/timeout/timeout_cache_test.go b/internal/lambda-managed-instances/invoke/timeout/timeout_cache_test.go new file mode 100644 index 0000000..ef17ce6 --- /dev/null +++ b/internal/lambda-managed-instances/invoke/timeout/timeout_cache_test.go @@ -0,0 +1,79 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package timeout_test + +import ( + "context" + "fmt" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/invoke/timeout" +) + +func TestRecentCache(t *testing.T) { + cache := timeout.NewRecentCache() + + assert.False(t, cache.Consume("invoke-0")) + + cache.Register("invoke-0") + assert.True(t, cache.Consume("invoke-0")) + + assert.False(t, cache.Consume("invoke-0")) + + for i := range 1000 { + cache.Register(fmt.Sprintf("invoke-%d", i)) + } + + cache.Register("invoke-0") + + cache.Register("invoke-1000") + assert.False(t, cache.Consume("invoke-1")) + + assert.True(t, cache.Consume("invoke-0")) +} + +func TestRecentCache_Concurrency(t *testing.T) { + cache := timeout.NewRecentCache() + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + var wg sync.WaitGroup + + for i := range 2000 { + invokeID := fmt.Sprintf("invoke-%d", i) + + wg.Go(func() { + for { + select { + case <-ctx.Done(): + return + default: + + } + + cache.Register(invokeID) + } + }) + + wg.Go(func() { + for { + select { + case <-ctx.Done(): + return + default: + + } + + cache.Consume(invokeID) + } + }) + } + + wg.Wait() +} diff --git a/internal/lambda-managed-instances/invoke/utils.go b/internal/lambda-managed-instances/invoke/utils.go new file mode 100644 index 0000000..a04538f --- /dev/null +++ b/internal/lambda-managed-instances/invoke/utils.go @@ -0,0 +1,30 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package invoke + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" +) + +func BuildInvokeAppError(err error, invokeTimeout time.Duration) model.AppError { + var appError model.AppError + + switch { + case errors.As(err, &appError): + return appError + case errors.Is(err, context.DeadlineExceeded): + return model.NewCustomerError( + model.ErrorSandboxTimedout, + model.WithCause(err), + model.WithErrorMessage(fmt.Sprintf("Task timed out after %.2f seconds", invokeTimeout.Seconds())), + ) + default: + return model.NewPlatformError(err, "BuildInvokeAppError doesn't know this error") + } +} diff --git a/internal/lambda-managed-instances/logging/contextual_logger.go b/internal/lambda-managed-instances/logging/contextual_logger.go new file mode 100644 index 0000000..0580ff0 --- /dev/null +++ b/internal/lambda-managed-instances/logging/contextual_logger.go @@ -0,0 +1,82 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package logging + +import ( + "context" + "io" + "log/slog" + "os" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" +) + +func init() { + w := io.Writer(os.Stderr) + + slog.SetDefault(CreateNewLogger(slog.LevelInfo, w)) +} + +type loggerKey int + +const ( + ctxLoggerKey loggerKey = iota +) + +func CreateNewLogger(level slog.Level, w io.Writer) *slog.Logger { + return slog.New(slog.NewTextHandler(w, &slog.HandlerOptions{ + Level: level, + ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr { + if a.Key == slog.TimeKey { + return slog.Attr{ + Key: slog.TimeKey, + Value: slog.StringValue(a.Value.Time().UTC().Format("2006-01-02T15:04:05.0000Z")), + } + } + return a + }, + })) +} + +func WithFields(ctx context.Context, args ...any) context.Context { + logger := FromContext(ctx) + logger = logger.With(args...) + return context.WithValue(ctx, ctxLoggerKey, logger) +} + +func WithInvokeID(ctx context.Context, invokeID interop.InvokeID) context.Context { + return WithFields(ctx, interop.RequestIdProperty, invokeID) +} + +func Debug(ctx context.Context, msg string, args ...any) { + FromContext(ctx).Debug(msg, args...) +} + +func Info(ctx context.Context, msg string, args ...any) { + FromContext(ctx).Info(msg, args...) +} + +func Warn(ctx context.Context, msg string, args ...any) { + FromContext(ctx).Warn(msg, args...) +} + +func Error(ctx context.Context, msg string, args ...any) { + FromContext(ctx).Error(msg, args...) +} + +func Err(ctx context.Context, msg string, err model.AppError) { + level := slog.LevelWarn + if err.Source() != model.ErrorSourceRuntime { + level = slog.LevelError + } + FromContext(ctx).Log(ctx, level, msg, "err", err) +} + +func FromContext(ctx context.Context) *slog.Logger { + if logger, ok := ctx.Value(ctxLoggerKey).(*slog.Logger); ok { + return logger + } + return slog.Default() +} diff --git a/internal/lambda-managed-instances/logging/contextual_logger_test.go b/internal/lambda-managed-instances/logging/contextual_logger_test.go new file mode 100644 index 0000000..aa0d8d4 --- /dev/null +++ b/internal/lambda-managed-instances/logging/contextual_logger_test.go @@ -0,0 +1,91 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package logging + +import ( + "context" + "log/slog" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" +) + +func getEmptyContext() context.Context { + return context.Background() +} + +func TestContextualLogger(t *testing.T) { + tests := []struct { + name string + setupContext func() context.Context + logFunc func(ctx context.Context, msg string, args ...any) + message string + args []any + }{ + { + name: "Default_Info", + setupContext: getEmptyContext, + logFunc: Info, + message: "info message", + args: []any{"key", "value"}, + }, + { + name: "Default_Debug", + setupContext: getEmptyContext, + logFunc: Debug, + message: "debug message", + args: []any{"someData", "123"}, + }, + { + name: "Default_Warn", + setupContext: getEmptyContext, + logFunc: Warn, + message: "warn message", + args: []any{"warning", "something"}, + }, + { + name: "Default_Error", + setupContext: getEmptyContext, + logFunc: Error, + message: "error message", + args: []any{"error", "something failed"}, + }, + { + name: "Context_InvokeId", + setupContext: func() context.Context { + return WithInvokeID(context.Background(), interop.InvokeID("12345")) + }, + logFunc: Info, + message: "processing request", + args: []any{"error", "something"}, + }, + { + name: "Context_Multiple_Values", + setupContext: func() context.Context { + ctx := WithInvokeID(context.Background(), interop.InvokeID("11111")) + return WithFields(ctx, "userID", "user-789") + }, + logFunc: Info, + message: "chained context logging", + args: []any{"error", "something"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + testCtx := tt.setupContext() + logger := FromContext(testCtx) + + tt.logFunc(testCtx, tt.message, tt.args...) + + if testCtx == getEmptyContext() { + assert.Equal(t, logger, slog.Default(), "Should return default logger for empty context") + } else { + assert.NotEqual(t, logger, slog.Default(), "Should return contextual logger, not default, when With functions are used") + } + }) + } +} diff --git a/internal/lambda-managed-instances/model/init.go b/internal/lambda-managed-instances/model/init.go new file mode 100644 index 0000000..5c25ad1 --- /dev/null +++ b/internal/lambda-managed-instances/model/init.go @@ -0,0 +1,141 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package model + +import ( + "encoding/json" + "fmt" + "net/netip" + "time" +) + +type InitRequestMessage struct { + AccountID string `json:"account_id"` + AwsKey string `json:"aws_key"` + AwsSecret string `json:"aws_secret"` + AwsSession string `json:"aws_session"` + AwsRegion string `json:"aws_region"` + EnvVars KVMap `json:"env_vars"` + MemorySizeBytes int `json:"ram_limit"` + FunctionARN string `json:"function_arn"` + FunctionVersion string `json:"function_version"` + FunctionVersionID string `json:"version_id"` + ArtefactType ArtefactType `json:"artefact_type"` + TaskName string `json:"task_name"` + Handler string `json:"handler,omitempty"` + InvokeTimeout DurationMS `json:"invoke_timeout_ms"` + InitTimeout DurationMS `json:"init_timeout_ms"` + RuntimeVersion string `json:"runtime_version,omitempty"` + RuntimeArn string `json:"runtime_arn,omitempty"` + RuntimeWorkerCount int `json:"runtime_worker_count"` + LogFormat string `json:"log_format"` + LogLevel string `json:"log_level"` + LogGroupName string `json:"log_group_name"` + LogStreamName string `json:"log_stream_name"` + TelemetryAPIAddress TelemetryAddr `json:"telemetry_api_address"` + TelemetryPassphrase string `json:"telemetry_passphrase"` + XRayDaemonAddress string `json:"xray_daemon_address"` + XrayTracingMode XrayTracingMode `json:"xray_tracing_mode"` + CurrentWorkingDir string `json:"cwd"` + RuntimeBinaryCommand []string `json:"cmd"` + + AvailabilityZoneId string `json:"aws_availability_zone_id"` + + AmiId string `json:"ami_id"` +} + +type DurationMS time.Duration + +func (d *DurationMS) UnmarshalJSON(b []byte) error { + var ms int + if err := json.Unmarshal(b, &ms); err != nil { + return err + } + *d = DurationMS(ms * int(time.Millisecond)) + return nil +} + +func (d DurationMS) MarshalJSON() ([]byte, error) { + ms := time.Duration(d).Milliseconds() + return json.Marshal(ms) +} + +type KVSlice []string + +type KVMap map[string]string + +type TelemetryAddr netip.AddrPort + +func (t *TelemetryAddr) UnmarshalJSON(b []byte) error { + var s string + if err := json.Unmarshal(b, &s); err != nil { + return err + } + addrPort, err := netip.ParseAddrPort(s) + if err != nil { + return err + } + *t = TelemetryAddr(addrPort) + return nil +} + +func (t TelemetryAddr) MarshalJSON() ([]byte, error) { + return json.Marshal(netip.AddrPort(t).String()) +} + +func (i InitRequestMessage) String() string { + return fmt.Sprintf("InitRequestMessage{"+ + "AccountID=%s, "+ + "AwsRegion=%s, "+ + "FunctionARN=%s, "+ + "FunctionVersion=%s, "+ + "FunctionVersionID=%s, "+ + "ArtefactType=%s, "+ + "TaskName=%s, "+ + "Handler=%s, "+ + "InvokeTimeout=%dms, "+ + "InitTimeout=%dms, "+ + "RuntimeVersion=%s, "+ + "RuntimeArn=%s, "+ + "RuntimeWorkerCount=%d, "+ + "LogFormat=%s, "+ + "LogLevel=%s, "+ + "LogGroupName=%s, "+ + "LogStreamName=%s, "+ + "TelemetryAPIAddress=%s, "+ + "XRayDaemonAddress=%s, "+ + "XrayTracingMode=%s, "+ + "CurrentWorkingDir=%s, "+ + "AvailabilityZoneId=%s, "+ + "AmiId=%s, "+ + "MemorySizeBytes=%d, "+ + "RuntimeBinaryCommand=%v"+ + "}", + i.AccountID, + i.AwsRegion, + i.FunctionARN, + i.FunctionVersion, + i.FunctionVersionID, + i.ArtefactType, + i.TaskName, + i.Handler, + time.Duration(i.InvokeTimeout).Milliseconds(), + time.Duration(i.InitTimeout).Milliseconds(), + i.RuntimeVersion, + i.RuntimeArn, + i.RuntimeWorkerCount, + i.LogFormat, + i.LogLevel, + i.LogGroupName, + i.LogStreamName, + netip.AddrPort(i.TelemetryAPIAddress).String(), + i.XRayDaemonAddress, + i.XrayTracingMode, + i.CurrentWorkingDir, + i.AvailabilityZoneId, + i.AmiId, + i.MemorySizeBytes, + i.RuntimeBinaryCommand, + ) +} diff --git a/internal/lambda-managed-instances/model/init_test.go b/internal/lambda-managed-instances/model/init_test.go new file mode 100644 index 0000000..16b39b4 --- /dev/null +++ b/internal/lambda-managed-instances/model/init_test.go @@ -0,0 +1,171 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package model + +import ( + "encoding/json" + "net/netip" + "reflect" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestDurationMS_UnmarshalJSON(t *testing.T) { + tests := []struct { + input string + expected time.Duration + hasError bool + }{ + {`1000`, 1 * time.Second, false}, + {`1500`, 1500 * time.Millisecond, false}, + {`0`, 0, false}, + {`-500`, -500 * time.Millisecond, false}, + {`"invalid"`, 0, true}, + {`null`, 0, false}, + {`123.456`, 0, true}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + var d DurationMS + err := json.Unmarshal([]byte(tt.input), &d) + if (err != nil) != tt.hasError { + t.Errorf("input %s: unexpected error status: %v", tt.input, err) + return + } + if !tt.hasError && time.Duration(d) != tt.expected { + t.Errorf("input %s: expected %v, got %v", tt.input, tt.expected, d) + } + }) + } +} + +func TestDurationMS_MarshalJSON(t *testing.T) { + tests := []struct { + name string + d DurationMS + want []byte + wantErr bool + }{ + { + name: "1_second", + d: DurationMS(1 * time.Second), + want: []byte("1000"), + wantErr: false, + }, + { + name: "1_5_seconds", + d: DurationMS(1500 * time.Millisecond), + want: []byte("1500"), + wantErr: false, + }, + { + name: "zero_duration", + d: DurationMS(0), + want: []byte("0"), + wantErr: false, + }, + { + name: "negative_duration", + d: DurationMS(-500 * time.Millisecond), + want: []byte("-500"), + wantErr: false, + }, + { + name: "large_duration", + d: DurationMS(10 * time.Minute), + want: []byte("600000"), + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := json.Marshal(tt.d) + if (err != nil) != tt.wantErr { + t.Errorf("json.Marshal() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("json.Marshal() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestInitRequestMessage_String(t *testing.T) { + + awsKey := "AKIAIOSFODNN7EXAMPLE" + awsSecret := "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY" + awsSession := "FwoGZXIvYXdzEMj//////////wEaDM1Qz0oN8BNwV9GqyyLVAebxhwq9ZGqojXZe1UTJkzK6F9V+VZHhT5JSWYzJUKEwOqOkQyQXJpfJsYHfkJEXtR6Kh9mXnEbqKi" + telemetryPassphrase := "hello" + envVarKey := "CUSTOMER_ENV_VAR_1" + envVarValue := "customer_env_value_1" + + sensitiveValues := []string{ + awsKey, + awsSecret, + awsSession, + telemetryPassphrase, + envVarKey, + envVarValue, + } + + tests := []struct { + name string + msg InitRequestMessage + expected string + }{ + { + name: "Test Init Message", + msg: InitRequestMessage{ + AccountID: "123456789012", + AwsKey: awsKey, + AwsSecret: awsSecret, + AwsSession: awsSession, + AwsRegion: "us-west-2", + EnvVars: map[string]string{ + envVarKey: envVarValue, + }, + ArtefactType: ArtefactTypeZIP, + MemorySizeBytes: 3008 * 1024 * 1024, + FunctionARN: "arn:aws:lambda:us-east-1:123456789012:function:test_function", + FunctionVersion: "$LATEST", + FunctionVersionID: "test-function-version-id", + TaskName: "test_function", + InvokeTimeout: DurationMS(3 * time.Second), + InitTimeout: DurationMS(10 * time.Second), + RuntimeWorkerCount: 1, + LogFormat: "json", + LogLevel: "info", + LogGroupName: "/aws/lambda/test_function", + LogStreamName: "$LATEST", + TelemetryAPIAddress: TelemetryAddr(netip.MustParseAddrPort("1.1.1.1:1234")), + TelemetryPassphrase: telemetryPassphrase, + XRayDaemonAddress: "2.2.2.2:2345", + XrayTracingMode: XRayTracingModeActive, + RuntimeBinaryCommand: []string{"cmd", "arg1", "arg2"}, + CurrentWorkingDir: "/", + AmiId: "ami-12345", + AvailabilityZoneId: "az-1", + Handler: "lambda_function.lambda_handler", + }, + expected: "InitRequestMessage{AccountID=123456789012, AwsRegion=us-west-2, FunctionARN=arn:aws:lambda:us-east-1:123456789012:function:test_function, FunctionVersion=$LATEST, FunctionVersionID=test-function-version-id, ArtefactType=zip, TaskName=test_function, Handler=lambda_function.lambda_handler, InvokeTimeout=3000ms, InitTimeout=10000ms, RuntimeVersion=, RuntimeArn=, RuntimeWorkerCount=1, LogFormat=json, LogLevel=info, LogGroupName=/aws/lambda/test_function, LogStreamName=$LATEST, TelemetryAPIAddress=1.1.1.1:1234, XRayDaemonAddress=2.2.2.2:2345, XrayTracingMode=Active, CurrentWorkingDir=/, AvailabilityZoneId=az-1, AmiId=ami-12345, MemorySizeBytes=3154116608, RuntimeBinaryCommand=[cmd arg1 arg2]}", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.msg.String() + if result != tt.expected { + assert.Equal(t, tt.expected, result, "Incorrect InitRequestMessage string representation") + } + + for _, sensitiveValue := range sensitiveValues { + assert.NotContains(t, result, sensitiveValue, "String() output should not contain sensitive value: %s", sensitiveValue) + } + }) + } +} diff --git a/internal/lambda-managed-instances/model/model.go b/internal/lambda-managed-instances/model/model.go new file mode 100644 index 0000000..2574fce --- /dev/null +++ b/internal/lambda-managed-instances/model/model.go @@ -0,0 +1,18 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package model + +type XrayTracingMode string + +const ( + XRayTracingModeActive XrayTracingMode = "Active" + XRayTracingModePassThrough XrayTracingMode = "PassThrough" +) + +type ArtefactType string + +const ( + ArtefactTypeOCI ArtefactType = "oci" + ArtefactTypeZIP ArtefactType = "zip" +) diff --git a/internal/lambda-managed-instances/ptr/ptr.go b/internal/lambda-managed-instances/ptr/ptr.go new file mode 100644 index 0000000..8950239 --- /dev/null +++ b/internal/lambda-managed-instances/ptr/ptr.go @@ -0,0 +1,8 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package ptr + +func To[T any](v T) *T { + return &v +} diff --git a/internal/lambda-managed-instances/rapi/extensions_fuzz_test.go b/internal/lambda-managed-instances/rapi/extensions_fuzz_test.go new file mode 100644 index 0000000..1d2f728 --- /dev/null +++ b/internal/lambda-managed-instances/rapi/extensions_fuzz_test.go @@ -0,0 +1,311 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package rapi + +import ( + "bytes" + "encoding/json" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/appctx" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/core" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi/handler" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi/rendering" + rapidmodel "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/testdata" +) + +func FuzzAgentRegisterHandler(f *testing.F) { + registerReq := handler.RegisterRequest{ + Events: []core.Event{core.ShutdownEvent}, + } + regReqBytes, err := json.Marshal(®isterReq) + if err != nil { + f.Errorf("failed to marshal register request: %v", err) + } + f.Add("agent", "accountId", true, regReqBytes) + f.Add("agent", "accountId", false, regReqBytes) + + f.Fuzz(func(t *testing.T, + agentName string, + featuresHeader string, + external bool, + payload []byte, + ) { + flowTest := testdata.NewFlowTest() + + if external { + _, err := flowTest.RegistrationService.CreateExternalAgent(agentName) + require.NoError(t, err) + } + + functionMetadata := createDummyFunctionMetadata() + flowTest.RegistrationService.SetFunctionMetadata(functionMetadata) + + rapiServer := makeRapiServer(flowTest) + + target := makeTargetURL("/extension/register", version20200101) + request := httptest.NewRequest("POST", target, bytes.NewReader(payload)) + request.Header.Add(handler.LambdaAgentName, agentName) + request.Header.Add("Lambda-Extension-Accept-Feature", featuresHeader) + + responseRecorder := serveTestRequest(rapiServer, request) + + if agentName == "" { + assertForbiddenErrorType(t, responseRecorder, "Extension.InvalidExtensionName") + return + } + + regReqStruct := struct { + handler.RegisterRequest + ConfigurationKeys []string `json:"configurationKeys"` + }{} + if err := json.Unmarshal(payload, ®ReqStruct); err != nil { + assertForbiddenErrorType(t, responseRecorder, "InvalidRequestFormat") + return + } + + if containsInvalidEvent(external, regReqStruct.Events) { + assertForbiddenErrorType(t, responseRecorder, "Extension.InvalidEventType") + return + } + + assert.Equal(t, http.StatusOK, responseRecorder.Code) + + respBody, err := io.ReadAll(responseRecorder.Body) + assert.NoError(t, err) + + expectedResponse := map[string]interface{}{ + "functionName": functionMetadata.FunctionName, + "functionVersion": functionMetadata.FunctionVersion, + "handler": functionMetadata.Handler, + } + if featuresHeader == "accountId" && functionMetadata.AccountID != "" { + expectedResponse["accountId"] = functionMetadata.AccountID + } + + expectedRespBytes, err := json.Marshal(expectedResponse) + assert.NoError(t, err) + assert.JSONEq(t, string(expectedRespBytes), string(respBody)) + + if external { + agent, found := flowTest.RegistrationService.FindExternalAgentByName(agentName) + assert.True(t, found) + assert.Equal(t, agent.RegisteredState, agent.GetState()) + } else { + agent, found := flowTest.RegistrationService.FindInternalAgentByName(agentName) + assert.True(t, found) + assert.Equal(t, agent.RegisteredState, agent.GetState()) + } + }) +} + +func FuzzAgentNextHandler(f *testing.F) { + regService := core.NewRegistrationService(core.NewInitFlowSynchronization()) + testAgent := makeExternalAgent(regService) + f.Add(testAgent.ID().String(), true) + f.Add(testAgent.ID().String(), false) + + f.Fuzz(func(t *testing.T, + agentIdentifierHeader string, + registered bool, + ) { + flowTest := testdata.NewFlowTest() + agent := makeExternalAgent(flowTest.RegistrationService) + + if registered { + agent.SetState(agent.RegisteredState) + agent.Release() + } + + configureRendererForEvent(flowTest) + + rapiServer := makeRapiServer(flowTest) + + target := makeTargetURL("/extension/event/next", version20200101) + request := httptest.NewRequest("GET", target, nil) + request.Header.Set(model.LambdaAgentIdentifier, agentIdentifierHeader) + + responseRecorder := serveTestRequest(rapiServer, request) + + if agentIdentifierHeader == "" { + assertForbiddenErrorType(t, responseRecorder, model.ErrAgentIdentifierMissing) + return + } + if _, err := uuid.Parse(agentIdentifierHeader); err != nil { + assertForbiddenErrorType(t, responseRecorder, model.ErrAgentIdentifierInvalid) + return + } + if agentIdentifierHeader != agent.ID().String() { + assertForbiddenErrorType(t, responseRecorder, "Extension.UnknownExtensionIdentifier") + return + } + if !registered { + assertForbiddenErrorType(t, responseRecorder, "Extension.InvalidExtensionState") + return + } + + assert.Equal(t, http.StatusOK, responseRecorder.Code) + + assertResponseEventType(t, responseRecorder) + + assert.Equal(t, agent.RunningState, agent.GetState()) + }) +} + +func FuzzAgentInitErrorHandler(f *testing.F) { + fuzzErrorHandler(f, "/extension/init/error", rapidmodel.ErrorAgentInit) +} + +func FuzzAgentExitErrorHandler(f *testing.F) { + fuzzErrorHandler(f, "/extension/exit/error", rapidmodel.ErrorAgentExit) +} + +func fuzzErrorHandler(f *testing.F, handlerPath string, fatalErrorType rapidmodel.ErrorType) { + regService := core.NewRegistrationService(core.NewInitFlowSynchronization()) + testAgent := makeExternalAgent(regService) + f.Add(true, testAgent.ID().String(), "Extension.SomeError") + f.Add(false, testAgent.ID().String(), "Extension.SomeError") + + f.Fuzz(func(t *testing.T, + agentRegistered bool, + agentIdentifierHeader string, + errorType string, + ) { + flowTest := testdata.NewFlowTest() + + agent := makeExternalAgent(flowTest.RegistrationService) + + if agentRegistered { + agent.SetState(agent.RegisteredState) + } + + rapiServer := makeRapiServer(flowTest) + + target := makeTargetURL(handlerPath, version20200101) + + request := httptest.NewRequest("POST", target, nil) + request = appctx.RequestWithAppCtx(request, flowTest.AppCtx) + request.Header.Set(model.LambdaAgentIdentifier, agentIdentifierHeader) + request.Header.Set(handler.LambdaAgentFunctionErrorType, errorType) + + responseRecorder := serveTestRequest(rapiServer, request) + + if agentIdentifierHeader == "" { + assertForbiddenErrorType(t, responseRecorder, model.ErrAgentIdentifierMissing) + return + } + + if _, e := uuid.Parse(agentIdentifierHeader); e != nil { + assertForbiddenErrorType(t, responseRecorder, model.ErrAgentIdentifierInvalid) + return + } + + if errorType == "" { + assertForbiddenErrorType(t, responseRecorder, "Extension.MissingHeader") + return + } + if agentIdentifierHeader != agent.ID().String() { + assertForbiddenErrorType(t, responseRecorder, "Extension.UnknownExtensionIdentifier") + return + } + if !agentRegistered { + assertForbiddenErrorType(t, responseRecorder, "Extension.InvalidExtensionState") + } else { + assertErrorAgentRegistered(t, responseRecorder, flowTest, fatalErrorType) + } + }) +} + +func assertErrorAgentRegistered(t *testing.T, responseRecorder *httptest.ResponseRecorder, flowTest *testdata.FlowTest, expectedErrType rapidmodel.ErrorType) { + var response model.StatusResponse + + respBody, _ := io.ReadAll(responseRecorder.Body) + err := json.Unmarshal(respBody, &response) + assert.NoError(t, err) + + assert.Equal(t, http.StatusAccepted, responseRecorder.Code) + assert.Equal(t, "OK", response.Status) + + v, found := appctx.LoadFirstFatalError(flowTest.AppCtx) + assert.True(t, found) + assert.Equal(t, expectedErrType, v) +} + +func assertForbiddenErrorType(t *testing.T, responseRecorder *httptest.ResponseRecorder, errType string) { + assert.Equal(t, http.StatusForbidden, responseRecorder.Code) + + var errorResponse model.ErrorResponse + + respBody, _ := io.ReadAll(responseRecorder.Body) + err := json.Unmarshal(respBody, &errorResponse) + assert.NoError(t, err) + + assert.Equal(t, errType, errorResponse.ErrorType) +} + +func createDummyFunctionMetadata() rapidmodel.FunctionMetadata { + return rapidmodel.FunctionMetadata{ + AccountID: "accID", + FunctionName: "myFunc", + FunctionVersion: "1.0", + Handler: "myHandler", + } +} + +func makeExternalAgent(registrationService core.RegistrationService) *core.ExternalAgent { + agent, err := registrationService.CreateExternalAgent("agent") + if err != nil { + slog.Error("failed to create external agent", "error", err) + panic(err) + } + + return agent +} + +func configureRendererForEvent(flowTest *testdata.FlowTest) { + flowTest.RenderingService.SetRenderer( + &rendering.ShutdownRenderer{ + AgentEvent: model.AgentShutdownEvent{ + AgentEvent: &model.AgentEvent{ + EventType: "SHUTDOWN", + DeadlineMs: int64(10000), + }, + ShutdownReason: "spindown", + }, + }) +} + +func assertResponseEventType(t *testing.T, responseRecorder *httptest.ResponseRecorder) { + var response model.AgentShutdownEvent + + respBody, _ := io.ReadAll(responseRecorder.Body) + err := json.Unmarshal(respBody, &response) + assert.NoError(t, err) + + assert.Equal(t, "SHUTDOWN", response.EventType) +} + +func containsInvalidEvent(external bool, events []core.Event) bool { + for _, e := range events { + if external { + if err := core.ValidateExternalAgentEvent(e); err != nil { + return true + } + } else if len(events) > 0 { + return true + } + } + + return false +} diff --git a/internal/lambda-managed-instances/rapi/handler/agentexiterror.go b/internal/lambda-managed-instances/rapi/handler/agentexiterror.go new file mode 100644 index 0000000..18415f4 --- /dev/null +++ b/internal/lambda-managed-instances/rapi/handler/agentexiterror.go @@ -0,0 +1,72 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package handler + +import ( + "net/http" + + "github.com/google/uuid" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/appctx" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/core" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/logging" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi/rendering" + rapidmodel "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" +) + +type agentExitErrorHandler struct { + registrationService core.RegistrationService +} + +func (h *agentExitErrorHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + agentID, ok := request.Context().Value(model.AgentIDCtxKey).(uuid.UUID) + if !ok { + rendering.RenderInternalServerError(writer, request) + return + } + + ctx := logging.WithFields(request.Context(), "agentID", agentID.String()) + + var rawErrorType string + if rawErrorType = request.Header.Get(LambdaAgentFunctionErrorType); rawErrorType == "" { + logging.Warn(ctx, "Extension exit error missing header", "header", LambdaAgentFunctionErrorType) + rendering.RenderForbiddenWithTypeMsg(writer, request, errAgentMissingHeader, "%s not found", LambdaAgentFunctionErrorType) + return + } + + errorType := rapidmodel.GetValidExtensionErrorType(rawErrorType, rapidmodel.ErrorAgentExit) + logging.Warn(ctx, "Received Extension exit error request", "errorType", errorType) + + if externalAgent, found := h.registrationService.FindExternalAgentByID(agentID); found { + ctx = logging.WithFields(ctx, "extension", externalAgent.Name()) + if err := externalAgent.ExitError(errorType); err != nil { + logging.Warn(ctx, "Extension exit error transition failed", "err", err, "currentState", externalAgent.GetState().Name()) + rendering.RenderForbiddenWithTypeMsg(writer, request, errAgentInvalidState, StateTransitionFailedForExtensionMessageFormat, + externalAgent.GetState().Name(), core.AgentExitedStateName, agentID.String(), err) + return + } + } else if internalAgent, found := h.registrationService.FindInternalAgentByID(agentID); found { + ctx = logging.WithFields(ctx, "extension", internalAgent.Name()) + if err := internalAgent.ExitError(errorType); err != nil { + logging.Warn(ctx, "Extension exit error transition failed", "err", err, "currentState", internalAgent.GetState().Name()) + rendering.RenderForbiddenWithTypeMsg(writer, request, errAgentInvalidState, StateTransitionFailedForExtensionMessageFormat, + internalAgent.GetState().Name(), core.AgentExitedStateName, agentID.String(), err) + return + } + } else { + logging.Warn(ctx, "Unknown extension exit error request") + rendering.RenderForbiddenWithTypeMsg(writer, request, errAgentIdentifierUnknown, "Unknown "+model.LambdaAgentIdentifier) + return + } + + appctx.StoreFirstFatalError(appctx.FromRequest(request), rapidmodel.WrapErrorIntoCustomerFatalError(nil, errorType)) + rendering.RenderAccepted(writer, request) +} + +func NewAgentExitErrorHandler(registrationService core.RegistrationService) http.Handler { + return &agentExitErrorHandler{ + registrationService: registrationService, + } +} diff --git a/internal/lambda-managed-instances/rapi/handler/agentiniterror.go b/internal/lambda-managed-instances/rapi/handler/agentiniterror.go new file mode 100644 index 0000000..3b51d93 --- /dev/null +++ b/internal/lambda-managed-instances/rapi/handler/agentiniterror.go @@ -0,0 +1,69 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package handler + +import ( + "net/http" + + "github.com/google/uuid" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/appctx" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/core" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/logging" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi/rendering" + rapidmodel "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" +) + +type agentInitErrorHandler struct { + registrationService core.RegistrationService +} + +func (h *agentInitErrorHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + agentID, ok := request.Context().Value(model.AgentIDCtxKey).(uuid.UUID) + if !ok { + rendering.RenderInternalServerError(writer, request) + return + } + ctx := logging.WithFields(request.Context(), "agentID", agentID.String()) + + var rawErrorType string + if rawErrorType = request.Header.Get(LambdaAgentFunctionErrorType); rawErrorType == "" { + logging.Warn(ctx, "Invalid /extension/init/error: missing header", "header", LambdaAgentFunctionErrorType) + rendering.RenderForbiddenWithTypeMsg(writer, request, errAgentMissingHeader, "%s not found", LambdaAgentFunctionErrorType) + return + } + + errorType := rapidmodel.GetValidExtensionErrorType(rawErrorType, rapidmodel.ErrorAgentInit) + logging.Warn(ctx, "Received extension Init error", "errorType", errorType) + + if externalAgent, found := h.registrationService.FindExternalAgentByID(agentID); found { + if err := externalAgent.InitError(errorType); err != nil { + logging.Warn(ctx, "InitError() failed for external agent", "agent", externalAgent.String(), "err", err, "state", externalAgent.GetState().Name()) + rendering.RenderForbiddenWithTypeMsg(writer, request, errAgentInvalidState, StateTransitionFailedForExtensionMessageFormat, + externalAgent.GetState().Name(), core.AgentInitErrorStateName, agentID.String(), err) + return + } + } else if internalAgent, found := h.registrationService.FindInternalAgentByID(agentID); found { + if err := internalAgent.InitError(errorType); err != nil { + logging.Warn(ctx, "InitError() failed for internal agent", "agent", internalAgent.String(), "err", err, "state", internalAgent.GetState().Name()) + rendering.RenderForbiddenWithTypeMsg(writer, request, errAgentInvalidState, StateTransitionFailedForExtensionMessageFormat, + internalAgent.GetState().Name(), core.AgentInitErrorStateName, agentID.String(), err) + return + } + } else { + logging.Warn(ctx, "Unknown agent tried to call /extension/init/error") + rendering.RenderForbiddenWithTypeMsg(writer, request, errAgentIdentifierUnknown, "Unknown "+model.LambdaAgentIdentifier) + return + } + + appctx.StoreFirstFatalError(appctx.FromRequest(request), rapidmodel.WrapErrorIntoCustomerFatalError(nil, errorType)) + rendering.RenderAccepted(writer, request) +} + +func NewAgentInitErrorHandler(registrationService core.RegistrationService) http.Handler { + return &agentInitErrorHandler{ + registrationService: registrationService, + } +} diff --git a/internal/lambda-managed-instances/rapi/handler/agentiniterror_test.go b/internal/lambda-managed-instances/rapi/handler/agentiniterror_test.go new file mode 100644 index 0000000..755df1b --- /dev/null +++ b/internal/lambda-managed-instances/rapi/handler/agentiniterror_test.go @@ -0,0 +1,120 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package handler + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/appctx" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/core" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi/model" + rapidmodel "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" +) + +func newRequest(appCtx appctx.ApplicationContext, agentID uuid.UUID) *http.Request { + request := httptest.NewRequest("POST", "/", nil) + request = request.WithContext(context.WithValue(context.Background(), model.AgentIDCtxKey, agentID)) + request = appctx.RequestWithAppCtx(request, appCtx) + request.Header.Set(LambdaAgentFunctionErrorType, "Extension.TestError") + return request +} + +func TestAgentInitErrorInternalError(t *testing.T) { + registrationService := core.NewRegistrationService( + core.NewInitFlowSynchronization()) + handler := NewAgentInitErrorHandler(registrationService) + request := httptest.NewRequest("POST", "/", nil) + + responseRecorder := httptest.NewRecorder() + + handler.ServeHTTP(responseRecorder, request) + assert.Equal(t, http.StatusInternalServerError, responseRecorder.Code) +} + +func TestAgentInitErrorMissingErrorHeader(t *testing.T) { + registrationService := core.NewRegistrationService( + core.NewInitFlowSynchronization()) + + appCtx := appctx.NewApplicationContext() + agent, err := registrationService.CreateExternalAgent("dummyName") + agent.SetState(agent.RegisteredState) + assert.NoError(t, err) + handler := NewAgentInitErrorHandler(registrationService) + responseRecorder := httptest.NewRecorder() + + req := newRequest(appCtx, uuid.New()) + req.Header.Del(LambdaAgentFunctionErrorType) + handler.ServeHTTP(responseRecorder, req) + assert.Equal(t, http.StatusForbidden, responseRecorder.Code) + + var errorResponse model.ErrorResponse + respBody, _ := io.ReadAll(responseRecorder.Body) + require.NoError(t, json.Unmarshal(respBody, &errorResponse)) + assert.Equal(t, errAgentMissingHeader, errorResponse.ErrorType) +} + +func TestAgentInitErrorUnknownAgent(t *testing.T) { + registrationService := core.NewRegistrationService( + core.NewInitFlowSynchronization()) + handler := NewAgentInitErrorHandler(registrationService) + responseRecorder := httptest.NewRecorder() + + handler.ServeHTTP(responseRecorder, newRequest(appctx.NewApplicationContext(), uuid.New())) + assert.Equal(t, http.StatusForbidden, responseRecorder.Code) + + var errorResponse model.ErrorResponse + respBody, _ := io.ReadAll(responseRecorder.Body) + require.NoError(t, json.Unmarshal(respBody, &errorResponse)) + assert.Equal(t, errAgentIdentifierUnknown, errorResponse.ErrorType) +} + +func TestAgentInitErrorAgentInvalidState(t *testing.T) { + registrationService := core.NewRegistrationService( + core.NewInitFlowSynchronization()) + + agent, err := registrationService.CreateExternalAgent("dummyName") + assert.NoError(t, err) + handler := NewAgentInitErrorHandler(registrationService) + responseRecorder := httptest.NewRecorder() + + handler.ServeHTTP(responseRecorder, newRequest(appctx.NewApplicationContext(), agent.ID())) + assert.Equal(t, http.StatusForbidden, responseRecorder.Code) + + var errorResponse model.ErrorResponse + respBody, _ := io.ReadAll(responseRecorder.Body) + require.NoError(t, json.Unmarshal(respBody, &errorResponse)) + assert.Equal(t, errAgentInvalidState, errorResponse.ErrorType) +} + +func TestAgentInitErrorRequestAccepted(t *testing.T) { + registrationService := core.NewRegistrationService( + core.NewInitFlowSynchronization()) + appCtx := appctx.NewApplicationContext() + agent, err := registrationService.CreateExternalAgent("dummyName") + agent.SetState(agent.RegisteredState) + assert.NoError(t, err) + handler := NewAgentInitErrorHandler(registrationService) + responseRecorder := httptest.NewRecorder() + + handler.ServeHTTP(responseRecorder, newRequest(appCtx, agent.ID())) + assert.Equal(t, http.StatusAccepted, responseRecorder.Code) + + var response model.StatusResponse + respBody, _ := io.ReadAll(responseRecorder.Body) + require.NoError(t, json.Unmarshal(respBody, &response)) + assert.Equal(t, "OK", response.Status) + + v, found := appctx.LoadFirstFatalError(appCtx) + assert.True(t, found) + assert.Equal(t, rapidmodel.ErrorType("Extension.TestError"), v.ErrorType()) +} diff --git a/internal/lambda-managed-instances/rapi/handler/agentnext.go b/internal/lambda-managed-instances/rapi/handler/agentnext.go new file mode 100644 index 0000000..78d947e --- /dev/null +++ b/internal/lambda-managed-instances/rapi/handler/agentnext.go @@ -0,0 +1,66 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package handler + +import ( + "net/http" + + "github.com/google/uuid" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/core" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/logging" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi/rendering" +) + +type agentNextHandler struct { + registrationService core.RegistrationService + renderingService *rendering.EventRenderingService +} + +func (h *agentNextHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + agentID, ok := request.Context().Value(model.AgentIDCtxKey).(uuid.UUID) + if !ok { + rendering.RenderInternalServerError(writer, request) + return + } + + ctx := logging.WithFields(request.Context(), "agentID", agentID.String()) + logging.Debug(ctx, "Received Extension /next") + + if externalAgent, found := h.registrationService.FindExternalAgentByID(agentID); found { + ctx = logging.WithFields(ctx, "agent", externalAgent.Name()) + if err := externalAgent.Ready(); err != nil { + logging.Warn(ctx, "Extension ready failed", "err", err, "state", externalAgent.GetState().Name()) + rendering.RenderForbiddenWithTypeMsg(writer, request, errAgentInvalidState, StateTransitionFailedForExtensionMessageFormat, + externalAgent.GetState().Name(), core.AgentReadyStateName, agentID.String(), err) + return + } + } else if internalAgent, found := h.registrationService.FindInternalAgentByID(agentID); found { + ctx = logging.WithFields(ctx, "agent", internalAgent.Name()) + if err := internalAgent.Ready(); err != nil { + logging.Warn(ctx, "Extension ready failed", "err", err, "state", internalAgent.GetState().Name()) + rendering.RenderForbiddenWithTypeMsg(writer, request, errAgentInvalidState, StateTransitionFailedForExtensionMessageFormat, + internalAgent.GetState().Name(), core.AgentReadyStateName, agentID.String(), err) + return + } + } else { + logging.Warn(ctx, "Unknown extension /next request") + rendering.RenderForbiddenWithTypeMsg(writer, request, errAgentIdentifierUnknown, "Unknown extension %s", agentID.String()) + return + } + + if err := h.renderingService.RenderAgentEvent(writer, request); err != nil { + logging.Error(ctx, "Render agent event failed", "err", err) + rendering.RenderInternalServerError(writer, request) + return + } +} + +func NewAgentNextHandler(registrationService core.RegistrationService, renderingService *rendering.EventRenderingService) http.Handler { + return &agentNextHandler{ + registrationService: registrationService, + renderingService: renderingService, + } +} diff --git a/internal/lambda-managed-instances/rapi/handler/agentnext_test.go b/internal/lambda-managed-instances/rapi/handler/agentnext_test.go new file mode 100644 index 0000000..1506b07 --- /dev/null +++ b/internal/lambda-managed-instances/rapi/handler/agentnext_test.go @@ -0,0 +1,289 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package handler + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/core" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi/rendering" +) + +func TestRenderAgentInternalError(t *testing.T) { + registrationService := core.NewRegistrationService( + core.NewInitFlowSynchronization()) + handler := NewAgentNextHandler(registrationService, rendering.NewRenderingService()) + request := httptest.NewRequest("GET", "/", nil) + + responseRecorder := httptest.NewRecorder() + + handler.ServeHTTP(responseRecorder, request) + assert.Equal(t, http.StatusInternalServerError, responseRecorder.Code) +} + +func TestRenderAgentInvokeUnknownAgent(t *testing.T) { + registrationService := core.NewRegistrationService( + core.NewInitFlowSynchronization()) + request := httptest.NewRequest("GET", "/", nil) + request = request.WithContext(context.WithValue(context.Background(), model.AgentIDCtxKey, uuid.New())) + responseRecorder := httptest.NewRecorder() + + handler := NewAgentNextHandler(registrationService, rendering.NewRenderingService()) + handler.ServeHTTP(responseRecorder, request) + assert.Equal(t, http.StatusForbidden, responseRecorder.Code) + + var errorResponse model.ErrorResponse + respBody, _ := io.ReadAll(responseRecorder.Body) + require.NoError(t, json.Unmarshal(respBody, &errorResponse)) + assert.Equal(t, http.StatusForbidden, responseRecorder.Code) + assert.Equal(t, errAgentIdentifierUnknown, errorResponse.ErrorType) +} + +func TestRenderAgentInvokeInvalidAgentState(t *testing.T) { + registrationService := core.NewRegistrationService( + core.NewInitFlowSynchronization()) + + agent, err := registrationService.CreateExternalAgent("dummyName") + assert.NoError(t, err) + handler := NewAgentNextHandler(registrationService, rendering.NewRenderingService()) + request := httptest.NewRequest("GET", "/", nil) + request = request.WithContext(context.WithValue(context.Background(), model.AgentIDCtxKey, agent.ID())) + responseRecorder := httptest.NewRecorder() + + handler.ServeHTTP(responseRecorder, request) + assert.Equal(t, http.StatusForbidden, responseRecorder.Code) + + var errorResponse model.ErrorResponse + respBody, _ := io.ReadAll(responseRecorder.Body) + require.NoError(t, json.Unmarshal(respBody, &errorResponse)) + assert.Equal(t, http.StatusForbidden, responseRecorder.Code) + assert.Equal(t, errAgentInvalidState, errorResponse.ErrorType) +} + +func TestRenderAgentInvokeNextHappy(t *testing.T) { + registrationService := core.NewRegistrationService( + core.NewInitFlowSynchronization()) + agent, err := registrationService.CreateExternalAgent("dummyName") + assert.NoError(t, err) + + agent.SetState(agent.RegisteredState) + agent.Release() + + deadlineNs := time.Now().Add(100 * time.Millisecond) + invokeID, functionArn := "ID", "InvokedFunctionArn" + traceID := "Root=RootID;Parent=LambdaFrontend;Sampled=1" + invoke := &interop.Invoke{ + TraceID: traceID, + ID: invokeID, + InvokedFunctionArn: functionArn, + CognitoIdentityID: "CognitoIdentityId1", + CognitoIdentityPoolID: "CognitoIdentityPoolId1", + ClientContext: "ClientContext", + Deadline: deadlineNs, + ContentType: "image/png", + Payload: strings.NewReader("Payload"), + } + + renderingService := rendering.NewRenderingService() + var buf bytes.Buffer + renderingService.SetRenderer(rendering.NewInvokeRenderer(context.Background(), invoke, &buf, func(context.Context) string { return "" })) + + handler := NewAgentNextHandler(registrationService, renderingService) + request := httptest.NewRequest("GET", "/", nil) + request = request.WithContext(context.WithValue(context.Background(), model.AgentIDCtxKey, agent.ID())) + responseRecorder := httptest.NewRecorder() + + handler.ServeHTTP(responseRecorder, request) + assert.Equal(t, http.StatusOK, responseRecorder.Code) + var response model.AgentInvokeEvent + respBody, _ := io.ReadAll(responseRecorder.Body) + require.NoError(t, json.Unmarshal(respBody, &response)) + + assert.Equal(t, agent.RunningState, agent.GetState()) + assert.Equal(t, "INVOKE", response.EventType) + assert.InDelta(t, time.Now().Add(100*time.Millisecond).UnixNano()/int64(time.Millisecond), response.DeadlineMs, 5) + assert.Equal(t, invokeID, response.RequestID) + assert.Equal(t, functionArn, response.InvokedFunctionArn) + assert.Equal(t, model.XRayTracingType, response.Tracing.Type) + assert.Equal(t, traceID, response.Tracing.Value) +} + +func TestRenderAgentInternalInvokeNextHappy(t *testing.T) { + registrationService := core.NewRegistrationService( + core.NewInitFlowSynchronization()) + agent, err := registrationService.CreateInternalAgent("dummyName") + assert.NoError(t, err) + + agent.SetState(agent.RegisteredState) + agent.Release() + + deadlineNs := time.Now().Add(100 * time.Millisecond) + invokeID, functionArn := "ID", "InvokedFunctionArn" + traceID := "Root=RootID;Parent=LambdaFrontend;Sampled=1" + invoke := &interop.Invoke{ + TraceID: traceID, + ID: invokeID, + InvokedFunctionArn: functionArn, + CognitoIdentityID: "CognitoIdentityId1", + CognitoIdentityPoolID: "CognitoIdentityPoolId1", + ClientContext: "ClientContext", + Deadline: deadlineNs, + ContentType: "image/png", + Payload: strings.NewReader("Payload"), + } + + renderingService := rendering.NewRenderingService() + var buf bytes.Buffer + renderingService.SetRenderer(rendering.NewInvokeRenderer(context.Background(), invoke, &buf, func(context.Context) string { return "" })) + + handler := NewAgentNextHandler(registrationService, renderingService) + request := httptest.NewRequest("GET", "/", nil) + request = request.WithContext(context.WithValue(context.Background(), model.AgentIDCtxKey, agent.ID())) + responseRecorder := httptest.NewRecorder() + + handler.ServeHTTP(responseRecorder, request) + assert.Equal(t, http.StatusOK, responseRecorder.Code) + var response model.AgentInvokeEvent + respBody, _ := io.ReadAll(responseRecorder.Body) + require.NoError(t, json.Unmarshal(respBody, &response)) + + assert.Equal(t, agent.RunningState, agent.GetState()) + assert.Equal(t, "INVOKE", response.EventType) + assert.InDelta(t, time.Now().Add(100*time.Millisecond).UnixNano()/int64(time.Millisecond), response.DeadlineMs, 5) + assert.Equal(t, invokeID, response.RequestID) + assert.Equal(t, functionArn, response.InvokedFunctionArn) + assert.Equal(t, model.XRayTracingType, response.Tracing.Type) + assert.Equal(t, traceID, response.Tracing.Value) +} + +func TestRenderAgentInternalShutdownEvent(t *testing.T) { + registrationService := core.NewRegistrationService( + core.NewInitFlowSynchronization()) + agent, err := registrationService.CreateInternalAgent("dummyName") + assert.NoError(t, err) + + agent.SetState(agent.RegisteredState) + agent.Release() + + renderingService := rendering.NewRenderingService() + deadlineMs := time.Now().UnixNano() / (1000 * 1000) + renderingService.SetRenderer( + &rendering.ShutdownRenderer{ + AgentEvent: model.AgentShutdownEvent{ + AgentEvent: &model.AgentEvent{ + EventType: "SHUTDOWN", + DeadlineMs: int64(deadlineMs), + }, + ShutdownReason: model.Spindown, + }, + }) + + handler := NewAgentNextHandler(registrationService, renderingService) + request := httptest.NewRequest("GET", "/", nil) + request = request.WithContext(context.WithValue(context.Background(), model.AgentIDCtxKey, agent.ID())) + responseRecorder := httptest.NewRecorder() + + handler.ServeHTTP(responseRecorder, request) + assert.Equal(t, http.StatusOK, responseRecorder.Code) + var response model.AgentShutdownEvent + respBody, _ := io.ReadAll(responseRecorder.Body) + require.NoError(t, json.Unmarshal(respBody, &response)) + + assert.Equal(t, agent.RunningState, agent.GetState()) + assert.Equal(t, "SHUTDOWN", response.EventType) + assert.Equal(t, int64(deadlineMs), response.DeadlineMs) + assert.Equal(t, model.Spindown, response.ShutdownReason) +} + +func TestRenderAgentExternalShutdownEvent(t *testing.T) { + registrationService := core.NewRegistrationService( + core.NewInitFlowSynchronization()) + agent, err := registrationService.CreateExternalAgent("dummyName") + assert.NoError(t, err) + + agent.SetState(agent.RegisteredState) + agent.Release() + + renderingService := rendering.NewRenderingService() + deadlineMs := time.Now().UnixNano() / (1000 * 1000) + renderingService.SetRenderer( + &rendering.ShutdownRenderer{ + AgentEvent: model.AgentShutdownEvent{ + AgentEvent: &model.AgentEvent{ + EventType: "SHUTDOWN", + DeadlineMs: int64(deadlineMs), + }, + ShutdownReason: model.Spindown, + }, + }) + + handler := NewAgentNextHandler(registrationService, renderingService) + request := httptest.NewRequest("GET", "/", nil) + request = request.WithContext(context.WithValue(context.Background(), model.AgentIDCtxKey, agent.ID())) + responseRecorder := httptest.NewRecorder() + + handler.ServeHTTP(responseRecorder, request) + assert.Equal(t, http.StatusOK, responseRecorder.Code) + var response model.AgentShutdownEvent + respBody, _ := io.ReadAll(responseRecorder.Body) + require.NoError(t, json.Unmarshal(respBody, &response)) + + assert.Equal(t, agent.RunningState, agent.GetState()) + assert.Equal(t, "SHUTDOWN", response.EventType) + assert.Equal(t, int64(deadlineMs), response.DeadlineMs) + assert.Equal(t, model.Spindown, response.ShutdownReason) +} + +func TestRenderAgentInvokeNextHappyEmptyTraceID(t *testing.T) { + registrationService := core.NewRegistrationService( + core.NewInitFlowSynchronization()) + agent, err := registrationService.CreateExternalAgent("dummyName") + assert.NoError(t, err) + + agent.SetState(agent.RegisteredState) + agent.Release() + + deadlineNs := time.Now().Add(100 * time.Millisecond) + invokeID, functionArn := "ID", "InvokedFunctionArn" + traceID := "" + invoke := &interop.Invoke{ + TraceID: traceID, + ID: invokeID, + InvokedFunctionArn: functionArn, + Deadline: deadlineNs, + ContentType: "image/png", + Payload: strings.NewReader("Payload"), + } + + renderingService := rendering.NewRenderingService() + var buf bytes.Buffer + renderingService.SetRenderer(rendering.NewInvokeRenderer(context.Background(), invoke, &buf, func(context.Context) string { return "" })) + + handler := NewAgentNextHandler(registrationService, renderingService) + request := httptest.NewRequest("GET", "/", nil) + request = request.WithContext(context.WithValue(context.Background(), model.AgentIDCtxKey, agent.ID())) + responseRecorder := httptest.NewRecorder() + + handler.ServeHTTP(responseRecorder, request) + assert.Equal(t, http.StatusOK, responseRecorder.Code) + var response model.AgentInvokeEvent + respBody, _ := io.ReadAll(responseRecorder.Body) + require.NoError(t, json.Unmarshal(respBody, &response)) + + assert.Nil(t, response.Tracing) +} diff --git a/internal/lambda-managed-instances/rapi/handler/agentregister.go b/internal/lambda-managed-instances/rapi/handler/agentregister.go new file mode 100644 index 0000000..52edf56 --- /dev/null +++ b/internal/lambda-managed-instances/rapi/handler/agentregister.go @@ -0,0 +1,224 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package handler + +import ( + "context" + "encoding/json" + "errors" + "io" + "log/slog" + "net/http" + "strings" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/core" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/logging" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi/rendering" +) + +type agentRegisterHandler struct { + registrationService core.RegistrationService +} + +type RegisterRequest struct { + Events []core.Event `json:"events"` +} + +const featuresHeader = "Lambda-Extension-Accept-Feature" + +type registrationFeature int + +const ( + accountFeature registrationFeature = iota + 1 +) + +var allowedFeatures = map[string]registrationFeature{ + "accountId": accountFeature, +} + +type responseModifier func(*model.ExtensionRegisterResponse) + +func parseRegister(request *http.Request) (*RegisterRequest, error) { + body, err := io.ReadAll(request.Body) + if err != nil { + return nil, err + } + + req := struct { + RegisterRequest + ConfigurationKeys []string `json:"configurationKeys"` + }{} + + if err := json.Unmarshal(body, &req); err != nil { + return nil, err + } + + if len(req.ConfigurationKeys) != 0 { + return nil, errors.New("configurationKeys are deprecated; use environment variables instead") + } + + return &req.RegisterRequest, nil +} + +func (h *agentRegisterHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + agentName := request.Header.Get(LambdaAgentName) + ctx := logging.WithFields(request.Context(), "agentName", agentName) + if agentName == "" { + logging.Warn(ctx, "Empty extension name") + rendering.RenderForbiddenWithTypeMsg(writer, request, errAgentNameInvalid, "Empty extension name") + return + } + + var responseModifiers []responseModifier + for _, f := range parseRegistrationFeatures(request) { + if f == accountFeature { + responseModifiers = append(responseModifiers, h.respondWithAccountID()) + } + } + + registerRequest, err := parseRegister(request) + if err != nil { + logging.Warn(ctx, "Invalid Register request format", "err", err) + rendering.RenderForbiddenWithTypeMsg(writer, request, errInvalidRequestFormat, "%s", err.Error()) + return + } + + agent, found := h.registrationService.FindExternalAgentByName(agentName) + if found { + h.registerExternalAgent(ctx, agent, registerRequest, writer, request, responseModifiers...) + } else { + h.registerInternalAgent(ctx, agentName, registerRequest, writer, request, responseModifiers...) + } +} + +func (h *agentRegisterHandler) respondWithAccountID() responseModifier { + return func(resp *model.ExtensionRegisterResponse) { + resp.AccountID = h.registrationService.GetFunctionMetadata().AccountID + } +} + +func parseRegistrationFeatures(request *http.Request) []registrationFeature { + rawFeatures := strings.Split(request.Header.Get(featuresHeader), ",") + + var features []registrationFeature + for _, feature := range rawFeatures { + feature = strings.TrimSpace(feature) + if v, found := allowedFeatures[feature]; found { + features = append(features, v) + } + } + + return features +} + +func (h *agentRegisterHandler) renderResponse( + agentID string, + writer http.ResponseWriter, + request *http.Request, + respModifiers ...responseModifier, +) { + writer.Header().Set(model.LambdaAgentIdentifier, agentID) + + metadata := h.registrationService.GetFunctionMetadata() + resp := &model.ExtensionRegisterResponse{ + FunctionVersion: metadata.FunctionVersion, + FunctionName: metadata.FunctionName, + Handler: metadata.Handler, + } + + for _, mod := range respModifiers { + mod(resp) + } + + if err := rendering.RenderJSON(http.StatusOK, writer, request, resp); err != nil { + slog.Warn("Error while rendering response", "err", err) + http.Error(writer, err.Error(), http.StatusInternalServerError) + } +} + +func (h *agentRegisterHandler) registerExternalAgent( + ctx context.Context, + agent *core.ExternalAgent, + registerRequest *RegisterRequest, + writer http.ResponseWriter, + request *http.Request, + respModifiers ...responseModifier, +) { + ctx = logging.WithFields(ctx, "agent", agent.String()) + for _, e := range registerRequest.Events { + if err := core.ValidateExternalAgentEvent(e); err != nil { + logging.Warn(ctx, "Failed to register agent event", "event", e, "err", err) + rendering.RenderForbiddenWithTypeMsg(writer, request, errInvalidEventType, "%s: %s", e, err) + return + } + } + + if err := agent.Register(registerRequest.Events); err != nil { + logging.Warn(ctx, "Failed to register agent", "err", err) + rendering.RenderForbiddenWithTypeMsg(writer, request, errAgentInvalidState, StateTransitionFailedForExtensionMessageFormat, + agent.GetState().Name(), core.AgentRegisteredStateName, agent.Name(), err) + return + } + + h.renderResponse(agent.ID().String(), writer, request, respModifiers...) + logging.Debug(ctx, "External agent registered", "events", registerRequest.Events) +} + +func (h *agentRegisterHandler) registerInternalAgent( + ctx context.Context, + agentName string, + registerRequest *RegisterRequest, + writer http.ResponseWriter, + request *http.Request, + respModifiers ...responseModifier, +) { + if len(registerRequest.Events) != 0 { + logging.Warn(ctx, "No events allowed for internal extensions") + rendering.RenderForbiddenWithTypeMsg(writer, request, errInvalidEventType, "No events allowed for internal extensions") + return + } + + agent, err := h.registrationService.CreateInternalAgent(agentName) + if err != nil { + logging.Warn(ctx, "Failed to create internal agent", "err", err) + + switch err { + case core.ErrRegistrationServiceOff: + logging.Warn(ctx, "Extension registration closed already") + rendering.RenderForbiddenWithTypeMsg(writer, request, + errAgentRegistrationClosed, "Extension registration closed already") + case core.ErrAgentNameCollision: + logging.Warn(ctx, "Extension with this name already registered") + rendering.RenderForbiddenWithTypeMsg(writer, request, + errAgentInvalidState, "Extension with this name already registered") + case core.ErrTooManyExtensions: + logging.Warn(ctx, "Extension limit reached", "limit", core.MaxAgentsAllowed) + rendering.RenderForbiddenWithTypeMsg(writer, request, + errTooManyExtensions, "Extension limit (%d) reached", core.MaxAgentsAllowed) + default: + rendering.RenderInternalServerError(writer, request) + } + + return + } + + ctx = logging.WithFields(ctx, "agent", agent.String()) + + if err := agent.Register(registerRequest.Events); err != nil { + logging.Warn(ctx, "Failed to register agent", "err", err) + rendering.RenderForbiddenWithTypeMsg(writer, request, errAgentInvalidState, StateTransitionFailedForExtensionMessageFormat, + agent.GetState().Name(), core.AgentRegisteredStateName, agent.Name(), err) + return + } + + h.renderResponse(agent.ID().String(), writer, request, respModifiers...) + logging.Info(ctx, "Internal agent registered") +} + +func NewAgentRegisterHandler(registrationService core.RegistrationService) http.Handler { + return &agentRegisterHandler{ + registrationService: registrationService, + } +} diff --git a/internal/lambda-managed-instances/rapi/handler/agentregister_test.go b/internal/lambda-managed-instances/rapi/handler/agentregister_test.go new file mode 100644 index 0000000..62b35f4 --- /dev/null +++ b/internal/lambda-managed-instances/rapi/handler/agentregister_test.go @@ -0,0 +1,314 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package handler + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/core" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi/model" + rapidmodel "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" +) + +func registerRequestReader(req RegisterRequest) io.Reader { + body, err := json.Marshal(req) + if err != nil { + panic(err) + } + return bytes.NewReader(body) +} + +func TestRenderAgentRegisterInvalidAgentName(t *testing.T) { + registrationService := core.NewRegistrationService( + core.NewInitFlowSynchronization()) + + handler := NewAgentRegisterHandler(registrationService) + request := httptest.NewRequest("POST", "/extension/register", registerRequestReader(RegisterRequest{})) + responseRecorder := httptest.NewRecorder() + + handler.ServeHTTP(responseRecorder, request) + require.Equal(t, http.StatusForbidden, responseRecorder.Code) + + var errorResponse model.ErrorResponse + respBody, _ := io.ReadAll(responseRecorder.Body) + require.NoError(t, json.Unmarshal(respBody, &errorResponse)) + require.Equal(t, http.StatusForbidden, responseRecorder.Code) + require.Equal(t, errAgentNameInvalid, errorResponse.ErrorType) +} + +func TestRenderAgentRegisterRegistrationClosed(t *testing.T) { + registrationService := core.NewRegistrationService( + core.NewInitFlowSynchronization()) + registrationService.TurnOff() + + handler := NewAgentRegisterHandler(registrationService) + request := httptest.NewRequest("POST", "/extension/register", registerRequestReader(RegisterRequest{})) + request.Header.Add(LambdaAgentName, "dummyName") + responseRecorder := httptest.NewRecorder() + + handler.ServeHTTP(responseRecorder, request) + require.Equal(t, http.StatusForbidden, responseRecorder.Code) + + var errorResponse model.ErrorResponse + respBody, _ := io.ReadAll(responseRecorder.Body) + require.NoError(t, json.Unmarshal(respBody, &errorResponse)) + require.Equal(t, http.StatusForbidden, responseRecorder.Code) + require.Equal(t, errAgentRegistrationClosed, errorResponse.ErrorType) +} + +func TestRenderAgentRegisterInvalidAgentState(t *testing.T) { + registrationService := core.NewRegistrationService( + core.NewInitFlowSynchronization()) + + agent, err := registrationService.CreateExternalAgent("dummyName") + require.NoError(t, err) + agent.SetState(agent.RegisteredState) + + handler := NewAgentRegisterHandler(registrationService) + request := httptest.NewRequest("POST", "/extension/register", registerRequestReader(RegisterRequest{})) + request.Header.Add(LambdaAgentName, "dummyName") + responseRecorder := httptest.NewRecorder() + + handler.ServeHTTP(responseRecorder, request) + require.Equal(t, http.StatusForbidden, responseRecorder.Code) + + var errorResponse model.ErrorResponse + respBody, _ := io.ReadAll(responseRecorder.Body) + require.NoError(t, json.Unmarshal(respBody, &errorResponse)) + require.Equal(t, http.StatusForbidden, responseRecorder.Code) + require.Equal(t, errAgentInvalidState, errorResponse.ErrorType) +} + +func registerAgent(t *testing.T, agentName string, events []core.Event, registerHandler http.Handler) { + request := httptest.NewRequest("POST", "/extension/register", registerRequestReader(RegisterRequest{Events: events})) + request.Header.Add(LambdaAgentName, agentName) + responseRecorder := httptest.NewRecorder() + registerHandler.ServeHTTP(responseRecorder, request) + require.Equal(t, http.StatusOK, responseRecorder.Code) +} + +func TestGetSubscribedExternalAgents(t *testing.T) { + registrationService := core.NewRegistrationService( + core.NewInitFlowSynchronization()) + _, err := registrationService.CreateExternalAgent("externalInvokeAgent") + require.NoError(t, err) + _, err = registrationService.CreateExternalAgent("externalShutdownAgent") + require.NoError(t, err) + + handler := NewAgentRegisterHandler(registrationService) + + registerAgent(t, "externalShutdownAgent", []core.Event{core.ShutdownEvent}, handler) + registerAgent(t, "internalInvokeAgent", []core.Event{}, handler) + + subscribers := registrationService.GetSubscribedExternalAgents(core.ShutdownEvent) + require.Equal(t, 1, len(subscribers)) + require.Equal(t, "externalShutdownAgent", subscribers[0].Name()) +} + +func TestExternalAgentInvalidEventType(t *testing.T) { + registrationService := core.NewRegistrationService( + core.NewInitFlowSynchronization()) + agentName := "ABC" + _, err := registrationService.CreateExternalAgent(agentName) + require.NoError(t, err) + + for i := 0; i < 2; i++ { + request := httptest.NewRequest("POST", "/extension/register", registerRequestReader(RegisterRequest{Events: []core.Event{"abcdef"}})) + request.Header.Add(LambdaAgentName, agentName) + + responseRecorder := httptest.NewRecorder() + NewAgentRegisterHandler(registrationService).ServeHTTP(responseRecorder, request) + require.Equal(t, http.StatusForbidden, responseRecorder.Code) + + response := model.ErrorResponse{} + require.NoError(t, json.Unmarshal(responseRecorder.Body.Bytes(), &response)) + require.Equal(t, errInvalidEventType, response.ErrorType) + require.Contains(t, response.ErrorMessage, "abcdef") + + _, found := registrationService.FindExternalAgentByName(agentName) + require.True(t, found) + + shutdownSubscribers := registrationService.GetSubscribedExternalAgents(core.ShutdownEvent) + require.Equal(t, 0, len(shutdownSubscribers)) + + require.Equal(t, 1, registrationService.CountAgents()) + } +} + +type ExtensionRegisterResponseWithConfig struct { + model.ExtensionRegisterResponse + Configuration map[string]string `json:"configuration"` +} + +func TestRenderAgentResponse(t *testing.T) { + defaultFunctionMetadata := rapidmodel.FunctionMetadata{ + FunctionVersion: "$LATEST", + FunctionName: "my-func", + Handler: "lambda_handler", + } + + happyPathTests := map[string]struct { + agentName string + external bool + registrationRequest RegisterRequest + featuresHeader string + functionMetadata rapidmodel.FunctionMetadata + expectedResponse string + }{ + "no-config-internal": { + agentName: "internal", + external: false, + functionMetadata: defaultFunctionMetadata, + registrationRequest: RegisterRequest{}, + expectedResponse: `{ + "functionName": "my-func", + "functionVersion": "$LATEST", + "handler": "lambda_handler" + }`, + }, + "no-config-external": { + agentName: "external", + external: true, + functionMetadata: defaultFunctionMetadata, + registrationRequest: RegisterRequest{}, + expectedResponse: `{ + "functionName": "my-func", + "functionVersion": "$LATEST", + "handler": "lambda_handler" + }`, + }, + "function-md-override": { + agentName: "external", + external: true, + functionMetadata: rapidmodel.FunctionMetadata{FunctionName: "function-name", FunctionVersion: "1", Handler: "myHandler"}, + registrationRequest: RegisterRequest{}, + expectedResponse: `{ + "functionName": "function-name", + "functionVersion": "1", + "handler": "myHandler" + }`, + }, + "internal with account id feature": { + agentName: "internal", + external: false, + functionMetadata: rapidmodel.FunctionMetadata{ + FunctionName: "function-name", + FunctionVersion: "1", + Handler: "myHandler", + AccountID: "0123", + }, + featuresHeader: "accountId", + registrationRequest: RegisterRequest{}, + expectedResponse: `{ + "functionName": "function-name", + "functionVersion": "1", + "handler": "myHandler", + "accountId": "0123" + }`, + }, + "external with account id feature": { + agentName: "external", + external: true, + functionMetadata: rapidmodel.FunctionMetadata{ + FunctionName: "function-name", + FunctionVersion: "1", + Handler: "myHandler", + AccountID: "0123", + }, + featuresHeader: "accountId", + registrationRequest: RegisterRequest{}, + expectedResponse: `{ + "functionName": "function-name", + "functionVersion": "1", + "handler": "myHandler", + "accountId": "0123" + }`, + }, + "with non-existing accept feature": { + agentName: "external", + external: true, + featuresHeader: "some_non_existing_feature,", + functionMetadata: defaultFunctionMetadata, + registrationRequest: RegisterRequest{}, + expectedResponse: `{ + "functionName": "my-func", + "functionVersion": "$LATEST", + "handler": "lambda_handler" + }`, + }, + "account id feature and some non-existing feature": { + agentName: "external", + external: true, + featuresHeader: "some_non_existing_feature,accountId,", + functionMetadata: rapidmodel.FunctionMetadata{ + FunctionName: "function-name", + FunctionVersion: "1", + Handler: "myHandler", + AccountID: "0123", + }, + registrationRequest: RegisterRequest{}, + expectedResponse: `{ + "functionName": "function-name", + "functionVersion": "1", + "handler": "myHandler", + "accountId": "0123" + }`, + }, + "with empty account id data": { + agentName: "external", + external: true, + featuresHeader: "accountId", + functionMetadata: defaultFunctionMetadata, + registrationRequest: RegisterRequest{}, + expectedResponse: `{ + "functionName": "my-func", + "functionVersion": "$LATEST", + "handler": "lambda_handler" + }`, + }, + } + + for name, tt := range happyPathTests { + t.Run(name, func(t *testing.T) { + registrationService := core.NewRegistrationService( + core.NewInitFlowSynchronization()) + _, _ = registrationService.CreateExternalAgent("external") + registrationService.SetFunctionMetadata(tt.functionMetadata) + + handler := NewAgentRegisterHandler(registrationService) + + request := httptest.NewRequest("POST", "/extension/register", registerRequestReader(tt.registrationRequest)) + request.Header.Add(LambdaAgentName, tt.agentName) + if tt.featuresHeader != "" { + request.Header.Add(featuresHeader, tt.featuresHeader) + } + responseRecorder := httptest.NewRecorder() + + handler.ServeHTTP(responseRecorder, request) + assert.Equal(t, http.StatusOK, responseRecorder.Code) + + respBody, err := io.ReadAll(responseRecorder.Body) + require.NoError(t, err) + assert.JSONEq(t, tt.expectedResponse, string(respBody)) + + if tt.external { + agent, found := registrationService.FindExternalAgentByName(tt.agentName) + assert.True(t, found) + assert.Equal(t, agent.RegisteredState, agent.GetState()) + } else { + agent, found := registrationService.FindInternalAgentByName(tt.agentName) + assert.True(t, found) + assert.Equal(t, agent.RegisteredState, agent.GetState()) + } + }) + } +} diff --git a/internal/lambda-managed-instances/rapi/handler/constants.go b/internal/lambda-managed-instances/rapi/handler/constants.go new file mode 100644 index 0000000..a7d76c6 --- /dev/null +++ b/internal/lambda-managed-instances/rapi/handler/constants.go @@ -0,0 +1,22 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package handler + +const ( + LambdaAgentFunctionErrorType string = "Lambda-Extension-Function-Error-Type" + + LambdaAgentName string = "Lambda-Extension-Name" + + errAgentNameInvalid string = "Extension.InvalidExtensionName" + errAgentRegistrationClosed string = "Extension.RegistrationClosed" + errAgentIdentifierUnknown string = "Extension.UnknownExtensionIdentifier" + errAgentInvalidState string = "Extension.InvalidExtensionState" + errAgentMissingHeader string = "Extension.MissingHeader" + errTooManyExtensions string = "Extension.TooManyExtensions" + errInvalidEventType string = "Extension.InvalidEventType" + errInvalidRequestFormat string = "InvalidRequestFormat" + + StateTransitionFailedForExtensionMessageFormat string = "State transition from %s to %s failed for extension %s. Error: %s" + StateTransitionFailedForRuntimeMessageFormat string = "State transition from %s to %s failed for runtime. Error: %s" +) diff --git a/internal/lambda-managed-instances/rapi/handler/initerror.go b/internal/lambda-managed-instances/rapi/handler/initerror.go new file mode 100644 index 0000000..d015576 --- /dev/null +++ b/internal/lambda-managed-instances/rapi/handler/initerror.go @@ -0,0 +1,54 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package handler + +import ( + "net/http" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/appctx" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/core" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/logging" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi/rendering" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" +) + +type initErrorHandler struct { + registrationService core.RegistrationService +} + +func (h *initErrorHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + appCtx := appctx.FromRequest(request) + ctx := request.Context() + + errorType := model.GetValidRuntimeOrFunctionErrorType(request.Header.Get("Lambda-Runtime-Function-Error-Type")) + ctx = logging.WithFields(ctx, "errType", errorType) + + logging.Warn(ctx, "Received Runtime Init Error") + + runtime := h.registrationService.GetRuntime() + + if err := runtime.InitError(); err != nil { + logging.Warn(ctx, "Runtime init error", "err", err) + rendering.RenderForbiddenWithTypeMsg( + writer, + request, + rendering.ErrorTypeInvalidStateTransition, + StateTransitionFailedForRuntimeMessageFormat, + runtime.GetState().Name(), + core.RuntimeInitErrorStateName, + err, + ) + return + } + + appctx.StoreFirstFatalError(appCtx, model.WrapErrorIntoCustomerFatalError(nil, errorType)) + + appctx.StoreInvokeErrorTraceData(appCtx, &interop.InvokeErrorTraceData{}) + rendering.RenderAccepted(writer, request) +} + +func NewInitErrorHandler(registrationService core.RegistrationService) http.Handler { + return &initErrorHandler{registrationService: registrationService} +} diff --git a/internal/lambda-managed-instances/rapi/handler/initerror_test.go b/internal/lambda-managed-instances/rapi/handler/initerror_test.go new file mode 100644 index 0000000..aebc102 --- /dev/null +++ b/internal/lambda-managed-instances/rapi/handler/initerror_test.go @@ -0,0 +1,48 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package handler + +import ( + "bytes" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/appctx" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/testdata" +) + +func TestInitErrorHandler(t *testing.T) { + t.Run("GA", func(t *testing.T) { runTestInitErrorHandler(t) }) +} + +func runTestInitErrorHandler(t *testing.T) { + flowTest := testdata.NewFlowTest() + flowTest.ConfigureForInit() + + handler := NewInitErrorHandler(flowTest.RegistrationService) + responseRecorder := httptest.NewRecorder() + appCtx := flowTest.AppCtx + + errorBody := []byte("My byte array is yours") + errorType := "ErrorType" + errorContentType := "application/MyBinaryType" + + request := appctx.RequestWithAppCtx(httptest.NewRequest("POST", "/", bytes.NewReader(errorBody)), appCtx) + request.Header.Set("Content-Type", errorContentType) + request.Header.Set("Lambda-runtime-functioN-erroR-typE", errorType) + + handler.ServeHTTP(responseRecorder, request) + + require.Equal(t, http.StatusAccepted, responseRecorder.Code, "Handler returned wrong status code: got %v expected %v", + responseRecorder.Code, http.StatusAccepted) + require.JSONEq(t, fmt.Sprintf("{\"status\":\"%s\"}\n", "OK"), responseRecorder.Body.String()) + require.Equal(t, "application/json", responseRecorder.Header().Get("Content-Type")) + + require.Nil(t, flowTest.InteropServer.Response) + +} diff --git a/internal/lambda-managed-instances/rapi/handler/invocationerror.go b/internal/lambda-managed-instances/rapi/handler/invocationerror.go new file mode 100644 index 0000000..1ba0b0b --- /dev/null +++ b/internal/lambda-managed-instances/rapi/handler/invocationerror.go @@ -0,0 +1,68 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package handler + +import ( + "bytes" + "context" + "net/http" + + "github.com/go-chi/chi" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/invoke" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/logging" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi/rendering" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" +) + +type RuntimeErrorHandler interface { + RuntimeError(ctx context.Context, runtimeErrReq invoke.RuntimeErrorRequest) model.AppError +} + +type invocationErrorHandler struct { + runtimeErrHandler RuntimeErrorHandler +} + +func (h *invocationErrorHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + invokeID := chi.URLParam(request, "awsrequestid") + ctx := logging.WithInvokeID(request.Context(), invokeID) + + buf := new(bytes.Buffer) + if _, err := buf.ReadFrom(request.Body); err != nil { + logging.Warn(ctx, "Failed to parse error body", "err", err) + rendering.RenderRequestEntityTooLarge(writer, request) + return + } + + resp := invoke.NewRuntimeError(ctx, request, invokeID, buf.String()) + err := h.runtimeErrHandler.RuntimeError(ctx, &resp) + logging.Warn(ctx, "Received Runtime error", "err", err) + if err == nil { + rendering.RenderAccepted(writer, request) + return + } + + logging.Warn(ctx, "Runtime response error", "err", err) + + switch err.ErrorType() { + case model.ErrorRuntimeInvalidInvokeId, model.ErrorRuntimeInvokeErrorInProgress: + + rendering.RenderInvalidRequestID(writer, request) + case model.ErrorRuntimeInvokeTimeout: + rendering.RenderInvokeTimeout(writer, request) + case model.ErrorRuntimeInvokeResponseWasSent: + + rendering.RenderInvalidRequestID(writer, request) + default: + + logging.Error(ctx, "Received unexpected runtime error", "err", err) + rendering.RenderInternalServerError(writer, request) + } +} + +func NewInvocationErrorHandler(runtimeErrHandler RuntimeErrorHandler) http.Handler { + return &invocationErrorHandler{ + runtimeErrHandler: runtimeErrHandler, + } +} diff --git a/internal/lambda-managed-instances/rapi/handler/invocationnext.go b/internal/lambda-managed-instances/rapi/handler/invocationnext.go new file mode 100644 index 0000000..f269e9d --- /dev/null +++ b/internal/lambda-managed-instances/rapi/handler/invocationnext.go @@ -0,0 +1,58 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package handler + +import ( + "context" + "net/http" + "sync" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/core" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/logging" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi/rendering" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" +) + +type RuntimeNextHandler interface { + RuntimeNext(ctx context.Context, runtimeReq http.ResponseWriter) (model.RuntimeNextWaiter, model.AppError) +} + +type invocationNextHandler struct { + registrationService core.RegistrationService + nextHandler RuntimeNextHandler + runtimeReadyOnce sync.Once +} + +func (h *invocationNextHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + ctx := request.Context() + logging.Debug(ctx, "Received Runtime /next") + + waiter, err := h.nextHandler.RuntimeNext(ctx, writer) + if err != nil { + logging.Error(ctx, "Runtime Next Error", "err", err) + rendering.RenderInternalServerError(writer, request) + return + } + + h.runtimeReadyOnce.Do(func() { + if err := h.registrationService.InitFlow().RuntimeReady(); err != nil { + logging.Warn(ctx, "Could not register runtime", "err", err) + rendering.RenderForbiddenWithTypeMsg(writer, request, rendering.ErrorTypeInvalidStateTransition, StateTransitionFailedForRuntimeMessageFormat, + h.registrationService.GetRuntime().GetState().Name(), core.RuntimeReadyStateName, err) + return + } + }) + + if err := waiter.RuntimeNextWait(ctx); err != nil { + logging.Warn(ctx, "Cancelled /next", "err", err) + rendering.RenderInternalServerError(writer, request) + } +} + +func NewInvocationNextHandler(registrationService core.RegistrationService, nextHandler RuntimeNextHandler) http.Handler { + return &invocationNextHandler{ + registrationService: registrationService, + nextHandler: nextHandler, + } +} diff --git a/internal/lambda-managed-instances/rapi/handler/invocationresponse.go b/internal/lambda-managed-instances/rapi/handler/invocationresponse.go new file mode 100644 index 0000000..8fb6ce2 --- /dev/null +++ b/internal/lambda-managed-instances/rapi/handler/invocationresponse.go @@ -0,0 +1,68 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package handler + +import ( + "context" + "net/http" + + "github.com/go-chi/chi" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/invoke" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/logging" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi/rendering" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" +) + +type RuntimeResponseHandler interface { + RuntimeResponse(ctx context.Context, runtimeRespReq invoke.RuntimeResponseRequest) model.AppError +} + +type invocationResponseHandler struct { + runtimeRespHandler RuntimeResponseHandler +} + +func (h *invocationResponseHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + invokeID := chi.URLParam(request, "awsrequestid") + ctx := logging.WithInvokeID(request.Context(), invokeID) + + logging.Debug(ctx, "Received Runtime Response") + resp := invoke.NewRuntimeResponse(ctx, request, invokeID) + + err := h.runtimeRespHandler.RuntimeResponse(ctx, &resp) + if err == nil { + rendering.RenderAccepted(writer, request) + return + } + + logging.Warn(ctx, "Runtime Response Error", "err", err) + + switch err.ErrorType() { + case model.ErrorRuntimeInvalidInvokeId, model.ErrorRuntimeInvokeResponseInProgress: + + rendering.RenderInvalidRequestID(writer, request) + case model.ErrorRuntimeInvokeTimeout, model.ErrorSandboxTimedout: + rendering.RenderInvokeTimeout(writer, request) + case model.ErrorRuntimeInvalidResponseModeHeader: + rendering.RenderInvalidFunctionResponseMode(writer, request) + case model.ErrorFunctionOversizedResponse: + rendering.RenderRequestEntityTooLarge(writer, request) + case model.ErrorRuntimeTruncatedResponse: + rendering.RenderTruncatedHTTPRequestError(writer, request) + default: + + if trailerError := resp.TrailerError(); trailerError.ErrorType() != "" { + rendering.RenderAccepted(writer, request) + return + } + logging.Error(ctx, "unexpected error in runtime response", "err", err) + rendering.RenderInternalServerError(writer, request) + } +} + +func NewInvocationResponseHandler(runtimeRespHandler RuntimeResponseHandler) http.Handler { + return &invocationResponseHandler{ + runtimeRespHandler: runtimeRespHandler, + } +} diff --git a/internal/lambda-managed-instances/rapi/handler/ping.go b/internal/lambda-managed-instances/rapi/handler/ping.go new file mode 100644 index 0000000..82a52d3 --- /dev/null +++ b/internal/lambda-managed-instances/rapi/handler/ping.go @@ -0,0 +1,23 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package handler + +import ( + "log/slog" + "net/http" +) + +type pingHandler struct { +} + +func (h *pingHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + if _, err := writer.Write([]byte("pong")); err != nil { + slog.Warn("Failed to write 'pong' response", "err", err) + panic(err) + } +} + +func NewPingHandler() http.Handler { + return &pingHandler{} +} diff --git a/internal/lambda-managed-instances/rapi/handler/runtimelogs.go b/internal/lambda-managed-instances/rapi/handler/runtimelogs.go new file mode 100644 index 0000000..1ca7ef9 --- /dev/null +++ b/internal/lambda-managed-instances/rapi/handler/runtimelogs.go @@ -0,0 +1,138 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package handler + +import ( + "bytes" + "errors" + "fmt" + "io" + "log/slog" + "net/http" + "strings" + + "github.com/google/uuid" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/core" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi/rendering" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/telemetry" +) + +type runtimeLogsHandler struct { + registrationService core.RegistrationService + telemetrySubscription telemetry.SubscriptionAPI +} + +func (h *runtimeLogsHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + agentName, err := h.verifyAgentID(writer, request) + if err != nil { + slog.Warn("Agent Verification Error", "err", err) + switch err := err.(type) { + case *ErrAgentIdentifierUnknown: + rendering.RenderForbiddenWithTypeMsg(writer, request, errAgentIdentifierUnknown, "Unknown extension %s", err.agentID.String()) + h.telemetrySubscription.RecordCounterMetric(telemetry.SubscribeClientErr, 1) + default: + rendering.RenderInternalServerError(writer, request) + h.telemetrySubscription.RecordCounterMetric(telemetry.SubscribeServerErr, 1) + } + return + } + + delete(request.Header, model.LambdaAgentIdentifier) + + body, err := h.getBody(writer, request) + if err != nil { + slog.Warn("Failed to get request body", "err", err) + rendering.RenderInternalServerError(writer, request) + h.telemetrySubscription.RecordCounterMetric(telemetry.SubscribeServerErr, 1) + return + } + + respBody, status, headers, err := h.telemetrySubscription.Subscribe(agentName, bytes.NewReader(body), request.Header, request.RemoteAddr) + if err != nil { + slog.Warn("Telemetry API error", "err", err) + switch err { + case telemetry.ErrTelemetryServiceOff: + rendering.RenderForbiddenWithTypeMsg(writer, request, + h.telemetrySubscription.GetServiceClosedErrorType(), "%s", h.telemetrySubscription.GetServiceClosedErrorMessage()) + h.telemetrySubscription.RecordCounterMetric(telemetry.SubscribeClientErr, 1) + default: + rendering.RenderInternalServerError(writer, request) + h.telemetrySubscription.RecordCounterMetric(telemetry.SubscribeServerErr, 1) + } + return + } + + if err := rendering.RenderRuntimeLogsResponse(writer, respBody, status, headers); err != nil { + slog.Warn("Failed to render runtime logs response", "err", err) + } + switch status / 100 { + case 2: + if strings.Contains(string(respBody), "OK") { + h.telemetrySubscription.RecordCounterMetric(telemetry.NumSubscribers, 1) + } + h.telemetrySubscription.RecordCounterMetric(telemetry.SubscribeSuccess, 1) + case 4: + h.telemetrySubscription.RecordCounterMetric(telemetry.SubscribeClientErr, 1) + slog.Warn("rendered telemetry api subscription client error response", "body", respBody) + case 5: + h.telemetrySubscription.RecordCounterMetric(telemetry.SubscribeServerErr, 1) + slog.Error("rendered telemetry api subscription server error response", "body", respBody) + } +} + +type ErrAgentIdentifierUnknown struct { + agentID uuid.UUID +} + +func NewErrAgentIdentifierUnknown(agentID uuid.UUID) *ErrAgentIdentifierUnknown { + return &ErrAgentIdentifierUnknown{ + agentID: agentID, + } +} + +func (e *ErrAgentIdentifierUnknown) Error() string { + return fmt.Sprintf("Unknown agent %s tried to call /runtime/logs", e.agentID.String()) +} + +func (h *runtimeLogsHandler) verifyAgentID(writer http.ResponseWriter, request *http.Request) (string, error) { + agentID, ok := request.Context().Value(model.AgentIDCtxKey).(uuid.UUID) + if !ok { + return "", errors.New("internal error: agent ID not set in context") + } + + agentName, found := h.getAgentName(agentID) + if !found { + return "", NewErrAgentIdentifierUnknown(agentID) + } + + return agentName, nil +} + +func (h *runtimeLogsHandler) getAgentName(agentID uuid.UUID) (string, bool) { + if agent, found := h.registrationService.FindExternalAgentByID(agentID); found { + return agent.Name(), true + } else if agent, found := h.registrationService.FindInternalAgentByID(agentID); found { + return agent.Name(), true + } else { + return "", false + } +} + +func (h *runtimeLogsHandler) getBody(writer http.ResponseWriter, request *http.Request) ([]byte, error) { + body, err := io.ReadAll(request.Body) + if err != nil { + return nil, fmt.Errorf("failed to read error body: %s", err) + } + + return body, nil +} + +func NewRuntimeTelemetrySubscriptionHandler(registrationService core.RegistrationService, telemetrySubscription telemetry.SubscriptionAPI) http.Handler { + return &runtimeLogsHandler{ + registrationService: registrationService, + telemetrySubscription: telemetrySubscription, + } +} diff --git a/internal/lambda-managed-instances/rapi/handler/runtimelogs_stub.go b/internal/lambda-managed-instances/rapi/handler/runtimelogs_stub.go new file mode 100644 index 0000000..bd05b7a --- /dev/null +++ b/internal/lambda-managed-instances/rapi/handler/runtimelogs_stub.go @@ -0,0 +1,49 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package handler + +import ( + "log/slog" + "net/http" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi/rendering" +) + +const ( + logsAPIDisabledErrorType = "Logs.NotSupported" + telemetryAPIDisabledErrorType = "Telemetry.NotSupported" +) + +type runtimeLogsStubAPIHandler struct{} + +func (h *runtimeLogsStubAPIHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + if err := rendering.RenderJSON(http.StatusAccepted, writer, request, &model.ErrorResponse{ + ErrorType: logsAPIDisabledErrorType, + ErrorMessage: "Logs API is not supported", + }); err != nil { + slog.Warn("Error while rendering response", "err", err) + http.Error(writer, err.Error(), http.StatusInternalServerError) + } +} + +func NewRuntimeLogsAPIStubHandler() http.Handler { + return &runtimeLogsStubAPIHandler{} +} + +type runtimeTelemetryAPIStubHandler struct{} + +func (h *runtimeTelemetryAPIStubHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + if err := rendering.RenderJSON(http.StatusAccepted, writer, request, &model.ErrorResponse{ + ErrorType: telemetryAPIDisabledErrorType, + ErrorMessage: "Telemetry API is not supported", + }); err != nil { + slog.Warn("Error while rendering response", "err", err) + http.Error(writer, err.Error(), http.StatusInternalServerError) + } +} + +func NewRuntimeTelemetryAPIStubHandler() http.Handler { + return &runtimeTelemetryAPIStubHandler{} +} diff --git a/internal/lambda-managed-instances/rapi/handler/runtimelogs_stub_test.go b/internal/lambda-managed-instances/rapi/handler/runtimelogs_stub_test.go new file mode 100644 index 0000000..59d9ec3 --- /dev/null +++ b/internal/lambda-managed-instances/rapi/handler/runtimelogs_stub_test.go @@ -0,0 +1,25 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package handler + +import ( + "bytes" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSuccessfulRuntimeTelemetryAPIStub202Response(t *testing.T) { + handler := NewRuntimeTelemetryAPIStubHandler() + requestBody := []byte(`foobar`) + request := httptest.NewRequest("PUT", "/telemetry", bytes.NewBuffer(requestBody)) + responseRecorder := httptest.NewRecorder() + + handler.ServeHTTP(responseRecorder, request) + + assert.Equal(t, http.StatusAccepted, responseRecorder.Code) + assert.JSONEq(t, `{"errorMessage":"Telemetry API is not supported","errorType":"Telemetry.NotSupported"}`, responseRecorder.Body.String()) +} diff --git a/internal/lambda-managed-instances/rapi/handler/runtimelogs_test.go b/internal/lambda-managed-instances/rapi/handler/runtimelogs_test.go new file mode 100644 index 0000000..78ae44a --- /dev/null +++ b/internal/lambda-managed-instances/rapi/handler/runtimelogs_test.go @@ -0,0 +1,356 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package handler + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/http/httptest" + "net/netip" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/core" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/telemetry" +) + +type mockSubscriptionAPI struct{ mock.Mock } + +func (s *mockSubscriptionAPI) Subscribe(agentName string, body io.Reader, headers map[string][]string, remoteAddr string) ([]byte, int, map[string][]string, error) { + args := s.Called(agentName, body, headers, remoteAddr) + return args.Get(0).([]byte), args.Int(1), args.Get(2).(map[string][]string), args.Error(3) +} + +func (s *mockSubscriptionAPI) RecordCounterMetric(metricName string, count int) { + s.Called(metricName, count) +} + +func (s *mockSubscriptionAPI) FlushMetrics() interop.TelemetrySubscriptionMetrics { + s.Called() + return interop.TelemetrySubscriptionMetrics{} +} + +func (s *mockSubscriptionAPI) Clear() { + s.Called() +} + +func (s *mockSubscriptionAPI) TurnOff() { + s.Called() +} + +func (s *mockSubscriptionAPI) GetEndpointURL() string { + args := s.Called() + return args.Get(0).(string) +} + +func (s *mockSubscriptionAPI) GetServiceClosedErrorMessage() string { + args := s.Called() + return args.Get(0).(string) +} + +func (s *mockSubscriptionAPI) GetServiceClosedErrorType() string { + args := s.Called() + return args.Get(0).(string) +} + +func (s *mockSubscriptionAPI) Configure(passphrase string, addr netip.AddrPort) { + s.Called(passphrase, addr) +} + +func validIPPort(addr string) bool { + ip, _, err := net.SplitHostPort(addr) + return err == nil && net.ParseIP(ip) != nil +} + +func TestSuccessfulRuntimeLogsResponseProxy(t *testing.T) { + agentName, reqBody, reqHeaders := "dummyName", []byte(`foobar`), map[string][]string{"Key": {"VAL1", "VAL2"}} + respBody, respStatus, respHeaders := []byte(`barbaz`), http.StatusNotFound, map[string][]string{"K": {"V1", "V2"}} + clientErrMetric := telemetry.SubscribeClientErr + + registrationService := core.NewRegistrationService( + core.NewInitFlowSynchronization()) + + agent, err := registrationService.CreateExternalAgent(agentName) + assert.NoError(t, err) + + telemetrySubscription := &mockSubscriptionAPI{} + telemetrySubscription.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders, mock.MatchedBy(validIPPort)).Return(respBody, respStatus, respHeaders, nil) + telemetrySubscription.On("RecordCounterMetric", clientErrMetric, 1) + + handler := NewRuntimeTelemetrySubscriptionHandler(registrationService, telemetrySubscription) + request := httptest.NewRequest("PUT", "/", bytes.NewBuffer(reqBody)) + for k, vals := range reqHeaders { + for _, v := range vals { + request.Header.Add(k, v) + } + } + + request = request.WithContext(context.WithValue(context.Background(), model.AgentIDCtxKey, agent.ID())) + responseRecorder := httptest.NewRecorder() + + handler.ServeHTTP(responseRecorder, request) + + telemetrySubscription.AssertCalled(t, "Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders, mock.MatchedBy(validIPPort)) + telemetrySubscription.AssertCalled(t, "RecordCounterMetric", clientErrMetric, 1) + + recordedBody, err := io.ReadAll(responseRecorder.Body) + assert.NoError(t, err) + + assert.Equal(t, respStatus, responseRecorder.Code) + assert.Equal(t, respBody, recordedBody) + assert.Equal(t, http.Header(respHeaders), responseRecorder.Header()) +} + +func TestSuccessfulTelemetryAPIPutRequest(t *testing.T) { + agentName, reqBody, reqHeaders := "extensionName", []byte(`foobar`), map[string][]string{"Key": {"VAL1", "VAL2"}} + respBody, respStatus, respHeaders := []byte(`"OK"`), http.StatusOK, map[string][]string{"K": {"V1", "V2"}} + numSubscribersMetric := telemetry.NumSubscribers + subscribeSuccessMetric := telemetry.SubscribeSuccess + + registrationService := core.NewRegistrationService( + core.NewInitFlowSynchronization()) + + agent, err := registrationService.CreateExternalAgent(agentName) + assert.NoError(t, err) + + telemetrySubscription := &mockSubscriptionAPI{} + telemetrySubscription.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders, mock.MatchedBy(validIPPort)).Return(respBody, respStatus, respHeaders, nil) + telemetrySubscription.On("RecordCounterMetric", numSubscribersMetric, 1) + telemetrySubscription.On("RecordCounterMetric", subscribeSuccessMetric, 1) + + handler := NewRuntimeTelemetrySubscriptionHandler(registrationService, telemetrySubscription) + request := httptest.NewRequest("PUT", "/", bytes.NewBuffer(reqBody)) + for k, vals := range reqHeaders { + for _, v := range vals { + request.Header.Add(k, v) + } + } + + request = request.WithContext(context.WithValue(context.Background(), model.AgentIDCtxKey, agent.ID())) + responseRecorder := httptest.NewRecorder() + + handler.ServeHTTP(responseRecorder, request) + + telemetrySubscription.AssertCalled(t, "Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders, mock.MatchedBy(validIPPort)) + telemetrySubscription.AssertCalled(t, "RecordCounterMetric", numSubscribersMetric, 1) + telemetrySubscription.AssertCalled(t, "RecordCounterMetric", subscribeSuccessMetric, 1) + + recordedBody, err := io.ReadAll(responseRecorder.Body) + assert.NoError(t, err) + + assert.Equal(t, respStatus, responseRecorder.Code) + assert.Equal(t, respBody, recordedBody) + assert.Equal(t, http.Header(respHeaders), responseRecorder.Header()) +} + +func TestNumberOfSubscribersWhenAnExtensionIsAlreadySubscribed(t *testing.T) { + agentName, reqBody, reqHeaders := "extensionName", []byte(`foobar`), map[string][]string{"Key": {"VAL1", "VAL2"}} + respBody, respStatus, respHeaders := []byte(`"AlreadySubcribed"`), http.StatusOK, map[string][]string{"K": {"V1", "V2"}} + numSubscribersMetric := telemetry.NumSubscribers + subscribeSuccessMetric := telemetry.SubscribeSuccess + + registrationService := core.NewRegistrationService( + core.NewInitFlowSynchronization()) + + agent, err := registrationService.CreateExternalAgent(agentName) + assert.NoError(t, err) + + telemetrySubscription := &mockSubscriptionAPI{} + telemetrySubscription.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders, mock.MatchedBy(validIPPort)).Return(respBody, respStatus, respHeaders, nil) + telemetrySubscription.On("RecordCounterMetric", subscribeSuccessMetric, 1) + + handler := NewRuntimeTelemetrySubscriptionHandler(registrationService, telemetrySubscription) + request := httptest.NewRequest("PUT", "/", bytes.NewBuffer(reqBody)) + for k, vals := range reqHeaders { + for _, v := range vals { + request.Header.Add(k, v) + } + } + + request = request.WithContext(context.WithValue(context.Background(), model.AgentIDCtxKey, agent.ID())) + responseRecorder := httptest.NewRecorder() + + handler.ServeHTTP(responseRecorder, request) + + telemetrySubscription.AssertCalled(t, "Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders, mock.MatchedBy(validIPPort)) + telemetrySubscription.AssertCalled(t, "RecordCounterMetric", subscribeSuccessMetric, 1) + telemetrySubscription.AssertNotCalled(t, "RecordCounterMetric", numSubscribersMetric, mock.Anything) + + recordedBody, err := io.ReadAll(responseRecorder.Body) + assert.NoError(t, err) + + assert.Equal(t, respStatus, responseRecorder.Code) + assert.Equal(t, respBody, recordedBody) + assert.Equal(t, http.Header(respHeaders), responseRecorder.Header()) +} + +func TestErrorUnregisteredAgentID(t *testing.T) { + invalidAgentID := uuid.New() + reqBody, reqHeaders := []byte(`foobar`), map[string][]string{"Key": {"VAL1", "VAL2"}} + clientErrMetric := telemetry.SubscribeClientErr + + registrationService := core.NewRegistrationService( + core.NewInitFlowSynchronization()) + + telemetrySubscription := &mockSubscriptionAPI{} + telemetrySubscription.On("RecordCounterMetric", clientErrMetric, 1) + + handler := NewRuntimeTelemetrySubscriptionHandler(registrationService, telemetrySubscription) + request := httptest.NewRequest("PUT", "/", bytes.NewBuffer(reqBody)) + for k, vals := range reqHeaders { + for _, v := range vals { + request.Header.Add(k, v) + } + } + + request = request.WithContext(context.WithValue(context.Background(), model.AgentIDCtxKey, invalidAgentID)) + responseRecorder := httptest.NewRecorder() + + handler.ServeHTTP(responseRecorder, request) + + recordedBody, err := io.ReadAll(responseRecorder.Body) + assert.NoError(t, err) + + expectedErrorBody := fmt.Sprintf(`{"errorMessage":"Unknown extension %s","errorType":"Extension.UnknownExtensionIdentifier"}`+"\n", invalidAgentID.String()) + expectedHeaders := http.Header(map[string][]string{"Content-Type": {"application/json"}}) + + assert.Equal(t, http.StatusForbidden, responseRecorder.Code) + assert.Equal(t, expectedErrorBody, string(recordedBody)) + assert.Equal(t, expectedHeaders, responseRecorder.Header()) + telemetrySubscription.AssertCalled(t, "RecordCounterMetric", clientErrMetric, 1) +} + +func TestErrorTelemetryAPICallFailure(t *testing.T) { + agentName, reqBody, reqHeaders := "dummyName", []byte(`foobar`), map[string][]string{"Key": {"VAL1", "VAL2"}} + apiError := errors.New("Error calling Telemetry API: connection refused") + serverErrMetric := telemetry.SubscribeServerErr + + registrationService := core.NewRegistrationService( + core.NewInitFlowSynchronization()) + + agent, err := registrationService.CreateExternalAgent(agentName) + assert.NoError(t, err) + + telemetrySubscription := &mockSubscriptionAPI{} + telemetrySubscription.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders, mock.MatchedBy(validIPPort)).Return([]byte(``), http.StatusOK, map[string][]string{}, apiError) + telemetrySubscription.On("RecordCounterMetric", serverErrMetric, 1) + + handler := NewRuntimeTelemetrySubscriptionHandler(registrationService, telemetrySubscription) + request := httptest.NewRequest("PUT", "/", bytes.NewBuffer(reqBody)) + for k, vals := range reqHeaders { + for _, v := range vals { + request.Header.Add(k, v) + } + } + + request = request.WithContext(context.WithValue(context.Background(), model.AgentIDCtxKey, agent.ID())) + responseRecorder := httptest.NewRecorder() + + handler.ServeHTTP(responseRecorder, request) + + recordedBody, err := io.ReadAll(responseRecorder.Body) + assert.NoError(t, err) + + expectedErrorBody := `{"errorMessage":"Internal Server Error","errorType":"InternalServerError"}` + "\n" + expectedHeaders := http.Header(map[string][]string{"Content-Type": {"application/json"}}) + + assert.Equal(t, http.StatusInternalServerError, responseRecorder.Code) + assert.Equal(t, expectedErrorBody, string(recordedBody)) + assert.Equal(t, expectedHeaders, responseRecorder.Header()) + telemetrySubscription.AssertCalled(t, "RecordCounterMetric", serverErrMetric, 1) +} + +func TestRenderLogsSubscriptionClosed(t *testing.T) { + agentName, reqBody, reqHeaders := "dummyName", []byte(`foobar`), map[string][]string{"Key": {"VAL1", "VAL2"}} + apiError := telemetry.ErrTelemetryServiceOff + clientErrMetric := telemetry.SubscribeClientErr + + registrationService := core.NewRegistrationService( + core.NewInitFlowSynchronization()) + + agent, err := registrationService.CreateExternalAgent(agentName) + assert.NoError(t, err) + + telemetrySubscription := &mockSubscriptionAPI{} + telemetrySubscription.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders, mock.MatchedBy(validIPPort)).Return([]byte(``), http.StatusOK, map[string][]string{}, apiError) + telemetrySubscription.On("RecordCounterMetric", clientErrMetric, 1) + telemetrySubscription.On("GetServiceClosedErrorMessage").Return("Logs API subscription is closed already") + telemetrySubscription.On("GetServiceClosedErrorType").Return("Logs.SubscriptionClosed") + + handler := NewRuntimeTelemetrySubscriptionHandler(registrationService, telemetrySubscription) + request := httptest.NewRequest("PUT", "/", bytes.NewBuffer(reqBody)) + for k, vals := range reqHeaders { + for _, v := range vals { + request.Header.Add(k, v) + } + } + + request = request.WithContext(context.WithValue(context.Background(), model.AgentIDCtxKey, agent.ID())) + responseRecorder := httptest.NewRecorder() + + handler.ServeHTTP(responseRecorder, request) + + recordedBody, err := io.ReadAll(responseRecorder.Body) + assert.NoError(t, err) + + expectedErrorBody := `{"errorMessage":"Logs API subscription is closed already","errorType":"Logs.SubscriptionClosed"}` + "\n" + expectedHeaders := http.Header(map[string][]string{"Content-Type": {"application/json"}}) + + assert.Equal(t, http.StatusForbidden, responseRecorder.Code) + assert.Equal(t, expectedErrorBody, string(recordedBody)) + assert.Equal(t, expectedHeaders, responseRecorder.Header()) + telemetrySubscription.AssertCalled(t, "RecordCounterMetric", clientErrMetric, 1) +} + +func TestRenderTelemetrySubscriptionClosed(t *testing.T) { + agentName, reqBody, reqHeaders := "dummyName", []byte(`foobar`), map[string][]string{"Key": {"VAL1", "VAL2"}} + apiError := telemetry.ErrTelemetryServiceOff + clientErrMetric := telemetry.SubscribeClientErr + + registrationService := core.NewRegistrationService( + core.NewInitFlowSynchronization()) + + agent, err := registrationService.CreateExternalAgent(agentName) + assert.NoError(t, err) + + telemetrySubscription := &mockSubscriptionAPI{} + telemetrySubscription.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders, mock.MatchedBy(validIPPort)).Return([]byte(``), http.StatusOK, map[string][]string{}, apiError) + telemetrySubscription.On("RecordCounterMetric", clientErrMetric, 1) + telemetrySubscription.On("GetServiceClosedErrorMessage").Return("Telemetry API subscription is closed already") + telemetrySubscription.On("GetServiceClosedErrorType").Return("Telemetry.SubscriptionClosed") + + handler := NewRuntimeTelemetrySubscriptionHandler(registrationService, telemetrySubscription) + request := httptest.NewRequest("PUT", "/", bytes.NewBuffer(reqBody)) + for k, vals := range reqHeaders { + for _, v := range vals { + request.Header.Add(k, v) + } + } + + request = request.WithContext(context.WithValue(context.Background(), model.AgentIDCtxKey, agent.ID())) + responseRecorder := httptest.NewRecorder() + + handler.ServeHTTP(responseRecorder, request) + + recordedBody, err := io.ReadAll(responseRecorder.Body) + assert.NoError(t, err) + + expectedErrorBody := `{"errorMessage":"Telemetry API subscription is closed already","errorType":"Telemetry.SubscriptionClosed"}` + "\n" + expectedHeaders := http.Header(map[string][]string{"Content-Type": {"application/json"}}) + + assert.Equal(t, http.StatusForbidden, responseRecorder.Code) + assert.Equal(t, expectedErrorBody, string(recordedBody)) + assert.Equal(t, expectedHeaders, responseRecorder.Header()) + telemetrySubscription.AssertCalled(t, "RecordCounterMetric", clientErrMetric, 1) +} diff --git a/internal/lambda-managed-instances/rapi/middleware/middleware.go b/internal/lambda-managed-instances/rapi/middleware/middleware.go new file mode 100644 index 0000000..0964fad --- /dev/null +++ b/internal/lambda-managed-instances/rapi/middleware/middleware.go @@ -0,0 +1,66 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package middleware + +import ( + "context" + "log/slog" + "net/http" + + "github.com/google/uuid" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/appctx" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi/rendering" +) + +func AgentUniqueIdentifierHeaderValidator(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + agentIdentifier := r.Header.Get(model.LambdaAgentIdentifier) + if len(agentIdentifier) == 0 { + rendering.RenderForbiddenWithTypeMsg(w, r, model.ErrAgentIdentifierMissing, "Missing Lambda-Extension-Identifier header") + return + } + agentID, e := uuid.Parse(agentIdentifier) + if e != nil { + rendering.RenderForbiddenWithTypeMsg(w, r, model.ErrAgentIdentifierInvalid, "Invalid Lambda-Extension-Identifier") + return + } + + r = r.WithContext(context.WithValue(r.Context(), model.AgentIDCtxKey, agentID)) + next.ServeHTTP(w, r) + }) +} + +func AppCtxMiddleware(appCtx appctx.ApplicationContext) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + r = appctx.RequestWithAppCtx(r, appCtx) + next.ServeHTTP(w, r) + } + return http.HandlerFunc(fn) + } +} + +func AccessLogMiddleware() func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + slog.Debug("API request", "method", r.Method, "url", r.URL, "headers", r.Header) + next.ServeHTTP(w, r) + } + return http.HandlerFunc(fn) + } +} + +func RuntimeReleaseMiddleware() func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + appCtx := appctx.FromRequest(r) + + appctx.UpdateAppCtxWithRuntimeRelease(r, appCtx) + next.ServeHTTP(w, r) + } + return http.HandlerFunc(fn) + } +} diff --git a/internal/lambda-managed-instances/rapi/middleware/middleware_test.go b/internal/lambda-managed-instances/rapi/middleware/middleware_test.go new file mode 100644 index 0000000..c619b96 --- /dev/null +++ b/internal/lambda-managed-instances/rapi/middleware/middleware_test.go @@ -0,0 +1,94 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package middleware + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/appctx" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi/model" +) + +type mockHandler struct{} + +func (h *mockHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {} + +func TestRuntimeReleaseMiddleware(t *testing.T) { + appCtx := appctx.NewApplicationContext() + router := chi.NewRouter() + handler := &mockHandler{} + router.Use(RuntimeReleaseMiddleware()) + router.Get("/", handler.ServeHTTP) + + userAgent := "foobar" + + responseRecorder := httptest.NewRecorder() + responseBody := make([]byte, 100) + request := httptest.NewRequest("GET", "/", bytes.NewReader(responseBody)) + request.Header.Set("User-Agent", userAgent) + router.ServeHTTP(responseRecorder, appctx.RequestWithAppCtx(request, appCtx)) + + assert.Equal(t, http.StatusOK, responseRecorder.Code) + ctxRuntimeRelease, ok := appCtx.Load(appctx.AppCtxRuntimeReleaseKey) + assert.True(t, ok) + assert.Equal(t, userAgent, ctxRuntimeRelease) +} + +func TestAgentUniqueIdentifierHeaderValidatorForbidden(t *testing.T) { + router := chi.NewRouter() + mockHandler := &mockHandler{} + router.Get("/", AgentUniqueIdentifierHeaderValidator(mockHandler).ServeHTTP) + responseBody := make([]byte, 100) + var errorResponse model.ErrorResponse + + request := httptest.NewRequest("GET", "/", bytes.NewReader(responseBody)) + + responseRecorder := httptest.NewRecorder() + router.ServeHTTP(responseRecorder, request) + assert.Equal(t, http.StatusForbidden, responseRecorder.Code) + respBody, _ := io.ReadAll(responseRecorder.Body) + require.NoError(t, json.Unmarshal(respBody, &errorResponse)) + assert.Equal(t, model.ErrAgentIdentifierMissing, errorResponse.ErrorType) + + responseRecorder = httptest.NewRecorder() + request.Header.Set(model.LambdaAgentIdentifier, "invalid-unique-identifier") + router.ServeHTTP(responseRecorder, request) + assert.Equal(t, http.StatusForbidden, responseRecorder.Code) + respBody, _ = io.ReadAll(responseRecorder.Body) + require.NoError(t, json.Unmarshal(respBody, &errorResponse)) + assert.Equal(t, model.ErrAgentIdentifierInvalid, errorResponse.ErrorType) +} + +func TestAgentUniqueIdentifierHeaderValidatorSuccess(t *testing.T) { + router := chi.NewRouter() + mockHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + val, ok := r.Context().Value(model.AgentIDCtxKey).(uuid.UUID) + if !ok { + assert.FailNow(t, "expected key not in request context") + } + assert.Equal(t, "85083764-ff1e-476f-ada1-d51f26e4f6be", val.String()) + }) + router.Get("/", AgentUniqueIdentifierHeaderValidator(mockHandler).ServeHTTP) + responseBody := make([]byte, 100) + request := httptest.NewRequest("GET", "/", bytes.NewReader(responseBody)) + ctx := context.Background() + request = request.WithContext(ctx) + + responseRecorder := httptest.NewRecorder() + responseRecorder.Code = http.StatusOK + request.Header.Set(model.LambdaAgentIdentifier, "85083764-ff1e-476f-ada1-d51f26e4f6be") + router.ServeHTTP(responseRecorder, request) + assert.Equal(t, http.StatusOK, responseRecorder.Code) +} diff --git a/internal/lambda-managed-instances/rapi/model/agentevent.go b/internal/lambda-managed-instances/rapi/model/agentevent.go new file mode 100644 index 0000000..0677fb9 --- /dev/null +++ b/internal/lambda-managed-instances/rapi/model/agentevent.go @@ -0,0 +1,28 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package model + +type AgentEvent struct { + EventType string `json:"eventType"` + DeadlineMs int64 `json:"deadlineMs"` +} + +type AgentInvokeEvent struct { + *AgentEvent + RequestID string `json:"requestId"` + InvokedFunctionArn string `json:"invokedFunctionArn"` + Tracing *Tracing `json:"tracing,omitempty"` +} + +type ShutdownReason string + +const ( + Spindown ShutdownReason = "spindown" + Failure ShutdownReason = "failure" +) + +type AgentShutdownEvent struct { + *AgentEvent + ShutdownReason ShutdownReason `json:"shutdownReason"` +} diff --git a/internal/lambda-managed-instances/rapi/model/agentregisterresponse.go b/internal/lambda-managed-instances/rapi/model/agentregisterresponse.go new file mode 100644 index 0000000..48ccf0a --- /dev/null +++ b/internal/lambda-managed-instances/rapi/model/agentregisterresponse.go @@ -0,0 +1,11 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package model + +type ExtensionRegisterResponse struct { + AccountID string `json:"accountId,omitempty"` + FunctionName string `json:"functionName"` + FunctionVersion string `json:"functionVersion"` + Handler string `json:"handler"` +} diff --git a/internal/lambda-managed-instances/rapi/model/cognitoidentity.go b/internal/lambda-managed-instances/rapi/model/cognitoidentity.go new file mode 100644 index 0000000..8da458e --- /dev/null +++ b/internal/lambda-managed-instances/rapi/model/cognitoidentity.go @@ -0,0 +1,9 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package model + +type CognitoIdentity struct { + CognitoIdentityID string `json:"cognitoIdentityId"` + CognitoIdentityPoolID string `json:"cognitoIdentityPoolId"` +} diff --git a/internal/lambda-managed-instances/rapi/model/constants.go b/internal/lambda-managed-instances/rapi/model/constants.go new file mode 100644 index 0000000..bfee3e0 --- /dev/null +++ b/internal/lambda-managed-instances/rapi/model/constants.go @@ -0,0 +1,16 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package model + +const ( + LambdaAgentIdentifier string = "Lambda-Extension-Identifier" + ErrAgentIdentifierMissing string = "Extension.MissingExtensionIdentifier" + ErrAgentIdentifierInvalid string = "Extension.InvalidExtensionIdentifier" +) + +type CtxKey int + +const ( + AgentIDCtxKey CtxKey = iota +) diff --git a/internal/lambda-managed-instances/rapi/model/error_cause.go b/internal/lambda-managed-instances/rapi/model/error_cause.go new file mode 100644 index 0000000..7be2259 --- /dev/null +++ b/internal/lambda-managed-instances/rapi/model/error_cause.go @@ -0,0 +1,98 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package model + +import ( + "encoding/json" + "fmt" +) + +const MaxErrorCauseSizeBytes = 64 << 10 + +type exceptionStackFrame struct { + Path string `json:"path,omitempty"` + Line int `json:"line,omitempty"` + Label string `json:"label,omitempty"` +} + +type exception struct { + Message string `json:"message,omitempty"` + Type string `json:"type,omitempty"` + Stack []exceptionStackFrame `json:"stack,omitempty"` +} + +type ErrorCause struct { + Exceptions []exception `json:"exceptions"` + WorkingDir string `json:"working_directory"` + Paths []string `json:"paths"` + Message string `json:"message,omitempty"` +} + +func newErrorCause(errorCauseJSON []byte) (*ErrorCause, error) { + var ec ErrorCause + + if err := json.Unmarshal(errorCauseJSON, &ec); err != nil { + return nil, fmt.Errorf("failed to parse error cause JSON: %s", err) + } + + return &ec, nil +} + +func (ec *ErrorCause) isValid() bool { + if len(ec.WorkingDir) == 0 && len(ec.Paths) == 0 && len(ec.Exceptions) == 0 && len(ec.Message) == 0 { + + return false + } + + return true +} + +func (ec *ErrorCause) croppedJSON() []byte { + + truncationFactors := []float64{0.8, 0.6, 0.4, 0.2} + for _, factor := range truncationFactors { + compactor := newErrorCauseCompactor(*ec) + compactor.crop(factor) + validErrorCauseJSON, err := json.Marshal(compactor.cause()) + if err != nil { + break + } + + if len(validErrorCauseJSON) <= MaxErrorCauseSizeBytes { + return validErrorCauseJSON + } + } + + compactor := newErrorCauseCompactor(*ec) + compactor.crop(0) + + validErrorCauseJSON, err := json.Marshal(compactor.cause()) + if err != nil { + return nil + } + + return validErrorCauseJSON +} + +func ValidatedErrorCauseJSON(errorCauseJSON []byte) ([]byte, error) { + errorCause, err := newErrorCause(errorCauseJSON) + if err != nil { + return nil, err + } + + if !errorCause.isValid() { + return nil, fmt.Errorf("error cause body has invalid format: %s", errorCauseJSON) + } + + validErrorCauseJSON, err := json.Marshal(errorCause) + if err != nil { + return nil, err + } + + if len(validErrorCauseJSON) > MaxErrorCauseSizeBytes { + return errorCause.croppedJSON(), nil + } + + return validErrorCauseJSON, nil +} diff --git a/internal/lambda-managed-instances/rapi/model/error_cause_compactor.go b/internal/lambda-managed-instances/rapi/model/error_cause_compactor.go new file mode 100644 index 0000000..f46c9a5 --- /dev/null +++ b/internal/lambda-managed-instances/rapi/model/error_cause_compactor.go @@ -0,0 +1,69 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package model + +const paddingForFieldNames = 4096 + +type errorCauseCompactor struct { + ec ErrorCause +} + +func newErrorCauseCompactor(errorCause ErrorCause) *errorCauseCompactor { + ec := errorCause + return &errorCauseCompactor{ec} +} + +func (c *errorCauseCompactor) cropStackTraces(factor float64) { + if factor > 0 { + factor = min(factor, 1.0) + exceptionsLen := float64(len(c.ec.Exceptions)) * factor + pathLen := float64(len(c.ec.Paths)) * factor + + c.ec.Exceptions = c.ec.Exceptions[:int(exceptionsLen)] + c.ec.Paths = c.ec.Paths[:int(pathLen)] + + return + } + + c.ec.Exceptions = nil + c.ec.Paths = nil +} + +func (c *errorCauseCompactor) cropMessage(factor float64) { + if factor > 0 { + return + } + + length := ((MaxErrorCauseSizeBytes - paddingForFieldNames) / 2) + c.ec.Message = cropString(c.ec.Message, length) +} + +func (c *errorCauseCompactor) cropWorkingDir(factor float64) { + if factor > 0 { + return + } + + length := ((MaxErrorCauseSizeBytes - paddingForFieldNames) / 2) + c.ec.WorkingDir = cropString(c.ec.WorkingDir, length) +} + +func (c *errorCauseCompactor) crop(factor float64) { + c.cropStackTraces(factor) + c.cropMessage(factor) + c.cropWorkingDir(factor) +} + +func (c *errorCauseCompactor) cause() *ErrorCause { + return &c.ec +} + +func cropString(str string, length int) string { + if len(str) <= length { + return str + } + + truncationIndicator := `...` + length -= len(truncationIndicator) + return str[:length] + truncationIndicator +} diff --git a/internal/lambda-managed-instances/rapi/model/error_cause_compactor_test.go b/internal/lambda-managed-instances/rapi/model/error_cause_compactor_test.go new file mode 100644 index 0000000..2322c48 --- /dev/null +++ b/internal/lambda-managed-instances/rapi/model/error_cause_compactor_test.go @@ -0,0 +1,83 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package model + +import ( + "fmt" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestErrorCauseCropMessageAndWorkingDir(t *testing.T) { + largeString := strings.Repeat("a", 4*MaxErrorCauseSizeBytes) + factorsAndExpectedLengths := map[float64]int{ + 1.5: len(largeString), + 1.0: len(largeString), + 0.5: len(largeString), + 0.0: (MaxErrorCauseSizeBytes - paddingForFieldNames) / 2, + } + + for factor, length := range factorsAndExpectedLengths { + cause := ErrorCause{ + Message: largeString, + WorkingDir: largeString, + } + + compactor := newErrorCauseCompactor(cause) + compactor.crop(factor) + + failureMsg := fmt.Sprintf("factor: %f, length: expected=%d, actual=%d", factor, length, len(compactor.ec.Message)) + assert.Len(t, compactor.ec.Message, length, "Message: "+failureMsg) + assert.Len(t, compactor.ec.WorkingDir, length, "WorkingDir: "+failureMsg) + } +} + +func TestErrorCauseCropStackTraces(t *testing.T) { + noOfElements := 3 * MaxErrorCauseSizeBytes + largeExceptions := make([]exception, noOfElements) + for i := range largeExceptions { + largeExceptions[i] = exception{Message: "a"} + } + + largePaths := make([]string, noOfElements) + for i := range largePaths { + largePaths[i] = "a" + } + + factorsAndExpectedLengths := map[float64]int{ + 1.5: noOfElements, + 1.0: noOfElements, + 0.5: int(noOfElements / 2), + 0.0: 0, + } + + for factor, length := range factorsAndExpectedLengths { + cause := ErrorCause{ + Exceptions: largeExceptions, + Paths: largePaths, + } + + compactor := newErrorCauseCompactor(cause) + compactor.crop(factor) + + failureMsg := fmt.Sprintf("factor: %f, length: expected=%d, actual=%d", factor, length, len(compactor.ec.WorkingDir)) + assert.Len(t, compactor.ec.Exceptions, length, "Exceptions: "+failureMsg) + assert.Len(t, compactor.ec.Paths, length, "Paths: "+failureMsg) + } +} + +func TestCropString(t *testing.T) { + maxLen := 5 + stringsAndExpectedCrops := map[string]string{ + "abcde": "abcde", + "abcdef": "ab...", + "": "", + } + + for str, expectedStr := range stringsAndExpectedCrops { + assert.Equal(t, expectedStr, cropString(str, maxLen)) + } +} diff --git a/internal/lambda-managed-instances/rapi/model/error_cause_test.go b/internal/lambda-managed-instances/rapi/model/error_cause_test.go new file mode 100644 index 0000000..677b5d9 --- /dev/null +++ b/internal/lambda-managed-instances/rapi/model/error_cause_test.go @@ -0,0 +1,142 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package model + +import ( + "fmt" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestErrorCauseValidationWhenCauseIsValid(t *testing.T) { + validCauses := [][]byte{ + []byte(`{"paths":[],"working_directory":"/foo/bar/baz","exceptions":[]}`), + []byte(`{"paths":["foo", "bar"]}`), + []byte(`{"working_directory":"/foo/bar/baz"}`), + []byte(`{"exceptions":[{"message": "foo"}, {"message": "bar"}]}`), + []byte(`{"exceptions":[{}]}`), + []byte(`{"exceptions":[{}], "arbitrary":"field"}`), + []byte(`{"message":"foo error"}`), + } + + for _, c := range validCauses { + _, err := ValidatedErrorCauseJSON(c) + assert.Nil(t, err, "validation failed for valid cause") + } +} + +func TestWorkingDirCropping(t *testing.T) { +} + +func TestErrorCauseMarshallingWhenCauseIsValid(t *testing.T) { + causesAndExpectations := map[string]string{ + `{"paths":[],"working_directory":"/","exceptions":[]}`: `{"paths":[],"working_directory":"/","exceptions":[]}`, + `{"paths":["f"]}`: `{"paths":["f"],"working_directory":"","exceptions":null}`, + `{"working_directory":"/foo"}`: `{"paths":null,"working_directory":"/foo","exceptions":null}`, + `{"exceptions":[{}], "arbitrary":"field"}`: `{"paths":null,"working_directory":"","exceptions":[{}]}`, + `{"message":"foo"}`: `{"paths":null,"working_directory":"","exceptions":null,"message":"foo"}`, + } + + for causeJSON, expectedJSON := range causesAndExpectations { + validCauseJSON, err := ValidatedErrorCauseJSON([]byte(causeJSON)) + assert.Nil(t, err, "validation failed for valid cause") + assert.JSONEq(t, string(expectedJSON), string(validCauseJSON)) + } +} + +func TestErrorCauseValidationWhenCauseIsInvalid(t *testing.T) { + invalidCauses := [][]byte{ + []byte(`{"paths":[],"working_directory":"","exceptions":[]}`), + []byte(`{"paths":"","working_directory":"","exceptions":[]}`), + []byte(`{"paths":"","exceptions":[]}`), + []byte(`{foo: invalid}`), + []byte(`{}`), + []byte(`{"arbitrary":"field"}`), + } + + for _, c := range invalidCauses { + causeJSON, err := ValidatedErrorCauseJSON(c) + assert.Error(t, err, "validation didn't return an error") + assert.Nil(t, causeJSON) + } +} + +func TestErrorCauseCroppedJSONForEmptyCause(t *testing.T) { + emptyCauseJSON := `{"exceptions":null, "paths":null, "working_directory":""}` + cause := ErrorCause{} + + causeJSON := cause.croppedJSON() + + assert.JSONEq(t, emptyCauseJSON, string(causeJSON)) +} + +func TestErrorCauseCroppedJSONForLargeCause(t *testing.T) { + noOfElements := MaxErrorCauseSizeBytes + largeExceptions := make([]exception, noOfElements) + for i := range largeExceptions { + largeExceptions[i] = exception{Message: "a"} + } + + largePaths := make([]string, noOfElements) + for i := range largePaths { + largePaths[i] = "a" + } + + largeCause := ErrorCause{ + Message: strings.Repeat("a", noOfElements), + WorkingDir: strings.Repeat("a", noOfElements), + Exceptions: largeExceptions, + Paths: largePaths, + } + expectedStringFieldsLen := (MaxErrorCauseSizeBytes - paddingForFieldNames) / 2 + + causeJSON := largeCause.croppedJSON() + assert.True(t, len(causeJSON) <= MaxErrorCauseSizeBytes, fmt.Sprintf("cropped JSON too long: len=%d", len(causeJSON))) + + parsedCause, err := newErrorCause(causeJSON) + assert.NoError(t, err, "failed to parse constructed JSON") + assert.Len(t, parsedCause.Message, expectedStringFieldsLen, "Message length incorrect") + assert.Len(t, parsedCause.WorkingDir, expectedStringFieldsLen, "WorkingDir length incorrect") + assert.Len(t, parsedCause.Exceptions, 0, "Exceptions length incorrect") + assert.Len(t, parsedCause.Paths, 0, "Paths length incorrect") +} + +func TestErrorCauseCroppedJSONForLargeCauseWithOnlyExceptionsAndPaths(t *testing.T) { + elementsAndExpectedLengthFactors := map[int]float64{ + 100: 0.8, + 5000: 0.6, + 8000: 0.4, + 10000: 0.2, + MaxErrorCauseSizeBytes / 4: 0.0, + } + + for noOfElements, factor := range elementsAndExpectedLengthFactors { + largeExceptions := make([]exception, noOfElements) + for i := range largeExceptions { + largeExceptions[i] = exception{Message: "a"} + } + + largePaths := make([]string, noOfElements) + for i := range largePaths { + largePaths[i] = "a" + } + + largeCause := ErrorCause{ + Exceptions: largeExceptions, + Paths: largePaths, + } + + causeJSON := largeCause.croppedJSON() + assert.True(t, len(causeJSON) <= MaxErrorCauseSizeBytes, fmt.Sprintf("cropped JSON too long: len=%d", len(causeJSON))) + + parsedCause, err := newErrorCause(causeJSON) + assert.NoError(t, err, "failed to parse constructed JSON") + assert.Len(t, parsedCause.Message, 0, "Message length incorrect") + assert.Len(t, parsedCause.WorkingDir, 0, "WorkingDir length incorrect") + assert.Len(t, parsedCause.Exceptions, int(float64(noOfElements)*factor), "Exceptions length incorrect") + assert.Len(t, parsedCause.Paths, int(float64(noOfElements)*factor), "Paths length incorrect") + } +} diff --git a/internal/lambda-managed-instances/rapi/model/errorresponse.go b/internal/lambda-managed-instances/rapi/model/errorresponse.go new file mode 100644 index 0000000..c49ce46 --- /dev/null +++ b/internal/lambda-managed-instances/rapi/model/errorresponse.go @@ -0,0 +1,10 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package model + +type ErrorResponse struct { + ErrorMessage string `json:"errorMessage"` + ErrorType string `json:"errorType"` + StackTrace []string `json:"stackTrace,omitempty"` +} diff --git a/internal/lambda-managed-instances/rapi/model/statusresponse.go b/internal/lambda-managed-instances/rapi/model/statusresponse.go new file mode 100644 index 0000000..3188702 --- /dev/null +++ b/internal/lambda-managed-instances/rapi/model/statusresponse.go @@ -0,0 +1,8 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package model + +type StatusResponse struct { + Status string `json:"status"` +} diff --git a/internal/lambda-managed-instances/rapi/model/tracing.go b/internal/lambda-managed-instances/rapi/model/tracing.go new file mode 100644 index 0000000..580918c --- /dev/null +++ b/internal/lambda-managed-instances/rapi/model/tracing.go @@ -0,0 +1,35 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package model + +type TracingType string + +const ( + XRayTracingType TracingType = "X-Amzn-Trace-Id" +) + +const ( + XRaySampled = "1" + XRayNonSampled = "0" +) + +type Tracing struct { + Type TracingType `json:"type"` + XRayTracing +} + +type XRayTracing struct { + Value string `json:"value"` +} + +func NewXRayTracing(value string) *Tracing { + if len(value) == 0 { + return nil + } + + return &Tracing{ + XRayTracingType, + XRayTracing{value}, + } +} diff --git a/internal/lambda-managed-instances/rapi/rapi_fuzz_test.go b/internal/lambda-managed-instances/rapi/rapi_fuzz_test.go new file mode 100644 index 0000000..c3519b6 --- /dev/null +++ b/internal/lambda-managed-instances/rapi/rapi_fuzz_test.go @@ -0,0 +1,98 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package rapi + +import ( + "errors" + "fmt" + "log/slog" + "net/http" + "net/http/httptest" + "net/netip" + "net/url" + "os" + "regexp" + "strings" + "testing" + "unicode" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/testdata" +) + +func makeRapiServer(flowTest *testdata.FlowTest) *Server { + s, err := NewServer( + netip.MustParseAddrPort("127.0.0.1:0"), + flowTest.AppCtx, + flowTest.RegistrationService, + flowTest.RenderingService, + flowTest.TelemetrySubscription, + nil, + ) + if err != nil { + panic(err) + } + return s +} + +func makeTargetURL(path string, apiVersion string) string { + protocol := "http" + endpoint := os.Getenv("AWS_LAMBDA_RUNTIME_API") + baseurl := fmt.Sprintf("%s://%s%s", protocol, endpoint, apiVersion) + + url := fmt.Sprintf("%s%s", baseurl, path) + + return strings.TrimRight(url, "#") +} + +func serveTestRequest(rapiServer *Server, request *http.Request) *httptest.ResponseRecorder { + responseRecorder := httptest.NewRecorder() + rapiServer.server.Handler.ServeHTTP(responseRecorder, request) + slog.Debug("test request", "url", request.URL, "status_code", responseRecorder.Code) + + return responseRecorder +} + +func parseToURLStruct(rawPath string) (*url.URL, error) { + invalidChars := regexp.MustCompile(`[ %]+`) + if invalidChars.MatchString(rawPath) { + return nil, errors.New("url must not contain spaces or %") + } + + for _, r := range rawPath { + if !unicode.IsGraphic(r) { + return nil, errors.New("url contains non-graphic runes") + } + } + + if _, err := url.ParseRequestURI(rawPath); err != nil { + return nil, err + } + + u, err := url.Parse(rawPath) + if err != nil { + return nil, err + } + + if u.Scheme == "" { + return nil, errors.New("blank url scheme") + } + + return u, nil +} + +func assertExpectedPathResponseCode(t *testing.T, code int, target string) { + if code != http.StatusOK && + code != http.StatusAccepted && + code != http.StatusForbidden { + t.Errorf("Unexpected status code (%v) for target (%v)", code, target) + } +} + +func assertUnexpectedPathResponseCode(t *testing.T, code int, target string) { + if code != http.StatusNotFound && + code != http.StatusMethodNotAllowed && + code != http.StatusBadRequest { + t.Errorf("Unexpected status code (%v) for target (%v)", code, target) + } +} diff --git a/internal/lambda-managed-instances/rapi/rendering/doc.go b/internal/lambda-managed-instances/rapi/rendering/doc.go new file mode 100644 index 0000000..42cafb3 --- /dev/null +++ b/internal/lambda-managed-instances/rapi/rendering/doc.go @@ -0,0 +1,4 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package rendering diff --git a/internal/lambda-managed-instances/rapi/rendering/render_error.go b/internal/lambda-managed-instances/rapi/rendering/render_error.go new file mode 100644 index 0000000..804e011 --- /dev/null +++ b/internal/lambda-managed-instances/rapi/rendering/render_error.go @@ -0,0 +1,93 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package rendering + +import ( + "fmt" + "log/slog" + "net/http" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi/model" +) + +func RenderForbiddenWithTypeMsg(w http.ResponseWriter, r *http.Request, errorType string, format string, args ...interface{}) { + if err := RenderJSON(http.StatusForbidden, w, r, &model.ErrorResponse{ + ErrorType: errorType, + ErrorMessage: fmt.Sprintf(format, args...), + }); err != nil { + slog.Warn("Error while rendering response", "err", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +func RenderInternalServerError(w http.ResponseWriter, r *http.Request) { + if err := RenderJSON(http.StatusInternalServerError, w, r, &model.ErrorResponse{ + ErrorMessage: "Internal Server Error", + ErrorType: ErrorTypeInternalServerError, + }); err != nil { + slog.Warn("Error while rendering response", "err", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +func RenderRequestEntityTooLarge(w http.ResponseWriter, r *http.Request) { + if err := RenderJSON(http.StatusRequestEntityTooLarge, w, r, &model.ErrorResponse{ + ErrorMessage: fmt.Sprintf("Exceeded maximum allowed payload size (%d bytes).", interop.MaxPayloadSize), + ErrorType: ErrorTypeRequestEntityTooLarge, + }); err != nil { + slog.Warn("Error while rendering response", "err", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +func RenderTruncatedHTTPRequestError(w http.ResponseWriter, r *http.Request) { + if err := RenderJSON(http.StatusBadRequest, w, r, &model.ErrorResponse{ + ErrorMessage: "HTTP request detected as truncated", + ErrorType: ErrorTypeTruncatedHTTPRequest, + }); err != nil { + slog.Warn("Error while rendering response", "err", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +func RenderInvalidRequestID(w http.ResponseWriter, r *http.Request) { + + if err := RenderJSON(http.StatusBadRequest, w, r, &model.ErrorResponse{ + ErrorMessage: "Invalid request ID", + ErrorType: "InvalidRequestID", + }); err != nil { + slog.Warn("Error while rendering response", "err", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +func RenderInvokeTimeout(w http.ResponseWriter, r *http.Request) { + if err := RenderJSON(http.StatusGone, w, r, &model.ErrorResponse{ + ErrorMessage: "Invoke timeout", + ErrorType: "InvokeTimeout", + }); err != nil { + slog.Warn("Error while rendering response", "err", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +func RenderInvalidFunctionResponseMode(w http.ResponseWriter, r *http.Request) { + if err := RenderJSON(http.StatusBadRequest, w, r, &model.ErrorResponse{ + ErrorMessage: "Invalid function response mode", + ErrorType: "InvalidFunctionResponseMode", + }); err != nil { + slog.Warn("Error while rendering response", "err", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +func RenderInteropError(writer http.ResponseWriter, request *http.Request, err error) { + if err == interop.ErrResponseSent { + RenderInvalidRequestID(writer, request) + } else { + slog.Error("Interop error", "err", err) + panic(err) + } +} diff --git a/internal/lambda-managed-instances/rapi/rendering/render_json.go b/internal/lambda-managed-instances/rapi/rendering/render_json.go new file mode 100644 index 0000000..4991ad5 --- /dev/null +++ b/internal/lambda-managed-instances/rapi/rendering/render_json.go @@ -0,0 +1,28 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package rendering + +import ( + "bytes" + "encoding/json" + "log/slog" + "net/http" +) + +func RenderJSON(status int, w http.ResponseWriter, r *http.Request, v interface{}) error { + buf := &bytes.Buffer{} + enc := json.NewEncoder(buf) + enc.SetEscapeHTML(true) + if err := enc.Encode(v); err != nil { + return err + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + if _, err := w.Write(buf.Bytes()); err != nil { + slog.Warn("Error while writing response body", "err", err) + } + + return nil +} diff --git a/internal/lambda-managed-instances/rapi/rendering/rendering.go b/internal/lambda-managed-instances/rapi/rendering/rendering.go new file mode 100644 index 0000000..c655923 --- /dev/null +++ b/internal/lambda-managed-instances/rapi/rendering/rendering.go @@ -0,0 +1,259 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package rendering + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "io" + "log/slog" + "net/http" + "strconv" + "sync" + "time" + + "github.com/google/uuid" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi/model" +) + +const ( + ErrorTypeInternalServerError = "InternalServerError" + + ErrorTypeInvalidStateTransition = "InvalidStateTransition" + + ErrorTypeInvalidRequestID = "InvalidRequestID" + + ErrorTypeRequestEntityTooLarge = "RequestEntityTooLarge" + + ErrorTypeTruncatedHTTPRequest = "TruncatedHTTPRequest" +) + +var ErrRenderingServiceStateNotSet = errors.New("EventRenderingService state not set") + +type RendererState interface { + RenderAgentEvent(w http.ResponseWriter, r *http.Request) error + RenderRuntimeEvent(w http.ResponseWriter, r *http.Request) error +} + +type EventRenderingService struct { + mutex *sync.RWMutex + currentState RendererState +} + +func NewRenderingService() *EventRenderingService { + return &EventRenderingService{ + mutex: &sync.RWMutex{}, + } +} + +func (s *EventRenderingService) SetRenderer(state RendererState) { + s.mutex.Lock() + defer s.mutex.Unlock() + s.currentState = state +} + +func (s *EventRenderingService) RenderAgentEvent(w http.ResponseWriter, r *http.Request) error { + s.mutex.RLock() + defer s.mutex.RUnlock() + if s.currentState == nil { + return ErrRenderingServiceStateNotSet + } + return s.currentState.RenderAgentEvent(w, r) +} + +func (s *EventRenderingService) RenderRuntimeEvent(w http.ResponseWriter, r *http.Request) error { + s.mutex.RLock() + defer s.mutex.RUnlock() + if s.currentState == nil { + return ErrRenderingServiceStateNotSet + } + return s.currentState.RenderRuntimeEvent(w, r) +} + +type InvokeRendererMetrics struct { + ReadTime time.Duration + SizeBytes int +} + +type InvokeRenderer struct { + ctx context.Context + invoke *interop.Invoke + tracingHeaderParser func(context.Context) string + requestBuffer *bytes.Buffer + requestMutex sync.Mutex + metrics InvokeRendererMetrics +} + +func NewInvokeRenderer(ctx context.Context, invoke *interop.Invoke, requestBuffer *bytes.Buffer, traceParser func(context.Context) string) *InvokeRenderer { + requestBuffer.Reset() + return &InvokeRenderer{ + invoke: invoke, + ctx: ctx, + tracingHeaderParser: traceParser, + requestBuffer: requestBuffer, + requestMutex: sync.Mutex{}, + } +} + +func newAgentInvokeEvent(ctx context.Context, req *interop.Invoke) (*model.AgentInvokeEvent, error) { + return &model.AgentInvokeEvent{ + AgentEvent: &model.AgentEvent{ + EventType: "INVOKE", + DeadlineMs: req.GetDeadlineMs(ctx), + }, + RequestID: req.ID, + InvokedFunctionArn: req.InvokedFunctionArn, + Tracing: model.NewXRayTracing(req.TraceID), + }, nil +} + +func (s *InvokeRenderer) RenderAgentEvent(writer http.ResponseWriter, request *http.Request) error { + event, err := newAgentInvokeEvent(s.ctx, s.invoke) + if err != nil { + return err + } + + bytes, err := json.Marshal(event) + if err != nil { + return err + } + + eventID := uuid.New() + headers := writer.Header() + headers.Set("Lambda-Extension-Event-Identifier", eventID.String()) + headers.Set("Content-Type", "application/json") + writer.WriteHeader(http.StatusOK) + + if _, err := writer.Write(bytes); err != nil { + return err + } + return nil +} + +func (s *InvokeRenderer) bufferInvokeRequest() error { + s.requestMutex.Lock() + defer s.requestMutex.Unlock() + var err error = nil + if s.requestBuffer.Len() == 0 { + reader := io.LimitReader(s.invoke.Payload, interop.MaxPayloadSize) + start := time.Now() + _, err = s.requestBuffer.ReadFrom(reader) + s.metrics = InvokeRendererMetrics{ + ReadTime: time.Since(start), + SizeBytes: s.requestBuffer.Len(), + } + } + return err +} + +func (s *InvokeRenderer) RenderRuntimeEvent(writer http.ResponseWriter, request *http.Request) error { + invoke := s.invoke + customerTraceID := s.tracingHeaderParser(s.ctx) + + cognitoIdentityJSON := "" + if len(invoke.CognitoIdentityID) != 0 || len(invoke.CognitoIdentityPoolID) != 0 { + cognitoJSON, err := json.Marshal(model.CognitoIdentity{ + CognitoIdentityID: invoke.CognitoIdentityID, + CognitoIdentityPoolID: invoke.CognitoIdentityPoolID, + }) + if err != nil { + return err + } + + cognitoIdentityJSON = string(cognitoJSON) + } + + var deadlineHeader string + if deadlineMs := invoke.GetDeadlineMs(s.ctx); deadlineMs > 0 { + deadlineHeader = strconv.FormatInt(deadlineMs, 10) + } + + renderInvokeHeaders(writer, invoke.ID, customerTraceID, invoke.ClientContext, + cognitoIdentityJSON, invoke.InvokedFunctionArn, deadlineHeader, invoke.ContentType) + + if invoke.Payload != nil { + if err := s.bufferInvokeRequest(); err != nil { + return err + } + _, err := writer.Write(s.requestBuffer.Bytes()) + return err + } + + return nil +} + +func (s *InvokeRenderer) GetMetrics() InvokeRendererMetrics { + s.requestMutex.Lock() + defer s.requestMutex.Unlock() + return s.metrics +} + +type ShutdownRenderer struct { + AgentEvent model.AgentShutdownEvent +} + +func (s *ShutdownRenderer) RenderAgentEvent(w http.ResponseWriter, r *http.Request) error { + bytes, err := json.Marshal(s.AgentEvent) + if err != nil { + return err + } + if _, err := w.Write(bytes); err != nil { + return err + } + return nil +} + +func (s *ShutdownRenderer) RenderRuntimeEvent(w http.ResponseWriter, r *http.Request) error { + panic("We should SIGTERM runtime") +} + +func renderInvokeHeaders(writer http.ResponseWriter, invokeID string, customerTraceID string, clientContext string, + cognitoIdentity string, invokedFunctionArn string, deadlineMs string, contentType string, +) { + setHeaderIfNotEmpty := func(headers http.Header, key string, value string) { + if value != "" { + headers.Set(key, value) + } + } + + headers := writer.Header() + setHeaderIfNotEmpty(headers, "Lambda-Runtime-Aws-Request-Id", invokeID) + setHeaderIfNotEmpty(headers, "Lambda-Runtime-Trace-Id", customerTraceID) + setHeaderIfNotEmpty(headers, "Lambda-Runtime-Client-Context", clientContext) + setHeaderIfNotEmpty(headers, "Lambda-Runtime-Cognito-Identity", cognitoIdentity) + setHeaderIfNotEmpty(headers, "Lambda-Runtime-Invoked-Function-Arn", invokedFunctionArn) + setHeaderIfNotEmpty(headers, "Lambda-Runtime-Deadline-Ms", deadlineMs) + if contentType == "" { + contentType = "application/json" + } + headers.Set("Content-Type", contentType) + writer.WriteHeader(http.StatusOK) +} + +func RenderRuntimeLogsResponse(w http.ResponseWriter, respBody []byte, status int, headers map[string][]string) error { + respHeaders := w.Header() + for k, vals := range headers { + for _, v := range vals { + respHeaders.Add(k, v) + } + } + + w.WriteHeader(status) + + _, err := w.Write(respBody) + return err +} + +func RenderAccepted(w http.ResponseWriter, r *http.Request) { + if err := RenderJSON(http.StatusAccepted, w, r, &model.StatusResponse{ + Status: "OK", + }); err != nil { + slog.Warn("Error while rendering response", "err", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} diff --git a/internal/lambda-managed-instances/rapi/router.go b/internal/lambda-managed-instances/rapi/router.go new file mode 100644 index 0000000..63486ec --- /dev/null +++ b/internal/lambda-managed-instances/rapi/router.go @@ -0,0 +1,79 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package rapi + +import ( + "net/http" + + "github.com/go-chi/chi" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/appctx" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/core" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi/handler" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi/middleware" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi/rendering" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/telemetry" +) + +type runtimeRequestHandler interface { + handler.RuntimeNextHandler + handler.RuntimeResponseHandler + handler.RuntimeErrorHandler +} + +func NewRouter(appCtx appctx.ApplicationContext, registrationService core.RegistrationService, renderingService *rendering.EventRenderingService, runtimeReqHandler runtimeRequestHandler) http.Handler { + router := chi.NewRouter() + router.Use(middleware.AppCtxMiddleware(appCtx)) + router.Use(middleware.AccessLogMiddleware()) + router.Use(middleware.RuntimeReleaseMiddleware()) + + router.Get("/ping", http.MaxBytesHandler(handler.NewPingHandler(), 0).ServeHTTP) + + router.Get("/runtime/invocation/next", + http.MaxBytesHandler(handler.NewInvocationNextHandler(registrationService, runtimeReqHandler), 0).ServeHTTP) + + router.Post("/runtime/invocation/{awsrequestid}/response", + handler.NewInvocationResponseHandler(runtimeReqHandler).ServeHTTP) + + router.Post("/runtime/invocation/{awsrequestid}/error", + http.MaxBytesHandler(handler.NewInvocationErrorHandler(runtimeReqHandler), requestBodyLimitBytes).ServeHTTP) + + router.Post("/runtime/init/error", http.MaxBytesHandler(handler.NewInitErrorHandler(registrationService), requestBodyLimitBytes).ServeHTTP) + return router +} + +func ExtensionsRouter(appCtx appctx.ApplicationContext, registrationService core.RegistrationService, renderingService *rendering.EventRenderingService) http.Handler { + router := chi.NewRouter() + router.Use(middleware.AccessLogMiddleware()) + router.Use(middleware.AppCtxMiddleware(appCtx)) + + registerHandler := handler.NewAgentRegisterHandler(registrationService) + router.Post("/extension/register", + registerHandler.ServeHTTP) + + router.Get("/extension/event/next", + middleware.AgentUniqueIdentifierHeaderValidator( + http.MaxBytesHandler(handler.NewAgentNextHandler(registrationService, renderingService), 0)).ServeHTTP) + + router.Post("/extension/init/error", + middleware.AgentUniqueIdentifierHeaderValidator( + handler.NewAgentInitErrorHandler(registrationService)).ServeHTTP) + + router.Post("/extension/exit/error", + middleware.AgentUniqueIdentifierHeaderValidator( + handler.NewAgentExitErrorHandler(registrationService)).ServeHTTP) + + return router +} + +func TelemetryAPIRouter(registrationService core.RegistrationService, telemetrySubscriptionAPI telemetry.SubscriptionAPI) http.Handler { + router := chi.NewRouter() + router.Use(middleware.AccessLogMiddleware()) + + router.Put("/telemetry", + middleware.AgentUniqueIdentifierHeaderValidator( + handler.NewRuntimeTelemetrySubscriptionHandler(registrationService, telemetrySubscriptionAPI)).ServeHTTP) + + return router +} diff --git a/internal/lambda-managed-instances/rapi/server.go b/internal/lambda-managed-instances/rapi/server.go new file mode 100644 index 0000000..8a974bc --- /dev/null +++ b/internal/lambda-managed-instances/rapi/server.go @@ -0,0 +1,122 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package rapi + +import ( + "context" + "fmt" + "log/slog" + "net" + "net/http" + "net/netip" + + "github.com/go-chi/chi" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/appctx" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/core" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi/rendering" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/telemetry" +) + +const ( + version20180601 = "/2018-06-01" + version20200101 = "/2020-01-01" + version20220701 = "/2022-07-01" +) + +const requestBodyLimitBytes int64 = 1 * 1024 * 1024 + +type Server struct { + runtimeAPIAddrPort netip.AddrPort + server *http.Server + listener net.Listener + exit chan error +} + +func SaveConnInContext(ctx context.Context, c net.Conn) context.Context { + return context.WithValue(ctx, interop.HTTPConnKey, c) +} + +func NewServer( + runtimeAPIAddrPort netip.AddrPort, + appCtx appctx.ApplicationContext, + registrationService core.RegistrationService, + renderingService *rendering.EventRenderingService, + telemetrySubscriptionAPI telemetry.SubscriptionAPI, + runtimeReqHandler runtimeRequestHandler, +) (*Server, error) { + exitErrors := make(chan error, 1) + + router := chi.NewRouter() + router.Mount(version20180601, NewRouter(appCtx, registrationService, renderingService, runtimeReqHandler)) + router.Mount(version20200101, http.MaxBytesHandler(ExtensionsRouter(appCtx, registrationService, renderingService), requestBodyLimitBytes)) + + router.Mount(version20220701, http.MaxBytesHandler(TelemetryAPIRouter(registrationService, telemetrySubscriptionAPI), requestBodyLimitBytes)) + + listener, err := net.Listen("tcp", runtimeAPIAddrPort.String()) + if err != nil { + return nil, err + } + + return &Server{ + listener: listener, + runtimeAPIAddrPort: netip.MustParseAddrPort(listener.Addr().String()), + + server: &http.Server{ + Handler: router, + ConnContext: SaveConnInContext, + }, + exit: exitErrors, + }, nil +} + +func (s *Server) Serve(ctx context.Context) error { + defer func() { + if err := s.Close(); err != nil { + slog.Error("Error closing server", "err", err) + } + }() + + select { + case err := <-s.serveAsync(): + return err + + case err := <-s.exit: + slog.Error("Error triggered exit", "err", err) + return err + + case <-ctx.Done(): + return ctx.Err() + } +} + +func (s *Server) serveAsync() chan error { + errors := make(chan error) + go func() { + errors <- s.server.Serve(s.listener) + }() + + return errors +} + +func (s *Server) AddrPort() netip.AddrPort { + return s.runtimeAPIAddrPort +} + +func (s *Server) URL(endpoint string) string { + return fmt.Sprintf("http://%s%s%s", s.runtimeAPIAddrPort, version20180601, endpoint) +} + +func (s *Server) Close() error { + err := s.server.Close() + if err == nil { + slog.Info("Runtime API Server closed") + } + return err +} + +func (s *Server) Shutdown() error { + return s.server.Shutdown(context.Background()) +} diff --git a/internal/lambda-managed-instances/rapi/telemetry_logs_fuzz_test.go b/internal/lambda-managed-instances/rapi/telemetry_logs_fuzz_test.go new file mode 100644 index 0000000..1f4dbf2 --- /dev/null +++ b/internal/lambda-managed-instances/rapi/telemetry_logs_fuzz_test.go @@ -0,0 +1,168 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package rapi + +import ( + "bytes" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/netip" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/telemetry" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/testdata" +) + +const ( + telemetryHandlerPath = "/telemetry" + + samplePayload = `{"foo" : "bar"}` +) + +func FuzzTelemetryLogRouters(f *testing.F) { + f.Add(makeTargetURL(telemetryHandlerPath, version20220701), []byte(samplePayload)) + f.Add(makeTargetURL("/telemetry#", version20220701), []byte(samplePayload)) + f.Add(makeTargetURL("/telemetry#fragment", version20220701), []byte(samplePayload)) + + telemetryPath := fmt.Sprintf("%s%s", version20220701, telemetryHandlerPath) + + f.Fuzz(func(t *testing.T, rawPath string, payload []byte) { + u, err := parseToURLStruct(rawPath) + if err != nil { + t.Skipf("error parsing url: %v. Skipping test.", err) + } + + flowTest := testdata.NewFlowTest() + + rapiServer := makeRapiServerWithMockSubscriptionAPI(flowTest, newMockSubscriptionAPI(true)) + + request := httptest.NewRequest("PUT", rawPath, bytes.NewReader(payload)) + responseRecorder := serveTestRequest(rapiServer, request) + + if u.Path == telemetryPath && u.Fragment == "" { + assertExpectedPathResponseCode(t, responseRecorder.Code, rawPath) + } else { + assertUnexpectedPathResponseCode(t, responseRecorder.Code, rawPath) + } + }) +} + +func FuzzTelemetryHandler(f *testing.F) { + fuzzSubscriptionAPIHandler(f, telemetryHandlerPath, version20220701) +} + +func fuzzSubscriptionAPIHandler(f *testing.F, handlerPath string, apiVersion string) { + flowTest := testdata.NewFlowTest() + agent := makeExternalAgent(flowTest.RegistrationService) + f.Add([]byte(samplePayload), agent.ID().String(), true) + f.Add([]byte(samplePayload), agent.ID().String(), false) + + f.Fuzz(func(t *testing.T, payload []byte, agentIdentifierHeader string, serviceOn bool) { + telemetrySubscriptionAPI := newMockSubscriptionAPI(serviceOn) + rapiServer := makeRapiServerWithMockSubscriptionAPI(flowTest, telemetrySubscriptionAPI) + + target := makeTargetURL(handlerPath, apiVersion) + request := httptest.NewRequest("PUT", target, bytes.NewReader(payload)) + request.Header.Set(model.LambdaAgentIdentifier, agentIdentifierHeader) + + responseRecorder := serveTestRequest(rapiServer, request) + + if agentIdentifierHeader == "" { + assertForbiddenErrorType(t, responseRecorder, model.ErrAgentIdentifierMissing) + return + } + + if _, err := uuid.Parse(agentIdentifierHeader); err != nil { + assertForbiddenErrorType(t, responseRecorder, model.ErrAgentIdentifierInvalid) + return + } + + if agentIdentifierHeader != agent.ID().String() { + assertForbiddenErrorType(t, responseRecorder, "Extension.UnknownExtensionIdentifier") + return + } + + if !serviceOn { + assertForbiddenErrorType(t, responseRecorder, telemetrySubscriptionAPI.GetServiceClosedErrorType()) + return + } + + assert.Equal(t, payload, telemetrySubscriptionAPI.receivedPayload) + }) +} + +func makeRapiServerWithMockSubscriptionAPI( + flowTest *testdata.FlowTest, + telemetrySubscription telemetry.SubscriptionAPI, +) *Server { + s, err := NewServer( + netip.MustParseAddrPort("127.0.0.1:0"), + flowTest.AppCtx, + flowTest.RegistrationService, + flowTest.RenderingService, + telemetrySubscription, + nil, + ) + if err != nil { + panic(err) + } + return s +} + +type mockSubscriptionAPI struct { + serviceOn bool + receivedPayload []byte +} + +func newMockSubscriptionAPI(serviceOn bool) *mockSubscriptionAPI { + return &mockSubscriptionAPI{ + serviceOn: serviceOn, + } +} + +func (m *mockSubscriptionAPI) Subscribe(agentName string, body io.Reader, headers map[string][]string, remoteAddr string) ([]byte, int, map[string][]string, error) { + if !m.serviceOn { + return nil, 0, map[string][]string{}, telemetry.ErrTelemetryServiceOff + } + + bodyBytes, err := io.ReadAll(body) + if err != nil { + return nil, 0, map[string][]string{}, fmt.Errorf("error Reading the body of subscription request: %s", err) + } + + m.receivedPayload = bodyBytes + + return []byte("OK"), http.StatusOK, map[string][]string{}, nil +} + +func (m *mockSubscriptionAPI) RecordCounterMetric(metricName string, count int) {} + +func (m *mockSubscriptionAPI) FlushMetrics() interop.TelemetrySubscriptionMetrics { + return interop.TelemetrySubscriptionMetrics{} +} + +func (m *mockSubscriptionAPI) Clear() {} + +func (m *mockSubscriptionAPI) TurnOff() {} + +func (m *mockSubscriptionAPI) GetEndpointURL() string { + return "/subscribe" +} + +func (m *mockSubscriptionAPI) GetServiceClosedErrorMessage() string { + return "Subscription API is closed" +} + +func (m *mockSubscriptionAPI) GetServiceClosedErrorType() string { + return "SubscriptionClosed" +} + +func (m *mockSubscriptionAPI) Configure(passphrase string, addr netip.AddrPort) {} diff --git a/internal/lambda-managed-instances/rapid/handlers.go b/internal/lambda-managed-instances/rapid/handlers.go new file mode 100644 index 0000000..9fb77f2 --- /dev/null +++ b/internal/lambda-managed-instances/rapid/handlers.go @@ -0,0 +1,379 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package rapid + +import ( + "context" + "errors" + "fmt" + "log/slog" + "path" + "strings" + "time" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/agents" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/appctx" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/core" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/invoke" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/logging" + internalmodel "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi/rendering" + rapidmodel "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" + supvmodel "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/supervisor/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/telemetry" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/utils" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/utils/invariant" +) + +var ErrorInvalidEntryPoint = errors.New("invalid entrypoint, runtime process spawn failed") + +const ( + runtimeProcessName = "runtime" + + maxExtensionNamesLength = 127 +) + +type processSupervisor struct { + supvmodel.ProcessSupervisor +} + +type rapidContext struct { + interopServer interop.Server + initExecutionData interop.InitExecutionData + server *rapi.Server + appCtx appctx.ApplicationContext + supervisor processSupervisor + initFlow core.InitFlowSynchronization + registrationService core.RegistrationService + renderingService *rendering.EventRenderingService + telemetrySubscriptionAPI telemetry.SubscriptionAPI + logsEgressAPI telemetry.StdLogsEgressAPI + eventsAPI interop.EventsAPI + invokeRouter *invoke.InvokeRouter + initMetrics interop.InitMetrics + + shutdownContext *shutdownContext + fileUtils utils.FileUtil + RuntimeStartedTime time.Time + RuntimeOverheadStartedTime time.Time + + processTermChan chan rapidmodel.AppError +} + +var _ interop.RapidContext = (*rapidContext)(nil) + +func (r *rapidContext) ProcessTerminationNotifier() <-chan rapidmodel.AppError { + return r.processTermChan +} + +func (r *rapidContext) GetExtensionNames() string { + var extensionNamesList []string + for _, agent := range r.registrationService.AgentsInfo() { + extensionNamesList = append(extensionNamesList, agent.Name) + } + extensionNames := strings.Join(extensionNamesList, ";") + if len(extensionNames) > maxExtensionNamesLength { + if idx := strings.LastIndex(extensionNames[:maxExtensionNamesLength], ";"); idx != -1 { + return extensionNames[:idx] + } + return "" + } + return extensionNames +} + +func doInitExtensions(ctx context.Context, execCtx *rapidContext) rapidmodel.AppError { + initFlow := execCtx.registrationService.InitFlow() + + bootstraps := agents.ListExternalAgentPaths(execCtx.fileUtils, agents.ExtensionsDir, "/") + + if err := initFlow.SetExternalAgentsRegisterCount(uint16(len(bootstraps))); err != nil { + return rapidmodel.WrapErrorIntoPlatformFatalError(err, rapidmodel.ErrorAgentCountRegistrationFailed) + } + + for _, agentPath := range bootstraps { + + agent, err := execCtx.registrationService.CreateExternalAgent(path.Base(agentPath)) + if err != nil { + return rapidmodel.WrapErrorIntoPlatformFatalError(err, rapidmodel.ErrorAgentExtensionCreationFailed) + } + + if execCtx.registrationService.CountAgents() > core.MaxAgentsAllowed { + if err := agent.LaunchError(rapidmodel.ErrorAgentTooManyExtensions); err != nil { + logging.Warn(ctx, "LaunchError transition fail", "agent", agent, "state", agent.GetState().Name(), "err", err) + } + customerErr := rapidmodel.WrapErrorIntoCustomerInvalidError(nil, rapidmodel.ErrorAgentTooManyExtensions) + appctx.StoreFirstFatalError(execCtx.appCtx, customerErr) + return customerErr + } + + agentStdoutWriter, agentStderrWriter, err := execCtx.logsEgressAPI.GetExtensionSockets() + if err != nil { + return rapidmodel.WrapErrorIntoPlatformFatalError( + fmt.Errorf("failed to get Extension Sockets: %w", err), + rapidmodel.ErrSandboxLogSocketsUnavailable) + } + agentName := fmt.Sprintf("extension-%s", path.Base(agentPath)) + logging.Debug(ctx, "Starting extension", "name", agentName) + execReq := &supvmodel.ExecRequest{ + Name: agentName, + Path: agentPath, + Env: &execCtx.initExecutionData.ExtensionEnv, + Logging: supvmodel.Logging{ + Managed: supvmodel.ManagedLogging{ + Topic: supvmodel.RtExtensionManagedLoggingTopic, + Formats: []supvmodel.ManagedLoggingFormat{ + supvmodel.LineBasedManagedLogging, + }, + }, + }, + StdoutWriter: agentStdoutWriter, + StderrWriter: agentStderrWriter, + } + if err := execCtx.supervisor.Exec(ctx, execReq); err != nil { + logging.Warn(ctx, "Could not exec extension process", "err", err, "agent", agentName) + errorType := core.MapErrorToAgentInfoErrorType(err) + if launchErr := agent.LaunchError(errorType); launchErr != nil { + logging.Warn(ctx, "LaunchError transition fail", "agent", agent, "state", agent.GetState().Name(), "error", launchErr) + } + + customerErr := rapidmodel.WrapErrorIntoCustomerInvalidError(err, errorType) + appctx.StoreFirstFatalError(execCtx.appCtx, customerErr) + return customerErr + } + + execCtx.shutdownContext.createExitedChannel(agentName) + } + + if err := initFlow.AwaitExternalAgentsRegistered(ctx); err != nil { + return resolveGateError(ctx, execCtx.appCtx, err, rapidmodel.ErrorAgentRegistrationFailed) + } + + return nil +} + +func prepareRuntimeBootstrap(execCtx *rapidContext, sbStaticData interop.InitExecutionData) ([]string, internalmodel.KVMap, string, *rapidmodel.CustomerError) { + cmd := sbStaticData.Runtime.ExecConfig.Cmd + env := sbStaticData.Runtime.ExecConfig.Env + cwd := sbStaticData.Runtime.ExecConfig.WorkingDir + + if sbStaticData.StaticData.ArtefactType == internalmodel.ArtefactTypeZIP { + slog.Debug("No bootstrap command provided, searching default locations") + var bootstrap []string + for _, path := range []string{"/var/runtime/bootstrap", "/var/task/bootstrap", "/opt/bootstrap"} { + info, err := execCtx.fileUtils.Stat(path) + if err == nil && !info.IsDir() { + slog.Debug("Found bootstrap", "path", path) + bootstrap = []string{path} + cwd = "/var/task" + break + } + slog.Warn("Ignoring invalid bootstrap path", "path", path, "err", err) + } + + if len(bootstrap) == 0 { + slog.Error("No valid bootstrap binary found in default locations") + return failBootstrap(execCtx, sbStaticData, rapidmodel.InvalidEntrypoint, ErrorInvalidEntryPoint) + } + cmd = bootstrap + } + + if _, err := execCtx.fileUtils.Stat(cwd); err != nil { + slog.Warn("Invalid working directory", "cwd", cwd, "error", err) + return failBootstrap(execCtx, sbStaticData, rapidmodel.InvalidWorkingDir, err) + } + + return cmd, env, cwd, nil +} + +func failBootstrap(execCtx *rapidContext, sbStaticData interop.InitExecutionData, errType rapidmodel.RuntimeExecErrorType, err error) ([]string, internalmodel.KVMap, string, *rapidmodel.CustomerError) { + runtimeErr := &rapidmodel.RuntimeExecError{Type: errType, Err: err} + customerErr := rapidmodel.WrapErrorIntoCustomerInvalidError(err, errType.FatalErrorType()) + appctx.StoreFirstFatalError(execCtx.appCtx, customerErr) + execCtx.eventsAPI.SendImageError(interop.ImageErrorLogData{ + ExecError: *runtimeErr, + ExecConfig: sbStaticData.Runtime.ExecConfig, + }) + return nil, nil, "", &customerErr +} + +func (r *rapidContext) watchEvents(events <-chan supvmodel.Event) { + for event := range events { + slog.Debug("Received event", "event", event) + + switch event.Event.EvType { + case supvmodel.EventLossType: + + invariant.Violatef("Lost %d events from supervisor", *event.Event.Size) + + case supvmodel.ProcessTerminationType: + + r.shutdownContext.processTermination(*event.Event.ProcessTerminated(), r) + } + } +} + +func setupEventsWatcher(ctx context.Context, execCtx *rapidContext) error { + events, err := execCtx.supervisor.Events(ctx) + if err != nil { + + return fmt.Errorf("could not get runtime events stream from supervisor: %w", err) + } + go execCtx.watchEvents(events) + return nil +} + +func runtimeInitWithTelemetry(ctx context.Context, execCtx *rapidContext, phase interop.LifecyclePhase) rapidmodel.AppError { + execCtx.initMetrics.TriggerStartRequest() + telemetry.SendInitStartLogEvent(execCtx.eventsAPI, execCtx.initExecutionData.FunctionMetadata, execCtx.initExecutionData.LogStreamName, phase) + + err := doRuntimeInit(ctx, execCtx, phase) + execCtx.initMetrics.TriggerInitCustomerPhaseDone() + + telemetry.SendAgentsInitStatus(execCtx.eventsAPI, execCtx.registrationService.AgentsInfo()) + execCtx.initMetrics.SetExtensionsNumber(len(execCtx.registrationService.GetInternalAgents()), len(execCtx.registrationService.GetExternalAgents())) + + telemetry.SendInitReportLogEvent(execCtx.eventsAPI, execCtx.appCtx, execCtx.initMetrics.RunDuration(), phase, err) + + logsAPIMetrics := execCtx.telemetrySubscriptionAPI.FlushMetrics() + execCtx.initMetrics.SetLogsAPIMetrics(logsAPIMetrics) + + return err +} + +func doRuntimeInit(ctx context.Context, execCtx *rapidContext, phase interop.LifecyclePhase) rapidmodel.AppError { + if err := doInitExtensions(ctx, execCtx); err != nil { + return err + } + + execCtx.initMetrics.TriggerStartingRuntime() + if err := doInitRuntime(ctx, execCtx, phase); err != nil { + return err + } + execCtx.initMetrics.TriggerRuntimeDone() + + if err := waitExtensionsToBeReady(ctx, execCtx); err != nil { + return err + } + + return nil +} + +func waitExtensionsToBeReady(ctx context.Context, execCtx *rapidContext) rapidmodel.AppError { + initFlow := execCtx.registrationService.InitFlow() + + execCtx.registrationService.TurnOff() + + if err := initFlow.SetAgentsReadyCount(execCtx.registrationService.GetRegisteredAgentsSize()); err != nil { + return rapidmodel.WrapErrorIntoCustomerFatalError(err, rapidmodel.ErrorAgentGateCreationFailed) + } + if err := initFlow.AwaitAgentsReady(ctx); err != nil { + return resolveGateError(ctx, execCtx.appCtx, err, rapidmodel.ErrorAgentReadyFailed) + } + + execCtx.telemetrySubscriptionAPI.TurnOff() + + return nil +} + +func doInitRuntime( + ctx context.Context, + execCtx *rapidContext, + phase interop.LifecyclePhase, +) rapidmodel.AppError { + initFlow := execCtx.registrationService.InitFlow() + runtimeFsm := core.NewRuntime(initFlow) + + logging.Debug(ctx, "Preregister runtime") + registrationService := execCtx.registrationService + if err := registrationService.PreregisterRuntime(runtimeFsm); err != nil { + return rapidmodel.WrapErrorIntoPlatformFatalError(err, rapidmodel.ErrorRuntimeRegistrationFailed) + } + + bootstrapCmd, bootstrapEnv, bootstrapCwd, runtimeErr := prepareRuntimeBootstrap(execCtx, execCtx.initExecutionData) + if runtimeErr != nil { + return *runtimeErr + } + + runtimeStdoutWriter, runtimeStderrWriter, err := execCtx.logsEgressAPI.GetRuntimeSockets() + if err != nil { + return rapidmodel.WrapErrorIntoPlatformFatalError(err, rapidmodel.ErrSandboxLogSocketsUnavailable) + } + + name := runtimeProcessName + slog.Debug("Start runtime", "name", name) + + execReq := &supvmodel.ExecRequest{ + Name: name, + Cwd: &bootstrapCwd, + Path: bootstrapCmd[0], + Args: bootstrapCmd[1:], + Env: &bootstrapEnv, + Logging: supvmodel.Logging{ + Managed: supvmodel.ManagedLogging{ + Topic: supvmodel.RuntimeManagedLoggingTopic, + Formats: execCtx.initExecutionData.RuntimeManagedLoggingFormats, + }, + }, + StdoutWriter: runtimeStdoutWriter, + StderrWriter: runtimeStderrWriter, + } + + if err := execCtx.supervisor.Exec(ctx, execReq); err != nil { + logging.Warn(ctx, "Could not Exec Runtime process", "err", err) + execError := rapidmodel.RuntimeExecError{ + Type: rapidmodel.InvalidEntrypoint, + Err: ErrorInvalidEntryPoint, + } + customerErr := rapidmodel.WrapErrorIntoCustomerInvalidError(ErrorInvalidEntryPoint, execError.Type.FatalErrorType()) + appctx.StoreFirstFatalError(execCtx.appCtx, customerErr) + telemetry.SendImageError(execCtx.eventsAPI, execError, execCtx.initExecutionData.Runtime.ExecConfig) + telemetry.SendInitRuntimeDoneLogEvent(execCtx.eventsAPI, execCtx.appCtx, phase, customerErr) + + return customerErr + } + + execCtx.shutdownContext.createExitedChannel(name) + + if err := initFlow.AwaitRuntimeReady(ctx); err != nil { + customerErr := resolveGateError(ctx, execCtx.appCtx, err, rapidmodel.ErrorRuntimeReadyFailed) + + if err != interop.ErrTimeout { + telemetry.SendInitRuntimeDoneLogEvent(execCtx.eventsAPI, execCtx.appCtx, phase, customerErr) + } + return customerErr + } + + telemetry.SendInitRuntimeDoneLogEvent(execCtx.eventsAPI, execCtx.appCtx, phase, nil) + + return nil +} + +func handleInit(ctx context.Context, execCtx *rapidContext) rapidmodel.AppError { + execCtx.registrationService.SetFunctionMetadata(execCtx.initExecutionData.FunctionMetadata) + if err := setupEventsWatcher(ctx, execCtx); err != nil { + return rapidmodel.WrapErrorIntoPlatformFatalError(err, rapidmodel.ErrSandboxEventSetupFailure) + } + + execCtx.telemetrySubscriptionAPI.Configure(execCtx.initExecutionData.TelemetryPassphrase(), execCtx.initExecutionData.TelemetryAPIAddr()) + + return runtimeInitWithTelemetry(ctx, execCtx, interop.LifecyclePhaseInit) +} + +func resolveGateError(ctx context.Context, appCtx appctx.ApplicationContext, gateErr error, errorType rapidmodel.ErrorType) rapidmodel.AppError { + if fatalError, found := appctx.LoadFirstFatalError(appCtx); found { + logging.Warn(ctx, "Ignoring gate error due to existing fatal error", + "gateError", gateErr, + "existingFatalError", fatalError) + return fatalError + } + + if gateErr == interop.ErrTimeout { + logging.Warn(ctx, "Operation timed out", "err", gateErr, "errorType", errorType) + return rapidmodel.WrapErrorIntoCustomerFatalError(gateErr, rapidmodel.ErrorSandboxTimedout) + } + logging.Warn(ctx, "Operation failed", "err", gateErr, "errorType", errorType) + return rapidmodel.WrapErrorIntoCustomerFatalError(gateErr, errorType) +} diff --git a/internal/lambda-managed-instances/rapid/handlers_test.go b/internal/lambda-managed-instances/rapid/handlers_test.go new file mode 100644 index 0000000..dc37ba7 --- /dev/null +++ b/internal/lambda-managed-instances/rapid/handlers_test.go @@ -0,0 +1,958 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package rapid + +import ( + "context" + "fmt" + "net/netip" + "os" + "regexp" + "strconv" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/appctx" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/core" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + internalmodel "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi/rendering" + rapidmodel "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapidcore/env" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/supervisor/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/telemetry" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/utils" +) + +func TestGetExtensionNamesWithNoExtensions(t *testing.T) { + rs := core.NewRegistrationService(nil) + + c := &rapidContext{ + registrationService: rs, + } + + assert.Equal(t, "", c.GetExtensionNames()) +} + +func TestGetExtensionNamesWithMultipleExtensions(t *testing.T) { + rs := core.NewRegistrationService(nil) + _, _ = rs.CreateExternalAgent("Example1") + _, _ = rs.CreateInternalAgent("Example2") + _, _ = rs.CreateExternalAgent("Example3") + _, _ = rs.CreateInternalAgent("Example4") + + c := &rapidContext{ + registrationService: rs, + } + + r := regexp.MustCompile(`^(Example\d;){3}(Example\d)$`) + assert.True(t, r.MatchString(c.GetExtensionNames())) +} + +func TestGetExtensionNamesWithTooManyExtensions(t *testing.T) { + rs := core.NewRegistrationService(nil) + for i := 10; i < 60; i++ { + _, _ = rs.CreateExternalAgent("E" + strconv.Itoa(i)) + } + + c := &rapidContext{ + registrationService: rs, + } + + output := c.GetExtensionNames() + + r := regexp.MustCompile(`^(E\d\d;){30}(E\d\d)$`) + assert.LessOrEqual(t, len(output), maxExtensionNamesLength) + assert.True(t, r.MatchString(output)) +} + +func TestGetExtensionNamesWithTooLongExtensionName(t *testing.T) { + rs := core.NewRegistrationService(nil) + for i := 10; i < 60; i++ { + _, _ = rs.CreateExternalAgent(strings.Repeat("E", 130)) + } + + c := &rapidContext{ + registrationService: rs, + } + + assert.Equal(t, "", c.GetExtensionNames()) +} + +func makeRapidTestEnv() (runtimeEnv internalmodel.KVMap, extensionEnv internalmodel.KVMap) { + + config := &internalmodel.InitRequestMessage{ + TaskName: "runtime", + AwsRegion: "platform", + AwsKey: "creds", + AwsSecret: "creds", + AwsSession: "creds", + FunctionVersion: "platform", + MemorySizeBytes: 3 * 1024 * 1024, + EnvVars: make(map[string]string), + } + + return env.SetupEnvironment(config, "host:port", "/path") +} + +func makeFileUtils(withExtensions bool) *utils.MockFileUtil { + + mockFileUtil := &utils.MockFileUtil{} + mockFileUtil.On("Stat", mock.Anything).Return(nil, nil) + if withExtensions { + mockFileUtil.On("ReadDirectory", mock.MatchedBy(func(path string) bool { + return strings.Contains(path, "/opt/extensions") + })).Return([]os.DirEntry{ + utils.NewMockDirEntry("NoOp.ext", false), + }, nil) + } + + mockFileUtil.On("ReadDirectory", mock.Anything).Return(nil, nil) + mockFileUtil.On("IsNotExist", mock.Anything).Return(true) + + return mockFileUtil +} + +func makeRapidContext(appCtx appctx.ApplicationContext, initFlow core.InitFlowSynchronization, registrationService core.RegistrationService, supervisor *processSupervisor, fileUtils *utils.MockFileUtil) (*rapidContext, internalmodel.KVMap, internalmodel.KVMap) { + appctx.StoreInteropServer(appCtx, &interop.MockServer{}) + + renderingService := rendering.NewRenderingService() + + runtime := core.NewRuntime(initFlow) + + if err := registrationService.PreregisterRuntime(runtime); err != nil { + panic(err) + } + runtime.SetState(runtime.RuntimeReadyState) + + rapidCtx := &rapidContext{ + + appCtx: appCtx, + initFlow: initFlow, + registrationService: registrationService, + renderingService: renderingService, + shutdownContext: newShutdownContext(), + eventsAPI: &telemetry.NoOpEventsAPI{}, + logsEgressAPI: &telemetry.NoOpLogsEgressAPI{}, + telemetrySubscriptionAPI: &telemetry.NoOpSubscriptionAPI{}, + processTermChan: make(chan rapidmodel.AppError, 20), + fileUtils: fileUtils, + } + if supervisor != nil { + rapidCtx.supervisor = *supervisor + } + + runtimeEnv, extensionEnv := makeRapidTestEnv() + return rapidCtx, runtimeEnv, extensionEnv +} + +func TestSetupEventWatcherRuntimeErrorHandling(t *testing.T) { + appCtx := appctx.NewApplicationContext() + initFlow := core.NewInitFlowSynchronization() + registrationService := core.NewRegistrationService(initFlow) + mockedProcessSupervisor := &model.MockProcessSupervisor{} + mockedProcessSupervisor.On("Events", mock.Anything).Return(nil, fmt.Errorf("events call failed")) + procSupv := &processSupervisor{ProcessSupervisor: mockedProcessSupervisor} + initMetrics := &interop.MockInitMetrics{} + defer initMetrics.AssertExpectations(t) + + rapidCtx, _, _ := makeRapidContext(appCtx, initFlow, registrationService, procSupv, makeFileUtils(false)) + + sbStaticData := interop.InitExecutionData{ + TelemetrySubscriptionConfig: interop.TelemetrySubscriptionConfig{ + Passphrase: "test-passphrase", + APIAddr: netip.MustParseAddrPort("127.0.0.1:1234"), + }, + } + + ctx := context.Background() + initErr := rapidCtx.HandleInit(ctx, sbStaticData, initMetrics) + assert.NotNil(t, initErr) + + assert.NotNil(t, "We receive a not nil error", initErr.Error()) + assert.Equal(t, rapidmodel.ErrorSeverityFatal, initErr.Severity()) + assert.Equal(t, rapidmodel.ErrorSourceSandbox, initErr.Source()) +} + +func TestInitTimeoutHandling(t *testing.T) { + appCtx := appctx.NewApplicationContext() + initFlow := core.NewInitFlowSynchronization() + registrationService := core.NewRegistrationService(initFlow) + mockedProcessSupervisor := &model.MockProcessSupervisor{} + eventChan := make(chan model.Event) + mockedProcessSupervisor.On("Events", mock.Anything).Return((<-chan model.Event)(eventChan), nil) + initMetrics := NewInitMetrics(nil) + + mockedProcessSupervisor.On("Exec", mock.Anything, mock.MatchedBy(func(execRequest *model.ExecRequest) bool { + return execRequest.Name == "extension-NoOp.ext" && + execRequest.Path == "/opt/extensions/NoOp.ext" + })).Return(nil) + + mockedProcessSupervisor.On("Exec", mock.Anything, mock.MatchedBy(func(execRequest *model.ExecRequest) bool { + return execRequest.Name == "runtime" + })).WaitUntil(time.After(300 * time.Millisecond)).Return(nil) + procSupv := &processSupervisor{ProcessSupervisor: mockedProcessSupervisor} + + rapidCtx, testEnv, testEnv2 := makeRapidContext(appCtx, initFlow, registrationService, procSupv, makeFileUtils(true)) + + var initTimeoutMs int64 = 200 + + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(initTimeoutMs)*time.Millisecond) + defer cancel() + + _, extensionEnv := testEnv, testEnv2 + sbStaticData := interop.InitExecutionData{ + ExtensionEnv: extensionEnv, + TelemetrySubscriptionConfig: interop.TelemetrySubscriptionConfig{ + Passphrase: "test-passphrase", + APIAddr: netip.MustParseAddrPort("127.0.0.1:1234"), + }, + Runtime: rapidmodel.Runtime{ + ExecConfig: rapidmodel.RuntimeExec{ + Cmd: make([]string, 2), + WorkingDir: "", + Env: internalmodel.KVMap{"AWS_REGION": "us-west-2"}, + }, + }, + } + + initErr := rapidCtx.HandleInit(ctx, sbStaticData, initMetrics) + + assert.Equal(t, 1, initMetrics.externalExtensionCount+initMetrics.internalExtensionCount) + assert.Equal(t, "Sandbox.Timedout", string(initErr.ErrorType())) + assert.Equal(t, "Sandbox.Timedout: errTimeout", initErr.Error()) +} + +func TestRuntimeExecFailureOnPlatformError(t *testing.T) { + appCtx := appctx.NewApplicationContext() + initFlow := core.NewInitFlowSynchronization() + registrationService := core.NewRegistrationService(initFlow) + mockedProcessSupervisor := &model.MockProcessSupervisor{} + defer mockedProcessSupervisor.AssertExpectations(t) + initMetrics := &interop.MockInitMetrics{} + mockInitMetricsRuntimeFail(initMetrics) + defer initMetrics.AssertExpectations(t) + + eventChan := make(chan model.Event) + mockedProcessSupervisor.On("Events", mock.Anything).Return((<-chan model.Event)(eventChan), nil) + err := &model.SupervisorError{ + SourceErr: model.ErrorSourceServer, + ReasonErr: "AnyErrorReason", + CauseErr: "msg", + } + + mockedProcessSupervisor.On("Exec", mock.Anything, mock.MatchedBy(func(execRequest *model.ExecRequest) bool { + return execRequest.Name == "extension-NoOp.ext" && + execRequest.Path == "/opt/extensions/NoOp.ext" + })).Return(nil) + + mockedProcessSupervisor.On("Exec", mock.Anything, mock.Anything).Return(err) + + procSupv := &processSupervisor{ProcessSupervisor: mockedProcessSupervisor} + + rapidCtx, testEnv, testEnv2 := makeRapidContext(appCtx, initFlow, registrationService, procSupv, makeFileUtils(true)) + + _, extensionEnv := testEnv, testEnv2 + sbStaticData := interop.InitExecutionData{ + ExtensionEnv: extensionEnv, + TelemetrySubscriptionConfig: interop.TelemetrySubscriptionConfig{ + Passphrase: "test-passphrase", + APIAddr: netip.MustParseAddrPort("127.0.0.1:1234"), + }, + Runtime: rapidmodel.Runtime{ + ExecConfig: rapidmodel.RuntimeExec{ + Cmd: make([]string, 2), + WorkingDir: "", + Env: internalmodel.KVMap{"FOO": "BAR"}, + }, + }, + } + + ctx := context.Background() + + require.NoError(t, initFlow.SetExternalAgentsRegisterCount(1)) + require.NoError(t, initFlow.ExternalAgentRegistered()) + + initErr := rapidCtx.HandleInit(ctx, sbStaticData, initMetrics) + + assert.Equal(t, rapidmodel.ErrorCategory("Invalid.Runtime"), initErr.ErrorCategory()) + assert.Equal(t, rapidmodel.ErrorRuntimeInvalidEntryPoint, initErr.ErrorType()) +} + +func TestRuntimeExecFailureOnCustomerError(t *testing.T) { + appCtx := appctx.NewApplicationContext() + initFlow := core.NewInitFlowSynchronization() + registrationService := core.NewRegistrationService(initFlow) + mockedProcessSupervisor := &model.MockProcessSupervisor{} + defer mockedProcessSupervisor.AssertExpectations(t) + + eventChan := make(chan model.Event) + mockedProcessSupervisor.On("Events", mock.Anything).Return((<-chan model.Event)(eventChan), nil) + + mockedProcessSupervisor.On("Exec", mock.Anything, mock.MatchedBy(func(execRequest *model.ExecRequest) bool { + return execRequest.Name == "extension-NoOp.ext" && + execRequest.Path == "/opt/extensions/NoOp.ext" + })).Return(nil).Once() + + initMetrics := &interop.MockInitMetrics{} + mockInitMetricsRuntimeFail(initMetrics) + defer initMetrics.AssertExpectations(t) + + err := &model.SupervisorError{ + SourceErr: model.ErrorSourceFunction, + ReasonErr: "AnyErrorReason", + CauseErr: "msg", + } + + mockedProcessSupervisor.On("Exec", mock.Anything, mock.MatchedBy(func(exec *model.ExecRequest) bool { + return exec.Name == "runtime" + })).Return(err) + + procSupv := &processSupervisor{ProcessSupervisor: mockedProcessSupervisor} + + rapidCtx, testEnv, testEnv2 := makeRapidContext(appCtx, initFlow, registrationService, procSupv, makeFileUtils(true)) + + _, extensionEnv := testEnv, testEnv2 + sbStaticData := interop.InitExecutionData{ + ExtensionEnv: extensionEnv, + TelemetrySubscriptionConfig: interop.TelemetrySubscriptionConfig{ + Passphrase: "test-passphrase", + APIAddr: netip.MustParseAddrPort("127.0.0.1:1234"), + }, + Runtime: rapidmodel.Runtime{ + ExecConfig: rapidmodel.RuntimeExec{ + Cmd: make([]string, 2), + WorkingDir: "", + Env: internalmodel.KVMap{"FOO": "BAR"}, + }, + }, + } + + ctx := context.Background() + + _ = initFlow.SetExternalAgentsRegisterCount(1) + _ = initFlow.ExternalAgentRegistered() + + initErr := rapidCtx.HandleInit(ctx, sbStaticData, initMetrics) + + assert.Equal(t, "Runtime.InvalidEntrypoint", string(initErr.ErrorType())) + assert.Equal(t, "Runtime.InvalidEntrypoint: invalid entrypoint, runtime process spawn failed", initErr.Error()) +} + +func TestRuntimeProcessTerminationWithDifferentCause(t *testing.T) { + testCases := []struct { + name string + TerminationCause model.ProcessTerminationCause + expectedErrorType rapidmodel.ErrorType + }{ + { + name: "runtimeTerminationDueToOOM", + TerminationCause: model.OomKilled, + expectedErrorType: rapidmodel.ErrorRuntimeOutOfMemory, + }, + { + name: "runtimeTerminationWithExitCodeZero", + TerminationCause: model.Exited, + expectedErrorType: rapidmodel.ErrorRuntimeExit, + }, + { + name: "runtimeTerminationSiganled", + TerminationCause: model.Signaled, + expectedErrorType: rapidmodel.ErrorRuntimeExit, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + appCtx := appctx.NewApplicationContext() + initFlow := core.NewInitFlowSynchronization() + + registrationService := core.NewRegistrationService(initFlow) + mockedProcessSupervisor := &model.MockProcessSupervisor{} + defer mockedProcessSupervisor.AssertExpectations(t) + + runtimeEvents := make(chan model.Event) + mockedProcessSupervisor.On("Events", mock.Anything).Return((<-chan model.Event)(runtimeEvents), nil) + execDone := make(chan struct{}) + mockedProcessSupervisor.On("Exec", mock.Anything, mock.Anything).Return(nil).Run(func(args mock.Arguments) { execDone <- struct{}{} }) + + initMetrics := &interop.MockInitMetrics{} + mockInitMetricsRuntimeFail(initMetrics) + defer initMetrics.AssertExpectations(t) + + go func() { + <-execDone + runtimeEvents <- model.Event{ + Time: time.Now().UnixMilli(), + Event: model.EventData{ + EvType: model.ProcessTerminationType, + Name: "runtime", + Cause: tc.TerminationCause, + ExitStatus: new(int32), + }, + } + }() + + procSupv := &processSupervisor{ProcessSupervisor: mockedProcessSupervisor} + + rapidCtx, testEnv, testEnv2 := makeRapidContext(appCtx, initFlow, registrationService, procSupv, makeFileUtils(false)) + + _, extensionEnv := testEnv, testEnv2 + sbStaticData := interop.InitExecutionData{ + ExtensionEnv: extensionEnv, + TelemetrySubscriptionConfig: interop.TelemetrySubscriptionConfig{ + Passphrase: "test-passphrase", + APIAddr: netip.MustParseAddrPort("127.0.0.1:1234"), + }, + Runtime: rapidmodel.Runtime{ + ExecConfig: rapidmodel.RuntimeExec{ + Cmd: make([]string, 2), + WorkingDir: "", + Env: make(internalmodel.KVMap), + }, + }, + } + + ctx := context.Background() + initErr := rapidCtx.HandleInit(ctx, sbStaticData, initMetrics) + assert.Equal(t, tc.expectedErrorType, initErr.ErrorType()) + }) + } +} + +func TestRuntimeExtensionTerminationWithDifferentCause(t *testing.T) { + testCases := []struct { + name string + TerminationCause model.ProcessTerminationCause + expectedErrorType rapidmodel.ErrorType + }{ + { + name: "extensionTerminationDueToOOM", + TerminationCause: model.OomKilled, + expectedErrorType: rapidmodel.ErrorAgentCrash, + }, + { + name: "extensionTerminationWithExitCodeZero", + TerminationCause: model.Exited, + expectedErrorType: rapidmodel.ErrorAgentCrash, + }, + { + name: "extensionTerminationSiganled", + TerminationCause: model.Signaled, + expectedErrorType: rapidmodel.ErrorAgentCrash, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + appCtx := appctx.NewApplicationContext() + initFlow := core.NewInitFlowSynchronization() + + registrationService := core.NewRegistrationService(initFlow) + mockedProcessSupervisor := &model.MockProcessSupervisor{} + defer mockedProcessSupervisor.AssertExpectations(t) + + initMetrics := &interop.MockInitMetrics{} + mockInitMetricsExtensionStartFailed(initMetrics) + defer initMetrics.AssertExpectations(t) + + runtimeEvents := make(chan model.Event) + go func() { + runtimeEvents <- model.Event{ + Time: time.Now().UnixMilli(), + Event: model.EventData{ + EvType: model.ProcessTerminationType, + Name: "extension-example1", + Cause: tc.TerminationCause, + ExitStatus: new(int32), + }, + } + }() + mockedProcessSupervisor.On("Events", mock.Anything).Return((<-chan model.Event)(runtimeEvents), nil) + mockedProcessSupervisor.On("Exec", mock.Anything, mock.Anything).Return(nil) + + procSupv := &processSupervisor{ProcessSupervisor: mockedProcessSupervisor} + + rapidCtx, testEnv, testEnv2 := makeRapidContext(appCtx, initFlow, registrationService, procSupv, makeFileUtils(true)) + + _, extensionEnv := testEnv, testEnv2 + sbStaticData := interop.InitExecutionData{ + ExtensionEnv: extensionEnv, + TelemetrySubscriptionConfig: interop.TelemetrySubscriptionConfig{ + Passphrase: "test-passphrase", + APIAddr: netip.MustParseAddrPort("127.0.0.1:1234"), + }, + Runtime: rapidmodel.Runtime{ + ExecConfig: rapidmodel.RuntimeExec{ + Cmd: make([]string, 2), + WorkingDir: "", + Env: make(internalmodel.KVMap), + }, + }, + } + + ctx := context.Background() + initErr := rapidCtx.HandleInit(ctx, sbStaticData, initMetrics) + + assert.Equal(t, tc.expectedErrorType, initErr.ErrorType()) + }) + } +} + +type MockedEventsAPI struct { + mock.Mock + telemetry.NoOpEventsAPI +} + +var _ interop.EventsAPI = (*MockedEventsAPI)(nil) + +func (m *MockedEventsAPI) SendInitReport(report interop.InitReportData) error { + args := m.Called(report) + return args.Error(0) +} + +func (m *MockedEventsAPI) SendInitRuntimeDone(data interop.InitRuntimeDoneData) error { + args := m.Called(data) + return args.Error(0) +} + +func (m *MockedEventsAPI) Flush() { + +} + +type MockInvokeResponseSender struct { + mock.Mock +} + +var _ interop.InvokeResponseSender = (*MockInvokeResponseSender)(nil) + +func (m *MockInvokeResponseSender) SendResponse(invokeID string, response *interop.StreamableInvokeResponse) (*interop.InvokeResponseMetrics, error) { + args := m.Called(invokeID, response) + return args.Get(0).(*interop.InvokeResponseMetrics), args.Error(1) +} + +func (m *MockInvokeResponseSender) SendErrorResponse(invokeID string, response *interop.ErrorInvokeResponse) (*interop.InvokeResponseMetrics, error) { + args := m.Called(invokeID, response) + return args.Get(0).(*interop.InvokeResponseMetrics), args.Error(1) +} + +func (m *MockInvokeResponseSender) SetInvokeError(response *interop.ErrorInvokeResponse) { + m.Called(response) +} + +func (m *MockInvokeResponseSender) WaitForResponse() *interop.InvokeResponseMetrics { + args := m.Called() + return args.Get(0).(*interop.InvokeResponseMetrics) +} + +func TestAgentCountInitResponseTimeout(t *testing.T) { + rapidCtx, registrationService, sbStaticData, _, initMetrics := createRapidContext() + + var initTimeoutMs int64 = 50 + + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(initTimeoutMs)*time.Millisecond) + defer cancel() + + numExternalExtensions := 3 + numInternalExtensions := 4 + + err := registerDummyExtensions(registrationService, numInternalExtensions, numExternalExtensions) + assert.NoError(t, err) + + initErr := rapidCtx.HandleInit(ctx, sbStaticData, initMetrics) + assert.NotNil(t, initErr) + + assert.Equal(t, numInternalExtensions, initMetrics.internalExtensionCount) + assert.Equal(t, numExternalExtensions, initMetrics.externalExtensionCount) +} + +func TestAgentCountInitResponseFailure(t *testing.T) { + rapidCtx, registrationService, sbStaticData, _, initMetrics := createRapidContext() + + var initTimeoutMs int64 = 50 + + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(initTimeoutMs)*time.Millisecond) + defer cancel() + + numExternalExtensions := 3 + numInternalExtensions := 4 + + err := registerDummyExtensions(registrationService, numInternalExtensions, numExternalExtensions) + assert.NoError(t, err) + + registrationService.CancelFlows(fmt.Errorf("lolError")) + initErr := rapidCtx.HandleInit(ctx, sbStaticData, initMetrics) + assert.NotNil(t, initErr) + assert.Equal(t, numInternalExtensions, initMetrics.internalExtensionCount) + assert.Equal(t, numExternalExtensions, initMetrics.externalExtensionCount) +} + +func TestMultipleNextFromRuntimesDuringInit(t *testing.T) { + appCtx := appctx.NewApplicationContext() + initFlow := core.NewInitFlowSynchronization() + registrationService := core.NewRegistrationService(initFlow) + mockedProcessSupervisor := &model.MockProcessSupervisor{} + defer mockedProcessSupervisor.AssertExpectations(t) + + initMetrics := &interop.MockInitMetrics{} + mockInitMetricsFullFlow(initMetrics) + defer initMetrics.AssertExpectations(t) + + eventChan := make(chan model.Event) + mockedProcessSupervisor.On("Events", mock.Anything).Return((<-chan model.Event)(eventChan), nil) + + mockedProcessSupervisor.On("Exec", mock.Anything, mock.Anything).Return(nil) + + procSupv := &processSupervisor{ProcessSupervisor: mockedProcessSupervisor} + + rapidCtx, testEnv, testEnv2 := makeRapidContext(appCtx, initFlow, registrationService, procSupv, makeFileUtils(false)) + + _, extensionEnv := testEnv, testEnv2 + sbStaticData := interop.InitExecutionData{ + ExtensionEnv: extensionEnv, + TelemetrySubscriptionConfig: interop.TelemetrySubscriptionConfig{ + Passphrase: "test-passphrase", + APIAddr: netip.MustParseAddrPort("127.0.0.1:1234"), + }, + Runtime: rapidmodel.Runtime{ + ExecConfig: rapidmodel.RuntimeExec{ + Cmd: make([]string, 2), + WorkingDir: "", + Env: internalmodel.KVMap{"FOO": "BAR"}, + }, + }, + } + + ctx := context.Background() + + for range 4 { + _ = registrationService.InitFlow().RuntimeReady() + } + + initErr := rapidCtx.HandleInit(ctx, sbStaticData, initMetrics) + assert.Nil(t, initErr) +} + +func createRapidContext() (*rapidContext, core.RegistrationService, interop.InitExecutionData, chan model.Event, *initMetrics) { + invokeStarted := make(chan bool) + appCtx := appctx.NewApplicationContext() + initFlow := core.NewInitFlowSynchronization() + registrationService := core.NewRegistrationService(initFlow) + mockedProcessSupervisor := &model.MockProcessSupervisor{} + eventChan := make(chan model.Event) + mockedProcessSupervisor.On("Events", mock.Anything).Return((<-chan model.Event)(eventChan), nil) + mockedProcessSupervisor.On("Exec", mock.Anything, mock.Anything).Return(nil).Once() + mockedProcessSupervisor.On("Exec", mock.Anything, mock.Anything).Return(nil).Once().Run(func(args mock.Arguments) { invokeStarted <- true }) + mockedProcessSupervisor.On("Kill", mock.Anything, mock.Anything).Return(nil) + procSupv := &processSupervisor{ProcessSupervisor: mockedProcessSupervisor} + + rapidCtx, testEnv, testEnv2 := makeRapidContext(appCtx, initFlow, registrationService, procSupv, makeFileUtils(false)) + + _, extensionEnv := testEnv, testEnv2 + sbStaticData := interop.InitExecutionData{ + ExtensionEnv: extensionEnv, + TelemetrySubscriptionConfig: interop.TelemetrySubscriptionConfig{ + Passphrase: "test-passphrase", + APIAddr: netip.MustParseAddrPort("127.0.0.1:1234"), + }, + Runtime: rapidmodel.Runtime{ + ExecConfig: rapidmodel.RuntimeExec{ + Cmd: make([]string, 2), + WorkingDir: "", + Env: make(internalmodel.KVMap), + }, + }, + } + + initMetrics := NewInitMetrics(nil) + return rapidCtx, registrationService, sbStaticData, eventChan, initMetrics +} + +func mockInitMetricsExtensionStartFailed(initMetrics *interop.MockInitMetrics) { + initMetrics.On("TriggerStartRequest") + initMetrics.On("SetExtensionsNumber", mock.AnythingOfType("int"), mock.AnythingOfType("int")) + initMetrics.On("TriggerInitCustomerPhaseDone") + initMetrics.On("RunDuration").Return(time.Millisecond) + initMetrics.On("SetLogsAPIMetrics", mock.AnythingOfType("interop.TelemetrySubscriptionMetrics")) +} + +func mockInitMetricsRuntimeFail(initMetrics *interop.MockInitMetrics) { + mockInitMetricsExtensionStartFailed(initMetrics) + initMetrics.On("TriggerStartingRuntime") +} + +func mockInitMetricsFullFlow(initMetrics *interop.MockInitMetrics) { + mockInitMetricsRuntimeFail(initMetrics) + initMetrics.On("TriggerRuntimeDone") +} + +func registerDummyExtensions(registrationService core.RegistrationService, numInternalExtensions, numExternalExtensions int) error { + for i := 0; i < numExternalExtensions; i++ { + _, err := registrationService.CreateExternalAgent("external/" + strconv.Itoa(i)) + if err != nil { + return err + } + } + + for i := 0; i < numInternalExtensions; i++ { + _, err := registrationService.CreateInternalAgent("internal/" + strconv.Itoa(i)) + if err != nil { + return err + } + } + + return nil +} + +func TestExecFailureOnPlatformErrorForExtensions(t *testing.T) { + appCtx := appctx.NewApplicationContext() + initFlow := core.NewInitFlowSynchronization() + registrationService := core.NewRegistrationService(initFlow) + mockedProcessSupervisor := &model.MockProcessSupervisor{} + defer mockedProcessSupervisor.AssertExpectations(t) + + eventChan := make(chan model.Event) + mockedProcessSupervisor.On("Events", mock.Anything).Return((<-chan model.Event)(eventChan), nil) + err := &model.SupervisorError{ + SourceErr: model.ErrorSourceServer, + ReasonErr: "AnyErrorReason", + CauseErr: "msg", + } + mockedProcessSupervisor.On("Exec", mock.Anything, mock.MatchedBy(func(exec *model.ExecRequest) bool { + return exec.Name == "extension-NoOp.ext" + })).Return(err).Once() + + initMetrics := &interop.MockInitMetrics{} + mockInitMetricsExtensionStartFailed(initMetrics) + defer initMetrics.AssertExpectations(t) + + procSupv := &processSupervisor{ProcessSupervisor: mockedProcessSupervisor} + + rapidCtx, testEnv, testEnv2 := makeRapidContext(appCtx, initFlow, registrationService, procSupv, makeFileUtils(true)) + + _, extensionEnv := testEnv, testEnv2 + sbStaticData := interop.InitExecutionData{ + ExtensionEnv: extensionEnv, + TelemetrySubscriptionConfig: interop.TelemetrySubscriptionConfig{ + Passphrase: "test-passphrase", + APIAddr: netip.MustParseAddrPort("127.0.0.1:1234"), + }, + Runtime: rapidmodel.Runtime{ + ExecConfig: rapidmodel.RuntimeExec{ + Cmd: make([]string, 2), + WorkingDir: "", + Env: internalmodel.KVMap{"FOO": "BAR"}, + }, + }, + } + + ctx := context.Background() + + require.NoError(t, initFlow.SetExternalAgentsRegisterCount(1)) + require.NoError(t, initFlow.ExternalAgentRegistered()) + + initErr := rapidCtx.HandleInit(ctx, sbStaticData, initMetrics) + assert.NotNil(t, initErr) + + assert.Equal(t, "Extension.LaunchError", string(initErr.ErrorType())) + assert.Equal(t, "Invalid.Runtime", string(initErr.ErrorCategory())) +} + +func TestPrepareRuntimeBootstrap(t *testing.T) { + testCases := []struct { + name string + cmd []string + workingDir string + env internalmodel.KVMap + artefactType internalmodel.ArtefactType + setupMocks func(*utils.MockFileUtil) + expectedCmd []string + expectedEnv internalmodel.KVMap + expectedCwd string + expectedErrorType rapidmodel.ErrorType + shouldHaveError bool + }{ + { + name: "ZIP: Bootstrap order - first location exists", + cmd: []string{}, + workingDir: "/var/task", + env: internalmodel.KVMap{"KEY": "VALUE"}, + artefactType: internalmodel.ArtefactTypeZIP, + setupMocks: func(mockFileUtil *utils.MockFileUtil) { + fileInfo := utils.NewMockFileInfo() + fileInfo.On("IsDir").Return(false) + mockFileUtil.On("Stat", "/var/runtime/bootstrap").Return(fileInfo, nil) + mockFileUtil.On("Stat", "/var/task").Return(fileInfo, nil) + }, + expectedCmd: []string{"/var/runtime/bootstrap"}, + expectedEnv: internalmodel.KVMap{"KEY": "VALUE"}, + expectedCwd: "/var/task", + shouldHaveError: false, + }, + { + name: "ZIP: Bootstrap order - second location exists", + cmd: []string{}, + workingDir: "/var/task", + env: internalmodel.KVMap{"KEY": "VALUE"}, + artefactType: internalmodel.ArtefactTypeZIP, + setupMocks: func(mockFileUtil *utils.MockFileUtil) { + fileInfo := utils.NewMockFileInfo() + fileInfo.On("IsDir").Return(false) + + mockFileUtil.On("Stat", "/var/runtime/bootstrap").Return(nil, fmt.Errorf("not found")) + mockFileUtil.On("Stat", "/var/task/bootstrap").Return(fileInfo, nil) + mockFileUtil.On("Stat", "/var/task").Return(fileInfo, nil) + }, + expectedCmd: []string{"/var/task/bootstrap"}, + expectedEnv: internalmodel.KVMap{"KEY": "VALUE"}, + expectedCwd: "/var/task", + shouldHaveError: false, + }, + { + name: "ZIP: Bootstrap order - third location exists", + cmd: []string{}, + workingDir: "/var/task", + env: internalmodel.KVMap{"KEY": "VALUE"}, + artefactType: internalmodel.ArtefactTypeZIP, + setupMocks: func(mockFileUtil *utils.MockFileUtil) { + fileInfo := utils.NewMockFileInfo() + fileInfo.On("IsDir").Return(false) + + mockFileUtil.On("Stat", "/var/runtime/bootstrap").Return(nil, fmt.Errorf("not found")) + + mockFileUtil.On("Stat", "/var/task/bootstrap").Return(nil, fmt.Errorf("not found")) + mockFileUtil.On("Stat", "/opt/bootstrap").Return(fileInfo, nil) + mockFileUtil.On("Stat", "/var/task").Return(fileInfo, nil) + }, + expectedCmd: []string{"/opt/bootstrap"}, + expectedEnv: internalmodel.KVMap{"KEY": "VALUE"}, + expectedCwd: "/var/task", + shouldHaveError: false, + }, + { + name: "ZIP: Bootstrap order preference - multiple bootstrap files exist", + cmd: []string{}, + workingDir: "/var/task", + env: internalmodel.KVMap{"KEY": "VALUE"}, + artefactType: internalmodel.ArtefactTypeZIP, + setupMocks: func(mockFileUtil *utils.MockFileUtil) { + fileInfo := utils.NewMockFileInfo() + fileInfo.On("IsDir").Return(false) + + mockFileUtil.On("Stat", "/var/runtime/bootstrap").Return(fileInfo, nil) + mockFileUtil.On("Stat", "/var/task/bootstrap").Return(fileInfo, nil) + mockFileUtil.On("Stat", "/opt/bootstrap").Return(fileInfo, nil) + mockFileUtil.On("Stat", "/var/task").Return(fileInfo, nil) + }, + + expectedCmd: []string{"/var/runtime/bootstrap"}, + expectedEnv: internalmodel.KVMap{"KEY": "VALUE"}, + expectedCwd: "/var/task", + shouldHaveError: false, + }, + { + name: "ZIP: No bootstrap command and no valid default location", + cmd: []string{}, + workingDir: "/var/task", + env: internalmodel.KVMap{"KEY": "VALUE"}, + artefactType: internalmodel.ArtefactTypeZIP, + setupMocks: func(mockFileUtil *utils.MockFileUtil) { + fileInfo := utils.NewMockFileInfo() + fileInfo.On("IsDir").Return(false) + mockFileUtil.On("Stat", "/var/task").Return(fileInfo, nil) + mockFileUtil.On("Stat", "/var/runtime/bootstrap").Return(nil, fmt.Errorf("not found")) + mockFileUtil.On("Stat", "/var/task/bootstrap").Return(nil, fmt.Errorf("not found")) + mockFileUtil.On("Stat", "/opt/bootstrap").Return(nil, fmt.Errorf("not found")) + }, + expectedCmd: nil, + expectedEnv: nil, + expectedCwd: "", + expectedErrorType: rapidmodel.ErrorRuntimeInvalidEntryPoint, + shouldHaveError: true, + }, + { + name: "Invalid working directory", + cmd: []string{"cmd", "arg"}, + workingDir: "/non/existent/dir", + env: internalmodel.KVMap{"KEY": "VALUE"}, + artefactType: internalmodel.ArtefactTypeOCI, + setupMocks: func(mockFileUtil *utils.MockFileUtil) { + mockFileUtil.On("Stat", "/non/existent/dir").Return(nil, fmt.Errorf("directory does not exist")) + }, + expectedCmd: nil, + expectedEnv: nil, + expectedCwd: "", + expectedErrorType: rapidmodel.ErrorRuntimeInvalidWorkingDir, + shouldHaveError: true, + }, + { + name: "OCI: No bootstrap command (should not look for default locations)", + cmd: []string{}, + workingDir: "/var/task", + env: internalmodel.KVMap{"KEY": "VALUE"}, + artefactType: internalmodel.ArtefactTypeOCI, + setupMocks: func(mockFileUtil *utils.MockFileUtil) { + fileInfo := utils.NewMockFileInfo() + fileInfo.On("IsDir").Return(false) + mockFileUtil.On("Stat", "/var/task").Return(fileInfo, nil) + }, + expectedCmd: []string{}, + expectedEnv: internalmodel.KVMap{"KEY": "VALUE"}, + expectedCwd: "/var/task", + shouldHaveError: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + appCtx := appctx.NewApplicationContext() + initFlow := core.NewInitFlowSynchronization() + registrationService := core.NewRegistrationService(initFlow) + + mockEventsAPI := &MockedEventsAPI{} + mockEventsAPI.On("SendImageError", mock.Anything).Return(nil) + + mockFileUtil := &utils.MockFileUtil{} + tc.setupMocks(mockFileUtil) + + rapidCtx := &rapidContext{ + appCtx: appCtx, + initFlow: initFlow, + registrationService: registrationService, + eventsAPI: mockEventsAPI, + fileUtils: mockFileUtil, + } + + sbStaticData := interop.InitExecutionData{ + Runtime: rapidmodel.Runtime{ + ExecConfig: rapidmodel.RuntimeExec{ + Cmd: tc.cmd, + Env: tc.env, + WorkingDir: tc.workingDir, + }, + }, + StaticData: interop.EEStaticData{ArtefactType: tc.artefactType}, + } + + cmd, env, cwd, err := prepareRuntimeBootstrap(rapidCtx, sbStaticData) + + if tc.shouldHaveError { + assert.NotNil(t, err) + assert.Equal(t, tc.expectedErrorType, err.ErrorType()) + assert.Empty(t, cmd) + assert.Empty(t, env) + assert.Empty(t, cwd) + } else { + assert.Equal(t, tc.expectedCmd, cmd) + assert.Equal(t, tc.expectedEnv, env) + assert.Equal(t, tc.expectedCwd, cwd) + } + }) + } +} diff --git a/internal/lambda-managed-instances/rapid/init_metrics.go b/internal/lambda-managed-instances/rapid/init_metrics.go new file mode 100644 index 0000000..a674765 --- /dev/null +++ b/internal/lambda-managed-instances/rapid/init_metrics.go @@ -0,0 +1,158 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package rapid + +import ( + "fmt" + "time" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/servicelogs" +) + +const ( + InitTimeoutProperty = "InitTimeoutSeconds" + + runtimeDurationMetric = "RuntimeDuration" + runDurationMetric = "RunDuration" +) + +type initMetrics struct { + logger servicelogs.Logger + + error model.AppError + + timeGetRequest time.Time + + timeStartRequest time.Time + timeStartedRuntime time.Time + timeRuntimeDone time.Time + timeCustomerPhaseDone time.Time + timeInitDone time.Time + + internalExtensionCount int + externalExtensionCount int + + logsAPIMetrics interop.TelemetrySubscriptionMetrics + + getCurrentTime func() time.Time +} + +func NewInitMetrics(logger servicelogs.Logger) *initMetrics { + return &initMetrics{ + getCurrentTime: time.Now, + logger: logger, + } +} + +func (m *initMetrics) TriggerGetRequest() { + m.timeGetRequest = m.getCurrentTime() +} + +func (m *initMetrics) SetLogsAPIMetrics(metrics interop.TelemetrySubscriptionMetrics) { + m.logsAPIMetrics = metrics +} + +func (m *initMetrics) SetExtensionsNumber(internal, external int) { + m.internalExtensionCount = internal + m.externalExtensionCount = external +} + +func (m *initMetrics) TriggerStartRequest() { + m.timeStartRequest = m.getCurrentTime() +} + +func (m *initMetrics) TriggerStartingRuntime() { + m.timeStartedRuntime = m.getCurrentTime() +} + +func (m *initMetrics) TriggerRuntimeDone() { + m.timeRuntimeDone = m.getCurrentTime() +} + +func (m *initMetrics) TriggerInitCustomerPhaseDone() { + m.timeCustomerPhaseDone = m.getCurrentTime() +} + +func (m *initMetrics) TriggerInitDone(err model.AppError) { + m.error = err + m.timeInitDone = m.getCurrentTime() +} + +func (m *initMetrics) RunDuration() time.Duration { + if m.timeCustomerPhaseDone.IsZero() { + return time.Duration(0) + } + + return m.timeCustomerPhaseDone.Sub(m.timeStartRequest) +} + +func (m *initMetrics) SendMetrics() error { + metrics := m.buildMetrics() + + m.logger.Log(servicelogs.InitOp, m.timeGetRequest, nil, nil, metrics) + return nil +} + +func (m *initMetrics) buildMetrics() []servicelogs.Metric { + totalDuration := m.timeInitDone.Sub(m.timeGetRequest) + runDuration := m.RunDuration() + overheadDuration := totalDuration - runDuration + + metrics := []servicelogs.Metric{ + servicelogs.Timer(interop.TotalDurationMetric, totalDuration), + servicelogs.Timer(interop.PlatformOverheadDurationMetric, overheadDuration), + servicelogs.Counter(interop.InternalExtensionsCountMetric, float64(m.internalExtensionCount)), + servicelogs.Counter(interop.ExternalExtensionsCountMetric, float64(m.externalExtensionCount)), + servicelogs.Counter(interop.TotalExtensionsCountMetric, float64(m.internalExtensionCount+m.externalExtensionCount)), + } + + if !m.timeRuntimeDone.IsZero() { + metrics = append(metrics, + servicelogs.Timer(runtimeDurationMetric, m.timeRuntimeDone.Sub(m.timeStartedRuntime)), + ) + } + + if runDuration > 0 { + metrics = append(metrics, + servicelogs.Timer(runDurationMetric, runDuration), + ) + } + + var clientErrCnt, customerErrCnt, platformErrCnt, nonCustomerErrCnt float64 + + switch m.error.(type) { + case model.ClientError: + clientErrCnt = 1 + nonCustomerErrCnt = 1 + metrics = append(metrics, + servicelogs.Counter(fmt.Sprintf(interop.ClientErrorReasonTemplate, m.error.ErrorType()), 1.0), + ) + case model.CustomerError: + customerErrCnt = 1 + metrics = append(metrics, + servicelogs.Counter(fmt.Sprintf(interop.CustomerErrorReasonTemplate, m.error.ErrorType()), 1.0), + ) + case model.PlatformError: + platformErrCnt = 1 + nonCustomerErrCnt = 1 + metrics = append(metrics, + servicelogs.Counter(fmt.Sprintf(interop.PlatformErrorReasonTemplate, m.error.ErrorType()), 1.0), + ) + } + + metrics = append(metrics, + servicelogs.Counter(interop.ClientErrorMetric, clientErrCnt), + servicelogs.Counter(interop.CustomerErrorMetric, customerErrCnt), + servicelogs.Counter(interop.PlatformErrorMetric, platformErrCnt), + servicelogs.Counter(interop.NonCustomerErrorMetric, nonCustomerErrCnt), + ) + + for metricName, value := range m.logsAPIMetrics { + metrics = append(metrics, servicelogs.Counter(metricName, float64(value))) + } + + return metrics +} diff --git a/internal/lambda-managed-instances/rapid/init_metrics_test.go b/internal/lambda-managed-instances/rapid/init_metrics_test.go new file mode 100644 index 0000000..b0e4986 --- /dev/null +++ b/internal/lambda-managed-instances/rapid/init_metrics_test.go @@ -0,0 +1,213 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package rapid + +import ( + "cmp" + "fmt" + "slices" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + intmodel "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/servicelogs" +) + +type initMetricsMocks struct { + timeStamp time.Time + error model.AppError + initData interop.MockInitStaticDataProvider + logger servicelogs.MockLogger +} + +func checkMocksExpectations(t *testing.T, mocks *initMetricsMocks) { + mocks.initData.AssertExpectations(t) + mocks.logger.AssertExpectations(t) +} + +func Test_initMetrics(t *testing.T) { + tests := []struct { + name string + metricFlow func(m *initMetrics, mocks *initMetricsMocks) + expectedMetrics []servicelogs.Metric + expectedRunDuration time.Duration + }{ + { + name: "malformed_request_flow", + metricFlow: func(m *initMetrics, mocks *initMetricsMocks) { + mocks.timeStamp = mocks.timeStamp.Add(time.Second) + mocks.error = model.NewClientError(nil, model.ErrorSeverityInvalid, model.ErrorMalformedRequest) + }, + expectedMetrics: []servicelogs.Metric{ + {Type: servicelogs.TimerType, Key: "TotalDuration", Value: 1000000}, + {Type: servicelogs.CounterType, Key: "TotalExtensionsCount", Value: 0}, + {Type: servicelogs.CounterType, Key: "InternalExtensionsCount", Value: 0}, + {Type: servicelogs.CounterType, Key: "ExternalExtensionsCount", Value: 0}, + {Type: servicelogs.TimerType, Key: "PlatformOverheadDuration", Value: 1000000}, + {Type: servicelogs.CounterType, Key: "ClientError", Value: 1}, + {Type: servicelogs.CounterType, Key: "CustomerError", Value: 0}, + {Type: servicelogs.CounterType, Key: "PlatformError", Value: 0}, + {Type: servicelogs.CounterType, Key: "ClientErrorReason-ErrMalformedRequest", Value: 1}, + {Type: servicelogs.CounterType, Key: "NonCustomerError", Value: 1}, + }, + }, + { + name: "init_platform_error", + metricFlow: func(m *initMetrics, mocks *initMetricsMocks) { + mocks.timeStamp = mocks.timeStamp.Add(time.Second) + m.TriggerStartRequest() + mocks.timeStamp = mocks.timeStamp.Add(time.Second) + m.TriggerInitCustomerPhaseDone() + m.SetExtensionsNumber(2, 3) + mocks.error = model.NewPlatformError(nil, model.ErrorAgentCountRegistrationFailed) + mocks.timeStamp = mocks.timeStamp.Add(time.Second) + }, + expectedMetrics: []servicelogs.Metric{ + {Type: servicelogs.TimerType, Key: "TotalDuration", Value: 3000000}, + {Type: servicelogs.CounterType, Key: "TotalExtensionsCount", Value: 5}, + {Type: servicelogs.CounterType, Key: "InternalExtensionsCount", Value: 2}, + {Type: servicelogs.CounterType, Key: "ExternalExtensionsCount", Value: 3}, + {Type: servicelogs.TimerType, Key: "RunDuration", Value: 1000000}, + {Type: servicelogs.TimerType, Key: "PlatformOverheadDuration", Value: 2000000}, + {Type: servicelogs.CounterType, Key: "ClientError", Value: 0}, + {Type: servicelogs.CounterType, Key: "CustomerError", Value: 0}, + {Type: servicelogs.CounterType, Key: "PlatformError", Value: 1}, + {Type: servicelogs.CounterType, Key: "PlatformErrorReason-Extension.CountRegistrationFailed", Value: 1}, + {Type: servicelogs.CounterType, Key: "NonCustomerError", Value: 1}, + }, + expectedRunDuration: time.Second, + }, + { + name: "init_runtime_failed", + metricFlow: func(m *initMetrics, mocks *initMetricsMocks) { + mocks.timeStamp = mocks.timeStamp.Add(time.Second) + m.TriggerStartRequest() + mocks.timeStamp = mocks.timeStamp.Add(time.Second) + m.TriggerStartingRuntime() + mocks.timeStamp = mocks.timeStamp.Add(time.Second) + m.TriggerInitCustomerPhaseDone() + m.SetExtensionsNumber(2, 3) + mocks.error = model.NewCustomerError(model.ErrorRuntimeInvalidWorkingDir, model.WithSeverity(model.ErrorSeverityInvalid)) + mocks.timeStamp = mocks.timeStamp.Add(time.Second) + }, + expectedMetrics: []servicelogs.Metric{ + {Type: servicelogs.TimerType, Key: "TotalDuration", Value: 4000000}, + {Type: servicelogs.CounterType, Key: "TotalExtensionsCount", Value: 5}, + {Type: servicelogs.CounterType, Key: "InternalExtensionsCount", Value: 2}, + {Type: servicelogs.CounterType, Key: "ExternalExtensionsCount", Value: 3}, + {Type: servicelogs.TimerType, Key: "RunDuration", Value: 2000000}, + {Type: servicelogs.TimerType, Key: "PlatformOverheadDuration", Value: 2000000}, + {Type: servicelogs.CounterType, Key: "ClientError", Value: 0}, + {Type: servicelogs.CounterType, Key: "CustomerError", Value: 1}, + {Type: servicelogs.CounterType, Key: "PlatformError", Value: 0}, + {Type: servicelogs.CounterType, Key: "CustomerErrorReason-Runtime.InvalidWorkingDir", Value: 1}, + {Type: servicelogs.CounterType, Key: "NonCustomerError", Value: 0}, + }, + expectedRunDuration: 2 * time.Second, + }, + { + name: "init_full_flow", + metricFlow: func(m *initMetrics, mocks *initMetricsMocks) { + mocks.timeStamp = mocks.timeStamp.Add(time.Second) + m.TriggerStartRequest() + mocks.timeStamp = mocks.timeStamp.Add(time.Second) + m.TriggerStartingRuntime() + mocks.timeStamp = mocks.timeStamp.Add(time.Second) + m.TriggerRuntimeDone() + mocks.timeStamp = mocks.timeStamp.Add(time.Second) + m.TriggerInitCustomerPhaseDone() + m.SetExtensionsNumber(2, 3) + m.SetLogsAPIMetrics(map[string]int{ + "logs_api_subscribe_success": 2, + "logs_api_subscribe_client_err": 1, + "logs_api_subscribe_server_err": 0, + "logs_api_num_subscribers": 2, + }) + mocks.timeStamp = mocks.timeStamp.Add(time.Second) + }, + expectedMetrics: []servicelogs.Metric{ + {Type: servicelogs.TimerType, Key: "TotalDuration", Value: 5000000}, + {Type: servicelogs.CounterType, Key: "TotalExtensionsCount", Value: 5}, + {Type: servicelogs.CounterType, Key: "InternalExtensionsCount", Value: 2}, + {Type: servicelogs.CounterType, Key: "ExternalExtensionsCount", Value: 3}, + {Type: servicelogs.TimerType, Key: "RuntimeDuration", Value: 1000000}, + {Type: servicelogs.TimerType, Key: "RunDuration", Value: 3000000}, + {Type: servicelogs.TimerType, Key: "PlatformOverheadDuration", Value: 2000000}, + {Type: servicelogs.CounterType, Key: "ClientError", Value: 0}, + {Type: servicelogs.CounterType, Key: "CustomerError", Value: 0}, + {Type: servicelogs.CounterType, Key: "PlatformError", Value: 0}, + {Type: servicelogs.CounterType, Key: "NonCustomerError", Value: 0}, + {Type: servicelogs.CounterType, Key: "logs_api_subscribe_success", Value: 2}, + {Type: servicelogs.CounterType, Key: "logs_api_subscribe_client_err", Value: 1}, + {Type: servicelogs.CounterType, Key: "logs_api_subscribe_server_err", Value: 0}, + {Type: servicelogs.CounterType, Key: "logs_api_num_subscribers", Value: 2}, + }, + expectedRunDuration: 3 * time.Second, + }, + } + + metricsSortFunc := func(a, b servicelogs.Metric) int { + return cmp.Compare(a.Key, b.Key) + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mocks := &initMetricsMocks{ + initData: interop.MockInitStaticDataProvider{}, + logger: servicelogs.MockLogger{}, + } + + m := NewInitMetrics(&mocks.logger) + m.getCurrentTime = func() time.Time { + return mocks.timeStamp + } + + mocks.initData.On("InitTimeout").Return(time.Second).Maybe() + mocks.initData.On("MemorySizeMB").Return(uint64(128)).Maybe() + mocks.initData.On("FunctionARN").Return("function-arn").Maybe() + mocks.initData.On("FunctionVersionID").Return("function-version-id").Maybe() + mocks.initData.On("RuntimeVersion").Return("python3.9").Maybe() + mocks.initData.On("ArtefactType").Return(intmodel.ArtefactTypeOCI).Maybe() + mocks.initData.On("AmiId").Return("ami-1234567").Maybe() + mocks.initData.On("AvailabilityZoneId").Return("us-west-2").Maybe() + + mocks.logger.On("Log", + mock.MatchedBy(func(op servicelogs.Operation) bool { + return assert.Equal(t, servicelogs.InitOp, op) + }), + mock.AnythingOfType("time.Time"), + []servicelogs.Tuple(nil), + []servicelogs.Tuple(nil), + mock.MatchedBy(func(metrics []servicelogs.Metric) bool { + slices.SortFunc(metrics, metricsSortFunc) + slices.SortFunc(tt.expectedMetrics, metricsSortFunc) + assert.Equal(t, len(tt.expectedMetrics), len(metrics)) + for i := range len(tt.expectedMetrics) { + require.Equal(t, tt.expectedMetrics[i].Key, metrics[i].Key) + require.Equal(t, tt.expectedMetrics[i].Type, metrics[i].Type, fmt.Sprintf("wrong format for %s", metrics[i].Key)) + require.Equal(t, tt.expectedMetrics[i].Value, metrics[i].Value, fmt.Sprintf("wrong value for %s", metrics[i].Key)) + } + + return true + }), + ).Once() + + mocks.timeStamp = time.Now() + + m.TriggerGetRequest() + tt.metricFlow(m, mocks) + m.TriggerInitDone(mocks.error) + + require.NoError(t, m.SendMetrics()) + assert.Equal(t, tt.expectedRunDuration, m.RunDuration()) + checkMocksExpectations(t, mocks) + }) + } +} diff --git a/internal/lambda-managed-instances/rapid/model/client_error.go b/internal/lambda-managed-instances/rapid/model/client_error.go new file mode 100644 index 0000000..a67e2a1 --- /dev/null +++ b/internal/lambda-managed-instances/rapid/model/client_error.go @@ -0,0 +1,36 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package model + +import "net/http" + +const ( + ErrorInvalidRequest ErrorType = "InvalidRequest" + ErrorMalformedRequest ErrorType = "ErrMalformedRequest" + ErrorInitIncomplete ErrorType = "Client.InitIncomplete" + ErrorEnvironmentUnhealthy ErrorType = "Client.ExecutionEnvironmentUnhealthy" + ErrorRuntimeUnavailable ErrorType = "Runtime.Unavailable" + ErrorDublicatedInvokeId ErrorType = "Client.DuplicatedInvokeId" + ErrorInvalidFunctionVersion ErrorType = "ErrInvalidFunctionVersion" + ErrorInvalidMaxPayloadSize ErrorType = "ErrInvalidMaxPayloadSize" + ErrorInvalidResponseBandwidthRate ErrorType = "ErrInvalidResponseBandwidthRate" + ErrorInvalidResponseBandwidthBurstSize ErrorType = "ErrInvalidResponseBandwidthBurstSize" + ErrorExecutionEnvironmentShutdown ErrorType = "Client.ExecutionEnvironmentShutDown" +) + +type ClientError struct { + *appError +} + +func NewClientError(cause error, severity ErrorSeverity, errorType ErrorType) ClientError { + return ClientError{ + appError: &appError{ + cause: cause, + severity: severity, + source: ErrorSourceClient, + errorType: errorType, + code: http.StatusBadRequest, + }, + } +} diff --git a/internal/lambda-managed-instances/rapid/model/client_error_test.go b/internal/lambda-managed-instances/rapid/model/client_error_test.go new file mode 100644 index 0000000..c2361d2 --- /dev/null +++ b/internal/lambda-managed-instances/rapid/model/client_error_test.go @@ -0,0 +1,33 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package model + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestClientError(t *testing.T) { + testCases := []struct { + name string + clientError ClientError + expectedError string + }{ + { + name: "New Client Error", + clientError: NewClientError(errors.New("Invalid Request"), ErrorSeverityFatal, ErrorInvalidRequest), + expectedError: "InvalidRequest: Invalid Request", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expectedError, tc.clientError.Error()) + assert.Equal(t, tc.clientError.Severity(), ErrorSeverityFatal) + assert.Equal(t, tc.clientError.Source(), ErrorSourceClient) + }) + } +} diff --git a/internal/lambda-managed-instances/rapid/model/customer_error.go b/internal/lambda-managed-instances/rapid/model/customer_error.go new file mode 100644 index 0000000..380370a --- /dev/null +++ b/internal/lambda-managed-instances/rapid/model/customer_error.go @@ -0,0 +1,94 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package model + +import ( + "net/http" +) + +const ( + ErrorAgentPermissionDenied ErrorType = "PermissionDenied" + ErrorAgentExtensionLaunch ErrorType = "Extension.LaunchError" + ErrorAgentTooManyExtensions ErrorType = "TooManyExtensions" + ErrorAgentInit ErrorType = "Extension.InitError" + ErrorAgentExit ErrorType = "Extension.ExitError" + ErrorAgentCrash ErrorType = "Extension.Crash" + ErrorAgentUnknown ErrorType = "Unknown" + + ErrorRuntimeInit ErrorType = "Runtime.InitError" + ErrorRuntimeInvalidWorkingDir ErrorType = "Runtime.InvalidWorkingDir" + ErrorRuntimeInvalidEntryPoint ErrorType = "Runtime.InvalidEntrypoint" + ErrorRuntimeExit ErrorType = "Runtime.ExitError" + ErrorRuntimeUnknown ErrorType = "Runtime.Unknown" + ErrorRuntimeOutOfMemory ErrorType = "Runtime.OutOfMemory" + ErrorRuntimeTruncatedResponse ErrorType = "Runtime.TruncatedResponse" + ErrorRuntimeInvalidResponseModeHeader ErrorType = "Runtime.InvalidResponseModeHeader" + + ErrorRuntimeInvokeResponseInProgress ErrorType = "Runtime.InvokeResponseInProgress" + ErrorRuntimeInvokeErrorInProgress ErrorType = "Runtime.ErrorResponseInProgress" + ErrorRuntimeInvokeResponseWasSent ErrorType = "Runtime.InvokeResponseWasSent" + ErrorRuntimeInvalidInvokeId ErrorType = "Runtime.InvalidInvokeId" + ErrorRuntimeInvokeTimeout ErrorType = "Runtime.InvokeTimeout" + ErrorRuntimeTooManyIdleRuntimes ErrorType = "Runtime.TooManyIdleRuntimes" + + ErrorSandboxTimedout ErrorType = "Sandbox.Timedout" + ErrorSandboxFailure ErrorType = "Sandbox.Failure" + ErrorSandboxTimeoutResponseTrailer ErrorType = "Sandbox.TimeoutResponseTrailer" + + ErrorFunctionOversizedResponse ErrorType = "Function.ResponseSizeTooLarge" + ErrorFunctionUnknown ErrorType = "Function.Unknown" +) + +type CustomerError struct { + *appError +} + +type ErrorOption func(err *appError) + +func WithErrorMessage(msg string) ErrorOption { + return func(err *appError) { + err.errorMessage = msg + } +} + +func WithSeverity(sev ErrorSeverity) ErrorOption { + return func(err *appError) { + err.severity = sev + } +} + +func WithCause(cause error) ErrorOption { + return func(err *appError) { + err.cause = cause + } +} + +func NewCustomerError(errorType ErrorType, opts ...ErrorOption) CustomerError { + err := appError{ + source: ErrorSourceRuntime, + severity: ErrorSeverityError, + errorType: errorType, + code: http.StatusOK, + } + for _, option := range opts { + option(&err) + } + return CustomerError{&err} +} + +func WrapErrorIntoCustomerInvalidError(e error, errorType ErrorType) CustomerError { + + if err, ok := e.(CustomerError); ok { + return err + } + return NewCustomerError(errorType, WithCause(e), WithSeverity(ErrorSeverityInvalid)) +} + +func WrapErrorIntoCustomerFatalError(e error, errorType ErrorType) CustomerError { + + if err, ok := e.(CustomerError); ok { + return err + } + return NewCustomerError(errorType, WithCause(e), WithSeverity(ErrorSeverityFatal)) +} diff --git a/internal/lambda-managed-instances/rapid/model/customer_error_test.go b/internal/lambda-managed-instances/rapid/model/customer_error_test.go new file mode 100644 index 0000000..de20937 --- /dev/null +++ b/internal/lambda-managed-instances/rapid/model/customer_error_test.go @@ -0,0 +1,36 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package model + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCustomerError(t *testing.T) { + testCases := []struct { + name string + customerError CustomerError + expectedError string + }{ + { + name: "WrapErrorIntoCustomerInvalidError", + customerError: WrapErrorIntoCustomerInvalidError(errors.New("permission denied"), ErrorAgentPermissionDenied), + expectedError: "PermissionDenied: permission denied", + }, + { + name: "WrapErrorIntoCustomerFatalError", + customerError: WrapErrorIntoCustomerFatalError(errors.New("runtime initialization failed"), ErrorRuntimeInit), + expectedError: "Runtime.InitError: runtime initialization failed", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expectedError, tc.customerError.Error()) + }) + } +} diff --git a/internal/lambda-managed-instances/rapid/model/error_types.go b/internal/lambda-managed-instances/rapid/model/error_types.go new file mode 100644 index 0000000..81000fb --- /dev/null +++ b/internal/lambda-managed-instances/rapid/model/error_types.go @@ -0,0 +1,144 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package model + +import ( + "encoding/json" + "fmt" + "regexp" + "strings" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/utils/invariant" +) + +type ErrorSeverity string + +const ( + ErrorSeverityError ErrorSeverity = "Error" + ErrorSeverityFatal ErrorSeverity = "Fatal" + ErrorSeverityInvalid ErrorSeverity = "Invalid" +) + +type ErrorSource string + +const ( + ErrorSourceClient ErrorSource = "Client" + ErrorSourceSandbox ErrorSource = "Sandbox" + ErrorSourceRuntime ErrorSource = "Runtime" +) + +type ( + ErrorType string + ErrorCategory string +) + +const ( + ErrorReasonUnknownError ErrorType = "UnknownError" + ErrorCategoryReasonUnknown ErrorCategory = "Fatal.Sandbox" + ErrorCategoryClientCaused ErrorCategory = "Invalid.Client" +) + +func (e ErrorType) String() string { + return string(e) +} + +func (e ErrorCategory) String() string { + return string(e) +} + +type AppError interface { + error + + Severity() ErrorSeverity + + Source() ErrorSource + + ErrorCategory() ErrorCategory + + ErrorType() ErrorType + + Unwrap() error + + ReturnCode() int + + ErrorDetails() string +} + +type appError struct { + cause error + severity ErrorSeverity + source ErrorSource + errorType ErrorType + code int + + errorMessage string +} + +func (e *appError) Severity() ErrorSeverity { + return e.severity +} + +func (e *appError) Source() ErrorSource { + return e.source +} + +func (e *appError) ErrorType() ErrorType { + return e.errorType +} + +func (e *appError) ErrorDetails() string { + errorDetails, err := json.Marshal(FunctionError{ + Type: e.errorType, + Message: e.errorMessage, + }) + invariant.Checkf(err == nil, "could not json marshal error details: %s", err) + return string(errorDetails) +} + +func (e *appError) Unwrap() error { + return e.cause +} + +func (e *appError) Error() string { + if e.cause == nil { + return string(e.errorType) + } + return string(e.errorType) + ": " + e.cause.Error() +} + +func (e *appError) ErrorCategory() ErrorCategory { + return ErrorCategory(fmt.Sprintf("%s.%s", e.Severity(), e.Source())) +} + +func (e *appError) ReturnCode() int { + return e.code +} + +func GetValidRuntimeOrFunctionErrorType(errorType string) ErrorType { + match, _ := regexp.MatchString("(Runtime|Function)\\.[A-Z][a-zA-Z]+", errorType) + if match { + return ErrorType(errorType) + } + + if strings.HasPrefix(errorType, "Function.") { + return ErrorFunctionUnknown + } + + return ErrorRuntimeUnknown +} + +func GetValidExtensionErrorType(errorType string, defaultErrorType ErrorType) ErrorType { + match, _ := regexp.MatchString("Extension\\.[A-Z][a-zA-Z]+", errorType) + if match { + return ErrorType(errorType) + } + + return defaultErrorType +} + +type FunctionError struct { + Type ErrorType `json:"errorType,omitempty"` + + Message string `json:"errorMessage,omitempty"` +} diff --git a/internal/lambda-managed-instances/rapid/model/error_types_test.go b/internal/lambda-managed-instances/rapid/model/error_types_test.go new file mode 100644 index 0000000..3aee46d --- /dev/null +++ b/internal/lambda-managed-instances/rapid/model/error_types_test.go @@ -0,0 +1,62 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package model + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGetValidRuntimeOrFunctionErrorType(t *testing.T) { + type test struct { + input string + expected ErrorType + } + + tests := []test{ + {"", ErrorRuntimeUnknown}, + {"MyCustomError", ErrorRuntimeUnknown}, + {"MyCustomError.Error", ErrorRuntimeUnknown}, + {"Runtime.MyCustomErrorTypeHere", ErrorType("Runtime.MyCustomErrorTypeHere")}, + {"Function.MyCustomErrorTypeHere", ErrorType("Function.MyCustomErrorTypeHere")}, + } + + for _, tt := range tests { + testname := fmt.Sprintf("TestGetValidRuntimeOrFunctionErrorType with %s", tt.input) + t.Run(testname, func(t *testing.T) { + assert.Equal(t, GetValidRuntimeOrFunctionErrorType(tt.input), tt.expected) + }) + } +} + +func TestGetValidExtensionErrorType(t *testing.T) { + type test struct { + input string + expected ErrorType + } + + defaultErrorType := ErrorAgentExit + + tests := []test{ + {"", defaultErrorType}, + {"MyCustomError", defaultErrorType}, + {"MyCustomError.Error", defaultErrorType}, + {"Runtime.MyCustomErrorTypeHere", defaultErrorType}, + {"Function.MyCustomErrorTypeHere", defaultErrorType}, + {"Extension.", defaultErrorType}, + {"Extension.A", defaultErrorType}, + {"Extension.az", defaultErrorType}, + {"Extension.AA", ErrorType("Extension.AA")}, + {"Extension.Az", ErrorType("Extension.Az")}, + } + + for _, tt := range tests { + testname := fmt.Sprintf("TestGetValidExtensionErrorType with %s", tt.input) + t.Run(testname, func(t *testing.T) { + assert.Equal(t, tt.expected, GetValidExtensionErrorType(tt.input, defaultErrorType)) + }) + } +} diff --git a/internal/lambda-managed-instances/rapid/model/exec.go b/internal/lambda-managed-instances/rapid/model/exec.go new file mode 100644 index 0000000..c677062 --- /dev/null +++ b/internal/lambda-managed-instances/rapid/model/exec.go @@ -0,0 +1,55 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package model + +import ( + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/utils/invariant" +) + +type RuntimeExecError struct { + Type RuntimeExecErrorType + Err error +} + +type RuntimeExec struct { + Env model.KVMap + Cmd []string + WorkingDir string +} + +type Runtime struct { + ExecConfig RuntimeExec + ConfigError *RuntimeExecError +} + +type ExtensionsExec struct { + Env model.KVMap + WorkingDir string +} + +type ExternalAgents struct { + Bootstraps []string + ExecConfig ExtensionsExec +} + +type RuntimeExecErrorType int + +const ( + InvalidTaskConfig RuntimeExecErrorType = iota + InvalidEntrypoint + InvalidWorkingDir +) + +func (e RuntimeExecErrorType) FatalErrorType() ErrorType { + switch e { + case InvalidEntrypoint: + return ErrorRuntimeInvalidEntryPoint + case InvalidWorkingDir: + return ErrorRuntimeInvalidWorkingDir + } + + invariant.Violatef("invalid runtime exec error value: %d", int(e)) + return ErrorType("") +} diff --git a/internal/lambda-managed-instances/rapid/model/function_metadata.go b/internal/lambda-managed-instances/rapid/model/function_metadata.go new file mode 100644 index 0000000..8349b7f --- /dev/null +++ b/internal/lambda-managed-instances/rapid/model/function_metadata.go @@ -0,0 +1,24 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package model + +type FunctionMetadata struct { + AccountID string + FunctionName string + FunctionVersion string + MemorySizeBytes uint64 + Handler string + RuntimeInfo RuntimeInfo +} + +type RuntimeInfo struct { + Arn string + Version string +} + +type Credentials struct { + AwsKey string + AwsSecret string + AwsSession string +} diff --git a/internal/lambda-managed-instances/rapid/model/interfaces.go b/internal/lambda-managed-instances/rapid/model/interfaces.go new file mode 100644 index 0000000..eebe8fd --- /dev/null +++ b/internal/lambda-managed-instances/rapid/model/interfaces.go @@ -0,0 +1,10 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package model + +import "context" + +type RuntimeNextWaiter interface { + RuntimeNextWait(ctx context.Context) AppError +} diff --git a/internal/lambda-managed-instances/rapid/model/platform_error.go b/internal/lambda-managed-instances/rapid/model/platform_error.go new file mode 100644 index 0000000..3f67bc6 --- /dev/null +++ b/internal/lambda-managed-instances/rapid/model/platform_error.go @@ -0,0 +1,51 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package model + +import "net/http" + +const ( + ErrorReasonExtensionExecFailed ErrorType = "ExtensionExecFailure" + ErrorAgentCountRegistrationFailed ErrorType = "Extension.CountRegistrationFailed" + ErrorAgentExtensionCreationFailed ErrorType = "Extension.CreationFailed" + ErrorAgentRegistrationFailed ErrorType = "Extension.RegistrationFailed" + ErrorAgentReadyFailed ErrorType = "Extension.ReadyFailed" + ErrorAgentGateCreationFailed ErrorType = "Extension.GateCreationFailed" + + ErrorReasonRuntimeExecFailed ErrorType = "RuntimeExecFailure" + ErrorRuntimeReadyFailed ErrorType = "Runtime.ReadyFailed" + ErrorRuntimeRegistrationFailed ErrorType = "Runtime.RegistrationFailed" + + ErrSandboxLogSocketsUnavailable ErrorType = "Sandbox.LogSocketsUnavailable" + ErrSandboxEventSetupFailure ErrorType = "Sandbox.EventSetupFailure" + + ErrSandboxShutdownFailed ErrorType = "Sandbox.ShutdownFailed" +) + +type PlatformError struct { + *appError +} + +func NewPlatformError(cause error, errorType ErrorType) PlatformError { + return PlatformError{ + appError: &appError{ + cause: cause, + severity: ErrorSeverityFatal, + source: ErrorSourceSandbox, + errorType: errorType, + code: http.StatusInternalServerError, + }, + } +} + +func WrapErrorIntoPlatformFatalError(e error, errorType ErrorType) PlatformError { + + if platformErr, ok := e.(PlatformError); ok { + return platformErr + } + return NewPlatformError( + e, + errorType, + ) +} diff --git a/internal/lambda-managed-instances/rapid/model/platform_error_test.go b/internal/lambda-managed-instances/rapid/model/platform_error_test.go new file mode 100644 index 0000000..a67a6dd --- /dev/null +++ b/internal/lambda-managed-instances/rapid/model/platform_error_test.go @@ -0,0 +1,31 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package model + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestPlatformError(t *testing.T) { + testCases := []struct { + name string + platformError PlatformError + expectedError string + }{ + { + name: "WrapGoErrorIntoPlatformError", + platformError: WrapErrorIntoPlatformFatalError(errors.New("connection has timed out"), ErrorSandboxTimedout), + expectedError: "Sandbox.Timedout: connection has timed out", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expectedError, tc.platformError.Error()) + }) + } +} diff --git a/internal/lambda-managed-instances/rapid/sandbox.go b/internal/lambda-managed-instances/rapid/sandbox.go new file mode 100644 index 0000000..17df54d --- /dev/null +++ b/internal/lambda-managed-instances/rapid/sandbox.go @@ -0,0 +1,137 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package rapid + +import ( + "context" + "log/slog" + "net/netip" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/appctx" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/core" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/invoke" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi" + rapimodel "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi/rendering" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" + supvmodel "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/supervisor/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/telemetry" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/utils" +) + +const MaxIdleRuntimesQueueSize = 10_000 + +type Dependencies struct { + InteropServer interop.Server + TelemetrySubscriptionAPI telemetry.SubscriptionAPI + LogsEgressAPI telemetry.StdLogsEgressAPI + EventsAPI interop.EventsAPI + Supervisor supvmodel.ProcessSupervisor + FileUtils utils.FileUtil + InvokeRouter *invoke.InvokeRouter + + RuntimeAPIAddrPort netip.AddrPort +} + +func Start(ctx context.Context, deps Dependencies) (interop.RapidContext, error) { + + appCtx := appctx.NewApplicationContext() + initFlow := core.NewInitFlowSynchronization() + registrationService := core.NewRegistrationService(initFlow) + renderingService := rendering.NewRenderingService() + + server, err := rapi.NewServer(deps.RuntimeAPIAddrPort, appCtx, registrationService, renderingService, deps.TelemetrySubscriptionAPI, deps.InvokeRouter) + if err != nil { + return nil, err + } + + appctx.StoreInteropServer(appCtx, deps.InteropServer) + + execCtx := &rapidContext{ + + server: server, + appCtx: appCtx, + initFlow: initFlow, + registrationService: registrationService, + renderingService: renderingService, + shutdownContext: newShutdownContext(), + fileUtils: deps.FileUtils, + invokeRouter: deps.InvokeRouter, + processTermChan: make(chan model.AppError), + + telemetrySubscriptionAPI: deps.TelemetrySubscriptionAPI, + logsEgressAPI: deps.LogsEgressAPI, + interopServer: deps.InteropServer, + eventsAPI: deps.EventsAPI, + supervisor: processSupervisor{ + ProcessSupervisor: deps.Supervisor, + }, + } + + go func() { + + if err := execCtx.server.Serve(ctx); err != nil { + slog.Error("Server error", "err", err) + } + + }() + + return execCtx, nil +} + +func (r *rapidContext) HandleInit(ctx context.Context, initData interop.InitExecutionData, initMetrics interop.InitMetrics) model.AppError { + r.initExecutionData = initData + r.initMetrics = initMetrics + return handleInit(ctx, r) +} + +func (r *rapidContext) HandleInvoke(ctx context.Context, invokeReq interop.InvokeRequest, metrics interop.InvokeMetrics) (err model.AppError, wasResponseSent bool) { + if err := invokeReq.UpdateFromInitData(&r.initExecutionData); err != nil { + return err, false + } + metrics.AttachDependencies(&r.initExecutionData, r.eventsAPI) + return r.invokeRouter.Invoke(ctx, &r.initExecutionData, invokeReq, metrics) +} + +func (r *rapidContext) HandleShutdown(shutdownCause model.AppError, metrics interop.ShutdownMetrics) model.AppError { + metrics.SetAgentCount(len(r.registrationService.GetInternalAgents()), len(r.registrationService.GetExternalAgents())) + + r.invokeRouter.AbortRunningInvokes(metrics, shutdownCause) + + reason := rapimodel.Spindown + + if shutdownCause != nil && shutdownCause.ErrorType() != model.ErrorExecutionEnvironmentShutdown { + reason = rapimodel.Failure + } + + slog.Info("ShutdownContext shutdown() initiated", "reason", reason) + + err := r.shutdownContext.shutdown( + r.supervisor, + r.renderingService, + r.registrationService.GetExternalAgents(), + r.registrationService.CountAgents(), + reason, + metrics, + r.eventsAPI, + ) + if err != nil { + + slog.Warn("Error during shutdown Context shutdown", "err", err) + return model.WrapErrorIntoPlatformFatalError(err, model.ErrSandboxShutdownFailed) + } + + duration := metrics.CreateDurationMetric(interop.ShutdownRuntimeServerDuration) + if err := r.server.Shutdown(); err != nil { + slog.Error("Error during runtime server shutdown", "err", err) + } + duration.Done() + + return nil +} + +func (r *rapidContext) RuntimeAPIAddrPort() netip.AddrPort { + return r.server.AddrPort() +} diff --git a/internal/lambda-managed-instances/rapid/shutdown.go b/internal/lambda-managed-instances/rapid/shutdown.go new file mode 100644 index 0000000..b0bc55b --- /dev/null +++ b/internal/lambda-managed-instances/rapid/shutdown.go @@ -0,0 +1,391 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package rapid + +import ( + "context" + "errors" + "fmt" + "log/slog" + "sync" + "time" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/appctx" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/core" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi/rendering" + rapidmodel "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" + supvmodel "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/supervisor/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/utils/invariant" +) + +const ( + supervisorBlockingTimeout = 1 * time.Second + + maxProcessExitWait = 2 * time.Second + + defaultShutdownTimeout = 2 * time.Second + + defaultRuntimeShutdownTimeout = 600 * time.Millisecond +) + +type shutdownContext struct { + mu sync.Mutex + shuttingDown bool + agentsAwaitingExit map[string]*core.ExternalAgent + + processExited map[string]chan struct{} + shutdownTimeout time.Duration + runtimeShutdownTimeout time.Duration +} + +func newShutdownContext() *shutdownContext { + return &shutdownContext{ + shuttingDown: false, + agentsAwaitingExit: make(map[string]*core.ExternalAgent), + processExited: make(map[string]chan struct{}), + shutdownTimeout: defaultShutdownTimeout, + runtimeShutdownTimeout: defaultRuntimeShutdownTimeout, + } +} + +func (s *shutdownContext) isShuttingDown() bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.shuttingDown +} + +func (s *shutdownContext) setShuttingDown(value bool) { + s.mu.Lock() + defer s.mu.Unlock() + s.shuttingDown = value +} + +func (s *shutdownContext) handleProcessExit(termination supvmodel.ProcessTermination) { + name := termination.Name + s.mu.Lock() + agent, found := s.agentsAwaitingExit[name] + s.mu.Unlock() + + if found { + slog.Debug("Handling termination", "name", name) + if termination.ExitedWithZeroExitCode() { + + stateErr := agent.Exited() + if stateErr != nil { + slog.Warn("failed to transition to EXITED", "agent", agent.String(), "error", stateErr, "currentState", agent.GetState().Name()) + } + } else { + + stateErr := agent.ShutdownFailed() + if stateErr != nil { + slog.Warn("failed to transition to ShutdownFailed", "agent", agent, "error", stateErr, "currentState", agent.GetState().Name()) + } + } + } + + if exitedChannel, found := s.getExitedChannel(name); found { + + close(exitedChannel) + } else { + slog.Warn("Unknown process: possibly failed to launch, or it is from previous generation", "name", name) + } +} + +func (s *shutdownContext) getExitedChannel(name string) (chan struct{}, bool) { + s.mu.Lock() + defer s.mu.Unlock() + exitedChannel, found := s.processExited[name] + return exitedChannel, found +} + +func (s *shutdownContext) createExitedChannel(name string) { + s.mu.Lock() + defer s.mu.Unlock() + + _, found := s.processExited[name] + invariant.Checkf(!found, "Tried to create an exited channel for '%s' but one already exists.", name) + + s.processExited[name] = make(chan struct{}) +} + +func (s *shutdownContext) waitUntilAllProcessesExit(metrics interop.ShutdownMetrics) error { + duration := metrics.CreateDurationMetric(interop.ShutdownWaitAllProcessesDuration) + defer duration.Done() + + s.mu.Lock() + channels := make([]chan struct{}, 0, len(s.processExited)) + for _, v := range s.processExited { + channels = append(channels, v) + } + s.mu.Unlock() + + exitTimeout := time.After(maxProcessExitWait) + for _, v := range channels { + select { + case <-v: + case <-exitTimeout: + return errors.New("timed out waiting for runtime processes to exit") + } + } + + return nil +} + +func (s *shutdownContext) shutdownRuntime(ctx context.Context, supervisor processSupervisor, metrics interop.ShutdownMetrics) error { + duration := metrics.CreateDurationMetric(interop.ShutdownRuntimeDuration) + defer duration.Done() + + slog.Debug("Shutting down the runtime.") + name := runtimeProcessName + + exitedChannel, found := s.getExitedChannel(name) + if !found { + slog.Warn("runtime was not started", "name", name) + return errors.New("runtime was not started") + } + + err := supervisor.Terminate(ctx, &supvmodel.TerminateRequest{ + Name: name, + }) + if err != nil { + + slog.Warn("Failed sending Termination signal to runtime", "err", err) + } + err = waitProcessExitedOrKill(ctx, exitedChannel, name, supervisor) + slog.Debug("Shutdown the runtime.") + return err +} + +func (s *shutdownContext) shutdownAgents( + ctx context.Context, + supervisor processSupervisor, + renderingService *rendering.EventRenderingService, + agents []*core.ExternalAgent, + reason model.ShutdownReason, + extShutdownDeadline time.Time, + metrics interop.ShutdownMetrics, +) error { + duration := metrics.CreateDurationMetric(interop.ShutdownExtensionsDuration) + defer duration.Done() + + slog.Debug("Shutting down the agents.") + + renderingService.SetRenderer( + &rendering.ShutdownRenderer{ + AgentEvent: model.AgentShutdownEvent{ + AgentEvent: &model.AgentEvent{ + EventType: "SHUTDOWN", + DeadlineMs: extShutdownDeadline.UnixMilli(), + }, + ShutdownReason: reason, + }, + }) + + resultAwaiters := make([]<-chan error, 0, len(agents)) + + for _, a := range agents { + name := extensionProcessName(a.Name()) + exitedChannel, found := s.getExitedChannel(name) + + if !found { + slog.Warn("Agent failed to launch, therefore skipping shutting it down", "agent", a) + continue + } + + awaiter := make(chan error, 1) + resultAwaiters = append(resultAwaiters, awaiter) + + if a.IsSubscribed(core.ShutdownEvent) { + slog.Debug("Agent is registered for the shutdown event", "agent", a) + s.mu.Lock() + s.agentsAwaitingExit[name] = a + s.mu.Unlock() + + go func(name string, agent *core.ExternalAgent, ch chan<- error) { + + agent.Release() + ch <- waitProcessExitedOrKill(ctx, exitedChannel, name, supervisor) + }(name, a, awaiter) + } else { + slog.Debug("Agent is not registered for the shutdown event, so just killing it", "agent", a) + + go func(name string, ch chan<- error) { + + defer close(ch) + if err := killProcess(ctx, name, supervisor, nil); err != nil { + slog.Warn("Failed to kill process", "name", name, "err", err) + } + }(name, awaiter) + } + } + + errs := make([]error, 0, len(resultAwaiters)) + for _, ch := range resultAwaiters { + errs = append(errs, <-ch) + } + slog.Debug("Shutdown the agents.") + return errors.Join(errs...) +} + +func killProcess(ctx context.Context, name string, supervisor supvmodel.ProcessSupervisor, metrics interop.ShutdownMetrics) error { + if metrics != nil { + duration := metrics.CreateDurationMetric(fmt.Sprintf(interop.ShutdownKillProcessDurationMetricTemplate, name)) + defer duration.Done() + + } + + deadline, hasDeadline := ctx.Deadline() + invariant.Checkf(hasDeadline, "Context to kill process %q has a deadline", name) + + err := supervisor.Kill(ctx, &supvmodel.KillRequest{ + Name: name, + Deadline: deadline, + }) + if err != nil { + slog.Warn("Failed sending Kill signal to process", "name", name, "err", err) + } + return err +} + +func waitProcessExitedOrKill( + ctx context.Context, + exitedChannel <-chan struct{}, + name string, + supervisor supvmodel.ProcessSupervisor, +) error { + select { + case <-exitedChannel: + return nil + case <-ctx.Done(): + select { + case <-exitedChannel: + + return nil + default: + killCtx, killCtxCancel := context.WithDeadline(context.Background(), time.Now().Add(supervisorBlockingTimeout)) + defer killCtxCancel() + return killProcess(killCtx, name, supervisor, nil) + } + } +} + +func (s *shutdownContext) shutdown( + supervisor processSupervisor, + renderingService *rendering.EventRenderingService, + externalAgents []*core.ExternalAgent, + totalAgentsCount int, + reason model.ShutdownReason, + metrics interop.ShutdownMetrics, + eventsAPI interop.EventsAPI, +) error { + errs := make([]error, 0, 4) + s.setShuttingDown(true) + + if totalAgentsCount == 0 { + name := runtimeProcessName + + _, found := s.getExitedChannel(name) + + if found { + slog.Debug("SIGKILLing the runtime as no agents are registered.") + runtimeSigkillCtx, cancel := context.WithTimeout(context.Background(), supervisorBlockingTimeout) + defer cancel() + errs = append(errs, killProcess(runtimeSigkillCtx, name, supervisor, metrics)) + } else { + slog.Debug("Could not find runtime process in processes map. Already exited/never started", "name", name) + } + } else { + shutdownConfig := s.configureShutdownDeadlines() + extShutdownCtx, extCancel := context.WithDeadline(context.Background(), shutdownConfig.extSigkillDeadline) + defer extCancel() + rtShutdownCtx, rtCancel := context.WithDeadline(context.Background(), shutdownConfig.rtShutdownDeadline) + defer rtCancel() + + errs = append(errs, s.shutdownRuntime(rtShutdownCtx, supervisor, metrics)) + eventsAPI.Flush() + errs = append(errs, s.shutdownAgents(extShutdownCtx, supervisor, renderingService, externalAgents, reason, shutdownConfig.extShutdownDeadline, metrics)) + } + + if err := errors.Join(errs...); err == nil { + slog.Info("Waiting for runtime domain processes termination") + err = s.waitUntilAllProcessesExit(metrics) + + errs = append(errs, err) + } + + return errors.Join(errs...) +} + +type configShutdownDeadlines struct { + shutdownTimeout time.Duration + rtShutdownDeadline time.Time + extShutdownDeadline time.Time + extSigkillDeadline time.Time +} + +func (s *shutdownContext) configureShutdownDeadlines() configShutdownDeadlines { + now := time.Now() + runtimeShutdownDeadline := now.Add(s.runtimeShutdownTimeout) + return configShutdownDeadlines{ + shutdownTimeout: s.shutdownTimeout, + rtShutdownDeadline: runtimeShutdownDeadline, + extShutdownDeadline: now.Add(s.shutdownTimeout), + extSigkillDeadline: now.Add(s.shutdownTimeout), + } +} + +func (s *shutdownContext) processTermination( + event supvmodel.ProcessTermination, + execCtx *rapidContext, +) { + var fatalError rapidmodel.ErrorType + var err error + + defer s.handleProcessExit(event) + + if !s.isShuttingDown() { + + s.setShuttingDown(true) + + if event.Name == runtimeProcessName { + switch { + case event.OomKilled(): + err = fmt.Errorf("runtime exited with error: %s", event.String()) + fatalError = rapidmodel.ErrorRuntimeOutOfMemory + case event.ExitedWithZeroExitCode(): + err = fmt.Errorf("runtime exited without providing a reason") + fatalError = rapidmodel.ErrorRuntimeExit + default: + err = fmt.Errorf("runtime exited with error: %s", event.String()) + fatalError = rapidmodel.ErrorRuntimeExit + } + } else { + fatalError = rapidmodel.ErrorAgentCrash + if event.ExitedWithZeroExitCode() { + err = fmt.Errorf("exit code 0") + } else { + err = errors.New(event.String()) + } + } + + appctx.StoreFirstFatalError(execCtx.appCtx, rapidmodel.WrapErrorIntoCustomerFatalError(nil, fatalError)) + + customerError, _ := appctx.LoadFirstFatalError(execCtx.appCtx) + + select { + case execCtx.processTermChan <- customerError: + default: + + } + slog.Warn("Process exited", "name", event.Name, "event", event) + } + + execCtx.registrationService.CancelFlows(err) +} + +func extensionProcessName(extensionName string) string { + return fmt.Sprintf("extension-%s", extensionName) +} diff --git a/internal/lambda-managed-instances/rapid/shutdown_metrics.go b/internal/lambda-managed-instances/rapid/shutdown_metrics.go new file mode 100644 index 0000000..a6e2a0d --- /dev/null +++ b/internal/lambda-managed-instances/rapid/shutdown_metrics.go @@ -0,0 +1,149 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package rapid + +import ( + "fmt" + "regexp" + "sync" + "time" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/servicelogs" +) + +const ( + shutdownReasonTemplate = "ShutdownReason-%s" +) + +type durationMetricTimer struct { + metricName string + startTime time.Time + metrics *shutdownMetrics +} + +type durationMetric struct { + metricName string + duration time.Duration +} + +func (t *durationMetricTimer) Done() { + t.metrics.mutex.Lock() + defer t.metrics.mutex.Unlock() + + t.metrics.durationMetrics = append(t.metrics.durationMetrics, durationMetric{ + metricName: t.metricName, + duration: t.metrics.getCurrentTime().Sub(t.startTime), + }) +} + +type shutdownMetrics struct { + logger servicelogs.Logger + + reason model.AppError + error model.AppError + + mutex sync.Mutex + + props []servicelogs.Property + dims []servicelogs.Dimension + metrics []servicelogs.Metric + durationMetrics []durationMetric + + startTime time.Time + + internalExtensionCount int + externalExtensionCount int + + killProcessDurationRegex *regexp.Regexp + + getCurrentTime func() time.Time +} + +func NewShutdownMetrics(logger servicelogs.Logger, reason model.AppError) *shutdownMetrics { + return &shutdownMetrics{ + logger: logger, + reason: reason, + startTime: time.Now(), + getCurrentTime: time.Now, + + killProcessDurationRegex: regexp.MustCompile(`^Kill.+Duration$`), + } +} + +func (m *shutdownMetrics) CreateDurationMetric(name string) interop.DurationMetricTimer { + return &durationMetricTimer{ + metricName: name, + startTime: m.getCurrentTime(), + metrics: m, + } +} + +func (m *shutdownMetrics) AddMetric(metric servicelogs.Metric) { + m.mutex.Lock() + defer m.mutex.Unlock() + + m.metrics = append(m.metrics, metric) +} + +func (m *shutdownMetrics) SetAgentCount(internal, external int) { + m.internalExtensionCount = internal + m.externalExtensionCount = external +} + +func (m *shutdownMetrics) SendMetrics(err model.AppError) { + m.error = err + + m.buildMetrics() + + m.mutex.Lock() + defer m.mutex.Unlock() + m.logger.Log(servicelogs.ShutdownOp, m.startTime, m.props, m.dims, m.metrics) +} + +func (m *shutdownMetrics) buildMetrics() { + m.mutex.Lock() + defer m.mutex.Unlock() + + m.metrics = append(m.metrics, + servicelogs.Counter(interop.InternalExtensionsCountMetric, float64(m.internalExtensionCount)), + servicelogs.Counter(interop.ExternalExtensionsCountMetric, float64(m.externalExtensionCount)), + servicelogs.Counter(interop.TotalExtensionsCountMetric, float64(m.internalExtensionCount+m.externalExtensionCount)), + ) + + var totalDuration, sumCustomerDuration time.Duration + + for _, metric := range m.durationMetrics { + switch key := metric.metricName; { + case key == interop.TotalDurationMetric: + totalDuration = metric.duration + case key == interop.ShutdownRuntimeDuration, key == interop.ShutdownExtensionsDuration, key == interop.ShutdownWaitAllProcessesDuration, m.killProcessDurationRegex.MatchString(key): + sumCustomerDuration += metric.duration + } + + m.metrics = append(m.metrics, servicelogs.Timer(metric.metricName, metric.duration)) + } + + shutdodwnOverhead := totalDuration - sumCustomerDuration + m.metrics = append(m.metrics, servicelogs.Timer(interop.PlatformOverheadDurationMetric, shutdodwnOverhead)) + + if m.reason != nil { + m.metrics = append( + m.metrics, + servicelogs.Counter(fmt.Sprintf(shutdownReasonTemplate, m.reason), 1.0), + ) + } + + var platformErrCnt float64 + if m.error != nil { + platformErrCnt = 1 + m.metrics = append(m.metrics, + servicelogs.Counter(fmt.Sprintf(interop.PlatformErrorReasonTemplate, m.error.ErrorType()), 1.0), + ) + } + m.metrics = append(m.metrics, + servicelogs.Counter(interop.PlatformErrorMetric, platformErrCnt), + ) +} diff --git a/internal/lambda-managed-instances/rapid/shutdown_metrics_test.go b/internal/lambda-managed-instances/rapid/shutdown_metrics_test.go new file mode 100644 index 0000000..20893fe --- /dev/null +++ b/internal/lambda-managed-instances/rapid/shutdown_metrics_test.go @@ -0,0 +1,222 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package rapid + +import ( + "cmp" + "fmt" + "slices" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + intmodel "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/servicelogs" +) + +type shutdownMetricsMocks struct { + timeStamp time.Time + err model.AppError + initData interop.MockInitStaticDataProvider + logger servicelogs.MockLogger +} + +func checkShutdownMocksExpectations(t *testing.T, mocks *shutdownMetricsMocks) { + mocks.initData.AssertExpectations(t) + mocks.logger.AssertExpectations(t) +} + +func Test_shutdownMetrics(t *testing.T) { + tests := []struct { + name string + shutdownReason model.AppError + metricFlow func(m *shutdownMetrics, mocks *shutdownMetricsMocks) + expectedMetrics []servicelogs.Metric + }{ + { + name: "shutdown_full_flow_no_extensions", + metricFlow: func(m *shutdownMetrics, mocks *shutdownMetricsMocks) { + m.SetAgentCount(0, 0) + + timer := m.CreateDurationMetric("TotalDuration") + mocks.timeStamp = mocks.timeStamp.Add(5 * time.Second) + timer.Done() + + timer = m.CreateDurationMetric("AbortInvokeDuration") + mocks.timeStamp = mocks.timeStamp.Add(time.Second) + timer.Done() + + timer = m.CreateDurationMetric("KillruntimeDuration") + mocks.timeStamp = mocks.timeStamp.Add(time.Second) + timer.Done() + + timer = m.CreateDurationMetric("WaitCustomerProcessesExitDuration") + mocks.timeStamp = mocks.timeStamp.Add(2 * time.Second) + timer.Done() + + timer = m.CreateDurationMetric("ShutdownRuntimeServerDuration") + mocks.timeStamp = mocks.timeStamp.Add(time.Second) + timer.Done() + }, + expectedMetrics: []servicelogs.Metric{ + {Type: servicelogs.TimerType, Key: "TotalDuration", Value: 5000000}, + {Type: servicelogs.TimerType, Key: "AbortInvokeDuration", Value: 1000000}, + {Type: servicelogs.TimerType, Key: "KillruntimeDuration", Value: 1000000}, + {Type: servicelogs.TimerType, Key: "WaitCustomerProcessesExitDuration", Value: 2000000}, + {Type: servicelogs.TimerType, Key: "ShutdownRuntimeServerDuration", Value: 1000000}, + {Type: servicelogs.TimerType, Key: "PlatformOverheadDuration", Value: 2000000}, + {Type: servicelogs.CounterType, Key: "TotalExtensionsCount", Value: 0}, + {Type: servicelogs.CounterType, Key: "InternalExtensionsCount", Value: 0}, + {Type: servicelogs.CounterType, Key: "ExternalExtensionsCount", Value: 0}, + {Type: servicelogs.CounterType, Key: "PlatformError", Value: 0}, + }, + }, + { + name: "shutdown_full_flow_with_extensions", + shutdownReason: model.NewPlatformError(nil, model.ErrorAgentCrash), + metricFlow: func(m *shutdownMetrics, mocks *shutdownMetricsMocks) { + m.SetAgentCount(2, 3) + + timer := m.CreateDurationMetric("TotalDuration") + mocks.timeStamp = mocks.timeStamp.Add(5 * time.Second) + timer.Done() + + timer = m.CreateDurationMetric("AbortInvokeDuration") + mocks.timeStamp = mocks.timeStamp.Add(time.Second) + timer.Done() + + timer = m.CreateDurationMetric("StopRuntimeDuration") + mocks.timeStamp = mocks.timeStamp.Add(time.Second) + timer.Done() + + timer = m.CreateDurationMetric("WaitCustomerProcessesExitDuration") + mocks.timeStamp = mocks.timeStamp.Add(2 * time.Second) + timer.Done() + + timer = m.CreateDurationMetric("ShutdownRuntimeServerDuration") + mocks.timeStamp = mocks.timeStamp.Add(time.Second) + timer.Done() + }, + expectedMetrics: []servicelogs.Metric{ + {Type: servicelogs.TimerType, Key: "TotalDuration", Value: 5000000}, + {Type: servicelogs.TimerType, Key: "AbortInvokeDuration", Value: 1000000}, + {Type: servicelogs.TimerType, Key: "StopRuntimeDuration", Value: 1000000}, + {Type: servicelogs.TimerType, Key: "WaitCustomerProcessesExitDuration", Value: 2000000}, + {Type: servicelogs.TimerType, Key: "ShutdownRuntimeServerDuration", Value: 1000000}, + {Type: servicelogs.TimerType, Key: "PlatformOverheadDuration", Value: 2000000}, + {Type: servicelogs.CounterType, Key: "TotalExtensionsCount", Value: 5}, + {Type: servicelogs.CounterType, Key: "InternalExtensionsCount", Value: 2}, + {Type: servicelogs.CounterType, Key: "ExternalExtensionsCount", Value: 3}, + {Type: servicelogs.CounterType, Key: "ShutdownReason-Extension.Crash", Value: 1}, + {Type: servicelogs.CounterType, Key: "PlatformError", Value: 0}, + }, + }, + { + name: "shutdown_no_init_data", + metricFlow: func(m *shutdownMetrics, mocks *shutdownMetricsMocks) { + timer := m.CreateDurationMetric("TotalDuration") + mocks.timeStamp = mocks.timeStamp.Add(time.Second) + timer.Done() + }, + expectedMetrics: []servicelogs.Metric{ + {Type: servicelogs.TimerType, Key: "TotalDuration", Value: 1000000}, + {Type: servicelogs.TimerType, Key: "PlatformOverheadDuration", Value: 1000000}, + {Type: servicelogs.CounterType, Key: "TotalExtensionsCount", Value: 0}, + {Type: servicelogs.CounterType, Key: "InternalExtensionsCount", Value: 0}, + {Type: servicelogs.CounterType, Key: "ExternalExtensionsCount", Value: 0}, + {Type: servicelogs.CounterType, Key: "PlatformError", Value: 0}, + }, + }, + { + name: "shutdown_failed", + metricFlow: func(m *shutdownMetrics, mocks *shutdownMetricsMocks) { + m.SetAgentCount(2, 3) + + timer := m.CreateDurationMetric("TotalDuration") + mocks.timeStamp = mocks.timeStamp.Add(2 * time.Second) + timer.Done() + + timer = m.CreateDurationMetric("AbortInvokeDuration") + mocks.timeStamp = mocks.timeStamp.Add(1 * time.Second) + timer.Done() + + timer = m.CreateDurationMetric("ShutdownRuntimeServerDuration") + mocks.timeStamp = mocks.timeStamp.Add(1 * time.Second) + timer.Done() + + mocks.err = model.NewPlatformError(nil, model.ErrorAgentCrash) + }, + expectedMetrics: []servicelogs.Metric{ + {Type: servicelogs.TimerType, Key: "TotalDuration", Value: 2000000}, + {Type: servicelogs.TimerType, Key: "AbortInvokeDuration", Value: 1000000}, + {Type: servicelogs.TimerType, Key: "ShutdownRuntimeServerDuration", Value: 1000000}, + {Type: servicelogs.TimerType, Key: "PlatformOverheadDuration", Value: 2000000}, + {Type: servicelogs.CounterType, Key: "TotalExtensionsCount", Value: 5}, + {Type: servicelogs.CounterType, Key: "InternalExtensionsCount", Value: 2}, + {Type: servicelogs.CounterType, Key: "ExternalExtensionsCount", Value: 3}, + {Type: servicelogs.CounterType, Key: "PlatformError", Value: 1}, + {Type: servicelogs.CounterType, Key: "PlatformErrorReason-Extension.Crash", Value: 1}, + }, + }, + } + + metricsSortFunc := func(a, b servicelogs.Metric) int { + return cmp.Compare(a.Key, b.Key) + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mocks := &shutdownMetricsMocks{ + initData: interop.MockInitStaticDataProvider{}, + logger: servicelogs.MockLogger{}, + } + + m := NewShutdownMetrics(&mocks.logger, tt.shutdownReason) + m.getCurrentTime = func() time.Time { + return mocks.timeStamp + } + + mocks.initData.On("InitTimeout").Return(time.Second).Maybe() + mocks.initData.On("MemorySizeMB").Return(uint64(128)).Maybe() + mocks.initData.On("FunctionARN").Return("function-arn").Maybe() + mocks.initData.On("FunctionVersionID").Return("function-version-id").Maybe() + mocks.initData.On("RuntimeVersion").Return("python3.9").Maybe() + mocks.initData.On("ArtefactType").Return(intmodel.ArtefactTypeOCI).Maybe() + mocks.initData.On("AmiId").Return("ami-1234567").Maybe() + mocks.initData.On("AvailabilityZoneId").Return("us-west-2").Maybe() + + mocks.logger.On("Log", + mock.MatchedBy(func(op servicelogs.Operation) bool { + return assert.Equal(t, servicelogs.ShutdownOp, op) + }), + mock.AnythingOfType("time.Time"), + []servicelogs.Tuple(nil), + []servicelogs.Tuple(nil), + mock.MatchedBy(func(metrics []servicelogs.Metric) bool { + slices.SortFunc(metrics, metricsSortFunc) + slices.SortFunc(tt.expectedMetrics, metricsSortFunc) + assert.Equal(t, len(tt.expectedMetrics), len(metrics)) + for i := range len(tt.expectedMetrics) { + require.Equal(t, tt.expectedMetrics[i].Key, metrics[i].Key) + require.Equal(t, tt.expectedMetrics[i].Type, metrics[i].Type, fmt.Sprintf("wrong format for %s", metrics[i].Key)) + require.Equal(t, tt.expectedMetrics[i].Value, metrics[i].Value, fmt.Sprintf("wrong value for %s", metrics[i].Key)) + } + + return true + }), + ).Once() + + mocks.timeStamp = time.Now() + tt.metricFlow(m, mocks) + + m.SendMetrics(mocks.err) + checkShutdownMocksExpectations(t, mocks) + }) + } +} diff --git a/internal/lambda-managed-instances/rapid/shutdown_test.go b/internal/lambda-managed-instances/rapid/shutdown_test.go new file mode 100644 index 0000000..33b5f01 --- /dev/null +++ b/internal/lambda-managed-instances/rapid/shutdown_test.go @@ -0,0 +1,393 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package rapid + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/core" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi/rendering" + supervisormodel "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/supervisor/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/telemetry" +) + +func TestShutdown(t *testing.T) { + type Agent struct { + name string + subscribed bool + } + + tests := []struct { + name string + + runtimeShutdownTimeout time.Duration + + shutdownTimeout time.Duration + + shutdownReason model.ShutdownReason + + agents []Agent + + internalAgentsCount int + + runtimeTermination func(*shutdownContext, *supervisormodel.MockProcessSupervisor) + + subscribedExtensionTermination func(*shutdownContext, *supervisormodel.MockProcessSupervisor) + + noSubscriptionExtensionTermination func(*shutdownContext, *supervisormodel.MockProcessSupervisor) + + verifyShutdownResult func(*testing.T, error, *shutdownContext) + }{ + { + name: "No extensions, runtime kill successful -> shutdown is success", + + runtimeShutdownTimeout: 30 * time.Millisecond, + shutdownTimeout: 100 * time.Millisecond, + shutdownReason: model.Spindown, + runtimeTermination: func(s *shutdownContext, supv *supervisormodel.MockProcessSupervisor) { + + supv.On("Kill", mock.Anything, mock.Anything).Return(nil).Once() + + exitedChannel, _ := s.getExitedChannel(runtimeProcessName) + close(exitedChannel) + }, + + verifyShutdownResult: func(t *testing.T, err error, s *shutdownContext) { + assert.Nil(t, err) + }, + }, + { + name: "No extensions, runtime kill fails -> shutdown is a failure", + + runtimeShutdownTimeout: 30 * time.Millisecond, + shutdownTimeout: 100 * time.Millisecond, + shutdownReason: model.Spindown, + runtimeTermination: func(s *shutdownContext, supv *supervisormodel.MockProcessSupervisor) { + + supv.On("Kill", mock.Anything, mock.Anything).Return(fmt.Errorf("boom")).Once() + + exitedChannel, _ := s.getExitedChannel(runtimeProcessName) + close(exitedChannel) + }, + + verifyShutdownResult: func(t *testing.T, err error, s *shutdownContext) { + assert.NotNil(t, err) + assert.NotEmpty(t, s.processExited) + }, + }, + { + name: "No extensions, we Kill runtime, but don't get a signal runtime actually exited -> shutdown is a failure", + + runtimeShutdownTimeout: 30 * time.Millisecond, + shutdownTimeout: 100 * time.Millisecond, + shutdownReason: model.Spindown, + runtimeTermination: func(s *shutdownContext, supv *supervisormodel.MockProcessSupervisor) { + + supv.On("Kill", mock.Anything, mock.Anything).Return(nil).Once() + }, + + verifyShutdownResult: func(t *testing.T, err error, s *shutdownContext) { + assert.NotNil(t, err) + assert.NotEmpty(t, s.processExited) + }, + }, + { + name: "Runtime and one extensions (subscribed to shutdown) and terminates gracefully -> shutdown is success", + runtimeShutdownTimeout: 36 * time.Millisecond, + shutdownTimeout: 120 * time.Millisecond, + shutdownReason: model.Spindown, + agents: []Agent{ + { + name: "agent1", + subscribed: true, + }, + }, + + runtimeTermination: func(s *shutdownContext, supv *supervisormodel.MockProcessSupervisor) { + supv.On("Terminate", mock.Anything, mock.Anything).Return(nil).Once() + + runtimeExitedChannel, _ := s.getExitedChannel(runtimeProcessName) + close(runtimeExitedChannel) + }, + + subscribedExtensionTermination: func(s *shutdownContext, _ *supervisormodel.MockProcessSupervisor) { + + go func() { + + extensionExitedChannel, _ := s.getExitedChannel(extensionProcessName("agent1")) + time.Sleep(20 * time.Millisecond) + close(extensionExitedChannel) + }() + }, + + verifyShutdownResult: func(t *testing.T, err error, s *shutdownContext) { + assert.Nil(t, err) + }, + }, + { + name: "Runtime and only one internal extension with successful runtime termination -> shutdown is success", + runtimeShutdownTimeout: 36 * time.Millisecond, + shutdownTimeout: 120 * time.Millisecond, + shutdownReason: model.Spindown, + + internalAgentsCount: 1, + + runtimeTermination: func(s *shutdownContext, supv *supervisormodel.MockProcessSupervisor) { + supv.On("Terminate", mock.Anything, mock.Anything).Return(nil).Once() + + runtimeExitedChannel, _ := s.getExitedChannel(runtimeProcessName) + close(runtimeExitedChannel) + }, + + verifyShutdownResult: func(t *testing.T, err error, s *shutdownContext) { + assert.Nil(t, err) + }, + }, + { + name: "Runtime and one extensions (subscribed to shutdown) and needs to be killed (fails) -> shutdown is a failure", + runtimeShutdownTimeout: 8 * time.Millisecond, + shutdownTimeout: 25 * time.Millisecond, + shutdownReason: model.Spindown, + agents: []Agent{ + { + name: "agent1", + subscribed: true, + }, + }, + + runtimeTermination: func(s *shutdownContext, supv *supervisormodel.MockProcessSupervisor) { + supv.On("Terminate", mock.Anything, mock.Anything).Return(nil) + + runtimeExitedChannel, _ := s.getExitedChannel(runtimeProcessName) + close(runtimeExitedChannel) + }, + + subscribedExtensionTermination: func(s *shutdownContext, supv *supervisormodel.MockProcessSupervisor) { + + supv.On("Kill", mock.Anything, mock.Anything).Return(fmt.Errorf("boom")).Once() + }, + + verifyShutdownResult: func(t *testing.T, err error, s *shutdownContext) { + assert.NotNil(t, err) + assert.NotEmpty(t, s.processExited) + }, + }, + { + name: "Runtime and one extensions (subscribed). Runtime doesn't terminate gracefully, but extension shutdown still has dedicated time -> success", + runtimeShutdownTimeout: 30 * time.Millisecond, + shutdownTimeout: 100 * time.Millisecond, + shutdownReason: model.Spindown, + agents: []Agent{ + { + name: "agent1", + subscribed: true, + }, + }, + + runtimeTermination: func(s *shutdownContext, supv *supervisormodel.MockProcessSupervisor) { + supv.On("Terminate", mock.Anything, mock.Anything).Return(nil) + supv.On("Kill", mock.Anything, mock.Anything).Return(nil) + + go func() { + + runtimeExitedChannel, _ := s.getExitedChannel(runtimeProcessName) + time.Sleep(50 * time.Millisecond) + close(runtimeExitedChannel) + }() + }, + subscribedExtensionTermination: func(s *shutdownContext, supv *supervisormodel.MockProcessSupervisor) { + go func() { + + extensionExitedChannel, _ := s.getExitedChannel(extensionProcessName("agent1")) + close(extensionExitedChannel) + }() + }, + + verifyShutdownResult: func(t *testing.T, err error, s *shutdownContext) { + assert.Nil(t, err) + }, + }, + { + name: "Runtime and one extensions (not subscribed) and is killed successfully -> shutdown is success", + runtimeShutdownTimeout: 30 * time.Millisecond, + shutdownTimeout: 100 * time.Millisecond, + shutdownReason: model.Spindown, + agents: []Agent{ + { + name: "agent1", + subscribed: false, + }, + }, + + runtimeTermination: func(s *shutdownContext, supv *supervisormodel.MockProcessSupervisor) { + supv.On("Terminate", mock.Anything, mock.Anything).Return(nil). + Run(func(args mock.Arguments) { time.Sleep(10 * time.Millisecond) }).Once() + + runtimeExitedChannel, _ := s.getExitedChannel(runtimeProcessName) + close(runtimeExitedChannel) + }, + + noSubscriptionExtensionTermination: func(s *shutdownContext, supv *supervisormodel.MockProcessSupervisor) { + supv.On("Kill", mock.Anything, mock.Anything).Return(nil).Once() + + extensionExitedChannel, _ := s.getExitedChannel(extensionProcessName("agent1")) + close(extensionExitedChannel) + }, + + verifyShutdownResult: func(t *testing.T, err error, s *shutdownContext) { + assert.Nil(t, err) + }, + }, + { + name: "Runtime and one extensions (subscribed). Runtime terminate gracefully, but extension don't and needs to be killed -> success", + runtimeShutdownTimeout: 30 * time.Millisecond, + shutdownTimeout: 100 * time.Millisecond, + shutdownReason: model.Spindown, + agents: []Agent{ + { + name: "agent1", + subscribed: true, + }, + }, + + runtimeTermination: func(s *shutdownContext, supv *supervisormodel.MockProcessSupervisor) { + supv.On("Terminate", mock.Anything, mock.Anything).Return(nil).Once() + + runtimeExitedChannel, _ := s.getExitedChannel(runtimeProcessName) + close(runtimeExitedChannel) + }, + + subscribedExtensionTermination: func(s *shutdownContext, supv *supervisormodel.MockProcessSupervisor) { + + supv.On("Kill", mock.Anything, mock.Anything).Return(nil) + go func() { + + extensionExitedChannel, _ := s.getExitedChannel(extensionProcessName("agent1")) + time.Sleep(120 * time.Millisecond) + close(extensionExitedChannel) + }() + }, + + verifyShutdownResult: func(t *testing.T, err error, s *shutdownContext) { + assert.Nil(t, err) + }, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + + supv := &supervisormodel.MockProcessSupervisor{} + processSupervisor := processSupervisor{ProcessSupervisor: supv} + renderingService := rendering.NewRenderingService() + shutdownContext := newShutdownContext() + shutdownContext.shutdownTimeout = tt.shutdownTimeout + shutdownContext.runtimeShutdownTimeout = tt.runtimeShutdownTimeout + + defer mock.AssertExpectationsForObjects(t, supv) + + var agents []*core.ExternalAgent + shutdownContext.createExitedChannel(runtimeProcessName) + + assert.NotNil(t, tt.runtimeTermination) + tt.runtimeTermination(shutdownContext, supv) + + for _, a := range tt.agents { + shutdownContext.createExitedChannel(extensionProcessName(a.name)) + + initFlow := core.NewInitFlowSynchronization() + extAgent := core.NewExternalAgent(a.name, initFlow) + agents = append(agents, extAgent) + + if a.subscribed { + require.NoError(t, extAgent.Register([]core.Event{core.ShutdownEvent})) + + assert.NotNil(t, tt.subscribedExtensionTermination) + tt.subscribedExtensionTermination(shutdownContext, supv) + } else { + + assert.NotNil(t, tt.noSubscriptionExtensionTermination) + tt.noSubscriptionExtensionTermination(shutdownContext, supv) + } + } + + metrics := NewShutdownMetrics(nil, nil) + + err := shutdownContext.shutdown( + processSupervisor, + renderingService, + agents, + len(agents)+tt.internalAgentsCount, + tt.shutdownReason, + metrics, + &telemetry.NoOpEventsAPI{}, + ) + + tt.verifyShutdownResult(t, err, shutdownContext) + }) + } +} + +type mockInitFlowSynchronization struct { + mock.Mock +} + +var _ core.InitFlowSynchronization = (*mockInitFlowSynchronization)(nil) + +func (m *mockInitFlowSynchronization) SetExternalAgentsRegisterCount(cnt uint16) error { + args := m.Called(cnt) + return args.Error(0) +} + +func (m *mockInitFlowSynchronization) SetAgentsReadyCount(cnt uint16) error { + args := m.Called(cnt) + return args.Error(0) +} + +func (m *mockInitFlowSynchronization) ExternalAgentRegistered() error { + args := m.Called() + return args.Error(0) +} + +func (m *mockInitFlowSynchronization) AwaitExternalAgentsRegistered(ctx context.Context) error { + args := m.Called(ctx) + return args.Error(0) +} + +func (m *mockInitFlowSynchronization) RuntimeReady() error { + args := m.Called() + return args.Error(0) +} + +func (m *mockInitFlowSynchronization) AwaitRuntimeReady(ctx context.Context) error { + args := m.Called(ctx) + return args.Error(0) +} + +func (m *mockInitFlowSynchronization) AgentReady() error { + args := m.Called() + return args.Error(0) +} + +func (m *mockInitFlowSynchronization) AwaitAgentsReady(ctx context.Context) error { + args := m.Called(ctx) + return args.Error(0) +} + +func (m *mockInitFlowSynchronization) CancelWithError(err error) { + m.Called(err) +} + +func (m *mockInitFlowSynchronization) Clear() { + m.Called() +} diff --git a/internal/lambda-managed-instances/rapidcore/env/environment.go b/internal/lambda-managed-instances/rapidcore/env/environment.go new file mode 100644 index 0000000..6624677 --- /dev/null +++ b/internal/lambda-managed-instances/rapidcore/env/environment.go @@ -0,0 +1,179 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package env + +import ( + "strconv" + "strings" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/model" +) + +const ( + AWS_ACCESS_KEY_ID = "AWS_ACCESS_KEY_ID" + AWS_DEFAULT_REGION = "AWS_DEFAULT_REGION" + AWS_LAMBDA_FUNCTION_MEMORY_SIZE = "AWS_LAMBDA_FUNCTION_MEMORY_SIZE" + AWS_LAMBDA_FUNCTION_NAME = "AWS_LAMBDA_FUNCTION_NAME" + AWS_LAMBDA_FUNCTION_VERSION = "AWS_LAMBDA_FUNCTION_VERSION" + AWS_LAMBDA_LOG_FORMAT = "AWS_LAMBDA_LOG_FORMAT" + AWS_LAMBDA_LOG_GROUP_NAME = "AWS_LAMBDA_LOG_GROUP_NAME" + AWS_LAMBDA_LOG_LEVEL = "AWS_LAMBDA_LOG_LEVEL" + AWS_LAMBDA_LOG_STREAM_NAME = "AWS_LAMBDA_LOG_STREAM_NAME" + AWS_LAMBDA_MAX_CONCURRENCY = "AWS_LAMBDA_MAX_CONCURRENCY" + AWS_REGION = "AWS_REGION" + AWS_SECRET_ACCESS_KEY = "AWS_SECRET_ACCESS_KEY" + AWS_SESSION_TOKEN = "AWS_SESSION_TOKEN" + _AWS_XRAY_DAEMON_ADDRESS = "_AWS_XRAY_DAEMON_ADDRESS" + _AWS_XRAY_DAEMON_PORT = "_AWS_XRAY_DAEMON_PORT" + _LAMBDA_TELEMETRY_LOG_FD_PROVIDER_SOCKET = "_LAMBDA_TELEMETRY_LOG_FD_PROVIDER_SOCKET" + AWS_EXECUTION_ENV = "AWS_EXECUTION_ENV" + AWS_LAMBDA_INITIALIZATION_TYPE = "AWS_LAMBDA_INITIALIZATION_TYPE" + AWS_LAMBDA_RUNTIME_API = "AWS_LAMBDA_RUNTIME_API" + AWS_XRAY_CONTEXT_MISSING = "AWS_XRAY_CONTEXT_MISSING" + AWS_XRAY_DAEMON_ADDRESS = "AWS_XRAY_DAEMON_ADDRESS" + AWS_XRAY_DAEMON_PORT = "AWS_XRAY_DAEMON_PORT" + AWS_XRAY_TRACE_ID = "AWS_XRAY_TRACE_ID" + HANDLER = "_HANDLER" + LAMBDA_RUNTIME_DIR = "LAMBDA_RUNTIME_DIR" + LAMBDA_TASK_ROOT = "LAMBDA_TASK_ROOT" + LANG = "LANG" + LD_LIBRARY_PATH = "LD_LIBRARY_PATH" + PATH = "PATH" + TZ = "TZ" +) + +var Defined = map[string]struct{}{ + AWS_ACCESS_KEY_ID: {}, + AWS_DEFAULT_REGION: {}, + AWS_LAMBDA_FUNCTION_MEMORY_SIZE: {}, + AWS_LAMBDA_FUNCTION_NAME: {}, + AWS_LAMBDA_FUNCTION_VERSION: {}, + AWS_LAMBDA_LOG_FORMAT: {}, + AWS_LAMBDA_LOG_GROUP_NAME: {}, + AWS_LAMBDA_LOG_LEVEL: {}, + AWS_LAMBDA_LOG_STREAM_NAME: {}, + AWS_LAMBDA_MAX_CONCURRENCY: {}, + AWS_REGION: {}, + AWS_SECRET_ACCESS_KEY: {}, + AWS_SESSION_TOKEN: {}, + _AWS_XRAY_DAEMON_ADDRESS: {}, + _AWS_XRAY_DAEMON_PORT: {}, + _LAMBDA_TELEMETRY_LOG_FD_PROVIDER_SOCKET: {}, + AWS_EXECUTION_ENV: {}, + AWS_LAMBDA_INITIALIZATION_TYPE: {}, + AWS_LAMBDA_RUNTIME_API: {}, + AWS_XRAY_CONTEXT_MISSING: {}, + AWS_XRAY_DAEMON_ADDRESS: {}, + AWS_XRAY_DAEMON_PORT: {}, + AWS_XRAY_TRACE_ID: {}, + HANDLER: {}, + LAMBDA_RUNTIME_DIR: {}, + LAMBDA_TASK_ROOT: {}, + LANG: {}, + LD_LIBRARY_PATH: {}, + PATH: {}, + TZ: {}, +} + +var overridable = map[string]struct{}{ + AWS_LAMBDA_LOG_FORMAT: {}, + AWS_LAMBDA_LOG_LEVEL: {}, + AWS_XRAY_CONTEXT_MISSING: {}, + AWS_XRAY_DAEMON_ADDRESS: {}, + LANG: {}, + LD_LIBRARY_PATH: {}, + PATH: {}, + TZ: {}, +} + +func SetupEnvironment(config *model.InitRequestMessage, runtimePort, runtimeLoggingSocket string) (runtimeEnv, extensionEnv model.KVMap) { + + commonVars := model.KVMap{ + AWS_ACCESS_KEY_ID: config.AwsKey, + AWS_DEFAULT_REGION: config.AwsRegion, + AWS_LAMBDA_FUNCTION_MEMORY_SIZE: strconv.Itoa(config.MemorySizeBytes / 1024 / 1024), + AWS_LAMBDA_FUNCTION_NAME: config.TaskName, + AWS_LAMBDA_FUNCTION_VERSION: config.FunctionVersion, + AWS_REGION: config.AwsRegion, + AWS_SECRET_ACCESS_KEY: config.AwsSecret, + AWS_SESSION_TOKEN: config.AwsSession, + AWS_LAMBDA_INITIALIZATION_TYPE: interop.InitializationType, + AWS_LAMBDA_RUNTIME_API: runtimePort, + } + if config.ArtefactType == model.ArtefactTypeZIP { + commonVars[LANG] = "en_US.UTF-8" + commonVars[LD_LIBRARY_PATH] = "/var/lang/lib:/lib64:/usr/lib64:/var/runtime:/var/runtime/lib:/var/task:/var/task/lib:/opt/lib" + commonVars[PATH] = "/var/lang/bin:/usr/local/bin:/usr/bin/:/bin:/opt/bin" + commonVars[TZ] = ":UTC" + } + if config.LogFormat != "" { + commonVars[AWS_LAMBDA_LOG_FORMAT] = config.LogFormat + } + if config.LogLevel != "" { + commonVars[AWS_LAMBDA_LOG_LEVEL] = config.LogLevel + } + + for k, v := range cloneAndFilterCustomerEnvVars(config.EnvVars) { + commonVars[k] = v + } + + return getRuntimeOnlyEnvVars(commonVars, config, runtimeLoggingSocket), commonVars +} + +func getRuntimeOnlyEnvVars(common model.KVMap, config *model.InitRequestMessage, runtimeLoggingSocket string) model.KVMap { + + runtimeOnlyVars := model.KVMap{ + AWS_LAMBDA_LOG_GROUP_NAME: config.LogGroupName, + AWS_LAMBDA_LOG_STREAM_NAME: config.LogStreamName, + AWS_LAMBDA_MAX_CONCURRENCY: strconv.Itoa(config.RuntimeWorkerCount), + _AWS_XRAY_DAEMON_ADDRESS: config.XRayDaemonAddress, + _AWS_XRAY_DAEMON_PORT: "2000", + AWS_XRAY_CONTEXT_MISSING: "LOG_ERROR", + AWS_XRAY_DAEMON_ADDRESS: config.XRayDaemonAddress, + LAMBDA_RUNTIME_DIR: "/var/runtime", + LAMBDA_TASK_ROOT: "/var/task", + } + + if config.ArtefactType == model.ArtefactTypeOCI { + runtimeOnlyVars[AWS_EXECUTION_ENV] = "AWS_Lambda_Image" + } else { + runtimeOnlyVars[HANDLER] = config.Handler + } + + if runtimeLoggingSocket != "" { + runtimeOnlyVars[_LAMBDA_TELEMETRY_LOG_FD_PROVIDER_SOCKET] = runtimeLoggingSocket + } + + merge(runtimeOnlyVars, common) + + return runtimeOnlyVars +} + +func cloneAndFilterCustomerEnvVars(envVars model.KVMap) model.KVMap { + filtered := make(model.KVMap, len(envVars)) + for k, v := range envVars { + + if strings.HasPrefix(string(k), "_") { + continue + } + + _, defined := Defined[k] + _, overridable := overridable[k] + if defined && !overridable { + + continue + } + + filtered[k] = v + } + + return filtered +} + +func merge(to, from model.KVMap) { + for k, v := range from { + to[k] = v + } +} diff --git a/internal/lambda-managed-instances/rapidcore/env/environment_test.go b/internal/lambda-managed-instances/rapidcore/env/environment_test.go new file mode 100644 index 0000000..f29578a --- /dev/null +++ b/internal/lambda-managed-instances/rapidcore/env/environment_test.go @@ -0,0 +1,238 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package env + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + intmodel "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/testutils" +) + +func TestSetupEnvironment(t *testing.T) { + defaultRuntimeEnv := intmodel.KVMap{ + + AWS_ACCESS_KEY_ID: "AKIAIOSFODNN7EXAMPLE", + AWS_DEFAULT_REGION: "us-west-2", + AWS_LAMBDA_FUNCTION_MEMORY_SIZE: "3008", + AWS_LAMBDA_FUNCTION_NAME: "test_function", + AWS_LAMBDA_FUNCTION_VERSION: "$LATEST", + AWS_REGION: "us-west-2", + AWS_SECRET_ACCESS_KEY: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + AWS_SESSION_TOKEN: "FwoGZXIvYXdzEMj//////////wEaDM1Qz0oN8BNwV9GqyyLVAebxhwq9ZGqojXZe1UTJkzK6F9V+VZHhT5JSWYzJUKEwOqOkQyQXJpfJsYHfkJEXtR6Kh9mXnEbqKi", + AWS_LAMBDA_INITIALIZATION_TYPE: "lambda-managed-instances", + AWS_LAMBDA_RUNTIME_API: "127.0.0.1:9001", + HANDLER: "lambda_function.lambda_handler", + LANG: "en_US.UTF-8", + LD_LIBRARY_PATH: "/var/lang/lib:/lib64:/usr/lib64:/var/runtime:/var/runtime/lib:/var/task:/var/task/lib:/opt/lib", + PATH: "/var/lang/bin:/usr/local/bin:/usr/bin/:/bin:/opt/bin", + TZ: ":UTC", + + AWS_LAMBDA_LOG_FORMAT: "json", + AWS_LAMBDA_LOG_GROUP_NAME: "/aws/lambda/test_function", + AWS_LAMBDA_LOG_LEVEL: "info", + AWS_LAMBDA_LOG_STREAM_NAME: "$LATEST", + AWS_LAMBDA_MAX_CONCURRENCY: "1", + _AWS_XRAY_DAEMON_ADDRESS: "2.2.2.2:2345", + _AWS_XRAY_DAEMON_PORT: "2000", + AWS_XRAY_CONTEXT_MISSING: "LOG_ERROR", + AWS_XRAY_DAEMON_ADDRESS: "2.2.2.2:2345", + LAMBDA_RUNTIME_DIR: "/var/runtime", + LAMBDA_TASK_ROOT: "/var/task", + + "CUSTOMER_ENV_VAR_1": "customer_env_value_1", + } + defaultExtensionEnv := intmodel.KVMap{ + + AWS_ACCESS_KEY_ID: "AKIAIOSFODNN7EXAMPLE", + AWS_DEFAULT_REGION: "us-west-2", + AWS_LAMBDA_FUNCTION_MEMORY_SIZE: "3008", + AWS_LAMBDA_FUNCTION_NAME: "test_function", + AWS_LAMBDA_FUNCTION_VERSION: "$LATEST", + AWS_LAMBDA_LOG_FORMAT: "json", + AWS_LAMBDA_LOG_LEVEL: "info", + AWS_REGION: "us-west-2", + AWS_SECRET_ACCESS_KEY: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + AWS_SESSION_TOKEN: "FwoGZXIvYXdzEMj//////////wEaDM1Qz0oN8BNwV9GqyyLVAebxhwq9ZGqojXZe1UTJkzK6F9V+VZHhT5JSWYzJUKEwOqOkQyQXJpfJsYHfkJEXtR6Kh9mXnEbqKi", + AWS_LAMBDA_INITIALIZATION_TYPE: "lambda-managed-instances", + AWS_LAMBDA_RUNTIME_API: "127.0.0.1:9001", + LANG: "en_US.UTF-8", + LD_LIBRARY_PATH: "/var/lang/lib:/lib64:/usr/lib64:/var/runtime:/var/runtime/lib:/var/task:/var/task/lib:/opt/lib", + PATH: "/var/lang/bin:/usr/local/bin:/usr/bin/:/bin:/opt/bin", + TZ: ":UTC", + + "CUSTOMER_ENV_VAR_1": "customer_env_value_1", + } + + tests := []struct { + name string + initMsg intmodel.InitRequestMessage + runtimeLoggingSocket string + wantRuntimeEnv func(runtimeEnv intmodel.KVMap) intmodel.KVMap + wantExtensionEnv func(extensionEnv intmodel.KVMap) intmodel.KVMap + }{ + { + name: "zip", + initMsg: testutils.MakeInitPayload(), + }, + { + name: "runtimeLoggingSocket", + initMsg: testutils.MakeInitPayload(), + runtimeLoggingSocket: "/path/to/runtimeLoggingSocket", + wantRuntimeEnv: func(env intmodel.KVMap) intmodel.KVMap { + env[_LAMBDA_TELEMETRY_LOG_FD_PROVIDER_SOCKET] = "/path/to/runtimeLoggingSocket" + return env + }, + }, + { + name: "oci", + initMsg: testutils.MakeInitPayload(testutils.WithArtefactType(intmodel.ArtefactTypeOCI)), + wantRuntimeEnv: func(env intmodel.KVMap) intmodel.KVMap { + env[AWS_EXECUTION_ENV] = "AWS_Lambda_Image" + delete(env, HANDLER) + delete(env, LANG) + delete(env, LD_LIBRARY_PATH) + delete(env, PATH) + delete(env, TZ) + return env + }, + wantExtensionEnv: func(env intmodel.KVMap) intmodel.KVMap { + delete(env, HANDLER) + delete(env, LANG) + delete(env, LD_LIBRARY_PATH) + delete(env, PATH) + delete(env, TZ) + return env + }, + }, + { + name: "empty_AWS_LAMBDA_LOG_FORMAT", + initMsg: testutils.MakeInitPayload(testutils.WithLogFormat("")), + wantRuntimeEnv: func(env intmodel.KVMap) intmodel.KVMap { + delete(env, AWS_LAMBDA_LOG_FORMAT) + return env + }, + wantExtensionEnv: func(env intmodel.KVMap) intmodel.KVMap { + delete(env, AWS_LAMBDA_LOG_FORMAT) + return env + }, + }, + { + name: "empty_AWS_LAMBDA_LOG_LEVEL", + initMsg: testutils.MakeInitPayload(testutils.WithLogLevel("")), + wantRuntimeEnv: func(env intmodel.KVMap) intmodel.KVMap { + delete(env, AWS_LAMBDA_LOG_LEVEL) + return env + }, + wantExtensionEnv: func(env intmodel.KVMap) intmodel.KVMap { + delete(env, AWS_LAMBDA_LOG_LEVEL) + return env + }, + }, + { + name: "customer_sets_underscore_env_var", + initMsg: testutils.MakeInitPayload(testutils.WithEnvVars(intmodel.KVMap{ + "CUSTOMER_ENV_VAR_1": "customer_env_value_1", + "_ENV_1": "val_1", + "_ENV_2": "val_2", + })), + }, + { + name: "customer_overwrites_all_defined_env_vars_zip", + initMsg: testutils.MakeInitPayload(testutils.WithEnvVars(func() intmodel.KVMap { + customerEnvVars := make(intmodel.KVMap, len(Defined)) + for k := range Defined { + customerEnvVars[k] = "customer_" + k + } + customerEnvVars["CUSTOMER_ENV_VAR_1"] = "customer_env_value_1" + return customerEnvVars + }())), + wantRuntimeEnv: func(env intmodel.KVMap) intmodel.KVMap { + env[AWS_LAMBDA_LOG_FORMAT] = "customer_AWS_LAMBDA_LOG_FORMAT" + env[AWS_LAMBDA_LOG_LEVEL] = "customer_AWS_LAMBDA_LOG_LEVEL" + env[AWS_XRAY_CONTEXT_MISSING] = "customer_AWS_XRAY_CONTEXT_MISSING" + env[AWS_XRAY_DAEMON_ADDRESS] = "customer_AWS_XRAY_DAEMON_ADDRESS" + env[LANG] = "customer_LANG" + env[LD_LIBRARY_PATH] = "customer_LD_LIBRARY_PATH" + env[PATH] = "customer_PATH" + env[TZ] = "customer_TZ" + return env + }, + wantExtensionEnv: func(env intmodel.KVMap) intmodel.KVMap { + env[AWS_LAMBDA_LOG_FORMAT] = "customer_AWS_LAMBDA_LOG_FORMAT" + env[AWS_LAMBDA_LOG_LEVEL] = "customer_AWS_LAMBDA_LOG_LEVEL" + env[AWS_XRAY_CONTEXT_MISSING] = "customer_AWS_XRAY_CONTEXT_MISSING" + env[AWS_XRAY_DAEMON_ADDRESS] = "customer_AWS_XRAY_DAEMON_ADDRESS" + env[LANG] = "customer_LANG" + env[LD_LIBRARY_PATH] = "customer_LD_LIBRARY_PATH" + env[PATH] = "customer_PATH" + env[TZ] = "customer_TZ" + return env + }, + }, + { + name: "customer_overwrites_all_defined_env_vars_oci", + initMsg: testutils.MakeInitPayload( + testutils.WithArtefactType(intmodel.ArtefactTypeOCI), + testutils.WithEnvVars(func() intmodel.KVMap { + customerEnvVars := make(intmodel.KVMap, len(Defined)) + for k := range Defined { + customerEnvVars[k] = "customer_" + k + } + customerEnvVars["CUSTOMER_ENV_VAR_1"] = "customer_env_value_1" + return customerEnvVars + }()), + ), + wantRuntimeEnv: func(env intmodel.KVMap) intmodel.KVMap { + delete(env, HANDLER) + env[AWS_EXECUTION_ENV] = "AWS_Lambda_Image" + + env[AWS_LAMBDA_LOG_FORMAT] = "customer_AWS_LAMBDA_LOG_FORMAT" + env[AWS_LAMBDA_LOG_LEVEL] = "customer_AWS_LAMBDA_LOG_LEVEL" + env[AWS_XRAY_CONTEXT_MISSING] = "customer_AWS_XRAY_CONTEXT_MISSING" + env[AWS_XRAY_DAEMON_ADDRESS] = "customer_AWS_XRAY_DAEMON_ADDRESS" + env[LANG] = "customer_LANG" + env[LD_LIBRARY_PATH] = "customer_LD_LIBRARY_PATH" + env[PATH] = "customer_PATH" + env[TZ] = "customer_TZ" + return env + }, + wantExtensionEnv: func(env intmodel.KVMap) intmodel.KVMap { + env[AWS_LAMBDA_LOG_FORMAT] = "customer_AWS_LAMBDA_LOG_FORMAT" + env[AWS_LAMBDA_LOG_LEVEL] = "customer_AWS_LAMBDA_LOG_LEVEL" + env[AWS_XRAY_CONTEXT_MISSING] = "customer_AWS_XRAY_CONTEXT_MISSING" + env[AWS_XRAY_DAEMON_ADDRESS] = "customer_AWS_XRAY_DAEMON_ADDRESS" + env[LANG] = "customer_LANG" + env[LD_LIBRARY_PATH] = "customer_LD_LIBRARY_PATH" + env[PATH] = "customer_PATH" + env[TZ] = "customer_TZ" + return env + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.wantRuntimeEnv == nil { + tt.wantRuntimeEnv = func(runtimeEnv intmodel.KVMap) intmodel.KVMap { return runtimeEnv } + } + if tt.wantExtensionEnv == nil { + tt.wantExtensionEnv = func(extensionEnv intmodel.KVMap) intmodel.KVMap { return extensionEnv } + } + + gotRuntimeEnv, gotExtensionEnv := SetupEnvironment(&tt.initMsg, "127.0.0.1:9001", tt.runtimeLoggingSocket) + assert.Equal(t, tt.wantRuntimeEnv(clone(defaultRuntimeEnv)), gotRuntimeEnv) + assert.Equal(t, tt.wantExtensionEnv(clone(defaultExtensionEnv)), gotExtensionEnv) + }) + } +} + +func clone(m intmodel.KVMap) intmodel.KVMap { + cloned := make(intmodel.KVMap, len(m)) + for k, v := range m { + cloned[k] = v + } + return cloned +} diff --git a/internal/lambda-managed-instances/rapidcore/env/util.go b/internal/lambda-managed-instances/rapidcore/env/util.go new file mode 100644 index 0000000..174c13c --- /dev/null +++ b/internal/lambda-managed-instances/rapidcore/env/util.go @@ -0,0 +1,43 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package env + +import ( + "errors" + "fmt" + "log/slog" + "strings" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/model" +) + +func SplitEnvironmentVariable(envKeyVal string) (string, string, error) { + splitKeyVal := strings.SplitN(envKeyVal, "=", 2) + if len(splitKeyVal) < 2 { + return "", "", errors.New("could not split env var by '=' delimiter") + } + return splitKeyVal[0], splitKeyVal[1], nil +} + +func KVPairStringsToMap(envKVPairs model.KVSlice) model.KVMap { + bootstrapEnvMap := make(model.KVMap, len(envKVPairs)) + for _, es := range envKVPairs { + key, val, err := SplitEnvironmentVariable(es) + if err != nil { + + slog.Warn("invalid environment variable format", "err", err) + continue + } + bootstrapEnvMap[key] = val + } + return bootstrapEnvMap +} + +func MapToKVPairStrings(m model.KVMap) model.KVSlice { + var env model.KVSlice + for k, v := range m { + env = append(env, fmt.Sprintf("%s=%s", k, v)) + } + return env +} diff --git a/internal/lambda-managed-instances/rapidcore/env/util_test.go b/internal/lambda-managed-instances/rapidcore/env/util_test.go new file mode 100644 index 0000000..2b0b139 --- /dev/null +++ b/internal/lambda-managed-instances/rapidcore/env/util_test.go @@ -0,0 +1,36 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package env + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestEnvironmentVariableSplitting(t *testing.T) { + envVar := "FOO=BAR" + k, v, err := SplitEnvironmentVariable(envVar) + assert.NoError(t, err) + assert.Equal(t, k, "FOO") + assert.Equal(t, v, "BAR") + + envVar = "FOO=BAR=BAZ" + k, v, err = SplitEnvironmentVariable(envVar) + assert.NoError(t, err) + assert.Equal(t, k, "FOO") + assert.Equal(t, v, "BAR=BAZ") + + envVar = "FOO=" + k, v, err = SplitEnvironmentVariable(envVar) + assert.NoError(t, err) + assert.Equal(t, k, "FOO") + assert.Equal(t, v, "") + + envVar = "FOO" + k, v, err = SplitEnvironmentVariable(envVar) + assert.Error(t, err) + assert.Equal(t, k, "") + assert.Equal(t, v, "") +} diff --git a/internal/lambda-managed-instances/rapidcore/errors.go b/internal/lambda-managed-instances/rapidcore/errors.go new file mode 100644 index 0000000..60aabee --- /dev/null +++ b/internal/lambda-managed-instances/rapidcore/errors.go @@ -0,0 +1,27 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package rapidcore + +import "errors" + +var ( + ErrInitDoneFailed = errors.New("InitDoneFailed") + ErrInitNotStarted = errors.New("InitNotStarted") +) + +var ( + ErrAlreadyReplied = errors.New("AlreadyReplied") + ErrAlreadyInvocating = errors.New("AlreadyInvocating") +) + +var ( + ErrInvokeResponseAlreadyWritten = errors.New("InvokeResponseAlreadyWritten") + ErrInvokeDoneFailed = errors.New("InvokeDoneFailed") + ErrInvokeReservationDone = errors.New("InvokeReservationDone") +) + +var ( + ErrInternalServerError = errors.New("InternalServerError") + ErrInvokeTimeout = errors.New("InvokeTimeout") +) diff --git a/internal/lambda-managed-instances/rapidcore/runtime_release.go b/internal/lambda-managed-instances/rapidcore/runtime_release.go new file mode 100644 index 0000000..9209173 --- /dev/null +++ b/internal/lambda-managed-instances/rapidcore/runtime_release.go @@ -0,0 +1,97 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package rapidcore + +import ( + "bufio" + "errors" + "fmt" + "log/slog" + "os" + "strings" +) + +type Logging string + +const ( + AmznStdout Logging = "amzn-stdout" + AmznStdoutTLV Logging = "amzn-stdout-tlv" +) + +type RuntimeRelease struct { + Name string + Version string + Logging Logging +} + +func (rr *RuntimeRelease) GetUAProduct() (string, error) { + if rr.Name == "" { + return "", errors.New("runtime release name is empty") + } + if rr.Version == "" { + return rr.Name, nil + } + + return fmt.Sprintf("%s/%s", rr.Name, rr.Version), nil +} + +const RuntimeReleasePath = "/var/runtime/runtime-release" + +const runtimeReleaseFileSizeLimitBytes = 1024 + +func GetRuntimeRelease(path string) (*RuntimeRelease, error) { + + pairs, err := ParsePropertiesFile(path, runtimeReleaseFileSizeLimitBytes) + if err != nil { + return nil, fmt.Errorf("could not parse %s: %w", path, err) + } + + return &RuntimeRelease{pairs["NAME"], pairs["VERSION"], Logging(pairs["LOGGING"])}, nil +} + +func GetRuntimeLoggingType(rr *RuntimeRelease) Logging { + if rr == nil { + return AmznStdout + } + return rr.Logging +} + +func ParsePropertiesFile(path string, limitBytes int64) (map[string]string, error) { + f, err := os.Open(path) + if err != nil { + return nil, fmt.Errorf("could not open %s: %w", path, err) + } + defer func() { + if err := f.Close(); err != nil { + slog.Warn("could not close file", "path", path, "err", err) + } + }() + + stat, err := f.Stat() + if err != nil { + return nil, fmt.Errorf("could not stat file: %s", path) + } + if stat.Size() > limitBytes { + return nil, fmt.Errorf("file %s size %d > %d limit", path, stat.Size(), limitBytes) + } + + pairs := make(map[string]string) + + s := bufio.NewScanner(f) + for s.Scan() { + if s.Text() == "" || strings.HasPrefix(s.Text(), "#") { + continue + } + k, v, found := strings.Cut(s.Text(), "=") + if !found { + return nil, fmt.Errorf("could not parse key-value pair from a line: %s", s.Text()) + } + pairs[k] = strings.Trim(v, "'\"") + } + if err := s.Err(); err != nil { + return nil, fmt.Errorf("failed to read properties file: %w", err) + } + + return pairs, nil +} diff --git a/internal/lambda-managed-instances/rapidcore/runtime_release_test.go b/internal/lambda-managed-instances/rapidcore/runtime_release_test.go new file mode 100644 index 0000000..e47d358 --- /dev/null +++ b/internal/lambda-managed-instances/rapidcore/runtime_release_test.go @@ -0,0 +1,151 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package rapidcore + +import ( + "os" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGetRuntimeRelease(t *testing.T) { + tests := []struct { + name string + content string + want *RuntimeRelease + wantErr bool + }{ + { + name: "simple", + content: "NAME=foo\nVERSION=bar\nLOGGING=baz\n", + want: &RuntimeRelease{"foo", "bar", "baz"}, + }, + { + name: "no trailing new line", + content: "NAME=foo\nVERSION=bar\nLOGGING=baz", + want: &RuntimeRelease{"foo", "bar", "baz"}, + }, + { + name: "nonexistent keys", + content: "LOGGING=baz\n", + want: &RuntimeRelease{"", "", "baz"}, + }, + { + name: "empty value", + content: "NAME=\nVERSION=\nLOGGING=\n", + want: &RuntimeRelease{"", "", ""}, + }, + { + name: "delimiter in value", + content: "NAME=Foo=Bar\nVERSION=bar\nLOGGING=baz\n", + want: &RuntimeRelease{"Foo=Bar", "bar", "baz"}, + }, + { + name: "empty file", + want: &RuntimeRelease{"", "", ""}, + }, + { + name: "quotes", + content: "NAME=\"foo\"\nVERSION='bar'\n", + want: &RuntimeRelease{"foo", "bar", ""}, + }, + { + name: "double quotes", + content: "NAME='\"foo\"'\nVERSION=\"'bar'\"\n", + want: &RuntimeRelease{"foo", "bar", ""}, + }, + { + name: "empty lines", + content: "\nNAME=foo\n\nVERSION=bar\n\nLOGGING=baz\n\n", + want: &RuntimeRelease{"foo", "bar", "baz"}, + }, + { + name: "comments", + content: "# comment 1\nNAME=foo\n# comment 2\nVERSION=bar\n# comment 3\nLOGGING=baz\n# comment 4\n", + want: &RuntimeRelease{"foo", "bar", "baz"}, + }, + { + name: "file exceeds size limit", + content: "NAME=foo\nVERSION=bar\nLOGGING=" + strings.Repeat("a", runtimeReleaseFileSizeLimitBytes), + wantErr: true, + }, + { + name: "invalid format", + content: "NAME=foo\nVERSION=bar\nLOGGING=baz\nLAST_LINE_IS_NOT_KV_PAIR", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f, err := os.CreateTemp("", "runtime-release") + require.NoError(t, err) + defer func() { + require.NoError(t, os.Remove(f.Name())) + }() + + _, err = f.WriteString(tt.content) + require.NoError(t, err) + + got, err := GetRuntimeRelease(f.Name()) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + } + }) + } +} + +func TestGetRuntimeRelease_NotFound(t *testing.T) { + _, err := GetRuntimeRelease("/sys/not-exists") + assert.Error(t, err) +} + +func TestRuntimeRelease_GetUAProduct(t *testing.T) { + tests := []struct { + name string + rr *RuntimeRelease + want string + wantErr bool + }{ + { + name: "no name", + rr: &RuntimeRelease{ + Version: "2.7.7-6419c85c", + }, + wantErr: true, + }, + { + name: "no version", + rr: &RuntimeRelease{ + Name: "Ruby", + }, + want: "Ruby", + }, + { + name: "name and version", + rr: &RuntimeRelease{ + Name: "Ruby", + Version: "2.7.7-6419c85c", + }, + want: "Ruby/2.7.7-6419c85c", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.rr.GetUAProduct() + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + } + }) + } +} diff --git a/internal/lambda-managed-instances/raptor/app.go b/internal/lambda-managed-instances/raptor/app.go new file mode 100644 index 0000000..9a9ac06 --- /dev/null +++ b/internal/lambda-managed-instances/raptor/app.go @@ -0,0 +1,209 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package raptor + +import ( + "context" + "errors" + "fmt" + "log/slog" + "net/netip" + "sync" + "sync/atomic" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/logging" + internalModel "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/raptor/internal" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/servicelogs" +) + +var ( + ErrNotInitialized = errors.New("sandbox is not initialized") + ErrorEnvironmentUnhealthy = errors.New("environment is unhealthy") +) + +type App struct { + rapidCtx interop.RapidContext + state *internal.StateGuard + shutdownOnce sync.Once + + err atomic.Value + doneCh chan struct{} + telemetryFDSocketPath string + raptorLogger raptorLogger +} + +func StartApp(deps rapid.Dependencies, telemetryFDSocketPath string, raptorLogger raptorLogger) (*App, error) { + ctx := context.Background() + rapidCtx, err := rapid.Start(ctx, deps) + if err != nil { + return nil, err + } + + app := &App{ + rapidCtx: rapidCtx, + state: internal.NewStateGuard(), + doneCh: make(chan struct{}), + telemetryFDSocketPath: telemetryFDSocketPath, + raptorLogger: raptorLogger, + } + + app.StartProcessTerminationMonitor() + + return app, nil +} + +func (a *App) Init(ctx context.Context, init *internalModel.InitRequestMessage, initMetrics interop.InitMetrics) model.AppError { + + if err := a.state.SetState(internal.Initializing); err != nil { + logging.Error(ctx, "State error : can't switch to initializing", "state", a.state.GetState(), "err", err) + return interop.ClientError{ + ClientError: model.NewClientError( + err, + model.ErrorSeverityFatal, + model.ErrorInvalidRequest, + ), + } + } + + initMessage := getInitExecutionData(init, a.RuntimeAPIAddrPort().String(), a.telemetryFDSocketPath) + a.raptorLogger.SetInitData(&initMessage) + logging.Debug(ctx, "Start handling Init", "initRequest", init) + initErr := a.rapidCtx.HandleInit(ctx, initMessage, initMetrics) + + if initErr != nil { + logging.Err(ctx, "Received Init error", initErr) + a.Shutdown(initErr) + + return initErr + } + + logging.Debug(ctx, "Received Init Success") + + err := a.state.SetState(internal.Initialized) + if err != nil { + logging.Error(ctx, "State error : can't switch to initalized state") + + return model.NewClientError(err, model.ErrorSeverityFatal, model.ErrorInvalidRequest) + } + + return nil +} + +func (a *App) Invoke(ctx context.Context, invokeMsg interop.InvokeRequest, metrics interop.InvokeMetrics) (err model.AppError, wasResponseSent bool) { + currState := a.state.GetState() + switch currState { + case internal.Initialized: + return a.rapidCtx.HandleInvoke(ctx, invokeMsg, metrics) + case internal.Idle, internal.Initializing: + logging.Error(ctx, "Sandbox not Initialized", "state", currState) + return interop.ClientError{ + ClientError: model.NewClientError( + ErrNotInitialized, + model.ErrorSeverityError, + model.ErrorInitIncomplete, + ), + }, false + case internal.ShuttingDown, internal.Shutdown: + logging.Error(ctx, "Invoke while Sandbox shutting down") + return interop.ClientError{ + ClientError: model.NewClientError( + ErrorEnvironmentUnhealthy, + model.ErrorSeverityFatal, + model.ErrorEnvironmentUnhealthy, + ), + }, false + default: + panic(fmt.Sprintf("unknown current state: %d", currState)) + } +} + +func (a *App) Shutdown(shutdownReason model.AppError) { + + a.shutdownOnce.Do(func() { + + if err := a.state.SetState(internal.ShuttingDown); err != nil { + + return + } + + metrics := rapid.NewShutdownMetrics(a.raptorLogger, shutdownReason) + shutdownDuration := metrics.CreateDurationMetric(interop.TotalDurationMetric) + + slog.Info("Shutting down", "reason", shutdownReason) + + if shutdownReason != nil { + a.err.Store(shutdownReason) + } + + var shutdownErr model.AppError + if shutdownErr = a.rapidCtx.HandleShutdown(shutdownReason, metrics); shutdownErr != nil { + slog.Warn("Shutdown error", "err", shutdownErr) + } + + shutdownDuration.Done() + metrics.SendMetrics(shutdownErr) + + if err := a.state.SetState(internal.Shutdown); err != nil { + slog.Error("could not change status from ShuttingDown to Shutdown", "status", a.state.GetState()) + } + close(a.doneCh) + }) +} + +func (a *App) RuntimeAPIAddrPort() netip.AddrPort { + return a.rapidCtx.RuntimeAPIAddrPort() +} + +func (a *App) StartProcessTerminationMonitor() { + + go func() { + + appErr := <-a.rapidCtx.ProcessTerminationNotifier() + + slog.Debug("Process termination monitor received", "error", appErr) + a.Shutdown(appErr) + }() +} + +func (a *App) HandleHealthCheck() interop.HealthCheckResponse { + currState := a.state.GetState() + + switch currState { + + case internal.Idle, internal.Initializing, internal.Initialized: + return interop.HealthyContainerResponse{} + case internal.ShuttingDown, internal.Shutdown: + if appError := a.Err(); appError != nil { + return interop.UnhealthyContainerResponse{ + ErrorType: appError.ErrorType(), + } + } + return interop.UnhealthyContainerResponse{ + ErrorType: "", + } + default: + panic(fmt.Sprintf("unknown current state: %d", currState)) + } +} + +func (a *App) Done() <-chan struct{} { + return a.doneCh +} + +func (a *App) Err() model.AppError { + err := a.err.Load() + if err != nil { + return err.(model.AppError) + } + return nil +} + +type raptorLogger interface { + servicelogs.Logger + SetInitData(initData interop.InitStaticDataProvider) +} diff --git a/internal/lambda-managed-instances/raptor/app_test.go b/internal/lambda-managed-instances/raptor/app_test.go new file mode 100644 index 0000000..bf5c9d2 --- /dev/null +++ b/internal/lambda-managed-instances/raptor/app_test.go @@ -0,0 +1,263 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package raptor + +import ( + "context" + "net/netip" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + internalModel "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/raptor/internal" +) + +func TestStartApp(t *testing.T) { + t.Run("successful start", func(t *testing.T) { + app := &App{ + state: internal.NewStateGuard(), + doneCh: make(chan struct{}), + } + assert.Equal(t, internal.Idle, app.state.GetState()) + }) +} + +func TestAppInitSuccessful(t *testing.T) { + mockRapidCtx, app, initRequest, initMetrics, _, _ := setupAppTest(t) + assert.Equal(t, internal.Idle, app.state.GetState()) + + mockRapidCtx.On("HandleInit", mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() + + initErr := app.Init(context.Background(), initRequest, initMetrics) + assert.Nil(t, initErr) + + assert.Equal(t, internal.Initialized, app.state.GetState()) + mockRapidCtx.AssertExpectations(t) +} + +func TestAppInitFailure(t *testing.T) { + expectedErr := model.NewCustomerError(model.ErrorReasonRuntimeExecFailed, model.WithSeverity(model.ErrorSeverityFatal)) + mockRapidCtx, app, initRequest, initMetrics, _, _ := setupAppTest(t) + mockRapidCtx.On("HandleInit", mock.Anything, mock.Anything, mock.Anything).Return(expectedErr) + mockRapidCtx.On("HandleShutdown", mock.Anything, mock.Anything).Return(nil) + + initErr := app.Init(context.Background(), initRequest, initMetrics) + assert.Equal(t, expectedErr, initErr) + + assert.Equal(t, internal.Shutdown, app.state.GetState()) + mockRapidCtx.AssertExpectations(t) +} + +func TestAppInitInvalidState(t *testing.T) { + _, app, initRequest, initMetrics, _, _ := setupAppTest(t) + + require.NoError(t, app.state.SetState(internal.ShuttingDown)) + + response := app.Init(context.Background(), initRequest, initMetrics) + clientErr, ok := response.(interop.ClientError) + assert.True(t, ok) + assert.Equal(t, model.ErrorInvalidRequest, clientErr.ErrorType()) +} + +func TestAppInitSuccessfulButStateError(t *testing.T) { + mockRapidCtx, app, initRequest, initMetrics, _, _ := setupAppTest(t) + assert.Equal(t, internal.Idle, app.state.GetState()) + + mockRapidCtx.On("HandleInit", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + + require.NoError(t, app.state.SetState(internal.ShuttingDown)) + }).Return(nil) + + initErr := app.Init(context.Background(), initRequest, initMetrics) + assert.NotNil(t, initErr) + assert.Equal(t, model.ErrorSeverityFatal, initErr.Severity()) + assert.Equal(t, model.ErrorInvalidRequest, initErr.ErrorType()) + + mockRapidCtx.AssertExpectations(t) +} + +func TestStartProcessTerminationMonitor(t *testing.T) { + mockRapidCtx := interop.NewMockRapidContext(t) + + termChan := make(chan model.AppError) + + notifierCalled := make(chan struct{}) + + mockRapidCtx.On("ProcessTerminationNotifier").Return((<-chan model.AppError)(termChan)).Run(func(args mock.Arguments) { + close(notifierCalled) + }).Maybe() + mockRapidCtx.On("HandleShutdown", mock.Anything, mock.Anything).Return(nil) + + mockLogger := newMockRaptorLogger(t) + mockLogger.On("Log", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything) + + app := &App{ + rapidCtx: mockRapidCtx, + state: internal.NewStateGuard(), + doneCh: make(chan struct{}), + raptorLogger: mockLogger, + } + + assert.Equal(t, internal.Idle, app.state.GetState()) + app.StartProcessTerminationMonitor() + + <-notifierCalled + close(termChan) + + stateChanged := make(chan struct{}) + go func() { + for { + if app.state.GetState() == internal.Shutdown { + close(stateChanged) + return + } + time.Sleep(1 * time.Millisecond) + } + }() + + select { + case <-stateChanged: + + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for state to change to Shutdown") + } + + assert.Equal(t, internal.Shutdown, app.state.GetState()) + mockRapidCtx.AssertExpectations(t) +} + +func TestAppShutdown(t *testing.T) { + mockRapidCtx, app, _, _, _, _ := setupAppTest(t) + + mockRapidCtx.On("HandleShutdown", mock.Anything, mock.Anything).Return(nil) + + assert.NoError(t, app.Err()) + select { + case <-app.Done(): + t.Fatal("app.Done must have been blocked") + default: + } + + err := model.NewCustomerError(model.ErrorRuntimeUnknown) + app.Shutdown(err) + + assert.Equal(t, err, app.Err()) + select { + case <-app.Done(): + default: + t.Fatal("app.Done must have been unblocked") + } + + assert.Equal(t, internal.Shutdown, app.state.GetState()) +} + +func TestAppInvokeStateValidation(t *testing.T) { + testCases := []struct { + name string + states []internal.Status + wantErrorType model.ErrorType + wantError error + }{ + { + name: "Idle", + states: []internal.Status{}, + wantErrorType: model.ErrorInitIncomplete, + wantError: ErrNotInitialized, + }, + { + name: "Initializing", + states: []internal.Status{internal.Initializing}, + wantErrorType: model.ErrorInitIncomplete, + wantError: ErrNotInitialized, + }, + { + name: "ShuttingDown", + states: []internal.Status{internal.ShuttingDown}, + wantErrorType: model.ErrorEnvironmentUnhealthy, + wantError: ErrorEnvironmentUnhealthy, + }, + { + name: "Shutdown", + states: []internal.Status{internal.ShuttingDown, internal.Shutdown}, + wantErrorType: model.ErrorEnvironmentUnhealthy, + wantError: ErrorEnvironmentUnhealthy, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + _, app, _, _, invokeMsg, invokeMetrics := setupAppTest(t) + + for _, state := range tc.states { + require.NoError(t, app.state.SetState(state)) + } + + err, wasResponseSent := app.Invoke(context.Background(), invokeMsg, invokeMetrics) + + assert.False(t, wasResponseSent) + assert.ErrorAs(t, err, &interop.ClientError{}) + assert.Equal(t, tc.wantErrorType, err.ErrorType()) + assert.Equal(t, tc.wantError, err.Unwrap()) + }) + } +} + +func setupAppTest(t *testing.T) (*interop.MockRapidContext, *App, *internalModel.InitRequestMessage, interop.InitMetrics, interop.InvokeRequest, interop.InvokeMetrics) { + mockRapidCtx := interop.NewMockRapidContext(t) + + mockAddr, _ := netip.ParseAddrPort("127.0.0.1:8080") + mockRapidCtx.On("RuntimeAPIAddrPort").Return(mockAddr).Maybe() + + mockChan := make(chan model.AppError) + mockRapidCtx.On("ProcessTerminationNotifier").Return((<-chan model.AppError)(mockChan)).Maybe() + + mockLogger := newMockRaptorLogger(t) + mockLogger.On("Log", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Maybe() + mockLogger.On("SetInitData", mock.Anything).Maybe() + + app := &App{ + rapidCtx: mockRapidCtx, + state: internal.NewStateGuard(), + doneCh: make(chan struct{}), + raptorLogger: mockLogger, + } + + app.StartProcessTerminationMonitor() + initRequest := &internalModel.InitRequestMessage{ + Handler: "handler", + TelemetryAPIAddress: internalModel.TelemetryAddr(netip.MustParseAddrPort("127.0.0.1:8081")), + } + + mockInvokeReq := interop.NewMockInvokeRequest(t) + + mockInvokeReq.On("ContentType").Return("").Maybe() + mockInvokeReq.On("InvokeID").Return("").Maybe() + mockInvokeReq.On("Deadline").Return(time.Time{}).Maybe() + mockInvokeReq.On("TraceId").Return("").Maybe() + mockInvokeReq.On("ClientContext").Return("").Maybe() + mockInvokeReq.On("CognitoId").Return("").Maybe() + mockInvokeReq.On("CognitoPoolId").Return("").Maybe() + mockInvokeReq.On("ResponseBandwidthRate").Return(int64(0)).Maybe() + mockInvokeReq.On("ResponseBandwidthBurstRate").Return(int64(0)).Maybe() + mockInvokeReq.On("MaxPayloadSize").Return(int64(0)).Maybe() + mockInvokeReq.On("BodyReader").Return(nil).Maybe() + mockInvokeReq.On("ResponseWriter").Return(nil).Maybe() + mockInvokeReq.On("SetResponseHeader", mock.Anything, mock.Anything).Return().Maybe() + mockInvokeReq.On("AddResponseHeader", mock.Anything, mock.Anything).Return().Maybe() + mockInvokeReq.On("WriteResponseHeaders", mock.Anything).Return().Maybe() + mockInvokeReq.On("UpdateFromInitData", mock.Anything).Return(nil).Maybe() + + mockInitMetrics := interop.NewMockInitMetrics(t) + mockInitMetrics.On("SetInitData", mock.Anything).Maybe() + + mockInvokeMetrics := interop.NewMockInvokeMetrics(t) + + return mockRapidCtx, app, initRequest, mockInitMetrics, mockInvokeReq, mockInvokeMetrics +} diff --git a/internal/lambda-managed-instances/raptor/internal/raptor_state.go b/internal/lambda-managed-instances/raptor/internal/raptor_state.go new file mode 100644 index 0000000..b5fa617 --- /dev/null +++ b/internal/lambda-managed-instances/raptor/internal/raptor_state.go @@ -0,0 +1,93 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package internal + +import ( + "errors" + "fmt" + "log/slog" + "sync" +) + +type Status int + +var ErrInvalidTransition = errors.New("invalid state transition") + +func NewInvalidTransitionError(from, to Status) error { + return fmt.Errorf("%w: from %v to %v", ErrInvalidTransition, from.String(), to.String()) +} + +const ( + Idle Status = iota + Initializing + Initialized + ShuttingDown + Shutdown +) + +func (s Status) String() string { + switch s { + case Idle: + return "Idle" + case Initializing: + return "Initializing" + case Initialized: + return "Initialized" + case ShuttingDown: + return "ShuttingDown" + case Shutdown: + return "Shutdown" + default: + return fmt.Sprintf("Status(%d)", int(s)) + } +} + +type StateGuard struct { + current Status + mu sync.RWMutex +} + +func NewStateGuard() *StateGuard { + return &StateGuard{} +} + +func (sm *StateGuard) GetState() Status { + sm.mu.RLock() + defer sm.mu.RUnlock() + slog.Debug("Raptor current state", "currState:", sm.current.String()) + return sm.current +} + +func (sm *StateGuard) SetState(state Status) error { + sm.mu.Lock() + defer sm.mu.Unlock() + if !isValidTransition(sm.current, state) { + slog.Error("invalid state transition", "from", sm.current.String(), "to", state.String()) + return NewInvalidTransitionError(sm.current, state) + } + sm.current = state + return nil +} + +func isValidTransition(from, to Status) bool { + + if to == Idle { + return true + } + + switch from { + case Idle: + return to == Initializing || to == ShuttingDown + case Initializing: + return to == Initialized || to == ShuttingDown + case Initialized: + return to == ShuttingDown + case ShuttingDown: + return to == Shutdown + case Shutdown: + return false + default: + return false + } +} diff --git a/internal/lambda-managed-instances/raptor/internal/raptor_state_test.go b/internal/lambda-managed-instances/raptor/internal/raptor_state_test.go new file mode 100644 index 0000000..5ab57fa --- /dev/null +++ b/internal/lambda-managed-instances/raptor/internal/raptor_state_test.go @@ -0,0 +1,239 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package internal + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewStateGuard(t *testing.T) { + sg := NewStateGuard() + assert.Equal(t, Idle, sg.current) +} + +func TestGetState(t *testing.T) { + sg := NewStateGuard() + assert.Equal(t, Idle, sg.GetState()) + + sg.current = Initialized + assert.Equal(t, Initialized, sg.GetState()) +} + +func TestSetState(t *testing.T) { + tests := []struct { + name string + initialState Status + targetState Status + expectError bool + expectedState Status + }{ + { + name: "Valid: Idle to Initializing", + initialState: Idle, + targetState: Initializing, + expectError: false, + expectedState: Initializing, + }, + { + name: "Valid: Idle to ShuttingDown", + initialState: Idle, + targetState: ShuttingDown, + expectError: false, + expectedState: ShuttingDown, + }, + { + name: "Valid: Initializing to Initialized", + initialState: Initializing, + targetState: Initialized, + expectError: false, + expectedState: Initialized, + }, + { + name: "Valid: Initializing to ShuttingDown", + initialState: Initializing, + targetState: ShuttingDown, + expectError: false, + expectedState: ShuttingDown, + }, + { + name: "Valid: Initialized to ShuttingDown", + initialState: Initialized, + targetState: ShuttingDown, + expectError: false, + expectedState: ShuttingDown, + }, + { + name: "Valid: ShuttingDown to Shutdown", + initialState: ShuttingDown, + targetState: Shutdown, + expectError: false, + expectedState: Shutdown, + }, + { + name: "Invalid: Idle to Initialized", + initialState: Idle, + targetState: Initialized, + expectError: true, + expectedState: Idle, + }, + { + name: "Invalid: Idle to Shutdown", + initialState: Idle, + targetState: Shutdown, + expectError: true, + expectedState: Idle, + }, + { + name: "Invalid: Initializing to Idle", + initialState: Initializing, + targetState: Idle, + expectError: false, + expectedState: Idle, + }, + { + name: "Invalid: Initialized to Idle", + initialState: Initialized, + targetState: Idle, + expectError: false, + expectedState: Idle, + }, + { + name: "Invalid: Initialized to Initializing", + initialState: Initialized, + targetState: Initializing, + expectError: true, + expectedState: Initialized, + }, + { + name: "Invalid: ShuttingDown to Idle", + initialState: ShuttingDown, + targetState: Idle, + expectError: false, + expectedState: Idle, + }, + { + name: "Valid: Shutdown to Idle", + initialState: Shutdown, + targetState: Idle, + expectError: false, + expectedState: Idle, + }, + { + name: "Invalid: Shutdown to other states", + initialState: Shutdown, + targetState: Initializing, + expectError: true, + expectedState: Shutdown, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sg := &StateGuard{current: tt.initialState} + err := sg.SetState(tt.targetState) + + if tt.expectError { + assert.ErrorIs(t, err, ErrInvalidTransition) + assert.Equal(t, tt.expectedState, sg.current) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expectedState, sg.current) + } + }) + } +} + +func TestIsValidTransition(t *testing.T) { + tests := []struct { + name string + from Status + to Status + expected bool + }{ + + {"Idle to Initializing", Idle, Initializing, true}, + {"Idle to ShuttingDown", Idle, ShuttingDown, true}, + {"Initializing to Initialized", Initializing, Initialized, true}, + {"Initializing to ShuttingDown", Initializing, ShuttingDown, true}, + {"Initialized to ShuttingDown", Initialized, ShuttingDown, true}, + {"ShuttingDown to Shutdown", ShuttingDown, Shutdown, true}, + + {"Idle to Initialized", Idle, Initialized, false}, + {"Idle to Shutdown", Idle, Shutdown, false}, + {"Initializing to Idle", Initializing, Idle, true}, + {"Initialized to Idle", Initialized, Idle, true}, + {"Initialized to Initializing", Initialized, Initializing, false}, + {"ShuttingDown to Idle", ShuttingDown, Idle, true}, + {"ShuttingDown to Initializing", ShuttingDown, Initializing, false}, + {"ShuttingDown to Initialized", ShuttingDown, Initialized, false}, + {"Shutdown to Idle", Shutdown, Idle, true}, + {"Shutdown to Initializing", Shutdown, Initializing, false}, + {"Shutdown to Initialized", Shutdown, Initialized, false}, + {"Shutdown to ShuttingDown", Shutdown, ShuttingDown, false}, + + {"Idle to Idle", Idle, Idle, true}, + {"Initializing to Initializing", Initializing, Initializing, false}, + {"Initialized to Initialized", Initialized, Initialized, false}, + {"ShuttingDown to ShuttingDown", ShuttingDown, ShuttingDown, false}, + {"Shutdown to Shutdown", Shutdown, Shutdown, false}, + + {"Invalid state to Idle", Status(99), Idle, true}, + {"Idle to invalid state", Idle, Status(99), false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isValidTransition(tt.from, tt.to) + assert.Equal(t, tt.expected, result, "isValidTransition(%v, %v) should return %v", tt.from, tt.to, tt.expected) + }) + } +} + +func TestStatus_String(t *testing.T) { + tests := []struct { + name string + status Status + expected string + }{ + { + name: "Idle", + status: Idle, + expected: "Idle", + }, + { + name: "initializing_status", + status: Initializing, + expected: "Initializing", + }, + { + name: "initialized_status", + status: Initialized, + expected: "Initialized", + }, + { + name: "shutting_down_status", + status: ShuttingDown, + expected: "ShuttingDown", + }, + { + name: "shutdown_status", + status: Shutdown, + expected: "Shutdown", + }, + { + name: "unknown_status", + status: Status(99), + expected: "Status(99)", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.status.String() + assert.Equal(t, tt.expected, result, "Status.String() should return correct string representation") + }) + } +} diff --git a/internal/lambda-managed-instances/raptor/mock_address.go b/internal/lambda-managed-instances/raptor/mock_address.go new file mode 100644 index 0000000..fdfe31b --- /dev/null +++ b/internal/lambda-managed-instances/raptor/mock_address.go @@ -0,0 +1,64 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package raptor + +import ( + net "net" + + mock "github.com/stretchr/testify/mock" +) + +type MockAddress struct { + mock.Mock +} + +func (_m *MockAddress) Protocol() string { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Protocol") + } + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +func (_m *MockAddress) String() string { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for String") + } + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +func (_m *MockAddress) UpdateFromListener(listener net.Listener) { + _m.Called(listener) +} + +func NewMockAddress(t interface { + mock.TestingT + Cleanup(func()) +}) *MockAddress { + mock := &MockAddress{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/lambda-managed-instances/raptor/mock_raptor_logger.go b/internal/lambda-managed-instances/raptor/mock_raptor_logger.go new file mode 100644 index 0000000..f926025 --- /dev/null +++ b/internal/lambda-managed-instances/raptor/mock_raptor_logger.go @@ -0,0 +1,54 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package raptor + +import ( + mock "github.com/stretchr/testify/mock" + interop "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + + servicelogs "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/servicelogs" + + time "time" +) + +type mockRaptorLogger struct { + mock.Mock +} + +func (_m *mockRaptorLogger) Close() error { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Close") + } + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +func (_m *mockRaptorLogger) Log(op servicelogs.Operation, opStart time.Time, props []servicelogs.Tuple, dims []servicelogs.Tuple, metrics []servicelogs.Metric) { + _m.Called(op, opStart, props, dims, metrics) +} + +func (_m *mockRaptorLogger) SetInitData(initData interop.InitStaticDataProvider) { + _m.Called(initData) +} + +func newMockRaptorLogger(t interface { + mock.TestingT + Cleanup(func()) +}) *mockRaptorLogger { + mock := &mockRaptorLogger{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/lambda-managed-instances/raptor/mock_shutdown_handler.go b/internal/lambda-managed-instances/raptor/mock_shutdown_handler.go new file mode 100644 index 0000000..a8075ab --- /dev/null +++ b/internal/lambda-managed-instances/raptor/mock_shutdown_handler.go @@ -0,0 +1,29 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package raptor + +import ( + mock "github.com/stretchr/testify/mock" + model "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" +) + +type mockShutdownHandler struct { + mock.Mock +} + +func (_m *mockShutdownHandler) Shutdown(shutdownReason model.AppError) { + _m.Called(shutdownReason) +} + +func newMockShutdownHandler(t interface { + mock.TestingT + Cleanup(func()) +}) *mockShutdownHandler { + mock := &mockShutdownHandler{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/lambda-managed-instances/raptor/raptor_utils.go b/internal/lambda-managed-instances/raptor/raptor_utils.go new file mode 100644 index 0000000..90d2dd7 --- /dev/null +++ b/internal/lambda-managed-instances/raptor/raptor_utils.go @@ -0,0 +1,74 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package raptor + +import ( + "net/netip" + "time" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + internalModel "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapidcore/env" + supvmodel "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/supervisor/model" +) + +func getInitExecutionData(initRequest *internalModel.InitRequestMessage, runtimePort, telemetryFDSocketPath string) interop.InitExecutionData { + + runtimeEnv, extensionEnv := env.SetupEnvironment(initRequest, runtimePort, telemetryFDSocketPath) + + initMessage := interop.InitExecutionData{ + + ExtensionEnv: extensionEnv, + Runtime: model.Runtime{ + ExecConfig: model.RuntimeExec{ + Cmd: initRequest.RuntimeBinaryCommand, + WorkingDir: initRequest.CurrentWorkingDir, + Env: runtimeEnv, + }, + }, + + Credentials: model.Credentials{ + AwsKey: initRequest.AwsKey, + AwsSecret: initRequest.AwsSecret, + AwsSession: initRequest.AwsSession, + }, + LogGroupName: initRequest.LogGroupName, + LogStreamName: initRequest.LogStreamName, + FunctionMetadata: model.FunctionMetadata{ + AccountID: initRequest.AccountID, + FunctionName: initRequest.TaskName, + FunctionVersion: initRequest.FunctionVersion, + MemorySizeBytes: uint64(initRequest.MemorySizeBytes), + Handler: initRequest.Handler, + RuntimeInfo: model.RuntimeInfo{ + Arn: initRequest.RuntimeArn, + Version: initRequest.RuntimeVersion, + }, + }, + RuntimeManagedLoggingFormats: []supvmodel.ManagedLoggingFormat{ + supvmodel.LineBasedManagedLogging, + }, + + StaticData: interop.EEStaticData{ + InitTimeout: time.Duration(initRequest.InitTimeout), + FunctionTimeout: time.Duration(initRequest.InvokeTimeout), + FunctionARN: initRequest.FunctionARN, + FunctionVersionID: initRequest.FunctionVersionID, + LogGroupName: initRequest.LogGroupName, + LogStreamName: initRequest.LogStreamName, + XRayTracingMode: initRequest.XrayTracingMode, + ArtefactType: initRequest.ArtefactType, + AmiId: initRequest.AmiId, + RuntimeVersion: initRequest.RuntimeVersion, + AvailabilityZoneId: initRequest.AvailabilityZoneId, + }, + TelemetrySubscriptionConfig: interop.TelemetrySubscriptionConfig{ + APIAddr: netip.AddrPort(initRequest.TelemetryAPIAddress), + Passphrase: initRequest.TelemetryPassphrase, + }, + } + + return initMessage +} diff --git a/internal/lambda-managed-instances/raptor/server.go b/internal/lambda-managed-instances/raptor/server.go new file mode 100644 index 0000000..27c00ce --- /dev/null +++ b/internal/lambda-managed-instances/raptor/server.go @@ -0,0 +1,136 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package raptor + +import ( + "context" + "log/slog" + "net" + "net/http" + "net/netip" + "os" + "os/signal" + "sync" + "sync/atomic" + "syscall" + "time" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" +) + +func StartServer(shutdownHandler shutdownHandler, handler http.Handler, addr Address) (*Server, error) { + listener, err := net.Listen(addr.Protocol(), addr.String()) + if err != nil { + return nil, err + } + + addr.UpdateFromListener(listener) + + s := &Server{ + httpServer: &http.Server{Handler: handler, ReadHeaderTimeout: 15 * time.Second}, + doneCh: make(chan struct{}), + shutdownHandler: shutdownHandler, + Addr: addr, + } + + go func() { + if err := s.httpServer.Serve(listener); err != nil { + + s.Shutdown(err) + } + }() + return s, nil +} + +func (s *Server) Shutdown(err error) { + s.shutdownOnce.Do(func() { + + s.shutdownHandler.Shutdown(model.NewClientError(err, model.ErrorSeverityFatal, model.ErrorExecutionEnvironmentShutdown)) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + slog.Info("Shutting down HTTP server...") + if err := s.httpServer.Shutdown(ctx); err != nil { + slog.Warn("could not gracefully shutdown EA http server", "err", err) + } + + if err != nil { + s.err.Store(err) + } + close(s.doneCh) + }) +} + +func (s *Server) Done() <-chan struct{} { + return s.doneCh +} + +func (s *Server) Err() error { + if err := s.err.Load(); err != nil { + return err.(error) + } + return nil +} + +func (s *Server) AttachShutdownSignalHandler(sigCh chan os.Signal) { + go func() { + signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGINT) + sig := <-sigCh + slog.Info("SignalHandler received:", "signal", sig) + s.Shutdown(nil) + }() +} + +type Address interface { + String() string + Protocol() string + + UpdateFromListener(listener net.Listener) +} + +type TCPAddress struct { + AddrPort netip.AddrPort +} + +func (t *TCPAddress) String() string { + return t.AddrPort.String() +} + +func (t *TCPAddress) Protocol() string { + return "tcp" +} + +func (t *TCPAddress) UpdateFromListener(listener net.Listener) { + t.AddrPort = netip.MustParseAddrPort(listener.Addr().String()) +} + +type UnixAddress struct { + Path string +} + +func (u *UnixAddress) String() string { + return u.Path +} + +func (u *UnixAddress) Protocol() string { + return "unix" +} + +func (u *UnixAddress) UpdateFromListener(listener net.Listener) { + +} + +type shutdownHandler interface { + Shutdown(shutdownReason model.AppError) +} + +type Server struct { + httpServer *http.Server + Addr Address + + shutdownHandler shutdownHandler + + shutdownOnce sync.Once + doneCh chan struct{} + err atomic.Value +} diff --git a/internal/lambda-managed-instances/raptor/server_test.go b/internal/lambda-managed-instances/raptor/server_test.go new file mode 100644 index 0000000..eabae71 --- /dev/null +++ b/internal/lambda-managed-instances/raptor/server_test.go @@ -0,0 +1,86 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package raptor + +import ( + "math" + "math/rand/v2" + "net/http" + "net/netip" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/testutils" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/testutils/mocks" +) + +func TestStartNewServer_UDS(t *testing.T) { + var err error + socketPath, err := testutils.CreateTempSocketPath(t) + require.NoError(t, err) + + mockShutdownHandler := newMockShutdownHandler(t) + mockShutdownHandler.On("Shutdown", mock.Anything).Return().Maybe() + handler := mocks.NewNoOpHandler() + + server, err := StartServer(mockShutdownHandler, handler, &UnixAddress{ + Path: socketPath, + }) + require.NoError(t, err) + assert.Equal(t, socketPath, server.Addr.String()) + + client := testutils.NewUnixSocketClient(socketPath) + req, err := http.NewRequest("GET", "http://unix/", nil) + require.NoError(t, err) + + _, err = client.Do(req) + require.NoError(t, err) +} + +func TestStartNewServer_TCP(t *testing.T) { + var err error + port := uint16(rand.UintN(math.MaxUint16-1024)) + 1024 + eaAPIAddrPort := netip.AddrPortFrom(netip.MustParseAddr("127.0.0.1"), port) + + mockShutdownHandler := newMockShutdownHandler(t) + mockShutdownHandler.On("Shutdown", mock.Anything).Return().Maybe() + handler := mocks.NewNoOpHandler() + + server, err := StartServer(mockShutdownHandler, handler, &TCPAddress{ + eaAPIAddrPort, + }) + require.NoError(t, err) + assert.Equal(t, eaAPIAddrPort, server.Addr.(*TCPAddress).AddrPort) + + _, err = http.Get("http://" + server.Addr.String()) + require.NoError(t, err) +} + +func TestStartNewServer_UDS_ListenError(t *testing.T) { + invalidSocketPath := filepath.Join("/nonexistent", "socket.sock") + + mockShutdownHandler := newMockShutdownHandler(t) + handler := mocks.NewNoOpHandler() + + _, err := StartServer(mockShutdownHandler, handler, &UnixAddress{ + Path: invalidSocketPath, + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), invalidSocketPath) +} + +func TestStartNewServe_TCP_ListenError(t *testing.T) { + eaAPIAddrPort := netip.MustParseAddrPort("1.1.1.1:49275") + + mockShutdownHandler := newMockShutdownHandler(t) + handler := mocks.NewNoOpHandler() + + _, err := StartServer(mockShutdownHandler, handler, &TCPAddress{eaAPIAddrPort}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "1.1.1.1:49275") +} diff --git a/internal/lambda-managed-instances/servicelogs/logger.go b/internal/lambda-managed-instances/servicelogs/logger.go new file mode 100644 index 0000000..a5b2450 --- /dev/null +++ b/internal/lambda-managed-instances/servicelogs/logger.go @@ -0,0 +1,54 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package servicelogs + +import ( + "io" + "time" +) + +type Logger interface { + io.Closer + Log(op Operation, opStart time.Time, props []Property, dims []Dimension, metrics []Metric) +} + +type Operation string + +const ( + InitOp Operation = "Init" + InvokeOp Operation = "Invoke" + ShutdownOp Operation = "Shutdown" +) + +type Tuple struct { + Name string + Value string +} + +type ( + Dimension = Tuple + Property = Tuple +) + +type MetricType uint8 + +const ( + CounterType MetricType = iota + TimerType +) + +type Metric struct { + Type MetricType + Key string + Value float64 + Dims []Dimension +} + +func Counter(name string, value float64, dims ...Dimension) Metric { + return Metric{Type: CounterType, Key: name, Value: value, Dims: dims} +} + +func Timer(name string, duration time.Duration, dims ...Dimension) Metric { + return Metric{Type: TimerType, Key: name, Value: float64(duration.Microseconds()), Dims: dims} +} diff --git a/internal/lambda-managed-instances/servicelogs/mock_logger.go b/internal/lambda-managed-instances/servicelogs/mock_logger.go new file mode 100644 index 0000000..be881af --- /dev/null +++ b/internal/lambda-managed-instances/servicelogs/mock_logger.go @@ -0,0 +1,47 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package servicelogs + +import ( + time "time" + + mock "github.com/stretchr/testify/mock" +) + +type MockLogger struct { + mock.Mock +} + +func (_m *MockLogger) Close() error { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Close") + } + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +func (_m *MockLogger) Log(op Operation, opStart time.Time, props []Tuple, dims []Tuple, metrics []Metric) { + _m.Called(op, opStart, props, dims, metrics) +} + +func NewMockLogger(t interface { + mock.TestingT + Cleanup(func()) +}) *MockLogger { + mock := &MockLogger{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/lambda-managed-instances/supervisor/local/process.go b/internal/lambda-managed-instances/supervisor/local/process.go new file mode 100644 index 0000000..ec0bc0c --- /dev/null +++ b/internal/lambda-managed-instances/supervisor/local/process.go @@ -0,0 +1,324 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package local + +import ( + "context" + "errors" + "fmt" + "log/slog" + "os" + "os/exec" + "strconv" + "sync" + "syscall" + "time" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapidcore/env" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/supervisor/model" +) + +var _ model.ProcessSupervisorClient = (*ProcessSupervisor)(nil) + +type process struct { + pid int + + termination chan struct{} +} + +type ProcessSupervisor struct { + events chan model.Event + processMapLock sync.Mutex + credential *syscall.Credential + processMap map[string]process + freezeThawCycleStart time.Time + msgLogFile *os.File + lowerPriorities bool +} + +type ProcessSupervisorOption func(ls *ProcessSupervisor) + +func MsgLogFile(file *os.File) ProcessSupervisorOption { + return func(ls *ProcessSupervisor) { + ls.msgLogFile = file + } +} + +func WithProcessCredential(credential *syscall.Credential) ProcessSupervisorOption { + return func(ls *ProcessSupervisor) { + ls.credential = credential + } +} + +func WithLowerPriorities(lowerPriorities bool) ProcessSupervisorOption { + return func(ls *ProcessSupervisor) { + ls.lowerPriorities = lowerPriorities + } +} + +func NewProcessSupervisor(opts ...ProcessSupervisorOption) *ProcessSupervisor { + ls := &ProcessSupervisor{ + events: make(chan model.Event), + processMap: make(map[string]process), + + credential: &syscall.Credential{ + Uid: uint32(993), + Gid: uint32(990), + }, + lowerPriorities: true, + } + + for _, option := range opts { + option(ls) + } + + return ls +} + +func (*ProcessSupervisor) Exit(ctx context.Context) {} + +func (s *ProcessSupervisor) Exec(ctx context.Context, req *model.ExecRequest) error { + + command := exec.Command(req.Path, req.Args...) + if req.Env != nil { + command.Env = env.MapToKVPairStrings(*req.Env) + } + + if req.Cwd != nil && *req.Cwd != "" { + command.Dir = *req.Cwd + } + + for _, format := range req.Logging.Managed.Formats { + if format == model.MessageBasedManagedLogging { + if s.msgLogFile == nil { + return errors.New("invalid message logging setup: could not find message log fd") + } + + command.ExtraFiles = []*os.File{s.msgLogFile} + command.Env = append(command.Env, "_LAMBDA_TELEMETRY_LOG_FD=3") + } + } + + command.Stdout = req.StdoutWriter + command.Stderr = req.StderrWriter + + command.SysProcAttr = &syscall.SysProcAttr{ + Setpgid: true, + Credential: s.credential, + } + + err := command.Start() + if err != nil { + return fmt.Errorf("failed to start command: %w", err) + } + + slog.Info("LocalProcessSupervisor.Exec", + "pid", strconv.Itoa(command.Process.Pid), + "path", command.Path, + "args", command.Args) + + pid := command.Process.Pid + termination := make(chan struct{}) + s.processMapLock.Lock() + s.processMap[req.Name] = process{ + pid: pid, + termination: termination, + } + s.processMapLock.Unlock() + + s.freezeThawCycleStart = time.Now() + + if err := s.setProcessGroupPriorities(pid); err != nil { + return fmt.Errorf("could not set process priorities: %w", err) + } + + go func() { + err = command.Wait() + + close(termination) + + var cell int32 + var exitStatus *int32 + var signo *int32 + var exitErr *exec.ExitError + var cause model.ProcessTerminationCause + + if err == nil { + + exitStatus = &cell + cause = model.Exited + } else if errors.As(err, &exitErr) { + if status, ok := exitErr.Sys().(syscall.WaitStatus); ok { + if code := status.ExitStatus(); code >= 0 { + + cell = int32(code) + exitStatus = &cell + cause = model.Exited + } else { + + cell = int32(status.Signal()) + signo = &cell + cause = model.Signaled + } + } + } + + if signo == nil && exitStatus == nil { + slog.Error("Cannot convert process exit status to unix WaitStatus. This is unexpected. Assuming ExitStatus 1") + cell = 1 + exitStatus = &cell + cause = model.Exited + } + + s.events <- model.Event{ + Time: time.Now().UnixMilli(), + Event: model.EventData{ + EvType: model.ProcessTerminationType, + Name: req.Name, + Cause: cause, + Signo: signo, + ExitStatus: exitStatus, + }, + } + }() + + return nil +} + +func (s *ProcessSupervisor) StopAll(ctx context.Context, deadline time.Time) error { + + s.processMapLock.Lock() + defer s.processMapLock.Unlock() + + nprocs := len(s.processMap) + + successes := make(chan struct{}) + errors := make(chan error) + for name, proc := range s.processMap { + go func(n string, p process) { + slog.Debug("Killing", "name", n) + err := kill(p, n, deadline) + if err != nil { + errors <- err + } else { + successes <- struct{}{} + } + }(name, proc) + } + + var err error + for i := 0; i < nprocs; i++ { + select { + case <-successes: + case e := <-errors: + if err == nil { + err = fmt.Errorf("shutdown failed: %s", e.Error()) + } + } + } + + s.processMap = make(map[string]process) + + return err +} + +func kill(p process, name string, deadline time.Time) error { + + select { + + case <-p.termination: + slog.Debug("Process already terminated", "name", name) + return nil + default: + slog.Info("Sending SIGKILL to process", "name", name, "pid", p.pid) + } + + if (time.Since(deadline)) > 0 { + return fmt.Errorf("invalid timeout while killing %s", name) + } + + pgid, err := syscall.Getpgid(p.pid) + + if err == nil { + + _ = syscall.Kill(-pgid, syscall.SIGKILL) + } else { + _ = syscall.Kill(p.pid, syscall.SIGKILL) + } + + ctx, cancel := context.WithDeadline(context.Background(), deadline) + defer cancel() + + select { + case <-p.termination: + return nil + case <-ctx.Done(): + return fmt.Errorf("timed out while trying to SIGKILL %s", name) + } +} + +func (s *ProcessSupervisor) Kill(ctx context.Context, req *model.KillRequest) error { + s.processMapLock.Lock() + process, ok := s.processMap[req.Name] + s.processMapLock.Unlock() + if !ok { + msg := "Unknown process" + return &model.SupervisorError{ + SourceErr: model.ErrorSourceClient, + ReasonErr: "ProcessNotFound", + CauseErr: msg, + } + } + + return kill(process, req.Name, req.Deadline) +} + +func (s *ProcessSupervisor) Terminate(ctx context.Context, req *model.TerminateRequest) error { + s.processMapLock.Lock() + process, ok := s.processMap[req.Name] + pid := process.pid + s.processMapLock.Unlock() + if !ok { + msg := "Unknown process" + err := &model.SupervisorError{ + SourceErr: model.ErrorSourceClient, + ReasonErr: "ProcessNotFound", + CauseErr: msg, + } + slog.Error("Process not found in local supervisor map", "name", req.Name, "err", err) + return err + } + + pgid, err := syscall.Getpgid(pid) + + if err == nil { + + _ = syscall.Kill(-pgid, syscall.SIGTERM) + } else { + _ = syscall.Kill(pid, syscall.SIGTERM) + } + + return nil +} + +func (s *ProcessSupervisor) Events(ctx context.Context) (<-chan model.Event, error) { + return s.events, nil +} + +func (s *ProcessSupervisor) setProcessGroupPriorities(pid int) error { + if !s.lowerPriorities { + return nil + } + + pgid, err := syscall.Getpgid(pid) + if err != nil { + return fmt.Errorf("could not get pgid for pid %d: %w", pid, err) + } + + if err := syscall.Setpriority(syscall.PRIO_PGRP, pgid, 10); err != nil { + return fmt.Errorf("failed to set nice score for %d: %w", pgid, err) + } + + return nil +} diff --git a/internal/lambda-managed-instances/supervisor/local/process_test.go b/internal/lambda-managed-instances/supervisor/local/process_test.go new file mode 100644 index 0000000..bd1f10f --- /dev/null +++ b/internal/lambda-managed-instances/supervisor/local/process_test.go @@ -0,0 +1,228 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package local + +import ( + "context" + "errors" + "syscall" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/supervisor/model" +) + +func TestRuntimeExec(t *testing.T) { + supv := NewProcessSupervisor(WithProcessCredential(nil), WithLowerPriorities(false)) + err := supv.Exec(context.Background(), &model.ExecRequest{ + Name: "agent", + Path: "/bin/bash", + }) + + assert.Nil(t, err) +} + +func TestInvalidRuntimeExec(t *testing.T) { + supv := NewProcessSupervisor() + err := supv.Exec(context.Background(), &model.ExecRequest{ + Name: "agent", + Path: "/bin/none", + }) + + require.Error(t, err) +} + +func TestEvents(t *testing.T) { + supv := NewProcessSupervisor(WithProcessCredential(nil), WithLowerPriorities(false)) + sync := make(chan struct{}) + go func() { + eventCh, err := supv.Events(context.Background()) + require.NoError(t, err) + + evt, ok := <-eventCh + require.True(t, ok) + termination := evt.Event.ProcessTerminated() + require.NotNil(t, termination) + assert.Equal(t, "agent", termination.Name) + sync <- struct{}{} + }() + + err := supv.Exec(context.Background(), &model.ExecRequest{ + Name: "agent", + Path: "/bin/bash", + }) + require.NoError(t, err) + <-sync +} + +func TestTerminate(t *testing.T) { + supv := NewProcessSupervisor(WithProcessCredential(nil), WithLowerPriorities(false)) + err := supv.Exec(context.Background(), &model.ExecRequest{ + Name: "agent", + Path: "/bin/bash", + Args: []string{"-c", "sleep 10s"}, + }) + require.NoError(t, err) + time.Sleep(100 * time.Millisecond) + err = supv.Terminate(context.Background(), &model.TerminateRequest{ + Name: "agent", + }) + require.NoError(t, err) + + eventCh, err := supv.Events(context.Background()) + require.NoError(t, err) + ev := <-eventCh + require.NotNil(t, ev.Event.ProcessTerminated()) + +} + +func TestStopAll(t *testing.T) { + supv := NewProcessSupervisor(WithProcessCredential(nil), WithLowerPriorities(false)) + + err := supv.Exec(context.Background(), &model.ExecRequest{ + Name: "agent1", + Path: "/bin/bash", + Args: []string{"-c", "sleep 10s"}, + }) + require.NoError(t, err) + + err = supv.Exec(context.Background(), &model.ExecRequest{ + Name: "agent2", + Path: "/bin/bash", + Args: []string{"-c", "sleep 10s"}, + }) + require.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + err = supv.StopAll(context.Background(), time.Now().Add(1*time.Second)) + require.NoError(t, err) +} + +func TestTerminateExited(t *testing.T) { + supv := NewProcessSupervisor(WithProcessCredential(nil), WithLowerPriorities(false)) + err := supv.Exec(context.Background(), &model.ExecRequest{ + Name: "agent", + Path: "/bin/bash", + }) + require.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + err = supv.Terminate(context.Background(), &model.TerminateRequest{ + Name: "agent", + }) + require.NoError(t, err) +} + +func TestKill(t *testing.T) { + supv := NewProcessSupervisor(WithProcessCredential(nil), WithLowerPriorities(false)) + err := supv.Exec(context.Background(), &model.ExecRequest{ + Name: "agent", + Path: "/bin/bash", + Args: []string{"-c", "sleep 10s"}, + }) + require.NoError(t, err) + err = supv.Kill(context.Background(), &model.KillRequest{ + Name: "agent", + Deadline: time.Now().Add(time.Second), + }) + require.NoError(t, err) + timer := time.NewTimer(50 * time.Millisecond) + eventCh, err := supv.Events(context.Background()) + require.NoError(t, err) + + select { + case _, ok := <-eventCh: + assert.True(t, ok) + case <-timer.C: + require.Fail(t, "Process should have exited by the time kill returns") + } +} + +func TestKillExited(t *testing.T) { + supv := NewProcessSupervisor(WithProcessCredential(nil), WithLowerPriorities(false)) + err := supv.Exec(context.Background(), &model.ExecRequest{ + Name: "agent", + Path: "/bin/bash", + }) + require.NoError(t, err) + + eventCh, err := supv.Events(context.Background()) + require.NoError(t, err) + <-eventCh + err = supv.Kill(context.Background(), &model.KillRequest{ + Name: "agent", + Deadline: time.Now().Add(time.Second), + }) + require.NoError(t, err, "Kill should succeed for exited processes") +} + +func TestKillUnknown(t *testing.T) { + supv := NewProcessSupervisor() + err := supv.Kill(context.Background(), &model.KillRequest{ + Name: "unknown", + Deadline: time.Now().Add(time.Second), + }) + require.Error(t, err) + var supvError *model.SupervisorError + assert.True(t, errors.As(err, &supvError)) + assert.Equal(t, supvError.Reason(), "ProcessNotFound") +} + +func TestTerminateUnknown(t *testing.T) { + supv := NewProcessSupervisor() + err := supv.Terminate(context.Background(), &model.TerminateRequest{ + Name: "unknown", + }) + require.Error(t, err) + var supvError *model.SupervisorError + assert.True(t, errors.As(err, &supvError)) + assert.Equal(t, supvError.Reason(), "ProcessNotFound") +} + +func TestEventsChannelShouldReturnEventsForRuntime(t *testing.T) { + supv := NewProcessSupervisor(WithProcessCredential(nil), WithLowerPriorities(false)) + ch, err := supv.Events(context.Background()) + assert.NoError(t, err) + err = supv.Exec(context.Background(), &model.ExecRequest{ + Path: "sleep", + Args: []string{"0.001s"}, + }) + assert.NoError(t, err) + + timeout, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + select { + case <-ch: + case <-timeout.Done(): + t.Error("We should get a message from the sleep binary") + } +} + +func TestTerminateCheckStatus(t *testing.T) { + supv := NewProcessSupervisor(WithProcessCredential(nil), WithLowerPriorities(false)) + err := supv.Exec(context.Background(), &model.ExecRequest{ + Name: "agent", + Path: "/bin/bash", + Args: []string{"-c", "sleep 10s"}, + }) + require.NoError(t, err) + time.Sleep(100 * time.Millisecond) + err = supv.Terminate(context.Background(), &model.TerminateRequest{ + Name: "agent", + }) + require.NoError(t, err) + + eventCh, err := supv.Events(context.Background()) + require.NoError(t, err) + ev := <-eventCh + require.NotNil(t, ev.Event.ProcessTerminated()) + term := *ev.Event.ProcessTerminated() + require.Nil(t, term.Exited()) + require.NotNil(t, term.Signaled()) + require.EqualValues(t, syscall.SIGTERM, *term.Signo) +} diff --git a/internal/lambda-managed-instances/supervisor/model/mock_lock_hard_error.go b/internal/lambda-managed-instances/supervisor/model/mock_lock_hard_error.go new file mode 100644 index 0000000..3d0e95c --- /dev/null +++ b/internal/lambda-managed-instances/supervisor/model/mock_lock_hard_error.go @@ -0,0 +1,90 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package model + +import mock "github.com/stretchr/testify/mock" + +type MockLockHardError struct { + mock.Mock +} + +func (_m *MockLockHardError) Cause() string { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Cause") + } + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +func (_m *MockLockHardError) HookName() string { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for HookName") + } + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +func (_m *MockLockHardError) Reason() string { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Reason") + } + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +func (_m *MockLockHardError) Source() string { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Source") + } + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +func NewMockLockHardError(t interface { + mock.TestingT + Cleanup(func()) +}) *MockLockHardError { + mock := &MockLockHardError{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/lambda-managed-instances/supervisor/model/mock_process_supervisor.go b/internal/lambda-managed-instances/supervisor/model/mock_process_supervisor.go new file mode 100644 index 0000000..6123b8b --- /dev/null +++ b/internal/lambda-managed-instances/supervisor/model/mock_process_supervisor.go @@ -0,0 +1,106 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package model + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" +) + +type MockProcessSupervisor struct { + mock.Mock +} + +func (_m *MockProcessSupervisor) Events(_a0 context.Context) (<-chan Event, error) { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for Events") + } + + var r0 <-chan Event + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) (<-chan Event, error)); ok { + return rf(_a0) + } + if rf, ok := ret.Get(0).(func(context.Context) <-chan Event); ok { + r0 = rf(_a0) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(<-chan Event) + } + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(_a0) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +func (_m *MockProcessSupervisor) Exec(_a0 context.Context, _a1 *ExecRequest) error { + ret := _m.Called(_a0, _a1) + + if len(ret) == 0 { + panic("no return value specified for Exec") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *ExecRequest) error); ok { + r0 = rf(_a0, _a1) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +func (_m *MockProcessSupervisor) Kill(_a0 context.Context, _a1 *KillRequest) error { + ret := _m.Called(_a0, _a1) + + if len(ret) == 0 { + panic("no return value specified for Kill") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *KillRequest) error); ok { + r0 = rf(_a0, _a1) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +func (_m *MockProcessSupervisor) Terminate(_a0 context.Context, _a1 *TerminateRequest) error { + ret := _m.Called(_a0, _a1) + + if len(ret) == 0 { + panic("no return value specified for Terminate") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *TerminateRequest) error); ok { + r0 = rf(_a0, _a1) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +func NewMockProcessSupervisor(t interface { + mock.TestingT + Cleanup(func()) +}) *MockProcessSupervisor { + mock := &MockProcessSupervisor{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/lambda-managed-instances/supervisor/model/mock_process_supervisor_client.go b/internal/lambda-managed-instances/supervisor/model/mock_process_supervisor_client.go new file mode 100644 index 0000000..5a9b5be --- /dev/null +++ b/internal/lambda-managed-instances/supervisor/model/mock_process_supervisor_client.go @@ -0,0 +1,106 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package model + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" +) + +type MockProcessSupervisorClient struct { + mock.Mock +} + +func (_m *MockProcessSupervisorClient) Events(_a0 context.Context) (<-chan Event, error) { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for Events") + } + + var r0 <-chan Event + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) (<-chan Event, error)); ok { + return rf(_a0) + } + if rf, ok := ret.Get(0).(func(context.Context) <-chan Event); ok { + r0 = rf(_a0) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(<-chan Event) + } + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(_a0) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +func (_m *MockProcessSupervisorClient) Exec(_a0 context.Context, _a1 *ExecRequest) error { + ret := _m.Called(_a0, _a1) + + if len(ret) == 0 { + panic("no return value specified for Exec") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *ExecRequest) error); ok { + r0 = rf(_a0, _a1) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +func (_m *MockProcessSupervisorClient) Kill(_a0 context.Context, _a1 *KillRequest) error { + ret := _m.Called(_a0, _a1) + + if len(ret) == 0 { + panic("no return value specified for Kill") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *KillRequest) error); ok { + r0 = rf(_a0, _a1) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +func (_m *MockProcessSupervisorClient) Terminate(_a0 context.Context, _a1 *TerminateRequest) error { + ret := _m.Called(_a0, _a1) + + if len(ret) == 0 { + panic("no return value specified for Terminate") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *TerminateRequest) error); ok { + r0 = rf(_a0, _a1) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +func NewMockProcessSupervisorClient(t interface { + mock.TestingT + Cleanup(func()) +}) *MockProcessSupervisorClient { + mock := &MockProcessSupervisorClient{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/lambda-managed-instances/supervisor/model/process.go b/internal/lambda-managed-instances/supervisor/model/process.go new file mode 100644 index 0000000..3419500 --- /dev/null +++ b/internal/lambda-managed-instances/supervisor/model/process.go @@ -0,0 +1,230 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package model + +import ( + "context" + "fmt" + "io" + "strconv" + "syscall" + "time" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/model" +) + +type EventType string + +const ( + ProcessTerminationType EventType = "process_termination" + EventLossType EventType = "event_loss" +) + +type ProcessTerminationCause string + +const ( + Signaled ProcessTerminationCause = "signaled" + Exited ProcessTerminationCause = "exited" + OomKilled ProcessTerminationCause = "oom_killed" +) + +type ProcessSupervisor interface { + Exec(context.Context, *ExecRequest) error + Terminate(context.Context, *TerminateRequest) error + Kill(context.Context, *KillRequest) error + Events(context.Context) (<-chan Event, error) +} + +type ProcessSupervisorClient interface { + ProcessSupervisor +} + +type Event struct { + Time int64 `json:"timestamp_millis"` + Event EventData `json:"event"` +} + +func (e Event) String() string { + return fmt.Sprintf("{Time:%s Event:%s}", time.UnixMilli(e.Time).UTC().Format("2006-01-02T15:04:05.000Z"), e.Event) +} + +type EventData struct { + EvType EventType `json:"type"` + Name string `json:"name"` + Cause ProcessTerminationCause `json:"cause"` + Signo *int32 `json:"signo"` + ExitStatus *int32 `json:"exit_status"` + Size *uint64 `json:"size"` +} + +func (d EventData) String() string { + signo := "" + if d.Signo != nil { + signo = strconv.FormatInt(int64(*d.Signo), 10) + } + exitStatus := "" + if d.ExitStatus != nil { + exitStatus = strconv.FormatInt(int64(*d.ExitStatus), 10) + } + size := "" + if d.Size != nil { + size = strconv.FormatUint(*d.Size, 10) + } + return fmt.Sprintf("{EvType:%s Name:%s Cause:%s Signo:%s ExitStatus:%s Size:%s}", d.EvType, d.Name, d.Cause, signo, exitStatus, size) +} + +func (d EventData) ProcessTerminated() *ProcessTermination { + return &ProcessTermination{ + Name: d.Name, + Cause: d.Cause, + Signo: d.Signo, + ExitStatus: d.ExitStatus, + } +} + +type ProcessTermination struct { + Name string + Cause ProcessTerminationCause + Signo *int32 + ExitStatus *int32 +} + +func (t ProcessTermination) Signaled() *int32 { + if t.Cause != Signaled { + return nil + } + return t.Signo +} + +func (t ProcessTermination) Exited() *int32 { + if t.Cause != Exited { + return nil + } + return t.ExitStatus +} + +func (t ProcessTermination) ExitedWithZeroExitCode() bool { + if t.Cause != Exited { + return false + } + return t.ExitStatus != nil && *t.ExitStatus == 0 +} + +func (t ProcessTermination) OomKilled() bool { + return t.Cause == OomKilled +} + +func (t ProcessTermination) String() string { + if t.ExitStatus != nil { + return fmt.Sprintf("exit status %d", *t.ExitStatus) + } + if t.Signo != nil { + sig := syscall.Signal(*t.Signo) + return fmt.Sprintf("signal: %s", sig.String()) + } + + return "signal: killed" +} + +type ExecRequest struct { + Name string `json:"name"` + + Path string `json:"path"` + Args []string `json:"args,omitempty"` + + Cwd *string `json:"cwd,omitempty"` + Env *model.KVMap `json:"env,omitempty"` + Logging Logging `json:"log_config"` + StdoutWriter io.Writer `json:"-"` + StderrWriter io.Writer `json:"-"` +} + +type Logging struct { + Managed ManagedLogging `json:"managed"` +} + +type ManagedLogging struct { + Topic ManagedLoggingTopic `json:"topic"` + Formats []ManagedLoggingFormat `json:"formats"` +} + +type ManagedLoggingTopic string + +const ( + RuntimeManagedLoggingTopic ManagedLoggingTopic = "runtime" + RtExtensionManagedLoggingTopic ManagedLoggingTopic = "runtime_extension" +) + +type ManagedLoggingFormat string + +const ( + LineBasedManagedLogging ManagedLoggingFormat = "line" + MessageBasedManagedLogging ManagedLoggingFormat = "message" +) + +type LockHardError interface { + Source() string + + Reason() string + + Cause() string + + HookName() string +} + +var _ error = (*SupervisorError)(nil) + +type ErrorExitCode uint8 + +type SupervisorError struct { + SourceErr ErrorSource `json:"source"` + HookNameErr string `json:"hook_name"` + ExitCodeErr ErrorExitCode `json:"exit_code"` + ReasonErr string `json:"reason"` + CauseErr string `json:"cause"` +} + +type ErrorSource string + +const ( + ErrorSourceClient ErrorSource = "Client" + ErrorSourceServer ErrorSource = "Server" + ErrorSourceFunction ErrorSource = "Customer" + ErrorSourceHook ErrorSource = "Hook" +) + +type ErrorReason string + +func (l *SupervisorError) Reason() string { + return l.ReasonErr +} + +func (l *SupervisorError) Source() ErrorSource { + return l.SourceErr +} + +func (l *SupervisorError) Cause() string { + return l.CauseErr +} + +func (l *SupervisorError) HookName() string { + return l.HookNameErr +} + +func (l *SupervisorError) ExitCode() ErrorExitCode { + return l.ExitCodeErr +} + +func (l *SupervisorError) Error() string { + return string(l.ReasonErr) +} + +type TerminateRequest struct { + Name string `json:"name"` +} + +type KillRequest struct { + Name string `json:"name"` + Deadline time.Time `json:"deadline"` +} diff --git a/internal/lambda-managed-instances/supervisor/model/process_test.go b/internal/lambda-managed-instances/supervisor/model/process_test.go new file mode 100644 index 0000000..850f16d --- /dev/null +++ b/internal/lambda-managed-instances/supervisor/model/process_test.go @@ -0,0 +1,186 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package model + +import ( + "encoding/json" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_KillDeadlineIsMarshalledIntoRFC3339(t *testing.T) { + deadline, err := time.Parse(time.RFC3339, "2022-12-21T10:00:00Z") + if err != nil { + t.Error(err) + } + k := KillRequest{ + Name: "", + Deadline: deadline, + } + bytes, err := json.Marshal(k) + if err != nil { + t.Error(err) + } + exepected := `{"name":"","deadline":"2022-12-21T10:00:00Z"}` + if string(bytes) != exepected { + t.Errorf("error in marshaling `KillRequest` it does not match the expected string (Expected(%q) != Got(%q))", exepected, string(bytes)) + } +} + +func TestEventToString(t *testing.T) { + signo := int32(9) + event := Event{ + Time: 1725043643030, + Event: EventData{ + EvType: ProcessTerminationType, + Name: "runtime-1", + Cause: "signaled", + Signo: &signo, + }, + } + + assert.Equal(t, "{Time:2024-08-30T18:47:23.030Z Event:{EvType:process_termination Name:runtime-1 Cause:signaled Signo:9 ExitStatus: Size:}}", event.String()) +} + +func TestEventsDerserilize(t *testing.T) { + name := "runtime-1" + signo := int32(9) + exitStatus := int32(0) + + tests := map[string]struct { + eventsJSON string + expectedEvent Event + }{ + "signaled": { + eventsJSON: `{"timestamp_millis":1686735425063,"event":{"type":"process_termination","name":"runtime-1","cause":"signaled","signo":9}}`, + expectedEvent: Event{ + Time: 1686735425063, + Event: EventData{ + EvType: ProcessTerminationType, + Name: name, + Cause: "signaled", + Signo: &signo, + }, + }, + }, + "exited": { + eventsJSON: `{"timestamp_millis":1686735425063,"event":{"type":"process_termination","name":"runtime-1","cause":"exited","exit_status":0}}`, + expectedEvent: Event{ + Time: 1686735425063, + Event: EventData{ + EvType: ProcessTerminationType, + Name: name, + Cause: "exited", + ExitStatus: &exitStatus, + }, + }, + }, + "oom_killed": { + eventsJSON: `{"timestamp_millis":1686735425063,"event":{"type":"process_termination","name":"runtime-1","cause":"oom_killed"}}`, + expectedEvent: Event{ + Time: 1686735425063, + Event: EventData{ + EvType: ProcessTerminationType, + Name: name, + Cause: "oom_killed", + }, + }, + }, + } + + for name, data := range tests { + t.Run(name, func(t *testing.T) { + var eventStruct Event + require.NoError(t, json.Unmarshal([]byte(data.eventsJSON), &eventStruct)) + assert.EqualValues(t, eventStruct, data.expectedEvent) + }) + } +} + +func TestOomKilledEvent(t *testing.T) { + name := "runtime-1" + cause := OomKilled + + ev := Event{ + Time: 1686735425063, + Event: EventData{ + EvType: ProcessTerminationType, + Name: name, + Cause: cause, + }, + } + + require.NotNil(t, ev.Event.ProcessTerminated()) + term := *ev.Event.ProcessTerminated() + require.Nil(t, term.Exited()) + require.Nil(t, term.Signaled()) + require.True(t, term.OomKilled()) +} + +func TestSupervisorErrorDerserilize(t *testing.T) { + tests := map[string]struct { + eventsJSON string + expectedError SupervisorError + expectedReason string + expectedHookName string + expectedSource string + }{ + "hook_err": { + eventsJSON: `{ + "source": "Hook", + "hook_name": "UnzipTask", + "reason":"HookFailed", + "cause": "whatever" + }`, + expectedError: SupervisorError{ + SourceErr: ErrorSourceHook, + HookNameErr: "UnzipTask", + ReasonErr: "HookFailed", + CauseErr: "whatever", + }, + expectedReason: "HookFailed", + expectedHookName: "UnzipTask", + expectedSource: "Hook", + }, + "client_err": { + eventsJSON: `{"source": "Client","reason":"DomainStartFailed","cause": "whatever"}`, + expectedError: SupervisorError{ + SourceErr: ErrorSourceClient, + ReasonErr: "DomainStartFailed", + CauseErr: "whatever", + }, + expectedReason: "DomainStartFailed", + expectedHookName: "", + expectedSource: "Client", + }, + } + + for name, data := range tests { + t.Run(name, func(t *testing.T) { + var eventStruct SupervisorError + require.NoError(t, json.Unmarshal([]byte(data.eventsJSON), &eventStruct)) + assert.EqualValues(t, eventStruct, data.expectedError) + assert.EqualValues(t, eventStruct.Reason(), data.expectedReason) + assert.EqualValues(t, eventStruct.HookName(), data.expectedHookName) + assert.EqualValues(t, eventStruct.Source(), data.expectedSource) + }) + } +} + +func TestSupervisorError(t *testing.T) { + err := SupervisorError{ + SourceErr: ErrorSourceHook, + ReasonErr: "DomainStartFailed", + CauseErr: "whatever", + HookNameErr: "hook1", + ExitCodeErr: 3, + } + assert.EqualValues(t, err.Cause(), "whatever") + assert.EqualValues(t, err.Error(), "DomainStartFailed") + assert.EqualValues(t, err.ExitCode(), 3) + assert.EqualValues(t, err.HookNameErr, "hook1") +} diff --git a/internal/lambda-managed-instances/telemetry/constants.go b/internal/lambda-managed-instances/telemetry/constants.go new file mode 100644 index 0000000..d276f6c --- /dev/null +++ b/internal/lambda-managed-instances/telemetry/constants.go @@ -0,0 +1,17 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package telemetry + +import "errors" + +const TimeFormat = "2006-01-02T15:04:05.000Z" + +const ( + SubscribeSuccess = "logs_api_subscribe_success" + SubscribeClientErr = "logs_api_subscribe_client_err" + SubscribeServerErr = "logs_api_subscribe_server_err" + NumSubscribers = "logs_api_num_subscribers" +) + +var ErrTelemetryServiceOff = errors.New("ErrTelemetryServiceOff") diff --git a/internal/lambda-managed-instances/telemetry/events.go b/internal/lambda-managed-instances/telemetry/events.go new file mode 100644 index 0000000..909d948 --- /dev/null +++ b/internal/lambda-managed-instances/telemetry/events.go @@ -0,0 +1,142 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package telemetry + +import ( + "log/slog" + "time" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/appctx" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/core" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" +) + +func SendInitStartLogEvent( + eventsAPI interop.EventsAPI, + functionMetadata model.FunctionMetadata, + logStreamName string, + phase interop.LifecyclePhase, +) { + initPhase, err := initPhaseFromLifecyclePhase(phase) + if err != nil { + slog.Error("failed to convert lifecycle phase into init phase", "err", err) + return + } + + initStartData := interop.InitStartData{ + InitializationType: interop.InitializationType, + RuntimeVersion: functionMetadata.RuntimeInfo.Version, + RuntimeVersionArn: functionMetadata.RuntimeInfo.Arn, + FunctionName: functionMetadata.FunctionName, + FunctionVersion: functionMetadata.FunctionVersion, + + InstanceID: logStreamName, + InstanceMaxMemory: functionMetadata.MemorySizeBytes, + Phase: initPhase, + } + slog.Debug("Init start data", "data", initStartData.String()) + + if err := eventsAPI.SendInitStart(initStartData); err != nil { + slog.Error("Failed to send Init START", "err", err) + } +} + +func prepareInitRuntimeDoneData(appCtx appctx.ApplicationContext, initError model.AppError, phase interop.LifecyclePhase) interop.InitRuntimeDoneData { + initPhase, _ := initPhaseFromLifecyclePhase(phase) + + status := interop.BuildStatusFromError(initError) + + initRuntimeDoneData := interop.InitRuntimeDoneData{ + InitializationType: interop.InitializationType, + Status: status, + Phase: initPhase, + ErrorType: getFirstFatalError(appCtx, status), + } + return initRuntimeDoneData +} + +func SendInitRuntimeDoneLogEvent( + eventsAPI interop.EventsAPI, + appCtx appctx.ApplicationContext, + phase interop.LifecyclePhase, + initError model.AppError, +) { + initRuntimeDoneData := prepareInitRuntimeDoneData(appCtx, initError, phase) + + slog.Debug("Init runtime done data", "data", initRuntimeDoneData.String()) + + if err := eventsAPI.SendInitRuntimeDone(initRuntimeDoneData); err != nil { + slog.Error("Failed to send Init RTDONE event", "err", err) + } +} + +func SendInitReportLogEvent( + eventsAPI interop.EventsAPI, + appCtx appctx.ApplicationContext, + initDuration time.Duration, + phase interop.LifecyclePhase, + initError model.AppError, +) { + initPhase, err := initPhaseFromLifecyclePhase(phase) + if err != nil { + slog.Error("failed to convert lifecycle phase into init phase", "err", err) + return + } + + status := interop.BuildStatusFromError(initError) + + initReportData := interop.InitReportData{ + InitializationType: interop.InitializationType, + Metrics: interop.InitReportMetrics{ + DurationMs: calculateDurationInMillis(initDuration), + }, + Phase: initPhase, + Status: status, + ErrorType: getFirstFatalError(appCtx, status), + } + slog.Debug("Init report data", "data", initReportData.String()) + + if err = eventsAPI.SendInitReport(initReportData); err != nil { + slog.Error("Failed to send INIT REPORT", "err", err) + } +} + +func SendAgentsInitStatus(eventsAPI interop.EventsAPI, agents []core.AgentInfo) { + for _, agent := range agents { + extensionInitData := interop.ExtensionInitData{ + AgentName: agent.Name, + State: agent.State, + ErrorType: string(agent.ErrorType), + Subscriptions: agent.Subscriptions, + } + if err := eventsAPI.SendExtensionInit(extensionInitData); err != nil { + slog.Error("Failed to send extension init", "err", err) + } + } +} + +func SendImageError(eventsAPI interop.EventsAPI, execError model.RuntimeExecError, execConfig model.RuntimeExec) { + eventsAPI.SendImageError(interop.ImageErrorLogData{ + ExecError: execError, + ExecConfig: execConfig, + }) +} + +func getFirstFatalError(appCtx appctx.ApplicationContext, status string) *string { + if status == interop.Success { + return nil + } + + customerError, found := appctx.LoadFirstFatalError(appCtx) + var errorType model.ErrorType + if !found { + + errorType = model.ErrorRuntimeUnknown + } else { + errorType = customerError.ErrorType() + } + stringifiedError := string(errorType) + return &stringifiedError +} diff --git a/internal/lambda-managed-instances/telemetry/events_api.go b/internal/lambda-managed-instances/telemetry/events_api.go new file mode 100644 index 0000000..194dd47 --- /dev/null +++ b/internal/lambda-managed-instances/telemetry/events_api.go @@ -0,0 +1,90 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package telemetry + +import ( + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/utils/invariant" +) + +const ( + InitInsideInitPhase interop.InitPhase = "init" + InitInsideInvokePhase interop.InitPhase = "invoke" +) + +func initPhaseFromLifecyclePhase(phase interop.LifecyclePhase) (interop.InitPhase, error) { + switch phase { + case interop.LifecyclePhaseInit: + return InitInsideInitPhase, nil + case interop.LifecyclePhaseInvoke: + return InitInsideInvokePhase, nil + default: + return interop.InitPhase(""), fmt.Errorf("unexpected lifecycle phase: %v", phase) + } +} + +func calculateDurationInMillis(duration time.Duration) float64 { + return float64(duration.Microseconds()) / 1000 +} + +type NoOpEventsAPI struct{} + +func (s *NoOpEventsAPI) SetCurrentRequestID(interop.InvokeID) {} + +func (s *NoOpEventsAPI) SendInitStart(interop.InitStartData) error { return nil } + +func (s *NoOpEventsAPI) SendInitRuntimeDone(interop.InitRuntimeDoneData) error { return nil } + +func (s *NoOpEventsAPI) SendInitReport(interop.InitReportData) error { return nil } + +func (s *NoOpEventsAPI) SendExtensionInit(interop.ExtensionInitData) error { return nil } + +func (s *NoOpEventsAPI) SendImageError(interop.ImageErrorLogData) {} + +func (s *NoOpEventsAPI) SendInvokeStart(interop.InvokeStartData) error { return nil } + +func (s *NoOpEventsAPI) SendReport(interop.ReportData) error { return nil } + +func (s *NoOpEventsAPI) SendInternalXRayErrorCause(interop.InternalXRayErrorCauseData) error { + return nil +} + +func (s *NoOpEventsAPI) Flush() { + +} + +type Event struct { + Time string `json:"time"` + Type string `json:"type"` + Record json.RawMessage `json:"record"` +} + +func FormatImageError(errLog interop.ImageErrorLogData) string { + switch errLog.ExecError.Type { + case model.InvalidTaskConfig: + return fmt.Sprintf("IMAGE\tInvalid task config: %s", errLog.ExecError.Err) + case model.InvalidEntrypoint: + return fmt.Sprintf("IMAGE\tLaunch error: %s\tEntrypoint: %s\tCmd: [%s]\tWorkingDir: [%s]", + errLog.ExecError.Err, + errLog.ExecConfig.Cmd[0], + strings.Join(errLog.ExecConfig.Cmd[1:], ","), + errLog.ExecConfig.WorkingDir) + case model.InvalidWorkingDir: + return fmt.Sprintf("IMAGE\tLaunch error: %s\tEntrypoint: %s\tCmd: [%s]\tWorkingDir: [%s]", + errLog.ExecError.Err, + errLog.ExecConfig.Cmd[0], + strings.Join(errLog.ExecConfig.Cmd[1:], ","), + errLog.ExecConfig.WorkingDir) + default: + + invariant.Violatef("Invalid runtime exec error type: %d", int(errLog.ExecError.Type)) + return "" + } +} diff --git a/internal/lambda-managed-instances/telemetry/events_api_test.go b/internal/lambda-managed-instances/telemetry/events_api_test.go new file mode 100644 index 0000000..b09c30d --- /dev/null +++ b/internal/lambda-managed-instances/telemetry/events_api_test.go @@ -0,0 +1,32 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package telemetry + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/appctx" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" +) + +func TestPrepareInitRuntimeDoneDataOnCustomerError(t *testing.T) { + appCtx := appctx.NewApplicationContext() + err := interop.ErrPlatformError + customerErr := model.WrapErrorIntoCustomerFatalError(err, model.ErrorRuntimeInit) + + appctx.StoreFirstFatalError(appCtx, customerErr) + + stringifiedError := string(model.ErrorRuntimeInit) + ActualInitRuntimeDoneData := prepareInitRuntimeDoneData(appCtx, customerErr, interop.LifecyclePhaseInit) + expectedInitRuntimeDoneData := interop.InitRuntimeDoneData{ + InitializationType: "lambda-managed-instances", + Status: "error", + Phase: InitInsideInitPhase, + ErrorType: &stringifiedError, + } + assert.Equal(t, expectedInitRuntimeDoneData, ActualInitRuntimeDoneData) +} diff --git a/internal/lambda-managed-instances/telemetry/logs_egress_api.go b/internal/lambda-managed-instances/telemetry/logs_egress_api.go new file mode 100644 index 0000000..9ade336 --- /dev/null +++ b/internal/lambda-managed-instances/telemetry/logs_egress_api.go @@ -0,0 +1,28 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package telemetry + +import ( + "io" + "os" +) + +type StdLogsEgressAPI interface { + GetExtensionSockets() (io.Writer, io.Writer, error) + GetRuntimeSockets() (io.Writer, io.Writer, error) +} + +type NoOpLogsEgressAPI struct{} + +func (s *NoOpLogsEgressAPI) GetExtensionSockets() (io.Writer, io.Writer, error) { + + return os.Stdout, os.Stdout, nil +} + +func (s *NoOpLogsEgressAPI) GetRuntimeSockets() (io.Writer, io.Writer, error) { + + return os.Stdout, os.Stdout, nil +} + +var _ StdLogsEgressAPI = (*NoOpLogsEgressAPI)(nil) diff --git a/internal/lambda-managed-instances/telemetry/logs_subscription_api.go b/internal/lambda-managed-instances/telemetry/logs_subscription_api.go new file mode 100644 index 0000000..46d6f35 --- /dev/null +++ b/internal/lambda-managed-instances/telemetry/logs_subscription_api.go @@ -0,0 +1,48 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package telemetry + +import ( + "io" + "net/http" + "net/netip" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" +) + +type SubscriptionAPI interface { + Subscribe(agentName string, body io.Reader, headers map[string][]string, remoteAddr string) (resp []byte, status int, respHeaders map[string][]string, err error) + RecordCounterMetric(metricName string, count int) + FlushMetrics() interop.TelemetrySubscriptionMetrics + Clear() + TurnOff() + GetEndpointURL() string + GetServiceClosedErrorMessage() string + GetServiceClosedErrorType() string + Configure(passphrase string, addr netip.AddrPort) +} + +type NoOpSubscriptionAPI struct{} + +func (m *NoOpSubscriptionAPI) Subscribe(agentName string, body io.Reader, headers map[string][]string, remoteAddr string) ([]byte, int, map[string][]string, error) { + return []byte(`{}`), http.StatusOK, map[string][]string{}, nil +} + +func (m *NoOpSubscriptionAPI) RecordCounterMetric(metricName string, count int) {} + +func (m *NoOpSubscriptionAPI) FlushMetrics() interop.TelemetrySubscriptionMetrics { + return interop.TelemetrySubscriptionMetrics{} +} + +func (m *NoOpSubscriptionAPI) Clear() {} + +func (m *NoOpSubscriptionAPI) TurnOff() {} + +func (m *NoOpSubscriptionAPI) GetEndpointURL() string { return "" } + +func (m *NoOpSubscriptionAPI) GetServiceClosedErrorMessage() string { return "" } + +func (m *NoOpSubscriptionAPI) GetServiceClosedErrorType() string { return "" } + +func (m *NoOpSubscriptionAPI) Configure(passphrase string, addr netip.AddrPort) {} diff --git a/internal/lambda-managed-instances/telemetry/xray/tracer.go b/internal/lambda-managed-instances/telemetry/xray/tracer.go new file mode 100644 index 0000000..9cd36e7 --- /dev/null +++ b/internal/lambda-managed-instances/telemetry/xray/tracer.go @@ -0,0 +1,109 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package xray + +import ( + "crypto/rand" + "fmt" + "strings" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + intmodel "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/model" + rapidmodel "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/utils/invariant" +) + +const xRayNonSampled = "0" + +type tracingHeader struct { + rootId string + parentId string + sampled string + + lineage string +} + +func CreateTracingData(upstreamTraceId string, tracingMode intmodel.XrayTracingMode, segmentIDGenerator func() string) (downstreamTraceId string, tracingCtx *interop.TracingCtx) { + var segmentID string + + tracingHeader := parseTracingHeader(upstreamTraceId) + + switch tracingMode { + case intmodel.XRayTracingModeActive: + segmentID = segmentIDGenerator() + downstreamTraceId = newTraceID(tracingHeader.rootId, segmentID, tracingHeader.sampled, tracingHeader.lineage) + case intmodel.XRayTracingModePassThrough: + downstreamTraceId = upstreamTraceId + default: + invariant.Violatef("Unknown tracingMode: %v", tracingMode) + } + + tracingCtx = buildTracingCtx(tracingHeader, segmentID, upstreamTraceId) + + return downstreamTraceId, tracingCtx +} + +func buildTracingCtx(t tracingHeader, segmentId, upstreamTraceId string) *interop.TracingCtx { + if t.rootId == "" || t.sampled != rapidmodel.XRaySampled { + return nil + } + + return &interop.TracingCtx{ + SpanID: segmentId, + Type: rapidmodel.XRayTracingType, + Value: upstreamTraceId, + } +} + +func parseTracingHeader(trace string) tracingHeader { + var tracingHeader tracingHeader + + keyValuePairs := strings.Split(trace, ";") + for _, pair := range keyValuePairs { + var key, value string + keyValue := strings.Split(pair, "=") + if len(keyValue) == 2 { + key = keyValue[0] + value = keyValue[1] + } + switch key { + case "Root": + tracingHeader.rootId = value + case "Parent": + tracingHeader.parentId = value + case "Sampled": + tracingHeader.sampled = value + case "Lineage": + tracingHeader.lineage = value + } + } + return tracingHeader +} + +func newTraceID(root, parent, sample, lineage string) string { + if root == "" { + return "" + } + + parts := make([]string, 0, 4) + parts = append(parts, "Root="+root) + if parent != "" { + parts = append(parts, "Parent="+parent) + } + if sample == "" { + sample = xRayNonSampled + } + parts = append(parts, "Sampled="+sample) + if lineage != "" { + parts = append(parts, "Lineage="+lineage) + } + + return strings.Join(parts, ";") +} + +func GenerateSegmentID() string { + bytes := make([]byte, 8) + _, _ = rand.Read(bytes) + return fmt.Sprintf("%08x", bytes) +} diff --git a/internal/lambda-managed-instances/telemetry/xray/tracer_test.go b/internal/lambda-managed-instances/telemetry/xray/tracer_test.go new file mode 100644 index 0000000..e44877f --- /dev/null +++ b/internal/lambda-managed-instances/telemetry/xray/tracer_test.go @@ -0,0 +1,126 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package xray + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + intmodel "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/model" + rapidmodel "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi/model" +) + +func mockSegmentIDGenerator() string { + return "12345678" +} + +func makeExpectedTracingCtx() *interop.TracingCtx { + return &interop.TracingCtx{ + SpanID: "", + Type: rapidmodel.XRayTracingType, + Value: "", + } +} + +func TestTracer(t *testing.T) { + tests := []struct { + name string + upstreamTraceId string + tracingMode intmodel.XrayTracingMode + segmentIDGenerator func() string + + expectedDownstreamId string + expectedTracingCtx *interop.TracingCtx + }{ + { + name: "Active", + upstreamTraceId: "Root=root1;Parent=parent1;Sampled=1;Lineage=foo:1|bar:65535", + tracingMode: intmodel.XRayTracingModeActive, + segmentIDGenerator: mockSegmentIDGenerator, + expectedDownstreamId: "Root=root1;Parent=12345678;Sampled=1;Lineage=foo:1|bar:65535", + expectedTracingCtx: makeExpectedTracingCtx(), + }, + { + name: "Active_NoRoot", + upstreamTraceId: "Parent=parent1;Sampled=1;Lineage=foo:1|bar:65535", + tracingMode: intmodel.XRayTracingModeActive, + segmentIDGenerator: mockSegmentIDGenerator, + expectedDownstreamId: "", + expectedTracingCtx: nil, + }, + { + name: "Active_NoParent", + upstreamTraceId: "Root=root1;Sampled=1;Lineage=foo:1|bar:65535", + tracingMode: intmodel.XRayTracingModeActive, + segmentIDGenerator: mockSegmentIDGenerator, + expectedDownstreamId: "Root=root1;Parent=12345678;Sampled=1;Lineage=foo:1|bar:65535", + expectedTracingCtx: makeExpectedTracingCtx(), + }, + { + name: "Active_NotSampled", + upstreamTraceId: "Root=root1;Parent=parent1;Lineage=foo:1|bar:65535", + tracingMode: intmodel.XRayTracingModeActive, + segmentIDGenerator: mockSegmentIDGenerator, + expectedDownstreamId: "Root=root1;Parent=12345678;Sampled=0;Lineage=foo:1|bar:65535", + expectedTracingCtx: nil, + }, + { + name: "Active_NoLineage", + upstreamTraceId: "Root=root1;Parent=parent1;Sampled=1", + tracingMode: intmodel.XRayTracingModeActive, + segmentIDGenerator: mockSegmentIDGenerator, + expectedDownstreamId: "Root=root1;Parent=12345678;Sampled=1", + expectedTracingCtx: makeExpectedTracingCtx(), + }, + { + name: "Active_UnorderedComponents", + upstreamTraceId: "Lineage=foo:1|bar:65535;Parent=parent1;Sampled=1;Root=root1", + tracingMode: intmodel.XRayTracingModeActive, + segmentIDGenerator: mockSegmentIDGenerator, + expectedDownstreamId: "Root=root1;Parent=12345678;Sampled=1;Lineage=foo:1|bar:65535", + expectedTracingCtx: makeExpectedTracingCtx(), + }, + { + name: "Active_EmptyTraceId", + upstreamTraceId: "", + tracingMode: intmodel.XRayTracingModeActive, + segmentIDGenerator: mockSegmentIDGenerator, + expectedDownstreamId: "", + expectedTracingCtx: nil, + }, + { + name: "Passthrough", + upstreamTraceId: "Root=root1;Parent=parent1;Sampled=1;Lineage=foo:1|bar:65535", + tracingMode: intmodel.XRayTracingModePassThrough, + segmentIDGenerator: mockSegmentIDGenerator, + expectedDownstreamId: "Root=root1;Parent=parent1;Sampled=1;Lineage=foo:1|bar:65535", + expectedTracingCtx: makeExpectedTracingCtx(), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + if tt.expectedTracingCtx != nil { + tt.expectedTracingCtx.Value = tt.upstreamTraceId + if tt.tracingMode == intmodel.XRayTracingModeActive { + tt.expectedTracingCtx.SpanID = tt.segmentIDGenerator() + } + } + + downstreamTraceId, tracingCtx := CreateTracingData(tt.upstreamTraceId, tt.tracingMode, tt.segmentIDGenerator) + + assert.Equal(t, tt.expectedDownstreamId, downstreamTraceId) + assert.Equal(t, tt.expectedTracingCtx, tracingCtx) + + if tt.expectedTracingCtx != nil && tt.tracingMode == intmodel.XRayTracingModeActive { + + parsedTraceId := parseTracingHeader(downstreamTraceId) + assert.Equal(t, parsedTraceId.parentId, tracingCtx.SpanID) + } + }) + } +} diff --git a/internal/lambda-managed-instances/testdata/agents/bash_true.sh b/internal/lambda-managed-instances/testdata/agents/bash_true.sh new file mode 100755 index 0000000..f1f641a --- /dev/null +++ b/internal/lambda-managed-instances/testdata/agents/bash_true.sh @@ -0,0 +1 @@ +#!/usr/bin/env bash diff --git a/internal/lambda-managed-instances/testdata/async_assertion_utils.go b/internal/lambda-managed-instances/testdata/async_assertion_utils.go new file mode 100644 index 0000000..5667c93 --- /dev/null +++ b/internal/lambda-managed-instances/testdata/async_assertion_utils.go @@ -0,0 +1,32 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package testdata + +import ( + "testing" + "time" +) + +func WaitForErrorWithTimeout(channel <-chan error, timeout time.Duration) error { + select { + case err := <-channel: + return err + case <-time.After(timeout): + return nil + } +} + +func Eventually(t *testing.T, testFunc func() (bool, error), pollingIntervalMultiple time.Duration, retries int) bool { + for try := 0; try < retries; try++ { + success, err := testFunc() + if success { + return true + } + if err != nil { + t.Logf("try %d: %v", try, err) + } + time.Sleep(time.Duration(try) * pollingIntervalMultiple) + } + return false +} diff --git a/internal/lambda-managed-instances/testdata/bash_function.sh b/internal/lambda-managed-instances/testdata/bash_function.sh new file mode 100755 index 0000000..c5c370b --- /dev/null +++ b/internal/lambda-managed-instances/testdata/bash_function.sh @@ -0,0 +1,7 @@ +function handler () { + EVENT_DATA=$1 + echo "$EVENT_DATA" 1>&2; + RESPONSE="Echoing request: '$EVENT_DATA'" + + echo $RESPONSE +} \ No newline at end of file diff --git a/internal/lambda-managed-instances/testdata/bash_runtime.sh b/internal/lambda-managed-instances/testdata/bash_runtime.sh new file mode 100755 index 0000000..f568d07 --- /dev/null +++ b/internal/lambda-managed-instances/testdata/bash_runtime.sh @@ -0,0 +1,22 @@ +#!/bin/sh + +set -euo pipefail + +# Initialization - load function handler +source $LAMBDA_TASK_ROOT/"bash_function.sh" + +# Processing +while true +do + HEADERS="$(mktemp)" + # Get an event + EVENT_DATA=$(curl -sS -LD "$HEADERS" -X GET "http://${AWS_LAMBDA_RUNTIME_API}/2018-06-01/runtime/invocation/next") + REQUEST_ID=$(grep -Fi Lambda-Runtime-Aws-Request-Id "$HEADERS" | tr -d '[:space:]' | cut -d: -f2) + + # Execute the handler function from the script + FN_PATH=$LAMBDA_TASK_ROOT/"bash_function.sh" + RESPONSE=$($FN_PATH "$EVENT_DATA") + + # Send the response + curl -X POST "http://${AWS_LAMBDA_RUNTIME_API}/2018-06-01/runtime/invocation/$REQUEST_ID/response" -d "response_from_runtime" +done \ No newline at end of file diff --git a/internal/lambda-managed-instances/testdata/bash_script_with_child_proc.sh b/internal/lambda-managed-instances/testdata/bash_script_with_child_proc.sh new file mode 100755 index 0000000..bdde5ab --- /dev/null +++ b/internal/lambda-managed-instances/testdata/bash_script_with_child_proc.sh @@ -0,0 +1,14 @@ +#!/bin/sh + +# Spawn one child process recursively and spin +# When parent process receives a SIGTERM, child process doesn't exit + +if [ -z "$DONT_SPAWN" ] +then + DONT_SPAWN=true ./$0 & +fi + +while true +do + sleep 1 +done \ No newline at end of file diff --git a/internal/lambda-managed-instances/testdata/env_setup_helpers.go b/internal/lambda-managed-instances/testdata/env_setup_helpers.go new file mode 100644 index 0000000..a3b733e --- /dev/null +++ b/internal/lambda-managed-instances/testdata/env_setup_helpers.go @@ -0,0 +1,108 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package testdata + +import ( + "bytes" + "fmt" + "io/ioutil" + "net" + "os" + "os/exec" + "path" + "path/filepath" + "runtime" + "syscall" + "testing" + + "github.com/stretchr/testify/assert" +) + +func CreateTestSocketPair(t *testing.T) (fd [2]int) { + domain, socketType := syscall.AF_UNIX, syscall.SOCK_DGRAM + fds, err := syscall.Socketpair(domain, socketType, 0) + if err != nil { + t.Error("Could not create socketpair for testing: ", err) + } + + return fds +} + +func CreateTestLogFile(t *testing.T) *os.File { + file, err := ioutil.TempFile(os.TempDir(), "rapid-unit-tests") + assert.NoError(t, err, "error opening tmp log file for test") + return file +} + +type TestSocketsRapid struct { + CtrlFd int + CnslFd int +} + +type TestSocketsSlicer struct { + CtrlSock net.Conn + CnslSock net.Conn + CtrlFd int + CnslFd int +} + +func SetupTestSockets(t *testing.T) (TestSocketsRapid, TestSocketsSlicer) { + ctrlFds := CreateTestSocketPair(t) + testCtrlFd := os.NewFile(uintptr(ctrlFds[0]), "ctrlParent") + + cnslFds := CreateTestSocketPair(t) + testCnslFd := os.NewFile(uintptr(cnslFds[0]), "cnslParent") + + ctrlSock, err := net.FileConn(testCtrlFd) + assert.NoError(t, err, "failed to setup test socket") + + cnslSock, err := net.FileConn(testCnslFd) + assert.NoError(t, err, "failed to setup test socket") + + rapidSockets := TestSocketsRapid{ctrlFds[1], cnslFds[1]} + slicerSockets := TestSocketsSlicer{ctrlSock, cnslSock, ctrlFds[0], cnslFds[0]} + return rapidSockets, slicerSockets +} + +func SetupTestXRayUDPSocket(t *testing.T) net.PacketConn { + pc, err := net.ListenPacket("udp", "localhost:0") + assert.NoError(t, err, "failed to create udp listener for testing") + return pc +} + +func setTestDependenciesBinPath(t *testing.T) { + var testDepsPath bytes.Buffer + + brazilPathCmd := exec.Command("brazil-path", "testrun.runtimefarm") + brazilPathCmd.Stdout = &testDepsPath + + err := brazilPathCmd.Run() + if err != nil { + assert.Fail(t, "Could not run brazil-path to setup $PATH for test runtime") + } + + testDepsBinPath := fmt.Sprintf("%s/bin", testDepsPath.String()) + + err = os.Setenv("PATH", fmt.Sprintf("%s:%s", testDepsBinPath, os.Getenv("PATH"))) + if err != nil { + assert.Fail(t, "Could not run brazil-path to setup $PATH for test runtime") + } + + return +} + +func SetupTestRuntime(t *testing.T, bootstrapScriptName string) (string, string) { + _, b, _, _ := runtime.Caller(0) + base := filepath.Dir(b) + resourcesDir := path.Join(base, "../testdata") + + bootstrap := path.Join(resourcesDir, bootstrapScriptName) + taskRoot := filepath.Dir(bootstrap) + + setTestDependenciesBinPath(t) + err := os.Setenv("LAMBDA_TASK_ROOT", taskRoot) + assert.NoError(t, err) + + return bootstrap, taskRoot +} diff --git a/internal/lambda-managed-instances/testdata/flowtesting.go b/internal/lambda-managed-instances/testdata/flowtesting.go new file mode 100644 index 0000000..28a2013 --- /dev/null +++ b/internal/lambda-managed-instances/testdata/flowtesting.go @@ -0,0 +1,102 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package testdata + +import ( + "bytes" + "context" + "io/ioutil" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/appctx" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/core" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi/rendering" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/telemetry" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/testdata/mockthread" +) + +const ( + contentTypeHeader = "Content-Type" + functionResponseModeHeader = "Lambda-Runtime-Function-Response-Mode" +) + +type MockInteropServer struct { + Response []byte + ErrorResponse *interop.ErrorInvokeResponse + ResponseContentType string + FunctionResponseMode string +} + +func (i *MockInteropServer) SendResponse(invokeID string, resp *interop.StreamableInvokeResponse) (*interop.InvokeResponseMetrics, error) { + bytes, err := ioutil.ReadAll(resp.Payload) + if err != nil { + return nil, err + } + if len(bytes) > interop.MaxPayloadSize { + return nil, &interop.ErrorResponseTooLarge{ + ResponseSize: len(bytes), + MaxResponseSize: interop.MaxPayloadSize, + } + } + i.Response = bytes + i.ResponseContentType = resp.Headers[contentTypeHeader] + i.FunctionResponseMode = resp.Headers[functionResponseModeHeader] + return nil, nil +} + +func (i *MockInteropServer) SendErrorResponse(invokeID string, response *interop.ErrorInvokeResponse) (*interop.InvokeResponseMetrics, error) { + i.ErrorResponse = response + i.ResponseContentType = response.Headers.ContentType + i.FunctionResponseMode = response.Headers.FunctionResponseMode + return nil, nil +} + +func (i *MockInteropServer) SendInitErrorResponse(response *interop.ErrorInvokeResponse) (*interop.InvokeResponseMetrics, error) { + i.ErrorResponse = response + i.ResponseContentType = response.Headers.ContentType + return nil, nil +} + +type FlowTest struct { + AppCtx appctx.ApplicationContext + InitFlow core.InitFlowSynchronization + RegistrationService core.RegistrationService + RenderingService *rendering.EventRenderingService + Runtime *core.Runtime + InteropServer *MockInteropServer + TelemetrySubscription *telemetry.NoOpSubscriptionAPI + EventsAPI interop.EventsAPI +} + +func (s *FlowTest) ConfigureForInit() { + s.RegistrationService.PreregisterRuntime(s.Runtime) +} + +func (s *FlowTest) ConfigureInvokeRenderer(ctx context.Context, invoke *interop.Invoke, buf *bytes.Buffer) { + s.RenderingService.SetRenderer(rendering.NewInvokeRenderer(ctx, invoke, buf, func(context.Context) string { return "" })) +} + +func NewFlowTest() *FlowTest { + appCtx := appctx.NewApplicationContext() + initFlow := core.NewInitFlowSynchronization() + registrationService := core.NewRegistrationService(initFlow) + renderingService := rendering.NewRenderingService() + runtime := core.NewRuntime(initFlow) + runtime.ManagedThread = &mockthread.MockManagedThread{} + interopServer := &MockInteropServer{} + eventsAPI := telemetry.NoOpEventsAPI{} + appctx.StoreInteropServer(appCtx, interopServer) + appctx.StoreResponseSender(appCtx, interopServer) + + return &FlowTest{ + AppCtx: appCtx, + InitFlow: initFlow, + RegistrationService: registrationService, + RenderingService: renderingService, + TelemetrySubscription: &telemetry.NoOpSubscriptionAPI{}, + Runtime: runtime, + InteropServer: interopServer, + EventsAPI: &eventsAPI, + } +} diff --git a/internal/lambda-managed-instances/testdata/mockcommand.go b/internal/lambda-managed-instances/testdata/mockcommand.go new file mode 100644 index 0000000..ac3bf52 --- /dev/null +++ b/internal/lambda-managed-instances/testdata/mockcommand.go @@ -0,0 +1,35 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package testdata + +import ( + "context" +) + +type MockCommand struct { + done chan error + ctx context.Context +} + +func NewMockCommand(ctx context.Context) MockCommand { + done := make(chan error) + return MockCommand{done, ctx} +} + +func (c MockCommand) Start() error { + return nil +} + +func (c MockCommand) Wait() error { + select { + case <-c.done: + return nil + case <-c.ctx.Done(): + return c.ctx.Err() + } +} + +func (c MockCommand) ForceExit() { + c.done <- nil +} diff --git a/internal/lambda-managed-instances/testdata/mockthread/mockthread.go b/internal/lambda-managed-instances/testdata/mockthread/mockthread.go new file mode 100644 index 0000000..b3922c2 --- /dev/null +++ b/internal/lambda-managed-instances/testdata/mockthread/mockthread.go @@ -0,0 +1,14 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package mockthread + +type MockManagedThread struct{} + +func (s *MockManagedThread) SuspendUnsafe() {} + +func (s *MockManagedThread) Release() {} + +func (s *MockManagedThread) Lock() {} + +func (s *MockManagedThread) Unlock() {} diff --git a/internal/lambda-managed-instances/testdata/mocktracer/mocktracer.go b/internal/lambda-managed-instances/testdata/mocktracer/mocktracer.go new file mode 100644 index 0000000..1de1665 --- /dev/null +++ b/internal/lambda-managed-instances/testdata/mocktracer/mocktracer.go @@ -0,0 +1,93 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package mocktracer + +import ( + "context" + "time" + + xray "golang.a2z.com/GoAmzn-LambdaXray" +) + +var MockStartTime = time.Now().UnixNano() + +var MockEndTime = time.Now().UnixNano() + 1 + +type MockTracer struct { + documentsMap map[xray.DocumentKey]xray.Document + sentDocuments []xray.Document +} + +func (m *MockTracer) Send(document xray.Document) (dk xray.DocumentKey, err error) { + if len(document.ID) == 0 { + + document.ID = IDFor(document.Name) + } + m.sentDocuments = append(m.sentDocuments, document) + return xray.DocumentKey{ + TraceID: document.TraceID, + DocumentID: document.ID, + }, nil +} + +func (m *MockTracer) Start(document xray.Document) (dk xray.DocumentKey, err error) { + document.StartTime = float64(MockStartTime) / xray.TimeDenominator + document.InProgress = true + dk, err = m.Send(document) + m.documentsMap[dk] = document + return +} + +func (m *MockTracer) SetOptions(dk xray.DocumentKey, documentOptions ...xray.DocumentOption) (err error) { + document := m.documentsMap[dk] + + for _, fieldValueSetter := range documentOptions { + fieldValueSetter(&document) + } + + m.documentsMap[dk] = document + + return nil +} + +func (m *MockTracer) End(dk xray.DocumentKey) (err error) { + document := m.documentsMap[dk] + document.EndTime = float64(MockEndTime) / xray.TimeDenominator + document.InProgress = false + + m.Send(document) + delete(m.documentsMap, dk) + return +} + +func (m *MockTracer) GetSentDocuments() []xray.Document { + return m.sentDocuments +} + +func (m *MockTracer) ResetSentDocuments() { + m.sentDocuments = []xray.Document{} +} + +func (m *MockTracer) SetDocumentMap(dm map[xray.DocumentKey]xray.Document) { + m.documentsMap = dm +} + +func (m *MockTracer) Capture(ctx context.Context, document xray.Document, criticalFunction func(context.Context) error) error { + return nil +} + +func (m *MockTracer) SetOptionsCtx(ctx context.Context, documentOptions ...xray.DocumentOption) (err error) { + return nil +} + +func NewMockTracer() xray.Tracer { + return &MockTracer{ + documentsMap: make(map[xray.DocumentKey]xray.Document), + sentDocuments: []xray.Document{}, + } +} + +func IDFor(name string) string { + return name + "_SEGMID" +} diff --git a/internal/lambda-managed-instances/testdata/parametrization.go b/internal/lambda-managed-instances/testdata/parametrization.go new file mode 100644 index 0000000..aa2ab0c --- /dev/null +++ b/internal/lambda-managed-instances/testdata/parametrization.go @@ -0,0 +1,12 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package testdata + +var SuppressInitTests = []struct { + TestName string + SuppressInit bool +}{ + {"Unsuppressed", false}, + {"Suppressed", true}, +} diff --git a/internal/lambda-managed-instances/testutils/functional/chunked.go b/internal/lambda-managed-instances/testutils/functional/chunked.go new file mode 100644 index 0000000..0772639 --- /dev/null +++ b/internal/lambda-managed-instances/testutils/functional/chunked.go @@ -0,0 +1,45 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package functional + +import ( + "bytes" + "io" + "time" +) + +type ChunkedReader struct { + buffers []*bytes.Buffer + currentIdx int + delay time.Duration +} + +func NewChunkedReader(chunks []string, delay time.Duration) *ChunkedReader { + buffers := make([]*bytes.Buffer, len(chunks)) + for i, chunk := range chunks { + buffers[i] = bytes.NewBuffer([]byte(chunk)) + } + + return &ChunkedReader{ + buffers: buffers, + currentIdx: 0, + delay: delay, + } +} + +func (r *ChunkedReader) Read(p []byte) (n int, err error) { + if r.currentIdx >= len(r.buffers) { + return 0, io.EOF + } + + n, err = r.buffers[r.currentIdx].Read(p) + + if err == io.EOF && r.currentIdx < len(r.buffers)-1 { + r.currentIdx++ + time.Sleep(r.delay) + return r.Read(p) + } + + return n, err +} diff --git a/internal/lambda-managed-instances/testutils/functional/doc.go b/internal/lambda-managed-instances/testutils/functional/doc.go new file mode 100644 index 0000000..2fc863f --- /dev/null +++ b/internal/lambda-managed-instances/testutils/functional/doc.go @@ -0,0 +1,4 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package functional diff --git a/internal/lambda-managed-instances/testutils/functional/extension_actions.go b/internal/lambda-managed-instances/testutils/functional/extension_actions.go new file mode 100644 index 0000000..6fa06bd --- /dev/null +++ b/internal/lambda-managed-instances/testutils/functional/extension_actions.go @@ -0,0 +1,397 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +//go:build test + +package functional + +import ( + "bufio" + "encoding/json" + "errors" + "fmt" + "io" + "log/slog" + "net" + "net/http" + "net/netip" + "strings" + "sync" + "syscall" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapi/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/telemetry" +) + +type SleepExtensionAction struct { + Duration time.Duration +} + +func (a SleepExtensionAction) Execute(t *testing.T, client *Client) (*http.Response, error) { + if a.Duration == 0 { + a.Duration = 100 * time.Millisecond + } + t.Logf("Extensions sleeping for %v\n", a.Duration) + time.Sleep(a.Duration) + return nil, nil +} + +func (a SleepExtensionAction) ValidateStatus(t *testing.T, resp *http.Response) {} + +func (a SleepExtensionAction) String() string { + return fmt.Sprintf("Extensions: Sleep(duration=%v)", a.Duration) +} + +type ExtensionsRegisterAction struct { + AgentUniqueName string + + Events []Event + ExpectedStatus int +} + +func (a ExtensionsRegisterAction) Execute(t *testing.T, client *Client) (*http.Response, error) { + resp, err := client.ExtensionsRegister(a.AgentUniqueName, a.Events) + require.NoError(t, err) + return resp, nil +} + +func (a ExtensionsRegisterAction) ValidateStatus(t *testing.T, resp *http.Response) { + if a.ExpectedStatus != 0 { + if resp != nil { + assert.Equal(t, a.ExpectedStatus, resp.StatusCode, "ExtensionsRegisterAction expected status code %d", a.ExpectedStatus) + } + } +} + +func (a ExtensionsRegisterAction) String() string { + return fmt.Sprintf("Extensions: Register(agentUniqueName=%s, events=%v)", a.AgentUniqueName, a.Events) +} + +type ExtensionsNextAction struct { + AgentIdentifier string + ExpectedStatus int +} + +func (a ExtensionsNextAction) Execute(t *testing.T, client *Client) (*http.Response, error) { + _, err := client.ExtensionsNext(a.AgentIdentifier) + if err != nil { + require.True(t, errors.Is(err, syscall.ECONNREFUSED) || errors.Is(err, io.EOF)) + } + return nil, nil +} + +func (a ExtensionsNextAction) ValidateStatus(t *testing.T, resp *http.Response) { + if a.ExpectedStatus == 0 { + a.ExpectedStatus = http.StatusOK + } + if resp != nil { + assert.Equal(t, a.ExpectedStatus, resp.StatusCode, "ExtensionsNextAction expected status code %d", a.ExpectedStatus) + } +} + +func (a ExtensionsNextAction) String() string { + return "Extensions: Next()" +} + +type ExtensionsNextParallelAction struct { + AgentIdentifier string + ExpectedStatus int + ParallelActions []ExecutionEnvironmentAction + Environment *ExtensionsExecutionEnvironment +} + +func (a ExtensionsNextParallelAction) Execute(t *testing.T, client *Client) (*http.Response, error) { + if a.Environment != nil && len(a.ParallelActions) > 0 { + go func() { + tempEnv := *a.Environment + tempEnv.Actions = a.ParallelActions + tempEnv.executeEnvActions(client, t) + }() + } + + _, err := client.ExtensionsNext(a.AgentIdentifier) + assert.NotNil(t, err) + return nil, nil +} + +func (a ExtensionsNextParallelAction) ValidateStatus(t *testing.T, resp *http.Response) { + if a.ExpectedStatus == 0 { + a.ExpectedStatus = http.StatusOK + } + if resp != nil { + assert.Equal(t, a.ExpectedStatus, resp.StatusCode, "ExtensionsNextParallelAction expected status code %d", a.ExpectedStatus) + } +} + +func (a ExtensionsNextParallelAction) String() string { + return "Extensions: Next() with other parallel actions" +} + +type ExtensionsInitErrorAction struct { + AgentIdentifier string + FunctionErrorType string + Payload string + ExpectedStatus int +} + +func (a ExtensionsInitErrorAction) Execute(t *testing.T, client *Client) (*http.Response, error) { + _, err := client.ExtensionsInitError(a.AgentIdentifier, a.FunctionErrorType, a.Payload) + require.NoError(t, err) + return nil, nil +} + +func (a ExtensionsInitErrorAction) ValidateStatus(t *testing.T, resp *http.Response) { + if a.ExpectedStatus != 0 { + if resp != nil { + assert.Equal(t, a.ExpectedStatus, resp.StatusCode, "ExtensionsInitErrorAction expected status code %d", a.ExpectedStatus) + } + } +} + +func (a ExtensionsInitErrorAction) String() string { + return "Extensions: InitError()" +} + +type ExtensionsExitErrorAction struct { + AgentIdentifier string + FunctionErrorType string + Payload string + ExpectedStatus int +} + +func (a ExtensionsExitErrorAction) Execute(t *testing.T, client *Client) (*http.Response, error) { + _, err := client.ExtensionsExitError(a.AgentIdentifier, a.FunctionErrorType, a.Payload) + require.NoError(t, err) + return nil, nil +} + +func (a ExtensionsExitErrorAction) ValidateStatus(t *testing.T, resp *http.Response) { + if a.ExpectedStatus != 0 { + if resp != nil { + assert.Equal(t, a.ExpectedStatus, resp.StatusCode, "ExtensionsExitErrorAction expected status code %d", a.ExpectedStatus) + } + } +} + +func (a ExtensionsExitErrorAction) String() string { + return fmt.Sprintf("Extensions: ExitError(type=%s)", a.FunctionErrorType) +} + +type ExtensionsTelemetryAPIHTTPSubscriberAction struct { + addrPort netip.AddrPort + Subscription ExtensionTelemetrySubscribeAction + InMemoryEventsApi *InMemoryEventsApi +} + +func (a ExtensionsTelemetryAPIHTTPSubscriberAction) Execute(t *testing.T, client *Client) (*http.Response, error) { + var mu sync.Mutex + go func() { + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + mu.Lock() + a.addrPort = netip.MustParseAddrPort(listener.Addr().String()) + mu.Unlock() + handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + require.NoError(t, err) + slog.Info(string(body)) + + var events []telemetry.Event + require.NoError(t, json.Unmarshal(body, &events)) + + for _, event := range events { + parseTelemetryEvent(t, event, a.InMemoryEventsApi) + } + })) + require.NoError(t, http.Serve(listener, handler)) + }() + + for { + mu.Lock() + address := a.addrPort.String() + mu.Unlock() + conn, err := net.DialTimeout("tcp", address, 500*time.Millisecond) + if err == nil { + _ = conn.Close() + break + } + time.Sleep(100 * time.Millisecond) + } + + a.Subscription.Payload = strings.NewReader(fmt.Sprintf(`{"schemaVersion": "2025-01-29", "destination":{"protocol":"HTTP","URI":"http://sandbox.localdomain:%d"}, "types": ["platform", "function", "extension"]}`, a.addrPort.Port())) + return a.Subscription.Execute(t, client) +} + +func (a ExtensionsTelemetryAPIHTTPSubscriberAction) ValidateStatus(t *testing.T, resp *http.Response) { + a.Subscription.ValidateStatus(t, resp) +} + +func (a ExtensionsTelemetryAPIHTTPSubscriberAction) String() string { + return "Extensions: ExtensionsTelemetryAPIHTTPSubscriberAction" +} + +type ExtensionsTelemetryAPITCPSubscriberAction struct { + addrPort netip.AddrPort + Subscription ExtensionTelemetrySubscribeAction + InMemoryEventsApi *InMemoryEventsApi +} + +func (a ExtensionsTelemetryAPITCPSubscriberAction) Execute(t *testing.T, client *Client) (*http.Response, error) { + var mu sync.Mutex + go func() { + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + mu.Lock() + a.addrPort = netip.MustParseAddrPort(listener.Addr().String()) + mu.Unlock() + for { + conn, err := listener.Accept() + require.NoError(t, err) + go func() { + scanner := bufio.NewScanner(conn) + for scanner.Scan() { + line := scanner.Bytes() + t.Log(string(line)) + + var event telemetry.Event + require.NoError(t, json.Unmarshal(line, &event)) + parseTelemetryEvent(t, event, a.InMemoryEventsApi) + } + require.NoError(t, scanner.Err()) + }() + } + }() + + for { + mu.Lock() + address := a.addrPort.String() + mu.Unlock() + conn, err := net.DialTimeout("tcp", address, 500*time.Millisecond) + if err == nil { + _ = conn.Close() + break + } + time.Sleep(100 * time.Millisecond) + } + + a.Subscription.Payload = strings.NewReader(fmt.Sprintf(`{"schemaVersion": "2025-01-29", "destination":{"protocol":"TCP","port":%d}, "types": ["platform", "function", "extension"]}`, a.addrPort.Port())) + return a.Subscription.Execute(t, client) +} + +func (a ExtensionsTelemetryAPITCPSubscriberAction) ValidateStatus(t *testing.T, resp *http.Response) { + a.Subscription.ValidateStatus(t, resp) +} + +func (a ExtensionsTelemetryAPITCPSubscriberAction) String() string { + return "Extensions: ExtensionsTelemetryAPITCPSubscriberAction" +} + +type ExtensionTelemetrySubscribeAction struct { + AgentIdentifier string + AgentName string + + Payload io.Reader + + Headers map[string][]string + + RemoteAddr string + + ExpectedStatus int + + ExpectedErrorType string + + ExpectedErrorMessage string +} + +func (a ExtensionTelemetrySubscribeAction) Execute(t *testing.T, client *Client) (*http.Response, error) { + resp, err := client.ExtensionsTelemetrySubscribe( + a.AgentIdentifier, + a.AgentName, + a.Payload, + a.Headers, + a.RemoteAddr, + ) + + require.NoError(t, err) + return resp, nil +} + +func (a ExtensionTelemetrySubscribeAction) ValidateStatus(t *testing.T, resp *http.Response) { + + defer func() { require.NoError(t, resp.Body.Close()) }() + body, err := io.ReadAll(resp.Body) + require.NoError(t, err, "failed to read resp body") + + if a.ExpectedStatus != 0 { + if resp != nil { + assert.Equal(t, a.ExpectedStatus, resp.StatusCode, "ExtensionTelemetrySubscribeAction expected status code %d, body: %s", a.ExpectedStatus, string(body)) + } + } + + if resp.StatusCode >= 400 { + var errorResp model.ErrorResponse + err = json.Unmarshal(body, &errorResp) + require.NoError(t, err, "failed to unmarshal resp body") + + if a.ExpectedErrorMessage != "" { + assert.Equal(t, a.ExpectedErrorMessage, errorResp.ErrorMessage, "ExtensionTelemetrySubscribeAction expected error message %s", a.ExpectedErrorMessage) + } + + if a.ExpectedErrorType != "" { + assert.Equal(t, a.ExpectedErrorType, errorResp.ErrorType, "ExtensionTelemetrySubscribeAction expected error type %s", a.ExpectedErrorType) + } + } +} + +func (a ExtensionTelemetrySubscribeAction) String() string { + return fmt.Sprintf("Extensions: TelemetrySubscribe(agentName=%s)", a.AgentName) +} + +func parseTelemetryEvent(t *testing.T, event telemetry.Event, eventsApi *InMemoryEventsApi) { + if eventsApi == nil { + return + } + + switch event.Type { + case "platform.initStart": + var data interop.InitStartData + require.NoError(t, json.Unmarshal(event.Record, &data)) + require.NoError(t, eventsApi.SendInitStart(data)) + + case "platform.initRuntimeDone": + var data interop.InitRuntimeDoneData + require.NoError(t, json.Unmarshal(event.Record, &data)) + require.NoError(t, eventsApi.SendInitRuntimeDone(data)) + + case "platform.initReport": + var data interop.InitReportData + require.NoError(t, json.Unmarshal(event.Record, &data)) + require.NoError(t, eventsApi.SendInitReport(data)) + + case "platform.start": + var data interop.InvokeStartData + require.NoError(t, json.Unmarshal(event.Record, &data)) + require.NoError(t, eventsApi.SendInvokeStart(data)) + + case "platform.report": + var data interop.ReportData + require.NoError(t, json.Unmarshal(event.Record, &data)) + require.NoError(t, eventsApi.SendReport(data)) + + case "platform.extension": + var data interop.ExtensionInitData + require.NoError(t, json.Unmarshal(event.Record, &data)) + require.NoError(t, eventsApi.SendExtensionInit(data)) + case "function", "extension": + eventsApi.RecordLogLine(event) + default: + + t.Logf("Received telemetry event of type %s (not forwarded to EventsAPI)", event.Type) + } +} diff --git a/internal/lambda-managed-instances/testutils/functional/extensions_client.go b/internal/lambda-managed-instances/testutils/functional/extensions_client.go new file mode 100644 index 0000000..40117e9 --- /dev/null +++ b/internal/lambda-managed-instances/testutils/functional/extensions_client.go @@ -0,0 +1,254 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +//go:build test + +package functional + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "net/netip" + "net/url" + "time" +) + +type ( + Event string + ConfigKey string +) + +const ( + Shutdown Event = "SHUTDOWN" +) + +const ( + lambdaAgentIdentifierHeaderKey string = "Lambda-Extension-Identifier" + lambdaAgentNameHeaderKey string = "Lambda-Extension-Name" + lambdaAgentErrorTypeHeaderKey string = "Lambda-Extension-Function-Error-Type" + lambdaAcceptFeatureHeaderKey string = "Lambda-Extension-Accept-Feature" +) + +type errorExtensions interface { + Error() string +} + +func (client *Client) ExtensionsSleep(d time.Duration) { + if d == 0 { + d = 100 * time.Millisecond + } + time.Sleep(d) +} + +type RegisterRequest struct { + Events []Event +} + +type rapidRegisterResponse struct { + AccountID *string `json:"accountId"` + FunctionName string `json:"functionName"` + FunctionVersion string `json:"functionVersion"` + Handler string `json:"handler"` + Configuration map[string]string `json:"configuration"` +} + +type StatusResponse struct { + Status string `json:"status"` +} + +type NextResponse struct { + EventType Event `json:"eventType"` +} + +type ShutdownResponse struct { + *NextResponse + ShutdownReason string `json:"shutdownReason"` + DeadlineMs int64 `json:"deadlineMs"` +} + +type RapidHTTPError struct { + StatusCode int + Status string +} + +func (s *RapidHTTPError) Error() string { + return fmt.Sprintf("/event/next failed: %d[%s]", s.StatusCode, s.Status) +} + +func NewExtensionsClient(endpoint netip.AddrPort) *Client { + return &Client{ + baseurl: fmt.Sprintf("http://%s/2020-01-01/extension", endpoint), + client: http.Client{}, + } +} + +func (client *Client) ExtensionsRegister(agentUniqueName string, events []Event) (*http.Response, errorExtensions) { + data, err := json.Marshal( + &RegisterRequest{ + Events: events, + }) + if err != nil { + return nil, fmt.Errorf("failed to marshal RegisterRequest: %w", err) + } + + return client.extensionsRegisterWithMarshalledRequest(agentUniqueName, data, nil) +} + +func (client *Client) extensionsRegisterWithMarshalledRequest( + agentUniqueName string, data []byte, headersOverrides map[string]string, +) (*http.Response, errorExtensions) { + headers := map[string]string{lambdaAgentNameHeaderKey: agentUniqueName} + for name, val := range headersOverrides { + headers[name] = val + } + + resp, err := HttpPostWithHeaders(&client.client, fmt.Sprintf("%s/register", client.baseurl), data, &headers) + if err != nil { + return nil, err + } + defer func() { + if err := resp.Body.Close(); err != nil { + panic(err) + } + }() + + var registerResponse rapidRegisterResponse + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + if err := json.Unmarshal(body, ®isterResponse); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + return resp, nil +} + +func (client *Client) ExtensionsInitError(agentIdentifier, functionErrorType, payload string) (*StatusResponse, errorExtensions) { + headers := make(map[string]string) + headers[lambdaAgentIdentifierHeaderKey] = agentIdentifier + headers[lambdaAgentErrorTypeHeaderKey] = functionErrorType + + resp, err := HttpPostWithHeaders(&client.client, fmt.Sprintf("%s/init/error", client.baseurl), []byte(payload), &headers) + if err != nil { + return nil, err + } + defer func() { + if err := resp.Body.Close(); err != nil { + panic(err) + } + }() + + var response StatusResponse + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + if err := json.Unmarshal(body, &response); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + return &response, nil +} + +func (client *Client) ExtensionsExitError(agentIdentifier, functionErrorType, payload string) (*StatusResponse, errorExtensions) { + headers := make(map[string]string) + headers[lambdaAgentIdentifierHeaderKey] = agentIdentifier + headers[lambdaAgentErrorTypeHeaderKey] = functionErrorType + + resp, err := HttpPostWithHeaders(&client.client, fmt.Sprintf("%s/exit/error", client.baseurl), []byte(payload), &headers) + if err != nil { + return nil, err + } + defer func() { + if err := resp.Body.Close(); err != nil { + panic(err) + } + }() + + var response StatusResponse + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + if err := json.Unmarshal(body, &response); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + return &response, nil +} + +func (client *Client) ExtensionsNext(agentIdentifier string) (interface{}, errorExtensions) { + invokeNextHeaders := map[string]string{lambdaAgentIdentifierHeaderKey: agentIdentifier} + return client.ExtensionsNextWithHeaders(invokeNextHeaders) +} + +func (client *Client) ExtensionsNextWithHeaders(headers map[string]string) (interface{}, errorExtensions) { + resp, err := HttpGetWithHeaders(&client.client, fmt.Sprintf("%s/event/next", client.baseurl), &headers) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusOK { + return nil, &RapidHTTPError{resp.StatusCode, resp.Status} + } + + defer func() { + if err := resp.Body.Close(); err != nil { + panic(err) + } + }() + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + var nextResponse NextResponse + if err := json.Unmarshal(body, &nextResponse); err != nil { + return nil, fmt.Errorf("could not unmarshal /next response: %s", err) + } + + switch nextResponse.EventType { + case Shutdown: + var shutdownResponse ShutdownResponse + err := json.Unmarshal(body, &shutdownResponse) + return shutdownResponse, err + default: + return nil, fmt.Errorf("unrecognisable eventType: %s", nextResponse.EventType) + } +} + +func (client *Client) ExtensionsTelemetrySubscribe(agentIdentifier string, agentName string, body io.Reader, headers map[string][]string, remoteAddr string) (*http.Response, errorExtensions) { + + baseURL, err := url.Parse(client.baseurl) + if err != nil { + return nil, fmt.Errorf("failed to parse base URL: %w", err) + } + + telemetryURL := fmt.Sprintf("http://%s/2022-07-01/telemetry", baseURL.Host) + + req, err := http.NewRequest(http.MethodPut, telemetryURL, body) + if err != nil { + return nil, fmt.Errorf("failed to create telemetry subscription request: %w", err) + } + + req.Header.Set(lambdaAgentIdentifierHeaderKey, agentIdentifier) + + for key, values := range headers { + for _, value := range values { + req.Header.Add(key, value) + } + } + + if remoteAddr != "" { + req.RemoteAddr = remoteAddr + } + + resp, err := client.client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send telemetry subscription request: %w", err) + } + + return resp, nil +} diff --git a/internal/lambda-managed-instances/testutils/functional/fluxpump_server.go b/internal/lambda-managed-instances/testutils/functional/fluxpump_server.go new file mode 100644 index 0000000..7ff5e64 --- /dev/null +++ b/internal/lambda-managed-instances/testutils/functional/fluxpump_server.go @@ -0,0 +1,153 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +//go:build test + +package functional + +import ( + "context" + "fmt" + "io" + "log/slog" + "net" + "net/http" + "net/netip" + "sync" + "time" +) + +const ( + passphraseHeader = "Passphrase" + extensionNameHeader = "Lambda-Extension-Name" + forwardedHeader = "Forwarded" + subscriptionAPIendpoint = "/subscribeV2" +) + +type FluxPumpServer struct { + requests []SubscribeRequestLog + mutex sync.Mutex + server *http.Server + addrPort netip.AddrPort +} + +type SubscribeRequestLog struct { + AgentName string + Passphrase string + Body []byte + Headers map[string][]string + RemoteAddr string + Timestamp time.Time +} + +func NewFluxPumpServer() *FluxPumpServer { + return &FluxPumpServer{ + requests: make([]SubscribeRequestLog, 0), + mutex: sync.Mutex{}, + } +} + +func (s *FluxPumpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + s.mutex.Lock() + defer s.mutex.Unlock() + + slog.Debug("FluxPump received request", + "method", r.Method, + "path", r.URL.Path, + "remoteAddr", r.RemoteAddr) + + if r.Method != http.MethodPut || r.URL.Path != subscriptionAPIendpoint { + http.Error(w, "Not found", http.StatusNotFound) + return + } + + body, err := io.ReadAll(r.Body) + if err != nil { + slog.Error("Error reading request body", "err", err) + http.Error(w, "Error reading request body", http.StatusInternalServerError) + return + } + + passphrase := r.Header.Get(passphraseHeader) + agentName := r.Header.Get(extensionNameHeader) + forwarded := r.Header.Get(forwardedHeader) + + SubscribeRequestLog := SubscribeRequestLog{ + AgentName: agentName, + Passphrase: passphrase, + Body: body, + Headers: r.Header, + RemoteAddr: r.RemoteAddr, + Timestamp: time.Now(), + } + s.requests = append(s.requests, SubscribeRequestLog) + + slog.Debug("Subscription request details", + "agentName", agentName, + "passphrase", passphrase, + "forwarded", forwarded, + "bodySize", len(body)) + + slog.Debug("Request body", "body", string(body)) + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, err = w.Write([]byte(`{"message":"Subscription accepted"}`)) + if err != nil { + panic(err) + } +} + +func (s *FluxPumpServer) GetRequests() []SubscribeRequestLog { + s.mutex.Lock() + defer s.mutex.Unlock() + return s.requests +} + +func (s *FluxPumpServer) GetAddrPort() netip.AddrPort { + return s.addrPort +} + +func (s *FluxPumpServer) Start(fxPumpAddrPort netip.AddrPort) error { + + listener, err := net.Listen("tcp", fxPumpAddrPort.String()) + if err != nil { + return fmt.Errorf("failed to find available port: %w", err) + } + + port := listener.Addr().(*net.TCPAddr).Port + + _ = listener.Close() + + s.server = &http.Server{ + Addr: fmt.Sprintf(":%d", port), + Handler: s, + ReadHeaderTimeout: 15 * time.Second, + } + + s.addrPort = netip.MustParseAddrPort(fmt.Sprintf("127.0.0.1:%d", port)) + + slog.Info("FluxPump server starting", "addrPort", s.addrPort.String()) + + go func() { + if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + slog.Error("FluxPump server error", "err", err) + } + }() + + time.Sleep(100 * time.Millisecond) + + return nil +} + +func (s *FluxPumpServer) Stop() error { + if s.server == nil { + return nil + } + slog.Info("Shutting down FluxPump server") + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + return s.server.Shutdown(ctx) +} diff --git a/internal/lambda-managed-instances/testutils/functional/httputils.go b/internal/lambda-managed-instances/testutils/functional/httputils.go new file mode 100644 index 0000000..82c2837 --- /dev/null +++ b/internal/lambda-managed-instances/testutils/functional/httputils.go @@ -0,0 +1,56 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +//go:build test + +package functional + +import ( + "bytes" + "log/slog" + "net/http" +) + +const contentType = "application/json" + +func HttpPostWithHeaders(client *http.Client, url string, data []byte, headers *map[string]string) (*http.Response, error) { + req, err := http.NewRequest("POST", url, bytes.NewBuffer(data)) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", contentType) + if headers != nil { + for k, v := range *headers { + req.Header.Set(k, v) + } + } + + resp, err := client.Do(req) + if err != nil { + return nil, err + } + + return resp, nil +} + +func HttpGetWithHeaders(client *http.Client, url string, headers *map[string]string) (*http.Response, error) { + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, err + } + + if headers != nil { + for k, v := range *headers { + req.Header.Set(k, v) + } + } + + resp, err := client.Do(req) + if err != nil { + slog.Error("extension HttpGetWithHeaders failed", "err", err) + return nil, err + } + + return resp, nil +} diff --git a/internal/lambda-managed-instances/testutils/functional/in_memory_events_api.go b/internal/lambda-managed-instances/testutils/functional/in_memory_events_api.go new file mode 100644 index 0000000..773553e --- /dev/null +++ b/internal/lambda-managed-instances/testutils/functional/in_memory_events_api.go @@ -0,0 +1,350 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +//go:build test + +package functional + +import ( + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/invoke" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/telemetry" +) + +type EventType string + +const ( + PlatformInitStart EventType = "platform.initStart" + PlatformInitRuntimeDone EventType = "platform.initRuntimeDone" + PlatformInitReport EventType = "platform.initReport" + PlatformRuntimeStart EventType = "platform.runtimeStart" + PlatformReport EventType = "platform.report" +) + +type ExpectedInitEvent struct { + EventType EventType + Status string + ErrorType string +} + +type ExpectedExtensionEvents struct { + ExtensionName string + State string + ErrorType string +} + +type ExpectedInvokeEvents struct { + EventType EventType + + Status string + + Spans []string +} + +type ExpectedGlobalData struct { + FunctionARN string + FunctionVersion string +} + +type InMemoryEventsApi struct { + testState *testing.T + mu sync.Mutex + + initStart *interop.InitStartData + initRuntimeDone *interop.InitRuntimeDoneData + initReport *interop.InitReportData + imageError *interop.ImageErrorLogData + + initEventsCount int + + extensionInit map[string]interop.ExtensionInitData + + invokeStart map[interop.InvokeID]interop.InvokeStartData + invokeReport map[interop.InvokeID]interop.ReportData + InvokeXRAYErrorCause map[interop.InvokeID]interop.InternalXRayErrorCauseData + + logLines []telemetry.Event +} + +func NewInMemoryEventsApi(t *testing.T) *InMemoryEventsApi { + return &InMemoryEventsApi{ + testState: t, + extensionInit: make(map[string]interop.ExtensionInitData), + invokeStart: make(map[interop.InvokeID]interop.InvokeStartData), + invokeReport: make(map[interop.InvokeID]interop.ReportData), + InvokeXRAYErrorCause: make(map[interop.InvokeID]interop.InternalXRayErrorCauseData), + } +} + +func (e *InMemoryEventsApi) SendInitStart(data interop.InitStartData) error { + e.mu.Lock() + defer e.mu.Unlock() + + e.initEventsCount++ + + if e.initStart != nil { + assert.FailNow(e.testState, "InitStart already exists") + } + + e.initStart = &data + return nil +} + +func (e *InMemoryEventsApi) SendInitRuntimeDone(data interop.InitRuntimeDoneData) error { + e.mu.Lock() + defer e.mu.Unlock() + + e.initEventsCount++ + + if e.initRuntimeDone != nil { + assert.FailNow(e.testState, "InitRuntimeDone already exists") + } + + e.initRuntimeDone = &data + return nil +} + +func (e *InMemoryEventsApi) SendInitReport(data interop.InitReportData) error { + e.mu.Lock() + defer e.mu.Unlock() + + e.initEventsCount++ + + if e.initReport != nil { + assert.FailNow(e.testState, "InitReport already exists") + } + + e.initReport = &data + return nil +} + +func (e *InMemoryEventsApi) SendExtensionInit(data interop.ExtensionInitData) error { + e.mu.Lock() + defer e.mu.Unlock() + + if _, ok := e.extensionInit[data.AgentName]; ok { + assert.FailNow(e.testState, "ExtensionInit for %s already exists", data.AgentName) + } + + e.extensionInit[data.AgentName] = data + return nil +} + +func (e *InMemoryEventsApi) SendImageError(errLog interop.ImageErrorLogData) { + e.mu.Lock() + defer e.mu.Unlock() + + if e.imageError != nil { + assert.FailNow(e.testState, "SendImageError already exists") + } + + e.imageError = &errLog +} + +func (e *InMemoryEventsApi) SendInvokeStart(data interop.InvokeStartData) error { + e.mu.Lock() + defer e.mu.Unlock() + + if _, ok := e.invokeStart[data.InvokeID]; ok { + assert.FailNow(e.testState, "InvokeStart for %s already exists", data.InvokeID) + } + + e.invokeStart[data.InvokeID] = data + return nil +} + +func (e *InMemoryEventsApi) SendReport(data interop.ReportData) error { + e.mu.Lock() + defer e.mu.Unlock() + + if _, ok := e.invokeReport[data.InvokeID]; ok { + assert.FailNow(e.testState, "InvokeReport for %s already exists", data.InvokeID) + } + + e.invokeReport[data.InvokeID] = data + return nil +} + +func (e *InMemoryEventsApi) SendInternalXRayErrorCause(data interop.InternalXRayErrorCauseData) error { + e.mu.Lock() + defer e.mu.Unlock() + + if _, ok := e.InvokeXRAYErrorCause[data.InvokeID]; ok { + assert.FailNow(e.testState, "SendInternalXRayErrorCause for %s already exists", data.InvokeID) + } + + e.InvokeXRAYErrorCause[data.InvokeID] = data + return nil +} + +func (e *InMemoryEventsApi) CheckSimpleInitExpectations(startTimestamp time.Time, finishTimestamp time.Time, expectedInitEvents []ExpectedInitEvent, initReq model.InitRequestMessage) { + + e.CheckComprehensiveInitExpectations(startTimestamp, finishTimestamp, 0, expectedInitEvents, initReq) +} + +func (e *InMemoryEventsApi) CheckComprehensiveInitExpectations(startTimestamp time.Time, finishTimestamp time.Time, expectedMinimalInitDuration time.Duration, expectedInitEvents []ExpectedInitEvent, initReq model.InitRequestMessage) { + if expectedInitEvents == nil { + return + } + + require.Equal(e.testState, len(expectedInitEvents), e.initEventsCount) + + for _, initEvent := range expectedInitEvents { + switch initEvent.EventType { + case PlatformInitStart: + assert.NotNil(e.testState, e.initStart) + assert.Equal(e.testState, interop.InitializationType, e.initStart.InitializationType) + assert.Equal(e.testState, initReq.RuntimeVersion, e.initStart.RuntimeVersion) + assert.Equal(e.testState, initReq.RuntimeArn, e.initStart.RuntimeVersionArn) + assert.Equal(e.testState, initReq.TaskName, e.initStart.FunctionName) + assert.Equal(e.testState, initReq.FunctionVersion, e.initStart.FunctionVersion) + assert.Equal(e.testState, initReq.LogStreamName, e.initStart.InstanceID) + assert.Equal(e.testState, uint64(initReq.MemorySizeBytes), e.initStart.InstanceMaxMemory) + assert.Equal(e.testState, telemetry.InitInsideInitPhase, e.initStart.Phase) + assert.Nil(e.testState, e.initStart.Tracing) + case PlatformInitRuntimeDone: + assert.NotNil(e.testState, e.initRuntimeDone) + assert.Equal(e.testState, interop.InitializationType, e.initRuntimeDone.InitializationType) + assert.Equal(e.testState, initEvent.Status, e.initRuntimeDone.Status) + assert.Equal(e.testState, telemetry.InitInsideInitPhase, e.initRuntimeDone.Phase) + checkErrorTypePtr(e.testState, initEvent.ErrorType, e.initRuntimeDone.ErrorType) + assert.Nil(e.testState, e.initRuntimeDone.Tracing) + case PlatformInitReport: + assert.NotNil(e.testState, e.initReport) + assert.Equal(e.testState, interop.InitializationType, e.initReport.InitializationType) + assert.Equal(e.testState, telemetry.InitInsideInitPhase, e.initReport.Phase) + assert.Nil(e.testState, e.initReport.Tracing) + assert.Equal(e.testState, initEvent.Status, e.initReport.Status) + checkErrorTypePtr(e.testState, initEvent.ErrorType, e.initReport.ErrorType) + checkDuration(e.testState, expectedMinimalInitDuration, finishTimestamp.Sub(startTimestamp), e.initReport.Metrics.DurationMs) + } + } +} + +func checkErrorTypePtr(t *testing.T, expectedErr string, realErrPtr *string) { + if expectedErr == "" { + assert.Nil(t, realErrPtr) + } else { + assert.Equal(t, expectedErr, *realErrPtr) + } +} + +func (e *InMemoryEventsApi) CheckSimpleExtensionExpectations(expectedExtensionsEvents []ExpectedExtensionEvents) { + if expectedExtensionsEvents == nil { + return + } + + require.Equal(e.testState, len(expectedExtensionsEvents), len(e.extensionInit)) + + for _, expectedEvent := range expectedExtensionsEvents { + event, ok := e.extensionInit[expectedEvent.ExtensionName] + assert.True(e.testState, ok) + assert.Equal(e.testState, expectedEvent.ErrorType, event.ErrorType) + assert.Equal(e.testState, expectedEvent.State, event.State) + } +} + +func (e *InMemoryEventsApi) CheckSimpleInvokeExpectations(startTimestamp time.Time, finishTimestamp time.Time, invokeID interop.InvokeID, expectedInvokeEvents []ExpectedInvokeEvents, initReq model.InitRequestMessage) { + + e.CheckComprehensiveInvokeExpectations(startTimestamp, finishTimestamp, invokeID, expectedInvokeEvents, initReq, 0, 0) +} + +func (e *InMemoryEventsApi) CheckComprehensiveInvokeExpectations(startTimestamp time.Time, finishTimestamp time.Time, invokeID interop.InvokeID, expectedInvokeEvents []ExpectedInvokeEvents, initReq model.InitRequestMessage, expectedInvokeLatency time.Duration, expectedInvokeRespDuration time.Duration) { + if expectedInvokeEvents == nil { + return + } + + maxDuration := finishTimestamp.Sub(startTimestamp) + for _, expectedInvokeEvent := range expectedInvokeEvents { + switch expectedInvokeEvent.EventType { + case PlatformRuntimeStart: + event := e.invokeStart[invokeID] + require.NotEmpty(e.testState, event) + assert.Equal(e.testState, initReq.FunctionARN, event.FunctionARN) + assert.Equal(e.testState, initReq.FunctionVersion, event.Version) + case PlatformReport: + event := e.invokeReport[invokeID] + require.NotEmpty(e.testState, event) + assert.Equal(e.testState, expectedInvokeEvent.Status, event.Status) + checkDuration(e.testState, expectedInvokeLatency+expectedInvokeRespDuration, maxDuration, float64(event.Metrics.DurationMs)) + require.Equal(e.testState, len(expectedInvokeEvent.Spans), len(event.Spans)) + + for i := range len(expectedInvokeEvent.Spans) { + assert.Equal(e.testState, expectedInvokeEvent.Spans[i], event.Spans[i].Name) + spanStartTime, err := time.ParseInLocation("2006-01-02T15:04:05.000Z", event.Spans[i].Start, time.UTC) + assert.NoError(e.testState, err) + checkTimestamp(e.testState, spanStartTime, startTimestamp, finishTimestamp) + + var minDuration time.Duration + switch expectedInvokeEvent.Spans[i] { + case invoke.ResponseLatencySpanName: + minDuration = expectedInvokeLatency + case invoke.ResponseDurationSpanName: + minDuration = expectedInvokeRespDuration + default: + + } + checkDuration(e.testState, minDuration, maxDuration, event.Spans[i].DurationMs) + } + } + } +} + +func (e *InMemoryEventsApi) CheckXRayErrorCauseExpectations(invokeID interop.InvokeID, expectedErrorCause string) { + e.mu.Lock() + defer e.mu.Unlock() + + if expectedErrorCause == "" { + _, exists := e.InvokeXRAYErrorCause[invokeID] + assert.False(e.testState, exists, "Expected no XRay error cause for invoke %s, but one was recorded", invokeID) + return + } + + errorCauseData, exists := e.InvokeXRAYErrorCause[invokeID] + assert.True(e.testState, exists, "Expected XRay error cause for invoke %s, but none was recorded", invokeID) + if exists { + assert.Equal(e.testState, expectedErrorCause, errorCauseData.Cause, "XRay error cause mismatch for invoke %s", invokeID) + } +} + +func (e *InMemoryEventsApi) Flush() { + +} + +func (e *InMemoryEventsApi) RecordLogLine(ev telemetry.Event) { + e.mu.Lock() + defer e.mu.Unlock() + + e.logLines = append(e.logLines, ev) +} + +func (e *InMemoryEventsApi) LogLines() []telemetry.Event { + e.mu.Lock() + defer e.mu.Unlock() + + lines := make([]telemetry.Event, len(e.logLines)) + copy(lines, e.logLines) + + return lines +} + +func checkDuration(t *testing.T, minDuration, maxDuration time.Duration, realDuration float64) { + assert.GreaterOrEqual(t, realDuration, getDurationMs(minDuration)) + assert.LessOrEqual(t, realDuration, getDurationMs(maxDuration)) +} + +func getDurationMs(d time.Duration) float64 { + return float64(d.Microseconds()) / 1000.0 +} + +func checkTimestamp(t *testing.T, realTimestamp, leftBound, rightBound time.Time) { + assert.WithinRange(t, realTimestamp, leftBound.Truncate(time.Millisecond), rightBound.Truncate(time.Millisecond)) +} diff --git a/internal/lambda-managed-instances/testutils/functional/process_supervisor.go b/internal/lambda-managed-instances/testutils/functional/process_supervisor.go new file mode 100644 index 0000000..dbffbf8 --- /dev/null +++ b/internal/lambda-managed-instances/testutils/functional/process_supervisor.go @@ -0,0 +1,285 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +//go:build test + +package functional + +import ( + "io" + "net/http" + "net/netip" + "os" + "strings" + "sync" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/invoke" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/supervisor/model" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/utils" +) + +type ExecutionEnvironmentAction interface { + Execute(t *testing.T, client *Client) (*http.Response, error) + ValidateStatus(t *testing.T, resp *http.Response) + String() string +} + +type RuntimeEnv struct { + Workers []RuntimeExecutionEnvironment + + ForcedError error + T *testing.T + ExitProcessOnce func() + Stdout, Stderr io.Writer + InvokeResponseWG *sync.WaitGroup + Done sync.WaitGroup +} + +type RuntimeExecutionEnvironment struct { + Actions []ExecutionEnvironmentAction + InvokeID interop.InvokeID + RuntimeEnv *RuntimeEnv +} + +func (r *RuntimeEnv) Exec(request *model.ExecRequest) (<-chan struct{}, error) { + if r.ForcedError != nil { + return nil, r.ForcedError + } + doneCh := make(chan struct{}) + r.ExitProcessOnce = sync.OnceFunc(func() { + close(doneCh) + }) + r.Stdout = request.StdoutWriter + r.Stderr = request.StderrWriter + runtimeAPIAddr := (*request.Env)["AWS_LAMBDA_RUNTIME_API"] + client := NewClient(netip.MustParseAddrPort(runtimeAPIAddr)) + for _, thread := range r.Workers { + r.Done.Add(1) + go func() { + thread.RuntimeEnv = r + thread.executeEnvActions(client, r.T) + r.Done.Done() + }() + } + + return doneCh, nil +} + +func (r *RuntimeEnv) Terminate() error { + r.ExitProcessOnce() + return nil +} + +func (r *RuntimeEnv) Kill() error { + r.ExitProcessOnce() + return nil +} + +type ExtensionsEnv = map[string]*ExtensionsExecutionEnvironment + +type ExtensionsExecutionEnvironment struct { + Actions []ExecutionEnvironmentAction + + ForcedError error + + ExtensionIdentifier uuid.UUID + T *testing.T + ExitProcessOnce func() + Stdout, Stderr io.Writer + Done sync.WaitGroup +} + +func (e *ExtensionsExecutionEnvironment) Exec(request *model.ExecRequest) (<-chan struct{}, error) { + if e.ForcedError != nil { + return nil, e.ForcedError + } + + runtimeAPIAddr := (*request.Env)["AWS_LAMBDA_RUNTIME_API"] + doneCh := make(chan struct{}) + e.ExitProcessOnce = sync.OnceFunc(func() { + close(doneCh) + }) + e.Stdout = request.StdoutWriter + e.Stderr = request.StderrWriter + e.Done.Add(1) + go func() { + extensionClient := NewExtensionsClient(netip.MustParseAddrPort(runtimeAPIAddr)) + e.executeEnvActions(extensionClient, e.T) + e.Done.Done() + }() + + return doneCh, nil +} + +func (e *ExtensionsExecutionEnvironment) Terminate() error { + e.ExitProcessOnce() + return nil +} + +func (e *ExtensionsExecutionEnvironment) Kill() error { + e.ExitProcessOnce() + return nil +} + +func (r *RuntimeExecutionEnvironment) executeEnvActions(client *Client, t *testing.T) { + for _, action := range r.Actions { + switch a := action.(type) { + case NextAction: + resp, err := a.Execute(t, client) + require.NoError(t, err, "Action %s failed: %v", action.String(), err) + + if resp != nil { + r.InvokeID = resp.Header.Get(invoke.RuntimeRequestIdHeader) + a.ValidateStatus(t, resp) + } + case StdoutAction: + a.stdout = r.RuntimeEnv.Stdout + executeAndValidateAction(a, client, t) + case StderrAction: + a.stderr = r.RuntimeEnv.Stderr + executeAndValidateAction(a, client, t) + case InvocationResponseAction: + if a.InvokeID == "" { + a.InvokeID = r.InvokeID + } + executeAndValidateAction(a, client, t) + case InvocationStreamingResponseAction: + if a.InvokeID == "" { + a.InvokeID = r.InvokeID + } + executeAndValidateAction(a, client, t) + case InvocationResponseErrorAction: + if a.InvokeID == "" { + a.InvokeID = r.InvokeID + } + executeAndValidateAction(a, client, t) + case ExitAction: + a.exitProcessOnce = r.RuntimeEnv.ExitProcessOnce + executeAndValidateAction(a, client, t) + case WaitInvokeResponseAction: + a.wg = r.RuntimeEnv.InvokeResponseWG + executeAndValidateAction(&a, client, t) + default: + executeAndValidateAction(a, client, t) + } + } +} + +func (e *ExtensionsExecutionEnvironment) executeEnvActions(client *Client, t *testing.T) { + for _, action := range e.Actions { + switch a := action.(type) { + case StdoutAction: + a.stdout = e.Stdout + executeAndValidateAction(a, client, t) + case StderrAction: + a.stderr = e.Stderr + executeAndValidateAction(a, client, t) + case ExtensionsRegisterAction: + resp, err := a.Execute(t, client) + require.NoError(t, err) + + agentIdentifier := resp.Header.Get("Lambda-Extension-Identifier") + if id, err := uuid.Parse(agentIdentifier); err == nil { + e.ExtensionIdentifier = id + } + require.NoError(t, err) + + a.ValidateStatus(t, resp) + case ExtensionsNextAction: + if a.AgentIdentifier == "" { + a.AgentIdentifier = e.ExtensionIdentifier.String() + } + executeAndValidateAction(a, client, t) + case ExtensionsNextParallelAction: + if a.AgentIdentifier == "" { + a.AgentIdentifier = e.ExtensionIdentifier.String() + } + a.Environment = e + executeAndValidateAction(a, client, t) + case ExtensionsInitErrorAction: + if a.AgentIdentifier == "" { + a.AgentIdentifier = e.ExtensionIdentifier.String() + } + executeAndValidateAction(a, client, t) + case ExtensionsExitErrorAction: + if a.AgentIdentifier == "" { + a.AgentIdentifier = e.ExtensionIdentifier.String() + } + executeAndValidateAction(a, client, t) + case ExtensionsTelemetryAPIHTTPSubscriberAction: + if a.Subscription.AgentIdentifier == "" { + a.Subscription.AgentIdentifier = e.ExtensionIdentifier.String() + } + executeAndValidateAction(a, client, t) + case ExtensionsTelemetryAPITCPSubscriberAction: + if a.Subscription.AgentIdentifier == "" { + a.Subscription.AgentIdentifier = e.ExtensionIdentifier.String() + } + executeAndValidateAction(a, client, t) + case ExtensionTelemetrySubscribeAction: + if a.AgentIdentifier == "" { + a.AgentIdentifier = e.ExtensionIdentifier.String() + } + executeAndValidateAction(a, client, t) + case ExitAction: + a.exitProcessOnce = e.ExitProcessOnce + executeAndValidateAction(a, client, t) + default: + executeAndValidateAction(a, client, t) + } + } +} + +func executeAndValidateAction(action ExecutionEnvironmentAction, client *Client, t *testing.T) { + resp, err := action.Execute(t, client) + require.NoError(t, err, "Action %s failed: %v", action.String(), err) + action.ValidateStatus(t, resp) +} + +func MakeMockFileUtil(extensions ExtensionsEnv) *utils.MockFileUtil { + return SetupMockFileUtil(extensions, true, true) +} + +func SetupMockFileUtil(extensions ExtensionsEnv, hasBootstrap bool, hasVarTask bool) *utils.MockFileUtil { + mockFileUtil := &utils.MockFileUtil{} + + var extensionEntries []os.DirEntry + for extensionId := range extensions { + extensionEntries = append(extensionEntries, utils.NewMockDirEntry(extensionId, false)) + } + + mockFileUtil.On("ReadDirectory", mock.MatchedBy(func(path string) bool { + return strings.Contains(path, "/opt/extensions") + })).Return(extensionEntries, nil) + + mockFileUtil.On("ReadDirectory", mock.Anything).Return(nil, nil) + + if hasBootstrap { + + bootstrapFileInfo := &utils.MockFileInfo{} + bootstrapFileInfo.On("IsDir").Return(false) + + mockFileUtil.On("Stat", "/var/runtime/bootstrap").Return(bootstrapFileInfo, nil) + } else { + + mockFileUtil.On("Stat", mock.MatchedBy(func(path string) bool { + return path == "/var/runtime/bootstrap" || path == "/var/task/bootstrap" || path == "/opt/bootstrap" + })).Return(nil, os.ErrNotExist) + } + + if hasVarTask { + mockFileUtil.On("Stat", "/var/task").Return(nil, nil) + } else { + mockFileUtil.On("Stat", "/var/task").Return(nil, os.ErrNotExist) + } + + mockFileUtil.On("Stat", mock.Anything).Return(nil, nil) + mockFileUtil.On("IsNotExist", mock.Anything).Return(os.IsNotExist) + return mockFileUtil +} diff --git a/internal/lambda-managed-instances/testutils/functional/runtime_actions.go b/internal/lambda-managed-instances/testutils/functional/runtime_actions.go new file mode 100644 index 0000000..601c027 --- /dev/null +++ b/internal/lambda-managed-instances/testutils/functional/runtime_actions.go @@ -0,0 +1,238 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +//go:build test + +package functional + +import ( + "fmt" + "io" + "net/http" + "net/http/httputil" + "strings" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" +) + +type WaitInvokeResponseAction struct { + wg *sync.WaitGroup +} + +func (a WaitInvokeResponseAction) Execute(t *testing.T, client *Client) (*http.Response, error) { + a.wg.Wait() + return nil, nil +} + +func (a WaitInvokeResponseAction) ValidateStatus(t *testing.T, resp *http.Response) {} + +func (a WaitInvokeResponseAction) String() string { + return "WaitInvokeResponseAction" +} + +type SleepAction struct { + Duration time.Duration +} + +func (a SleepAction) Execute(t *testing.T, client *Client) (*http.Response, error) { + if a.Duration == 0 { + a.Duration = 100 * time.Millisecond + } + t.Logf("runtime sleeping for %v\n", a.Duration) + time.Sleep(a.Duration) + return nil, nil +} + +func (a SleepAction) ValidateStatus(t *testing.T, resp *http.Response) {} + +func (a SleepAction) String() string { + return fmt.Sprintf("Sleep(duration=%v)", a.Duration) +} + +type NextAction struct { + Payload string + ExpectedStatus int +} + +func (a NextAction) Execute(t *testing.T, client *Client) (*http.Response, error) { + var body io.Reader + if a.Payload != "" { + body = strings.NewReader(a.Payload) + } + resp := client.Next(body) + return resp, nil +} + +func (a NextAction) ValidateStatus(t *testing.T, resp *http.Response) { + if a.ExpectedStatus != 0 { + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, a.ExpectedStatus, resp.StatusCode, "response", string(body)) + } +} + +func (a NextAction) String() string { + return "Next()" +} + +type StdoutAction struct { + Payload string + stdout io.Writer +} + +func (a StdoutAction) Execute(t *testing.T, client *Client) (*http.Response, error) { + _, err := a.stdout.Write([]byte(a.Payload)) + require.NoError(t, err) + return nil, err +} + +func (a StdoutAction) ValidateStatus(t *testing.T, resp *http.Response) {} + +func (a StdoutAction) String() string { + return "Stdout()" +} + +type StderrAction struct { + Payload string + stderr io.Writer +} + +func (a StderrAction) Execute(t *testing.T, client *Client) (*http.Response, error) { + _, err := a.stderr.Write([]byte(a.Payload)) + require.NoError(t, err) + return nil, err +} + +func (a StderrAction) ValidateStatus(t *testing.T, resp *http.Response) {} + +func (a StderrAction) String() string { + return "Stderr()" +} + +type InitErrorAction struct { + Payload string + ContentType string + ErrorType string + ExpectedStatus int +} + +func (a InitErrorAction) Execute(t *testing.T, client *Client) (*http.Response, error) { + return client.InitError(a.Payload, a.ContentType, a.ErrorType) +} + +func (a InitErrorAction) ValidateStatus(t *testing.T, resp *http.Response) { + if a.ExpectedStatus != 0 { + assert.Equal(t, a.ExpectedStatus, resp.StatusCode, "InitErrorAction expected status code %d", a.ExpectedStatus) + } +} + +func (a InitErrorAction) String() string { + return fmt.Sprintf("InitError(type=%s)", a.ErrorType) +} + +type ExitAction struct { + ExitCode int32 + exitProcessOnce func() + CrashingProcessName string +} + +func (a ExitAction) Execute(t *testing.T, client *Client) (*http.Response, error) { + a.exitProcessOnce() + + return nil, nil +} + +func (a ExitAction) ValidateStatus(t *testing.T, resp *http.Response) {} + +func (a ExitAction) String() string { + return fmt.Sprintf("Exit(code=%d)", a.ExitCode) +} + +type InvocationResponseAction struct { + Payload io.Reader + ContentType string + InvokeID interop.InvokeID + ResponseModeHeader string + ExpectedStatus int + ExpectedBody string + Trailers map[string]string +} + +func (a InvocationResponseAction) Execute(t *testing.T, client *Client) (*http.Response, error) { + return client.Response(a.InvokeID, a.Payload, a.ContentType, a.ResponseModeHeader, a.Trailers) +} + +func (a InvocationResponseAction) ValidateStatus(t *testing.T, resp *http.Response) { + + if a.ExpectedStatus != 0 && resp != nil { + dump, err := httputil.DumpResponse(resp, true) + require.NoError(t, err) + assert.Equal(t, a.ExpectedStatus, resp.StatusCode, "Expected status %d for RequestId=%s, got %s", resp.StatusCode, a.InvokeID, string(dump)) + } + if a.ExpectedBody != "" { + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equalf(t, a.ExpectedBody, string(body), "invokeID=%s", a.InvokeID) + } +} + +func (a InvocationResponseAction) String() string { + return fmt.Sprintf("InvocationResponse(%s)", a.InvokeID) +} + +type InvocationResponseErrorAction struct { + Payload string + ContentType string + InvokeID interop.InvokeID + ErrorType string + ErrorCause string + ExpectedStatus int + ExpectedBody string +} + +func (a InvocationResponseErrorAction) Execute(t *testing.T, client *Client) (*http.Response, error) { + return client.ResponseError(a.InvokeID, a.Payload, a.ContentType, a.ErrorType, a.ErrorCause) +} + +func (a InvocationResponseErrorAction) ValidateStatus(t *testing.T, resp *http.Response) { + if a.ExpectedStatus != 0 { + assert.Equal(t, a.ExpectedStatus, resp.StatusCode, "InvocationResponseErrorAction expected status %d", a.ExpectedStatus) + } + if a.ExpectedBody != "" { + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equalf(t, a.ExpectedBody, string(body), "invokeID=%s", a.InvokeID) + } +} + +func (a InvocationResponseErrorAction) String() string { + return "InvocationResponse()" +} + +type InvocationStreamingResponseAction struct { + Chunks []string + ContentType string + InvokeID interop.InvokeID + ResponseModeHeader string + ChunkDelay time.Duration + Trailers map[string]string +} + +func (a InvocationStreamingResponseAction) Execute(t *testing.T, client *Client) (*http.Response, error) { + + chunkedReader := NewChunkedReader(a.Chunks, a.ChunkDelay) + + return client.Response(a.InvokeID, chunkedReader, a.ContentType, a.ResponseModeHeader, a.Trailers) +} + +func (a InvocationStreamingResponseAction) ValidateStatus(t *testing.T, resp *http.Response) {} + +func (a InvocationStreamingResponseAction) String() string { + return fmt.Sprintf("InvocationStreamingResponse(%s, %d chunks)", a.InvokeID, len(a.Chunks)) +} diff --git a/internal/lambda-managed-instances/testutils/functional/runtime_client.go b/internal/lambda-managed-instances/testutils/functional/runtime_client.go new file mode 100644 index 0000000..5963d3d --- /dev/null +++ b/internal/lambda-managed-instances/testutils/functional/runtime_client.go @@ -0,0 +1,149 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +//go:build test + +package functional + +import ( + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "net/netip" + "os" + "strings" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/interop" +) + +const ( + protocol = "http" + apiversion = "2018-06-01" +) + +const ( + ContentTypeHeader = "Content-Type" + LambdaInvocationIDHeader = "Lambda-Runtime-Aws-Request-Id" + LambdaInvocationDeadlineHeader = "Lambda-Runtime-Deadline-Ms" + LambdaErrorTypeHeader = "Lambda-Runtime-Function-Error-Type" + LambdaErrorBodyHeader = "Lambda-Runtime-Function-Error-Body" + LambdaResponseModeHeader = "Lambda-Runtime-Function-Response-Mode" + LambdaXRayErrorCauseHeader = "Lambda-Runtime-Function-XRay-Error-Cause" +) + +type Client struct { + baseurl string + client http.Client +} + +func NewClient(endpoint netip.AddrPort) *Client { + return &Client{ + baseurl: fmt.Sprintf("%s://%s/%s", protocol, endpoint, apiversion), + client: http.Client{}, + } +} + +func (client *Client) Next(body io.Reader) *http.Response { + slog.Debug("Runtime Client calling Next", "baseurl", client.baseurl) + + url := fmt.Sprintf("%s/runtime/invocation/next", client.baseurl) + req, err := http.NewRequest(http.MethodGet, url, body) + if err != nil { + slog.Error("could not create request", "url", url, "error", err) + panic(fmt.Sprintf("could not create request for %s: %s", url, err)) + } + + headers := make(map[string]string) + headersJSON := os.Getenv("INVOCATION_NEXT_REQUEST_HEADERS") + if headersJSON != "" { + if err := json.Unmarshal([]byte(headersJSON), &headers); err != nil { + slog.Error("failed to unmarshal headers", "error", err) + panic(err) + } + } + for k, v := range headers { + req.Header.Set(k, v) + } + + resp, err := client.client.Do(req) + if err != nil { + slog.Error("Unable to call URL", "url", url, "error", err) + + return nil + + } + + return resp +} + +func (client *Client) Response(invokeID interop.InvokeID, payload io.Reader, contentType string, responseModeHeader string, trailers map[string]string) (*http.Response, error) { + url := fmt.Sprintf("%s/runtime/invocation/%s/response", client.baseurl, invokeID) + headers := map[string]string{ContentTypeHeader: contentType} + + if responseModeHeader != "" { + headers[LambdaResponseModeHeader] = responseModeHeader + } + return client.postBufferedResponse(url, payload, headers, trailers) +} + +func (client *Client) ResponseError(invokeID interop.InvokeID, payload string, contentType string, errorType string, errorCause string) (*http.Response, error) { + headers := map[string]string{ContentTypeHeader: contentType} + if len(errorType) > 0 { + headers[LambdaErrorTypeHeader] = errorType + } + if len(errorCause) > 0 { + headers[LambdaXRayErrorCauseHeader] = errorCause + } + return client.ResponseErrorWithHeaders(invokeID, payload, headers) +} + +func (client *Client) ResponseErrorWithHeaders(invokeID interop.InvokeID, payload string, headers map[string]string) (*http.Response, error) { + url := fmt.Sprintf("%s/runtime/invocation/%s/error", client.baseurl, invokeID) + return client.postBufferedResponse(url, strings.NewReader(payload), headers, nil) +} + +func (client *Client) InitError(payload string, contentType string, errorType string) (*http.Response, error) { + url := fmt.Sprintf("%s/runtime/init/error", client.baseurl) + headers := map[string]string{ContentTypeHeader: contentType} + if len(errorType) > 0 { + headers[LambdaErrorTypeHeader] = errorType + } + return client.postBufferedResponse(url, strings.NewReader(payload), headers, nil) +} + +func (client *Client) postBufferedResponse(url string, payload io.Reader, headers, trailers map[string]string) (*http.Response, error) { + req, err := http.NewRequest("POST", url, payload) + if err != nil { + slog.Error("Unable to create request", "url", url, "error", err) + panic(err) + } + for key, value := range headers { + req.Header.Set(key, value) + } + + if len(trailers) > 0 { + req.ContentLength = -1 + req.Trailer = make(http.Header) + req.TransferEncoding = []string{"chunked"} + for key, value := range trailers { + req.Header.Add("Trailer", key) + req.Trailer.Set(key, value) + } + } + + return client.client.Do(req) +} + +func (client *Client) postResponse(req *http.Request) *http.Response { + resp, err := client.client.Do(req) + + if err != nil { + + slog.Error("Got postResponse error", "err", err) + + } + + return resp +} diff --git a/internal/lambda-managed-instances/testutils/functional/supv.go b/internal/lambda-managed-instances/testutils/functional/supv.go new file mode 100644 index 0000000..70b16d8 --- /dev/null +++ b/internal/lambda-managed-instances/testutils/functional/supv.go @@ -0,0 +1,88 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +//go:build test + +package functional + +import ( + "context" + "testing" + "time" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/supervisor/model" +) + +type process interface { + Exec(request *model.ExecRequest) (doneCh <-chan struct{}, err error) + Terminate() error + Kill() error +} + +type MockSupervisor struct { + t *testing.T + ps map[string]process + eventsCh chan model.Event + eventsErr error +} + +func NewMockSupervisor(t *testing.T, runtime process, extensions ExtensionsEnv, eventsErr error) *MockSupervisor { + ps := map[string]process{ + "runtime": runtime, + } + for name, proc := range extensions { + ps["extension-"+name] = proc + } + return &MockSupervisor{ + t: t, + ps: ps, + eventsCh: make(chan model.Event), + eventsErr: eventsErr, + } +} + +func (m *MockSupervisor) Exec(_ context.Context, request *model.ExecRequest) error { + proc, ok := m.ps[request.Name] + if !ok { + m.t.Fatalf("unknown process name: %s", request.Name) + } + doneCh, err := proc.Exec(request) + if err != nil { + return err + } + go func() { + <-doneCh + var zero int32 + event := model.Event{ + Time: time.Now().UnixMilli(), + Event: model.EventData{ + EvType: model.ProcessTerminationType, + Name: request.Name, + Cause: model.Exited, + ExitStatus: &zero, + }, + } + m.eventsCh <- event + }() + return nil +} + +func (m *MockSupervisor) Terminate(_ context.Context, request *model.TerminateRequest) error { + proc, ok := m.ps[request.Name] + if !ok { + m.t.Fatalf("unknown process name: %s", request.Name) + } + return proc.Terminate() +} + +func (m *MockSupervisor) Kill(_ context.Context, request *model.KillRequest) error { + proc, ok := m.ps[request.Name] + if !ok { + m.t.Fatalf("unknown process name: %s", request.Name) + } + return proc.Kill() +} + +func (m *MockSupervisor) Events(_ context.Context) (<-chan model.Event, error) { + return m.eventsCh, m.eventsErr +} diff --git a/internal/lambda-managed-instances/testutils/mocks/http_handler_mock.go b/internal/lambda-managed-instances/testutils/mocks/http_handler_mock.go new file mode 100644 index 0000000..7ac33fb --- /dev/null +++ b/internal/lambda-managed-instances/testutils/mocks/http_handler_mock.go @@ -0,0 +1,25 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package mocks + +import ( + "net/http" + + "github.com/stretchr/testify/mock" +) + +type MockHTTPHandler struct { + mock.Mock +} + +func (m *MockHTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + m.Called(w, r) + w.WriteHeader(http.StatusOK) +} + +func NewNoOpHandler() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) +} diff --git a/internal/lambda-managed-instances/testutils/mocks/http_mock.go b/internal/lambda-managed-instances/testutils/mocks/http_mock.go new file mode 100644 index 0000000..54afe73 --- /dev/null +++ b/internal/lambda-managed-instances/testutils/mocks/http_mock.go @@ -0,0 +1,56 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package mocks + +import ( + "fmt" + "io" + "net/http" + "time" +) + +type ReaderMock struct { + PayloadSize int + ReadTotal int + WaitBeforeRead time.Duration +} + +func (r *ReaderMock) Read(p []byte) (int, error) { + if r.WaitBeforeRead != 0 { + time.Sleep(r.WaitBeforeRead) + } + + if r.ReadTotal >= r.PayloadSize { + return 0, io.EOF + } + + haveRead := min(len(p), r.PayloadSize-r.ReadTotal) + r.ReadTotal += haveRead + + return haveRead, nil +} + +type ReaderFailureMock struct{} + +func (r *ReaderFailureMock) Read(p []byte) (int, error) { + return 0, fmt.Errorf("can't read") +} + +type WriterFailureMock struct{} + +func (w *WriterFailureMock) Write(p []byte) (n int, err error) { + return 0, fmt.Errorf("can't write") +} + +type ResponseWriterMock struct { + io.Writer +} + +func (w *ResponseWriterMock) Header() http.Header { + return map[string][]string{} +} + +func (w *ResponseWriterMock) WriteHeader(int) {} + +func (w *ResponseWriterMock) Flush() {} diff --git a/internal/lambda-managed-instances/testutils/socket_utils.go b/internal/lambda-managed-instances/testutils/socket_utils.go new file mode 100644 index 0000000..b5fa7be --- /dev/null +++ b/internal/lambda-managed-instances/testutils/socket_utils.go @@ -0,0 +1,49 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package testutils + +import ( + "context" + "fmt" + "net" + "net/http" + "os" + "path/filepath" + "testing" + "time" + + "github.com/google/uuid" +) + +func CreateTempSocketPath(t *testing.T) (string, error) { + socketPath := filepath.Join(os.TempDir(), fmt.Sprintf("%s.sock", uuid.New().String())) + + t.Cleanup(func() { + if err := os.Remove(socketPath); err != nil { + t.Logf("could not cleanup unix socket file %s: %s", socketPath, err) + } + }) + + if len(socketPath) > 104 { + return "", fmt.Errorf("socket path is too long: %s", socketPath) + } + + return socketPath, nil +} + +func NewUnixSocketClient(socketPath string) *http.Client { + dialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + return net.Dial("unix", socketPath) + } + + transport := &http.Transport{ + DialContext: dialer, + DisableCompression: true, + ResponseHeaderTimeout: 5 * time.Second, + } + + return &http.Client{ + Transport: transport, + } +} diff --git a/internal/lambda-managed-instances/testutils/socket_utils_test.go b/internal/lambda-managed-instances/testutils/socket_utils_test.go new file mode 100644 index 0000000..c6823a6 --- /dev/null +++ b/internal/lambda-managed-instances/testutils/socket_utils_test.go @@ -0,0 +1,28 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package testutils + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestCreateTempSocketPath(t *testing.T) { + + socketPath, err := CreateTempSocketPath(t) + require.NoError(t, err) + + if !strings.Contains(socketPath, ".sock") { + t.Errorf("Expected socket path to contain '.sock', got: %s", socketPath) + } + + parentDir := filepath.Dir(socketPath) + if _, err := os.Stat(parentDir); os.IsNotExist(err) { + t.Errorf("Expected parent directory %s to exist", parentDir) + } +} diff --git a/internal/lambda-managed-instances/testutils/test_data.go b/internal/lambda-managed-instances/testutils/test_data.go new file mode 100644 index 0000000..7a5c42c --- /dev/null +++ b/internal/lambda-managed-instances/testutils/test_data.go @@ -0,0 +1,133 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package testutils + +import ( + "encoding/json" + "fmt" + "net/netip" + "time" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/model" +) + +var ( + DefaultTestFunctionARN = "arn:aws:lambda:us-east-1:123456789012:function:test_function" + DefaultTestFunctionVersion = "$LATEST" +) + +func MakeValidInitPayload(opts ...InitPayloadOption) string { + return JsonEncode(MakeInitPayload(opts...)) +} + +func MakeInvalidInitPayload() string { + return JsonEncode(MakeInitPayload(WithInvalidPayload())) +} + +type InitPayloadOption func(*model.InitRequestMessage) + +func MakeInitPayload(opts ...InitPayloadOption) model.InitRequestMessage { + + payload := model.InitRequestMessage{ + AccountID: "123456789012", + AwsKey: "AKIAIOSFODNN7EXAMPLE", + AwsSecret: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + AwsSession: "FwoGZXIvYXdzEMj//////////wEaDM1Qz0oN8BNwV9GqyyLVAebxhwq9ZGqojXZe1UTJkzK6F9V+VZHhT5JSWYzJUKEwOqOkQyQXJpfJsYHfkJEXtR6Kh9mXnEbqKi", + AwsRegion: "us-west-2", + EnvVars: map[string]string{ + "CUSTOMER_ENV_VAR_1": "customer_env_value_1", + }, + ArtefactType: model.ArtefactTypeZIP, + MemorySizeBytes: 3008 * 1024 * 1024, + FunctionARN: DefaultTestFunctionARN, + FunctionVersion: DefaultTestFunctionVersion, + FunctionVersionID: "test-function-version-id", + TaskName: "test_function", + InvokeTimeout: model.DurationMS(3 * time.Second), + InitTimeout: model.DurationMS(10 * time.Second), + RuntimeWorkerCount: 1, + LogFormat: "json", + LogLevel: "info", + LogGroupName: "/aws/lambda/test_function", + LogStreamName: "$LATEST", + TelemetryAPIAddress: model.TelemetryAddr(netip.MustParseAddrPort("1.1.1.1:1234")), + TelemetryPassphrase: "hello", + XRayDaemonAddress: "2.2.2.2:2345", + XrayTracingMode: model.XRayTracingModeActive, + RuntimeBinaryCommand: []string{"cmd", "arg1", "arg2"}, + CurrentWorkingDir: "/", + AmiId: "ami-12345", + AvailabilityZoneId: "az-1", + Handler: "lambda_function.lambda_handler", + } + + for _, opt := range opts { + opt(&payload) + } + + return payload +} + +func JsonEncode(payload model.InitRequestMessage) string { + jsonBytes, err := json.MarshalIndent(payload, "", " ") + if err != nil { + panic(fmt.Sprintf("Failed to marshal init payload: %v", err)) + } + + return string(jsonBytes) +} + +func WithInvalidPayload() InitPayloadOption { + return func(p *model.InitRequestMessage) { + p.AwsKey = "AKIAIOSFODNN7EXAMPLE" + + p.RuntimeBinaryCommand = nil + } +} + +func WithTimeouts(invokeTimeout, initTimeout time.Duration) InitPayloadOption { + return func(p *model.InitRequestMessage) { + p.InvokeTimeout = model.DurationMS(invokeTimeout) + p.InitTimeout = model.DurationMS(initTimeout) + } +} + +func WithLogFormat(format string) InitPayloadOption { + return func(p *model.InitRequestMessage) { + p.LogFormat = format + } +} + +func WithLogLevel(level string) InitPayloadOption { + return func(p *model.InitRequestMessage) { + p.LogLevel = level + } +} + +func WithArtefactType(typ model.ArtefactType) InitPayloadOption { + return func(p *model.InitRequestMessage) { + p.ArtefactType = typ + } +} + +func WithTelemetry(apiAddress netip.AddrPort, passphrase string) InitPayloadOption { + return func(p *model.InitRequestMessage) { + p.TelemetryAPIAddress = model.TelemetryAddr(apiAddress) + p.TelemetryPassphrase = passphrase + } +} + +func WithExecValues(artefact model.ArtefactType, cmd []string, cwd string) InitPayloadOption { + return func(p *model.InitRequestMessage) { + p.ArtefactType = artefact + p.RuntimeBinaryCommand = cmd + p.CurrentWorkingDir = cwd + } +} + +func WithEnvVars(envVars model.KVMap) InitPayloadOption { + return func(p *model.InitRequestMessage) { + p.EnvVars = envVars + } +} diff --git a/internal/lambda-managed-instances/utils/buffer_pool.go b/internal/lambda-managed-instances/utils/buffer_pool.go new file mode 100644 index 0000000..1d25e05 --- /dev/null +++ b/internal/lambda-managed-instances/utils/buffer_pool.go @@ -0,0 +1,22 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package utils + +import ( + "io" + "sync" +) + +var bufferPool = sync.Pool{ + New: func() interface{} { + buf := make([]byte, 32*1024) + return &buf + }, +} + +func CopyWithPool(dst io.Writer, src io.Reader) (written int64, err error) { + bufPtr := bufferPool.Get().(*[]byte) + defer bufferPool.Put(bufPtr) + return io.CopyBuffer(dst, src, *bufPtr) +} diff --git a/internal/lambda-managed-instances/utils/buffer_pool_test.go b/internal/lambda-managed-instances/utils/buffer_pool_test.go new file mode 100644 index 0000000..9c92b19 --- /dev/null +++ b/internal/lambda-managed-instances/utils/buffer_pool_test.go @@ -0,0 +1,44 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package utils + +import ( + "bytes" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCopyWithPool_BasicFunctionality(t *testing.T) { + + testCases := []struct { + name string + data string + }{ + {"empty", ""}, + {"small", "hello world"}, + {"large", strings.Repeat("large test data ", 10000)}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + for range 10 { + go func() { + for range 100 { + src := strings.NewReader(tc.data) + dst := &bytes.Buffer{} + + written, err := CopyWithPool(dst, src) + + require.NoError(t, err) + assert.Equal(t, int64(len(tc.data)), written) + assert.Equal(t, tc.data, dst.String()) + } + }() + } + }) + } +} diff --git a/internal/lambda-managed-instances/utils/file_utils.go b/internal/lambda-managed-instances/utils/file_utils.go new file mode 100644 index 0000000..da55de2 --- /dev/null +++ b/internal/lambda-managed-instances/utils/file_utils.go @@ -0,0 +1,49 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package utils + +import ( + "fmt" + "io/fs" + "os" + "path/filepath" +) + +type FileUtil interface { + ReadDirectory(dirPath string) ([]os.DirEntry, error) + Stat(name string) (os.FileInfo, error) + IsNotExist(err error) bool +} + +func NewFileUtil() FileUtil { + return &fileUtil{} +} + +type fileUtil struct{} + +func (l *fileUtil) ReadDirectory(dirPath string) ([]os.DirEntry, error) { + return os.ReadDir(dirPath) +} + +func (l *fileUtil) Stat(name string) (os.FileInfo, error) { + return os.Stat(name) +} + +func (l *fileUtil) IsNotExist(err error) bool { + return os.IsNotExist(err) +} + +func FixTmpDir(dir string, uid, gid int) error { + if err := os.Chown(dir, uid, gid); err != nil { + return fmt.Errorf("could not chown %s folder: %w", dir, err) + } + if err := os.Chmod(dir, fs.ModeDir|fs.ModeSticky|fs.ModePerm); err != nil { + return fmt.Errorf("could not chmod %s folder: %w", dir, err) + } + if err := os.RemoveAll(filepath.Join(dir, "lost+found")); err != nil { + return fmt.Errorf("could not remove %s/lost+found folder: %w", dir, err) + } + + return nil +} diff --git a/internal/lambda-managed-instances/utils/file_utils_test.go b/internal/lambda-managed-instances/utils/file_utils_test.go new file mode 100644 index 0000000..b6dc5e7 --- /dev/null +++ b/internal/lambda-managed-instances/utils/file_utils_test.go @@ -0,0 +1,39 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package utils + +import ( + "io/fs" + "os" + "path/filepath" + "syscall" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestCreateStickyWorldWritableDir(t *testing.T) { + tempDir := t.TempDir() + + lostFoundDir := filepath.Join(tempDir, "lost+found") + require.NoError(t, os.Mkdir(lostFoundDir, 0o755)) + + uid := os.Getuid() + gid := os.Getgid() + + require.NoError(t, FixTmpDir(tempDir, uid, gid)) + + stat, err := os.Stat(tempDir) + require.NoError(t, err) + + require.Equal(t, fs.ModeDir|fs.ModeSticky|fs.ModePerm, stat.Mode()) + + if sysStat, ok := stat.Sys().(*syscall.Stat_t); ok { + require.Equal(t, uint32(uid), sysStat.Uid) + require.Equal(t, uint32(gid), sysStat.Gid) + } + + _, err = os.Stat(lostFoundDir) + require.True(t, os.IsNotExist(err)) +} diff --git a/internal/lambda-managed-instances/utils/invariant/invariant.go b/internal/lambda-managed-instances/utils/invariant/invariant.go new file mode 100644 index 0000000..f4ca149 --- /dev/null +++ b/internal/lambda-managed-instances/utils/invariant/invariant.go @@ -0,0 +1,46 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package invariant + +import ( + "fmt" + "sync" +) + +func Check(cond bool, statement string) { + if !cond { + Violate(statement) + } +} + +func Checkf(cond bool, format string, args ...any) { + if !cond { + Violatef(format, args...) + } +} + +func Violate(statement string) { + std.mtx.Lock() + defer std.mtx.Unlock() + + std.executor.Exec(ViolationError{Statement: statement}) +} + +func Violatef(format string, args ...any) { + Violate(fmt.Sprintf(format, args...)) +} + +func SetViolationExecutor(exector ViolationExecutor) { + std.mtx.Lock() + defer std.mtx.Unlock() + + std.executor = exector +} + +var std = struct { + executor ViolationExecutor + mtx sync.Mutex +}{ + executor: NewPanicViolationExecuter(), +} diff --git a/internal/lambda-managed-instances/utils/invariant/invariant_test.go b/internal/lambda-managed-instances/utils/invariant/invariant_test.go new file mode 100644 index 0000000..b7074e0 --- /dev/null +++ b/internal/lambda-managed-instances/utils/invariant/invariant_test.go @@ -0,0 +1,76 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package invariant + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestInvariantViolationError(t *testing.T) { + msg := "oops violation message" + err := ViolationError{Statement: msg} + assert.ErrorContains(t, err, msg) +} + +func TestGlobalChecksDoNothingOnOk(t *testing.T) { + defer SetViolationExecutor(std.executor) + + m := &mockViolationExecutor{} + SetViolationExecutor(m) + + Check(true, "oops") + Checkf(true, "oops with arg %v", 42) + + assert.Empty(t, m.Calls) + m.AssertExpectations(t) +} + +func TestGlobalViolationsUseProvidedImplementation(t *testing.T) { + defer SetViolationExecutor(std.executor) + + m := &mockViolationExecutor{} + SetViolationExecutor(m) + + { + m.On("Exec", ViolationError{Statement: "oops check"}).Once() + Check(false, "oops check") + m.AssertExpectations(t) + } + { + m.On("Exec", ViolationError{Statement: "oops check with arg 42"}).Once() + Checkf(false, "oops check with arg %v", 42) + m.AssertExpectations(t) + } + { + m.On("Exec", ViolationError{Statement: "oops violate"}).Once() + Violate("oops violate") + m.AssertExpectations(t) + } + { + m.On("Exec", ViolationError{Statement: "oops violate with arg 42"}).Once() + Violatef("oops violate with arg %v", 42) + m.AssertExpectations(t) + } +} + +func TestGlobalImplementationIsOfExpectedType(t *testing.T) { + var valueOfExpectedType *PanicViolationExecutor + assert.IsType(t, valueOfExpectedType, std.executor) + + violator := NewPanicViolationExecuter() + assert.IsType(t, valueOfExpectedType, violator) +} + +type mockViolationExecutor struct { + mock.Mock +} + +var _ ViolationExecutor = (*mockViolationExecutor)(nil) + +func (m *mockViolationExecutor) Exec(err ViolationError) { + m.Called(err) +} diff --git a/internal/lambda-managed-instances/utils/invariant/mock_violation_executor.go b/internal/lambda-managed-instances/utils/invariant/mock_violation_executor.go new file mode 100644 index 0000000..9ff2818 --- /dev/null +++ b/internal/lambda-managed-instances/utils/invariant/mock_violation_executor.go @@ -0,0 +1,26 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package invariant + +import mock "github.com/stretchr/testify/mock" + +type MockViolationExecutor struct { + mock.Mock +} + +func (_m *MockViolationExecutor) Exec(_a0 ViolationError) { + _m.Called(_a0) +} + +func NewMockViolationExecutor(t interface { + mock.TestingT + Cleanup(func()) +}) *MockViolationExecutor { + mock := &MockViolationExecutor{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/lambda-managed-instances/utils/invariant/model.go b/internal/lambda-managed-instances/utils/invariant/model.go new file mode 100644 index 0000000..ad40bcf --- /dev/null +++ b/internal/lambda-managed-instances/utils/invariant/model.go @@ -0,0 +1,16 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package invariant + +type ViolationError struct { + Statement string +} + +func (err ViolationError) Error() string { + return "Invariant violation: " + err.Statement +} + +type ViolationExecutor interface { + Exec(ViolationError) +} diff --git a/internal/lambda-managed-instances/utils/invariant/panic.go b/internal/lambda-managed-instances/utils/invariant/panic.go new file mode 100644 index 0000000..02d7d5f --- /dev/null +++ b/internal/lambda-managed-instances/utils/invariant/panic.go @@ -0,0 +1,16 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package invariant + +type PanicViolationExecutor struct{} + +var _ ViolationExecutor = (*PanicViolationExecutor)(nil) + +func NewPanicViolationExecuter() *PanicViolationExecutor { + return &PanicViolationExecutor{} +} + +func (executor *PanicViolationExecutor) Exec(err ViolationError) { + panic(err) +} diff --git a/internal/lambda-managed-instances/utils/invariant/panic_test.go b/internal/lambda-managed-instances/utils/invariant/panic_test.go new file mode 100644 index 0000000..6eaf0be --- /dev/null +++ b/internal/lambda-managed-instances/utils/invariant/panic_test.go @@ -0,0 +1,20 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package invariant + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestPanicExecutor(t *testing.T) { + executor := NewPanicViolationExecuter() + + for i := 0; i < 2; i++ { + err := ViolationError{Statement: fmt.Sprintf("oops %v", i)} + assert.PanicsWithValue(t, err, func() { executor.Exec(err) }) + } +} diff --git a/internal/lambda-managed-instances/utils/io.go b/internal/lambda-managed-instances/utils/io.go new file mode 100644 index 0000000..f2260e9 --- /dev/null +++ b/internal/lambda-managed-instances/utils/io.go @@ -0,0 +1,58 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package utils + +import ( + "context" + "io" + "time" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/logging" +) + +type TimedReader struct { + Reader io.Reader + Name string + Ctx context.Context + TotalDuration time.Duration +} + +func (tr *TimedReader) Read(p []byte) (n int, err error) { + start := time.Now() + n, err = tr.Reader.Read(p) + duration := time.Since(start) + tr.TotalDuration += duration + + logging.Debug(tr.Ctx, "Read operation completed", + "name", tr.Name, + "bytes", n, + "duration", duration, + "totalReadTime", tr.TotalDuration, + "error", err) + + return n, err +} + +type TimedWriter struct { + Writer io.Writer + Name string + Ctx context.Context + TotalDuration time.Duration +} + +func (tw *TimedWriter) Write(p []byte) (n int, err error) { + start := time.Now() + n, err = tw.Writer.Write(p) + duration := time.Since(start) + tw.TotalDuration += duration + + logging.Debug(tw.Ctx, "Write operation completed", + "name", tw.Name, + "bytes", n, + "duration", duration, + "totalWriteDuration", tw.TotalDuration, + "error", err) + + return n, err +} diff --git a/internal/lambda-managed-instances/utils/io_test.go b/internal/lambda-managed-instances/utils/io_test.go new file mode 100644 index 0000000..8547c75 --- /dev/null +++ b/internal/lambda-managed-instances/utils/io_test.go @@ -0,0 +1,95 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package utils + +import ( + "context" + "io" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +type mockReaderWithSleep struct { + reader io.Reader + sleepTime time.Duration +} + +func (m *mockReaderWithSleep) Read(p []byte) (n int, err error) { + time.Sleep(m.sleepTime) + return m.reader.Read(p) +} + +type mockWriterWithSleep struct { + writer io.Writer + sleepTime time.Duration +} + +func (m *mockWriterWithSleep) Write(p []byte) (n int, err error) { + time.Sleep(m.sleepTime) + return m.writer.Write(p) +} + +func TestTimedReader(t *testing.T) { + t.Parallel() + + testData := "test data for reading" + mockReader := &mockReaderWithSleep{ + reader: strings.NewReader(testData), + sleepTime: 1 * time.Nanosecond, + } + + timedReader := &TimedReader{ + Reader: mockReader, + Ctx: context.Background(), + } + + testStart := time.Now() + + buffer := make([]byte, len(testData)) + n, err := timedReader.Read(buffer) + + testDuration := time.Since(testStart) + + assert.NoError(t, err) + assert.Equal(t, len(testData), n) + assert.Equal(t, testData, string(buffer[:n])) + + assert.GreaterOrEqual(t, timedReader.TotalDuration, 1*time.Nanosecond, + "TotalTime should be greater than 1ns (the sleep duration)") + assert.LessOrEqual(t, timedReader.TotalDuration, testDuration, + "TotalTime should be less than total measured test duration") +} + +func TestTimedWriter(t *testing.T) { + t.Parallel() + + mockWriter := &mockWriterWithSleep{ + writer: io.Discard, + sleepTime: 1 * time.Nanosecond, + } + + ctx := context.Background() + timedWriter := &TimedWriter{ + Writer: mockWriter, + Ctx: ctx, + } + + testStart := time.Now() + + testData := []byte("test data for writing") + n, err := timedWriter.Write(testData) + + testDuration := time.Since(testStart) + + assert.NoError(t, err) + assert.Equal(t, len(testData), n) + + assert.GreaterOrEqual(t, timedWriter.TotalDuration, 1*time.Nanosecond, + "TotalTime should be greater than 1ns (the sleep duration)") + assert.LessOrEqual(t, timedWriter.TotalDuration, testDuration, + "TotalTime should be less than total measured test duration") +} diff --git a/internal/lambda-managed-instances/utils/mock_file_util.go b/internal/lambda-managed-instances/utils/mock_file_util.go new file mode 100644 index 0000000..6c5d171 --- /dev/null +++ b/internal/lambda-managed-instances/utils/mock_file_util.go @@ -0,0 +1,101 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package utils + +import ( + fs "io/fs" + + mock "github.com/stretchr/testify/mock" +) + +type MockFileUtil struct { + mock.Mock +} + +func (_m *MockFileUtil) IsNotExist(err error) bool { + ret := _m.Called(err) + + if len(ret) == 0 { + panic("no return value specified for IsNotExist") + } + + var r0 bool + if rf, ok := ret.Get(0).(func(error) bool); ok { + r0 = rf(err) + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +func (_m *MockFileUtil) ReadDirectory(dirPath string) ([]fs.DirEntry, error) { + ret := _m.Called(dirPath) + + if len(ret) == 0 { + panic("no return value specified for ReadDirectory") + } + + var r0 []fs.DirEntry + var r1 error + if rf, ok := ret.Get(0).(func(string) ([]fs.DirEntry, error)); ok { + return rf(dirPath) + } + if rf, ok := ret.Get(0).(func(string) []fs.DirEntry); ok { + r0 = rf(dirPath) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]fs.DirEntry) + } + } + + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(dirPath) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +func (_m *MockFileUtil) Stat(name string) (fs.FileInfo, error) { + ret := _m.Called(name) + + if len(ret) == 0 { + panic("no return value specified for Stat") + } + + var r0 fs.FileInfo + var r1 error + if rf, ok := ret.Get(0).(func(string) (fs.FileInfo, error)); ok { + return rf(name) + } + if rf, ok := ret.Get(0).(func(string) fs.FileInfo); ok { + r0 = rf(name) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(fs.FileInfo) + } + } + + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(name) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +func NewMockFileUtil(t interface { + mock.TestingT + Cleanup(func()) +}) *MockFileUtil { + mock := &MockFileUtil{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/lambda-managed-instances/utils/mocks.go b/internal/lambda-managed-instances/utils/mocks.go new file mode 100644 index 0000000..7a85bc3 --- /dev/null +++ b/internal/lambda-managed-instances/utils/mocks.go @@ -0,0 +1,63 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package utils + +import ( + "os" + "time" + + "github.com/stretchr/testify/mock" +) + +type MockDirEntry struct { + name string + isDir bool +} + +func NewMockDirEntry(name string, isDir bool) MockDirEntry { + return MockDirEntry{name: name, isDir: isDir} +} + +func (m MockDirEntry) Name() string { return m.name } +func (m MockDirEntry) IsDir() bool { return m.isDir } +func (m MockDirEntry) Type() os.FileMode { return 0 } +func (m MockDirEntry) Info() (os.FileInfo, error) { return nil, nil } + +type MockFileInfo struct { + mock.Mock +} + +func NewMockFileInfo() *MockFileInfo { + return &MockFileInfo{} +} + +func (m *MockFileInfo) Name() string { + args := m.Called() + return args.String(0) +} + +func (m *MockFileInfo) Size() int64 { + args := m.Called() + return args.Get(0).(int64) +} + +func (m *MockFileInfo) Mode() os.FileMode { + args := m.Called() + return args.Get(0).(os.FileMode) +} + +func (m *MockFileInfo) ModTime() time.Time { + args := m.Called() + return args.Get(0).(time.Time) +} + +func (m *MockFileInfo) IsDir() bool { + args := m.Called() + return args.Bool(0) +} + +func (m *MockFileInfo) Sys() interface{} { + args := m.Called() + return args.Get(0) +} diff --git a/internal/lambda/appctx/appctxutil_test.go b/internal/lambda/appctx/appctxutil_test.go index bb78583..daa3bcc 100644 --- a/internal/lambda/appctx/appctxutil_test.go +++ b/internal/lambda/appctx/appctxutil_test.go @@ -8,9 +8,9 @@ import ( "strings" "testing" - "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/fatalerror" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/fatalerror" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/interop" ) diff --git a/internal/lambda/core/directinvoke/directinvoke.go b/internal/lambda/core/directinvoke/directinvoke.go index e6f8304..6cd4f13 100644 --- a/internal/lambda/core/directinvoke/directinvoke.go +++ b/internal/lambda/core/directinvoke/directinvoke.go @@ -11,11 +11,11 @@ import ( "strconv" "strings" + "github.com/go-chi/chi/v5" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/core/bandwidthlimiter" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/fatalerror" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/interop" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/metering" - "github.com/go-chi/chi/v5" log "github.com/sirupsen/logrus" ) diff --git a/internal/lambda/core/directinvoke/directinvoke_test.go b/internal/lambda/core/directinvoke/directinvoke_test.go index 7ab6156..d15301a 100644 --- a/internal/lambda/core/directinvoke/directinvoke_test.go +++ b/internal/lambda/core/directinvoke/directinvoke_test.go @@ -17,12 +17,12 @@ import ( "testing" "time" - "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/fatalerror" - "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/interop" - "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/metering" "github.com/go-chi/chi/v5" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/fatalerror" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/interop" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/metering" ) func NewResponseWriterWithoutFlushMethod() *ResponseWriterWithoutFlushMethod { diff --git a/internal/lambda/core/externalagent_states_test.go b/internal/lambda/core/externalagent_states_test.go index 37d1612..198b186 100644 --- a/internal/lambda/core/externalagent_states_test.go +++ b/internal/lambda/core/externalagent_states_test.go @@ -7,8 +7,8 @@ import ( "errors" "testing" - "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/testdata/mockthread" "github.com/stretchr/testify/require" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/testdata/mockthread" ) func TestExternalAgentStateUnknownEventType(t *testing.T) { diff --git a/internal/lambda/core/internalagent_states_test.go b/internal/lambda/core/internalagent_states_test.go index fa3e2c8..a259166 100644 --- a/internal/lambda/core/internalagent_states_test.go +++ b/internal/lambda/core/internalagent_states_test.go @@ -6,8 +6,8 @@ package core import ( "testing" - "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/testdata/mockthread" "github.com/stretchr/testify/require" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/testdata/mockthread" ) func TestInternalAgentStateUnknownEventType(t *testing.T) { diff --git a/internal/lambda/core/states_test.go b/internal/lambda/core/states_test.go index 1892e64..2e9f532 100644 --- a/internal/lambda/core/states_test.go +++ b/internal/lambda/core/states_test.go @@ -8,10 +8,10 @@ import ( "sync" "testing" - "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/interop" - "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/testdata/mockthread" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/interop" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/testdata/mockthread" ) func TestRuntimeInitErrorAfterReady(t *testing.T) { diff --git a/internal/lambda/interop/events_api_test.go b/internal/lambda/interop/events_api_test.go index 06ed5f0..af1396a 100644 --- a/internal/lambda/interop/events_api_test.go +++ b/internal/lambda/interop/events_api_test.go @@ -7,9 +7,9 @@ import ( "encoding/json" "testing" - "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/rapi/model" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/rapi/model" ) const requestID RequestID = "REQUEST_ID" diff --git a/internal/lambda/rapi/extensions_fuzz_test.go b/internal/lambda/rapi/extensions_fuzz_test.go index 1845a06..56b65e6 100644 --- a/internal/lambda/rapi/extensions_fuzz_test.go +++ b/internal/lambda/rapi/extensions_fuzz_test.go @@ -13,6 +13,8 @@ import ( "net/http/httptest" "testing" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/appctx" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/core" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/extensions" @@ -22,8 +24,6 @@ import ( "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/rapi/rendering" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/telemetry" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/testdata" - "github.com/google/uuid" - "github.com/stretchr/testify/assert" ) func FuzzAgentRegisterHandler(f *testing.F) { diff --git a/internal/lambda/rapi/handler/agentexiterror.go b/internal/lambda/rapi/handler/agentexiterror.go index 3866a2f..022a03d 100644 --- a/internal/lambda/rapi/handler/agentexiterror.go +++ b/internal/lambda/rapi/handler/agentexiterror.go @@ -6,11 +6,11 @@ package handler import ( "net/http" + "github.com/google/uuid" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/appctx" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/core" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/fatalerror" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/rapi/rendering" - "github.com/google/uuid" log "github.com/sirupsen/logrus" ) diff --git a/internal/lambda/rapi/handler/agentiniterror.go b/internal/lambda/rapi/handler/agentiniterror.go index aa37ba6..0310e48 100644 --- a/internal/lambda/rapi/handler/agentiniterror.go +++ b/internal/lambda/rapi/handler/agentiniterror.go @@ -6,11 +6,11 @@ package handler import ( "net/http" + "github.com/google/uuid" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/appctx" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/core" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/fatalerror" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/rapi/rendering" - "github.com/google/uuid" log "github.com/sirupsen/logrus" ) diff --git a/internal/lambda/rapi/handler/agentiniterror_test.go b/internal/lambda/rapi/handler/agentiniterror_test.go index e91a758..09f5347 100644 --- a/internal/lambda/rapi/handler/agentiniterror_test.go +++ b/internal/lambda/rapi/handler/agentiniterror_test.go @@ -11,12 +11,12 @@ import ( "net/http/httptest" "testing" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/appctx" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/core" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/fatalerror" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/rapi/model" - "github.com/google/uuid" - "github.com/stretchr/testify/assert" ) func newRequest(appCtx appctx.ApplicationContext, agentID uuid.UUID) *http.Request { diff --git a/internal/lambda/rapi/handler/agentnext.go b/internal/lambda/rapi/handler/agentnext.go index 73c69b3..15dc564 100644 --- a/internal/lambda/rapi/handler/agentnext.go +++ b/internal/lambda/rapi/handler/agentnext.go @@ -6,10 +6,10 @@ package handler import ( "net/http" - "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/core" - "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/rapi/rendering" "github.com/google/uuid" log "github.com/sirupsen/logrus" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/core" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/rapi/rendering" ) // A CtxKey type is used as a key for storing values in the request context. diff --git a/internal/lambda/rapi/handler/agentnext_test.go b/internal/lambda/rapi/handler/agentnext_test.go index 4be8291..d76cc8d 100644 --- a/internal/lambda/rapi/handler/agentnext_test.go +++ b/internal/lambda/rapi/handler/agentnext_test.go @@ -15,14 +15,14 @@ import ( "testing" "time" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/core" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/interop" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/metering" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/rapi/model" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/rapi/rendering" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/telemetry" - "github.com/google/uuid" - "github.com/stretchr/testify/assert" ) func TestRenderAgentInternalError(t *testing.T) { diff --git a/internal/lambda/rapi/handler/agentregister.go b/internal/lambda/rapi/handler/agentregister.go index 8b4e178..6e3c1f2 100644 --- a/internal/lambda/rapi/handler/agentregister.go +++ b/internal/lambda/rapi/handler/agentregister.go @@ -10,10 +10,10 @@ import ( "net/http" "strings" + log "github.com/sirupsen/logrus" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/core" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/rapi/model" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/rapi/rendering" - log "github.com/sirupsen/logrus" ) type agentRegisterHandler struct { diff --git a/internal/lambda/rapi/handler/agentregister_test.go b/internal/lambda/rapi/handler/agentregister_test.go index b2f02d4..0529ceb 100644 --- a/internal/lambda/rapi/handler/agentregister_test.go +++ b/internal/lambda/rapi/handler/agentregister_test.go @@ -13,9 +13,9 @@ import ( "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/core" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/rapi/model" - "github.com/stretchr/testify/require" ) func registerRequestReader(req RegisterRequest) io.Reader { diff --git a/internal/lambda/rapi/handler/credentials_test.go b/internal/lambda/rapi/handler/credentials_test.go index 29c6924..2f52537 100644 --- a/internal/lambda/rapi/handler/credentials_test.go +++ b/internal/lambda/rapi/handler/credentials_test.go @@ -11,9 +11,9 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/appctx" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/testdata" - "github.com/stretchr/testify/assert" ) const InitCachingToken = "sampleInitCachingToken" diff --git a/internal/lambda/rapi/handler/initerror_test.go b/internal/lambda/rapi/handler/initerror_test.go index c465ae0..bd1936f 100644 --- a/internal/lambda/rapi/handler/initerror_test.go +++ b/internal/lambda/rapi/handler/initerror_test.go @@ -10,10 +10,10 @@ import ( "net/http/httptest" "testing" + "github.com/stretchr/testify/require" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/appctx" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/fatalerror" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/testdata" - "github.com/stretchr/testify/require" ) // TestInitErrorHandler tests that API handler for diff --git a/internal/lambda/rapi/handler/invocationresponse_test.go b/internal/lambda/rapi/handler/invocationresponse_test.go index bc3d80b..9ce0a3d 100644 --- a/internal/lambda/rapi/handler/invocationresponse_test.go +++ b/internal/lambda/rapi/handler/invocationresponse_test.go @@ -15,11 +15,11 @@ import ( "testing" "github.com/aws/aws-lambda-go/events/test" + "github.com/stretchr/testify/assert" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/appctx" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/fatalerror" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/interop" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/testdata" - "github.com/stretchr/testify/assert" ) func TestResponseTooLarge(t *testing.T) { diff --git a/internal/lambda/rapi/handler/restoreerror.go b/internal/lambda/rapi/handler/restoreerror.go index e39d7c4..e4241a6 100644 --- a/internal/lambda/rapi/handler/restoreerror.go +++ b/internal/lambda/rapi/handler/restoreerror.go @@ -6,12 +6,12 @@ package handler import ( "net/http" + log "github.com/sirupsen/logrus" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/appctx" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/core" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/fatalerror" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/interop" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/rapi/rendering" - log "github.com/sirupsen/logrus" ) type restoreErrorHandler struct { diff --git a/internal/lambda/rapi/handler/restoreerror_test.go b/internal/lambda/rapi/handler/restoreerror_test.go index c8f630c..b1d5a25 100644 --- a/internal/lambda/rapi/handler/restoreerror_test.go +++ b/internal/lambda/rapi/handler/restoreerror_test.go @@ -10,9 +10,9 @@ import ( "net/http/httptest" "testing" + "github.com/stretchr/testify/require" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/appctx" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/testdata" - "github.com/stretchr/testify/require" ) func TestRestoreErrorHandler(t *testing.T) { diff --git a/internal/lambda/rapi/handler/restorenext.go b/internal/lambda/rapi/handler/restorenext.go index 3a5922a..31606fe 100644 --- a/internal/lambda/rapi/handler/restorenext.go +++ b/internal/lambda/rapi/handler/restorenext.go @@ -6,9 +6,9 @@ package handler import ( "net/http" + log "github.com/sirupsen/logrus" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/core" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/rapi/rendering" - log "github.com/sirupsen/logrus" ) type restoreNextHandler struct { diff --git a/internal/lambda/rapi/handler/restorenext_test.go b/internal/lambda/rapi/handler/restorenext_test.go index d25d1b4..d0af65d 100644 --- a/internal/lambda/rapi/handler/restorenext_test.go +++ b/internal/lambda/rapi/handler/restorenext_test.go @@ -11,11 +11,11 @@ import ( "strings" "testing" + "github.com/stretchr/testify/assert" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/appctx" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/interop" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/telemetry" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/testdata" - "github.com/stretchr/testify/assert" ) func TestRenderRestoreNext(t *testing.T) { diff --git a/internal/lambda/rapi/handler/runtimelogs_stub.go b/internal/lambda/rapi/handler/runtimelogs_stub.go index 4f2e5e3..5c462be 100644 --- a/internal/lambda/rapi/handler/runtimelogs_stub.go +++ b/internal/lambda/rapi/handler/runtimelogs_stub.go @@ -6,9 +6,9 @@ package handler import ( "net/http" + log "github.com/sirupsen/logrus" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/rapi/model" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/rapi/rendering" - log "github.com/sirupsen/logrus" ) const ( diff --git a/internal/lambda/rapi/middleware/middleware.go b/internal/lambda/rapi/middleware/middleware.go index 0792edc..bd93a72 100644 --- a/internal/lambda/rapi/middleware/middleware.go +++ b/internal/lambda/rapi/middleware/middleware.go @@ -7,13 +7,13 @@ import ( "context" "net/http" + "github.com/google/uuid" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/extensions" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/rapi/handler" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/rapi/rendering" - "github.com/google/uuid" - "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/appctx" "github.com/go-chi/chi/v5" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/appctx" log "github.com/sirupsen/logrus" ) diff --git a/internal/lambda/rapi/middleware/middleware_test.go b/internal/lambda/rapi/middleware/middleware_test.go index 2ebad22..4237aa1 100644 --- a/internal/lambda/rapi/middleware/middleware_test.go +++ b/internal/lambda/rapi/middleware/middleware_test.go @@ -12,13 +12,13 @@ import ( "net/http/httptest" "testing" + "github.com/go-chi/chi/v5" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/appctx" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/extensions" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/rapi/handler" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/rapi/model" - "github.com/go-chi/chi/v5" - "github.com/google/uuid" - "github.com/stretchr/testify/assert" ) type mockHandler struct{} diff --git a/internal/lambda/rapi/rapi_fuzz_test.go b/internal/lambda/rapi/rapi_fuzz_test.go index 37582a8..017a11c 100644 --- a/internal/lambda/rapi/rapi_fuzz_test.go +++ b/internal/lambda/rapi/rapi_fuzz_test.go @@ -20,13 +20,13 @@ import ( "testing" "unicode" + "github.com/stretchr/testify/assert" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/appctx" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/extensions" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/fatalerror" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/interop" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/telemetry" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/testdata" - "github.com/stretchr/testify/assert" ) type runtimeFunctionErrStruct struct { diff --git a/internal/lambda/rapi/rendering/render_error.go b/internal/lambda/rapi/rendering/render_error.go index e6a5fe3..f71c761 100644 --- a/internal/lambda/rapi/rendering/render_error.go +++ b/internal/lambda/rapi/rendering/render_error.go @@ -7,9 +7,9 @@ import ( "fmt" "net/http" + log "github.com/sirupsen/logrus" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/interop" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/rapi/model" - log "github.com/sirupsen/logrus" ) // RenderForbiddenWithTypeMsg method for rendering error response diff --git a/internal/lambda/rapi/server.go b/internal/lambda/rapi/server.go index cbb5307..93a9ba1 100644 --- a/internal/lambda/rapi/server.go +++ b/internal/lambda/rapi/server.go @@ -9,8 +9,8 @@ import ( "net" "net/http" - "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/appctx" "github.com/go-chi/chi/v5" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/appctx" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/core" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/interop" diff --git a/internal/lambda/rapi/telemetry_logs_fuzz_test.go b/internal/lambda/rapi/telemetry_logs_fuzz_test.go index bffdba6..046db49 100644 --- a/internal/lambda/rapi/telemetry_logs_fuzz_test.go +++ b/internal/lambda/rapi/telemetry_logs_fuzz_test.go @@ -11,13 +11,13 @@ import ( "net/http/httptest" "testing" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/extensions" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/interop" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/rapi/handler" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/telemetry" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/testdata" - "github.com/google/uuid" - "github.com/stretchr/testify/assert" ) const ( diff --git a/internal/lambda/rapidcore/server_test.go b/internal/lambda/rapidcore/server_test.go index 2c815ad..bfb0b6b 100644 --- a/internal/lambda/rapidcore/server_test.go +++ b/internal/lambda/rapidcore/server_test.go @@ -12,10 +12,10 @@ import ( "testing" "time" + "github.com/stretchr/testify/require" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/core/statejson" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/interop" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/rapidcore/env" - "github.com/stretchr/testify/require" ) func waitForChanWithTimeout(channel <-chan error, timeout time.Duration) error { diff --git a/internal/lambda/rapidcore/standalone/directInvokeHandler.go b/internal/lambda/rapidcore/standalone/directInvokeHandler.go index bf74128..540538c 100644 --- a/internal/lambda/rapidcore/standalone/directInvokeHandler.go +++ b/internal/lambda/rapidcore/standalone/directInvokeHandler.go @@ -8,8 +8,8 @@ import ( "net/http" - "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/core/directinvoke" log "github.com/sirupsen/logrus" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/core/directinvoke" ) func DirectInvokeHandler(w http.ResponseWriter, r *http.Request, s InteropServer) { diff --git a/internal/lambda/rapidcore/standalone/executeHandler.go b/internal/lambda/rapidcore/standalone/executeHandler.go index a21ca00..6587f70 100644 --- a/internal/lambda/rapidcore/standalone/executeHandler.go +++ b/internal/lambda/rapidcore/standalone/executeHandler.go @@ -6,10 +6,10 @@ package standalone import ( "net/http" + log "github.com/sirupsen/logrus" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/interop" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/metering" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/rapidcore" - log "github.com/sirupsen/logrus" ) func Execute(w http.ResponseWriter, r *http.Request, sandbox rapidcore.LambdaInvokeAPI) { diff --git a/internal/lambda/rapidcore/standalone/reserveHandler.go b/internal/lambda/rapidcore/standalone/reserveHandler.go index a56818e..9ac41ef 100644 --- a/internal/lambda/rapidcore/standalone/reserveHandler.go +++ b/internal/lambda/rapidcore/standalone/reserveHandler.go @@ -6,10 +6,10 @@ package standalone import ( "net/http" + log "github.com/sirupsen/logrus" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/core/directinvoke" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/interop" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/rapidcore" - log "github.com/sirupsen/logrus" ) const ( diff --git a/internal/lambda/rapidcore/standalone/restoreHandler.go b/internal/lambda/rapidcore/standalone/restoreHandler.go index 3b38210..ec6a734 100644 --- a/internal/lambda/rapidcore/standalone/restoreHandler.go +++ b/internal/lambda/rapidcore/standalone/restoreHandler.go @@ -9,8 +9,8 @@ import ( "strconv" "time" - "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/interop" log "github.com/sirupsen/logrus" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/interop" ) type RestoreBody struct { diff --git a/internal/lambda/rapidcore/standalone/util.go b/internal/lambda/rapidcore/standalone/util.go index ea46c56..ae8348d 100644 --- a/internal/lambda/rapidcore/standalone/util.go +++ b/internal/lambda/rapidcore/standalone/util.go @@ -9,8 +9,8 @@ import ( "io" "net/http" - "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/rapi/model" log "github.com/sirupsen/logrus" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/rapi/model" ) const ( diff --git a/internal/lambda/rie/http.go b/internal/lambda/rie/http.go index deaaf59..4bc2d5a 100644 --- a/internal/lambda/rie/http.go +++ b/internal/lambda/rie/http.go @@ -6,9 +6,9 @@ package rie import ( "net/http" + log "github.com/sirupsen/logrus" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/interop" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/rapidcore" - log "github.com/sirupsen/logrus" ) func startHTTPServer(ipport string, sandbox *rapidcore.SandboxBuilder, bs interop.Bootstrap) { diff --git a/internal/lambda/rie/run.go b/internal/lambda/rie/run.go index 8c14fda..4b3c9ea 100644 --- a/internal/lambda/rie/run.go +++ b/internal/lambda/rie/run.go @@ -10,9 +10,9 @@ import ( "os" "runtime/debug" + "github.com/jessevdk/go-flags" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/interop" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/rapidcore" - "github.com/jessevdk/go-flags" log "github.com/sirupsen/logrus" ) diff --git a/internal/lambda/supervisor/local_supervisor.go b/internal/lambda/supervisor/local_supervisor.go index 9e80782..2eb9f57 100644 --- a/internal/lambda/supervisor/local_supervisor.go +++ b/internal/lambda/supervisor/local_supervisor.go @@ -13,8 +13,8 @@ import ( "syscall" "time" - "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/supervisor/model" log "github.com/sirupsen/logrus" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/supervisor/model" ) // typecheck interface compliance diff --git a/internal/lambda/supervisor/local_supervisor_test.go b/internal/lambda/supervisor/local_supervisor_test.go index 1ff6c8f..de1db34 100644 --- a/internal/lambda/supervisor/local_supervisor_test.go +++ b/internal/lambda/supervisor/local_supervisor_test.go @@ -11,10 +11,10 @@ import ( "testing" "time" - "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/supervisor/model" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/supervisor/model" ) func TestRuntimeDomainExec(t *testing.T) { diff --git a/internal/lambda/telemetry/events_api_test.go b/internal/lambda/telemetry/events_api_test.go index decdc15..5ae75a1 100644 --- a/internal/lambda/telemetry/events_api_test.go +++ b/internal/lambda/telemetry/events_api_test.go @@ -7,9 +7,9 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/interop" "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda/metering" - "github.com/stretchr/testify/assert" ) func TestGetRuntimeDoneInvokeMetrics(t *testing.T) { diff --git a/licenses.tpl b/licenses.tpl new file mode 100644 index 0000000..2032a6c --- /dev/null +++ b/licenses.tpl @@ -0,0 +1,11 @@ +{{ range . }} +## {{ .Name }} + +* Name: {{ .Name }} +* Version: {{ .Version }} +* License: [{{ .LicenseName }}]({{ .LicenseURL }}) + +``` +{{ .LicenseText }} +``` +{{ end }} \ No newline at end of file