diff --git a/.github/workflows/pr-validation.yml b/.github/workflows/pr-validation.yml index 63540ac..33de31f 100644 --- a/.github/workflows/pr-validation.yml +++ b/.github/workflows/pr-validation.yml @@ -28,7 +28,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install flake8 pytest + pip install flake8 pytest pytest-cov pytest-asyncio pip install -r requirements.txt - name: Lint with flake8 run: | diff --git a/dev-requirements.txt b/dev-requirements.txt index 119f072..ba589ab 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1 +1 @@ -grpcio-tools==1.62.3 # 1.62.X is the latest version before protobuf 1.26.X is used which has breaking changes for Python +grpcio-tools==1.62.3 # 1.62.X is the latest version before protobuf 1.26.X is used which has breaking changes for Python # supports protobuf 6.x and aligns with generated code \ No newline at end of file diff --git a/durabletask/aio/__init__.py b/durabletask/aio/__init__.py new file mode 100644 index 0000000..d446228 --- /dev/null +++ b/durabletask/aio/__init__.py @@ -0,0 +1,5 @@ +from .client import AsyncTaskHubGrpcClient + +__all__ = [ + "AsyncTaskHubGrpcClient", +] diff --git a/durabletask/aio/client.py b/durabletask/aio/client.py new file mode 100644 index 0000000..4ec9bbf --- /dev/null +++ b/durabletask/aio/client.py @@ -0,0 +1,170 @@ +# Copyright (c) The Dapr Authors. +# Licensed under the MIT License. + +import logging +import uuid +from datetime import datetime +from typing import Any, Optional, Sequence, Union + +import grpc +from google.protobuf import wrappers_pb2 + +import durabletask.internal.helpers as helpers +import durabletask.internal.orchestrator_service_pb2 as pb +import durabletask.internal.orchestrator_service_pb2_grpc as stubs +import durabletask.internal.shared as shared +from durabletask.aio.internal.shared import get_grpc_aio_channel, ClientInterceptor +from durabletask import task +from durabletask.client import OrchestrationState, OrchestrationStatus, new_orchestration_state, TInput, TOutput +from durabletask.aio.internal.grpc_interceptor import DefaultClientInterceptorImpl + + +class AsyncTaskHubGrpcClient: + + def __init__(self, *, + host_address: Optional[str] = None, + metadata: Optional[list[tuple[str, str]]] = None, + log_handler: Optional[logging.Handler] = None, + log_formatter: Optional[logging.Formatter] = None, + secure_channel: bool = False, + interceptors: Optional[Sequence[ClientInterceptor]] = None): + + if interceptors is not None: + interceptors = list(interceptors) + if metadata is not None: + interceptors.append(DefaultClientInterceptorImpl(metadata)) + elif metadata is not None: + interceptors = [DefaultClientInterceptorImpl(metadata)] + else: + interceptors = None + + channel = get_grpc_aio_channel( + host_address=host_address, + secure_channel=secure_channel, + interceptors=interceptors + ) + self._channel = channel + self._stub = stubs.TaskHubSidecarServiceStub(channel) + self._logger = shared.get_logger("client", log_handler, log_formatter) + + async def aclose(self): + await self._channel.close() + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.aclose() + return False + + async def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInput, TOutput], str], *, + input: Optional[TInput] = None, + instance_id: Optional[str] = None, + start_at: Optional[datetime] = None, + reuse_id_policy: Optional[pb.OrchestrationIdReusePolicy] = None) -> str: + + name = orchestrator if isinstance(orchestrator, str) else task.get_name(orchestrator) + + req = pb.CreateInstanceRequest( + name=name, + instanceId=instance_id if instance_id else uuid.uuid4().hex, + input=wrappers_pb2.StringValue(value=shared.to_json(input)) if input is not None else None, + scheduledStartTimestamp=helpers.new_timestamp(start_at) if start_at else None, + version=helpers.get_string_value(None), + orchestrationIdReusePolicy=reuse_id_policy, + ) + + self._logger.info(f"Starting new '{name}' instance with ID = '{req.instanceId}'.") + res: pb.CreateInstanceResponse = await self._stub.StartInstance(req) + return res.instanceId + + async def get_orchestration_state(self, instance_id: str, *, fetch_payloads: bool = True) -> Optional[OrchestrationState]: + req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads) + res: pb.GetInstanceResponse = await self._stub.GetInstance(req) + return new_orchestration_state(req.instanceId, res) + + async def wait_for_orchestration_start(self, instance_id: str, *, + fetch_payloads: bool = False, + timeout: int = 0) -> Optional[OrchestrationState]: + req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads) + try: + grpc_timeout = None if timeout == 0 else timeout + self._logger.info( + f"Waiting {'indefinitely' if timeout == 0 else f'up to {timeout}s'} for instance '{instance_id}' to start.") + res: pb.GetInstanceResponse = await self._stub.WaitForInstanceStart(req, timeout=grpc_timeout) + return new_orchestration_state(req.instanceId, res) + except grpc.RpcError as rpc_error: + if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: # type: ignore + # Replace gRPC error with the built-in TimeoutError + raise TimeoutError("Timed-out waiting for the orchestration to start") + else: + raise + + async def wait_for_orchestration_completion(self, instance_id: str, *, + fetch_payloads: bool = True, + timeout: int = 0) -> Optional[OrchestrationState]: + req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads) + try: + grpc_timeout = None if timeout == 0 else timeout + self._logger.info( + f"Waiting {'indefinitely' if timeout == 0 else f'up to {timeout}s'} for instance '{instance_id}' to complete.") + res: pb.GetInstanceResponse = await self._stub.WaitForInstanceCompletion(req, timeout=grpc_timeout) + state = new_orchestration_state(req.instanceId, res) + if not state: + return None + + if state.runtime_status == OrchestrationStatus.FAILED and state.failure_details is not None: + details = state.failure_details + self._logger.info(f"Instance '{instance_id}' failed: [{details.error_type}] {details.message}") + elif state.runtime_status == OrchestrationStatus.TERMINATED: + self._logger.info(f"Instance '{instance_id}' was terminated.") + elif state.runtime_status == OrchestrationStatus.COMPLETED: + self._logger.info(f"Instance '{instance_id}' completed.") + + return state + except grpc.RpcError as rpc_error: + if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: # type: ignore + # Replace gRPC error with the built-in TimeoutError + raise TimeoutError("Timed-out waiting for the orchestration to complete") + else: + raise + + async def raise_orchestration_event( + self, + instance_id: str, + event_name: str, + *, + data: Optional[Any] = None): + req = pb.RaiseEventRequest( + instanceId=instance_id, + name=event_name, + input=wrappers_pb2.StringValue(value=shared.to_json(data)) if data else None) + + self._logger.info(f"Raising event '{event_name}' for instance '{instance_id}'.") + await self._stub.RaiseEvent(req) + + async def terminate_orchestration(self, instance_id: str, *, + output: Optional[Any] = None, + recursive: bool = True): + req = pb.TerminateRequest( + instanceId=instance_id, + output=wrappers_pb2.StringValue(value=shared.to_json(output)) if output else None, + recursive=recursive) + + self._logger.info(f"Terminating instance '{instance_id}'.") + await self._stub.TerminateInstance(req) + + async def suspend_orchestration(self, instance_id: str): + req = pb.SuspendRequest(instanceId=instance_id) + self._logger.info(f"Suspending instance '{instance_id}'.") + await self._stub.SuspendInstance(req) + + async def resume_orchestration(self, instance_id: str): + req = pb.ResumeRequest(instanceId=instance_id) + self._logger.info(f"Resuming instance '{instance_id}'.") + await self._stub.ResumeInstance(req) + + async def purge_orchestration(self, instance_id: str, recursive: bool = True): + req = pb.PurgeInstancesRequest(instanceId=instance_id, recursive=recursive) + self._logger.info(f"Purging instance '{instance_id}'.") + await self._stub.PurgeInstances(req) diff --git a/durabletask/aio/internal/__init__.py b/durabletask/aio/internal/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/durabletask/aio/internal/grpc_interceptor.py b/durabletask/aio/internal/grpc_interceptor.py new file mode 100644 index 0000000..bf1ac98 --- /dev/null +++ b/durabletask/aio/internal/grpc_interceptor.py @@ -0,0 +1,58 @@ +# Copyright (c) The Dapr Authors. +# Licensed under the MIT License. + +from collections import namedtuple + +from grpc import aio as grpc_aio + + +class _ClientCallDetails( + namedtuple( + '_ClientCallDetails', + ['method', 'timeout', 'metadata', 'credentials', 'wait_for_ready', 'compression']), + grpc_aio.ClientCallDetails): + pass + + +class DefaultClientInterceptorImpl( + grpc_aio.UnaryUnaryClientInterceptor, grpc_aio.UnaryStreamClientInterceptor, + grpc_aio.StreamUnaryClientInterceptor, grpc_aio.StreamStreamClientInterceptor): + """Async gRPC client interceptor to add metadata to all calls.""" + + def __init__(self, metadata: list[tuple[str, str]]): + super().__init__() + self._metadata = metadata + + def _intercept_call(self, client_call_details: _ClientCallDetails) -> grpc_aio.ClientCallDetails: + if self._metadata is None: + return client_call_details + + if client_call_details.metadata is not None: + metadata = list(client_call_details.metadata) + else: + metadata = [] + + metadata.extend(self._metadata) + return _ClientCallDetails( + client_call_details.method, + client_call_details.timeout, + metadata, + client_call_details.credentials, + client_call_details.wait_for_ready, + client_call_details.compression) + + async def intercept_unary_unary(self, continuation, client_call_details, request): + new_client_call_details = self._intercept_call(client_call_details) + return await continuation(new_client_call_details, request) + + async def intercept_unary_stream(self, continuation, client_call_details, request): + new_client_call_details = self._intercept_call(client_call_details) + return await continuation(new_client_call_details, request) + + async def intercept_stream_unary(self, continuation, client_call_details, request_iterator): + new_client_call_details = self._intercept_call(client_call_details) + return await continuation(new_client_call_details, request_iterator) + + async def intercept_stream_stream(self, continuation, client_call_details, request_iterator): + new_client_call_details = self._intercept_call(client_call_details) + return await continuation(new_client_call_details, request_iterator) diff --git a/durabletask/aio/internal/shared.py b/durabletask/aio/internal/shared.py new file mode 100644 index 0000000..6bdb256 --- /dev/null +++ b/durabletask/aio/internal/shared.py @@ -0,0 +1,49 @@ +# Copyright (c) The Dapr Authors. +# Licensed under the MIT License. + +from typing import Optional, Sequence, Union + +import grpc +from grpc import aio as grpc_aio + +from durabletask.internal.shared import ( + get_default_host_address, + SECURE_PROTOCOLS, + INSECURE_PROTOCOLS, +) + + +ClientInterceptor = Union[ + grpc_aio.UnaryUnaryClientInterceptor, + grpc_aio.UnaryStreamClientInterceptor, + grpc_aio.StreamUnaryClientInterceptor, + grpc_aio.StreamStreamClientInterceptor +] + + +def get_grpc_aio_channel( + host_address: Optional[str], + secure_channel: bool = False, + interceptors: Optional[Sequence[ClientInterceptor]] = None) -> grpc_aio.Channel: + + if host_address is None: + host_address = get_default_host_address() + + for protocol in SECURE_PROTOCOLS: + if host_address.lower().startswith(protocol): + secure_channel = True + host_address = host_address[len(protocol):] + break + + for protocol in INSECURE_PROTOCOLS: + if host_address.lower().startswith(protocol): + secure_channel = False + host_address = host_address[len(protocol):] + break + + if secure_channel: + channel = grpc_aio.secure_channel(host_address, grpc.ssl_channel_credentials(), interceptors=interceptors) + else: + channel = grpc_aio.insecure_channel(host_address, interceptors=interceptors) + + return channel diff --git a/durabletask/task.py b/durabletask/task.py index 29af2c5..5210c99 100644 --- a/durabletask/task.py +++ b/durabletask/task.py @@ -283,6 +283,7 @@ def get_tasks(self) -> list[Task]: def on_child_completed(self, task: Task[T]): pass + class WhenAllTask(CompositeTask[list[T]]): """A task that completes when all of its child tasks complete.""" @@ -290,6 +291,10 @@ def __init__(self, tasks: list[Task[T]]): super().__init__(tasks) self._completed_tasks = 0 self._failed_tasks = 0 + # If there are no child tasks, this composite should complete immediately + if len(self._tasks) == 0: + self._result = [] # type: ignore[assignment] + self._is_complete = True @property def pending_tasks(self) -> int: @@ -387,6 +392,10 @@ class WhenAnyTask(CompositeTask[Task]): def __init__(self, tasks: list[Task]): super().__init__(tasks) + # If there are no child tasks, complete immediately with an empty result + if len(self._tasks) == 0: + self._result = [] # type: ignore[assignment] + self._is_complete = True def on_child_completed(self, task: Task): # The first task to complete is the result of the WhenAnyTask. diff --git a/durabletask/worker.py b/durabletask/worker.py index 7a04649..e8e1fa9 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -880,13 +880,13 @@ class ExecutionResults: actions: list[pb.OrchestratorAction] encoded_custom_status: Optional[str] - def __init__( self, actions: list[pb.OrchestratorAction], encoded_custom_status: Optional[str] ): self.actions = actions self.encoded_custom_status = encoded_custom_status + class _OrchestrationExecutor: _generator: Optional[task.Orchestrator] = None diff --git a/requirements.txt b/requirements.txt index 07426eb..06750e2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,8 @@ autopep8 grpcio>=1.60.0 # 1.60.0 is the version introducing protobuf 1.25.X support, newer versions are backwards compatible protobuf +asyncio pytest pytest-cov -asyncio +pytest-asyncio +flake8 \ No newline at end of file diff --git a/tests/durabletask/test_client.py b/tests/durabletask/test_client.py index e5a8e9b..e750134 100644 --- a/tests/durabletask/test_client.py +++ b/tests/durabletask/test_client.py @@ -21,6 +21,7 @@ def test_get_grpc_channel_secure(): get_grpc_channel(HOST_ADDRESS, True, interceptors=INTERCEPTORS) mock_channel.assert_called_once_with(HOST_ADDRESS, mock_credentials.return_value) + def test_get_grpc_channel_default_host_address(): with patch('grpc.insecure_channel') as mock_channel: get_grpc_channel(None, False, interceptors=INTERCEPTORS) diff --git a/tests/durabletask/test_client_async.py b/tests/durabletask/test_client_async.py new file mode 100644 index 0000000..8f2b83e --- /dev/null +++ b/tests/durabletask/test_client_async.py @@ -0,0 +1,106 @@ +# Copyright (c) The Dapr Authors. +# Licensed under the MIT License. + +from unittest.mock import ANY, patch + +from durabletask.aio.internal.grpc_interceptor import DefaultClientInterceptorImpl +from durabletask.internal.shared import get_default_host_address +from durabletask.aio.internal.shared import get_grpc_aio_channel +from durabletask.aio.client import AsyncTaskHubGrpcClient + + +HOST_ADDRESS = 'localhost:50051' +METADATA = [('key1', 'value1'), ('key2', 'value2')] +INTERCEPTORS_AIO = [DefaultClientInterceptorImpl(METADATA)] + + +def test_get_grpc_aio_channel_insecure(): + with patch('durabletask.aio.internal.shared.grpc_aio.insecure_channel') as mock_channel: + get_grpc_aio_channel(HOST_ADDRESS, False, interceptors=INTERCEPTORS_AIO) + mock_channel.assert_called_once_with(HOST_ADDRESS, interceptors=INTERCEPTORS_AIO) + + +def test_get_grpc_aio_channel_secure(): + with patch('durabletask.aio.internal.shared.grpc_aio.secure_channel') as mock_channel, patch( + 'grpc.ssl_channel_credentials') as mock_credentials: + get_grpc_aio_channel(HOST_ADDRESS, True, interceptors=INTERCEPTORS_AIO) + mock_channel.assert_called_once_with(HOST_ADDRESS, mock_credentials.return_value, interceptors=INTERCEPTORS_AIO) + + +def test_get_grpc_aio_channel_default_host_address(): + with patch('durabletask.aio.internal.shared.grpc_aio.insecure_channel') as mock_channel: + get_grpc_aio_channel(None, False, interceptors=INTERCEPTORS_AIO) + mock_channel.assert_called_once_with(get_default_host_address(), interceptors=INTERCEPTORS_AIO) + + +def test_get_grpc_aio_channel_with_interceptors(): + with patch('durabletask.aio.internal.shared.grpc_aio.insecure_channel') as mock_channel: + get_grpc_aio_channel(HOST_ADDRESS, False, interceptors=INTERCEPTORS_AIO) + mock_channel.assert_called_once_with(HOST_ADDRESS, interceptors=INTERCEPTORS_AIO) + + # Capture and check the arguments passed to insecure_channel() + args, kwargs = mock_channel.call_args + assert args[0] == HOST_ADDRESS + assert 'interceptors' in kwargs + interceptors = kwargs['interceptors'] + assert isinstance(interceptors[0], DefaultClientInterceptorImpl) + assert interceptors[0]._metadata == METADATA + + +def test_grpc_aio_channel_with_host_name_protocol_stripping(): + with patch('durabletask.aio.internal.shared.grpc_aio.insecure_channel') as mock_insecure_channel, patch( + 'durabletask.aio.internal.shared.grpc_aio.secure_channel') as mock_secure_channel: + + host_name = "myserver.com:1234" + + prefix = "grpc://" + get_grpc_aio_channel(prefix + host_name, interceptors=INTERCEPTORS_AIO) + mock_insecure_channel.assert_called_with(host_name, interceptors=INTERCEPTORS_AIO) + + prefix = "http://" + get_grpc_aio_channel(prefix + host_name, interceptors=INTERCEPTORS_AIO) + mock_insecure_channel.assert_called_with(host_name, interceptors=INTERCEPTORS_AIO) + + prefix = "HTTP://" + get_grpc_aio_channel(prefix + host_name, interceptors=INTERCEPTORS_AIO) + mock_insecure_channel.assert_called_with(host_name, interceptors=INTERCEPTORS_AIO) + + prefix = "GRPC://" + get_grpc_aio_channel(prefix + host_name, interceptors=INTERCEPTORS_AIO) + mock_insecure_channel.assert_called_with(host_name, interceptors=INTERCEPTORS_AIO) + + prefix = "" + get_grpc_aio_channel(prefix + host_name, interceptors=INTERCEPTORS_AIO) + mock_insecure_channel.assert_called_with(host_name, interceptors=INTERCEPTORS_AIO) + + prefix = "grpcs://" + get_grpc_aio_channel(prefix + host_name, interceptors=INTERCEPTORS_AIO) + mock_secure_channel.assert_called_with(host_name, ANY, interceptors=INTERCEPTORS_AIO) + + prefix = "https://" + get_grpc_aio_channel(prefix + host_name, interceptors=INTERCEPTORS_AIO) + mock_secure_channel.assert_called_with(host_name, ANY, interceptors=INTERCEPTORS_AIO) + + prefix = "HTTPS://" + get_grpc_aio_channel(prefix + host_name, interceptors=INTERCEPTORS_AIO) + mock_secure_channel.assert_called_with(host_name, ANY, interceptors=INTERCEPTORS_AIO) + + prefix = "GRPCS://" + get_grpc_aio_channel(prefix + host_name, interceptors=INTERCEPTORS_AIO) + mock_secure_channel.assert_called_with(host_name, ANY, interceptors=INTERCEPTORS_AIO) + + prefix = "" + get_grpc_aio_channel(prefix + host_name, True, interceptors=INTERCEPTORS_AIO) + mock_secure_channel.assert_called_with(host_name, ANY, interceptors=INTERCEPTORS_AIO) + + +def test_async_client_construct_with_metadata(): + with patch('durabletask.aio.internal.shared.grpc_aio.insecure_channel') as mock_channel: + AsyncTaskHubGrpcClient(host_address=HOST_ADDRESS, metadata=METADATA) + # Ensure channel created with an interceptor that has the expected metadata + args, kwargs = mock_channel.call_args + assert args[0] == HOST_ADDRESS + assert 'interceptors' in kwargs + interceptors = kwargs['interceptors'] + assert isinstance(interceptors[0], DefaultClientInterceptorImpl) + assert interceptors[0]._metadata == METADATA diff --git a/tests/durabletask/test_orchestration_e2e.py b/tests/durabletask/test_orchestration_e2e.py index 2343184..76ec355 100644 --- a/tests/durabletask/test_orchestration_e2e.py +++ b/tests/durabletask/test_orchestration_e2e.py @@ -316,7 +316,6 @@ def parent_orchestrator(ctx: task.OrchestrationContext, count: int): output = "Recursive termination = {recurse}" task_hub_client.terminate_orchestration(instance_id, output=output, recursive=recurse) - metadata = task_hub_client.wait_for_orchestration_completion(instance_id, timeout=30) assert metadata is not None diff --git a/tests/durabletask/test_orchestration_e2e_async.py b/tests/durabletask/test_orchestration_e2e_async.py new file mode 100644 index 0000000..de586f1 --- /dev/null +++ b/tests/durabletask/test_orchestration_e2e_async.py @@ -0,0 +1,480 @@ +# Copyright (c) The Dapr Authors. +# Licensed under the MIT License. + +import asyncio +import json +import threading +from datetime import timedelta + +import pytest + +from durabletask.aio.client import AsyncTaskHubGrpcClient +from durabletask.client import OrchestrationStatus +from durabletask import task, worker + + +# NOTE: These tests assume a sidecar process is running. Example command: +# go install github.com/microsoft/durabletask-go@main +# durabletask-go --port 4001 +pytestmark = [pytest.mark.e2e, pytest.mark.asyncio] + + +async def test_empty_orchestration(): + + invoked = False + + def empty_orchestrator(ctx: task.OrchestrationContext, _): + nonlocal invoked # don't do this in a real app! + invoked = True + + # Start a worker, which will connect to the sidecar in a background thread + with worker.TaskHubGrpcWorker() as w: + w.add_orchestrator(empty_orchestrator) + w.start() + + c = AsyncTaskHubGrpcClient() + id = await c.schedule_new_orchestration(empty_orchestrator) + state = await c.wait_for_orchestration_completion(id, timeout=30) + await c.aclose() + + assert invoked + assert state is not None + assert state.name == task.get_name(empty_orchestrator) + assert state.instance_id == id + assert state.failure_details is None + assert state.runtime_status == OrchestrationStatus.COMPLETED + assert state.serialized_input is None + assert state.serialized_output is None + assert state.serialized_custom_status is None + + +async def test_activity_sequence(): + + def plus_one(_: task.ActivityContext, input: int) -> int: + return input + 1 + + def sequence(ctx: task.OrchestrationContext, start_val: int): + numbers = [start_val] + current = start_val + for _ in range(10): + current = yield ctx.call_activity(plus_one, input=current) + numbers.append(current) + return numbers + + # Start a worker, which will connect to the sidecar in a background thread + with worker.TaskHubGrpcWorker() as w: + w.add_orchestrator(sequence) + w.add_activity(plus_one) + w.start() + + client = AsyncTaskHubGrpcClient() + id = await client.schedule_new_orchestration(sequence, input=1) + state = await client.wait_for_orchestration_completion(id, timeout=30) + await client.aclose() + + assert state is not None + assert state.name == task.get_name(sequence) + assert state.instance_id == id + assert state.runtime_status == OrchestrationStatus.COMPLETED + assert state.failure_details is None + assert state.serialized_input == json.dumps(1) + assert state.serialized_output == json.dumps([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) + assert state.serialized_custom_status is None + + +async def test_activity_error_handling(): + + def throw(_: task.ActivityContext, input: int) -> int: + raise RuntimeError("Kah-BOOOOM!!!") + + compensation_counter = 0 + + def increment_counter(ctx, _): + nonlocal compensation_counter + compensation_counter += 1 + + def orchestrator(ctx: task.OrchestrationContext, input: int): + error_msg = "" + try: + yield ctx.call_activity(throw, input=input) + except task.TaskFailedError as e: + error_msg = e.details.message + + # compensating actions + yield ctx.call_activity(increment_counter) + yield ctx.call_activity(increment_counter) + + return error_msg + + # Start a worker, which will connect to the sidecar in a background thread + with worker.TaskHubGrpcWorker() as w: + w.add_orchestrator(orchestrator) + w.add_activity(throw) + w.add_activity(increment_counter) + w.start() + + client = AsyncTaskHubGrpcClient() + id = await client.schedule_new_orchestration(orchestrator, input=1) + state = await client.wait_for_orchestration_completion(id, timeout=30) + await client.aclose() + + assert state is not None + assert state.name == task.get_name(orchestrator) + assert state.instance_id == id + assert state.runtime_status == OrchestrationStatus.COMPLETED + assert state.serialized_output == json.dumps("Kah-BOOOOM!!!") + assert state.failure_details is None + assert state.serialized_custom_status is None + assert compensation_counter == 2 + + +async def test_sub_orchestration_fan_out(): + threadLock = threading.Lock() + activity_counter = 0 + + def increment(ctx, _): + with threadLock: + nonlocal activity_counter + activity_counter += 1 + + def orchestrator_child(ctx: task.OrchestrationContext, activity_count: int): + for _ in range(activity_count): + yield ctx.call_activity(increment) + + def parent_orchestrator(ctx: task.OrchestrationContext, count: int): + # Fan out to multiple sub-orchestrations + tasks = [] + for _ in range(count): + tasks.append(ctx.call_sub_orchestrator( + orchestrator_child, input=3)) + # Wait for all sub-orchestrations to complete + yield task.when_all(tasks) + + # Start a worker, which will connect to the sidecar in a background thread + with worker.TaskHubGrpcWorker() as w: + w.add_activity(increment) + w.add_orchestrator(orchestrator_child) + w.add_orchestrator(parent_orchestrator) + w.start() + + client = AsyncTaskHubGrpcClient() + id = await client.schedule_new_orchestration(parent_orchestrator, input=10) + state = await client.wait_for_orchestration_completion(id, timeout=30) + await client.aclose() + + assert state is not None + assert state.runtime_status == OrchestrationStatus.COMPLETED + assert state.failure_details is None + assert activity_counter == 30 + + +async def test_wait_for_multiple_external_events(): + def orchestrator(ctx: task.OrchestrationContext, _): + a = yield ctx.wait_for_external_event('A') + b = yield ctx.wait_for_external_event('B') + c = yield ctx.wait_for_external_event('C') + return [a, b, c] + + # Start a worker, which will connect to the sidecar in a background thread + with worker.TaskHubGrpcWorker() as w: + w.add_orchestrator(orchestrator) + w.start() + + # Start the orchestration and immediately raise events to it. + client = AsyncTaskHubGrpcClient() + id = await client.schedule_new_orchestration(orchestrator) + await client.raise_orchestration_event(id, 'A', data='a') + await client.raise_orchestration_event(id, 'B', data='b') + await client.raise_orchestration_event(id, 'C', data='c') + state = await client.wait_for_orchestration_completion(id, timeout=30) + await client.aclose() + + assert state is not None + assert state.runtime_status == OrchestrationStatus.COMPLETED + assert state.serialized_output == json.dumps(['a', 'b', 'c']) + + +@pytest.mark.parametrize("raise_event", [True, False]) +async def test_wait_for_external_event_timeout(raise_event: bool): + def orchestrator(ctx: task.OrchestrationContext, _): + approval: task.Task[bool] = ctx.wait_for_external_event('Approval') + timeout = ctx.create_timer(timedelta(seconds=3)) + winner = yield task.when_any([approval, timeout]) + if winner == approval: + return "approved" + else: + return "timed out" + + # Start a worker, which will connect to the sidecar in a background thread + with worker.TaskHubGrpcWorker() as w: + w.add_orchestrator(orchestrator) + w.start() + + # Start the orchestration and immediately raise events to it. + client = AsyncTaskHubGrpcClient() + id = await client.schedule_new_orchestration(orchestrator) + if raise_event: + await client.raise_orchestration_event(id, 'Approval') + state = await client.wait_for_orchestration_completion(id, timeout=30) + await client.aclose() + + assert state is not None + assert state.runtime_status == OrchestrationStatus.COMPLETED + if raise_event: + assert state.serialized_output == json.dumps("approved") + else: + assert state.serialized_output == json.dumps("timed out") + + +async def test_suspend_and_resume(): + def orchestrator(ctx: task.OrchestrationContext, _): + result = yield ctx.wait_for_external_event("my_event") + return result + + # Start a worker, which will connect to the sidecar in a background thread + with worker.TaskHubGrpcWorker() as w: + w.add_orchestrator(orchestrator) + w.start() + # there could be a race condition if the workflow is scheduled before orchestrator is started + await asyncio.sleep(0.2) + + async with AsyncTaskHubGrpcClient() as client: + id = await client.schedule_new_orchestration(orchestrator) + state = await client.wait_for_orchestration_start(id, timeout=30) + assert state is not None + + # Suspend the orchestration and wait for it to go into the SUSPENDED state + await client.suspend_orchestration(id) + while state.runtime_status == OrchestrationStatus.RUNNING: + await asyncio.sleep(0.1) + state = await client.get_orchestration_state(id) + assert state is not None + assert state.runtime_status == OrchestrationStatus.SUSPENDED + + # Raise an event to the orchestration and confirm that it does NOT complete + await client.raise_orchestration_event(id, "my_event", data=42) + try: + state = await client.wait_for_orchestration_completion(id, timeout=3) + assert False, "Orchestration should not have completed" + except TimeoutError: + pass + + # Resume the orchestration and wait for it to complete + await client.resume_orchestration(id) + state = await client.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == OrchestrationStatus.COMPLETED + assert state.serialized_output == json.dumps(42) + + +async def test_terminate(): + def orchestrator(ctx: task.OrchestrationContext, _): + result = yield ctx.wait_for_external_event("my_event") + return result + + # Start a worker, which will connect to the sidecar in a background thread + with worker.TaskHubGrpcWorker() as w: + w.add_orchestrator(orchestrator) + w.start() + + async with AsyncTaskHubGrpcClient() as client: + id = await client.schedule_new_orchestration(orchestrator) + state = await client.wait_for_orchestration_start(id, timeout=30) + assert state is not None + assert state.runtime_status == OrchestrationStatus.RUNNING + + await client.terminate_orchestration(id, output="some reason for termination") + state = await client.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == OrchestrationStatus.TERMINATED + assert state.serialized_output == json.dumps("some reason for termination") + + +async def test_terminate_recursive(): + def root(ctx: task.OrchestrationContext, _): + result = yield ctx.call_sub_orchestrator(child) + return result + + def child(ctx: task.OrchestrationContext, _): + result = yield ctx.wait_for_external_event("my_event") + return result + + # Start a worker, which will connect to the sidecar in a background thread + with worker.TaskHubGrpcWorker() as w: + w.add_orchestrator(root) + w.add_orchestrator(child) + w.start() + + async with AsyncTaskHubGrpcClient() as client: + id = await client.schedule_new_orchestration(root) + state = await client.wait_for_orchestration_start(id, timeout=30) + assert state is not None + assert state.runtime_status == OrchestrationStatus.RUNNING + + # Terminate root orchestration(recursive set to True by default) + await client.terminate_orchestration(id, output="some reason for termination") + state = await client.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == OrchestrationStatus.TERMINATED + + # Verify that child orchestration is also terminated + await client.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == OrchestrationStatus.TERMINATED + + await client.purge_orchestration(id) + state = await client.get_orchestration_state(id) + assert state is None + + +async def test_continue_as_new(): + all_results = [] + + def orchestrator(ctx: task.OrchestrationContext, input: int): + result = yield ctx.wait_for_external_event("my_event") + if not ctx.is_replaying: + # NOTE: Real orchestrations should never interact with nonlocal variables like this. + nonlocal all_results # noqa: F824 + all_results.append(result) + + if len(all_results) <= 4: + ctx.continue_as_new(max(all_results), save_events=True) + else: + return all_results + + # Start a worker, which will connect to the sidecar in a background thread + with worker.TaskHubGrpcWorker() as w: + w.add_orchestrator(orchestrator) + w.start() + + async with AsyncTaskHubGrpcClient() as client: + id = await client.schedule_new_orchestration(orchestrator, input=0) + await client.raise_orchestration_event(id, "my_event", data=1) + await client.raise_orchestration_event(id, "my_event", data=2) + await client.raise_orchestration_event(id, "my_event", data=3) + await client.raise_orchestration_event(id, "my_event", data=4) + await client.raise_orchestration_event(id, "my_event", data=5) + + state = await client.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == OrchestrationStatus.COMPLETED + assert state.serialized_output == json.dumps(all_results) + assert state.serialized_input == json.dumps(4) + assert all_results == [1, 2, 3, 4, 5] + + +async def test_retry_policies(): + # This test verifies that the retry policies are working as expected. + # It does this by creating an orchestration that calls a sub-orchestrator, + # which in turn calls an activity that always fails. + # In this test, the retry policies are added, and the orchestration + # should still fail. But, number of times the sub-orchestrator and activity + # is called should increase as per the retry policies. + + child_orch_counter = 0 + throw_activity_counter = 0 + + # Second setup: With retry policies + retry_policy = task.RetryPolicy( + first_retry_interval=timedelta(seconds=1), + max_number_of_attempts=3, + backoff_coefficient=1, + max_retry_interval=timedelta(seconds=10), + retry_timeout=timedelta(seconds=30)) + + def parent_orchestrator_with_retry(ctx: task.OrchestrationContext, _): + yield ctx.call_sub_orchestrator(child_orchestrator_with_retry, retry_policy=retry_policy) + + def child_orchestrator_with_retry(ctx: task.OrchestrationContext, _): + nonlocal child_orch_counter + if not ctx.is_replaying: + # NOTE: Real orchestrations should never interact with nonlocal variables like this. + # This is done only for testing purposes. + child_orch_counter += 1 + yield ctx.call_activity(throw_activity_with_retry, retry_policy=retry_policy) + + def throw_activity_with_retry(ctx: task.ActivityContext, _): + nonlocal throw_activity_counter + throw_activity_counter += 1 + raise RuntimeError("Kah-BOOOOM!!!") + + with worker.TaskHubGrpcWorker() as w: + w.add_orchestrator(parent_orchestrator_with_retry) + w.add_orchestrator(child_orchestrator_with_retry) + w.add_activity(throw_activity_with_retry) + w.start() + + async with AsyncTaskHubGrpcClient() as client: + id = await client.schedule_new_orchestration(parent_orchestrator_with_retry) + state = await client.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == OrchestrationStatus.FAILED + assert state.failure_details is not None + assert state.failure_details.error_type == "TaskFailedError" + assert state.failure_details.message.startswith("Sub-orchestration task #1 failed:") + assert state.failure_details.message.endswith("Activity task #1 failed: Kah-BOOOOM!!!") + assert state.failure_details.stack_trace is not None + assert throw_activity_counter == 9 + assert child_orch_counter == 3 + + +async def test_retry_timeout(): + # This test verifies that the retry timeout is working as expected. + # Max number of attempts is 5 and retry timeout is 14 seconds. + # Total seconds consumed till 4th attempt is 1 + 2 + 4 + 8 = 15 seconds. + # So, the 5th attempt should not be made and the orchestration should fail. + throw_activity_counter = 0 + retry_policy = task.RetryPolicy( + first_retry_interval=timedelta(seconds=1), + max_number_of_attempts=5, + backoff_coefficient=2, + max_retry_interval=timedelta(seconds=10), + retry_timeout=timedelta(seconds=14)) + + def mock_orchestrator(ctx: task.OrchestrationContext, _): + yield ctx.call_activity(throw_activity, retry_policy=retry_policy) + + def throw_activity(ctx: task.ActivityContext, _): + nonlocal throw_activity_counter + throw_activity_counter += 1 + raise RuntimeError("Kah-BOOOOM!!!") + + with worker.TaskHubGrpcWorker() as w: + w.add_orchestrator(mock_orchestrator) + w.add_activity(throw_activity) + w.start() + + async with AsyncTaskHubGrpcClient() as client: + id = await client.schedule_new_orchestration(mock_orchestrator) + state = await client.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == OrchestrationStatus.FAILED + assert state.failure_details is not None + assert state.failure_details.error_type == "TaskFailedError" + assert state.failure_details.message.endswith("Activity task #1 failed: Kah-BOOOOM!!!") + assert state.failure_details.stack_trace is not None + assert throw_activity_counter == 4 + + +async def test_custom_status(): + + def empty_orchestrator(ctx: task.OrchestrationContext, _): + ctx.set_custom_status("foobaz") + + # Start a worker, which will connect to the sidecar in a background thread + with worker.TaskHubGrpcWorker() as w: + w.add_orchestrator(empty_orchestrator) + w.start() + + async with AsyncTaskHubGrpcClient() as client: + id = await client.schedule_new_orchestration(empty_orchestrator) + state = await client.wait_for_orchestration_completion(id, timeout=30) + + assert state is not None + assert state.name == task.get_name(empty_orchestrator) + assert state.instance_id == id + assert state.failure_details is None + assert state.runtime_status == OrchestrationStatus.COMPLETED + assert state.serialized_input is None + assert state.serialized_output is None + assert state.serialized_custom_status == "\"foobaz\"" diff --git a/tests/durabletask/test_orchestration_executor.py b/tests/durabletask/test_orchestration_executor.py index 21f6c6c..c784135 100644 --- a/tests/durabletask/test_orchestration_executor.py +++ b/tests/durabletask/test_orchestration_executor.py @@ -634,7 +634,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): yield ctx.call_sub_orchestrator(suborchestrator, input=None) registry = worker._Registry() - suborchestrator_name = registry.add_orchestrator(suborchestrator) + registry.add_orchestrator(suborchestrator) orchestrator_name = registry.add_orchestrator(orchestrator) exec_evt = helpers.new_execution_started_event(orchestrator_name, TEST_INSTANCE_ID, encoded_input=None) @@ -666,7 +666,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): yield ctx.call_sub_orchestrator(suborchestrator, input=None, app_id="target-app") registry = worker._Registry() - suborchestrator_name = registry.add_orchestrator(suborchestrator) + registry.add_orchestrator(suborchestrator) orchestrator_name = registry.add_orchestrator(orchestrator) exec_evt = helpers.new_execution_started_event(orchestrator_name, TEST_INSTANCE_ID, encoded_input=None) diff --git a/tests/durabletask/test_orchestration_wait.py b/tests/durabletask/test_orchestration_wait.py index 03f7e30..c27345f 100644 --- a/tests/durabletask/test_orchestration_wait.py +++ b/tests/durabletask/test_orchestration_wait.py @@ -1,11 +1,9 @@ -from unittest.mock import patch, ANY, Mock +from unittest.mock import Mock from durabletask.client import TaskHubGrpcClient -from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl -from durabletask.internal.shared import (get_default_host_address, - get_grpc_channel) import pytest + @pytest.mark.parametrize("timeout", [None, 0, 5]) def test_wait_for_orchestration_start_timeout(timeout): instance_id = "test-instance" @@ -34,6 +32,7 @@ def test_wait_for_orchestration_start_timeout(timeout): else: assert kwargs.get('timeout') == timeout + @pytest.mark.parametrize("timeout", [None, 0, 5]) def test_wait_for_orchestration_completion_timeout(timeout): instance_id = "test-instance" diff --git a/tests/durabletask/test_task.py b/tests/durabletask/test_task.py new file mode 100644 index 0000000..81cc8a2 --- /dev/null +++ b/tests/durabletask/test_task.py @@ -0,0 +1,70 @@ +# Copyright (c) The Dapr Authors. +# Licensed under the MIT License. + +"""Unit tests for durabletask.task primitives.""" + +from durabletask import task + + +def test_when_all_empty_returns_successfully(): + """task.when_all([]) should complete immediately and return an empty list.""" + when_all_task = task.when_all([]) + + assert when_all_task.is_complete + assert when_all_task.get_result() == [] + + +def test_when_any_empty_returns_successfully(): + """task.when_any([]) should complete immediately and return an empty list.""" + when_any_task = task.when_any([]) + + assert when_any_task.is_complete + assert when_any_task.get_result() == [] + + +def test_when_all_happy_path_returns_ordered_results_and_completes_last(): + c1 = task.CompletableTask() + c2 = task.CompletableTask() + c3 = task.CompletableTask() + + all_task = task.when_all([c1, c2, c3]) + + assert not all_task.is_complete + + c2.complete("two") + + assert not all_task.is_complete + + c1.complete("one") + + assert not all_task.is_complete + + c3.complete("three") + + assert all_task.is_complete + + assert all_task.get_result() == ["one", "two", "three"] + + +def test_when_any_happy_path_returns_winner_task_and_completes_on_first(): + a = task.CompletableTask() + b = task.CompletableTask() + + any_task = task.when_any([a, b]) + + assert not any_task.is_complete + + b.complete("B") + + assert any_task.is_complete + + winner = any_task.get_result() + + assert winner is b + + assert winner.get_result() == "B" + + # Completing the other child should not change the winner + a.complete("A") + + assert any_task.get_result() is b