Skip to content

Commit 3a7031d

Browse files
committed
feat: normalize http responses and types
1 parent addc36a commit 3a7031d

17 files changed

+879
-463
lines changed

packages/gg_api_core/pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,5 +37,5 @@ sentry = [
3737
]
3838

3939
[build-system]
40-
requires = ["hatchling"]
41-
build-backend = "hatchling.build"
40+
requires = ["setuptools>=42", "wheel"]
41+
build-backend = "setuptools.build_meta"

packages/gg_api_core/src/gg_api_core/client.py

Lines changed: 222 additions & 124 deletions
Large diffs are not rendered by default.

packages/gg_api_core/src/gg_api_core/mcp_server.py

Lines changed: 43 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,19 @@ class CachedTokenInfoMixin:
3535
"""Mixin for MCP servers that are mono-tenant (only one authenticated identity from startup to close of the server)"""
3636

3737
_token_scopes: set[str] = set()
38-
_token_info = None
38+
_token_info: dict[str, Any] | None = None
3939

4040
def __init__(self, *args, **kwargs):
4141
# Add a custom lifespan contextmanager that fetches and cache token scopes and infos
4242
original_lifespan = kwargs.get("lifespan")
4343
kwargs["lifespan"] = self._create_token_scope_lifespan(original_lifespan)
4444
# Call parent __init__ in the MRO chain
4545
super().__init__(*args, **kwargs)
46+
47+
def clear_cache(self) -> None:
48+
"""Clear cached token information and scopes."""
49+
self._token_scopes = set()
50+
self._token_info = None
4651

4752
def _create_token_scope_lifespan(self, original_lifespan=None):
4853
"""Create a lifespan context manager that fetches token scopes."""
@@ -72,12 +77,14 @@ async def token_scope_lifespan(fastmcp) -> AsyncIterator[dict]:
7277

7378
return token_scope_lifespan
7479

75-
async def get_token_info(self) -> dict:
80+
async def get_token_info(self) -> dict[str, Any]:
7681
"""Return the token info dictionary."""
77-
if token_info := getattr(self, "_token_info", None):
78-
return token_info
82+
if self._token_info is not None:
83+
return self._token_info
7984

80-
return await self._fetch_token_info_from_api()
85+
# Type ignore because this method will be available via MRO from AbstractGitGuardianFastMCP
86+
self._token_info = await self._fetch_token_info_from_api() # type: ignore[attr-defined]
87+
return self._token_info
8188

8289

8390
class AbstractGitGuardianFastMCP(FastMCP, ABC):
@@ -87,6 +94,8 @@ class AbstractGitGuardianFastMCP(FastMCP, ABC):
8794
Subclasses implement authentication-specific behavior.
8895
"""
8996

97+
authentication_mode: AuthenticationMode
98+
9099
def __init__(self, *args, default_scopes: list[str] | None = None, **kwargs):
91100
"""
92101
Initialize the GitGuardian MCP server.
@@ -97,27 +106,31 @@ def __init__(self, *args, default_scopes: list[str] | None = None, **kwargs):
97106
# Map each tool to its required scopes (instance attribute)
98107
self._tool_scopes: dict[str, set[str]] = {}
99108

100-
self.add_middleware(self._scope_filtering_middleware)
109+
self.add_middleware(self._scope_filtering_middleware)
110+
111+
def clear_cache(self) -> None:
112+
"""Clear cached data. Override in subclasses that cache."""
113+
pass
101114

102115
@abstractmethod
103-
def get_personal_access_token(self):
116+
def get_personal_access_token(self) -> str | None:
104117
"""Get the personal access token for the current request"""
105118
pass
106119

107120
@abstractmethod
108-
async def get_token_info(self):
121+
async def get_token_info(self) -> dict[str, Any]:
109122
"""Return the token info dictionary."""
110123
pass
111124

112125
def get_client(self):
113126
return get_client(personal_access_token=self.get_personal_access_token())
114127

115-
async def revoke_current_token(self) -> dict:
128+
async def revoke_current_token(self) -> dict[str, Any]:
116129
"""Revoke the current API token via GitGuardian API."""
117130
try:
118131
logger.debug("Revoking current API token")
119132
# Call the DELETE /api_tokens/self endpoint
120-
result = await self.get_client()._request("DELETE", "/api_tokens/self")
133+
result = await self.get_client().revoke_current_token()
121134
logger.debug("API token revoked")
122135
return result
123136
except Exception as e:
@@ -157,7 +170,7 @@ def wrapper(fn):
157170

