diff --git a/.gitignore b/.gitignore index f0f8c09495..0672e15083 100644 --- a/.gitignore +++ b/.gitignore @@ -208,13 +208,14 @@ WARP.md **/projectBrief.md # Azurite storage emulator files -*/__azurite_db_blob__.json -*/__azurite_db_blob_extent__.json -*/__azurite_db_queue__.json -*/__azurite_db_queue_extent__.json -*/__azurite_db_table__.json +*/__azurite_db_blob__.json* +*/__azurite_db_blob_extent__.json* +*/__azurite_db_queue__.json* +*/__azurite_db_queue_extent__.json* +*/__azurite_db_table__.json* */__blobstorage__/ */__queuestorage__/ +*/AzuriteConfig # Azure Functions local settings local.settings.json diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/__init__.py b/python/packages/azurefunctions/agent_framework_azurefunctions/__init__.py index 960b8d4023..e5be2aa36e 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/__init__.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/__init__.py @@ -2,10 +2,9 @@ import importlib.metadata -from agent_framework_durabletask import AgentCallbackContext, AgentResponseCallbackProtocol +from agent_framework_durabletask import AgentCallbackContext, AgentResponseCallbackProtocol, DurableAIAgent from ._app import AgentFunctionApp -from ._orchestration import DurableAIAgent try: __version__ = importlib.metadata.version(__name__) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py index 80f2b6dec9..99b0c1f7d6 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py @@ -8,6 +8,7 @@ import json import re +import uuid from collections.abc import Callable, Mapping from dataclasses import dataclass from datetime import datetime, timezone @@ -28,14 +29,16 @@ WAIT_FOR_RESPONSE_FIELD, WAIT_FOR_RESPONSE_HEADER, AgentResponseCallbackProtocol, + AgentSessionId, + ApiResponseFields, DurableAgentState, + DurableAIAgent, RunRequest, ) from ._entities import create_agent_entity from ._errors import IncomingRequestError -from ._models import AgentSessionId -from ._orchestration import AgentOrchestrationContextType, DurableAIAgent +from ._orchestration import AgentOrchestrationContextType, AgentTask, AzureFunctionsAgentExecutor logger = get_logger("agent_framework.azurefunctions") @@ -296,7 +299,7 @@ def get_agent( self, context: AgentOrchestrationContextType, agent_name: str, - ) -> DurableAIAgent: + ) -> DurableAIAgent[AgentTask]: """Return a DurableAIAgent proxy for a registered agent. Args: @@ -307,14 +310,15 @@ def get_agent( ValueError: If the requested agent has not been registered. Returns: - DurableAIAgent wrapper bound to the orchestration context. + DurableAIAgent[AgentTask] wrapper bound to the orchestration context. """ normalized_name = str(agent_name) if normalized_name not in self._agent_metadata: raise ValueError(f"Agent '{normalized_name}' is not registered with this app.") - return DurableAIAgent(context, normalized_name) + executor = AzureFunctionsAgentExecutor(context) + return DurableAIAgent(executor, normalized_name) def _setup_agent_functions( self, @@ -377,8 +381,6 @@ async def http_start(req: func.HttpRequest, client: df.DurableOrchestrationClien "enable_tool_calls": true|false (optional, default: true) } """ - logger.debug(f"[HTTP Trigger] Received request on route: /api/agents/{agent_name}/run") - request_response_format: str = REQUEST_RESPONSE_FORMAT_JSON thread_id: str | None = None @@ -387,9 +389,9 @@ async def http_start(req: func.HttpRequest, client: df.DurableOrchestrationClien thread_id = self._resolve_thread_id(req=req, req_body=req_body) wait_for_response = self._should_wait_for_response(req=req, req_body=req_body) - logger.debug(f"[HTTP Trigger] Message: {message}") - logger.debug(f"[HTTP Trigger] Thread ID: {thread_id}") - logger.debug(f"[HTTP Trigger] wait_for_response: {wait_for_response}") + logger.debug( + f"[HTTP Trigger] Message: {message}, Thread ID: {thread_id}, wait_for_response: {wait_for_response}" + ) if not message: logger.warning("[HTTP Trigger] Request rejected: Missing message") @@ -403,15 +405,18 @@ async def http_start(req: func.HttpRequest, client: df.DurableOrchestrationClien session_id = self._create_session_id(agent_name, thread_id) correlation_id = self._generate_unique_id() - logger.debug(f"[HTTP Trigger] Using session ID: {session_id}") - logger.debug(f"[HTTP Trigger] Generated correlation ID: {correlation_id}") - logger.debug("[HTTP Trigger] Calling entity to run agent...") + logger.debug( + f"[HTTP Trigger] Calling entity to run agent using session ID: {session_id} " + f"and correlation ID: {correlation_id}" + ) - entity_instance_id = session_id.to_entity_id() + entity_instance_id = df.EntityId( + name=session_id.entity_name, + key=session_id.key, + ) run_request = self._build_request_data( req_body, message, - thread_id, correlation_id, request_response_format, ) @@ -624,14 +629,16 @@ async def _handle_mcp_tool_invocation( session_id = AgentSessionId.with_random_key(agent_name) # Build entity instance ID - entity_instance_id = session_id.to_entity_id() + entity_instance_id = df.EntityId( + name=session_id.entity_name, + key=session_id.key, + ) # Create run request correlation_id = self._generate_unique_id() run_request = self._build_request_data( req_body={"message": query, "role": "user"}, message=query, - thread_id=str(session_id), correlation_id=correlation_id, request_response_format=REQUEST_RESPONSE_FORMAT_TEXT, ) @@ -783,7 +790,7 @@ async def _poll_entity_for_response( agent_response = state.try_get_agent_response(correlation_id) if agent_response: result = self._build_success_result( - response_data=agent_response, + response_message=agent_response.text, message=message, thread_id=thread_id, correlation_id=correlation_id, @@ -829,23 +836,22 @@ async def _build_timeout_result(self, message: str, thread_id: str, correlation_ ) def _build_success_result( - self, response_data: dict[str, Any], message: str, thread_id: str, correlation_id: str, state: DurableAgentState + self, response_message: str, message: str, thread_id: str, correlation_id: str, state: DurableAgentState ) -> dict[str, Any]: """Build the success result returned to the HTTP caller.""" return self._build_response_payload( - response=response_data.get("content"), + response=response_message, message=message, thread_id=thread_id, status="success", correlation_id=correlation_id, - extra_fields={"message_count": response_data.get("message_count", state.message_count)}, + extra_fields={ApiResponseFields.MESSAGE_COUNT: state.message_count}, ) def _build_request_data( self, req_body: dict[str, Any], message: str, - thread_id: str, correlation_id: str, request_response_format: str, ) -> dict[str, Any]: @@ -912,15 +918,13 @@ def _convert_payload_to_text(self, payload: dict[str, Any]) -> str: def _generate_unique_id(self) -> str: """Generate a new unique identifier.""" - import uuid - return uuid.uuid4().hex - def _create_session_id(self, func_name: str, thread_id: str | None) -> AgentSessionId: + def _create_session_id(self, agent_name: str, thread_id: str | None) -> AgentSessionId: """Create a session identifier using the provided thread id or a random value.""" if thread_id: - return AgentSessionId(name=func_name, key=thread_id) - return AgentSessionId.with_random_key(name=func_name) + return AgentSessionId(name=agent_name, key=thread_id) + return AgentSessionId.with_random_key(name=agent_name) def _resolve_thread_id(self, req: func.HttpRequest, req_body: dict[str, Any]) -> str: """Retrieve the thread identifier from request body or query parameters.""" diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_models.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_models.py deleted file mode 100644 index 4b069686df..0000000000 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_models.py +++ /dev/null @@ -1,201 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -"""Azure Functions-specific data models for Durable Agent Framework. - -This module contains Azure Functions-specific models: -- AgentSessionId: Entity ID management for Azure Durable Entities -- DurableAgentThread: Thread implementation that tracks AgentSessionId - -Common models like RunRequest have been moved to agent-framework-durabletask. -""" - -from __future__ import annotations - -import uuid -from collections.abc import MutableMapping -from dataclasses import dataclass -from typing import Any - -import azure.durable_functions as df -from agent_framework import AgentThread - - -@dataclass -class AgentSessionId: - """Represents an agent session ID, which is used to identify a long-running agent session. - - Attributes: - name: The name of the agent that owns the session (case-insensitive) - key: The unique key of the agent session (case-sensitive) - """ - - name: str - key: str - - ENTITY_NAME_PREFIX: str = "dafx-" - - @staticmethod - def to_entity_name(name: str) -> str: - """Converts an agent name to an entity name by adding the DAFx prefix. - - Args: - name: The agent name - - Returns: - The entity name with the dafx- prefix - """ - return f"{AgentSessionId.ENTITY_NAME_PREFIX}{name}" - - @staticmethod - def with_random_key(name: str) -> AgentSessionId: - """Creates a new AgentSessionId with the specified name and a randomly generated key. - - Args: - name: The name of the agent that owns the session - - Returns: - A new AgentSessionId with the specified name and a random GUID key - """ - return AgentSessionId(name=name, key=uuid.uuid4().hex) - - def to_entity_id(self) -> df.EntityId: - """Converts this AgentSessionId to a Durable Functions EntityId. - - Returns: - EntityId for use with Durable Functions APIs - """ - return df.EntityId(self.to_entity_name(self.name), self.key) - - @staticmethod - def from_entity_id(entity_id: df.EntityId) -> AgentSessionId: - """Creates an AgentSessionId from a Durable Functions EntityId. - - Args: - entity_id: The EntityId to convert - - Returns: - AgentSessionId instance - - Raises: - ValueError: If the entity ID does not have the expected prefix - """ - if not entity_id.name.startswith(AgentSessionId.ENTITY_NAME_PREFIX): - raise ValueError( - f"'{entity_id}' is not a valid agent session ID. " - f"Expected entity name to start with '{AgentSessionId.ENTITY_NAME_PREFIX}'" - ) - - agent_name = entity_id.name[len(AgentSessionId.ENTITY_NAME_PREFIX) :] - return AgentSessionId(name=agent_name, key=entity_id.key) - - def __str__(self) -> str: - """Returns a string representation in the form @name@key.""" - return f"@{self.name}@{self.key}" - - def __repr__(self) -> str: - """Returns a detailed string representation.""" - return f"AgentSessionId(name='{self.name}', key='{self.key}')" - - @staticmethod - def parse(session_id_string: str) -> AgentSessionId: - """Parses a string representation of an agent session ID. - - Args: - session_id_string: A string in the form @name@key - - Returns: - AgentSessionId instance - - Raises: - ValueError: If the string format is invalid - """ - if not session_id_string.startswith("@"): - raise ValueError(f"Invalid agent session ID format: {session_id_string}") - - parts = session_id_string[1:].split("@", 1) - if len(parts) != 2: - raise ValueError(f"Invalid agent session ID format: {session_id_string}") - - return AgentSessionId(name=parts[0], key=parts[1]) - - -class DurableAgentThread(AgentThread): - """Durable agent thread that tracks the owning :class:`AgentSessionId`.""" - - _SERIALIZED_SESSION_ID_KEY = "durable_session_id" - - def __init__( - self, - *, - session_id: AgentSessionId | None = None, - service_thread_id: str | None = None, - message_store: Any = None, - context_provider: Any = None, - ) -> None: - super().__init__( - service_thread_id=service_thread_id, - message_store=message_store, - context_provider=context_provider, - ) - self._session_id: AgentSessionId | None = session_id - - @property - def session_id(self) -> AgentSessionId | None: - """Returns the durable agent session identifier for this thread.""" - return self._session_id - - def attach_session(self, session_id: AgentSessionId) -> None: - """Associates the thread with the provided :class:`AgentSessionId`.""" - self._session_id = session_id - - @classmethod - def from_session_id( - cls, - session_id: AgentSessionId, - *, - service_thread_id: str | None = None, - message_store: Any = None, - context_provider: Any = None, - ) -> DurableAgentThread: - """Creates a durable thread pre-associated with the supplied session ID.""" - return cls( - session_id=session_id, - service_thread_id=service_thread_id, - message_store=message_store, - context_provider=context_provider, - ) - - async def serialize(self, **kwargs: Any) -> dict[str, Any]: - """Serializes thread state including the durable session identifier.""" - state = await super().serialize(**kwargs) - if self._session_id is not None: - state[self._SERIALIZED_SESSION_ID_KEY] = str(self._session_id) - return state - - @classmethod - async def deserialize( - cls, - serialized_thread_state: MutableMapping[str, Any], - *, - message_store: Any = None, - **kwargs: Any, - ) -> DurableAgentThread: - """Restores a durable thread, rehydrating the stored session identifier.""" - state_payload = dict(serialized_thread_state) - session_id_value = state_payload.pop(cls._SERIALIZED_SESSION_ID_KEY, None) - thread = await super().deserialize( - state_payload, - message_store=message_store, - **kwargs, - ) - if not isinstance(thread, DurableAgentThread): - raise TypeError("Deserialized thread is not a DurableAgentThread instance") - - if session_id_value is None: - return thread - - if not isinstance(session_id_value, str): - raise ValueError("durable_session_id must be a string when present in serialized state") - - thread.attach_session(AgentSessionId.parse(session_id_value)) - return thread diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py index cf91132916..a1061d8ceb 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py @@ -5,25 +5,21 @@ This module provides support for using agents inside Durable Function orchestrations. """ -import uuid -from collections.abc import AsyncIterator, Callable -from typing import TYPE_CHECKING, Any, TypeAlias, cast - -from agent_framework import ( - AgentProtocol, - AgentRunResponse, - AgentRunResponseUpdate, - AgentThread, - ChatMessage, - get_logger, +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, TypeAlias + +import azure.durable_functions as df +from agent_framework import AgentThread, get_logger +from agent_framework_durabletask import ( + DurableAgentExecutor, + RunRequest, + ensure_response_format, + load_agent_response, ) -from agent_framework_durabletask import RunRequest from azure.durable_functions.models import TaskBase from azure.durable_functions.models.Task import CompoundTask, TaskState from pydantic import BaseModel -from ._models import AgentSessionId, DurableAgentThread - logger = get_logger("agent_framework.azurefunctions.orchestration") CompoundActionConstructor: TypeAlias = Callable[[list[Any]], Any] | None @@ -96,10 +92,10 @@ def try_set_value(self, child: TaskBase) -> None: ) try: - response = self._load_agent_response(raw_result) + response = load_agent_response(raw_result) if self._response_format is not None: - self._ensure_response_format( + ensure_response_format( self._response_format, self._correlation_id, response, @@ -119,249 +115,60 @@ def try_set_value(self, child: TaskBase) -> None: self._first_error = child.result self.set_value(is_error=True, value=self._first_error) - def _load_agent_response(self, agent_response: AgentRunResponse | dict[str, Any] | None) -> AgentRunResponse: - """Convert raw payloads into AgentRunResponse instance.""" - if agent_response is None: - raise ValueError("agent_response cannot be None") - logger.debug("[load_agent_response] Loading agent response of type: %s", type(agent_response)) +class AzureFunctionsAgentExecutor(DurableAgentExecutor[AgentTask]): + """Executor that executes durable agents inside Azure Functions orchestrations.""" - if isinstance(agent_response, AgentRunResponse): - return agent_response - if isinstance(agent_response, dict): - logger.debug("[load_agent_response] Converting dict payload using AgentRunResponse.from_dict") - return AgentRunResponse.from_dict(agent_response) + def __init__(self, context: AgentOrchestrationContextType): + self.context = context - raise TypeError(f"Unsupported type for agent_response: {type(agent_response)}") + def generate_unique_id(self) -> str: + return str(self.context.new_uuid()) - def _ensure_response_format( + def get_run_request( self, + message: str, response_format: type[BaseModel] | None, - correlation_id: str, - response: AgentRunResponse, - ) -> None: - """Ensure the AgentRunResponse value is parsed into the expected response_format.""" - if response_format is not None and not isinstance(response.value, response_format): - response.try_parse_value(response_format) - - logger.debug( - "[DurableAIAgent] Loaded AgentRunResponse.value for correlation_id %s with type: %s", - correlation_id, - type(response.value).__name__, - ) - - -class DurableAIAgent(AgentProtocol): - """A durable agent implementation that uses entity methods to interact with agent entities. - - This class implements AgentProtocol and provides methods to work with Azure Durable Functions - orchestrations, which use generators and yield instead of async/await. + enable_tool_calls: bool, + ) -> RunRequest: + """Get the current run request from the orchestration context. - Key methods: - - get_new_thread(): Create a new conversation thread - - run(): Execute the agent and return a Task for yielding in orchestrations - - Note: The run() method is NOT async. It returns a Task directly that must be - yielded in orchestrations to wait for the entity call to complete. - - Example usage in orchestration: - writer = app.get_agent(context, "WriterAgent") - thread = writer.get_new_thread() # NOT yielded - returns immediately - - response = yield writer.run( # Yielded - waits for entity call - message="Write a haiku about coding", - thread=thread + Returns: + RunRequest: The current run request + """ + request = super().get_run_request( + message, + response_format, + enable_tool_calls, ) - """ + request.orchestration_id = self.context.instance_id + return request - def __init__(self, context: AgentOrchestrationContextType, agent_name: str): - """Initialize the DurableAIAgent. - - Args: - context: The orchestration context - agent_name: Name of the agent (used to construct entity ID) - """ - self.context = context - self.agent_name = agent_name - self._id = str(uuid.uuid4()) - self._name = agent_name - self._display_name = agent_name - self._description = f"Durable agent proxy for {agent_name}" - logger.debug("[DurableAIAgent] Initialized for agent: %s", agent_name) - - @property - def id(self) -> str: - """Get the unique identifier for this agent.""" - return self._id - - @property - def name(self) -> str | None: - """Get the name of the agent.""" - return self._name - - @property - def display_name(self) -> str: - """Get the display name of the agent.""" - return self._display_name - - @property - def description(self) -> str | None: - """Get the description of the agent.""" - return self._description - - # We return an AgentTask here which is a TaskBase subclass. - # This is an intentional deviation from AgentProtocol which defines run() as async. - # The AgentTask can be yielded in Durable Functions orchestrations and will provide - # a typed AgentRunResponse result. - def run( # type: ignore[override] + def run_durable_agent( self, - messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, - *, + agent_name: str, + run_request: RunRequest, thread: AgentThread | None = None, - response_format: type[BaseModel] | None = None, - **kwargs: Any, ) -> AgentTask: - """Execute the agent with messages and return an AgentTask for orchestrations. - This method implements AgentProtocol and returns an AgentTask (subclass of TaskBase) - that can be yielded in Durable Functions orchestrations. The task's result will be - a typed AgentRunResponse. - - Args: - messages: The message(s) to send to the agent - thread: Optional agent thread for conversation context - response_format: Optional Pydantic model for response parsing - **kwargs: Additional arguments (enable_tool_calls) + # Resolve session + session_id = self._create_session_id(agent_name, thread) - Returns: - An AgentTask that resolves to an AgentRunResponse when yielded - - Example: - @app.orchestration_trigger(context_name="context") - def my_orchestration(context): - agent = app.get_agent(context, "MyAgent") - thread = agent.get_new_thread() - response = yield agent.run("Hello", thread=thread) - # response is typed as AgentRunResponse - """ - message_str = self._normalize_messages(messages) - - # Extract optional parameters from kwargs - enable_tool_calls = kwargs.get("enable_tool_calls", True) + entity_id = df.EntityId( + name=session_id.entity_name, + key=session_id.key, + ) - # Get the session ID for the entity - if isinstance(thread, DurableAgentThread) and thread.session_id is not None: - session_id = thread.session_id - else: - # Create a unique session ID for each call when no thread is provided - # This ensures each call gets its own conversation context - session_key = str(self.context.new_uuid()) - session_id = AgentSessionId(name=self.agent_name, key=session_key) - logger.debug("[DurableAIAgent] No thread provided, created unique session_id: %s", session_id) - - # Create entity ID from session ID - entity_id = session_id.to_entity_id() - - # Generate a deterministic correlation ID for this call - # This is required by the entity and must be unique per call - correlation_id = str(self.context.new_uuid()) logger.debug( - "[DurableAIAgent] Using correlation_id: %s for entity_id: %s for session_id: %s", - correlation_id, + "[AzureFunctionsAgentProvider] correlation_id: %s entity_id: %s session_id: %s", + run_request.correlation_id, entity_id, session_id, ) - # Prepare the request using RunRequest model - # Include the orchestration's instance_id so it can be stored in the agent's entity state - run_request = RunRequest( - message=message_str, - enable_tool_calls=enable_tool_calls, - correlation_id=correlation_id, - response_format=response_format, - orchestration_id=self.context.instance_id, - created_at=self.context.current_utc_datetime, - ) - - logger.debug("[DurableAIAgent] Calling entity %s with message: %s", entity_id, message_str[:100]) - - # Call the entity to get the underlying task entity_task = self.context.call_entity(entity_id, "run", run_request.to_dict()) - - # Wrap it in an AgentTask that will convert the result to AgentRunResponse - agent_task = AgentTask( + return AgentTask( entity_task=entity_task, - response_format=response_format, - correlation_id=correlation_id, + response_format=run_request.response_format, + correlation_id=run_request.correlation_id, ) - - logger.debug( - "[DurableAIAgent] Created AgentTask for correlation_id %s", - correlation_id, - ) - - return agent_task - - def run_stream( - self, - messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, - *, - thread: AgentThread | None = None, - **kwargs: Any, - ) -> AsyncIterator[AgentRunResponseUpdate]: - """Run the agent with streaming (not supported for durable agents). - - Raises: - NotImplementedError: Streaming is not supported for durable agents. - """ - raise NotImplementedError("Streaming is not supported for durable agents in orchestrations.") - - def get_new_thread(self, **kwargs: Any) -> AgentThread: - """Create a new agent thread for this orchestration instance. - - Each call creates a unique thread with its own conversation context. - The session ID is deterministic (uses context.new_uuid()) to ensure - orchestration replay works correctly. - - Returns: - A new AgentThread instance with a unique session ID - """ - # Generate a deterministic unique key for this thread - # Using context.new_uuid() ensures the same GUID is generated during replay - session_key = str(self.context.new_uuid()) - - # Create AgentSessionId with agent name and session key - session_id = AgentSessionId(name=self.agent_name, key=session_key) - - thread = DurableAgentThread.from_session_id(session_id, **kwargs) - - logger.debug("[DurableAIAgent] Created new thread with session_id: %s", session_id) - return thread - - def _messages_to_string(self, messages: list[ChatMessage]) -> str: - """Convert a list of ChatMessage objects to a single string. - - Args: - messages: List of ChatMessage objects - - Returns: - Concatenated string of message contents - """ - return "\n".join([msg.text or "" for msg in messages]) - - def _normalize_messages(self, messages: str | ChatMessage | list[str] | list[ChatMessage] | None) -> str: - """Convert supported message inputs to a single string.""" - if messages is None: - return "" - if isinstance(messages, str): - return messages - if isinstance(messages, ChatMessage): - return messages.text or "" - if isinstance(messages, list): - if not messages: - return "" - first_item = messages[0] - if isinstance(first_item, str): - return "\n".join(cast(list[str], messages)) - return self._messages_to_string(cast(list[ChatMessage], messages)) - return str(messages) diff --git a/python/packages/azurefunctions/tests/test_models.py b/python/packages/azurefunctions/tests/test_models.py deleted file mode 100644 index 5f4dc47080..0000000000 --- a/python/packages/azurefunctions/tests/test_models.py +++ /dev/null @@ -1,402 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -"""Unit tests for data models (AgentSessionId, RunRequest, AgentResponse).""" - -import azure.durable_functions as df -import pytest -from agent_framework import Role -from agent_framework_durabletask import RunRequest -from pydantic import BaseModel - -from agent_framework_azurefunctions._models import AgentSessionId - - -class ModuleStructuredResponse(BaseModel): - value: int - - -class TestAgentSessionId: - """Test suite for AgentSessionId.""" - - def test_init_creates_session_id(self) -> None: - """Test that AgentSessionId initializes correctly.""" - session_id = AgentSessionId(name="AgentEntity", key="test-key-123") - - assert session_id.name == "AgentEntity" - assert session_id.key == "test-key-123" - - def test_with_random_key_generates_guid(self) -> None: - """Test that with_random_key generates a GUID.""" - session_id = AgentSessionId.with_random_key(name="AgentEntity") - - assert session_id.name == "AgentEntity" - assert len(session_id.key) == 32 # UUID hex is 32 chars - # Verify it's a valid hex string - int(session_id.key, 16) - - def test_with_random_key_unique_keys(self) -> None: - """Test that with_random_key generates unique keys.""" - session_id1 = AgentSessionId.with_random_key(name="AgentEntity") - session_id2 = AgentSessionId.with_random_key(name="AgentEntity") - - assert session_id1.key != session_id2.key - - def test_to_entity_id_conversion(self) -> None: - """Test conversion to EntityId.""" - session_id = AgentSessionId(name="AgentEntity", key="test-key") - entity_id = session_id.to_entity_id() - - assert isinstance(entity_id, df.EntityId) - assert entity_id.name == "dafx-AgentEntity" - assert entity_id.key == "test-key" - - def test_from_entity_id_conversion(self) -> None: - """Test creation from EntityId.""" - entity_id = df.EntityId(name="dafx-AgentEntity", key="test-key") - session_id = AgentSessionId.from_entity_id(entity_id) - - assert isinstance(session_id, AgentSessionId) - assert session_id.name == "AgentEntity" - assert session_id.key == "test-key" - - def test_round_trip_entity_id_conversion(self) -> None: - """Test round-trip conversion to and from EntityId.""" - original = AgentSessionId(name="AgentEntity", key="test-key") - entity_id = original.to_entity_id() - restored = AgentSessionId.from_entity_id(entity_id) - - assert restored.name == original.name - assert restored.key == original.key - - def test_str_representation(self) -> None: - """Test string representation.""" - session_id = AgentSessionId(name="AgentEntity", key="test-key-123") - str_repr = str(session_id) - - assert str_repr == "@AgentEntity@test-key-123" - - def test_repr_representation(self) -> None: - """Test repr representation.""" - session_id = AgentSessionId(name="AgentEntity", key="test-key") - repr_str = repr(session_id) - - assert "AgentSessionId" in repr_str - assert "AgentEntity" in repr_str - assert "test-key" in repr_str - - def test_parse_valid_session_id(self) -> None: - """Test parsing valid session ID string.""" - session_id = AgentSessionId.parse("@AgentEntity@test-key-123") - - assert session_id.name == "AgentEntity" - assert session_id.key == "test-key-123" - - def test_parse_invalid_format_no_prefix(self) -> None: - """Test parsing invalid format without @ prefix.""" - with pytest.raises(ValueError) as exc_info: - AgentSessionId.parse("AgentEntity@test-key") - - assert "Invalid agent session ID format" in str(exc_info.value) - - def test_parse_invalid_format_single_part(self) -> None: - """Test parsing invalid format with single part.""" - with pytest.raises(ValueError) as exc_info: - AgentSessionId.parse("@AgentEntity") - - assert "Invalid agent session ID format" in str(exc_info.value) - - def test_parse_with_multiple_at_signs_in_key(self) -> None: - """Test parsing with @ signs in the key.""" - session_id = AgentSessionId.parse("@AgentEntity@key-with@symbols") - - assert session_id.name == "AgentEntity" - assert session_id.key == "key-with@symbols" - - def test_parse_round_trip(self) -> None: - """Test round-trip parse and string conversion.""" - original = AgentSessionId(name="AgentEntity", key="test-key") - str_repr = str(original) - parsed = AgentSessionId.parse(str_repr) - - assert parsed.name == original.name - assert parsed.key == original.key - - def test_to_entity_name_adds_prefix(self) -> None: - """Test that to_entity_name adds the dafx- prefix.""" - entity_name = AgentSessionId.to_entity_name("TestAgent") - assert entity_name == "dafx-TestAgent" - - def test_from_entity_id_strips_prefix(self) -> None: - """Test that from_entity_id strips the dafx- prefix.""" - entity_id = df.EntityId(name="dafx-TestAgent", key="key123") - session_id = AgentSessionId.from_entity_id(entity_id) - - assert session_id.name == "TestAgent" - assert session_id.key == "key123" - - def test_from_entity_id_raises_without_prefix(self) -> None: - """Test that from_entity_id raises ValueError when entity name lacks the prefix.""" - entity_id = df.EntityId(name="TestAgent", key="key123") - - with pytest.raises(ValueError) as exc_info: - AgentSessionId.from_entity_id(entity_id) - - assert "not a valid agent session ID" in str(exc_info.value) - assert "dafx-" in str(exc_info.value) - - -class TestRunRequest: - """Test suite for RunRequest.""" - - def test_init_with_defaults(self) -> None: - """Test RunRequest initialization with defaults.""" - request = RunRequest(message="Hello") - - assert request.message == "Hello" - assert request.role == Role.USER - assert request.response_format is None - assert request.enable_tool_calls is True - - def test_init_with_all_fields(self) -> None: - """Test RunRequest initialization with all fields.""" - schema = ModuleStructuredResponse - request = RunRequest( - message="Hello", - role=Role.SYSTEM, - response_format=schema, - enable_tool_calls=False, - ) - - assert request.message == "Hello" - assert request.role == Role.SYSTEM - assert request.response_format is schema - assert request.enable_tool_calls is False - - def test_init_coerces_string_role(self) -> None: - """Ensure string role values are coerced into Role instances.""" - request = RunRequest(message="Hello", role="system") # type: ignore[arg-type] - - assert request.role == Role.SYSTEM - - def test_to_dict_with_defaults(self) -> None: - """Test to_dict with default values.""" - request = RunRequest(message="Test message") - data = request.to_dict() - - assert data["message"] == "Test message" - assert data["enable_tool_calls"] is True - assert data["role"] == "user" - assert "response_format" not in data or data["response_format"] is None - assert "thread_id" not in data - - def test_to_dict_with_all_fields(self) -> None: - """Test to_dict with all fields.""" - schema = ModuleStructuredResponse - request = RunRequest( - message="Hello", - role=Role.ASSISTANT, - response_format=schema, - enable_tool_calls=False, - ) - data = request.to_dict() - - assert data["message"] == "Hello" - assert data["role"] == "assistant" - assert data["response_format"]["__response_schema_type__"] == "pydantic_model" - assert data["response_format"]["module"] == schema.__module__ - assert data["response_format"]["qualname"] == schema.__qualname__ - assert data["enable_tool_calls"] is False - assert "thread_id" not in data - - def test_from_dict_with_defaults(self) -> None: - """Test from_dict with minimal data.""" - data = {"message": "Hello"} - request = RunRequest.from_dict(data) - - assert request.message == "Hello" - assert request.role == Role.USER - assert request.enable_tool_calls is True - - def test_from_dict_ignores_thread_id_field(self) -> None: - """Ensure legacy thread_id input does not break RunRequest parsing.""" - request = RunRequest.from_dict({"message": "Hello", "thread_id": "ignored"}) - - assert request.message == "Hello" - - def test_from_dict_with_all_fields(self) -> None: - """Test from_dict with all fields.""" - data = { - "message": "Test", - "role": "system", - "response_format": { - "__response_schema_type__": "pydantic_model", - "module": ModuleStructuredResponse.__module__, - "qualname": ModuleStructuredResponse.__qualname__, - }, - "enable_tool_calls": False, - } - request = RunRequest.from_dict(data) - - assert request.message == "Test" - assert request.role == Role.SYSTEM - assert request.response_format is ModuleStructuredResponse - assert request.enable_tool_calls is False - - def test_from_dict_with_unknown_role_preserves_value(self) -> None: - """Test from_dict keeps custom roles intact.""" - data = {"message": "Test", "role": "reviewer"} - request = RunRequest.from_dict(data) - - assert request.role.value == "reviewer" - assert request.role != Role.USER - - def test_from_dict_empty_message(self) -> None: - """Test from_dict with empty message.""" - request = RunRequest.from_dict({}) - - assert request.message == "" - assert request.role == Role.USER - - def test_round_trip_dict_conversion(self) -> None: - """Test round-trip to_dict and from_dict.""" - original = RunRequest( - message="Test message", - role=Role.SYSTEM, - response_format=ModuleStructuredResponse, - enable_tool_calls=False, - ) - - data = original.to_dict() - restored = RunRequest.from_dict(data) - - assert restored.message == original.message - assert restored.role == original.role - assert restored.response_format is ModuleStructuredResponse - assert restored.enable_tool_calls == original.enable_tool_calls - - def test_round_trip_with_pydantic_response_format(self) -> None: - """Ensure Pydantic response formats serialize and deserialize properly.""" - original = RunRequest( - message="Structured", - response_format=ModuleStructuredResponse, - ) - - data = original.to_dict() - - assert data["response_format"]["__response_schema_type__"] == "pydantic_model" - assert data["response_format"]["module"] == ModuleStructuredResponse.__module__ - assert data["response_format"]["qualname"] == ModuleStructuredResponse.__qualname__ - - restored = RunRequest.from_dict(data) - assert restored.response_format is ModuleStructuredResponse - - def test_init_with_correlationId(self) -> None: - """Test RunRequest initialization with correlationId.""" - request = RunRequest(message="Test message", correlation_id="corr-123") - - assert request.message == "Test message" - assert request.correlation_id == "corr-123" - - def test_to_dict_with_correlationId(self) -> None: - """Test to_dict includes correlationId.""" - request = RunRequest(message="Test", correlation_id="corr-456") - data = request.to_dict() - - assert data["message"] == "Test" - assert data["correlationId"] == "corr-456" - - def test_from_dict_with_correlationId(self) -> None: - """Test from_dict with correlationId.""" - data = {"message": "Test", "correlationId": "corr-789"} - request = RunRequest.from_dict(data) - - assert request.message == "Test" - assert request.correlation_id == "corr-789" - - def test_round_trip_with_correlationId(self) -> None: - """Test round-trip to_dict and from_dict with correlationId.""" - original = RunRequest( - message="Test message", - role=Role.SYSTEM, - correlation_id="corr-123", - ) - - data = original.to_dict() - restored = RunRequest.from_dict(data) - - assert restored.message == original.message - assert restored.role == original.role - assert restored.correlation_id == original.correlation_id - - def test_init_with_orchestration_id(self) -> None: - """Test RunRequest initialization with orchestration_id.""" - request = RunRequest( - message="Test message", - orchestration_id="orch-123", - ) - - assert request.message == "Test message" - assert request.orchestration_id == "orch-123" - - def test_to_dict_with_orchestration_id(self) -> None: - """Test to_dict includes orchestrationId.""" - request = RunRequest( - message="Test", - orchestration_id="orch-456", - ) - data = request.to_dict() - - assert data["message"] == "Test" - assert data["orchestrationId"] == "orch-456" - - def test_to_dict_excludes_orchestration_id_when_none(self) -> None: - """Test to_dict excludes orchestrationId when not set.""" - request = RunRequest( - message="Test", - ) - data = request.to_dict() - - assert "orchestrationId" not in data - - def test_from_dict_with_orchestration_id(self) -> None: - """Test from_dict with orchestrationId.""" - data = { - "message": "Test", - "orchestrationId": "orch-789", - } - request = RunRequest.from_dict(data) - - assert request.message == "Test" - assert request.orchestration_id == "orch-789" - - def test_round_trip_with_orchestration_id(self) -> None: - """Test round-trip to_dict and from_dict with orchestration_id.""" - original = RunRequest( - message="Test message", - role=Role.SYSTEM, - correlation_id="corr-123", - orchestration_id="orch-123", - ) - - data = original.to_dict() - restored = RunRequest.from_dict(data) - - assert restored.message == original.message - assert restored.role == original.role - assert restored.correlation_id == original.correlation_id - assert restored.orchestration_id == original.orchestration_id - - -class TestModelIntegration: - """Test suite for integration between models.""" - - def test_run_request_with_session_id_string(self) -> None: - """AgentSessionId string can still be used by callers, but is not stored on RunRequest.""" - session_id = AgentSessionId.with_random_key("AgentEntity") - session_id_str = str(session_id) - - assert session_id_str.startswith("@AgentEntity@") - - -if __name__ == "__main__": - pytest.main([__file__, "-v", "--tb=short"]) diff --git a/python/packages/azurefunctions/tests/test_orchestration.py b/python/packages/azurefunctions/tests/test_orchestration.py index 98f70af95a..3da03f12be 100644 --- a/python/packages/azurefunctions/tests/test_orchestration.py +++ b/python/packages/azurefunctions/tests/test_orchestration.py @@ -6,11 +6,11 @@ from unittest.mock import Mock import pytest -from agent_framework import AgentRunResponse, AgentThread, ChatMessage +from agent_framework import AgentRunResponse, ChatMessage +from agent_framework_durabletask import DurableAIAgent from azure.durable_functions.models.Task import TaskBase, TaskState -from agent_framework_azurefunctions import AgentFunctionApp, DurableAIAgent -from agent_framework_azurefunctions._models import AgentSessionId, DurableAgentThread +from agent_framework_azurefunctions import AgentFunctionApp from agent_framework_azurefunctions._orchestration import AgentTask @@ -38,46 +38,96 @@ def _create_entity_task(task_id: int = 1) -> TaskBase: return _FakeTask(task_id) -class TestAgentResponseHelpers: - """Tests for helper utilities that prepare AgentRunResponse values.""" +@pytest.fixture +def mock_context(): + """Create a mock orchestration context with UUID support.""" + context = Mock() + context.instance_id = "test-instance" + context.current_utc_datetime = Mock() + return context - @staticmethod - def _create_agent_task() -> AgentTask: - entity_task = _create_entity_task() - return AgentTask(entity_task, None, "correlation-id") - def test_load_agent_response_from_instance(self) -> None: - task = self._create_agent_task() - response = AgentRunResponse(messages=[ChatMessage(role="assistant", text='{"foo": "bar"}')]) +@pytest.fixture +def mock_context_with_uuid() -> tuple[Mock, str]: + """Create a mock context with a single UUID.""" + from uuid import UUID + + context = Mock() + context.instance_id = "test-instance" + context.current_utc_datetime = Mock() + test_uuid = UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa") + context.new_uuid = Mock(return_value=test_uuid) + return context, test_uuid.hex + + +@pytest.fixture +def mock_context_with_multiple_uuids() -> tuple[Mock, list[str]]: + """Create a mock context with multiple UUIDs via side_effect.""" + from uuid import UUID + + context = Mock() + context.instance_id = "test-instance" + context.current_utc_datetime = Mock() + uuids = [ + UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"), + UUID("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"), + UUID("cccccccc-cccc-cccc-cccc-cccccccccccc"), + ] + context.new_uuid = Mock(side_effect=uuids) + # Return the hex versions for assertion checking + hex_uuids = [uuid.hex for uuid in uuids] + return context, hex_uuids + - loaded = task._load_agent_response(response) +@pytest.fixture +def executor_with_uuid() -> tuple[Any, Mock, str]: + """Create an executor with a mocked generate_unique_id method.""" + from agent_framework_azurefunctions._orchestration import AzureFunctionsAgentExecutor - assert loaded is response - assert loaded.value is None + context = Mock() + context.instance_id = "test-instance" + context.current_utc_datetime = Mock() - def test_load_agent_response_from_serialized(self) -> None: - task = self._create_agent_task() - serialized = AgentRunResponse(messages=[ChatMessage(role="assistant", text="structured")]).to_dict() - serialized["value"] = {"answer": 42} + executor = AzureFunctionsAgentExecutor(context) + test_uuid_hex = "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + executor.generate_unique_id = Mock(return_value=test_uuid_hex) - loaded = task._load_agent_response(serialized) + return executor, context, test_uuid_hex - assert loaded is not None - assert loaded.value == {"answer": 42} - loaded_dict = loaded.to_dict() - assert loaded_dict["type"] == "agent_run_response" - def test_load_agent_response_rejects_none(self) -> None: - task = self._create_agent_task() +@pytest.fixture +def executor_with_multiple_uuids() -> tuple[Any, Mock, list[str]]: + """Create an executor with multiple mocked UUIDs.""" + from agent_framework_azurefunctions._orchestration import AzureFunctionsAgentExecutor - with pytest.raises(ValueError): - task._load_agent_response(None) + context = Mock() + context.instance_id = "test-instance" + context.current_utc_datetime = Mock() - def test_load_agent_response_rejects_unsupported_type(self) -> None: - task = self._create_agent_task() + executor = AzureFunctionsAgentExecutor(context) + uuid_hexes = [ + "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", + "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb", + "cccccccc-cccc-cccc-cccc-cccccccccccc", + "dddddddd-dddd-dddd-dddd-dddddddddddd", + "eeeeeeee-eeee-eeee-eeee-eeeeeeeeeeee", + ] + executor.generate_unique_id = Mock(side_effect=uuid_hexes) - with pytest.raises(TypeError, match="Unsupported type"): - task._load_agent_response(["invalid", "list"]) # type: ignore[arg-type] + return executor, context, uuid_hexes + + +@pytest.fixture +def executor_with_context(mock_context_with_uuid: tuple[Mock, str]) -> tuple[Any, Mock]: + """Create an executor with a mocked context.""" + from agent_framework_azurefunctions._orchestration import AzureFunctionsAgentExecutor + + context, _ = mock_context_with_uuid + return AzureFunctionsAgentExecutor(context), context + + +class TestAgentResponseHelpers: + """Tests for response handling through public AgentTask API.""" def test_try_set_value_success(self) -> None: """Test try_set_value correctly processes successful task completion.""" @@ -144,335 +194,10 @@ class TestSchema(BaseModel): assert isinstance(task.result.value, TestSchema) assert task.result.value.answer == "42" - def test_ensure_response_format_parses_value(self) -> None: - """Test _ensure_response_format correctly parses response value.""" - from pydantic import BaseModel - - class SampleSchema(BaseModel): - name: str - - task = self._create_agent_task() - response = AgentRunResponse(messages=[ChatMessage(role="assistant", text='{"name": "test"}')]) - - # Value should be None initially - assert response.value is None - - # Parse the value - task._ensure_response_format(SampleSchema, "test-correlation", response) - - # Value should now be parsed - assert isinstance(response.value, SampleSchema) - assert response.value.name == "test" - - def test_ensure_response_format_skips_if_already_parsed(self) -> None: - """Test _ensure_response_format does not re-parse if value already matches format.""" - from pydantic import BaseModel - - class SampleSchema(BaseModel): - name: str - - task = self._create_agent_task() - existing_value = SampleSchema(name="existing") - response = AgentRunResponse( - messages=[ChatMessage(role="assistant", text='{"name": "new"}')], - value=existing_value, - ) - - # Call _ensure_response_format - task._ensure_response_format(SampleSchema, "test-correlation", response) - - # Value should remain unchanged (not re-parsed) - assert response.value is existing_value - assert response.value.name == "existing" - - -class TestDurableAIAgent: - """Test suite for DurableAIAgent wrapper.""" - - def test_init(self) -> None: - """Test DurableAIAgent initialization.""" - mock_context = Mock() - mock_context.instance_id = "test-instance-123" - - agent = DurableAIAgent(mock_context, "TestAgent") - - assert agent.context == mock_context - assert agent.agent_name == "TestAgent" - - def test_implements_agent_protocol(self) -> None: - """Test that DurableAIAgent implements AgentProtocol.""" - from agent_framework import AgentProtocol - - mock_context = Mock() - agent = DurableAIAgent(mock_context, "TestAgent") - - # Check that agent satisfies AgentProtocol - assert isinstance(agent, AgentProtocol) - - def test_has_agent_protocol_properties(self) -> None: - """Test that DurableAIAgent has AgentProtocol properties.""" - mock_context = Mock() - agent = DurableAIAgent(mock_context, "TestAgent") - - # AgentProtocol properties - assert hasattr(agent, "id") - assert hasattr(agent, "name") - assert hasattr(agent, "description") - assert hasattr(agent, "display_name") - - # Verify values - assert agent.name == "TestAgent" - assert agent.description == "Durable agent proxy for TestAgent" - assert agent.display_name == "TestAgent" - assert agent.id is not None # Auto-generated UUID - - def test_get_new_thread(self) -> None: - """Test creating a new agent thread.""" - mock_context = Mock() - mock_context.instance_id = "test-instance-456" - mock_context.new_uuid = Mock(return_value="test-guid-456") - - agent = DurableAIAgent(mock_context, "WriterAgent") - thread = agent.get_new_thread() - - assert isinstance(thread, DurableAgentThread) - assert thread.session_id is not None - session_id = thread.session_id - assert isinstance(session_id, AgentSessionId) - assert session_id.name == "WriterAgent" - assert session_id.key == "test-guid-456" - mock_context.new_uuid.assert_called_once() - - def test_get_new_thread_deterministic(self) -> None: - """Test that get_new_thread creates deterministic session IDs.""" - - mock_context = Mock() - mock_context.instance_id = "test-instance-789" - mock_context.new_uuid = Mock(side_effect=["session-guid-1", "session-guid-2"]) - - agent = DurableAIAgent(mock_context, "EditorAgent") - - # Create multiple threads - they should have unique session IDs - thread1 = agent.get_new_thread() - thread2 = agent.get_new_thread() - - assert isinstance(thread1, DurableAgentThread) - assert isinstance(thread2, DurableAgentThread) - - session_id1 = thread1.session_id - session_id2 = thread2.session_id - assert session_id1 is not None and session_id2 is not None - assert isinstance(session_id1, AgentSessionId) - assert isinstance(session_id2, AgentSessionId) - assert session_id1.name == "EditorAgent" - assert session_id2.name == "EditorAgent" - assert session_id1.key == "session-guid-1" - assert session_id2.key == "session-guid-2" - assert mock_context.new_uuid.call_count == 2 - - def test_run_creates_entity_call(self) -> None: - """Test that run() creates proper entity call and returns a Task.""" - mock_context = Mock() - mock_context.instance_id = "test-instance-001" - mock_context.new_uuid = Mock(side_effect=["thread-guid", "correlation-guid"]) - - entity_task = _create_entity_task() - mock_context.call_entity = Mock(return_value=entity_task) - - agent = DurableAIAgent(mock_context, "TestAgent") - - # Create thread - thread = agent.get_new_thread() - - # Call run() - returns AgentTask directly - task = agent.run(messages="Test message", thread=thread, enable_tool_calls=True) - - assert isinstance(task, AgentTask) - assert task.children[0] == entity_task - - # Verify call_entity was called with correct parameters - assert mock_context.call_entity.called - call_args = mock_context.call_entity.call_args - entity_id, operation, request = call_args[0] - - assert operation == "run" - assert request["message"] == "Test message" - assert request["enable_tool_calls"] is True - assert "correlationId" in request - assert request["correlationId"] == "correlation-guid" - assert "thread_id" not in request - # Verify orchestration ID is set from context.instance_id - assert "orchestrationId" in request - assert request["orchestrationId"] == "test-instance-001" - - def test_run_sets_orchestration_id(self) -> None: - """Test that run() sets the orchestration_id from context.instance_id.""" - mock_context = Mock() - mock_context.instance_id = "my-orchestration-123" - mock_context.new_uuid = Mock(side_effect=["thread-guid", "correlation-guid"]) - - entity_task = _create_entity_task() - mock_context.call_entity = Mock(return_value=entity_task) - - agent = DurableAIAgent(mock_context, "TestAgent") - thread = agent.get_new_thread() - - agent.run(messages="Test", thread=thread) - - call_args = mock_context.call_entity.call_args - request = call_args[0][2] - - assert request["orchestrationId"] == "my-orchestration-123" - - def test_run_without_thread(self) -> None: - """Test that run() works without explicit thread (creates unique session key).""" - mock_context = Mock() - mock_context.instance_id = "test-instance-002" - mock_context.new_uuid = Mock(side_effect=["auto-generated-guid", "correlation-guid"]) - - entity_task = _create_entity_task() - mock_context.call_entity = Mock(return_value=entity_task) - - agent = DurableAIAgent(mock_context, "TestAgent") - - # Call without thread - task = agent.run(messages="Test message") - - assert isinstance(task, AgentTask) - assert task.children[0] == entity_task - - # Verify the entity ID uses the auto-generated GUID with dafx- prefix - call_args = mock_context.call_entity.call_args - entity_id = call_args[0][0] - assert entity_id.name == "dafx-TestAgent" - assert entity_id.key == "auto-generated-guid" - # Should be called twice: once for session_key, once for correlationId - assert mock_context.new_uuid.call_count == 2 - - def test_run_with_response_format(self) -> None: - """Test that run() passes response format correctly.""" - mock_context = Mock() - mock_context.instance_id = "test-instance-003" - - entity_task = _create_entity_task() - mock_context.call_entity = Mock(return_value=entity_task) - - agent = DurableAIAgent(mock_context, "TestAgent") - - from pydantic import BaseModel - - class SampleSchema(BaseModel): - key: str - - # Create thread and call - thread = agent.get_new_thread() - - task = agent.run(messages="Test message", thread=thread, response_format=SampleSchema) - - assert isinstance(task, AgentTask) - assert task.children[0] == entity_task - - # Verify schema was passed in the call_entity arguments - call_args = mock_context.call_entity.call_args - input_data = call_args[0][2] # Third argument is input_data - assert "response_format" in input_data - assert input_data["response_format"]["__response_schema_type__"] == "pydantic_model" - assert input_data["response_format"]["module"] == SampleSchema.__module__ - assert input_data["response_format"]["qualname"] == SampleSchema.__qualname__ - - def test_messages_to_string(self) -> None: - """Test converting ChatMessage list to string.""" - from agent_framework import ChatMessage - - mock_context = Mock() - agent = DurableAIAgent(mock_context, "TestAgent") - - messages = [ - ChatMessage(role="user", text="Hello"), - ChatMessage(role="assistant", text="Hi there"), - ChatMessage(role="user", text="How are you?"), - ] - - result = agent._messages_to_string(messages) - - assert result == "Hello\nHi there\nHow are you?" - - def test_run_with_chat_message(self) -> None: - """Test that run() handles ChatMessage input.""" - from agent_framework import ChatMessage - - mock_context = Mock() - mock_context.new_uuid = Mock(side_effect=["thread-guid", "correlation-guid"]) - entity_task = _create_entity_task() - mock_context.call_entity = Mock(return_value=entity_task) - - agent = DurableAIAgent(mock_context, "TestAgent") - thread = agent.get_new_thread() - - # Call with ChatMessage - msg = ChatMessage(role="user", text="Hello") - task = agent.run(messages=msg, thread=thread) - - assert isinstance(task, AgentTask) - assert task.children[0] == entity_task - - # Verify message was converted to string - call_args = mock_context.call_entity.call_args - request = call_args[0][2] - assert request["message"] == "Hello" - - def test_run_stream_raises_not_implemented(self) -> None: - """Test that run_stream() method raises NotImplementedError.""" - mock_context = Mock() - agent = DurableAIAgent(mock_context, "TestAgent") - - with pytest.raises(NotImplementedError) as exc_info: - agent.run_stream("Test message") - - error_msg = str(exc_info.value) - assert "Streaming is not supported" in error_msg - - def test_entity_id_format(self) -> None: - """Test that EntityId is created with correct format (name, key).""" - from azure.durable_functions import EntityId - - mock_context = Mock() - mock_context.new_uuid = Mock(return_value="test-guid-789") - mock_context.call_entity = Mock(return_value=_create_entity_task()) - - agent = DurableAIAgent(mock_context, "WriterAgent") - thread = agent.get_new_thread() - - # Call run() to trigger entity ID creation - agent.run("Test", thread=thread) - - # Verify call_entity was called with correct EntityId - call_args = mock_context.call_entity.call_args - entity_id = call_args[0][0] - - # EntityId should be EntityId(name="dafx-WriterAgent", key="test-guid-789") - # Which formats as "@dafx-writeragent@test-guid-789" - assert isinstance(entity_id, EntityId) - assert entity_id.name == "dafx-WriterAgent" - assert entity_id.key == "test-guid-789" - assert str(entity_id) == "@dafx-writeragent@test-guid-789" - class TestAgentFunctionAppGetAgent: """Test suite for AgentFunctionApp.get_agent.""" - def test_get_agent_method(self) -> None: - """Test get_agent method creates DurableAIAgent for registered agent.""" - app = _app_with_registered_agents("MyAgent") - mock_context = Mock() - mock_context.instance_id = "test-instance-100" - - agent = app.get_agent(mock_context, "MyAgent") - - assert isinstance(agent, DurableAIAgent) - assert agent.agent_name == "MyAgent" - assert agent.context == mock_context - def test_get_agent_raises_for_unregistered_agent(self) -> None: """Test get_agent raises ValueError when agent is not registered.""" app = _app_with_registered_agents("KnownAgent") @@ -484,15 +209,9 @@ def test_get_agent_raises_for_unregistered_agent(self) -> None: class TestOrchestrationIntegration: """Integration tests for orchestration scenarios.""" - def test_sequential_agent_calls_simulation(self) -> None: + def test_sequential_agent_calls_simulation(self, executor_with_multiple_uuids: tuple[Any, Mock, list[str]]) -> None: """Simulate sequential agent calls in an orchestration.""" - mock_context = Mock() - mock_context.instance_id = "test-orchestration-001" - # new_uuid will be called 3 times: - # 1. thread creation - # 2. correlationId for first call - # 3. correlationId for second call - mock_context.new_uuid = Mock(side_effect=["deterministic-guid-001", "corr-1", "corr-2"]) + executor, context, uuid_hexes = executor_with_multiple_uuids # Track entity calls entity_calls: list[dict[str, Any]] = [] @@ -501,10 +220,10 @@ def mock_call_entity_side_effect(entity_id: Any, operation: str, input_data: dic entity_calls.append({"entity_id": str(entity_id), "operation": operation, "input": input_data}) return _create_entity_task() - mock_context.call_entity = Mock(side_effect=mock_call_entity_side_effect) + context.call_entity = Mock(side_effect=mock_call_entity_side_effect) - app = _app_with_registered_agents("WriterAgent") - agent = app.get_agent(mock_context, "WriterAgent") + # Create agent directly with executor (not via app.get_agent) + agent = DurableAIAgent(executor, "WriterAgent") # Create thread thread = agent.get_new_thread() @@ -520,18 +239,15 @@ def mock_call_entity_side_effect(entity_id: Any, operation: str, input_data: dic # Verify both calls used the same entity (same session key) assert len(entity_calls) == 2 assert entity_calls[0]["entity_id"] == entity_calls[1]["entity_id"] - # EntityId format is @dafx-writeragent@deterministic-guid-001 - assert entity_calls[0]["entity_id"] == "@dafx-writeragent@deterministic-guid-001" - # new_uuid called 3 times: thread + 2 correlation IDs - assert mock_context.new_uuid.call_count == 3 + # EntityId format is @dafx-writeragent@ + expected_entity_id = f"@dafx-writeragent@{uuid_hexes[0]}" + assert entity_calls[0]["entity_id"] == expected_entity_id + # generate_unique_id called 3 times: thread + 2 correlation IDs + assert executor.generate_unique_id.call_count == 3 - def test_multiple_agents_in_orchestration(self) -> None: + def test_multiple_agents_in_orchestration(self, executor_with_multiple_uuids: tuple[Any, Mock, list[str]]) -> None: """Test using multiple different agents in one orchestration.""" - mock_context = Mock() - mock_context.instance_id = "test-orchestration-002" - # Mock new_uuid to return different GUIDs for each call - # Order: writer thread, editor thread, writer correlation, editor correlation - mock_context.new_uuid = Mock(side_effect=["writer-guid-001", "editor-guid-002", "writer-corr", "editor-corr"]) + executor, context, uuid_hexes = executor_with_multiple_uuids entity_calls: list[str] = [] @@ -539,11 +255,11 @@ def mock_call_entity_side_effect(entity_id: Any, operation: str, input_data: dic entity_calls.append(str(entity_id)) return _create_entity_task() - mock_context.call_entity = Mock(side_effect=mock_call_entity_side_effect) + context.call_entity = Mock(side_effect=mock_call_entity_side_effect) - app = _app_with_registered_agents("WriterAgent", "EditorAgent") - writer = app.get_agent(mock_context, "WriterAgent") - editor = app.get_agent(mock_context, "EditorAgent") + # Create agents directly with executor (not via app.get_agent) + writer = DurableAIAgent(executor, "WriterAgent") + editor = DurableAIAgent(executor, "EditorAgent") writer_thread = writer.get_new_thread() editor_thread = editor.get_new_thread() @@ -557,62 +273,11 @@ def mock_call_entity_side_effect(entity_id: Any, operation: str, input_data: dic # Verify different entity IDs were used assert len(entity_calls) == 2 - # EntityId format is @dafx-agentname@guid (lowercased agent name with dafx- prefix) - assert entity_calls[0] == "@dafx-writeragent@writer-guid-001" - assert entity_calls[1] == "@dafx-editoragent@editor-guid-002" - - -class TestAgentThreadSerialization: - """Test that AgentThread can be serialized for orchestration state.""" - - async def test_agent_thread_serialize(self) -> None: - """Test that AgentThread can be serialized.""" - thread = AgentThread() - - # Serialize - serialized = await thread.serialize() - - assert isinstance(serialized, dict) - assert "service_thread_id" in serialized - - async def test_agent_thread_deserialize(self) -> None: - """Test that AgentThread can be deserialized.""" - thread = AgentThread() - serialized = await thread.serialize() - - # Deserialize - restored = await AgentThread.deserialize(serialized) - - assert isinstance(restored, AgentThread) - assert restored.service_thread_id == thread.service_thread_id - - async def test_durable_agent_thread_serialization(self) -> None: - """Test that DurableAgentThread persists session metadata during serialization.""" - mock_context = Mock() - mock_context.instance_id = "test-instance-999" - mock_context.new_uuid = Mock(return_value="test-guid-999") - - agent = DurableAIAgent(mock_context, "TestAgent") - thread = agent.get_new_thread() - - assert isinstance(thread, DurableAgentThread) - # Verify custom attribute and property exist - assert thread.session_id is not None - session_id = thread.session_id - assert isinstance(session_id, AgentSessionId) - assert session_id.name == "TestAgent" - assert session_id.key == "test-guid-999" - - # Standard serialization should still work - serialized = await thread.serialize() - assert isinstance(serialized, dict) - assert serialized.get("durable_session_id") == str(session_id) - - # After deserialization, we'd need to restore the custom attribute - # This would be handled by the orchestration framework - restored = await DurableAgentThread.deserialize(serialized) - assert isinstance(restored, DurableAgentThread) - assert restored.session_id == session_id + # EntityId format is @dafx-agentname@uuid_hex (lowercased agent name with dafx- prefix) + expected_writer_id = f"@dafx-writeragent@{uuid_hexes[0]}" + expected_editor_id = f"@dafx-editoragent@{uuid_hexes[1]}" + assert entity_calls[0] == expected_writer_id + assert entity_calls[1] == expected_editor_id if __name__ == "__main__": diff --git a/python/packages/durabletask/DESIGN.md b/python/packages/durabletask/DESIGN.md deleted file mode 100644 index d8be9f7840..0000000000 --- a/python/packages/durabletask/DESIGN.md +++ /dev/null @@ -1,248 +0,0 @@ -# Design: Durable Task Provider for Agent Framework - -## Overview - -This package, `agent-framework-durabletask`, provides a durability layer for the Microsoft Agent Framework using the `durabletask` Python SDK. It enables stateful, reliable, and distributed agent execution on any platform (Bring Your Own Platform), decoupling the agent's durability from the Azure Functions platform. - -## Design Decision - -**Selected Approach: Object-Oriented Wrappers with Symmetric Factory Pattern** - -We will use a symmetric Object-Oriented design where both the Client (external) and Orchestrator (internal) expose a consistent interface for retrieving and interacting with durable agents. - -## Core Philosophy - -* **Native `DurableEntity` Support**: We will leverage the `DurableEntity` support introduced in `durabletask` v1.0.0. -* **Symmetric Factories**: `DurableAIAgentClient` (for external use) and `DurableAIAgentOrchestrator` (for internal use) both provide a `get_agent` method. -* **Unified Interface**: `DurableAIAgent` serves as the common interface for executing agents, regardless of the context (Client vs Orchestration). -* **Consistent Return Type**: `DurableAIAgent.run` always returns a `Task` (or compatible awaitable), ensuring consistent usage patterns. - -## Architecture - -### 1. Package Structure - -```text -packages/durabletask/ -├── pyproject.toml -├── README.md -├── agent_framework_durabletask/ -│ ├── __init__.py -│ ├── _worker.py # DurableAIAgentWorker -│ ├── _client.py # DurableAIAgentClient -│ ├── _orchestrator.py # DurableAIAgentOrchestrator -│ ├── _entities.py # AgentEntity implementation -│ ├── _models.py # Data models (RunRequest, AgentResponse, etc.) -│ ├── _durable_agent_state.py # State schema (Ported from azurefunctions) -│ ├── _shim.py # DurableAIAgent implementation (will be ported from azurefunctions) -│ └── _utils.py # Mixins and helpers -└── tests/ -``` - -### 2. State Management (`_durable_agent_state.py`) - -**Important**: This will be the state maintained in the durable entities for both `durabletask` and `azurefunctions` package. - -### 3. The Agent Entity (`_entities.py`) - -We will implement a class `AgentEntity` that inherits from `durabletask.entities.DurableEntity`. - -**Important**: This will be ported from `azurefunctions` package too but with slight modifications, details TBD. - -### 4. The Worker Wrapper (`_worker.py`) - -The `DurableAIAgentWorker` wraps an existing `durabletask` worker instance. - -```python -class DurableAIAgentWorker: - def __init__(self, worker: TaskHubGrpcWorker): - self._worker = worker - self._registered_agents: dict[str, AgentProtocol] = {} - - def add_agent(self, agent: AgentProtocol) -> None: - """Registers an agent with the worker. - - Uses the factory pattern to create an AgentEntity class with the agent - instance injected, then registers it with the durabletask worker. - """ - # Store the agent reference - self._registered_agents[agent.name] = agent - - # Create a configured entity class using the factory - entity_class = create_agent_entity(agent) - - # Register the entity class with the worker - # The worker.add_entity method takes a class or function - self._worker.add_entity(entity_class) - - def start(self): - """Start the worker to begin processing tasks.""" - self._worker.start() - - def stop(self): - """Stop the worker gracefully.""" - self._worker.stop() -``` - -### 5. The Mixin (`_utils.py`) - -```python -class GetDurableAgentMixin: - """Mixin to provide get_agent interface.""" - - def get_agent(self, agent_name: str) -> 'DurableAIAgent': - raise NotImplementedError -``` - -### 6. The Client Wrapper (`_client.py`) - -The `DurableAIAgentClient` is for external clients (e.g., FastAPI, CLI). - -```python -class DurableAIAgentClient(GetDurableAgentMixin): - def __init__(self, client: TaskHubGrpcClient): - self._client = client - - async def get_agent(self, agent_name: str) -> 'DurableAIAgent': - """Retrieves a DurableAIAgent shim. - - Validates existence by attempting to fetch entity state/metadata. - """ - # Validation logic using self._client.get_entity(...) - # ... - return DurableAIAgent(self, agent_name) - - def run_agent(self, agent_name: str, message: str, **kwargs) -> 'Task': - """Runs agent via signal + poll and returns a Task wrapper.""" - # Returns a ClientTask (wrapper around asyncio.Task) - pass -``` - -### 7. The Orchestration Context Wrapper (`_orchestration_context.py`) - -The `DurableAIAgentOrchestrationContext` is for use *inside* orchestrations to get access to agents that were registered in the workers. - -```python -class DurableAIAgentOrchestrationContext(GetDurableAgentMixin): - def __init__(self, context: OrchestrationContext): - self._context = context - - def get_agent(self, agent_name: str) -> 'DurableAIAgent': - """Retrieves a DurableAIAgent shim. - - Validation is deferred or performed via call_entity if needed. - """ - return DurableAIAgent(self, agent_name) - - def run_agent(self, agent_name: str, message: str, **kwargs) -> 'Task': - """Runs agent via call_entity and returns the Task.""" - # Returns the native durabletask.Task - pass -``` - -### 8. The Durable Agent Shim (`_shim.py`) - -The `DurableAIAgent` implements `AgentProtocol` but delegates execution to the provider. This will be ported from `azurefunctions` package and updated accordingly. - -```python -class DurableAIAgent(AgentProtocol): - """A shim that delegates execution to the provider (Client or Orchestrator).""" - - def __init__(self, provider: GetDurableAgentMixin, name: str): - self._provider = provider - self._name = name - - @property - def name(self) -> str: - return self._name - - def run(self, message: str, **kwargs) -> 'Task': - """Executes the agent. - - Returns: - Task: A yieldable/awaitable task object. - """ - return self._provider.run_agent( - agent_name=self.name, - message=message, - **kwargs - ) -``` - -## Usage Experience - -**Scenario A: Worker Side** -```python -# 1. Define your agent -# The agent can be any implementation of AgentProtocol. -# For example, a standard Agent with a model and instructions. -my_agent = Agent( - name="my_agent", - instructions="You are a helpful assistant.", - model=openai_model -) - -# 2. Create the worker and the agent worker wrapper -with DurableTaskSchedulerWorker(...) as worker: - - agent_worker = DurableAIAgentWorker(worker) - - # 3. Register the agent - agent_worker.add_agent(my_agent) - - # 4. Start the worker - worker.start() - - # ... keep running ... -``` - -**Scenario B: Client Side** -```python -# 1. Configure the Durable Task client -client = DurableTaskSchedulerClient(...) - -# 2. Create the Agent Client wrapper -agent_client = DurableAIAgentClient(client) - -# 3. Get a reference to the agent -agent = await agent_client.get_agent("my_agent") - -# 4. Run the agent -# The returned object is designed to be compatible with both `await` (Client) -# and `yield` (Orchestrator). Implementation details on this unified return type will follow. -response = await agent.run("Hello") -``` - -**Scenario C: Orchestration Side** -```python -def orchestrator(context: OrchestrationContext): - # 1. Create the Agent Orchestrator wrapper - agent_orch = DurableAIAgentOrchestrator(context) - - # 2. Get a reference to the agent - agent = agent_orch.get_agent("my_agent") - - # 3. Run the agent (returns a Task, so we yield it) - result = yield agent.run("Hello") - - return result -``` - -## Additional Styles Considered - -### Inheritance Pattern for worker and client (like `DurableAIAgentWorker`, `DurableAIAgentClient`, etc) - -We investigated inheriting `DurableAIAgentWorker` directly from `TaskHubGrpcWorker` (or `DurableTaskSchedulerWorker`) to provide a unified API where the agent worker *is* a durable task worker (and similarly the client). - -**Why we chose Composition over Inheritance:** - -1. **Initialization Divergence:** The `durabletask` package has two distinct worker classes with incompatible `__init__` signatures: - * `TaskHubGrpcWorker`: Requires `host_address`, `metadata`, etc. - * `DurableTaskSchedulerWorker`: Requires `host_address`, `taskhub`, `token_credential`, etc. - - To support both via inheritance, we would need to maintain two separate classes (e.g., `DurableAIAgentGrpcWorker` and `DurableAIAgentSchedulerWorker`) or use a complex Mixin approach. This increases the API surface area and maintenance burden. - -2. **Encapsulation:** The logic for Azure Managed DTS (authentication, routing) is currently encapsulated in an internal interceptor class within `durabletask`. Without changes to the upstream package to expose this logic, we cannot create a single "Universal" worker class that inherits from the base worker but supports Azure features. - -3. **Flexibility:** The Composition pattern allows `DurableAIAgentWorker` to accept *any* instance of a worker that satisfies the required interface. This makes it forward-compatible with future worker implementations or custom subclasses without requiring code changes in our package. - -4. **Simplicity:** While Composition requires a two-step setup (instantiate worker, then wrap it), it keeps the `agent-framework-durabletask` package simple, focused, and loosely coupled from the implementation details of the underlying `durabletask` workers. \ No newline at end of file diff --git a/python/packages/durabletask/agent_framework_durabletask/__init__.py b/python/packages/durabletask/agent_framework_durabletask/__init__.py index f28f4d8064..3283cf8959 100644 --- a/python/packages/durabletask/agent_framework_durabletask/__init__.py +++ b/python/packages/durabletask/agent_framework_durabletask/__init__.py @@ -3,6 +3,7 @@ """Durable Task integration for Microsoft Agent Framework.""" from ._callbacks import AgentCallbackContext, AgentResponseCallbackProtocol +from ._client import DurableAIAgentClient from ._constants import ( DEFAULT_MAX_POLL_RETRIES, DEFAULT_POLL_INTERVAL_SECONDS, @@ -41,7 +42,12 @@ DurableAgentStateUsageContent, ) from ._entities import AgentEntity, AgentEntityStateProviderMixin -from ._models import RunRequest, serialize_response_format +from ._executors import DurableAgentExecutor +from ._models import AgentSessionId, DurableAgentThread, RunRequest +from ._orchestration_context import DurableAIAgentOrchestrationContext +from ._response_utils import ensure_response_format, load_agent_response +from ._shim import DurableAIAgent +from ._worker import DurableAIAgentWorker __all__ = [ "DEFAULT_MAX_POLL_RETRIES", @@ -58,8 +64,14 @@ "AgentEntity", "AgentEntityStateProviderMixin", "AgentResponseCallbackProtocol", + "AgentSessionId", "ApiResponseFields", "ContentTypes", + "DurableAIAgent", + "DurableAIAgentClient", + "DurableAIAgentOrchestrationContext", + "DurableAIAgentWorker", + "DurableAgentExecutor", "DurableAgentState", "DurableAgentStateContent", "DurableAgentStateData", @@ -80,7 +92,10 @@ "DurableAgentStateUriContent", "DurableAgentStateUsage", "DurableAgentStateUsageContent", + "DurableAgentThread", + "DurableAgentThread", "DurableStateFields", "RunRequest", - "serialize_response_format", + "ensure_response_format", + "load_agent_response", ] diff --git a/python/packages/durabletask/agent_framework_durabletask/_client.py b/python/packages/durabletask/agent_framework_durabletask/_client.py new file mode 100644 index 0000000000..5e70a7ba1f --- /dev/null +++ b/python/packages/durabletask/agent_framework_durabletask/_client.py @@ -0,0 +1,90 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Client wrapper for Durable Task Agent Framework. + +This module provides the DurableAIAgentClient class for external clients to interact +with durable agents via gRPC. +""" + +from __future__ import annotations + +from agent_framework import AgentRunResponse, get_logger +from durabletask.client import TaskHubGrpcClient + +from ._constants import DEFAULT_MAX_POLL_RETRIES, DEFAULT_POLL_INTERVAL_SECONDS +from ._executors import ClientAgentExecutor +from ._shim import DurableAgentProvider, DurableAIAgent + +logger = get_logger("agent_framework.durabletask.client") + + +class DurableAIAgentClient(DurableAgentProvider[AgentRunResponse]): + """Client wrapper for interacting with durable agents externally. + + This class wraps a durabletask TaskHubGrpcClient and provides a convenient + interface for retrieving and executing durable agents from external contexts. + + Example: + ```python + from durabletask import TaskHubGrpcClient + from agent_framework_durabletask import DurableAIAgentClient + + # Create the underlying client + client = TaskHubGrpcClient(host_address="localhost:4001") + + # Wrap it with the agent client + agent_client = DurableAIAgentClient(client) + + # Get an agent reference + agent = agent_client.get_agent("assistant") + + # Run the agent (synchronous call that waits for completion) + response = agent.run("Hello, how are you?") + print(response.text) + ``` + """ + + def __init__( + self, + client: TaskHubGrpcClient, + max_poll_retries: int = DEFAULT_MAX_POLL_RETRIES, + poll_interval_seconds: float = DEFAULT_POLL_INTERVAL_SECONDS, + ): + """Initialize the client wrapper. + + Args: + client: The durabletask client instance to wrap + max_poll_retries: Maximum polling attempts when waiting for responses + poll_interval_seconds: Delay in seconds between polling attempts + """ + self._client = client + + # Validate and set polling parameters + self.max_poll_retries = max(1, max_poll_retries) + self.poll_interval_seconds = ( + poll_interval_seconds if poll_interval_seconds > 0 else DEFAULT_POLL_INTERVAL_SECONDS + ) + + self._executor = ClientAgentExecutor(self._client, self.max_poll_retries, self.poll_interval_seconds) + logger.debug("[DurableAIAgentClient] Initialized with client type: %s", type(client).__name__) + + def get_agent(self, agent_name: str) -> DurableAIAgent[AgentRunResponse]: + """Retrieve a DurableAIAgent shim for the specified agent. + + This method returns a proxy object that can be used to execute the agent. + The actual agent must be registered on a worker with the same name. + + Args: + agent_name: Name of the agent to retrieve (without the dafx- prefix) + + Returns: + DurableAIAgent instance that can be used to run the agent + + Note: + This method does not validate that the agent exists. Validation + will occur when the agent is executed. If the entity doesn't exist, + the execution will fail with an appropriate error. + """ + logger.debug("[DurableAIAgentClient] Creating agent proxy for: %s", agent_name) + + return DurableAIAgent(self._executor, agent_name) diff --git a/python/packages/durabletask/agent_framework_durabletask/_durable_agent_state.py b/python/packages/durabletask/agent_framework_durabletask/_durable_agent_state.py index a42a6fad43..453d180612 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_durable_agent_state.py +++ b/python/packages/durabletask/agent_framework_durabletask/_durable_agent_state.py @@ -53,7 +53,7 @@ ) from dateutil import parser as date_parser -from ._constants import ApiResponseFields, ContentTypes, DurableStateFields +from ._constants import ContentTypes, DurableStateFields from ._models import RunRequest, serialize_response_format logger = get_logger("agent_framework.durabletask.durable_agent_state") @@ -452,7 +452,7 @@ def message_count(self) -> int: """Get the count of conversation entries (requests + responses).""" return len(self.data.conversation_history) - def try_get_agent_response(self, correlation_id: str) -> dict[str, Any] | None: + def try_get_agent_response(self, correlation_id: str) -> AgentRunResponse | None: """Try to get an agent response by correlation ID. This method searches the conversation history for a response entry matching the given @@ -474,14 +474,8 @@ def try_get_agent_response(self, correlation_id: str) -> dict[str, Any] | None: for entry in self.data.conversation_history: if entry.correlation_id == correlation_id and isinstance(entry, DurableAgentStateResponse): # Found the entry, extract response data - # Get the text content from assistant messages only - content = "\n".join(message.text for message in entry.messages if message.text) + return DurableAgentStateResponse.to_run_response(entry) - return { - ApiResponseFields.CONTENT: content, - ApiResponseFields.MESSAGE_COUNT: self.message_count, - ApiResponseFields.CORRELATION_ID: correlation_id, - } return None @@ -705,6 +699,21 @@ def from_run_response(correlation_id: str, response: AgentRunResponse) -> Durabl usage=DurableAgentStateUsage.from_usage(response.usage_details), ) + @staticmethod + def to_run_response( + response_entry: DurableAgentStateResponse, + ) -> AgentRunResponse: + """Converts a DurableAgentStateResponse back to an AgentRunResponse.""" + messages = [m.to_chat_message() for m in response_entry.messages] + + usage_details = response_entry.usage.to_usage_details() if response_entry.usage is not None else UsageDetails() + + return AgentRunResponse( + created_at=response_entry.created_at.isoformat(), + messages=messages, + usage_details=usage_details, + ) + class DurableAgentStateMessage: """Represents a message within a conversation history entry. @@ -1214,14 +1223,24 @@ def from_usage(usage: UsageDetails | None) -> DurableAgentStateUsage | None: input_token_count=usage.input_token_count, output_token_count=usage.output_token_count, total_token_count=usage.total_token_count, + extensionData=usage.additional_counts, ) def to_usage_details(self) -> UsageDetails: # Convert back to AI SDK UsageDetails + extension_data: dict[str, int] = {} + if self.extensionData is not None: + for k, v in self.extensionData.items(): + try: + extension_data[k] = int(v) + except (ValueError, TypeError): + continue + return UsageDetails( input_token_count=self.input_token_count, output_token_count=self.output_token_count, total_token_count=self.total_token_count, + **extension_data, ) diff --git a/python/packages/durabletask/agent_framework_durabletask/_entities.py b/python/packages/durabletask/agent_framework_durabletask/_entities.py index 2f7bb8a62c..2e0b429233 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_entities.py +++ b/python/packages/durabletask/agent_framework_durabletask/_entities.py @@ -128,7 +128,7 @@ async def run( ) -> AgentRunResponse: """Execute the agent with a message.""" if isinstance(request, str): - run_request = RunRequest(message=request, role=Role.USER) + run_request = RunRequest.from_json(request) elif isinstance(request, dict): run_request = RunRequest.from_dict(request) else: @@ -139,8 +139,6 @@ async def run( correlation_id = run_request.correlation_id if not thread_id: raise ValueError("Entity State Provider must provide a thread_id") - if not correlation_id: - raise ValueError("RunRequest must include a correlation_id") response_format = run_request.response_format enable_tool_calls = run_request.enable_tool_calls diff --git a/python/packages/durabletask/agent_framework_durabletask/_executors.py b/python/packages/durabletask/agent_framework_durabletask/_executors.py new file mode 100644 index 0000000000..8ac893b943 --- /dev/null +++ b/python/packages/durabletask/agent_framework_durabletask/_executors.py @@ -0,0 +1,460 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Provider strategies for Durable Agent execution. + +These classes are internal execution strategies used by the DurableAIAgent shim. +They are intentionally separate from the public client/orchestration APIs to keep +only `get_agent` exposed to consumers. Executors implement the execution contract +and are injected into the shim. +""" + +from __future__ import annotations + +import time +import uuid +from abc import ABC, abstractmethod +from datetime import datetime, timezone +from typing import Any, Generic, TypeVar + +from agent_framework import AgentRunResponse, AgentThread, ChatMessage, ErrorContent, Role, get_logger +from durabletask.client import TaskHubGrpcClient +from durabletask.entities import EntityInstanceId +from durabletask.task import CompositeTask, OrchestrationContext, Task +from pydantic import BaseModel + +from ._constants import DEFAULT_MAX_POLL_RETRIES, DEFAULT_POLL_INTERVAL_SECONDS +from ._durable_agent_state import DurableAgentState +from ._models import AgentSessionId, DurableAgentThread, RunRequest +from ._response_utils import ensure_response_format, load_agent_response + +logger = get_logger("agent_framework.durabletask.executors") + +# TypeVar for the task type returned by executors +TaskT = TypeVar("TaskT") + + +class DurableAgentTask(CompositeTask[AgentRunResponse]): + """A custom Task that wraps entity calls and provides typed AgentRunResponse results. + + This task wraps the underlying entity call task and intercepts its completion + to convert the raw result into a typed AgentRunResponse object. + """ + + def __init__( + self, + entity_task: Task[Any], + response_format: type[BaseModel] | None, + correlation_id: str, + ): + """Initialize the DurableAgentTask. + + Args: + entity_task: The underlying entity call task + response_format: Optional Pydantic model for response parsing + correlation_id: Correlation ID for logging + """ + self._response_format = response_format + self._correlation_id = correlation_id + super().__init__([entity_task]) # type: ignore[misc] + + def on_child_completed(self, task: Task[Any]) -> None: + """Handle completion of the underlying entity task. + + Parameters + ---------- + task : Task + The entity call task that just completed + """ + if self.is_complete: + return + + if task.is_failed: + # Propagate the failure + self._exception = task.get_exception() + self._is_complete = True + if self._parent is not None: + self._parent.on_child_completed(self) + return + + # Task succeeded - transform the raw result + raw_result = task.get_result() + logger.debug( + "[DurableAgentTask] Converting raw result for correlation_id %s", + self._correlation_id, + ) + + try: + response = load_agent_response(raw_result) + + if self._response_format is not None: + ensure_response_format( + self._response_format, + self._correlation_id, + response, + ) + + # Set the typed AgentRunResponse as this task's result + self._result = response + self._is_complete = True + + if self._parent is not None: + self._parent.on_child_completed(self) + + except Exception: + logger.exception( + "[DurableAgentTask] Failed to convert result for correlation_id: %s", + self._correlation_id, + ) + raise + + +class DurableAgentExecutor(ABC, Generic[TaskT]): + """Abstract base class for durable agent execution strategies. + + Type Parameters: + TaskT: The task type returned by this executor + """ + + @abstractmethod + def run_durable_agent( + self, + agent_name: str, + run_request: RunRequest, + thread: AgentThread | None = None, + ) -> TaskT: + """Execute the durable agent. + + Returns: + TaskT: The task type specific to this executor implementation + """ + raise NotImplementedError + + def get_new_thread(self, agent_name: str, **kwargs: Any) -> DurableAgentThread: + """Create a new DurableAgentThread with random session ID.""" + session_id = self._create_session_id(agent_name) + return DurableAgentThread.from_session_id(session_id, **kwargs) + + def _create_session_id( + self, + agent_name: str, + thread: AgentThread | None = None, + ) -> AgentSessionId: + """Create the AgentSessionId for the execution.""" + if isinstance(thread, DurableAgentThread) and thread.session_id is not None: + return thread.session_id + # Create new session ID - either no thread provided or it's a regular AgentThread + key = self.generate_unique_id() + return AgentSessionId(name=agent_name, key=key) + + def generate_unique_id(self) -> str: + """Generate a new Unique ID.""" + return uuid.uuid4().hex + + def get_run_request( + self, + message: str, + response_format: type[BaseModel] | None, + enable_tool_calls: bool, + ) -> RunRequest: + """Create a RunRequest for the given parameters.""" + correlation_id = self.generate_unique_id() + return RunRequest( + message=message, + response_format=response_format, + enable_tool_calls=enable_tool_calls, + correlation_id=correlation_id, + ) + + +class ClientAgentExecutor(DurableAgentExecutor[AgentRunResponse]): + """Execution strategy for external clients. + + Note: Returns AgentRunResponse directly since the execution + is blocking until response is available via polling + as per the design of TaskHubGrpcClient. + """ + + def __init__( + self, + client: TaskHubGrpcClient, + max_poll_retries: int = DEFAULT_MAX_POLL_RETRIES, + poll_interval_seconds: float = DEFAULT_POLL_INTERVAL_SECONDS, + ): + self._client = client + self.max_poll_retries = max_poll_retries + self.poll_interval_seconds = poll_interval_seconds + + def run_durable_agent( + self, + agent_name: str, + run_request: RunRequest, + thread: AgentThread | None = None, + ) -> AgentRunResponse: + """Execute the agent via the durabletask client. + + Signals the agent entity with a message request, then polls the entity + state to retrieve the response once processing is complete. + + Note: This is a blocking/synchronous operation (in line with how + TaskHubGrpcClient works) that polls until a response is available or + timeout occurs. + + Args: + agent_name: Name of the agent to execute + run_request: The run request containing message and optional response format + thread: Optional conversation thread (creates new if not provided) + + Returns: + AgentRunResponse: The agent's response after execution completes + """ + # Signal the entity with the request + entity_id = self._signal_agent_entity(agent_name, run_request, thread) + + # Poll for the response + agent_response = self._poll_for_agent_response(entity_id, run_request.correlation_id) + + # Handle and return the result + return self._handle_agent_response(agent_response, run_request.response_format, run_request.correlation_id) + + def _signal_agent_entity( + self, + agent_name: str, + run_request: RunRequest, + thread: AgentThread | None, + ) -> EntityInstanceId: + """Signal the agent entity with a run request. + + Args: + agent_name: Name of the agent to execute + run_request: The run request containing message and optional response format + thread: Optional conversation thread + + Returns: + entity_id + """ + # Get or create session ID + session_id = self._create_session_id(agent_name, thread) + + # Create the entity ID + entity_id = EntityInstanceId( + entity=session_id.entity_name, + key=session_id.key, + ) + + logger.debug( + "[ClientAgentExecutor] Signaling entity '%s' (session: %s, correlation: %s)", + agent_name, + session_id, + run_request.correlation_id, + ) + + self._client.signal_entity(entity_id, "run", run_request.to_dict()) + return entity_id + + def _poll_for_agent_response( + self, + entity_id: EntityInstanceId, + correlation_id: str, + ) -> AgentRunResponse | None: + """Poll the entity for a response with retries. + + Args: + entity_id: Entity instance identifier + correlation_id: Correlation ID to track the request + + Returns: + The agent response if found, None if timeout occurs + """ + agent_response = None + + for attempt in range(1, self.max_poll_retries + 1): + time.sleep(self.poll_interval_seconds) + + agent_response = self._poll_entity_for_response(entity_id, correlation_id) + if agent_response is not None: + logger.info( + "[ClientAgentExecutor] Found response (attempt %d/%d, correlation: %s)", + attempt, + self.max_poll_retries, + correlation_id, + ) + break + + logger.debug( + "[ClientAgentExecutor] Response not ready (attempt %d/%d)", + attempt, + self.max_poll_retries, + ) + + return agent_response + + def _handle_agent_response( + self, + agent_response: AgentRunResponse | None, + response_format: type[BaseModel] | None, + correlation_id: str, + ) -> AgentRunResponse: + """Handle the agent response or create an error response. + + Args: + agent_response: The response from polling, or None if timeout + response_format: Optional response format for validation + correlation_id: Correlation ID for logging + + Returns: + AgentRunResponse with either the agent's response or an error message + """ + if agent_response is not None: + try: + # Validate response format if specified + if response_format is not None: + ensure_response_format( + response_format, + correlation_id, + agent_response, + ) + + return agent_response + + except Exception as e: + logger.exception( + "[ClientAgentExecutor] Error converting response for correlation: %s", + correlation_id, + ) + error_message = ChatMessage( + role=Role.SYSTEM, + contents=[ + ErrorContent( + message=f"Error processing agent response: {e}", + error_code="response_processing_error", + ) + ], + ) + else: + logger.warning( + "[ClientAgentExecutor] Timeout after %d attempts (correlation: %s)", + self.max_poll_retries, + correlation_id, + ) + error_message = ChatMessage( + role=Role.SYSTEM, + contents=[ + ErrorContent( + message=f"Timeout waiting for agent response after {self.max_poll_retries} attempts", + error_code="response_timeout", + ) + ], + ) + + return AgentRunResponse( + messages=[error_message], + created_at=datetime.now(timezone.utc).isoformat(), + ) + + def _poll_entity_for_response( + self, + entity_id: EntityInstanceId, + correlation_id: str, + ) -> AgentRunResponse | None: + """Poll the entity state for a response matching the correlation ID. + + Args: + entity_id: Entity instance identifier + correlation_id: Correlation ID to search for + + Returns: + Response AgentRunResponse, None otherwise + """ + try: + entity_metadata = self._client.get_entity(entity_id, include_state=True) + + if entity_metadata is None: + return None + + state_json = entity_metadata.get_state() + if not state_json: + return None + + state = DurableAgentState.from_json(state_json) + + # Use the helper method to get response by correlation ID + return state.try_get_agent_response(correlation_id) + + except Exception as e: + logger.warning( + "[ClientAgentExecutor] Error reading entity state: %s", + e, + ) + return None + + +class OrchestrationAgentExecutor(DurableAgentExecutor[DurableAgentTask]): + """Execution strategy for orchestrations (sync/yield).""" + + def __init__(self, context: OrchestrationContext): + self._context = context + logger.debug("[OrchestrationAgentExecutor] Initialized") + + def get_run_request( + self, + message: str, + response_format: type[BaseModel] | None, + enable_tool_calls: bool, + ) -> RunRequest: + """Get the current run request from the orchestration context. + + Returns: + RunRequest: The current run request + """ + request = super().get_run_request( + message, + response_format, + enable_tool_calls, + ) + request.orchestration_id = self._context.instance_id + return request + + def run_durable_agent( + self, + agent_name: str, + run_request: RunRequest, + thread: AgentThread | None = None, + ) -> DurableAgentTask: + """Execute the agent via orchestration context. + + Calls the agent entity and returns a DurableAgentTask that can be yielded + in orchestrations to wait for the entity's response. + + Args: + agent_name: Name of the agent to execute + run_request: The run request containing message and optional response format + thread: Optional conversation thread (creates new if not provided) + + Returns: + DurableAgentTask: A task wrapping the entity call that yields AgentRunResponse + """ + # Resolve session + session_id = self._create_session_id(agent_name, thread) + + # Create the entity ID + entity_id = EntityInstanceId( + entity=session_id.entity_name, + key=session_id.key, + ) + + logger.debug( + "[OrchestrationAgentExecutor] correlation_id: %s entity_id: %s session_id: %s", + run_request.correlation_id, + entity_id, + session_id, + ) + + # Call the entity and get the underlying task + entity_task: Task[Any] = self._context.call_entity(entity_id, "run", run_request.to_dict()) # type: ignore + + # Wrap in DurableAgentTask for response transformation + return DurableAgentTask( + entity_task=entity_task, + response_format=run_request.response_format, + correlation_id=run_request.correlation_id, + ) diff --git a/python/packages/durabletask/agent_framework_durabletask/_models.py b/python/packages/durabletask/agent_framework_durabletask/_models.py index 14ca37f098..947ab7a17f 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_models.py +++ b/python/packages/durabletask/agent_framework_durabletask/_models.py @@ -8,12 +8,15 @@ from __future__ import annotations import inspect +import json +import uuid +from collections.abc import MutableMapping from dataclasses import dataclass -from datetime import datetime +from datetime import datetime, timezone from importlib import import_module from typing import TYPE_CHECKING, Any, cast -from agent_framework import Role +from agent_framework import AgentThread, Role from ._constants import REQUEST_RESPONSE_FORMAT_TEXT @@ -101,38 +104,38 @@ class RunRequest: role: The role of the message sender (user, system, or assistant) response_format: Optional Pydantic BaseModel type describing the structured response format enable_tool_calls: Whether to enable tool calls for this request - correlation_id: Optional correlation ID for tracking the response to this specific request + correlation_id: Correlation ID for tracking the response to this specific request created_at: Optional timestamp when the request was created orchestration_id: Optional ID of the orchestration that initiated this request """ message: str request_response_format: str + correlation_id: str role: Role = Role.USER response_format: type[BaseModel] | None = None enable_tool_calls: bool = True - correlation_id: str | None = None created_at: datetime | None = None orchestration_id: str | None = None def __init__( self, message: str, + correlation_id: str, request_response_format: str = REQUEST_RESPONSE_FORMAT_TEXT, role: Role | str | None = Role.USER, response_format: type[BaseModel] | None = None, enable_tool_calls: bool = True, - correlation_id: str | None = None, created_at: datetime | None = None, orchestration_id: str | None = None, ) -> None: self.message = message + self.correlation_id = correlation_id self.role = self.coerce_role(role) self.response_format = response_format self.request_response_format = request_response_format self.enable_tool_calls = enable_tool_calls - self.correlation_id = correlation_id - self.created_at = created_at + self.created_at = created_at if created_at is not None else datetime.now(tz=timezone.utc) self.orchestration_id = orchestration_id @staticmethod @@ -154,11 +157,10 @@ def to_dict(self) -> dict[str, Any]: "enable_tool_calls": self.enable_tool_calls, "role": self.role.value, "request_response_format": self.request_response_format, + "correlationId": self.correlation_id, } if self.response_format: result["response_format"] = serialize_response_format(self.response_format) - if self.correlation_id: - result["correlationId"] = self.correlation_id if self.created_at: result["created_at"] = self.created_at.isoformat() if self.orchestration_id: @@ -166,6 +168,16 @@ def to_dict(self) -> dict[str, Any]: return result + @classmethod + def from_json(cls, data: str) -> RunRequest: + """Create RunRequest from JSON string.""" + try: + dict_data = json.loads(data) + except json.JSONDecodeError as e: + raise ValueError("The durable agent state is not valid JSON.") from e + + return cls.from_dict(dict_data) + @classmethod def from_dict(cls, data: dict[str, Any]) -> RunRequest: """Create RunRequest from dictionary.""" @@ -176,13 +188,120 @@ def from_dict(cls, data: dict[str, Any]) -> RunRequest: except ValueError: created_at = None + correlation_id = data.get("correlationId") + if not correlation_id: + raise ValueError("correlationId is required in RunRequest data") + return cls( message=data.get("message", ""), + correlation_id=correlation_id, request_response_format=data.get("request_response_format", REQUEST_RESPONSE_FORMAT_TEXT), role=cls.coerce_role(data.get("role")), response_format=_deserialize_response_format(data.get("response_format")), enable_tool_calls=data.get("enable_tool_calls", True), - correlation_id=data.get("correlationId"), created_at=created_at, orchestration_id=data.get("orchestrationId"), ) + + +@dataclass +class AgentSessionId: + """Represents an agent session identifier (name + key).""" + + name: str + key: str + + ENTITY_NAME_PREFIX: str = "dafx-" + + @staticmethod + def to_entity_name(name: str) -> str: + return f"{AgentSessionId.ENTITY_NAME_PREFIX}{name}" + + @staticmethod + def with_random_key(name: str) -> AgentSessionId: + return AgentSessionId(name=name, key=uuid.uuid4().hex) + + @property + def entity_name(self) -> str: + return self.to_entity_name(self.name) + + def __str__(self) -> str: + return f"@{self.name}@{self.key}" + + def __repr__(self) -> str: + return f"AgentSessionId(name='{self.name}', key='{self.key}')" + + @staticmethod + def parse(session_id_string: str) -> AgentSessionId: + if not session_id_string.startswith("@"): + raise ValueError(f"Invalid agent session ID format: {session_id_string}") + + parts = session_id_string[1:].split("@", 1) + if len(parts) != 2: + raise ValueError(f"Invalid agent session ID format: {session_id_string}") + + return AgentSessionId(name=parts[0], key=parts[1]) + + +class DurableAgentThread(AgentThread): + """Durable agent thread that tracks the owning :class:`AgentSessionId`.""" + + _SERIALIZED_SESSION_ID_KEY = "durable_session_id" + + def __init__( + self, + *, + session_id: AgentSessionId | None = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self._session_id: AgentSessionId | None = session_id + + @property + def session_id(self) -> AgentSessionId | None: + return self._session_id + + @session_id.setter + def session_id(self, value: AgentSessionId | None) -> None: + self._session_id = value + + @classmethod + def from_session_id( + cls, + session_id: AgentSessionId, + **kwargs: Any, + ) -> DurableAgentThread: + return cls(session_id=session_id, **kwargs) + + async def serialize(self, **kwargs: Any) -> dict[str, Any]: + state = await super().serialize(**kwargs) + if self._session_id is not None: + state[self._SERIALIZED_SESSION_ID_KEY] = str(self._session_id) + return state + + @classmethod + async def deserialize( + cls, + serialized_thread_state: MutableMapping[str, Any], + *, + message_store: Any = None, + **kwargs: Any, + ) -> DurableAgentThread: + state_payload = dict(serialized_thread_state) + session_id_value = state_payload.pop(cls._SERIALIZED_SESSION_ID_KEY, None) + thread = await super().deserialize( + state_payload, + message_store=message_store, + **kwargs, + ) + if not isinstance(thread, DurableAgentThread): + raise TypeError("Deserialized thread is not a DurableAgentThread instance") + + if session_id_value is None: + return thread + + if not isinstance(session_id_value, str): + raise ValueError("durable_session_id must be a string when present in serialized state") + + thread.session_id = AgentSessionId.parse(session_id_value) + return thread diff --git a/python/packages/durabletask/agent_framework_durabletask/_orchestration_context.py b/python/packages/durabletask/agent_framework_durabletask/_orchestration_context.py new file mode 100644 index 0000000000..f5e1a15238 --- /dev/null +++ b/python/packages/durabletask/agent_framework_durabletask/_orchestration_context.py @@ -0,0 +1,75 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Orchestration context wrapper for Durable Task Agent Framework. + +This module provides the DurableAIAgentOrchestrationContext class for use inside +orchestration functions to interact with durable agents. +""" + +from __future__ import annotations + +from agent_framework import get_logger +from durabletask.task import OrchestrationContext + +from ._executors import DurableAgentTask, OrchestrationAgentExecutor +from ._shim import DurableAgentProvider, DurableAIAgent + +logger = get_logger("agent_framework.durabletask.orchestration_context") + + +class DurableAIAgentOrchestrationContext(DurableAgentProvider[DurableAgentTask]): + """Orchestration context wrapper for interacting with durable agents internally. + + This class wraps a durabletask OrchestrationContext and provides a convenient + interface for retrieving and executing durable agents from within orchestration + functions. + + Example: + ```python + from durabletask import Orchestration + from agent_framework_durabletask import DurableAIAgentOrchestrationContext + + + def my_orchestration(context: OrchestrationContext): + # Wrap the context + agent_context = DurableAIAgentOrchestrationContext(context) + + # Get an agent reference + agent = agent_context.get_agent("assistant") + + # Run the agent (returns a Task to be yielded) + result = yield agent.run("Hello, how are you?") + + return result.text + ``` + """ + + def __init__(self, context: OrchestrationContext): + """Initialize the orchestration context wrapper. + + Args: + context: The durabletask orchestration context to wrap + """ + self._context = context + self._executor = OrchestrationAgentExecutor(self._context) + logger.debug("[DurableAIAgentOrchestrationContext] Initialized") + + def get_agent(self, agent_name: str) -> DurableAIAgent[DurableAgentTask]: + """Retrieve a DurableAIAgent shim for the specified agent. + + This method returns a proxy object that can be used to execute the agent + within an orchestration. The agent's run() method will return a Task that + must be yielded. + + Args: + agent_name: Name of the agent to retrieve (without the dafx- prefix) + + Returns: + DurableAIAgent instance that can be used to run the agent + + Note: + Validation is deferred to execution time. The entity must be registered + on a worker with the name f"dafx-{agent_name}". + """ + logger.debug("[DurableAIAgentOrchestrationContext] Creating agent proxy for: %s", agent_name) + return DurableAIAgent(self._executor, agent_name) diff --git a/python/packages/durabletask/agent_framework_durabletask/_response_utils.py b/python/packages/durabletask/agent_framework_durabletask/_response_utils.py new file mode 100644 index 0000000000..aeb0e19c6c --- /dev/null +++ b/python/packages/durabletask/agent_framework_durabletask/_response_utils.py @@ -0,0 +1,62 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Shared utilities for handling AgentRunResponse parsing and validation.""" + +from typing import Any + +from agent_framework import AgentRunResponse, get_logger +from pydantic import BaseModel + +logger = get_logger("agent_framework.durabletask.response_utils") + + +def load_agent_response(agent_response: AgentRunResponse | dict[str, Any] | None) -> AgentRunResponse: + """Convert raw payloads into AgentRunResponse instance. + + Args: + agent_response: The response to convert, can be an AgentRunResponse, dict, or None + + Returns: + AgentRunResponse: The converted response object + + Raises: + ValueError: If agent_response is None + TypeError: If agent_response is an unsupported type + """ + if agent_response is None: + raise ValueError("agent_response cannot be None") + + logger.debug("[load_agent_response] Loading agent response of type: %s", type(agent_response)) + + if isinstance(agent_response, AgentRunResponse): + return agent_response + if isinstance(agent_response, dict): + logger.debug("[load_agent_response] Converting dict payload using AgentRunResponse.from_dict") + return AgentRunResponse.from_dict(agent_response) + + raise TypeError(f"Unsupported type for agent_response: {type(agent_response)}") + + +def ensure_response_format( + response_format: type[BaseModel] | None, + correlation_id: str, + response: AgentRunResponse, +) -> None: + """Ensure the AgentRunResponse value is parsed into the expected response_format. + + This function modifies the response in-place by parsing its value attribute + into the specified Pydantic model format. + + Args: + response_format: Optional Pydantic model class to parse the response value into + correlation_id: Correlation ID for logging purposes + response: The AgentRunResponse object to validate and parse + """ + if response_format is not None and not isinstance(response.value, response_format): + response.try_parse_value(response_format) + + logger.debug( + "[ensure_response_format] Loaded AgentRunResponse.value for correlation_id %s with type: %s", + correlation_id, + type(response.value).__name__, + ) diff --git a/python/packages/durabletask/agent_framework_durabletask/_shim.py b/python/packages/durabletask/agent_framework_durabletask/_shim.py new file mode 100644 index 0000000000..c2e9aee039 --- /dev/null +++ b/python/packages/durabletask/agent_framework_durabletask/_shim.py @@ -0,0 +1,185 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Durable Agent Shim for Durable Task Framework. + +This module provides the DurableAIAgent shim that implements AgentProtocol +and provides a consistent interface for both Client and Orchestration contexts. +The actual execution is delegated to the context-specific providers. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections.abc import AsyncIterator +from typing import Any, Generic, TypeVar + +from agent_framework import AgentProtocol, AgentRunResponseUpdate, AgentThread, ChatMessage +from pydantic import BaseModel + +from ._executors import DurableAgentExecutor +from ._models import DurableAgentThread + +# TypeVar for the task type returned by executors +# Covariant because TaskT only appears in return positions (output) +TaskT = TypeVar("TaskT", covariant=True) + + +class DurableAgentProvider(ABC, Generic[TaskT]): + """Abstract provider for constructing durable agent proxies. + + Implemented by context-specific wrappers (client/orchestration) to return a + `DurableAIAgent` shim backed by their respective `DurableAgentExecutor` + implementation, ensuring a consistent `get_agent` entry point regardless of + execution context. + """ + + @abstractmethod + def get_agent(self, agent_name: str) -> DurableAIAgent[TaskT]: + """Retrieve a DurableAIAgent shim for the specified agent. + + Args: + agent_name: Name of the agent to retrieve + + Returns: + DurableAIAgent instance that can be used to run the agent + + Raises: + NotImplementedError: Must be implemented by subclasses + """ + raise NotImplementedError("Subclasses must implement get_agent()") + + +class DurableAIAgent(AgentProtocol, Generic[TaskT]): + """A durable agent proxy that delegates execution to the provider. + + This class implements AgentProtocol but with one critical difference: + - AgentProtocol.run() returns a Coroutine (async, must await) + - DurableAIAgent.run() returns TaskT (sync Task object - must yield + or the AgentRunResponse directly in the case of TaskHubGrpcClient) + + This represents fundamentally different execution models but maintains the same + interface contract for all other properties and methods. + + The underlying provider determines how execution occurs (entity calls, HTTP requests, etc.) + and what type of Task object is returned. + + Type Parameters: + TaskT: The task type returned by this agent (e.g., AgentRunResponse, DurableAgentTask, AgentTask) + """ + + def __init__(self, executor: DurableAgentExecutor[TaskT], name: str, *, agent_id: str | None = None): + """Initialize the shim with a provider and agent name. + + Args: + executor: The execution provider (Client or OrchestrationContext) + name: The name of the agent to execute + agent_id: Optional unique identifier for the agent (defaults to name) + """ + self._executor = executor + self._name = name + self._id = agent_id if agent_id is not None else name + self._display_name = name + self._description = f"Durable agent proxy for {name}" + + @property + def id(self) -> str: + """Get the unique identifier for this agent.""" + return self._id + + @property + def name(self) -> str | None: + """Get the name of the agent.""" + return self._name + + @property + def display_name(self) -> str: + """Get the display name of the agent.""" + return self._display_name + + @property + def description(self) -> str | None: + """Get the description of the agent.""" + return self._description + + def run( # pyright: ignore[reportIncompatibleMethodOverride] + self, + messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, + *, + thread: AgentThread | None = None, + response_format: type[BaseModel] | None = None, + enable_tool_calls: bool = True, + ) -> TaskT: + """Execute the agent via the injected provider. + + Note: + This method overrides AgentProtocol.run() with a different return type: + - AgentProtocol.run() returns Coroutine[Any, Any, AgentRunResponse] (async) + - DurableAIAgent.run() returns TaskT (Task object for yielding) + + This is intentional to support orchestration contexts that use yield patterns + instead of async/await patterns. + + Returns: + TaskT: The task type specific to the executor + """ + message_str = self._normalize_messages(messages) + + run_request = self._executor.get_run_request( + message=message_str, + response_format=response_format, + enable_tool_calls=enable_tool_calls, + ) + + return self._executor.run_durable_agent( + agent_name=self._name, + run_request=run_request, + thread=thread, + ) + + def run_stream( + self, + messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, + *, + thread: AgentThread | None = None, + **kwargs: Any, + ) -> AsyncIterator[AgentRunResponseUpdate]: + """Run the agent with streaming (not supported for durable agents). + + Args: + messages: The message(s) to send to the agent + thread: Optional agent thread for conversation context + **kwargs: Additional arguments + + Raises: + NotImplementedError: Streaming is not supported for durable agents + """ + raise NotImplementedError("Streaming is not supported for durable agents") + + def get_new_thread(self, **kwargs: Any) -> DurableAgentThread: + """Create a new agent thread via the provider.""" + return self._executor.get_new_thread(self._name, **kwargs) + + def _normalize_messages(self, messages: str | ChatMessage | list[str] | list[ChatMessage] | None) -> str: + """Convert supported message inputs to a single string. + + Args: + messages: The messages to normalize + + Returns: + A single string representation of the messages + """ + if messages is None: + return "" + if isinstance(messages, str): + return messages + if isinstance(messages, ChatMessage): + return messages.text or "" + if isinstance(messages, list): + if not messages: + return "" + first_item = messages[0] + if isinstance(first_item, str): + return "\n".join(messages) # type: ignore[arg-type] + # List of ChatMessage + return "\n".join([msg.text or "" for msg in messages]) # type: ignore[union-attr] + return "" diff --git a/python/packages/durabletask/agent_framework_durabletask/_worker.py b/python/packages/durabletask/agent_framework_durabletask/_worker.py new file mode 100644 index 0000000000..fea4b8ba7c --- /dev/null +++ b/python/packages/durabletask/agent_framework_durabletask/_worker.py @@ -0,0 +1,218 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Worker wrapper for Durable Task Agent Framework. + +This module provides the DurableAIAgentWorker class that wraps a durabletask worker +and enables registration of agents as durable entities. +""" + +from __future__ import annotations + +import asyncio +from typing import Any + +from agent_framework import AgentProtocol, get_logger +from durabletask.worker import TaskHubGrpcWorker + +from ._callbacks import AgentResponseCallbackProtocol +from ._entities import AgentEntity, DurableTaskEntityStateProvider + +logger = get_logger("agent_framework.durabletask.worker") + + +class DurableAIAgentWorker: + """Wrapper for durabletask worker that enables agent registration. + + This class wraps an existing TaskHubGrpcWorker instance and provides + a convenient interface for registering agents as durable entities. + + Example: + ```python + from durabletask import TaskHubGrpcWorker + from agent_framework import ChatAgent + from agent_framework_durabletask import DurableAIAgentWorker + + # Create the underlying worker + worker = TaskHubGrpcWorker(host_address="localhost:4001") + + # Wrap it with the agent worker + agent_worker = DurableAIAgentWorker(worker) + + # Register agents + my_agent = ChatAgent(chat_client=client, name="assistant") + agent_worker.add_agent(my_agent) + + # Start the worker + worker.start() + ``` + """ + + def __init__( + self, + worker: TaskHubGrpcWorker, + callback: AgentResponseCallbackProtocol | None = None, + ): + """Initialize the worker wrapper. + + Args: + worker: The durabletask worker instance to wrap + callback: Optional callback for agent response notifications + """ + self._worker = worker + self._callback = callback + self._registered_agents: dict[str, AgentProtocol] = {} + logger.debug("[DurableAIAgentWorker] Initialized with worker type: %s", type(worker).__name__) + + def add_agent( + self, + agent: AgentProtocol, + callback: AgentResponseCallbackProtocol | None = None, + ) -> None: + """Register an agent with the worker. + + This method creates a durable entity class for the agent and registers + it with the underlying durabletask worker. The entity will be accessible + by the name "dafx-{agent_name}". + + Args: + agent: The agent to register (must have a name) + callback: Optional callback for this specific agent (overrides worker-level callback) + + Raises: + ValueError: If the agent doesn't have a name or is already registered + """ + agent_name = agent.name + if not agent_name: + raise ValueError("Agent must have a name to be registered") + + if agent_name in self._registered_agents: + raise ValueError(f"Agent '{agent_name}' is already registered") + + logger.info("[DurableAIAgentWorker] Registering agent: %s as entity: dafx-%s", agent_name, agent_name) + + # Store the agent reference + self._registered_agents[agent_name] = agent + + # Use agent-specific callback if provided, otherwise use worker-level callback + effective_callback = callback or self._callback + + # Create a configured entity class using the factory + entity_class = self.__create_agent_entity(agent, effective_callback) + + # Register the entity class with the worker + # The worker.add_entity method takes a class + entity_registered: str = self._worker.add_entity(entity_class) # pyright: ignore[reportUnknownMemberType] + + logger.debug( + "[DurableAIAgentWorker] Successfully registered entity class %s for agent: %s", + entity_registered, + agent_name, + ) + + def start(self) -> None: + """Start the worker to begin processing tasks. + + Note: + This method delegates to the underlying worker's start method. + The worker will block until stopped. + """ + logger.info("[DurableAIAgentWorker] Starting worker with %d registered agents", len(self._registered_agents)) + self._worker.start() + + def stop(self) -> None: + """Stop the worker gracefully. + + Note: + This method delegates to the underlying worker's stop method. + """ + logger.info("[DurableAIAgentWorker] Stopping worker") + self._worker.stop() + + @property + def registered_agent_names(self) -> list[str]: + """Get the names of all registered agents. + + Returns: + List of agent names (without the dafx- prefix) + """ + return list(self._registered_agents.keys()) + + def __create_agent_entity( + self, + agent: AgentProtocol, + callback: AgentResponseCallbackProtocol | None = None, + ) -> type[DurableTaskEntityStateProvider]: + """Factory function to create a DurableEntity class configured with an agent. + + This factory creates a new class that combines the entity state provider + with the agent execution logic. Each agent gets its own entity class. + + Args: + agent: The agent instance to wrap + callback: Optional callback for agent responses + + Returns: + A new DurableEntity subclass configured for this agent + """ + agent_name = agent.name or type(agent).__name__ + entity_name = f"dafx-{agent_name}" + + class ConfiguredAgentEntity(DurableTaskEntityStateProvider): + """Durable entity configured with a specific agent instance.""" + + def __init__(self) -> None: + super().__init__() + # Create the AgentEntity with this state provider + self._agent_entity = AgentEntity( + agent=agent, + callback=callback, + state_provider=self, + ) + logger.debug( + "[ConfiguredAgentEntity] Initialized entity for agent: %s (entity name: %s)", + agent_name, + entity_name, + ) + + def run(self, request: Any) -> Any: + """Handle run requests from clients or orchestrations. + + Args: + request: RunRequest as dict or string + + Returns: + AgentRunResponse as dict + """ + logger.debug("[ConfiguredAgentEntity.run] Executing agent: %s", agent_name) + # Get or create event loop for async execution + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + # Run the async agent execution synchronously + if loop.is_running(): + # If loop is already running (shouldn't happen in entity context), + # create a temporary loop + temp_loop = asyncio.new_event_loop() + try: + response = temp_loop.run_until_complete(self._agent_entity.run(request)) + finally: + temp_loop.close() + else: + response = loop.run_until_complete(self._agent_entity.run(request)) + + return response.to_dict() + + def reset(self) -> None: + """Reset the agent's conversation history.""" + logger.debug("[ConfiguredAgentEntity.reset] Resetting agent: %s", agent_name) + self._agent_entity.reset() + + # Set the entity name to match the prefixed agent name + # This is used by durabletask to register the entity + ConfiguredAgentEntity.__name__ = entity_name + ConfiguredAgentEntity.__qualname__ = entity_name + + return ConfiguredAgentEntity diff --git a/python/packages/durabletask/tests/test_agent_session_id.py b/python/packages/durabletask/tests/test_agent_session_id.py new file mode 100644 index 0000000000..92c4e5f872 --- /dev/null +++ b/python/packages/durabletask/tests/test_agent_session_id.py @@ -0,0 +1,272 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Unit tests for AgentSessionId and DurableAgentThread.""" + +import pytest +from agent_framework import AgentThread + +from agent_framework_durabletask._models import AgentSessionId, DurableAgentThread + + +class TestAgentSessionId: + """Test suite for AgentSessionId.""" + + def test_init_creates_session_id(self) -> None: + """Test that AgentSessionId initializes correctly.""" + session_id = AgentSessionId(name="AgentEntity", key="test-key-123") + + assert session_id.name == "AgentEntity" + assert session_id.key == "test-key-123" + + def test_with_random_key_generates_guid(self) -> None: + """Test that with_random_key generates a GUID.""" + session_id = AgentSessionId.with_random_key(name="AgentEntity") + + assert session_id.name == "AgentEntity" + assert len(session_id.key) == 32 # UUID hex is 32 chars + # Verify it's a valid hex string + int(session_id.key, 16) + + def test_with_random_key_unique_keys(self) -> None: + """Test that with_random_key generates unique keys.""" + session_id1 = AgentSessionId.with_random_key(name="AgentEntity") + session_id2 = AgentSessionId.with_random_key(name="AgentEntity") + + assert session_id1.key != session_id2.key + + def test_str_representation(self) -> None: + """Test string representation.""" + session_id = AgentSessionId(name="AgentEntity", key="test-key-123") + str_repr = str(session_id) + + assert str_repr == "@AgentEntity@test-key-123" + + def test_repr_representation(self) -> None: + """Test repr representation.""" + session_id = AgentSessionId(name="AgentEntity", key="test-key") + repr_str = repr(session_id) + + assert "AgentSessionId" in repr_str + assert "AgentEntity" in repr_str + assert "test-key" in repr_str + + def test_parse_valid_session_id(self) -> None: + """Test parsing valid session ID string.""" + session_id = AgentSessionId.parse("@AgentEntity@test-key-123") + + assert session_id.name == "AgentEntity" + assert session_id.key == "test-key-123" + + def test_parse_invalid_format_no_prefix(self) -> None: + """Test parsing invalid format without @ prefix.""" + with pytest.raises(ValueError) as exc_info: + AgentSessionId.parse("AgentEntity@test-key") + + assert "Invalid agent session ID format" in str(exc_info.value) + + def test_parse_invalid_format_single_part(self) -> None: + """Test parsing invalid format with single part.""" + with pytest.raises(ValueError) as exc_info: + AgentSessionId.parse("@AgentEntity") + + assert "Invalid agent session ID format" in str(exc_info.value) + + def test_parse_with_multiple_at_signs_in_key(self) -> None: + """Test parsing with @ signs in the key.""" + session_id = AgentSessionId.parse("@AgentEntity@key-with@symbols") + + assert session_id.name == "AgentEntity" + assert session_id.key == "key-with@symbols" + + def test_parse_round_trip(self) -> None: + """Test round-trip parse and string conversion.""" + original = AgentSessionId(name="AgentEntity", key="test-key") + str_repr = str(original) + parsed = AgentSessionId.parse(str_repr) + + assert parsed.name == original.name + assert parsed.key == original.key + + def test_to_entity_name_adds_prefix(self) -> None: + """Test that to_entity_name adds the dafx- prefix.""" + entity_name = AgentSessionId.to_entity_name("TestAgent") + assert entity_name == "dafx-TestAgent" + + +class TestDurableAgentThread: + """Test suite for DurableAgentThread.""" + + def test_init_with_session_id(self) -> None: + """Test DurableAgentThread initialization with session ID.""" + session_id = AgentSessionId(name="TestAgent", key="test-key") + thread = DurableAgentThread(session_id=session_id) + + assert thread.session_id is not None + assert thread.session_id == session_id + + def test_init_without_session_id(self) -> None: + """Test DurableAgentThread initialization without session ID.""" + thread = DurableAgentThread() + + assert thread.session_id is None + + def test_session_id_setter(self) -> None: + """Test setting a session ID to an existing thread.""" + thread = DurableAgentThread() + assert thread.session_id is None + + session_id = AgentSessionId(name="TestAgent", key="test-key") + thread.session_id = session_id + + assert thread.session_id is not None + assert thread.session_id == session_id + assert thread.session_id.name == "TestAgent" + + def test_from_session_id(self) -> None: + """Test creating DurableAgentThread from session ID.""" + session_id = AgentSessionId(name="TestAgent", key="test-key") + thread = DurableAgentThread.from_session_id(session_id) + + assert isinstance(thread, DurableAgentThread) + assert thread.session_id is not None + assert thread.session_id == session_id + assert thread.session_id.name == "TestAgent" + assert thread.session_id.key == "test-key" + + def test_from_session_id_with_service_thread_id(self) -> None: + """Test creating DurableAgentThread with service thread ID.""" + session_id = AgentSessionId(name="TestAgent", key="test-key") + thread = DurableAgentThread.from_session_id(session_id, service_thread_id="service-123") + + assert thread.session_id is not None + assert thread.session_id == session_id + assert thread.service_thread_id == "service-123" + + async def test_serialize_with_session_id(self) -> None: + """Test serialization includes session ID.""" + session_id = AgentSessionId(name="TestAgent", key="test-key") + thread = DurableAgentThread(session_id=session_id) + + serialized = await thread.serialize() + + assert isinstance(serialized, dict) + assert "durable_session_id" in serialized + assert serialized["durable_session_id"] == "@TestAgent@test-key" + + async def test_serialize_without_session_id(self) -> None: + """Test serialization without session ID.""" + thread = DurableAgentThread() + + serialized = await thread.serialize() + + assert isinstance(serialized, dict) + assert "durable_session_id" not in serialized + + async def test_deserialize_with_session_id(self) -> None: + """Test deserialization restores session ID.""" + serialized = { + "service_thread_id": "thread-123", + "durable_session_id": "@TestAgent@test-key", + } + + thread = await DurableAgentThread.deserialize(serialized) + + assert isinstance(thread, DurableAgentThread) + assert thread.session_id is not None + assert thread.session_id.name == "TestAgent" + assert thread.session_id.key == "test-key" + assert thread.service_thread_id == "thread-123" + + async def test_deserialize_without_session_id(self) -> None: + """Test deserialization without session ID.""" + serialized = { + "service_thread_id": "thread-456", + } + + thread = await DurableAgentThread.deserialize(serialized) + + assert isinstance(thread, DurableAgentThread) + assert thread.session_id is None + assert thread.service_thread_id == "thread-456" + + async def test_round_trip_serialization(self) -> None: + """Test round-trip serialization preserves session ID.""" + session_id = AgentSessionId(name="TestAgent", key="test-key-789") + original = DurableAgentThread(session_id=session_id) + + serialized = await original.serialize() + restored = await DurableAgentThread.deserialize(serialized) + + assert isinstance(restored, DurableAgentThread) + assert restored.session_id is not None + assert restored.session_id.name == session_id.name + assert restored.session_id.key == session_id.key + + async def test_deserialize_invalid_session_id_type(self) -> None: + """Test deserialization with invalid session ID type raises error.""" + serialized = { + "service_thread_id": "thread-123", + "durable_session_id": 12345, # Invalid type + } + + with pytest.raises(ValueError, match="durable_session_id must be a string"): + await DurableAgentThread.deserialize(serialized) + + +class TestAgentThreadCompatibility: + """Test suite for compatibility between AgentThread and DurableAgentThread.""" + + async def test_agent_thread_serialize(self) -> None: + """Test that base AgentThread can be serialized.""" + thread = AgentThread() + + serialized = await thread.serialize() + + assert isinstance(serialized, dict) + assert "service_thread_id" in serialized + + async def test_agent_thread_deserialize(self) -> None: + """Test that base AgentThread can be deserialized.""" + thread = AgentThread() + serialized = await thread.serialize() + + restored = await AgentThread.deserialize(serialized) + + assert isinstance(restored, AgentThread) + assert restored.service_thread_id == thread.service_thread_id + + async def test_durable_thread_is_agent_thread(self) -> None: + """Test that DurableAgentThread is an AgentThread.""" + thread = DurableAgentThread() + + assert isinstance(thread, AgentThread) + assert isinstance(thread, DurableAgentThread) + + +class TestModelIntegration: + """Test suite for integration between models.""" + + def test_session_id_string_format(self) -> None: + """Test that AgentSessionId string format is consistent.""" + session_id = AgentSessionId.with_random_key("AgentEntity") + session_id_str = str(session_id) + + assert session_id_str.startswith("@AgentEntity@") + + async def test_thread_with_session_preserves_on_serialization(self) -> None: + """Test that thread with session ID preserves it through serialization.""" + session_id = AgentSessionId(name="TestAgent", key="preserved-key") + thread = DurableAgentThread.from_session_id(session_id) + + # Serialize and deserialize + serialized = await thread.serialize() + restored = await DurableAgentThread.deserialize(serialized) + + # Session ID should be preserved + assert restored.session_id is not None + assert restored.session_id.name == "TestAgent" + assert restored.session_id.key == "preserved-key" + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "--tb=short"]) diff --git a/python/packages/durabletask/tests/test_client.py b/python/packages/durabletask/tests/test_client.py new file mode 100644 index 0000000000..cf2ccfe1af --- /dev/null +++ b/python/packages/durabletask/tests/test_client.py @@ -0,0 +1,142 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Unit tests for DurableAIAgentClient. + +Focuses on critical client workflows: agent retrieval, protocol compliance, and integration. +Run with: pytest tests/test_client.py -v +""" + +from unittest.mock import Mock + +import pytest +from agent_framework import AgentProtocol + +from agent_framework_durabletask import DurableAgentThread, DurableAIAgentClient +from agent_framework_durabletask._constants import DEFAULT_MAX_POLL_RETRIES, DEFAULT_POLL_INTERVAL_SECONDS +from agent_framework_durabletask._shim import DurableAIAgent + + +@pytest.fixture +def mock_grpc_client() -> Mock: + """Create a mock TaskHubGrpcClient for testing.""" + return Mock() + + +@pytest.fixture +def agent_client(mock_grpc_client: Mock) -> DurableAIAgentClient: + """Create a DurableAIAgentClient with mock gRPC client.""" + return DurableAIAgentClient(mock_grpc_client) + + +@pytest.fixture +def agent_client_with_custom_polling(mock_grpc_client: Mock) -> DurableAIAgentClient: + """Create a DurableAIAgentClient with custom polling parameters.""" + return DurableAIAgentClient( + mock_grpc_client, + max_poll_retries=15, + poll_interval_seconds=0.5, + ) + + +class TestDurableAIAgentClientGetAgent: + """Test core workflow: retrieving agents from the client.""" + + def test_get_agent_returns_durable_agent_shim(self, agent_client: DurableAIAgentClient) -> None: + """Verify get_agent returns a DurableAIAgent instance.""" + agent = agent_client.get_agent("assistant") + + assert isinstance(agent, DurableAIAgent) + assert isinstance(agent, AgentProtocol) + + def test_get_agent_shim_has_correct_name(self, agent_client: DurableAIAgentClient) -> None: + """Verify retrieved agent has the correct name.""" + agent = agent_client.get_agent("my_agent") + + assert agent.name == "my_agent" + + def test_get_agent_multiple_times_returns_new_instances(self, agent_client: DurableAIAgentClient) -> None: + """Verify multiple get_agent calls return independent instances.""" + agent1 = agent_client.get_agent("assistant") + agent2 = agent_client.get_agent("assistant") + + assert agent1 is not agent2 # Different object instances + + def test_get_agent_different_agents(self, agent_client: DurableAIAgentClient) -> None: + """Verify client can retrieve multiple different agents.""" + agent1 = agent_client.get_agent("agent1") + agent2 = agent_client.get_agent("agent2") + + assert agent1.name == "agent1" + assert agent2.name == "agent2" + + +class TestDurableAIAgentClientIntegration: + """Test integration scenarios between client and agent shim.""" + + def test_client_agent_has_working_run_method(self, agent_client: DurableAIAgentClient) -> None: + """Verify agent from client has callable run method (even if not yet implemented).""" + agent = agent_client.get_agent("assistant") + + assert hasattr(agent, "run") + assert callable(agent.run) + + def test_client_agent_can_create_threads(self, agent_client: DurableAIAgentClient) -> None: + """Verify agent from client can create DurableAgentThread instances.""" + agent = agent_client.get_agent("assistant") + + thread = agent.get_new_thread() + + assert isinstance(thread, DurableAgentThread) + + def test_client_agent_thread_with_parameters(self, agent_client: DurableAIAgentClient) -> None: + """Verify agent can create threads with custom parameters.""" + agent = agent_client.get_agent("assistant") + + thread = agent.get_new_thread(service_thread_id="client-session-123") + + assert isinstance(thread, DurableAgentThread) + assert thread.service_thread_id == "client-session-123" + + +class TestDurableAIAgentClientPollingConfiguration: + """Test polling configuration parameters for DurableAIAgentClient.""" + + def test_client_uses_default_polling_parameters(self, agent_client: DurableAIAgentClient) -> None: + """Verify client initializes with default polling parameters.""" + assert agent_client.max_poll_retries == DEFAULT_MAX_POLL_RETRIES + assert agent_client.poll_interval_seconds == DEFAULT_POLL_INTERVAL_SECONDS + + def test_client_accepts_custom_polling_parameters( + self, agent_client_with_custom_polling: DurableAIAgentClient + ) -> None: + """Verify client accepts and stores custom polling parameters.""" + assert agent_client_with_custom_polling.max_poll_retries == 15 + assert agent_client_with_custom_polling.poll_interval_seconds == 0.5 + + def test_client_validates_max_poll_retries(self, mock_grpc_client: Mock) -> None: + """Verify client validates and normalizes max_poll_retries.""" + # Test with zero - should enforce minimum of 1 + client = DurableAIAgentClient(mock_grpc_client, max_poll_retries=0) + assert client.max_poll_retries == 1 + + # Test with negative - should enforce minimum of 1 + client = DurableAIAgentClient(mock_grpc_client, max_poll_retries=-5) + assert client.max_poll_retries == 1 + + def test_client_validates_poll_interval_seconds(self, mock_grpc_client: Mock) -> None: + """Verify client validates and normalizes poll_interval_seconds.""" + # Test with zero - should use default + client = DurableAIAgentClient(mock_grpc_client, poll_interval_seconds=0) + assert client.poll_interval_seconds == DEFAULT_POLL_INTERVAL_SECONDS + + # Test with negative - should use default + client = DurableAIAgentClient(mock_grpc_client, poll_interval_seconds=-0.5) + assert client.poll_interval_seconds == DEFAULT_POLL_INTERVAL_SECONDS + + # Test with valid float + client = DurableAIAgentClient(mock_grpc_client, poll_interval_seconds=2.5) + assert client.poll_interval_seconds == 2.5 + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "--tb=short"]) diff --git a/python/packages/durabletask/tests/test_durable_agent_state.py b/python/packages/durabletask/tests/test_durable_agent_state.py index 94e29f29a5..ec9a2f0cc9 100644 --- a/python/packages/durabletask/tests/test_durable_agent_state.py +++ b/python/packages/durabletask/tests/test_durable_agent_state.py @@ -112,20 +112,21 @@ class TestDurableAgentStateMessageCreatedAt: """Test suite for DurableAgentStateMessage created_at field handling.""" def test_message_from_run_request_without_created_at_preserves_none(self) -> None: - """Test from_run_request preserves None created_at instead of defaulting to current time. + """Test from_run_request handles auto-populated created_at from RunRequest. - When a RunRequest has no created_at value, the resulting DurableAgentStateMessage - should also have None for created_at, not default to current UTC time. + When a RunRequest is created with None for created_at, RunRequest defaults it to + current UTC time. The resulting DurableAgentStateMessage should have this timestamp. """ run_request = RunRequest( message="test message", correlation_id="corr-run", - created_at=None, # Explicitly None + created_at=None, # RunRequest will default this to current time ) durable_message = DurableAgentStateMessage.from_run_request(run_request) - assert durable_message.created_at is None + # RunRequest auto-populates created_at, so it should not be None + assert durable_message.created_at is not None def test_message_from_run_request_with_created_at_parses_correctly(self) -> None: """Test from_run_request correctly parses a valid created_at timestamp.""" diff --git a/python/packages/durabletask/tests/test_executors.py b/python/packages/durabletask/tests/test_executors.py new file mode 100644 index 0000000000..a42200bdea --- /dev/null +++ b/python/packages/durabletask/tests/test_executors.py @@ -0,0 +1,320 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Unit tests for DurableAgentExecutor implementations. + +Focuses on critical behavioral flows for executor strategies. +Run with: pytest tests/test_executors.py -v +""" + +import time +from typing import Any +from unittest.mock import Mock + +import pytest +from agent_framework import AgentRunResponse, Role +from durabletask.entities import EntityInstanceId +from durabletask.task import Task +from pydantic import BaseModel + +from agent_framework_durabletask import DurableAgentThread +from agent_framework_durabletask._constants import DEFAULT_MAX_POLL_RETRIES, DEFAULT_POLL_INTERVAL_SECONDS +from agent_framework_durabletask._executors import ( + ClientAgentExecutor, + DurableAgentTask, + OrchestrationAgentExecutor, +) +from agent_framework_durabletask._models import AgentSessionId, RunRequest + + +# Fixtures +@pytest.fixture +def mock_client() -> Mock: + """Provide a mock client for ClientAgentExecutor tests.""" + client = Mock() + client.signal_entity = Mock() + client.get_entity = Mock(return_value=None) + return client + + +@pytest.fixture +def mock_entity_task() -> Mock: + """Provide a mock entity task.""" + return Mock(spec=Task) + + +@pytest.fixture +def mock_orchestration_context(mock_entity_task: Mock) -> Mock: + """Provide a mock orchestration context with call_entity configured.""" + context = Mock() + context.call_entity = Mock(return_value=mock_entity_task) + return context + + +@pytest.fixture +def sample_run_request() -> RunRequest: + """Provide a sample RunRequest for tests.""" + return RunRequest(message="test message", correlation_id="test-123") + + +@pytest.fixture +def client_executor(mock_client: Mock) -> ClientAgentExecutor: + """Provide a ClientAgentExecutor with minimal polling for fast tests.""" + return ClientAgentExecutor(mock_client, max_poll_retries=1, poll_interval_seconds=0.01) + + +@pytest.fixture +def orchestration_executor(mock_orchestration_context: Mock) -> OrchestrationAgentExecutor: + """Provide an OrchestrationAgentExecutor.""" + return OrchestrationAgentExecutor(mock_orchestration_context) + + +@pytest.fixture +def successful_agent_response() -> dict[str, Any]: + """Provide a successful agent response dictionary.""" + return { + "messages": [{"role": "assistant", "contents": [{"type": "text", "text": "Hello!"}]}], + "created_at": "2025-12-30T10:00:00Z", + } + + +class TestExecutorThreadCreation: + """Test that executors properly create DurableAgentThread with parameters.""" + + def test_client_executor_creates_durable_thread(self, mock_client: Mock) -> None: + """Verify ClientAgentExecutor creates DurableAgentThread instances.""" + executor = ClientAgentExecutor(mock_client) + + thread = executor.get_new_thread("test_agent") + + assert isinstance(thread, DurableAgentThread) + + def test_client_executor_forwards_kwargs_to_thread(self, mock_client: Mock) -> None: + """Verify ClientAgentExecutor forwards kwargs to DurableAgentThread creation.""" + executor = ClientAgentExecutor(mock_client) + + thread = executor.get_new_thread("test_agent", service_thread_id="client-123") + + assert isinstance(thread, DurableAgentThread) + assert thread.service_thread_id == "client-123" + + def test_orchestration_executor_creates_durable_thread( + self, orchestration_executor: OrchestrationAgentExecutor + ) -> None: + """Verify OrchestrationAgentExecutor creates DurableAgentThread instances.""" + thread = orchestration_executor.get_new_thread("test_agent") + + assert isinstance(thread, DurableAgentThread) + + def test_orchestration_executor_forwards_kwargs_to_thread( + self, orchestration_executor: OrchestrationAgentExecutor + ) -> None: + """Verify OrchestrationAgentExecutor forwards kwargs to DurableAgentThread creation.""" + thread = orchestration_executor.get_new_thread("test_agent", service_thread_id="orch-456") + + assert isinstance(thread, DurableAgentThread) + assert thread.service_thread_id == "orch-456" + + +class TestClientAgentExecutorRun: + """Test that ClientAgentExecutor.run_durable_agent works as implemented.""" + + def test_client_executor_run_returns_response( + self, client_executor: ClientAgentExecutor, sample_run_request: RunRequest + ) -> None: + """Verify ClientAgentExecutor.run_durable_agent returns AgentRunResponse (synchronous).""" + result = client_executor.run_durable_agent("test_agent", sample_run_request) + + # Verify it returns an AgentRunResponse (synchronous, not a coroutine) + assert isinstance(result, AgentRunResponse) + assert result is not None + + +class TestClientAgentExecutorPollingConfiguration: + """Test polling configuration parameters for ClientAgentExecutor.""" + + def test_executor_uses_default_polling_parameters(self, mock_client: Mock) -> None: + """Verify executor initializes with default polling parameters.""" + executor = ClientAgentExecutor(mock_client) + + assert executor.max_poll_retries == DEFAULT_MAX_POLL_RETRIES + assert executor.poll_interval_seconds == DEFAULT_POLL_INTERVAL_SECONDS + + def test_executor_accepts_custom_polling_parameters(self, mock_client: Mock) -> None: + """Verify executor accepts and stores custom polling parameters.""" + executor = ClientAgentExecutor(mock_client, max_poll_retries=20, poll_interval_seconds=0.5) + + assert executor.max_poll_retries == 20 + assert executor.poll_interval_seconds == 0.5 + + def test_executor_respects_custom_max_poll_retries(self, mock_client: Mock, sample_run_request: RunRequest) -> None: + """Verify executor respects custom max_poll_retries during polling.""" + # Create executor with only 2 retries + executor = ClientAgentExecutor(mock_client, max_poll_retries=2, poll_interval_seconds=0.01) + + # Run the agent + result = executor.run_durable_agent("test_agent", sample_run_request) + + # Verify it returns AgentRunResponse (should timeout after 2 attempts) + assert isinstance(result, AgentRunResponse) + + # Verify get_entity was called 2 times (max_poll_retries) + assert mock_client.get_entity.call_count == 2 + + def test_executor_respects_custom_poll_interval(self, mock_client: Mock, sample_run_request: RunRequest) -> None: + """Verify executor respects custom poll_interval_seconds during polling.""" + # Create executor with very short interval + executor = ClientAgentExecutor(mock_client, max_poll_retries=3, poll_interval_seconds=0.01) + + # Measure time taken + start = time.time() + result = executor.run_durable_agent("test_agent", sample_run_request) + elapsed = time.time() - start + + # Should take roughly 3 * 0.01 = 0.03 seconds (plus overhead) + # Be generous with timing to avoid flakiness + assert elapsed < 0.2 # Should be quick with 0.01 interval + assert isinstance(result, AgentRunResponse) + + +class TestOrchestrationAgentExecutorRun: + """Test OrchestrationAgentExecutor.run_durable_agent implementation.""" + + def test_orchestration_executor_run_returns_durable_agent_task( + self, orchestration_executor: OrchestrationAgentExecutor, sample_run_request: RunRequest + ) -> None: + """Verify OrchestrationAgentExecutor.run_durable_agent returns DurableAgentTask.""" + result = orchestration_executor.run_durable_agent("test_agent", sample_run_request) + + assert isinstance(result, DurableAgentTask) + + def test_orchestration_executor_calls_entity_with_correct_parameters( + self, + mock_orchestration_context: Mock, + orchestration_executor: OrchestrationAgentExecutor, + sample_run_request: RunRequest, + ) -> None: + """Verify call_entity is invoked with correct entity ID and request.""" + orchestration_executor.run_durable_agent("test_agent", sample_run_request) + + # Verify call_entity was called once + assert mock_orchestration_context.call_entity.call_count == 1 + + # Get the call arguments + call_args = mock_orchestration_context.call_entity.call_args + entity_id_arg = call_args[0][0] + operation_arg = call_args[0][1] + request_dict_arg = call_args[0][2] + + # Verify entity ID + assert isinstance(entity_id_arg, EntityInstanceId) + assert entity_id_arg.entity == "dafx-test_agent" + + # Verify operation name + assert operation_arg == "run" + + # Verify request dict + assert request_dict_arg == sample_run_request.to_dict() + + def test_orchestration_executor_uses_thread_session_id( + self, + mock_orchestration_context: Mock, + orchestration_executor: OrchestrationAgentExecutor, + sample_run_request: RunRequest, + ) -> None: + """Verify executor uses thread's session ID when provided.""" + # Create thread with specific session ID + session_id = AgentSessionId(name="test_agent", key="specific-key-123") + thread = DurableAgentThread.from_session_id(session_id) + + result = orchestration_executor.run_durable_agent("test_agent", sample_run_request, thread=thread) + + # Verify call_entity was called with the specific key + call_args = mock_orchestration_context.call_entity.call_args + entity_id_arg = call_args[0][0] + + assert entity_id_arg.key == "specific-key-123" + assert isinstance(result, DurableAgentTask) + + +class TestDurableAgentTask: + """Test DurableAgentTask completion and response transformation.""" + + def test_durable_agent_task_transforms_successful_result( + self, mock_entity_task: Mock, successful_agent_response: dict[str, Any] + ) -> None: + """Verify DurableAgentTask converts successful entity result to AgentRunResponse.""" + mock_entity_task.is_failed = False + mock_entity_task.get_result = Mock(return_value=successful_agent_response) + + task = DurableAgentTask(entity_task=mock_entity_task, response_format=None, correlation_id="test-123") + + # Simulate child task completion + task.on_child_completed(mock_entity_task) + + assert task.is_complete + result = task.get_result() + assert isinstance(result, AgentRunResponse) + assert len(result.messages) == 1 + assert result.messages[0].role == Role.ASSISTANT + + def test_durable_agent_task_propagates_failure(self, mock_entity_task: Mock) -> None: + """Verify DurableAgentTask propagates task failures.""" + mock_entity_task.is_failed = True + mock_entity_task.get_exception = Mock(return_value=ValueError("Entity error")) + + task = DurableAgentTask(entity_task=mock_entity_task, response_format=None, correlation_id="test-123") + + # Simulate child task completion with failure + task.on_child_completed(mock_entity_task) + + assert task.is_complete + assert task.is_failed + exception = task.get_exception() + assert isinstance(exception, ValueError) + assert str(exception) == "Entity error" + + def test_durable_agent_task_validates_response_format(self, mock_entity_task: Mock) -> None: + """Verify DurableAgentTask validates response format when provided.""" + mock_entity_task.is_failed = False + mock_entity_task.get_result = Mock( + return_value={ + "messages": [{"role": "assistant", "contents": [{"type": "text", "text": '{"answer": "42"}'}]}], + "created_at": "2025-12-30T10:00:00Z", + } + ) + + class TestResponse(BaseModel): + answer: str + + task = DurableAgentTask(entity_task=mock_entity_task, response_format=TestResponse, correlation_id="test-123") + + # Simulate child task completion + task.on_child_completed(mock_entity_task) + + assert task.is_complete + result = task.get_result() + assert isinstance(result, AgentRunResponse) + + def test_durable_agent_task_ignores_duplicate_completion( + self, mock_entity_task: Mock, successful_agent_response: dict[str, Any] + ) -> None: + """Verify DurableAgentTask ignores duplicate completion calls.""" + mock_entity_task.is_failed = False + mock_entity_task.get_result = Mock(return_value=successful_agent_response) + + task = DurableAgentTask(entity_task=mock_entity_task, response_format=None, correlation_id="test-123") + + # Simulate child task completion twice + task.on_child_completed(mock_entity_task) + first_result = task.get_result() + + task.on_child_completed(mock_entity_task) + second_result = task.get_result() + + # Should be the same result, get_result should only be called once + assert first_result is second_result + assert mock_entity_task.get_result.call_count == 1 + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "--tb=short"]) diff --git a/python/packages/durabletask/tests/test_models.py b/python/packages/durabletask/tests/test_models.py index f14bffcaf8..ffcfe868e1 100644 --- a/python/packages/durabletask/tests/test_models.py +++ b/python/packages/durabletask/tests/test_models.py @@ -18,9 +18,10 @@ class TestRunRequest: def test_init_with_defaults(self) -> None: """Test RunRequest initialization with defaults.""" - request = RunRequest(message="Hello") + request = RunRequest(message="Hello", correlation_id="corr-001") assert request.message == "Hello" + assert request.correlation_id == "corr-001" assert request.role == Role.USER assert request.response_format is None assert request.enable_tool_calls is True @@ -30,30 +31,33 @@ def test_init_with_all_fields(self) -> None: schema = ModuleStructuredResponse request = RunRequest( message="Hello", + correlation_id="corr-002", role=Role.SYSTEM, response_format=schema, enable_tool_calls=False, ) assert request.message == "Hello" + assert request.correlation_id == "corr-002" assert request.role == Role.SYSTEM assert request.response_format is schema assert request.enable_tool_calls is False def test_init_coerces_string_role(self) -> None: """Ensure string role values are coerced into Role instances.""" - request = RunRequest(message="Hello", role="system") # type: ignore[arg-type] + request = RunRequest(message="Hello", correlation_id="corr-003", role="system") # type: ignore[arg-type] assert request.role == Role.SYSTEM def test_to_dict_with_defaults(self) -> None: """Test to_dict with default values.""" - request = RunRequest(message="Test message") + request = RunRequest(message="Test message", correlation_id="corr-004") data = request.to_dict() assert data["message"] == "Test message" assert data["enable_tool_calls"] is True assert data["role"] == "user" + assert data["correlationId"] == "corr-004" assert "response_format" not in data or data["response_format"] is None assert "thread_id" not in data @@ -62,6 +66,7 @@ def test_to_dict_with_all_fields(self) -> None: schema = ModuleStructuredResponse request = RunRequest( message="Hello", + correlation_id="corr-005", role=Role.ASSISTANT, response_format=schema, enable_tool_calls=False, @@ -69,6 +74,7 @@ def test_to_dict_with_all_fields(self) -> None: data = request.to_dict() assert data["message"] == "Hello" + assert data["correlationId"] == "corr-005" assert data["role"] == "assistant" assert data["response_format"]["__response_schema_type__"] == "pydantic_model" assert data["response_format"]["module"] == schema.__module__ @@ -78,16 +84,17 @@ def test_to_dict_with_all_fields(self) -> None: def test_from_dict_with_defaults(self) -> None: """Test from_dict with minimal data.""" - data = {"message": "Hello"} + data = {"message": "Hello", "correlationId": "corr-006"} request = RunRequest.from_dict(data) assert request.message == "Hello" + assert request.correlation_id == "corr-006" assert request.role == Role.USER assert request.enable_tool_calls is True def test_from_dict_ignores_thread_id_field(self) -> None: """Ensure legacy thread_id input does not break RunRequest parsing.""" - request = RunRequest.from_dict({"message": "Hello", "thread_id": "ignored"}) + request = RunRequest.from_dict({"message": "Hello", "correlationId": "corr-007", "thread_id": "ignored"}) assert request.message == "Hello" @@ -95,6 +102,7 @@ def test_from_dict_with_all_fields(self) -> None: """Test from_dict with all fields.""" data = { "message": "Test", + "correlationId": "corr-008", "role": "system", "response_format": { "__response_schema_type__": "pydantic_model", @@ -106,13 +114,14 @@ def test_from_dict_with_all_fields(self) -> None: request = RunRequest.from_dict(data) assert request.message == "Test" + assert request.correlation_id == "corr-008" assert request.role == Role.SYSTEM assert request.response_format is ModuleStructuredResponse assert request.enable_tool_calls is False - def test_from_dict_with_unknown_role_preserves_value(self) -> None: + def test_from_dict_unknown_role_preserves_value(self) -> None: """Test from_dict keeps custom roles intact.""" - data = {"message": "Test", "role": "reviewer"} + data = {"message": "Test", "correlationId": "corr-009", "role": "reviewer"} request = RunRequest.from_dict(data) assert request.role.value == "reviewer" @@ -120,15 +129,22 @@ def test_from_dict_with_unknown_role_preserves_value(self) -> None: def test_from_dict_empty_message(self) -> None: """Test from_dict with empty message.""" - request = RunRequest.from_dict({}) + request = RunRequest.from_dict({"correlationId": "corr-010"}) assert request.message == "" + assert request.correlation_id == "corr-010" assert request.role == Role.USER + def test_from_dict_missing_correlation_id_raises(self) -> None: + """Test from_dict raises when correlationId is missing.""" + with pytest.raises(ValueError, match="correlationId is required"): + RunRequest.from_dict({"message": "Test"}) + def test_round_trip_dict_conversion(self) -> None: """Test round-trip to_dict and from_dict.""" original = RunRequest( message="Test message", + correlation_id="corr-011", role=Role.SYSTEM, response_format=ModuleStructuredResponse, enable_tool_calls=False, @@ -138,6 +154,7 @@ def test_round_trip_dict_conversion(self) -> None: restored = RunRequest.from_dict(data) assert restored.message == original.message + assert restored.correlation_id == original.correlation_id assert restored.role == original.role assert restored.response_format is ModuleStructuredResponse assert restored.enable_tool_calls == original.enable_tool_calls @@ -146,6 +163,7 @@ def test_round_trip_with_pydantic_response_format(self) -> None: """Ensure Pydantic response formats serialize and deserialize properly.""" original = RunRequest( message="Structured", + correlation_id="corr-012", response_format=ModuleStructuredResponse, ) @@ -186,7 +204,7 @@ def test_round_trip_with_correlationId(self) -> None: original = RunRequest( message="Test message", role=Role.SYSTEM, - correlation_id="corr-123", + correlation_id="corr-124", ) data = original.to_dict() @@ -200,6 +218,7 @@ def test_init_with_orchestration_id(self) -> None: """Test RunRequest initialization with orchestration_id.""" request = RunRequest( message="Test message", + correlation_id="corr-125", orchestration_id="orch-123", ) @@ -210,6 +229,7 @@ def test_to_dict_with_orchestration_id(self) -> None: """Test to_dict includes orchestrationId.""" request = RunRequest( message="Test", + correlation_id="corr-126", orchestration_id="orch-456", ) data = request.to_dict() @@ -221,6 +241,7 @@ def test_to_dict_excludes_orchestration_id_when_none(self) -> None: """Test to_dict excludes orchestrationId when not set.""" request = RunRequest( message="Test", + correlation_id="corr-127", ) data = request.to_dict() @@ -230,6 +251,7 @@ def test_from_dict_with_orchestration_id(self) -> None: """Test from_dict with orchestrationId.""" data = { "message": "Test", + "correlationId": "corr-128", "orchestrationId": "orch-789", } request = RunRequest.from_dict(data) @@ -242,7 +264,7 @@ def test_round_trip_with_orchestration_id(self) -> None: original = RunRequest( message="Test message", role=Role.SYSTEM, - correlation_id="corr-123", + correlation_id="corr-129", orchestration_id="orch-123", ) diff --git a/python/packages/durabletask/tests/test_orchestration_context.py b/python/packages/durabletask/tests/test_orchestration_context.py new file mode 100644 index 0000000000..f6a7755335 --- /dev/null +++ b/python/packages/durabletask/tests/test_orchestration_context.py @@ -0,0 +1,98 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Unit tests for DurableAIAgentOrchestrationContext. + +Focuses on critical orchestration workflows: agent retrieval and integration. +Run with: pytest tests/test_orchestration_context.py -v +""" + +from unittest.mock import Mock + +import pytest +from agent_framework import AgentProtocol + +from agent_framework_durabletask import DurableAgentThread +from agent_framework_durabletask._orchestration_context import DurableAIAgentOrchestrationContext +from agent_framework_durabletask._shim import DurableAIAgent + + +@pytest.fixture +def mock_orchestration_context() -> Mock: + """Create a mock OrchestrationContext for testing.""" + return Mock() + + +@pytest.fixture +def agent_context(mock_orchestration_context: Mock) -> DurableAIAgentOrchestrationContext: + """Create a DurableAIAgentOrchestrationContext with mock context.""" + return DurableAIAgentOrchestrationContext(mock_orchestration_context) + + +class TestDurableAIAgentOrchestrationContextGetAgent: + """Test core workflow: retrieving agents from orchestration context.""" + + def test_get_agent_returns_durable_agent_shim(self, agent_context: DurableAIAgentOrchestrationContext) -> None: + """Verify get_agent returns a DurableAIAgent instance.""" + agent = agent_context.get_agent("assistant") + + assert isinstance(agent, DurableAIAgent) + assert isinstance(agent, AgentProtocol) + + def test_get_agent_shim_has_correct_name(self, agent_context: DurableAIAgentOrchestrationContext) -> None: + """Verify retrieved agent has the correct name.""" + agent = agent_context.get_agent("my_agent") + + assert agent.name == "my_agent" + + def test_get_agent_multiple_times_returns_new_instances( + self, agent_context: DurableAIAgentOrchestrationContext + ) -> None: + """Verify multiple get_agent calls return independent instances.""" + agent1 = agent_context.get_agent("assistant") + agent2 = agent_context.get_agent("assistant") + + assert agent1 is not agent2 # Different object instances + + def test_get_agent_different_agents(self, agent_context: DurableAIAgentOrchestrationContext) -> None: + """Verify context can retrieve multiple different agents.""" + agent1 = agent_context.get_agent("agent1") + agent2 = agent_context.get_agent("agent2") + + assert agent1.name == "agent1" + assert agent2.name == "agent2" + + +class TestDurableAIAgentOrchestrationContextIntegration: + """Test integration scenarios between orchestration context and agent shim.""" + + def test_orchestration_agent_has_working_run_method( + self, agent_context: DurableAIAgentOrchestrationContext + ) -> None: + """Verify agent from context has callable run method (even if not yet implemented).""" + agent = agent_context.get_agent("assistant") + + assert hasattr(agent, "run") + assert callable(agent.run) + + def test_orchestration_agent_can_create_threads(self, agent_context: DurableAIAgentOrchestrationContext) -> None: + """Verify agent from context can create DurableAgentThread instances.""" + agent = agent_context.get_agent("assistant") + + thread = agent.get_new_thread() + + assert isinstance(thread, DurableAgentThread) + + def test_orchestration_agent_thread_with_parameters( + self, agent_context: DurableAIAgentOrchestrationContext + ) -> None: + """Verify agent can create threads with custom parameters.""" + agent = agent_context.get_agent("assistant") + + thread = agent.get_new_thread(service_thread_id="orch-session-456") + + assert isinstance(thread, DurableAgentThread) + assert thread.service_thread_id == "orch-session-456" + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "--tb=short"]) diff --git a/python/packages/durabletask/tests/test_shim.py b/python/packages/durabletask/tests/test_shim.py new file mode 100644 index 0000000000..1fa348695c --- /dev/null +++ b/python/packages/durabletask/tests/test_shim.py @@ -0,0 +1,206 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Unit tests for DurableAIAgent shim and DurableAgentProvider. + +Focuses on critical message normalization, delegation, and protocol compliance. +Run with: pytest tests/test_shim.py -v +""" + +from typing import Any +from unittest.mock import Mock + +import pytest +from agent_framework import AgentProtocol, ChatMessage +from pydantic import BaseModel + +from agent_framework_durabletask import DurableAgentThread +from agent_framework_durabletask._executors import DurableAgentExecutor +from agent_framework_durabletask._models import RunRequest +from agent_framework_durabletask._shim import DurableAgentProvider, DurableAIAgent + + +class ResponseFormatModel(BaseModel): + """Test Pydantic model for response format testing.""" + + result: str + + +@pytest.fixture +def mock_executor() -> Mock: + """Create a mock executor for testing.""" + mock = Mock(spec=DurableAgentExecutor) + mock.run_durable_agent = Mock(return_value=None) + mock.get_new_thread = Mock(return_value=DurableAgentThread()) + + # Mock get_run_request to create actual RunRequest objects + def create_run_request( + message: str, response_format: type[BaseModel] | None = None, enable_tool_calls: bool = True + ) -> RunRequest: + import uuid + + return RunRequest( + message=message, + correlation_id=str(uuid.uuid4()), + response_format=response_format, + enable_tool_calls=enable_tool_calls, + ) + + mock.get_run_request = Mock(side_effect=create_run_request) + return mock + + +@pytest.fixture +def test_agent(mock_executor: Mock) -> DurableAIAgent[Any]: + """Create a test agent with mock executor.""" + return DurableAIAgent(mock_executor, "test_agent") + + +class TestDurableAIAgentMessageNormalization: + """Test that DurableAIAgent properly normalizes various message input types.""" + + def test_run_accepts_string_message(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None: + """Verify run accepts and normalizes string messages.""" + test_agent.run("Hello, world!") + + mock_executor.run_durable_agent.assert_called_once() + # Verify agent_name and run_request were passed correctly as kwargs + _, kwargs = mock_executor.run_durable_agent.call_args + assert kwargs["agent_name"] == "test_agent" + assert kwargs["run_request"].message == "Hello, world!" + + def test_run_accepts_chat_message(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None: + """Verify run accepts and normalizes ChatMessage objects.""" + chat_msg = ChatMessage(role="user", text="Test message") + test_agent.run(chat_msg) + + mock_executor.run_durable_agent.assert_called_once() + _, kwargs = mock_executor.run_durable_agent.call_args + assert kwargs["run_request"].message == "Test message" + + def test_run_accepts_list_of_strings(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None: + """Verify run accepts and joins list of strings.""" + test_agent.run(["First message", "Second message"]) + + mock_executor.run_durable_agent.assert_called_once() + _, kwargs = mock_executor.run_durable_agent.call_args + assert kwargs["run_request"].message == "First message\nSecond message" + + def test_run_accepts_list_of_chat_messages(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None: + """Verify run accepts and joins list of ChatMessage objects.""" + messages = [ + ChatMessage(role="user", text="Message 1"), + ChatMessage(role="assistant", text="Message 2"), + ] + test_agent.run(messages) + + mock_executor.run_durable_agent.assert_called_once() + _, kwargs = mock_executor.run_durable_agent.call_args + assert kwargs["run_request"].message == "Message 1\nMessage 2" + + def test_run_handles_none_message(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None: + """Verify run handles None message gracefully.""" + test_agent.run(None) + + mock_executor.run_durable_agent.assert_called_once() + _, kwargs = mock_executor.run_durable_agent.call_args + assert kwargs["run_request"].message == "" + + def test_run_handles_empty_list(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None: + """Verify run handles empty list gracefully.""" + test_agent.run([]) + + mock_executor.run_durable_agent.assert_called_once() + _, kwargs = mock_executor.run_durable_agent.call_args + assert kwargs["run_request"].message == "" + + +class TestDurableAIAgentParameterFlow: + """Test that parameters flow correctly through the shim to executor.""" + + def test_run_forwards_thread_parameter(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None: + """Verify run forwards thread parameter to executor.""" + thread = DurableAgentThread(service_thread_id="test-thread") + test_agent.run("message", thread=thread) + + mock_executor.run_durable_agent.assert_called_once() + _, kwargs = mock_executor.run_durable_agent.call_args + assert kwargs["thread"] == thread + + def test_run_forwards_response_format(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None: + """Verify run forwards response_format parameter to executor.""" + test_agent.run("message", response_format=ResponseFormatModel) + + mock_executor.run_durable_agent.assert_called_once() + _, kwargs = mock_executor.run_durable_agent.call_args + assert kwargs["run_request"].response_format == ResponseFormatModel + + +class TestDurableAIAgentProtocolCompliance: + """Test that DurableAIAgent implements AgentProtocol correctly.""" + + def test_agent_implements_protocol(self, test_agent: DurableAIAgent[Any]) -> None: + """Verify DurableAIAgent implements AgentProtocol.""" + assert isinstance(test_agent, AgentProtocol) + + def test_agent_has_required_properties(self, test_agent: DurableAIAgent[Any]) -> None: + """Verify DurableAIAgent has all required AgentProtocol properties.""" + assert hasattr(test_agent, "id") + assert hasattr(test_agent, "name") + assert hasattr(test_agent, "display_name") + assert hasattr(test_agent, "description") + + def test_agent_id_defaults_to_name(self, mock_executor: Mock) -> None: + """Verify agent id defaults to name when not provided.""" + agent: DurableAIAgent[Any] = DurableAIAgent(mock_executor, "my_agent") + + assert agent.id == "my_agent" + assert agent.name == "my_agent" + + def test_agent_id_can_be_customized(self, mock_executor: Mock) -> None: + """Verify agent id can be set independently from name.""" + agent: DurableAIAgent[Any] = DurableAIAgent(mock_executor, "my_agent", agent_id="custom-id") + + assert agent.id == "custom-id" + assert agent.name == "my_agent" + + +class TestDurableAIAgentThreadManagement: + """Test thread creation and management.""" + + def test_get_new_thread_delegates_to_executor(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None: + """Verify get_new_thread delegates to executor.""" + mock_thread = DurableAgentThread() + mock_executor.get_new_thread.return_value = mock_thread + + thread = test_agent.get_new_thread() + + mock_executor.get_new_thread.assert_called_once_with("test_agent") + assert thread == mock_thread + + def test_get_new_thread_forwards_kwargs(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None: + """Verify get_new_thread forwards kwargs to executor.""" + mock_thread = DurableAgentThread(service_thread_id="thread-123") + mock_executor.get_new_thread.return_value = mock_thread + + test_agent.get_new_thread(service_thread_id="thread-123") + + mock_executor.get_new_thread.assert_called_once() + _, kwargs = mock_executor.get_new_thread.call_args + assert kwargs["service_thread_id"] == "thread-123" + + +class TestDurableAgentProviderInterface: + """Test that DurableAgentProvider defines the correct interface.""" + + def test_provider_cannot_be_instantiated(self) -> None: + """Verify DurableAgentProvider is abstract and cannot be instantiated.""" + with pytest.raises(TypeError): + DurableAgentProvider() # type: ignore[abstract] + + def test_provider_defines_get_agent_method(self) -> None: + """Verify DurableAgentProvider defines get_agent abstract method.""" + assert hasattr(DurableAgentProvider, "get_agent") + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "--tb=short"]) diff --git a/python/packages/durabletask/tests/test_worker.py b/python/packages/durabletask/tests/test_worker.py new file mode 100644 index 0000000000..e6dabcdfdf --- /dev/null +++ b/python/packages/durabletask/tests/test_worker.py @@ -0,0 +1,168 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Unit tests for DurableAIAgentWorker. + +Focuses on critical worker flows: agent registration, validation, callbacks, and lifecycle. +""" + +from unittest.mock import Mock + +import pytest + +from agent_framework_durabletask import DurableAIAgentWorker + + +@pytest.fixture +def mock_grpc_worker() -> Mock: + """Create a mock TaskHubGrpcWorker for testing.""" + mock = Mock() + mock.add_entity = Mock(return_value="dafx-test_agent") + mock.start = Mock() + mock.stop = Mock() + return mock + + +@pytest.fixture +def mock_agent() -> Mock: + """Create a mock agent for testing.""" + agent = Mock() + agent.name = "test_agent" + return agent + + +@pytest.fixture +def agent_worker(mock_grpc_worker: Mock) -> DurableAIAgentWorker: + """Create a DurableAIAgentWorker with mock worker.""" + return DurableAIAgentWorker(mock_grpc_worker) + + +class TestDurableAIAgentWorkerRegistration: + """Test agent registration behavior.""" + + def test_add_agent_accepts_agent_with_name( + self, agent_worker: DurableAIAgentWorker, mock_agent: Mock, mock_grpc_worker: Mock + ) -> None: + """Verify that agents with names can be registered.""" + agent_worker.add_agent(mock_agent) + + # Verify entity was registered with underlying worker + mock_grpc_worker.add_entity.assert_called_once() + # Verify agent name is tracked + assert "test_agent" in agent_worker.registered_agent_names + + def test_add_agent_rejects_agent_without_name(self, agent_worker: DurableAIAgentWorker) -> None: + """Verify that agents without names are rejected.""" + agent_no_name = Mock() + agent_no_name.name = None + + with pytest.raises(ValueError, match="Agent must have a name"): + agent_worker.add_agent(agent_no_name) + + def test_add_agent_rejects_empty_name(self, agent_worker: DurableAIAgentWorker) -> None: + """Verify that agents with empty names are rejected.""" + agent_empty_name = Mock() + agent_empty_name.name = "" + + with pytest.raises(ValueError, match="Agent must have a name"): + agent_worker.add_agent(agent_empty_name) + + def test_add_agent_rejects_duplicate_names(self, agent_worker: DurableAIAgentWorker, mock_agent: Mock) -> None: + """Verify duplicate agent names are not allowed.""" + agent_worker.add_agent(mock_agent) + + # Try to register another agent with the same name + duplicate_agent = Mock() + duplicate_agent.name = "test_agent" + + with pytest.raises(ValueError, match="already registered"): + agent_worker.add_agent(duplicate_agent) + + def test_registered_agent_names_tracks_multiple_agents(self, agent_worker: DurableAIAgentWorker) -> None: + """Verify registered_agent_names tracks all registered agents.""" + agent1 = Mock() + agent1.name = "agent1" + agent2 = Mock() + agent2.name = "agent2" + agent3 = Mock() + agent3.name = "agent3" + + agent_worker.add_agent(agent1) + agent_worker.add_agent(agent2) + agent_worker.add_agent(agent3) + + registered = agent_worker.registered_agent_names + assert "agent1" in registered + assert "agent2" in registered + assert "agent3" in registered + assert len(registered) == 3 + + +class TestDurableAIAgentWorkerCallbacks: + """Test callback configuration behavior.""" + + def test_worker_level_callback_accepted(self, mock_grpc_worker: Mock) -> None: + """Verify worker-level callback can be set.""" + mock_callback = Mock() + agent_worker = DurableAIAgentWorker(mock_grpc_worker, callback=mock_callback) + + assert agent_worker is not None + + def test_agent_level_callback_accepted(self, agent_worker: DurableAIAgentWorker, mock_agent: Mock) -> None: + """Verify agent-level callback can be set during registration.""" + mock_callback = Mock() + + # Should not raise exception + agent_worker.add_agent(mock_agent, callback=mock_callback) + + assert "test_agent" in agent_worker.registered_agent_names + + def test_none_callback_accepted(self, mock_grpc_worker: Mock, mock_agent: Mock) -> None: + """Verify None callback is valid (no callbacks required).""" + agent_worker = DurableAIAgentWorker(mock_grpc_worker, callback=None) + agent_worker.add_agent(mock_agent, callback=None) + + assert "test_agent" in agent_worker.registered_agent_names + + +class TestDurableAIAgentWorkerLifecycle: + """Test worker lifecycle behavior.""" + + def test_start_delegates_to_underlying_worker( + self, agent_worker: DurableAIAgentWorker, mock_grpc_worker: Mock + ) -> None: + """Verify start() delegates to wrapped worker.""" + agent_worker.start() + + mock_grpc_worker.start.assert_called_once() + + def test_stop_delegates_to_underlying_worker( + self, agent_worker: DurableAIAgentWorker, mock_grpc_worker: Mock + ) -> None: + """Verify stop() delegates to wrapped worker.""" + agent_worker.stop() + + mock_grpc_worker.stop.assert_called_once() + + def test_start_works_with_no_agents(self, agent_worker: DurableAIAgentWorker, mock_grpc_worker: Mock) -> None: + """Verify worker can start even with no agents registered.""" + agent_worker.start() + + mock_grpc_worker.start.assert_called_once() + + def test_start_works_with_multiple_agents(self, agent_worker: DurableAIAgentWorker, mock_grpc_worker: Mock) -> None: + """Verify worker can start with multiple agents registered.""" + agent1 = Mock() + agent1.name = "agent1" + agent2 = Mock() + agent2.name = "agent2" + + agent_worker.add_agent(agent1) + agent_worker.add_agent(agent2) + agent_worker.start() + + mock_grpc_worker.start.assert_called_once() + assert len(agent_worker.registered_agent_names) == 2 + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "--tb=short"]) diff --git a/python/samples/getting_started/azure_functions/04_single_agent_orchestration_chaining/function_app.py b/python/samples/getting_started/azure_functions/04_single_agent_orchestration_chaining/function_app.py index cc05a323f3..32b2fab131 100644 --- a/python/samples/getting_started/azure_functions/04_single_agent_orchestration_chaining/function_app.py +++ b/python/samples/getting_started/azure_functions/04_single_agent_orchestration_chaining/function_app.py @@ -10,6 +10,7 @@ import json import logging +from collections.abc import Generator from typing import Any import azure.functions as func @@ -44,7 +45,7 @@ def _create_writer_agent() -> Any: # 4. Orchestration that runs the agent sequentially on a shared thread for chaining behaviour. @app.orchestration_trigger(context_name="context") -def single_agent_orchestration(context: DurableOrchestrationContext): +def single_agent_orchestration(context: DurableOrchestrationContext) -> Generator[Any, Any, str]: """Run the writer agent twice on the same thread to mirror chaining behaviour.""" writer = app.get_agent(context, WRITER_AGENT_NAME) @@ -116,12 +117,6 @@ async def get_orchestration_status( ) status = await client.get_status(instance_id) - if status is None: - return func.HttpResponse( - body=json.dumps({"error": "Instance not found"}), - status_code=404, - mimetype="application/json", - ) response_data: dict[str, Any] = { "instanceId": status.instance_id, diff --git a/python/samples/getting_started/azure_functions/05_multi_agent_orchestration_concurrency/function_app.py b/python/samples/getting_started/azure_functions/05_multi_agent_orchestration_concurrency/function_app.py index 69ea8816b2..41fb3f08b2 100644 --- a/python/samples/getting_started/azure_functions/05_multi_agent_orchestration_concurrency/function_app.py +++ b/python/samples/getting_started/azure_functions/05_multi_agent_orchestration_concurrency/function_app.py @@ -10,6 +10,7 @@ import json import logging +from collections.abc import Generator from typing import Any, cast from agent_framework import AgentRunResponse @@ -51,7 +52,7 @@ def _create_agents() -> list[Any]: # 4. Durable Functions orchestration that runs both agents in parallel. @app.orchestration_trigger(context_name="context") -def multi_agent_concurrent_orchestration(context: DurableOrchestrationContext): +def multi_agent_concurrent_orchestration(context: DurableOrchestrationContext) -> Generator[Any, Any, dict[str, str]]: """Fan out to two domain-specific agents and aggregate their responses.""" prompt = context.get_input() @@ -137,12 +138,6 @@ async def get_orchestration_status( ) status = await client.get_status(instance_id) - if status is None: - return func.HttpResponse( - body=json.dumps({"error": "Instance not found"}), - status_code=404, - mimetype="application/json", - ) response_data: dict[str, Any] = { "instanceId": status.instance_id, diff --git a/python/samples/getting_started/azure_functions/06_multi_agent_orchestration_conditionals/function_app.py b/python/samples/getting_started/azure_functions/06_multi_agent_orchestration_conditionals/function_app.py index 2ee445d423..8ef5ef7211 100644 --- a/python/samples/getting_started/azure_functions/06_multi_agent_orchestration_conditionals/function_app.py +++ b/python/samples/getting_started/azure_functions/06_multi_agent_orchestration_conditionals/function_app.py @@ -11,7 +11,7 @@ import json import logging -from collections.abc import Mapping +from collections.abc import Generator, Mapping from typing import Any, cast import azure.functions as func @@ -74,7 +74,7 @@ def send_email(message: str) -> str: # 4. Orchestration validates input, runs agents, and branches on spam results. @app.orchestration_trigger(context_name="context") -def spam_detection_orchestration(context: DurableOrchestrationContext): +def spam_detection_orchestration(context: DurableOrchestrationContext) -> Generator[Any, Any, str]: payload_raw = context.get_input() if not isinstance(payload_raw, Mapping): raise ValueError("Email data is required") @@ -105,7 +105,7 @@ def spam_detection_orchestration(context: DurableOrchestrationContext): spam_result = cast(SpamDetectionResult, spam_result_raw.value) if spam_result.is_spam: - result = yield context.call_activity("handle_spam_email", spam_result.reason) + result = yield context.call_activity("handle_spam_email", spam_result.reason) # type: ignore[misc] return result email_thread = email_agent.get_new_thread() @@ -125,7 +125,7 @@ def spam_detection_orchestration(context: DurableOrchestrationContext): email_result = cast(EmailResponse, email_result_raw.value) - result = yield context.call_activity("send_email", email_result.response) + result = yield context.call_activity("send_email", email_result.response) # type: ignore[misc] return result @@ -196,12 +196,6 @@ async def get_orchestration_status( ) status = await client.get_status(instance_id) - if status is None: - return func.HttpResponse( - body=json.dumps({"error": "Instance not found"}), - status_code=404, - mimetype="application/json", - ) response_data: dict[str, Any] = { "instanceId": status.instance_id, diff --git a/python/samples/getting_started/azure_functions/07_single_agent_orchestration_hitl/function_app.py b/python/samples/getting_started/azure_functions/07_single_agent_orchestration_hitl/function_app.py index c8e2bbaa9c..fc72ceb770 100644 --- a/python/samples/getting_started/azure_functions/07_single_agent_orchestration_hitl/function_app.py +++ b/python/samples/getting_started/azure_functions/07_single_agent_orchestration_hitl/function_app.py @@ -10,7 +10,7 @@ import json import logging -from collections.abc import Mapping +from collections.abc import Generator, Mapping from datetime import timedelta from typing import Any @@ -62,7 +62,7 @@ def _create_writer_agent() -> Any: # 3. Activities encapsulate external work for review notifications and publishing. @app.activity_trigger(input_name="content") -def notify_user_for_approval(content: dict) -> None: +def notify_user_for_approval(content: Any) -> None: model = GeneratedContent.model_validate(content) logger.info("NOTIFICATION: Please review the following content for approval:") logger.info("Title: %s", model.title or "(untitled)") @@ -71,7 +71,7 @@ def notify_user_for_approval(content: dict) -> None: @app.activity_trigger(input_name="content") -def publish_content(content: dict) -> None: +def publish_content(content: Any) -> None: model = GeneratedContent.model_validate(content) logger.info("PUBLISHING: Content has been published successfully:") logger.info("Title: %s", model.title or "(untitled)") @@ -80,7 +80,7 @@ def publish_content(content: dict) -> None: # 4. Orchestration loops until the human approves, times out, or attempts are exhausted. @app.orchestration_trigger(context_name="context") -def content_generation_hitl_orchestration(context: DurableOrchestrationContext): +def content_generation_hitl_orchestration(context: DurableOrchestrationContext) -> Generator[Any, Any, dict[str, str]]: payload_raw = context.get_input() if not isinstance(payload_raw, Mapping): raise ValueError("Content generation input is required") @@ -102,7 +102,7 @@ def content_generation_hitl_orchestration(context: DurableOrchestrationContext): ) content = initial_raw.value - logger.info("Type of content after extraction: %s", type(content)) + logger.info("Type of content after extraction: %s", type(content)) # type: ignore[misc] if content is None or not isinstance(content, GeneratedContent): raise ValueError("Agent returned no content after extraction.") @@ -114,7 +114,7 @@ def content_generation_hitl_orchestration(context: DurableOrchestrationContext): f"Requesting human feedback. Iteration #{attempt}. Timeout: {payload.approval_timeout_hours} hour(s)." ) - yield context.call_activity("notify_user_for_approval", content.model_dump()) + yield context.call_activity("notify_user_for_approval", content.model_dump()) # type: ignore[misc] approval_task = context.wait_for_external_event(HUMAN_APPROVAL_EVENT) timeout_task = context.create_timer( @@ -129,7 +129,7 @@ def content_generation_hitl_orchestration(context: DurableOrchestrationContext): if approval_payload.approved: context.set_custom_status("Content approved by human reviewer. Publishing content...") - yield context.call_activity("publish_content", content.model_dump()) + yield context.call_activity("publish_content", content.model_dump()) # type: ignore[misc] context.set_custom_status( f"Content published successfully at {context.current_utc_datetime:%Y-%m-%dT%H:%M:%S}" ) @@ -287,7 +287,7 @@ async def get_orchestration_status( ) # Check if status is None or if the instance doesn't exist (runtime_status is None) - if status is None or getattr(status, "runtime_status", None) is None: + if getattr(status, "runtime_status", None) is None: return func.HttpResponse( body=json.dumps({"error": "Instance not found."}), status_code=404, diff --git a/python/samples/getting_started/durabletask/01_single_agent/README.md b/python/samples/getting_started/durabletask/01_single_agent/README.md new file mode 100644 index 0000000000..6e31b0a737 --- /dev/null +++ b/python/samples/getting_started/durabletask/01_single_agent/README.md @@ -0,0 +1,66 @@ +# Single Agent Sample + +This sample demonstrates how to use the durable agents extension to create a worker-client setup that hosts a single AI agent and provides interactive conversation via the Durable Task Scheduler. + +## Key Concepts Demonstrated + +- Using the Microsoft Agent Framework to define a simple AI agent with a name and instructions. +- Registering durable agents with the worker and interacting with them via a client. +- Conversation management (via threads) for isolated interactions. +- Worker-client architecture for distributed agent execution. + +## Environment Setup + +See the [README.md](../README.md) file in the parent directory for more information on how to configure the environment, including how to install and run common sample dependencies. + +## Running the Sample + +With the environment setup, you can run the sample using separate worker and client processes: + +**Start the worker:** + +```bash +cd samples/getting_started/durabletask/01_single_agent +python worker.py +``` + +The worker will register the Joker agent and listen for requests. + +**In a new terminal, run the client:** + +```bash +python client.py +``` + +The client will interact with the Joker agent: + +``` +Starting Durable Task Agent Client... +Using taskhub: default +Using endpoint: http://localhost:8080 + +Getting reference to Joker agent... +Created conversation thread: a1b2c3d4-e5f6-7890-abcd-ef1234567890 + +User: Tell me a short joke about cloud computing. + +Joker: Why did the cloud break up with the server? +Because it found someone more "uplifting"! + +User: Now tell me one about Python programming. + +Joker: Why do Python programmers prefer dark mode? +Because light attracts bugs! +``` + +## Viewing Agent State + +You can view the state of the agent in the Durable Task Scheduler dashboard: + +1. Open your browser and navigate to `http://localhost:8082` +2. In the dashboard, you can view the state of the Joker agent, including its conversation history and current state + +The agent maintains conversation state across multiple interactions, and you can inspect this state in the dashboard to understand how the durable agents extension manages conversation context. + + + diff --git a/python/samples/getting_started/durabletask/01_single_agent/client.py b/python/samples/getting_started/durabletask/01_single_agent/client.py new file mode 100644 index 0000000000..bfd5147ea7 --- /dev/null +++ b/python/samples/getting_started/durabletask/01_single_agent/client.py @@ -0,0 +1,92 @@ +"""Client application for interacting with a Durable Task hosted agent. + +This client connects to the Durable Task Scheduler and sends requests to +registered agents, demonstrating how to interact with agents from external processes. + +Prerequisites: +- The worker must be running with the agent registered +- Set AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_CHAT_DEPLOYMENT_NAME + (plus AZURE_OPENAI_API_KEY or Azure CLI authentication) +- Durable Task Scheduler must be running +""" + +import asyncio +import logging +import os + +from agent_framework_durabletask import DurableAIAgentClient +from azure.identity import DefaultAzureCredential +from durabletask.azuremanaged.client import DurableTaskSchedulerClient + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +async def main() -> None: + """Main entry point for the client application.""" + logger.info("Starting Durable Task Agent Client...") + + # Get environment variables for taskhub and endpoint with defaults + taskhub_name = os.getenv("TASKHUB", "default") + endpoint = os.getenv("ENDPOINT", "http://localhost:8080") + + logger.info(f"Using taskhub: {taskhub_name}") + logger.info(f"Using endpoint: {endpoint}") + logger.info("") + + # Set credential to None for emulator, or DefaultAzureCredential for Azure + credential = None if endpoint == "http://localhost:8080" else DefaultAzureCredential() + + # Create a client using Azure Managed Durable Task + client = DurableTaskSchedulerClient( + host_address=endpoint, + secure_channel=endpoint != "http://localhost:8080", + taskhub=taskhub_name, + token_credential=credential + ) + + # Wrap it with the agent client + agent_client = DurableAIAgentClient(client) + + # Get a reference to the Joker agent + logger.info("Getting reference to Joker agent...") + joker = agent_client.get_agent("Joker") + + # Create a new thread for the conversation + thread = joker.get_new_thread() + + logger.info(f"Created conversation thread: {thread.session_id}") + logger.info("") + + try: + # First message + message1 = "Tell me a short joke about cloud computing." + logger.info(f"User: {message1}") + logger.info("") + + # Run the agent - this blocks until the response is ready + response1 = joker.run(message1, thread=thread) + logger.info(f"Agent: {response1.text}") + logger.info("") + + # Second message - continuing the conversation + message2 = "Now tell me one about Python programming." + logger.info(f"User: {message2}") + logger.info("") + + response2 = joker.run(message2, thread=thread) + logger.info(f"Agent: {response2.text}") + logger.info("") + + logger.info(f"Conversation completed successfully!") + logger.info(f"Thread ID: {thread.session_id}") + + except Exception as e: + logger.exception(f"Error during agent interaction: {e}") + finally: + logger.info("Client shutting down") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/getting_started/durabletask/01_single_agent/requirements.txt b/python/samples/getting_started/durabletask/01_single_agent/requirements.txt new file mode 100644 index 0000000000..371b9e3b79 --- /dev/null +++ b/python/samples/getting_started/durabletask/01_single_agent/requirements.txt @@ -0,0 +1,6 @@ +# Agent Framework packages (installing from local package until a package is published) +-e ../../../../ +-e ../../../../packages/durabletask + +# Azure authentication +azure-identity diff --git a/python/samples/getting_started/durabletask/01_single_agent/sample.py b/python/samples/getting_started/durabletask/01_single_agent/sample.py new file mode 100644 index 0000000000..cfbceaaebd --- /dev/null +++ b/python/samples/getting_started/durabletask/01_single_agent/sample.py @@ -0,0 +1,137 @@ +"""Single Agent Sample - Durable Task Integration (Combined Worker + Client) + +This sample demonstrates running both the worker and client in a single process. +The worker is started first to register the agent, then client operations are +performed against the running worker. + +Prerequisites: +- Set AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_CHAT_DEPLOYMENT_NAME + (plus AZURE_OPENAI_API_KEY or Azure CLI authentication) +- Durable Task Scheduler must be running (e.g., using Docker) + +To run this sample: + python sample.py +""" + +import logging +import os + +from agent_framework.azure import AzureOpenAIChatClient +from agent_framework_durabletask import DurableAIAgentClient, DurableAIAgentWorker +from azure.identity import AzureCliCredential, DefaultAzureCredential +from dotenv import load_dotenv +from durabletask.azuremanaged.client import DurableTaskSchedulerClient +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def create_joker_agent(): + """Create the Joker agent using Azure OpenAI. + + Returns: + AgentProtocol: The configured Joker agent + """ + return AzureOpenAIChatClient(credential=AzureCliCredential()).create_agent( + name="Joker", + instructions="You are good at telling jokes.", + ) + + +def main(): + """Main entry point - runs both worker and client in single process.""" + logger.info("Starting Durable Task Agent Sample (Combined Worker + Client)...") + + # Get environment variables for taskhub and endpoint with defaults + taskhub_name = os.getenv("TASKHUB", "default") + endpoint = os.getenv("ENDPOINT", "http://localhost:8080") + + logger.info(f"Using taskhub: {taskhub_name}") + logger.info(f"Using endpoint: {endpoint}") + logger.info("") + + # Set credential to None for emulator, or DefaultAzureCredential for Azure + credential = None if endpoint == "http://localhost:8080" else DefaultAzureCredential() + secure_channel = endpoint != "http://localhost:8080" + + # Create and start the worker using a context manager + with DurableTaskSchedulerWorker( + host_address=endpoint, + secure_channel=secure_channel, + taskhub=taskhub_name, + token_credential=credential + ) as worker: + + # Wrap with the agent worker + agent_worker = DurableAIAgentWorker(worker) + + # Create and register the Joker agent + logger.info("Creating and registering Joker agent...") + joker_agent = create_joker_agent() + agent_worker.add_agent(joker_agent) + + logger.info(f"✓ Registered agent: {joker_agent.name}") + logger.info(f" Entity name: dafx-{joker_agent.name}") + logger.info("") + + # Start the worker + worker.start() + logger.info("Worker started and listening for requests...") + logger.info("") + + # Create the client + client = DurableTaskSchedulerClient( + host_address=endpoint, + secure_channel=secure_channel, + taskhub=taskhub_name, + token_credential=credential + ) + + # Wrap it with the agent client + agent_client = DurableAIAgentClient(client) + + # Get a reference to the Joker agent + logger.info("Getting reference to Joker agent...") + joker = agent_client.get_agent("Joker") + + # Create a new thread for the conversation + thread = joker.get_new_thread() + + logger.info(f"Created conversation thread: {thread.session_id}") + logger.info("") + + try: + # First message + message1 = "Tell me a short joke about cloud computing." + logger.info(f"User: {message1}") + logger.info("") + + # Run the agent - this blocks until the response is ready + response1 = joker.run(message1, thread=thread) + logger.info(f"Agent: {response1.text}; {response1}") + logger.info("") + + # Second message - continuing the conversation + message2 = "Now tell me one about Python programming." + logger.info(f"User: {message2}") + logger.info("") + + response2 = joker.run(message2, thread=thread) + logger.info(f"Agent: {response2.text}; {response2}") + logger.info("") + + logger.info(f"Conversation completed successfully!") + logger.info(f"Thread ID: {thread.session_id}") + + except Exception as e: + logger.exception(f"Error during agent interaction: {e}") + + logger.info("") + logger.info("Sample completed. Worker shutting down...") + + +if __name__ == "__main__": + load_dotenv() + main() diff --git a/python/samples/getting_started/durabletask/01_single_agent/worker.py b/python/samples/getting_started/durabletask/01_single_agent/worker.py new file mode 100644 index 0000000000..4c0d172915 --- /dev/null +++ b/python/samples/getting_started/durabletask/01_single_agent/worker.py @@ -0,0 +1,89 @@ +"""Worker process for hosting a single Azure OpenAI-powered agent using Durable Task. + +This worker registers agents as durable entities and continuously listens for requests. +The worker should run as a background service, processing incoming agent requests. + +Prerequisites: +- Set AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_CHAT_DEPLOYMENT_NAME + (plus AZURE_OPENAI_API_KEY or Azure CLI authentication) +- Start a Durable Task Scheduler (e.g., using Docker) +""" + +import asyncio +import logging +import os + +from agent_framework.azure import AzureOpenAIChatClient +from agent_framework_durabletask import DurableAIAgentWorker +from azure.identity import AzureCliCredential, DefaultAzureCredential +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def create_joker_agent(): + """Create the Joker agent using Azure OpenAI. + + Returns: + AgentProtocol: The configured Joker agent + """ + return AzureOpenAIChatClient(credential=AzureCliCredential()).create_agent( + name="Joker", + instructions="You are good at telling jokes.", + ) + + +async def main(): + """Main entry point for the worker process.""" + logger.info("Starting Durable Task Agent Worker...") + + # Get environment variables for taskhub and endpoint with defaults + taskhub_name = os.getenv("TASKHUB", "default") + endpoint = os.getenv("ENDPOINT", "http://localhost:8080") + + logger.info(f"Using taskhub: {taskhub_name}") + logger.info(f"Using endpoint: {endpoint}") + + # Set credential to None for emulator, or DefaultAzureCredential for Azure + credential = None if endpoint == "http://localhost:8080" else DefaultAzureCredential() + + # Create a worker using Azure Managed Durable Task + worker = DurableTaskSchedulerWorker( + host_address=endpoint, + secure_channel=endpoint != "http://localhost:8080", + taskhub=taskhub_name, + token_credential=credential + ) + + # Wrap it with the agent worker + agent_worker = DurableAIAgentWorker(worker) + + # Create and register the Joker agent + logger.info("Creating and registering Joker agent...") + joker_agent = create_joker_agent() + agent_worker.add_agent(joker_agent) + + logger.info(f"✓ Registered agent: {joker_agent.name}") + logger.info(f" Entity name: dafx-{joker_agent.name}") + logger.info("") + logger.info("Worker is ready and listening for requests...") + logger.info("Press Ctrl+C to stop.") + logger.info("") + + try: + # Start the worker (this blocks until stopped) + worker.start() + + # Keep the worker running + while True: + await asyncio.sleep(1) + except KeyboardInterrupt: + logger.info("Worker shutdown initiated") + + logger.info("Worker stopped") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/getting_started/durabletask/04_single_agent_orchestration_chaining/README.md b/python/samples/getting_started/durabletask/04_single_agent_orchestration_chaining/README.md new file mode 100644 index 0000000000..090fdec3b9 --- /dev/null +++ b/python/samples/getting_started/durabletask/04_single_agent_orchestration_chaining/README.md @@ -0,0 +1,72 @@ +# Single Agent Orchestration Chaining Sample + +This sample demonstrates how to chain multiple invocations of the same agent using a durable orchestration while preserving conversation state between runs. + +## Key Concepts Demonstrated + +- Using durable orchestrations to coordinate sequential agent invocations. +- Chaining agent calls where the output of one run becomes input to the next. +- Maintaining conversation context across sequential runs using a shared thread. +- Using `DurableAIAgentOrchestrationContext` to access agents within orchestrations. + +## Environment Setup + +See the [README.md](../README.md) file in the parent directory for more information on how to configure the environment, including how to install and run common sample dependencies. + +## Running the Sample + +With the environment setup, you can run the sample using one of two approaches: + +### Option 1: Combined Worker + Client (Quick Start) + +```bash +cd samples/getting_started/durabletask/04_single_agent_orchestration_chaining +python sample.py +``` + +This runs both worker and client in a single process. + +### Option 2: Separate Worker and Client + +**Start the worker in one terminal:** + +```bash +python worker.py +``` + +**In a new terminal, run the client:** + +```bash +python client.py +``` + +The orchestration will execute the writer agent twice sequentially, and you'll see output like: + +``` +[Orchestration] Starting single agent chaining... +[Orchestration] Created thread: abc-123 +[Orchestration] First agent run: Generating initial sentence... +[Orchestration] Initial response: Every small step forward is progress toward mastery. +[Orchestration] Second agent run: Refining the sentence... +[Orchestration] Refined response: Each small step forward brings you closer to mastery and growth. +[Orchestration] Chaining complete + +================================================================================ +Orchestration Result +================================================================================ +Each small step forward brings you closer to mastery and growth. +``` + +## Viewing Orchestration State + +You can view the state of the orchestration in the Durable Task Scheduler dashboard: + +1. Open your browser and navigate to `http://localhost:8082` +2. In the dashboard, you can view the orchestration instance, including: + - The sequential execution of both agent runs + - The conversation thread shared between runs + - Input and output at each step + - Overall orchestration state and history + +The orchestration maintains the conversation context across both agent invocations, demonstrating how durable orchestrations can coordinate multi-step agent workflows. + diff --git a/python/samples/getting_started/durabletask/04_single_agent_orchestration_chaining/client.py b/python/samples/getting_started/durabletask/04_single_agent_orchestration_chaining/client.py new file mode 100644 index 0000000000..1b5331e47e --- /dev/null +++ b/python/samples/getting_started/durabletask/04_single_agent_orchestration_chaining/client.py @@ -0,0 +1,104 @@ +"""Client application for starting a single agent chaining orchestration. + +This client connects to the Durable Task Scheduler and starts an orchestration +that runs a writer agent twice sequentially on the same thread, demonstrating +how conversation context is maintained across multiple agent invocations. + +Prerequisites: +- The worker must be running with the writer agent and orchestration registered +- Set AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_CHAT_DEPLOYMENT_NAME + (plus AZURE_OPENAI_API_KEY or Azure CLI authentication) +- Durable Task Scheduler must be running +""" + +import asyncio +import json +import logging +import os + +from azure.identity import DefaultAzureCredential +from durabletask.azuremanaged.client import DurableTaskSchedulerClient + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +async def main() -> None: + """Main entry point for the client application.""" + logger.info("Starting Durable Task Single Agent Chaining Orchestration Client...") + + # Get environment variables for taskhub and endpoint with defaults + taskhub_name = os.getenv("TASKHUB", "default") + endpoint = os.getenv("ENDPOINT", "http://localhost:8080") + + logger.info(f"Using taskhub: {taskhub_name}") + logger.info(f"Using endpoint: {endpoint}") + logger.info("") + + # Set credential to None for emulator, or DefaultAzureCredential for Azure + credential = None if endpoint == "http://localhost:8080" else DefaultAzureCredential() + + # Create a client using Azure Managed Durable Task + client = DurableTaskSchedulerClient( + host_address=endpoint, + secure_channel=endpoint != "http://localhost:8080", + taskhub=taskhub_name, + token_credential=credential + ) + + logger.info("Starting single agent chaining orchestration...") + logger.info("") + + try: + # Start the orchestration + instance_id = client.schedule_new_orchestration( + orchestrator="single_agent_chaining_orchestration", + input="", + ) + + logger.info(f"Orchestration started with instance ID: {instance_id}") + logger.info("Waiting for orchestration to complete...") + logger.info("") + + # Retrieve the final state + metadata = client.wait_for_orchestration_completion( + instance_id=instance_id, + timeout=300 + ) + + if metadata and metadata.runtime_status.name == "COMPLETED": + result = metadata.serialized_output + + logger.info("=" * 80) + logger.info("Orchestration completed successfully!") + logger.info("=" * 80) + logger.info("") + logger.info("Results:") + logger.info("") + + # Parse and display the result + if result: + final_text = json.loads(result) + logger.info("Final refined sentence:") + logger.info(f" {final_text}") + logger.info("") + + logger.info("=" * 80) + + elif metadata: + logger.error(f"Orchestration ended with status: {metadata.runtime_status.name}") + if metadata.serialized_output: + logger.error(f"Output: {metadata.serialized_output}") + else: + logger.error("Orchestration did not complete within the timeout period") + + except Exception as e: + logger.exception(f"Error during orchestration: {e}") + finally: + logger.info("") + logger.info("Client shutting down") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/getting_started/durabletask/04_single_agent_orchestration_chaining/requirements.txt b/python/samples/getting_started/durabletask/04_single_agent_orchestration_chaining/requirements.txt new file mode 100644 index 0000000000..371b9e3b79 --- /dev/null +++ b/python/samples/getting_started/durabletask/04_single_agent_orchestration_chaining/requirements.txt @@ -0,0 +1,6 @@ +# Agent Framework packages (installing from local package until a package is published) +-e ../../../../ +-e ../../../../packages/durabletask + +# Azure authentication +azure-identity diff --git a/python/samples/getting_started/durabletask/04_single_agent_orchestration_chaining/sample.py b/python/samples/getting_started/durabletask/04_single_agent_orchestration_chaining/sample.py new file mode 100644 index 0000000000..1f1ee81deb --- /dev/null +++ b/python/samples/getting_started/durabletask/04_single_agent_orchestration_chaining/sample.py @@ -0,0 +1,255 @@ +"""Single Agent Orchestration Chaining Sample - Durable Task Integration + +This sample demonstrates chaining two invocations of the same agent inside a Durable Task +orchestration while preserving the conversation state between runs. The orchestration +runs the writer agent sequentially on a shared thread to refine text iteratively. + +Components used: +- AzureOpenAIChatClient to construct the writer agent +- DurableTaskSchedulerWorker and DurableAIAgentWorker for agent hosting +- DurableTaskSchedulerClient and orchestration for sequential agent invocations +- Thread management to maintain conversation context across invocations + +Prerequisites: +- Set AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_CHAT_DEPLOYMENT_NAME + (plus AZURE_OPENAI_API_KEY or Azure CLI authentication) +- Durable Task Scheduler must be running (e.g., using Docker emulator) + +To run this sample: + python sample.py +""" + +import asyncio +import json +import logging +import os +from collections.abc import Generator +from typing import Any + +from agent_framework import AgentRunResponse +from agent_framework.azure import AzureOpenAIChatClient +from agent_framework_durabletask import DurableAIAgentOrchestrationContext, DurableAIAgentWorker +from azure.identity import AzureCliCredential, DefaultAzureCredential +from dotenv import load_dotenv +from durabletask.task import OrchestrationContext, Task +from durabletask.azuremanaged.client import DurableTaskSchedulerClient +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Agent name +WRITER_AGENT_NAME = "WriterAgent" + + +def create_writer_agent(): + """Create the Writer agent using Azure OpenAI. + + This agent refines short pieces of text, enhancing initial sentences + and polishing improved versions further. + + Returns: + AgentProtocol: The configured Writer agent + """ + instructions = ( + "You refine short pieces of text. When given an initial sentence you enhance it;\n" + "when given an improved sentence you polish it further." + ) + + return AzureOpenAIChatClient(credential=AzureCliCredential()).create_agent( + name=WRITER_AGENT_NAME, + instructions=instructions, + ) + + +def single_agent_chaining_orchestration( + context: OrchestrationContext, _: str +) -> Generator[Task[Any], Any, str]: + """Orchestration that runs the writer agent twice on the same thread. + + This demonstrates chaining behavior where the output of the first agent run + becomes part of the input for the second run, all while maintaining the + conversation context through a shared thread. + + Args: + context: The orchestration context + _: Input parameter (unused) + + Returns: + str: The final refined text from the second agent run + """ + logger.info("[Orchestration] Starting single agent chaining...") + + # Wrap the orchestration context to access agents + agent_context = DurableAIAgentOrchestrationContext(context) + + # Get the writer agent using the agent context + writer = agent_context.get_agent(WRITER_AGENT_NAME) + + # Create a new thread for the conversation - this will be shared across both runs + writer_thread = writer.get_new_thread() + + logger.info(f"[Orchestration] Created thread: {writer_thread.session_id}") + + # First run: Generate an initial inspirational sentence + logger.info("[Orchestration] First agent run: Generating initial sentence...") + initial_response: AgentRunResponse = yield writer.run( + messages="Write a concise inspirational sentence about learning.", + thread=writer_thread, + ) + logger.info(f"[Orchestration] Initial response: {initial_response.text}") + + # Second run: Refine the initial response on the same thread + improved_prompt = ( + f"Improve this further while keeping it under 25 words: " + f"{initial_response.text}" + ) + + logger.info("[Orchestration] Second agent run: Refining the sentence...") + refined_response: AgentRunResponse = yield writer.run( + messages=improved_prompt, + thread=writer_thread, + ) + + logger.info(f"[Orchestration] Refined response: {refined_response.text}") + + logger.info("[Orchestration] Chaining complete") + return refined_response.text + + +async def run_client( + endpoint: str, taskhub_name: str, credential: DefaultAzureCredential | None +): + """Run the client to start and monitor the orchestration. + + Args: + endpoint: The durable task scheduler endpoint + taskhub_name: The task hub name + credential: The credential for authentication + """ + logger.info("") + logger.info("=" * 80) + logger.info("CLIENT: Starting orchestration...") + logger.info("=" * 80) + logger.info("") + + # Create a client + client = DurableTaskSchedulerClient( + host_address=endpoint, + secure_channel=endpoint != "http://localhost:8080", + taskhub=taskhub_name, + token_credential=credential + ) + + try: + # Start the orchestration + instance_id = client.schedule_new_orchestration( + single_agent_chaining_orchestration + ) + + logger.info(f"Orchestration started with instance ID: {instance_id}") + logger.info("Waiting for orchestration to complete...") + logger.info("") + + # Retrieve the final state + metadata = client.wait_for_orchestration_completion( + instance_id=instance_id, + timeout=300 + ) + + if metadata and metadata.runtime_status.name == "COMPLETED": + result = metadata.serialized_output + + logger.info("") + logger.info("=" * 80) + logger.info("ORCHESTRATION COMPLETED SUCCESSFULLY!") + logger.info("=" * 80) + logger.info("") + + # Parse and display the result + if result: + final_text = json.loads(result) + logger.info("Final refined sentence:") + logger.info(f" {final_text}") + else: + logger.warning("No output returned from orchestration") + + elif metadata: + logger.error(f"Orchestration did not complete successfully: {metadata.runtime_status.name}") + if metadata.serialized_output: + logger.error(f"Output: {metadata.serialized_output}") + else: + logger.error("Could not retrieve orchestration metadata") + + except Exception as e: + logger.exception(f"Error during orchestration: {e}") + + logger.info("") + logger.info("Client shutting down") + + +def main(): + """Main entry point - runs both worker and client in single process.""" + logger.info("Starting Single Agent Orchestration Chaining Sample...") + logger.info("") + + # Load environment variables + load_dotenv() + + # Get environment variables for taskhub and endpoint with defaults + taskhub_name = os.getenv("TASKHUB", "default") + endpoint = os.getenv("ENDPOINT", "http://localhost:8080") + + logger.info(f"Using taskhub: {taskhub_name}") + logger.info(f"Using endpoint: {endpoint}") + logger.info("") + + # Set credential to None for emulator, or DefaultAzureCredential for Azure + credential = None if endpoint == "http://localhost:8080" else DefaultAzureCredential() + secure_channel = endpoint != "http://localhost:8080" + + # Create and start the worker using a context manager + with DurableTaskSchedulerWorker( + host_address=endpoint, + secure_channel=secure_channel, + taskhub=taskhub_name, + token_credential=credential + ) as worker: + + # Wrap with the agent worker + agent_worker = DurableAIAgentWorker(worker) + + # Create and register the Writer agent + logger.info("Creating and registering Writer agent...") + writer_agent = create_writer_agent() + agent_worker.add_agent(writer_agent) + + logger.info(f"✓ Registered agent: {writer_agent.name}") + logger.info(f" Entity name: dafx-{writer_agent.name}") + + # Register the orchestration function + logger.info("Registering orchestration function...") + worker.add_orchestrator(single_agent_chaining_orchestration) + logger.info("✓ Registered orchestration: single_agent_chaining_orchestration") + logger.info("") + + # Start the worker + worker.start() + logger.info("Worker started and listening for requests...") + logger.info("") + + # Run the client in the same process + try: + asyncio.run(run_client(endpoint, taskhub_name, credential)) + except KeyboardInterrupt: + logger.info("Sample interrupted by user") + finally: + logger.info("Worker stopping...") + + logger.info("Sample completed") + + +if __name__ == "__main__": + load_dotenv() + main() diff --git a/python/samples/getting_started/durabletask/04_single_agent_orchestration_chaining/worker.py b/python/samples/getting_started/durabletask/04_single_agent_orchestration_chaining/worker.py new file mode 100644 index 0000000000..b033bca155 --- /dev/null +++ b/python/samples/getting_started/durabletask/04_single_agent_orchestration_chaining/worker.py @@ -0,0 +1,167 @@ +"""Worker process for hosting a single agent with chaining orchestration using Durable Task. + +This worker registers a writer agent and an orchestration function that demonstrates +chaining behavior by running the agent twice sequentially on the same thread, +preserving conversation context between invocations. + +Prerequisites: +- Set AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_CHAT_DEPLOYMENT_NAME + (plus AZURE_OPENAI_API_KEY or Azure CLI authentication) +- Start a Durable Task Scheduler (e.g., using Docker) +""" + +import asyncio +from collections.abc import Generator +import logging +import os +from typing import Any + +from agent_framework import AgentRunResponse +from agent_framework.azure import AzureOpenAIChatClient +from agent_framework_durabletask import DurableAIAgentOrchestrationContext, DurableAIAgentWorker +from azure.identity import AzureCliCredential, DefaultAzureCredential +from durabletask.task import OrchestrationContext, Task +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Agent name +WRITER_AGENT_NAME = "WriterAgent" + + +def create_writer_agent(): + """Create the Writer agent using Azure OpenAI. + + This agent refines short pieces of text, enhancing initial sentences + and polishing improved versions further. + + Returns: + AgentProtocol: The configured Writer agent + """ + instructions = ( + "You refine short pieces of text. When given an initial sentence you enhance it;\n" + "when given an improved sentence you polish it further." + ) + + return AzureOpenAIChatClient(credential=AzureCliCredential()).create_agent( + name=WRITER_AGENT_NAME, + instructions=instructions, + ) + + +def single_agent_chaining_orchestration( + context: OrchestrationContext, _: str +) -> Generator[Task[Any], Any, str]: + """Orchestration that runs the writer agent twice on the same thread. + + This demonstrates chaining behavior where the output of the first agent run + becomes part of the input for the second run, all while maintaining the + conversation context through a shared thread. + + Args: + context: The orchestration context + _: Input parameter (unused) + + Returns: + str: The final refined text from the second agent run + """ + logger.info("[Orchestration] Starting single agent chaining...") + + # Wrap the orchestration context to access agents + agent_context = DurableAIAgentOrchestrationContext(context) + + # Get the writer agent using the agent context + writer = agent_context.get_agent(WRITER_AGENT_NAME) + + # Create a new thread for the conversation - this will be shared across both runs + writer_thread = writer.get_new_thread() + + logger.info(f"[Orchestration] Created thread: {writer_thread.session_id}") + + # First run: Generate an initial inspirational sentence + logger.info("[Orchestration] First agent run: Generating initial sentence...") + initial_response: AgentRunResponse = yield writer.run( + messages="Write a concise inspirational sentence about learning.", + thread=writer_thread, + ) + logger.info(f"[Orchestration] Initial response: {initial_response.text}") + + # Second run: Refine the initial response on the same thread + improved_prompt = ( + f"Improve this further while keeping it under 25 words: " + f"{initial_response.text}" + ) + + logger.info("[Orchestration] Second agent run: Refining the sentence...") + refined_response: AgentRunResponse = yield writer.run( + messages=improved_prompt, + thread=writer_thread, + ) + + logger.info(f"[Orchestration] Refined response: {refined_response.text}") + + logger.info("[Orchestration] Chaining complete") + return refined_response.text + + +async def main(): + """Main entry point for the worker process.""" + logger.info("Starting Durable Task Single Agent Chaining Worker with Orchestration...") + + # Get environment variables for taskhub and endpoint with defaults + taskhub_name = os.getenv("TASKHUB", "default") + endpoint = os.getenv("ENDPOINT", "http://localhost:8080") + + logger.info(f"Using taskhub: {taskhub_name}") + logger.info(f"Using endpoint: {endpoint}") + + # Set credential to None for emulator, or DefaultAzureCredential for Azure + credential = None if endpoint == "http://localhost:8080" else DefaultAzureCredential() + + # Create a worker using Azure Managed Durable Task + worker = DurableTaskSchedulerWorker( + host_address=endpoint, + secure_channel=endpoint != "http://localhost:8080", + taskhub=taskhub_name, + token_credential=credential + ) + + # Wrap it with the agent worker + agent_worker = DurableAIAgentWorker(worker) + + # Create and register the Writer agent + logger.info("Creating and registering Writer agent...") + writer_agent = create_writer_agent() + agent_worker.add_agent(writer_agent) + + logger.info(f"✓ Registered agent: {writer_agent.name}") + logger.info(f" Entity name: dafx-{writer_agent.name}") + logger.info("") + + # Register the orchestration function + logger.info("Registering orchestration function...") + worker.add_orchestrator(single_agent_chaining_orchestration) + logger.info(f"✓ Registered orchestration: {single_agent_chaining_orchestration.__name__}") + logger.info("") + + logger.info("Worker is ready and listening for requests...") + logger.info("Press Ctrl+C to stop.") + logger.info("") + + try: + # Start the worker (this blocks until stopped) + worker.start() + + # Keep the worker running + while True: + await asyncio.sleep(1) + except KeyboardInterrupt: + logger.info("Worker shutdown initiated") + + logger.info("Worker stopped") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/getting_started/durabletask/05_multi_agent_orchestration_concurrency/README.md b/python/samples/getting_started/durabletask/05_multi_agent_orchestration_concurrency/README.md new file mode 100644 index 0000000000..89efdb5e8d --- /dev/null +++ b/python/samples/getting_started/durabletask/05_multi_agent_orchestration_concurrency/README.md @@ -0,0 +1,75 @@ +# Multi-Agent Orchestration with Concurrency Sample + +This sample demonstrates how to host multiple agents and run them concurrently using a durable orchestration, aggregating their responses into a single result. + +## Key Concepts Demonstrated + +- Running multiple specialized agents in parallel within an orchestration. +- Using `OrchestrationAgentExecutor` to get `DurableAgentTask` objects for concurrent execution. +- Aggregating results from multiple agents using `task.when_all()`. +- Creating separate conversation threads for independent agent contexts. + +## Environment Setup + +See the [README.md](../README.md) file in the parent directory for more information on how to configure the environment, including how to install and run common sample dependencies. + +## Running the Sample + +With the environment setup, you can run the sample using one of two approaches: + +### Option 1: Combined Worker + Client (Quick Start) + +```bash +cd samples/getting_started/durabletask/05_multi_agent_orchestration_concurrency +python sample.py +``` + +This runs both worker and client in a single process. + +### Option 2: Separate Worker and Client + +**Start the worker in one terminal:** + +```bash +python worker.py +``` + +**In a new terminal, run the client:** + +```bash +python client.py +``` + +The orchestration will execute both agents concurrently, and you'll see output like: + +``` +Prompt: What is temperature? + +Starting multi-agent concurrent orchestration... +Orchestration started with instance ID: abc123... +Orchestration status: COMPLETED + +Results: + +Physicist's response: + Temperature measures the average kinetic energy of particles in a system... + +Chemist's response: + Temperature reflects how molecular motion influences reaction rates... +``` + +## Viewing Orchestration State + +You can view the state of the orchestration in the Durable Task Scheduler dashboard: + +1. Open your browser and navigate to `http://localhost:8082` +2. In the dashboard, you can view the orchestration instance, including: + - The concurrent execution of both agents (Physicist and Chemist) + - Separate conversation threads for each agent + - Parallel task execution and completion timing + - Aggregated results from both agents + - Overall orchestration state and history + +The orchestration demonstrates how multiple agents can be executed in parallel, with results collected and aggregated once all agents complete. + + diff --git a/python/samples/getting_started/durabletask/05_multi_agent_orchestration_concurrency/client.py b/python/samples/getting_started/durabletask/05_multi_agent_orchestration_concurrency/client.py new file mode 100644 index 0000000000..2e517a8b6a --- /dev/null +++ b/python/samples/getting_started/durabletask/05_multi_agent_orchestration_concurrency/client.py @@ -0,0 +1,114 @@ +"""Client application for starting a multi-agent concurrent orchestration. + +This client connects to the Durable Task Scheduler and starts an orchestration +that runs two agents (physicist and chemist) concurrently, then retrieves and +displays the aggregated results. + +Prerequisites: +- The worker must be running with both agents and orchestration registered +- Set AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_CHAT_DEPLOYMENT_NAME + (plus AZURE_OPENAI_API_KEY or Azure CLI authentication) +- Durable Task Scheduler must be running +""" + +import asyncio +import json +import logging +import os + +from azure.identity import DefaultAzureCredential +from durabletask.azuremanaged.client import DurableTaskSchedulerClient + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +async def main() -> None: + """Main entry point for the client application.""" + logger.info("Starting Durable Task Multi-Agent Orchestration Client...") + + # Get environment variables for taskhub and endpoint with defaults + taskhub_name = os.getenv("TASKHUB", "default") + endpoint = os.getenv("ENDPOINT", "http://localhost:8080") + + logger.info(f"Using taskhub: {taskhub_name}") + logger.info(f"Using endpoint: {endpoint}") + logger.info("") + + # Set credential to None for emulator, or DefaultAzureCredential for Azure + credential = None if endpoint == "http://localhost:8080" else DefaultAzureCredential() + + # Create a client using Azure Managed Durable Task + client = DurableTaskSchedulerClient( + host_address=endpoint, + secure_channel=endpoint != "http://localhost:8080", + taskhub=taskhub_name, + token_credential=credential + ) + + # Define the prompt to send to both agents + prompt = "What is temperature?" + + logger.info(f"Prompt: {prompt}") + logger.info("") + logger.info("Starting multi-agent concurrent orchestration...") + + try: + # Start the orchestration with the prompt as input + instance_id = client.schedule_new_orchestration( + orchestrator="multi_agent_concurrent_orchestration", + input=prompt, + ) + + logger.info(f"Orchestration started with instance ID: {instance_id}") + logger.info("Waiting for orchestration to complete...") + logger.info("") + + # Retrieve the final state + metadata = client.wait_for_orchestration_completion( + instance_id=instance_id, + ) + + if metadata and metadata.runtime_status.name == "COMPLETED": + result = metadata.serialized_output + + logger.info("=" * 80) + logger.info("Orchestration completed successfully!") + logger.info("=" * 80) + logger.info("") + logger.info(f"Prompt: {prompt}") + logger.info("") + logger.info("Results:") + logger.info("") + + # Parse and display the result + if result: + result_dict = json.loads(result) + + logger.info("Physicist's response:") + logger.info(f" {result_dict.get('physicist', 'N/A')}") + logger.info("") + + logger.info("Chemist's response:") + logger.info(f" {result_dict.get('chemist', 'N/A')}") + logger.info("") + + logger.info("=" * 80) + + elif metadata: + logger.error(f"Orchestration ended with status: {metadata.runtime_status.name}") + if metadata.serialized_output: + logger.error(f"Output: {metadata.serialized_output}") + else: + logger.error("Orchestration did not complete within the timeout period") + + except Exception as e: + logger.exception(f"Error during orchestration: {e}") + finally: + logger.info("") + logger.info("Client shutting down") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/getting_started/durabletask/05_multi_agent_orchestration_concurrency/requirements.txt b/python/samples/getting_started/durabletask/05_multi_agent_orchestration_concurrency/requirements.txt new file mode 100644 index 0000000000..371b9e3b79 --- /dev/null +++ b/python/samples/getting_started/durabletask/05_multi_agent_orchestration_concurrency/requirements.txt @@ -0,0 +1,6 @@ +# Agent Framework packages (installing from local package until a package is published) +-e ../../../../ +-e ../../../../packages/durabletask + +# Azure authentication +azure-identity diff --git a/python/samples/getting_started/durabletask/05_multi_agent_orchestration_concurrency/sample.py b/python/samples/getting_started/durabletask/05_multi_agent_orchestration_concurrency/sample.py new file mode 100644 index 0000000000..1f6723695e --- /dev/null +++ b/python/samples/getting_started/durabletask/05_multi_agent_orchestration_concurrency/sample.py @@ -0,0 +1,266 @@ +"""Multi-Agent Orchestration Sample - Durable Task Integration (Combined Worker + Client) + +This sample demonstrates running both the worker and client in a single process for +concurrent multi-agent orchestration. The worker registers two domain-specific agents +(physicist and chemist) and an orchestration function that runs them in parallel. + +The orchestration uses OrchestrationAgentExecutor to execute agents concurrently +and aggregate their responses. + +Prerequisites: +- Set AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_CHAT_DEPLOYMENT_NAME + (plus AZURE_OPENAI_API_KEY or Azure CLI authentication) +- Durable Task Scheduler must be running (e.g., using Docker) + +To run this sample: + python sample.py +""" + +import asyncio +import json +import logging +import os +from collections.abc import Generator +from typing import Any + +from agent_framework import AgentRunResponse +from agent_framework.azure import AzureOpenAIChatClient +from agent_framework_durabletask import DurableAIAgentOrchestrationContext, DurableAIAgentWorker +from azure.identity import AzureCliCredential, DefaultAzureCredential +from dotenv import load_dotenv +from durabletask.task import OrchestrationContext, when_all, Task +from durabletask.azuremanaged.client import DurableTaskSchedulerClient +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Agent names +PHYSICIST_AGENT_NAME = "PhysicistAgent" +CHEMIST_AGENT_NAME = "ChemistAgent" + + +def create_physicist_agent(): + """Create the Physicist agent using Azure OpenAI. + + Returns: + AgentProtocol: The configured Physicist agent + """ + return AzureOpenAIChatClient(credential=AzureCliCredential()).create_agent( + name=PHYSICIST_AGENT_NAME, + instructions="You are an expert in physics. You answer questions from a physics perspective.", + ) + + +def create_chemist_agent(): + """Create the Chemist agent using Azure OpenAI. + + Returns: + AgentProtocol: The configured Chemist agent + """ + return AzureOpenAIChatClient(credential=AzureCliCredential()).create_agent( + name=CHEMIST_AGENT_NAME, + instructions="You are an expert in chemistry. You answer questions from a chemistry perspective.", + ) + + +def multi_agent_concurrent_orchestration(context: OrchestrationContext, prompt: str) -> Generator[Task[Any], Any, dict[str, str]]: + """Orchestration that runs both agents in parallel and aggregates results. + + Uses DurableAIAgentOrchestrationContext to wrap the orchestration context and + access agents via the OrchestrationAgentExecutor. + + Args: + context: The orchestration context + + Returns: + dict: Dictionary with 'physicist' and 'chemist' response texts + """ + logger.info(f"[Orchestration] Starting concurrent execution for prompt: {prompt}") + + # Wrap the orchestration context to access agents + agent_context = DurableAIAgentOrchestrationContext(context) + + # Get agents using the agent context (returns DurableAIAgent proxies) + physicist = agent_context.get_agent(PHYSICIST_AGENT_NAME) + chemist = agent_context.get_agent(CHEMIST_AGENT_NAME) + + # Create separate threads for each agent + physicist_thread = physicist.get_new_thread() + chemist_thread = chemist.get_new_thread() + + logger.info(f"[Orchestration] Created threads - Physicist: {physicist_thread.session_id}, Chemist: {chemist_thread.session_id}") + + # Create tasks from agent.run() calls - these return DurableAgentTask instances + physicist_task = physicist.run(messages=str(prompt), thread=physicist_thread) + chemist_task = chemist.run(messages=str(prompt), thread=chemist_thread) + + logger.info("[Orchestration] Created agent tasks, executing concurrently...") + + # Execute both tasks concurrently using task.when_all + # The DurableAgentTask instances wrap the underlying entity calls + task_results = yield when_all([physicist_task, chemist_task]) + + logger.info("[Orchestration] Both agents completed") + + # Extract results from the tasks - DurableAgentTask yields AgentRunResponse + physicist_result: AgentRunResponse = task_results[0] + chemist_result: AgentRunResponse = task_results[1] + + result = { + "physicist": physicist_result.text, + "chemist": chemist_result.text, + } + + logger.info(f"[Orchestration] Aggregated results ready") + return result + + +async def run_client(endpoint: str, taskhub_name: str, credential: DefaultAzureCredential | None, prompt: str): + """Run the client to start and monitor the orchestration. + + Args: + endpoint: The durable task scheduler endpoint + taskhub_name: The task hub name + credential: The credential for authentication + prompt: The prompt to send to both agents + """ + logger.info("") + logger.info("=" * 80) + logger.info("CLIENT: Starting orchestration...") + logger.info("=" * 80) + + # Create a client + client = DurableTaskSchedulerClient( + host_address=endpoint, + secure_channel=endpoint != "http://localhost:8080", + taskhub=taskhub_name, + token_credential=credential + ) + + logger.info(f"Prompt: {prompt}") + logger.info("") + + try: + # Start the orchestration with the prompt as input + instance_id = client.schedule_new_orchestration( + multi_agent_concurrent_orchestration, + input=prompt, + ) + + logger.info(f"Orchestration started with instance ID: {instance_id}") + logger.info("Waiting for orchestration to complete...") + logger.info("") + + # Retrieve the final state + metadata = client.wait_for_orchestration_completion( + instance_id=instance_id + ) + + if metadata and metadata.runtime_status.name == "COMPLETED": + result = metadata.serialized_output + + logger.info("") + logger.info("=" * 80) + logger.info("ORCHESTRATION COMPLETED SUCCESSFULLY!") + logger.info("=" * 80) + logger.info("") + logger.info(f"Prompt: {prompt}") + logger.info("") + logger.info("Results:") + logger.info("") + + # Parse and display the result + if result: + result_dict = json.loads(result) + + logger.info("Physicist's response:") + logger.info(f" {result_dict.get('physicist', 'N/A')}") + logger.info("") + + logger.info("Chemist's response:") + logger.info(f" {result_dict.get('chemist', 'N/A')}") + logger.info("") + + logger.info("=" * 80) + + elif metadata: + logger.error(f"Orchestration ended with status: {metadata.runtime_status.name}") + if metadata.serialized_output: + logger.error(f"Output: {metadata.serialized_output}") + else: + logger.error("Orchestration did not complete within the timeout period") + + except Exception as e: + logger.exception(f"Error during orchestration: {e}") + + +def main(): + """Main entry point - runs both worker and client in single process.""" + logger.info("Starting Durable Task Multi-Agent Orchestration Sample (Combined Worker + Client)...") + + # Get environment variables for taskhub and endpoint with defaults + taskhub_name = os.getenv("TASKHUB", "default") + endpoint = os.getenv("ENDPOINT", "http://localhost:8080") + + logger.info(f"Using taskhub: {taskhub_name}") + logger.info(f"Using endpoint: {endpoint}") + logger.info("") + + # Set credential to None for emulator, or DefaultAzureCredential for Azure + credential = None if endpoint == "http://localhost:8080" else DefaultAzureCredential() + secure_channel = endpoint != "http://localhost:8080" + + # Create and start the worker using a context manager + with DurableTaskSchedulerWorker( + host_address=endpoint, + secure_channel=secure_channel, + taskhub=taskhub_name, + token_credential=credential + ) as worker: + + # Wrap with the agent worker + agent_worker = DurableAIAgentWorker(worker) + + # Create and register both agents + logger.info("Creating and registering agents...") + physicist_agent = create_physicist_agent() + chemist_agent = create_chemist_agent() + + agent_worker.add_agent(physicist_agent) + agent_worker.add_agent(chemist_agent) + + logger.info(f"✓ Registered agent: {physicist_agent.name}") + logger.info(f" Entity name: dafx-{physicist_agent.name}") + logger.info(f"✓ Registered agent: {chemist_agent.name}") + logger.info(f" Entity name: dafx-{chemist_agent.name}") + logger.info("") + + # Register the orchestration function + logger.info("Registering orchestration function...") + worker.add_orchestrator(multi_agent_concurrent_orchestration) + logger.info(f"✓ Registered orchestration: {multi_agent_concurrent_orchestration.__name__}") + logger.info("") + + # Start the worker + worker.start() + logger.info("Worker started and listening for requests...") + + # Define the prompt + prompt = "What is temperature?" + + try: + # Run the client to start the orchestration + asyncio.run(run_client(endpoint, taskhub_name, credential, prompt)) + + except Exception as e: + logger.exception(f"Error during sample execution: {e}") + + logger.info("") + logger.info("Sample completed. Worker shutting down...") + + +if __name__ == "__main__": + load_dotenv() + main() diff --git a/python/samples/getting_started/durabletask/05_multi_agent_orchestration_concurrency/worker.py b/python/samples/getting_started/durabletask/05_multi_agent_orchestration_concurrency/worker.py new file mode 100644 index 0000000000..bb35b5a8d4 --- /dev/null +++ b/python/samples/getting_started/durabletask/05_multi_agent_orchestration_concurrency/worker.py @@ -0,0 +1,175 @@ +"""Worker process for hosting multiple agents with orchestration using Durable Task. + +This worker registers two domain-specific agents (physicist and chemist) and an orchestration +function that runs them concurrently. The orchestration uses OrchestrationAgentExecutor +to execute agents in parallel and aggregate their responses. + +Prerequisites: +- Set AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_CHAT_DEPLOYMENT_NAME + (plus AZURE_OPENAI_API_KEY or Azure CLI authentication) +- Start a Durable Task Scheduler (e.g., using Docker) +""" + +import asyncio +from collections.abc import Generator +import logging +import os +from typing import Any + +from agent_framework import AgentRunResponse +from agent_framework.azure import AzureOpenAIChatClient +from agent_framework_durabletask import DurableAIAgentOrchestrationContext, DurableAIAgentWorker +from azure.identity import AzureCliCredential, DefaultAzureCredential +from durabletask.task import OrchestrationContext, when_all, Task +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Agent names +PHYSICIST_AGENT_NAME = "PhysicistAgent" +CHEMIST_AGENT_NAME = "ChemistAgent" + + +def create_physicist_agent(): + """Create the Physicist agent using Azure OpenAI. + + Returns: + AgentProtocol: The configured Physicist agent + """ + return AzureOpenAIChatClient(credential=AzureCliCredential()).create_agent( + name=PHYSICIST_AGENT_NAME, + instructions="You are an expert in physics. You answer questions from a physics perspective.", + ) + + +def create_chemist_agent(): + """Create the Chemist agent using Azure OpenAI. + + Returns: + AgentProtocol: The configured Chemist agent + """ + return AzureOpenAIChatClient(credential=AzureCliCredential()).create_agent( + name=CHEMIST_AGENT_NAME, + instructions="You are an expert in chemistry. You answer questions from a chemistry perspective.", + ) + + +def multi_agent_concurrent_orchestration(context: OrchestrationContext, prompt: str) -> Generator[Task[Any], Any, dict[str, str]]: + """Orchestration that runs both agents in parallel and aggregates results. + + Uses DurableAIAgentOrchestrationContext to wrap the orchestration context and + access agents via the OrchestrationAgentExecutor. + + Args: + context: The orchestration context + + Returns: + dict: Dictionary with 'physicist' and 'chemist' response texts + """ + + logger.info(f"[Orchestration] Starting concurrent execution for prompt: {prompt}") + + # Wrap the orchestration context to access agents + agent_context = DurableAIAgentOrchestrationContext(context) + + # Get agents using the agent context (returns DurableAIAgent proxies) + physicist = agent_context.get_agent(PHYSICIST_AGENT_NAME) + chemist = agent_context.get_agent(CHEMIST_AGENT_NAME) + + # Create separate threads for each agent + physicist_thread = physicist.get_new_thread() + chemist_thread = chemist.get_new_thread() + + logger.info(f"[Orchestration] Created threads - Physicist: {physicist_thread.session_id}, Chemist: {chemist_thread.session_id}") + + # Create tasks from agent.run() calls - these return DurableAgentTask instances + physicist_task = physicist.run(messages=str(prompt), thread=physicist_thread) + chemist_task = chemist.run(messages=str(prompt), thread=chemist_thread) + + logger.info("[Orchestration] Created agent tasks, executing concurrently...") + + # Execute both tasks concurrently using when_all + # The DurableAgentTask instances wrap the underlying entity calls + task_results = yield when_all([physicist_task, chemist_task]) + + logger.info("[Orchestration] Both agents completed") + + # Extract results from the tasks - DurableAgentTask yields AgentRunResponse + physicist_result: AgentRunResponse = task_results[0] + chemist_result: AgentRunResponse = task_results[1] + + result = { + "physicist": physicist_result.text, + "chemist": chemist_result.text, + } + + logger.info(f"[Orchestration] Aggregated results ready") + return result + + +async def main(): + """Main entry point for the worker process.""" + logger.info("Starting Durable Task Multi-Agent Worker with Orchestration...") + + # Get environment variables for taskhub and endpoint with defaults + taskhub_name = os.getenv("TASKHUB", "default") + endpoint = os.getenv("ENDPOINT", "http://localhost:8080") + + logger.info(f"Using taskhub: {taskhub_name}") + logger.info(f"Using endpoint: {endpoint}") + + # Set credential to None for emulator, or DefaultAzureCredential for Azure + credential = None if endpoint == "http://localhost:8080" else DefaultAzureCredential() + + # Create a worker using Azure Managed Durable Task + worker = DurableTaskSchedulerWorker( + host_address=endpoint, + secure_channel=endpoint != "http://localhost:8080", + taskhub=taskhub_name, + token_credential=credential + ) + + # Wrap it with the agent worker + agent_worker = DurableAIAgentWorker(worker) + + # Create and register both agents + logger.info("Creating and registering agents...") + physicist_agent = create_physicist_agent() + chemist_agent = create_chemist_agent() + + agent_worker.add_agent(physicist_agent) + agent_worker.add_agent(chemist_agent) + + logger.info(f"✓ Registered agent: {physicist_agent.name}") + logger.info(f" Entity name: dafx-{physicist_agent.name}") + logger.info(f"✓ Registered agent: {chemist_agent.name}") + logger.info(f" Entity name: dafx-{chemist_agent.name}") + logger.info("") + + # Register the orchestration function + logger.info("Registering orchestration function...") + worker.add_orchestrator(multi_agent_concurrent_orchestration) + logger.info(f"✓ Registered orchestration: {multi_agent_concurrent_orchestration.__name__}") + logger.info("") + + logger.info("Worker is ready and listening for requests...") + logger.info("Press Ctrl+C to stop.") + logger.info("") + + try: + # Start the worker (this blocks until stopped) + worker.start() + + # Keep the worker running + while True: + await asyncio.sleep(1) + except KeyboardInterrupt: + logger.info("Worker shutdown initiated") + + logger.info("Worker stopped") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/getting_started/durabletask/README.md b/python/samples/getting_started/durabletask/README.md new file mode 100644 index 0000000000..3b63294756 --- /dev/null +++ b/python/samples/getting_started/durabletask/README.md @@ -0,0 +1,124 @@ +# Durable Task Samples + +This directory contains samples for durable agent hosting using the Durable Task Scheduler. These samples demonstrate the worker-client architecture pattern, enabling distributed agent execution with persistent conversation state. + +- **[01_single_agent](01_single_agent/)**: A sample that demonstrates how to host a single conversational agent using the Durable Task Scheduler and interact with it via a client. +- **[04_single_agent_orchestration_chaining](04_single_agent_orchestration_chaining/)**: A sample that demonstrates how to chain multiple invocations of the same agent using a durable orchestration. +- **[05_multi_agent_orchestration_concurrency](05_multi_agent_orchestration_concurrency/)**: A sample that demonstrates how to host multiple agents and run them concurrently using a durable orchestration. + +## Running the Samples + +These samples are designed to be run locally in a cloned repository. + +### Prerequisites + +The following prerequisites are required to run the samples: + +- [Python 3.9 or later](https://www.python.org/downloads/) +- [Azure CLI](https://learn.microsoft.com/cli/azure/install-azure-cli) installed and authenticated (`az login`) or an API key for the Azure OpenAI service +- [Azure OpenAI Service](https://learn.microsoft.com/azure/ai-services/openai/how-to/create-resource) with a deployed model (gpt-4o-mini or better is recommended) +- [Durable Task Scheduler](https://learn.microsoft.com/azure/azure-functions/durable/durable-task-scheduler/develop-with-durable-task-scheduler) (local emulator or Azure-hosted) +- [Docker](https://docs.docker.com/get-docker/) installed if running the Durable Task Scheduler emulator locally + +### Configuring RBAC Permissions for Azure OpenAI + +These samples are configured to use the Azure OpenAI service with RBAC permissions to access the model. You'll need to configure the RBAC permissions for the Azure OpenAI service to allow the Python app to access the model. + +Below is an example of how to configure the RBAC permissions for the Azure OpenAI service to allow the current user to access the model. + +Bash (Linux/macOS/WSL): + +```bash +az role assignment create \ + --assignee "yourname@contoso.com" \ + --role "Cognitive Services OpenAI User" \ + --scope /subscriptions//resourceGroups//providers/Microsoft.CognitiveServices/accounts/ +``` + +PowerShell: + +```powershell +az role assignment create ` + --assignee "yourname@contoso.com" ` + --role "Cognitive Services OpenAI User" ` + --scope /subscriptions//resourceGroups//providers/Microsoft.CognitiveServices/accounts/ +``` + +More information on how to configure RBAC permissions for Azure OpenAI can be found in the [Azure OpenAI documentation](https://learn.microsoft.com/azure/ai-services/openai/how-to/create-resource?pivots=cli). + +### Setting an API key for the Azure OpenAI service + +As an alternative to configuring Azure RBAC permissions, you can set an API key for the Azure OpenAI service by setting the `AZURE_OPENAI_API_KEY` environment variable. + +Bash (Linux/macOS/WSL): + +```bash +export AZURE_OPENAI_API_KEY="your-api-key" +``` + +PowerShell: + +```powershell +$env:AZURE_OPENAI_API_KEY="your-api-key" +``` + +### Start Durable Task Scheduler + +Most samples use the Durable Task Scheduler (DTS) to support hosted agents and durable orchestrations. DTS also allows you to view the status of orchestrations and their inputs and outputs from a web UI. + +To run the Durable Task Scheduler locally, you can use the following `docker` command: + +```bash +docker run -d --name dts-emulator -p 8080:8080 -p 8082:8082 mcr.microsoft.com/dts/dts-emulator:latest +``` + +The DTS dashboard will be available at `http://localhost:8082`. + +### Environment Configuration + +Each sample reads configuration from environment variables. You'll need to set the following environment variables: + +```bash +export AZURE_OPENAI_ENDPOINT="https://your-resource.openai.azure.com/" +export AZURE_OPENAI_CHAT_DEPLOYMENT_NAME="your-deployment-name" +``` + +### Installing Dependencies + +Navigate to the sample directory and install dependencies: + +```bash +cd samples/getting_started/durabletask/01_single_agent +pip install -r requirements.txt +``` + +### Running the Samples + +Each sample follows a worker-client architecture. Most samples provide separate `worker.py` and `client.py` files, though some include a combined `sample.py` for convenience. + +**Running with separate worker and client:** + +In one terminal, start the worker: + +```bash +python worker.py +``` + +In another terminal, run the client: + +```bash +python client.py +``` + +**Running with combined sample:** + +```bash +python sample.py +``` + +### Viewing the Sample Output + +The sample output is displayed directly in the terminal where you ran the Python script. Agent responses are printed to stdout with log formatting for better readability. + +You can also see the state of agents and orchestrations in the Durable Task Scheduler dashboard at `http://localhost:8082`. +