Skip to content

Commit bba229f

Browse files
authored
Merge pull request #17963 from BerriAI/litellm_feat_rest-mcp-list-tools-auth-header
add MCP auth header propagation
2 parents 87ba4fa + 73e00c1 commit bba229f

File tree

2 files changed

+153
-3
lines changed

2 files changed

+153
-3
lines changed

litellm/proxy/_experimental/mcp_server/rest_endpoints.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from litellm._logging import verbose_logger
88
from litellm.proxy._types import UserAPIKeyAuth
99
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
10+
from litellm.types.mcp import MCPAuth
1011

1112
MCP_AVAILABLE: bool = True
1213
try:
@@ -297,6 +298,7 @@ async def call_tool_rest_api(
297298
async def _execute_with_mcp_client(
298299
request: NewMCPServerRequest,
299300
operation,
301+
mcp_auth_header: Optional[Union[str, Dict[str, str]]] = None,
300302
oauth2_headers: Optional[Dict[str, str]] = None,
301303
):
302304
"""
@@ -319,7 +321,7 @@ async def _execute_with_mcp_client(
319321
auth_type=request.auth_type,
320322
mcp_info=request.mcp_info,
321323
),
322-
mcp_auth_header=None,
324+
mcp_auth_header=mcp_auth_header,
323325
extra_headers=oauth2_headers,
324326
)
325327

@@ -365,7 +367,21 @@ async def test_tools_list(
365367
)
366368

367369
headers = request.headers
368-
oauth2_headers = MCPRequestHandler._get_oauth2_headers_from_headers(headers)
370+
371+
mcp_auth_header: Optional[str] = None
372+
if new_mcp_server_request.auth_type in {
373+
MCPAuth.api_key,
374+
MCPAuth.bearer_token,
375+
MCPAuth.basic,
376+
MCPAuth.authorization,
377+
}:
378+
credentials = getattr(new_mcp_server_request, "credentials", None)
379+
if isinstance(credentials, dict):
380+
mcp_auth_header = credentials.get("auth_value")
381+
382+
oauth2_headers: Optional[Dict[str, str]] = None
383+
if new_mcp_server_request.auth_type == MCPAuth.oauth2:
384+
oauth2_headers = MCPRequestHandler._get_oauth2_headers_from_headers(headers)
369385

370386
async def _list_tools_operation(client):
371387
async def _list_tools_session_operation(session):
@@ -385,5 +401,8 @@ async def _list_tools_session_operation(session):
385401
}
386402

387403
return await _execute_with_mcp_client(
388-
new_mcp_server_request, _list_tools_operation, oauth2_headers
404+
new_mcp_server_request,
405+
_list_tools_operation,
406+
mcp_auth_header=mcp_auth_header,
407+
oauth2_headers=oauth2_headers,
389408
)
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
from typing import Dict, Optional
2+
3+
import pytest
4+
from starlette.requests import Request
5+
6+
from litellm.proxy._experimental.mcp_server import rest_endpoints
7+
from litellm.proxy._experimental.mcp_server.auth import (
8+
user_api_key_auth_mcp as auth_mcp,
9+
)
10+
from litellm.proxy._types import NewMCPServerRequest, UserAPIKeyAuth
11+
from litellm.types.mcp import MCPAuth
12+
13+
14+
def _build_request(headers: Optional[Dict[str, str]] = None) -> Request:
15+
headers = headers or {}
16+
raw_headers = [
17+
(key.lower().encode("latin-1"), value.encode("latin-1"))
18+
for key, value in headers.items()
19+
]
20+
scope = {
21+
"type": "http",
22+
"http_version": "1.1",
23+
"method": "POST",
24+
"path": "/mcp-rest/test/tools/list",
25+
"headers": raw_headers,
26+
}
27+
28+
async def receive():
29+
return {"type": "http.request", "body": b"", "more_body": False}
30+
31+
return Request(scope, receive=receive)
32+
33+
34+
@pytest.mark.asyncio
35+
async def test_test_tools_list_forwards_mcp_auth_header(monkeypatch):
36+
"""Ensure credential-based auth forwards the auth_value to the MCP client."""
37+
38+
captured: dict = {}
39+
40+
async def fake_execute(request, operation, mcp_auth_header=None, oauth2_headers=None):
41+
captured["mcp_auth_header"] = mcp_auth_header
42+
captured["oauth2_headers"] = oauth2_headers
43+
return {
44+
"tools": [],
45+
"error": None,
46+
"message": "Successfully retrieved tools",
47+
}
48+
49+
monkeypatch.setattr(
50+
rest_endpoints, "_execute_with_mcp_client", fake_execute, raising=False
51+
)
52+
53+
oauth_call_counter = {"count": 0}
54+
55+
def fake_oauth(headers):
56+
oauth_call_counter["count"] += 1
57+
return {"Authorization": "Bearer oauth"}
58+
59+
monkeypatch.setattr(
60+
auth_mcp.MCPRequestHandler,
61+
"_get_oauth2_headers_from_headers",
62+
staticmethod(fake_oauth),
63+
raising=False,
64+
)
65+
66+
request = _build_request()
67+
payload = NewMCPServerRequest(
68+
server_name="example",
69+
url="https://example.com",
70+
auth_type=MCPAuth.api_key,
71+
credentials={"auth_value": "secret-key"},
72+
)
73+
74+
result = await rest_endpoints.test_tools_list(
75+
request, payload, user_api_key_dict=UserAPIKeyAuth()
76+
)
77+
78+
assert result["message"] == "Successfully retrieved tools"
79+
assert captured["mcp_auth_header"] == "secret-key"
80+
assert captured["oauth2_headers"] is None
81+
assert oauth_call_counter["count"] == 0
82+
83+
84+
@pytest.mark.asyncio
85+
async def test_test_tools_list_extracts_oauth2_headers(monkeypatch):
86+
"""Ensure oauth2 auth type pulls oauth headers and omits MCP auth header."""
87+
88+
captured: dict = {}
89+
90+
async def fake_execute(request, operation, mcp_auth_header=None, oauth2_headers=None):
91+
captured["mcp_auth_header"] = mcp_auth_header
92+
captured["oauth2_headers"] = oauth2_headers
93+
return {
94+
"tools": [],
95+
"error": None,
96+
"message": "Successfully retrieved tools",
97+
}
98+
99+
monkeypatch.setattr(
100+
rest_endpoints, "_execute_with_mcp_client", fake_execute, raising=False
101+
)
102+
103+
oauth_headers = {"Authorization": "Bearer oauth"}
104+
oauth_call_counter = {"count": 0}
105+
106+
def fake_oauth(headers):
107+
oauth_call_counter["count"] += 1
108+
return oauth_headers
109+
110+
monkeypatch.setattr(
111+
auth_mcp.MCPRequestHandler,
112+
"_get_oauth2_headers_from_headers",
113+
staticmethod(fake_oauth),
114+
raising=False,
115+
)
116+
117+
request = _build_request({"authorization": "Bearer incoming"})
118+
payload = NewMCPServerRequest(
119+
server_name="example",
120+
url="https://example.com",
121+
auth_type=MCPAuth.oauth2,
122+
)
123+
124+
result = await rest_endpoints.test_tools_list(
125+
request, payload, user_api_key_dict=UserAPIKeyAuth()
126+
)
127+
128+
assert result["message"] == "Successfully retrieved tools"
129+
assert captured["mcp_auth_header"] is None
130+
assert captured["oauth2_headers"] == oauth_headers
131+
assert oauth_call_counter["count"] == 1

0 commit comments

Comments
 (0)