158171
return result
159172

160-
async def _fetch_token_scopes_from_api(self, client=None):
173+
async def _fetch_token_scopes_from_api(self, client=None) -> set[str]:
161174
"""Fetch token scopes from the GitGuardian API.
162175
163176
Args:
@@ -183,19 +196,20 @@ async def _fetch_token_scopes_from_api(self, client=None):
183196
except Exception as e:
184197
raise FastMCPError("Error fetching token scopes from /api_tokens/self endpoint") from e
185198

186-
async def _fetch_token_info_from_api(self) -> dict:
199+
async def _fetch_token_info_from_api(self) -> dict[str, Any]:
187200
try:
188201
client = self.get_client()
189202
token_info = await client.get_current_token_info()
190203
self._token_info = token_info
191204
return token_info
192205
except Exception as e:
193-
raise FastMCPError("Error fetching token scopes from /api_tokens/self endpoint") from e
206+
raise FastMCPError("Error fetching token info from /api_tokens/self endpoint") from e
194207

195-
async def get_scopes(self):
196-
if scopes := getattr(self, "_token_scopes", None):
208+
async def get_scopes(self) -> set[str]:
209+
cached_scopes: set[str] | None = getattr(self, "_token_scopes", None)
210+
if cached_scopes:
197211
logger.debug("reading from cached scopes")
198-
return scopes
212+
return cached_scopes
199213

200214
scopes = await self._fetch_token_scopes_from_api()
201215
logger.debug(f"scopes: {scopes}")
@@ -242,7 +256,7 @@ async def list_tools(self) -> list[MCPTool]:
242256

243257

244258
# Common MCP tools for user information and token management
245-
def register_common_tools(mcp_instance: "AbstractGitGuardianFastMCP"):
259+
def register_common_tools(mcp_instance: AbstractGitGuardianFastMCP):
246260
"""Register common MCP tools for user information and token management."""
247261

248262
logger.debug("Registering common MCP tools...")
@@ -251,42 +265,37 @@ def register_common_tools(mcp_instance: "AbstractGitGuardianFastMCP"):
251265
name="get_authenticated_user_info",
252266
description="Get comprehensive information about the authenticated user and current API token including scopes and authentication method",
253267
)
254-
async def get_authenticated_user_info() -> dict:
268+
async def get_authenticated_user_info() -> dict[str, Any]:
255269
"""Get information about the authenticated user and current API token."""
256270
logger.debug("Getting authenticated user information")
257271

258272
token_info = await mcp_instance.get_token_info()
259273
scopes = await mcp_instance.get_scopes()
260274
return {
261275
"token_info": token_info,
262-
"authentication_mode": mcp_instance.authentication_mode,
276+
"authentication_mode": mcp_instance.authentication_mode.value,
263277
"available_scopes": list(scopes),
264278
}
265279

266280
@mcp_instance.tool(
267281
name="revoke_current_token",
268282
description="Revoke the current API token and clean up stored credentials",
269283
)
270-
async def revoke_current_token() -> dict:
284+
async def revoke_current_token() -> dict[str, Any]:
271285
"""Revoke the current API token and clean up stored credentials."""
272286
logger.debug("Starting token revocation process")
273287

274288
try:
275-
client = mcp_instance.client
276-
await client._request("DELETE", "/api_tokens/self")
289+
await mcp_instance.revoke_current_token()
277290
logger.debug("Token revoked via API")
278291

279-
# Clear cached client and token info
280-
mcp_instance._client = None
281-
mcp_instance._token_info = None
282-
# Only clear _token_scopes if it exists (cached modes only)
283-
if hasattr(mcp_instance, "_token_scopes"):
284-
mcp_instance._token_scopes = set()
292+
# Clear cached data
293+
mcp_instance.clear_cache()
285294

286295
return {
287296
"success": True,
288297
"message": "Token revoked and credentials cleaned up",
289-
"authentication_method": mcp_instance._get_auth_method(),
298+
"authentication_method": mcp_instance.authentication_mode.value,
290299
}
291300

292301
except Exception as e:
@@ -304,7 +313,7 @@ class GitGuardianLocalOAuthMCP(CachedTokenInfoMixin, AbstractGitGuardianFastMCP)
304313

305314
authentication_mode = AuthenticationMode.LOCAL_OAUTH_FLOW
306315

307-
def get_personal_access_token(self):
316+
def get_personal_access_token(self) -> str | None:
308317
# It will be actually provided within the client by the OAuth flow, or from the filesystem storage
309318
return None
310319

@@ -318,16 +327,16 @@ def __init__(self, *args, personal_access_token: str, **kwargs):
318327
super().__init__(*args, **kwargs)
319328
self.personal_access_token = personal_access_token
320329

321-
def get_personal_access_token(self):
330+
def get_personal_access_token(self) -> str:
322331
return self.personal_access_token
323332

324333

325334
class GitGuardianAuthorizationHeaderMCP(AbstractGitGuardianFastMCP):
326335
"""GitGuardian MCP server using per-request Authorization header (HTTP/SSE mode)."""
327336

328-
authentication_mode = (AuthenticationMode.AUTHORIZATION_HEADER,)
337+
authentication_mode = AuthenticationMode.AUTHORIZATION_HEADER
329338

330-
def get_personal_access_token(self):
339+
def get_personal_access_token(self) -> str:
331340
headers = get_http_headers()
332341
if not headers:
333342
raise ValidationError("No HTTP headers available - Authorization header required")
@@ -364,7 +373,7 @@ def _default_extract_token(auth_header: str) -> str | None:
364373

365374
return None
366375

367-
async def get_token_info(self):
376+
async def get_token_info(self) -> dict[str, Any]:
368377
return await self._fetch_token_info_from_api()
369378

370379

packages/gg_api_core/src/gg_api_core/tools/assign_incident.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -95,15 +95,8 @@ async def assign_incident(params: AssignIncidentParams) -> AssignIncidentResult:
9595
logger.debug(f"Looking up member ID for email: {params.email}")
9696
try:
9797
# Use the /members endpoint to search by email
98-
result, _ = await client._request("GET", "/members", params={"search": params.email}, return_headers=True)
99-
100-
# Handle response format
101-
if isinstance(result, dict):
102-
members = cast(list[dict[str, Any]], result.get("results", result.get("data", [])))
103-
elif isinstance(result, list):
104-
members = result
105-
else:
106-
raise ToolError(f"Unexpected response format when searching for member: {type(result).__name__}")
98+
result = await client._request_list("/members", params={"search": params.email})
99+
members = result["data"]
107100

108101
# Find exact email match
109102
matching_member: dict[str, Any] | None = None
@@ -150,14 +143,14 @@ async def assign_incident(params: AssignIncidentParams) -> AssignIncidentResult:
150143

151144
try:
152145
# Call the client method
153-
result = await client.assign_incident(incident_id=str(params.incident_id), assignee_id=str(assignee_id))
146+
api_result = await client.assign_incident(incident_id=str(params.incident_id), assignee_id=str(assignee_id))
154147

155148
logger.debug(f"Successfully assigned incident {params.incident_id} to member {assignee_id}")
156149

157150
# Parse the response
158-
if isinstance(result, dict):
151+
if isinstance(api_result, dict):
159152
# Remove assignee_id from result dict to avoid conflict with our explicit parameter
160-
result_copy = result.copy()
153+
result_copy = api_result.copy()
161154
result_copy.pop("assignee_id", None)
162155
return AssignIncidentResult(
163156
incident_id=params.incident_id, assignee_id=assignee_id, success=True, **result_copy

packages/gg_api_core/src/gg_api_core/tools/generate_honey_token.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,7 @@ async def generate_honeytoken(params: GenerateHoneytokenParams) -> GenerateHoney
8383
result = await client.list_honeytokens(**filters)
8484

8585
# Process the result to get the list of tokens
86-
if isinstance(result, dict):
87-
honeytokens = result.get("honeytokens", [])
88-
else:
89-
honeytokens = result
86+
honeytokens = result.get("data", [])
9087

9188
# Find the most recent active token
9289
if honeytokens:
@@ -111,22 +108,22 @@ async def generate_honeytoken(params: GenerateHoneytokenParams) -> GenerateHoney
111108
{"key": "source", "value": "auto-generated"},
112109
{"key": "type", "value": "aws"},
113110
]
114-
result = await client.create_honeytoken(
111+
creation_result = await client.create_honeytoken(
115112
name=params.name, description=params.description, custom_tags=custom_tags
116113
)
117114

118115
# Validate that we got an ID in the response
119-
if not result.get("id"):
116+
if not creation_result.get("id"):
120117
raise ToolError("Failed to get honeytoken ID from GitGuardian API")
121118

122-
logger.debug(f"Generated new honeytoken with ID: {result.get('id')}")
119+
logger.debug(f"Generated new honeytoken with ID: {creation_result.get('id')}")
123120

124121
# Add injection recommendations to the response
125-
result["injection_recommendations"] = {
122+
creation_result["injection_recommendations"] = {
126123
"instructions": "Add the honeytoken to your codebase in configuration files, environment variables, or code comments to detect unauthorized access."
127124
}
128125

129-
return GenerateHoneytokenResult(**result)
126+
return GenerateHoneytokenResult(**creation_result)
130127
except Exception as e:
131128
logger.error(f"Error generating honeytoken: {str(e)}")
132129
raise ToolError(f"Failed to generate honeytoken: {str(e)}")

packages/gg_api_core/src/gg_api_core/tools/list_honeytokens.py

Lines changed: 11 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,15 @@ class ListHoneytokensParams(BaseModel):
2626
creator_id: str | int | None = Field(default=None, description="Filter by creator ID")
2727
creator_api_token_id: str | int | None = Field(default=None, description="Filter by creator API token ID")
2828
per_page: int = Field(default=20, description="Number of results per page (default: 20, min: 1, max: 100)")
29+
cursor: str | None = Field(default=None, description="Pagination cursor from a previous response")
2930
get_all: bool = Field(default=False, description="If True, fetch all results using cursor-based pagination")
3031

3132

3233
class ListHoneytokensResult(BaseModel):
3334
"""Result from listing honeytokens."""
3435

3536
honeytokens: list[dict[str, Any]] = Field(description="List of honeytoken objects")
37+
next_cursor: str | None = Field(default=None, description="Cursor for fetching the next page (null if no more results)")
3638

3739

3840
async def list_honeytokens(params: ListHoneytokensParams) -> ListHoneytokensResult:
@@ -73,48 +75,25 @@ async def list_honeytokens(params: ListHoneytokensParams) -> ListHoneytokensResu
7375
except Exception as e:
7476
logger.warning(f"Failed to get current user info for 'mine' filter: {str(e)}")
7577

76-
# Build filters dictionary with parameters supported by the client API
77-
filters: dict[str, Any] = {}
78-
if params.status is not None:
79-
filters["status"] = params.status
80-
if params.search is not None:
81-
filters["search"] = params.search
82-
if params.ordering is not None:
83-
filters["ordering"] = params.ordering
84-
if params.show_token is not None:
85-
filters["show_token"] = params.show_token
86-
if creator_id is not None:
87-
filters["creator_id"] = creator_id
88-
if params.creator_api_token_id is not None:
89-
filters["creator_api_token_id"] = params.creator_api_token_id
90-
if params.per_page is not None:
91-
filters["per_page"] = params.per_page
92-
if params.get_all is not None:
93-
filters["get_all"] = params.get_all
94-
95-
logger.debug(f"Filters: {json.dumps({k: v for k, v in filters.items() if v is not None})}")
9678

9779
try:
98-
result = await client.list_honeytokens(
80+
response = await client.list_honeytokens(
9981
status=params.status,
10082
search=params.search,
10183
ordering=params.ordering,
10284
show_token=params.show_token,
103-
creator_id=creator_id,
104-
creator_api_token_id=params.creator_api_token_id,
85+
creator_id=str(creator_id) if creator_id is not None else None,
86+
creator_api_token_id=str(params.creator_api_token_id) if params.creator_api_token_id is not None else None,
10587
per_page=params.per_page,
88+
cursor=params.cursor,
10689
get_all=params.get_all,
10790
)
10891

109-
# Handle both response formats: either a dict with 'honeytokens' key or a list directly
110-
if isinstance(result, dict):
111-
honeytokens = result.get("honeytokens", [])
112-
else:
113-
# If the result is already a list, use it directly
114-
honeytokens = result
115-
116-
logger.debug(f"Found {len(honeytokens)} honeytokens")
117-
return ListHoneytokensResult(honeytokens=honeytokens)
92+
honeytokens_data = response["data"]
93+
next_cursor = response["cursor"]
94+
95+
logger.debug(f"Found {len(honeytokens_data)} honeytokens")
96+
return ListHoneytokensResult(honeytokens=honeytokens_data, next_cursor=next_cursor)
11897
except Exception as e:
11998
logger.error(f"Error listing honeytokens: {str(e)}")
12099
raise ToolError(str(e))

0 commit comments

Comments
 (0)