Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 179 additions & 19 deletions src/agents/mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,17 @@
import abc
import asyncio
import inspect
import sys
from collections.abc import Awaitable
from contextlib import AbstractAsyncContextManager, AsyncExitStack
from datetime import timedelta
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar

import httpx

if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup # pyright: ignore[reportMissingImports]
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from mcp import ClientSession, StdioServerParameters, Tool as MCPTool, stdio_client
from mcp.client.session import MessageHandlerFnT
Expand Down Expand Up @@ -251,6 +256,35 @@ def invalidate_tools_cache(self):
"""Invalidate the tools cache."""
self._cache_dirty = True

def _extract_http_error_from_exception(self, e: Exception) -> Exception | None:
"""Extract HTTP error from exception or ExceptionGroup."""
if isinstance(e, (httpx.HTTPStatusError, httpx.ConnectError, httpx.TimeoutException)):
return e

# Check if it's an ExceptionGroup containing HTTP errors
if isinstance(e, BaseExceptionGroup):
for exc in e.exceptions:
if isinstance(
exc, (httpx.HTTPStatusError, httpx.ConnectError, httpx.TimeoutException)
):
return exc

return None

def _raise_user_error_for_http_error(self, http_error: Exception) -> None:
"""Raise appropriate UserError for HTTP error."""
error_message = f"Failed to connect to MCP server '{self.name}': "
if isinstance(http_error, httpx.HTTPStatusError):
error_message += f"HTTP error {http_error.response.status_code} ({http_error.response.reason_phrase})" # noqa: E501

elif isinstance(http_error, httpx.ConnectError):
error_message += "Could not reach the server."

elif isinstance(http_error, httpx.TimeoutException):
error_message += "Connection timeout."

raise UserError(error_message) from http_error

async def _run_with_retries(self, func: Callable[[], Awaitable[T]]) -> T:
attempts = 0
while True:
Expand All @@ -265,6 +299,7 @@ async def _run_with_retries(self, func: Callable[[], Awaitable[T]]) -> T:

async def connect(self):
"""Connect to the server."""
connection_succeeded = False
try:
transport = await self.exit_stack.enter_async_context(self.create_streams())
# streamablehttp_client returns (read, write, get_session_id)
Expand All @@ -285,10 +320,49 @@ async def connect(self):
server_result = await session.initialize()
self.server_initialize_result = server_result
self.session = session
connection_succeeded = True
except Exception as e:
logger.error(f"Error initializing MCP server: {e}")
await self.cleanup()
# Try to extract HTTP error from exception or ExceptionGroup
http_error = self._extract_http_error_from_exception(e)
if http_error:
self._raise_user_error_for_http_error(http_error)

# For CancelledError, preserve cancellation semantics - don't wrap it.
# If it's masking an HTTP error, cleanup() will extract and raise UserError.
if isinstance(e, asyncio.CancelledError):
raise

# For HTTP-related errors, wrap them
if isinstance(e, (httpx.HTTPStatusError, httpx.ConnectError, httpx.TimeoutException)):
self._raise_user_error_for_http_error(e)

# For other errors, re-raise as-is (don't wrap non-HTTP errors)
raise
finally:
# Always attempt cleanup on error, but suppress cleanup errors that mask the original
if not connection_succeeded:
try:
await self.cleanup()
except UserError:
# Re-raise UserError from cleanup (contains the real HTTP error)
raise
except Exception as cleanup_error:
# Suppress RuntimeError about cancel scopes during cleanup - this is a known
# issue with the MCP library's async generator cleanup and shouldn't mask the
# original error
if isinstance(cleanup_error, RuntimeError) and "cancel scope" in str(
cleanup_error
):
logger.debug(
f"Ignoring cancel scope error during cleanup of MCP server "
f"'{self.name}': {cleanup_error}"
)
else:
# Log other cleanup errors but don't raise - original error is more
# important
logger.warning(
f"Error during cleanup of MCP server '{self.name}': {cleanup_error}"
)

