22
33import logging
44from abc import ABC , abstractmethod
5- from collections .abc import AsyncIterator , Callable
5+ from collections .abc import AsyncIterator , Callable , Sequence
66from contextlib import asynccontextmanager
77from enum import Enum
88from typing import Any
99
1010from fastmcp import FastMCP
1111from fastmcp .exceptions import FastMCPError , ValidationError
1212from fastmcp .server .dependencies import get_http_headers
13- from fastmcp .server .middleware import MiddlewareContext
13+ from fastmcp .server .middleware import Middleware
14+ from fastmcp .tools import Tool
1415from mcp .types import Tool as MCPTool
1516
16- from gg_api_core .client import get_personal_access_token_from_env , is_oauth_enabled
17+ from gg_api_core .client import GitGuardianClient , get_personal_access_token_from_env , is_oauth_enabled
1718from gg_api_core .utils import get_client
1819
1920# Configure logger
@@ -32,7 +33,11 @@ class AuthenticationMode(Enum):
3233
3334
3435class CachedTokenInfoMixin :
35- """Mixin for MCP servers that are mono-tenant (only one authenticated identity from startup to close of the server)"""
36+ """Mixin for MCP servers that are mono-tenant (only one authenticated identity from startup to close of the server)
37+
38+ Note: This mixin expects to be used with AbstractGitGuardianFastMCP which provides
39+ _fetch_token_scopes_from_api() and _fetch_token_info_from_api() methods.
40+ """
3641
3742 _token_scopes : set [str ] = set ()
3843 _token_info : dict [str , Any ] | None = None
@@ -65,7 +70,7 @@ async def token_scope_lifespan(fastmcp) -> AsyncIterator[dict]:
6570
6671 # Cache scopes at startup (single token throughout lifespan)
6772 try :
68- self ._token_scopes = await self ._fetch_token_scopes_from_api ()
73+ self ._token_scopes = await self ._fetch_token_scopes_from_api () # type: ignore[attr-defined]
6974 logger .debug (f"Retrieved token scopes: { self ._token_scopes } " )
7075 except Exception as e :
7176 logger .warning (f"Failed to fetch token scopes during startup: { str (e )} " )
@@ -82,11 +87,41 @@ async def get_token_info(self) -> dict[str, Any]:
8287 if self ._token_info is not None :
8388 return self ._token_info
8489
85- # Type ignore because this method will be available via MRO from AbstractGitGuardianFastMCP
8690 self ._token_info = await self ._fetch_token_info_from_api () # type: ignore[attr-defined]
8791 return self ._token_info
8892
8993
94+ class ScopeFilteringMiddleware (Middleware ):
95+ """Middleware to filter tools based on token scopes."""
96+
97+ def __init__ (self , mcp_server : "AbstractGitGuardianFastMCP" ):
98+ self ._mcp_server = mcp_server
99+
100+ async def on_list_tools (
101+ self ,
102+ context ,
103+ call_next ,
104+ ) -> Sequence [Tool ]:
105+ """Filter tools based on the user's API token scopes."""
106+ # Get all tools from the next middleware/handler
107+ all_tools = await call_next (context )
108+
109+ # Filter tools by scopes
110+ scopes = await self ._mcp_server .get_scopes ()
111+ filtered_tools : list [Tool ] = []
112+ for tool in all_tools :
113+ tool_name = tool .name
114+ required_scopes = self ._mcp_server ._tool_scopes .get (tool_name , set ())
115+
116+ if not required_scopes or required_scopes .issubset (scopes ):
117+ filtered_tools .append (tool )
118+ else :
119+ missing_scopes = required_scopes - scopes
120+ logger .info (f"Removing tool '{ tool_name } ' due to missing scopes: { ', ' .join (missing_scopes )} " )
121+
122+ return filtered_tools
123+
124+
90125class AbstractGitGuardianFastMCP (FastMCP , ABC ):
91126 """Abstract base class for GitGuardian MCP servers with scope-based tool filtering.
92127
@@ -106,7 +141,7 @@ def __init__(self, *args, default_scopes: list[str] | None = None, **kwargs):
106141 # Map each tool to its required scopes (instance attribute)
107142 self ._tool_scopes : dict [str , set [str ]] = {}
108143
109- self .add_middleware (self . _scope_filtering_middleware )
144+ self .add_middleware (ScopeFilteringMiddleware ( self ) )
110145
111146 def clear_cache (self ) -> None :
112147 """Clear cached data. Override in subclasses that cache."""
@@ -122,7 +157,7 @@ async def get_token_info(self) -> dict[str, Any]:
122157 """Return the token info dictionary."""
123158 pass
124159
125- def get_client (self ):
160+ def get_client (self ) -> GitGuardianClient :
126161 return get_client (personal_access_token = self .get_personal_access_token ())
127162
128163 async def revoke_current_token (self ) -> dict [str , Any ]:
@@ -199,9 +234,7 @@ async def _fetch_token_scopes_from_api(self, client=None) -> set[str]:
199234 async def _fetch_token_info_from_api (self ) -> dict [str , Any ]:
200235 try :
201236 client = self .get_client ()
202- token_info = await client .get_current_token_info ()
203- self ._token_info = token_info
204- return token_info
237+ return await client .get_current_token_info ()
205238 except Exception as e :
206239 raise FastMCPError ("Error fetching token info from /api_tokens/self endpoint" ) from e
207240
@@ -215,37 +248,6 @@ async def get_scopes(self) -> set[str]:
215248 logger .debug (f"scopes: { scopes } " )
216249 return scopes
217250
218- async def _scope_filtering_middleware (self , context : MiddlewareContext , call_next : Callable ) -> Any :
219- """Middleware to filter tools based on token scopes.
220-
221- This middleware intercepts tools/list requests and filters the tools
222- based on the user's API token scopes.
223-
224- The authentication strategy determines whether to use cached scopes
225- or fetch them per-request.
226- """
227- # Only apply filtering to tools/list requests
228- if context .method != "tools/list" :
229- return await call_next (context )
230-
231- # Get all tools from the next middleware/handler
232- all_tools = await call_next (context )
233-
234- # Filter tools by scopes
235- scopes = await self .get_scopes ()
236- filtered_tools = []
237- for tool in all_tools :
238- tool_name = tool .name
239- required_scopes = self ._tool_scopes .get (tool_name , set ())
240-
241- if not required_scopes or required_scopes .issubset (scopes ):
242- filtered_tools .append (tool )
243- else :
244- missing_scopes = required_scopes - scopes
245- logger .info (f"Removing tool '{ tool_name } ' due to missing scopes: { ', ' .join (missing_scopes )} " )
246-
247- return filtered_tools
248-
249251 async def list_tools (self ) -> list [MCPTool ]:
250252 """
251253 Public method to list tools (for compatibility with tests and external code).
0 commit comments