Skip to content

Commit ee2a010

Browse files
committed
Fix UnboundLocalError in send_request when response not assigned
Fixes an intermittent UnboundLocalError that could occur in send_request() when anyio's fail_after context manager incorrectly suppresses exceptions due to a race condition (anyio#589). Changes: - Initialize response_or_error to None before the try block - Add _process_response static method with defensive None check - Handle EndOfStream and ClosedResourceError exceptions explicitly - Use try/except/else structure for cleaner control flow - Add unit tests for _process_response Closes #1717
1 parent 8e02fc1 commit ee2a010

File tree

2 files changed

+163
-5
lines changed

2 files changed

+163
-5
lines changed

src/mcp/shared/session.py

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from mcp.shared.response_router import ResponseRouter
1717
from mcp.types import (
1818
CONNECTION_CLOSED,
19+
INTERNAL_ERROR,
1920
INVALID_PARAMS,
2021
CancelledNotification,
2122
ClientNotification,
@@ -237,6 +238,34 @@ async def __aexit__(
237238
self._task_group.cancel_scope.cancel()
238239
return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
239240

241+
@staticmethod
242+
def _process_response(
243+
response_or_error: JSONRPCResponse | JSONRPCError | None,
244+
result_type: type[ReceiveResultT],
245+
) -> ReceiveResultT:
246+
"""
247+
Process a JSON-RPC response, validating and returning the result.
248+
249+
Raises McpError if the response is an error or if response_or_error is None.
250+
The None check is a defensive guard against anyio race conditions - see #1717.
251+
"""
252+
if response_or_error is None:
253+
# Defensive check for anyio fail_after race condition (#1717).
254+
# If anyio's CancelScope incorrectly suppresses an exception,
255+
# the response variable may never be assigned. See:
256+
# https://github.com/agronholm/anyio/issues/589
257+
raise McpError(
258+
ErrorData(
259+
code=INTERNAL_ERROR,
260+
message="Internal error: no response received",
261+
)
262+
)
263+
264+
if isinstance(response_or_error, JSONRPCError):
265+
raise McpError(response_or_error.error)
266+
267+
return result_type.model_validate(response_or_error.result)
268+
240269
async def send_request(
241270
self,
242271
request: SendRequestT,
@@ -287,6 +316,10 @@ async def send_request(
287316
elif self._session_read_timeout_seconds is not None: # pragma: no cover
288317
timeout = self._session_read_timeout_seconds.total_seconds()
289318

319+
# Initialize to None as a defensive guard against anyio race conditions
320+
# where fail_after may incorrectly suppress exceptions (#1717)
321+
response_or_error: JSONRPCResponse | JSONRPCError | None = None
322+
290323
try:
291324
with anyio.fail_after(timeout):
292325
response_or_error = await response_stream_reader.receive()
@@ -301,12 +334,22 @@ async def send_request(
301334
),
302335
)
303336
)
304-
305-
if isinstance(response_or_error, JSONRPCError):
306-
raise McpError(response_or_error.error)
337+
except anyio.EndOfStream:
338+
raise McpError(
339+
ErrorData(
340+
code=CONNECTION_CLOSED,
341+
message="Connection closed: stream ended unexpectedly",
342+
)
343+
)
344+
except anyio.ClosedResourceError:
345+
raise McpError(
346+
ErrorData(
347+
code=CONNECTION_CLOSED,
348+
message="Connection closed",
349+
)
350+
)
307351
else:
308-
return result_type.model_validate(response_or_error.result)
309-
352+
return self._process_response(response_or_error, result_type)
310353
finally:
311354
self._response_streams.pop(request_id, None)
312355
self._progress_callbacks.pop(request_id, None)

tests/shared/test_session.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from collections.abc import AsyncGenerator
22
from typing import Any
3+
from unittest.mock import AsyncMock, patch
34

45
import anyio
56
import pytest
@@ -9,12 +10,18 @@
910
from mcp.server.lowlevel.server import Server
1011
from mcp.shared.exceptions import McpError
1112
from mcp.shared.memory import create_client_server_memory_streams, create_connected_server_and_client_session
13+
from mcp.shared.session import BaseSession
1214
from mcp.types import (
15+
CONNECTION_CLOSED,
16+
INTERNAL_ERROR,
1317
CancelledNotification,
1418
CancelledNotificationParams,
1519
ClientNotification,
1620
ClientRequest,
1721
EmptyResult,
22+
ErrorData,
23+
JSONRPCError,
24+
JSONRPCResponse,
1825
TextContent,
1926
)
2027

@@ -168,3 +175,111 @@ async def mock_server():
168175
await ev_closed.wait()
169176
with anyio.fail_after(1):
170177
await ev_response.wait()
178+
179+
180+
class TestProcessResponse:
181+
"""Tests for BaseSession._process_response static method."""
182+
183+
def test_process_response_with_valid_response(self):
184+
"""Test that a valid JSONRPCResponse is processed correctly."""
185+
response = JSONRPCResponse(
186+
jsonrpc="2.0",
187+
id=1,
188+
result={},
189+
)
190+
191+
result = BaseSession._process_response(response, EmptyResult)
192+
193+
assert isinstance(result, EmptyResult)
194+
195+
def test_process_response_with_error(self):
196+
"""Test that a JSONRPCError raises McpError."""
197+
error = JSONRPCError(
198+
jsonrpc="2.0",
199+
id=1,
200+
error=ErrorData(code=-32600, message="Invalid request"),
201+
)
202+
203+
with pytest.raises(McpError) as exc_info:
204+
BaseSession._process_response(error, EmptyResult)
205+
206+
assert exc_info.value.error.code == -32600
207+
assert exc_info.value.error.message == "Invalid request"
208+
209+
def test_process_response_with_none(self):
210+
"""
211+
Test defensive check for anyio fail_after race condition (#1717).
212+
213+
If anyio's CancelScope incorrectly suppresses an exception during
214+
receive(), the response variable may never be assigned. This test
215+
verifies we handle this gracefully instead of raising UnboundLocalError.
216+
217+
See: https://github.com/agronholm/anyio/issues/589
218+
"""
219+
with pytest.raises(McpError) as exc_info:
220+
BaseSession._process_response(None, EmptyResult)
221+
222+
assert exc_info.value.error.code == INTERNAL_ERROR
223+
assert "no response received" in exc_info.value.error.message
224+
225+
226+
@pytest.mark.anyio
227+
async def test_send_request_handles_end_of_stream():
228+
"""Test that EndOfStream from response stream raises McpError with CONNECTION_CLOSED."""
229+
230+
async with create_client_server_memory_streams() as (client_streams, _):
231+
client_read, client_write = client_streams
232+
233+
async with ClientSession(read_stream=client_read, write_stream=client_write) as client_session:
234+
# Mock create_memory_object_stream to return a stream that raises EndOfStream
235+
mock_reader = AsyncMock()
236+
mock_reader.receive = AsyncMock(side_effect=anyio.EndOfStream)
237+
mock_reader.aclose = AsyncMock()
238+
239+
mock_sender = AsyncMock()
240+
mock_sender.aclose = AsyncMock()
241+
242+
# The subscripted form returns a callable that returns the tuple
243+
with patch("mcp.shared.session.anyio.create_memory_object_stream") as mock_create:
244+
# pyright: ignore[reportUnknownLambdaType]
245+
mock_create.__getitem__ = lambda _s, _k: lambda _z: (mock_sender, mock_reader) # type: ignore
246+
247+
with pytest.raises(McpError) as exc_info:
248+
await client_session.send_request(
249+
ClientRequest(types.PingRequest()),
250+
EmptyResult,
251+
)
252+
253+
assert exc_info.value.error.code == CONNECTION_CLOSED
254+
assert "stream ended unexpectedly" in exc_info.value.error.message
255+
256+
257+
@pytest.mark.anyio
258+
async def test_send_request_handles_closed_resource_error():
259+
"""Test that ClosedResourceError from response stream raises McpError with CONNECTION_CLOSED."""
260+
261+
async with create_client_server_memory_streams() as (client_streams, _):
262+
client_read, client_write = client_streams
263+
264+
async with ClientSession(read_stream=client_read, write_stream=client_write) as client_session:
265+
# Mock create_memory_object_stream to return a stream that raises ClosedResourceError
266+
mock_reader = AsyncMock()
267+
mock_reader.receive = AsyncMock(side_effect=anyio.ClosedResourceError)
268+
mock_reader.aclose = AsyncMock()
269+
270+
mock_sender = AsyncMock()
271+
mock_sender.aclose = AsyncMock()
272+
273+
# The subscripted form returns a callable that returns the tuple
274+
with patch("mcp.shared.session.anyio.create_memory_object_stream") as mock_create:
275+
# pyright: ignore[reportUnknownLambdaType]
276+
mock_create.__getitem__ = lambda _s, _k: lambda _z: (mock_sender, mock_reader) # type: ignore
277+
278+
with pytest.raises(McpError) as exc_info:
279+
await client_session.send_request(
280+
ClientRequest(types.PingRequest()),
281+
EmptyResult,
282+
)
283+
284+
assert exc_info.value.error.code == CONNECTION_CLOSED
285+
assert "Connection closed" in exc_info.value.error.message

0 commit comments

Comments
 (0)