Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions durabletask/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,13 @@ def get_orchestration_state(self, instance_id: str, *, fetch_payloads: bool = Tr

def wait_for_orchestration_start(self, instance_id: str, *,
fetch_payloads: bool = False,
timeout: int = 60) -> Optional[OrchestrationState]:
timeout: int = 0) -> Optional[OrchestrationState]:
req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
try:
self._logger.info(f"Waiting up to {timeout}s for instance '{instance_id}' to start.")
res: pb.GetInstanceResponse = self._stub.WaitForInstanceStart(req, timeout=timeout)
grpc_timeout = None if timeout == 0 else timeout
self._logger.info(
f"Waiting {'indefinitely' if timeout == 0 else f'up to {timeout}s'} for instance '{instance_id}' to start.")
res: pb.GetInstanceResponse = self._stub.WaitForInstanceStart(req, timeout=grpc_timeout)
return new_orchestration_state(req.instanceId, res)
except grpc.RpcError as rpc_error:
if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: # type: ignore
Expand All @@ -144,11 +146,13 @@ def wait_for_orchestration_start(self, instance_id: str, *,

def wait_for_orchestration_completion(self, instance_id: str, *,
fetch_payloads: bool = True,
timeout: int = 60) -> Optional[OrchestrationState]:
timeout: int = 0) -> Optional[OrchestrationState]:
req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
try:
self._logger.info(f"Waiting {timeout}s for instance '{instance_id}' to complete.")
res: pb.GetInstanceResponse = self._stub.WaitForInstanceCompletion(req, timeout=timeout)
grpc_timeout = None if timeout == 0 else timeout
self._logger.info(
f"Waiting {'indefinitely' if timeout == 0 else f'up to {timeout}s'} for instance '{instance_id}' to complete.")
res: pb.GetInstanceResponse = self._stub.WaitForInstanceCompletion(req, timeout=grpc_timeout)
state = new_orchestration_state(req.instanceId, res)
if not state:
return None
Expand Down
63 changes: 61 additions & 2 deletions tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from unittest.mock import patch, ANY
from unittest.mock import patch, ANY, Mock

from durabletask.client import TaskHubGrpcClient
from durabletask.internal.shared import (DefaultClientInterceptorImpl,
get_default_host_address,
get_grpc_channel)
import pytest

HOST_ADDRESS = 'localhost:50051'
METADATA = [('key1', 'value1'), ('key2', 'value2')]
Expand Down Expand Up @@ -85,4 +87,61 @@ def test_grpc_channel_with_host_name_protocol_stripping():

prefix = ""
get_grpc_channel(prefix + host_name, METADATA, True)
mock_secure_channel.assert_called_with(host_name, ANY)
mock_secure_channel.assert_called_with(host_name, ANY)


@pytest.mark.parametrize("timeout", [None, 0, 5])
def test_wait_for_orchestration_start_timeout(timeout):
instance_id = "test-instance"

from durabletask.internal.orchestrator_service_pb2 import GetInstanceResponse, \
OrchestrationState, ORCHESTRATION_STATUS_RUNNING

response = GetInstanceResponse()
state = OrchestrationState()
state.instanceId = instance_id
state.orchestrationStatus = ORCHESTRATION_STATUS_RUNNING
response.orchestrationState.CopyFrom(state)

c = TaskHubGrpcClient()
c._stub = Mock()
c._stub.WaitForInstanceStart.return_value = response

grpc_timeout = None if timeout is None else timeout
c.wait_for_orchestration_start(instance_id, timeout=grpc_timeout)

# Verify WaitForInstanceStart was called with timeout=None
c._stub.WaitForInstanceStart.assert_called_once()
_, kwargs = c._stub.WaitForInstanceStart.call_args
if timeout is None or timeout == 0:
assert kwargs.get('timeout') is None
else:
assert kwargs.get('timeout') == timeout

@pytest.mark.parametrize("timeout", [None, 0, 5])
def test_wait_for_orchestration_completion_timeout(timeout):
instance_id = "test-instance"

from durabletask.internal.orchestrator_service_pb2 import GetInstanceResponse, \
OrchestrationState, ORCHESTRATION_STATUS_COMPLETED

response = GetInstanceResponse()
state = OrchestrationState()
state.instanceId = instance_id
state.orchestrationStatus = ORCHESTRATION_STATUS_COMPLETED
response.orchestrationState.CopyFrom(state)

c = TaskHubGrpcClient()
c._stub = Mock()
c._stub.WaitForInstanceCompletion.return_value = response

grpc_timeout = None if timeout is None else timeout
c.wait_for_orchestration_completion(instance_id, timeout=grpc_timeout)

# Verify WaitForInstanceStart was called with timeout=None
c._stub.WaitForInstanceCompletion.assert_called_once()
_, kwargs = c._stub.WaitForInstanceCompletion.call_args
if timeout is None or timeout == 0:
assert kwargs.get('timeout') is None
else:
assert kwargs.get('timeout') == timeout
Loading