From f430bc2ec1dbee488297331743b3ab71e1b92d65 Mon Sep 17 00:00:00 2001 From: Elena Kolevska Date: Wed, 19 Feb 2025 00:25:40 +0000 Subject: [PATCH] Removes default timeout for `wait_for_orchestration_start` and `wait_for_orchestration_completion` Signed-off-by: Elena Kolevska --- durabletask/client.py | 16 ++++++----- tests/test_client.py | 63 +++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 71 insertions(+), 8 deletions(-) diff --git a/durabletask/client.py b/durabletask/client.py index 31953ae..fae968d 100644 --- a/durabletask/client.py +++ b/durabletask/client.py @@ -129,11 +129,13 @@ def get_orchestration_state(self, instance_id: str, *, fetch_payloads: bool = Tr def wait_for_orchestration_start(self, instance_id: str, *, fetch_payloads: bool = False, - timeout: int = 60) -> Optional[OrchestrationState]: + timeout: int = 0) -> Optional[OrchestrationState]: req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads) try: - self._logger.info(f"Waiting up to {timeout}s for instance '{instance_id}' to start.") - res: pb.GetInstanceResponse = self._stub.WaitForInstanceStart(req, timeout=timeout) + grpc_timeout = None if timeout == 0 else timeout + self._logger.info( + f"Waiting {'indefinitely' if timeout == 0 else f'up to {timeout}s'} for instance '{instance_id}' to start.") + res: pb.GetInstanceResponse = self._stub.WaitForInstanceStart(req, timeout=grpc_timeout) return new_orchestration_state(req.instanceId, res) except grpc.RpcError as rpc_error: if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: # type: ignore @@ -144,11 +146,13 @@ def wait_for_orchestration_start(self, instance_id: str, *, def wait_for_orchestration_completion(self, instance_id: str, *, fetch_payloads: bool = True, - timeout: int = 60) -> Optional[OrchestrationState]: + timeout: int = 0) -> Optional[OrchestrationState]: req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads) try: - self._logger.info(f"Waiting {timeout}s for instance '{instance_id}' to complete.") - res: pb.GetInstanceResponse = self._stub.WaitForInstanceCompletion(req, timeout=timeout) + grpc_timeout = None if timeout == 0 else timeout + self._logger.info( + f"Waiting {'indefinitely' if timeout == 0 else f'up to {timeout}s'} for instance '{instance_id}' to complete.") + res: pb.GetInstanceResponse = self._stub.WaitForInstanceCompletion(req, timeout=grpc_timeout) state = new_orchestration_state(req.instanceId, res) if not state: return None diff --git a/tests/test_client.py b/tests/test_client.py index caacf65..5990db0 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,8 +1,10 @@ -from unittest.mock import patch, ANY +from unittest.mock import patch, ANY, Mock +from durabletask.client import TaskHubGrpcClient from durabletask.internal.shared import (DefaultClientInterceptorImpl, get_default_host_address, get_grpc_channel) +import pytest HOST_ADDRESS = 'localhost:50051' METADATA = [('key1', 'value1'), ('key2', 'value2')] @@ -85,4 +87,61 @@ def test_grpc_channel_with_host_name_protocol_stripping(): prefix = "" get_grpc_channel(prefix + host_name, METADATA, True) - mock_secure_channel.assert_called_with(host_name, ANY) \ No newline at end of file + mock_secure_channel.assert_called_with(host_name, ANY) + + +@pytest.mark.parametrize("timeout", [None, 0, 5]) +def test_wait_for_orchestration_start_timeout(timeout): + instance_id = "test-instance" + + from durabletask.internal.orchestrator_service_pb2 import GetInstanceResponse, \ + OrchestrationState, ORCHESTRATION_STATUS_RUNNING + + response = GetInstanceResponse() + state = OrchestrationState() + state.instanceId = instance_id + state.orchestrationStatus = ORCHESTRATION_STATUS_RUNNING + response.orchestrationState.CopyFrom(state) + + c = TaskHubGrpcClient() + c._stub = Mock() + c._stub.WaitForInstanceStart.return_value = response + + grpc_timeout = None if timeout is None else timeout + c.wait_for_orchestration_start(instance_id, timeout=grpc_timeout) + + # Verify WaitForInstanceStart was called with timeout=None + c._stub.WaitForInstanceStart.assert_called_once() + _, kwargs = c._stub.WaitForInstanceStart.call_args + if timeout is None or timeout == 0: + assert kwargs.get('timeout') is None + else: + assert kwargs.get('timeout') == timeout + +@pytest.mark.parametrize("timeout", [None, 0, 5]) +def test_wait_for_orchestration_completion_timeout(timeout): + instance_id = "test-instance" + + from durabletask.internal.orchestrator_service_pb2 import GetInstanceResponse, \ + OrchestrationState, ORCHESTRATION_STATUS_COMPLETED + + response = GetInstanceResponse() + state = OrchestrationState() + state.instanceId = instance_id + state.orchestrationStatus = ORCHESTRATION_STATUS_COMPLETED + response.orchestrationState.CopyFrom(state) + + c = TaskHubGrpcClient() + c._stub = Mock() + c._stub.WaitForInstanceCompletion.return_value = response + + grpc_timeout = None if timeout is None else timeout + c.wait_for_orchestration_completion(instance_id, timeout=grpc_timeout) + + # Verify WaitForInstanceStart was called with timeout=None + c._stub.WaitForInstanceCompletion.assert_called_once() + _, kwargs = c._stub.WaitForInstanceCompletion.call_args + if timeout is None or timeout == 0: + assert kwargs.get('timeout') is None + else: + assert kwargs.get('timeout') == timeout \ No newline at end of file