diff --git a/.github/workflows/durabletask-azuremanaged.yml b/.github/workflows/durabletask-azuremanaged.yml index 73017e4..e2215a3 100644 --- a/.github/workflows/durabletask-azuremanaged.yml +++ b/.github/workflows/durabletask-azuremanaged.yml @@ -72,6 +72,15 @@ jobs: python -m pip install --upgrade pip pip install -r requirements.txt + - name: Install durabletask-azuremanaged locally + working-directory: durabletask-azuremanaged + run: | + pip install . --no-deps --force-reinstall + + - name: Install durabletask locally + run: | + pip install . --no-deps --force-reinstall + - name: Run the tests working-directory: tests/durabletask-azuremanaged run: | diff --git a/.vscode/settings.json b/.vscode/settings.json index 1c929ac..824a8c3 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -30,5 +30,6 @@ "jacoco.xml", "coverage.cobertura.xml" ], - "makefile.configureOnOpen": false + "makefile.configureOnOpen": false, + "debugpy.debugJustMyCode": false } \ No newline at end of file diff --git a/durabletask-azuremanaged/durabletask/azuremanaged/client.py b/durabletask-azuremanaged/durabletask/azuremanaged/client.py index e1c2445..ffc0a7e 100644 --- a/durabletask-azuremanaged/durabletask/azuremanaged/client.py +++ b/durabletask-azuremanaged/durabletask/azuremanaged/client.py @@ -17,7 +17,8 @@ def __init__(self, *, host_address: str, taskhub: str, token_credential: Optional[TokenCredential], - secure_channel: bool = True): + secure_channel: bool = True, + default_version: Optional[str] = None): if not taskhub: raise ValueError("Taskhub value cannot be empty. Please provide a value for your taskhub") @@ -30,4 +31,5 @@ def __init__(self, *, host_address=host_address, secure_channel=secure_channel, metadata=None, - interceptors=interceptors) + interceptors=interceptors, + default_version=default_version) diff --git a/durabletask/__init__.py b/durabletask/__init__.py index 88af82b..e0e73d3 100644 --- a/durabletask/__init__.py +++ b/durabletask/__init__.py @@ -3,8 +3,8 @@ """Durable Task SDK for Python""" -from durabletask.worker import ConcurrencyOptions +from durabletask.worker import ConcurrencyOptions, VersioningOptions -__all__ = ["ConcurrencyOptions"] +__all__ = ["ConcurrencyOptions", "VersioningOptions"] PACKAGE_NAME = "durabletask" diff --git a/durabletask/client.py b/durabletask/client.py index 591aac3..bc3abed 100644 --- a/durabletask/client.py +++ b/durabletask/client.py @@ -98,7 +98,8 @@ def __init__(self, *, log_handler: Optional[logging.Handler] = None, log_formatter: Optional[logging.Formatter] = None, secure_channel: bool = False, - interceptors: Optional[Sequence[shared.ClientInterceptor]] = None): + interceptors: Optional[Sequence[shared.ClientInterceptor]] = None, + default_version: Optional[str] = None): # If the caller provided metadata, we need to create a new interceptor for it and # add it to the list of interceptors. @@ -118,13 +119,15 @@ def __init__(self, *, ) self._stub = stubs.TaskHubSidecarServiceStub(channel) self._logger = shared.get_logger("client", log_handler, log_formatter) + self.default_version = default_version def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInput, TOutput], str], *, input: Optional[TInput] = None, instance_id: Optional[str] = None, start_at: Optional[datetime] = None, reuse_id_policy: Optional[pb.OrchestrationIdReusePolicy] = None, - tags: Optional[dict[str, str]] = None) -> str: + tags: Optional[dict[str, str]] = None, + version: Optional[str] = None) -> str: name = orchestrator if isinstance(orchestrator, str) else task.get_name(orchestrator) @@ -133,9 +136,9 @@ def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInpu instanceId=instance_id if instance_id else uuid.uuid4().hex, input=wrappers_pb2.StringValue(value=shared.to_json(input)) if input is not None else None, scheduledStartTimestamp=helpers.new_timestamp(start_at) if start_at else None, - version=wrappers_pb2.StringValue(value=""), + version=helpers.get_string_value(version if version else self.default_version), orchestrationIdReusePolicy=reuse_id_policy, - tags=tags, + tags=tags ) self._logger.info(f"Starting new '{name}' instance with ID = '{req.instanceId}'.") diff --git a/durabletask/internal/exceptions.py b/durabletask/internal/exceptions.py new file mode 100644 index 0000000..efda599 --- /dev/null +++ b/durabletask/internal/exceptions.py @@ -0,0 +1,7 @@ +class VersionFailureException(Exception): + pass + + +class AbandonOrchestrationError(Exception): + def __init__(self, *args: object) -> None: + super().__init__(*args) diff --git a/durabletask/internal/helpers.py b/durabletask/internal/helpers.py index 29f29e0..6140dec 100644 --- a/durabletask/internal/helpers.py +++ b/durabletask/internal/helpers.py @@ -199,11 +199,13 @@ def new_create_sub_orchestration_action( id: int, name: str, instance_id: Optional[str], - encoded_input: Optional[str]) -> pb.OrchestratorAction: + encoded_input: Optional[str], + version: Optional[str]) -> pb.OrchestratorAction: return pb.OrchestratorAction(id=id, createSubOrchestration=pb.CreateSubOrchestrationAction( name=name, instanceId=instance_id, - input=get_string_value(encoded_input) + input=get_string_value(encoded_input), + version=get_string_value(version) )) diff --git a/durabletask/task.py b/durabletask/task.py index 1424436..14f5fac 100644 --- a/durabletask/task.py +++ b/durabletask/task.py @@ -35,6 +35,21 @@ def instance_id(self) -> str: """ pass + @property + @abstractmethod + def version(self) -> Optional[str]: + """Get the version of the orchestration instance. + + This version is set when the orchestration is scheduled and can be used + to determine which version of the orchestrator function is being executed. + + Returns + ------- + Optional[str] + The version of the orchestration instance, or None if not set. + """ + pass + @property @abstractmethod def current_utc_datetime(self) -> datetime: @@ -126,7 +141,8 @@ def call_activity(self, activity: Union[Activity[TInput, TOutput], str], *, def call_sub_orchestrator(self, orchestrator: Orchestrator[TInput, TOutput], *, input: Optional[TInput] = None, instance_id: Optional[str] = None, - retry_policy: Optional[RetryPolicy] = None) -> Task[TOutput]: + retry_policy: Optional[RetryPolicy] = None, + version: Optional[str] = None) -> Task[TOutput]: """Schedule sub-orchestrator function for execution. Parameters diff --git a/durabletask/worker.py b/durabletask/worker.py index 8a85070..2a1626d 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -10,12 +10,15 @@ from datetime import datetime, timedelta from threading import Event, Thread from types import GeneratorType +from enum import Enum from typing import Any, Generator, Optional, Sequence, TypeVar, Union +from packaging.version import InvalidVersion, parse import grpc from google.protobuf import empty_pb2 import durabletask.internal.helpers as ph +import durabletask.internal.exceptions as pe import durabletask.internal.orchestrator_service_pb2 as pb import durabletask.internal.orchestrator_service_pb2_grpc as stubs import durabletask.internal.shared as shared @@ -72,9 +75,56 @@ def __init__( ) +class VersionMatchStrategy(Enum): + """Enumeration for version matching strategies.""" + + NONE = 1 + STRICT = 2 + CURRENT_OR_OLDER = 3 + + +class VersionFailureStrategy(Enum): + """Enumeration for version failure strategies.""" + + REJECT = 1 + FAIL = 2 + + +class VersioningOptions: + """Configuration options for orchestrator and activity versioning. + + This class provides options to control how versioning is handled for orchestrators + and activities, including whether to use the default version and how to compare versions. + """ + + version: Optional[str] = None + default_version: Optional[str] = None + match_strategy: Optional[VersionMatchStrategy] = None + failure_strategy: Optional[VersionFailureStrategy] = None + + def __init__(self, version: Optional[str] = None, + default_version: Optional[str] = None, + match_strategy: Optional[VersionMatchStrategy] = None, + failure_strategy: Optional[VersionFailureStrategy] = None + ): + """Initialize versioning options. + + Args: + version: The version of orchestrations that the worker can work on. + default_version: The default version that will be used for starting new orchestrations. + match_strategy: The versioning strategy for the Durable Task worker. + failure_strategy: The versioning failure strategy for the Durable Task worker. + """ + self.version = version + self.default_version = default_version + self.match_strategy = match_strategy + self.failure_strategy = failure_strategy + + class _Registry: orchestrators: dict[str, task.Orchestrator] activities: dict[str, task.Activity] + versioning: Optional[VersioningOptions] = None def __init__(self): self.orchestrators = {} @@ -279,6 +329,12 @@ def add_activity(self, fn: task.Activity) -> str: ) return self._registry.add_activity(fn) + def use_versioning(self, version: VersioningOptions) -> None: + """Initializes versioning options for sub-orchestrators and activities.""" + if self._is_running: + raise RuntimeError("Cannot set default version while the worker is running.") + self._registry.versioning = version + def start(self): """Starts the worker on a background thread and begins listening for work items.""" if self._is_running: @@ -513,6 +569,16 @@ def _execute_orchestrator( customStatus=ph.get_string_value(result.encoded_custom_status), completionToken=completionToken, ) + except pe.AbandonOrchestrationError: + self._logger.info( + f"Abandoning orchestration. InstanceId = '{req.instanceId}'. Completion token = '{completionToken}'" + ) + stub.AbandonTaskOrchestratorWorkItem( + pb.AbandonOrchestrationTaskRequest( + completionToken=completionToken + ) + ) + return except Exception as ex: self._logger.exception( f"An error occurred while trying to execute instance '{req.instanceId}': {ex}" @@ -574,7 +640,7 @@ class _RuntimeOrchestrationContext(task.OrchestrationContext): _generator: Optional[Generator[task.Task, Any, Any]] _previous_task: Optional[task.Task] - def __init__(self, instance_id: str): + def __init__(self, instance_id: str, registry: _Registry): self._generator = None self._is_replaying = True self._is_complete = False @@ -584,6 +650,8 @@ def __init__(self, instance_id: str): self._sequence_number = 0 self._current_utc_datetime = datetime(1000, 1, 1) self._instance_id = instance_id + self._registry = registry + self._version: Optional[str] = None self._completion_status: Optional[pb.OrchestrationStatus] = None self._received_events: dict[str, list[Any]] = {} self._pending_events: dict[str, list[task.CompletableTask]] = {} @@ -646,7 +714,7 @@ def set_complete( ) self._pending_actions[action.id] = action - def set_failed(self, ex: Exception): + def set_failed(self, ex: Union[Exception, pb.TaskFailureDetails]): if self._is_complete: return @@ -658,7 +726,7 @@ def set_failed(self, ex: Exception): self.next_sequence_number(), pb.ORCHESTRATION_STATUS_FAILED, None, - ph.new_failure_details(ex), + ph.new_failure_details(ex) if isinstance(ex, Exception) else ex, ) self._pending_actions[action.id] = action @@ -709,6 +777,10 @@ def next_sequence_number(self) -> int: def instance_id(self) -> str: return self._instance_id + @property + def version(self) -> Optional[str]: + return self._version + @property def current_utc_datetime(self) -> datetime: return self._current_utc_datetime @@ -768,9 +840,12 @@ def call_sub_orchestrator( input: Optional[TInput] = None, instance_id: Optional[str] = None, retry_policy: Optional[task.RetryPolicy] = None, + version: Optional[str] = None, ) -> task.Task[TOutput]: id = self.next_sequence_number() orchestrator_name = task.get_name(orchestrator) + default_version = self._registry.versioning.default_version if self._registry.versioning else None + orchestrator_version = version if version else default_version self.call_activity_function_helper( id, orchestrator_name, @@ -778,6 +853,7 @@ def call_sub_orchestrator( retry_policy=retry_policy, is_sub_orch=True, instance_id=instance_id, + version=orchestrator_version ) return self._pending_tasks.get(id, task.CompletableTask()) @@ -792,6 +868,7 @@ def call_activity_function_helper( is_sub_orch: bool = False, instance_id: Optional[str] = None, fn_task: Optional[task.CompletableTask[TOutput]] = None, + version: Optional[str] = None, ): if id is None: id = self.next_sequence_number() @@ -816,7 +893,7 @@ def call_activity_function_helper( if not isinstance(activity_function, str): raise ValueError("Orchestrator function name must be a string") action = ph.new_create_sub_orchestration_action( - id, activity_function, instance_id, encoded_input + id, activity_function, instance_id, encoded_input, version ) self._pending_actions[id] = action @@ -892,7 +969,8 @@ def execute( "The new history event list must have at least one event in it." ) - ctx = _RuntimeOrchestrationContext(instance_id) + ctx = _RuntimeOrchestrationContext(instance_id, self._registry) + version_failure = None try: # Rebuild local state by replaying old history into the orchestrator function self._logger.debug( @@ -902,6 +980,23 @@ def execute( for old_event in old_events: self.process_event(ctx, old_event) + # Process versioning if applicable + execution_started_events = [e.executionStarted for e in old_events if e.HasField("executionStarted")] + # We only check versioning if there are executionStarted events - otherwise, on the first replay when + # ctx.version will be Null, we may invalidate orchestrations early depending on the versioning strategy. + if self._registry.versioning and len(execution_started_events) > 0: + version_failure = self.evaluate_orchestration_versioning( + self._registry.versioning, + ctx.version + ) + if version_failure: + self._logger.warning( + f"Orchestration version did not meet worker versioning requirements. " + f"Error action = '{self._registry.versioning.failure_strategy}'. " + f"Version error = '{version_failure}'" + ) + raise pe.VersionFailureException + # Get new actions by executing newly received events into the orchestrator function if self._logger.level <= logging.DEBUG: summary = _get_new_event_summary(new_events) @@ -912,6 +1007,15 @@ def execute( for new_event in new_events: self.process_event(ctx, new_event) + except pe.VersionFailureException as ex: + if self._registry.versioning and self._registry.versioning.failure_strategy == VersionFailureStrategy.FAIL: + if version_failure: + ctx.set_failed(version_failure) + else: + ctx.set_failed(ex) + elif self._registry.versioning and self._registry.versioning.failure_strategy == VersionFailureStrategy.REJECT: + raise pe.AbandonOrchestrationError + except Exception as ex: # Unhandled exceptions fail the orchestration ctx.set_failed(ex) @@ -961,6 +1065,9 @@ def process_event( f"A '{event.executionStarted.name}' orchestrator was not registered." ) + if event.executionStarted.version: + ctx._version = event.executionStarted.version.value + # deserialize the input, if any input = None if ( @@ -1223,6 +1330,48 @@ def process_event( # The orchestrator generator function completed ctx.set_complete(generatorStopped.value, pb.ORCHESTRATION_STATUS_COMPLETED) + def evaluate_orchestration_versioning(self, versioning: Optional[VersioningOptions], orchestration_version: Optional[str]) -> Optional[pb.TaskFailureDetails]: + if versioning is None: + return None + version_comparison = self.compare_versions(orchestration_version, versioning.version) + if versioning.match_strategy == VersionMatchStrategy.NONE: + return None + elif versioning.match_strategy == VersionMatchStrategy.STRICT: + if version_comparison != 0: + return pb.TaskFailureDetails( + errorType="VersionMismatch", + errorMessage=f"The orchestration version '{orchestration_version}' does not match the worker version '{versioning.version}'.", + isNonRetriable=True, + ) + elif versioning.match_strategy == VersionMatchStrategy.CURRENT_OR_OLDER: + if version_comparison > 0: + return pb.TaskFailureDetails( + errorType="VersionMismatch", + errorMessage=f"The orchestration version '{orchestration_version}' is greater than the worker version '{versioning.version}'.", + isNonRetriable=True, + ) + else: + # If there is a type of versioning we don't understand, it is better to treat it as a versioning failure. + return pb.TaskFailureDetails( + errorType="VersionMismatch", + errorMessage=f"The version match strategy '{versioning.match_strategy}' is unknown.", + isNonRetriable=True, + ) + + def compare_versions(self, source_version: Optional[str], default_version: Optional[str]) -> int: + if not source_version and not default_version: + return 0 + if not source_version: + return -1 + if not default_version: + return 1 + try: + source_version_parsed = parse(source_version) + default_version_parsed = parse(default_version) + return (source_version_parsed > default_version_parsed) - (source_version_parsed < default_version_parsed) + except InvalidVersion: + return (source_version > default_version) - (source_version < default_version) + class _ActivityExecutor: def __init__(self, registry: _Registry, logger: logging.Logger): diff --git a/pyproject.toml b/pyproject.toml index 5438ca4..bc8ddb0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,8 @@ readme = "README.md" dependencies = [ "grpcio", "protobuf", - "asyncio" + "asyncio", + "packaging" ] [project.urls] diff --git a/requirements.txt b/requirements.txt index 721453b..f32d350 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,5 @@ protobuf pytest pytest-cov azure-identity -asyncio \ No newline at end of file +asyncio +packaging \ No newline at end of file diff --git a/tests/durabletask-azuremanaged/test_dts_orchestration_e2e.py b/tests/durabletask-azuremanaged/test_dts_orchestration_e2e.py index 9b7603f..6155733 100644 --- a/tests/durabletask-azuremanaged/test_dts_orchestration_e2e.py +++ b/tests/durabletask-azuremanaged/test_dts_orchestration_e2e.py @@ -13,7 +13,7 @@ from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker # NOTE: These tests assume a sidecar process is running. Example command: -# docker run --name durabletask-sidecar -p 4001:4001 --env 'DURABLETASK_SIDECAR_LOGLEVEL=Debug' --rm cgillum/durabletask-sidecar:latest start --backend Emulator +# docker run -i -p 8080:8080 -p 8082:8082 -d mcr.microsoft.com/dts/dts-emulator:latest pytestmark = pytest.mark.dts # Read the environment variables diff --git a/tests/durabletask-azuremanaged/test_dts_orchestration_versioning_e2e.py b/tests/durabletask-azuremanaged/test_dts_orchestration_versioning_e2e.py new file mode 100644 index 0000000..8b62185 --- /dev/null +++ b/tests/durabletask-azuremanaged/test_dts_orchestration_versioning_e2e.py @@ -0,0 +1,374 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import json +import os + +import pytest + +from durabletask import client, task, worker +from durabletask.azuremanaged.client import DurableTaskSchedulerClient +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker + +# NOTE: These tests assume a sidecar process is running. Example command: +# docker run -i -p 8080:8080 -p 8082:8082 -d mcr.microsoft.com/dts/dts-emulator:latest +pytestmark = pytest.mark.dts + +# Read the environment variables +taskhub_name = os.getenv("TASKHUB", "default") +endpoint = os.getenv("ENDPOINT", "http://localhost:8080") + + +def plus_one(_: task.ActivityContext, input: int) -> int: + return input + 1 + + +def plus_two(_: task.ActivityContext, input: int) -> int: + return input + 2 + + +def single_activity(ctx: task.OrchestrationContext, start_val: int): + yield ctx.call_activity(plus_one, input=start_val) + return "Success" + + +def sequence(ctx: task.OrchestrationContext, start_val: int): + numbers = [start_val] + current = start_val + for _ in range(10): + current = yield ctx.call_activity(plus_one, input=current) + numbers.append(current) + return numbers + + +def test_versioned_orchestration_succeeds(): + # Start a worker, which will connect to the sidecar in a background thread + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(sequence) + w.add_activity(plus_one) + w.use_versioning(worker.VersioningOptions( + version="1.0.0", + default_version="1.0.0", + match_strategy=worker.VersionMatchStrategy.CURRENT_OR_OLDER, + failure_strategy=worker.VersionFailureStrategy.FAIL + )) + w.start() + + task_hub_client = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None, + default_version="1.0.0") + id = task_hub_client.schedule_new_orchestration(sequence, input=1, version="1.0.0") + state = task_hub_client.wait_for_orchestration_completion( + id, timeout=30) + + assert state is not None + assert state.name == task.get_name(sequence) + assert state.instance_id == id + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.failure_details is None + assert state.serialized_input == json.dumps(1) + assert state.serialized_output == json.dumps([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) + assert state.serialized_custom_status is None + + +def test_lower_version_worker_fails(): + # Start a worker, which will connect to the sidecar in a background thread + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(single_activity) + w.use_versioning(worker.VersioningOptions( + version="1.0.0", + default_version="1.0.0", + match_strategy=worker.VersionMatchStrategy.CURRENT_OR_OLDER, + failure_strategy=worker.VersionFailureStrategy.FAIL + )) + w.start() + + task_hub_client = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None, + default_version="1.0.0") + id = task_hub_client.schedule_new_orchestration(single_activity, input=1, version="1.1.0") + state = task_hub_client.wait_for_orchestration_completion( + id, timeout=30) + + assert state is not None + assert state.name == task.get_name(single_activity) + assert state.instance_id == id + assert state.runtime_status == client.OrchestrationStatus.FAILED + assert state.failure_details is not None + assert state.failure_details.message.find("The orchestration version '1.1.0' is greater than the worker version '1.0.0'.") >= 0 + + +def test_lower_version_worker_no_match_strategy_succeeds(): + # Start a worker, which will connect to the sidecar in a background thread + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(single_activity) + w.add_activity(plus_one) + w.use_versioning(worker.VersioningOptions( + version="1.0.0", + default_version="1.0.0", + match_strategy=worker.VersionMatchStrategy.NONE, + failure_strategy=worker.VersionFailureStrategy.FAIL + )) + w.start() + + task_hub_client = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None, + default_version="1.0.0") + id = task_hub_client.schedule_new_orchestration(single_activity, input=1, version="1.1.0") + state = task_hub_client.wait_for_orchestration_completion( + id, timeout=30) + + assert state is not None + assert state.name == task.get_name(single_activity) + assert state.instance_id == id + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.failure_details is None + + +def test_upper_version_worker_succeeds(): + # Start a worker, which will connect to the sidecar in a background thread + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(single_activity) + w.add_activity(plus_one) + w.use_versioning(worker.VersioningOptions( + version="1.1.0", + default_version="1.1.0", + match_strategy=worker.VersionMatchStrategy.CURRENT_OR_OLDER, + failure_strategy=worker.VersionFailureStrategy.FAIL + )) + w.start() + + task_hub_client = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None, + default_version="1.1.0") + id = task_hub_client.schedule_new_orchestration(single_activity, input=1, version="1.0.0") + state = task_hub_client.wait_for_orchestration_completion( + id, timeout=30) + + assert state is not None + assert state.name == task.get_name(single_activity) + assert state.instance_id == id + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.failure_details is None + + +def test_upper_version_worker_strict_fails(): + # Start a worker, which will connect to the sidecar in a background thread + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(single_activity) + w.add_activity(plus_one) + w.use_versioning(worker.VersioningOptions( + version="1.1.0", + default_version="1.1.0", + match_strategy=worker.VersionMatchStrategy.STRICT, + failure_strategy=worker.VersionFailureStrategy.FAIL + )) + w.start() + + task_hub_client = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None, + default_version="1.1.0") + id = task_hub_client.schedule_new_orchestration(single_activity, input=1, version="1.0.0") + state = task_hub_client.wait_for_orchestration_completion( + id, timeout=30) + + assert state is not None + assert state.name == task.get_name(single_activity) + assert state.instance_id == id + assert state.runtime_status == client.OrchestrationStatus.FAILED + assert state.failure_details is not None + assert state.failure_details.message.find("The orchestration version '1.0.0' does not match the worker version '1.1.0'.") >= 0 + + +def test_reject_abandons_and_reprocess(): + # Start a worker, which will connect to the sidecar in a background thread + instance_id: str = '' + thrown = False + try: + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(single_activity) + w.add_activity(plus_one) + w.use_versioning(worker.VersioningOptions( + version="1.0.0", + default_version="1.1.0", + match_strategy=worker.VersionMatchStrategy.STRICT, + failure_strategy=worker.VersionFailureStrategy.REJECT + )) + w.start() + + task_hub_client = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None, + default_version="1.1.0") + instance_id = task_hub_client.schedule_new_orchestration(single_activity, input=1) + state = task_hub_client.wait_for_orchestration_completion( + instance_id, timeout=5) + except TimeoutError as e: + thrown = True + assert str(e).find("Timed-out waiting for the orchestration to complete") >= 0 + + assert thrown is True + + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(single_activity) + w.add_activity(plus_one) + w.use_versioning(worker.VersioningOptions( + version="1.1.0", + default_version="1.1.0", + match_strategy=worker.VersionMatchStrategy.STRICT, + failure_strategy=worker.VersionFailureStrategy.REJECT + )) + w.start() + + task_hub_client = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None, + default_version="1.1.0") + state = task_hub_client.wait_for_orchestration_completion( + instance_id, timeout=5) + + assert state is not None + assert state.name == task.get_name(single_activity) + assert state.instance_id == instance_id + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.failure_details is None + + +def multiversion_sequence(ctx: task.OrchestrationContext, start_val: int): + if ctx.version == "1.0.0": + result = yield ctx.call_activity(plus_one, input=start_val) + elif ctx.version == "1.1.0": + result = yield ctx.call_activity(plus_two, input=start_val) + else: + raise ValueError(f"Unsupported version: {ctx.version}") + return result + + +def test_multiversion_orchestration_succeeds(): + # Start a worker, which will connect to the sidecar in a background thread + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(multiversion_sequence) + w.add_activity(plus_one) + w.add_activity(plus_two) + w.use_versioning(worker.VersioningOptions( + version="1.1.0", + default_version="1.1.0", + match_strategy=worker.VersionMatchStrategy.CURRENT_OR_OLDER, + failure_strategy=worker.VersionFailureStrategy.FAIL + )) + w.start() + + task_hub_client = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None, + default_version="1.1.0") + id = task_hub_client.schedule_new_orchestration(multiversion_sequence, input=1, version="1.0.0") + state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) + + id_2 = task_hub_client.schedule_new_orchestration(multiversion_sequence, input=1, version="1.1.0") + state_2 = task_hub_client.wait_for_orchestration_completion(id_2, timeout=30) + + print(state.failure_details.message if state and state.failure_details else "State is None") + print(state_2.failure_details.message if state_2 and state_2.failure_details else "State is None") + + assert state is not None + assert state.name == task.get_name(multiversion_sequence) + assert state.instance_id == id + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.failure_details is None + assert state.serialized_input == json.dumps(1) + assert state.serialized_output == json.dumps(2) + + assert state_2 is not None + assert state_2.name == task.get_name(multiversion_sequence) + assert state_2.instance_id == id_2 + assert state_2.runtime_status == client.OrchestrationStatus.COMPLETED + assert state_2.failure_details is None + assert state_2.serialized_input == json.dumps(1) + assert state_2.serialized_output == json.dumps(3) + + +def sequence_suborchestator(ctx: task.OrchestrationContext, start_val: int): + numbers = [] + for current in range(start_val, start_val + 5): + current = yield ctx.call_activity(plus_one, input=current) + numbers.append(current) + return numbers + + +def sequence_parent(ctx: task.OrchestrationContext, sub_orchestration_version: str): + tasks = [] + for current in range(2): + tasks.append(ctx.call_sub_orchestrator(sequence_suborchestator, input=current * 5, version=sub_orchestration_version)) + results = yield task.when_all(tasks) + numbers = [] + for result in results: + numbers.extend(result) + return numbers + + +def test_versioned_sub_orchestration_succeeds(): + # Start a worker, which will connect to the sidecar in a background thread + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(sequence_parent) + w.add_orchestrator(sequence_suborchestator) + w.add_activity(plus_one) + w.use_versioning(worker.VersioningOptions( + version="1.0.0", + default_version="1.0.0", + match_strategy=worker.VersionMatchStrategy.CURRENT_OR_OLDER, + failure_strategy=worker.VersionFailureStrategy.FAIL + )) + w.start() + + task_hub_client = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None, + default_version="1.0.0") + id = task_hub_client.schedule_new_orchestration(sequence_parent, input='1.0.0', version="1.0.0") + state = task_hub_client.wait_for_orchestration_completion( + id, timeout=30) + + assert state is not None + assert state.name == task.get_name(sequence_parent) + assert state.instance_id == id + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.failure_details is None + assert state.serialized_input == json.dumps("1.0.0") + assert state.serialized_output == json.dumps([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + assert state.serialized_custom_status is None + + +def test_higher_versioned_sub_orchestration_fails(): + # Start a worker, which will connect to the sidecar in a background thread + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(sequence_parent) + w.add_orchestrator(sequence_suborchestator) + w.add_activity(plus_one) + w.use_versioning(worker.VersioningOptions( + version="1.0.0", + default_version="1.0.0", + match_strategy=worker.VersionMatchStrategy.CURRENT_OR_OLDER, + failure_strategy=worker.VersionFailureStrategy.FAIL + )) + w.start() + + task_hub_client = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None, + default_version="1.0.0") + id = task_hub_client.schedule_new_orchestration(sequence_parent, input='1.1.0', version="1.0.0") + state = task_hub_client.wait_for_orchestration_completion( + id, timeout=30) + + assert state is not None + assert state.name == task.get_name(sequence_parent) + assert state.instance_id == id + assert state.runtime_status == client.OrchestrationStatus.FAILED + assert state.failure_details is not None + assert state.failure_details.message.find("The orchestration version '1.1.0' is greater than the worker version '1.0.0'.") >= 0 diff --git a/tests/durabletask/test_orchestration_e2e.py b/tests/durabletask/test_orchestration_e2e.py index 63d2058..63f14c5 100644 --- a/tests/durabletask/test_orchestration_e2e.py +++ b/tests/durabletask/test_orchestration_e2e.py @@ -11,7 +11,8 @@ from durabletask import client, task, worker # NOTE: These tests assume a sidecar process is running. Example command: -# docker run --name durabletask-sidecar -p 4001:4001 --env 'DURABLETASK_SIDECAR_LOGLEVEL=Debug' --rm cgillum/durabletask-sidecar:latest start --backend Emulator +# go install github.com/microsoft/durabletask-go@main +# durabletask-go --port 4001 pytestmark = pytest.mark.e2e diff --git a/tests/durabletask/test_orchestration_versioning_e2e.py b/tests/durabletask/test_orchestration_versioning_e2e.py new file mode 100644 index 0000000..45dd2bd --- /dev/null +++ b/tests/durabletask/test_orchestration_versioning_e2e.py @@ -0,0 +1,58 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import json +import warnings + +import pytest + +from durabletask import client, task, worker + +# NOTE: These tests assume a sidecar process is running. Example command: +# go install github.com/microsoft/durabletask-go@main +# durabletask-go --port 4001 +pytestmark = pytest.mark.e2e + + +def test_versioned_orchestration_succeeds(): + warnings.warn("Skipping test_versioned_orchestration_succeeds. " + "Currently not passing as the sidecar does not support versioning yet") + return # Currently not passing as the sidecar does not support versioning yet + # Remove these lines to run the test after the sidecar is updated + + def plus_one(_: task.ActivityContext, input: int) -> int: + return input + 1 + + def sequence(ctx: task.OrchestrationContext, start_val: int): + numbers = [start_val] + current = start_val + for _ in range(10): + current = yield ctx.call_activity(plus_one, input=current, tags={'Activity': 'PlusOne'}) + numbers.append(current) + return numbers + + # Start a worker, which will connect to the sidecar in a background thread + with worker.TaskHubGrpcWorker() as w: + w.add_orchestrator(sequence) + w.add_activity(plus_one) + w.use_versioning(worker.VersioningOptions( + version="1.0.0", + default_version="1.0.0", + match_strategy=worker.VersionMatchStrategy.CURRENT_OR_OLDER, + failure_strategy=worker.VersionFailureStrategy.FAIL + )) + w.start() + + task_hub_client = client.TaskHubGrpcClient(default_version="1.0.0") + id = task_hub_client.schedule_new_orchestration(sequence, input=1, tags={'Orchestration': 'Sequence'}, version="1.0.0") + state = task_hub_client.wait_for_orchestration_completion( + id, timeout=30) + + assert state is not None + assert state.name == task.get_name(sequence) + assert state.instance_id == id + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.failure_details is None + assert state.serialized_input == json.dumps(1) + assert state.serialized_output == json.dumps([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) + assert state.serialized_custom_status is None