|
1 | 1 | from collections.abc import AsyncGenerator |
2 | 2 | from typing import Any |
| 3 | +from unittest.mock import AsyncMock, patch |
3 | 4 |
|
4 | 5 | import anyio |
5 | 6 | import pytest |
|
9 | 10 | from mcp.server.lowlevel.server import Server |
10 | 11 | from mcp.shared.exceptions import McpError |
11 | 12 | from mcp.shared.memory import create_client_server_memory_streams, create_connected_server_and_client_session |
| 13 | +from mcp.shared.session import BaseSession |
12 | 14 | from mcp.types import ( |
| 15 | + CONNECTION_CLOSED, |
| 16 | + INTERNAL_ERROR, |
13 | 17 | CancelledNotification, |
14 | 18 | CancelledNotificationParams, |
15 | 19 | ClientNotification, |
16 | 20 | ClientRequest, |
17 | 21 | EmptyResult, |
| 22 | + ErrorData, |
| 23 | + JSONRPCError, |
| 24 | + JSONRPCResponse, |
18 | 25 | TextContent, |
19 | 26 | ) |
20 | 27 |
|
@@ -168,3 +175,111 @@ async def mock_server(): |
168 | 175 | await ev_closed.wait() |
169 | 176 | with anyio.fail_after(1): |
170 | 177 | await ev_response.wait() |
| 178 | + |
| 179 | + |
| 180 | +class TestProcessResponse: |
| 181 | + """Tests for BaseSession._process_response static method.""" |
| 182 | + |
| 183 | + def test_process_response_with_valid_response(self): |
| 184 | + """Test that a valid JSONRPCResponse is processed correctly.""" |
| 185 | + response = JSONRPCResponse( |
| 186 | + jsonrpc="2.0", |
| 187 | + id=1, |
| 188 | + result={}, |
| 189 | + ) |
| 190 | + |
| 191 | + result = BaseSession._process_response(response, EmptyResult) |
| 192 | + |
| 193 | + assert isinstance(result, EmptyResult) |
| 194 | + |
| 195 | + def test_process_response_with_error(self): |
| 196 | + """Test that a JSONRPCError raises McpError.""" |
| 197 | + error = JSONRPCError( |
| 198 | + jsonrpc="2.0", |
| 199 | + id=1, |
| 200 | + error=ErrorData(code=-32600, message="Invalid request"), |
| 201 | + ) |
| 202 | + |
| 203 | + with pytest.raises(McpError) as exc_info: |
| 204 | + BaseSession._process_response(error, EmptyResult) |
| 205 | + |
| 206 | + assert exc_info.value.error.code == -32600 |
| 207 | + assert exc_info.value.error.message == "Invalid request" |
| 208 | + |
| 209 | + def test_process_response_with_none(self): |
| 210 | + """ |
| 211 | + Test defensive check for anyio fail_after race condition (#1717). |
| 212 | +
|
| 213 | + If anyio's CancelScope incorrectly suppresses an exception during |
| 214 | + receive(), the response variable may never be assigned. This test |
| 215 | + verifies we handle this gracefully instead of raising UnboundLocalError. |
| 216 | +
|
| 217 | + See: https://github.com/agronholm/anyio/issues/589 |
| 218 | + """ |
| 219 | + with pytest.raises(McpError) as exc_info: |
| 220 | + BaseSession._process_response(None, EmptyResult) |
| 221 | + |
| 222 | + assert exc_info.value.error.code == INTERNAL_ERROR |
| 223 | + assert "no response received" in exc_info.value.error.message |
| 224 | + |
| 225 | + |
| 226 | +@pytest.mark.anyio |
| 227 | +async def test_send_request_handles_end_of_stream(): |
| 228 | + """Test that EndOfStream from response stream raises McpError with CONNECTION_CLOSED.""" |
| 229 | + |
| 230 | + async with create_client_server_memory_streams() as (client_streams, _): |
| 231 | + client_read, client_write = client_streams |
| 232 | + |
| 233 | + async with ClientSession(read_stream=client_read, write_stream=client_write) as client_session: |
| 234 | + # Mock create_memory_object_stream to return a stream that raises EndOfStream |
| 235 | + mock_reader = AsyncMock() |
| 236 | + mock_reader.receive = AsyncMock(side_effect=anyio.EndOfStream) |
| 237 | + mock_reader.aclose = AsyncMock() |
| 238 | + |
| 239 | + mock_sender = AsyncMock() |
| 240 | + mock_sender.aclose = AsyncMock() |
| 241 | + |
| 242 | + # The subscripted form returns a callable that returns the tuple |
| 243 | + with patch("mcp.shared.session.anyio.create_memory_object_stream") as mock_create: |
| 244 | + # pyright: ignore[reportUnknownLambdaType] |
| 245 | + mock_create.__getitem__ = lambda _s, _k: lambda _z: (mock_sender, mock_reader) # type: ignore |
| 246 | + |
| 247 | + with pytest.raises(McpError) as exc_info: |
| 248 | + await client_session.send_request( |
| 249 | + ClientRequest(types.PingRequest()), |
| 250 | + EmptyResult, |
| 251 | + ) |
| 252 | + |
| 253 | + assert exc_info.value.error.code == CONNECTION_CLOSED |
| 254 | + assert "stream ended unexpectedly" in exc_info.value.error.message |
| 255 | + |
| 256 | + |
| 257 | +@pytest.mark.anyio |
| 258 | +async def test_send_request_handles_closed_resource_error(): |
| 259 | + """Test that ClosedResourceError from response stream raises McpError with CONNECTION_CLOSED.""" |
| 260 | + |
| 261 | + async with create_client_server_memory_streams() as (client_streams, _): |
| 262 | + client_read, client_write = client_streams |
| 263 | + |
| 264 | + async with ClientSession(read_stream=client_read, write_stream=client_write) as client_session: |
| 265 | + # Mock create_memory_object_stream to return a stream that raises ClosedResourceError |
| 266 | + mock_reader = AsyncMock() |
| 267 | + mock_reader.receive = AsyncMock(side_effect=anyio.ClosedResourceError) |
| 268 | + mock_reader.aclose = AsyncMock() |
| 269 | + |
| 270 | + mock_sender = AsyncMock() |
| 271 | + mock_sender.aclose = AsyncMock() |
| 272 | + |
| 273 | + # The subscripted form returns a callable that returns the tuple |
| 274 | + with patch("mcp.shared.session.anyio.create_memory_object_stream") as mock_create: |
| 275 | + # pyright: ignore[reportUnknownLambdaType] |
| 276 | + mock_create.__getitem__ = lambda _s, _k: lambda _z: (mock_sender, mock_reader) # type: ignore |
| 277 | + |
| 278 | + with pytest.raises(McpError) as exc_info: |
| 279 | + await client_session.send_request( |
| 280 | + ClientRequest(types.PingRequest()), |
| 281 | + EmptyResult, |
| 282 | + ) |
| 283 | + |
| 284 | + assert exc_info.value.error.code == CONNECTION_CLOSED |
| 285 | + assert "Connection closed" in exc_info.value.error.message |
0 commit comments