Skip to content

Commit 640dbf5

Browse files
committed
refactor(ScopeMiddleware): use a proper Middleware class for ScopeFiltering
1 parent 5a78ffa commit 640dbf5

File tree

1 file changed

+44
-42
lines changed

1 file changed

+44
-42
lines changed

packages/gg_api_core/src/gg_api_core/mcp_server.py

Lines changed: 44 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,19 @@
22

33
import logging
44
from abc import ABC, abstractmethod
5-
from collections.abc import AsyncIterator, Callable
5+
from collections.abc import AsyncIterator, Sequence
66
from contextlib import asynccontextmanager
77
from enum import Enum
88
from typing import Any
99

1010
from fastmcp import FastMCP
1111
from fastmcp.exceptions import FastMCPError, ValidationError
1212
from fastmcp.server.dependencies import get_http_headers
13-
from fastmcp.server.middleware import MiddlewareContext
13+
from fastmcp.server.middleware import Middleware
14+
from fastmcp.tools import Tool
1415
from mcp.types import Tool as MCPTool
1516

16-
from gg_api_core.client import get_personal_access_token_from_env, is_oauth_enabled
17+
from gg_api_core.client import GitGuardianClient, get_personal_access_token_from_env, is_oauth_enabled
1718
from gg_api_core.utils import get_client
1819

1920
# Configure logger
@@ -32,7 +33,11 @@ class AuthenticationMode(Enum):
3233

3334

3435
class CachedTokenInfoMixin:
35-
"""Mixin for MCP servers that are mono-tenant (only one authenticated identity from startup to close of the server)"""
36+
"""Mixin for MCP servers that are mono-tenant (only one authenticated identity from startup to close of the server)
37+
38+
Note: This mixin expects to be used with AbstractGitGuardianFastMCP which provides
39+
_fetch_token_scopes_from_api() and _fetch_token_info_from_api() methods.
40+
"""
3641

3742
_token_scopes: set[str] = set()
3843
_token_info: dict[str, Any] | None = None
@@ -65,7 +70,7 @@ async def token_scope_lifespan(fastmcp) -> AsyncIterator[dict]:
6570

6671
# Cache scopes at startup (single token throughout lifespan)
6772
try:
68-
self._token_scopes = await self._fetch_token_scopes_from_api()
73+
self._token_scopes = await self._fetch_token_scopes_from_api() # type: ignore[attr-defined]
6974
logger.debug(f"Retrieved token scopes: {self._token_scopes}")
7075
except Exception as e:
7176
logger.warning(f"Failed to fetch token scopes during startup: {str(e)}")
@@ -82,11 +87,41 @@ async def get_token_info(self) -> dict[str, Any]:
8287
if self._token_info is not None:
8388
return self._token_info
8489

85-
# Type ignore because this method will be available via MRO from AbstractGitGuardianFastMCP
8690
self._token_info = await self._fetch_token_info_from_api() # type: ignore[attr-defined]
8791
return self._token_info
8892

8993

94+
class ScopeFilteringMiddleware(Middleware):
95+
"""Middleware to filter tools based on token scopes."""
96+
97+
def __init__(self, mcp_server: "AbstractGitGuardianFastMCP"):
98+
self._mcp_server = mcp_server
99+
100+
async def on_list_tools(
101+
self,
102+
context,
103+
call_next,
104+
) -> Sequence[Tool]:
105+
"""Filter tools based on the user's API token scopes."""
106+
# Get all tools from the next middleware/handler
107+
all_tools = await call_next(context)
108+
109+
# Filter tools by scopes
110+
scopes = await self._mcp_server.get_scopes()
111+
filtered_tools: list[Tool] = []
112+
for tool in all_tools:
113+
tool_name = tool.name
114+
required_scopes = self._mcp_server._tool_scopes.get(tool_name, set())
115+
116+
if not required_scopes or required_scopes.issubset(scopes):
117+
filtered_tools.append(tool)
118+
else:
119+
missing_scopes = required_scopes - scopes
120+
logger.info(f"Removing tool '{tool_name}' due to missing scopes: {', '.join(missing_scopes)}")
121+
122+
return filtered_tools
123+
124+
90125
class AbstractGitGuardianFastMCP(FastMCP, ABC):
91126
"""Abstract base class for GitGuardian MCP servers with scope-based tool filtering.
92127
@@ -106,7 +141,7 @@ def __init__(self, *args, default_scopes: list[str] | None = None, **kwargs):
106141
# Map each tool to its required scopes (instance attribute)
107142
self._tool_scopes: dict[str, set[str]] = {}
108143

109-
self.add_middleware(self._scope_filtering_middleware)
144+
self.add_middleware(ScopeFilteringMiddleware(self))
110145

111146
def clear_cache(self) -> None:
112147
"""Clear cached data. Override in subclasses that cache."""
@@ -122,7 +157,7 @@ async def get_token_info(self) -> dict[str, Any]:
122157
"""Return the token info dictionary."""
123158
pass
124159

125-
def get_client(self):
160+
def get_client(self) -> GitGuardianClient:
126161
return get_client(personal_access_token=self.get_personal_access_token())
127162

128163
async def revoke_current_token(self) -> dict[str, Any]:
@@ -199,9 +234,7 @@ async def _fetch_token_scopes_from_api(self, client=None) -> set[str]:
199234
async def _fetch_token_info_from_api(self) -> dict[str, Any]:
200235
try:
201236
client = self.get_client()
202-
token_info = await client.get_current_token_info()
203-
self._token_info = token_info
204-
return token_info
237+
return await client.get_current_token_info()
205238
except Exception as e:
206239
raise FastMCPError("Error fetching token info from /api_tokens/self endpoint") from e
207240

@@ -215,37 +248,6 @@ async def get_scopes(self) -> set[str]:
215248
logger.debug(f"scopes: {scopes}")
216249
return scopes
217250

218-
async def _scope_filtering_middleware(self, context: MiddlewareContext, call_next: Callable) -> Any:
219-
"""Middleware to filter tools based on token scopes.
220-
221-
This middleware intercepts tools/list requests and filters the tools
222-
based on the user's API token scopes.
223-
224-
The authentication strategy determines whether to use cached scopes
225-
or fetch them per-request.
226-
"""
227-
# Only apply filtering to tools/list requests
228-
if context.method != "tools/list":
229-
return await call_next(context)
230-
231-
# Get all tools from the next middleware/handler
232-
all_tools = await call_next(context)
233-
234-
# Filter tools by scopes
235-
scopes = await self.get_scopes()
236-
filtered_tools = []
237-
for tool in all_tools:
238-
tool_name = tool.name
239-
required_scopes = self._tool_scopes.get(tool_name, set())
240-
241-
if not required_scopes or required_scopes.issubset(scopes):
242-
filtered_tools.append(tool)
243-
else:
244-
missing_scopes = required_scopes - scopes
245-
logger.info(f"Removing tool '{tool_name}' due to missing scopes: {', '.join(missing_scopes)}")
246-
247-
return filtered_tools
248-
249251
async def list_tools(self) -> list[MCPTool]:
250252
"""
251253
Public method to list tools (for compatibility with tests and external code).

0 commit comments

Comments
 (0)