From dde19fb371649861abfa437288ec970030f9bafb Mon Sep 17 00:00:00 2001 From: Felix Weinberger Date: Fri, 16 Jan 2026 17:58:41 +0100 Subject: [PATCH] refactor: use Client class in tests Refactored tests to use the ergonomic Client class instead of the verbose InMemoryTransport + ClientSession pattern: - tests/client/transports/test_memory.py: 3 tests - tests/shared/test_session.py: 2 tests - tests/server/test_cancel_handling.py: 1 test - tests/shared/test_progress_notifications.py: 1 test - tests/client/test_list_methods_cursor.py: 6 tests (including stream_spy test) The stream_spy fixture continues to work with Client since it patches the underlying memory streams used by InMemoryTransport. Github-Issue: #1891 --- tests/client/test_list_methods_cursor.py | 85 ++++++-------- tests/client/transports/test_memory.py | 41 ++----- tests/server/test_cancel_handling.py | 98 +++++++--------- tests/shared/test_progress_notifications.py | 42 +++---- tests/shared/test_session.py | 124 +++++++++----------- 5 files changed, 163 insertions(+), 227 deletions(-) diff --git a/tests/client/test_list_methods_cursor.py b/tests/client/test_list_methods_cursor.py index a5f79910f..2d2b8f823 100644 --- a/tests/client/test_list_methods_cursor.py +++ b/tests/client/test_list_methods_cursor.py @@ -3,8 +3,7 @@ import pytest import mcp.types as types -from mcp.client._memory import InMemoryTransport -from mcp.client.session import ClientSession +from mcp import Client from mcp.server import Server from mcp.server.fastmcp import FastMCP from mcp.types import ListToolsRequest, ListToolsResult @@ -66,49 +65,43 @@ async def test_list_methods_params_parameter( See: https://modelcontextprotocol.io/specification/2025-03-26/server/utilities/pagination#request-format """ - transport = InMemoryTransport(full_featured_server) - async with transport.connect() as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - spies = stream_spy() - - # Test without params (omitted) - method = getattr(session, method_name) - _ = await method() - requests = spies.get_client_requests(method=request_method) - assert len(requests) == 1 - assert requests[0].params is None - - spies.clear() - - # Test with params containing cursor - _ = await method(params=types.PaginatedRequestParams(cursor="from_params")) - requests = spies.get_client_requests(method=request_method) - assert len(requests) == 1 - assert requests[0].params is not None - assert requests[0].params["cursor"] == "from_params" - - spies.clear() - - # Test with empty params - _ = await method(params=types.PaginatedRequestParams()) - requests = spies.get_client_requests(method=request_method) - assert len(requests) == 1 - # Empty params means no cursor - assert requests[0].params is None or "cursor" not in requests[0].params + async with Client(full_featured_server) as client: + spies = stream_spy() + + # Test without params (omitted) + method = getattr(client, method_name) + _ = await method() + requests = spies.get_client_requests(method=request_method) + assert len(requests) == 1 + assert requests[0].params is None + + spies.clear() + + # Test with params containing cursor + _ = await method(params=types.PaginatedRequestParams(cursor="from_params")) + requests = spies.get_client_requests(method=request_method) + assert len(requests) == 1 + assert requests[0].params is not None + assert requests[0].params["cursor"] == "from_params" + + spies.clear() + + # Test with empty params + _ = await method(params=types.PaginatedRequestParams()) + requests = spies.get_client_requests(method=request_method) + assert len(requests) == 1 + # Empty params means no cursor + assert requests[0].params is None or "cursor" not in requests[0].params async def test_list_tools_with_strict_server_validation( full_featured_server: FastMCP, ): """Test pagination with a server that validates request format strictly.""" - transport = InMemoryTransport(full_featured_server) - async with transport.connect() as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - result = await session.list_tools(params=types.PaginatedRequestParams()) - assert isinstance(result, ListToolsResult) - assert len(result.tools) > 0 + async with Client(full_featured_server) as client: + result = await client.list_tools(params=types.PaginatedRequestParams()) + assert isinstance(result, ListToolsResult) + assert len(result.tools) > 0 async def test_list_tools_with_lowlevel_server(): @@ -129,13 +122,9 @@ async def handle_list_tools(request: ListToolsRequest) -> ListToolsResult: ] ) - transport = InMemoryTransport(server) - async with transport.connect() as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - - result = await session.list_tools(params=types.PaginatedRequestParams()) - assert result.tools[0].description == "cursor=None" + async with Client(server) as client: + result = await client.list_tools(params=types.PaginatedRequestParams()) + assert result.tools[0].description == "cursor=None" - result = await session.list_tools(params=types.PaginatedRequestParams(cursor="page2")) - assert result.tools[0].description == "cursor=page2" + result = await client.list_tools(params=types.PaginatedRequestParams(cursor="page2")) + assert result.tools[0].description == "cursor=page2" diff --git a/tests/client/transports/test_memory.py b/tests/client/transports/test_memory.py index fbaf9a982..b97ebcea2 100644 --- a/tests/client/transports/test_memory.py +++ b/tests/client/transports/test_memory.py @@ -2,6 +2,7 @@ import pytest +from mcp import Client from mcp.client._memory import InMemoryTransport from mcp.server import Server from mcp.server.fastmcp import FastMCP @@ -69,42 +70,26 @@ async def test_with_fastmcp(fastmcp_server: FastMCP): async def test_server_is_running(fastmcp_server: FastMCP): """Test that the server is running and responding to requests.""" - from mcp.client.session import ClientSession - - transport = InMemoryTransport(fastmcp_server) - async with transport.connect() as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - result = await session.initialize() - assert result is not None - assert result.server_info.name == "test" + async with Client(fastmcp_server) as client: + assert client.server_capabilities is not None async def test_list_tools(fastmcp_server: FastMCP): """Test listing tools through the transport.""" - from mcp.client.session import ClientSession - - transport = InMemoryTransport(fastmcp_server) - async with transport.connect() as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - tools_result = await session.list_tools() - assert len(tools_result.tools) > 0 - tool_names = [t.name for t in tools_result.tools] - assert "greet" in tool_names + async with Client(fastmcp_server) as client: + tools_result = await client.list_tools() + assert len(tools_result.tools) > 0 + tool_names = [t.name for t in tools_result.tools] + assert "greet" in tool_names async def test_call_tool(fastmcp_server: FastMCP): """Test calling a tool through the transport.""" - from mcp.client.session import ClientSession - - transport = InMemoryTransport(fastmcp_server) - async with transport.connect() as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - result = await session.call_tool("greet", {"name": "World"}) - assert result is not None - assert len(result.content) > 0 - assert "Hello, World!" in str(result.content[0]) + async with Client(fastmcp_server) as client: + result = await client.call_tool("greet", {"name": "World"}) + assert result is not None + assert len(result.content) > 0 + assert "Hello, World!" in str(result.content[0]) async def test_raise_exceptions(fastmcp_server: FastMCP): diff --git a/tests/server/test_cancel_handling.py b/tests/server/test_cancel_handling.py index 8f109d9fb..aa8d42261 100644 --- a/tests/server/test_cancel_handling.py +++ b/tests/server/test_cancel_handling.py @@ -6,8 +6,7 @@ import pytest import mcp.types as types -from mcp.client._memory import InMemoryTransport -from mcp.client.session import ClientSession +from mcp import Client from mcp.server.lowlevel.server import Server from mcp.shared.exceptions import McpError from mcp.types import ( @@ -55,61 +54,50 @@ async def handle_call_tool(name: str, arguments: dict[str, Any] | None) -> list[ return [types.TextContent(type="text", text=f"Call number: {call_count}")] raise ValueError(f"Unknown tool: {name}") # pragma: no cover - transport = InMemoryTransport(server) - async with transport.connect() as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as client: - await client.initialize() - - # First request (will be cancelled) - async def first_request(): - try: - await client.send_request( - ClientRequest( - CallToolRequest( - params=CallToolRequestParams(name="test_tool", arguments={}), - ) - ), - CallToolResult, - ) - pytest.fail("First request should have been cancelled") # pragma: no cover - except McpError: - pass # Expected - - # Start first request - async with anyio.create_task_group() as tg: - tg.start_soon(first_request) - - # Wait for it to start - await ev_first_call.wait() - - # Cancel it - assert first_request_id is not None - await client.send_notification( - ClientNotification( - CancelledNotification( - params=CancelledNotificationParams( - request_id=first_request_id, - reason="Testing server recovery", - ), + async with Client(server) as client: + # First request (will be cancelled) + async def first_request(): + try: + await client.session.send_request( + ClientRequest( + CallToolRequest( + params=CallToolRequestParams(name="test_tool", arguments={}), ) - ) + ), + CallToolResult, ) - - # Second request (should work normally) - result = await client.send_request( - ClientRequest( - CallToolRequest( - params=CallToolRequestParams(name="test_tool", arguments={}), + pytest.fail("First request should have been cancelled") # pragma: no cover + except McpError: + pass # Expected + + # Start first request + async with anyio.create_task_group() as tg: + tg.start_soon(first_request) + + # Wait for it to start + await ev_first_call.wait() + + # Cancel it + assert first_request_id is not None + await client.session.send_notification( + ClientNotification( + CancelledNotification( + params=CancelledNotificationParams( + request_id=first_request_id, + reason="Testing server recovery", + ), ) - ), - CallToolResult, + ) ) - # Verify second request completed successfully - assert len(result.content) == 1 - # Type narrowing for pyright - content = result.content[0] - assert content.type == "text" - assert isinstance(content, types.TextContent) - assert content.text == "Call number: 2" - assert call_count == 2 + # Second request (should work normally) + result = await client.call_tool("test_tool", {}) + + # Verify second request completed successfully + assert len(result.content) == 1 + # Type narrowing for pyright + content = result.content[0] + assert content.type == "text" + assert isinstance(content, types.TextContent) + assert content.text == "Call number: 2" + assert call_count == 2 diff --git a/tests/shared/test_progress_notifications.py b/tests/shared/test_progress_notifications.py index 1d7de0b34..78896397b 100644 --- a/tests/shared/test_progress_notifications.py +++ b/tests/shared/test_progress_notifications.py @@ -5,7 +5,7 @@ import pytest import mcp.types as types -from mcp.client._memory import InMemoryTransport +from mcp import Client from mcp.client.session import ClientSession from mcp.server import Server from mcp.server.lowlevel import NotificationOptions @@ -369,30 +369,20 @@ async def handle_list_tools() -> list[types.Tool]: # Test with mocked logging with patch("mcp.shared.session.logging.error", side_effect=mock_log_error): - transport = InMemoryTransport(server) - async with transport.connect() as (read_stream, write_stream): - async with ClientSession( # pragma: no branch - read_stream=read_stream, write_stream=write_stream - ) as session: - await session.initialize() - # Send a request with a failing progress callback - result = await session.send_request( - types.ClientRequest( - types.CallToolRequest( - method="tools/call", - params=types.CallToolRequestParams(name="progress_tool", arguments={}), - ) - ), - types.CallToolResult, - progress_callback=failing_progress_callback, - ) + async with Client(server) as client: + # Call tool with a failing progress callback + result = await client.call_tool( + "progress_tool", + arguments={}, + progress_callback=failing_progress_callback, + ) - # Verify the request completed successfully despite the callback failure - assert len(result.content) == 1 - content = result.content[0] - assert isinstance(content, types.TextContent) - assert content.text == "progress_result" + # Verify the request completed successfully despite the callback failure + assert len(result.content) == 1 + content = result.content[0] + assert isinstance(content, types.TextContent) + assert content.text == "progress_result" - # Check that a warning was logged for the progress callback exception - assert len(logged_errors) > 0 - assert any("Progress callback raised an exception" in warning for warning in logged_errors) + # Check that a warning was logged for the progress callback exception + assert len(logged_errors) > 0 + assert any("Progress callback raised an exception" in warning for warning in logged_errors) diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index be5ab4862..8b4ebd81f 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -4,7 +4,7 @@ import pytest import mcp.types as types -from mcp.client._memory import InMemoryTransport +from mcp import Client from mcp.client.session import ClientSession from mcp.server.lowlevel.server import Server from mcp.shared.exceptions import McpError @@ -25,68 +25,55 @@ ) -@pytest.fixture -def mcp_server() -> Server: - return Server(name="test server") - - @pytest.mark.anyio -async def test_in_flight_requests_cleared_after_completion(mcp_server: Server): +async def test_in_flight_requests_cleared_after_completion(): """Verify that _in_flight is empty after all requests complete.""" - transport = InMemoryTransport(mcp_server) - async with transport.connect() as (read_stream, write_stream): - async with ClientSession(read_stream=read_stream, write_stream=write_stream) as session: - await session.initialize() - - # Send a request and wait for response - response = await session.send_ping() - assert isinstance(response, EmptyResult) + server = Server(name="test server") + async with Client(server) as client: + # Send a request and wait for response + response = await client.send_ping() + assert isinstance(response, EmptyResult) - # Verify _in_flight is empty - assert len(session._in_flight) == 0 + # Verify _in_flight is empty + assert len(client.session._in_flight) == 0 @pytest.mark.anyio async def test_request_cancellation(): """Test that requests can be cancelled while in-flight.""" - # The tool is already registered in the fixture - ev_tool_called = anyio.Event() ev_cancelled = anyio.Event() request_id = None - # Start the request in a separate task so we can cancel it - def make_server() -> Server: - server = Server(name="TestSessionServer") - - # Register the tool handler - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any] | None) -> list[TextContent]: - nonlocal request_id, ev_tool_called - if name == "slow_tool": - request_id = server.request_context.request_id - ev_tool_called.set() - await anyio.sleep(10) # Long enough to ensure we can cancel - return [] # pragma: no cover - raise ValueError(f"Unknown tool: {name}") # pragma: no cover - - # Register the tool so it shows up in list_tools - @server.list_tools() - async def handle_list_tools() -> list[types.Tool]: - return [ - types.Tool( - name="slow_tool", - description="A slow tool that takes 10 seconds to complete", - input_schema={}, - ) - ] - - return server + # Create a server with a slow tool + server = Server(name="TestSessionServer") + + # Register the tool handler + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict[str, Any] | None) -> list[TextContent]: + nonlocal request_id, ev_tool_called + if name == "slow_tool": + request_id = server.request_context.request_id + ev_tool_called.set() + await anyio.sleep(10) # Long enough to ensure we can cancel + return [] # pragma: no cover + raise ValueError(f"Unknown tool: {name}") # pragma: no cover + + # Register the tool so it shows up in list_tools + @server.list_tools() + async def handle_list_tools() -> list[types.Tool]: + return [ + types.Tool( + name="slow_tool", + description="A slow tool that takes 10 seconds to complete", + input_schema={}, + ) + ] - async def make_request(session: ClientSession): + async def make_request(client: Client): nonlocal ev_cancelled try: - await session.send_request( + await client.session.send_request( ClientRequest( types.CallToolRequest( params=types.CallToolRequestParams(name="slow_tool", arguments={}), @@ -100,31 +87,28 @@ async def make_request(session: ClientSession): assert "Request cancelled" in str(e) ev_cancelled.set() - transport = InMemoryTransport(make_server()) - async with transport.connect() as (read_stream, write_stream): - async with ClientSession(read_stream=read_stream, write_stream=write_stream) as session: - await session.initialize() - async with anyio.create_task_group() as tg: # pragma: no branch - tg.start_soon(make_request, session) - - # Wait for the request to be in-flight - with anyio.fail_after(1): # Timeout after 1 second - await ev_tool_called.wait() - - # Send cancellation notification - assert request_id is not None - await session.send_notification( - ClientNotification( - CancelledNotification( - params=CancelledNotificationParams(request_id=request_id), - ) + async with Client(server) as client: + async with anyio.create_task_group() as tg: # pragma: no branch + tg.start_soon(make_request, client) + + # Wait for the request to be in-flight + with anyio.fail_after(1): # Timeout after 1 second + await ev_tool_called.wait() + + # Send cancellation notification + assert request_id is not None + await client.session.send_notification( + ClientNotification( + CancelledNotification( + params=CancelledNotificationParams(request_id=request_id), ) ) + ) - # Give cancellation time to process - # TODO(Marcelo): Drop the pragma once https://github.com/coveragepy/coveragepy/issues/1987 is fixed. - with anyio.fail_after(1): # pragma: no cover - await ev_cancelled.wait() + # Give cancellation time to process + # TODO(Marcelo): Drop the pragma once https://github.com/coveragepy/coveragepy/issues/1987 is fixed. + with anyio.fail_after(1): # pragma: no cover + await ev_cancelled.wait() @pytest.mark.anyio