async def list_tools(
self,
Expand All @@ -301,21 +375,32 @@ async def list_tools(
session = self.session
assert session is not None

# Return from cache if caching is enabled, we have tools, and the cache is not dirty
if self.cache_tools_list and not self._cache_dirty and self._tools_list:
tools = self._tools_list
else:
# Fetch the tools from the server
result = await self._run_with_retries(lambda: session.list_tools())
self._tools_list = result.tools
self._cache_dirty = False
tools = self._tools_list

# Filter tools based on tool_filter
filtered_tools = tools
if self.tool_filter is not None:
filtered_tools = await self._apply_tool_filter(filtered_tools, run_context, agent)
return filtered_tools
try:
# Return from cache if caching is enabled, we have tools, and the cache is not dirty
if self.cache_tools_list and not self._cache_dirty and self._tools_list:
tools = self._tools_list
else:
# Fetch the tools from the server
result = await self._run_with_retries(lambda: session.list_tools())
self._tools_list = result.tools
self._cache_dirty = False
tools = self._tools_list

# Filter tools based on tool_filter
filtered_tools = tools
if self.tool_filter is not None:
filtered_tools = await self._apply_tool_filter(filtered_tools, run_context, agent)
return filtered_tools
except httpx.HTTPStatusError as e:
status_code = e.response.status_code
raise UserError(
f"Failed to list tools from MCP server '{self.name}': HTTP error {status_code}"
) from e
except httpx.ConnectError as e:
raise UserError(
f"Failed to list tools from MCP server '{self.name}': Connection lost. "
f"The server may have disconnected."
) from e

async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult:
"""Invoke a tool on the server."""
Expand All @@ -324,7 +409,19 @@ async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> C
session = self.session
assert session is not None

return await self._run_with_retries(lambda: session.call_tool(tool_name, arguments))
try:
return await self._run_with_retries(lambda: session.call_tool(tool_name, arguments))
except httpx.HTTPStatusError as e:
status_code = e.response.status_code
raise UserError(
f"Failed to call tool '{tool_name}' on MCP server '{self.name}': "
f"HTTP error {status_code}"
) from e
except httpx.ConnectError as e:
raise UserError(
f"Failed to call tool '{tool_name}' on MCP server '{self.name}': Connection lost. "
f"The server may have disconnected."
) from e

async def list_prompts(
self,
Expand All @@ -347,10 +444,73 @@ async def get_prompt(
async def cleanup(self):
"""Cleanup the server."""
async with self._cleanup_lock:
# Only raise HTTP errors if we're cleaning up after a failed connection.
# During normal teardown (via __aexit__), log but don't raise to avoid
# masking the original exception.
is_failed_connection_cleanup = self.session is None

try:
await self.exit_stack.aclose()
except BaseExceptionGroup as eg:
# Extract HTTP errors from ExceptionGroup raised during cleanup
# This happens when background tasks fail (e.g., HTTP errors)
http_error = None
connect_error = None
timeout_error = None
error_message = f"Failed to connect to MCP server '{self.name}': "

for exc in eg.exceptions:
if isinstance(exc, httpx.HTTPStatusError):
http_error = exc
elif isinstance(exc, httpx.ConnectError):
connect_error = exc
elif isinstance(exc, httpx.TimeoutException):
timeout_error = exc

# Only raise HTTP errors if we're cleaning up after a failed connection.
# During normal teardown, log them instead.
if http_error:
if is_failed_connection_cleanup:
error_message += f"HTTP error {http_error.response.status_code} ({http_error.response.reason_phrase})" # noqa: E501
raise UserError(error_message) from http_error
else:
# Normal teardown - log but don't raise
logger.warning(
f"HTTP error during cleanup of MCP server '{self.name}': {http_error}"
)
elif connect_error:
if is_failed_connection_cleanup:
error_message += "Could not reach the server."
raise UserError(error_message) from connect_error
else:
logger.warning(
f"Connection error during cleanup of MCP server '{self.name}': {connect_error}" # noqa: E501
)
elif timeout_error:
if is_failed_connection_cleanup:
error_message += "Connection timeout."
raise UserError(error_message) from timeout_error
else:
logger.warning(
f"Timeout error during cleanup of MCP server '{self.name}': {timeout_error}" # noqa: E501
)
else:
# No HTTP error found, suppress RuntimeError about cancel scopes
has_cancel_scope_error = any(
isinstance(exc, RuntimeError) and "cancel scope" in str(exc)
for exc in eg.exceptions
)
if has_cancel_scope_error:
logger.debug(f"Ignoring cancel scope error during cleanup: {eg}")
else:
logger.error(f"Error cleaning up server: {eg}")
except Exception as e:
logger.error(f"Error cleaning up server: {e}")
# Suppress RuntimeError about cancel scopes - this is a known issue with the MCP
# library when background tasks fail during async generator cleanup
if isinstance(e, RuntimeError) and "cancel scope" in str(e):
logger.debug(f"Ignoring cancel scope error during cleanup: {e}")
else:
logger.error(f"Error cleaning up server: {e}")
finally:
self.session = None

Expand Down
9 changes: 7 additions & 2 deletions src/agents/mcp/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,14 @@ async def invoke_mcp_tool(

try:
result = await server.call_tool(tool.name, json_data)
except UserError:
# Re-raise UserError as-is (it already has a good message)
raise
except Exception as e:
logger.error(f"Error invoking MCP tool {tool.name}: {e}")
raise AgentsException(f"Error invoking MCP tool {tool.name}: {e}") from e
logger.error(f"Error invoking MCP tool {tool.name} on server '{server.name}': {e}")
raise AgentsException(
f"Error invoking MCP tool {tool.name} on server '{server.name}': {e}"
) from e

if _debug.DONT_LOG_TOOL_DATA:
logger.debug(f"MCP tool {tool.name} completed.")
Expand Down