From c62edb7a8b645aecc5661e88064129010513ba61 Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Sun, 7 Sep 2025 21:30:07 -0600 Subject: [PATCH 01/21] Signalling working, some other logic --- durabletask/client.py | 16 +- durabletask/entities/__init__.py | 10 + durabletask/entities/entity_instance_id.py | 28 ++ durabletask/internal/entity_lock_releaser.py | 6 + durabletask/internal/entity_state_shim.py | 26 ++ durabletask/internal/helpers.py | 52 +++ .../internal/orchestration_entity_context.py | 78 +++++ durabletask/task.py | 107 +++++- durabletask/worker.py | 326 +++++++++++++++++- 9 files changed, 643 insertions(+), 6 deletions(-) create mode 100644 durabletask/entities/__init__.py create mode 100644 durabletask/entities/entity_instance_id.py create mode 100644 durabletask/internal/entity_lock_releaser.py create mode 100644 durabletask/internal/entity_state_shim.py create mode 100644 durabletask/internal/orchestration_entity_context.py diff --git a/durabletask/client.py b/durabletask/client.py index bc3abed..5f01678 100644 --- a/durabletask/client.py +++ b/durabletask/client.py @@ -4,13 +4,14 @@ import logging import uuid from dataclasses import dataclass -from datetime import datetime +from datetime import datetime, timezone from enum import Enum from typing import Any, Optional, Sequence, TypeVar, Union import grpc from google.protobuf import wrappers_pb2 +from durabletask.entities.entity_instance_id import EntityInstanceId import durabletask.internal.helpers as helpers import durabletask.internal.orchestrator_service_pb2 as pb import durabletask.internal.orchestrator_service_pb2_grpc as stubs @@ -227,3 +228,16 @@ def purge_orchestration(self, instance_id: str, recursive: bool = True): req = pb.PurgeInstancesRequest(instanceId=instance_id, recursive=recursive) self._logger.info(f"Purging instance '{instance_id}'.") self._stub.PurgeInstances(req) + + def signal_entity(self, entity_instance_id: EntityInstanceId, operation_name: str, input: Optional[Any] = None, signal_entity_options=None, cancellation=None): + scheduled_time = signal_entity_options.scheduled_time if signal_entity_options and signal_entity_options.scheduled_time else None + req = pb.SignalEntityRequest( + instanceId=str(entity_instance_id), + requestId=str(uuid.uuid4()), + name=operation_name, + input=wrappers_pb2.StringValue(value=shared.to_json(input)) if input else None, + scheduledTime=scheduled_time, + requestTime=helpers.new_timestamp(datetime.now(timezone.utc)) + ) + self._logger.info(f"Signaling entity '{entity_instance_id}' operation '{operation_name}'.") + self._stub.SignalEntity(req, timeout=cancellation.timeout if cancellation else None) diff --git a/durabletask/entities/__init__.py b/durabletask/entities/__init__.py new file mode 100644 index 0000000..47ea475 --- /dev/null +++ b/durabletask/entities/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Durable Task SDK for Python entities component""" + +from durabletask.entities.entity_instance_id import EntityInstanceId + +__all__ = ["EntityInstanceId"] + +PACKAGE_NAME = "durabletask.entities" diff --git a/durabletask/entities/entity_instance_id.py b/durabletask/entities/entity_instance_id.py new file mode 100644 index 0000000..1fee44f --- /dev/null +++ b/durabletask/entities/entity_instance_id.py @@ -0,0 +1,28 @@ +from typing import Optional + + +class EntityInstanceId: + def __init__(self, entity: str, key: str): + self.entity = entity + self.key = key + + def __str__(self) -> str: + return f"@{self.entity}@{self.key}" + + def __eq__(self, other): + if not isinstance(other, EntityInstanceId): + return False + return self.entity == other.entity and self.key == other.key + + def __lt__(self, other): + if not isinstance(other, EntityInstanceId): + return self < other + return str(self) < str(other) + + @staticmethod + def parse(entity_id: str) -> Optional["EntityInstanceId"]: + try: + _, entity, key = entity_id.split("@", 2) + return EntityInstanceId(entity=entity, key=key) + except ValueError as ex: + raise ValueError("Invalid entity ID format", ex) diff --git a/durabletask/internal/entity_lock_releaser.py b/durabletask/internal/entity_lock_releaser.py new file mode 100644 index 0000000..bc898b5 --- /dev/null +++ b/durabletask/internal/entity_lock_releaser.py @@ -0,0 +1,6 @@ +from durabletask.entities.entity_instance_id import EntityInstanceId + + +class EntityLockReleaser: + def __init__(self, entities: list[EntityInstanceId]): + self.entities = entities diff --git a/durabletask/internal/entity_state_shim.py b/durabletask/internal/entity_state_shim.py new file mode 100644 index 0000000..fba2d81 --- /dev/null +++ b/durabletask/internal/entity_state_shim.py @@ -0,0 +1,26 @@ +from typing import Optional, Type + + +class StateShim: + def __init__(self, start_state): + self._current_state = start_state + self._checkpoint_state = start_state + + def get_state(self, intended_type: Optional[Type]): + if not intended_type: + return self._current_state + if isinstance(self._current_state, intended_type) or self._current_state is None: + return self._current_state + return intended_type(self._current_state) + + def set_state(self, state): + self._current_state = state + + def commit(self): + self._checkpoint_state = self._current_state + + def rollback(self): + self._current_state = self._checkpoint_state + + def reset(self): + self._current_state = None diff --git a/durabletask/internal/helpers.py b/durabletask/internal/helpers.py index 6140dec..2fe35e8 100644 --- a/durabletask/internal/helpers.py +++ b/durabletask/internal/helpers.py @@ -159,6 +159,12 @@ def get_string_value(val: Optional[str]) -> Optional[wrappers_pb2.StringValue]: return wrappers_pb2.StringValue(value=val) +def get_string_value_or_empty(val: Optional[str]) -> wrappers_pb2.StringValue: + if val is None: + return wrappers_pb2.StringValue(value="") + return wrappers_pb2.StringValue(value=val) + + def new_complete_orchestration_action( id: int, status: pb.OrchestrationStatus, @@ -189,6 +195,52 @@ def new_schedule_task_action(id: int, name: str, encoded_input: Optional[str], )) +def new_call_entity_action(id: int, name: str, encoded_input: Optional[str]): + return pb.OrchestratorAction(id=id, sendEntityMessage=pb.SendEntityMessageAction(entityOperationCalled=pb.EntityOperationCalledEvent( + requestId=None, + targetInstanceId=get_string_value(name), + input=get_string_value(encoded_input) + ))) + + +def new_signal_entity_action(id: int, name: str): + return pb.OrchestratorAction(id=id, sendEntityMessage=pb.SendEntityMessageAction(entityOperationSignaled=pb.EntityOperationSignaledEvent( + requestId=None, + targetInstanceId=get_string_value(name) + ))) + + +def new_lock_entities_action(id: int, instance_id: str, critical_section_id: str, entity_ids: list[str]): + return pb.OrchestratorAction(id=id, sendEntityMessage=pb.SendEntityMessageAction(entityLockRequested=pb.EntityLockRequestedEvent( + parentInstanceId=get_string_value(instance_id), + criticalSectionId=critical_section_id, + lockSet=entity_ids, + position=0 + ))) + + +def convert_to_entity_batch_request(req: pb.EntityRequest) -> tuple[pb.EntityBatchRequest, list[pb.OperationInfo]]: + batch_request = pb.EntityBatchRequest(entityState=req.entityState, instanceId=req.instanceId, operations=[]) + + operation_infos: list[pb.OperationInfo] = [] + + for op in req.operationRequests: + if op.HasField("entityOperationSignaled"): + batch_request.operations.append(pb.OperationRequest(requestId=op.entityOperationSignaled.requestId, + operation=op.entityOperationSignaled.operation, + input=op.entityOperationSignaled.input)) + operation_infos.append(pb.OperationInfo(requestId=op.entityOperationSignaled.requestId, + responseDestination=None)) + elif op.HasField("entityOperationCalled"): + batch_request.operations.append(pb.OperationRequest(requestId=op.entityOperationCalled.requestId, + operation=op.entityOperationCalled.operation, + input=op.entityOperationCalled.input)) + operation_infos.append(pb.OperationInfo(requestId=op.entityOperationCalled.requestId, + responseDestination=None)) + + return batch_request, operation_infos + + def new_timestamp(dt: datetime) -> timestamp_pb2.Timestamp: ts = timestamp_pb2.Timestamp() ts.FromDatetime(dt) diff --git a/durabletask/internal/orchestration_entity_context.py b/durabletask/internal/orchestration_entity_context.py new file mode 100644 index 0000000..b74097f --- /dev/null +++ b/durabletask/internal/orchestration_entity_context.py @@ -0,0 +1,78 @@ +from datetime import datetime +from typing import Generator, List, Optional, Tuple + +from durabletask.entities.entity_instance_id import EntityInstanceId + + +class OrchestrationEntityContext: + def __init__(self, instance_id: str): + self.instance_id = instance_id + + self.lock_acquisition_pending = False + + self.critical_section_id = None + self.critical_section_locks = [] + self.available_locks = [] + + @property + def is_inside_critical_section(self) -> bool: + return self.critical_section_id is not None + + def get_available_entities(self) -> Generator[str, None, None]: + if self.is_inside_critical_section: + for available_lock in self.available_locks: + yield available_lock + + def validate_suborchestration_transition(self) -> Tuple[bool, str]: + if self.is_inside_critical_section: + return False, "While holding locks, cannot call suborchestrators." + return True, "" + + def validate_operation_transition(self, target_instance_id: EntityInstanceId, one_way: bool) -> Tuple[bool, str]: + if self.is_inside_critical_section: + lock_to_use = target_instance_id + if one_way: + if target_instance_id in self.critical_section_locks: + return False, "Must not signal a locked entity from a critical section." + else: + try: + self.available_locks.remove(lock_to_use) + except ValueError: + if self.lock_acquisition_pending: + return False, "Must await the completion of the lock request prior to calling any entity." + if lock_to_use in self.critical_section_locks: + return False, "Must not call an entity from a critical section while a prior call to the same entity is still pending." + else: + return False, "Must not call an entity from a critical section if it is not one of the locked entities." + return True, "" + + def validate_acquire_transition(self) -> Tuple[bool, str]: + if self.is_inside_critical_section: + return False, "Must not enter another critical section from within a critical section." + return True, "" + + def recover_lock_after_call(self, target_instance_id: EntityInstanceId): + if self.is_inside_critical_section: + self.available_locks.append(target_instance_id) + + def emit_lock_release_messages(self): + raise NotImplementedError() + + def emit_request_message(self, target, operation_name: str, one_way: bool, operation_id: str, + scheduled_time_utc: datetime, input: Optional[str], + request_time: Optional[datetime] = None, create_trace: bool = False): + raise NotImplementedError() + + def emit_acquire_message(self, lock_request_id: str, entities: List[str]): + raise NotImplementedError() + + def complete_acquire(self, result, critical_section_id): + # TODO: HashSet or equivalent + self.available_locks = self.critical_section_locks + self.lock_acquisition_pending = False + + def adjust_outgoing_message(self, instance_id: str, request_message, capped_time: datetime) -> str: + raise NotImplementedError() + + def deserialize_entity_response_event(self, event_content: str): + raise NotImplementedError() diff --git a/durabletask/task.py b/durabletask/task.py index 14f5fac..043e83c 100644 --- a/durabletask/task.py +++ b/durabletask/task.py @@ -7,8 +7,11 @@ import math from abc import ABC, abstractmethod from datetime import datetime, timedelta -from typing import Any, Callable, Generator, Generic, Optional, TypeVar, Union +from typing import Any, Callable, Generator, Generic, Optional, Type, TypeVar, Union +from durabletask.entities.entity_instance_id import EntityInstanceId +from durabletask.internal.entity_lock_releaser import EntityLockReleaser +from durabletask.internal.entity_state_shim import StateShim import durabletask.internal.helpers as pbh import durabletask.internal.orchestrator_service_pb2 as pb @@ -137,6 +140,55 @@ def call_activity(self, activity: Union[Activity[TInput, TOutput], str], *, """ pass + @abstractmethod + def call_entity(self, entity: EntityInstanceId, *, + input: Optional[TInput] = None): + """Schedule entity function for execution. + + Parameters + ---------- + entity: EntityInstanceId + The ID of the entity instance to call. + input: Optional[TInput] + The optional JSON-serializable input to pass to the entity function. + + Returns + ------- + Task + A Durable Task that completes when the called entity function completes or fails. + """ + pass + + @abstractmethod + def signal_entity( + self, + entity_id: EntityInstanceId + ) -> None: + """Signal an entity function for execution. + + Parameters + ---------- + entity_id: EntityInstanceId + The ID of the entity instance to signal. + """ + pass + + @abstractmethod + def lock_entities(self, entities: list[EntityInstanceId]) -> EntityLockReleaser: + """Lock the specified entity instances for the duration of the orchestration. + + Parameters + ---------- + entities: list[EntityInstanceId] + The list of entity instance IDs to lock. + + Returns + ------- + EntityLockReleaser + A context manager that releases the locks when disposed. + """ + pass + @abstractmethod def call_sub_orchestrator(self, orchestrator: Orchestrator[TInput, TOutput], *, input: Optional[TInput] = None, @@ -452,12 +504,65 @@ def task_id(self) -> int: return self._task_id +class EntityContext: + def __init__(self, orchestration_id: str, operation: str, state: StateShim, entity_id: EntityInstanceId): + self._orchestration_id = orchestration_id + self._operation = operation + self._state = state + self._entity_id = entity_id + + @property + def orchestration_id(self) -> str: + """Get the ID of the orchestration instance that scheduled this entity. + + Returns + ------- + str + The ID of the current orchestration instance. + """ + return self._orchestration_id + + @property + def operation(self) -> str: + """Get the operation associated with this entity invocation. + + The operation is a string that identifies the specific action being + performed on the entity. It can be used to distinguish between + multiple operations that are part of the same entity invocation. + + Returns + ------- + str + The operation associated with this entity invocation. + """ + return self._operation + + def get_state(self, intended_type: Optional[Type] = None): + return self._state.get_state(intended_type) + + def set_state(self, new_state): + self._state.set_state(new_state) + + @property + def entity_id(self) -> EntityInstanceId: + """Get the ID of the entity instance. + + Returns + ------- + str + The ID of the current entity instance. + """ + return self._entity_id + + # Orchestrators are generators that yield tasks and receive/return any type Orchestrator = Callable[[OrchestrationContext, TInput], Union[Generator[Task, Any, Any], TOutput]] # Activities are simple functions that can be scheduled by orchestrators Activity = Callable[[ActivityContext, TInput], TOutput] +Entity = Callable[[EntityContext, TInput], TOutput] + class RetryPolicy: """Represents the retry policy for an orchestration or activity function.""" diff --git a/durabletask/worker.py b/durabletask/worker.py index ba5f0ba..5011903 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -3,20 +3,28 @@ import asyncio import inspect +import json import logging import os import random from concurrent.futures import ThreadPoolExecutor -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from threading import Event, Thread from types import GeneratorType from enum import Enum -from typing import Any, Generator, Optional, Sequence, TypeVar, Union +from typing import Any, Generator, List, Optional, Sequence, TypeVar, Union +import uuid from packaging.version import InvalidVersion, parse import grpc from google.protobuf import empty_pb2 +from durabletask.internal import helpers +from durabletask.internal.entity_state_shim import StateShim +from durabletask.internal.helpers import new_timestamp +from durabletask.entities.entity_instance_id import EntityInstanceId +from durabletask.internal.entity_lock_releaser import EntityLockReleaser +from durabletask.internal.orchestration_entity_context import OrchestrationEntityContext import durabletask.internal.helpers as ph import durabletask.internal.exceptions as pe import durabletask.internal.orchestrator_service_pb2 as pb @@ -124,11 +132,13 @@ def __init__(self, version: Optional[str] = None, class _Registry: orchestrators: dict[str, task.Orchestrator] activities: dict[str, task.Activity] + entities: dict[str, task.Entity] versioning: Optional[VersioningOptions] = None def __init__(self): self.orchestrators = {} self.activities = {} + self.entities = {} def add_orchestrator(self, fn: task.Orchestrator) -> str: if fn is None: @@ -168,6 +178,25 @@ def add_named_activity(self, name: str, fn: task.Activity) -> None: def get_activity(self, name: str) -> Optional[task.Activity]: return self.activities.get(name) + def add_entity(self, fn: task.Entity) -> str: + if fn is None: + raise ValueError("An entity function argument is required.") + + name = task.get_name(fn) + self.add_named_entity(name, fn) + return name + + def add_named_entity(self, name: str, fn: task.Entity) -> None: + if not name: + raise ValueError("A non-empty entity name is required.") + if name in self.entities: + raise ValueError(f"A '{name}' entity already exists.") + + self.entities[name] = fn + + def get_entity(self, name: str) -> Optional[task.Entity]: + return self.entities.get(name) + class OrchestratorNotRegisteredError(ValueError): """Raised when attempting to start an orchestration that is not registered""" @@ -181,6 +210,12 @@ class ActivityNotRegisteredError(ValueError): pass +class EntityNotRegisteredError(ValueError): + """Raised when attempting to call an entity that is not registered""" + + pass + + class TaskHubGrpcWorker: """A gRPC-based worker for processing durable task orchestrations and activities. @@ -329,6 +364,14 @@ def add_activity(self, fn: task.Activity) -> str: ) return self._registry.add_activity(fn) + def add_entity(self, fn: task.Entity) -> str: + """Registers an entity function with the worker.""" + if self._is_running: + raise RuntimeError( + "Entities cannot be added while the worker is running." + ) + return self._registry.add_entity(fn) + def use_versioning(self, version: VersioningOptions) -> None: """Initializes versioning options for sub-orchestrators and activities.""" if self._is_running: @@ -490,6 +533,20 @@ def stream_reader(): stub, work_item.completionToken, ) + elif work_item.HasField("entityRequest"): + self._async_worker_manager.submit_activity( + self._execute_entity_batch, + work_item.entityRequest, + stub, + work_item.completionToken, + ) + elif work_item.HasField("entityRequestV2"): + self._async_worker_manager.submit_activity( + self._execute_entity_batch, + work_item.entityRequestV2, + stub, + work_item.completionToken + ) elif work_item.HasField("healthPing"): pass else: @@ -635,12 +692,80 @@ def _execute_activity( f"Failed to deliver activity response for '{req.name}#{req.taskId}' of orchestration ID '{instance_id}' to sidecar: {ex}" ) + def _execute_entity_batch( + self, + req: Union[pb.EntityBatchRequest, pb.EntityRequest], + stub: stubs.TaskHubSidecarServiceStub, + completionToken, + ): + if isinstance(req, pb.EntityRequest): + req, operation_infos = helpers.convert_to_entity_batch_request(req) + + entity_state = StateShim(shared.from_json(req.entityState.value) if req.entityState.value else None) + + instance_id = req.instanceId + + results: list[pb.OperationResult] = [] + for operation in req.operations: + start_time = datetime.now(timezone.utc) + executor = _EntityExecutor(self._registry, self._logger) + entity_instance_id = EntityInstanceId.parse(instance_id) + if not entity_instance_id: + raise RuntimeError(f"Invalid entity instance ID '{operation.requestId}' in entity operation request.") + + operation_result = None + + try: + entity_result = executor.execute( + instance_id, entity_instance_id, operation.operation, entity_state, operation.input.value + ) + + entity_result = ph.get_string_value_or_empty(entity_result) + operation_result = pb.OperationResult(success=pb.OperationResultSuccess( + result=entity_result, + startTimeUtc=new_timestamp(start_time), + endTimeUtc=new_timestamp(datetime.now(timezone.utc)) + )) + results.append(operation_result) + + entity_state.commit() + except Exception as ex: + self._logger.exception(ex) + operation_result = pb.OperationResult(failure=pb.OperationResultFailure( + failureDetails=ph.new_failure_details(ex), + startTimeUtc=new_timestamp(start_time), + endTimeUtc=new_timestamp(datetime.now(timezone.utc)) + )) + results.append(operation_result) + + entity_state.rollback() + + try: + stub.CompleteEntityTask(operation_result) + except Exception as ex: + self._logger.exception( + f"Failed to deliver entity response for '{entity_instance_id}' of orchestration ID '{instance_id}' to sidecar: {ex}" + ) + + batch_result = pb.EntityBatchResult( + results=results, + actions=None, # TODO: Context should also provide actions like signaling another entity or starting sub-orchestrations + entityState=helpers.get_string_value(shared.to_json(entity_state._current_state)) if entity_state._current_state else None, + failureDetails=None, + operationInfos=operation_infos, + completionToken=completionToken, + ) + + # TODO: Reset context + + return batch_result + class _RuntimeOrchestrationContext(task.OrchestrationContext): _generator: Optional[Generator[task.Task, Any, Any]] _previous_task: Optional[task.Task] - def __init__(self, instance_id: str, registry: _Registry): + def __init__(self, instance_id: str, registry: _Registry, entity_context: OrchestrationEntityContext): self._generator = None self._is_replaying = True self._is_complete = False @@ -651,6 +776,7 @@ def __init__(self, instance_id: str, registry: _Registry): self._current_utc_datetime = datetime(1000, 1, 1) self._instance_id = instance_id self._registry = registry + self._entity_context = entity_context self._version: Optional[str] = None self._completion_status: Optional[pb.OrchestrationStatus] = None self._received_events: dict[str, list[Any]] = {} @@ -833,6 +959,41 @@ def call_activity( ) return self._pending_tasks.get(id, task.CompletableTask()) + def call_entity( + self, + entity_id: EntityInstanceId, + *, + input: Optional[TInput] = None, + ) -> task.Task: + id = self.next_sequence_number() + + self.call_entity_function_helper( + id, entity_id, input=input + ) + + return self._pending_tasks.get(id, task.CompletableTask()) + + def signal_entity( + self, + entity_id: EntityInstanceId + ) -> None: + id = self.next_sequence_number() + + self.signal_entity_function_helper( + id, entity_id + ) + + def lock_entities(self, entities: list[EntityInstanceId]) -> EntityLockReleaser: + id = self.next_sequence_number() + + self.lock_entities_function_helper( + id, entities + ) + + # Todo: EntityLockReleaser should be a disposable that uses python's using statement + # and should release the locks when disposed + return EntityLockReleaser(entities) + def call_sub_orchestrator( self, orchestrator: task.Orchestrator[TInput, TOutput], @@ -909,6 +1070,70 @@ def call_activity_function_helper( ) self._pending_tasks[id] = fn_task + def call_entity_function_helper( + self, + id: Optional[int], + entity_id: EntityInstanceId, + *, + input: Optional[TInput] = None + ): + if id is None: + id = self.next_sequence_number() + + transition_valid, error_message = self._entity_context.validate_operation_transition(entity_id, False) + if not transition_valid: + raise RuntimeError(error_message) + + encoded_input = shared.to_json(input) if input is not None else None + action = ph.new_call_entity_action(id, str(entity_id), encoded_input) + self._pending_actions[id] = action + + fn_task = task.CompletableTask() + self._pending_tasks[id] = fn_task + + def signal_entity_function_helper( + self, + id: Optional[int], + entity_id: EntityInstanceId + ) -> None: + if id is None: + id = self.next_sequence_number() + + transition_valid, error_message = self._entity_context.validate_operation_transition(entity_id, True) + + if not transition_valid: + raise RuntimeError(error_message) + + action = ph.new_signal_entity_action(id, str(entity_id)) + self._pending_actions[id] = action + + def lock_entities_function_helper( + self, + id: Optional[int], + entity_ids: List[EntityInstanceId] + ): + if id is None: + id = self.next_sequence_number() + + transition_valid, error_message = self._entity_context.validate_acquire_transition() + if not transition_valid: + raise RuntimeError(error_message) + + # Acquire the locks in a globally fixed order to avoid deadlocks + # Also remove duplicates - this can be optimized for perf if necessary + entity_ids = sorted(entity_ids) + entity_ids_dedup = [] + for i, entity_id in enumerate(entity_ids): + if entity_id != entity_ids[i - 1] if i > 0 else None: + entity_ids_dedup.append(entity_id) + + # Use a deterministically replayable unique ID for this lock request + # TODO: Implement deterministically replayable IDs + critical_section_id = str(uuid.uuid4()) + + action = ph.new_lock_entities_action(id, self.instance_id, critical_section_id, [str(eid) for eid in entity_ids_dedup]) + self._pending_actions[id] = action + def wait_for_external_event(self, name: str) -> task.Task: # Check to see if this event has already been received, in which case we # can return it immediately. Otherwise, record out intent to receive an @@ -957,6 +1182,7 @@ def __init__(self, registry: _Registry, logger: logging.Logger): self._logger = logger self._is_suspended = False self._suspended_events: list[pb.HistoryEvent] = [] + self._entity_state: Optional[OrchestrationEntityContext] = None def execute( self, @@ -964,12 +1190,14 @@ def execute( old_events: Sequence[pb.HistoryEvent], new_events: Sequence[pb.HistoryEvent], ) -> ExecutionResults: + self._entity_state = OrchestrationEntityContext(instance_id) + if not new_events: raise task.OrchestrationStateError( "The new history event list must have at least one event in it." ) - ctx = _RuntimeOrchestrationContext(instance_id, self._registry) + ctx = _RuntimeOrchestrationContext(instance_id, self._registry, self._entity_state) try: # Rebuild local state by replaying old history into the orchestrator function self._logger.debug( @@ -1316,6 +1544,57 @@ def process_event( pb.ORCHESTRATION_STATUS_TERMINATED, is_result_encoded=True, ) + elif event.HasField("entityOperationCalled"): + entity_call_id = event.eventId + action = ctx._pending_actions.pop(entity_call_id, None) + entity_task = ctx._pending_tasks.get(entity_call_id, None) + if not action: + raise _get_non_determinism_error( + entity_call_id, task.get_name(ctx.call_entity) + ) + elif not action.HasField("callEntity"): + expected_method_name = task.get_name(ctx.call_entity) + raise _get_wrong_action_type_error( + entity_call_id, expected_method_name, action + ) + # TODO: Validate entity ID + elif event.HasField("entityOperationSignaled"): + entity_signal_id = event.eventId + action = ctx._pending_actions.pop(entity_signal_id, None) + if not action: + raise _get_non_determinism_error( + entity_signal_id, task.get_name(ctx.signal_entity) + ) + elif not action.HasField("signalEntity"): + expected_method_name = task.get_name(ctx.signal_entity) + raise _get_wrong_action_type_error( + entity_signal_id, expected_method_name, action + ) + elif event.HasField("entityLockRequested"): + if not ctx.is_replaying: + self._logger.info(f"{ctx.instance_id}: Entity lock requested.") + self._logger.info(f"Data: {json.dumps(event.entityLockRequested)}") + pass + elif event.HasField("entityUnlockSent"): + if not ctx.is_replaying: + self._logger.info(f"{ctx.instance_id}: Entity unlock sent.") + self._logger.info(f"Data: {json.dumps(event.entityUnlockSent)}") + pass + elif event.HasField("entityLockGranted"): + if not ctx.is_replaying: + self._logger.info(f"{ctx.instance_id}: Entity lock granted.") + self._logger.info(f"Data: {json.dumps(event.entityLockGranted)}") + pass + elif event.HasField("entityOperationCompleted"): + if not ctx.is_replaying: + self._logger.info(f"{ctx.instance_id}: Entity operation completed.") + self._logger.info(f"Data: {json.dumps(event.entityOperationCompleted)}") + pass + elif event.HasField("entityOperationFailed"): + if not ctx.is_replaying: + self._logger.info(f"{ctx.instance_id}: Entity operation failed.") + self._logger.info(f"Data: {json.dumps(event.entityOperationFailed)}") + pass else: eventType = event.WhichOneof("eventType") raise task.OrchestrationStateError( @@ -1406,6 +1685,45 @@ def execute( return encoded_output +class _EntityExecutor: + def __init__(self, registry: _Registry, logger: logging.Logger): + self._registry = registry + self._logger = logger + + def execute( + self, + orchestration_id: str, + entity_id: EntityInstanceId, + operation: str, + state: StateShim, + encoded_input: Optional[str], + ) -> Optional[str]: + """Executes an entity function and returns the serialized result, if any.""" + self._logger.debug( + f"{orchestration_id}: Executing entity '{entity_id}'..." + ) + fn = self._registry.get_entity(entity_id.entity) + if not fn: + raise EntityNotRegisteredError( + f"Entity function named '{entity_id.entity}' was not registered!" + ) + + entity_input = shared.from_json(encoded_input) if encoded_input else None + ctx = task.EntityContext(orchestration_id, operation, state, entity_id) + + # Execute the entity function + entity_output = fn(ctx, entity_input) + + encoded_output = ( + shared.to_json(entity_output) if entity_output is not None else None + ) + chars = len(encoded_output) if encoded_output else 0 + self._logger.debug( + f"{orchestration_id}: Entity '{entity_id}' completed successfully with {chars} char(s) of encoded output." + ) + return encoded_output + + def _get_non_determinism_error( task_id: int, action_name: str ) -> task.NonDeterminismError: From 9a74b6d3adc992fc555c4491d8e7e2ac5bd77a17 Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Wed, 10 Sep 2025 10:03:58 -0700 Subject: [PATCH 02/21] Entities kind of working --- durabletask/internal/helpers.py | 30 +++++--- durabletask/task.py | 13 +++- durabletask/worker.py | 123 +++++++++++++++++++++++++------- 3 files changed, 126 insertions(+), 40 deletions(-) diff --git a/durabletask/internal/helpers.py b/durabletask/internal/helpers.py index 2fe35e8..aaee47c 100644 --- a/durabletask/internal/helpers.py +++ b/durabletask/internal/helpers.py @@ -7,6 +7,7 @@ from google.protobuf import timestamp_pb2, wrappers_pb2 +from durabletask.entities.entity_instance_id import EntityInstanceId import durabletask.internal.orchestrator_service_pb2 as pb # TODO: The new_xxx_event methods are only used by test code and should be moved elsewhere @@ -195,26 +196,30 @@ def new_schedule_task_action(id: int, name: str, encoded_input: Optional[str], )) -def new_call_entity_action(id: int, name: str, encoded_input: Optional[str]): +def new_call_entity_action(id: int, parent_instance_id: str, entity_id: EntityInstanceId, operation: str, encoded_input: Optional[str]): return pb.OrchestratorAction(id=id, sendEntityMessage=pb.SendEntityMessageAction(entityOperationCalled=pb.EntityOperationCalledEvent( - requestId=None, - targetInstanceId=get_string_value(name), - input=get_string_value(encoded_input) + requestId=f"{parent_instance_id}:{id}", + parentInstanceId=get_string_value(parent_instance_id), + targetInstanceId=get_string_value(str(entity_id)), + input=get_string_value(encoded_input), + operation=operation ))) -def new_signal_entity_action(id: int, name: str): +def new_signal_entity_action(id: int, entity_id: EntityInstanceId, operation: str, encoded_input: Optional[str]): return pb.OrchestratorAction(id=id, sendEntityMessage=pb.SendEntityMessageAction(entityOperationSignaled=pb.EntityOperationSignaledEvent( - requestId=None, - targetInstanceId=get_string_value(name) + requestId=f"{entity_id}:{id}", + targetInstanceId=get_string_value(str(entity_id)), + operation=operation, + input=get_string_value(encoded_input) ))) -def new_lock_entities_action(id: int, instance_id: str, critical_section_id: str, entity_ids: list[str]): +def new_lock_entities_action(id: int, parent_instance_id: str, critical_section_id: str, entity_ids: list[EntityInstanceId]): return pb.OrchestratorAction(id=id, sendEntityMessage=pb.SendEntityMessageAction(entityLockRequested=pb.EntityLockRequestedEvent( - parentInstanceId=get_string_value(instance_id), + parentInstanceId=get_string_value(parent_instance_id), criticalSectionId=critical_section_id, - lockSet=entity_ids, + lockSet=[str(eid) for eid in entity_ids], position=0 ))) @@ -236,7 +241,10 @@ def convert_to_entity_batch_request(req: pb.EntityRequest) -> tuple[pb.EntityBat operation=op.entityOperationCalled.operation, input=op.entityOperationCalled.input)) operation_infos.append(pb.OperationInfo(requestId=op.entityOperationCalled.requestId, - responseDestination=None)) + responseDestination=pb.OrchestrationInstance( + instanceId=op.entityOperationCalled.parentInstanceId.value, + executionId=op.entityOperationCalled.parentExecutionId + ))) return batch_request, operation_infos diff --git a/durabletask/task.py b/durabletask/task.py index 043e83c..80d982d 100644 --- a/durabletask/task.py +++ b/durabletask/task.py @@ -141,7 +141,8 @@ def call_activity(self, activity: Union[Activity[TInput, TOutput], str], *, pass @abstractmethod - def call_entity(self, entity: EntityInstanceId, *, + def call_entity(self, entity: EntityInstanceId, + operation: str, *, input: Optional[TInput] = None): """Schedule entity function for execution. @@ -149,6 +150,8 @@ def call_entity(self, entity: EntityInstanceId, *, ---------- entity: EntityInstanceId The ID of the entity instance to call. + operation: str + The name of the operation to invoke on the entity. input: Optional[TInput] The optional JSON-serializable input to pass to the entity function. @@ -162,7 +165,9 @@ def call_entity(self, entity: EntityInstanceId, *, @abstractmethod def signal_entity( self, - entity_id: EntityInstanceId + entity_id: EntityInstanceId, + operation_name: str, + input: Optional[TInput] = None ) -> None: """Signal an entity function for execution. @@ -170,6 +175,10 @@ def signal_entity( ---------- entity_id: EntityInstanceId The ID of the entity instance to signal. + operation_name: str + The name of the operation to invoke on the entity. + input: Optional[TInput] + The optional JSON-serializable input to pass to the entity function. """ pass diff --git a/durabletask/worker.py b/durabletask/worker.py index 5011903..0062768 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -48,6 +48,7 @@ def __init__( self, maximum_concurrent_activity_work_items: Optional[int] = None, maximum_concurrent_orchestration_work_items: Optional[int] = None, + maximum_concurrent_entity_work_items: Optional[int] = None, maximum_thread_pool_workers: Optional[int] = None, ): """Initialize concurrency options. @@ -76,6 +77,12 @@ def __init__( else default_concurrency ) + self.maximum_concurrent_entity_work_items = ( + maximum_concurrent_entity_work_items + if maximum_concurrent_entity_work_items is not None + else default_concurrency + ) + self.maximum_thread_pool_workers = ( maximum_thread_pool_workers if maximum_thread_pool_workers is not None @@ -534,14 +541,14 @@ def stream_reader(): work_item.completionToken, ) elif work_item.HasField("entityRequest"): - self._async_worker_manager.submit_activity( + self._async_worker_manager.submit_entity_batch( self._execute_entity_batch, work_item.entityRequest, stub, work_item.completionToken, ) elif work_item.HasField("entityRequestV2"): - self._async_worker_manager.submit_activity( + self._async_worker_manager.submit_entity_batch( self._execute_entity_batch, work_item.entityRequestV2, stub, @@ -740,22 +747,22 @@ def _execute_entity_batch( entity_state.rollback() - try: - stub.CompleteEntityTask(operation_result) - except Exception as ex: - self._logger.exception( - f"Failed to deliver entity response for '{entity_instance_id}' of orchestration ID '{instance_id}' to sidecar: {ex}" - ) - batch_result = pb.EntityBatchResult( results=results, actions=None, # TODO: Context should also provide actions like signaling another entity or starting sub-orchestrations entityState=helpers.get_string_value(shared.to_json(entity_state._current_state)) if entity_state._current_state else None, failureDetails=None, - operationInfos=operation_infos, completionToken=completionToken, + operationInfos=operation_infos, ) + try: + stub.CompleteEntityTask(batch_result) + except Exception as ex: + self._logger.exception( + f"Failed to deliver entity response for '{entity_instance_id}' of orchestration ID '{instance_id}' to sidecar: {ex}" + ) + # TODO: Reset context return batch_result @@ -772,6 +779,7 @@ def __init__(self, instance_id: str, registry: _Registry, entity_context: Orches self._result = None self._pending_actions: dict[int, pb.OrchestratorAction] = {} self._pending_tasks: dict[int, task.CompletableTask] = {} + self._entity_task_id_map: dict[str, int] = {} self._sequence_number = 0 self._current_utc_datetime = datetime(1000, 1, 1) self._instance_id = instance_id @@ -962,25 +970,28 @@ def call_activity( def call_entity( self, entity_id: EntityInstanceId, + operation: str, *, input: Optional[TInput] = None, ) -> task.Task: id = self.next_sequence_number() self.call_entity_function_helper( - id, entity_id, input=input + id, entity_id, operation, input=input ) return self._pending_tasks.get(id, task.CompletableTask()) def signal_entity( self, - entity_id: EntityInstanceId + entity_id: EntityInstanceId, + operation: str, + input: Optional[TInput] = None ) -> None: id = self.next_sequence_number() self.signal_entity_function_helper( - id, entity_id + id, entity_id, operation, input ) def lock_entities(self, entities: list[EntityInstanceId]) -> EntityLockReleaser: @@ -1074,8 +1085,9 @@ def call_entity_function_helper( self, id: Optional[int], entity_id: EntityInstanceId, + operation: str, *, - input: Optional[TInput] = None + input: Optional[TInput] = None, ): if id is None: id = self.next_sequence_number() @@ -1083,18 +1095,21 @@ def call_entity_function_helper( transition_valid, error_message = self._entity_context.validate_operation_transition(entity_id, False) if not transition_valid: raise RuntimeError(error_message) - + encoded_input = shared.to_json(input) if input is not None else None - action = ph.new_call_entity_action(id, str(entity_id), encoded_input) + action = ph.new_call_entity_action(id, self.instance_id, entity_id, operation, encoded_input) self._pending_actions[id] = action fn_task = task.CompletableTask() self._pending_tasks[id] = fn_task + def signal_entity_function_helper( self, id: Optional[int], - entity_id: EntityInstanceId + entity_id: EntityInstanceId, + operation: str, + input: Optional[TInput] ) -> None: if id is None: id = self.next_sequence_number() @@ -1104,7 +1119,9 @@ def signal_entity_function_helper( if not transition_valid: raise RuntimeError(error_message) - action = ph.new_signal_entity_action(id, str(entity_id)) + encoded_input = shared.to_json(input) if input is not None else None + + action = ph.new_signal_entity_action(id, entity_id, operation, encoded_input) self._pending_actions[id] = action def lock_entities_function_helper( @@ -1131,7 +1148,7 @@ def lock_entities_function_helper( # TODO: Implement deterministically replayable IDs critical_section_id = str(uuid.uuid4()) - action = ph.new_lock_entities_action(id, self.instance_id, critical_section_id, [str(eid) for eid in entity_ids_dedup]) + action = ph.new_lock_entities_action(id, self.instance_id, critical_section_id, entity_ids_dedup) self._pending_actions[id] = action def wait_for_external_event(self, name: str) -> task.Task: @@ -1545,6 +1562,8 @@ def process_event( is_result_encoded=True, ) elif event.HasField("entityOperationCalled"): + # This history event confirms that the entity operation was successfully scheduled. + # Remove the entityOperationCalled event from the pending action list so we don't schedule it again entity_call_id = event.eventId action = ctx._pending_actions.pop(entity_call_id, None) entity_task = ctx._pending_tasks.get(entity_call_id, None) @@ -1552,20 +1571,23 @@ def process_event( raise _get_non_determinism_error( entity_call_id, task.get_name(ctx.call_entity) ) - elif not action.HasField("callEntity"): + elif not action.HasField("sendEntityMessage") or not action.sendEntityMessage.HasField("entityOperationCalled"): expected_method_name = task.get_name(ctx.call_entity) raise _get_wrong_action_type_error( entity_call_id, expected_method_name, action ) + ctx._entity_task_id_map[event.entityOperationCalled.requestId] = entity_call_id # TODO: Validate entity ID elif event.HasField("entityOperationSignaled"): + # This history event confirms that the entity signal was successfully scheduled. + # Remove the entityOperationSignaled event from the pending action list so we don't schedule it entity_signal_id = event.eventId action = ctx._pending_actions.pop(entity_signal_id, None) if not action: raise _get_non_determinism_error( entity_signal_id, task.get_name(ctx.signal_entity) ) - elif not action.HasField("signalEntity"): + elif not action.HasField("sendEntityMessage") or not action.sendEntityMessage.HasField("entityOperationSignaled"): expected_method_name = task.get_name(ctx.signal_entity) raise _get_wrong_action_type_error( entity_signal_id, expected_method_name, action @@ -1586,10 +1608,22 @@ def process_event( self._logger.info(f"Data: {json.dumps(event.entityLockGranted)}") pass elif event.HasField("entityOperationCompleted"): - if not ctx.is_replaying: - self._logger.info(f"{ctx.instance_id}: Entity operation completed.") - self._logger.info(f"Data: {json.dumps(event.entityOperationCompleted)}") - pass + request_id = event.entityOperationCompleted.requestId + task_id = ctx._entity_task_id_map.pop(request_id, None) + if not task_id: + raise RuntimeError(f"Could not find matching task ID for entity operation with request ID '{request_id}'") + entity_task = ctx._pending_tasks.pop(task_id, None) + if not entity_task: + if not ctx.is_replaying: + self._logger.warning( + f"{ctx.instance_id}: Ignoring unexpected entityOperationCompleted event with request ID = {request_id}." + ) + return + result = None + if not ph.is_empty(event.entityOperationCompleted.output): + result = shared.from_json(event.entityOperationCompleted.output.value) + entity_task.complete(result) + ctx.resume() elif event.HasField("entityOperationFailed"): if not ctx.is_replaying: self._logger.info(f"{ctx.instance_id}: Entity operation failed.") @@ -1815,13 +1849,16 @@ def __init__(self, concurrency_options: ConcurrencyOptions): self.concurrency_options = concurrency_options self.activity_semaphore = None self.orchestration_semaphore = None + self.entity_semaphore = None # Don't create queues here - defer until we have an event loop self.activity_queue: Optional[asyncio.Queue] = None self.orchestration_queue: Optional[asyncio.Queue] = None + self.entity_batch_queue: Optional[asyncio.Queue] = None self._queue_event_loop: Optional[asyncio.AbstractEventLoop] = None # Store work items when no event loop is available self._pending_activity_work: list = [] self._pending_orchestration_work: list = [] + self._pending_entity_batch_work: list = [] self.thread_pool = ThreadPoolExecutor( max_workers=concurrency_options.maximum_thread_pool_workers, thread_name_prefix="DurableTask", @@ -1838,7 +1875,7 @@ def _ensure_queues_for_current_loop(self): # Check if queues are already properly set up for current loop if self._queue_event_loop is current_loop: - if self.activity_queue is not None and self.orchestration_queue is not None: + if self.activity_queue is not None and self.orchestration_queue is not None and self.entity_batch_queue is not None: # Queues are already bound to the current loop and exist return @@ -1846,6 +1883,7 @@ def _ensure_queues_for_current_loop(self): # First, preserve any existing work items existing_activity_items = [] existing_orchestration_items = [] + existing_entity_batch_items = [] if self.activity_queue is not None: try: @@ -1863,9 +1901,19 @@ def _ensure_queues_for_current_loop(self): except Exception: pass + if self.entity_batch_queue is not None: + try: + while not self.entity_batch_queue.empty(): + existing_entity_batch_items.append( + self.entity_batch_queue.get_nowait() + ) + except Exception: + pass + # Create fresh queues for the current event loop self.activity_queue = asyncio.Queue() self.orchestration_queue = asyncio.Queue() + self.entity_batch_queue = asyncio.Queue() self._queue_event_loop = current_loop # Restore the work items to the new queues @@ -1873,16 +1921,21 @@ def _ensure_queues_for_current_loop(self): self.activity_queue.put_nowait(item) for item in existing_orchestration_items: self.orchestration_queue.put_nowait(item) + for item in existing_entity_batch_items: + self.entity_batch_queue.put_nowait(item) # Move pending work items to the queues for item in self._pending_activity_work: self.activity_queue.put_nowait(item) for item in self._pending_orchestration_work: self.orchestration_queue.put_nowait(item) + for item in self._pending_entity_batch_work: + self.entity_batch_queue.put_nowait(item) # Clear the pending work lists self._pending_activity_work.clear() self._pending_orchestration_work.clear() + self._pending_entity_batch_work.clear() async def run(self): # Reset shutdown flag in case this manager is being reused @@ -1898,14 +1951,21 @@ async def run(self): self.orchestration_semaphore = asyncio.Semaphore( self.concurrency_options.maximum_concurrent_orchestration_work_items ) + self.entity_semaphore = asyncio.Semaphore( + self.concurrency_options.maximum_concurrent_entity_work_items + ) # Start background consumers for each work type - if self.activity_queue is not None and self.orchestration_queue is not None: + if self.activity_queue is not None and self.orchestration_queue is not None \ + and self.entity_batch_queue is not None: await asyncio.gather( self._consume_queue(self.activity_queue, self.activity_semaphore), self._consume_queue( self.orchestration_queue, self.orchestration_semaphore ), + self._consume_queue( + self.entity_batch_queue, self.entity_semaphore + ) ) async def _consume_queue(self, queue: asyncio.Queue, semaphore: asyncio.Semaphore): @@ -1975,6 +2035,15 @@ def submit_orchestration(self, func, *args, **kwargs): # No event loop running, store in pending list self._pending_orchestration_work.append(work_item) + def submit_entity_batch(self, func, *args, **kwargs): + work_item = (func, args, kwargs) + self._ensure_queues_for_current_loop() + if self.entity_batch_queue is not None: + self.entity_batch_queue.put_nowait(work_item) + else: + # No event loop running, store in pending list + self._pending_entity_batch_work.append(work_item) + def shutdown(self): self._shutdown = True self.thread_pool.shutdown(wait=True) From e5b6db27347c10318b7dc6710cdd0deccad10108 Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Wed, 10 Sep 2025 13:27:21 -0700 Subject: [PATCH 03/21] Entity lock incremental change --- durabletask/entities/entity_lock.py | 14 +++++++++ durabletask/internal/entity_lock_releaser.py | 6 ---- durabletask/internal/helpers.py | 9 ++---- .../internal/orchestration_entity_context.py | 28 +++++++++++++++-- durabletask/task.py | 8 ++--- durabletask/worker.py | 30 ++++++++----------- 6 files changed, 58 insertions(+), 37 deletions(-) create mode 100644 durabletask/entities/entity_lock.py delete mode 100644 durabletask/internal/entity_lock_releaser.py diff --git a/durabletask/entities/entity_lock.py b/durabletask/entities/entity_lock.py new file mode 100644 index 0000000..94df13f --- /dev/null +++ b/durabletask/entities/entity_lock.py @@ -0,0 +1,14 @@ +from durabletask.entities.entity_instance_id import EntityInstanceId +from durabletask.internal.orchestration_entity_context import OrchestrationEntityContext + + +class EntityLock: + def __init__(self, entity_context: OrchestrationEntityContext, entities: list[EntityInstanceId]): + self.entity_context = entity_context + self.entities = entities + + def __enter__(self): + print(f"Locking entities: {self.entities}") + + def __exit__(self, exc_type, exc_val, exc_tb): + print(f"Unlocking entities: {self.entities}") diff --git a/durabletask/internal/entity_lock_releaser.py b/durabletask/internal/entity_lock_releaser.py deleted file mode 100644 index bc898b5..0000000 --- a/durabletask/internal/entity_lock_releaser.py +++ /dev/null @@ -1,6 +0,0 @@ -from durabletask.entities.entity_instance_id import EntityInstanceId - - -class EntityLockReleaser: - def __init__(self, entities: list[EntityInstanceId]): - self.entities = entities diff --git a/durabletask/internal/helpers.py b/durabletask/internal/helpers.py index aaee47c..0ecd650 100644 --- a/durabletask/internal/helpers.py +++ b/durabletask/internal/helpers.py @@ -215,13 +215,8 @@ def new_signal_entity_action(id: int, entity_id: EntityInstanceId, operation: st ))) -def new_lock_entities_action(id: int, parent_instance_id: str, critical_section_id: str, entity_ids: list[EntityInstanceId]): - return pb.OrchestratorAction(id=id, sendEntityMessage=pb.SendEntityMessageAction(entityLockRequested=pb.EntityLockRequestedEvent( - parentInstanceId=get_string_value(parent_instance_id), - criticalSectionId=critical_section_id, - lockSet=[str(eid) for eid in entity_ids], - position=0 - ))) +def new_lock_entities_action(id: int, entity_message: pb.SendEntityMessageAction): + return pb.OrchestratorAction(id=id, sendEntityMessage=entity_message) def convert_to_entity_batch_request(req: pb.EntityRequest) -> tuple[pb.EntityBatchRequest, list[pb.OperationInfo]]: diff --git a/durabletask/internal/orchestration_entity_context.py b/durabletask/internal/orchestration_entity_context.py index b74097f..64d8ca9 100644 --- a/durabletask/internal/orchestration_entity_context.py +++ b/durabletask/internal/orchestration_entity_context.py @@ -1,6 +1,8 @@ from datetime import datetime -from typing import Generator, List, Optional, Tuple +from typing import Generator, List, Optional, Tuple, Union +from durabletask.internal.helpers import get_string_value +import durabletask.internal.orchestrator_service_pb2 as pb from durabletask.entities.entity_instance_id import EntityInstanceId @@ -63,8 +65,28 @@ def emit_request_message(self, target, operation_name: str, one_way: bool, opera request_time: Optional[datetime] = None, create_trace: bool = False): raise NotImplementedError() - def emit_acquire_message(self, lock_request_id: str, entities: List[str]): - raise NotImplementedError() + def emit_acquire_message(self, critical_section_id: str, entities: List[EntityInstanceId]) -> Union[Tuple[None, None, None], Tuple[str, pb.SendEntityMessageAction, pb.OrchestrationInstance]]: + if not entities: + return None, None, None + + # Acquire the locks in a globally fixed order to avoid deadlocks + # Also remove duplicates - this can be optimized for perf if necessary + entity_ids = sorted(entities) + entity_ids_dedup = [] + for i, entity_id in enumerate(entity_ids): + if entity_id != entity_ids[i - 1] if i > 0 else True: + entity_ids_dedup.append(entity_id) + + target = pb.OrchestrationInstance(instanceId=str(entity_ids_dedup[0])) + request = pb.SendEntityMessageAction(entityLockRequested=pb.EntityLockRequestedEvent( + criticalSectionId=critical_section_id, + parentInstanceId=get_string_value(self.instance_id), + lockSet=entity_ids_dedup, + position=0, + )) + + return "op", request, target + def complete_acquire(self, result, critical_section_id): # TODO: HashSet or equivalent diff --git a/durabletask/task.py b/durabletask/task.py index 80d982d..c8adb56 100644 --- a/durabletask/task.py +++ b/durabletask/task.py @@ -10,7 +10,7 @@ from typing import Any, Callable, Generator, Generic, Optional, Type, TypeVar, Union from durabletask.entities.entity_instance_id import EntityInstanceId -from durabletask.internal.entity_lock_releaser import EntityLockReleaser +from durabletask.entities.entity_lock import EntityLock from durabletask.internal.entity_state_shim import StateShim import durabletask.internal.helpers as pbh import durabletask.internal.orchestrator_service_pb2 as pb @@ -183,7 +183,7 @@ def signal_entity( pass @abstractmethod - def lock_entities(self, entities: list[EntityInstanceId]) -> EntityLockReleaser: + def lock_entities(self, entities: list[EntityInstanceId]) -> EntityLock: """Lock the specified entity instances for the duration of the orchestration. Parameters @@ -193,8 +193,8 @@ def lock_entities(self, entities: list[EntityInstanceId]) -> EntityLockReleaser: Returns ------- - EntityLockReleaser - A context manager that releases the locks when disposed. + EntityLock + A disposable object that acquires and releases the locks when initialized or disposed. """ pass diff --git a/durabletask/worker.py b/durabletask/worker.py index 0062768..c86032b 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -14,6 +14,7 @@ from enum import Enum from typing import Any, Generator, List, Optional, Sequence, TypeVar, Union import uuid +from durabletask.entities.entity_lock import EntityLock from packaging.version import InvalidVersion, parse import grpc @@ -23,7 +24,6 @@ from durabletask.internal.entity_state_shim import StateShim from durabletask.internal.helpers import new_timestamp from durabletask.entities.entity_instance_id import EntityInstanceId -from durabletask.internal.entity_lock_releaser import EntityLockReleaser from durabletask.internal.orchestration_entity_context import OrchestrationEntityContext import durabletask.internal.helpers as ph import durabletask.internal.exceptions as pe @@ -994,7 +994,7 @@ def signal_entity( id, entity_id, operation, input ) - def lock_entities(self, entities: list[EntityInstanceId]) -> EntityLockReleaser: + def lock_entities(self, entities: list[EntityInstanceId]) -> EntityLock: id = self.next_sequence_number() self.lock_entities_function_helper( @@ -1003,7 +1003,7 @@ def lock_entities(self, entities: list[EntityInstanceId]) -> EntityLockReleaser: # Todo: EntityLockReleaser should be a disposable that uses python's using statement # and should release the locks when disposed - return EntityLockReleaser(entities) + return EntityLock(self._entity_context, entities) def call_sub_orchestrator( self, @@ -1129,26 +1129,22 @@ def lock_entities_function_helper( id: Optional[int], entity_ids: List[EntityInstanceId] ): + valid, message = self._entity_context.validate_acquire_transition() + if not valid: + raise RuntimeError(message) + if id is None: id = self.next_sequence_number() - transition_valid, error_message = self._entity_context.validate_acquire_transition() - if not transition_valid: - raise RuntimeError(error_message) + # Use a deterministically replayable unique ID for this lock request + critical_section_id = f"{self.instance_id}:{id}" - # Acquire the locks in a globally fixed order to avoid deadlocks - # Also remove duplicates - this can be optimized for perf if necessary - entity_ids = sorted(entity_ids) - entity_ids_dedup = [] - for i, entity_id in enumerate(entity_ids): - if entity_id != entity_ids[i - 1] if i > 0 else None: - entity_ids_dedup.append(entity_id) + event_name, request, target = self._entity_context.emit_acquire_message(critical_section_id, entity_ids) - # Use a deterministically replayable unique ID for this lock request - # TODO: Implement deterministically replayable IDs - critical_section_id = str(uuid.uuid4()) + if not event_name or not request or not target: + raise RuntimeError("Failed to create entity lock request.") - action = ph.new_lock_entities_action(id, self.instance_id, critical_section_id, entity_ids_dedup) + action = ph.new_lock_entities_action(id, request) self._pending_actions[id] = action def wait_for_external_event(self, name: str) -> task.Task: From 4107182d418d3aacd975b00e469dce87216d1fce Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Tue, 16 Sep 2025 10:05:18 -0700 Subject: [PATCH 04/21] More entity implementing --- durabletask/entities/__init__.py | 4 +- durabletask/entities/durable_entity.py | 20 ++++ durabletask/entities/entity_lock.py | 19 +-- durabletask/internal/entity_state_shim.py | 31 ++++- .../internal/orchestration_entity_context.py | 32 +++-- durabletask/task.py | 20 ++-- durabletask/worker.py | 109 ++++++++++++------ 7 files changed, 169 insertions(+), 66 deletions(-) create mode 100644 durabletask/entities/durable_entity.py diff --git a/durabletask/entities/__init__.py b/durabletask/entities/__init__.py index 47ea475..93ea7c5 100644 --- a/durabletask/entities/__init__.py +++ b/durabletask/entities/__init__.py @@ -4,7 +4,9 @@ """Durable Task SDK for Python entities component""" from durabletask.entities.entity_instance_id import EntityInstanceId +from durabletask.entities.durable_entity import DurableEntity +from durabletask.entities.entity_lock import EntityLock -__all__ = ["EntityInstanceId"] +__all__ = ["EntityInstanceId", "DurableEntity", "EntityLock"] PACKAGE_NAME = "durabletask.entities" diff --git a/durabletask/entities/durable_entity.py b/durabletask/entities/durable_entity.py new file mode 100644 index 0000000..8f03cfe --- /dev/null +++ b/durabletask/entities/durable_entity.py @@ -0,0 +1,20 @@ +from typing import Any, Optional, Type, TypeVar, overload + +TState = TypeVar("TState") + + +class DurableEntity: + def _initialize_entity_context(self, context): + self.entity_context = context + + @overload + def get_state(self, intended_type: Type[TState]) -> Optional[TState]: ... + + @overload + def get_state(self, intended_type: None = None) -> Any: ... + + def get_state(self, intended_type: Optional[Type[TState]] = None) -> Optional[TState] | Any: + return self.entity_context.get_state(intended_type) + + def set_state(self, state: Any): + self.entity_context.set_state(state) \ No newline at end of file diff --git a/durabletask/entities/entity_lock.py b/durabletask/entities/entity_lock.py index 94df13f..b693ad3 100644 --- a/durabletask/entities/entity_lock.py +++ b/durabletask/entities/entity_lock.py @@ -1,14 +1,19 @@ +import durabletask.internal.helpers as ph + from durabletask.entities.entity_instance_id import EntityInstanceId -from durabletask.internal.orchestration_entity_context import OrchestrationEntityContext +import durabletask.internal.orchestrator_service_pb2 as pb class EntityLock: - def __init__(self, entity_context: OrchestrationEntityContext, entities: list[EntityInstanceId]): - self.entity_context = entity_context - self.entities = entities + def __init__(self, context): + self._context = context def __enter__(self): - print(f"Locking entities: {self.entities}") + return self - def __exit__(self, exc_type, exc_val, exc_tb): - print(f"Unlocking entities: {self.entities}") + def __exit__(self, exc_type, exc_val, exc_tb): # TODO: Handle exceptions? + print(f"Unlocking entities: {self._context._entity_context.critical_section_locks}") + for entity_unlock_message in self._context._entity_context.emit_lock_release_messages(): + task_id = self._context.next_sequence_number() + action = pb.OrchestratorAction(task_id, sendEntityMessage=entity_unlock_message) + self._context._pending_actions[task_id] = action diff --git a/durabletask/internal/entity_state_shim.py b/durabletask/internal/entity_state_shim.py index fba2d81..2239427 100644 --- a/durabletask/internal/entity_state_shim.py +++ b/durabletask/internal/entity_state_shim.py @@ -1,17 +1,36 @@ -from typing import Optional, Type +from ctypes import Union +from typing import Any, TypeVar, runtime_checkable +from typing import Optional, Type, overload +from typing_extensions import Protocol + + +TState = TypeVar("TState") class StateShim: def __init__(self, start_state): - self._current_state = start_state - self._checkpoint_state = start_state + self._current_state: Any = start_state + self._checkpoint_state: Any = start_state + + @overload + def get_state(self, intended_type: Type[TState]) -> Optional[TState]: ... + + @overload + def get_state(self, intended_type: None = None) -> Any: ... - def get_state(self, intended_type: Optional[Type]): - if not intended_type: + def get_state(self, intended_type: Optional[Type[TState]] = None) -> Optional[TState] | Any: + if intended_type is None: return self._current_state + if isinstance(self._current_state, intended_type) or self._current_state is None: return self._current_state - return intended_type(self._current_state) + + try: + return intended_type(self._current_state) # type: ignore[call-arg] + except Exception as ex: + raise TypeError( + f"Could not convert state of type '{type(self._current_state).__name__}' to '{intended_type.__name__}'" + ) from ex def set_state(self, state): self._current_state = state diff --git a/durabletask/internal/orchestration_entity_context.py b/durabletask/internal/orchestration_entity_context.py index 64d8ca9..38effed 100644 --- a/durabletask/internal/orchestration_entity_context.py +++ b/durabletask/internal/orchestration_entity_context.py @@ -13,14 +13,14 @@ def __init__(self, instance_id: str): self.lock_acquisition_pending = False self.critical_section_id = None - self.critical_section_locks = [] - self.available_locks = [] + self.critical_section_locks: list[EntityInstanceId] = [] + self.available_locks: list[EntityInstanceId] = [] @property def is_inside_critical_section(self) -> bool: return self.critical_section_id is not None - def get_available_entities(self) -> Generator[str, None, None]: + def get_available_entities(self) -> Generator[EntityInstanceId, None, None]: if self.is_inside_critical_section: for available_lock in self.available_locks: yield available_lock @@ -58,16 +58,27 @@ def recover_lock_after_call(self, target_instance_id: EntityInstanceId): self.available_locks.append(target_instance_id) def emit_lock_release_messages(self): - raise NotImplementedError() + if self.is_inside_critical_section: + for entity_id in self.critical_section_locks: + unlock_event = pb.SendEntityMessageAction(entityUnlockSent=pb.EntityUnlockSentEvent( + criticalSectionId=self.critical_section_id, + targetInstanceId=get_string_value(str(entity_id)) + )) + yield unlock_event + + # TODO: Emit the actual release messages (?) + self.critical_section_locks = [] + self.available_locks = [] + self.critical_section_id = None def emit_request_message(self, target, operation_name: str, one_way: bool, operation_id: str, scheduled_time_utc: datetime, input: Optional[str], request_time: Optional[datetime] = None, create_trace: bool = False): raise NotImplementedError() - def emit_acquire_message(self, critical_section_id: str, entities: List[EntityInstanceId]) -> Union[Tuple[None, None, None], Tuple[str, pb.SendEntityMessageAction, pb.OrchestrationInstance]]: + def emit_acquire_message(self, critical_section_id: str, entities: List[EntityInstanceId]) -> Union[Tuple[None, None], Tuple[pb.SendEntityMessageAction, pb.OrchestrationInstance]]: if not entities: - return None, None, None + return None, None # Acquire the locks in a globally fixed order to avoid deadlocks # Also remove duplicates - this can be optimized for perf if necessary @@ -81,12 +92,15 @@ def emit_acquire_message(self, critical_section_id: str, entities: List[EntityIn request = pb.SendEntityMessageAction(entityLockRequested=pb.EntityLockRequestedEvent( criticalSectionId=critical_section_id, parentInstanceId=get_string_value(self.instance_id), - lockSet=entity_ids_dedup, + lockSet=[str(eid) for eid in entity_ids_dedup], position=0, )) - return "op", request, target - + self.critical_section_id = critical_section_id + self.critical_section_locks = entity_ids_dedup + self.lock_acquisition_pending = True + + return request, target def complete_acquire(self, result, critical_section_id): # TODO: HashSet or equivalent diff --git a/durabletask/task.py b/durabletask/task.py index c8adb56..0f4016d 100644 --- a/durabletask/task.py +++ b/durabletask/task.py @@ -7,10 +7,9 @@ import math from abc import ABC, abstractmethod from datetime import datetime, timedelta -from typing import Any, Callable, Generator, Generic, Optional, Type, TypeVar, Union +from typing import Any, Callable, Generator, Generic, Optional, Type, TypeVar, Union, overload -from durabletask.entities.entity_instance_id import EntityInstanceId -from durabletask.entities.entity_lock import EntityLock +from durabletask.entities import DurableEntity, EntityInstanceId, EntityLock from durabletask.internal.entity_state_shim import StateShim import durabletask.internal.helpers as pbh import durabletask.internal.orchestrator_service_pb2 as pb @@ -18,6 +17,7 @@ T = TypeVar('T') TInput = TypeVar('TInput') TOutput = TypeVar('TOutput') +TState = TypeVar("TState") class OrchestrationContext(ABC): @@ -545,11 +545,17 @@ def operation(self) -> str: The operation associated with this entity invocation. """ return self._operation - - def get_state(self, intended_type: Optional[Type] = None): + + @overload + def get_state(self, intended_type: Type[TState]) -> Optional[TState]: ... + + @overload + def get_state(self, intended_type: None = None) -> Any: ... + + def get_state(self, intended_type: Optional[Type[TState]] = None) -> Optional[TState] | Any: return self._state.get_state(intended_type) - def set_state(self, new_state): + def set_state(self, new_state: Any): self._state.set_state(new_state) @property @@ -570,7 +576,7 @@ def entity_id(self) -> EntityInstanceId: # Activities are simple functions that can be scheduled by orchestrators Activity = Callable[[ActivityContext, TInput], TOutput] -Entity = Callable[[EntityContext, TInput], TOutput] +Entity = Union[Callable[[EntityContext, TInput], TOutput], type[DurableEntity]] class RetryPolicy: diff --git a/durabletask/worker.py b/durabletask/worker.py index c86032b..2d3be95 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -12,9 +12,7 @@ from threading import Event, Thread from types import GeneratorType from enum import Enum -from typing import Any, Generator, List, Optional, Sequence, TypeVar, Union -import uuid -from durabletask.entities.entity_lock import EntityLock +from typing import Any, Generator, Optional, Sequence, TypeVar, Union from packaging.version import InvalidVersion, parse import grpc @@ -23,7 +21,7 @@ from durabletask.internal import helpers from durabletask.internal.entity_state_shim import StateShim from durabletask.internal.helpers import new_timestamp -from durabletask.entities.entity_instance_id import EntityInstanceId +from durabletask.entities import DurableEntity, EntityLock, EntityInstanceId from durabletask.internal.orchestration_entity_context import OrchestrationEntityContext import durabletask.internal.helpers as ph import durabletask.internal.exceptions as pe @@ -140,12 +138,14 @@ class _Registry: orchestrators: dict[str, task.Orchestrator] activities: dict[str, task.Activity] entities: dict[str, task.Entity] + entity_instances: dict[str, DurableEntity] versioning: Optional[VersioningOptions] = None def __init__(self): self.orchestrators = {} self.activities = {} self.entities = {} + self.entity_instances = {} def add_orchestrator(self, fn: task.Orchestrator) -> str: if fn is None: @@ -189,8 +189,12 @@ def add_entity(self, fn: task.Entity) -> str: if fn is None: raise ValueError("An entity function argument is required.") - name = task.get_name(fn) - self.add_named_entity(name, fn) + if isinstance(fn, type) and issubclass(fn, DurableEntity): + name = fn.__name__ + self.add_named_entity(name, fn) + else: + name = task.get_name(fn) + self.add_named_entity(name, fn) return name def add_named_entity(self, name: str, fn: task.Entity) -> None: @@ -994,16 +998,13 @@ def signal_entity( id, entity_id, operation, input ) - def lock_entities(self, entities: list[EntityInstanceId]) -> EntityLock: + def lock_entities(self, entities: list[EntityInstanceId]) -> task.Task[EntityLock]: id = self.next_sequence_number() - + self.lock_entities_function_helper( id, entities ) - - # Todo: EntityLockReleaser should be a disposable that uses python's using statement - # and should release the locks when disposed - return EntityLock(self._entity_context, entities) + return self._pending_tasks.get(id, task.CompletableTask()) def call_sub_orchestrator( self, @@ -1124,29 +1125,29 @@ def signal_entity_function_helper( action = ph.new_signal_entity_action(id, entity_id, operation, encoded_input) self._pending_actions[id] = action - def lock_entities_function_helper( - self, - id: Optional[int], - entity_ids: List[EntityInstanceId] - ): - valid, message = self._entity_context.validate_acquire_transition() - if not valid: - raise RuntimeError(message) - + + def lock_entities_function_helper(self, id: int, entities: list[EntityInstanceId]) -> None: if id is None: id = self.next_sequence_number() + + transition_valid, error_message = self._entity_context.validate_acquire_transition() + if not transition_valid: + raise RuntimeError(error_message) + + critical_section_id = f"{self.instance_id}:{id:04x}" - # Use a deterministically replayable unique ID for this lock request - critical_section_id = f"{self.instance_id}:{id}" - - event_name, request, target = self._entity_context.emit_acquire_message(critical_section_id, entity_ids) + request, target = self._entity_context.emit_acquire_message(critical_section_id, entities) - if not event_name or not request or not target: + if not request or not target: raise RuntimeError("Failed to create entity lock request.") action = ph.new_lock_entities_action(id, request) self._pending_actions[id] = action + fn_task = task.CompletableTask[EntityLock]() + self._pending_tasks[id] = fn_task + + def wait_for_external_event(self, name: str) -> task.Task: # Check to see if this event has already been received, in which case we # can return it immediately. Otherwise, record out intent to receive an @@ -1589,20 +1590,41 @@ def process_event( entity_signal_id, expected_method_name, action ) elif event.HasField("entityLockRequested"): - if not ctx.is_replaying: - self._logger.info(f"{ctx.instance_id}: Entity lock requested.") - self._logger.info(f"Data: {json.dumps(event.entityLockRequested)}") - pass + section_id = event.entityLockRequested.criticalSectionId + task_id = ctx._entity_task_id_map.get(section_id, None) + if not task_id: + raise RuntimeError(f"Unexpected entityLockRequested event for criticalSectionId '{section_id}'") + action = ctx._pending_actions.pop(task_id, None) + entity_task = ctx._pending_tasks.get(task_id, None) + if not action: + raise _get_non_determinism_error( + task_id, task.get_name(ctx.lock_entities) + ) + elif not action.HasField("sendEntityMessage") or not action.sendEntityMessage.HasField("entityLockRequested"): + expected_method_name = task.get_name(ctx.lock_entities) + raise _get_wrong_action_type_error( + task_id, expected_method_name, action + ) elif event.HasField("entityUnlockSent"): if not ctx.is_replaying: self._logger.info(f"{ctx.instance_id}: Entity unlock sent.") self._logger.info(f"Data: {json.dumps(event.entityUnlockSent)}") + # I don't think there's anything we need to do here - if we decide we need to send the lock + # release messages before continuing replay, we can confirm that they were processed here pass elif event.HasField("entityLockGranted"): - if not ctx.is_replaying: - self._logger.info(f"{ctx.instance_id}: Entity lock granted.") - self._logger.info(f"Data: {json.dumps(event.entityLockGranted)}") - pass + section_id = event.entityLockGranted.criticalSectionId + task_id = ctx._entity_task_id_map.pop(section_id, None) + if not task_id: + raise RuntimeError(f"Unexpected entityLockGranted event for criticalSectionId '{section_id}'") + entity_task = ctx._pending_tasks.pop(task_id, None) + if not entity_task: + if not ctx.is_replaying: + self._logger.warning( + f"{ctx.instance_id}: Ignoring unexpected entityLockGranted event for criticalSectionId '{section_id}'." + ) + return + entity_task.complete(EntityLock(ctx)) elif event.HasField("entityOperationCompleted"): request_id = event.entityOperationCompleted.requestId task_id = ctx._entity_task_id_map.pop(request_id, None) @@ -1741,8 +1763,23 @@ def execute( entity_input = shared.from_json(encoded_input) if encoded_input else None ctx = task.EntityContext(orchestration_id, operation, state, entity_id) - # Execute the entity function - entity_output = fn(ctx, entity_input) + if isinstance(fn, type) and issubclass(fn, DurableEntity): + if self._registry.entity_instances.get(str(entity_id), None): + entity_instance = self._registry.entity_instances[str(entity_id)] + else: + entity_instance = fn() + self._registry.entity_instances[str(entity_id)] = entity_instance + if not hasattr(entity_instance, operation): + raise AttributeError(f"Entity '{entity_id}' does not have operation '{operation}'") + method = getattr(entity_instance, operation) + if not callable(method): + raise TypeError(f"Entity operation '{operation}' is not callable") + # Execute the entity method + entity_instance._initialize_entity_context(ctx) + entity_output = method(entity_input) + else: + # Execute the entity function + entity_output = fn(ctx, entity_input) encoded_output = ( shared.to_json(entity_output) if entity_output is not None else None From 8ac1876a541f31c02bec817ed8c1644e33e7057d Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Fri, 26 Sep 2025 11:45:01 -0600 Subject: [PATCH 05/21] Finish locking, add operationactions --- durabletask/client.py | 10 ++-- durabletask/entities/durable_entity.py | 14 ++++- durabletask/entities/entity_lock.py | 7 +-- durabletask/internal/entity_state_shim.py | 22 +++++-- durabletask/internal/helpers.py | 11 ++-- .../internal/orchestration_entity_context.py | 9 ++- durabletask/task.py | 39 +++++++++++- durabletask/worker.py | 59 +++++++++++-------- 8 files changed, 120 insertions(+), 51 deletions(-) diff --git a/durabletask/client.py b/durabletask/client.py index 5f01678..a7a3775 100644 --- a/durabletask/client.py +++ b/durabletask/client.py @@ -229,15 +229,15 @@ def purge_orchestration(self, instance_id: str, recursive: bool = True): self._logger.info(f"Purging instance '{instance_id}'.") self._stub.PurgeInstances(req) - def signal_entity(self, entity_instance_id: EntityInstanceId, operation_name: str, input: Optional[Any] = None, signal_entity_options=None, cancellation=None): - scheduled_time = signal_entity_options.scheduled_time if signal_entity_options and signal_entity_options.scheduled_time else None + def signal_entity(self, entity_instance_id: EntityInstanceId, operation_name: str, input: Optional[Any] = None): req = pb.SignalEntityRequest( instanceId=str(entity_instance_id), - requestId=str(uuid.uuid4()), name=operation_name, input=wrappers_pb2.StringValue(value=shared.to_json(input)) if input else None, - scheduledTime=scheduled_time, + requestId=str(uuid.uuid4()), + scheduledTime=None, + parentTraceContext=None, requestTime=helpers.new_timestamp(datetime.now(timezone.utc)) ) self._logger.info(f"Signaling entity '{entity_instance_id}' operation '{operation_name}'.") - self._stub.SignalEntity(req, timeout=cancellation.timeout if cancellation else None) + self._stub.SignalEntity(req, None) # TODO: Cancellation timeout? diff --git a/durabletask/entities/durable_entity.py b/durabletask/entities/durable_entity.py index 8f03cfe..03c5dce 100644 --- a/durabletask/entities/durable_entity.py +++ b/durabletask/entities/durable_entity.py @@ -1,5 +1,7 @@ from typing import Any, Optional, Type, TypeVar, overload +from durabletask.entities.entity_instance_id import EntityInstanceId + TState = TypeVar("TState") @@ -9,12 +11,18 @@ def _initialize_entity_context(self, context): @overload def get_state(self, intended_type: Type[TState]) -> Optional[TState]: ... - + @overload def get_state(self, intended_type: None = None) -> Any: ... - + def get_state(self, intended_type: Optional[Type[TState]] = None) -> Optional[TState] | Any: return self.entity_context.get_state(intended_type) def set_state(self, state: Any): - self.entity_context.set_state(state) \ No newline at end of file + self.entity_context.set_state(state) + + def signal_entity(self, entity_instance_id: EntityInstanceId, operation: str, input: Optional[Any] = None) -> None: + self.entity_context.signal_entity(entity_instance_id, operation, input) + + def schedule_new_orchestration(self, orchestration_name: str, input: Optional[Any] = None, instance_id: Optional[str] = None) -> None: + self.entity_context.schedule_new_orchestration(orchestration_name, input, instance_id=instance_id) diff --git a/durabletask/entities/entity_lock.py b/durabletask/entities/entity_lock.py index b693ad3..5cbf7ea 100644 --- a/durabletask/entities/entity_lock.py +++ b/durabletask/entities/entity_lock.py @@ -1,6 +1,3 @@ -import durabletask.internal.helpers as ph - -from durabletask.entities.entity_instance_id import EntityInstanceId import durabletask.internal.orchestrator_service_pb2 as pb @@ -11,9 +8,9 @@ def __init__(self, context): def __enter__(self): return self - def __exit__(self, exc_type, exc_val, exc_tb): # TODO: Handle exceptions? + def __exit__(self, exc_type, exc_val, exc_tb): # TODO: Handle exceptions? print(f"Unlocking entities: {self._context._entity_context.critical_section_locks}") for entity_unlock_message in self._context._entity_context.emit_lock_release_messages(): task_id = self._context.next_sequence_number() - action = pb.OrchestratorAction(task_id, sendEntityMessage=entity_unlock_message) + action = pb.OrchestratorAction(id=task_id, sendEntityMessage=entity_unlock_message) self._context._pending_actions[task_id] = action diff --git a/durabletask/internal/entity_state_shim.py b/durabletask/internal/entity_state_shim.py index 2239427..5b410a3 100644 --- a/durabletask/internal/entity_state_shim.py +++ b/durabletask/internal/entity_state_shim.py @@ -1,8 +1,7 @@ -from ctypes import Union -from typing import Any, TypeVar, runtime_checkable +from typing import Any, TypeVar from typing import Optional, Type, overload -from typing_extensions import Protocol +import durabletask.internal.orchestrator_service_pb2 as pb TState = TypeVar("TState") @@ -11,10 +10,12 @@ class StateShim: def __init__(self, start_state): self._current_state: Any = start_state self._checkpoint_state: Any = start_state + self._operation_actions: list[pb.OperationAction] = [] + self._actions_checkpoint_state: int = 0 @overload def get_state(self, intended_type: Type[TState]) -> Optional[TState]: ... - + @overload def get_state(self, intended_type: None = None) -> Any: ... @@ -26,7 +27,7 @@ def get_state(self, intended_type: Optional[Type[TState]] = None) -> Optional[TS return self._current_state try: - return intended_type(self._current_state) # type: ignore[call-arg] + return intended_type(self._current_state) # type: ignore[call-arg] except Exception as ex: raise TypeError( f"Could not convert state of type '{type(self._current_state).__name__}' to '{intended_type.__name__}'" @@ -35,11 +36,22 @@ def get_state(self, intended_type: Optional[Type[TState]] = None) -> Optional[TS def set_state(self, state): self._current_state = state + def add_operation_action(self, action: pb.OperationAction): + self._operation_actions.append(action) + + def get_operation_actions(self) -> list[pb.OperationAction]: + return self._operation_actions[:self._actions_checkpoint_state] + def commit(self): self._checkpoint_state = self._current_state + self._actions_checkpoint_state = len(self._operation_actions) def rollback(self): self._current_state = self._checkpoint_state + self._operation_actions = self._operation_actions[:self._actions_checkpoint_state] def reset(self): self._current_state = None + self._checkpoint_state = None + self._operation_actions = [] + self._actions_checkpoint_state = 0 diff --git a/durabletask/internal/helpers.py b/durabletask/internal/helpers.py index 0ecd650..3e6d887 100644 --- a/durabletask/internal/helpers.py +++ b/durabletask/internal/helpers.py @@ -199,19 +199,22 @@ def new_schedule_task_action(id: int, name: str, encoded_input: Optional[str], def new_call_entity_action(id: int, parent_instance_id: str, entity_id: EntityInstanceId, operation: str, encoded_input: Optional[str]): return pb.OrchestratorAction(id=id, sendEntityMessage=pb.SendEntityMessageAction(entityOperationCalled=pb.EntityOperationCalledEvent( requestId=f"{parent_instance_id}:{id}", + operation=operation, + scheduledTime=None, + input=get_string_value(encoded_input), parentInstanceId=get_string_value(parent_instance_id), + parentExecutionId=None, targetInstanceId=get_string_value(str(entity_id)), - input=get_string_value(encoded_input), - operation=operation ))) def new_signal_entity_action(id: int, entity_id: EntityInstanceId, operation: str, encoded_input: Optional[str]): return pb.OrchestratorAction(id=id, sendEntityMessage=pb.SendEntityMessageAction(entityOperationSignaled=pb.EntityOperationSignaledEvent( requestId=f"{entity_id}:{id}", - targetInstanceId=get_string_value(str(entity_id)), operation=operation, - input=get_string_value(encoded_input) + scheduledTime=None, + input=get_string_value(encoded_input), + targetInstanceId=get_string_value(str(entity_id)), ))) diff --git a/durabletask/internal/orchestration_entity_context.py b/durabletask/internal/orchestration_entity_context.py index 38effed..21085a9 100644 --- a/durabletask/internal/orchestration_entity_context.py +++ b/durabletask/internal/orchestration_entity_context.py @@ -62,7 +62,8 @@ def emit_lock_release_messages(self): for entity_id in self.critical_section_locks: unlock_event = pb.SendEntityMessageAction(entityUnlockSent=pb.EntityUnlockSentEvent( criticalSectionId=self.critical_section_id, - targetInstanceId=get_string_value(str(entity_id)) + targetInstanceId=get_string_value(str(entity_id)), + parentInstanceId=get_string_value(self.instance_id) )) yield unlock_event @@ -79,7 +80,7 @@ def emit_request_message(self, target, operation_name: str, one_way: bool, opera def emit_acquire_message(self, critical_section_id: str, entities: List[EntityInstanceId]) -> Union[Tuple[None, None], Tuple[pb.SendEntityMessageAction, pb.OrchestrationInstance]]: if not entities: return None, None - + # Acquire the locks in a globally fixed order to avoid deadlocks # Also remove duplicates - this can be optimized for perf if necessary entity_ids = sorted(entities) @@ -102,8 +103,10 @@ def emit_acquire_message(self, critical_section_id: str, entities: List[EntityIn return request, target - def complete_acquire(self, result, critical_section_id): + def complete_acquire(self, critical_section_id): # TODO: HashSet or equivalent + if self.critical_section_id != critical_section_id: + raise RuntimeError(f"Unexpected lock acquire for critical section ID '{critical_section_id}' (expected '{self.critical_section_id}')") self.available_locks = self.critical_section_locks self.lock_acquisition_pending = False diff --git a/durabletask/task.py b/durabletask/task.py index 0f4016d..349dcb2 100644 --- a/durabletask/task.py +++ b/durabletask/task.py @@ -8,8 +8,10 @@ from abc import ABC, abstractmethod from datetime import datetime, timedelta from typing import Any, Callable, Generator, Generic, Optional, Type, TypeVar, Union, overload +import uuid from durabletask.entities import DurableEntity, EntityInstanceId, EntityLock +from durabletask.internal import shared from durabletask.internal.entity_state_shim import StateShim import durabletask.internal.helpers as pbh import durabletask.internal.orchestrator_service_pb2 as pb @@ -141,7 +143,7 @@ def call_activity(self, activity: Union[Activity[TInput, TOutput], str], *, pass @abstractmethod - def call_entity(self, entity: EntityInstanceId, + def call_entity(self, entity: EntityInstanceId, operation: str, *, input: Optional[TInput] = None): """Schedule entity function for execution. @@ -545,10 +547,10 @@ def operation(self) -> str: The operation associated with this entity invocation. """ return self._operation - + @overload def get_state(self, intended_type: Type[TState]) -> Optional[TState]: ... - + @overload def get_state(self, intended_type: None = None) -> Any: ... @@ -558,6 +560,37 @@ def get_state(self, intended_type: Optional[Type[TState]] = None) -> Optional[TS def set_state(self, new_state: Any): self._state.set_state(new_state) + def signal_entity(self, entity_instance_id: EntityInstanceId, operation: str, input: Optional[Any] = None) -> None: + encoded_input = shared.to_json(input) if input is not None else None + self._state.add_operation_action( + pb.OperationAction( + sendSignal=pb.SendSignalAction( + instanceId=str(entity_instance_id), + name=operation, + input=pbh.get_string_value(encoded_input), + scheduledTime=None, + requestTime=None, + parentTraceContext=None, + ) + ) + ) + + def schedule_new_orchestration(self, orchestration_name: str, input: Optional[Any] = None, instance_id: Optional[str] = None) -> None: + encoded_input = shared.to_json(input) if input is not None else None + self._state.add_operation_action( + pb.OperationAction( + startNewOrchestration=pb.StartNewOrchestrationAction( + instanceId=instance_id if instance_id else uuid.uuid4().hex, # TODO: Should this be non-none? + name=orchestration_name, + input=pbh.get_string_value(encoded_input), + version=None, + scheduledTime=None, + requestTime=None, + parentTraceContext=None + ) + ) + ) + @property def entity_id(self) -> EntityInstanceId: """Get the ID of the entity instance. diff --git a/durabletask/worker.py b/durabletask/worker.py index 2d3be95..e28b469 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -753,7 +753,7 @@ def _execute_entity_batch( batch_result = pb.EntityBatchResult( results=results, - actions=None, # TODO: Context should also provide actions like signaling another entity or starting sub-orchestrations + actions=entity_state.get_operation_actions(), entityState=helpers.get_string_value(shared.to_json(entity_state._current_state)) if entity_state._current_state else None, failureDetails=None, completionToken=completionToken, @@ -783,7 +783,10 @@ def __init__(self, instance_id: str, registry: _Registry, entity_context: Orches self._result = None self._pending_actions: dict[int, pb.OrchestratorAction] = {} self._pending_tasks: dict[int, task.CompletableTask] = {} - self._entity_task_id_map: dict[str, int] = {} + # Maps entity ID to task ID + self._entity_task_id_map: dict[str, tuple[EntityInstanceId, int]] = {} + # Maps criticalSectionId to task ID + self._entity_lock_id_map: dict[str, int] = {} self._sequence_number = 0 self._current_utc_datetime = datetime(1000, 1, 1) self._instance_id = instance_id @@ -1000,7 +1003,7 @@ def signal_entity( def lock_entities(self, entities: list[EntityInstanceId]) -> task.Task[EntityLock]: id = self.next_sequence_number() - + self.lock_entities_function_helper( id, entities ) @@ -1096,7 +1099,7 @@ def call_entity_function_helper( transition_valid, error_message = self._entity_context.validate_operation_transition(entity_id, False) if not transition_valid: raise RuntimeError(error_message) - + encoded_input = shared.to_json(input) if input is not None else None action = ph.new_call_entity_action(id, self.instance_id, entity_id, operation, encoded_input) self._pending_actions[id] = action @@ -1104,7 +1107,6 @@ def call_entity_function_helper( fn_task = task.CompletableTask() self._pending_tasks[id] = fn_task - def signal_entity_function_helper( self, id: Optional[int], @@ -1125,15 +1127,14 @@ def signal_entity_function_helper( action = ph.new_signal_entity_action(id, entity_id, operation, encoded_input) self._pending_actions[id] = action - def lock_entities_function_helper(self, id: int, entities: list[EntityInstanceId]) -> None: if id is None: id = self.next_sequence_number() - + transition_valid, error_message = self._entity_context.validate_acquire_transition() if not transition_valid: raise RuntimeError(error_message) - + critical_section_id = f"{self.instance_id}:{id:04x}" request, target = self._entity_context.emit_acquire_message(critical_section_id, entities) @@ -1147,7 +1148,6 @@ def lock_entities_function_helper(self, id: int, entities: list[EntityInstanceId fn_task = task.CompletableTask[EntityLock]() self._pending_tasks[id] = fn_task - def wait_for_external_event(self, name: str) -> task.Task: # Check to see if this event has already been received, in which case we # can return it immediately. Otherwise, record out intent to receive an @@ -1573,8 +1573,10 @@ def process_event( raise _get_wrong_action_type_error( entity_call_id, expected_method_name, action ) - ctx._entity_task_id_map[event.entityOperationCalled.requestId] = entity_call_id - # TODO: Validate entity ID + entity_id = EntityInstanceId.parse(event.entityOperationCalled.targetInstanceId.value) + if not entity_id: + raise RuntimeError(f"Could not parse entity ID from targetInstanceId '{event.entityOperationCalled.targetInstanceId.value}'") + ctx._entity_task_id_map[event.entityOperationCalled.requestId] = (entity_id, entity_call_id) elif event.HasField("entityOperationSignaled"): # This history event confirms that the entity signal was successfully scheduled. # Remove the entityOperationSignaled event from the pending action list so we don't schedule it @@ -1591,9 +1593,7 @@ def process_event( ) elif event.HasField("entityLockRequested"): section_id = event.entityLockRequested.criticalSectionId - task_id = ctx._entity_task_id_map.get(section_id, None) - if not task_id: - raise RuntimeError(f"Unexpected entityLockRequested event for criticalSectionId '{section_id}'") + task_id = event.eventId action = ctx._pending_actions.pop(task_id, None) entity_task = ctx._pending_tasks.get(task_id, None) if not action: @@ -1605,18 +1605,26 @@ def process_event( raise _get_wrong_action_type_error( task_id, expected_method_name, action ) + ctx._entity_lock_id_map[section_id] = task_id elif event.HasField("entityUnlockSent"): - if not ctx.is_replaying: - self._logger.info(f"{ctx.instance_id}: Entity unlock sent.") - self._logger.info(f"Data: {json.dumps(event.entityUnlockSent)}") - # I don't think there's anything we need to do here - if we decide we need to send the lock - # release messages before continuing replay, we can confirm that they were processed here - pass + # Remove the unlock tasks as they have already been processed + tasks_to_remove = [] + for task_id, action in ctx._pending_actions.items(): + if action.HasField("sendEntityMessage") and action.sendEntityMessage.HasField("entityUnlockSent"): + if action.sendEntityMessage.entityUnlockSent.criticalSectionId == event.entityUnlockSent.criticalSectionId: + tasks_to_remove.append(task_id) + for task_to_remove in tasks_to_remove: + ctx._pending_actions.pop(task_to_remove, None) elif event.HasField("entityLockGranted"): section_id = event.entityLockGranted.criticalSectionId - task_id = ctx._entity_task_id_map.pop(section_id, None) + task_id = ctx._entity_lock_id_map.pop(section_id, None) if not task_id: - raise RuntimeError(f"Unexpected entityLockGranted event for criticalSectionId '{section_id}'") + # TODO: Should this be an error? When would it ever happen? + if not ctx.is_replaying: + self._logger.warning( + f"{ctx.instance_id}: Ignoring unexpected entityLockGranted event for criticalSectionId '{section_id}'." + ) + return entity_task = ctx._pending_tasks.pop(task_id, None) if not entity_task: if not ctx.is_replaying: @@ -1624,10 +1632,14 @@ def process_event( f"{ctx.instance_id}: Ignoring unexpected entityLockGranted event for criticalSectionId '{section_id}'." ) return + ctx._entity_context.complete_acquire(section_id) entity_task.complete(EntityLock(ctx)) + ctx.resume() elif event.HasField("entityOperationCompleted"): request_id = event.entityOperationCompleted.requestId - task_id = ctx._entity_task_id_map.pop(request_id, None) + entity_id, task_id = ctx._entity_task_id_map.pop(request_id, (None, None)) + if not entity_id: + raise RuntimeError(f"Could not parse entity ID from request ID '{request_id}'") if not task_id: raise RuntimeError(f"Could not find matching task ID for entity operation with request ID '{request_id}'") entity_task = ctx._pending_tasks.pop(task_id, None) @@ -1640,6 +1652,7 @@ def process_event( result = None if not ph.is_empty(event.entityOperationCompleted.output): result = shared.from_json(event.entityOperationCompleted.output.value) + ctx._entity_context.recover_lock_after_call(entity_id) entity_task.complete(result) ctx.resume() elif event.HasField("entityOperationFailed"): From 29c645466189efa7921741d05d2d7834b9c031e3 Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Fri, 26 Sep 2025 13:16:40 -0600 Subject: [PATCH 06/21] Add samples, test, documentation --- docs/features.md | 82 ++++++ durabletask/entities/durable_entity.py | 16 +- durabletask/internal/entity_state_shim.py | 12 +- durabletask/task.py | 20 +- examples/entities/class_based_entity.py | 65 +++++ .../entities/class_based_entity_actions.py | 85 ++++++ examples/entities/entity_locking.py | 67 +++++ examples/entities/function_based_entity.py | 66 +++++ .../entities/function_based_entity_actions.py | 79 ++++++ .../test_dts_class_based_entities_e2e.py | 111 ++++++++ ...st_dts_function_based_entities_e2e copy.py | 259 ++++++++++++++++++ 11 files changed, 846 insertions(+), 16 deletions(-) create mode 100644 examples/entities/class_based_entity.py create mode 100644 examples/entities/class_based_entity_actions.py create mode 100644 examples/entities/entity_locking.py create mode 100644 examples/entities/function_based_entity.py create mode 100644 examples/entities/function_based_entity_actions.py create mode 100644 tests/durabletask-azuremanaged/test_dts_class_based_entities_e2e.py create mode 100644 tests/durabletask-azuremanaged/test_dts_function_based_entities_e2e copy.py diff --git a/docs/features.md b/docs/features.md index d3fcc56..1357b22 100644 --- a/docs/features.md +++ b/docs/features.md @@ -48,6 +48,88 @@ Orchestrations can schedule durable timers using the `create_timer` API. These t Orchestrations can start child orchestrations using the `call_sub_orchestrator` API. Child orchestrations are useful for encapsulating complex logic and for breaking up large orchestrations into smaller, more manageable pieces. Sub-orchestrations can also be versioned in a similar manner to their parent orchestrations, however, they do not inherit the parent orchestrator's version. Instead, they will use the default_version defined in the current worker's VersioningOptions unless otherwise specified during `call_sub_orchestrator`. +### Entities + +#### Concepts + +Durable Entities provide a way to model small, stateful objects within your orchestration workflows. Each entity has a unique identity and maintains its own state, which is persisted durably. Entities can be interacted with by sending them operations (messages) that mutate or query their state. These operations are processed sequentially, ensuring consistency. Examples of uses for durable entities include counters, accumulators, or any other operation which requires state to persist across orchestrations. + +Entities can be invoked from durable clients directly, or from durable orchestrators. They support features like automatic state persistence, concurrency control, and can be locked for exclusive access during critical operations. + +Entities are accessed by a unique ID, implemented here as EntityInstanceId. This ID is comprised of two parts, an entity name referring to the function or class that defines the behavior of the entity, and a key which is any string defined in your code. Each entity instance, represented by a distinct EntityInstanceId, has its own state. + +#### Syntax + +##### Defining Entities + +Entities can be defined using either function-based or class-based syntax. + +```python +# Funtion-based entity +def counter(ctx: task.EntityContext, input: int): + state = ctx.get_state(int, 0) + if ctx.operation == "add": + state += input + ctx.set_state(state) + elif operation == "get": + return state + +# Class-based entity +class Counter(entities.DurableEntity): + def __init__(self): + self.set_state(0) + + def add(self, amount: int): + self.set_state(self.get_state(int, 0) + amount) + + def get(self): + return self.get_state(int, 0) +``` + +> Note that the object properties of class-based entities may not be preserved across invocations. Use the derived get_state and set_state methods to access the persisted entity data. + +##### Invoking entities + +Entities are invoked using the `signal_entity` or `call_entity` APIs. The Durable Client only allows `signal_entity`: + +```python +c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) +entity_id = entities.EntityInstanceId("my_entity_function", "myEntityId") +c.signal_entity(entity_id, "do_nothing") +``` + +Whereas orchestrators can choose to use `signal_entity` or `call_entity`: + +```python +# Signal an entity (fire-and-forget) +entity_id = entities.EntityInstanceId("my_entity_function", "myEntityId") +ctx.signal_entity(entity_id, operation_name="add", input=5) + +# Call an entity (wait for result) +entity_id = entities.EntityInstanceId("my_entity_function", "myEntityId") +result = yield ctx.call_entity(entity_id, operation_name="get") +``` + +##### Entity actions + +Entities can perform actions such signaling other entities or starting new orchestrations + +- `ctx.signal_entity(entity_id, operation, input)` +- `ctx.schedule_new_orchestration(orchestrator_name, input)` + +##### Locking and concurrency + +Because entites can be accessed from multiple running orchestrations at the same time, entities may also be locked by a single orchestrator ensuring exclusive access during the duration of the lock (also known as a critical section). Think semaphores: + +```python +with (yield ctx.lock_entities([entity_id_1, entity_id_2]): + # Perform entity call operations that require exclusive access + ... +``` + +Note that locked entities may not be signalled, and every call to a locked entity must return a result before another call to the same entity may be made from within the critical section. For more details and advanced usage, see the examples and API documentation. + ### External events Orchestrations can wait for external events using the `wait_for_external_event` API. External events are useful for implementing human interaction patterns, such as waiting for a user to approve an order before continuing. diff --git a/durabletask/entities/durable_entity.py b/durabletask/entities/durable_entity.py index 03c5dce..8369cb1 100644 --- a/durabletask/entities/durable_entity.py +++ b/durabletask/entities/durable_entity.py @@ -1,22 +1,26 @@ from typing import Any, Optional, Type, TypeVar, overload from durabletask.entities.entity_instance_id import EntityInstanceId +from durabletask.task import EntityContext TState = TypeVar("TState") class DurableEntity: - def _initialize_entity_context(self, context): + def _initialize_entity_context(self, context: EntityContext): self.entity_context = context + @overload + def get_state(self, intended_type: Type[TState], default: TState) -> TState: ... + @overload def get_state(self, intended_type: Type[TState]) -> Optional[TState]: ... @overload - def get_state(self, intended_type: None = None) -> Any: ... + def get_state(self, intended_type: None = None, default: Any = None) -> Any: ... - def get_state(self, intended_type: Optional[Type[TState]] = None) -> Optional[TState] | Any: - return self.entity_context.get_state(intended_type) + def get_state(self, intended_type: Optional[Type[TState]] = None, default: Optional[TState] = None) -> Optional[TState] | Any: + return self.entity_context.get_state(intended_type, default) def set_state(self, state: Any): self.entity_context.set_state(state) @@ -24,5 +28,5 @@ def set_state(self, state: Any): def signal_entity(self, entity_instance_id: EntityInstanceId, operation: str, input: Optional[Any] = None) -> None: self.entity_context.signal_entity(entity_instance_id, operation, input) - def schedule_new_orchestration(self, orchestration_name: str, input: Optional[Any] = None, instance_id: Optional[str] = None) -> None: - self.entity_context.schedule_new_orchestration(orchestration_name, input, instance_id=instance_id) + def schedule_new_orchestration(self, orchestration_name: str, input: Optional[Any] = None, instance_id: Optional[str] = None) -> str: + return self.entity_context.schedule_new_orchestration(orchestration_name, input, instance_id=instance_id) diff --git a/durabletask/internal/entity_state_shim.py b/durabletask/internal/entity_state_shim.py index 5b410a3..fd025b4 100644 --- a/durabletask/internal/entity_state_shim.py +++ b/durabletask/internal/entity_state_shim.py @@ -13,17 +13,23 @@ def __init__(self, start_state): self._operation_actions: list[pb.OperationAction] = [] self._actions_checkpoint_state: int = 0 + @overload + def get_state(self, intended_type: Type[TState], default: TState) -> TState: ... + @overload def get_state(self, intended_type: Type[TState]) -> Optional[TState]: ... @overload - def get_state(self, intended_type: None = None) -> Any: ... + def get_state(self, intended_type: None = None, default: Any = None) -> Any: ... + + def get_state(self, intended_type: Optional[Type[TState]] = None, default: Optional[TState] = None) -> Optional[TState] | Any: + if self._current_state is None and default is not None: + return default - def get_state(self, intended_type: Optional[Type[TState]] = None) -> Optional[TState] | Any: if intended_type is None: return self._current_state - if isinstance(self._current_state, intended_type) or self._current_state is None: + if isinstance(self._current_state, intended_type): return self._current_state try: diff --git a/durabletask/task.py b/durabletask/task.py index 349dcb2..da63e65 100644 --- a/durabletask/task.py +++ b/durabletask/task.py @@ -144,8 +144,8 @@ def call_activity(self, activity: Union[Activity[TInput, TOutput], str], *, @abstractmethod def call_entity(self, entity: EntityInstanceId, - operation: str, *, - input: Optional[TInput] = None): + operation: str, + input: Optional[TInput] = None) -> Task: """Schedule entity function for execution. Parameters @@ -548,14 +548,17 @@ def operation(self) -> str: """ return self._operation + @overload + def get_state(self, intended_type: Type[TState], default: TState) -> TState: ... + @overload def get_state(self, intended_type: Type[TState]) -> Optional[TState]: ... @overload - def get_state(self, intended_type: None = None) -> Any: ... + def get_state(self, intended_type: None = None, default: Any = None) -> Any: ... - def get_state(self, intended_type: Optional[Type[TState]] = None) -> Optional[TState] | Any: - return self._state.get_state(intended_type) + def get_state(self, intended_type: Optional[Type[TState]] = None, default: Optional[TState] = None) -> Optional[TState] | Any: + return self._state.get_state(intended_type, default) def set_state(self, new_state: Any): self._state.set_state(new_state) @@ -575,12 +578,14 @@ def signal_entity(self, entity_instance_id: EntityInstanceId, operation: str, in ) ) - def schedule_new_orchestration(self, orchestration_name: str, input: Optional[Any] = None, instance_id: Optional[str] = None) -> None: + def schedule_new_orchestration(self, orchestration_name: str, input: Optional[Any] = None, instance_id: Optional[str] = None) -> str: encoded_input = shared.to_json(input) if input is not None else None + if not instance_id: + instance_id = uuid.uuid4().hex self._state.add_operation_action( pb.OperationAction( startNewOrchestration=pb.StartNewOrchestrationAction( - instanceId=instance_id if instance_id else uuid.uuid4().hex, # TODO: Should this be non-none? + instanceId=instance_id, name=orchestration_name, input=pbh.get_string_value(encoded_input), version=None, @@ -590,6 +595,7 @@ def schedule_new_orchestration(self, orchestration_name: str, input: Optional[An ) ) ) + return instance_id @property def entity_id(self) -> EntityInstanceId: diff --git a/examples/entities/class_based_entity.py b/examples/entities/class_based_entity.py new file mode 100644 index 0000000..f211b65 --- /dev/null +++ b/examples/entities/class_based_entity.py @@ -0,0 +1,65 @@ +"""End-to-end sample that demonstrates how to configure an orchestrator +that calls an activity function in a sequence and prints the outputs.""" +import os + +from azure.identity import DefaultAzureCredential + +from durabletask import client, task, entities +from durabletask.azuremanaged.client import DurableTaskSchedulerClient +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker + + +class Counter(entities.DurableEntity): + def set(self, input: int): + self.set_state(input) + + def add(self, input: int): + current_state = self.get_state(int, 0) + new_state = current_state + (input or 1) + self.set_state(new_state) + return new_state + + def get(self): + return self.get_state(int, 0) + + +def counter_orchestrator(ctx: task.OrchestrationContext, _): + """Orchestrator function that demonstrates the behavior of the counter entity""" + + entity_id = task.EntityInstanceId("Counter", "myCounter") + + # Initialize the entity with state 0 + ctx.signal_entity(entity_id, "set", 0) + # Increment the counter by 1 + yield ctx.call_entity(entity_id, "add", 1) + # Return the entity's current value (should be 1) + return (yield ctx.call_entity(entity_id, "get")) + + +# Use environment variables if provided, otherwise use default emulator values +taskhub_name = os.getenv("TASKHUB", "default") +endpoint = os.getenv("ENDPOINT", "http://localhost:8080") + +print(f"Using taskhub: {taskhub_name}") +print(f"Using endpoint: {endpoint}") + +# Set credential to None for emulator, or DefaultAzureCredential for Azure +credential = None if endpoint == "http://localhost:8080" else DefaultAzureCredential() + +# configure and start the worker - use secure_channel=False for emulator +secure_channel = endpoint != "http://localhost:8080" +with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=secure_channel, + taskhub=taskhub_name, token_credential=credential) as w: + w.add_orchestrator(counter_orchestrator) + w.add_entity(Counter) + w.start() + + # Construct the client and run the orchestrations + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=secure_channel, + taskhub=taskhub_name, token_credential=credential) + instance_id = c.schedule_new_orchestration(counter_orchestrator) + state = c.wait_for_orchestration_completion(instance_id, timeout=60) + if state and state.runtime_status == client.OrchestrationStatus.COMPLETED: + print(f'Orchestration completed! Result: {state.serialized_output}') + elif state: + print(f'Orchestration failed: {state.failure_details}') diff --git a/examples/entities/class_based_entity_actions.py b/examples/entities/class_based_entity_actions.py new file mode 100644 index 0000000..1804aaf --- /dev/null +++ b/examples/entities/class_based_entity_actions.py @@ -0,0 +1,85 @@ +"""End-to-end sample that demonstrates how to configure an orchestrator +that calls an activity function in a sequence and prints the outputs.""" +import os + +from azure.identity import DefaultAzureCredential + +from durabletask import client, task, entities +from durabletask.azuremanaged.client import DurableTaskSchedulerClient +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker + + +class Counter(entities.DurableEntity): + def set(self, input: int): + self.set_state(input) + + def add(self, input: int): + current_state = self.get_state(int, 0) + new_state = current_state + (input or 1) + self.set_state(new_state) + return new_state + + def update_parent(self): + parent_entity_id = entities.EntityInstanceId("Counter", "parentCounter") + if self.entity_context.entity_id == parent_entity_id: + return # Prevent self-update + self.signal_entity(parent_entity_id, "set", self.get_state(int, 0)) + + def start_hello(self): + self.schedule_new_orchestration("hello_orchestrator") + + def get(self): + return self.get_state(int, 0) + + +def counter_orchestrator(ctx: task.OrchestrationContext, _): + """Orchestrator function that demonstrates the behavior of the counter entity""" + + entity_id = task.EntityInstanceId("Counter", "myCounter") + parent_entity_id = task.EntityInstanceId("Counter", "parentCounter") + + # Use Counter to demonstrate starting an orchestration from an entity + ctx.signal_entity(entity_id, "start_hello") + + # User Counter to demonstrate signaling an entity from another entity + # Initialize myCounter with state 0, increment it by 1, and set the state of parentCounter using + # update_parent on myCounter. Retrieve and return the state of parentCounter (should be 1). + ctx.signal_entity(entity_id, "set", 0) + yield ctx.call_entity(entity_id, "add", 1) + yield ctx.call_entity(entity_id, "update_parent") + + return (yield ctx.call_entity(parent_entity_id, "get")) + + +def hello_orchestrator(ctx: task.OrchestrationContext, _): + return f"Hello world!" + + +# Use environment variables if provided, otherwise use default emulator values +taskhub_name = os.getenv("TASKHUB", "default") +endpoint = os.getenv("ENDPOINT", "http://localhost:8080") + +print(f"Using taskhub: {taskhub_name}") +print(f"Using endpoint: {endpoint}") + +# Set credential to None for emulator, or DefaultAzureCredential for Azure +credential = None if endpoint == "http://localhost:8080" else DefaultAzureCredential() + +# configure and start the worker - use secure_channel=False for emulator +secure_channel = endpoint != "http://localhost:8080" +with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=secure_channel, + taskhub=taskhub_name, token_credential=credential) as w: + w.add_orchestrator(counter_orchestrator) + w.add_orchestrator(hello_orchestrator) + w.add_entity(Counter) + w.start() + + # Construct the client and run the orchestrations + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=secure_channel, + taskhub=taskhub_name, token_credential=credential) + instance_id = c.schedule_new_orchestration(counter_orchestrator) + state = c.wait_for_orchestration_completion(instance_id, timeout=60) + if state and state.runtime_status == client.OrchestrationStatus.COMPLETED: + print(f'Orchestration completed! Result: {state.serialized_output}') + elif state: + print(f'Orchestration failed: {state.failure_details}') diff --git a/examples/entities/entity_locking.py b/examples/entities/entity_locking.py new file mode 100644 index 0000000..cdc25ab --- /dev/null +++ b/examples/entities/entity_locking.py @@ -0,0 +1,67 @@ +"""End-to-end sample that demonstrates how to configure an orchestrator +that calls an activity function in a sequence and prints the outputs.""" +import os + +from azure.identity import DefaultAzureCredential + +from durabletask import client, task, entities +from durabletask.azuremanaged.client import DurableTaskSchedulerClient +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker + + +class Counter(entities.DurableEntity): + def set(self, input: int): + self.set_state(input) + + def add(self, input: int): + current_state = self.get_state(int, 0) + new_state = current_state + (input or 1) + self.set_state(new_state) + return new_state + + def get(self): + return self.get_state(int, 0) + + +def counter_orchestrator(ctx: task.OrchestrationContext, _): + """Orchestrator function that demonstrates the behavior of the counter entity""" + + entity_id = entities.EntityInstanceId("Counter", "myCounter") + + # Initialize the entity with state 0, increment the counter by 1, and get the entity state using + # entity locking to ensure no other orchestrator can modify the entity state between the calls to call_entity + with (yield ctx.lock_entities([entity_id])): + yield ctx.call_entity(entity_id, "set", 0) + yield ctx.call_entity(entity_id, "add", 1) + result = yield ctx.call_entity(entity_id, "get") + # Return the entity's current value (will be 1) + return result + + +# Use environment variables if provided, otherwise use default emulator values +taskhub_name = os.getenv("TASKHUB", "default") +endpoint = os.getenv("ENDPOINT", "http://localhost:8080") + +print(f"Using taskhub: {taskhub_name}") +print(f"Using endpoint: {endpoint}") + +# Set credential to None for emulator, or DefaultAzureCredential for Azure +credential = None if endpoint == "http://localhost:8080" else DefaultAzureCredential() + +# configure and start the worker - use secure_channel=False for emulator +secure_channel = endpoint != "http://localhost:8080" +with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=secure_channel, + taskhub=taskhub_name, token_credential=credential) as w: + w.add_orchestrator(counter_orchestrator) + w.add_entity(Counter) + w.start() + + # Construct the client and run the orchestrations + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=secure_channel, + taskhub=taskhub_name, token_credential=credential) + instance_id = c.schedule_new_orchestration(counter_orchestrator) + state = c.wait_for_orchestration_completion(instance_id, timeout=60) + if state and state.runtime_status == client.OrchestrationStatus.COMPLETED: + print(f'Orchestration completed! Result: {state.serialized_output}') + elif state: + print(f'Orchestration failed: {state.failure_details}') diff --git a/examples/entities/function_based_entity.py b/examples/entities/function_based_entity.py new file mode 100644 index 0000000..cd2aa85 --- /dev/null +++ b/examples/entities/function_based_entity.py @@ -0,0 +1,66 @@ +"""End-to-end sample that demonstrates how to configure an orchestrator +that calls an activity function in a sequence and prints the outputs.""" +import os +from typing import Optional + +from azure.identity import DefaultAzureCredential + +from durabletask import client, entities, task +from durabletask.azuremanaged.client import DurableTaskSchedulerClient +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker + + +def counter(ctx: task.EntityContext, input: int) -> Optional[int]: + if ctx.operation == "set": + ctx.set_state(input) + if ctx.operation == "add": + current_state = ctx.get_state(int, 0) + new_state = current_state + (input or 1) + ctx.set_state(new_state) + return new_state + elif ctx.operation == "get": + return ctx.get_state(int, 0) + else: + raise ValueError(f"Unknown operation '{ctx.operation}'") + + +def counter_orchestrator(ctx: task.OrchestrationContext, _): + """Orchestrator function that demonstrates the behavior of the counter entity""" + + entity_id = entities.EntityInstanceId("counter", "myCounter") + + # Initialize the entity with state 0 + ctx.signal_entity(entity_id, "set", 0) + # Increment the counter by 1 + yield ctx.call_entity(entity_id, "add", 1) + # Return the entity's current value (should be 1) + return (yield ctx.call_entity(entity_id, "get")) + + +# Use environment variables if provided, otherwise use default emulator values +taskhub_name = os.getenv("TASKHUB", "default") +endpoint = os.getenv("ENDPOINT", "http://localhost:8080") + +print(f"Using taskhub: {taskhub_name}") +print(f"Using endpoint: {endpoint}") + +# Set credential to None for emulator, or DefaultAzureCredential for Azure +credential = None if endpoint == "http://localhost:8080" else DefaultAzureCredential() + +# configure and start the worker - use secure_channel=False for emulator +secure_channel = endpoint != "http://localhost:8080" +with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=secure_channel, + taskhub=taskhub_name, token_credential=credential) as w: + w.add_orchestrator(counter_orchestrator) + w.add_entity(counter) + w.start() + + # Construct the client and run the orchestrations + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=secure_channel, + taskhub=taskhub_name, token_credential=credential) + instance_id = c.schedule_new_orchestration(counter_orchestrator) + state = c.wait_for_orchestration_completion(instance_id, timeout=60) + if state and state.runtime_status == client.OrchestrationStatus.COMPLETED: + print(f'Orchestration completed! Result: {state.serialized_output}') + elif state: + print(f'Orchestration failed: {state.failure_details}') diff --git a/examples/entities/function_based_entity_actions.py b/examples/entities/function_based_entity_actions.py new file mode 100644 index 0000000..be3a97d --- /dev/null +++ b/examples/entities/function_based_entity_actions.py @@ -0,0 +1,79 @@ +"""End-to-end sample that demonstrates how to configure an orchestrator +that calls an activity function in a sequence and prints the outputs.""" +import os +from typing import Optional + +from azure.identity import DefaultAzureCredential + +from durabletask import client, task, entities +from durabletask.azuremanaged.client import DurableTaskSchedulerClient +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker + + +def counter(ctx: task.EntityContext, input: int) -> Optional[int]: + if ctx.operation == "set": + ctx.set_state(input) + elif ctx.operation == "get": + return ctx.get_state(int, 0) + elif ctx.operation == "update_parent": + parent_entity_id = entities.EntityInstanceId("counter", "parentCounter") + if ctx.entity_id == parent_entity_id: + return # Prevent self-update + ctx.signal_entity(parent_entity_id, "set", ctx.get_state(int, 0)) + elif ctx.operation == "start_hello": + ctx.schedule_new_orchestration("hello_orchestrator") + else: + raise ValueError(f"Unknown operation '{ctx.operation}'") + + +def counter_orchestrator(ctx: task.OrchestrationContext, _): + """Orchestrator function that demonstrates the behavior of the counter entity""" + + entity_id = task.EntityInstanceId("counter", "myCounter") + parent_entity_id = task.EntityInstanceId("counter", "parentCounter") + + # Use counter to demonstrate starting an orchestration from an entity + ctx.signal_entity(entity_id, "start_hello") + + # User counter to demonstrate signaling an entity from another entity + # Initialize myCounter with state 0, increment it by 1, and set the state of parentCounter using + # update_parent on myCounter. Retrieve and return the state of parentCounter (should be 1). + ctx.signal_entity(entity_id, "set", 0) + yield ctx.call_entity(entity_id, "add", 1) + yield ctx.call_entity(entity_id, "update_parent") + + return (yield ctx.call_entity(parent_entity_id, "get")) + + +def hello_orchestrator(ctx: task.OrchestrationContext, _): + return f"Hello world!" + + +# Use environment variables if provided, otherwise use default emulator values +taskhub_name = os.getenv("TASKHUB", "default") +endpoint = os.getenv("ENDPOINT", "http://localhost:8080") + +print(f"Using taskhub: {taskhub_name}") +print(f"Using endpoint: {endpoint}") + +# Set credential to None for emulator, or DefaultAzureCredential for Azure +credential = None if endpoint == "http://localhost:8080" else DefaultAzureCredential() + +# configure and start the worker - use secure_channel=False for emulator +secure_channel = endpoint != "http://localhost:8080" +with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=secure_channel, + taskhub=taskhub_name, token_credential=credential) as w: + w.add_orchestrator(counter_orchestrator) + w.add_orchestrator(hello_orchestrator) + w.add_entity(counter) + w.start() + + # Construct the client and run the orchestrations + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=secure_channel, + taskhub=taskhub_name, token_credential=credential) + instance_id = c.schedule_new_orchestration(counter_orchestrator) + state = c.wait_for_orchestration_completion(instance_id, timeout=60) + if state and state.runtime_status == client.OrchestrationStatus.COMPLETED: + print(f'Orchestration completed! Result: {state.serialized_output}') + elif state: + print(f'Orchestration failed: {state.failure_details}') diff --git a/tests/durabletask-azuremanaged/test_dts_class_based_entities_e2e.py b/tests/durabletask-azuremanaged/test_dts_class_based_entities_e2e.py new file mode 100644 index 0000000..2d98931 --- /dev/null +++ b/tests/durabletask-azuremanaged/test_dts_class_based_entities_e2e.py @@ -0,0 +1,111 @@ +import os +import time +from typing import Optional + +import pytest + +from durabletask import client, entities, task +from durabletask.azuremanaged.client import DurableTaskSchedulerClient +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker + +# NOTE: These tests assume a sidecar process is running. Example command: +# docker run -i -p 8080:8080 -p 8082:8082 -d mcr.microsoft.com/dts/dts-emulator:latest +pytestmark = pytest.mark.dts + +# Read the environment variables +taskhub_name = os.getenv("TASKHUB", "default") +endpoint = os.getenv("ENDPOINT", "http://localhost:8080") + + +def test_client_signal_class_entity(): + invoked = False + + class EmptyEntity(entities.DurableEntity): + def do_nothing(self, _): + nonlocal invoked # don't do this in a real app! + invoked = True + + # Start a worker, which will connect to the sidecar in a background thread + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_entity(EmptyEntity) + w.start() + + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + entity_id = entities.EntityInstanceId("EmptyEntity", "testEntity") + c.signal_entity(entity_id, "do_nothing") + time.sleep(2) # wait for the signal to be processed + + assert invoked + + +def test_orchestration_signal_class_entity(): + invoked = False + + class EmptyEntity(entities.DurableEntity): + def do_nothing(self, _): + nonlocal invoked # don't do this in a real app! + invoked = True + + def empty_orchestrator(ctx: task.OrchestrationContext, _): + entity_id = entities.EntityInstanceId("EmptyEntity", "testEntity") + ctx.signal_entity(entity_id, "do_nothing") + + # Start a worker, which will connect to the sidecar in a background thread + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(empty_orchestrator) + w.add_entity(EmptyEntity) + w.start() + + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + id = c.schedule_new_orchestration(empty_orchestrator) + state = c.wait_for_orchestration_completion(id, timeout=30) + time.sleep(2) # wait for the signal to be processed - signals cannot be awaited from inside the orchestrator + + assert invoked + assert state is not None + assert state.name == task.get_name(empty_orchestrator) + assert state.instance_id == id + assert state.failure_details is None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_input is None + assert state.serialized_output is None + assert state.serialized_custom_status is None + + +def test_orchestration_call_class_entity(): + invoked = False + + class EmptyEntity(entities.DurableEntity): + def do_nothing(self, _): + nonlocal invoked # don't do this in a real app! + invoked = True + + def empty_orchestrator(ctx: task.OrchestrationContext, _): + entity_id = entities.EntityInstanceId("EmptyEntity", "testEntity") + yield ctx.call_entity(entity_id, "do_nothing") + + # Start a worker, which will connect to the sidecar in a background thread + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(empty_orchestrator) + w.add_entity(EmptyEntity) + w.start() + + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + id = c.schedule_new_orchestration(empty_orchestrator) + state = c.wait_for_orchestration_completion(id, timeout=30) + + assert invoked + assert state is not None + assert state.name == task.get_name(empty_orchestrator) + assert state.instance_id == id + assert state.failure_details is None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_input is None + assert state.serialized_output is None + assert state.serialized_custom_status is None diff --git a/tests/durabletask-azuremanaged/test_dts_function_based_entities_e2e copy.py b/tests/durabletask-azuremanaged/test_dts_function_based_entities_e2e copy.py new file mode 100644 index 0000000..b4bdb15 --- /dev/null +++ b/tests/durabletask-azuremanaged/test_dts_function_based_entities_e2e copy.py @@ -0,0 +1,259 @@ +import os +import time +from typing import Optional + +import pytest + +from durabletask import client, entities, task +from durabletask.azuremanaged.client import DurableTaskSchedulerClient +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker + +# NOTE: These tests assume a sidecar process is running. Example command: +# docker run -i -p 8080:8080 -p 8082:8082 -d mcr.microsoft.com/dts/dts-emulator:latest +pytestmark = pytest.mark.dts + +# Read the environment variables +taskhub_name = os.getenv("TASKHUB", "default") +endpoint = os.getenv("ENDPOINT", "http://localhost:8080") + + +def test_client_signal_entity(): + invoked = False + + def empty_entity(ctx: task.EntityContext, _): + nonlocal invoked # don't do this in a real app! + if ctx.operation == "do_nothing": + invoked = True + + # Start a worker, which will connect to the sidecar in a background thread + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_entity(empty_entity) + w.start() + + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + entity_id = entities.EntityInstanceId("empty_entity", "testEntity") + c.signal_entity(entity_id, "do_nothing") + time.sleep(2) # wait for the signal to be processed + + assert invoked + + +def test_orchestration_signal_entity(): + invoked = False + + def empty_entity(ctx: task.EntityContext, _): + if ctx.operation == "do_nothing": + nonlocal invoked # don't do this in a real app! + invoked = True + + def empty_orchestrator(ctx: task.OrchestrationContext, _): + entity_id = entities.EntityInstanceId("empty_entity", "testEntity") + ctx.signal_entity(entity_id, "do_nothing") + + # Start a worker, which will connect to the sidecar in a background thread + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(empty_orchestrator) + w.add_entity(empty_entity) + w.start() + + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + id = c.schedule_new_orchestration(empty_orchestrator) + state = c.wait_for_orchestration_completion(id, timeout=30) + time.sleep(2) # wait for the signal to be processed - signals cannot be awaited from inside the orchestrator + + assert invoked + assert state is not None + assert state.name == task.get_name(empty_orchestrator) + assert state.instance_id == id + assert state.failure_details is None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_input is None + assert state.serialized_output is None + assert state.serialized_custom_status is None + + +def test_orchestration_call_entity(): + invoked = False + + def empty_entity(ctx: task.EntityContext, _): + if ctx.operation == "do_nothing": + nonlocal invoked # don't do this in a real app! + invoked = True + + def empty_orchestrator(ctx: task.OrchestrationContext, _): + entity_id = entities.EntityInstanceId("empty_entity", "testEntity") + yield ctx.call_entity(entity_id, "do_nothing") + + # Start a worker, which will connect to the sidecar in a background thread + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(empty_orchestrator) + w.add_entity(empty_entity) + w.start() + + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + id = c.schedule_new_orchestration(empty_orchestrator) + state = c.wait_for_orchestration_completion(id, timeout=30) + + assert invoked + assert state is not None + assert state.name == task.get_name(empty_orchestrator) + assert state.instance_id == id + assert state.failure_details is None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_input is None + assert state.serialized_output is None + assert state.serialized_custom_status is None + + +def test_orchestration_call_entity_with_lock(): + invoked = False + + def empty_entity(ctx: task.EntityContext, _): + if ctx.operation == "do_nothing": + nonlocal invoked # don't do this in a real app! + invoked = True + + def empty_orchestrator(ctx: task.OrchestrationContext, _): + entity_id = entities.EntityInstanceId("empty_entity", "testEntity") + with (yield ctx.lock_entities([entity_id])): + yield ctx.call_entity(entity_id, "do_nothing") + + # Start a worker, which will connect to the sidecar in a background thread + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(empty_orchestrator) + w.add_entity(empty_entity) + w.start() + + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + id = c.schedule_new_orchestration(empty_orchestrator) + state = c.wait_for_orchestration_completion(id, timeout=30) + + # Call this a second time to ensure the entity is still responsive after being locked and unlocked + id_2 = c.schedule_new_orchestration(empty_orchestrator) + state_2 = c.wait_for_orchestration_completion(id_2, timeout=30) + + assert invoked + assert state is not None + assert state.name == task.get_name(empty_orchestrator) + assert state.instance_id == id + assert state.failure_details is None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_input is None + assert state.serialized_output is None + assert state.serialized_custom_status is None + + assert state_2 is not None + assert state_2.name == task.get_name(empty_orchestrator) + assert state_2.instance_id == id_2 + assert state_2.failure_details is None + assert state_2.runtime_status == client.OrchestrationStatus.COMPLETED + assert state_2.serialized_input is None + assert state_2.serialized_output is None + assert state_2.serialized_custom_status is None + + +def test_orchestration_entity_signals_entity(): + invoked = False + + def empty_entity(ctx: task.EntityContext, _): + if ctx.operation == "do_nothing": + nonlocal invoked # don't do this in a real app! + invoked = True + elif ctx.operation == "signal_other": + entity_id = entities.EntityInstanceId("empty_entity", "otherEntity") + ctx.signal_entity(entity_id, "do_nothing") + + def empty_orchestrator(ctx: task.OrchestrationContext, _): + entity_id = entities.EntityInstanceId("empty_entity", "testEntity") + yield ctx.call_entity(entity_id, "signal_other") + + # Start a worker, which will connect to the sidecar in a background thread + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(empty_orchestrator) + w.add_entity(empty_entity) + w.start() + + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + id = c.schedule_new_orchestration(empty_orchestrator) + state = c.wait_for_orchestration_completion(id, timeout=30) + + assert invoked + assert state is not None + assert state.name == task.get_name(empty_orchestrator) + assert state.instance_id == id + assert state.failure_details is None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_input is None + assert state.serialized_output is None + assert state.serialized_custom_status is None + + +def test_entity_starts_orchestration(): + invoked = False + + def empty_entity(ctx: task.EntityContext, _): + if ctx.operation == "start_orchestration": + ctx.schedule_new_orchestration("empty_orchestrator") + + def empty_orchestrator(ctx: task.OrchestrationContext, _): + nonlocal invoked # don't do this in a real app! + invoked = True + + # Start a worker, which will connect to the sidecar in a background thread + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(empty_orchestrator) + w.add_entity(empty_entity) + w.start() + + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + c.signal_entity(entities.EntityInstanceId("empty_entity", "testEntity"), "start_orchestration") + time.sleep(2) # wait for the signal and orchestration to be processed + + assert invoked + + +def test_entity_locking_behavior(): + def empty_entity(ctx: task.EntityContext, _): + pass + + def empty_orchestrator(ctx: task.OrchestrationContext, _): + entity_id = entities.EntityInstanceId("empty_entity", "testEntity") + with (yield ctx.lock_entities([entity_id])): + # Cannot signal entities that have been locked + assert pytest.raises(Exception, ctx.signal_entity, entity_id, "do_nothing") + ctx.call_entity(entity_id, "do_nothing") + # Cannot call entities that have been locked and already called, but not yet returned a result + assert pytest.raises(Exception, ctx.call_entity, entity_id, "do_nothing") + + # Start a worker, which will connect to the sidecar in a background thread + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(empty_orchestrator) + w.add_entity(empty_entity) + w.start() + + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + id = c.schedule_new_orchestration(empty_orchestrator) + state = c.wait_for_orchestration_completion(id, timeout=30) + + assert state is not None + assert state.name == task.get_name(empty_orchestrator) + assert state.instance_id == id + assert state.failure_details is None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_input is None + assert state.serialized_output is None + assert state.serialized_custom_status is None From bea195436be08deb0217a2498076439e0453e342 Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Fri, 26 Sep 2025 15:28:46 -0600 Subject: [PATCH 07/21] Linting --- .../test_dts_class_based_entities_e2e.py | 1 - .../test_dts_function_based_entities_e2e copy.py | 1 - 2 files changed, 2 deletions(-) diff --git a/tests/durabletask-azuremanaged/test_dts_class_based_entities_e2e.py b/tests/durabletask-azuremanaged/test_dts_class_based_entities_e2e.py index 2d98931..19e8e5b 100644 --- a/tests/durabletask-azuremanaged/test_dts_class_based_entities_e2e.py +++ b/tests/durabletask-azuremanaged/test_dts_class_based_entities_e2e.py @@ -1,6 +1,5 @@ import os import time -from typing import Optional import pytest diff --git a/tests/durabletask-azuremanaged/test_dts_function_based_entities_e2e copy.py b/tests/durabletask-azuremanaged/test_dts_function_based_entities_e2e copy.py index b4bdb15..f45c10c 100644 --- a/tests/durabletask-azuremanaged/test_dts_function_based_entities_e2e copy.py +++ b/tests/durabletask-azuremanaged/test_dts_function_based_entities_e2e copy.py @@ -1,6 +1,5 @@ import os import time -from typing import Optional import pytest From 0f9b19acb718504bc8d41f2d57ed4129c0ab7364 Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Fri, 26 Sep 2025 15:30:19 -0600 Subject: [PATCH 08/21] Linting --- durabletask/entities/durable_entity.py | 9 ++++++--- durabletask/internal/entity_state_shim.py | 9 ++++++--- durabletask/task.py | 9 ++++++--- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/durabletask/entities/durable_entity.py b/durabletask/entities/durable_entity.py index 8369cb1..a79ec1d 100644 --- a/durabletask/entities/durable_entity.py +++ b/durabletask/entities/durable_entity.py @@ -11,13 +11,16 @@ def _initialize_entity_context(self, context: EntityContext): self.entity_context = context @overload - def get_state(self, intended_type: Type[TState], default: TState) -> TState: ... + def get_state(self, intended_type: Type[TState], default: TState) -> TState: + ... @overload - def get_state(self, intended_type: Type[TState]) -> Optional[TState]: ... + def get_state(self, intended_type: Type[TState]) -> Optional[TState]: + ... @overload - def get_state(self, intended_type: None = None, default: Any = None) -> Any: ... + def get_state(self, intended_type: None = None, default: Any = None) -> Any: + ... def get_state(self, intended_type: Optional[Type[TState]] = None, default: Optional[TState] = None) -> Optional[TState] | Any: return self.entity_context.get_state(intended_type, default) diff --git a/durabletask/internal/entity_state_shim.py b/durabletask/internal/entity_state_shim.py index fd025b4..f27edc5 100644 --- a/durabletask/internal/entity_state_shim.py +++ b/durabletask/internal/entity_state_shim.py @@ -14,13 +14,16 @@ def __init__(self, start_state): self._actions_checkpoint_state: int = 0 @overload - def get_state(self, intended_type: Type[TState], default: TState) -> TState: ... + def get_state(self, intended_type: Type[TState], default: TState) -> TState: + ... @overload - def get_state(self, intended_type: Type[TState]) -> Optional[TState]: ... + def get_state(self, intended_type: Type[TState]) -> Optional[TState]: + ... @overload - def get_state(self, intended_type: None = None, default: Any = None) -> Any: ... + def get_state(self, intended_type: None = None, default: Any = None) -> Any: + ... def get_state(self, intended_type: Optional[Type[TState]] = None, default: Optional[TState] = None) -> Optional[TState] | Any: if self._current_state is None and default is not None: diff --git a/durabletask/task.py b/durabletask/task.py index da63e65..8d52e77 100644 --- a/durabletask/task.py +++ b/durabletask/task.py @@ -549,13 +549,16 @@ def operation(self) -> str: return self._operation @overload - def get_state(self, intended_type: Type[TState], default: TState) -> TState: ... + def get_state(self, intended_type: Type[TState], default: TState) -> TState: + ... @overload - def get_state(self, intended_type: Type[TState]) -> Optional[TState]: ... + def get_state(self, intended_type: Type[TState]) -> Optional[TState]: + ... @overload - def get_state(self, intended_type: None = None, default: Any = None) -> Any: ... + def get_state(self, intended_type: None = None, default: Any = None) -> Any: + ... def get_state(self, intended_type: Optional[Type[TState]] = None, default: Optional[TState] = None) -> Optional[TState] | Any: return self._state.get_state(intended_type, default) From 35563892104543bb1531ba6d4e4980e61c64a1ce Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Fri, 26 Sep 2025 15:31:28 -0600 Subject: [PATCH 09/21] Linting --- examples/entities/class_based_entity_actions.py | 2 +- examples/entities/function_based_entity_actions.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/entities/class_based_entity_actions.py b/examples/entities/class_based_entity_actions.py index 1804aaf..8a38218 100644 --- a/examples/entities/class_based_entity_actions.py +++ b/examples/entities/class_based_entity_actions.py @@ -52,7 +52,7 @@ def counter_orchestrator(ctx: task.OrchestrationContext, _): def hello_orchestrator(ctx: task.OrchestrationContext, _): - return f"Hello world!" + return "Hello world!" # Use environment variables if provided, otherwise use default emulator values diff --git a/examples/entities/function_based_entity_actions.py b/examples/entities/function_based_entity_actions.py index be3a97d..9c349a8 100644 --- a/examples/entities/function_based_entity_actions.py +++ b/examples/entities/function_based_entity_actions.py @@ -46,7 +46,7 @@ def counter_orchestrator(ctx: task.OrchestrationContext, _): def hello_orchestrator(ctx: task.OrchestrationContext, _): - return f"Hello world!" + return "Hello world!" # Use environment variables if provided, otherwise use default emulator values From f72cde72bb84ba3439a4d58abc9dd1c8d8652a9c Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Fri, 26 Sep 2025 15:45:28 -0600 Subject: [PATCH 10/21] Remove circular import --- docs/features.md | 2 +- durabletask/client.py | 2 +- durabletask/entities/__init__.py | 3 +- durabletask/entities/durable_entity.py | 3 +- durabletask/entities/entity_context.py | 106 ++++++++++++++++++ durabletask/internal/helpers.py | 2 +- .../internal/orchestration_entity_context.py | 2 +- durabletask/task.py | 105 +---------------- durabletask/worker.py | 5 +- examples/entities/function_based_entity.py | 2 +- .../entities/function_based_entity_actions.py | 2 +- ...st_dts_function_based_entities_e2e copy.py | 14 +-- 12 files changed, 126 insertions(+), 122 deletions(-) create mode 100644 durabletask/entities/entity_context.py diff --git a/docs/features.md b/docs/features.md index 1357b22..cd28b2c 100644 --- a/docs/features.md +++ b/docs/features.md @@ -66,7 +66,7 @@ Entities can be defined using either function-based or class-based syntax. ```python # Funtion-based entity -def counter(ctx: task.EntityContext, input: int): +def counter(ctx: entities.EntityContext, input: int): state = ctx.get_state(int, 0) if ctx.operation == "add": state += input diff --git a/durabletask/client.py b/durabletask/client.py index a7a3775..c150822 100644 --- a/durabletask/client.py +++ b/durabletask/client.py @@ -11,7 +11,7 @@ import grpc from google.protobuf import wrappers_pb2 -from durabletask.entities.entity_instance_id import EntityInstanceId +from durabletask.entities import EntityInstanceId import durabletask.internal.helpers as helpers import durabletask.internal.orchestrator_service_pb2 as pb import durabletask.internal.orchestrator_service_pb2_grpc as stubs diff --git a/durabletask/entities/__init__.py b/durabletask/entities/__init__.py index 93ea7c5..4ab03c0 100644 --- a/durabletask/entities/__init__.py +++ b/durabletask/entities/__init__.py @@ -6,7 +6,8 @@ from durabletask.entities.entity_instance_id import EntityInstanceId from durabletask.entities.durable_entity import DurableEntity from durabletask.entities.entity_lock import EntityLock +from durabletask.entities.entity_context import EntityContext -__all__ = ["EntityInstanceId", "DurableEntity", "EntityLock"] +__all__ = ["EntityInstanceId", "DurableEntity", "EntityLock", "EntityContext"] PACKAGE_NAME = "durabletask.entities" diff --git a/durabletask/entities/durable_entity.py b/durabletask/entities/durable_entity.py index a79ec1d..c9b0dc1 100644 --- a/durabletask/entities/durable_entity.py +++ b/durabletask/entities/durable_entity.py @@ -1,7 +1,6 @@ from typing import Any, Optional, Type, TypeVar, overload -from durabletask.entities.entity_instance_id import EntityInstanceId -from durabletask.task import EntityContext +from durabletask.entities import EntityContext, EntityInstanceId TState = TypeVar("TState") diff --git a/durabletask/entities/entity_context.py b/durabletask/entities/entity_context.py new file mode 100644 index 0000000..775a57f --- /dev/null +++ b/durabletask/entities/entity_context.py @@ -0,0 +1,106 @@ + +from typing import Any, Optional, Type, TypeVar, overload +import uuid +from durabletask.entities import EntityInstanceId +from durabletask.internal import helpers, shared +from durabletask.internal.entity_state_shim import StateShim +import durabletask.internal.orchestrator_service_pb2 as pb + +TState = TypeVar("TState") + + +class EntityContext: + def __init__(self, orchestration_id: str, operation: str, state: StateShim, entity_id: EntityInstanceId): + self._orchestration_id = orchestration_id + self._operation = operation + self._state = state + self._entity_id = entity_id + + @property + def orchestration_id(self) -> str: + """Get the ID of the orchestration instance that scheduled this entity. + + Returns + ------- + str + The ID of the current orchestration instance. + """ + return self._orchestration_id + + @property + def operation(self) -> str: + """Get the operation associated with this entity invocation. + + The operation is a string that identifies the specific action being + performed on the entity. It can be used to distinguish between + multiple operations that are part of the same entity invocation. + + Returns + ------- + str + The operation associated with this entity invocation. + """ + return self._operation + + @overload + def get_state(self, intended_type: Type[TState], default: TState) -> TState: + ... + + @overload + def get_state(self, intended_type: Type[TState]) -> Optional[TState]: + ... + + @overload + def get_state(self, intended_type: None = None, default: Any = None) -> Any: + ... + + def get_state(self, intended_type: Optional[Type[TState]] = None, default: Optional[TState] = None) -> Optional[TState] | Any: + return self._state.get_state(intended_type, default) + + def set_state(self, new_state: Any): + self._state.set_state(new_state) + + def signal_entity(self, entity_instance_id: EntityInstanceId, operation: str, input: Optional[Any] = None) -> None: + encoded_input = shared.to_json(input) if input is not None else None + self._state.add_operation_action( + pb.OperationAction( + sendSignal=pb.SendSignalAction( + instanceId=str(entity_instance_id), + name=operation, + input=helpers.get_string_value(encoded_input), + scheduledTime=None, + requestTime=None, + parentTraceContext=None, + ) + ) + ) + + def schedule_new_orchestration(self, orchestration_name: str, input: Optional[Any] = None, instance_id: Optional[str] = None) -> str: + encoded_input = shared.to_json(input) if input is not None else None + if not instance_id: + instance_id = uuid.uuid4().hex + self._state.add_operation_action( + pb.OperationAction( + startNewOrchestration=pb.StartNewOrchestrationAction( + instanceId=instance_id, + name=orchestration_name, + input=helpers.get_string_value(encoded_input), + version=None, + scheduledTime=None, + requestTime=None, + parentTraceContext=None + ) + ) + ) + return instance_id + + @property + def entity_id(self) -> EntityInstanceId: + """Get the ID of the entity instance. + + Returns + ------- + str + The ID of the current entity instance. + """ + return self._entity_id diff --git a/durabletask/internal/helpers.py b/durabletask/internal/helpers.py index 3e6d887..ccd8558 100644 --- a/durabletask/internal/helpers.py +++ b/durabletask/internal/helpers.py @@ -7,7 +7,7 @@ from google.protobuf import timestamp_pb2, wrappers_pb2 -from durabletask.entities.entity_instance_id import EntityInstanceId +from durabletask.entities import EntityInstanceId import durabletask.internal.orchestrator_service_pb2 as pb # TODO: The new_xxx_event methods are only used by test code and should be moved elsewhere diff --git a/durabletask/internal/orchestration_entity_context.py b/durabletask/internal/orchestration_entity_context.py index 21085a9..1cb4619 100644 --- a/durabletask/internal/orchestration_entity_context.py +++ b/durabletask/internal/orchestration_entity_context.py @@ -3,7 +3,7 @@ from durabletask.internal.helpers import get_string_value import durabletask.internal.orchestrator_service_pb2 as pb -from durabletask.entities.entity_instance_id import EntityInstanceId +from durabletask.entities import EntityInstanceId class OrchestrationEntityContext: diff --git a/durabletask/task.py b/durabletask/task.py index 8d52e77..645354e 100644 --- a/durabletask/task.py +++ b/durabletask/task.py @@ -7,19 +7,15 @@ import math from abc import ABC, abstractmethod from datetime import datetime, timedelta -from typing import Any, Callable, Generator, Generic, Optional, Type, TypeVar, Union, overload -import uuid +from typing import Any, Callable, Generator, Generic, Optional, TypeVar, Union -from durabletask.entities import DurableEntity, EntityInstanceId, EntityLock -from durabletask.internal import shared -from durabletask.internal.entity_state_shim import StateShim +from durabletask.entities import DurableEntity, EntityInstanceId, EntityLock, EntityContext import durabletask.internal.helpers as pbh import durabletask.internal.orchestrator_service_pb2 as pb T = TypeVar('T') TInput = TypeVar('TInput') TOutput = TypeVar('TOutput') -TState = TypeVar("TState") class OrchestrationContext(ABC): @@ -515,103 +511,6 @@ def task_id(self) -> int: return self._task_id -class EntityContext: - def __init__(self, orchestration_id: str, operation: str, state: StateShim, entity_id: EntityInstanceId): - self._orchestration_id = orchestration_id - self._operation = operation - self._state = state - self._entity_id = entity_id - - @property - def orchestration_id(self) -> str: - """Get the ID of the orchestration instance that scheduled this entity. - - Returns - ------- - str - The ID of the current orchestration instance. - """ - return self._orchestration_id - - @property - def operation(self) -> str: - """Get the operation associated with this entity invocation. - - The operation is a string that identifies the specific action being - performed on the entity. It can be used to distinguish between - multiple operations that are part of the same entity invocation. - - Returns - ------- - str - The operation associated with this entity invocation. - """ - return self._operation - - @overload - def get_state(self, intended_type: Type[TState], default: TState) -> TState: - ... - - @overload - def get_state(self, intended_type: Type[TState]) -> Optional[TState]: - ... - - @overload - def get_state(self, intended_type: None = None, default: Any = None) -> Any: - ... - - def get_state(self, intended_type: Optional[Type[TState]] = None, default: Optional[TState] = None) -> Optional[TState] | Any: - return self._state.get_state(intended_type, default) - - def set_state(self, new_state: Any): - self._state.set_state(new_state) - - def signal_entity(self, entity_instance_id: EntityInstanceId, operation: str, input: Optional[Any] = None) -> None: - encoded_input = shared.to_json(input) if input is not None else None - self._state.add_operation_action( - pb.OperationAction( - sendSignal=pb.SendSignalAction( - instanceId=str(entity_instance_id), - name=operation, - input=pbh.get_string_value(encoded_input), - scheduledTime=None, - requestTime=None, - parentTraceContext=None, - ) - ) - ) - - def schedule_new_orchestration(self, orchestration_name: str, input: Optional[Any] = None, instance_id: Optional[str] = None) -> str: - encoded_input = shared.to_json(input) if input is not None else None - if not instance_id: - instance_id = uuid.uuid4().hex - self._state.add_operation_action( - pb.OperationAction( - startNewOrchestration=pb.StartNewOrchestrationAction( - instanceId=instance_id, - name=orchestration_name, - input=pbh.get_string_value(encoded_input), - version=None, - scheduledTime=None, - requestTime=None, - parentTraceContext=None - ) - ) - ) - return instance_id - - @property - def entity_id(self) -> EntityInstanceId: - """Get the ID of the entity instance. - - Returns - ------- - str - The ID of the current entity instance. - """ - return self._entity_id - - # Orchestrators are generators that yield tasks and receive/return any type Orchestrator = Callable[[OrchestrationContext, TInput], Union[Generator[Task, Any, Any], TOutput]] diff --git a/durabletask/worker.py b/durabletask/worker.py index e28b469..7d4c8d6 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -21,7 +21,7 @@ from durabletask.internal import helpers from durabletask.internal.entity_state_shim import StateShim from durabletask.internal.helpers import new_timestamp -from durabletask.entities import DurableEntity, EntityLock, EntityInstanceId +from durabletask.entities import DurableEntity, EntityLock, EntityInstanceId, EntityContext from durabletask.internal.orchestration_entity_context import OrchestrationEntityContext import durabletask.internal.helpers as ph import durabletask.internal.exceptions as pe @@ -978,7 +978,6 @@ def call_entity( self, entity_id: EntityInstanceId, operation: str, - *, input: Optional[TInput] = None, ) -> task.Task: id = self.next_sequence_number() @@ -1774,7 +1773,7 @@ def execute( ) entity_input = shared.from_json(encoded_input) if encoded_input else None - ctx = task.EntityContext(orchestration_id, operation, state, entity_id) + ctx = EntityContext(orchestration_id, operation, state, entity_id) if isinstance(fn, type) and issubclass(fn, DurableEntity): if self._registry.entity_instances.get(str(entity_id), None): diff --git a/examples/entities/function_based_entity.py b/examples/entities/function_based_entity.py index cd2aa85..a43b86d 100644 --- a/examples/entities/function_based_entity.py +++ b/examples/entities/function_based_entity.py @@ -10,7 +10,7 @@ from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker -def counter(ctx: task.EntityContext, input: int) -> Optional[int]: +def counter(ctx: entities.EntityContext, input: int) -> Optional[int]: if ctx.operation == "set": ctx.set_state(input) if ctx.operation == "add": diff --git a/examples/entities/function_based_entity_actions.py b/examples/entities/function_based_entity_actions.py index 9c349a8..129eb6c 100644 --- a/examples/entities/function_based_entity_actions.py +++ b/examples/entities/function_based_entity_actions.py @@ -10,7 +10,7 @@ from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker -def counter(ctx: task.EntityContext, input: int) -> Optional[int]: +def counter(ctx: entities.EntityContext, input: int) -> Optional[int]: if ctx.operation == "set": ctx.set_state(input) elif ctx.operation == "get": diff --git a/tests/durabletask-azuremanaged/test_dts_function_based_entities_e2e copy.py b/tests/durabletask-azuremanaged/test_dts_function_based_entities_e2e copy.py index f45c10c..e5018a5 100644 --- a/tests/durabletask-azuremanaged/test_dts_function_based_entities_e2e copy.py +++ b/tests/durabletask-azuremanaged/test_dts_function_based_entities_e2e copy.py @@ -19,7 +19,7 @@ def test_client_signal_entity(): invoked = False - def empty_entity(ctx: task.EntityContext, _): + def empty_entity(ctx: entities.EntityContext, _): nonlocal invoked # don't do this in a real app! if ctx.operation == "do_nothing": invoked = True @@ -42,7 +42,7 @@ def empty_entity(ctx: task.EntityContext, _): def test_orchestration_signal_entity(): invoked = False - def empty_entity(ctx: task.EntityContext, _): + def empty_entity(ctx: entities.EntityContext, _): if ctx.operation == "do_nothing": nonlocal invoked # don't do this in a real app! invoked = True @@ -78,7 +78,7 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): def test_orchestration_call_entity(): invoked = False - def empty_entity(ctx: task.EntityContext, _): + def empty_entity(ctx: entities.EntityContext, _): if ctx.operation == "do_nothing": nonlocal invoked # don't do this in a real app! invoked = True @@ -113,7 +113,7 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): def test_orchestration_call_entity_with_lock(): invoked = False - def empty_entity(ctx: task.EntityContext, _): + def empty_entity(ctx: entities.EntityContext, _): if ctx.operation == "do_nothing": nonlocal invoked # don't do this in a real app! invoked = True @@ -162,7 +162,7 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): def test_orchestration_entity_signals_entity(): invoked = False - def empty_entity(ctx: task.EntityContext, _): + def empty_entity(ctx: entities.EntityContext, _): if ctx.operation == "do_nothing": nonlocal invoked # don't do this in a real app! invoked = True @@ -200,7 +200,7 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): def test_entity_starts_orchestration(): invoked = False - def empty_entity(ctx: task.EntityContext, _): + def empty_entity(ctx: entities.EntityContext, _): if ctx.operation == "start_orchestration": ctx.schedule_new_orchestration("empty_orchestrator") @@ -224,7 +224,7 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): def test_entity_locking_behavior(): - def empty_entity(ctx: task.EntityContext, _): + def empty_entity(ctx: entities.EntityContext, _): pass def empty_orchestrator(ctx: task.OrchestrationContext, _): From ebf98bf59ffc46f1499c0e99908ed9c4ef090688 Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Fri, 26 Sep 2025 15:47:26 -0600 Subject: [PATCH 11/21] Remove circular import --- durabletask/entities/durable_entity.py | 3 ++- durabletask/entities/entity_context.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/durabletask/entities/durable_entity.py b/durabletask/entities/durable_entity.py index c9b0dc1..31e3488 100644 --- a/durabletask/entities/durable_entity.py +++ b/durabletask/entities/durable_entity.py @@ -1,6 +1,7 @@ from typing import Any, Optional, Type, TypeVar, overload -from durabletask.entities import EntityContext, EntityInstanceId +from durabletask.entities.entity_context import EntityContext +from durabletask.entities.entity_instance_id import EntityInstanceId TState = TypeVar("TState") diff --git a/durabletask/entities/entity_context.py b/durabletask/entities/entity_context.py index 775a57f..b861030 100644 --- a/durabletask/entities/entity_context.py +++ b/durabletask/entities/entity_context.py @@ -1,7 +1,7 @@ from typing import Any, Optional, Type, TypeVar, overload import uuid -from durabletask.entities import EntityInstanceId +from durabletask.entities.entity_instance_id import EntityInstanceId from durabletask.internal import helpers, shared from durabletask.internal.entity_state_shim import StateShim import durabletask.internal.orchestrator_service_pb2 as pb From a993cfc2d738eb9e40ca1fb33446e99ae156873f Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Mon, 29 Sep 2025 11:08:44 -0600 Subject: [PATCH 12/21] Remove resolved todos --- durabletask/entities/entity_lock.py | 1 - durabletask/internal/orchestration_entity_context.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/durabletask/entities/entity_lock.py b/durabletask/entities/entity_lock.py index 5cbf7ea..9f978bc 100644 --- a/durabletask/entities/entity_lock.py +++ b/durabletask/entities/entity_lock.py @@ -9,7 +9,6 @@ def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): # TODO: Handle exceptions? - print(f"Unlocking entities: {self._context._entity_context.critical_section_locks}") for entity_unlock_message in self._context._entity_context.emit_lock_release_messages(): task_id = self._context.next_sequence_number() action = pb.OrchestratorAction(id=task_id, sendEntityMessage=entity_unlock_message) diff --git a/durabletask/internal/orchestration_entity_context.py b/durabletask/internal/orchestration_entity_context.py index 1cb4619..d8465f0 100644 --- a/durabletask/internal/orchestration_entity_context.py +++ b/durabletask/internal/orchestration_entity_context.py @@ -67,7 +67,6 @@ def emit_lock_release_messages(self): )) yield unlock_event - # TODO: Emit the actual release messages (?) self.critical_section_locks = [] self.available_locks = [] self.critical_section_id = None @@ -104,7 +103,6 @@ def emit_acquire_message(self, critical_section_id: str, entities: List[EntityIn return request, target def complete_acquire(self, critical_section_id): - # TODO: HashSet or equivalent if self.critical_section_id != critical_section_id: raise RuntimeError(f"Unexpected lock acquire for critical section ID '{critical_section_id}' (expected '{self.critical_section_id}')") self.available_locks = self.critical_section_locks From f66abf415f821c1c84f48eb829473a93aac846d3 Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Mon, 29 Sep 2025 11:19:39 -0600 Subject: [PATCH 13/21] Add SDK docs for expoded methods --- durabletask/entities/durable_entity.py | 48 ++++++++++++++++++++++ durabletask/entities/entity_context.py | 48 ++++++++++++++++++++++ durabletask/entities/entity_instance_id.py | 12 ++++++ durabletask/task.py | 10 +++-- 4 files changed, 115 insertions(+), 3 deletions(-) diff --git a/durabletask/entities/durable_entity.py b/durabletask/entities/durable_entity.py index 31e3488..41fcab4 100644 --- a/durabletask/entities/durable_entity.py +++ b/durabletask/entities/durable_entity.py @@ -23,13 +23,61 @@ def get_state(self, intended_type: None = None, default: Any = None) -> Any: ... def get_state(self, intended_type: Optional[Type[TState]] = None, default: Optional[TState] = None) -> Optional[TState] | Any: + """Get the current state of the entity, optionally converting it to a specified type. + + Parameters + ---------- + intended_type : Type[TState] | None, optional + The type to which the state should be converted. If None, the state is returned as-is. + default : TState, optional + The default value to return if the state is not found or cannot be converted. + + Returns + ------- + TState | Any + The current state of the entity, optionally converted to the specified type. + """ return self.entity_context.get_state(intended_type, default) def set_state(self, state: Any): + """Set the state of the entity to a new value. + + Parameters + ---------- + new_state : Any + The new state to set for the entity. + """ self.entity_context.set_state(state) def signal_entity(self, entity_instance_id: EntityInstanceId, operation: str, input: Optional[Any] = None) -> None: + """Signal another entity to perform an operation. + + Parameters + ---------- + entity_instance_id : EntityInstanceId + The ID of the entity instance to signal. + operation : str + The operation to perform on the entity. + input : Any, optional + The input to provide to the entity for the operation. + """ self.entity_context.signal_entity(entity_instance_id, operation, input) def schedule_new_orchestration(self, orchestration_name: str, input: Optional[Any] = None, instance_id: Optional[str] = None) -> str: + """Schedule a new orchestration instance. + + Parameters + ---------- + orchestration_name : str + The name of the orchestration to schedule. + input : Any, optional + The input to provide to the new orchestration. + instance_id : str, optional + The instance ID to assign to the new orchestration. If None, a new ID will be generated. + + Returns + ------- + str + The instance ID of the scheduled orchestration. + """ return self.entity_context.schedule_new_orchestration(orchestration_name, input, instance_id=instance_id) diff --git a/durabletask/entities/entity_context.py b/durabletask/entities/entity_context.py index b861030..974767b 100644 --- a/durabletask/entities/entity_context.py +++ b/durabletask/entities/entity_context.py @@ -55,12 +55,44 @@ def get_state(self, intended_type: None = None, default: Any = None) -> Any: ... def get_state(self, intended_type: Optional[Type[TState]] = None, default: Optional[TState] = None) -> Optional[TState] | Any: + """Get the current state of the entity, optionally converting it to a specified type. + + Parameters + ---------- + intended_type : Type[TState] | None, optional + The type to which the state should be converted. If None, the state is returned as-is. + default : TState, optional + The default value to return if the state is not found or cannot be converted. + + Returns + ------- + TState | Any + The current state of the entity, optionally converted to the specified type. + """ return self._state.get_state(intended_type, default) def set_state(self, new_state: Any): + """Set the state of the entity to a new value. + + Parameters + ---------- + new_state : Any + The new state to set for the entity. + """ self._state.set_state(new_state) def signal_entity(self, entity_instance_id: EntityInstanceId, operation: str, input: Optional[Any] = None) -> None: + """Signal another entity to perform an operation. + + Parameters + ---------- + entity_instance_id : EntityInstanceId + The ID of the entity instance to signal. + operation : str + The operation to perform on the entity. + input : Any, optional + The input to provide to the entity for the operation. + """ encoded_input = shared.to_json(input) if input is not None else None self._state.add_operation_action( pb.OperationAction( @@ -76,6 +108,22 @@ def signal_entity(self, entity_instance_id: EntityInstanceId, operation: str, in ) def schedule_new_orchestration(self, orchestration_name: str, input: Optional[Any] = None, instance_id: Optional[str] = None) -> str: + """Schedule a new orchestration instance. + + Parameters + ---------- + orchestration_name : str + The name of the orchestration to schedule. + input : Any, optional + The input to provide to the new orchestration. + instance_id : str, optional + The instance ID to assign to the new orchestration. If None, a new ID will be generated. + + Returns + ------- + str + The instance ID of the scheduled orchestration. + """ encoded_input = shared.to_json(input) if input is not None else None if not instance_id: instance_id = uuid.uuid4().hex diff --git a/durabletask/entities/entity_instance_id.py b/durabletask/entities/entity_instance_id.py index 1fee44f..53c1171 100644 --- a/durabletask/entities/entity_instance_id.py +++ b/durabletask/entities/entity_instance_id.py @@ -21,6 +21,18 @@ def __lt__(self, other): @staticmethod def parse(entity_id: str) -> Optional["EntityInstanceId"]: + """Parse a string representation of an entity ID into an EntityInstanceId object. + + Parameters + ---------- + entity_id : str + The string representation of the entity ID, in the format '@entity@key'. + + Returns + ------- + Optional[EntityInstanceId] + The parsed EntityInstanceId object, or None if the input is None. + """ try: _, entity, key = entity_id.split("@", 2) return EntityInstanceId(entity=entity, key=key) diff --git a/durabletask/task.py b/durabletask/task.py index 645354e..ddb9ef0 100644 --- a/durabletask/task.py +++ b/durabletask/task.py @@ -181,8 +181,12 @@ def signal_entity( pass @abstractmethod - def lock_entities(self, entities: list[EntityInstanceId]) -> EntityLock: - """Lock the specified entity instances for the duration of the orchestration. + def lock_entities(self, entities: list[EntityInstanceId]) -> Task[EntityLock]: + """Creates a Task object that locks the specified entity instances. + + The locks will be acquired the next time the orchestrator yields. + Best practice is to immediately yield this Task and enter the returned EntityLock. + The lock is released when the EntityLock is exited. Parameters ---------- @@ -192,7 +196,7 @@ def lock_entities(self, entities: list[EntityInstanceId]) -> EntityLock: Returns ------- EntityLock - A disposable object that acquires and releases the locks when initialized or disposed. + A context manager object that releases the locks when exited. """ pass From a377de7ce845a1fc1ad4c163a65d27e920d3a584 Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Tue, 30 Sep 2025 15:03:25 -0600 Subject: [PATCH 14/21] Various - Return pending actions when orchestrations complete - Ensure locked entities are unlocked when orchestration ends (success/fail/continue_as_new) - Provide default "delete" operation and document deleting entities --- docs/features.md | 4 + durabletask/entities/durable_entity.py | 10 ++ durabletask/entities/entity_lock.py | 5 +- durabletask/worker.py | 42 ++++++-- ...st_dts_function_based_entities_e2e copy.py | 100 ++++++++++++++++++ 5 files changed, 151 insertions(+), 10 deletions(-) diff --git a/docs/features.md b/docs/features.md index cd28b2c..daa727e 100644 --- a/docs/features.md +++ b/docs/features.md @@ -130,6 +130,10 @@ with (yield ctx.lock_entities([entity_id_1, entity_id_2]): Note that locked entities may not be signalled, and every call to a locked entity must return a result before another call to the same entity may be made from within the critical section. For more details and advanced usage, see the examples and API documentation. +##### Deleting entities + +Entites are represented as orchestration instances in your Task Hub, and their state is persisted in the Task Hub as well. When using the Durable Task Scheduler as your durability provider, the backend will automatically clean up entities when their state is empty, this is effectively the "delete" operation to save space in the Task Hub. In the DTS Dashboard, "delete entity" simply signals the entity with the "delete" operation. In this SDK, we provide a default implementation for the "delete" operation to clear the state when using class-based entities, which end users are free to override as needed. Users must implement "delete" manually for function-based entities. + ### External events Orchestrations can wait for external events using the `wait_for_external_event` API. External events are useful for implementing human interaction patterns, such as waiting for a user to approve an order before continuing. diff --git a/durabletask/entities/durable_entity.py b/durabletask/entities/durable_entity.py index 41fcab4..7c81e79 100644 --- a/durabletask/entities/durable_entity.py +++ b/durabletask/entities/durable_entity.py @@ -81,3 +81,13 @@ def schedule_new_orchestration(self, orchestration_name: str, input: Optional[An The instance ID of the scheduled orchestration. """ return self.entity_context.schedule_new_orchestration(orchestration_name, input, instance_id=instance_id) + + def delete(self, input: Any = None) -> None: + """Delete the entity instance. + + Parameters + ---------- + input : Any, optional + Unused: The input for the entity "delete" operation. + """ + self.set_state(None) diff --git a/durabletask/entities/entity_lock.py b/durabletask/entities/entity_lock.py index 9f978bc..0768d1f 100644 --- a/durabletask/entities/entity_lock.py +++ b/durabletask/entities/entity_lock.py @@ -9,7 +9,4 @@ def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): # TODO: Handle exceptions? - for entity_unlock_message in self._context._entity_context.emit_lock_release_messages(): - task_id = self._context.next_sequence_number() - action = pb.OrchestratorAction(id=task_id, sendEntityMessage=entity_unlock_message) - self._context._pending_actions[task_id] = action + self._context._exit_critical_section() diff --git a/durabletask/worker.py b/durabletask/worker.py index 7d4c8d6..5463835 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -842,9 +842,15 @@ def set_complete( if self._is_complete: return + # If the user code returned without yielding the entity unlock, do that now + if self._entity_context.is_inside_critical_section: + self._exit_critical_section() + self._is_complete = True self._completion_status = status - self._pending_actions.clear() # Cancel any pending actions + # This is probably a bug - an orchestrator may complete with some actions remaining that the user still + # wants to execute - for example, signaling an entity. So we shouldn't clear the pending actions here. + # self._pending_actions.clear() # Cancel any pending actions self._result = result result_json: Optional[str] = None @@ -859,8 +865,14 @@ def set_failed(self, ex: Union[Exception, pb.TaskFailureDetails]): if self._is_complete: return + # If the user code crashed inside a critical section, or did not exit it, do that now + if self._entity_context.is_inside_critical_section: + self._exit_critical_section() + self._is_complete = True - self._pending_actions.clear() # Cancel any pending actions + # We also cannot cancel the pending actions in the failure case - if the user code had released an entity + # lock, we *must* send that action to the sidecar. + # self._pending_actions.clear() # Cancel any pending actions self._completion_status = pb.ORCHESTRATION_STATUS_FAILED action = ph.new_complete_orchestration_action( @@ -875,13 +887,20 @@ def set_continued_as_new(self, new_input: Any, save_events: bool): if self._is_complete: return + # If the user code called continue_as_new while holding an entity lock, unlock it now + if self._entity_context.is_inside_critical_section: + self._exit_critical_section() + self._is_complete = True - self._pending_actions.clear() # Cancel any pending actions + # We also cannot cancel the pending actions in the continue as new case - if the user code had released an + # entity lock, we *must* send that action to the sidecar. + # self._pending_actions.clear() # Cancel any pending actions self._completion_status = pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW self._new_input = new_input self._save_events = save_events def get_actions(self) -> list[pb.OrchestratorAction]: + current_actions = list(self._pending_actions.values()) if self._completion_status == pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW: # When continuing-as-new, we only return a single completion action. carryover_events: Optional[list[pb.HistoryEvent]] = None @@ -906,9 +925,9 @@ def get_actions(self) -> list[pb.OrchestratorAction]: failure_details=None, carryover_events=carryover_events, ) - return [action] - else: - return list(self._pending_actions.values()) + # We must return the existing tasks as well, to capture entity unlocks + current_actions.append(action) + return current_actions def next_sequence_number(self) -> int: self._sequence_number += 1 @@ -1147,6 +1166,17 @@ def lock_entities_function_helper(self, id: int, entities: list[EntityInstanceId fn_task = task.CompletableTask[EntityLock]() self._pending_tasks[id] = fn_task + def _exit_critical_section(self) -> None: + if not self._entity_context.is_inside_critical_section: + # Possible if the user calls continue_as_new inside the lock - in the success case, we will call + # _exit_critical_section both from the EntityLock and the exit logic. We must keep both calls in + # case the user code crashes after calling continue_as_new but before the EntityLock object is exited. + return + for entity_unlock_message in self._entity_context.emit_lock_release_messages(): + task_id = self.next_sequence_number() + action = pb.OrchestratorAction(id=task_id, sendEntityMessage=entity_unlock_message) + self._pending_actions[task_id] = action + def wait_for_external_event(self, name: str) -> task.Task: # Check to see if this event has already been received, in which case we # can return it immediately. Otherwise, record out intent to receive an diff --git a/tests/durabletask-azuremanaged/test_dts_function_based_entities_e2e copy.py b/tests/durabletask-azuremanaged/test_dts_function_based_entities_e2e copy.py index e5018a5..a15d0f2 100644 --- a/tests/durabletask-azuremanaged/test_dts_function_based_entities_e2e copy.py +++ b/tests/durabletask-azuremanaged/test_dts_function_based_entities_e2e copy.py @@ -1,7 +1,9 @@ +from datetime import timedelta import os import time import pytest +from azure.identity import DefaultAzureCredential from durabletask import client, entities, task from durabletask.azuremanaged.client import DurableTaskSchedulerClient @@ -14,6 +16,8 @@ # Read the environment variables taskhub_name = os.getenv("TASKHUB", "default") endpoint = os.getenv("ENDPOINT", "http://localhost:8080") +# endpoint = os.getenv("ENDPOINT", "https://andy-dts-testin-byaje2c8.northcentralus.durabletask.io") +credential = None if endpoint == "http://localhost:8080" else DefaultAzureCredential() def test_client_signal_entity(): @@ -256,3 +260,99 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): assert state.serialized_input is None assert state.serialized_output is None assert state.serialized_custom_status is None + + +def test_entity_unlocks_when_user_code_throws(): + invoke_count = 0 + + def empty_entity(ctx: entities.EntityContext, _): + nonlocal invoke_count # don't do this in a real app! + invoke_count += 1 + + def empty_orchestrator(ctx: task.OrchestrationContext, _): + entity_id = entities.EntityInstanceId("empty_entity", "testEntity3") + with (yield ctx.lock_entities([entity_id])): + yield ctx.call_entity(entity_id, "do_nothing") + raise Exception("Simulated exception") + + # Start a worker, which will connect to the sidecar in a background thread + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(empty_orchestrator) + w.add_entity(empty_entity) + w.start() + + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + time.sleep(2) # wait for the signal and orchestration to be processed + id = c.schedule_new_orchestration(empty_orchestrator) + c.wait_for_orchestration_completion(id, timeout=30) + id = c.schedule_new_orchestration(empty_orchestrator) + c.wait_for_orchestration_completion(id, timeout=30) + + assert invoke_count == 2 + + +def test_entity_unlocks_when_user_mishandles_lock(): + invoke_count = 0 + + def empty_entity(ctx: entities.EntityContext, _): + nonlocal invoke_count # don't do this in a real app! + invoke_count += 1 + + def empty_orchestrator(ctx: task.OrchestrationContext, _): + entity_id = entities.EntityInstanceId("empty_entity", "testEntity3") + yield ctx.lock_entities([entity_id]) + yield ctx.call_entity(entity_id, "do_nothing") + + # Start a worker, which will connect to the sidecar in a background thread + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(empty_orchestrator) + w.add_entity(empty_entity) + w.start() + + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + time.sleep(2) # wait for the signal and orchestration to be processed + id = c.schedule_new_orchestration(empty_orchestrator) + c.wait_for_orchestration_completion(id, timeout=30) + id = c.schedule_new_orchestration(empty_orchestrator) + c.wait_for_orchestration_completion(id, timeout=30) + + assert invoke_count == 2 + + +# TODO: Uncomment this test +# Will not pass until https://msazure.visualstudio.com/One/_git/AAPT-DTMB/pullrequest/13610881 is merged and deployed to the docker image +# def test_entity_unlocks_when_user_calls_continue_as_new(): +# invoke_count = 0 + +# def empty_entity(ctx: entities.EntityContext, _): +# nonlocal invoke_count # don't do this in a real app! +# invoke_count += 1 + +# def empty_orchestrator(ctx: task.OrchestrationContext, entity_call_count: int): +# entity_id = entities.EntityInstanceId("empty_entity", "testEntity6") +# nonlocal invoke_count +# if not ctx.is_replaying: +# invoke_count += 1 +# with (yield ctx.lock_entities([entity_id])): +# yield ctx.call_entity(entity_id, "do_nothing") +# if entity_call_count > 0: +# ctx.continue_as_new(entity_call_count - 1, save_events=True) + +# # Start a worker, which will connect to the sidecar in a background thread +# with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, +# taskhub=taskhub_name, token_credential=credential) as w: +# w.add_orchestrator(empty_orchestrator) +# w.add_entity(empty_entity) +# w.start() + +# c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, +# taskhub=taskhub_name, token_credential=credential) +# time.sleep(2) # wait for the signal and orchestration to be processed +# id = c.schedule_new_orchestration(empty_orchestrator, input=2) +# c.wait_for_orchestration_completion(id, timeout=500) + +# assert invoke_count == 6 From ac3d4d8bc20cf40d692c65bd52f8307e88714f74 Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Tue, 30 Sep 2025 15:05:57 -0600 Subject: [PATCH 15/21] Linting --- .../test_dts_function_based_entities_e2e copy.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/durabletask-azuremanaged/test_dts_function_based_entities_e2e copy.py b/tests/durabletask-azuremanaged/test_dts_function_based_entities_e2e copy.py index a15d0f2..b3a5b2f 100644 --- a/tests/durabletask-azuremanaged/test_dts_function_based_entities_e2e copy.py +++ b/tests/durabletask-azuremanaged/test_dts_function_based_entities_e2e copy.py @@ -1,4 +1,3 @@ -from datetime import timedelta import os import time From 88abe0a3d73d40295921c6cba304b7633f12ffd8 Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Tue, 30 Sep 2025 15:07:27 -0600 Subject: [PATCH 16/21] Linting --- durabletask/entities/entity_lock.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/durabletask/entities/entity_lock.py b/durabletask/entities/entity_lock.py index 0768d1f..06377f1 100644 --- a/durabletask/entities/entity_lock.py +++ b/durabletask/entities/entity_lock.py @@ -1,6 +1,3 @@ -import durabletask.internal.orchestrator_service_pb2 as pb - - class EntityLock: def __init__(self, context): self._context = context From 81ca041244bd8e56d44108cf49f2370db98d3c30 Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Tue, 30 Sep 2025 15:17:16 -0600 Subject: [PATCH 17/21] Linting --- durabletask/entities/entity_lock.py | 12 ++++++++++-- durabletask/task.py | 4 ++++ durabletask/worker.py | 2 +- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/durabletask/entities/entity_lock.py b/durabletask/entities/entity_lock.py index 06377f1..bb7921c 100644 --- a/durabletask/entities/entity_lock.py +++ b/durabletask/entities/entity_lock.py @@ -1,9 +1,17 @@ +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from durabletask.task import OrchestrationContext + + class EntityLock: - def __init__(self, context): + # Note: This should + def __init__(self, context: 'OrchestrationContext'): self._context = context def __enter__(self): return self - def __exit__(self, exc_type, exc_val, exc_tb): # TODO: Handle exceptions? + def __exit__(self, exc_type, exc_val, exc_tb): self._context._exit_critical_section() diff --git a/durabletask/task.py b/durabletask/task.py index ddb9ef0..23723fb 100644 --- a/durabletask/task.py +++ b/durabletask/task.py @@ -258,6 +258,10 @@ def continue_as_new(self, new_input: Any, *, save_events: bool = False) -> None: """ pass + @abstractmethod + def _exit_critical_section(self) -> None: + pass + class FailureDetails: def __init__(self, message: str, error_type: str, stack_trace: Optional[str]): diff --git a/durabletask/worker.py b/durabletask/worker.py index 5463835..09f6559 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -1169,7 +1169,7 @@ def lock_entities_function_helper(self, id: int, entities: list[EntityInstanceId def _exit_critical_section(self) -> None: if not self._entity_context.is_inside_critical_section: # Possible if the user calls continue_as_new inside the lock - in the success case, we will call - # _exit_critical_section both from the EntityLock and the exit logic. We must keep both calls in + # _exit_critical_section both from the EntityLock and the continue_as_new logic. We must keep both calls in # case the user code crashes after calling continue_as_new but before the EntityLock object is exited. return for entity_unlock_message in self._entity_context.emit_lock_release_messages(): From a4d9e74fb0d979d880b1341a26ca80c6b83f1640 Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Tue, 30 Sep 2025 15:18:38 -0600 Subject: [PATCH 18/21] Linting --- durabletask/task.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/durabletask/task.py b/durabletask/task.py index 23723fb..2f49371 100644 --- a/durabletask/task.py +++ b/durabletask/task.py @@ -184,7 +184,7 @@ def signal_entity( def lock_entities(self, entities: list[EntityInstanceId]) -> Task[EntityLock]: """Creates a Task object that locks the specified entity instances. - The locks will be acquired the next time the orchestrator yields. + The locks will be acquired the next time the orchestrator yields. Best practice is to immediately yield this Task and enter the returned EntityLock. The lock is released when the EntityLock is exited. From bd49c8f0bdbd2469077394733c6dcdcfee9c5fd1 Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Tue, 30 Sep 2025 15:38:53 -0600 Subject: [PATCH 19/21] Fix tests --- ...st_dts_function_based_entities_e2e copy.py | 34 +++++++++---------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/tests/durabletask-azuremanaged/test_dts_function_based_entities_e2e copy.py b/tests/durabletask-azuremanaged/test_dts_function_based_entities_e2e copy.py index b3a5b2f..4220655 100644 --- a/tests/durabletask-azuremanaged/test_dts_function_based_entities_e2e copy.py +++ b/tests/durabletask-azuremanaged/test_dts_function_based_entities_e2e copy.py @@ -2,7 +2,6 @@ import time import pytest -from azure.identity import DefaultAzureCredential from durabletask import client, entities, task from durabletask.azuremanaged.client import DurableTaskSchedulerClient @@ -15,8 +14,6 @@ # Read the environment variables taskhub_name = os.getenv("TASKHUB", "default") endpoint = os.getenv("ENDPOINT", "http://localhost:8080") -# endpoint = os.getenv("ENDPOINT", "https://andy-dts-testin-byaje2c8.northcentralus.durabletask.io") -credential = None if endpoint == "http://localhost:8080" else DefaultAzureCredential() def test_client_signal_entity(): @@ -51,7 +48,7 @@ def empty_entity(ctx: entities.EntityContext, _): invoked = True def empty_orchestrator(ctx: task.OrchestrationContext, _): - entity_id = entities.EntityInstanceId("empty_entity", "testEntity") + entity_id = entities.EntityInstanceId("empty_entity", f"{ctx.instance_id}_testEntity") ctx.signal_entity(entity_id, "do_nothing") # Start a worker, which will connect to the sidecar in a background thread @@ -87,7 +84,7 @@ def empty_entity(ctx: entities.EntityContext, _): invoked = True def empty_orchestrator(ctx: task.OrchestrationContext, _): - entity_id = entities.EntityInstanceId("empty_entity", "testEntity") + entity_id = entities.EntityInstanceId("empty_entity", f"{ctx.instance_id}_testEntity") yield ctx.call_entity(entity_id, "do_nothing") # Start a worker, which will connect to the sidecar in a background thread @@ -122,7 +119,7 @@ def empty_entity(ctx: entities.EntityContext, _): invoked = True def empty_orchestrator(ctx: task.OrchestrationContext, _): - entity_id = entities.EntityInstanceId("empty_entity", "testEntity") + entity_id = entities.EntityInstanceId("empty_entity", f"{ctx.instance_id}_testEntity") with (yield ctx.lock_entities([entity_id])): yield ctx.call_entity(entity_id, "do_nothing") @@ -170,11 +167,12 @@ def empty_entity(ctx: entities.EntityContext, _): nonlocal invoked # don't do this in a real app! invoked = True elif ctx.operation == "signal_other": - entity_id = entities.EntityInstanceId("empty_entity", "otherEntity") + entity_id = entities.EntityInstanceId("empty_entity", + ctx.entity_id.key.replace("_testEntity", "_otherEntity")) ctx.signal_entity(entity_id, "do_nothing") def empty_orchestrator(ctx: task.OrchestrationContext, _): - entity_id = entities.EntityInstanceId("empty_entity", "testEntity") + entity_id = entities.EntityInstanceId("empty_entity", f"{ctx.instance_id}_testEntity") yield ctx.call_entity(entity_id, "signal_other") # Start a worker, which will connect to the sidecar in a background thread @@ -231,13 +229,14 @@ def empty_entity(ctx: entities.EntityContext, _): pass def empty_orchestrator(ctx: task.OrchestrationContext, _): - entity_id = entities.EntityInstanceId("empty_entity", "testEntity") + entity_id = entities.EntityInstanceId("empty_entity", f"{ctx.instance_id}_testEntity") with (yield ctx.lock_entities([entity_id])): # Cannot signal entities that have been locked assert pytest.raises(Exception, ctx.signal_entity, entity_id, "do_nothing") - ctx.call_entity(entity_id, "do_nothing") + entity_call_task = ctx.call_entity(entity_id, "do_nothing") # Cannot call entities that have been locked and already called, but not yet returned a result assert pytest.raises(Exception, ctx.call_entity, entity_id, "do_nothing") + yield entity_call_task # Start a worker, which will connect to the sidecar in a background thread with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, @@ -269,7 +268,7 @@ def empty_entity(ctx: entities.EntityContext, _): invoke_count += 1 def empty_orchestrator(ctx: task.OrchestrationContext, _): - entity_id = entities.EntityInstanceId("empty_entity", "testEntity3") + entity_id = entities.EntityInstanceId("empty_entity", f"{ctx.instance_id}_testEntity") with (yield ctx.lock_entities([entity_id])): yield ctx.call_entity(entity_id, "do_nothing") raise Exception("Simulated exception") @@ -300,7 +299,7 @@ def empty_entity(ctx: entities.EntityContext, _): invoke_count += 1 def empty_orchestrator(ctx: task.OrchestrationContext, _): - entity_id = entities.EntityInstanceId("empty_entity", "testEntity3") + entity_id = entities.EntityInstanceId("empty_entity", f"{ctx.instance_id}_testEntity") yield ctx.lock_entities([entity_id]) yield ctx.call_entity(entity_id, "do_nothing") @@ -323,7 +322,8 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): # TODO: Uncomment this test -# Will not pass until https://msazure.visualstudio.com/One/_git/AAPT-DTMB/pullrequest/13610881 is merged and deployed to the docker image +# Will not pass until https://msazure.visualstudio.com/One/_git/AAPT-DTMB/pullrequest/13610881 is merged and +# deployed to the docker image # def test_entity_unlocks_when_user_calls_continue_as_new(): # invoke_count = 0 @@ -332,7 +332,7 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): # invoke_count += 1 # def empty_orchestrator(ctx: task.OrchestrationContext, entity_call_count: int): -# entity_id = entities.EntityInstanceId("empty_entity", "testEntity6") +# entity_id = entities.EntityInstanceId("empty_entity", "testEntity") # nonlocal invoke_count # if not ctx.is_replaying: # invoke_count += 1 @@ -343,15 +343,15 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): # # Start a worker, which will connect to the sidecar in a background thread # with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, -# taskhub=taskhub_name, token_credential=credential) as w: +# taskhub=taskhub_name, token_credential=None) as w: # w.add_orchestrator(empty_orchestrator) # w.add_entity(empty_entity) # w.start() # c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, -# taskhub=taskhub_name, token_credential=credential) +# taskhub=taskhub_name, token_credential=None) # time.sleep(2) # wait for the signal and orchestration to be processed # id = c.schedule_new_orchestration(empty_orchestrator, input=2) -# c.wait_for_orchestration_completion(id, timeout=500) +# c.wait_for_orchestration_completion(id, timeout=30) # assert invoke_count == 6 From 6cb1f4ef4e00039d34fec36fb333c34c8c759214 Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Tue, 30 Sep 2025 15:54:23 -0600 Subject: [PATCH 20/21] 3.9 compat, fix tests for action output --- durabletask/entities/durable_entity.py | 4 +- durabletask/entities/entity_context.py | 4 +- durabletask/internal/entity_state_shim.py | 4 +- ...> test_dts_function_based_entities_e2e.py} | 0 .../test_orchestration_executor.py | 62 +++++++++---------- 5 files changed, 37 insertions(+), 37 deletions(-) rename tests/durabletask-azuremanaged/{test_dts_function_based_entities_e2e copy.py => test_dts_function_based_entities_e2e.py} (100%) diff --git a/durabletask/entities/durable_entity.py b/durabletask/entities/durable_entity.py index 7c81e79..34ea3e0 100644 --- a/durabletask/entities/durable_entity.py +++ b/durabletask/entities/durable_entity.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Type, TypeVar, overload +from typing import Any, Optional, Type, TypeVar, Union, overload from durabletask.entities.entity_context import EntityContext from durabletask.entities.entity_instance_id import EntityInstanceId @@ -22,7 +22,7 @@ def get_state(self, intended_type: Type[TState]) -> Optional[TState]: def get_state(self, intended_type: None = None, default: Any = None) -> Any: ... - def get_state(self, intended_type: Optional[Type[TState]] = None, default: Optional[TState] = None) -> Optional[TState] | Any: + def get_state(self, intended_type: Optional[Type[TState]] = None, default: Optional[TState] = None) -> Union[None, TState, Any]: """Get the current state of the entity, optionally converting it to a specified type. Parameters diff --git a/durabletask/entities/entity_context.py b/durabletask/entities/entity_context.py index 974767b..da5ef19 100644 --- a/durabletask/entities/entity_context.py +++ b/durabletask/entities/entity_context.py @@ -1,5 +1,5 @@ -from typing import Any, Optional, Type, TypeVar, overload +from typing import Any, Optional, Type, TypeVar, Union, overload import uuid from durabletask.entities.entity_instance_id import EntityInstanceId from durabletask.internal import helpers, shared @@ -54,7 +54,7 @@ def get_state(self, intended_type: Type[TState]) -> Optional[TState]: def get_state(self, intended_type: None = None, default: Any = None) -> Any: ... - def get_state(self, intended_type: Optional[Type[TState]] = None, default: Optional[TState] = None) -> Optional[TState] | Any: + def get_state(self, intended_type: Optional[Type[TState]] = None, default: Optional[TState] = None) -> Union[None, TState, Any]: """Get the current state of the entity, optionally converting it to a specified type. Parameters diff --git a/durabletask/internal/entity_state_shim.py b/durabletask/internal/entity_state_shim.py index f27edc5..57a9e15 100644 --- a/durabletask/internal/entity_state_shim.py +++ b/durabletask/internal/entity_state_shim.py @@ -1,4 +1,4 @@ -from typing import Any, TypeVar +from typing import Any, TypeVar, Union from typing import Optional, Type, overload import durabletask.internal.orchestrator_service_pb2 as pb @@ -25,7 +25,7 @@ def get_state(self, intended_type: Type[TState]) -> Optional[TState]: def get_state(self, intended_type: None = None, default: Any = None) -> Any: ... - def get_state(self, intended_type: Optional[Type[TState]] = None, default: Optional[TState] = None) -> Optional[TState] | Any: + def get_state(self, intended_type: Optional[Type[TState]] = None, default: Optional[TState] = None) -> Union[None, TState, Any]: if self._current_state is None and default is not None: return default diff --git a/tests/durabletask-azuremanaged/test_dts_function_based_entities_e2e copy.py b/tests/durabletask-azuremanaged/test_dts_function_based_entities_e2e.py similarity index 100% rename from tests/durabletask-azuremanaged/test_dts_function_based_entities_e2e copy.py rename to tests/durabletask-azuremanaged/test_dts_function_based_entities_e2e.py diff --git a/tests/durabletask/test_orchestration_executor.py b/tests/durabletask/test_orchestration_executor.py index cb77c81..b2f65cd 100644 --- a/tests/durabletask/test_orchestration_executor.py +++ b/tests/durabletask/test_orchestration_executor.py @@ -40,7 +40,7 @@ def orchestrator(ctx: task.OrchestrationContext, my_input: int): result = executor.execute(TEST_INSTANCE_ID, [], new_events) actions = result.actions - complete_action = get_and_validate_single_complete_orchestration_action(actions) + complete_action = get_and_validate_complete_orchestration_action_list(1, actions) assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_COMPLETED assert complete_action.result is not None @@ -62,7 +62,7 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): result = executor.execute(TEST_INSTANCE_ID, [], new_events) actions = result.actions - complete_action = get_and_validate_single_complete_orchestration_action(actions) + complete_action = get_and_validate_complete_orchestration_action_list(1, actions) assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_COMPLETED assert complete_action.result.value == '"done"' # results are JSON-encoded @@ -77,7 +77,7 @@ def test_orchestrator_not_registered(): result = executor.execute(TEST_INSTANCE_ID, [], new_events) actions = result.actions - complete_action = get_and_validate_single_complete_orchestration_action(actions) + complete_action = get_and_validate_complete_orchestration_action_list(1, actions) assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_FAILED assert complete_action.failureDetails.errorType == "OrchestratorNotRegisteredError" assert complete_action.failureDetails.errorMessage @@ -137,7 +137,7 @@ def delay_orchestrator(ctx: task.OrchestrationContext, _): result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions - complete_action = get_and_validate_single_complete_orchestration_action(actions) + complete_action = get_and_validate_complete_orchestration_action_list(1, actions) assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_COMPLETED assert complete_action.result is not None assert complete_action.result.value == '"done"' # results are JSON-encoded @@ -196,7 +196,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions - complete_action = get_and_validate_single_complete_orchestration_action(actions) + complete_action = get_and_validate_complete_orchestration_action_list(1, actions) assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_COMPLETED assert complete_action.result.value == encoded_output @@ -225,7 +225,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions - complete_action = get_and_validate_single_complete_orchestration_action(actions) + complete_action = get_and_validate_complete_orchestration_action_list(1, actions) assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_FAILED assert complete_action.failureDetails.errorType == 'TaskFailedError' # TODO: Should this be the specific error? assert str(ex) in complete_action.failureDetails.errorMessage @@ -405,7 +405,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions - assert len(actions) == 1 + assert len(actions) == 7 assert actions[0].completeOrchestration.failureDetails.errorMessage.__contains__("Activity task #1 failed: Kah-BOOOOM!!!") assert actions[0].id == 7 @@ -433,7 +433,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions - complete_action = get_and_validate_single_complete_orchestration_action(actions) + complete_action = get_and_validate_complete_orchestration_action_list(1, actions) assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_FAILED assert complete_action.failureDetails.errorType == 'NonDeterminismError' assert "1" in complete_action.failureDetails.errorMessage # task ID @@ -461,7 +461,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions - complete_action = get_and_validate_single_complete_orchestration_action(actions) + complete_action = get_and_validate_complete_orchestration_action_list(1, actions) assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_FAILED assert complete_action.failureDetails.errorType == 'NonDeterminismError' assert "1" in complete_action.failureDetails.errorMessage # task ID @@ -491,7 +491,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions - complete_action = get_and_validate_single_complete_orchestration_action(actions) + complete_action = get_and_validate_complete_orchestration_action_list(1, actions) assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_FAILED assert complete_action.failureDetails.errorType == 'NonDeterminismError' assert "1" in complete_action.failureDetails.errorMessage # task ID @@ -522,7 +522,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions - complete_action = get_and_validate_single_complete_orchestration_action(actions) + complete_action = get_and_validate_complete_orchestration_action_list(1, actions) assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_FAILED assert complete_action.failureDetails.errorType == 'NonDeterminismError' assert "1" in complete_action.failureDetails.errorMessage # task ID @@ -556,7 +556,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions - complete_action = get_and_validate_single_complete_orchestration_action(actions) + complete_action = get_and_validate_complete_orchestration_action_list(1, actions) assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_COMPLETED assert complete_action.result.value == "42" @@ -586,7 +586,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions - complete_action = get_and_validate_single_complete_orchestration_action(actions) + complete_action = get_and_validate_complete_orchestration_action_list(1, actions) assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_FAILED assert complete_action.failureDetails.errorType == 'TaskFailedError' # TODO: Should this be the specific error? assert str(ex) in complete_action.failureDetails.errorMessage @@ -617,7 +617,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions - complete_action = get_and_validate_single_complete_orchestration_action(actions) + complete_action = get_and_validate_complete_orchestration_action_list(1, actions) assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_FAILED assert complete_action.failureDetails.errorType == 'NonDeterminismError' assert "1" in complete_action.failureDetails.errorMessage # task ID @@ -647,7 +647,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions - complete_action = get_and_validate_single_complete_orchestration_action(actions) + complete_action = get_and_validate_complete_orchestration_action_list(1, actions) assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_FAILED assert complete_action.failureDetails.errorType == 'NonDeterminismError' assert "1" in complete_action.failureDetails.errorMessage # task ID @@ -682,7 +682,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions - complete_action = get_and_validate_single_complete_orchestration_action(actions) + complete_action = get_and_validate_complete_orchestration_action_list(1, actions) assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_COMPLETED assert complete_action.result.value == "42" @@ -718,7 +718,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions - complete_action = get_and_validate_single_complete_orchestration_action(actions) + complete_action = get_and_validate_complete_orchestration_action_list(1, actions) assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_COMPLETED assert complete_action.result.value == "42" @@ -753,7 +753,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions - complete_action = get_and_validate_single_complete_orchestration_action(actions) + complete_action = get_and_validate_complete_orchestration_action_list(1, actions) assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_COMPLETED assert complete_action.result.value == "42" @@ -779,7 +779,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions - complete_action = get_and_validate_single_complete_orchestration_action(actions) + complete_action = get_and_validate_complete_orchestration_action_list(1, actions) assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_TERMINATED assert complete_action.result.value == json.dumps("terminated!") @@ -808,7 +808,7 @@ def orchestrator(ctx: task.OrchestrationContext, input: int): executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions - complete_action = get_and_validate_single_complete_orchestration_action(actions) + complete_action = get_and_validate_complete_orchestration_action_list(1, actions) assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW assert complete_action.result.value == json.dumps(2) assert len(complete_action.carryoverEvents) == (3 if save_events else 0) @@ -893,7 +893,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions - complete_action = get_and_validate_single_complete_orchestration_action(actions) + complete_action = get_and_validate_complete_orchestration_action_list(1, actions) assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_COMPLETED assert complete_action.result.value == "[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]" @@ -935,7 +935,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions - complete_action = get_and_validate_single_complete_orchestration_action(actions) + complete_action = get_and_validate_complete_orchestration_action_list(1, actions) assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_FAILED assert complete_action.failureDetails.errorType == 'TaskFailedError' # TODO: Is this the right error type? assert str(ex) in complete_action.failureDetails.errorMessage @@ -983,7 +983,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions - complete_action = get_and_validate_single_complete_orchestration_action(actions) + complete_action = get_and_validate_complete_orchestration_action_list(1, actions) assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_COMPLETED assert complete_action.result.value == encoded_output @@ -993,7 +993,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions - complete_action = get_and_validate_single_complete_orchestration_action(actions) + complete_action = get_and_validate_complete_orchestration_action_list(1, actions) assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_COMPLETED assert complete_action.result.value == encoded_output @@ -1078,7 +1078,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions - complete_action = get_and_validate_single_complete_orchestration_action(actions) + complete_action = get_and_validate_complete_orchestration_action_list(3, actions) assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_COMPLETED assert complete_action.result.value == encoded_output @@ -1177,14 +1177,14 @@ def orchestrator(ctx: task.OrchestrationContext, _): executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions - complete_action = get_and_validate_single_complete_orchestration_action(actions) + complete_action = get_and_validate_complete_orchestration_action_list(4, actions) assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_FAILED assert complete_action.failureDetails.errorType == 'TaskFailedError' # TODO: Should this be the specific error? assert str(ex) in complete_action.failureDetails.errorMessage -def get_and_validate_single_complete_orchestration_action(actions: list[pb.OrchestratorAction]) -> pb.CompleteOrchestrationAction: - assert len(actions) == 1 - assert type(actions[0]) is pb.OrchestratorAction - assert actions[0].HasField("completeOrchestration") - return actions[0].completeOrchestration +def get_and_validate_complete_orchestration_action_list(expected_action_count: int, actions: list[pb.OrchestratorAction]) -> pb.CompleteOrchestrationAction: + assert len(actions) == expected_action_count + assert type(actions[-1]) is pb.OrchestratorAction + assert actions[-1].HasField("completeOrchestration") + return actions[-1].completeOrchestration From 31e9b15ab83b264757eeb7552294ae8b3dd790a4 Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Tue, 30 Sep 2025 15:56:35 -0600 Subject: [PATCH 21/21] Action output test fix --- tests/durabletask/test_orchestration_executor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/durabletask/test_orchestration_executor.py b/tests/durabletask/test_orchestration_executor.py index b2f65cd..5646f07 100644 --- a/tests/durabletask/test_orchestration_executor.py +++ b/tests/durabletask/test_orchestration_executor.py @@ -406,8 +406,8 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 7 - assert actions[0].completeOrchestration.failureDetails.errorMessage.__contains__("Activity task #1 failed: Kah-BOOOOM!!!") - assert actions[0].id == 7 + assert actions[-1].completeOrchestration.failureDetails.errorMessage.__contains__("Activity task #1 failed: Kah-BOOOOM!!!") + assert actions[-1].id == 7 def test_nondeterminism_expected_timer():