From 416b804b88af88470292b6c20d1f1291e4bc9c52 Mon Sep 17 00:00:00 2001 From: Bernd Verst Date: Mon, 5 May 2025 15:12:42 +0200 Subject: [PATCH] Add user agent --- .../durabletask/azuremanaged/client.py | 8 +- .../internal/durabletask_grpc_interceptor.py | 21 +++- .../durabletask/azuremanaged/worker.py | 8 +- .../test_durabletask_grpc_interceptor.py | 108 ++++++++++++++++++ 4 files changed, 134 insertions(+), 11 deletions(-) create mode 100644 tests/durabletask-azuremanaged/test_durabletask_grpc_interceptor.py diff --git a/durabletask-azuremanaged/durabletask/azuremanaged/client.py b/durabletask-azuremanaged/durabletask/azuremanaged/client.py index 1d8cecd..e1c2445 100644 --- a/durabletask-azuremanaged/durabletask/azuremanaged/client.py +++ b/durabletask-azuremanaged/durabletask/azuremanaged/client.py @@ -1,11 +1,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from azure.core.credentials import TokenCredential from typing import Optional -from durabletask.azuremanaged.internal.durabletask_grpc_interceptor import \ - DTSDefaultClientInterceptorImpl +from azure.core.credentials import TokenCredential + +from durabletask.azuremanaged.internal.durabletask_grpc_interceptor import ( + DTSDefaultClientInterceptorImpl, +) from durabletask.client import TaskHubGrpcClient diff --git a/durabletask-azuremanaged/durabletask/azuremanaged/internal/durabletask_grpc_interceptor.py b/durabletask-azuremanaged/durabletask/azuremanaged/internal/durabletask_grpc_interceptor.py index 077905e..fa1459f 100644 --- a/durabletask-azuremanaged/durabletask/azuremanaged/internal/durabletask_grpc_interceptor.py +++ b/durabletask-azuremanaged/durabletask/azuremanaged/internal/durabletask_grpc_interceptor.py @@ -1,15 +1,17 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import grpc +from importlib.metadata import version from typing import Optional +import grpc from azure.core.credentials import TokenCredential -from durabletask.azuremanaged.internal.access_token_manager import \ - AccessTokenManager +from durabletask.azuremanaged.internal.access_token_manager import AccessTokenManager from durabletask.internal.grpc_interceptor import ( - DefaultClientInterceptorImpl, _ClientCallDetails) + DefaultClientInterceptorImpl, + _ClientCallDetails, +) class DTSDefaultClientInterceptorImpl (DefaultClientInterceptorImpl): @@ -18,7 +20,16 @@ class DTSDefaultClientInterceptorImpl (DefaultClientInterceptorImpl): interceptor to add additional headers to all calls as needed.""" def __init__(self, token_credential: Optional[TokenCredential], taskhub_name: str): - self._metadata = [("taskhub", taskhub_name)] + try: + # Get the version of the azuremanaged package + sdk_version = version('durabletask-azuremanaged') + except Exception: + # Fallback if version cannot be determined + sdk_version = "unknown" + user_agent = f"durabletask-python/{sdk_version}" + self._metadata = [ + ("taskhub", taskhub_name), + ("x-user-agent", user_agent)] # 'user-agent' is a reserved header in grpc, so we use 'x-user-agent' instead super().__init__(self._metadata) if token_credential is not None: diff --git a/durabletask-azuremanaged/durabletask/azuremanaged/worker.py b/durabletask-azuremanaged/durabletask/azuremanaged/worker.py index 8bdff3d..fd3b1e4 100644 --- a/durabletask-azuremanaged/durabletask/azuremanaged/worker.py +++ b/durabletask-azuremanaged/durabletask/azuremanaged/worker.py @@ -1,11 +1,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from azure.core.credentials import TokenCredential from typing import Optional -from durabletask.azuremanaged.internal.durabletask_grpc_interceptor import \ - DTSDefaultClientInterceptorImpl +from azure.core.credentials import TokenCredential + +from durabletask.azuremanaged.internal.durabletask_grpc_interceptor import ( + DTSDefaultClientInterceptorImpl, +) from durabletask.worker import TaskHubGrpcWorker diff --git a/tests/durabletask-azuremanaged/test_durabletask_grpc_interceptor.py b/tests/durabletask-azuremanaged/test_durabletask_grpc_interceptor.py new file mode 100644 index 0000000..62978f9 --- /dev/null +++ b/tests/durabletask-azuremanaged/test_durabletask_grpc_interceptor.py @@ -0,0 +1,108 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import threading +import unittest +from concurrent import futures +from importlib.metadata import version + +import grpc + +from durabletask.azuremanaged.client import DurableTaskSchedulerClient +from durabletask.azuremanaged.internal.durabletask_grpc_interceptor import ( + DTSDefaultClientInterceptorImpl, +) +from durabletask.internal import orchestrator_service_pb2 as pb +from durabletask.internal import orchestrator_service_pb2_grpc as stubs + + +class MockTaskHubSidecarServiceServicer(stubs.TaskHubSidecarServiceServicer): + """Mock implementation of the TaskHubSidecarService for testing.""" + + def __init__(self): + self.captured_metadata = {} + self.requests_received = 0 + + def GetInstance(self, request, context): + """Implementation of GetInstance that captures the metadata.""" + # Store all metadata key-value pairs from the context + for key, value in context.invocation_metadata(): + self.captured_metadata[key] = value + + self.requests_received += 1 + + # Return a mock response + response = pb.GetInstanceResponse(exists=False) + return response + + +class TestDurableTaskGrpcInterceptor(unittest.TestCase): + """Tests for the DTSDefaultClientInterceptorImpl class.""" + + @classmethod + def setUpClass(cls): + # Start a real gRPC server on a free port + cls.server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + cls.port = cls.server.add_insecure_port('[::]:0') # Bind to a random free port + cls.server_address = f"localhost:{cls.port}" + + # Add our mock service implementation to the server + cls.mock_servicer = MockTaskHubSidecarServiceServicer() + stubs.add_TaskHubSidecarServiceServicer_to_server(cls.mock_servicer, cls.server) + + # Start the server in a background thread + cls.server.start() + + @classmethod + def tearDownClass(cls): + cls.server.stop(grace=None) + + def test_user_agent_metadata_passed_in_request(self): + """Test that the user agent metadata is correctly passed in gRPC requests.""" + # Create a client that connects to our mock server + # Note: secure_channel is False and token_credential is None as specified + task_hub_client = DurableTaskSchedulerClient( + host_address=self.server_address, + secure_channel=False, + taskhub="test-taskhub", + token_credential=None + ) + + # Make a client call that will trigger our interceptor + task_hub_client.get_orchestration_state("test-instance-id") + + # Verify the request was received by our mock server + self.assertEqual(1, self.mock_servicer.requests_received, "Expected one request to be received") + + # Check if our custom x-user-agent header was correctly set + self.assertIn("x-user-agent", self.mock_servicer.captured_metadata, "x-user-agent header not found") + + # Get what we expect our user agent to be + try: + expected_version = version('durabletask-azuremanaged') + except Exception: + expected_version = "unknown" + + expected_user_agent = f"durabletask-python/{expected_version}" + self.assertEqual( + expected_user_agent, + self.mock_servicer.captured_metadata["x-user-agent"], + f"Expected x-user-agent header to be '{expected_user_agent}'" + ) + + # Check if the taskhub header was correctly set + self.assertIn("taskhub", self.mock_servicer.captured_metadata, "taskhub header not found") + self.assertEqual("test-taskhub", self.mock_servicer.captured_metadata["taskhub"]) + + # Verify the standard gRPC user-agent is different from our custom one + # Note: gRPC automatically adds its own "user-agent" header + self.assertIn("user-agent", self.mock_servicer.captured_metadata, "gRPC user-agent header not found") + self.assertNotEqual( + self.mock_servicer.captured_metadata["user-agent"], + self.mock_servicer.captured_metadata["x-user-agent"], + "gRPC user-agent should be different from our custom x-user-agent" + ) + + +if __name__ == "__main__": + unittest.main()