From 721069e4e2c1387fdf66367f41416348b2d952d0 Mon Sep 17 00:00:00 2001 From: Laveesh Rohra Date: Tue, 23 Dec 2025 16:22:20 -0800 Subject: [PATCH 01/11] Add worker and clients --- .../__init__.py | 3 +- .../agent_framework_azurefunctions/_app.py | 18 +- .../agent_framework_azurefunctions/_models.py | 201 --------- .../_orchestration.py | 220 ++-------- .../azurefunctions/tests/test_models.py | 402 ------------------ .../tests/test_orchestration.py | 206 ++++----- .../agent_framework/_workflows/_handoff.py | 6 +- .../tests/workflow/test_workflow_builder.py | 6 +- python/packages/durabletask/DESIGN.md | 262 +++++++++--- .../agent_framework_durabletask/__init__.py | 15 +- .../agent_framework_durabletask/_client.py | 79 ++++ .../agent_framework_durabletask/_executors.py | 108 +++++ .../agent_framework_durabletask/_models.py | 107 ++++- .../_orchestration_context.py | 76 ++++ .../agent_framework_durabletask/_shim.py | 168 ++++++++ .../agent_framework_durabletask/_worker.py | 199 +++++++++ .../tests/test_agent_session_id.py | 272 ++++++++++++ .../packages/durabletask/tests/test_client.py | 91 ++++ .../durabletask/tests/test_executors.py | 83 ++++ .../tests/test_orchestration_context.py | 98 +++++ .../packages/durabletask/tests/test_shim.py | 197 +++++++++ .../packages/durabletask/tests/test_worker.py | 168 ++++++++ 22 files changed, 1983 insertions(+), 1002 deletions(-) delete mode 100644 python/packages/azurefunctions/agent_framework_azurefunctions/_models.py delete mode 100644 python/packages/azurefunctions/tests/test_models.py create mode 100644 python/packages/durabletask/agent_framework_durabletask/_client.py create mode 100644 python/packages/durabletask/agent_framework_durabletask/_executors.py create mode 100644 python/packages/durabletask/agent_framework_durabletask/_orchestration_context.py create mode 100644 python/packages/durabletask/agent_framework_durabletask/_shim.py create mode 100644 python/packages/durabletask/agent_framework_durabletask/_worker.py create mode 100644 python/packages/durabletask/tests/test_agent_session_id.py create mode 100644 python/packages/durabletask/tests/test_client.py create mode 100644 python/packages/durabletask/tests/test_executors.py create mode 100644 python/packages/durabletask/tests/test_orchestration_context.py create mode 100644 python/packages/durabletask/tests/test_shim.py create mode 100644 python/packages/durabletask/tests/test_worker.py diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/__init__.py b/python/packages/azurefunctions/agent_framework_azurefunctions/__init__.py index 960b8d4023..e5be2aa36e 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/__init__.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/__init__.py @@ -2,10 +2,9 @@ import importlib.metadata -from agent_framework_durabletask import AgentCallbackContext, AgentResponseCallbackProtocol +from agent_framework_durabletask import AgentCallbackContext, AgentResponseCallbackProtocol, DurableAIAgent from ._app import AgentFunctionApp -from ._orchestration import DurableAIAgent try: __version__ = importlib.metadata.version(__name__) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py index 80f2b6dec9..2f4df67e58 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py @@ -28,14 +28,15 @@ WAIT_FOR_RESPONSE_FIELD, WAIT_FOR_RESPONSE_HEADER, AgentResponseCallbackProtocol, + AgentSessionId, DurableAgentState, + DurableAIAgent, RunRequest, ) from ._entities import create_agent_entity from ._errors import IncomingRequestError -from ._models import AgentSessionId -from ._orchestration import AgentOrchestrationContextType, DurableAIAgent +from ._orchestration import AgentOrchestrationContextType, AzureFunctionsAgentExecutor logger = get_logger("agent_framework.azurefunctions") @@ -314,7 +315,8 @@ def get_agent( if normalized_name not in self._agent_metadata: raise ValueError(f"Agent '{normalized_name}' is not registered with this app.") - return DurableAIAgent(context, normalized_name) + executor = AzureFunctionsAgentExecutor(context) + return DurableAIAgent(executor, normalized_name) def _setup_agent_functions( self, @@ -407,7 +409,10 @@ async def http_start(req: func.HttpRequest, client: df.DurableOrchestrationClien logger.debug(f"[HTTP Trigger] Generated correlation ID: {correlation_id}") logger.debug("[HTTP Trigger] Calling entity to run agent...") - entity_instance_id = session_id.to_entity_id() + entity_instance_id = df.EntityId( + name=session_id.entity_name, + key=session_id.key, + ) run_request = self._build_request_data( req_body, message, @@ -624,7 +629,10 @@ async def _handle_mcp_tool_invocation( session_id = AgentSessionId.with_random_key(agent_name) # Build entity instance ID - entity_instance_id = session_id.to_entity_id() + entity_instance_id = df.EntityId( + name=session_id.entity_name, + key=session_id.key, + ) # Create run request correlation_id = self._generate_unique_id() diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_models.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_models.py deleted file mode 100644 index 4b069686df..0000000000 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_models.py +++ /dev/null @@ -1,201 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -"""Azure Functions-specific data models for Durable Agent Framework. - -This module contains Azure Functions-specific models: -- AgentSessionId: Entity ID management for Azure Durable Entities -- DurableAgentThread: Thread implementation that tracks AgentSessionId - -Common models like RunRequest have been moved to agent-framework-durabletask. -""" - -from __future__ import annotations - -import uuid -from collections.abc import MutableMapping -from dataclasses import dataclass -from typing import Any - -import azure.durable_functions as df -from agent_framework import AgentThread - - -@dataclass -class AgentSessionId: - """Represents an agent session ID, which is used to identify a long-running agent session. - - Attributes: - name: The name of the agent that owns the session (case-insensitive) - key: The unique key of the agent session (case-sensitive) - """ - - name: str - key: str - - ENTITY_NAME_PREFIX: str = "dafx-" - - @staticmethod - def to_entity_name(name: str) -> str: - """Converts an agent name to an entity name by adding the DAFx prefix. - - Args: - name: The agent name - - Returns: - The entity name with the dafx- prefix - """ - return f"{AgentSessionId.ENTITY_NAME_PREFIX}{name}" - - @staticmethod - def with_random_key(name: str) -> AgentSessionId: - """Creates a new AgentSessionId with the specified name and a randomly generated key. - - Args: - name: The name of the agent that owns the session - - Returns: - A new AgentSessionId with the specified name and a random GUID key - """ - return AgentSessionId(name=name, key=uuid.uuid4().hex) - - def to_entity_id(self) -> df.EntityId: - """Converts this AgentSessionId to a Durable Functions EntityId. - - Returns: - EntityId for use with Durable Functions APIs - """ - return df.EntityId(self.to_entity_name(self.name), self.key) - - @staticmethod - def from_entity_id(entity_id: df.EntityId) -> AgentSessionId: - """Creates an AgentSessionId from a Durable Functions EntityId. - - Args: - entity_id: The EntityId to convert - - Returns: - AgentSessionId instance - - Raises: - ValueError: If the entity ID does not have the expected prefix - """ - if not entity_id.name.startswith(AgentSessionId.ENTITY_NAME_PREFIX): - raise ValueError( - f"'{entity_id}' is not a valid agent session ID. " - f"Expected entity name to start with '{AgentSessionId.ENTITY_NAME_PREFIX}'" - ) - - agent_name = entity_id.name[len(AgentSessionId.ENTITY_NAME_PREFIX) :] - return AgentSessionId(name=agent_name, key=entity_id.key) - - def __str__(self) -> str: - """Returns a string representation in the form @name@key.""" - return f"@{self.name}@{self.key}" - - def __repr__(self) -> str: - """Returns a detailed string representation.""" - return f"AgentSessionId(name='{self.name}', key='{self.key}')" - - @staticmethod - def parse(session_id_string: str) -> AgentSessionId: - """Parses a string representation of an agent session ID. - - Args: - session_id_string: A string in the form @name@key - - Returns: - AgentSessionId instance - - Raises: - ValueError: If the string format is invalid - """ - if not session_id_string.startswith("@"): - raise ValueError(f"Invalid agent session ID format: {session_id_string}") - - parts = session_id_string[1:].split("@", 1) - if len(parts) != 2: - raise ValueError(f"Invalid agent session ID format: {session_id_string}") - - return AgentSessionId(name=parts[0], key=parts[1]) - - -class DurableAgentThread(AgentThread): - """Durable agent thread that tracks the owning :class:`AgentSessionId`.""" - - _SERIALIZED_SESSION_ID_KEY = "durable_session_id" - - def __init__( - self, - *, - session_id: AgentSessionId | None = None, - service_thread_id: str | None = None, - message_store: Any = None, - context_provider: Any = None, - ) -> None: - super().__init__( - service_thread_id=service_thread_id, - message_store=message_store, - context_provider=context_provider, - ) - self._session_id: AgentSessionId | None = session_id - - @property - def session_id(self) -> AgentSessionId | None: - """Returns the durable agent session identifier for this thread.""" - return self._session_id - - def attach_session(self, session_id: AgentSessionId) -> None: - """Associates the thread with the provided :class:`AgentSessionId`.""" - self._session_id = session_id - - @classmethod - def from_session_id( - cls, - session_id: AgentSessionId, - *, - service_thread_id: str | None = None, - message_store: Any = None, - context_provider: Any = None, - ) -> DurableAgentThread: - """Creates a durable thread pre-associated with the supplied session ID.""" - return cls( - session_id=session_id, - service_thread_id=service_thread_id, - message_store=message_store, - context_provider=context_provider, - ) - - async def serialize(self, **kwargs: Any) -> dict[str, Any]: - """Serializes thread state including the durable session identifier.""" - state = await super().serialize(**kwargs) - if self._session_id is not None: - state[self._SERIALIZED_SESSION_ID_KEY] = str(self._session_id) - return state - - @classmethod - async def deserialize( - cls, - serialized_thread_state: MutableMapping[str, Any], - *, - message_store: Any = None, - **kwargs: Any, - ) -> DurableAgentThread: - """Restores a durable thread, rehydrating the stored session identifier.""" - state_payload = dict(serialized_thread_state) - session_id_value = state_payload.pop(cls._SERIALIZED_SESSION_ID_KEY, None) - thread = await super().deserialize( - state_payload, - message_store=message_store, - **kwargs, - ) - if not isinstance(thread, DurableAgentThread): - raise TypeError("Deserialized thread is not a DurableAgentThread instance") - - if session_id_value is None: - return thread - - if not isinstance(session_id_value, str): - raise ValueError("durable_session_id must be a string when present in serialized state") - - thread.attach_session(AgentSessionId.parse(session_id_value)) - return thread diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py index cf91132916..5e03bd934d 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py @@ -5,25 +5,20 @@ This module provides support for using agents inside Durable Function orchestrations. """ -import uuid -from collections.abc import AsyncIterator, Callable -from typing import TYPE_CHECKING, Any, TypeAlias, cast +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, TypeAlias +import azure.durable_functions as df from agent_framework import ( - AgentProtocol, AgentRunResponse, - AgentRunResponseUpdate, AgentThread, - ChatMessage, get_logger, ) -from agent_framework_durabletask import RunRequest +from agent_framework_durabletask import AgentSessionId, DurableAgentExecutor, DurableAgentThread, RunRequest from azure.durable_functions.models import TaskBase from azure.durable_functions.models.Task import CompoundTask, TaskState from pydantic import BaseModel -from ._models import AgentSessionId, DurableAgentThread - logger = get_logger("agent_framework.azurefunctions.orchestration") CompoundActionConstructor: TypeAlias = Callable[[list[Any]], Any] | None @@ -151,217 +146,64 @@ def _ensure_response_format( ) -class DurableAIAgent(AgentProtocol): - """A durable agent implementation that uses entity methods to interact with agent entities. - - This class implements AgentProtocol and provides methods to work with Azure Durable Functions - orchestrations, which use generators and yield instead of async/await. - - Key methods: - - get_new_thread(): Create a new conversation thread - - run(): Execute the agent and return a Task for yielding in orchestrations - - Note: The run() method is NOT async. It returns a Task directly that must be - yielded in orchestrations to wait for the entity call to complete. - - Example usage in orchestration: - writer = app.get_agent(context, "WriterAgent") - thread = writer.get_new_thread() # NOT yielded - returns immediately - - response = yield writer.run( # Yielded - waits for entity call - message="Write a haiku about coding", - thread=thread - ) - """ - - def __init__(self, context: AgentOrchestrationContextType, agent_name: str): - """Initialize the DurableAIAgent. +class AzureFunctionsAgentExecutor(DurableAgentExecutor): + """Executor that executes durable agents inside Azure Functions orchestrations.""" - Args: - context: The orchestration context - agent_name: Name of the agent (used to construct entity ID) - """ + def __init__(self, context: AgentOrchestrationContextType): self.context = context - self.agent_name = agent_name - self._id = str(uuid.uuid4()) - self._name = agent_name - self._display_name = agent_name - self._description = f"Durable agent proxy for {agent_name}" - logger.debug("[DurableAIAgent] Initialized for agent: %s", agent_name) - - @property - def id(self) -> str: - """Get the unique identifier for this agent.""" - return self._id - - @property - def name(self) -> str | None: - """Get the name of the agent.""" - return self._name - - @property - def display_name(self) -> str: - """Get the display name of the agent.""" - return self._display_name - - @property - def description(self) -> str | None: - """Get the description of the agent.""" - return self._description - - # We return an AgentTask here which is a TaskBase subclass. - # This is an intentional deviation from AgentProtocol which defines run() as async. - # The AgentTask can be yielded in Durable Functions orchestrations and will provide - # a typed AgentRunResponse result. - def run( # type: ignore[override] + + def run_durable_agent( self, - messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, - *, + agent_name: str, + message: str, thread: AgentThread | None = None, response_format: type[BaseModel] | None = None, + enable_tool_calls: bool | None = None, **kwargs: Any, ) -> AgentTask: - """Execute the agent with messages and return an AgentTask for orchestrations. + # Extract optional parameters + enable_tools = True if enable_tool_calls is None else enable_tool_calls - This method implements AgentProtocol and returns an AgentTask (subclass of TaskBase) - that can be yielded in Durable Functions orchestrations. The task's result will be - a typed AgentRunResponse. - - Args: - messages: The message(s) to send to the agent - thread: Optional agent thread for conversation context - response_format: Optional Pydantic model for response parsing - **kwargs: Additional arguments (enable_tool_calls) - - Returns: - An AgentTask that resolves to an AgentRunResponse when yielded - - Example: - @app.orchestration_trigger(context_name="context") - def my_orchestration(context): - agent = app.get_agent(context, "MyAgent") - thread = agent.get_new_thread() - response = yield agent.run("Hello", thread=thread) - # response is typed as AgentRunResponse - """ - message_str = self._normalize_messages(messages) - - # Extract optional parameters from kwargs - enable_tool_calls = kwargs.get("enable_tool_calls", True) - - # Get the session ID for the entity + # Resolve session if isinstance(thread, DurableAgentThread) and thread.session_id is not None: session_id = thread.session_id else: - # Create a unique session ID for each call when no thread is provided - # This ensures each call gets its own conversation context session_key = str(self.context.new_uuid()) - session_id = AgentSessionId(name=self.agent_name, key=session_key) - logger.debug("[DurableAIAgent] No thread provided, created unique session_id: %s", session_id) - - # Create entity ID from session ID - entity_id = session_id.to_entity_id() + session_id = AgentSessionId(name=agent_name, key=session_key) + logger.debug( + "[AzureFunctionsAgentProvider] No thread provided, created session_id: %s", + session_id, + ) - # Generate a deterministic correlation ID for this call - # This is required by the entity and must be unique per call + entity_id = df.EntityId( + name=session_id.entity_name, + key=session_id.key, + ) correlation_id = str(self.context.new_uuid()) logger.debug( - "[DurableAIAgent] Using correlation_id: %s for entity_id: %s for session_id: %s", + "[AzureFunctionsAgentProvider] correlation_id: %s entity_id: %s session_id: %s", correlation_id, entity_id, session_id, ) - # Prepare the request using RunRequest model - # Include the orchestration's instance_id so it can be stored in the agent's entity state run_request = RunRequest( - message=message_str, - enable_tool_calls=enable_tool_calls, + message=message, + enable_tool_calls=enable_tools, correlation_id=correlation_id, response_format=response_format, orchestration_id=self.context.instance_id, created_at=self.context.current_utc_datetime, ) - logger.debug("[DurableAIAgent] Calling entity %s with message: %s", entity_id, message_str[:100]) - - # Call the entity to get the underlying task entity_task = self.context.call_entity(entity_id, "run", run_request.to_dict()) - - # Wrap it in an AgentTask that will convert the result to AgentRunResponse - agent_task = AgentTask( + return AgentTask( entity_task=entity_task, response_format=response_format, correlation_id=correlation_id, ) - logger.debug( - "[DurableAIAgent] Created AgentTask for correlation_id %s", - correlation_id, - ) - - return agent_task - - def run_stream( - self, - messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, - *, - thread: AgentThread | None = None, - **kwargs: Any, - ) -> AsyncIterator[AgentRunResponseUpdate]: - """Run the agent with streaming (not supported for durable agents). - - Raises: - NotImplementedError: Streaming is not supported for durable agents. - """ - raise NotImplementedError("Streaming is not supported for durable agents in orchestrations.") - - def get_new_thread(self, **kwargs: Any) -> AgentThread: - """Create a new agent thread for this orchestration instance. - - Each call creates a unique thread with its own conversation context. - The session ID is deterministic (uses context.new_uuid()) to ensure - orchestration replay works correctly. - - Returns: - A new AgentThread instance with a unique session ID - """ - # Generate a deterministic unique key for this thread - # Using context.new_uuid() ensures the same GUID is generated during replay + def get_new_thread(self, agent_name: str, **kwargs: Any) -> DurableAgentThread: session_key = str(self.context.new_uuid()) - - # Create AgentSessionId with agent name and session key - session_id = AgentSessionId(name=self.agent_name, key=session_key) - - thread = DurableAgentThread.from_session_id(session_id, **kwargs) - - logger.debug("[DurableAIAgent] Created new thread with session_id: %s", session_id) - return thread - - def _messages_to_string(self, messages: list[ChatMessage]) -> str: - """Convert a list of ChatMessage objects to a single string. - - Args: - messages: List of ChatMessage objects - - Returns: - Concatenated string of message contents - """ - return "\n".join([msg.text or "" for msg in messages]) - - def _normalize_messages(self, messages: str | ChatMessage | list[str] | list[ChatMessage] | None) -> str: - """Convert supported message inputs to a single string.""" - if messages is None: - return "" - if isinstance(messages, str): - return messages - if isinstance(messages, ChatMessage): - return messages.text or "" - if isinstance(messages, list): - if not messages: - return "" - first_item = messages[0] - if isinstance(first_item, str): - return "\n".join(cast(list[str], messages)) - return self._messages_to_string(cast(list[ChatMessage], messages)) - return str(messages) + session_id = AgentSessionId(name=agent_name, key=session_key) + return DurableAgentThread.from_session_id(session_id, **kwargs) diff --git a/python/packages/azurefunctions/tests/test_models.py b/python/packages/azurefunctions/tests/test_models.py deleted file mode 100644 index 5f4dc47080..0000000000 --- a/python/packages/azurefunctions/tests/test_models.py +++ /dev/null @@ -1,402 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -"""Unit tests for data models (AgentSessionId, RunRequest, AgentResponse).""" - -import azure.durable_functions as df -import pytest -from agent_framework import Role -from agent_framework_durabletask import RunRequest -from pydantic import BaseModel - -from agent_framework_azurefunctions._models import AgentSessionId - - -class ModuleStructuredResponse(BaseModel): - value: int - - -class TestAgentSessionId: - """Test suite for AgentSessionId.""" - - def test_init_creates_session_id(self) -> None: - """Test that AgentSessionId initializes correctly.""" - session_id = AgentSessionId(name="AgentEntity", key="test-key-123") - - assert session_id.name == "AgentEntity" - assert session_id.key == "test-key-123" - - def test_with_random_key_generates_guid(self) -> None: - """Test that with_random_key generates a GUID.""" - session_id = AgentSessionId.with_random_key(name="AgentEntity") - - assert session_id.name == "AgentEntity" - assert len(session_id.key) == 32 # UUID hex is 32 chars - # Verify it's a valid hex string - int(session_id.key, 16) - - def test_with_random_key_unique_keys(self) -> None: - """Test that with_random_key generates unique keys.""" - session_id1 = AgentSessionId.with_random_key(name="AgentEntity") - session_id2 = AgentSessionId.with_random_key(name="AgentEntity") - - assert session_id1.key != session_id2.key - - def test_to_entity_id_conversion(self) -> None: - """Test conversion to EntityId.""" - session_id = AgentSessionId(name="AgentEntity", key="test-key") - entity_id = session_id.to_entity_id() - - assert isinstance(entity_id, df.EntityId) - assert entity_id.name == "dafx-AgentEntity" - assert entity_id.key == "test-key" - - def test_from_entity_id_conversion(self) -> None: - """Test creation from EntityId.""" - entity_id = df.EntityId(name="dafx-AgentEntity", key="test-key") - session_id = AgentSessionId.from_entity_id(entity_id) - - assert isinstance(session_id, AgentSessionId) - assert session_id.name == "AgentEntity" - assert session_id.key == "test-key" - - def test_round_trip_entity_id_conversion(self) -> None: - """Test round-trip conversion to and from EntityId.""" - original = AgentSessionId(name="AgentEntity", key="test-key") - entity_id = original.to_entity_id() - restored = AgentSessionId.from_entity_id(entity_id) - - assert restored.name == original.name - assert restored.key == original.key - - def test_str_representation(self) -> None: - """Test string representation.""" - session_id = AgentSessionId(name="AgentEntity", key="test-key-123") - str_repr = str(session_id) - - assert str_repr == "@AgentEntity@test-key-123" - - def test_repr_representation(self) -> None: - """Test repr representation.""" - session_id = AgentSessionId(name="AgentEntity", key="test-key") - repr_str = repr(session_id) - - assert "AgentSessionId" in repr_str - assert "AgentEntity" in repr_str - assert "test-key" in repr_str - - def test_parse_valid_session_id(self) -> None: - """Test parsing valid session ID string.""" - session_id = AgentSessionId.parse("@AgentEntity@test-key-123") - - assert session_id.name == "AgentEntity" - assert session_id.key == "test-key-123" - - def test_parse_invalid_format_no_prefix(self) -> None: - """Test parsing invalid format without @ prefix.""" - with pytest.raises(ValueError) as exc_info: - AgentSessionId.parse("AgentEntity@test-key") - - assert "Invalid agent session ID format" in str(exc_info.value) - - def test_parse_invalid_format_single_part(self) -> None: - """Test parsing invalid format with single part.""" - with pytest.raises(ValueError) as exc_info: - AgentSessionId.parse("@AgentEntity") - - assert "Invalid agent session ID format" in str(exc_info.value) - - def test_parse_with_multiple_at_signs_in_key(self) -> None: - """Test parsing with @ signs in the key.""" - session_id = AgentSessionId.parse("@AgentEntity@key-with@symbols") - - assert session_id.name == "AgentEntity" - assert session_id.key == "key-with@symbols" - - def test_parse_round_trip(self) -> None: - """Test round-trip parse and string conversion.""" - original = AgentSessionId(name="AgentEntity", key="test-key") - str_repr = str(original) - parsed = AgentSessionId.parse(str_repr) - - assert parsed.name == original.name - assert parsed.key == original.key - - def test_to_entity_name_adds_prefix(self) -> None: - """Test that to_entity_name adds the dafx- prefix.""" - entity_name = AgentSessionId.to_entity_name("TestAgent") - assert entity_name == "dafx-TestAgent" - - def test_from_entity_id_strips_prefix(self) -> None: - """Test that from_entity_id strips the dafx- prefix.""" - entity_id = df.EntityId(name="dafx-TestAgent", key="key123") - session_id = AgentSessionId.from_entity_id(entity_id) - - assert session_id.name == "TestAgent" - assert session_id.key == "key123" - - def test_from_entity_id_raises_without_prefix(self) -> None: - """Test that from_entity_id raises ValueError when entity name lacks the prefix.""" - entity_id = df.EntityId(name="TestAgent", key="key123") - - with pytest.raises(ValueError) as exc_info: - AgentSessionId.from_entity_id(entity_id) - - assert "not a valid agent session ID" in str(exc_info.value) - assert "dafx-" in str(exc_info.value) - - -class TestRunRequest: - """Test suite for RunRequest.""" - - def test_init_with_defaults(self) -> None: - """Test RunRequest initialization with defaults.""" - request = RunRequest(message="Hello") - - assert request.message == "Hello" - assert request.role == Role.USER - assert request.response_format is None - assert request.enable_tool_calls is True - - def test_init_with_all_fields(self) -> None: - """Test RunRequest initialization with all fields.""" - schema = ModuleStructuredResponse - request = RunRequest( - message="Hello", - role=Role.SYSTEM, - response_format=schema, - enable_tool_calls=False, - ) - - assert request.message == "Hello" - assert request.role == Role.SYSTEM - assert request.response_format is schema - assert request.enable_tool_calls is False - - def test_init_coerces_string_role(self) -> None: - """Ensure string role values are coerced into Role instances.""" - request = RunRequest(message="Hello", role="system") # type: ignore[arg-type] - - assert request.role == Role.SYSTEM - - def test_to_dict_with_defaults(self) -> None: - """Test to_dict with default values.""" - request = RunRequest(message="Test message") - data = request.to_dict() - - assert data["message"] == "Test message" - assert data["enable_tool_calls"] is True - assert data["role"] == "user" - assert "response_format" not in data or data["response_format"] is None - assert "thread_id" not in data - - def test_to_dict_with_all_fields(self) -> None: - """Test to_dict with all fields.""" - schema = ModuleStructuredResponse - request = RunRequest( - message="Hello", - role=Role.ASSISTANT, - response_format=schema, - enable_tool_calls=False, - ) - data = request.to_dict() - - assert data["message"] == "Hello" - assert data["role"] == "assistant" - assert data["response_format"]["__response_schema_type__"] == "pydantic_model" - assert data["response_format"]["module"] == schema.__module__ - assert data["response_format"]["qualname"] == schema.__qualname__ - assert data["enable_tool_calls"] is False - assert "thread_id" not in data - - def test_from_dict_with_defaults(self) -> None: - """Test from_dict with minimal data.""" - data = {"message": "Hello"} - request = RunRequest.from_dict(data) - - assert request.message == "Hello" - assert request.role == Role.USER - assert request.enable_tool_calls is True - - def test_from_dict_ignores_thread_id_field(self) -> None: - """Ensure legacy thread_id input does not break RunRequest parsing.""" - request = RunRequest.from_dict({"message": "Hello", "thread_id": "ignored"}) - - assert request.message == "Hello" - - def test_from_dict_with_all_fields(self) -> None: - """Test from_dict with all fields.""" - data = { - "message": "Test", - "role": "system", - "response_format": { - "__response_schema_type__": "pydantic_model", - "module": ModuleStructuredResponse.__module__, - "qualname": ModuleStructuredResponse.__qualname__, - }, - "enable_tool_calls": False, - } - request = RunRequest.from_dict(data) - - assert request.message == "Test" - assert request.role == Role.SYSTEM - assert request.response_format is ModuleStructuredResponse - assert request.enable_tool_calls is False - - def test_from_dict_with_unknown_role_preserves_value(self) -> None: - """Test from_dict keeps custom roles intact.""" - data = {"message": "Test", "role": "reviewer"} - request = RunRequest.from_dict(data) - - assert request.role.value == "reviewer" - assert request.role != Role.USER - - def test_from_dict_empty_message(self) -> None: - """Test from_dict with empty message.""" - request = RunRequest.from_dict({}) - - assert request.message == "" - assert request.role == Role.USER - - def test_round_trip_dict_conversion(self) -> None: - """Test round-trip to_dict and from_dict.""" - original = RunRequest( - message="Test message", - role=Role.SYSTEM, - response_format=ModuleStructuredResponse, - enable_tool_calls=False, - ) - - data = original.to_dict() - restored = RunRequest.from_dict(data) - - assert restored.message == original.message - assert restored.role == original.role - assert restored.response_format is ModuleStructuredResponse - assert restored.enable_tool_calls == original.enable_tool_calls - - def test_round_trip_with_pydantic_response_format(self) -> None: - """Ensure Pydantic response formats serialize and deserialize properly.""" - original = RunRequest( - message="Structured", - response_format=ModuleStructuredResponse, - ) - - data = original.to_dict() - - assert data["response_format"]["__response_schema_type__"] == "pydantic_model" - assert data["response_format"]["module"] == ModuleStructuredResponse.__module__ - assert data["response_format"]["qualname"] == ModuleStructuredResponse.__qualname__ - - restored = RunRequest.from_dict(data) - assert restored.response_format is ModuleStructuredResponse - - def test_init_with_correlationId(self) -> None: - """Test RunRequest initialization with correlationId.""" - request = RunRequest(message="Test message", correlation_id="corr-123") - - assert request.message == "Test message" - assert request.correlation_id == "corr-123" - - def test_to_dict_with_correlationId(self) -> None: - """Test to_dict includes correlationId.""" - request = RunRequest(message="Test", correlation_id="corr-456") - data = request.to_dict() - - assert data["message"] == "Test" - assert data["correlationId"] == "corr-456" - - def test_from_dict_with_correlationId(self) -> None: - """Test from_dict with correlationId.""" - data = {"message": "Test", "correlationId": "corr-789"} - request = RunRequest.from_dict(data) - - assert request.message == "Test" - assert request.correlation_id == "corr-789" - - def test_round_trip_with_correlationId(self) -> None: - """Test round-trip to_dict and from_dict with correlationId.""" - original = RunRequest( - message="Test message", - role=Role.SYSTEM, - correlation_id="corr-123", - ) - - data = original.to_dict() - restored = RunRequest.from_dict(data) - - assert restored.message == original.message - assert restored.role == original.role - assert restored.correlation_id == original.correlation_id - - def test_init_with_orchestration_id(self) -> None: - """Test RunRequest initialization with orchestration_id.""" - request = RunRequest( - message="Test message", - orchestration_id="orch-123", - ) - - assert request.message == "Test message" - assert request.orchestration_id == "orch-123" - - def test_to_dict_with_orchestration_id(self) -> None: - """Test to_dict includes orchestrationId.""" - request = RunRequest( - message="Test", - orchestration_id="orch-456", - ) - data = request.to_dict() - - assert data["message"] == "Test" - assert data["orchestrationId"] == "orch-456" - - def test_to_dict_excludes_orchestration_id_when_none(self) -> None: - """Test to_dict excludes orchestrationId when not set.""" - request = RunRequest( - message="Test", - ) - data = request.to_dict() - - assert "orchestrationId" not in data - - def test_from_dict_with_orchestration_id(self) -> None: - """Test from_dict with orchestrationId.""" - data = { - "message": "Test", - "orchestrationId": "orch-789", - } - request = RunRequest.from_dict(data) - - assert request.message == "Test" - assert request.orchestration_id == "orch-789" - - def test_round_trip_with_orchestration_id(self) -> None: - """Test round-trip to_dict and from_dict with orchestration_id.""" - original = RunRequest( - message="Test message", - role=Role.SYSTEM, - correlation_id="corr-123", - orchestration_id="orch-123", - ) - - data = original.to_dict() - restored = RunRequest.from_dict(data) - - assert restored.message == original.message - assert restored.role == original.role - assert restored.correlation_id == original.correlation_id - assert restored.orchestration_id == original.orchestration_id - - -class TestModelIntegration: - """Test suite for integration between models.""" - - def test_run_request_with_session_id_string(self) -> None: - """AgentSessionId string can still be used by callers, but is not stored on RunRequest.""" - session_id = AgentSessionId.with_random_key("AgentEntity") - session_id_str = str(session_id) - - assert session_id_str.startswith("@AgentEntity@") - - -if __name__ == "__main__": - pytest.main([__file__, "-v", "--tb=short"]) diff --git a/python/packages/azurefunctions/tests/test_orchestration.py b/python/packages/azurefunctions/tests/test_orchestration.py index 98f70af95a..a5f67b0510 100644 --- a/python/packages/azurefunctions/tests/test_orchestration.py +++ b/python/packages/azurefunctions/tests/test_orchestration.py @@ -6,14 +6,36 @@ from unittest.mock import Mock import pytest -from agent_framework import AgentRunResponse, AgentThread, ChatMessage +from agent_framework import AgentRunResponse, ChatMessage +from agent_framework_durabletask import AgentSessionId, DurableAgentThread, DurableAIAgent from azure.durable_functions.models.Task import TaskBase, TaskState -from agent_framework_azurefunctions import AgentFunctionApp, DurableAIAgent -from agent_framework_azurefunctions._models import AgentSessionId, DurableAgentThread +from agent_framework_azurefunctions import AgentFunctionApp from agent_framework_azurefunctions._orchestration import AgentTask +def _create_mock_context(instance_id: str = "test-instance", uuid_values: list[str] | None = None) -> Mock: + """Create a mock orchestration context with common attributes. + + Args: + instance_id: The orchestration instance ID + uuid_values: List of UUIDs to return from new_uuid() calls (if None, returns "test-guid") + + Returns: + Mock context object configured for testing + """ + mock_context = Mock() + mock_context.instance_id = instance_id + mock_context.current_utc_datetime = Mock() + + if uuid_values: + mock_context.new_uuid = Mock(side_effect=uuid_values) + else: + mock_context.new_uuid = Mock(return_value="test-guid") + + return mock_context + + def _app_with_registered_agents(*agent_names: str) -> AgentFunctionApp: app = AgentFunctionApp(enable_health_check=False, enable_http_endpoints=False) for name in agent_names: @@ -189,16 +211,6 @@ class SampleSchema(BaseModel): class TestDurableAIAgent: """Test suite for DurableAIAgent wrapper.""" - def test_init(self) -> None: - """Test DurableAIAgent initialization.""" - mock_context = Mock() - mock_context.instance_id = "test-instance-123" - - agent = DurableAIAgent(mock_context, "TestAgent") - - assert agent.context == mock_context - assert agent.agent_name == "TestAgent" - def test_implements_agent_protocol(self) -> None: """Test that DurableAIAgent implements AgentProtocol.""" from agent_framework import AgentProtocol @@ -228,11 +240,12 @@ def test_has_agent_protocol_properties(self) -> None: def test_get_new_thread(self) -> None: """Test creating a new agent thread.""" - mock_context = Mock() - mock_context.instance_id = "test-instance-456" - mock_context.new_uuid = Mock(return_value="test-guid-456") + mock_context = _create_mock_context("test-instance-456", ["test-guid-456"]) - agent = DurableAIAgent(mock_context, "WriterAgent") + from agent_framework_azurefunctions._orchestration import AzureFunctionsAgentExecutor + + executor = AzureFunctionsAgentExecutor(mock_context) + agent = DurableAIAgent(executor, "WriterAgent") thread = agent.get_new_thread() assert isinstance(thread, DurableAgentThread) @@ -245,12 +258,12 @@ def test_get_new_thread(self) -> None: def test_get_new_thread_deterministic(self) -> None: """Test that get_new_thread creates deterministic session IDs.""" + mock_context = _create_mock_context("test-instance-789", ["session-guid-1", "session-guid-2"]) - mock_context = Mock() - mock_context.instance_id = "test-instance-789" - mock_context.new_uuid = Mock(side_effect=["session-guid-1", "session-guid-2"]) + from agent_framework_azurefunctions._orchestration import AzureFunctionsAgentExecutor - agent = DurableAIAgent(mock_context, "EditorAgent") + executor = AzureFunctionsAgentExecutor(mock_context) + agent = DurableAIAgent(executor, "EditorAgent") # Create multiple threads - they should have unique session IDs thread1 = agent.get_new_thread() @@ -272,14 +285,15 @@ def test_get_new_thread_deterministic(self) -> None: def test_run_creates_entity_call(self) -> None: """Test that run() creates proper entity call and returns a Task.""" - mock_context = Mock() - mock_context.instance_id = "test-instance-001" - mock_context.new_uuid = Mock(side_effect=["thread-guid", "correlation-guid"]) + mock_context = _create_mock_context("test-instance-001", ["thread-guid", "correlation-guid"]) entity_task = _create_entity_task() mock_context.call_entity = Mock(return_value=entity_task) - agent = DurableAIAgent(mock_context, "TestAgent") + from agent_framework_azurefunctions._orchestration import AzureFunctionsAgentExecutor + + executor = AzureFunctionsAgentExecutor(mock_context) + agent = DurableAIAgent(executor, "TestAgent") # Create thread thread = agent.get_new_thread() @@ -293,7 +307,7 @@ def test_run_creates_entity_call(self) -> None: # Verify call_entity was called with correct parameters assert mock_context.call_entity.called call_args = mock_context.call_entity.call_args - entity_id, operation, request = call_args[0] + _, operation, request = call_args[0] assert operation == "run" assert request["message"] == "Test message" @@ -307,14 +321,15 @@ def test_run_creates_entity_call(self) -> None: def test_run_sets_orchestration_id(self) -> None: """Test that run() sets the orchestration_id from context.instance_id.""" - mock_context = Mock() - mock_context.instance_id = "my-orchestration-123" - mock_context.new_uuid = Mock(side_effect=["thread-guid", "correlation-guid"]) + mock_context = _create_mock_context("my-orchestration-123", ["thread-guid", "correlation-guid"]) entity_task = _create_entity_task() mock_context.call_entity = Mock(return_value=entity_task) - agent = DurableAIAgent(mock_context, "TestAgent") + from agent_framework_azurefunctions._orchestration import AzureFunctionsAgentExecutor + + executor = AzureFunctionsAgentExecutor(mock_context) + agent = DurableAIAgent(executor, "TestAgent") thread = agent.get_new_thread() agent.run(messages="Test", thread=thread) @@ -326,14 +341,15 @@ def test_run_sets_orchestration_id(self) -> None: def test_run_without_thread(self) -> None: """Test that run() works without explicit thread (creates unique session key).""" - mock_context = Mock() - mock_context.instance_id = "test-instance-002" - mock_context.new_uuid = Mock(side_effect=["auto-generated-guid", "correlation-guid"]) + mock_context = _create_mock_context("test-instance-002", ["auto-generated-guid", "correlation-guid"]) entity_task = _create_entity_task() mock_context.call_entity = Mock(return_value=entity_task) - agent = DurableAIAgent(mock_context, "TestAgent") + from agent_framework_azurefunctions._orchestration import AzureFunctionsAgentExecutor + + executor = AzureFunctionsAgentExecutor(mock_context) + agent = DurableAIAgent(executor, "TestAgent") # Call without thread task = agent.run(messages="Test message") @@ -346,18 +362,20 @@ def test_run_without_thread(self) -> None: entity_id = call_args[0][0] assert entity_id.name == "dafx-TestAgent" assert entity_id.key == "auto-generated-guid" - # Should be called twice: once for session_key, once for correlationId + # Should be called twice: once for session_key, once for correlation_id assert mock_context.new_uuid.call_count == 2 def test_run_with_response_format(self) -> None: """Test that run() passes response format correctly.""" - mock_context = Mock() - mock_context.instance_id = "test-instance-003" + mock_context = _create_mock_context("test-instance-003", ["thread-guid", "correlation-guid"]) entity_task = _create_entity_task() mock_context.call_entity = Mock(return_value=entity_task) - agent = DurableAIAgent(mock_context, "TestAgent") + from agent_framework_azurefunctions._orchestration import AzureFunctionsAgentExecutor + + executor = AzureFunctionsAgentExecutor(mock_context) + agent = DurableAIAgent(executor, "TestAgent") from pydantic import BaseModel @@ -380,33 +398,18 @@ class SampleSchema(BaseModel): assert input_data["response_format"]["module"] == SampleSchema.__module__ assert input_data["response_format"]["qualname"] == SampleSchema.__qualname__ - def test_messages_to_string(self) -> None: - """Test converting ChatMessage list to string.""" - from agent_framework import ChatMessage - - mock_context = Mock() - agent = DurableAIAgent(mock_context, "TestAgent") - - messages = [ - ChatMessage(role="user", text="Hello"), - ChatMessage(role="assistant", text="Hi there"), - ChatMessage(role="user", text="How are you?"), - ] - - result = agent._messages_to_string(messages) - - assert result == "Hello\nHi there\nHow are you?" - def test_run_with_chat_message(self) -> None: """Test that run() handles ChatMessage input.""" from agent_framework import ChatMessage - mock_context = Mock() - mock_context.new_uuid = Mock(side_effect=["thread-guid", "correlation-guid"]) + mock_context = _create_mock_context(uuid_values=["thread-guid", "correlation-guid"]) entity_task = _create_entity_task() mock_context.call_entity = Mock(return_value=entity_task) - agent = DurableAIAgent(mock_context, "TestAgent") + from agent_framework_azurefunctions._orchestration import AzureFunctionsAgentExecutor + + executor = AzureFunctionsAgentExecutor(mock_context) + agent = DurableAIAgent(executor, "TestAgent") thread = agent.get_new_thread() # Call with ChatMessage @@ -436,11 +439,13 @@ def test_entity_id_format(self) -> None: """Test that EntityId is created with correct format (name, key).""" from azure.durable_functions import EntityId - mock_context = Mock() - mock_context.new_uuid = Mock(return_value="test-guid-789") + mock_context = _create_mock_context(uuid_values=["test-guid-789", "correlation-guid"]) mock_context.call_entity = Mock(return_value=_create_entity_task()) - agent = DurableAIAgent(mock_context, "WriterAgent") + from agent_framework_azurefunctions._orchestration import AzureFunctionsAgentExecutor + + executor = AzureFunctionsAgentExecutor(mock_context) + agent = DurableAIAgent(executor, "WriterAgent") thread = agent.get_new_thread() # Call run() to trigger entity ID creation @@ -461,18 +466,6 @@ def test_entity_id_format(self) -> None: class TestAgentFunctionAppGetAgent: """Test suite for AgentFunctionApp.get_agent.""" - def test_get_agent_method(self) -> None: - """Test get_agent method creates DurableAIAgent for registered agent.""" - app = _app_with_registered_agents("MyAgent") - mock_context = Mock() - mock_context.instance_id = "test-instance-100" - - agent = app.get_agent(mock_context, "MyAgent") - - assert isinstance(agent, DurableAIAgent) - assert agent.agent_name == "MyAgent" - assert agent.context == mock_context - def test_get_agent_raises_for_unregistered_agent(self) -> None: """Test get_agent raises ValueError when agent is not registered.""" app = _app_with_registered_agents("KnownAgent") @@ -486,13 +479,11 @@ class TestOrchestrationIntegration: def test_sequential_agent_calls_simulation(self) -> None: """Simulate sequential agent calls in an orchestration.""" - mock_context = Mock() - mock_context.instance_id = "test-orchestration-001" # new_uuid will be called 3 times: # 1. thread creation - # 2. correlationId for first call - # 3. correlationId for second call - mock_context.new_uuid = Mock(side_effect=["deterministic-guid-001", "corr-1", "corr-2"]) + # 2. correlation_id for first call + # 3. correlation_id for second call + mock_context = _create_mock_context("test-orchestration-001", ["deterministic-guid-001", "corr-1", "corr-2"]) # Track entity calls entity_calls: list[dict[str, Any]] = [] @@ -527,11 +518,11 @@ def mock_call_entity_side_effect(entity_id: Any, operation: str, input_data: dic def test_multiple_agents_in_orchestration(self) -> None: """Test using multiple different agents in one orchestration.""" - mock_context = Mock() - mock_context.instance_id = "test-orchestration-002" # Mock new_uuid to return different GUIDs for each call # Order: writer thread, editor thread, writer correlation, editor correlation - mock_context.new_uuid = Mock(side_effect=["writer-guid-001", "editor-guid-002", "writer-corr", "editor-corr"]) + mock_context = _create_mock_context( + "test-orchestration-002", ["writer-guid-001", "editor-guid-002", "writer-corr", "editor-corr"] + ) entity_calls: list[str] = [] @@ -562,58 +553,5 @@ def mock_call_entity_side_effect(entity_id: Any, operation: str, input_data: dic assert entity_calls[1] == "@dafx-editoragent@editor-guid-002" -class TestAgentThreadSerialization: - """Test that AgentThread can be serialized for orchestration state.""" - - async def test_agent_thread_serialize(self) -> None: - """Test that AgentThread can be serialized.""" - thread = AgentThread() - - # Serialize - serialized = await thread.serialize() - - assert isinstance(serialized, dict) - assert "service_thread_id" in serialized - - async def test_agent_thread_deserialize(self) -> None: - """Test that AgentThread can be deserialized.""" - thread = AgentThread() - serialized = await thread.serialize() - - # Deserialize - restored = await AgentThread.deserialize(serialized) - - assert isinstance(restored, AgentThread) - assert restored.service_thread_id == thread.service_thread_id - - async def test_durable_agent_thread_serialization(self) -> None: - """Test that DurableAgentThread persists session metadata during serialization.""" - mock_context = Mock() - mock_context.instance_id = "test-instance-999" - mock_context.new_uuid = Mock(return_value="test-guid-999") - - agent = DurableAIAgent(mock_context, "TestAgent") - thread = agent.get_new_thread() - - assert isinstance(thread, DurableAgentThread) - # Verify custom attribute and property exist - assert thread.session_id is not None - session_id = thread.session_id - assert isinstance(session_id, AgentSessionId) - assert session_id.name == "TestAgent" - assert session_id.key == "test-guid-999" - - # Standard serialization should still work - serialized = await thread.serialize() - assert isinstance(serialized, dict) - assert serialized.get("durable_session_id") == str(session_id) - - # After deserialization, we'd need to restore the custom attribute - # This would be handled by the orchestration framework - restored = await DurableAgentThread.deserialize(serialized) - assert isinstance(restored, DurableAgentThread) - assert restored.session_id == session_id - - if __name__ == "__main__": pytest.main([__file__, "-v", "--tb=short"]) diff --git a/python/packages/core/agent_framework/_workflows/_handoff.py b/python/packages/core/agent_framework/_workflows/_handoff.py index 9a99657902..33c533c5e5 100644 --- a/python/packages/core/agent_framework/_workflows/_handoff.py +++ b/python/packages/core/agent_framework/_workflows/_handoff.py @@ -871,8 +871,10 @@ def create_specialist() -> AgentProtocol: HandoffBuilder(participants=[coordinator, refund, shipping]) .set_coordinator(coordinator) .with_termination_condition( - lambda conv: sum(1 for msg in conv if msg.role.value == "user") >= 5 - or any("goodbye" in msg.text.lower() for msg in conv[-2:]) + lambda conv: ( + sum(1 for msg in conv if msg.role.value == "user") >= 5 + or any("goodbye" in msg.text.lower() for msg in conv[-2:]) + ) ) .build() ) diff --git a/python/packages/core/tests/workflow/test_workflow_builder.py b/python/packages/core/tests/workflow/test_workflow_builder.py index 83c9d41c22..91a213e3c2 100644 --- a/python/packages/core/tests/workflow/test_workflow_builder.py +++ b/python/packages/core/tests/workflow/test_workflow_builder.py @@ -245,7 +245,8 @@ def test_register_multiple_executors(): # Build workflow with edges using registered names workflow = ( - builder.set_start_executor("ExecutorA") + builder + .set_start_executor("ExecutorA") .add_edge("ExecutorA", "ExecutorB") .add_edge("ExecutorB", "ExecutorC") .build() @@ -426,7 +427,8 @@ def test_register_with_fan_in_edges(): # Add fan-in edges using registered names # Both Source1 and Source2 need to be reachable, so connect Source1 to Source2 workflow = ( - builder.set_start_executor("Source1") + builder + .set_start_executor("Source1") .add_edge("Source1", "Source2") .add_fan_in_edges(["Source1", "Source2"], "Aggregator") .build() diff --git a/python/packages/durabletask/DESIGN.md b/python/packages/durabletask/DESIGN.md index d8be9f7840..3fe14a385d 100644 --- a/python/packages/durabletask/DESIGN.md +++ b/python/packages/durabletask/DESIGN.md @@ -6,16 +6,17 @@ This package, `agent-framework-durabletask`, provides a durability layer for the ## Design Decision -**Selected Approach: Object-Oriented Wrappers with Symmetric Factory Pattern** +**Selected Approach: Object-Oriented Wrappers with Symmetric Factory Pattern + Strategy Pattern for Execution** -We will use a symmetric Object-Oriented design where both the Client (external) and Orchestrator (internal) expose a consistent interface for retrieving and interacting with durable agents. +We will use a symmetric Object-Oriented design where both the Client (external) and Orchestrator (internal) expose a consistent interface for retrieving and interacting with durable agents. Execution logic is delegated to dedicated provider strategies. ## Core Philosophy * **Native `DurableEntity` Support**: We will leverage the `DurableEntity` support introduced in `durabletask` v1.0.0. -* **Symmetric Factories**: `DurableAIAgentClient` (for external use) and `DurableAIAgentOrchestrator` (for internal use) both provide a `get_agent` method. +* **Symmetric Factories**: `DurableAIAgentClient` (for external use) and `DurableAIAgentOrchestrationContext` (for internal use) both provide a `get_agent` method. * **Unified Interface**: `DurableAIAgent` serves as the common interface for executing agents, regardless of the context (Client vs Orchestration). -* **Consistent Return Type**: `DurableAIAgent.run` always returns a `Task` (or compatible awaitable), ensuring consistent usage patterns. +* **Strategy Pattern for Execution**: Execution logic is encapsulated in `DurableAgentExecutor` implementations, allowing flexible delegation while keeping the public API clean. +* **Consistent Return Type**: `DurableAIAgent.run` returns context-appropriate objects (awaitable for Client, yieldable Task for Orchestrator), ensuring consistent usage patterns. ## Architecture @@ -29,12 +30,12 @@ packages/durabletask/ │ ├── __init__.py │ ├── _worker.py # DurableAIAgentWorker │ ├── _client.py # DurableAIAgentClient -│ ├── _orchestrator.py # DurableAIAgentOrchestrator +│ ├── _orchestration_context.py # DurableAIAgentOrchestrationContext +│ ├── _executors.py # DurableAgentExecutor ABC and implementations │ ├── _entities.py # AgentEntity implementation │ ├── _models.py # Data models (RunRequest, AgentResponse, etc.) │ ├── _durable_agent_state.py # State schema (Ported from azurefunctions) -│ ├── _shim.py # DurableAIAgent implementation (will be ported from azurefunctions) -│ └── _utils.py # Mixins and helpers +│ └── _shim.py # DurableAIAgent and DurableAgentProvider ABC └── tests/ ``` @@ -83,89 +84,191 @@ class DurableAIAgentWorker: self._worker.stop() ``` -### 5. The Mixin (`_utils.py`) +### 5. The Shim and Provider ABC (`_shim.py`) + +The `_shim.py` module contains two key abstractions: + +1. **`DurableAgentProvider` ABC**: Defines the contract for constructing durable agent proxies. Implemented by context-specific wrappers (client/orchestration) to provide a consistent `get_agent` entry point. + +2. **`DurableAIAgent`**: The agent shim that delegates execution to an executor strategy. ```python -class GetDurableAgentMixin: - """Mixin to provide get_agent interface.""" +from abc import ABC, abstractmethod + +class DurableAgentProvider(ABC): + """Abstract provider for constructing durable agent proxies. + + Implemented by context-specific wrappers (client/orchestration) to return a + DurableAIAgent shim backed by their respective DurableAgentExecutor + implementation, ensuring a consistent get_agent entry point regardless of + execution context. + """ + + @abstractmethod + def get_agent(self, agent_name: str) -> DurableAIAgent: + """Retrieve a DurableAIAgent shim for the specified agent.""" + raise NotImplementedError("Subclasses must implement get_agent()") + + +class DurableAIAgent(AgentProtocol): + """A durable agent proxy that delegates execution to an executor. + + This class implements AgentProtocol but doesn't contain any agent logic itself. + Instead, it serves as a consistent interface that delegates to the underlying + executor, which can be either ClientAgentExecutor or OrchestrationAgentExecutor. + """ + + def __init__(self, executor: DurableAgentExecutor, name: str, *, agent_id: str | None = None): + self._executor = executor + self._name = name + self._id = agent_id if agent_id is not None else name - def get_agent(self, agent_name: str) -> 'DurableAIAgent': + def run(self, messages: ..., **kwargs) -> Any: + """Execute the agent via the injected executor.""" + message_str = self._normalize_messages(messages) + return self._executor.run_durable_agent( + agent_name=self._name, + message=message_str, + thread=kwargs.get('thread'), + response_format=kwargs.get('response_format'), + **kwargs + ) +``` + +### 6. The Executor Strategy (`_executors.py`) + +We introduce dedicated "Executor" classes to handle execution logic using the Strategy Pattern. These are internal execution strategies that are injected into the `DurableAIAgent` shim. This ensures the public API of the Client and Orchestration Context remains clean, while allowing the Shim to be reused across different environments. + +```python +from abc import ABC, abstractmethod +from typing import Any +from agent_framework import AgentThread +from pydantic import BaseModel + +class DurableAgentExecutor(ABC): + """Abstract base class for durable agent execution strategies.""" + + @abstractmethod + def run_durable_agent( + self, + agent_name: str, + message: str, + *, + thread: AgentThread | None = None, + response_format: type[BaseModel] | None = None, + **kwargs: Any, + ) -> Any: + """Execute the durable agent. + + Returns: + Any: Either an awaitable AgentRunResponse (Client) or a yieldable Task (Orchestrator). + """ + raise NotImplementedError + + @abstractmethod + def get_new_thread(self, agent_name: str, **kwargs: Any) -> AgentThread: + """Create a new thread appropriate for the context.""" raise NotImplementedError + + +class ClientAgentExecutor(DurableAgentExecutor): + """Execution strategy for external clients (async).""" + + def __init__(self, client: 'TaskHubGrpcClient'): + self._client = client + + async def run_durable_agent( + self, + agent_name: str, + message: str, + *, + thread: AgentThread | None = None, + response_format: type[BaseModel] | None = None, + **kwargs: Any, + ) -> Any: + # Implementation using self._client + # Returns an awaitable AgentRunResponse + raise NotImplementedError("ClientAgentExecutor.run_durable_agent is not yet implemented") + + def get_new_thread(self, agent_name: str, **kwargs: Any) -> AgentThread: + # Implementation for client context + return AgentThread(**kwargs) + + +class OrchestrationAgentExecutor(DurableAgentExecutor): + """Execution strategy for orchestrations (sync/yield).""" + + def __init__(self, context: 'OrchestrationContext'): + self._context = context + + def run_durable_agent( + self, + agent_name: str, + message: str, + *, + thread: AgentThread | None = None, + response_format: type[BaseModel] | None = None, + **kwargs: Any, + ) -> Any: + # Implementation using self._context + # Returns a yieldable Task + raise NotImplementedError("OrchestrationAgentExecutor.run_durable_agent is not yet implemented") + + def get_new_thread(self, agent_name: str, **kwargs: Any) -> AgentThread: + # Implementation for orchestration context + return AgentThread(**kwargs) ``` -### 6. The Client Wrapper (`_client.py`) +**Benefits of the Strategy Pattern:** -The `DurableAIAgentClient` is for external clients (e.g., FastAPI, CLI). +1. **Strong Contract**: ABC enforces implementation of required methods. +2. **Encapsulation**: Execution logic is hidden in provider classes. +3. **Flexibility**: Easy to add new providers (e.g., for Azure Functions). +4. **Separation of Concerns**: Client/Context wrappers focus on being factories/adapters. +5. **Reusability**: The shim can be reused across different environments without modification. + +### 7. The Client Wrapper (`_client.py`) + +The `DurableAIAgentClient` is for external clients (e.g., FastAPI, CLI). It implements `DurableAgentProvider` to provide the `get_agent` factory method, and instantiates the `ClientAgentExecutor` to inject into the `DurableAIAgent`. ```python -class DurableAIAgentClient(GetDurableAgentMixin): +from ._executors import ClientAgentExecutor +from ._shim import DurableAgentProvider, DurableAIAgent + +class DurableAIAgentClient(DurableAgentProvider): def __init__(self, client: TaskHubGrpcClient): self._client = client - async def get_agent(self, agent_name: str) -> 'DurableAIAgent': + def get_agent(self, agent_name: str) -> DurableAIAgent: """Retrieves a DurableAIAgent shim. Validates existence by attempting to fetch entity state/metadata. """ # Validation logic using self._client.get_entity(...) # ... - return DurableAIAgent(self, agent_name) - - def run_agent(self, agent_name: str, message: str, **kwargs) -> 'Task': - """Runs agent via signal + poll and returns a Task wrapper.""" - # Returns a ClientTask (wrapper around asyncio.Task) - pass + executor = ClientAgentExecutor(self._client) + return DurableAIAgent(executor, agent_name) ``` -### 7. The Orchestration Context Wrapper (`_orchestration_context.py`) +### 8. The Orchestration Context Wrapper (`_orchestration_context.py`) -The `DurableAIAgentOrchestrationContext` is for use *inside* orchestrations to get access to agents that were registered in the workers. +The `DurableAIAgentOrchestrationContext` is for use *inside* orchestrations to get access to agents that were registered in the workers. It implements `DurableAgentProvider` to provide the `get_agent` factory method, and instantiates the `OrchestrationAgentExecutor`. ```python -class DurableAIAgentOrchestrationContext(GetDurableAgentMixin): +from ._executors import OrchestrationAgentExecutor +from ._shim import DurableAgentProvider, DurableAIAgent + +class DurableAIAgentOrchestrationContext(DurableAgentProvider): def __init__(self, context: OrchestrationContext): self._context = context - def get_agent(self, agent_name: str) -> 'DurableAIAgent': + def get_agent(self, agent_name: str) -> DurableAIAgent: """Retrieves a DurableAIAgent shim. Validation is deferred or performed via call_entity if needed. """ - return DurableAIAgent(self, agent_name) - - def run_agent(self, agent_name: str, message: str, **kwargs) -> 'Task': - """Runs agent via call_entity and returns the Task.""" - # Returns the native durabletask.Task - pass -``` - -### 8. The Durable Agent Shim (`_shim.py`) - -The `DurableAIAgent` implements `AgentProtocol` but delegates execution to the provider. This will be ported from `azurefunctions` package and updated accordingly. - -```python -class DurableAIAgent(AgentProtocol): - """A shim that delegates execution to the provider (Client or Orchestrator).""" - - def __init__(self, provider: GetDurableAgentMixin, name: str): - self._provider = provider - self._name = name - - @property - def name(self) -> str: - return self._name - - def run(self, message: str, **kwargs) -> 'Task': - """Executes the agent. - - Returns: - Task: A yieldable/awaitable task object. - """ - return self._provider.run_agent( - agent_name=self.name, - message=message, - **kwargs - ) + executor = OrchestrationAgentExecutor(self._context) + return DurableAIAgent(executor, agent_name) ``` ## Usage Experience @@ -215,8 +318,8 @@ response = await agent.run("Hello") **Scenario C: Orchestration Side** ```python def orchestrator(context: OrchestrationContext): - # 1. Create the Agent Orchestrator wrapper - agent_orch = DurableAIAgentOrchestrator(context) + # 1. Create the Agent Orchestration Context + agent_orch = DurableAIAgentOrchestrationContext(context) # 2. Get a reference to the agent agent = agent_orch.get_agent("my_agent") @@ -245,4 +348,37 @@ We investigated inheriting `DurableAIAgentWorker` directly from `TaskHubGrpcWork 3. **Flexibility:** The Composition pattern allows `DurableAIAgentWorker` to accept *any* instance of a worker that satisfies the required interface. This makes it forward-compatible with future worker implementations or custom subclasses without requiring code changes in our package. -4. **Simplicity:** While Composition requires a two-step setup (instantiate worker, then wrap it), it keeps the `agent-framework-durabletask` package simple, focused, and loosely coupled from the implementation details of the underlying `durabletask` workers. \ No newline at end of file +4. **Simplicity:** While Composition requires a two-step setup (instantiate worker, then wrap it), it keeps the `agent-framework-durabletask` package simple, focused, and loosely coupled from the implementation details of the underlying `durabletask` workers. + +## Extension Point: Azure Functions Integration + +The Strategy Pattern design allows for easy integration with Azure Functions. The `azurefunctions` package can define its own `AzureFunctionsAgentExecutor` in `packages/azurefunctions/agent_framework_azurefunctions/_executors.py`. + +```python +from agent_framework_durabletask._executors import DurableAgentExecutor + +class AzureFunctionsAgentExecutor(DurableAgentExecutor): + """Execution strategy for Azure Functions orchestrations.""" + + def __init__(self, context: DurableOrchestrationContext): + self._context = context + + def run_durable_agent( + self, + agent_name: str, + message: str, + *, + thread: AgentThread | None = None, + response_format: type[BaseModel] | None = None, + **kwargs: Any, + ) -> Any: + # Implementation using Azure Functions context + # Returns AgentTask + ... + + def get_new_thread(self, agent_name: str, **kwargs: Any) -> AgentThread: + # Implementation for Azure Functions context + ... +``` + +Then `packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py` implements `DurableAgentProvider` and uses this executor when creating the agent, ensuring consistent behavior across platforms while accommodating Azure Functions-specific features. \ No newline at end of file diff --git a/python/packages/durabletask/agent_framework_durabletask/__init__.py b/python/packages/durabletask/agent_framework_durabletask/__init__.py index f28f4d8064..10d880f001 100644 --- a/python/packages/durabletask/agent_framework_durabletask/__init__.py +++ b/python/packages/durabletask/agent_framework_durabletask/__init__.py @@ -3,6 +3,7 @@ """Durable Task integration for Microsoft Agent Framework.""" from ._callbacks import AgentCallbackContext, AgentResponseCallbackProtocol +from ._client import DurableAIAgentClient from ._constants import ( DEFAULT_MAX_POLL_RETRIES, DEFAULT_POLL_INTERVAL_SECONDS, @@ -41,7 +42,11 @@ DurableAgentStateUsageContent, ) from ._entities import AgentEntity, AgentEntityStateProviderMixin -from ._models import RunRequest, serialize_response_format +from ._executors import DurableAgentExecutor +from ._models import AgentSessionId, DurableAgentThread, RunRequest, serialize_response_format +from ._orchestration_context import DurableAIAgentOrchestrationContext +from ._shim import DurableAIAgent +from ._worker import DurableAIAgentWorker __all__ = [ "DEFAULT_MAX_POLL_RETRIES", @@ -58,8 +63,14 @@ "AgentEntity", "AgentEntityStateProviderMixin", "AgentResponseCallbackProtocol", + "AgentSessionId", "ApiResponseFields", "ContentTypes", + "DurableAIAgent", + "DurableAIAgentClient", + "DurableAIAgentOrchestrationContext", + "DurableAIAgentWorker", + "DurableAgentExecutor", "DurableAgentState", "DurableAgentStateContent", "DurableAgentStateData", @@ -80,6 +91,8 @@ "DurableAgentStateUriContent", "DurableAgentStateUsage", "DurableAgentStateUsageContent", + "DurableAgentThread", + "DurableAgentThread", "DurableStateFields", "RunRequest", "serialize_response_format", diff --git a/python/packages/durabletask/agent_framework_durabletask/_client.py b/python/packages/durabletask/agent_framework_durabletask/_client.py new file mode 100644 index 0000000000..e26c9d6ba5 --- /dev/null +++ b/python/packages/durabletask/agent_framework_durabletask/_client.py @@ -0,0 +1,79 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Client wrapper for Durable Task Agent Framework. + +This module provides the DurableAIAgentClient class for external clients to interact +with durable agents via gRPC. +""" + +from __future__ import annotations + +from agent_framework import get_logger +from durabletask.client import TaskHubGrpcClient + +from ._executors import ClientAgentExecutor +from ._shim import DurableAgentProvider, DurableAIAgent + +logger = get_logger("agent_framework.durabletask.client") + + +class DurableAIAgentClient(DurableAgentProvider): + """Client wrapper for interacting with durable agents externally. + + This class wraps a durabletask TaskHubGrpcClient and provides a convenient + interface for retrieving and executing durable agents from external contexts + (e.g., FastAPI endpoints, CLI tools, etc.). + + Example: + ```python + from durabletask import TaskHubGrpcClient + from agent_framework_durabletask import DurableAIAgentClient + + # Create the underlying client + client = TaskHubGrpcClient(host_address="localhost:4001") + + # Wrap it with the agent client + agent_client = DurableAIAgentClient(client) + + # Get an agent reference + agent = await agent_client.get_agent("assistant") + + # Run the agent + response = await agent.run("Hello, how are you?") + print(response.text) + ``` + """ + + def __init__(self, client: TaskHubGrpcClient): + """Initialize the client wrapper. + + Args: + client: The durabletask client instance to wrap + """ + self._client = client + self._executor = ClientAgentExecutor(self._client) + logger.debug("[DurableAIAgentClient] Initialized with client type: %s", type(client).__name__) + + def get_agent(self, agent_name: str) -> DurableAIAgent: + """Retrieve a DurableAIAgent shim for the specified agent. + + This method returns a proxy object that can be used to execute the agent. + The actual agent must be registered on a worker with the same name. + + Args: + agent_name: Name of the agent to retrieve (without the dafx- prefix) + + Returns: + DurableAIAgent instance that can be used to run the agent + + Note: + This method does not validate that the agent exists. Validation + will occur when the agent is executed. If the entity doesn't exist, + the execution will fail with an appropriate error. + """ + logger.debug("[DurableAIAgentClient] Creating agent proxy for: %s", agent_name) + + # Note: Validation would require async, so we defer it to execution time + # The entity name will be f"dafx-{agent_name}" + + return DurableAIAgent(self._executor, agent_name) diff --git a/python/packages/durabletask/agent_framework_durabletask/_executors.py b/python/packages/durabletask/agent_framework_durabletask/_executors.py new file mode 100644 index 0000000000..d165166c60 --- /dev/null +++ b/python/packages/durabletask/agent_framework_durabletask/_executors.py @@ -0,0 +1,108 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Provider strategies for Durable Agent execution. + +These classes are internal execution strategies used by the DurableAIAgent shim. +They are intentionally separate from the public client/orchestration APIs to keep +only `get_agent` exposed to consumers. Providers implement the execution contract +and are injected into the shim. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any + +from agent_framework import AgentThread, get_logger +from pydantic import BaseModel + +from ._models import DurableAgentThread + +if TYPE_CHECKING: # pragma: no cover + from durabletask.client import TaskHubGrpcClient + from durabletask.task import OrchestrationContext + +logger = get_logger("agent_framework.durabletask.executors") + + +class DurableAgentExecutor(ABC): + """Abstract base class for durable agent execution strategies.""" + + @abstractmethod + def run_durable_agent( + self, + agent_name: str, + message: str, + *, + thread: AgentThread | None = None, + response_format: type[BaseModel] | None = None, + **kwargs: Any, + ) -> Any: + """Execute the durable agent. + + Returns: + Any: Either an awaitable AgentRunResponse (Client) or a yieldable Task (Orchestrator). + """ + raise NotImplementedError + + @abstractmethod + def get_new_thread(self, agent_name: str, **kwargs: Any) -> DurableAgentThread: + """Create a new thread appropriate for the provider context.""" + raise NotImplementedError + + +class ClientAgentExecutor(DurableAgentExecutor): + """Execution strategy for external clients (async).""" + + def __init__(self, client: TaskHubGrpcClient): + self._client = client + logger.debug("[ClientAgentExecutor] Initialized with client type: %s", type(client).__name__) + + async def run_durable_agent( + self, + agent_name: str, + message: str, + *, + thread: AgentThread | None = None, + response_format: type[BaseModel] | None = None, + **kwargs: Any, + ) -> Any: + """Execute the agent via the durabletask client. + + Note: Implementation is backend-specific and should signal/call the entity + and await the durable response. This placeholder raises NotImplementedError + until wired to concrete durabletask calls. + """ + raise NotImplementedError("ClientAgentProvider.run_durable_agent is not yet implemented") + + def get_new_thread(self, agent_name: str, **kwargs: Any) -> DurableAgentThread: + """Create a new AgentThread for client-side execution.""" + return DurableAgentThread(**kwargs) + + +class OrchestrationAgentExecutor(DurableAgentExecutor): + """Execution strategy for orchestrations (sync/yield).""" + + def __init__(self, context: OrchestrationContext): + self._context = context + logger.debug("[OrchestrationAgentExecutor] Initialized") + + def run_durable_agent( + self, + agent_name: str, + message: str, + *, + thread: AgentThread | None = None, + response_format: type[BaseModel] | None = None, + **kwargs: Any, + ) -> Any: + """Execute the agent via orchestration context. + + Note: Implementation should call the entity (e.g., context.call_entity) + and return the native Task for yielding. Placeholder until wired. + """ + raise NotImplementedError("OrchestrationAgentProvider.run_durable_agent is not yet implemented") + + def get_new_thread(self, agent_name: str, **kwargs: Any) -> DurableAgentThread: + """Create a new AgentThread for orchestration context.""" + return DurableAgentThread(**kwargs) diff --git a/python/packages/durabletask/agent_framework_durabletask/_models.py b/python/packages/durabletask/agent_framework_durabletask/_models.py index 14ca37f098..4317b87832 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_models.py +++ b/python/packages/durabletask/agent_framework_durabletask/_models.py @@ -8,12 +8,14 @@ from __future__ import annotations import inspect +import uuid +from collections.abc import MutableMapping from dataclasses import dataclass from datetime import datetime from importlib import import_module from typing import TYPE_CHECKING, Any, cast -from agent_framework import Role +from agent_framework import AgentThread, Role from ._constants import REQUEST_RESPONSE_FORMAT_TEXT @@ -186,3 +188,106 @@ def from_dict(cls, data: dict[str, Any]) -> RunRequest: created_at=created_at, orchestration_id=data.get("orchestrationId"), ) + + +@dataclass +class AgentSessionId: + """Represents an agent session identifier (name + key).""" + + name: str + key: str + + ENTITY_NAME_PREFIX: str = "dafx-" + + @staticmethod + def to_entity_name(name: str) -> str: + return f"{AgentSessionId.ENTITY_NAME_PREFIX}{name}" + + @staticmethod + def with_random_key(name: str) -> AgentSessionId: + return AgentSessionId(name=name, key=uuid.uuid4().hex) + + @property + def entity_name(self) -> str: + return self.to_entity_name(self.name) + + def __str__(self) -> str: + return f"@{self.name}@{self.key}" + + def __repr__(self) -> str: + return f"AgentSessionId(name='{self.name}', key='{self.key}')" + + @staticmethod + def parse(session_id_string: str) -> AgentSessionId: + if not session_id_string.startswith("@"): + raise ValueError(f"Invalid agent session ID format: {session_id_string}") + + parts = session_id_string[1:].split("@", 1) + if len(parts) != 2: + raise ValueError(f"Invalid agent session ID format: {session_id_string}") + + return AgentSessionId(name=parts[0], key=parts[1]) + + +class DurableAgentThread(AgentThread): + """Durable agent thread that tracks the owning :class:`AgentSessionId`.""" + + _SERIALIZED_SESSION_ID_KEY = "durable_session_id" + + def __init__( + self, + *, + session_id: AgentSessionId | None = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self._session_id: AgentSessionId | None = session_id + + @property + def session_id(self) -> AgentSessionId | None: + return self._session_id + + @session_id.setter + def session_id(self, value: AgentSessionId | None) -> None: + self._session_id = value + + @classmethod + def from_session_id( + cls, + session_id: AgentSessionId, + **kwargs: Any, + ) -> DurableAgentThread: + return cls(session_id=session_id, **kwargs) + + async def serialize(self, **kwargs: Any) -> dict[str, Any]: + state = await super().serialize(**kwargs) + if self._session_id is not None: + state[self._SERIALIZED_SESSION_ID_KEY] = str(self._session_id) + return state + + @classmethod + async def deserialize( + cls, + serialized_thread_state: MutableMapping[str, Any], + *, + message_store: Any = None, + **kwargs: Any, + ) -> DurableAgentThread: + state_payload = dict(serialized_thread_state) + session_id_value = state_payload.pop(cls._SERIALIZED_SESSION_ID_KEY, None) + thread = await super().deserialize( + state_payload, + message_store=message_store, + **kwargs, + ) + if not isinstance(thread, DurableAgentThread): + raise TypeError("Deserialized thread is not a DurableAgentThread instance") + + if session_id_value is None: + return thread + + if not isinstance(session_id_value, str): + raise ValueError("durable_session_id must be a string when present in serialized state") + + thread.session_id = AgentSessionId.parse(session_id_value) + return thread diff --git a/python/packages/durabletask/agent_framework_durabletask/_orchestration_context.py b/python/packages/durabletask/agent_framework_durabletask/_orchestration_context.py new file mode 100644 index 0000000000..442d247831 --- /dev/null +++ b/python/packages/durabletask/agent_framework_durabletask/_orchestration_context.py @@ -0,0 +1,76 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Orchestration context wrapper for Durable Task Agent Framework. + +This module provides the DurableAIAgentOrchestrationContext class for use inside +orchestration functions to interact with durable agents. +""" + +from __future__ import annotations + +from agent_framework import get_logger +from durabletask.task import OrchestrationContext + +from ._executors import OrchestrationAgentExecutor +from ._shim import DurableAgentProvider, DurableAIAgent + +logger = get_logger("agent_framework.durabletask.orchestration_context") + + +class DurableAIAgentOrchestrationContext(DurableAgentProvider): + """Orchestration context wrapper for interacting with durable agents internally. + + This class wraps a durabletask OrchestrationContext and provides a convenient + interface for retrieving and executing durable agents from within orchestration + functions. + + Example: + ```python + from durabletask import Orchestration + from agent_framework_durabletask import DurableAIAgentOrchestrationContext + + + @orchestration + def my_orchestration(context: OrchestrationContext): + # Wrap the context + agent_context = DurableAIAgentOrchestrationContext(context) + + # Get an agent reference + agent = agent_context.get_agent("assistant") + + # Run the agent (returns a Task to be yielded) + result = yield agent.run("Hello, how are you?") + + return result.text + ``` + """ + + def __init__(self, context: OrchestrationContext): + """Initialize the orchestration context wrapper. + + Args: + context: The durabletask orchestration context to wrap + """ + self._context = context + self._executor = OrchestrationAgentExecutor(self._context) + logger.debug("[DurableAIAgentOrchestrationContext] Initialized") + + def get_agent(self, agent_name: str) -> DurableAIAgent: + """Retrieve a DurableAIAgent shim for the specified agent. + + This method returns a proxy object that can be used to execute the agent + within an orchestration. The agent's run() method will return a Task that + must be yielded. + + Args: + agent_name: Name of the agent to retrieve (without the dafx- prefix) + + Returns: + DurableAIAgent instance that can be used to run the agent + + Note: + Validation is deferred to execution time. The entity must be registered + on a worker with the name f"dafx-{agent_name}". + """ + logger.debug("[DurableAIAgentOrchestrationContext] Creating agent proxy for: %s", agent_name) + return DurableAIAgent(self._executor, agent_name) diff --git a/python/packages/durabletask/agent_framework_durabletask/_shim.py b/python/packages/durabletask/agent_framework_durabletask/_shim.py new file mode 100644 index 0000000000..696bddf611 --- /dev/null +++ b/python/packages/durabletask/agent_framework_durabletask/_shim.py @@ -0,0 +1,168 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Durable Agent Shim for Durable Task Framework. + +This module provides the DurableAIAgent shim that implements AgentProtocol +and provides a consistent interface for both Client and Orchestration contexts. +The actual execution is delegated to the context-specific providers. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections.abc import AsyncIterator +from typing import TYPE_CHECKING, Any + +from agent_framework import AgentProtocol, AgentRunResponseUpdate, AgentThread, ChatMessage +from pydantic import BaseModel + +if TYPE_CHECKING: + from ._executors import DurableAgentExecutor + + +class DurableAgentProvider(ABC): + """Abstract provider for constructing durable agent proxies. + + Implemented by context-specific wrappers (client/orchestration) to return a + `DurableAIAgent` shim backed by their respective `DurableAgentExecutor` + implementation, ensuring a consistent `get_agent` entry point regardless of + execution context. + """ + + @abstractmethod + def get_agent(self, agent_name: str) -> DurableAIAgent: + """Retrieve a DurableAIAgent shim for the specified agent. + + Args: + agent_name: Name of the agent to retrieve + + Returns: + DurableAIAgent instance that can be used to run the agent + + Raises: + NotImplementedError: Must be implemented by subclasses + """ + raise NotImplementedError("Subclasses must implement get_agent()") + + +class DurableAIAgent(AgentProtocol): + """A durable agent proxy that delegates execution to the provider. + + This class implements AgentProtocol but doesn't contain any agent logic itself. + Instead, it serves as a consistent interface that delegates to the underlying + provider, which can be either: + - DurableAIAgentClient (for external usage via HTTP/gRPC) + - DurableAIAgentOrchestrationContext (for use inside orchestrations) + + The provider determines how execution occurs (entity calls, HTTP requests, etc.) + and what type of Task object is returned (asyncio.Task vs durabletask.Task). + + Note: + This class intentionally does NOT inherit from BaseAgent because: + - BaseAgent assumes async/await patterns + - Orchestration contexts require yield patterns + - BaseAgent methods like as_tool() would fail in orchestrations + """ + + def __init__(self, executor: DurableAgentExecutor, name: str, *, agent_id: str | None = None): + """Initialize the shim with a provider and agent name. + + Args: + executor: The execution provider (Client or OrchestrationContext) + name: The name of the agent to execute + agent_id: Optional unique identifier for the agent (defaults to name) + """ + self._executor = executor + self._name = name + self._id = agent_id if agent_id is not None else name + self._display_name = name + self._description = f"Durable agent proxy for {name}" + + @property + def id(self) -> str: + """Get the unique identifier for this agent.""" + return self._id + + @property + def name(self) -> str | None: + """Get the name of the agent.""" + return self._name + + @property + def display_name(self) -> str: + """Get the display name of the agent.""" + return self._display_name + + @property + def description(self) -> str | None: + """Get the description of the agent.""" + return self._description + + def run( + self, + messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, + *, + thread: AgentThread | None = None, + response_format: type[BaseModel] | None = None, + **kwargs: Any, + ) -> Any: + """Execute the agent via the injected provider. + + The provider determines whether the return is awaitable (client) or yieldable (orchestration). + """ + message_str = self._normalize_messages(messages) + return self._executor.run_durable_agent( + agent_name=self._name, + message=message_str, + thread=thread, + response_format=response_format, + **kwargs, + ) + + def run_stream( + self, + messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, + *, + thread: AgentThread | None = None, + **kwargs: Any, + ) -> AsyncIterator[AgentRunResponseUpdate]: + """Run the agent with streaming (not supported for durable agents). + + Args: + messages: The message(s) to send to the agent + thread: Optional agent thread for conversation context + **kwargs: Additional arguments + + Raises: + NotImplementedError: Streaming is not supported for durable agents + """ + raise NotImplementedError("Streaming is not supported for durable agents") + + def get_new_thread(self, **kwargs: Any) -> AgentThread: + """Create a new agent thread via the provider.""" + return self._executor.get_new_thread(self._name, **kwargs) + + def _normalize_messages(self, messages: str | ChatMessage | list[str] | list[ChatMessage] | None) -> str: + """Convert supported message inputs to a single string. + + Args: + messages: The messages to normalize + + Returns: + A single string representation of the messages + """ + if messages is None: + return "" + if isinstance(messages, str): + return messages + if isinstance(messages, ChatMessage): + return messages.text or "" + if isinstance(messages, list): + if not messages: + return "" + first_item = messages[0] + if isinstance(first_item, str): + return "\n".join(messages) # type: ignore[arg-type] + # List of ChatMessage + return "\n".join([msg.text or "" for msg in messages]) # type: ignore[union-attr] + return "" diff --git a/python/packages/durabletask/agent_framework_durabletask/_worker.py b/python/packages/durabletask/agent_framework_durabletask/_worker.py new file mode 100644 index 0000000000..aa2c90b3b4 --- /dev/null +++ b/python/packages/durabletask/agent_framework_durabletask/_worker.py @@ -0,0 +1,199 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Worker wrapper for Durable Task Agent Framework. + +This module provides the DurableAIAgentWorker class that wraps a durabletask worker +and enables registration of agents as durable entities. +""" + +from __future__ import annotations + +from typing import Any + +from agent_framework import AgentProtocol, get_logger +from durabletask.worker import TaskHubGrpcWorker + +from ._callbacks import AgentResponseCallbackProtocol +from ._entities import AgentEntity, DurableTaskEntityStateProvider + +logger = get_logger("agent_framework.durabletask.worker") + + +class DurableAIAgentWorker: + """Wrapper for durabletask worker that enables agent registration. + + This class wraps an existing TaskHubGrpcWorker instance and provides + a convenient interface for registering agents as durable entities. + + Example: + ```python + from durabletask import TaskHubGrpcWorker + from agent_framework import ChatAgent + from agent_framework_durabletask import DurableAIAgentWorker + + # Create the underlying worker + worker = TaskHubGrpcWorker(host_address="localhost:4001") + + # Wrap it with the agent worker + agent_worker = DurableAIAgentWorker(worker) + + # Register agents + my_agent = ChatAgent(chat_client=client, name="assistant") + agent_worker.add_agent(my_agent) + + # Start the worker + worker.start() + ``` + """ + + def __init__( + self, + worker: TaskHubGrpcWorker, + callback: AgentResponseCallbackProtocol | None = None, + ): + """Initialize the worker wrapper. + + Args: + worker: The durabletask worker instance to wrap + callback: Optional callback for agent response notifications + """ + self._worker = worker + self._callback = callback + self._registered_agents: dict[str, AgentProtocol] = {} + logger.debug("[DurableAIAgentWorker] Initialized with worker type: %s", type(worker).__name__) + + def add_agent( + self, + agent: AgentProtocol, + callback: AgentResponseCallbackProtocol | None = None, + ) -> None: + """Register an agent with the worker. + + This method creates a durable entity class for the agent and registers + it with the underlying durabletask worker. The entity will be accessible + by the name "dafx-{agent_name}". + + Args: + agent: The agent to register (must have a name) + callback: Optional callback for this specific agent (overrides worker-level callback) + + Raises: + ValueError: If the agent doesn't have a name or is already registered + """ + agent_name = agent.name + if not agent_name: + raise ValueError("Agent must have a name to be registered") + + if agent_name in self._registered_agents: + raise ValueError(f"Agent '{agent_name}' is already registered") + + logger.info("[DurableAIAgentWorker] Registering agent: %s as entity: dafx-%s", agent_name, agent_name) + + # Store the agent reference + self._registered_agents[agent_name] = agent + + # Use agent-specific callback if provided, otherwise use worker-level callback + effective_callback = callback or self._callback + + # Create a configured entity class using the factory + entity_class = self.__create_agent_entity(agent, effective_callback) + + # Register the entity class with the worker + # The worker.add_entity method takes a class + entity_registered: str = self._worker.add_entity(entity_class) # pyright: ignore[reportUnknownMemberType] + + logger.debug( + "[DurableAIAgentWorker] Successfully registered entity class %s for agent: %s", + entity_registered, + agent_name, + ) + + def start(self) -> None: + """Start the worker to begin processing tasks. + + Note: + This method delegates to the underlying worker's start method. + The worker will block until stopped. + """ + logger.info("[DurableAIAgentWorker] Starting worker with %d registered agents", len(self._registered_agents)) + self._worker.start() + + def stop(self) -> None: + """Stop the worker gracefully. + + Note: + This method delegates to the underlying worker's stop method. + """ + logger.info("[DurableAIAgentWorker] Stopping worker") + self._worker.stop() + + @property + def registered_agent_names(self) -> list[str]: + """Get the names of all registered agents. + + Returns: + List of agent names (without the dafx- prefix) + """ + return list(self._registered_agents.keys()) + + def __create_agent_entity( + self, + agent: AgentProtocol, + callback: AgentResponseCallbackProtocol | None = None, + ) -> type[DurableTaskEntityStateProvider]: + """Factory function to create a DurableEntity class configured with an agent. + + This factory creates a new class that combines the entity state provider + with the agent execution logic. Each agent gets its own entity class. + + Args: + agent: The agent instance to wrap + callback: Optional callback for agent responses + + Returns: + A new DurableEntity subclass configured for this agent + """ + agent_name = agent.name or type(agent).__name__ + entity_name = f"dafx-{agent_name}" + + class ConfiguredAgentEntity(DurableTaskEntityStateProvider): + """Durable entity configured with a specific agent instance.""" + + def __init__(self) -> None: + super().__init__() + # Create the AgentEntity with this state provider + self._agent_entity = AgentEntity( + agent=agent, + callback=callback, + state_provider=self, + ) + logger.debug( + "[ConfiguredAgentEntity] Initialized entity for agent: %s (entity name: %s)", + agent_name, + entity_name, + ) + + async def run(self, request: Any) -> Any: + """Handle run requests from clients or orchestrations. + + Args: + request: RunRequest as dict or string + + Returns: + AgentRunResponse as dict + """ + logger.debug("[ConfiguredAgentEntity.run] Executing agent: %s", agent_name) + response = await self._agent_entity.run(request) + return response.to_dict() + + def reset(self) -> None: + """Reset the agent's conversation history.""" + logger.debug("[ConfiguredAgentEntity.reset] Resetting agent: %s", agent_name) + self._agent_entity.reset() + + # Set the entity name to match the prefixed agent name + # This is used by durabletask to register the entity + ConfiguredAgentEntity.__name__ = entity_name + ConfiguredAgentEntity.__qualname__ = entity_name + + return ConfiguredAgentEntity diff --git a/python/packages/durabletask/tests/test_agent_session_id.py b/python/packages/durabletask/tests/test_agent_session_id.py new file mode 100644 index 0000000000..92c4e5f872 --- /dev/null +++ b/python/packages/durabletask/tests/test_agent_session_id.py @@ -0,0 +1,272 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Unit tests for AgentSessionId and DurableAgentThread.""" + +import pytest +from agent_framework import AgentThread + +from agent_framework_durabletask._models import AgentSessionId, DurableAgentThread + + +class TestAgentSessionId: + """Test suite for AgentSessionId.""" + + def test_init_creates_session_id(self) -> None: + """Test that AgentSessionId initializes correctly.""" + session_id = AgentSessionId(name="AgentEntity", key="test-key-123") + + assert session_id.name == "AgentEntity" + assert session_id.key == "test-key-123" + + def test_with_random_key_generates_guid(self) -> None: + """Test that with_random_key generates a GUID.""" + session_id = AgentSessionId.with_random_key(name="AgentEntity") + + assert session_id.name == "AgentEntity" + assert len(session_id.key) == 32 # UUID hex is 32 chars + # Verify it's a valid hex string + int(session_id.key, 16) + + def test_with_random_key_unique_keys(self) -> None: + """Test that with_random_key generates unique keys.""" + session_id1 = AgentSessionId.with_random_key(name="AgentEntity") + session_id2 = AgentSessionId.with_random_key(name="AgentEntity") + + assert session_id1.key != session_id2.key + + def test_str_representation(self) -> None: + """Test string representation.""" + session_id = AgentSessionId(name="AgentEntity", key="test-key-123") + str_repr = str(session_id) + + assert str_repr == "@AgentEntity@test-key-123" + + def test_repr_representation(self) -> None: + """Test repr representation.""" + session_id = AgentSessionId(name="AgentEntity", key="test-key") + repr_str = repr(session_id) + + assert "AgentSessionId" in repr_str + assert "AgentEntity" in repr_str + assert "test-key" in repr_str + + def test_parse_valid_session_id(self) -> None: + """Test parsing valid session ID string.""" + session_id = AgentSessionId.parse("@AgentEntity@test-key-123") + + assert session_id.name == "AgentEntity" + assert session_id.key == "test-key-123" + + def test_parse_invalid_format_no_prefix(self) -> None: + """Test parsing invalid format without @ prefix.""" + with pytest.raises(ValueError) as exc_info: + AgentSessionId.parse("AgentEntity@test-key") + + assert "Invalid agent session ID format" in str(exc_info.value) + + def test_parse_invalid_format_single_part(self) -> None: + """Test parsing invalid format with single part.""" + with pytest.raises(ValueError) as exc_info: + AgentSessionId.parse("@AgentEntity") + + assert "Invalid agent session ID format" in str(exc_info.value) + + def test_parse_with_multiple_at_signs_in_key(self) -> None: + """Test parsing with @ signs in the key.""" + session_id = AgentSessionId.parse("@AgentEntity@key-with@symbols") + + assert session_id.name == "AgentEntity" + assert session_id.key == "key-with@symbols" + + def test_parse_round_trip(self) -> None: + """Test round-trip parse and string conversion.""" + original = AgentSessionId(name="AgentEntity", key="test-key") + str_repr = str(original) + parsed = AgentSessionId.parse(str_repr) + + assert parsed.name == original.name + assert parsed.key == original.key + + def test_to_entity_name_adds_prefix(self) -> None: + """Test that to_entity_name adds the dafx- prefix.""" + entity_name = AgentSessionId.to_entity_name("TestAgent") + assert entity_name == "dafx-TestAgent" + + +class TestDurableAgentThread: + """Test suite for DurableAgentThread.""" + + def test_init_with_session_id(self) -> None: + """Test DurableAgentThread initialization with session ID.""" + session_id = AgentSessionId(name="TestAgent", key="test-key") + thread = DurableAgentThread(session_id=session_id) + + assert thread.session_id is not None + assert thread.session_id == session_id + + def test_init_without_session_id(self) -> None: + """Test DurableAgentThread initialization without session ID.""" + thread = DurableAgentThread() + + assert thread.session_id is None + + def test_session_id_setter(self) -> None: + """Test setting a session ID to an existing thread.""" + thread = DurableAgentThread() + assert thread.session_id is None + + session_id = AgentSessionId(name="TestAgent", key="test-key") + thread.session_id = session_id + + assert thread.session_id is not None + assert thread.session_id == session_id + assert thread.session_id.name == "TestAgent" + + def test_from_session_id(self) -> None: + """Test creating DurableAgentThread from session ID.""" + session_id = AgentSessionId(name="TestAgent", key="test-key") + thread = DurableAgentThread.from_session_id(session_id) + + assert isinstance(thread, DurableAgentThread) + assert thread.session_id is not None + assert thread.session_id == session_id + assert thread.session_id.name == "TestAgent" + assert thread.session_id.key == "test-key" + + def test_from_session_id_with_service_thread_id(self) -> None: + """Test creating DurableAgentThread with service thread ID.""" + session_id = AgentSessionId(name="TestAgent", key="test-key") + thread = DurableAgentThread.from_session_id(session_id, service_thread_id="service-123") + + assert thread.session_id is not None + assert thread.session_id == session_id + assert thread.service_thread_id == "service-123" + + async def test_serialize_with_session_id(self) -> None: + """Test serialization includes session ID.""" + session_id = AgentSessionId(name="TestAgent", key="test-key") + thread = DurableAgentThread(session_id=session_id) + + serialized = await thread.serialize() + + assert isinstance(serialized, dict) + assert "durable_session_id" in serialized + assert serialized["durable_session_id"] == "@TestAgent@test-key" + + async def test_serialize_without_session_id(self) -> None: + """Test serialization without session ID.""" + thread = DurableAgentThread() + + serialized = await thread.serialize() + + assert isinstance(serialized, dict) + assert "durable_session_id" not in serialized + + async def test_deserialize_with_session_id(self) -> None: + """Test deserialization restores session ID.""" + serialized = { + "service_thread_id": "thread-123", + "durable_session_id": "@TestAgent@test-key", + } + + thread = await DurableAgentThread.deserialize(serialized) + + assert isinstance(thread, DurableAgentThread) + assert thread.session_id is not None + assert thread.session_id.name == "TestAgent" + assert thread.session_id.key == "test-key" + assert thread.service_thread_id == "thread-123" + + async def test_deserialize_without_session_id(self) -> None: + """Test deserialization without session ID.""" + serialized = { + "service_thread_id": "thread-456", + } + + thread = await DurableAgentThread.deserialize(serialized) + + assert isinstance(thread, DurableAgentThread) + assert thread.session_id is None + assert thread.service_thread_id == "thread-456" + + async def test_round_trip_serialization(self) -> None: + """Test round-trip serialization preserves session ID.""" + session_id = AgentSessionId(name="TestAgent", key="test-key-789") + original = DurableAgentThread(session_id=session_id) + + serialized = await original.serialize() + restored = await DurableAgentThread.deserialize(serialized) + + assert isinstance(restored, DurableAgentThread) + assert restored.session_id is not None + assert restored.session_id.name == session_id.name + assert restored.session_id.key == session_id.key + + async def test_deserialize_invalid_session_id_type(self) -> None: + """Test deserialization with invalid session ID type raises error.""" + serialized = { + "service_thread_id": "thread-123", + "durable_session_id": 12345, # Invalid type + } + + with pytest.raises(ValueError, match="durable_session_id must be a string"): + await DurableAgentThread.deserialize(serialized) + + +class TestAgentThreadCompatibility: + """Test suite for compatibility between AgentThread and DurableAgentThread.""" + + async def test_agent_thread_serialize(self) -> None: + """Test that base AgentThread can be serialized.""" + thread = AgentThread() + + serialized = await thread.serialize() + + assert isinstance(serialized, dict) + assert "service_thread_id" in serialized + + async def test_agent_thread_deserialize(self) -> None: + """Test that base AgentThread can be deserialized.""" + thread = AgentThread() + serialized = await thread.serialize() + + restored = await AgentThread.deserialize(serialized) + + assert isinstance(restored, AgentThread) + assert restored.service_thread_id == thread.service_thread_id + + async def test_durable_thread_is_agent_thread(self) -> None: + """Test that DurableAgentThread is an AgentThread.""" + thread = DurableAgentThread() + + assert isinstance(thread, AgentThread) + assert isinstance(thread, DurableAgentThread) + + +class TestModelIntegration: + """Test suite for integration between models.""" + + def test_session_id_string_format(self) -> None: + """Test that AgentSessionId string format is consistent.""" + session_id = AgentSessionId.with_random_key("AgentEntity") + session_id_str = str(session_id) + + assert session_id_str.startswith("@AgentEntity@") + + async def test_thread_with_session_preserves_on_serialization(self) -> None: + """Test that thread with session ID preserves it through serialization.""" + session_id = AgentSessionId(name="TestAgent", key="preserved-key") + thread = DurableAgentThread.from_session_id(session_id) + + # Serialize and deserialize + serialized = await thread.serialize() + restored = await DurableAgentThread.deserialize(serialized) + + # Session ID should be preserved + assert restored.session_id is not None + assert restored.session_id.name == "TestAgent" + assert restored.session_id.key == "preserved-key" + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "--tb=short"]) diff --git a/python/packages/durabletask/tests/test_client.py b/python/packages/durabletask/tests/test_client.py new file mode 100644 index 0000000000..d6a13379a6 --- /dev/null +++ b/python/packages/durabletask/tests/test_client.py @@ -0,0 +1,91 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Unit tests for DurableAIAgentClient. + +Focuses on critical client workflows: agent retrieval, protocol compliance, and integration. +Run with: pytest tests/test_client.py -v +""" + +from unittest.mock import Mock + +import pytest +from agent_framework import AgentProtocol + +from agent_framework_durabletask import DurableAgentThread, DurableAIAgentClient +from agent_framework_durabletask._shim import DurableAIAgent + + +@pytest.fixture +def mock_grpc_client() -> Mock: + """Create a mock TaskHubGrpcClient for testing.""" + return Mock() + + +@pytest.fixture +def agent_client(mock_grpc_client: Mock) -> DurableAIAgentClient: + """Create a DurableAIAgentClient with mock gRPC client.""" + return DurableAIAgentClient(mock_grpc_client) + + +class TestDurableAIAgentClientGetAgent: + """Test core workflow: retrieving agents from the client.""" + + def test_get_agent_returns_durable_agent_shim(self, agent_client: DurableAIAgentClient) -> None: + """Verify get_agent returns a DurableAIAgent instance.""" + agent = agent_client.get_agent("assistant") + + assert isinstance(agent, DurableAIAgent) + assert isinstance(agent, AgentProtocol) + + def test_get_agent_shim_has_correct_name(self, agent_client: DurableAIAgentClient) -> None: + """Verify retrieved agent has the correct name.""" + agent = agent_client.get_agent("my_agent") + + assert agent.name == "my_agent" + + def test_get_agent_multiple_times_returns_new_instances(self, agent_client: DurableAIAgentClient) -> None: + """Verify multiple get_agent calls return independent instances.""" + agent1 = agent_client.get_agent("assistant") + agent2 = agent_client.get_agent("assistant") + + assert agent1 is not agent2 # Different object instances + + def test_get_agent_different_agents(self, agent_client: DurableAIAgentClient) -> None: + """Verify client can retrieve multiple different agents.""" + agent1 = agent_client.get_agent("agent1") + agent2 = agent_client.get_agent("agent2") + + assert agent1.name == "agent1" + assert agent2.name == "agent2" + + +class TestDurableAIAgentClientIntegration: + """Test integration scenarios between client and agent shim.""" + + def test_client_agent_has_working_run_method(self, agent_client: DurableAIAgentClient) -> None: + """Verify agent from client has callable run method (even if not yet implemented).""" + agent = agent_client.get_agent("assistant") + + assert hasattr(agent, "run") + assert callable(agent.run) + + def test_client_agent_can_create_threads(self, agent_client: DurableAIAgentClient) -> None: + """Verify agent from client can create DurableAgentThread instances.""" + agent = agent_client.get_agent("assistant") + + thread = agent.get_new_thread() + + assert isinstance(thread, DurableAgentThread) + + def test_client_agent_thread_with_parameters(self, agent_client: DurableAIAgentClient) -> None: + """Verify agent can create threads with custom parameters.""" + agent = agent_client.get_agent("assistant") + + thread = agent.get_new_thread(service_thread_id="client-session-123") + + assert isinstance(thread, DurableAgentThread) + assert thread.service_thread_id == "client-session-123" + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "--tb=short"]) diff --git a/python/packages/durabletask/tests/test_executors.py b/python/packages/durabletask/tests/test_executors.py new file mode 100644 index 0000000000..ff6efe768e --- /dev/null +++ b/python/packages/durabletask/tests/test_executors.py @@ -0,0 +1,83 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Unit tests for DurableAgentExecutor implementations. + +Focuses on critical behavioral flows for executor strategies. +Run with: pytest tests/test_executors.py -v +""" + +from unittest.mock import Mock + +import pytest + +from agent_framework_durabletask import DurableAgentThread +from agent_framework_durabletask._executors import ( + ClientAgentExecutor, + OrchestrationAgentExecutor, +) + + +class TestExecutorThreadCreation: + """Test that executors properly create DurableAgentThread with parameters.""" + + def test_client_executor_creates_durable_thread(self) -> None: + """Verify ClientAgentExecutor creates DurableAgentThread instances.""" + mock_client = Mock() + executor = ClientAgentExecutor(mock_client) + + thread = executor.get_new_thread("test_agent") + + assert isinstance(thread, DurableAgentThread) + + def test_client_executor_forwards_kwargs_to_thread(self) -> None: + """Verify ClientAgentExecutor forwards kwargs to DurableAgentThread creation.""" + mock_client = Mock() + executor = ClientAgentExecutor(mock_client) + + thread = executor.get_new_thread("test_agent", service_thread_id="client-123") + + assert isinstance(thread, DurableAgentThread) + assert thread.service_thread_id == "client-123" + + def test_orchestration_executor_creates_durable_thread(self) -> None: + """Verify OrchestrationAgentExecutor creates DurableAgentThread instances.""" + mock_context = Mock() + executor = OrchestrationAgentExecutor(mock_context) + + thread = executor.get_new_thread("test_agent") + + assert isinstance(thread, DurableAgentThread) + + def test_orchestration_executor_forwards_kwargs_to_thread(self) -> None: + """Verify OrchestrationAgentExecutor forwards kwargs to DurableAgentThread creation.""" + mock_context = Mock() + executor = OrchestrationAgentExecutor(mock_context) + + thread = executor.get_new_thread("test_agent", service_thread_id="orch-456") + + assert isinstance(thread, DurableAgentThread) + assert thread.service_thread_id == "orch-456" + + +class TestExecutorRunNotImplemented: + """Test that run_durable_agent raises NotImplementedError until wired.""" + + async def test_client_executor_run_not_implemented(self) -> None: + """Verify ClientAgentExecutor run raises NotImplementedError until implementation.""" + mock_client = Mock() + executor = ClientAgentExecutor(mock_client) + + with pytest.raises(NotImplementedError, match="ClientAgentProvider.run_durable_agent"): + await executor.run_durable_agent("test_agent", "test message") + + def test_orchestration_executor_run_not_implemented(self) -> None: + """Verify OrchestrationAgentExecutor run raises NotImplementedError until implementation.""" + mock_context = Mock() + executor = OrchestrationAgentExecutor(mock_context) + + with pytest.raises(NotImplementedError, match="OrchestrationAgentProvider.run_durable_agent"): + executor.run_durable_agent("test_agent", "test message") + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "--tb=short"]) diff --git a/python/packages/durabletask/tests/test_orchestration_context.py b/python/packages/durabletask/tests/test_orchestration_context.py new file mode 100644 index 0000000000..f6a7755335 --- /dev/null +++ b/python/packages/durabletask/tests/test_orchestration_context.py @@ -0,0 +1,98 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Unit tests for DurableAIAgentOrchestrationContext. + +Focuses on critical orchestration workflows: agent retrieval and integration. +Run with: pytest tests/test_orchestration_context.py -v +""" + +from unittest.mock import Mock + +import pytest +from agent_framework import AgentProtocol + +from agent_framework_durabletask import DurableAgentThread +from agent_framework_durabletask._orchestration_context import DurableAIAgentOrchestrationContext +from agent_framework_durabletask._shim import DurableAIAgent + + +@pytest.fixture +def mock_orchestration_context() -> Mock: + """Create a mock OrchestrationContext for testing.""" + return Mock() + + +@pytest.fixture +def agent_context(mock_orchestration_context: Mock) -> DurableAIAgentOrchestrationContext: + """Create a DurableAIAgentOrchestrationContext with mock context.""" + return DurableAIAgentOrchestrationContext(mock_orchestration_context) + + +class TestDurableAIAgentOrchestrationContextGetAgent: + """Test core workflow: retrieving agents from orchestration context.""" + + def test_get_agent_returns_durable_agent_shim(self, agent_context: DurableAIAgentOrchestrationContext) -> None: + """Verify get_agent returns a DurableAIAgent instance.""" + agent = agent_context.get_agent("assistant") + + assert isinstance(agent, DurableAIAgent) + assert isinstance(agent, AgentProtocol) + + def test_get_agent_shim_has_correct_name(self, agent_context: DurableAIAgentOrchestrationContext) -> None: + """Verify retrieved agent has the correct name.""" + agent = agent_context.get_agent("my_agent") + + assert agent.name == "my_agent" + + def test_get_agent_multiple_times_returns_new_instances( + self, agent_context: DurableAIAgentOrchestrationContext + ) -> None: + """Verify multiple get_agent calls return independent instances.""" + agent1 = agent_context.get_agent("assistant") + agent2 = agent_context.get_agent("assistant") + + assert agent1 is not agent2 # Different object instances + + def test_get_agent_different_agents(self, agent_context: DurableAIAgentOrchestrationContext) -> None: + """Verify context can retrieve multiple different agents.""" + agent1 = agent_context.get_agent("agent1") + agent2 = agent_context.get_agent("agent2") + + assert agent1.name == "agent1" + assert agent2.name == "agent2" + + +class TestDurableAIAgentOrchestrationContextIntegration: + """Test integration scenarios between orchestration context and agent shim.""" + + def test_orchestration_agent_has_working_run_method( + self, agent_context: DurableAIAgentOrchestrationContext + ) -> None: + """Verify agent from context has callable run method (even if not yet implemented).""" + agent = agent_context.get_agent("assistant") + + assert hasattr(agent, "run") + assert callable(agent.run) + + def test_orchestration_agent_can_create_threads(self, agent_context: DurableAIAgentOrchestrationContext) -> None: + """Verify agent from context can create DurableAgentThread instances.""" + agent = agent_context.get_agent("assistant") + + thread = agent.get_new_thread() + + assert isinstance(thread, DurableAgentThread) + + def test_orchestration_agent_thread_with_parameters( + self, agent_context: DurableAIAgentOrchestrationContext + ) -> None: + """Verify agent can create threads with custom parameters.""" + agent = agent_context.get_agent("assistant") + + thread = agent.get_new_thread(service_thread_id="orch-session-456") + + assert isinstance(thread, DurableAgentThread) + assert thread.service_thread_id == "orch-session-456" + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "--tb=short"]) diff --git a/python/packages/durabletask/tests/test_shim.py b/python/packages/durabletask/tests/test_shim.py new file mode 100644 index 0000000000..79c8f641da --- /dev/null +++ b/python/packages/durabletask/tests/test_shim.py @@ -0,0 +1,197 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Unit tests for DurableAIAgent shim and DurableAgentProvider. + +Focuses on critical message normalization, delegation, and protocol compliance. +Run with: pytest tests/test_shim.py -v +""" + +from unittest.mock import Mock + +import pytest +from agent_framework import AgentProtocol, ChatMessage +from pydantic import BaseModel + +from agent_framework_durabletask import DurableAgentThread +from agent_framework_durabletask._executors import DurableAgentExecutor +from agent_framework_durabletask._shim import DurableAgentProvider, DurableAIAgent + + +class ResponseFormatModel(BaseModel): + """Test Pydantic model for response format testing.""" + + result: str + + +@pytest.fixture +def mock_executor() -> Mock: + """Create a mock executor for testing.""" + mock = Mock(spec=DurableAgentExecutor) + mock.run_durable_agent = Mock(return_value=None) + mock.get_new_thread = Mock(return_value=DurableAgentThread()) + return mock + + +@pytest.fixture +def test_agent(mock_executor: Mock) -> DurableAIAgent: + """Create a test agent with mock executor.""" + return DurableAIAgent(mock_executor, "test_agent") + + +class TestDurableAIAgentMessageNormalization: + """Test that DurableAIAgent properly normalizes various message input types.""" + + def test_run_accepts_string_message(self, test_agent: DurableAIAgent, mock_executor: Mock) -> None: + """Verify run accepts and normalizes string messages.""" + test_agent.run("Hello, world!") + + mock_executor.run_durable_agent.assert_called_once() + # Verify agent_name and message were passed correctly as kwargs + _, kwargs = mock_executor.run_durable_agent.call_args + assert kwargs["agent_name"] == "test_agent" + assert kwargs["message"] == "Hello, world!" + + def test_run_accepts_chat_message(self, test_agent: DurableAIAgent, mock_executor: Mock) -> None: + """Verify run accepts and normalizes ChatMessage objects.""" + chat_msg = ChatMessage(role="user", text="Test message") + test_agent.run(chat_msg) + + mock_executor.run_durable_agent.assert_called_once() + _, kwargs = mock_executor.run_durable_agent.call_args + assert kwargs["message"] == "Test message" + + def test_run_accepts_list_of_strings(self, test_agent: DurableAIAgent, mock_executor: Mock) -> None: + """Verify run accepts and joins list of strings.""" + test_agent.run(["First message", "Second message"]) + + mock_executor.run_durable_agent.assert_called_once() + _, kwargs = mock_executor.run_durable_agent.call_args + assert kwargs["message"] == "First message\nSecond message" + + def test_run_accepts_list_of_chat_messages(self, test_agent: DurableAIAgent, mock_executor: Mock) -> None: + """Verify run accepts and joins list of ChatMessage objects.""" + messages = [ + ChatMessage(role="user", text="Message 1"), + ChatMessage(role="assistant", text="Message 2"), + ] + test_agent.run(messages) + + mock_executor.run_durable_agent.assert_called_once() + _, kwargs = mock_executor.run_durable_agent.call_args + assert kwargs["message"] == "Message 1\nMessage 2" + + def test_run_handles_none_message(self, test_agent: DurableAIAgent, mock_executor: Mock) -> None: + """Verify run handles None message gracefully.""" + test_agent.run(None) + + mock_executor.run_durable_agent.assert_called_once() + _, kwargs = mock_executor.run_durable_agent.call_args + assert kwargs["message"] == "" + + def test_run_handles_empty_list(self, test_agent: DurableAIAgent, mock_executor: Mock) -> None: + """Verify run handles empty list gracefully.""" + test_agent.run([]) + + mock_executor.run_durable_agent.assert_called_once() + _, kwargs = mock_executor.run_durable_agent.call_args + assert kwargs["message"] == "" + + +class TestDurableAIAgentParameterFlow: + """Test that parameters flow correctly through the shim to executor.""" + + def test_run_forwards_thread_parameter(self, test_agent: DurableAIAgent, mock_executor: Mock) -> None: + """Verify run forwards thread parameter to executor.""" + thread = DurableAgentThread(service_thread_id="test-thread") + test_agent.run("message", thread=thread) + + mock_executor.run_durable_agent.assert_called_once() + _, kwargs = mock_executor.run_durable_agent.call_args + assert kwargs["thread"] == thread + + def test_run_forwards_response_format(self, test_agent: DurableAIAgent, mock_executor: Mock) -> None: + """Verify run forwards response_format parameter to executor.""" + test_agent.run("message", response_format=ResponseFormatModel) + + mock_executor.run_durable_agent.assert_called_once() + _, kwargs = mock_executor.run_durable_agent.call_args + assert kwargs["response_format"] == ResponseFormatModel + + def test_run_forwards_additional_kwargs(self, test_agent: DurableAIAgent, mock_executor: Mock) -> None: + """Verify run forwards additional kwargs to executor.""" + test_agent.run("message", custom_param="custom_value") + + mock_executor.run_durable_agent.assert_called_once() + _, kwargs = mock_executor.run_durable_agent.call_args + assert kwargs["custom_param"] == "custom_value" + + +class TestDurableAIAgentProtocolCompliance: + """Test that DurableAIAgent implements AgentProtocol correctly.""" + + def test_agent_implements_protocol(self, test_agent: DurableAIAgent) -> None: + """Verify DurableAIAgent implements AgentProtocol.""" + assert isinstance(test_agent, AgentProtocol) + + def test_agent_has_required_properties(self, test_agent: DurableAIAgent) -> None: + """Verify DurableAIAgent has all required AgentProtocol properties.""" + assert hasattr(test_agent, "id") + assert hasattr(test_agent, "name") + assert hasattr(test_agent, "display_name") + assert hasattr(test_agent, "description") + + def test_agent_id_defaults_to_name(self, mock_executor: Mock) -> None: + """Verify agent id defaults to name when not provided.""" + agent = DurableAIAgent(mock_executor, "my_agent") + + assert agent.id == "my_agent" + assert agent.name == "my_agent" + + def test_agent_id_can_be_customized(self, mock_executor: Mock) -> None: + """Verify agent id can be set independently from name.""" + agent = DurableAIAgent(mock_executor, "my_agent", agent_id="custom-id") + + assert agent.id == "custom-id" + assert agent.name == "my_agent" + + +class TestDurableAIAgentThreadManagement: + """Test thread creation and management.""" + + def test_get_new_thread_delegates_to_executor(self, test_agent: DurableAIAgent, mock_executor: Mock) -> None: + """Verify get_new_thread delegates to executor.""" + mock_thread = DurableAgentThread() + mock_executor.get_new_thread.return_value = mock_thread + + thread = test_agent.get_new_thread() + + mock_executor.get_new_thread.assert_called_once_with("test_agent") + assert thread == mock_thread + + def test_get_new_thread_forwards_kwargs(self, test_agent: DurableAIAgent, mock_executor: Mock) -> None: + """Verify get_new_thread forwards kwargs to executor.""" + mock_thread = DurableAgentThread(service_thread_id="thread-123") + mock_executor.get_new_thread.return_value = mock_thread + + test_agent.get_new_thread(service_thread_id="thread-123") + + mock_executor.get_new_thread.assert_called_once() + _, kwargs = mock_executor.get_new_thread.call_args + assert kwargs["service_thread_id"] == "thread-123" + + +class TestDurableAgentProviderInterface: + """Test that DurableAgentProvider defines the correct interface.""" + + def test_provider_cannot_be_instantiated(self) -> None: + """Verify DurableAgentProvider is abstract and cannot be instantiated.""" + with pytest.raises(TypeError): + DurableAgentProvider() # type: ignore[abstract] + + def test_provider_defines_get_agent_method(self) -> None: + """Verify DurableAgentProvider defines get_agent abstract method.""" + assert hasattr(DurableAgentProvider, "get_agent") + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "--tb=short"]) diff --git a/python/packages/durabletask/tests/test_worker.py b/python/packages/durabletask/tests/test_worker.py new file mode 100644 index 0000000000..e6dabcdfdf --- /dev/null +++ b/python/packages/durabletask/tests/test_worker.py @@ -0,0 +1,168 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Unit tests for DurableAIAgentWorker. + +Focuses on critical worker flows: agent registration, validation, callbacks, and lifecycle. +""" + +from unittest.mock import Mock + +import pytest + +from agent_framework_durabletask import DurableAIAgentWorker + + +@pytest.fixture +def mock_grpc_worker() -> Mock: + """Create a mock TaskHubGrpcWorker for testing.""" + mock = Mock() + mock.add_entity = Mock(return_value="dafx-test_agent") + mock.start = Mock() + mock.stop = Mock() + return mock + + +@pytest.fixture +def mock_agent() -> Mock: + """Create a mock agent for testing.""" + agent = Mock() + agent.name = "test_agent" + return agent + + +@pytest.fixture +def agent_worker(mock_grpc_worker: Mock) -> DurableAIAgentWorker: + """Create a DurableAIAgentWorker with mock worker.""" + return DurableAIAgentWorker(mock_grpc_worker) + + +class TestDurableAIAgentWorkerRegistration: + """Test agent registration behavior.""" + + def test_add_agent_accepts_agent_with_name( + self, agent_worker: DurableAIAgentWorker, mock_agent: Mock, mock_grpc_worker: Mock + ) -> None: + """Verify that agents with names can be registered.""" + agent_worker.add_agent(mock_agent) + + # Verify entity was registered with underlying worker + mock_grpc_worker.add_entity.assert_called_once() + # Verify agent name is tracked + assert "test_agent" in agent_worker.registered_agent_names + + def test_add_agent_rejects_agent_without_name(self, agent_worker: DurableAIAgentWorker) -> None: + """Verify that agents without names are rejected.""" + agent_no_name = Mock() + agent_no_name.name = None + + with pytest.raises(ValueError, match="Agent must have a name"): + agent_worker.add_agent(agent_no_name) + + def test_add_agent_rejects_empty_name(self, agent_worker: DurableAIAgentWorker) -> None: + """Verify that agents with empty names are rejected.""" + agent_empty_name = Mock() + agent_empty_name.name = "" + + with pytest.raises(ValueError, match="Agent must have a name"): + agent_worker.add_agent(agent_empty_name) + + def test_add_agent_rejects_duplicate_names(self, agent_worker: DurableAIAgentWorker, mock_agent: Mock) -> None: + """Verify duplicate agent names are not allowed.""" + agent_worker.add_agent(mock_agent) + + # Try to register another agent with the same name + duplicate_agent = Mock() + duplicate_agent.name = "test_agent" + + with pytest.raises(ValueError, match="already registered"): + agent_worker.add_agent(duplicate_agent) + + def test_registered_agent_names_tracks_multiple_agents(self, agent_worker: DurableAIAgentWorker) -> None: + """Verify registered_agent_names tracks all registered agents.""" + agent1 = Mock() + agent1.name = "agent1" + agent2 = Mock() + agent2.name = "agent2" + agent3 = Mock() + agent3.name = "agent3" + + agent_worker.add_agent(agent1) + agent_worker.add_agent(agent2) + agent_worker.add_agent(agent3) + + registered = agent_worker.registered_agent_names + assert "agent1" in registered + assert "agent2" in registered + assert "agent3" in registered + assert len(registered) == 3 + + +class TestDurableAIAgentWorkerCallbacks: + """Test callback configuration behavior.""" + + def test_worker_level_callback_accepted(self, mock_grpc_worker: Mock) -> None: + """Verify worker-level callback can be set.""" + mock_callback = Mock() + agent_worker = DurableAIAgentWorker(mock_grpc_worker, callback=mock_callback) + + assert agent_worker is not None + + def test_agent_level_callback_accepted(self, agent_worker: DurableAIAgentWorker, mock_agent: Mock) -> None: + """Verify agent-level callback can be set during registration.""" + mock_callback = Mock() + + # Should not raise exception + agent_worker.add_agent(mock_agent, callback=mock_callback) + + assert "test_agent" in agent_worker.registered_agent_names + + def test_none_callback_accepted(self, mock_grpc_worker: Mock, mock_agent: Mock) -> None: + """Verify None callback is valid (no callbacks required).""" + agent_worker = DurableAIAgentWorker(mock_grpc_worker, callback=None) + agent_worker.add_agent(mock_agent, callback=None) + + assert "test_agent" in agent_worker.registered_agent_names + + +class TestDurableAIAgentWorkerLifecycle: + """Test worker lifecycle behavior.""" + + def test_start_delegates_to_underlying_worker( + self, agent_worker: DurableAIAgentWorker, mock_grpc_worker: Mock + ) -> None: + """Verify start() delegates to wrapped worker.""" + agent_worker.start() + + mock_grpc_worker.start.assert_called_once() + + def test_stop_delegates_to_underlying_worker( + self, agent_worker: DurableAIAgentWorker, mock_grpc_worker: Mock + ) -> None: + """Verify stop() delegates to wrapped worker.""" + agent_worker.stop() + + mock_grpc_worker.stop.assert_called_once() + + def test_start_works_with_no_agents(self, agent_worker: DurableAIAgentWorker, mock_grpc_worker: Mock) -> None: + """Verify worker can start even with no agents registered.""" + agent_worker.start() + + mock_grpc_worker.start.assert_called_once() + + def test_start_works_with_multiple_agents(self, agent_worker: DurableAIAgentWorker, mock_grpc_worker: Mock) -> None: + """Verify worker can start with multiple agents registered.""" + agent1 = Mock() + agent1.name = "agent1" + agent2 = Mock() + agent2.name = "agent2" + + agent_worker.add_agent(agent1) + agent_worker.add_agent(agent2) + agent_worker.start() + + mock_grpc_worker.start.assert_called_once() + assert len(agent_worker.registered_agent_names) == 2 + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "--tb=short"]) From 6e94ca8a9964145f56b83e178f7bd12985751a03 Mon Sep 17 00:00:00 2001 From: Laveesh Rohra Date: Mon, 29 Dec 2025 12:46:52 -0800 Subject: [PATCH 02/11] Clean code and refactor common code --- .../agent_framework_azurefunctions/_app.py | 6 +- .../_orchestration.py | 50 ++------- .../agent_framework_durabletask/__init__.py | 6 +- .../agent_framework_durabletask/_executors.py | 106 ++++++++++++++++-- .../_response_utils.py | 66 +++++++++++ .../agent_framework_durabletask/_shim.py | 49 +++++--- .../function_app.py | 9 +- .../function_app.py | 9 +- .../function_app.py | 14 +-- .../function_app.py | 22 ++-- 10 files changed, 228 insertions(+), 109 deletions(-) create mode 100644 python/packages/durabletask/agent_framework_durabletask/_response_utils.py diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py index 2f4df67e58..49d3cb5251 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py @@ -36,7 +36,7 @@ from ._entities import create_agent_entity from ._errors import IncomingRequestError -from ._orchestration import AgentOrchestrationContextType, AzureFunctionsAgentExecutor +from ._orchestration import AgentOrchestrationContextType, AgentTask, AzureFunctionsAgentExecutor logger = get_logger("agent_framework.azurefunctions") @@ -297,7 +297,7 @@ def get_agent( self, context: AgentOrchestrationContextType, agent_name: str, - ) -> DurableAIAgent: + ) -> DurableAIAgent[AgentTask]: """Return a DurableAIAgent proxy for a registered agent. Args: @@ -308,7 +308,7 @@ def get_agent( ValueError: If the requested agent has not been registered. Returns: - DurableAIAgent wrapper bound to the orchestration context. + DurableAIAgent[AgentTask] wrapper bound to the orchestration context. """ normalized_name = str(agent_name) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py index 5e03bd934d..2d4b424b1e 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py @@ -9,12 +9,15 @@ from typing import TYPE_CHECKING, Any, TypeAlias import azure.durable_functions as df -from agent_framework import ( - AgentRunResponse, - AgentThread, - get_logger, +from agent_framework import AgentThread, get_logger +from agent_framework_durabletask import ( + AgentSessionId, + DurableAgentExecutor, + DurableAgentThread, + RunRequest, + ensure_response_format, + load_agent_response, ) -from agent_framework_durabletask import AgentSessionId, DurableAgentExecutor, DurableAgentThread, RunRequest from azure.durable_functions.models import TaskBase from azure.durable_functions.models.Task import CompoundTask, TaskState from pydantic import BaseModel @@ -91,10 +94,10 @@ def try_set_value(self, child: TaskBase) -> None: ) try: - response = self._load_agent_response(raw_result) + response = load_agent_response(raw_result) if self._response_format is not None: - self._ensure_response_format( + ensure_response_format( self._response_format, self._correlation_id, response, @@ -114,39 +117,8 @@ def try_set_value(self, child: TaskBase) -> None: self._first_error = child.result self.set_value(is_error=True, value=self._first_error) - def _load_agent_response(self, agent_response: AgentRunResponse | dict[str, Any] | None) -> AgentRunResponse: - """Convert raw payloads into AgentRunResponse instance.""" - if agent_response is None: - raise ValueError("agent_response cannot be None") - logger.debug("[load_agent_response] Loading agent response of type: %s", type(agent_response)) - - if isinstance(agent_response, AgentRunResponse): - return agent_response - if isinstance(agent_response, dict): - logger.debug("[load_agent_response] Converting dict payload using AgentRunResponse.from_dict") - return AgentRunResponse.from_dict(agent_response) - - raise TypeError(f"Unsupported type for agent_response: {type(agent_response)}") - - def _ensure_response_format( - self, - response_format: type[BaseModel] | None, - correlation_id: str, - response: AgentRunResponse, - ) -> None: - """Ensure the AgentRunResponse value is parsed into the expected response_format.""" - if response_format is not None and not isinstance(response.value, response_format): - response.try_parse_value(response_format) - - logger.debug( - "[DurableAIAgent] Loaded AgentRunResponse.value for correlation_id %s with type: %s", - correlation_id, - type(response.value).__name__, - ) - - -class AzureFunctionsAgentExecutor(DurableAgentExecutor): +class AzureFunctionsAgentExecutor(DurableAgentExecutor[AgentTask]): """Executor that executes durable agents inside Azure Functions orchestrations.""" def __init__(self, context: AgentOrchestrationContextType): diff --git a/python/packages/durabletask/agent_framework_durabletask/__init__.py b/python/packages/durabletask/agent_framework_durabletask/__init__.py index 10d880f001..3283cf8959 100644 --- a/python/packages/durabletask/agent_framework_durabletask/__init__.py +++ b/python/packages/durabletask/agent_framework_durabletask/__init__.py @@ -43,8 +43,9 @@ ) from ._entities import AgentEntity, AgentEntityStateProviderMixin from ._executors import DurableAgentExecutor -from ._models import AgentSessionId, DurableAgentThread, RunRequest, serialize_response_format +from ._models import AgentSessionId, DurableAgentThread, RunRequest from ._orchestration_context import DurableAIAgentOrchestrationContext +from ._response_utils import ensure_response_format, load_agent_response from ._shim import DurableAIAgent from ._worker import DurableAIAgentWorker @@ -95,5 +96,6 @@ "DurableAgentThread", "DurableStateFields", "RunRequest", - "serialize_response_format", + "ensure_response_format", + "load_agent_response", ] diff --git a/python/packages/durabletask/agent_framework_durabletask/_executors.py b/python/packages/durabletask/agent_framework_durabletask/_executors.py index d165166c60..e6061716fc 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_executors.py +++ b/python/packages/durabletask/agent_framework_durabletask/_executors.py @@ -11,12 +11,14 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Generic, TypeVar -from agent_framework import AgentThread, get_logger +from agent_framework import AgentRunResponse, AgentThread, get_logger +from durabletask.task import CompositeTask, Task from pydantic import BaseModel from ._models import DurableAgentThread +from ._response_utils import ensure_response_format, load_agent_response if TYPE_CHECKING: # pragma: no cover from durabletask.client import TaskHubGrpcClient @@ -24,9 +26,91 @@ logger = get_logger("agent_framework.durabletask.executors") +# TypeVar for the task type returned by executors +TaskT = TypeVar("TaskT") -class DurableAgentExecutor(ABC): - """Abstract base class for durable agent execution strategies.""" + +class DurableAgentTask(CompositeTask[AgentRunResponse]): + """A custom Task that wraps entity calls and provides typed AgentRunResponse results. + + This task wraps the underlying entity call task and intercepts its completion + to convert the raw result into a typed AgentRunResponse object. + """ + + def __init__( + self, + entity_task: Task[AgentRunResponse], + response_format: type[BaseModel] | None, + correlation_id: str, + ): + """Initialize the DurableAgentTask. + + Args: + entity_task: The underlying entity call task + response_format: Optional Pydantic model for response parsing + correlation_id: Correlation ID for logging + """ + super().__init__([entity_task]) # type: ignore[misc] + self._response_format = response_format + self._correlation_id = correlation_id + + def on_child_completed(self, task: Task[Any]) -> None: + """Handle completion of the underlying entity task. + + Parameters + ---------- + task : Task + The entity call task that just completed + """ + if self.is_complete: + return + + if task.is_failed: + # Propagate the failure + self._exception = task.get_exception() + self._is_complete = True + if self._parent is not None: + self._parent.on_child_completed(self) + return + + # Task succeeded - transform the raw result + raw_result = task.get_result() + logger.debug( + "[DurableAgentTask] Converting raw result for correlation_id %s", + self._correlation_id, + ) + + try: + response = load_agent_response(raw_result) + + if self._response_format is not None: + ensure_response_format( + self._response_format, + self._correlation_id, + response, + ) + + # Set the typed AgentRunResponse as this task's result + self._result = response + self._is_complete = True + + if self._parent is not None: + self._parent.on_child_completed(self) + + except Exception: + logger.exception( + "[DurableAgentTask] Failed to convert result for correlation_id: %s", + self._correlation_id, + ) + raise + + +class DurableAgentExecutor(ABC, Generic[TaskT]): + """Abstract base class for durable agent execution strategies. + + Type Parameters: + TaskT: The task type returned by this executor + """ @abstractmethod def run_durable_agent( @@ -37,11 +121,11 @@ def run_durable_agent( thread: AgentThread | None = None, response_format: type[BaseModel] | None = None, **kwargs: Any, - ) -> Any: + ) -> TaskT: """Execute the durable agent. Returns: - Any: Either an awaitable AgentRunResponse (Client) or a yieldable Task (Orchestrator). + TaskT: The task type specific to this executor implementation """ raise NotImplementedError @@ -51,14 +135,14 @@ def get_new_thread(self, agent_name: str, **kwargs: Any) -> DurableAgentThread: raise NotImplementedError -class ClientAgentExecutor(DurableAgentExecutor): +class ClientAgentExecutor(DurableAgentExecutor[DurableAgentTask]): """Execution strategy for external clients (async).""" def __init__(self, client: TaskHubGrpcClient): self._client = client logger.debug("[ClientAgentExecutor] Initialized with client type: %s", type(client).__name__) - async def run_durable_agent( + def run_durable_agent( self, agent_name: str, message: str, @@ -66,7 +150,7 @@ async def run_durable_agent( thread: AgentThread | None = None, response_format: type[BaseModel] | None = None, **kwargs: Any, - ) -> Any: + ) -> DurableAgentTask: """Execute the agent via the durabletask client. Note: Implementation is backend-specific and should signal/call the entity @@ -80,7 +164,7 @@ def get_new_thread(self, agent_name: str, **kwargs: Any) -> DurableAgentThread: return DurableAgentThread(**kwargs) -class OrchestrationAgentExecutor(DurableAgentExecutor): +class OrchestrationAgentExecutor(DurableAgentExecutor[DurableAgentTask]): """Execution strategy for orchestrations (sync/yield).""" def __init__(self, context: OrchestrationContext): @@ -95,7 +179,7 @@ def run_durable_agent( thread: AgentThread | None = None, response_format: type[BaseModel] | None = None, **kwargs: Any, - ) -> Any: + ) -> DurableAgentTask: """Execute the agent via orchestration context. Note: Implementation should call the entity (e.g., context.call_entity) diff --git a/python/packages/durabletask/agent_framework_durabletask/_response_utils.py b/python/packages/durabletask/agent_framework_durabletask/_response_utils.py new file mode 100644 index 0000000000..ec7ed50c0c --- /dev/null +++ b/python/packages/durabletask/agent_framework_durabletask/_response_utils.py @@ -0,0 +1,66 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Shared utilities for handling AgentRunResponse parsing and validation. + +These utilities are used by both DurableAgentTask and AzureFunctions AgentTask +to ensure consistent response handling across different durable task implementations. +""" + +from typing import Any + +from agent_framework import AgentRunResponse, get_logger +from pydantic import BaseModel + +logger = get_logger("agent_framework.durabletask.response_utils") + + +def load_agent_response(agent_response: AgentRunResponse | dict[str, Any] | None) -> AgentRunResponse: + """Convert raw payloads into AgentRunResponse instance. + + Args: + agent_response: The response to convert, can be an AgentRunResponse, dict, or None + + Returns: + AgentRunResponse: The converted response object + + Raises: + ValueError: If agent_response is None + TypeError: If agent_response is an unsupported type + """ + if agent_response is None: + raise ValueError("agent_response cannot be None") + + logger.debug("[load_agent_response] Loading agent response of type: %s", type(agent_response)) + + if isinstance(agent_response, AgentRunResponse): + return agent_response + if isinstance(agent_response, dict): + logger.debug("[load_agent_response] Converting dict payload using AgentRunResponse.from_dict") + return AgentRunResponse.from_dict(agent_response) + + raise TypeError(f"Unsupported type for agent_response: {type(agent_response)}") + + +def ensure_response_format( + response_format: type[BaseModel] | None, + correlation_id: str, + response: AgentRunResponse, +) -> None: + """Ensure the AgentRunResponse value is parsed into the expected response_format. + + This function modifies the response in-place by parsing its value attribute + into the specified Pydantic model format. + + Args: + response_format: Optional Pydantic model class to parse the response value into + correlation_id: Correlation ID for logging purposes + response: The AgentRunResponse object to validate and parse + """ + if response_format is not None and not isinstance(response.value, response_format): + response.try_parse_value(response_format) + + logger.debug( + "[ensure_response_format] Loaded AgentRunResponse.value for correlation_id %s with type: %s", + correlation_id, + type(response.value).__name__, + ) diff --git a/python/packages/durabletask/agent_framework_durabletask/_shim.py b/python/packages/durabletask/agent_framework_durabletask/_shim.py index 696bddf611..3b4cdcc95f 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_shim.py +++ b/python/packages/durabletask/agent_framework_durabletask/_shim.py @@ -11,7 +11,7 @@ from abc import ABC, abstractmethod from collections.abc import AsyncIterator -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Generic, TypeVar from agent_framework import AgentProtocol, AgentRunResponseUpdate, AgentThread, ChatMessage from pydantic import BaseModel @@ -19,8 +19,12 @@ if TYPE_CHECKING: from ._executors import DurableAgentExecutor +# TypeVar for the task type returned by executors +# Covariant because TaskT only appears in return positions (output) +TaskT = TypeVar("TaskT", covariant=True) -class DurableAgentProvider(ABC): + +class DurableAgentProvider(ABC, Generic[TaskT]): """Abstract provider for constructing durable agent proxies. Implemented by context-specific wrappers (client/orchestration) to return a @@ -30,7 +34,7 @@ class DurableAgentProvider(ABC): """ @abstractmethod - def get_agent(self, agent_name: str) -> DurableAIAgent: + def get_agent(self, agent_name: str) -> DurableAIAgent[TaskT]: """Retrieve a DurableAIAgent shim for the specified agent. Args: @@ -45,17 +49,21 @@ def get_agent(self, agent_name: str) -> DurableAIAgent: raise NotImplementedError("Subclasses must implement get_agent()") -class DurableAIAgent(AgentProtocol): +class DurableAIAgent(AgentProtocol, Generic[TaskT]): """A durable agent proxy that delegates execution to the provider. - This class implements AgentProtocol but doesn't contain any agent logic itself. - Instead, it serves as a consistent interface that delegates to the underlying - provider, which can be either: - - DurableAIAgentClient (for external usage via HTTP/gRPC) - - DurableAIAgentOrchestrationContext (for use inside orchestrations) + This class implements AgentProtocol but with one critical difference: + - AgentProtocol.run() returns a Coroutine (async, must await) + - DurableAIAgent.run() returns TaskT (sync Task object, must yield) + + This represents fundamentally different execution models but maintains the same + interface contract for all other properties and methods. + + The underlying provider determines how execution occurs (entity calls, HTTP requests, etc.) + and what type of Task object is returned. - The provider determines how execution occurs (entity calls, HTTP requests, etc.) - and what type of Task object is returned (asyncio.Task vs durabletask.Task). + Type Parameters: + TaskT: The task type returned by this agent (e.g., DurableAgentTask, AgentTask) Note: This class intentionally does NOT inherit from BaseAgent because: @@ -64,7 +72,7 @@ class DurableAIAgent(AgentProtocol): - BaseAgent methods like as_tool() would fail in orchestrations """ - def __init__(self, executor: DurableAgentExecutor, name: str, *, agent_id: str | None = None): + def __init__(self, executor: DurableAgentExecutor[TaskT], name: str, *, agent_id: str | None = None): """Initialize the shim with a provider and agent name. Args: @@ -98,17 +106,27 @@ def description(self) -> str | None: """Get the description of the agent.""" return self._description - def run( + def run( # pyright: ignore[reportIncompatibleMethodOverride] self, messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, *, thread: AgentThread | None = None, response_format: type[BaseModel] | None = None, + enable_tool_calls: bool | None = None, **kwargs: Any, - ) -> Any: + ) -> TaskT: """Execute the agent via the injected provider. - The provider determines whether the return is awaitable (client) or yieldable (orchestration). + Note: + This method overrides AgentProtocol.run() with a different return type: + - AgentProtocol.run() returns Coroutine[Any, Any, AgentRunResponse] (async) + - DurableAIAgent.run() returns TaskT (Task object for yielding) + + This is intentional to support orchestration contexts that use yield patterns + instead of async/await patterns. + + Returns: + TaskT: The task type specific to the executor (e.g., DurableAgentTask or AgentTask) """ message_str = self._normalize_messages(messages) return self._executor.run_durable_agent( @@ -116,6 +134,7 @@ def run( message=message_str, thread=thread, response_format=response_format, + enable_tool_calls=enable_tool_calls, **kwargs, ) diff --git a/python/samples/getting_started/azure_functions/04_single_agent_orchestration_chaining/function_app.py b/python/samples/getting_started/azure_functions/04_single_agent_orchestration_chaining/function_app.py index cc05a323f3..32b2fab131 100644 --- a/python/samples/getting_started/azure_functions/04_single_agent_orchestration_chaining/function_app.py +++ b/python/samples/getting_started/azure_functions/04_single_agent_orchestration_chaining/function_app.py @@ -10,6 +10,7 @@ import json import logging +from collections.abc import Generator from typing import Any import azure.functions as func @@ -44,7 +45,7 @@ def _create_writer_agent() -> Any: # 4. Orchestration that runs the agent sequentially on a shared thread for chaining behaviour. @app.orchestration_trigger(context_name="context") -def single_agent_orchestration(context: DurableOrchestrationContext): +def single_agent_orchestration(context: DurableOrchestrationContext) -> Generator[Any, Any, str]: """Run the writer agent twice on the same thread to mirror chaining behaviour.""" writer = app.get_agent(context, WRITER_AGENT_NAME) @@ -116,12 +117,6 @@ async def get_orchestration_status( ) status = await client.get_status(instance_id) - if status is None: - return func.HttpResponse( - body=json.dumps({"error": "Instance not found"}), - status_code=404, - mimetype="application/json", - ) response_data: dict[str, Any] = { "instanceId": status.instance_id, diff --git a/python/samples/getting_started/azure_functions/05_multi_agent_orchestration_concurrency/function_app.py b/python/samples/getting_started/azure_functions/05_multi_agent_orchestration_concurrency/function_app.py index 69ea8816b2..41fb3f08b2 100644 --- a/python/samples/getting_started/azure_functions/05_multi_agent_orchestration_concurrency/function_app.py +++ b/python/samples/getting_started/azure_functions/05_multi_agent_orchestration_concurrency/function_app.py @@ -10,6 +10,7 @@ import json import logging +from collections.abc import Generator from typing import Any, cast from agent_framework import AgentRunResponse @@ -51,7 +52,7 @@ def _create_agents() -> list[Any]: # 4. Durable Functions orchestration that runs both agents in parallel. @app.orchestration_trigger(context_name="context") -def multi_agent_concurrent_orchestration(context: DurableOrchestrationContext): +def multi_agent_concurrent_orchestration(context: DurableOrchestrationContext) -> Generator[Any, Any, dict[str, str]]: """Fan out to two domain-specific agents and aggregate their responses.""" prompt = context.get_input() @@ -137,12 +138,6 @@ async def get_orchestration_status( ) status = await client.get_status(instance_id) - if status is None: - return func.HttpResponse( - body=json.dumps({"error": "Instance not found"}), - status_code=404, - mimetype="application/json", - ) response_data: dict[str, Any] = { "instanceId": status.instance_id, diff --git a/python/samples/getting_started/azure_functions/06_multi_agent_orchestration_conditionals/function_app.py b/python/samples/getting_started/azure_functions/06_multi_agent_orchestration_conditionals/function_app.py index 2ee445d423..8ef5ef7211 100644 --- a/python/samples/getting_started/azure_functions/06_multi_agent_orchestration_conditionals/function_app.py +++ b/python/samples/getting_started/azure_functions/06_multi_agent_orchestration_conditionals/function_app.py @@ -11,7 +11,7 @@ import json import logging -from collections.abc import Mapping +from collections.abc import Generator, Mapping from typing import Any, cast import azure.functions as func @@ -74,7 +74,7 @@ def send_email(message: str) -> str: # 4. Orchestration validates input, runs agents, and branches on spam results. @app.orchestration_trigger(context_name="context") -def spam_detection_orchestration(context: DurableOrchestrationContext): +def spam_detection_orchestration(context: DurableOrchestrationContext) -> Generator[Any, Any, str]: payload_raw = context.get_input() if not isinstance(payload_raw, Mapping): raise ValueError("Email data is required") @@ -105,7 +105,7 @@ def spam_detection_orchestration(context: DurableOrchestrationContext): spam_result = cast(SpamDetectionResult, spam_result_raw.value) if spam_result.is_spam: - result = yield context.call_activity("handle_spam_email", spam_result.reason) + result = yield context.call_activity("handle_spam_email", spam_result.reason) # type: ignore[misc] return result email_thread = email_agent.get_new_thread() @@ -125,7 +125,7 @@ def spam_detection_orchestration(context: DurableOrchestrationContext): email_result = cast(EmailResponse, email_result_raw.value) - result = yield context.call_activity("send_email", email_result.response) + result = yield context.call_activity("send_email", email_result.response) # type: ignore[misc] return result @@ -196,12 +196,6 @@ async def get_orchestration_status( ) status = await client.get_status(instance_id) - if status is None: - return func.HttpResponse( - body=json.dumps({"error": "Instance not found"}), - status_code=404, - mimetype="application/json", - ) response_data: dict[str, Any] = { "instanceId": status.instance_id, diff --git a/python/samples/getting_started/azure_functions/07_single_agent_orchestration_hitl/function_app.py b/python/samples/getting_started/azure_functions/07_single_agent_orchestration_hitl/function_app.py index c8e2bbaa9c..5c79bb4a86 100644 --- a/python/samples/getting_started/azure_functions/07_single_agent_orchestration_hitl/function_app.py +++ b/python/samples/getting_started/azure_functions/07_single_agent_orchestration_hitl/function_app.py @@ -10,7 +10,7 @@ import json import logging -from collections.abc import Mapping +from collections.abc import Generator, Mapping from datetime import timedelta from typing import Any @@ -62,7 +62,7 @@ def _create_writer_agent() -> Any: # 3. Activities encapsulate external work for review notifications and publishing. @app.activity_trigger(input_name="content") -def notify_user_for_approval(content: dict) -> None: +def notify_user_for_approval(content: dict[str, Any]) -> None: model = GeneratedContent.model_validate(content) logger.info("NOTIFICATION: Please review the following content for approval:") logger.info("Title: %s", model.title or "(untitled)") @@ -71,7 +71,7 @@ def notify_user_for_approval(content: dict) -> None: @app.activity_trigger(input_name="content") -def publish_content(content: dict) -> None: +def publish_content(content: dict[str, Any]) -> None: model = GeneratedContent.model_validate(content) logger.info("PUBLISHING: Content has been published successfully:") logger.info("Title: %s", model.title or "(untitled)") @@ -80,7 +80,7 @@ def publish_content(content: dict) -> None: # 4. Orchestration loops until the human approves, times out, or attempts are exhausted. @app.orchestration_trigger(context_name="context") -def content_generation_hitl_orchestration(context: DurableOrchestrationContext): +def content_generation_hitl_orchestration(context: DurableOrchestrationContext) -> Generator[Any, Any, dict[str, str]]: payload_raw = context.get_input() if not isinstance(payload_raw, Mapping): raise ValueError("Content generation input is required") @@ -102,7 +102,7 @@ def content_generation_hitl_orchestration(context: DurableOrchestrationContext): ) content = initial_raw.value - logger.info("Type of content after extraction: %s", type(content)) + logger.info("Type of content after extraction: %s", type(content)) # type: ignore[misc] if content is None or not isinstance(content, GeneratedContent): raise ValueError("Agent returned no content after extraction.") @@ -114,7 +114,7 @@ def content_generation_hitl_orchestration(context: DurableOrchestrationContext): f"Requesting human feedback. Iteration #{attempt}. Timeout: {payload.approval_timeout_hours} hour(s)." ) - yield context.call_activity("notify_user_for_approval", content.model_dump()) + yield context.call_activity("notify_user_for_approval", content.model_dump()) # type: ignore[misc] approval_task = context.wait_for_external_event(HUMAN_APPROVAL_EVENT) timeout_task = context.create_timer( @@ -129,7 +129,7 @@ def content_generation_hitl_orchestration(context: DurableOrchestrationContext): if approval_payload.approved: context.set_custom_status("Content approved by human reviewer. Publishing content...") - yield context.call_activity("publish_content", content.model_dump()) + yield context.call_activity("publish_content", content.model_dump()) # type: ignore[misc] context.set_custom_status( f"Content published successfully at {context.current_utc_datetime:%Y-%m-%dT%H:%M:%S}" ) @@ -286,14 +286,6 @@ async def get_orchestration_status( show_input=True, ) - # Check if status is None or if the instance doesn't exist (runtime_status is None) - if status is None or getattr(status, "runtime_status", None) is None: - return func.HttpResponse( - body=json.dumps({"error": "Instance not found."}), - status_code=404, - mimetype="application/json", - ) - response_data: dict[str, Any] = { "instanceId": getattr(status, "instance_id", None), "runtimeStatus": getattr(status.runtime_status, "name", None) From 0cd3bdf65c4caa6e9486ddc29e8a97e181a3e60a Mon Sep 17 00:00:00 2001 From: Laveesh Rohra Date: Tue, 30 Dec 2025 12:47:19 -0800 Subject: [PATCH 03/11] Implement sample --- .gitignore | 3 + .../agent_framework_azurefunctions/_app.py | 22 +- .../_orchestration.py | 63 ++- .../agent_framework_durabletask/_client.py | 25 +- .../_durable_agent_state.py | 27 +- .../agent_framework_durabletask/_entities.py | 4 +- .../agent_framework_durabletask/_executors.py | 302 ++++++++++++-- .../agent_framework_durabletask/_models.py | 32 +- .../_orchestration_context.py | 6 +- .../agent_framework_durabletask/_shim.py | 23 +- .../agent_framework_durabletask/_worker.py | 23 +- .../packages/durabletask/tests/test_client.py | 51 +++ .../tests/test_durable_agent_state.py | 11 +- .../durabletask/tests/test_executors.py | 96 ++++- .../packages/durabletask/tests/test_models.py | 42 +- .../packages/durabletask/tests/test_shim.py | 71 ++-- .../durabletask/01_single_agent/README.md | 387 ++++++++++++++++++ .../durabletask/01_single_agent/client.py | 92 +++++ .../01_single_agent/requirements.txt | 9 + .../durabletask/01_single_agent/worker.py | 89 ++++ 20 files changed, 1198 insertions(+), 180 deletions(-) create mode 100644 python/samples/getting_started/durabletask/01_single_agent/README.md create mode 100644 python/samples/getting_started/durabletask/01_single_agent/client.py create mode 100644 python/samples/getting_started/durabletask/01_single_agent/requirements.txt create mode 100644 python/samples/getting_started/durabletask/01_single_agent/worker.py diff --git a/.gitignore b/.gitignore index f0f8c09495..6d6f957b2a 100644 --- a/.gitignore +++ b/.gitignore @@ -226,3 +226,6 @@ local.settings.json # Database files *.db + +# Sample files +**/sample.py diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py index 49d3cb5251..fc0a480be3 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py @@ -8,6 +8,7 @@ import json import re +import uuid from collections.abc import Callable, Mapping from dataclasses import dataclass from datetime import datetime, timezone @@ -29,6 +30,7 @@ WAIT_FOR_RESPONSE_HEADER, AgentResponseCallbackProtocol, AgentSessionId, + ApiResponseFields, DurableAgentState, DurableAIAgent, RunRequest, @@ -416,7 +418,6 @@ async def http_start(req: func.HttpRequest, client: df.DurableOrchestrationClien run_request = self._build_request_data( req_body, message, - thread_id, correlation_id, request_response_format, ) @@ -639,7 +640,6 @@ async def _handle_mcp_tool_invocation( run_request = self._build_request_data( req_body={"message": query, "role": "user"}, message=query, - thread_id=str(session_id), correlation_id=correlation_id, request_response_format=REQUEST_RESPONSE_FORMAT_TEXT, ) @@ -790,8 +790,9 @@ async def _poll_entity_for_response( agent_response = state.try_get_agent_response(correlation_id) if agent_response: + response_message = "\n".join(message.text for message in agent_response.messages if message.text) result = self._build_success_result( - response_data=agent_response, + response_message=response_message, message=message, thread_id=thread_id, correlation_id=correlation_id, @@ -837,23 +838,22 @@ async def _build_timeout_result(self, message: str, thread_id: str, correlation_ ) def _build_success_result( - self, response_data: dict[str, Any], message: str, thread_id: str, correlation_id: str, state: DurableAgentState + self, response_message: str, message: str, thread_id: str, correlation_id: str, state: DurableAgentState ) -> dict[str, Any]: """Build the success result returned to the HTTP caller.""" return self._build_response_payload( - response=response_data.get("content"), + response=response_message, message=message, thread_id=thread_id, status="success", correlation_id=correlation_id, - extra_fields={"message_count": response_data.get("message_count", state.message_count)}, + extra_fields={ApiResponseFields.MESSAGE_COUNT: state.message_count}, ) def _build_request_data( self, req_body: dict[str, Any], message: str, - thread_id: str, correlation_id: str, request_response_format: str, ) -> dict[str, Any]: @@ -920,15 +920,13 @@ def _convert_payload_to_text(self, payload: dict[str, Any]) -> str: def _generate_unique_id(self) -> str: """Generate a new unique identifier.""" - import uuid - return uuid.uuid4().hex - def _create_session_id(self, func_name: str, thread_id: str | None) -> AgentSessionId: + def _create_session_id(self, agent_name: str, thread_id: str | None) -> AgentSessionId: """Create a session identifier using the provided thread id or a random value.""" if thread_id: - return AgentSessionId(name=func_name, key=thread_id) - return AgentSessionId.with_random_key(name=func_name) + return AgentSessionId(name=agent_name, key=thread_id) + return AgentSessionId.with_random_key(name=agent_name) def _resolve_thread_id(self, req: func.HttpRequest, req_body: dict[str, Any]) -> str: """Retrieve the thread identifier from request body or query parameters.""" diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py index 2d4b424b1e..d28a2540db 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py @@ -11,9 +11,7 @@ import azure.durable_functions as df from agent_framework import AgentThread, get_logger from agent_framework_durabletask import ( - AgentSessionId, DurableAgentExecutor, - DurableAgentThread, RunRequest, ensure_response_format, load_agent_response, @@ -124,58 +122,53 @@ class AzureFunctionsAgentExecutor(DurableAgentExecutor[AgentTask]): def __init__(self, context: AgentOrchestrationContextType): self.context = context + def _generate_unique_id(self) -> str: + return str(self.context.new_uuid()) + + def get_run_request( + self, + message: str, + response_format: type[BaseModel] | None, + enable_tool_calls: bool, + ) -> RunRequest: + """Get the current run request from the orchestration context. + + Returns: + RunRequest: The current run request + """ + request = super().get_run_request( + message, + response_format, + enable_tool_calls, + ) + request.orchestration_id = self.context.instance_id + return request + def run_durable_agent( self, agent_name: str, - message: str, + run_request: RunRequest, thread: AgentThread | None = None, - response_format: type[BaseModel] | None = None, - enable_tool_calls: bool | None = None, - **kwargs: Any, ) -> AgentTask: - # Extract optional parameters - enable_tools = True if enable_tool_calls is None else enable_tool_calls # Resolve session - if isinstance(thread, DurableAgentThread) and thread.session_id is not None: - session_id = thread.session_id - else: - session_key = str(self.context.new_uuid()) - session_id = AgentSessionId(name=agent_name, key=session_key) - logger.debug( - "[AzureFunctionsAgentProvider] No thread provided, created session_id: %s", - session_id, - ) + session_id = self._create_session_id(agent_name, thread) entity_id = df.EntityId( name=session_id.entity_name, key=session_id.key, ) - correlation_id = str(self.context.new_uuid()) + logger.debug( "[AzureFunctionsAgentProvider] correlation_id: %s entity_id: %s session_id: %s", - correlation_id, + run_request.correlation_id, entity_id, session_id, ) - run_request = RunRequest( - message=message, - enable_tool_calls=enable_tools, - correlation_id=correlation_id, - response_format=response_format, - orchestration_id=self.context.instance_id, - created_at=self.context.current_utc_datetime, - ) - entity_task = self.context.call_entity(entity_id, "run", run_request.to_dict()) return AgentTask( entity_task=entity_task, - response_format=response_format, - correlation_id=correlation_id, + response_format=run_request.response_format, + correlation_id=run_request.correlation_id, ) - - def get_new_thread(self, agent_name: str, **kwargs: Any) -> DurableAgentThread: - session_key = str(self.context.new_uuid()) - session_id = AgentSessionId(name=agent_name, key=session_key) - return DurableAgentThread.from_session_id(session_id, **kwargs) diff --git a/python/packages/durabletask/agent_framework_durabletask/_client.py b/python/packages/durabletask/agent_framework_durabletask/_client.py index e26c9d6ba5..8c16e67445 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_client.py +++ b/python/packages/durabletask/agent_framework_durabletask/_client.py @@ -8,16 +8,17 @@ from __future__ import annotations -from agent_framework import get_logger +from agent_framework import AgentRunResponse, get_logger from durabletask.client import TaskHubGrpcClient +from ._constants import DEFAULT_MAX_POLL_RETRIES, DEFAULT_POLL_INTERVAL_SECONDS from ._executors import ClientAgentExecutor from ._shim import DurableAgentProvider, DurableAIAgent logger = get_logger("agent_framework.durabletask.client") -class DurableAIAgentClient(DurableAgentProvider): +class DurableAIAgentClient(DurableAgentProvider[AgentRunResponse]): """Client wrapper for interacting with durable agents externally. This class wraps a durabletask TaskHubGrpcClient and provides a convenient @@ -44,17 +45,31 @@ class DurableAIAgentClient(DurableAgentProvider): ``` """ - def __init__(self, client: TaskHubGrpcClient): + def __init__( + self, + client: TaskHubGrpcClient, + max_poll_retries: int = DEFAULT_MAX_POLL_RETRIES, + poll_interval_seconds: float = DEFAULT_POLL_INTERVAL_SECONDS, + ): """Initialize the client wrapper. Args: client: The durabletask client instance to wrap + max_poll_retries: Maximum polling attempts when waiting for responses + poll_interval_seconds: Delay in seconds between polling attempts """ self._client = client - self._executor = ClientAgentExecutor(self._client) + + # Validate and set polling parameters + self.max_poll_retries = max(1, max_poll_retries) + self.poll_interval_seconds = ( + poll_interval_seconds if poll_interval_seconds > 0 else DEFAULT_POLL_INTERVAL_SECONDS + ) + + self._executor = ClientAgentExecutor(self._client, self.max_poll_retries, self.poll_interval_seconds) logger.debug("[DurableAIAgentClient] Initialized with client type: %s", type(client).__name__) - def get_agent(self, agent_name: str) -> DurableAIAgent: + def get_agent(self, agent_name: str) -> DurableAIAgent[AgentRunResponse]: """Retrieve a DurableAIAgent shim for the specified agent. This method returns a proxy object that can be used to execute the agent. diff --git a/python/packages/durabletask/agent_framework_durabletask/_durable_agent_state.py b/python/packages/durabletask/agent_framework_durabletask/_durable_agent_state.py index a42a6fad43..3e5ced3ad6 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_durable_agent_state.py +++ b/python/packages/durabletask/agent_framework_durabletask/_durable_agent_state.py @@ -53,7 +53,7 @@ ) from dateutil import parser as date_parser -from ._constants import ApiResponseFields, ContentTypes, DurableStateFields +from ._constants import ContentTypes, DurableStateFields from ._models import RunRequest, serialize_response_format logger = get_logger("agent_framework.durabletask.durable_agent_state") @@ -452,7 +452,7 @@ def message_count(self) -> int: """Get the count of conversation entries (requests + responses).""" return len(self.data.conversation_history) - def try_get_agent_response(self, correlation_id: str) -> dict[str, Any] | None: + def try_get_agent_response(self, correlation_id: str) -> AgentRunResponse | None: """Try to get an agent response by correlation ID. This method searches the conversation history for a response entry matching the given @@ -474,14 +474,8 @@ def try_get_agent_response(self, correlation_id: str) -> dict[str, Any] | None: for entry in self.data.conversation_history: if entry.correlation_id == correlation_id and isinstance(entry, DurableAgentStateResponse): # Found the entry, extract response data - # Get the text content from assistant messages only - content = "\n".join(message.text for message in entry.messages if message.text) + return DurableAgentStateResponse.to_run_response(entry) - return { - ApiResponseFields.CONTENT: content, - ApiResponseFields.MESSAGE_COUNT: self.message_count, - ApiResponseFields.CORRELATION_ID: correlation_id, - } return None @@ -705,6 +699,21 @@ def from_run_response(correlation_id: str, response: AgentRunResponse) -> Durabl usage=DurableAgentStateUsage.from_usage(response.usage_details), ) + @staticmethod + def to_run_response( + response_entry: DurableAgentStateResponse, + ) -> AgentRunResponse: + """Converts a DurableAgentStateResponse back to an AgentRunResponse.""" + messages = [m.to_chat_message() for m in response_entry.messages] + + usage_details = response_entry.usage.to_usage_details() if response_entry.usage is not None else UsageDetails() + + return AgentRunResponse( + created_at=response_entry.created_at.isoformat(), + messages=messages, + usage_details=usage_details, + ) + class DurableAgentStateMessage: """Represents a message within a conversation history entry. diff --git a/python/packages/durabletask/agent_framework_durabletask/_entities.py b/python/packages/durabletask/agent_framework_durabletask/_entities.py index 2f7bb8a62c..2e0b429233 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_entities.py +++ b/python/packages/durabletask/agent_framework_durabletask/_entities.py @@ -128,7 +128,7 @@ async def run( ) -> AgentRunResponse: """Execute the agent with a message.""" if isinstance(request, str): - run_request = RunRequest(message=request, role=Role.USER) + run_request = RunRequest.from_json(request) elif isinstance(request, dict): run_request = RunRequest.from_dict(request) else: @@ -139,8 +139,6 @@ async def run( correlation_id = run_request.correlation_id if not thread_id: raise ValueError("Entity State Provider must provide a thread_id") - if not correlation_id: - raise ValueError("RunRequest must include a correlation_id") response_format = run_request.response_format enable_tool_calls = run_request.enable_tool_calls diff --git a/python/packages/durabletask/agent_framework_durabletask/_executors.py b/python/packages/durabletask/agent_framework_durabletask/_executors.py index e6061716fc..9d4d15845f 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_executors.py +++ b/python/packages/durabletask/agent_framework_durabletask/_executors.py @@ -10,20 +10,23 @@ from __future__ import annotations +import time +import uuid from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Generic, TypeVar +from datetime import datetime, timezone +from typing import Any, Generic, TypeVar -from agent_framework import AgentRunResponse, AgentThread, get_logger -from durabletask.task import CompositeTask, Task +from agent_framework import AgentRunResponse, AgentThread, ChatMessage, ErrorContent, Role, get_logger +from durabletask.client import TaskHubGrpcClient +from durabletask.entities import EntityInstanceId +from durabletask.task import CompositeTask, OrchestrationContext, Task from pydantic import BaseModel -from ._models import DurableAgentThread +from ._constants import DEFAULT_MAX_POLL_RETRIES, DEFAULT_POLL_INTERVAL_SECONDS +from ._durable_agent_state import DurableAgentState +from ._models import AgentSessionId, DurableAgentThread, RunRequest from ._response_utils import ensure_response_format, load_agent_response -if TYPE_CHECKING: # pragma: no cover - from durabletask.client import TaskHubGrpcClient - from durabletask.task import OrchestrationContext - logger = get_logger("agent_framework.durabletask.executors") # TypeVar for the task type returned by executors @@ -116,11 +119,8 @@ class DurableAgentExecutor(ABC, Generic[TaskT]): def run_durable_agent( self, agent_name: str, - message: str, - *, + run_request: RunRequest, thread: AgentThread | None = None, - response_format: type[BaseModel] | None = None, - **kwargs: Any, ) -> TaskT: """Execute the durable agent. @@ -129,39 +129,270 @@ def run_durable_agent( """ raise NotImplementedError - @abstractmethod def get_new_thread(self, agent_name: str, **kwargs: Any) -> DurableAgentThread: - """Create a new thread appropriate for the provider context.""" - raise NotImplementedError + """Create a new DurableAgentThread with random session ID.""" + session_id = self._create_session_id(agent_name) + return DurableAgentThread.from_session_id(session_id, **kwargs) + + def _create_session_id( + self, + agent_name: str, + thread: AgentThread | None = None, + ) -> AgentSessionId: + """Resolve or create the AgentSessionId for the execution.""" + if isinstance(thread, DurableAgentThread) and thread.session_id is not None: + return thread.session_id + # Create new session ID - either no thread provided or it's a regular AgentThread + key = self.generate_unique_id() + return AgentSessionId(name=agent_name, key=key) + + def generate_unique_id(self) -> str: + """Generate a new Unique ID.""" + return uuid.uuid4().hex + + def get_run_request( + self, + message: str, + response_format: type[BaseModel] | None, + enable_tool_calls: bool, + ) -> RunRequest: + """Create a RunRequest for the given parameters.""" + correlation_id = self.generate_unique_id() + return RunRequest( + message=message, + response_format=response_format, + enable_tool_calls=enable_tool_calls, + correlation_id=correlation_id, + ) -class ClientAgentExecutor(DurableAgentExecutor[DurableAgentTask]): - """Execution strategy for external clients (async).""" +class ClientAgentExecutor(DurableAgentExecutor[AgentRunResponse]): + """Execution strategy for external clients. - def __init__(self, client: TaskHubGrpcClient): + Note: Returns AgentRunResponse directly since the execution + is blocking until response is available via polling + as per the design of TaskHubGrpcClient. + """ + + def __init__( + self, + client: TaskHubGrpcClient, + max_poll_retries: int = DEFAULT_MAX_POLL_RETRIES, + poll_interval_seconds: float = DEFAULT_POLL_INTERVAL_SECONDS, + ): self._client = client - logger.debug("[ClientAgentExecutor] Initialized with client type: %s", type(client).__name__) + self.max_poll_retries = max_poll_retries + self.poll_interval_seconds = poll_interval_seconds def run_durable_agent( self, agent_name: str, - message: str, - *, + run_request: RunRequest, thread: AgentThread | None = None, - response_format: type[BaseModel] | None = None, - **kwargs: Any, - ) -> DurableAgentTask: + ) -> AgentRunResponse: """Execute the agent via the durabletask client. - Note: Implementation is backend-specific and should signal/call the entity - and await the durable response. This placeholder raises NotImplementedError - until wired to concrete durabletask calls. + Signals the agent entity with a message request, then polls the entity + state to retrieve the response once processing is complete. + + Note: This is a blocking/synchronous operation (in line with how + TaskHubGrpcClient works) that polls until a response is available or + timeout occurs. + + Args: + agent_name: Name of the agent to execute + run_request: The run request containing message and optional response format + thread: Optional conversation thread (creates new if not provided) + + Returns: + AgentRunResponse: The agent's response after execution completes """ - raise NotImplementedError("ClientAgentProvider.run_durable_agent is not yet implemented") + # Signal the entity with the request + entity_id = self._signal_agent_entity(agent_name, run_request, thread) - def get_new_thread(self, agent_name: str, **kwargs: Any) -> DurableAgentThread: - """Create a new AgentThread for client-side execution.""" - return DurableAgentThread(**kwargs) + # Poll for the response + agent_response = self._poll_for_agent_response(entity_id, run_request.correlation_id) + + # Handle and return the result + return self._handle_agent_response(agent_response, run_request.response_format, run_request.correlation_id) + + def _signal_agent_entity( + self, + agent_name: str, + run_request: RunRequest, + thread: AgentThread | None, + ) -> EntityInstanceId: + """Signal the agent entity with a run request. + + Args: + agent_name: Name of the agent to execute + run_request: The run request containing message and optional response format + thread: Optional conversation thread + + Returns: + entity_id + """ + # Get or create session ID + session_id = self._create_session_id(agent_name, thread) + + # Create the entity ID + entity_id = EntityInstanceId( + entity=session_id.entity_name, + key=session_id.key, + ) + + logger.debug( + "[ClientAgentExecutor] Signaling entity '%s' (session: %s, correlation: %s)", + agent_name, + session_id, + run_request.correlation_id, + ) + + self._client.signal_entity(entity_id, "run", run_request.to_dict()) + + logger.info( + "[ClientAgentExecutor] Signaled entity '%s' for correlation: %s", + agent_name, + run_request.correlation_id, + ) + + return entity_id + + def _poll_for_agent_response( + self, + entity_id: EntityInstanceId, + correlation_id: str, + ) -> AgentRunResponse | None: + """Poll the entity for a response with retries. + + Args: + entity_id: Entity instance identifier + correlation_id: Correlation ID to track the request + + Returns: + The agent response if found, None if timeout occurs + """ + agent_response = None + + for attempt in range(1, self.max_poll_retries + 1): + time.sleep(self.poll_interval_seconds) + + agent_response = self._poll_entity_for_response(entity_id, correlation_id) + if agent_response is not None: + logger.info( + "[ClientAgentExecutor] Found response (attempt %d/%d, correlation: %s)", + attempt, + self.max_poll_retries, + correlation_id, + ) + break + + logger.debug( + "[ClientAgentExecutor] Response not ready (attempt %d/%d)", + attempt, + self.max_poll_retries, + ) + + return agent_response + + def _handle_agent_response( + self, + agent_response: AgentRunResponse | None, + response_format: type[BaseModel] | None, + correlation_id: str, + ) -> AgentRunResponse: + """Handle the agent response or create an error response. + + Args: + agent_response: The response from polling, or None if timeout + response_format: Optional response format for validation + correlation_id: Correlation ID for logging + + Returns: + AgentRunResponse with either the agent's response or an error message + """ + if agent_response is not None: + try: + # Validate response format if specified + if response_format is not None: + ensure_response_format( + response_format, + correlation_id, + agent_response, + ) + + return agent_response + + except Exception as e: + logger.exception( + "[ClientAgentExecutor] Error converting response for correlation: %s", + correlation_id, + ) + error_message = ChatMessage( + role=Role.SYSTEM, + contents=[ + ErrorContent( + message=f"Error processing agent response: {e}", + error_code="response_processing_error", + ) + ], + ) + else: + logger.warning( + "[ClientAgentExecutor] Timeout after %d attempts (correlation: %s)", + self.max_poll_retries, + correlation_id, + ) + error_message = ChatMessage( + role=Role.SYSTEM, + contents=[ + ErrorContent( + message=f"Timeout waiting for agent response after {self.max_poll_retries} attempts", + error_code="response_timeout", + ) + ], + ) + + return AgentRunResponse( + messages=[error_message], + created_at=datetime.now(timezone.utc).isoformat(), + ) + + def _poll_entity_for_response( + self, + entity_id: EntityInstanceId, + correlation_id: str, + ) -> AgentRunResponse | None: + """Poll the entity state for a response matching the correlation ID. + + Args: + entity_id: Entity instance identifier + correlation_id: Correlation ID to search for + + Returns: + Response data dict if found, None otherwise + """ + try: + entity_metadata = self._client.get_entity(entity_id, include_state=True) + + if entity_metadata is None or not entity_metadata.includes_state: + return None + + state_json = entity_metadata.get_state() + if not state_json: + return None + + state = DurableAgentState.from_json(state_json) + + # Use the helper method to get response by correlation ID + return state.try_get_agent_response(correlation_id) + + except Exception as e: + logger.warning( + "[ClientAgentExecutor] Error reading entity state: %s", + e, + ) + return None class OrchestrationAgentExecutor(DurableAgentExecutor[DurableAgentTask]): @@ -174,11 +405,8 @@ def __init__(self, context: OrchestrationContext): def run_durable_agent( self, agent_name: str, - message: str, - *, + run_request: RunRequest, thread: AgentThread | None = None, - response_format: type[BaseModel] | None = None, - **kwargs: Any, ) -> DurableAgentTask: """Execute the agent via orchestration context. @@ -186,7 +414,3 @@ def run_durable_agent( and return the native Task for yielding. Placeholder until wired. """ raise NotImplementedError("OrchestrationAgentProvider.run_durable_agent is not yet implemented") - - def get_new_thread(self, agent_name: str, **kwargs: Any) -> DurableAgentThread: - """Create a new AgentThread for orchestration context.""" - return DurableAgentThread(**kwargs) diff --git a/python/packages/durabletask/agent_framework_durabletask/_models.py b/python/packages/durabletask/agent_framework_durabletask/_models.py index 4317b87832..947ab7a17f 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_models.py +++ b/python/packages/durabletask/agent_framework_durabletask/_models.py @@ -8,10 +8,11 @@ from __future__ import annotations import inspect +import json import uuid from collections.abc import MutableMapping from dataclasses import dataclass -from datetime import datetime +from datetime import datetime, timezone from importlib import import_module from typing import TYPE_CHECKING, Any, cast @@ -103,38 +104,38 @@ class RunRequest: role: The role of the message sender (user, system, or assistant) response_format: Optional Pydantic BaseModel type describing the structured response format enable_tool_calls: Whether to enable tool calls for this request - correlation_id: Optional correlation ID for tracking the response to this specific request + correlation_id: Correlation ID for tracking the response to this specific request created_at: Optional timestamp when the request was created orchestration_id: Optional ID of the orchestration that initiated this request """ message: str request_response_format: str + correlation_id: str role: Role = Role.USER response_format: type[BaseModel] | None = None enable_tool_calls: bool = True - correlation_id: str | None = None created_at: datetime | None = None orchestration_id: str | None = None def __init__( self, message: str, + correlation_id: str, request_response_format: str = REQUEST_RESPONSE_FORMAT_TEXT, role: Role | str | None = Role.USER, response_format: type[BaseModel] | None = None, enable_tool_calls: bool = True, - correlation_id: str | None = None, created_at: datetime | None = None, orchestration_id: str | None = None, ) -> None: self.message = message + self.correlation_id = correlation_id self.role = self.coerce_role(role) self.response_format = response_format self.request_response_format = request_response_format self.enable_tool_calls = enable_tool_calls - self.correlation_id = correlation_id - self.created_at = created_at + self.created_at = created_at if created_at is not None else datetime.now(tz=timezone.utc) self.orchestration_id = orchestration_id @staticmethod @@ -156,11 +157,10 @@ def to_dict(self) -> dict[str, Any]: "enable_tool_calls": self.enable_tool_calls, "role": self.role.value, "request_response_format": self.request_response_format, + "correlationId": self.correlation_id, } if self.response_format: result["response_format"] = serialize_response_format(self.response_format) - if self.correlation_id: - result["correlationId"] = self.correlation_id if self.created_at: result["created_at"] = self.created_at.isoformat() if self.orchestration_id: @@ -168,6 +168,16 @@ def to_dict(self) -> dict[str, Any]: return result + @classmethod + def from_json(cls, data: str) -> RunRequest: + """Create RunRequest from JSON string.""" + try: + dict_data = json.loads(data) + except json.JSONDecodeError as e: + raise ValueError("The durable agent state is not valid JSON.") from e + + return cls.from_dict(dict_data) + @classmethod def from_dict(cls, data: dict[str, Any]) -> RunRequest: """Create RunRequest from dictionary.""" @@ -178,13 +188,17 @@ def from_dict(cls, data: dict[str, Any]) -> RunRequest: except ValueError: created_at = None + correlation_id = data.get("correlationId") + if not correlation_id: + raise ValueError("correlationId is required in RunRequest data") + return cls( message=data.get("message", ""), + correlation_id=correlation_id, request_response_format=data.get("request_response_format", REQUEST_RESPONSE_FORMAT_TEXT), role=cls.coerce_role(data.get("role")), response_format=_deserialize_response_format(data.get("response_format")), enable_tool_calls=data.get("enable_tool_calls", True), - correlation_id=data.get("correlationId"), created_at=created_at, orchestration_id=data.get("orchestrationId"), ) diff --git a/python/packages/durabletask/agent_framework_durabletask/_orchestration_context.py b/python/packages/durabletask/agent_framework_durabletask/_orchestration_context.py index 442d247831..d9a7ae3c02 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_orchestration_context.py +++ b/python/packages/durabletask/agent_framework_durabletask/_orchestration_context.py @@ -11,13 +11,13 @@ from agent_framework import get_logger from durabletask.task import OrchestrationContext -from ._executors import OrchestrationAgentExecutor +from ._executors import DurableAgentTask, OrchestrationAgentExecutor from ._shim import DurableAgentProvider, DurableAIAgent logger = get_logger("agent_framework.durabletask.orchestration_context") -class DurableAIAgentOrchestrationContext(DurableAgentProvider): +class DurableAIAgentOrchestrationContext(DurableAgentProvider[DurableAgentTask]): """Orchestration context wrapper for interacting with durable agents internally. This class wraps a durabletask OrchestrationContext and provides a convenient @@ -55,7 +55,7 @@ def __init__(self, context: OrchestrationContext): self._executor = OrchestrationAgentExecutor(self._context) logger.debug("[DurableAIAgentOrchestrationContext] Initialized") - def get_agent(self, agent_name: str) -> DurableAIAgent: + def get_agent(self, agent_name: str) -> DurableAIAgent[DurableAgentTask]: """Retrieve a DurableAIAgent shim for the specified agent. This method returns a proxy object that can be used to execute the agent diff --git a/python/packages/durabletask/agent_framework_durabletask/_shim.py b/python/packages/durabletask/agent_framework_durabletask/_shim.py index 3b4cdcc95f..e469011753 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_shim.py +++ b/python/packages/durabletask/agent_framework_durabletask/_shim.py @@ -11,13 +11,13 @@ from abc import ABC, abstractmethod from collections.abc import AsyncIterator -from typing import TYPE_CHECKING, Any, Generic, TypeVar +from typing import Any, Generic, TypeVar from agent_framework import AgentProtocol, AgentRunResponseUpdate, AgentThread, ChatMessage from pydantic import BaseModel -if TYPE_CHECKING: - from ._executors import DurableAgentExecutor +from ._executors import DurableAgentExecutor +from ._models import DurableAgentThread # TypeVar for the task type returned by executors # Covariant because TaskT only appears in return positions (output) @@ -112,8 +112,7 @@ def run( # pyright: ignore[reportIncompatibleMethodOverride] *, thread: AgentThread | None = None, response_format: type[BaseModel] | None = None, - enable_tool_calls: bool | None = None, - **kwargs: Any, + enable_tool_calls: bool = True, ) -> TaskT: """Execute the agent via the injected provider. @@ -129,13 +128,17 @@ def run( # pyright: ignore[reportIncompatibleMethodOverride] TaskT: The task type specific to the executor (e.g., DurableAgentTask or AgentTask) """ message_str = self._normalize_messages(messages) - return self._executor.run_durable_agent( - agent_name=self._name, + + run_request = self._executor.get_run_request( message=message_str, - thread=thread, response_format=response_format, enable_tool_calls=enable_tool_calls, - **kwargs, + ) + + return self._executor.run_durable_agent( + agent_name=self._name, + run_request=run_request, + thread=thread, ) def run_stream( @@ -157,7 +160,7 @@ def run_stream( """ raise NotImplementedError("Streaming is not supported for durable agents") - def get_new_thread(self, **kwargs: Any) -> AgentThread: + def get_new_thread(self, **kwargs: Any) -> DurableAgentThread: """Create a new agent thread via the provider.""" return self._executor.get_new_thread(self._name, **kwargs) diff --git a/python/packages/durabletask/agent_framework_durabletask/_worker.py b/python/packages/durabletask/agent_framework_durabletask/_worker.py index aa2c90b3b4..fea4b8ba7c 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_worker.py +++ b/python/packages/durabletask/agent_framework_durabletask/_worker.py @@ -8,6 +8,7 @@ from __future__ import annotations +import asyncio from typing import Any from agent_framework import AgentProtocol, get_logger @@ -173,7 +174,7 @@ def __init__(self) -> None: entity_name, ) - async def run(self, request: Any) -> Any: + def run(self, request: Any) -> Any: """Handle run requests from clients or orchestrations. Args: @@ -183,7 +184,25 @@ async def run(self, request: Any) -> Any: AgentRunResponse as dict """ logger.debug("[ConfiguredAgentEntity.run] Executing agent: %s", agent_name) - response = await self._agent_entity.run(request) + # Get or create event loop for async execution + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + # Run the async agent execution synchronously + if loop.is_running(): + # If loop is already running (shouldn't happen in entity context), + # create a temporary loop + temp_loop = asyncio.new_event_loop() + try: + response = temp_loop.run_until_complete(self._agent_entity.run(request)) + finally: + temp_loop.close() + else: + response = loop.run_until_complete(self._agent_entity.run(request)) + return response.to_dict() def reset(self) -> None: diff --git a/python/packages/durabletask/tests/test_client.py b/python/packages/durabletask/tests/test_client.py index d6a13379a6..cf2ccfe1af 100644 --- a/python/packages/durabletask/tests/test_client.py +++ b/python/packages/durabletask/tests/test_client.py @@ -12,6 +12,7 @@ from agent_framework import AgentProtocol from agent_framework_durabletask import DurableAgentThread, DurableAIAgentClient +from agent_framework_durabletask._constants import DEFAULT_MAX_POLL_RETRIES, DEFAULT_POLL_INTERVAL_SECONDS from agent_framework_durabletask._shim import DurableAIAgent @@ -27,6 +28,16 @@ def agent_client(mock_grpc_client: Mock) -> DurableAIAgentClient: return DurableAIAgentClient(mock_grpc_client) +@pytest.fixture +def agent_client_with_custom_polling(mock_grpc_client: Mock) -> DurableAIAgentClient: + """Create a DurableAIAgentClient with custom polling parameters.""" + return DurableAIAgentClient( + mock_grpc_client, + max_poll_retries=15, + poll_interval_seconds=0.5, + ) + + class TestDurableAIAgentClientGetAgent: """Test core workflow: retrieving agents from the client.""" @@ -87,5 +98,45 @@ def test_client_agent_thread_with_parameters(self, agent_client: DurableAIAgentC assert thread.service_thread_id == "client-session-123" +class TestDurableAIAgentClientPollingConfiguration: + """Test polling configuration parameters for DurableAIAgentClient.""" + + def test_client_uses_default_polling_parameters(self, agent_client: DurableAIAgentClient) -> None: + """Verify client initializes with default polling parameters.""" + assert agent_client.max_poll_retries == DEFAULT_MAX_POLL_RETRIES + assert agent_client.poll_interval_seconds == DEFAULT_POLL_INTERVAL_SECONDS + + def test_client_accepts_custom_polling_parameters( + self, agent_client_with_custom_polling: DurableAIAgentClient + ) -> None: + """Verify client accepts and stores custom polling parameters.""" + assert agent_client_with_custom_polling.max_poll_retries == 15 + assert agent_client_with_custom_polling.poll_interval_seconds == 0.5 + + def test_client_validates_max_poll_retries(self, mock_grpc_client: Mock) -> None: + """Verify client validates and normalizes max_poll_retries.""" + # Test with zero - should enforce minimum of 1 + client = DurableAIAgentClient(mock_grpc_client, max_poll_retries=0) + assert client.max_poll_retries == 1 + + # Test with negative - should enforce minimum of 1 + client = DurableAIAgentClient(mock_grpc_client, max_poll_retries=-5) + assert client.max_poll_retries == 1 + + def test_client_validates_poll_interval_seconds(self, mock_grpc_client: Mock) -> None: + """Verify client validates and normalizes poll_interval_seconds.""" + # Test with zero - should use default + client = DurableAIAgentClient(mock_grpc_client, poll_interval_seconds=0) + assert client.poll_interval_seconds == DEFAULT_POLL_INTERVAL_SECONDS + + # Test with negative - should use default + client = DurableAIAgentClient(mock_grpc_client, poll_interval_seconds=-0.5) + assert client.poll_interval_seconds == DEFAULT_POLL_INTERVAL_SECONDS + + # Test with valid float + client = DurableAIAgentClient(mock_grpc_client, poll_interval_seconds=2.5) + assert client.poll_interval_seconds == 2.5 + + if __name__ == "__main__": pytest.main([__file__, "-v", "--tb=short"]) diff --git a/python/packages/durabletask/tests/test_durable_agent_state.py b/python/packages/durabletask/tests/test_durable_agent_state.py index 94e29f29a5..ec9a2f0cc9 100644 --- a/python/packages/durabletask/tests/test_durable_agent_state.py +++ b/python/packages/durabletask/tests/test_durable_agent_state.py @@ -112,20 +112,21 @@ class TestDurableAgentStateMessageCreatedAt: """Test suite for DurableAgentStateMessage created_at field handling.""" def test_message_from_run_request_without_created_at_preserves_none(self) -> None: - """Test from_run_request preserves None created_at instead of defaulting to current time. + """Test from_run_request handles auto-populated created_at from RunRequest. - When a RunRequest has no created_at value, the resulting DurableAgentStateMessage - should also have None for created_at, not default to current UTC time. + When a RunRequest is created with None for created_at, RunRequest defaults it to + current UTC time. The resulting DurableAgentStateMessage should have this timestamp. """ run_request = RunRequest( message="test message", correlation_id="corr-run", - created_at=None, # Explicitly None + created_at=None, # RunRequest will default this to current time ) durable_message = DurableAgentStateMessage.from_run_request(run_request) - assert durable_message.created_at is None + # RunRequest auto-populates created_at, so it should not be None + assert durable_message.created_at is not None def test_message_from_run_request_with_created_at_parses_correctly(self) -> None: """Test from_run_request correctly parses a valid created_at timestamp.""" diff --git a/python/packages/durabletask/tests/test_executors.py b/python/packages/durabletask/tests/test_executors.py index ff6efe768e..8b425962b0 100644 --- a/python/packages/durabletask/tests/test_executors.py +++ b/python/packages/durabletask/tests/test_executors.py @@ -6,15 +6,19 @@ Run with: pytest tests/test_executors.py -v """ +import time from unittest.mock import Mock import pytest +from agent_framework import AgentRunResponse from agent_framework_durabletask import DurableAgentThread +from agent_framework_durabletask._constants import DEFAULT_MAX_POLL_RETRIES, DEFAULT_POLL_INTERVAL_SECONDS from agent_framework_durabletask._executors import ( ClientAgentExecutor, OrchestrationAgentExecutor, ) +from agent_framework_durabletask._models import RunRequest class TestExecutorThreadCreation: @@ -60,23 +64,101 @@ def test_orchestration_executor_forwards_kwargs_to_thread(self) -> None: class TestExecutorRunNotImplemented: - """Test that run_durable_agent raises NotImplementedError until wired.""" + """Test that run_durable_agent works as implemented.""" - async def test_client_executor_run_not_implemented(self) -> None: - """Verify ClientAgentExecutor run raises NotImplementedError until implementation.""" + def test_client_executor_run_returns_response(self) -> None: + """Verify ClientAgentExecutor.run_durable_agent returns AgentRunResponse (synchronous).""" mock_client = Mock() - executor = ClientAgentExecutor(mock_client) + mock_client.signal_entity = Mock() + mock_client.get_entity = Mock(return_value=None) + # Use minimal polling parameters to avoid long test times + executor = ClientAgentExecutor(mock_client, max_poll_retries=1, poll_interval_seconds=0.01) + + # Create a RunRequest + run_request = RunRequest(message="test message", correlation_id="test-123") - with pytest.raises(NotImplementedError, match="ClientAgentProvider.run_durable_agent"): - await executor.run_durable_agent("test_agent", "test message") + # This should return a timeout response (since mock doesn't have state) + result = executor.run_durable_agent("test_agent", run_request) + + # Verify it returns an AgentRunResponse (synchronous, not a coroutine) + assert isinstance(result, AgentRunResponse) + assert result is not None def test_orchestration_executor_run_not_implemented(self) -> None: """Verify OrchestrationAgentExecutor run raises NotImplementedError until implementation.""" mock_context = Mock() executor = OrchestrationAgentExecutor(mock_context) + # Create a RunRequest + run_request = RunRequest(message="test message", correlation_id="test-123") + with pytest.raises(NotImplementedError, match="OrchestrationAgentProvider.run_durable_agent"): - executor.run_durable_agent("test_agent", "test message") + executor.run_durable_agent("test_agent", run_request) + + +class TestClientAgentExecutorPollingConfiguration: + """Test polling configuration parameters for ClientAgentExecutor.""" + + def test_executor_uses_default_polling_parameters(self) -> None: + """Verify executor initializes with default polling parameters.""" + mock_client = Mock() + executor = ClientAgentExecutor(mock_client) + + assert executor.max_poll_retries == DEFAULT_MAX_POLL_RETRIES + assert executor.poll_interval_seconds == DEFAULT_POLL_INTERVAL_SECONDS + + def test_executor_accepts_custom_polling_parameters(self) -> None: + """Verify executor accepts and stores custom polling parameters.""" + mock_client = Mock() + executor = ClientAgentExecutor(mock_client, max_poll_retries=20, poll_interval_seconds=0.5) + + assert executor.max_poll_retries == 20 + assert executor.poll_interval_seconds == 0.5 + + def test_executor_respects_custom_max_poll_retries(self) -> None: + """Verify executor respects custom max_poll_retries during polling.""" + + mock_client = Mock() + mock_client.signal_entity = Mock() + mock_client.get_entity = Mock(return_value=None) + + # Create executor with only 2 retries + executor = ClientAgentExecutor(mock_client, max_poll_retries=2, poll_interval_seconds=0.01) + + # Create a RunRequest + run_request = RunRequest(message="test message", correlation_id="test-123") + + # Run the agent + result = executor.run_durable_agent("test_agent", run_request) + + # Verify it returns AgentRunResponse (should timeout after 2 attempts) + assert isinstance(result, AgentRunResponse) + + # Verify get_entity was called 2 times (max_poll_retries) + assert mock_client.get_entity.call_count == 2 + + def test_executor_respects_custom_poll_interval(self) -> None: + """Verify executor respects custom poll_interval_seconds during polling.""" + + mock_client = Mock() + mock_client.signal_entity = Mock() + mock_client.get_entity = Mock(return_value=None) + + # Create executor with very short interval + executor = ClientAgentExecutor(mock_client, max_poll_retries=3, poll_interval_seconds=0.01) + + # Create a RunRequest + run_request = RunRequest(message="test message", correlation_id="test-123") + + # Measure time taken + start = time.time() + result = executor.run_durable_agent("test_agent", run_request) + elapsed = time.time() - start + + # Should take roughly 3 * 0.01 = 0.03 seconds (plus overhead) + # Be generous with timing to avoid flakiness + assert elapsed < 0.2 # Should be quick with 0.01 interval + assert isinstance(result, AgentRunResponse) if __name__ == "__main__": diff --git a/python/packages/durabletask/tests/test_models.py b/python/packages/durabletask/tests/test_models.py index f14bffcaf8..ffcfe868e1 100644 --- a/python/packages/durabletask/tests/test_models.py +++ b/python/packages/durabletask/tests/test_models.py @@ -18,9 +18,10 @@ class TestRunRequest: def test_init_with_defaults(self) -> None: """Test RunRequest initialization with defaults.""" - request = RunRequest(message="Hello") + request = RunRequest(message="Hello", correlation_id="corr-001") assert request.message == "Hello" + assert request.correlation_id == "corr-001" assert request.role == Role.USER assert request.response_format is None assert request.enable_tool_calls is True @@ -30,30 +31,33 @@ def test_init_with_all_fields(self) -> None: schema = ModuleStructuredResponse request = RunRequest( message="Hello", + correlation_id="corr-002", role=Role.SYSTEM, response_format=schema, enable_tool_calls=False, ) assert request.message == "Hello" + assert request.correlation_id == "corr-002" assert request.role == Role.SYSTEM assert request.response_format is schema assert request.enable_tool_calls is False def test_init_coerces_string_role(self) -> None: """Ensure string role values are coerced into Role instances.""" - request = RunRequest(message="Hello", role="system") # type: ignore[arg-type] + request = RunRequest(message="Hello", correlation_id="corr-003", role="system") # type: ignore[arg-type] assert request.role == Role.SYSTEM def test_to_dict_with_defaults(self) -> None: """Test to_dict with default values.""" - request = RunRequest(message="Test message") + request = RunRequest(message="Test message", correlation_id="corr-004") data = request.to_dict() assert data["message"] == "Test message" assert data["enable_tool_calls"] is True assert data["role"] == "user" + assert data["correlationId"] == "corr-004" assert "response_format" not in data or data["response_format"] is None assert "thread_id" not in data @@ -62,6 +66,7 @@ def test_to_dict_with_all_fields(self) -> None: schema = ModuleStructuredResponse request = RunRequest( message="Hello", + correlation_id="corr-005", role=Role.ASSISTANT, response_format=schema, enable_tool_calls=False, @@ -69,6 +74,7 @@ def test_to_dict_with_all_fields(self) -> None: data = request.to_dict() assert data["message"] == "Hello" + assert data["correlationId"] == "corr-005" assert data["role"] == "assistant" assert data["response_format"]["__response_schema_type__"] == "pydantic_model" assert data["response_format"]["module"] == schema.__module__ @@ -78,16 +84,17 @@ def test_to_dict_with_all_fields(self) -> None: def test_from_dict_with_defaults(self) -> None: """Test from_dict with minimal data.""" - data = {"message": "Hello"} + data = {"message": "Hello", "correlationId": "corr-006"} request = RunRequest.from_dict(data) assert request.message == "Hello" + assert request.correlation_id == "corr-006" assert request.role == Role.USER assert request.enable_tool_calls is True def test_from_dict_ignores_thread_id_field(self) -> None: """Ensure legacy thread_id input does not break RunRequest parsing.""" - request = RunRequest.from_dict({"message": "Hello", "thread_id": "ignored"}) + request = RunRequest.from_dict({"message": "Hello", "correlationId": "corr-007", "thread_id": "ignored"}) assert request.message == "Hello" @@ -95,6 +102,7 @@ def test_from_dict_with_all_fields(self) -> None: """Test from_dict with all fields.""" data = { "message": "Test", + "correlationId": "corr-008", "role": "system", "response_format": { "__response_schema_type__": "pydantic_model", @@ -106,13 +114,14 @@ def test_from_dict_with_all_fields(self) -> None: request = RunRequest.from_dict(data) assert request.message == "Test" + assert request.correlation_id == "corr-008" assert request.role == Role.SYSTEM assert request.response_format is ModuleStructuredResponse assert request.enable_tool_calls is False - def test_from_dict_with_unknown_role_preserves_value(self) -> None: + def test_from_dict_unknown_role_preserves_value(self) -> None: """Test from_dict keeps custom roles intact.""" - data = {"message": "Test", "role": "reviewer"} + data = {"message": "Test", "correlationId": "corr-009", "role": "reviewer"} request = RunRequest.from_dict(data) assert request.role.value == "reviewer" @@ -120,15 +129,22 @@ def test_from_dict_with_unknown_role_preserves_value(self) -> None: def test_from_dict_empty_message(self) -> None: """Test from_dict with empty message.""" - request = RunRequest.from_dict({}) + request = RunRequest.from_dict({"correlationId": "corr-010"}) assert request.message == "" + assert request.correlation_id == "corr-010" assert request.role == Role.USER + def test_from_dict_missing_correlation_id_raises(self) -> None: + """Test from_dict raises when correlationId is missing.""" + with pytest.raises(ValueError, match="correlationId is required"): + RunRequest.from_dict({"message": "Test"}) + def test_round_trip_dict_conversion(self) -> None: """Test round-trip to_dict and from_dict.""" original = RunRequest( message="Test message", + correlation_id="corr-011", role=Role.SYSTEM, response_format=ModuleStructuredResponse, enable_tool_calls=False, @@ -138,6 +154,7 @@ def test_round_trip_dict_conversion(self) -> None: restored = RunRequest.from_dict(data) assert restored.message == original.message + assert restored.correlation_id == original.correlation_id assert restored.role == original.role assert restored.response_format is ModuleStructuredResponse assert restored.enable_tool_calls == original.enable_tool_calls @@ -146,6 +163,7 @@ def test_round_trip_with_pydantic_response_format(self) -> None: """Ensure Pydantic response formats serialize and deserialize properly.""" original = RunRequest( message="Structured", + correlation_id="corr-012", response_format=ModuleStructuredResponse, ) @@ -186,7 +204,7 @@ def test_round_trip_with_correlationId(self) -> None: original = RunRequest( message="Test message", role=Role.SYSTEM, - correlation_id="corr-123", + correlation_id="corr-124", ) data = original.to_dict() @@ -200,6 +218,7 @@ def test_init_with_orchestration_id(self) -> None: """Test RunRequest initialization with orchestration_id.""" request = RunRequest( message="Test message", + correlation_id="corr-125", orchestration_id="orch-123", ) @@ -210,6 +229,7 @@ def test_to_dict_with_orchestration_id(self) -> None: """Test to_dict includes orchestrationId.""" request = RunRequest( message="Test", + correlation_id="corr-126", orchestration_id="orch-456", ) data = request.to_dict() @@ -221,6 +241,7 @@ def test_to_dict_excludes_orchestration_id_when_none(self) -> None: """Test to_dict excludes orchestrationId when not set.""" request = RunRequest( message="Test", + correlation_id="corr-127", ) data = request.to_dict() @@ -230,6 +251,7 @@ def test_from_dict_with_orchestration_id(self) -> None: """Test from_dict with orchestrationId.""" data = { "message": "Test", + "correlationId": "corr-128", "orchestrationId": "orch-789", } request = RunRequest.from_dict(data) @@ -242,7 +264,7 @@ def test_round_trip_with_orchestration_id(self) -> None: original = RunRequest( message="Test message", role=Role.SYSTEM, - correlation_id="corr-123", + correlation_id="corr-129", orchestration_id="orch-123", ) diff --git a/python/packages/durabletask/tests/test_shim.py b/python/packages/durabletask/tests/test_shim.py index 79c8f641da..1fa348695c 100644 --- a/python/packages/durabletask/tests/test_shim.py +++ b/python/packages/durabletask/tests/test_shim.py @@ -6,6 +6,7 @@ Run with: pytest tests/test_shim.py -v """ +from typing import Any from unittest.mock import Mock import pytest @@ -14,6 +15,7 @@ from agent_framework_durabletask import DurableAgentThread from agent_framework_durabletask._executors import DurableAgentExecutor +from agent_framework_durabletask._models import RunRequest from agent_framework_durabletask._shim import DurableAgentProvider, DurableAIAgent @@ -29,11 +31,26 @@ def mock_executor() -> Mock: mock = Mock(spec=DurableAgentExecutor) mock.run_durable_agent = Mock(return_value=None) mock.get_new_thread = Mock(return_value=DurableAgentThread()) + + # Mock get_run_request to create actual RunRequest objects + def create_run_request( + message: str, response_format: type[BaseModel] | None = None, enable_tool_calls: bool = True + ) -> RunRequest: + import uuid + + return RunRequest( + message=message, + correlation_id=str(uuid.uuid4()), + response_format=response_format, + enable_tool_calls=enable_tool_calls, + ) + + mock.get_run_request = Mock(side_effect=create_run_request) return mock @pytest.fixture -def test_agent(mock_executor: Mock) -> DurableAIAgent: +def test_agent(mock_executor: Mock) -> DurableAIAgent[Any]: """Create a test agent with mock executor.""" return DurableAIAgent(mock_executor, "test_agent") @@ -41,34 +58,34 @@ def test_agent(mock_executor: Mock) -> DurableAIAgent: class TestDurableAIAgentMessageNormalization: """Test that DurableAIAgent properly normalizes various message input types.""" - def test_run_accepts_string_message(self, test_agent: DurableAIAgent, mock_executor: Mock) -> None: + def test_run_accepts_string_message(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None: """Verify run accepts and normalizes string messages.""" test_agent.run("Hello, world!") mock_executor.run_durable_agent.assert_called_once() - # Verify agent_name and message were passed correctly as kwargs + # Verify agent_name and run_request were passed correctly as kwargs _, kwargs = mock_executor.run_durable_agent.call_args assert kwargs["agent_name"] == "test_agent" - assert kwargs["message"] == "Hello, world!" + assert kwargs["run_request"].message == "Hello, world!" - def test_run_accepts_chat_message(self, test_agent: DurableAIAgent, mock_executor: Mock) -> None: + def test_run_accepts_chat_message(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None: """Verify run accepts and normalizes ChatMessage objects.""" chat_msg = ChatMessage(role="user", text="Test message") test_agent.run(chat_msg) mock_executor.run_durable_agent.assert_called_once() _, kwargs = mock_executor.run_durable_agent.call_args - assert kwargs["message"] == "Test message" + assert kwargs["run_request"].message == "Test message" - def test_run_accepts_list_of_strings(self, test_agent: DurableAIAgent, mock_executor: Mock) -> None: + def test_run_accepts_list_of_strings(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None: """Verify run accepts and joins list of strings.""" test_agent.run(["First message", "Second message"]) mock_executor.run_durable_agent.assert_called_once() _, kwargs = mock_executor.run_durable_agent.call_args - assert kwargs["message"] == "First message\nSecond message" + assert kwargs["run_request"].message == "First message\nSecond message" - def test_run_accepts_list_of_chat_messages(self, test_agent: DurableAIAgent, mock_executor: Mock) -> None: + def test_run_accepts_list_of_chat_messages(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None: """Verify run accepts and joins list of ChatMessage objects.""" messages = [ ChatMessage(role="user", text="Message 1"), @@ -78,29 +95,29 @@ def test_run_accepts_list_of_chat_messages(self, test_agent: DurableAIAgent, moc mock_executor.run_durable_agent.assert_called_once() _, kwargs = mock_executor.run_durable_agent.call_args - assert kwargs["message"] == "Message 1\nMessage 2" + assert kwargs["run_request"].message == "Message 1\nMessage 2" - def test_run_handles_none_message(self, test_agent: DurableAIAgent, mock_executor: Mock) -> None: + def test_run_handles_none_message(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None: """Verify run handles None message gracefully.""" test_agent.run(None) mock_executor.run_durable_agent.assert_called_once() _, kwargs = mock_executor.run_durable_agent.call_args - assert kwargs["message"] == "" + assert kwargs["run_request"].message == "" - def test_run_handles_empty_list(self, test_agent: DurableAIAgent, mock_executor: Mock) -> None: + def test_run_handles_empty_list(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None: """Verify run handles empty list gracefully.""" test_agent.run([]) mock_executor.run_durable_agent.assert_called_once() _, kwargs = mock_executor.run_durable_agent.call_args - assert kwargs["message"] == "" + assert kwargs["run_request"].message == "" class TestDurableAIAgentParameterFlow: """Test that parameters flow correctly through the shim to executor.""" - def test_run_forwards_thread_parameter(self, test_agent: DurableAIAgent, mock_executor: Mock) -> None: + def test_run_forwards_thread_parameter(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None: """Verify run forwards thread parameter to executor.""" thread = DurableAgentThread(service_thread_id="test-thread") test_agent.run("message", thread=thread) @@ -109,31 +126,23 @@ def test_run_forwards_thread_parameter(self, test_agent: DurableAIAgent, mock_ex _, kwargs = mock_executor.run_durable_agent.call_args assert kwargs["thread"] == thread - def test_run_forwards_response_format(self, test_agent: DurableAIAgent, mock_executor: Mock) -> None: + def test_run_forwards_response_format(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None: """Verify run forwards response_format parameter to executor.""" test_agent.run("message", response_format=ResponseFormatModel) mock_executor.run_durable_agent.assert_called_once() _, kwargs = mock_executor.run_durable_agent.call_args - assert kwargs["response_format"] == ResponseFormatModel - - def test_run_forwards_additional_kwargs(self, test_agent: DurableAIAgent, mock_executor: Mock) -> None: - """Verify run forwards additional kwargs to executor.""" - test_agent.run("message", custom_param="custom_value") - - mock_executor.run_durable_agent.assert_called_once() - _, kwargs = mock_executor.run_durable_agent.call_args - assert kwargs["custom_param"] == "custom_value" + assert kwargs["run_request"].response_format == ResponseFormatModel class TestDurableAIAgentProtocolCompliance: """Test that DurableAIAgent implements AgentProtocol correctly.""" - def test_agent_implements_protocol(self, test_agent: DurableAIAgent) -> None: + def test_agent_implements_protocol(self, test_agent: DurableAIAgent[Any]) -> None: """Verify DurableAIAgent implements AgentProtocol.""" assert isinstance(test_agent, AgentProtocol) - def test_agent_has_required_properties(self, test_agent: DurableAIAgent) -> None: + def test_agent_has_required_properties(self, test_agent: DurableAIAgent[Any]) -> None: """Verify DurableAIAgent has all required AgentProtocol properties.""" assert hasattr(test_agent, "id") assert hasattr(test_agent, "name") @@ -142,14 +151,14 @@ def test_agent_has_required_properties(self, test_agent: DurableAIAgent) -> None def test_agent_id_defaults_to_name(self, mock_executor: Mock) -> None: """Verify agent id defaults to name when not provided.""" - agent = DurableAIAgent(mock_executor, "my_agent") + agent: DurableAIAgent[Any] = DurableAIAgent(mock_executor, "my_agent") assert agent.id == "my_agent" assert agent.name == "my_agent" def test_agent_id_can_be_customized(self, mock_executor: Mock) -> None: """Verify agent id can be set independently from name.""" - agent = DurableAIAgent(mock_executor, "my_agent", agent_id="custom-id") + agent: DurableAIAgent[Any] = DurableAIAgent(mock_executor, "my_agent", agent_id="custom-id") assert agent.id == "custom-id" assert agent.name == "my_agent" @@ -158,7 +167,7 @@ def test_agent_id_can_be_customized(self, mock_executor: Mock) -> None: class TestDurableAIAgentThreadManagement: """Test thread creation and management.""" - def test_get_new_thread_delegates_to_executor(self, test_agent: DurableAIAgent, mock_executor: Mock) -> None: + def test_get_new_thread_delegates_to_executor(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None: """Verify get_new_thread delegates to executor.""" mock_thread = DurableAgentThread() mock_executor.get_new_thread.return_value = mock_thread @@ -168,7 +177,7 @@ def test_get_new_thread_delegates_to_executor(self, test_agent: DurableAIAgent, mock_executor.get_new_thread.assert_called_once_with("test_agent") assert thread == mock_thread - def test_get_new_thread_forwards_kwargs(self, test_agent: DurableAIAgent, mock_executor: Mock) -> None: + def test_get_new_thread_forwards_kwargs(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None: """Verify get_new_thread forwards kwargs to executor.""" mock_thread = DurableAgentThread(service_thread_id="thread-123") mock_executor.get_new_thread.return_value = mock_thread diff --git a/python/samples/getting_started/durabletask/01_single_agent/README.md b/python/samples/getting_started/durabletask/01_single_agent/README.md new file mode 100644 index 0000000000..f18d32e573 --- /dev/null +++ b/python/samples/getting_started/durabletask/01_single_agent/README.md @@ -0,0 +1,387 @@ +# Single Agent Sample (Python) - Durable Task + +This sample demonstrates how to use the **Durable Task package** for Agent Framework to create a simple agent hosting setup with persistent conversation state and distributed execution capabilities. + +## Description of the Sample + +This sample shows how to host a single AI agent (the "Joker" agent) using the Durable Task Scheduler. The agent responds to user messages by telling jokes, demonstrating: + +- How to register agents as durable entities that can persist state +- How to interact with registered agents from external clients +- How to maintain conversation context across multiple interactions +- The worker-client architecture pattern for distributed agent execution + +## Key Concepts Demonstrated + +- **Worker Registration**: Using `DurableAIAgentWorker` to register agents as durable entities that can process requests +- **Client Interaction**: Using `DurableAIAgentClient` to send messages to registered agents from external contexts +- **Thread Management**: Creating and maintaining conversation threads for stateful interactions +- **Distributed Architecture**: Separating worker (agent host) and client (caller) into independent processes +- **BYOP (Bring Your Own Platform)**: Not tied to Azure Functions - run anywhere with Durable Task Scheduler + +## Architecture Overview + +This sample uses a **client-worker architecture**: + +1. **Worker Process** (`worker.py`): Registers agents as durable entities and continuously listens for requests +2. **Client Process** (`client.py`): Connects to the same scheduler and sends requests to agents by name +3. **Durable Task Scheduler**: Coordinates communication between clients and workers (runs separately) + +This architecture enables: +- **Scalability**: Multiple workers can process requests in parallel +- **Reliability**: State is persisted, so conversations survive process restarts +- **Flexibility**: Clients and workers can be on different machines +- **BYOP (Bring Your Own Platform)**: Not tied to Azure Functions - run anywhere + +## Prerequisites + +### 1. Python 3.9+ + +Ensure you have Python 3.9 or later installed. + +### 2. Azure OpenAI Setup + +Configure your Azure OpenAI credentials: +- Set `AZURE_OPENAI_ENDPOINT` environment variable +- Set `AZURE_OPENAI_CHAT_DEPLOYMENT_NAME` environment variable +- Either: + - Set `AZURE_OPENAI_API_KEY` environment variable, OR + - Run `az login` to authenticate with Azure CLI + +### 3. Install Dependencies + +Install the required packages: + +```bash +pip install -r requirements.txt +``` + +Or if using uv: + +```bash +uv pip install -r requirements.txt +``` + +### 4. Durable Task Scheduler + +The sample requires a Durable Task Scheduler running. There are two options: + +#### Using the Emulator (Recommended for Local Development) + +The emulator simulates a scheduler and taskhub in a Docker container, making it ideal for development and learning. + +1. Pull the Docker Image for the Emulator: + ```bash + docker pull mcr.microsoft.com/dts/dts-emulator:latest + ``` + +2. Run the Emulator: + ```bash + docker run --name dtsemulator -d -p 8080:8080 -p 8082:8082 mcr.microsoft.com/dts/dts-emulator:latest + ``` + Wait a few seconds for the container to be ready. + +> *How to Run the Sample + +Once you have set up either the emulator or deployed scheduler, follow these steps to run the sample: + +1. **Activate your Python virtual environment** (if you're using one): + ```bash + python -m venv venv + source venv/bin/activate # On Windows, use: venv\Scripts\activate + ``` + +2. **If you're using a deployed scheduler**, set environment variables: + ```bash + export ENDPOINT=$(az durabletask scheduler show \ + --resource-group my-resource-group \ + --name my-scheduler \ + --query "properties.endpoint" \ + --output tsv) + + export TASKHUB="my-taskhub" + ``` + +3. **Install the required packages**: + ```bash + pip install -r requirements.txt + ``` + +4. **Start the worker** in a terminal: + ```bash + python worker.py + ``` + You should see output indicating the worker has started and registered the agent: + ``` + INFO:__main__:Starting Durable Task Agent Worker... + INFO:__main__:Using taskhub: default + INFO:__main__:Using endpoint: http://localhost:8080 + INFO:__main__:Creating and registering Joker agent... + INFO:__main__:✓ Registered agent: Joker + INFO:__main__: Entity name: dafx-Joker + INFO:__main__: + INFO:__main__:Worker is ready and listening for requests... + INFO:__main__:Press Ctrl+C to stop. + ``` + +5. **In a new terminal** (with the virtual environment activated if applicable), **run the client**: + > **Note:** Remember to set the environment variables again if you're using a deployed scheduler. + + ```bash + python client.py + ``` + az role assignment create \ + --assignee $loggedInUser \ + --role "Durable Task Data Contributor" \ + --scope "/subscriptions/$subscriptionId/resourceGroups/my-resource-group/providers/Microsoft.DurableTask/schedulers/my-scheduler/taskHubs/my-taskhub" + ``` + +5. Set environment variables: + ```bash + export ENDPOINT=$(az durabletask scheduler show \ + --resource-group my-resource-group \ + --name my-scheduler \ + --query "properties.endpoint" \ + --output tsv) + + export TASKHUB="my-taskhub" + ``` + +## Running the Sample + +### Step 1: Start the Worker + +In one terminal, start the worker to host the agent: + +```bash +python sample.py worker +``` + +You should see output similar to: +``` +Starting Durable Task worker... +Connecting to scheduler at: localhost:4001 +✓ Registered agent: Joker + Entity name: dafx-Joker + +Worker is ready and listening for requests... +Press Ctrl+C to stop. +``` + +The worker will continue running and processing requests until you stop it (Ctrl+C). + +### Step 2: Run the Client + +In a **separate terminal**, run the client to interact with the agent: +Understanding the Output + +When you run the sample, you'll see output from both the worker and client processes: + +### Worker Output + +The worker shows: +- Connection information (taskhub and endpoint) +- Registration of the Joker agent as a durable entity +- Entity name (`dafx-Joker`) +- Status message indicating it's ready to process requests + +Example: +``` +INFO:__main__:Starting Durable Task Agent Worker... +INFO:__main__:Using taskhub: default +INFO:__main__:Using endpoint: http://localhost:8080 +INFO:__main__:Creating and registering Joker agent... +INFO:__main__:✓ Registered agent: Joker +INFO:__main__: Entity name: dafx-Joker +INFO:__main__: +INFO:__main__:Worker is ready and listening for requests... +INFO:__main__:Press Ctrl+C to stop. +``` + +### Client Output + +The client shows: +- Connection information +- Thread creation +- User messages sent to the agent +- Agent responses (jokes) +- Token usage statistics +- Conversation completion status + +Example: +``` +INFO:__main__:Starting Durable Task Agent Client... +INFO:__main__:Using taskhub: default +INFO:__main__:Using endpoint: http://localhost:8080 +INFO:__main__: +INFO:__main__:Getting reference to Joker agent... +INFO:__main__:Created conversation thread: a1b2c3d4-e5f6-7890-abcd-ef1234567890 +INFO:__main__: +INFO:__main__:User: Tell me a short joke about cloud computing. +INFO:__main__: +INFO:__main__:Joker: Why did the cloud break up with the server? + +Because it found someone more "uplifting"! +INFO:__main__:Usage: UsageStats(input_tokens=42, output_tokens=18, total_tokens=60) +INFO:__main__: +INFO:__main__:User: Now tell me one about Python programming. +INFO:__main__: +INFO:__main__:Joker: Why do Python programmers prefer dark mode? +Understanding the Code + +### Worker (`worker.py`) + +The worker process is responsible for hosting agents: + +```python +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker +from agent_framework_durabletask import DurableAIAgentWorker + +# Create a worker using Azure Managed Durable Task +worker = DurableTaskSchedulerWorker( + host_address=endpoint, + secure_channel=endpoint != "http://localhost:8080", + taskhub=taskhub_name, + token_credential=credential +) + +# Wrap it with the agent worker +agent_worker = DurableAIAgentWorker(worker) + +# Create and register agents +joker_agent = create_joker_agent() +agent_worker.add_agent(joker_agent) + +# Start processing (blocks until stopped) +worker.start() +``` + +**What happens:** +- The agent is registered as a durable entity with name `dafx-{agent_name}` +- The worker continuously polls for requests directed to this entity +- Each request is routed to the agent's execution logic +- Conversation state is persisted automatically in the entity + +### Client (`client.py`) + +The client process interacts with registered agents: + +```python +from durabletask.azuremanaged.client import DurableTaskSchedulerClient +from agent_framework_durabletask import DurableAIAgentClient + +# Create a client using Azure Managed Durable Task +client = DurableTaskSchedulerClient( + host_address=endpoint, + secure_channel=endpoint != "http://localhost:8080", + taskhub=taskhub_name, + token_credential=credential +) + +# Wrap it with the agent client +agent_client = DurableAIAgentClient(client) + +# Get agent reference (no validation until execution) +joker = agent_client.get_agent("Joker") + +# Create thread and run +thread = joker.get_new_thread() +response = await joker.run(message, thread=thread) +``` + +**What happens:** +- The client constructs a request with the message and thread information +- The request is sent to the entity `dafx-Joker` via the scheduler +- The client waits for the entity to process theEmulator is running: +```bash +docker ps | grep dts-emulator +``` + +If not running, start it: +```bash +docker run --name dtsemulator -d -p 8080:8080 -p 8082:8082 mcr.microsoft.com/dts/dts-emulator:latest +``` + +### Agent Not Found + +**Error**: Agent execution fails with "entity not found" or similar + +**Solution**: +1. Ensure the worker is running and has registered the agent +2. Check that the agent name matches exactly (case-sensitive) +3. Verify both client and worker are connecting to the same endpoint and taskhub +4. Check worker logs for successful agent registration + +### Azure OpenAI Authentication + +**Error**: Authentication errors when creating the agent + +**Solution**: +1. Ensure `AZURE_OPENAI_ENDPOINT` and `AZURE_OPENAI_CHAT_DEPLOYMENT_NAME` are set +2. Either: + - Set `AZURE_OPENAI_API_KEY` environment variable, OR + - Run `az login` to authenticate with Azure CLI + +### Environment Variables Not Set + +If using a deployed scheduler, ensure you set the environment variables in **both** terminals (worker and client): +```bash +export ENDPOINT="" +export TASKHUB="" +``` + +## Reviewing the Agent in the Durable Task Scheduler Dashboard + +### Using the Emulator + +1. Navigate to http://localhost:8082 in your web browser +2. Click on the "default" task hub +3. You'll see the agent entity (`dafx-Joker`) in the list +4. Click on the entity to view: + - Entity state and conversation history + - Request and response details + - Execution timeline + +### Using a Deployed Scheduler + +1. Navigate to the Scheduler resource in the Azure portal +2. Go to the Task Hub subresource that you're using +3. Click on the dashboard URL in the top right corner +4. Search for your entity (`dafx-Joker`) +5. Review the entity state and execution history + +## Comparison with Azure Functions Sample + +| Aspect | Azure Functions | Durable Task (BYOP) | +|--------|----------------|---------------------| +| **Platform** | Azure Functions (PaaS) | Any platform with gRPC | +| **Hosting** | AgentFunctionApp | DurableTaskSchedulerWorker + DurableAIAgentWorker | +| **Client API** | HTTP endpoints | DurableAIAgentClient | +| **Infrastructure** | Managed by Azure | Self-hosted scheduler or Azure DTS | +| **Scalability** | Auto-scaling | Manual scaling or K8s | +| **Use Case** | Production cloud workloads | Local dev, on-prem, custom platforms | + +## Identity-based Authentication + +Learn how to set up [identity-based authentication](https://learn.microsoft.com/azure/azure-functions/durable/durable-task-scheduler/durable-task-scheduler-identity?tabs=df&pivots=az-cli) when you deploy to Azure. + +## Next Steps + +- **Multiple Agents**: Modify the sample to register multiple agents with different capabilities +- **Structured Responses**: Use `response_format` parameter to get JSON structured output +- **Agent Orchestration**: Create orchestrations that coordinate multiple agents (see advanced samples) +- **Production Deployment**: Deploy workers to Kubernetes, VMs, or container services +- **Monitoring**: Add telemetry and logging for production workloads + +## Related Samples + +- [Azure Functions Single Agent Sample](../../../azure_functions/01_single_agent/) - Azure Functions hosting +- [Durable Task Scheduler Samples](https://github.com/Azure-Samples/Durable-Task-Scheduler) - More patterns and examples + +## Additional Resources + +- [Durable Task Framework](https://github.com/microsoft/durabletask-python) +- [Agent Framework Documentation](https://github.com/microsoft/agent-framework) +- [Durable Task Scheduler](https://github.com/Azure-Samples/Durable-Task-Scheduler) +- [Azure Durable Task Scheduler Documentation](https://learn.microsoft.com/azure/azure-functions/durable/durable-task-scheduler/) + diff --git a/python/samples/getting_started/durabletask/01_single_agent/client.py b/python/samples/getting_started/durabletask/01_single_agent/client.py new file mode 100644 index 0000000000..bfd5147ea7 --- /dev/null +++ b/python/samples/getting_started/durabletask/01_single_agent/client.py @@ -0,0 +1,92 @@ +"""Client application for interacting with a Durable Task hosted agent. + +This client connects to the Durable Task Scheduler and sends requests to +registered agents, demonstrating how to interact with agents from external processes. + +Prerequisites: +- The worker must be running with the agent registered +- Set AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_CHAT_DEPLOYMENT_NAME + (plus AZURE_OPENAI_API_KEY or Azure CLI authentication) +- Durable Task Scheduler must be running +""" + +import asyncio +import logging +import os + +from agent_framework_durabletask import DurableAIAgentClient +from azure.identity import DefaultAzureCredential +from durabletask.azuremanaged.client import DurableTaskSchedulerClient + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +async def main() -> None: + """Main entry point for the client application.""" + logger.info("Starting Durable Task Agent Client...") + + # Get environment variables for taskhub and endpoint with defaults + taskhub_name = os.getenv("TASKHUB", "default") + endpoint = os.getenv("ENDPOINT", "http://localhost:8080") + + logger.info(f"Using taskhub: {taskhub_name}") + logger.info(f"Using endpoint: {endpoint}") + logger.info("") + + # Set credential to None for emulator, or DefaultAzureCredential for Azure + credential = None if endpoint == "http://localhost:8080" else DefaultAzureCredential() + + # Create a client using Azure Managed Durable Task + client = DurableTaskSchedulerClient( + host_address=endpoint, + secure_channel=endpoint != "http://localhost:8080", + taskhub=taskhub_name, + token_credential=credential + ) + + # Wrap it with the agent client + agent_client = DurableAIAgentClient(client) + + # Get a reference to the Joker agent + logger.info("Getting reference to Joker agent...") + joker = agent_client.get_agent("Joker") + + # Create a new thread for the conversation + thread = joker.get_new_thread() + + logger.info(f"Created conversation thread: {thread.session_id}") + logger.info("") + + try: + # First message + message1 = "Tell me a short joke about cloud computing." + logger.info(f"User: {message1}") + logger.info("") + + # Run the agent - this blocks until the response is ready + response1 = joker.run(message1, thread=thread) + logger.info(f"Agent: {response1.text}") + logger.info("") + + # Second message - continuing the conversation + message2 = "Now tell me one about Python programming." + logger.info(f"User: {message2}") + logger.info("") + + response2 = joker.run(message2, thread=thread) + logger.info(f"Agent: {response2.text}") + logger.info("") + + logger.info(f"Conversation completed successfully!") + logger.info(f"Thread ID: {thread.session_id}") + + except Exception as e: + logger.exception(f"Error during agent interaction: {e}") + finally: + logger.info("Client shutting down") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/getting_started/durabletask/01_single_agent/requirements.txt b/python/samples/getting_started/durabletask/01_single_agent/requirements.txt new file mode 100644 index 0000000000..da871507c8 --- /dev/null +++ b/python/samples/getting_started/durabletask/01_single_agent/requirements.txt @@ -0,0 +1,9 @@ +# Agent Framework packages +agent-framework-azure +agent-framework-durabletask + +# Durable Task Python SDK with Azure Managed support +durabletask-python + +# Azure authentication +azure-identity diff --git a/python/samples/getting_started/durabletask/01_single_agent/worker.py b/python/samples/getting_started/durabletask/01_single_agent/worker.py new file mode 100644 index 0000000000..4c0d172915 --- /dev/null +++ b/python/samples/getting_started/durabletask/01_single_agent/worker.py @@ -0,0 +1,89 @@ +"""Worker process for hosting a single Azure OpenAI-powered agent using Durable Task. + +This worker registers agents as durable entities and continuously listens for requests. +The worker should run as a background service, processing incoming agent requests. + +Prerequisites: +- Set AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_CHAT_DEPLOYMENT_NAME + (plus AZURE_OPENAI_API_KEY or Azure CLI authentication) +- Start a Durable Task Scheduler (e.g., using Docker) +""" + +import asyncio +import logging +import os + +from agent_framework.azure import AzureOpenAIChatClient +from agent_framework_durabletask import DurableAIAgentWorker +from azure.identity import AzureCliCredential, DefaultAzureCredential +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def create_joker_agent(): + """Create the Joker agent using Azure OpenAI. + + Returns: + AgentProtocol: The configured Joker agent + """ + return AzureOpenAIChatClient(credential=AzureCliCredential()).create_agent( + name="Joker", + instructions="You are good at telling jokes.", + ) + + +async def main(): + """Main entry point for the worker process.""" + logger.info("Starting Durable Task Agent Worker...") + + # Get environment variables for taskhub and endpoint with defaults + taskhub_name = os.getenv("TASKHUB", "default") + endpoint = os.getenv("ENDPOINT", "http://localhost:8080") + + logger.info(f"Using taskhub: {taskhub_name}") + logger.info(f"Using endpoint: {endpoint}") + + # Set credential to None for emulator, or DefaultAzureCredential for Azure + credential = None if endpoint == "http://localhost:8080" else DefaultAzureCredential() + + # Create a worker using Azure Managed Durable Task + worker = DurableTaskSchedulerWorker( + host_address=endpoint, + secure_channel=endpoint != "http://localhost:8080", + taskhub=taskhub_name, + token_credential=credential + ) + + # Wrap it with the agent worker + agent_worker = DurableAIAgentWorker(worker) + + # Create and register the Joker agent + logger.info("Creating and registering Joker agent...") + joker_agent = create_joker_agent() + agent_worker.add_agent(joker_agent) + + logger.info(f"✓ Registered agent: {joker_agent.name}") + logger.info(f" Entity name: dafx-{joker_agent.name}") + logger.info("") + logger.info("Worker is ready and listening for requests...") + logger.info("Press Ctrl+C to stop.") + logger.info("") + + try: + # Start the worker (this blocks until stopped) + worker.start() + + # Keep the worker running + while True: + await asyncio.sleep(1) + except KeyboardInterrupt: + logger.info("Worker shutdown initiated") + + logger.info("Worker stopped") + + +if __name__ == "__main__": + asyncio.run(main()) From 2eb71282d5439933d0aa5810079147992e5a6dd7 Mon Sep 17 00:00:00 2001 From: Laveesh Rohra Date: Wed, 31 Dec 2025 10:42:29 -0800 Subject: [PATCH 04/11] Add sample --- python/packages/durabletask/DESIGN.md | 384 ------------------ .../agent_framework_durabletask/_client.py | 6 +- .../agent_framework_durabletask/_executors.py | 61 ++- .../_response_utils.py | 6 +- .../agent_framework_durabletask/_shim.py | 13 +- .../durabletask/tests/test_executors.py | 285 ++++++++++--- .../README.md | 317 +++++++++++++++ .../client.py | 114 ++++++ .../requirements.txt | 9 + .../worker.py | 175 ++++++++ 10 files changed, 897 insertions(+), 473 deletions(-) delete mode 100644 python/packages/durabletask/DESIGN.md create mode 100644 python/samples/getting_started/durabletask/05_multi_agent_orchestration_concurrency/README.md create mode 100644 python/samples/getting_started/durabletask/05_multi_agent_orchestration_concurrency/client.py create mode 100644 python/samples/getting_started/durabletask/05_multi_agent_orchestration_concurrency/requirements.txt create mode 100644 python/samples/getting_started/durabletask/05_multi_agent_orchestration_concurrency/worker.py diff --git a/python/packages/durabletask/DESIGN.md b/python/packages/durabletask/DESIGN.md deleted file mode 100644 index 3fe14a385d..0000000000 --- a/python/packages/durabletask/DESIGN.md +++ /dev/null @@ -1,384 +0,0 @@ -# Design: Durable Task Provider for Agent Framework - -## Overview - -This package, `agent-framework-durabletask`, provides a durability layer for the Microsoft Agent Framework using the `durabletask` Python SDK. It enables stateful, reliable, and distributed agent execution on any platform (Bring Your Own Platform), decoupling the agent's durability from the Azure Functions platform. - -## Design Decision - -**Selected Approach: Object-Oriented Wrappers with Symmetric Factory Pattern + Strategy Pattern for Execution** - -We will use a symmetric Object-Oriented design where both the Client (external) and Orchestrator (internal) expose a consistent interface for retrieving and interacting with durable agents. Execution logic is delegated to dedicated provider strategies. - -## Core Philosophy - -* **Native `DurableEntity` Support**: We will leverage the `DurableEntity` support introduced in `durabletask` v1.0.0. -* **Symmetric Factories**: `DurableAIAgentClient` (for external use) and `DurableAIAgentOrchestrationContext` (for internal use) both provide a `get_agent` method. -* **Unified Interface**: `DurableAIAgent` serves as the common interface for executing agents, regardless of the context (Client vs Orchestration). -* **Strategy Pattern for Execution**: Execution logic is encapsulated in `DurableAgentExecutor` implementations, allowing flexible delegation while keeping the public API clean. -* **Consistent Return Type**: `DurableAIAgent.run` returns context-appropriate objects (awaitable for Client, yieldable Task for Orchestrator), ensuring consistent usage patterns. - -## Architecture - -### 1. Package Structure - -```text -packages/durabletask/ -├── pyproject.toml -├── README.md -├── agent_framework_durabletask/ -│ ├── __init__.py -│ ├── _worker.py # DurableAIAgentWorker -│ ├── _client.py # DurableAIAgentClient -│ ├── _orchestration_context.py # DurableAIAgentOrchestrationContext -│ ├── _executors.py # DurableAgentExecutor ABC and implementations -│ ├── _entities.py # AgentEntity implementation -│ ├── _models.py # Data models (RunRequest, AgentResponse, etc.) -│ ├── _durable_agent_state.py # State schema (Ported from azurefunctions) -│ └── _shim.py # DurableAIAgent and DurableAgentProvider ABC -└── tests/ -``` - -### 2. State Management (`_durable_agent_state.py`) - -**Important**: This will be the state maintained in the durable entities for both `durabletask` and `azurefunctions` package. - -### 3. The Agent Entity (`_entities.py`) - -We will implement a class `AgentEntity` that inherits from `durabletask.entities.DurableEntity`. - -**Important**: This will be ported from `azurefunctions` package too but with slight modifications, details TBD. - -### 4. The Worker Wrapper (`_worker.py`) - -The `DurableAIAgentWorker` wraps an existing `durabletask` worker instance. - -```python -class DurableAIAgentWorker: - def __init__(self, worker: TaskHubGrpcWorker): - self._worker = worker - self._registered_agents: dict[str, AgentProtocol] = {} - - def add_agent(self, agent: AgentProtocol) -> None: - """Registers an agent with the worker. - - Uses the factory pattern to create an AgentEntity class with the agent - instance injected, then registers it with the durabletask worker. - """ - # Store the agent reference - self._registered_agents[agent.name] = agent - - # Create a configured entity class using the factory - entity_class = create_agent_entity(agent) - - # Register the entity class with the worker - # The worker.add_entity method takes a class or function - self._worker.add_entity(entity_class) - - def start(self): - """Start the worker to begin processing tasks.""" - self._worker.start() - - def stop(self): - """Stop the worker gracefully.""" - self._worker.stop() -``` - -### 5. The Shim and Provider ABC (`_shim.py`) - -The `_shim.py` module contains two key abstractions: - -1. **`DurableAgentProvider` ABC**: Defines the contract for constructing durable agent proxies. Implemented by context-specific wrappers (client/orchestration) to provide a consistent `get_agent` entry point. - -2. **`DurableAIAgent`**: The agent shim that delegates execution to an executor strategy. - -```python -from abc import ABC, abstractmethod - -class DurableAgentProvider(ABC): - """Abstract provider for constructing durable agent proxies. - - Implemented by context-specific wrappers (client/orchestration) to return a - DurableAIAgent shim backed by their respective DurableAgentExecutor - implementation, ensuring a consistent get_agent entry point regardless of - execution context. - """ - - @abstractmethod - def get_agent(self, agent_name: str) -> DurableAIAgent: - """Retrieve a DurableAIAgent shim for the specified agent.""" - raise NotImplementedError("Subclasses must implement get_agent()") - - -class DurableAIAgent(AgentProtocol): - """A durable agent proxy that delegates execution to an executor. - - This class implements AgentProtocol but doesn't contain any agent logic itself. - Instead, it serves as a consistent interface that delegates to the underlying - executor, which can be either ClientAgentExecutor or OrchestrationAgentExecutor. - """ - - def __init__(self, executor: DurableAgentExecutor, name: str, *, agent_id: str | None = None): - self._executor = executor - self._name = name - self._id = agent_id if agent_id is not None else name - - def run(self, messages: ..., **kwargs) -> Any: - """Execute the agent via the injected executor.""" - message_str = self._normalize_messages(messages) - return self._executor.run_durable_agent( - agent_name=self._name, - message=message_str, - thread=kwargs.get('thread'), - response_format=kwargs.get('response_format'), - **kwargs - ) -``` - -### 6. The Executor Strategy (`_executors.py`) - -We introduce dedicated "Executor" classes to handle execution logic using the Strategy Pattern. These are internal execution strategies that are injected into the `DurableAIAgent` shim. This ensures the public API of the Client and Orchestration Context remains clean, while allowing the Shim to be reused across different environments. - -```python -from abc import ABC, abstractmethod -from typing import Any -from agent_framework import AgentThread -from pydantic import BaseModel - -class DurableAgentExecutor(ABC): - """Abstract base class for durable agent execution strategies.""" - - @abstractmethod - def run_durable_agent( - self, - agent_name: str, - message: str, - *, - thread: AgentThread | None = None, - response_format: type[BaseModel] | None = None, - **kwargs: Any, - ) -> Any: - """Execute the durable agent. - - Returns: - Any: Either an awaitable AgentRunResponse (Client) or a yieldable Task (Orchestrator). - """ - raise NotImplementedError - - @abstractmethod - def get_new_thread(self, agent_name: str, **kwargs: Any) -> AgentThread: - """Create a new thread appropriate for the context.""" - raise NotImplementedError - - -class ClientAgentExecutor(DurableAgentExecutor): - """Execution strategy for external clients (async).""" - - def __init__(self, client: 'TaskHubGrpcClient'): - self._client = client - - async def run_durable_agent( - self, - agent_name: str, - message: str, - *, - thread: AgentThread | None = None, - response_format: type[BaseModel] | None = None, - **kwargs: Any, - ) -> Any: - # Implementation using self._client - # Returns an awaitable AgentRunResponse - raise NotImplementedError("ClientAgentExecutor.run_durable_agent is not yet implemented") - - def get_new_thread(self, agent_name: str, **kwargs: Any) -> AgentThread: - # Implementation for client context - return AgentThread(**kwargs) - - -class OrchestrationAgentExecutor(DurableAgentExecutor): - """Execution strategy for orchestrations (sync/yield).""" - - def __init__(self, context: 'OrchestrationContext'): - self._context = context - - def run_durable_agent( - self, - agent_name: str, - message: str, - *, - thread: AgentThread | None = None, - response_format: type[BaseModel] | None = None, - **kwargs: Any, - ) -> Any: - # Implementation using self._context - # Returns a yieldable Task - raise NotImplementedError("OrchestrationAgentExecutor.run_durable_agent is not yet implemented") - - def get_new_thread(self, agent_name: str, **kwargs: Any) -> AgentThread: - # Implementation for orchestration context - return AgentThread(**kwargs) -``` - -**Benefits of the Strategy Pattern:** - -1. **Strong Contract**: ABC enforces implementation of required methods. -2. **Encapsulation**: Execution logic is hidden in provider classes. -3. **Flexibility**: Easy to add new providers (e.g., for Azure Functions). -4. **Separation of Concerns**: Client/Context wrappers focus on being factories/adapters. -5. **Reusability**: The shim can be reused across different environments without modification. - -### 7. The Client Wrapper (`_client.py`) - -The `DurableAIAgentClient` is for external clients (e.g., FastAPI, CLI). It implements `DurableAgentProvider` to provide the `get_agent` factory method, and instantiates the `ClientAgentExecutor` to inject into the `DurableAIAgent`. - -```python -from ._executors import ClientAgentExecutor -from ._shim import DurableAgentProvider, DurableAIAgent - -class DurableAIAgentClient(DurableAgentProvider): - def __init__(self, client: TaskHubGrpcClient): - self._client = client - - def get_agent(self, agent_name: str) -> DurableAIAgent: - """Retrieves a DurableAIAgent shim. - - Validates existence by attempting to fetch entity state/metadata. - """ - # Validation logic using self._client.get_entity(...) - # ... - executor = ClientAgentExecutor(self._client) - return DurableAIAgent(executor, agent_name) -``` - -### 8. The Orchestration Context Wrapper (`_orchestration_context.py`) - -The `DurableAIAgentOrchestrationContext` is for use *inside* orchestrations to get access to agents that were registered in the workers. It implements `DurableAgentProvider` to provide the `get_agent` factory method, and instantiates the `OrchestrationAgentExecutor`. - -```python -from ._executors import OrchestrationAgentExecutor -from ._shim import DurableAgentProvider, DurableAIAgent - -class DurableAIAgentOrchestrationContext(DurableAgentProvider): - def __init__(self, context: OrchestrationContext): - self._context = context - - def get_agent(self, agent_name: str) -> DurableAIAgent: - """Retrieves a DurableAIAgent shim. - - Validation is deferred or performed via call_entity if needed. - """ - executor = OrchestrationAgentExecutor(self._context) - return DurableAIAgent(executor, agent_name) -``` - -## Usage Experience - -**Scenario A: Worker Side** -```python -# 1. Define your agent -# The agent can be any implementation of AgentProtocol. -# For example, a standard Agent with a model and instructions. -my_agent = Agent( - name="my_agent", - instructions="You are a helpful assistant.", - model=openai_model -) - -# 2. Create the worker and the agent worker wrapper -with DurableTaskSchedulerWorker(...) as worker: - - agent_worker = DurableAIAgentWorker(worker) - - # 3. Register the agent - agent_worker.add_agent(my_agent) - - # 4. Start the worker - worker.start() - - # ... keep running ... -``` - -**Scenario B: Client Side** -```python -# 1. Configure the Durable Task client -client = DurableTaskSchedulerClient(...) - -# 2. Create the Agent Client wrapper -agent_client = DurableAIAgentClient(client) - -# 3. Get a reference to the agent -agent = await agent_client.get_agent("my_agent") - -# 4. Run the agent -# The returned object is designed to be compatible with both `await` (Client) -# and `yield` (Orchestrator). Implementation details on this unified return type will follow. -response = await agent.run("Hello") -``` - -**Scenario C: Orchestration Side** -```python -def orchestrator(context: OrchestrationContext): - # 1. Create the Agent Orchestration Context - agent_orch = DurableAIAgentOrchestrationContext(context) - - # 2. Get a reference to the agent - agent = agent_orch.get_agent("my_agent") - - # 3. Run the agent (returns a Task, so we yield it) - result = yield agent.run("Hello") - - return result -``` - -## Additional Styles Considered - -### Inheritance Pattern for worker and client (like `DurableAIAgentWorker`, `DurableAIAgentClient`, etc) - -We investigated inheriting `DurableAIAgentWorker` directly from `TaskHubGrpcWorker` (or `DurableTaskSchedulerWorker`) to provide a unified API where the agent worker *is* a durable task worker (and similarly the client). - -**Why we chose Composition over Inheritance:** - -1. **Initialization Divergence:** The `durabletask` package has two distinct worker classes with incompatible `__init__` signatures: - * `TaskHubGrpcWorker`: Requires `host_address`, `metadata`, etc. - * `DurableTaskSchedulerWorker`: Requires `host_address`, `taskhub`, `token_credential`, etc. - - To support both via inheritance, we would need to maintain two separate classes (e.g., `DurableAIAgentGrpcWorker` and `DurableAIAgentSchedulerWorker`) or use a complex Mixin approach. This increases the API surface area and maintenance burden. - -2. **Encapsulation:** The logic for Azure Managed DTS (authentication, routing) is currently encapsulated in an internal interceptor class within `durabletask`. Without changes to the upstream package to expose this logic, we cannot create a single "Universal" worker class that inherits from the base worker but supports Azure features. - -3. **Flexibility:** The Composition pattern allows `DurableAIAgentWorker` to accept *any* instance of a worker that satisfies the required interface. This makes it forward-compatible with future worker implementations or custom subclasses without requiring code changes in our package. - -4. **Simplicity:** While Composition requires a two-step setup (instantiate worker, then wrap it), it keeps the `agent-framework-durabletask` package simple, focused, and loosely coupled from the implementation details of the underlying `durabletask` workers. - -## Extension Point: Azure Functions Integration - -The Strategy Pattern design allows for easy integration with Azure Functions. The `azurefunctions` package can define its own `AzureFunctionsAgentExecutor` in `packages/azurefunctions/agent_framework_azurefunctions/_executors.py`. - -```python -from agent_framework_durabletask._executors import DurableAgentExecutor - -class AzureFunctionsAgentExecutor(DurableAgentExecutor): - """Execution strategy for Azure Functions orchestrations.""" - - def __init__(self, context: DurableOrchestrationContext): - self._context = context - - def run_durable_agent( - self, - agent_name: str, - message: str, - *, - thread: AgentThread | None = None, - response_format: type[BaseModel] | None = None, - **kwargs: Any, - ) -> Any: - # Implementation using Azure Functions context - # Returns AgentTask - ... - - def get_new_thread(self, agent_name: str, **kwargs: Any) -> AgentThread: - # Implementation for Azure Functions context - ... -``` - -Then `packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py` implements `DurableAgentProvider` and uses this executor when creating the agent, ensuring consistent behavior across platforms while accommodating Azure Functions-specific features. \ No newline at end of file diff --git a/python/packages/durabletask/agent_framework_durabletask/_client.py b/python/packages/durabletask/agent_framework_durabletask/_client.py index 8c16e67445..2ccf269509 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_client.py +++ b/python/packages/durabletask/agent_framework_durabletask/_client.py @@ -22,8 +22,7 @@ class DurableAIAgentClient(DurableAgentProvider[AgentRunResponse]): """Client wrapper for interacting with durable agents externally. This class wraps a durabletask TaskHubGrpcClient and provides a convenient - interface for retrieving and executing durable agents from external contexts - (e.g., FastAPI endpoints, CLI tools, etc.). + interface for retrieving and executing durable agents from external contexts. Example: ```python @@ -88,7 +87,4 @@ def get_agent(self, agent_name: str) -> DurableAIAgent[AgentRunResponse]: """ logger.debug("[DurableAIAgentClient] Creating agent proxy for: %s", agent_name) - # Note: Validation would require async, so we defer it to execution time - # The entity name will be f"dafx-{agent_name}" - return DurableAIAgent(self._executor, agent_name) diff --git a/python/packages/durabletask/agent_framework_durabletask/_executors.py b/python/packages/durabletask/agent_framework_durabletask/_executors.py index 9d4d15845f..563706f531 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_executors.py +++ b/python/packages/durabletask/agent_framework_durabletask/_executors.py @@ -42,7 +42,7 @@ class DurableAgentTask(CompositeTask[AgentRunResponse]): def __init__( self, - entity_task: Task[AgentRunResponse], + entity_task: Task[Any], response_format: type[BaseModel] | None, correlation_id: str, ): @@ -139,7 +139,7 @@ def _create_session_id( agent_name: str, thread: AgentThread | None = None, ) -> AgentSessionId: - """Resolve or create the AgentSessionId for the execution.""" + """Create the AgentSessionId for the execution.""" if isinstance(thread, DurableAgentThread) and thread.session_id is not None: return thread.session_id # Create new session ID - either no thread provided or it's a regular AgentThread @@ -402,6 +402,25 @@ def __init__(self, context: OrchestrationContext): self._context = context logger.debug("[OrchestrationAgentExecutor] Initialized") + def get_run_request( + self, + message: str, + response_format: type[BaseModel] | None, + enable_tool_calls: bool, + ) -> RunRequest: + """Get the current run request from the orchestration context. + + Returns: + RunRequest: The current run request + """ + request = super().get_run_request( + message, + response_format, + enable_tool_calls, + ) + request.orchestration_id = self._context.instance_id + return request + def run_durable_agent( self, agent_name: str, @@ -410,7 +429,39 @@ def run_durable_agent( ) -> DurableAgentTask: """Execute the agent via orchestration context. - Note: Implementation should call the entity (e.g., context.call_entity) - and return the native Task for yielding. Placeholder until wired. + Calls the agent entity and returns a DurableAgentTask that can be yielded + in orchestrations to wait for the entity's response. + + Args: + agent_name: Name of the agent to execute + run_request: The run request containing message and optional response format + thread: Optional conversation thread (creates new if not provided) + + Returns: + DurableAgentTask: A task wrapping the entity call that yields AgentRunResponse """ - raise NotImplementedError("OrchestrationAgentProvider.run_durable_agent is not yet implemented") + # Resolve session + session_id = self._create_session_id(agent_name, thread) + + # Create the entity ID + entity_id = EntityInstanceId( + entity=session_id.entity_name, + key=session_id.key, + ) + + logger.debug( + "[OrchestrationAgentExecutor] correlation_id: %s entity_id: %s session_id: %s", + run_request.correlation_id, + entity_id, + session_id, + ) + + # Call the entity and get the underlying task + entity_task: Task[Any] = self._context.call_entity(entity_id, "run", run_request.to_dict()) + + # Wrap in DurableAgentTask for response transformation + return DurableAgentTask( + entity_task=entity_task, + response_format=run_request.response_format, + correlation_id=run_request.correlation_id, + ) diff --git a/python/packages/durabletask/agent_framework_durabletask/_response_utils.py b/python/packages/durabletask/agent_framework_durabletask/_response_utils.py index ec7ed50c0c..aeb0e19c6c 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_response_utils.py +++ b/python/packages/durabletask/agent_framework_durabletask/_response_utils.py @@ -1,10 +1,6 @@ # Copyright (c) Microsoft. All rights reserved. -"""Shared utilities for handling AgentRunResponse parsing and validation. - -These utilities are used by both DurableAgentTask and AzureFunctions AgentTask -to ensure consistent response handling across different durable task implementations. -""" +"""Shared utilities for handling AgentRunResponse parsing and validation.""" from typing import Any diff --git a/python/packages/durabletask/agent_framework_durabletask/_shim.py b/python/packages/durabletask/agent_framework_durabletask/_shim.py index e469011753..c2e9aee039 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_shim.py +++ b/python/packages/durabletask/agent_framework_durabletask/_shim.py @@ -54,7 +54,8 @@ class DurableAIAgent(AgentProtocol, Generic[TaskT]): This class implements AgentProtocol but with one critical difference: - AgentProtocol.run() returns a Coroutine (async, must await) - - DurableAIAgent.run() returns TaskT (sync Task object, must yield) + - DurableAIAgent.run() returns TaskT (sync Task object - must yield + or the AgentRunResponse directly in the case of TaskHubGrpcClient) This represents fundamentally different execution models but maintains the same interface contract for all other properties and methods. @@ -63,13 +64,7 @@ class DurableAIAgent(AgentProtocol, Generic[TaskT]): and what type of Task object is returned. Type Parameters: - TaskT: The task type returned by this agent (e.g., DurableAgentTask, AgentTask) - - Note: - This class intentionally does NOT inherit from BaseAgent because: - - BaseAgent assumes async/await patterns - - Orchestration contexts require yield patterns - - BaseAgent methods like as_tool() would fail in orchestrations + TaskT: The task type returned by this agent (e.g., AgentRunResponse, DurableAgentTask, AgentTask) """ def __init__(self, executor: DurableAgentExecutor[TaskT], name: str, *, agent_id: str | None = None): @@ -125,7 +120,7 @@ def run( # pyright: ignore[reportIncompatibleMethodOverride] instead of async/await patterns. Returns: - TaskT: The task type specific to the executor (e.g., DurableAgentTask or AgentTask) + TaskT: The task type specific to the executor """ message_str = self._normalize_messages(messages) diff --git a/python/packages/durabletask/tests/test_executors.py b/python/packages/durabletask/tests/test_executors.py index 8b425962b0..e518528832 100644 --- a/python/packages/durabletask/tests/test_executors.py +++ b/python/packages/durabletask/tests/test_executors.py @@ -7,35 +7,89 @@ """ import time +from typing import Any from unittest.mock import Mock import pytest -from agent_framework import AgentRunResponse +from agent_framework import AgentRunResponse, Role +from durabletask.entities import EntityInstanceId +from durabletask.task import Task +from pydantic import BaseModel from agent_framework_durabletask import DurableAgentThread from agent_framework_durabletask._constants import DEFAULT_MAX_POLL_RETRIES, DEFAULT_POLL_INTERVAL_SECONDS from agent_framework_durabletask._executors import ( ClientAgentExecutor, + DurableAgentTask, OrchestrationAgentExecutor, ) -from agent_framework_durabletask._models import RunRequest +from agent_framework_durabletask._models import AgentSessionId, RunRequest + + +# Fixtures +@pytest.fixture +def mock_client() -> Mock: + """Provide a mock client for ClientAgentExecutor tests.""" + client = Mock() + client.signal_entity = Mock() + client.get_entity = Mock(return_value=None) + return client + + +@pytest.fixture +def mock_entity_task() -> Mock: + """Provide a mock entity task.""" + return Mock(spec=Task) + + +@pytest.fixture +def mock_orchestration_context(mock_entity_task: Mock) -> Mock: + """Provide a mock orchestration context with call_entity configured.""" + context = Mock() + context.call_entity = Mock(return_value=mock_entity_task) + return context + + +@pytest.fixture +def sample_run_request() -> RunRequest: + """Provide a sample RunRequest for tests.""" + return RunRequest(message="test message", correlation_id="test-123") + + +@pytest.fixture +def client_executor(mock_client: Mock) -> ClientAgentExecutor: + """Provide a ClientAgentExecutor with minimal polling for fast tests.""" + return ClientAgentExecutor(mock_client, max_poll_retries=1, poll_interval_seconds=0.01) + + +@pytest.fixture +def orchestration_executor(mock_orchestration_context: Mock) -> OrchestrationAgentExecutor: + """Provide an OrchestrationAgentExecutor.""" + return OrchestrationAgentExecutor(mock_orchestration_context) + + +@pytest.fixture +def successful_agent_response() -> dict[str, Any]: + """Provide a successful agent response dictionary.""" + return { + "messages": [{"role": "assistant", "contents": [{"type": "text", "text": "Hello!"}]}], + "created_at": "2025-12-30T10:00:00Z", + } class TestExecutorThreadCreation: """Test that executors properly create DurableAgentThread with parameters.""" - def test_client_executor_creates_durable_thread(self) -> None: + def test_client_executor_creates_durable_thread(self, mock_client: Mock) -> None: """Verify ClientAgentExecutor creates DurableAgentThread instances.""" - mock_client = Mock() executor = ClientAgentExecutor(mock_client) thread = executor.get_new_thread("test_agent") assert isinstance(thread, DurableAgentThread) - def test_client_executor_forwards_kwargs_to_thread(self) -> None: + def test_client_executor_forwards_kwargs_to_thread(self, mock_client: Mock) -> None: """Verify ClientAgentExecutor forwards kwargs to DurableAgentThread creation.""" - mock_client = Mock() executor = ClientAgentExecutor(mock_client) thread = executor.get_new_thread("test_agent", service_thread_id="client-123") @@ -43,93 +97,62 @@ def test_client_executor_forwards_kwargs_to_thread(self) -> None: assert isinstance(thread, DurableAgentThread) assert thread.service_thread_id == "client-123" - def test_orchestration_executor_creates_durable_thread(self) -> None: + def test_orchestration_executor_creates_durable_thread( + self, orchestration_executor: OrchestrationAgentExecutor + ) -> None: """Verify OrchestrationAgentExecutor creates DurableAgentThread instances.""" - mock_context = Mock() - executor = OrchestrationAgentExecutor(mock_context) - - thread = executor.get_new_thread("test_agent") + thread = orchestration_executor.get_new_thread("test_agent") assert isinstance(thread, DurableAgentThread) - def test_orchestration_executor_forwards_kwargs_to_thread(self) -> None: + def test_orchestration_executor_forwards_kwargs_to_thread( + self, orchestration_executor: OrchestrationAgentExecutor + ) -> None: """Verify OrchestrationAgentExecutor forwards kwargs to DurableAgentThread creation.""" - mock_context = Mock() - executor = OrchestrationAgentExecutor(mock_context) - - thread = executor.get_new_thread("test_agent", service_thread_id="orch-456") + thread = orchestration_executor.get_new_thread("test_agent", service_thread_id="orch-456") assert isinstance(thread, DurableAgentThread) assert thread.service_thread_id == "orch-456" -class TestExecutorRunNotImplemented: - """Test that run_durable_agent works as implemented.""" +class TestClientAgentExecutorRun: + """Test that ClientAgentExecutor.run_durable_agent works as implemented.""" - def test_client_executor_run_returns_response(self) -> None: + def test_client_executor_run_returns_response( + self, client_executor: ClientAgentExecutor, sample_run_request: RunRequest + ) -> None: """Verify ClientAgentExecutor.run_durable_agent returns AgentRunResponse (synchronous).""" - mock_client = Mock() - mock_client.signal_entity = Mock() - mock_client.get_entity = Mock(return_value=None) - # Use minimal polling parameters to avoid long test times - executor = ClientAgentExecutor(mock_client, max_poll_retries=1, poll_interval_seconds=0.01) - - # Create a RunRequest - run_request = RunRequest(message="test message", correlation_id="test-123") - - # This should return a timeout response (since mock doesn't have state) - result = executor.run_durable_agent("test_agent", run_request) + result = client_executor.run_durable_agent("test_agent", sample_run_request) # Verify it returns an AgentRunResponse (synchronous, not a coroutine) assert isinstance(result, AgentRunResponse) assert result is not None - def test_orchestration_executor_run_not_implemented(self) -> None: - """Verify OrchestrationAgentExecutor run raises NotImplementedError until implementation.""" - mock_context = Mock() - executor = OrchestrationAgentExecutor(mock_context) - - # Create a RunRequest - run_request = RunRequest(message="test message", correlation_id="test-123") - - with pytest.raises(NotImplementedError, match="OrchestrationAgentProvider.run_durable_agent"): - executor.run_durable_agent("test_agent", run_request) - class TestClientAgentExecutorPollingConfiguration: """Test polling configuration parameters for ClientAgentExecutor.""" - def test_executor_uses_default_polling_parameters(self) -> None: + def test_executor_uses_default_polling_parameters(self, mock_client: Mock) -> None: """Verify executor initializes with default polling parameters.""" - mock_client = Mock() executor = ClientAgentExecutor(mock_client) assert executor.max_poll_retries == DEFAULT_MAX_POLL_RETRIES assert executor.poll_interval_seconds == DEFAULT_POLL_INTERVAL_SECONDS - def test_executor_accepts_custom_polling_parameters(self) -> None: + def test_executor_accepts_custom_polling_parameters(self, mock_client: Mock) -> None: """Verify executor accepts and stores custom polling parameters.""" - mock_client = Mock() executor = ClientAgentExecutor(mock_client, max_poll_retries=20, poll_interval_seconds=0.5) assert executor.max_poll_retries == 20 assert executor.poll_interval_seconds == 0.5 - def test_executor_respects_custom_max_poll_retries(self) -> None: + def test_executor_respects_custom_max_poll_retries(self, mock_client: Mock, sample_run_request: RunRequest) -> None: """Verify executor respects custom max_poll_retries during polling.""" - - mock_client = Mock() - mock_client.signal_entity = Mock() - mock_client.get_entity = Mock(return_value=None) - # Create executor with only 2 retries executor = ClientAgentExecutor(mock_client, max_poll_retries=2, poll_interval_seconds=0.01) - # Create a RunRequest - run_request = RunRequest(message="test message", correlation_id="test-123") - # Run the agent - result = executor.run_durable_agent("test_agent", run_request) + result = executor.run_durable_agent("test_agent", sample_run_request) # Verify it returns AgentRunResponse (should timeout after 2 attempts) assert isinstance(result, AgentRunResponse) @@ -137,22 +160,14 @@ def test_executor_respects_custom_max_poll_retries(self) -> None: # Verify get_entity was called 2 times (max_poll_retries) assert mock_client.get_entity.call_count == 2 - def test_executor_respects_custom_poll_interval(self) -> None: + def test_executor_respects_custom_poll_interval(self, mock_client: Mock, sample_run_request: RunRequest) -> None: """Verify executor respects custom poll_interval_seconds during polling.""" - - mock_client = Mock() - mock_client.signal_entity = Mock() - mock_client.get_entity = Mock(return_value=None) - # Create executor with very short interval executor = ClientAgentExecutor(mock_client, max_poll_retries=3, poll_interval_seconds=0.01) - # Create a RunRequest - run_request = RunRequest(message="test message", correlation_id="test-123") - # Measure time taken start = time.time() - result = executor.run_durable_agent("test_agent", run_request) + result = executor.run_durable_agent("test_agent", sample_run_request) elapsed = time.time() - start # Should take roughly 3 * 0.01 = 0.03 seconds (plus overhead) @@ -161,5 +176,145 @@ def test_executor_respects_custom_poll_interval(self) -> None: assert isinstance(result, AgentRunResponse) +class TestOrchestrationAgentExecutorRun: + """Test OrchestrationAgentExecutor.run_durable_agent implementation.""" + + def test_orchestration_executor_run_returns_durable_agent_task( + self, orchestration_executor: OrchestrationAgentExecutor, sample_run_request: RunRequest + ) -> None: + """Verify OrchestrationAgentExecutor.run_durable_agent returns DurableAgentTask.""" + result = orchestration_executor.run_durable_agent("test_agent", sample_run_request) + + assert isinstance(result, DurableAgentTask) + + def test_orchestration_executor_calls_entity_with_correct_parameters( + self, + mock_orchestration_context: Mock, + orchestration_executor: OrchestrationAgentExecutor, + sample_run_request: RunRequest, + ) -> None: + """Verify call_entity is invoked with correct entity ID and request.""" + orchestration_executor.run_durable_agent("test_agent", sample_run_request) + + # Verify call_entity was called once + assert mock_orchestration_context.call_entity.call_count == 1 + + # Get the call arguments + call_args = mock_orchestration_context.call_entity.call_args + entity_id_arg = call_args[0][0] + operation_arg = call_args[0][1] + request_dict_arg = call_args[0][2] + + # Verify entity ID + assert isinstance(entity_id_arg, EntityInstanceId) + assert entity_id_arg.entity.startswith("test_agent-") + + # Verify operation name + assert operation_arg == "run" + + # Verify request dict + assert request_dict_arg == sample_run_request.to_dict() + + def test_orchestration_executor_uses_thread_session_id( + self, + mock_orchestration_context: Mock, + orchestration_executor: OrchestrationAgentExecutor, + sample_run_request: RunRequest, + ) -> None: + """Verify executor uses thread's session ID when provided.""" + # Create thread with specific session ID + session_id = AgentSessionId(name="test_agent", key="specific-key-123") + thread = DurableAgentThread.from_session_id(session_id) + + result = orchestration_executor.run_durable_agent("test_agent", sample_run_request, thread=thread) + + # Verify call_entity was called with the specific key + call_args = mock_orchestration_context.call_entity.call_args + entity_id_arg = call_args[0][0] + + assert entity_id_arg.key == "specific-key-123" + assert isinstance(result, DurableAgentTask) + + +class TestDurableAgentTask: + """Test DurableAgentTask completion and response transformation.""" + + def test_durable_agent_task_transforms_successful_result( + self, mock_entity_task: Mock, successful_agent_response: dict[str, Any] + ) -> None: + """Verify DurableAgentTask converts successful entity result to AgentRunResponse.""" + mock_entity_task.is_failed = False + mock_entity_task.get_result = Mock(return_value=successful_agent_response) + + task = DurableAgentTask(entity_task=mock_entity_task, response_format=None, correlation_id="test-123") + + # Simulate child task completion + task.on_child_completed(mock_entity_task) + + assert task.is_complete + result = task.get_result() + assert isinstance(result, AgentRunResponse) + assert len(result.messages) == 1 + assert result.messages[0].role == Role.ASSISTANT + + def test_durable_agent_task_propagates_failure(self, mock_entity_task: Mock) -> None: + """Verify DurableAgentTask propagates task failures.""" + mock_entity_task.is_failed = True + mock_entity_task.get_exception = Mock(return_value=ValueError("Entity error")) + + task = DurableAgentTask(entity_task=mock_entity_task, response_format=None, correlation_id="test-123") + + # Simulate child task completion with failure + task.on_child_completed(mock_entity_task) + + assert task.is_complete + assert task.is_failed + exception = task.get_exception() + assert isinstance(exception, ValueError) + assert str(exception) == "Entity error" + + def test_durable_agent_task_validates_response_format(self, mock_entity_task: Mock) -> None: + """Verify DurableAgentTask validates response format when provided.""" + mock_entity_task.is_failed = False + mock_entity_task.get_result = Mock( + return_value={ + "messages": [{"role": "assistant", "contents": [{"type": "text", "text": '{"answer": "42"}'}]}], + "created_at": "2025-12-30T10:00:00Z", + } + ) + + class TestResponse(BaseModel): + answer: str + + task = DurableAgentTask(entity_task=mock_entity_task, response_format=TestResponse, correlation_id="test-123") + + # Simulate child task completion + task.on_child_completed(mock_entity_task) + + assert task.is_complete + result = task.get_result() + assert isinstance(result, AgentRunResponse) + + def test_durable_agent_task_ignores_duplicate_completion( + self, mock_entity_task: Mock, successful_agent_response: dict[str, Any] + ) -> None: + """Verify DurableAgentTask ignores duplicate completion calls.""" + mock_entity_task.is_failed = False + mock_entity_task.get_result = Mock(return_value=successful_agent_response) + + task = DurableAgentTask(entity_task=mock_entity_task, response_format=None, correlation_id="test-123") + + # Simulate child task completion twice + task.on_child_completed(mock_entity_task) + first_result = task.get_result() + + task.on_child_completed(mock_entity_task) + second_result = task.get_result() + + # Should be the same result, get_result should only be called once + assert first_result is second_result + assert mock_entity_task.get_result.call_count == 1 + + if __name__ == "__main__": pytest.main([__file__, "-v", "--tb=short"]) diff --git a/python/samples/getting_started/durabletask/05_multi_agent_orchestration_concurrency/README.md b/python/samples/getting_started/durabletask/05_multi_agent_orchestration_concurrency/README.md new file mode 100644 index 0000000000..f11ab7aaa0 --- /dev/null +++ b/python/samples/getting_started/durabletask/05_multi_agent_orchestration_concurrency/README.md @@ -0,0 +1,317 @@ +# Multi-Agent Orchestration with Concurrency Sample (Python) - Durable Task + +This sample demonstrates how to use the **Durable Task package** with **OrchestrationAgentExecutor** to orchestrate multiple AI agents running concurrently and aggregate their responses. + +## Description of the Sample + +This sample shows how to run two domain-specific agents (a Physicist and a Chemist) concurrently using Durable Task orchestration. The agents respond to the same prompt from their respective domain perspectives, demonstrating: + +- How to register multiple agents as durable entities +- How to create an orchestration function that executes agents in parallel +- How to use `OrchestrationAgentExecutor` for concurrent agent execution +- How to aggregate results from multiple agents +- The benefits of concurrent execution for improved performance + +## Key Concepts Demonstrated + +- **Multi-Agent Architecture**: Running multiple specialized agents in parallel +- **Orchestration Functions**: Using Durable Task orchestrations to coordinate agent execution +- **OrchestrationAgentExecutor**: Execution strategy for orchestrations that returns `DurableAgentTask` objects +- **Concurrent Execution**: Using `task.when_all()` to run agents in parallel +- **Result Aggregation**: Collecting and combining responses from multiple agents +- **Thread Management**: Creating separate conversation threads for each agent +- **BYOP (Bring Your Own Platform)**: Not tied to Azure Functions - run anywhere with Durable Task Scheduler + +## Architecture Overview + +This sample uses a **worker-orchestration-client architecture**: + +1. **Worker Process** (`worker.py`): Registers two agents (Physicist and Chemist) as durable entities and an orchestration function +2. **Orchestration Function**: Coordinates concurrent execution of both agents and aggregates results +3. **Client Process** (`client.py`): Starts the orchestration and retrieves aggregated results +4. **Durable Task Scheduler**: Coordinates communication and orchestration execution (runs separately) + +### Execution Flow + +``` +Client → Start Orchestration → Orchestration Context + ↓ + ┌─────────────┴─────────────┐ + ↓ ↓ + Physicist Agent Chemist Agent + (Concurrent) (Concurrent) + ↓ ↓ + └─────────────┬─────────────┘ + ↓ + Aggregate Results → Client +``` + +## What Makes This Different? + +This sample differs from the single agent sample in several key ways: + +1. **Multiple Agents**: Two specialized agents with different domain expertise +2. **Concurrent Execution**: Both agents run simultaneously using `task.when_all()` +3. **Orchestration Function**: Uses a Durable Task orchestrator to coordinate execution +4. **OrchestrationAgentExecutor**: Different execution strategy that returns tasks instead of blocking +5. **Result Aggregation**: Combines responses from multiple agents into a single result + +## Prerequisites + +### 1. Python 3.9+ + +Ensure you have Python 3.9 or later installed. + +### 2. Azure OpenAI Setup + +Configure your Azure OpenAI credentials: +- Set `AZURE_OPENAI_ENDPOINT` environment variable +- Set `AZURE_OPENAI_CHAT_DEPLOYMENT_NAME` environment variable +- Either: + - Set `AZURE_OPENAI_API_KEY` environment variable, OR + - Run `az login` to authenticate with Azure CLI + +### 3. Install Dependencies + +Install the required packages: + +```bash +pip install -r requirements.txt +``` + +Or if using uv: + +```bash +uv pip install -r requirements.txt +``` + +### 4. Durable Task Scheduler + +The sample requires a Durable Task Scheduler running. For local development, use the emulator: + +#### Using the Emulator (Recommended for Local Development) + +1. Pull the Docker Image for the Emulator: + ```bash + docker pull mcr.microsoft.com/dts/dts-emulator:latest + ``` + +2. Run the Emulator: + ```bash + docker run --name dtsemulator -d -p 8080:8080 -p 8082:8082 mcr.microsoft.com/dts/dts-emulator:latest + ``` + Wait a few seconds for the container to be ready. + +## Running the Sample + +### Option 1: Combined Worker + Client (Recommended for Quick Start) + +The easiest way to run the sample is using `sample.py`, which runs both worker and client in a single process: + +```bash +python sample.py +``` + +This will: +1. Start the worker and register both agents and the orchestration +2. Start the orchestration with a sample prompt +3. Wait for completion and display aggregated results +4. Shut down the worker + +### Option 2: Separate Worker and Client + +For a more realistic distributed setup, run the worker and client separately: + +1. **Start the worker** in one terminal: + ```bash + python worker.py + ``` + + You should see output indicating both agents are registered: + ``` + INFO:__main__:Starting Durable Task Multi-Agent Worker with Orchestration... + INFO:__main__:Using taskhub: default + INFO:__main__:Using endpoint: http://localhost:8080 + INFO:__main__:Creating and registering agents... + INFO:__main__:✓ Registered agent: PhysicistAgent + INFO:__main__: Entity name: dafx-PhysicistAgent + INFO:__main__:✓ Registered agent: ChemistAgent + INFO:__main__: Entity name: dafx-ChemistAgent + INFO:__main__:✓ Registered orchestration: multi_agent_concurrent_orchestration + INFO:__main__:Worker is ready and listening for requests... + ``` + +2. **In a new terminal**, run the client: + ```bash + python client.py + ``` + + The client will start the orchestration and wait for results. + +## Understanding the Output + +### Worker Output + +The worker shows: +- Connection information (taskhub and endpoint) +- Registration of both agents (Physicist and Chemist) as durable entities +- Registration of the orchestration function +- Status messages during orchestration execution + +Example: +``` +INFO:__main__:Starting Durable Task Multi-Agent Worker with Orchestration... +INFO:__main__:Using taskhub: default +INFO:__main__:Using endpoint: http://localhost:8080 +INFO:__main__:Creating and registering agents... +INFO:__main__:✓ Registered agent: PhysicistAgent +INFO:__main__: Entity name: dafx-PhysicistAgent +INFO:__main__:✓ Registered agent: ChemistAgent +INFO:__main__: Entity name: dafx-ChemistAgent +INFO:__main__:✓ Registered orchestration: multi_agent_concurrent_orchestration +INFO:__main__:Worker is ready and listening for requests... +``` + +### Client Output + +The client shows: +- The prompt sent to both agents +- Orchestration instance ID +- Status updates during execution +- Aggregated results from both agents + +Example: +``` +INFO:__main__:Starting Durable Task Multi-Agent Orchestration Client... +INFO:__main__:Using taskhub: default +INFO:__main__:Using endpoint: http://localhost:8080 + +INFO:__main__:Prompt: What is temperature? + +INFO:__main__:Starting multi-agent concurrent orchestration... +INFO:__main__:Orchestration started with instance ID: abc123... +INFO:__main__:Waiting for orchestration to complete... + +INFO:__main__:Orchestration status: COMPLETED +================================================================================ +Orchestration completed successfully! +================================================================================ + +Prompt: What is temperature? + +Results: + +Physicist's response: + Temperature measures the average kinetic energy of particles in a system... + +Chemist's response: + Temperature reflects how molecular motion influences reaction rates... + +================================================================================ +``` + +### Orchestration Output + +During execution, the orchestration logs show: +- When concurrent execution starts +- Thread creation for each agent +- Task creation and execution +- Completion and result aggregation + +Example: +``` +INFO:__main__:[Orchestration] Starting concurrent execution for prompt: What is temperature? +INFO:__main__:[Orchestration] Created threads - Physicist: session-123, Chemist: session-456 +INFO:__main__:[Orchestration] Created agent tasks, executing concurrently... +INFO:__main__:[Orchestration] Both agents completed +INFO:__main__:[Orchestration] Aggregated results ready +``` + +## How It Works + +### 1. Agent Registration + +Both agents are registered as durable entities in the worker: + +```python +agent_worker.add_agent(physicist_agent) +agent_worker.add_agent(chemist_agent) +``` + +### 2. Orchestration Registration + +The orchestration function is registered with the worker: + +```python +worker.add_orchestrator(multi_agent_concurrent_orchestration) +``` + +### 3. Orchestration Execution + +The orchestration uses `OrchestrationAgentExecutor` (implicitly through `context.get_agent()`): + +```python +@task.orchestrator +def multi_agent_concurrent_orchestration(context: OrchestrationContext): + # Get agents (uses OrchestrationAgentExecutor internally) + physicist = context.get_agent(PHYSICIST_AGENT_NAME) + chemist = context.get_agent(CHEMIST_AGENT_NAME) + + # Create tasks (returns DurableAgentTask instances) + physicist_task = physicist.run(messages=prompt, thread=physicist_thread) + chemist_task = chemist.run(messages=prompt, thread=chemist_thread) + + # Execute concurrently + task_results = yield task.when_all([physicist_task, chemist_task]) + + # Aggregate results + return { + "physicist": task_results[0].text, + "chemist": task_results[1].text, + } +``` + +### 4. Key Differences from Single Agent + +- **OrchestrationAgentExecutor**: Returns `DurableAgentTask` instead of `AgentRunResponse` +- **Concurrent Execution**: Uses `task.when_all()` for parallel execution +- **Yield Syntax**: Orchestrations use `yield` to await async operations +- **Result Aggregation**: Combines multiple agent responses into one result + +## Customization + +You can modify the sample to: + +1. **Change the prompt**: Edit the `prompt` variable in `sample.py` or `client.py` +2. **Add more agents**: Create additional agents and add them to the orchestration +3. **Change agent instructions**: Modify the `instructions` parameter when creating agents +4. **Adjust concurrency**: Use different task combination patterns (e.g., sequential, selective) +5. **Add error handling**: Implement retry logic or fallback agents + +## Troubleshooting + +### Orchestration times out +- Increase `max_wait_time` in the client code +- Check that both agents are properly registered in the worker +- Verify Azure OpenAI endpoint and credentials + +### Agents return errors +- Verify Azure OpenAI deployment name is correct +- Check Azure OpenAI quota and rate limits +- Review worker logs for detailed error messages + +### Connection errors +- Ensure the Durable Task Scheduler is running +- Verify `ENDPOINT` and `TASKHUB` environment variables +- Check network connectivity and firewall rules + +## Related Samples + +- **01_single_agent**: Basic single agent setup +- **Azure Functions 05_multi_agent_orchestration_concurrency**: Similar pattern using Azure Functions + +## Learn More + +- [Agent Framework Documentation](https://github.com/microsoft/agent-framework) +- [Durable Task Framework](https://github.com/microsoft/durabletask) +- [Azure OpenAI Service](https://azure.microsoft.com/en-us/products/cognitive-services/openai-service) diff --git a/python/samples/getting_started/durabletask/05_multi_agent_orchestration_concurrency/client.py b/python/samples/getting_started/durabletask/05_multi_agent_orchestration_concurrency/client.py new file mode 100644 index 0000000000..2e517a8b6a --- /dev/null +++ b/python/samples/getting_started/durabletask/05_multi_agent_orchestration_concurrency/client.py @@ -0,0 +1,114 @@ +"""Client application for starting a multi-agent concurrent orchestration. + +This client connects to the Durable Task Scheduler and starts an orchestration +that runs two agents (physicist and chemist) concurrently, then retrieves and +displays the aggregated results. + +Prerequisites: +- The worker must be running with both agents and orchestration registered +- Set AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_CHAT_DEPLOYMENT_NAME + (plus AZURE_OPENAI_API_KEY or Azure CLI authentication) +- Durable Task Scheduler must be running +""" + +import asyncio +import json +import logging +import os + +from azure.identity import DefaultAzureCredential +from durabletask.azuremanaged.client import DurableTaskSchedulerClient + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +async def main() -> None: + """Main entry point for the client application.""" + logger.info("Starting Durable Task Multi-Agent Orchestration Client...") + + # Get environment variables for taskhub and endpoint with defaults + taskhub_name = os.getenv("TASKHUB", "default") + endpoint = os.getenv("ENDPOINT", "http://localhost:8080") + + logger.info(f"Using taskhub: {taskhub_name}") + logger.info(f"Using endpoint: {endpoint}") + logger.info("") + + # Set credential to None for emulator, or DefaultAzureCredential for Azure + credential = None if endpoint == "http://localhost:8080" else DefaultAzureCredential() + + # Create a client using Azure Managed Durable Task + client = DurableTaskSchedulerClient( + host_address=endpoint, + secure_channel=endpoint != "http://localhost:8080", + taskhub=taskhub_name, + token_credential=credential + ) + + # Define the prompt to send to both agents + prompt = "What is temperature?" + + logger.info(f"Prompt: {prompt}") + logger.info("") + logger.info("Starting multi-agent concurrent orchestration...") + + try: + # Start the orchestration with the prompt as input + instance_id = client.schedule_new_orchestration( + orchestrator="multi_agent_concurrent_orchestration", + input=prompt, + ) + + logger.info(f"Orchestration started with instance ID: {instance_id}") + logger.info("Waiting for orchestration to complete...") + logger.info("") + + # Retrieve the final state + metadata = client.wait_for_orchestration_completion( + instance_id=instance_id, + ) + + if metadata and metadata.runtime_status.name == "COMPLETED": + result = metadata.serialized_output + + logger.info("=" * 80) + logger.info("Orchestration completed successfully!") + logger.info("=" * 80) + logger.info("") + logger.info(f"Prompt: {prompt}") + logger.info("") + logger.info("Results:") + logger.info("") + + # Parse and display the result + if result: + result_dict = json.loads(result) + + logger.info("Physicist's response:") + logger.info(f" {result_dict.get('physicist', 'N/A')}") + logger.info("") + + logger.info("Chemist's response:") + logger.info(f" {result_dict.get('chemist', 'N/A')}") + logger.info("") + + logger.info("=" * 80) + + elif metadata: + logger.error(f"Orchestration ended with status: {metadata.runtime_status.name}") + if metadata.serialized_output: + logger.error(f"Output: {metadata.serialized_output}") + else: + logger.error("Orchestration did not complete within the timeout period") + + except Exception as e: + logger.exception(f"Error during orchestration: {e}") + finally: + logger.info("") + logger.info("Client shutting down") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/getting_started/durabletask/05_multi_agent_orchestration_concurrency/requirements.txt b/python/samples/getting_started/durabletask/05_multi_agent_orchestration_concurrency/requirements.txt new file mode 100644 index 0000000000..da871507c8 --- /dev/null +++ b/python/samples/getting_started/durabletask/05_multi_agent_orchestration_concurrency/requirements.txt @@ -0,0 +1,9 @@ +# Agent Framework packages +agent-framework-azure +agent-framework-durabletask + +# Durable Task Python SDK with Azure Managed support +durabletask-python + +# Azure authentication +azure-identity diff --git a/python/samples/getting_started/durabletask/05_multi_agent_orchestration_concurrency/worker.py b/python/samples/getting_started/durabletask/05_multi_agent_orchestration_concurrency/worker.py new file mode 100644 index 0000000000..bb35b5a8d4 --- /dev/null +++ b/python/samples/getting_started/durabletask/05_multi_agent_orchestration_concurrency/worker.py @@ -0,0 +1,175 @@ +"""Worker process for hosting multiple agents with orchestration using Durable Task. + +This worker registers two domain-specific agents (physicist and chemist) and an orchestration +function that runs them concurrently. The orchestration uses OrchestrationAgentExecutor +to execute agents in parallel and aggregate their responses. + +Prerequisites: +- Set AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_CHAT_DEPLOYMENT_NAME + (plus AZURE_OPENAI_API_KEY or Azure CLI authentication) +- Start a Durable Task Scheduler (e.g., using Docker) +""" + +import asyncio +from collections.abc import Generator +import logging +import os +from typing import Any + +from agent_framework import AgentRunResponse +from agent_framework.azure import AzureOpenAIChatClient +from agent_framework_durabletask import DurableAIAgentOrchestrationContext, DurableAIAgentWorker +from azure.identity import AzureCliCredential, DefaultAzureCredential +from durabletask.task import OrchestrationContext, when_all, Task +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Agent names +PHYSICIST_AGENT_NAME = "PhysicistAgent" +CHEMIST_AGENT_NAME = "ChemistAgent" + + +def create_physicist_agent(): + """Create the Physicist agent using Azure OpenAI. + + Returns: + AgentProtocol: The configured Physicist agent + """ + return AzureOpenAIChatClient(credential=AzureCliCredential()).create_agent( + name=PHYSICIST_AGENT_NAME, + instructions="You are an expert in physics. You answer questions from a physics perspective.", + ) + + +def create_chemist_agent(): + """Create the Chemist agent using Azure OpenAI. + + Returns: + AgentProtocol: The configured Chemist agent + """ + return AzureOpenAIChatClient(credential=AzureCliCredential()).create_agent( + name=CHEMIST_AGENT_NAME, + instructions="You are an expert in chemistry. You answer questions from a chemistry perspective.", + ) + + +def multi_agent_concurrent_orchestration(context: OrchestrationContext, prompt: str) -> Generator[Task[Any], Any, dict[str, str]]: + """Orchestration that runs both agents in parallel and aggregates results. + + Uses DurableAIAgentOrchestrationContext to wrap the orchestration context and + access agents via the OrchestrationAgentExecutor. + + Args: + context: The orchestration context + + Returns: + dict: Dictionary with 'physicist' and 'chemist' response texts + """ + + logger.info(f"[Orchestration] Starting concurrent execution for prompt: {prompt}") + + # Wrap the orchestration context to access agents + agent_context = DurableAIAgentOrchestrationContext(context) + + # Get agents using the agent context (returns DurableAIAgent proxies) + physicist = agent_context.get_agent(PHYSICIST_AGENT_NAME) + chemist = agent_context.get_agent(CHEMIST_AGENT_NAME) + + # Create separate threads for each agent + physicist_thread = physicist.get_new_thread() + chemist_thread = chemist.get_new_thread() + + logger.info(f"[Orchestration] Created threads - Physicist: {physicist_thread.session_id}, Chemist: {chemist_thread.session_id}") + + # Create tasks from agent.run() calls - these return DurableAgentTask instances + physicist_task = physicist.run(messages=str(prompt), thread=physicist_thread) + chemist_task = chemist.run(messages=str(prompt), thread=chemist_thread) + + logger.info("[Orchestration] Created agent tasks, executing concurrently...") + + # Execute both tasks concurrently using when_all + # The DurableAgentTask instances wrap the underlying entity calls + task_results = yield when_all([physicist_task, chemist_task]) + + logger.info("[Orchestration] Both agents completed") + + # Extract results from the tasks - DurableAgentTask yields AgentRunResponse + physicist_result: AgentRunResponse = task_results[0] + chemist_result: AgentRunResponse = task_results[1] + + result = { + "physicist": physicist_result.text, + "chemist": chemist_result.text, + } + + logger.info(f"[Orchestration] Aggregated results ready") + return result + + +async def main(): + """Main entry point for the worker process.""" + logger.info("Starting Durable Task Multi-Agent Worker with Orchestration...") + + # Get environment variables for taskhub and endpoint with defaults + taskhub_name = os.getenv("TASKHUB", "default") + endpoint = os.getenv("ENDPOINT", "http://localhost:8080") + + logger.info(f"Using taskhub: {taskhub_name}") + logger.info(f"Using endpoint: {endpoint}") + + # Set credential to None for emulator, or DefaultAzureCredential for Azure + credential = None if endpoint == "http://localhost:8080" else DefaultAzureCredential() + + # Create a worker using Azure Managed Durable Task + worker = DurableTaskSchedulerWorker( + host_address=endpoint, + secure_channel=endpoint != "http://localhost:8080", + taskhub=taskhub_name, + token_credential=credential + ) + + # Wrap it with the agent worker + agent_worker = DurableAIAgentWorker(worker) + + # Create and register both agents + logger.info("Creating and registering agents...") + physicist_agent = create_physicist_agent() + chemist_agent = create_chemist_agent() + + agent_worker.add_agent(physicist_agent) + agent_worker.add_agent(chemist_agent) + + logger.info(f"✓ Registered agent: {physicist_agent.name}") + logger.info(f" Entity name: dafx-{physicist_agent.name}") + logger.info(f"✓ Registered agent: {chemist_agent.name}") + logger.info(f" Entity name: dafx-{chemist_agent.name}") + logger.info("") + + # Register the orchestration function + logger.info("Registering orchestration function...") + worker.add_orchestrator(multi_agent_concurrent_orchestration) + logger.info(f"✓ Registered orchestration: {multi_agent_concurrent_orchestration.__name__}") + logger.info("") + + logger.info("Worker is ready and listening for requests...") + logger.info("Press Ctrl+C to stop.") + logger.info("") + + try: + # Start the worker (this blocks until stopped) + worker.start() + + # Keep the worker running + while True: + await asyncio.sleep(1) + except KeyboardInterrupt: + logger.info("Worker shutdown initiated") + + logger.info("Worker stopped") + + +if __name__ == "__main__": + asyncio.run(main()) From 5f926aa00905208ba7ab8ec317392fbd0eb498d3 Mon Sep 17 00:00:00 2001 From: Laveesh Rohra Date: Wed, 31 Dec 2025 12:34:58 -0800 Subject: [PATCH 05/11] Update readmes --- .gitignore | 5 +- .../durabletask/01_single_agent/README.md | 385 ++---------------- .../durabletask/01_single_agent/sample.py | 137 +++++++ .../README.md | 72 ++++ .../client.py | 104 +++++ .../requirements.txt | 9 + .../sample.py | 255 ++++++++++++ .../worker.py | 167 ++++++++ .../README.md | 310 ++------------ .../sample.py | 266 ++++++++++++ .../getting_started/durabletask/README.md | 124 ++++++ 11 files changed, 1201 insertions(+), 633 deletions(-) create mode 100644 python/samples/getting_started/durabletask/01_single_agent/sample.py create mode 100644 python/samples/getting_started/durabletask/04_single_agent_orchestration_chaining/README.md create mode 100644 python/samples/getting_started/durabletask/04_single_agent_orchestration_chaining/client.py create mode 100644 python/samples/getting_started/durabletask/04_single_agent_orchestration_chaining/requirements.txt create mode 100644 python/samples/getting_started/durabletask/04_single_agent_orchestration_chaining/sample.py create mode 100644 python/samples/getting_started/durabletask/04_single_agent_orchestration_chaining/worker.py create mode 100644 python/samples/getting_started/durabletask/05_multi_agent_orchestration_concurrency/sample.py create mode 100644 python/samples/getting_started/durabletask/README.md diff --git a/.gitignore b/.gitignore index 6d6f957b2a..e4ebc70b54 100644 --- a/.gitignore +++ b/.gitignore @@ -225,7 +225,4 @@ local.settings.json **/frontend/dist/ # Database files -*.db - -# Sample files -**/sample.py +*.db \ No newline at end of file diff --git a/python/samples/getting_started/durabletask/01_single_agent/README.md b/python/samples/getting_started/durabletask/01_single_agent/README.md index f18d32e573..6e31b0a737 100644 --- a/python/samples/getting_started/durabletask/01_single_agent/README.md +++ b/python/samples/getting_started/durabletask/01_single_agent/README.md @@ -1,387 +1,66 @@ -# Single Agent Sample (Python) - Durable Task +# Single Agent Sample -This sample demonstrates how to use the **Durable Task package** for Agent Framework to create a simple agent hosting setup with persistent conversation state and distributed execution capabilities. - -## Description of the Sample - -This sample shows how to host a single AI agent (the "Joker" agent) using the Durable Task Scheduler. The agent responds to user messages by telling jokes, demonstrating: - -- How to register agents as durable entities that can persist state -- How to interact with registered agents from external clients -- How to maintain conversation context across multiple interactions -- The worker-client architecture pattern for distributed agent execution +This sample demonstrates how to use the durable agents extension to create a worker-client setup that hosts a single AI agent and provides interactive conversation via the Durable Task Scheduler. ## Key Concepts Demonstrated -- **Worker Registration**: Using `DurableAIAgentWorker` to register agents as durable entities that can process requests -- **Client Interaction**: Using `DurableAIAgentClient` to send messages to registered agents from external contexts -- **Thread Management**: Creating and maintaining conversation threads for stateful interactions -- **Distributed Architecture**: Separating worker (agent host) and client (caller) into independent processes -- **BYOP (Bring Your Own Platform)**: Not tied to Azure Functions - run anywhere with Durable Task Scheduler - -## Architecture Overview - -This sample uses a **client-worker architecture**: - -1. **Worker Process** (`worker.py`): Registers agents as durable entities and continuously listens for requests -2. **Client Process** (`client.py`): Connects to the same scheduler and sends requests to agents by name -3. **Durable Task Scheduler**: Coordinates communication between clients and workers (runs separately) - -This architecture enables: -- **Scalability**: Multiple workers can process requests in parallel -- **Reliability**: State is persisted, so conversations survive process restarts -- **Flexibility**: Clients and workers can be on different machines -- **BYOP (Bring Your Own Platform)**: Not tied to Azure Functions - run anywhere - -## Prerequisites - -### 1. Python 3.9+ - -Ensure you have Python 3.9 or later installed. - -### 2. Azure OpenAI Setup +- Using the Microsoft Agent Framework to define a simple AI agent with a name and instructions. +- Registering durable agents with the worker and interacting with them via a client. +- Conversation management (via threads) for isolated interactions. +- Worker-client architecture for distributed agent execution. -Configure your Azure OpenAI credentials: -- Set `AZURE_OPENAI_ENDPOINT` environment variable -- Set `AZURE_OPENAI_CHAT_DEPLOYMENT_NAME` environment variable -- Either: - - Set `AZURE_OPENAI_API_KEY` environment variable, OR - - Run `az login` to authenticate with Azure CLI +## Environment Setup -### 3. Install Dependencies +See the [README.md](../README.md) file in the parent directory for more information on how to configure the environment, including how to install and run common sample dependencies. -Install the required packages: +## Running the Sample -```bash -pip install -r requirements.txt -``` +With the environment setup, you can run the sample using separate worker and client processes: -Or if using uv: +**Start the worker:** ```bash -uv pip install -r requirements.txt +cd samples/getting_started/durabletask/01_single_agent +python worker.py ``` -### 4. Durable Task Scheduler - -The sample requires a Durable Task Scheduler running. There are two options: - -#### Using the Emulator (Recommended for Local Development) - -The emulator simulates a scheduler and taskhub in a Docker container, making it ideal for development and learning. - -1. Pull the Docker Image for the Emulator: - ```bash - docker pull mcr.microsoft.com/dts/dts-emulator:latest - ``` - -2. Run the Emulator: - ```bash - docker run --name dtsemulator -d -p 8080:8080 -p 8082:8082 mcr.microsoft.com/dts/dts-emulator:latest - ``` - Wait a few seconds for the container to be ready. - -> *How to Run the Sample - -Once you have set up either the emulator or deployed scheduler, follow these steps to run the sample: +The worker will register the Joker agent and listen for requests. -1. **Activate your Python virtual environment** (if you're using one): - ```bash - python -m venv venv - source venv/bin/activate # On Windows, use: venv\Scripts\activate - ``` - -2. **If you're using a deployed scheduler**, set environment variables: - ```bash - export ENDPOINT=$(az durabletask scheduler show \ - --resource-group my-resource-group \ - --name my-scheduler \ - --query "properties.endpoint" \ - --output tsv) - - export TASKHUB="my-taskhub" - ``` - -3. **Install the required packages**: - ```bash - pip install -r requirements.txt - ``` - -4. **Start the worker** in a terminal: - ```bash - python worker.py - ``` - You should see output indicating the worker has started and registered the agent: - ``` - INFO:__main__:Starting Durable Task Agent Worker... - INFO:__main__:Using taskhub: default - INFO:__main__:Using endpoint: http://localhost:8080 - INFO:__main__:Creating and registering Joker agent... - INFO:__main__:✓ Registered agent: Joker - INFO:__main__: Entity name: dafx-Joker - INFO:__main__: - INFO:__main__:Worker is ready and listening for requests... - INFO:__main__:Press Ctrl+C to stop. - ``` - -5. **In a new terminal** (with the virtual environment activated if applicable), **run the client**: - > **Note:** Remember to set the environment variables again if you're using a deployed scheduler. - - ```bash - python client.py - ``` - az role assignment create \ - --assignee $loggedInUser \ - --role "Durable Task Data Contributor" \ - --scope "/subscriptions/$subscriptionId/resourceGroups/my-resource-group/providers/Microsoft.DurableTask/schedulers/my-scheduler/taskHubs/my-taskhub" - ``` - -5. Set environment variables: - ```bash - export ENDPOINT=$(az durabletask scheduler show \ - --resource-group my-resource-group \ - --name my-scheduler \ - --query "properties.endpoint" \ - --output tsv) - - export TASKHUB="my-taskhub" - ``` - -## Running the Sample - -### Step 1: Start the Worker - -In one terminal, start the worker to host the agent: +**In a new terminal, run the client:** ```bash -python sample.py worker +python client.py ``` -You should see output similar to: -``` -Starting Durable Task worker... -Connecting to scheduler at: localhost:4001 -✓ Registered agent: Joker - Entity name: dafx-Joker - -Worker is ready and listening for requests... -Press Ctrl+C to stop. -``` - -The worker will continue running and processing requests until you stop it (Ctrl+C). - -### Step 2: Run the Client +The client will interact with the Joker agent: -In a **separate terminal**, run the client to interact with the agent: -Understanding the Output - -When you run the sample, you'll see output from both the worker and client processes: - -### Worker Output - -The worker shows: -- Connection information (taskhub and endpoint) -- Registration of the Joker agent as a durable entity -- Entity name (`dafx-Joker`) -- Status message indicating it's ready to process requests - -Example: -``` -INFO:__main__:Starting Durable Task Agent Worker... -INFO:__main__:Using taskhub: default -INFO:__main__:Using endpoint: http://localhost:8080 -INFO:__main__:Creating and registering Joker agent... -INFO:__main__:✓ Registered agent: Joker -INFO:__main__: Entity name: dafx-Joker -INFO:__main__: -INFO:__main__:Worker is ready and listening for requests... -INFO:__main__:Press Ctrl+C to stop. ``` +Starting Durable Task Agent Client... +Using taskhub: default +Using endpoint: http://localhost:8080 -### Client Output +Getting reference to Joker agent... +Created conversation thread: a1b2c3d4-e5f6-7890-abcd-ef1234567890 -The client shows: -- Connection information -- Thread creation -- User messages sent to the agent -- Agent responses (jokes) -- Token usage statistics -- Conversation completion status - -Example: -``` -INFO:__main__:Starting Durable Task Agent Client... -INFO:__main__:Using taskhub: default -INFO:__main__:Using endpoint: http://localhost:8080 -INFO:__main__: -INFO:__main__:Getting reference to Joker agent... -INFO:__main__:Created conversation thread: a1b2c3d4-e5f6-7890-abcd-ef1234567890 -INFO:__main__: -INFO:__main__:User: Tell me a short joke about cloud computing. -INFO:__main__: -INFO:__main__:Joker: Why did the cloud break up with the server? +User: Tell me a short joke about cloud computing. +Joker: Why did the cloud break up with the server? Because it found someone more "uplifting"! -INFO:__main__:Usage: UsageStats(input_tokens=42, output_tokens=18, total_tokens=60) -INFO:__main__: -INFO:__main__:User: Now tell me one about Python programming. -INFO:__main__: -INFO:__main__:Joker: Why do Python programmers prefer dark mode? -Understanding the Code - -### Worker (`worker.py`) - -The worker process is responsible for hosting agents: - -```python -from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker -from agent_framework_durabletask import DurableAIAgentWorker - -# Create a worker using Azure Managed Durable Task -worker = DurableTaskSchedulerWorker( - host_address=endpoint, - secure_channel=endpoint != "http://localhost:8080", - taskhub=taskhub_name, - token_credential=credential -) -# Wrap it with the agent worker -agent_worker = DurableAIAgentWorker(worker) +User: Now tell me one about Python programming. -# Create and register agents -joker_agent = create_joker_agent() -agent_worker.add_agent(joker_agent) - -# Start processing (blocks until stopped) -worker.start() +Joker: Why do Python programmers prefer dark mode? +Because light attracts bugs! ``` -**What happens:** -- The agent is registered as a durable entity with name `dafx-{agent_name}` -- The worker continuously polls for requests directed to this entity -- Each request is routed to the agent's execution logic -- Conversation state is persisted automatically in the entity - -### Client (`client.py`) - -The client process interacts with registered agents: - -```python -from durabletask.azuremanaged.client import DurableTaskSchedulerClient -from agent_framework_durabletask import DurableAIAgentClient - -# Create a client using Azure Managed Durable Task -client = DurableTaskSchedulerClient( - host_address=endpoint, - secure_channel=endpoint != "http://localhost:8080", - taskhub=taskhub_name, - token_credential=credential -) - -# Wrap it with the agent client -agent_client = DurableAIAgentClient(client) - -# Get agent reference (no validation until execution) -joker = agent_client.get_agent("Joker") - -# Create thread and run -thread = joker.get_new_thread() -response = await joker.run(message, thread=thread) -``` - -**What happens:** -- The client constructs a request with the message and thread information -- The request is sent to the entity `dafx-Joker` via the scheduler -- The client waits for the entity to process theEmulator is running: -```bash -docker ps | grep dts-emulator -``` - -If not running, start it: -```bash -docker run --name dtsemulator -d -p 8080:8080 -p 8082:8082 mcr.microsoft.com/dts/dts-emulator:latest -``` - -### Agent Not Found - -**Error**: Agent execution fails with "entity not found" or similar - -**Solution**: -1. Ensure the worker is running and has registered the agent -2. Check that the agent name matches exactly (case-sensitive) -3. Verify both client and worker are connecting to the same endpoint and taskhub -4. Check worker logs for successful agent registration - -### Azure OpenAI Authentication - -**Error**: Authentication errors when creating the agent - -**Solution**: -1. Ensure `AZURE_OPENAI_ENDPOINT` and `AZURE_OPENAI_CHAT_DEPLOYMENT_NAME` are set -2. Either: - - Set `AZURE_OPENAI_API_KEY` environment variable, OR - - Run `az login` to authenticate with Azure CLI - -### Environment Variables Not Set - -If using a deployed scheduler, ensure you set the environment variables in **both** terminals (worker and client): -```bash -export ENDPOINT="" -export TASKHUB="" -``` - -## Reviewing the Agent in the Durable Task Scheduler Dashboard - -### Using the Emulator - -1. Navigate to http://localhost:8082 in your web browser -2. Click on the "default" task hub -3. You'll see the agent entity (`dafx-Joker`) in the list -4. Click on the entity to view: - - Entity state and conversation history - - Request and response details - - Execution timeline - -### Using a Deployed Scheduler - -1. Navigate to the Scheduler resource in the Azure portal -2. Go to the Task Hub subresource that you're using -3. Click on the dashboard URL in the top right corner -4. Search for your entity (`dafx-Joker`) -5. Review the entity state and execution history - -## Comparison with Azure Functions Sample - -| Aspect | Azure Functions | Durable Task (BYOP) | -|--------|----------------|---------------------| -| **Platform** | Azure Functions (PaaS) | Any platform with gRPC | -| **Hosting** | AgentFunctionApp | DurableTaskSchedulerWorker + DurableAIAgentWorker | -| **Client API** | HTTP endpoints | DurableAIAgentClient | -| **Infrastructure** | Managed by Azure | Self-hosted scheduler or Azure DTS | -| **Scalability** | Auto-scaling | Manual scaling or K8s | -| **Use Case** | Production cloud workloads | Local dev, on-prem, custom platforms | - -## Identity-based Authentication - -Learn how to set up [identity-based authentication](https://learn.microsoft.com/azure/azure-functions/durable/durable-task-scheduler/durable-task-scheduler-identity?tabs=df&pivots=az-cli) when you deploy to Azure. - -## Next Steps +## Viewing Agent State -- **Multiple Agents**: Modify the sample to register multiple agents with different capabilities -- **Structured Responses**: Use `response_format` parameter to get JSON structured output -- **Agent Orchestration**: Create orchestrations that coordinate multiple agents (see advanced samples) -- **Production Deployment**: Deploy workers to Kubernetes, VMs, or container services -- **Monitoring**: Add telemetry and logging for production workloads +You can view the state of the agent in the Durable Task Scheduler dashboard: -## Related Samples +1. Open your browser and navigate to `http://localhost:8082` +2. In the dashboard, you can view the state of the Joker agent, including its conversation history and current state -- [Azure Functions Single Agent Sample](../../../azure_functions/01_single_agent/) - Azure Functions hosting -- [Durable Task Scheduler Samples](https://github.com/Azure-Samples/Durable-Task-Scheduler) - More patterns and examples +The agent maintains conversation state across multiple interactions, and you can inspect this state in the dashboard to understand how the durable agents extension manages conversation context. -## Additional Resources -- [Durable Task Framework](https://github.com/microsoft/durabletask-python) -- [Agent Framework Documentation](https://github.com/microsoft/agent-framework) -- [Durable Task Scheduler](https://github.com/Azure-Samples/Durable-Task-Scheduler) -- [Azure Durable Task Scheduler Documentation](https://learn.microsoft.com/azure/azure-functions/durable/durable-task-scheduler/) diff --git a/python/samples/getting_started/durabletask/01_single_agent/sample.py b/python/samples/getting_started/durabletask/01_single_agent/sample.py new file mode 100644 index 0000000000..cfbceaaebd --- /dev/null +++ b/python/samples/getting_started/durabletask/01_single_agent/sample.py @@ -0,0 +1,137 @@ +"""Single Agent Sample - Durable Task Integration (Combined Worker + Client) + +This sample demonstrates running both the worker and client in a single process. +The worker is started first to register the agent, then client operations are +performed against the running worker. + +Prerequisites: +- Set AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_CHAT_DEPLOYMENT_NAME + (plus AZURE_OPENAI_API_KEY or Azure CLI authentication) +- Durable Task Scheduler must be running (e.g., using Docker) + +To run this sample: + python sample.py +""" + +import logging +import os + +from agent_framework.azure import AzureOpenAIChatClient +from agent_framework_durabletask import DurableAIAgentClient, DurableAIAgentWorker +from azure.identity import AzureCliCredential, DefaultAzureCredential +from dotenv import load_dotenv +from durabletask.azuremanaged.client import DurableTaskSchedulerClient +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def create_joker_agent(): + """Create the Joker agent using Azure OpenAI. + + Returns: + AgentProtocol: The configured Joker agent + """ + return AzureOpenAIChatClient(credential=AzureCliCredential()).create_agent( + name="Joker", + instructions="You are good at telling jokes.", + ) + + +def main(): + """Main entry point - runs both worker and client in single process.""" + logger.info("Starting Durable Task Agent Sample (Combined Worker + Client)...") + + # Get environment variables for taskhub and endpoint with defaults + taskhub_name = os.getenv("TASKHUB", "default") + endpoint = os.getenv("ENDPOINT", "http://localhost:8080") + + logger.info(f"Using taskhub: {taskhub_name}") + logger.info(f"Using endpoint: {endpoint}") + logger.info("") + + # Set credential to None for emulator, or DefaultAzureCredential for Azure + credential = None if endpoint == "http://localhost:8080" else DefaultAzureCredential() + secure_channel = endpoint != "http://localhost:8080" + + # Create and start the worker using a context manager + with DurableTaskSchedulerWorker( + host_address=endpoint, + secure_channel=secure_channel, + taskhub=taskhub_name, + token_credential=credential + ) as worker: + + # Wrap with the agent worker + agent_worker = DurableAIAgentWorker(worker) + + # Create and register the Joker agent + logger.info("Creating and registering Joker agent...") + joker_agent = create_joker_agent() + agent_worker.add_agent(joker_agent) + + logger.info(f"✓ Registered agent: {joker_agent.name}") + logger.info(f" Entity name: dafx-{joker_agent.name}") + logger.info("") + + # Start the worker + worker.start() + logger.info("Worker started and listening for requests...") + logger.info("") + + # Create the client + client = DurableTaskSchedulerClient( + host_address=endpoint, + secure_channel=secure_channel, + taskhub=taskhub_name, + token_credential=credential + ) + + # Wrap it with the agent client + agent_client = DurableAIAgentClient(client) + + # Get a reference to the Joker agent + logger.info("Getting reference to Joker agent...") + joker = agent_client.get_agent("Joker") + + # Create a new thread for the conversation + thread = joker.get_new_thread() + + logger.info(f"Created conversation thread: {thread.session_id}") + logger.info("") + + try: + # First message + message1 = "Tell me a short joke about cloud computing." + logger.info(f"User: {message1}") + logger.info("") + + # Run the agent - this blocks until the response is ready + response1 = joker.run(message1, thread=thread) + logger.info(f"Agent: {response1.text}; {response1}") + logger.info("") + + # Second message - continuing the conversation + message2 = "Now tell me one about Python programming." + logger.info(f"User: {message2}") + logger.info("") + + response2 = joker.run(message2, thread=thread) + logger.info(f"Agent: {response2.text}; {response2}") + logger.info("") + + logger.info(f"Conversation completed successfully!") + logger.info(f"Thread ID: {thread.session_id}") + + except Exception as e: + logger.exception(f"Error during agent interaction: {e}") + + logger.info("") + logger.info("Sample completed. Worker shutting down...") + + +if __name__ == "__main__": + load_dotenv() + main() diff --git a/python/samples/getting_started/durabletask/04_single_agent_orchestration_chaining/README.md b/python/samples/getting_started/durabletask/04_single_agent_orchestration_chaining/README.md new file mode 100644 index 0000000000..090fdec3b9 --- /dev/null +++ b/python/samples/getting_started/durabletask/04_single_agent_orchestration_chaining/README.md @@ -0,0 +1,72 @@ +# Single Agent Orchestration Chaining Sample + +This sample demonstrates how to chain multiple invocations of the same agent using a durable orchestration while preserving conversation state between runs. + +## Key Concepts Demonstrated + +- Using durable orchestrations to coordinate sequential agent invocations. +- Chaining agent calls where the output of one run becomes input to the next. +- Maintaining conversation context across sequential runs using a shared thread. +- Using `DurableAIAgentOrchestrationContext` to access agents within orchestrations. + +## Environment Setup + +See the [README.md](../README.md) file in the parent directory for more information on how to configure the environment, including how to install and run common sample dependencies. + +## Running the Sample + +With the environment setup, you can run the sample using one of two approaches: + +### Option 1: Combined Worker + Client (Quick Start) + +```bash +cd samples/getting_started/durabletask/04_single_agent_orchestration_chaining +python sample.py +``` + +This runs both worker and client in a single process. + +### Option 2: Separate Worker and Client + +**Start the worker in one terminal:** + +```bash +python worker.py +``` + +**In a new terminal, run the client:** + +```bash +python client.py +``` + +The orchestration will execute the writer agent twice sequentially, and you'll see output like: + +``` +[Orchestration] Starting single agent chaining... +[Orchestration] Created thread: abc-123 +[Orchestration] First agent run: Generating initial sentence... +[Orchestration] Initial response: Every small step forward is progress toward mastery. +[Orchestration] Second agent run: Refining the sentence... +[Orchestration] Refined response: Each small step forward brings you closer to mastery and growth. +[Orchestration] Chaining complete + +================================================================================ +Orchestration Result +================================================================================ +Each small step forward brings you closer to mastery and growth. +``` + +## Viewing Orchestration State + +You can view the state of the orchestration in the Durable Task Scheduler dashboard: + +1. Open your browser and navigate to `http://localhost:8082` +2. In the dashboard, you can view the orchestration instance, including: + - The sequential execution of both agent runs + - The conversation thread shared between runs + - Input and output at each step + - Overall orchestration state and history + +The orchestration maintains the conversation context across both agent invocations, demonstrating how durable orchestrations can coordinate multi-step agent workflows. + diff --git a/python/samples/getting_started/durabletask/04_single_agent_orchestration_chaining/client.py b/python/samples/getting_started/durabletask/04_single_agent_orchestration_chaining/client.py new file mode 100644 index 0000000000..1b5331e47e --- /dev/null +++ b/python/samples/getting_started/durabletask/04_single_agent_orchestration_chaining/client.py @@ -0,0 +1,104 @@ +"""Client application for starting a single agent chaining orchestration. + +This client connects to the Durable Task Scheduler and starts an orchestration +that runs a writer agent twice sequentially on the same thread, demonstrating +how conversation context is maintained across multiple agent invocations. + +Prerequisites: +- The worker must be running with the writer agent and orchestration registered +- Set AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_CHAT_DEPLOYMENT_NAME + (plus AZURE_OPENAI_API_KEY or Azure CLI authentication) +- Durable Task Scheduler must be running +""" + +import asyncio +import json +import logging +import os + +from azure.identity import DefaultAzureCredential +from durabletask.azuremanaged.client import DurableTaskSchedulerClient + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +async def main() -> None: + """Main entry point for the client application.""" + logger.info("Starting Durable Task Single Agent Chaining Orchestration Client...") + + # Get environment variables for taskhub and endpoint with defaults + taskhub_name = os.getenv("TASKHUB", "default") + endpoint = os.getenv("ENDPOINT", "http://localhost:8080") + + logger.info(f"Using taskhub: {taskhub_name}") + logger.info(f"Using endpoint: {endpoint}") + logger.info("") + + # Set credential to None for emulator, or DefaultAzureCredential for Azure + credential = None if endpoint == "http://localhost:8080" else DefaultAzureCredential() + + # Create a client using Azure Managed Durable Task + client = DurableTaskSchedulerClient( + host_address=endpoint, + secure_channel=endpoint != "http://localhost:8080", + taskhub=taskhub_name, + token_credential=credential + ) + + logger.info("Starting single agent chaining orchestration...") + logger.info("") + + try: + # Start the orchestration + instance_id = client.schedule_new_orchestration( + orchestrator="single_agent_chaining_orchestration", + input="", + ) + + logger.info(f"Orchestration started with instance ID: {instance_id}") + logger.info("Waiting for orchestration to complete...") + logger.info("") + + # Retrieve the final state + metadata = client.wait_for_orchestration_completion( + instance_id=instance_id, + timeout=300 + ) + + if metadata and metadata.runtime_status.name == "COMPLETED": + result = metadata.serialized_output + + logger.info("=" * 80) + logger.info("Orchestration completed successfully!") + logger.info("=" * 80) + logger.info("") + logger.info("Results:") + logger.info("") + + # Parse and display the result + if result: + final_text = json.loads(result) + logger.info("Final refined sentence:") + logger.info(f" {final_text}") + logger.info("") + + logger.info("=" * 80) + + elif metadata: + logger.error(f"Orchestration ended with status: {metadata.runtime_status.name}") + if metadata.serialized_output: + logger.error(f"Output: {metadata.serialized_output}") + else: + logger.error("Orchestration did not complete within the timeout period") + + except Exception as e: + logger.exception(f"Error during orchestration: {e}") + finally: + logger.info("") + logger.info("Client shutting down") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/getting_started/durabletask/04_single_agent_orchestration_chaining/requirements.txt b/python/samples/getting_started/durabletask/04_single_agent_orchestration_chaining/requirements.txt new file mode 100644 index 0000000000..da871507c8 --- /dev/null +++ b/python/samples/getting_started/durabletask/04_single_agent_orchestration_chaining/requirements.txt @@ -0,0 +1,9 @@ +# Agent Framework packages +agent-framework-azure +agent-framework-durabletask + +# Durable Task Python SDK with Azure Managed support +durabletask-python + +# Azure authentication +azure-identity diff --git a/python/samples/getting_started/durabletask/04_single_agent_orchestration_chaining/sample.py b/python/samples/getting_started/durabletask/04_single_agent_orchestration_chaining/sample.py new file mode 100644 index 0000000000..1f1ee81deb --- /dev/null +++ b/python/samples/getting_started/durabletask/04_single_agent_orchestration_chaining/sample.py @@ -0,0 +1,255 @@ +"""Single Agent Orchestration Chaining Sample - Durable Task Integration + +This sample demonstrates chaining two invocations of the same agent inside a Durable Task +orchestration while preserving the conversation state between runs. The orchestration +runs the writer agent sequentially on a shared thread to refine text iteratively. + +Components used: +- AzureOpenAIChatClient to construct the writer agent +- DurableTaskSchedulerWorker and DurableAIAgentWorker for agent hosting +- DurableTaskSchedulerClient and orchestration for sequential agent invocations +- Thread management to maintain conversation context across invocations + +Prerequisites: +- Set AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_CHAT_DEPLOYMENT_NAME + (plus AZURE_OPENAI_API_KEY or Azure CLI authentication) +- Durable Task Scheduler must be running (e.g., using Docker emulator) + +To run this sample: + python sample.py +""" + +import asyncio +import json +import logging +import os +from collections.abc import Generator +from typing import Any + +from agent_framework import AgentRunResponse +from agent_framework.azure import AzureOpenAIChatClient +from agent_framework_durabletask import DurableAIAgentOrchestrationContext, DurableAIAgentWorker +from azure.identity import AzureCliCredential, DefaultAzureCredential +from dotenv import load_dotenv +from durabletask.task import OrchestrationContext, Task +from durabletask.azuremanaged.client import DurableTaskSchedulerClient +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Agent name +WRITER_AGENT_NAME = "WriterAgent" + + +def create_writer_agent(): + """Create the Writer agent using Azure OpenAI. + + This agent refines short pieces of text, enhancing initial sentences + and polishing improved versions further. + + Returns: + AgentProtocol: The configured Writer agent + """ + instructions = ( + "You refine short pieces of text. When given an initial sentence you enhance it;\n" + "when given an improved sentence you polish it further." + ) + + return AzureOpenAIChatClient(credential=AzureCliCredential()).create_agent( + name=WRITER_AGENT_NAME, + instructions=instructions, + ) + + +def single_agent_chaining_orchestration( + context: OrchestrationContext, _: str +) -> Generator[Task[Any], Any, str]: + """Orchestration that runs the writer agent twice on the same thread. + + This demonstrates chaining behavior where the output of the first agent run + becomes part of the input for the second run, all while maintaining the + conversation context through a shared thread. + + Args: + context: The orchestration context + _: Input parameter (unused) + + Returns: + str: The final refined text from the second agent run + """ + logger.info("[Orchestration] Starting single agent chaining...") + + # Wrap the orchestration context to access agents + agent_context = DurableAIAgentOrchestrationContext(context) + + # Get the writer agent using the agent context + writer = agent_context.get_agent(WRITER_AGENT_NAME) + + # Create a new thread for the conversation - this will be shared across both runs + writer_thread = writer.get_new_thread() + + logger.info(f"[Orchestration] Created thread: {writer_thread.session_id}") + + # First run: Generate an initial inspirational sentence + logger.info("[Orchestration] First agent run: Generating initial sentence...") + initial_response: AgentRunResponse = yield writer.run( + messages="Write a concise inspirational sentence about learning.", + thread=writer_thread, + ) + logger.info(f"[Orchestration] Initial response: {initial_response.text}") + + # Second run: Refine the initial response on the same thread + improved_prompt = ( + f"Improve this further while keeping it under 25 words: " + f"{initial_response.text}" + ) + + logger.info("[Orchestration] Second agent run: Refining the sentence...") + refined_response: AgentRunResponse = yield writer.run( + messages=improved_prompt, + thread=writer_thread, + ) + + logger.info(f"[Orchestration] Refined response: {refined_response.text}") + + logger.info("[Orchestration] Chaining complete") + return refined_response.text + + +async def run_client( + endpoint: str, taskhub_name: str, credential: DefaultAzureCredential | None +): + """Run the client to start and monitor the orchestration. + + Args: + endpoint: The durable task scheduler endpoint + taskhub_name: The task hub name + credential: The credential for authentication + """ + logger.info("") + logger.info("=" * 80) + logger.info("CLIENT: Starting orchestration...") + logger.info("=" * 80) + logger.info("") + + # Create a client + client = DurableTaskSchedulerClient( + host_address=endpoint, + secure_channel=endpoint != "http://localhost:8080", + taskhub=taskhub_name, + token_credential=credential + ) + + try: + # Start the orchestration + instance_id = client.schedule_new_orchestration( + single_agent_chaining_orchestration + ) + + logger.info(f"Orchestration started with instance ID: {instance_id}") + logger.info("Waiting for orchestration to complete...") + logger.info("") + + # Retrieve the final state + metadata = client.wait_for_orchestration_completion( + instance_id=instance_id, + timeout=300 + ) + + if metadata and metadata.runtime_status.name == "COMPLETED": + result = metadata.serialized_output + + logger.info("") + logger.info("=" * 80) + logger.info("ORCHESTRATION COMPLETED SUCCESSFULLY!") + logger.info("=" * 80) + logger.info("") + + # Parse and display the result + if result: + final_text = json.loads(result) + logger.info("Final refined sentence:") + logger.info(f" {final_text}") + else: + logger.warning("No output returned from orchestration") + + elif metadata: + logger.error(f"Orchestration did not complete successfully: {metadata.runtime_status.name}") + if metadata.serialized_output: + logger.error(f"Output: {metadata.serialized_output}") + else: + logger.error("Could not retrieve orchestration metadata") + + except Exception as e: + logger.exception(f"Error during orchestration: {e}") + + logger.info("") + logger.info("Client shutting down") + + +def main(): + """Main entry point - runs both worker and client in single process.""" + logger.info("Starting Single Agent Orchestration Chaining Sample...") + logger.info("") + + # Load environment variables + load_dotenv() + + # Get environment variables for taskhub and endpoint with defaults + taskhub_name = os.getenv("TASKHUB", "default") + endpoint = os.getenv("ENDPOINT", "http://localhost:8080") + + logger.info(f"Using taskhub: {taskhub_name}") + logger.info(f"Using endpoint: {endpoint}") + logger.info("") + + # Set credential to None for emulator, or DefaultAzureCredential for Azure + credential = None if endpoint == "http://localhost:8080" else DefaultAzureCredential() + secure_channel = endpoint != "http://localhost:8080" + + # Create and start the worker using a context manager + with DurableTaskSchedulerWorker( + host_address=endpoint, + secure_channel=secure_channel, + taskhub=taskhub_name, + token_credential=credential + ) as worker: + + # Wrap with the agent worker + agent_worker = DurableAIAgentWorker(worker) + + # Create and register the Writer agent + logger.info("Creating and registering Writer agent...") + writer_agent = create_writer_agent() + agent_worker.add_agent(writer_agent) + + logger.info(f"✓ Registered agent: {writer_agent.name}") + logger.info(f" Entity name: dafx-{writer_agent.name}") + + # Register the orchestration function + logger.info("Registering orchestration function...") + worker.add_orchestrator(single_agent_chaining_orchestration) + logger.info("✓ Registered orchestration: single_agent_chaining_orchestration") + logger.info("") + + # Start the worker + worker.start() + logger.info("Worker started and listening for requests...") + logger.info("") + + # Run the client in the same process + try: + asyncio.run(run_client(endpoint, taskhub_name, credential)) + except KeyboardInterrupt: + logger.info("Sample interrupted by user") + finally: + logger.info("Worker stopping...") + + logger.info("Sample completed") + + +if __name__ == "__main__": + load_dotenv() + main() diff --git a/python/samples/getting_started/durabletask/04_single_agent_orchestration_chaining/worker.py b/python/samples/getting_started/durabletask/04_single_agent_orchestration_chaining/worker.py new file mode 100644 index 0000000000..b033bca155 --- /dev/null +++ b/python/samples/getting_started/durabletask/04_single_agent_orchestration_chaining/worker.py @@ -0,0 +1,167 @@ +"""Worker process for hosting a single agent with chaining orchestration using Durable Task. + +This worker registers a writer agent and an orchestration function that demonstrates +chaining behavior by running the agent twice sequentially on the same thread, +preserving conversation context between invocations. + +Prerequisites: +- Set AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_CHAT_DEPLOYMENT_NAME + (plus AZURE_OPENAI_API_KEY or Azure CLI authentication) +- Start a Durable Task Scheduler (e.g., using Docker) +""" + +import asyncio +from collections.abc import Generator +import logging +import os +from typing import Any + +from agent_framework import AgentRunResponse +from agent_framework.azure import AzureOpenAIChatClient +from agent_framework_durabletask import DurableAIAgentOrchestrationContext, DurableAIAgentWorker +from azure.identity import AzureCliCredential, DefaultAzureCredential +from durabletask.task import OrchestrationContext, Task +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Agent name +WRITER_AGENT_NAME = "WriterAgent" + + +def create_writer_agent(): + """Create the Writer agent using Azure OpenAI. + + This agent refines short pieces of text, enhancing initial sentences + and polishing improved versions further. + + Returns: + AgentProtocol: The configured Writer agent + """ + instructions = ( + "You refine short pieces of text. When given an initial sentence you enhance it;\n" + "when given an improved sentence you polish it further." + ) + + return AzureOpenAIChatClient(credential=AzureCliCredential()).create_agent( + name=WRITER_AGENT_NAME, + instructions=instructions, + ) + + +def single_agent_chaining_orchestration( + context: OrchestrationContext, _: str +) -> Generator[Task[Any], Any, str]: + """Orchestration that runs the writer agent twice on the same thread. + + This demonstrates chaining behavior where the output of the first agent run + becomes part of the input for the second run, all while maintaining the + conversation context through a shared thread. + + Args: + context: The orchestration context + _: Input parameter (unused) + + Returns: + str: The final refined text from the second agent run + """ + logger.info("[Orchestration] Starting single agent chaining...") + + # Wrap the orchestration context to access agents + agent_context = DurableAIAgentOrchestrationContext(context) + + # Get the writer agent using the agent context + writer = agent_context.get_agent(WRITER_AGENT_NAME) + + # Create a new thread for the conversation - this will be shared across both runs + writer_thread = writer.get_new_thread() + + logger.info(f"[Orchestration] Created thread: {writer_thread.session_id}") + + # First run: Generate an initial inspirational sentence + logger.info("[Orchestration] First agent run: Generating initial sentence...") + initial_response: AgentRunResponse = yield writer.run( + messages="Write a concise inspirational sentence about learning.", + thread=writer_thread, + ) + logger.info(f"[Orchestration] Initial response: {initial_response.text}") + + # Second run: Refine the initial response on the same thread + improved_prompt = ( + f"Improve this further while keeping it under 25 words: " + f"{initial_response.text}" + ) + + logger.info("[Orchestration] Second agent run: Refining the sentence...") + refined_response: AgentRunResponse = yield writer.run( + messages=improved_prompt, + thread=writer_thread, + ) + + logger.info(f"[Orchestration] Refined response: {refined_response.text}") + + logger.info("[Orchestration] Chaining complete") + return refined_response.text + + +async def main(): + """Main entry point for the worker process.""" + logger.info("Starting Durable Task Single Agent Chaining Worker with Orchestration...") + + # Get environment variables for taskhub and endpoint with defaults + taskhub_name = os.getenv("TASKHUB", "default") + endpoint = os.getenv("ENDPOINT", "http://localhost:8080") + + logger.info(f"Using taskhub: {taskhub_name}") + logger.info(f"Using endpoint: {endpoint}") + + # Set credential to None for emulator, or DefaultAzureCredential for Azure + credential = None if endpoint == "http://localhost:8080" else DefaultAzureCredential() + + # Create a worker using Azure Managed Durable Task + worker = DurableTaskSchedulerWorker( + host_address=endpoint, + secure_channel=endpoint != "http://localhost:8080", + taskhub=taskhub_name, + token_credential=credential + ) + + # Wrap it with the agent worker + agent_worker = DurableAIAgentWorker(worker) + + # Create and register the Writer agent + logger.info("Creating and registering Writer agent...") + writer_agent = create_writer_agent() + agent_worker.add_agent(writer_agent) + + logger.info(f"✓ Registered agent: {writer_agent.name}") + logger.info(f" Entity name: dafx-{writer_agent.name}") + logger.info("") + + # Register the orchestration function + logger.info("Registering orchestration function...") + worker.add_orchestrator(single_agent_chaining_orchestration) + logger.info(f"✓ Registered orchestration: {single_agent_chaining_orchestration.__name__}") + logger.info("") + + logger.info("Worker is ready and listening for requests...") + logger.info("Press Ctrl+C to stop.") + logger.info("") + + try: + # Start the worker (this blocks until stopped) + worker.start() + + # Keep the worker running + while True: + await asyncio.sleep(1) + except KeyboardInterrupt: + logger.info("Worker shutdown initiated") + + logger.info("Worker stopped") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/getting_started/durabletask/05_multi_agent_orchestration_concurrency/README.md b/python/samples/getting_started/durabletask/05_multi_agent_orchestration_concurrency/README.md index f11ab7aaa0..89efdb5e8d 100644 --- a/python/samples/getting_started/durabletask/05_multi_agent_orchestration_concurrency/README.md +++ b/python/samples/getting_started/durabletask/05_multi_agent_orchestration_concurrency/README.md @@ -1,204 +1,54 @@ -# Multi-Agent Orchestration with Concurrency Sample (Python) - Durable Task +# Multi-Agent Orchestration with Concurrency Sample -This sample demonstrates how to use the **Durable Task package** with **OrchestrationAgentExecutor** to orchestrate multiple AI agents running concurrently and aggregate their responses. - -## Description of the Sample - -This sample shows how to run two domain-specific agents (a Physicist and a Chemist) concurrently using Durable Task orchestration. The agents respond to the same prompt from their respective domain perspectives, demonstrating: - -- How to register multiple agents as durable entities -- How to create an orchestration function that executes agents in parallel -- How to use `OrchestrationAgentExecutor` for concurrent agent execution -- How to aggregate results from multiple agents -- The benefits of concurrent execution for improved performance +This sample demonstrates how to host multiple agents and run them concurrently using a durable orchestration, aggregating their responses into a single result. ## Key Concepts Demonstrated -- **Multi-Agent Architecture**: Running multiple specialized agents in parallel -- **Orchestration Functions**: Using Durable Task orchestrations to coordinate agent execution -- **OrchestrationAgentExecutor**: Execution strategy for orchestrations that returns `DurableAgentTask` objects -- **Concurrent Execution**: Using `task.when_all()` to run agents in parallel -- **Result Aggregation**: Collecting and combining responses from multiple agents -- **Thread Management**: Creating separate conversation threads for each agent -- **BYOP (Bring Your Own Platform)**: Not tied to Azure Functions - run anywhere with Durable Task Scheduler - -## Architecture Overview - -This sample uses a **worker-orchestration-client architecture**: - -1. **Worker Process** (`worker.py`): Registers two agents (Physicist and Chemist) as durable entities and an orchestration function -2. **Orchestration Function**: Coordinates concurrent execution of both agents and aggregates results -3. **Client Process** (`client.py`): Starts the orchestration and retrieves aggregated results -4. **Durable Task Scheduler**: Coordinates communication and orchestration execution (runs separately) - -### Execution Flow - -``` -Client → Start Orchestration → Orchestration Context - ↓ - ┌─────────────┴─────────────┐ - ↓ ↓ - Physicist Agent Chemist Agent - (Concurrent) (Concurrent) - ↓ ↓ - └─────────────┬─────────────┘ - ↓ - Aggregate Results → Client -``` - -## What Makes This Different? - -This sample differs from the single agent sample in several key ways: - -1. **Multiple Agents**: Two specialized agents with different domain expertise -2. **Concurrent Execution**: Both agents run simultaneously using `task.when_all()` -3. **Orchestration Function**: Uses a Durable Task orchestrator to coordinate execution -4. **OrchestrationAgentExecutor**: Different execution strategy that returns tasks instead of blocking -5. **Result Aggregation**: Combines responses from multiple agents into a single result - -## Prerequisites - -### 1. Python 3.9+ - -Ensure you have Python 3.9 or later installed. - -### 2. Azure OpenAI Setup - -Configure your Azure OpenAI credentials: -- Set `AZURE_OPENAI_ENDPOINT` environment variable -- Set `AZURE_OPENAI_CHAT_DEPLOYMENT_NAME` environment variable -- Either: - - Set `AZURE_OPENAI_API_KEY` environment variable, OR - - Run `az login` to authenticate with Azure CLI +- Running multiple specialized agents in parallel within an orchestration. +- Using `OrchestrationAgentExecutor` to get `DurableAgentTask` objects for concurrent execution. +- Aggregating results from multiple agents using `task.when_all()`. +- Creating separate conversation threads for independent agent contexts. -### 3. Install Dependencies +## Environment Setup -Install the required packages: - -```bash -pip install -r requirements.txt -``` - -Or if using uv: - -```bash -uv pip install -r requirements.txt -``` - -### 4. Durable Task Scheduler - -The sample requires a Durable Task Scheduler running. For local development, use the emulator: - -#### Using the Emulator (Recommended for Local Development) - -1. Pull the Docker Image for the Emulator: - ```bash - docker pull mcr.microsoft.com/dts/dts-emulator:latest - ``` - -2. Run the Emulator: - ```bash - docker run --name dtsemulator -d -p 8080:8080 -p 8082:8082 mcr.microsoft.com/dts/dts-emulator:latest - ``` - Wait a few seconds for the container to be ready. +See the [README.md](../README.md) file in the parent directory for more information on how to configure the environment, including how to install and run common sample dependencies. ## Running the Sample -### Option 1: Combined Worker + Client (Recommended for Quick Start) +With the environment setup, you can run the sample using one of two approaches: -The easiest way to run the sample is using `sample.py`, which runs both worker and client in a single process: +### Option 1: Combined Worker + Client (Quick Start) ```bash +cd samples/getting_started/durabletask/05_multi_agent_orchestration_concurrency python sample.py ``` -This will: -1. Start the worker and register both agents and the orchestration -2. Start the orchestration with a sample prompt -3. Wait for completion and display aggregated results -4. Shut down the worker +This runs both worker and client in a single process. ### Option 2: Separate Worker and Client -For a more realistic distributed setup, run the worker and client separately: - -1. **Start the worker** in one terminal: - ```bash - python worker.py - ``` - - You should see output indicating both agents are registered: - ``` - INFO:__main__:Starting Durable Task Multi-Agent Worker with Orchestration... - INFO:__main__:Using taskhub: default - INFO:__main__:Using endpoint: http://localhost:8080 - INFO:__main__:Creating and registering agents... - INFO:__main__:✓ Registered agent: PhysicistAgent - INFO:__main__: Entity name: dafx-PhysicistAgent - INFO:__main__:✓ Registered agent: ChemistAgent - INFO:__main__: Entity name: dafx-ChemistAgent - INFO:__main__:✓ Registered orchestration: multi_agent_concurrent_orchestration - INFO:__main__:Worker is ready and listening for requests... - ``` - -2. **In a new terminal**, run the client: - ```bash - python client.py - ``` - - The client will start the orchestration and wait for results. - -## Understanding the Output - -### Worker Output - -The worker shows: -- Connection information (taskhub and endpoint) -- Registration of both agents (Physicist and Chemist) as durable entities -- Registration of the orchestration function -- Status messages during orchestration execution +**Start the worker in one terminal:** -Example: -``` -INFO:__main__:Starting Durable Task Multi-Agent Worker with Orchestration... -INFO:__main__:Using taskhub: default -INFO:__main__:Using endpoint: http://localhost:8080 -INFO:__main__:Creating and registering agents... -INFO:__main__:✓ Registered agent: PhysicistAgent -INFO:__main__: Entity name: dafx-PhysicistAgent -INFO:__main__:✓ Registered agent: ChemistAgent -INFO:__main__: Entity name: dafx-ChemistAgent -INFO:__main__:✓ Registered orchestration: multi_agent_concurrent_orchestration -INFO:__main__:Worker is ready and listening for requests... +```bash +python worker.py ``` -### Client Output - -The client shows: -- The prompt sent to both agents -- Orchestration instance ID -- Status updates during execution -- Aggregated results from both agents +**In a new terminal, run the client:** -Example: +```bash +python client.py ``` -INFO:__main__:Starting Durable Task Multi-Agent Orchestration Client... -INFO:__main__:Using taskhub: default -INFO:__main__:Using endpoint: http://localhost:8080 - -INFO:__main__:Prompt: What is temperature? -INFO:__main__:Starting multi-agent concurrent orchestration... -INFO:__main__:Orchestration started with instance ID: abc123... -INFO:__main__:Waiting for orchestration to complete... - -INFO:__main__:Orchestration status: COMPLETED -================================================================================ -Orchestration completed successfully! -================================================================================ +The orchestration will execute both agents concurrently, and you'll see output like: +``` Prompt: What is temperature? +Starting multi-agent concurrent orchestration... +Orchestration started with instance ID: abc123... +Orchestration status: COMPLETED + Results: Physicist's response: @@ -206,112 +56,20 @@ Physicist's response: Chemist's response: Temperature reflects how molecular motion influences reaction rates... - -================================================================================ -``` - -### Orchestration Output - -During execution, the orchestration logs show: -- When concurrent execution starts -- Thread creation for each agent -- Task creation and execution -- Completion and result aggregation - -Example: -``` -INFO:__main__:[Orchestration] Starting concurrent execution for prompt: What is temperature? -INFO:__main__:[Orchestration] Created threads - Physicist: session-123, Chemist: session-456 -INFO:__main__:[Orchestration] Created agent tasks, executing concurrently... -INFO:__main__:[Orchestration] Both agents completed -INFO:__main__:[Orchestration] Aggregated results ready -``` - -## How It Works - -### 1. Agent Registration - -Both agents are registered as durable entities in the worker: - -```python -agent_worker.add_agent(physicist_agent) -agent_worker.add_agent(chemist_agent) ``` -### 2. Orchestration Registration - -The orchestration function is registered with the worker: - -```python -worker.add_orchestrator(multi_agent_concurrent_orchestration) -``` - -### 3. Orchestration Execution - -The orchestration uses `OrchestrationAgentExecutor` (implicitly through `context.get_agent()`): - -```python -@task.orchestrator -def multi_agent_concurrent_orchestration(context: OrchestrationContext): - # Get agents (uses OrchestrationAgentExecutor internally) - physicist = context.get_agent(PHYSICIST_AGENT_NAME) - chemist = context.get_agent(CHEMIST_AGENT_NAME) - - # Create tasks (returns DurableAgentTask instances) - physicist_task = physicist.run(messages=prompt, thread=physicist_thread) - chemist_task = chemist.run(messages=prompt, thread=chemist_thread) - - # Execute concurrently - task_results = yield task.when_all([physicist_task, chemist_task]) - - # Aggregate results - return { - "physicist": task_results[0].text, - "chemist": task_results[1].text, - } -``` - -### 4. Key Differences from Single Agent - -- **OrchestrationAgentExecutor**: Returns `DurableAgentTask` instead of `AgentRunResponse` -- **Concurrent Execution**: Uses `task.when_all()` for parallel execution -- **Yield Syntax**: Orchestrations use `yield` to await async operations -- **Result Aggregation**: Combines multiple agent responses into one result - -## Customization - -You can modify the sample to: - -1. **Change the prompt**: Edit the `prompt` variable in `sample.py` or `client.py` -2. **Add more agents**: Create additional agents and add them to the orchestration -3. **Change agent instructions**: Modify the `instructions` parameter when creating agents -4. **Adjust concurrency**: Use different task combination patterns (e.g., sequential, selective) -5. **Add error handling**: Implement retry logic or fallback agents - -## Troubleshooting - -### Orchestration times out -- Increase `max_wait_time` in the client code -- Check that both agents are properly registered in the worker -- Verify Azure OpenAI endpoint and credentials - -### Agents return errors -- Verify Azure OpenAI deployment name is correct -- Check Azure OpenAI quota and rate limits -- Review worker logs for detailed error messages +## Viewing Orchestration State -### Connection errors -- Ensure the Durable Task Scheduler is running -- Verify `ENDPOINT` and `TASKHUB` environment variables -- Check network connectivity and firewall rules +You can view the state of the orchestration in the Durable Task Scheduler dashboard: -## Related Samples +1. Open your browser and navigate to `http://localhost:8082` +2. In the dashboard, you can view the orchestration instance, including: + - The concurrent execution of both agents (Physicist and Chemist) + - Separate conversation threads for each agent + - Parallel task execution and completion timing + - Aggregated results from both agents + - Overall orchestration state and history -- **01_single_agent**: Basic single agent setup -- **Azure Functions 05_multi_agent_orchestration_concurrency**: Similar pattern using Azure Functions +The orchestration demonstrates how multiple agents can be executed in parallel, with results collected and aggregated once all agents complete. -## Learn More -- [Agent Framework Documentation](https://github.com/microsoft/agent-framework) -- [Durable Task Framework](https://github.com/microsoft/durabletask) -- [Azure OpenAI Service](https://azure.microsoft.com/en-us/products/cognitive-services/openai-service) diff --git a/python/samples/getting_started/durabletask/05_multi_agent_orchestration_concurrency/sample.py b/python/samples/getting_started/durabletask/05_multi_agent_orchestration_concurrency/sample.py new file mode 100644 index 0000000000..1f6723695e --- /dev/null +++ b/python/samples/getting_started/durabletask/05_multi_agent_orchestration_concurrency/sample.py @@ -0,0 +1,266 @@ +"""Multi-Agent Orchestration Sample - Durable Task Integration (Combined Worker + Client) + +This sample demonstrates running both the worker and client in a single process for +concurrent multi-agent orchestration. The worker registers two domain-specific agents +(physicist and chemist) and an orchestration function that runs them in parallel. + +The orchestration uses OrchestrationAgentExecutor to execute agents concurrently +and aggregate their responses. + +Prerequisites: +- Set AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_CHAT_DEPLOYMENT_NAME + (plus AZURE_OPENAI_API_KEY or Azure CLI authentication) +- Durable Task Scheduler must be running (e.g., using Docker) + +To run this sample: + python sample.py +""" + +import asyncio +import json +import logging +import os +from collections.abc import Generator +from typing import Any + +from agent_framework import AgentRunResponse +from agent_framework.azure import AzureOpenAIChatClient +from agent_framework_durabletask import DurableAIAgentOrchestrationContext, DurableAIAgentWorker +from azure.identity import AzureCliCredential, DefaultAzureCredential +from dotenv import load_dotenv +from durabletask.task import OrchestrationContext, when_all, Task +from durabletask.azuremanaged.client import DurableTaskSchedulerClient +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Agent names +PHYSICIST_AGENT_NAME = "PhysicistAgent" +CHEMIST_AGENT_NAME = "ChemistAgent" + + +def create_physicist_agent(): + """Create the Physicist agent using Azure OpenAI. + + Returns: + AgentProtocol: The configured Physicist agent + """ + return AzureOpenAIChatClient(credential=AzureCliCredential()).create_agent( + name=PHYSICIST_AGENT_NAME, + instructions="You are an expert in physics. You answer questions from a physics perspective.", + ) + + +def create_chemist_agent(): + """Create the Chemist agent using Azure OpenAI. + + Returns: + AgentProtocol: The configured Chemist agent + """ + return AzureOpenAIChatClient(credential=AzureCliCredential()).create_agent( + name=CHEMIST_AGENT_NAME, + instructions="You are an expert in chemistry. You answer questions from a chemistry perspective.", + ) + + +def multi_agent_concurrent_orchestration(context: OrchestrationContext, prompt: str) -> Generator[Task[Any], Any, dict[str, str]]: + """Orchestration that runs both agents in parallel and aggregates results. + + Uses DurableAIAgentOrchestrationContext to wrap the orchestration context and + access agents via the OrchestrationAgentExecutor. + + Args: + context: The orchestration context + + Returns: + dict: Dictionary with 'physicist' and 'chemist' response texts + """ + logger.info(f"[Orchestration] Starting concurrent execution for prompt: {prompt}") + + # Wrap the orchestration context to access agents + agent_context = DurableAIAgentOrchestrationContext(context) + + # Get agents using the agent context (returns DurableAIAgent proxies) + physicist = agent_context.get_agent(PHYSICIST_AGENT_NAME) + chemist = agent_context.get_agent(CHEMIST_AGENT_NAME) + + # Create separate threads for each agent + physicist_thread = physicist.get_new_thread() + chemist_thread = chemist.get_new_thread() + + logger.info(f"[Orchestration] Created threads - Physicist: {physicist_thread.session_id}, Chemist: {chemist_thread.session_id}") + + # Create tasks from agent.run() calls - these return DurableAgentTask instances + physicist_task = physicist.run(messages=str(prompt), thread=physicist_thread) + chemist_task = chemist.run(messages=str(prompt), thread=chemist_thread) + + logger.info("[Orchestration] Created agent tasks, executing concurrently...") + + # Execute both tasks concurrently using task.when_all + # The DurableAgentTask instances wrap the underlying entity calls + task_results = yield when_all([physicist_task, chemist_task]) + + logger.info("[Orchestration] Both agents completed") + + # Extract results from the tasks - DurableAgentTask yields AgentRunResponse + physicist_result: AgentRunResponse = task_results[0] + chemist_result: AgentRunResponse = task_results[1] + + result = { + "physicist": physicist_result.text, + "chemist": chemist_result.text, + } + + logger.info(f"[Orchestration] Aggregated results ready") + return result + + +async def run_client(endpoint: str, taskhub_name: str, credential: DefaultAzureCredential | None, prompt: str): + """Run the client to start and monitor the orchestration. + + Args: + endpoint: The durable task scheduler endpoint + taskhub_name: The task hub name + credential: The credential for authentication + prompt: The prompt to send to both agents + """ + logger.info("") + logger.info("=" * 80) + logger.info("CLIENT: Starting orchestration...") + logger.info("=" * 80) + + # Create a client + client = DurableTaskSchedulerClient( + host_address=endpoint, + secure_channel=endpoint != "http://localhost:8080", + taskhub=taskhub_name, + token_credential=credential + ) + + logger.info(f"Prompt: {prompt}") + logger.info("") + + try: + # Start the orchestration with the prompt as input + instance_id = client.schedule_new_orchestration( + multi_agent_concurrent_orchestration, + input=prompt, + ) + + logger.info(f"Orchestration started with instance ID: {instance_id}") + logger.info("Waiting for orchestration to complete...") + logger.info("") + + # Retrieve the final state + metadata = client.wait_for_orchestration_completion( + instance_id=instance_id + ) + + if metadata and metadata.runtime_status.name == "COMPLETED": + result = metadata.serialized_output + + logger.info("") + logger.info("=" * 80) + logger.info("ORCHESTRATION COMPLETED SUCCESSFULLY!") + logger.info("=" * 80) + logger.info("") + logger.info(f"Prompt: {prompt}") + logger.info("") + logger.info("Results:") + logger.info("") + + # Parse and display the result + if result: + result_dict = json.loads(result) + + logger.info("Physicist's response:") + logger.info(f" {result_dict.get('physicist', 'N/A')}") + logger.info("") + + logger.info("Chemist's response:") + logger.info(f" {result_dict.get('chemist', 'N/A')}") + logger.info("") + + logger.info("=" * 80) + + elif metadata: + logger.error(f"Orchestration ended with status: {metadata.runtime_status.name}") + if metadata.serialized_output: + logger.error(f"Output: {metadata.serialized_output}") + else: + logger.error("Orchestration did not complete within the timeout period") + + except Exception as e: + logger.exception(f"Error during orchestration: {e}") + + +def main(): + """Main entry point - runs both worker and client in single process.""" + logger.info("Starting Durable Task Multi-Agent Orchestration Sample (Combined Worker + Client)...") + + # Get environment variables for taskhub and endpoint with defaults + taskhub_name = os.getenv("TASKHUB", "default") + endpoint = os.getenv("ENDPOINT", "http://localhost:8080") + + logger.info(f"Using taskhub: {taskhub_name}") + logger.info(f"Using endpoint: {endpoint}") + logger.info("") + + # Set credential to None for emulator, or DefaultAzureCredential for Azure + credential = None if endpoint == "http://localhost:8080" else DefaultAzureCredential() + secure_channel = endpoint != "http://localhost:8080" + + # Create and start the worker using a context manager + with DurableTaskSchedulerWorker( + host_address=endpoint, + secure_channel=secure_channel, + taskhub=taskhub_name, + token_credential=credential + ) as worker: + + # Wrap with the agent worker + agent_worker = DurableAIAgentWorker(worker) + + # Create and register both agents + logger.info("Creating and registering agents...") + physicist_agent = create_physicist_agent() + chemist_agent = create_chemist_agent() + + agent_worker.add_agent(physicist_agent) + agent_worker.add_agent(chemist_agent) + + logger.info(f"✓ Registered agent: {physicist_agent.name}") + logger.info(f" Entity name: dafx-{physicist_agent.name}") + logger.info(f"✓ Registered agent: {chemist_agent.name}") + logger.info(f" Entity name: dafx-{chemist_agent.name}") + logger.info("") + + # Register the orchestration function + logger.info("Registering orchestration function...") + worker.add_orchestrator(multi_agent_concurrent_orchestration) + logger.info(f"✓ Registered orchestration: {multi_agent_concurrent_orchestration.__name__}") + logger.info("") + + # Start the worker + worker.start() + logger.info("Worker started and listening for requests...") + + # Define the prompt + prompt = "What is temperature?" + + try: + # Run the client to start the orchestration + asyncio.run(run_client(endpoint, taskhub_name, credential, prompt)) + + except Exception as e: + logger.exception(f"Error during sample execution: {e}") + + logger.info("") + logger.info("Sample completed. Worker shutting down...") + + +if __name__ == "__main__": + load_dotenv() + main() diff --git a/python/samples/getting_started/durabletask/README.md b/python/samples/getting_started/durabletask/README.md new file mode 100644 index 0000000000..3b63294756 --- /dev/null +++ b/python/samples/getting_started/durabletask/README.md @@ -0,0 +1,124 @@ +# Durable Task Samples + +This directory contains samples for durable agent hosting using the Durable Task Scheduler. These samples demonstrate the worker-client architecture pattern, enabling distributed agent execution with persistent conversation state. + +- **[01_single_agent](01_single_agent/)**: A sample that demonstrates how to host a single conversational agent using the Durable Task Scheduler and interact with it via a client. +- **[04_single_agent_orchestration_chaining](04_single_agent_orchestration_chaining/)**: A sample that demonstrates how to chain multiple invocations of the same agent using a durable orchestration. +- **[05_multi_agent_orchestration_concurrency](05_multi_agent_orchestration_concurrency/)**: A sample that demonstrates how to host multiple agents and run them concurrently using a durable orchestration. + +## Running the Samples + +These samples are designed to be run locally in a cloned repository. + +### Prerequisites + +The following prerequisites are required to run the samples: + +- [Python 3.9 or later](https://www.python.org/downloads/) +- [Azure CLI](https://learn.microsoft.com/cli/azure/install-azure-cli) installed and authenticated (`az login`) or an API key for the Azure OpenAI service +- [Azure OpenAI Service](https://learn.microsoft.com/azure/ai-services/openai/how-to/create-resource) with a deployed model (gpt-4o-mini or better is recommended) +- [Durable Task Scheduler](https://learn.microsoft.com/azure/azure-functions/durable/durable-task-scheduler/develop-with-durable-task-scheduler) (local emulator or Azure-hosted) +- [Docker](https://docs.docker.com/get-docker/) installed if running the Durable Task Scheduler emulator locally + +### Configuring RBAC Permissions for Azure OpenAI + +These samples are configured to use the Azure OpenAI service with RBAC permissions to access the model. You'll need to configure the RBAC permissions for the Azure OpenAI service to allow the Python app to access the model. + +Below is an example of how to configure the RBAC permissions for the Azure OpenAI service to allow the current user to access the model. + +Bash (Linux/macOS/WSL): + +```bash +az role assignment create \ + --assignee "yourname@contoso.com" \ + --role "Cognitive Services OpenAI User" \ + --scope /subscriptions//resourceGroups//providers/Microsoft.CognitiveServices/accounts/ +``` + +PowerShell: + +```powershell +az role assignment create ` + --assignee "yourname@contoso.com" ` + --role "Cognitive Services OpenAI User" ` + --scope /subscriptions//resourceGroups//providers/Microsoft.CognitiveServices/accounts/ +``` + +More information on how to configure RBAC permissions for Azure OpenAI can be found in the [Azure OpenAI documentation](https://learn.microsoft.com/azure/ai-services/openai/how-to/create-resource?pivots=cli). + +### Setting an API key for the Azure OpenAI service + +As an alternative to configuring Azure RBAC permissions, you can set an API key for the Azure OpenAI service by setting the `AZURE_OPENAI_API_KEY` environment variable. + +Bash (Linux/macOS/WSL): + +```bash +export AZURE_OPENAI_API_KEY="your-api-key" +``` + +PowerShell: + +```powershell +$env:AZURE_OPENAI_API_KEY="your-api-key" +``` + +### Start Durable Task Scheduler + +Most samples use the Durable Task Scheduler (DTS) to support hosted agents and durable orchestrations. DTS also allows you to view the status of orchestrations and their inputs and outputs from a web UI. + +To run the Durable Task Scheduler locally, you can use the following `docker` command: + +```bash +docker run -d --name dts-emulator -p 8080:8080 -p 8082:8082 mcr.microsoft.com/dts/dts-emulator:latest +``` + +The DTS dashboard will be available at `http://localhost:8082`. + +### Environment Configuration + +Each sample reads configuration from environment variables. You'll need to set the following environment variables: + +```bash +export AZURE_OPENAI_ENDPOINT="https://your-resource.openai.azure.com/" +export AZURE_OPENAI_CHAT_DEPLOYMENT_NAME="your-deployment-name" +``` + +### Installing Dependencies + +Navigate to the sample directory and install dependencies: + +```bash +cd samples/getting_started/durabletask/01_single_agent +pip install -r requirements.txt +``` + +### Running the Samples + +Each sample follows a worker-client architecture. Most samples provide separate `worker.py` and `client.py` files, though some include a combined `sample.py` for convenience. + +**Running with separate worker and client:** + +In one terminal, start the worker: + +```bash +python worker.py +``` + +In another terminal, run the client: + +```bash +python client.py +``` + +**Running with combined sample:** + +```bash +python sample.py +``` + +### Viewing the Sample Output + +The sample output is displayed directly in the terminal where you ran the Python script. Agent responses are printed to stdout with log formatting for better readability. + +You can also see the state of agents and orchestrations in the Durable Task Scheduler dashboard at `http://localhost:8082`. + From b00f8e08afaae84879a0228737f1f1224c595436 Mon Sep 17 00:00:00 2001 From: Laveesh Rohra Date: Wed, 31 Dec 2025 12:50:41 -0800 Subject: [PATCH 06/11] Fix tests --- .../durabletask/agent_framework_durabletask/_executors.py | 2 +- python/packages/durabletask/tests/test_executors.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/packages/durabletask/agent_framework_durabletask/_executors.py b/python/packages/durabletask/agent_framework_durabletask/_executors.py index 563706f531..1753dbeacb 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_executors.py +++ b/python/packages/durabletask/agent_framework_durabletask/_executors.py @@ -53,9 +53,9 @@ def __init__( response_format: Optional Pydantic model for response parsing correlation_id: Correlation ID for logging """ - super().__init__([entity_task]) # type: ignore[misc] self._response_format = response_format self._correlation_id = correlation_id + super().__init__([entity_task]) # type: ignore[misc] def on_child_completed(self, task: Task[Any]) -> None: """Handle completion of the underlying entity task. diff --git a/python/packages/durabletask/tests/test_executors.py b/python/packages/durabletask/tests/test_executors.py index e518528832..a42200bdea 100644 --- a/python/packages/durabletask/tests/test_executors.py +++ b/python/packages/durabletask/tests/test_executors.py @@ -207,7 +207,7 @@ def test_orchestration_executor_calls_entity_with_correct_parameters( # Verify entity ID assert isinstance(entity_id_arg, EntityInstanceId) - assert entity_id_arg.entity.startswith("test_agent-") + assert entity_id_arg.entity == "dafx-test_agent" # Verify operation name assert operation_arg == "run" From 04741e75f04b254d202da3b35fedc792cd578cd2 Mon Sep 17 00:00:00 2001 From: Laveesh Rohra Date: Wed, 31 Dec 2025 14:58:33 -0800 Subject: [PATCH 07/11] Fix tests --- .../tests/test_orchestration.py | 475 ++++-------------- .../function_app.py | 12 +- 2 files changed, 111 insertions(+), 376 deletions(-) diff --git a/python/packages/azurefunctions/tests/test_orchestration.py b/python/packages/azurefunctions/tests/test_orchestration.py index a5f67b0510..3da03f12be 100644 --- a/python/packages/azurefunctions/tests/test_orchestration.py +++ b/python/packages/azurefunctions/tests/test_orchestration.py @@ -7,35 +7,13 @@ import pytest from agent_framework import AgentRunResponse, ChatMessage -from agent_framework_durabletask import AgentSessionId, DurableAgentThread, DurableAIAgent +from agent_framework_durabletask import DurableAIAgent from azure.durable_functions.models.Task import TaskBase, TaskState from agent_framework_azurefunctions import AgentFunctionApp from agent_framework_azurefunctions._orchestration import AgentTask -def _create_mock_context(instance_id: str = "test-instance", uuid_values: list[str] | None = None) -> Mock: - """Create a mock orchestration context with common attributes. - - Args: - instance_id: The orchestration instance ID - uuid_values: List of UUIDs to return from new_uuid() calls (if None, returns "test-guid") - - Returns: - Mock context object configured for testing - """ - mock_context = Mock() - mock_context.instance_id = instance_id - mock_context.current_utc_datetime = Mock() - - if uuid_values: - mock_context.new_uuid = Mock(side_effect=uuid_values) - else: - mock_context.new_uuid = Mock(return_value="test-guid") - - return mock_context - - def _app_with_registered_agents(*agent_names: str) -> AgentFunctionApp: app = AgentFunctionApp(enable_health_check=False, enable_http_endpoints=False) for name in agent_names: @@ -60,46 +38,96 @@ def _create_entity_task(task_id: int = 1) -> TaskBase: return _FakeTask(task_id) -class TestAgentResponseHelpers: - """Tests for helper utilities that prepare AgentRunResponse values.""" +@pytest.fixture +def mock_context(): + """Create a mock orchestration context with UUID support.""" + context = Mock() + context.instance_id = "test-instance" + context.current_utc_datetime = Mock() + return context + + +@pytest.fixture +def mock_context_with_uuid() -> tuple[Mock, str]: + """Create a mock context with a single UUID.""" + from uuid import UUID + + context = Mock() + context.instance_id = "test-instance" + context.current_utc_datetime = Mock() + test_uuid = UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa") + context.new_uuid = Mock(return_value=test_uuid) + return context, test_uuid.hex + + +@pytest.fixture +def mock_context_with_multiple_uuids() -> tuple[Mock, list[str]]: + """Create a mock context with multiple UUIDs via side_effect.""" + from uuid import UUID + + context = Mock() + context.instance_id = "test-instance" + context.current_utc_datetime = Mock() + uuids = [ + UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"), + UUID("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"), + UUID("cccccccc-cccc-cccc-cccc-cccccccccccc"), + ] + context.new_uuid = Mock(side_effect=uuids) + # Return the hex versions for assertion checking + hex_uuids = [uuid.hex for uuid in uuids] + return context, hex_uuids - @staticmethod - def _create_agent_task() -> AgentTask: - entity_task = _create_entity_task() - return AgentTask(entity_task, None, "correlation-id") - def test_load_agent_response_from_instance(self) -> None: - task = self._create_agent_task() - response = AgentRunResponse(messages=[ChatMessage(role="assistant", text='{"foo": "bar"}')]) +@pytest.fixture +def executor_with_uuid() -> tuple[Any, Mock, str]: + """Create an executor with a mocked generate_unique_id method.""" + from agent_framework_azurefunctions._orchestration import AzureFunctionsAgentExecutor - loaded = task._load_agent_response(response) + context = Mock() + context.instance_id = "test-instance" + context.current_utc_datetime = Mock() - assert loaded is response - assert loaded.value is None + executor = AzureFunctionsAgentExecutor(context) + test_uuid_hex = "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + executor.generate_unique_id = Mock(return_value=test_uuid_hex) - def test_load_agent_response_from_serialized(self) -> None: - task = self._create_agent_task() - serialized = AgentRunResponse(messages=[ChatMessage(role="assistant", text="structured")]).to_dict() - serialized["value"] = {"answer": 42} + return executor, context, test_uuid_hex - loaded = task._load_agent_response(serialized) - assert loaded is not None - assert loaded.value == {"answer": 42} - loaded_dict = loaded.to_dict() - assert loaded_dict["type"] == "agent_run_response" +@pytest.fixture +def executor_with_multiple_uuids() -> tuple[Any, Mock, list[str]]: + """Create an executor with multiple mocked UUIDs.""" + from agent_framework_azurefunctions._orchestration import AzureFunctionsAgentExecutor - def test_load_agent_response_rejects_none(self) -> None: - task = self._create_agent_task() + context = Mock() + context.instance_id = "test-instance" + context.current_utc_datetime = Mock() - with pytest.raises(ValueError): - task._load_agent_response(None) + executor = AzureFunctionsAgentExecutor(context) + uuid_hexes = [ + "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", + "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb", + "cccccccc-cccc-cccc-cccc-cccccccccccc", + "dddddddd-dddd-dddd-dddd-dddddddddddd", + "eeeeeeee-eeee-eeee-eeee-eeeeeeeeeeee", + ] + executor.generate_unique_id = Mock(side_effect=uuid_hexes) - def test_load_agent_response_rejects_unsupported_type(self) -> None: - task = self._create_agent_task() + return executor, context, uuid_hexes - with pytest.raises(TypeError, match="Unsupported type"): - task._load_agent_response(["invalid", "list"]) # type: ignore[arg-type] + +@pytest.fixture +def executor_with_context(mock_context_with_uuid: tuple[Mock, str]) -> tuple[Any, Mock]: + """Create an executor with a mocked context.""" + from agent_framework_azurefunctions._orchestration import AzureFunctionsAgentExecutor + + context, _ = mock_context_with_uuid + return AzureFunctionsAgentExecutor(context), context + + +class TestAgentResponseHelpers: + """Tests for response handling through public AgentTask API.""" def test_try_set_value_success(self) -> None: """Test try_set_value correctly processes successful task completion.""" @@ -166,302 +194,6 @@ class TestSchema(BaseModel): assert isinstance(task.result.value, TestSchema) assert task.result.value.answer == "42" - def test_ensure_response_format_parses_value(self) -> None: - """Test _ensure_response_format correctly parses response value.""" - from pydantic import BaseModel - - class SampleSchema(BaseModel): - name: str - - task = self._create_agent_task() - response = AgentRunResponse(messages=[ChatMessage(role="assistant", text='{"name": "test"}')]) - - # Value should be None initially - assert response.value is None - - # Parse the value - task._ensure_response_format(SampleSchema, "test-correlation", response) - - # Value should now be parsed - assert isinstance(response.value, SampleSchema) - assert response.value.name == "test" - - def test_ensure_response_format_skips_if_already_parsed(self) -> None: - """Test _ensure_response_format does not re-parse if value already matches format.""" - from pydantic import BaseModel - - class SampleSchema(BaseModel): - name: str - - task = self._create_agent_task() - existing_value = SampleSchema(name="existing") - response = AgentRunResponse( - messages=[ChatMessage(role="assistant", text='{"name": "new"}')], - value=existing_value, - ) - - # Call _ensure_response_format - task._ensure_response_format(SampleSchema, "test-correlation", response) - - # Value should remain unchanged (not re-parsed) - assert response.value is existing_value - assert response.value.name == "existing" - - -class TestDurableAIAgent: - """Test suite for DurableAIAgent wrapper.""" - - def test_implements_agent_protocol(self) -> None: - """Test that DurableAIAgent implements AgentProtocol.""" - from agent_framework import AgentProtocol - - mock_context = Mock() - agent = DurableAIAgent(mock_context, "TestAgent") - - # Check that agent satisfies AgentProtocol - assert isinstance(agent, AgentProtocol) - - def test_has_agent_protocol_properties(self) -> None: - """Test that DurableAIAgent has AgentProtocol properties.""" - mock_context = Mock() - agent = DurableAIAgent(mock_context, "TestAgent") - - # AgentProtocol properties - assert hasattr(agent, "id") - assert hasattr(agent, "name") - assert hasattr(agent, "description") - assert hasattr(agent, "display_name") - - # Verify values - assert agent.name == "TestAgent" - assert agent.description == "Durable agent proxy for TestAgent" - assert agent.display_name == "TestAgent" - assert agent.id is not None # Auto-generated UUID - - def test_get_new_thread(self) -> None: - """Test creating a new agent thread.""" - mock_context = _create_mock_context("test-instance-456", ["test-guid-456"]) - - from agent_framework_azurefunctions._orchestration import AzureFunctionsAgentExecutor - - executor = AzureFunctionsAgentExecutor(mock_context) - agent = DurableAIAgent(executor, "WriterAgent") - thread = agent.get_new_thread() - - assert isinstance(thread, DurableAgentThread) - assert thread.session_id is not None - session_id = thread.session_id - assert isinstance(session_id, AgentSessionId) - assert session_id.name == "WriterAgent" - assert session_id.key == "test-guid-456" - mock_context.new_uuid.assert_called_once() - - def test_get_new_thread_deterministic(self) -> None: - """Test that get_new_thread creates deterministic session IDs.""" - mock_context = _create_mock_context("test-instance-789", ["session-guid-1", "session-guid-2"]) - - from agent_framework_azurefunctions._orchestration import AzureFunctionsAgentExecutor - - executor = AzureFunctionsAgentExecutor(mock_context) - agent = DurableAIAgent(executor, "EditorAgent") - - # Create multiple threads - they should have unique session IDs - thread1 = agent.get_new_thread() - thread2 = agent.get_new_thread() - - assert isinstance(thread1, DurableAgentThread) - assert isinstance(thread2, DurableAgentThread) - - session_id1 = thread1.session_id - session_id2 = thread2.session_id - assert session_id1 is not None and session_id2 is not None - assert isinstance(session_id1, AgentSessionId) - assert isinstance(session_id2, AgentSessionId) - assert session_id1.name == "EditorAgent" - assert session_id2.name == "EditorAgent" - assert session_id1.key == "session-guid-1" - assert session_id2.key == "session-guid-2" - assert mock_context.new_uuid.call_count == 2 - - def test_run_creates_entity_call(self) -> None: - """Test that run() creates proper entity call and returns a Task.""" - mock_context = _create_mock_context("test-instance-001", ["thread-guid", "correlation-guid"]) - - entity_task = _create_entity_task() - mock_context.call_entity = Mock(return_value=entity_task) - - from agent_framework_azurefunctions._orchestration import AzureFunctionsAgentExecutor - - executor = AzureFunctionsAgentExecutor(mock_context) - agent = DurableAIAgent(executor, "TestAgent") - - # Create thread - thread = agent.get_new_thread() - - # Call run() - returns AgentTask directly - task = agent.run(messages="Test message", thread=thread, enable_tool_calls=True) - - assert isinstance(task, AgentTask) - assert task.children[0] == entity_task - - # Verify call_entity was called with correct parameters - assert mock_context.call_entity.called - call_args = mock_context.call_entity.call_args - _, operation, request = call_args[0] - - assert operation == "run" - assert request["message"] == "Test message" - assert request["enable_tool_calls"] is True - assert "correlationId" in request - assert request["correlationId"] == "correlation-guid" - assert "thread_id" not in request - # Verify orchestration ID is set from context.instance_id - assert "orchestrationId" in request - assert request["orchestrationId"] == "test-instance-001" - - def test_run_sets_orchestration_id(self) -> None: - """Test that run() sets the orchestration_id from context.instance_id.""" - mock_context = _create_mock_context("my-orchestration-123", ["thread-guid", "correlation-guid"]) - - entity_task = _create_entity_task() - mock_context.call_entity = Mock(return_value=entity_task) - - from agent_framework_azurefunctions._orchestration import AzureFunctionsAgentExecutor - - executor = AzureFunctionsAgentExecutor(mock_context) - agent = DurableAIAgent(executor, "TestAgent") - thread = agent.get_new_thread() - - agent.run(messages="Test", thread=thread) - - call_args = mock_context.call_entity.call_args - request = call_args[0][2] - - assert request["orchestrationId"] == "my-orchestration-123" - - def test_run_without_thread(self) -> None: - """Test that run() works without explicit thread (creates unique session key).""" - mock_context = _create_mock_context("test-instance-002", ["auto-generated-guid", "correlation-guid"]) - - entity_task = _create_entity_task() - mock_context.call_entity = Mock(return_value=entity_task) - - from agent_framework_azurefunctions._orchestration import AzureFunctionsAgentExecutor - - executor = AzureFunctionsAgentExecutor(mock_context) - agent = DurableAIAgent(executor, "TestAgent") - - # Call without thread - task = agent.run(messages="Test message") - - assert isinstance(task, AgentTask) - assert task.children[0] == entity_task - - # Verify the entity ID uses the auto-generated GUID with dafx- prefix - call_args = mock_context.call_entity.call_args - entity_id = call_args[0][0] - assert entity_id.name == "dafx-TestAgent" - assert entity_id.key == "auto-generated-guid" - # Should be called twice: once for session_key, once for correlation_id - assert mock_context.new_uuid.call_count == 2 - - def test_run_with_response_format(self) -> None: - """Test that run() passes response format correctly.""" - mock_context = _create_mock_context("test-instance-003", ["thread-guid", "correlation-guid"]) - - entity_task = _create_entity_task() - mock_context.call_entity = Mock(return_value=entity_task) - - from agent_framework_azurefunctions._orchestration import AzureFunctionsAgentExecutor - - executor = AzureFunctionsAgentExecutor(mock_context) - agent = DurableAIAgent(executor, "TestAgent") - - from pydantic import BaseModel - - class SampleSchema(BaseModel): - key: str - - # Create thread and call - thread = agent.get_new_thread() - - task = agent.run(messages="Test message", thread=thread, response_format=SampleSchema) - - assert isinstance(task, AgentTask) - assert task.children[0] == entity_task - - # Verify schema was passed in the call_entity arguments - call_args = mock_context.call_entity.call_args - input_data = call_args[0][2] # Third argument is input_data - assert "response_format" in input_data - assert input_data["response_format"]["__response_schema_type__"] == "pydantic_model" - assert input_data["response_format"]["module"] == SampleSchema.__module__ - assert input_data["response_format"]["qualname"] == SampleSchema.__qualname__ - - def test_run_with_chat_message(self) -> None: - """Test that run() handles ChatMessage input.""" - from agent_framework import ChatMessage - - mock_context = _create_mock_context(uuid_values=["thread-guid", "correlation-guid"]) - entity_task = _create_entity_task() - mock_context.call_entity = Mock(return_value=entity_task) - - from agent_framework_azurefunctions._orchestration import AzureFunctionsAgentExecutor - - executor = AzureFunctionsAgentExecutor(mock_context) - agent = DurableAIAgent(executor, "TestAgent") - thread = agent.get_new_thread() - - # Call with ChatMessage - msg = ChatMessage(role="user", text="Hello") - task = agent.run(messages=msg, thread=thread) - - assert isinstance(task, AgentTask) - assert task.children[0] == entity_task - - # Verify message was converted to string - call_args = mock_context.call_entity.call_args - request = call_args[0][2] - assert request["message"] == "Hello" - - def test_run_stream_raises_not_implemented(self) -> None: - """Test that run_stream() method raises NotImplementedError.""" - mock_context = Mock() - agent = DurableAIAgent(mock_context, "TestAgent") - - with pytest.raises(NotImplementedError) as exc_info: - agent.run_stream("Test message") - - error_msg = str(exc_info.value) - assert "Streaming is not supported" in error_msg - - def test_entity_id_format(self) -> None: - """Test that EntityId is created with correct format (name, key).""" - from azure.durable_functions import EntityId - - mock_context = _create_mock_context(uuid_values=["test-guid-789", "correlation-guid"]) - mock_context.call_entity = Mock(return_value=_create_entity_task()) - - from agent_framework_azurefunctions._orchestration import AzureFunctionsAgentExecutor - - executor = AzureFunctionsAgentExecutor(mock_context) - agent = DurableAIAgent(executor, "WriterAgent") - thread = agent.get_new_thread() - - # Call run() to trigger entity ID creation - agent.run("Test", thread=thread) - - # Verify call_entity was called with correct EntityId - call_args = mock_context.call_entity.call_args - entity_id = call_args[0][0] - - # EntityId should be EntityId(name="dafx-WriterAgent", key="test-guid-789") - # Which formats as "@dafx-writeragent@test-guid-789" - assert isinstance(entity_id, EntityId) - assert entity_id.name == "dafx-WriterAgent" - assert entity_id.key == "test-guid-789" - assert str(entity_id) == "@dafx-writeragent@test-guid-789" - class TestAgentFunctionAppGetAgent: """Test suite for AgentFunctionApp.get_agent.""" @@ -477,13 +209,9 @@ def test_get_agent_raises_for_unregistered_agent(self) -> None: class TestOrchestrationIntegration: """Integration tests for orchestration scenarios.""" - def test_sequential_agent_calls_simulation(self) -> None: + def test_sequential_agent_calls_simulation(self, executor_with_multiple_uuids: tuple[Any, Mock, list[str]]) -> None: """Simulate sequential agent calls in an orchestration.""" - # new_uuid will be called 3 times: - # 1. thread creation - # 2. correlation_id for first call - # 3. correlation_id for second call - mock_context = _create_mock_context("test-orchestration-001", ["deterministic-guid-001", "corr-1", "corr-2"]) + executor, context, uuid_hexes = executor_with_multiple_uuids # Track entity calls entity_calls: list[dict[str, Any]] = [] @@ -492,10 +220,10 @@ def mock_call_entity_side_effect(entity_id: Any, operation: str, input_data: dic entity_calls.append({"entity_id": str(entity_id), "operation": operation, "input": input_data}) return _create_entity_task() - mock_context.call_entity = Mock(side_effect=mock_call_entity_side_effect) + context.call_entity = Mock(side_effect=mock_call_entity_side_effect) - app = _app_with_registered_agents("WriterAgent") - agent = app.get_agent(mock_context, "WriterAgent") + # Create agent directly with executor (not via app.get_agent) + agent = DurableAIAgent(executor, "WriterAgent") # Create thread thread = agent.get_new_thread() @@ -511,18 +239,15 @@ def mock_call_entity_side_effect(entity_id: Any, operation: str, input_data: dic # Verify both calls used the same entity (same session key) assert len(entity_calls) == 2 assert entity_calls[0]["entity_id"] == entity_calls[1]["entity_id"] - # EntityId format is @dafx-writeragent@deterministic-guid-001 - assert entity_calls[0]["entity_id"] == "@dafx-writeragent@deterministic-guid-001" - # new_uuid called 3 times: thread + 2 correlation IDs - assert mock_context.new_uuid.call_count == 3 + # EntityId format is @dafx-writeragent@ + expected_entity_id = f"@dafx-writeragent@{uuid_hexes[0]}" + assert entity_calls[0]["entity_id"] == expected_entity_id + # generate_unique_id called 3 times: thread + 2 correlation IDs + assert executor.generate_unique_id.call_count == 3 - def test_multiple_agents_in_orchestration(self) -> None: + def test_multiple_agents_in_orchestration(self, executor_with_multiple_uuids: tuple[Any, Mock, list[str]]) -> None: """Test using multiple different agents in one orchestration.""" - # Mock new_uuid to return different GUIDs for each call - # Order: writer thread, editor thread, writer correlation, editor correlation - mock_context = _create_mock_context( - "test-orchestration-002", ["writer-guid-001", "editor-guid-002", "writer-corr", "editor-corr"] - ) + executor, context, uuid_hexes = executor_with_multiple_uuids entity_calls: list[str] = [] @@ -530,11 +255,11 @@ def mock_call_entity_side_effect(entity_id: Any, operation: str, input_data: dic entity_calls.append(str(entity_id)) return _create_entity_task() - mock_context.call_entity = Mock(side_effect=mock_call_entity_side_effect) + context.call_entity = Mock(side_effect=mock_call_entity_side_effect) - app = _app_with_registered_agents("WriterAgent", "EditorAgent") - writer = app.get_agent(mock_context, "WriterAgent") - editor = app.get_agent(mock_context, "EditorAgent") + # Create agents directly with executor (not via app.get_agent) + writer = DurableAIAgent(executor, "WriterAgent") + editor = DurableAIAgent(executor, "EditorAgent") writer_thread = writer.get_new_thread() editor_thread = editor.get_new_thread() @@ -548,9 +273,11 @@ def mock_call_entity_side_effect(entity_id: Any, operation: str, input_data: dic # Verify different entity IDs were used assert len(entity_calls) == 2 - # EntityId format is @dafx-agentname@guid (lowercased agent name with dafx- prefix) - assert entity_calls[0] == "@dafx-writeragent@writer-guid-001" - assert entity_calls[1] == "@dafx-editoragent@editor-guid-002" + # EntityId format is @dafx-agentname@uuid_hex (lowercased agent name with dafx- prefix) + expected_writer_id = f"@dafx-writeragent@{uuid_hexes[0]}" + expected_editor_id = f"@dafx-editoragent@{uuid_hexes[1]}" + assert entity_calls[0] == expected_writer_id + assert entity_calls[1] == expected_editor_id if __name__ == "__main__": diff --git a/python/samples/getting_started/azure_functions/07_single_agent_orchestration_hitl/function_app.py b/python/samples/getting_started/azure_functions/07_single_agent_orchestration_hitl/function_app.py index 5c79bb4a86..fc72ceb770 100644 --- a/python/samples/getting_started/azure_functions/07_single_agent_orchestration_hitl/function_app.py +++ b/python/samples/getting_started/azure_functions/07_single_agent_orchestration_hitl/function_app.py @@ -62,7 +62,7 @@ def _create_writer_agent() -> Any: # 3. Activities encapsulate external work for review notifications and publishing. @app.activity_trigger(input_name="content") -def notify_user_for_approval(content: dict[str, Any]) -> None: +def notify_user_for_approval(content: Any) -> None: model = GeneratedContent.model_validate(content) logger.info("NOTIFICATION: Please review the following content for approval:") logger.info("Title: %s", model.title or "(untitled)") @@ -71,7 +71,7 @@ def notify_user_for_approval(content: dict[str, Any]) -> None: @app.activity_trigger(input_name="content") -def publish_content(content: dict[str, Any]) -> None: +def publish_content(content: Any) -> None: model = GeneratedContent.model_validate(content) logger.info("PUBLISHING: Content has been published successfully:") logger.info("Title: %s", model.title or "(untitled)") @@ -286,6 +286,14 @@ async def get_orchestration_status( show_input=True, ) + # Check if status is None or if the instance doesn't exist (runtime_status is None) + if getattr(status, "runtime_status", None) is None: + return func.HttpResponse( + body=json.dumps({"error": "Instance not found."}), + status_code=404, + mimetype="application/json", + ) + response_data: dict[str, Any] = { "instanceId": getattr(status, "instance_id", None), "runtimeStatus": getattr(status.runtime_status, "name", None) From f956cc0c9121c42bca68fffaa9d640dacd3e30c2 Mon Sep 17 00:00:00 2001 From: Laveesh Rohra Date: Fri, 2 Jan 2026 14:02:42 -0800 Subject: [PATCH 08/11] Update requirements --- .../agent_framework_durabletask/_executors.py | 2 +- .../durabletask/01_single_agent/requirements.txt | 9 +++------ .../requirements.txt | 9 +++------ .../requirements.txt | 9 +++------ 4 files changed, 10 insertions(+), 19 deletions(-) diff --git a/python/packages/durabletask/agent_framework_durabletask/_executors.py b/python/packages/durabletask/agent_framework_durabletask/_executors.py index 1753dbeacb..0a82988c12 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_executors.py +++ b/python/packages/durabletask/agent_framework_durabletask/_executors.py @@ -4,7 +4,7 @@ These classes are internal execution strategies used by the DurableAIAgent shim. They are intentionally separate from the public client/orchestration APIs to keep -only `get_agent` exposed to consumers. Providers implement the execution contract +only `get_agent` exposed to consumers. Executors implement the execution contract and are injected into the shim. """ diff --git a/python/samples/getting_started/durabletask/01_single_agent/requirements.txt b/python/samples/getting_started/durabletask/01_single_agent/requirements.txt index da871507c8..371b9e3b79 100644 --- a/python/samples/getting_started/durabletask/01_single_agent/requirements.txt +++ b/python/samples/getting_started/durabletask/01_single_agent/requirements.txt @@ -1,9 +1,6 @@ -# Agent Framework packages -agent-framework-azure -agent-framework-durabletask - -# Durable Task Python SDK with Azure Managed support -durabletask-python +# Agent Framework packages (installing from local package until a package is published) +-e ../../../../ +-e ../../../../packages/durabletask # Azure authentication azure-identity diff --git a/python/samples/getting_started/durabletask/04_single_agent_orchestration_chaining/requirements.txt b/python/samples/getting_started/durabletask/04_single_agent_orchestration_chaining/requirements.txt index da871507c8..371b9e3b79 100644 --- a/python/samples/getting_started/durabletask/04_single_agent_orchestration_chaining/requirements.txt +++ b/python/samples/getting_started/durabletask/04_single_agent_orchestration_chaining/requirements.txt @@ -1,9 +1,6 @@ -# Agent Framework packages -agent-framework-azure -agent-framework-durabletask - -# Durable Task Python SDK with Azure Managed support -durabletask-python +# Agent Framework packages (installing from local package until a package is published) +-e ../../../../ +-e ../../../../packages/durabletask # Azure authentication azure-identity diff --git a/python/samples/getting_started/durabletask/05_multi_agent_orchestration_concurrency/requirements.txt b/python/samples/getting_started/durabletask/05_multi_agent_orchestration_concurrency/requirements.txt index da871507c8..371b9e3b79 100644 --- a/python/samples/getting_started/durabletask/05_multi_agent_orchestration_concurrency/requirements.txt +++ b/python/samples/getting_started/durabletask/05_multi_agent_orchestration_concurrency/requirements.txt @@ -1,9 +1,6 @@ -# Agent Framework packages -agent-framework-azure -agent-framework-durabletask - -# Durable Task Python SDK with Azure Managed support -durabletask-python +# Agent Framework packages (installing from local package until a package is published) +-e ../../../../ +-e ../../../../packages/durabletask # Azure authentication azure-identity From 760e9750d9af8f150409bb6cd092acd6b798a185 Mon Sep 17 00:00:00 2001 From: Laveesh Rohra Date: Tue, 6 Jan 2026 14:31:23 -0800 Subject: [PATCH 09/11] Fix typo --- .gitignore | 13 +++++++------ .../_orchestration.py | 2 +- .../_durable_agent_state.py | 10 ++++++++++ .../agent_framework_durabletask/_executors.py | 2 +- 4 files changed, 19 insertions(+), 8 deletions(-) diff --git a/.gitignore b/.gitignore index e4ebc70b54..0672e15083 100644 --- a/.gitignore +++ b/.gitignore @@ -208,13 +208,14 @@ WARP.md **/projectBrief.md # Azurite storage emulator files -*/__azurite_db_blob__.json -*/__azurite_db_blob_extent__.json -*/__azurite_db_queue__.json -*/__azurite_db_queue_extent__.json -*/__azurite_db_table__.json +*/__azurite_db_blob__.json* +*/__azurite_db_blob_extent__.json* +*/__azurite_db_queue__.json* +*/__azurite_db_queue_extent__.json* +*/__azurite_db_table__.json* */__blobstorage__/ */__queuestorage__/ +*/AzuriteConfig # Azure Functions local settings local.settings.json @@ -225,4 +226,4 @@ local.settings.json **/frontend/dist/ # Database files -*.db \ No newline at end of file +*.db diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py index d28a2540db..a1061d8ceb 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py @@ -122,7 +122,7 @@ class AzureFunctionsAgentExecutor(DurableAgentExecutor[AgentTask]): def __init__(self, context: AgentOrchestrationContextType): self.context = context - def _generate_unique_id(self) -> str: + def generate_unique_id(self) -> str: return str(self.context.new_uuid()) def get_run_request( diff --git a/python/packages/durabletask/agent_framework_durabletask/_durable_agent_state.py b/python/packages/durabletask/agent_framework_durabletask/_durable_agent_state.py index 3e5ced3ad6..453d180612 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_durable_agent_state.py +++ b/python/packages/durabletask/agent_framework_durabletask/_durable_agent_state.py @@ -1223,14 +1223,24 @@ def from_usage(usage: UsageDetails | None) -> DurableAgentStateUsage | None: input_token_count=usage.input_token_count, output_token_count=usage.output_token_count, total_token_count=usage.total_token_count, + extensionData=usage.additional_counts, ) def to_usage_details(self) -> UsageDetails: # Convert back to AI SDK UsageDetails + extension_data: dict[str, int] = {} + if self.extensionData is not None: + for k, v in self.extensionData.items(): + try: + extension_data[k] = int(v) + except (ValueError, TypeError): + continue + return UsageDetails( input_token_count=self.input_token_count, output_token_count=self.output_token_count, total_token_count=self.total_token_count, + **extension_data, ) diff --git a/python/packages/durabletask/agent_framework_durabletask/_executors.py b/python/packages/durabletask/agent_framework_durabletask/_executors.py index 0a82988c12..92b4440c4a 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_executors.py +++ b/python/packages/durabletask/agent_framework_durabletask/_executors.py @@ -457,7 +457,7 @@ def run_durable_agent( ) # Call the entity and get the underlying task - entity_task: Task[Any] = self._context.call_entity(entity_id, "run", run_request.to_dict()) + entity_task: Task[Any] = self._context.call_entity(entity_id, "run", run_request.to_dict()) # type: ignore # Wrap in DurableAgentTask for response transformation return DurableAgentTask( From 73a2a86a20be34019748f9796fce99fda30a3e43 Mon Sep 17 00:00:00 2001 From: Laveesh Rohra Date: Tue, 6 Jan 2026 16:36:32 -0800 Subject: [PATCH 10/11] Address comments --- .../agent_framework_azurefunctions/_app.py | 15 +++++++-------- .../agent_framework_durabletask/_client.py | 6 +++--- .../agent_framework_durabletask/_executors.py | 11 ++--------- .../_orchestration_context.py | 1 - 4 files changed, 12 insertions(+), 21 deletions(-) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py index fc0a480be3..697a957923 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py @@ -381,8 +381,6 @@ async def http_start(req: func.HttpRequest, client: df.DurableOrchestrationClien "enable_tool_calls": true|false (optional, default: true) } """ - logger.debug(f"[HTTP Trigger] Received request on route: /api/agents/{agent_name}/run") - request_response_format: str = REQUEST_RESPONSE_FORMAT_JSON thread_id: str | None = None @@ -391,9 +389,9 @@ async def http_start(req: func.HttpRequest, client: df.DurableOrchestrationClien thread_id = self._resolve_thread_id(req=req, req_body=req_body) wait_for_response = self._should_wait_for_response(req=req, req_body=req_body) - logger.debug(f"[HTTP Trigger] Message: {message}") - logger.debug(f"[HTTP Trigger] Thread ID: {thread_id}") - logger.debug(f"[HTTP Trigger] wait_for_response: {wait_for_response}") + logger.debug( + f"[HTTP Trigger] Message: {message}, Thread ID: {thread_id}, wait_for_response: {wait_for_response}" + ) if not message: logger.warning("[HTTP Trigger] Request rejected: Missing message") @@ -407,9 +405,10 @@ async def http_start(req: func.HttpRequest, client: df.DurableOrchestrationClien session_id = self._create_session_id(agent_name, thread_id) correlation_id = self._generate_unique_id() - logger.debug(f"[HTTP Trigger] Using session ID: {session_id}") - logger.debug(f"[HTTP Trigger] Generated correlation ID: {correlation_id}") - logger.debug("[HTTP Trigger] Calling entity to run agent...") + logger.debug( + f"[HTTP Trigger] Calling entity to run agent using session ID: {session_id} " + f"and correlation ID: {correlation_id}" + ) entity_instance_id = df.EntityId( name=session_id.entity_name, diff --git a/python/packages/durabletask/agent_framework_durabletask/_client.py b/python/packages/durabletask/agent_framework_durabletask/_client.py index 2ccf269509..5e70a7ba1f 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_client.py +++ b/python/packages/durabletask/agent_framework_durabletask/_client.py @@ -36,10 +36,10 @@ class DurableAIAgentClient(DurableAgentProvider[AgentRunResponse]): agent_client = DurableAIAgentClient(client) # Get an agent reference - agent = await agent_client.get_agent("assistant") + agent = agent_client.get_agent("assistant") - # Run the agent - response = await agent.run("Hello, how are you?") + # Run the agent (synchronous call that waits for completion) + response = agent.run("Hello, how are you?") print(response.text) ``` """ diff --git a/python/packages/durabletask/agent_framework_durabletask/_executors.py b/python/packages/durabletask/agent_framework_durabletask/_executors.py index 92b4440c4a..8ac893b943 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_executors.py +++ b/python/packages/durabletask/agent_framework_durabletask/_executors.py @@ -249,13 +249,6 @@ def _signal_agent_entity( ) self._client.signal_entity(entity_id, "run", run_request.to_dict()) - - logger.info( - "[ClientAgentExecutor] Signaled entity '%s' for correlation: %s", - agent_name, - run_request.correlation_id, - ) - return entity_id def _poll_for_agent_response( @@ -370,12 +363,12 @@ def _poll_entity_for_response( correlation_id: Correlation ID to search for Returns: - Response data dict if found, None otherwise + Response AgentRunResponse, None otherwise """ try: entity_metadata = self._client.get_entity(entity_id, include_state=True) - if entity_metadata is None or not entity_metadata.includes_state: + if entity_metadata is None: return None state_json = entity_metadata.get_state() diff --git a/python/packages/durabletask/agent_framework_durabletask/_orchestration_context.py b/python/packages/durabletask/agent_framework_durabletask/_orchestration_context.py index d9a7ae3c02..f5e1a15238 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_orchestration_context.py +++ b/python/packages/durabletask/agent_framework_durabletask/_orchestration_context.py @@ -30,7 +30,6 @@ class DurableAIAgentOrchestrationContext(DurableAgentProvider[DurableAgentTask]) from agent_framework_durabletask import DurableAIAgentOrchestrationContext - @orchestration def my_orchestration(context: OrchestrationContext): # Wrap the context agent_context = DurableAIAgentOrchestrationContext(context) From 08538f94222148b4c630295f76fbd95fdd896b8d Mon Sep 17 00:00:00 2001 From: Laveesh Rohra Date: Wed, 7 Jan 2026 13:41:47 -0800 Subject: [PATCH 11/11] use response.text --- .../azurefunctions/agent_framework_azurefunctions/_app.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py index 697a957923..99b0c1f7d6 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py @@ -789,9 +789,8 @@ async def _poll_entity_for_response( agent_response = state.try_get_agent_response(correlation_id) if agent_response: - response_message = "\n".join(message.text for message in agent_response.messages if message.text) result = self._build_success_result( - response_message=response_message, + response_message=agent_response.text, message=message, thread_id=thread_id, correlation_id=correlation_id,