Skip to content

Commit e62d740

Browse files
committed
Add correlation ID system for unified request tracking
Signed-off-by: Shoumi <shoumimukherjee@gmail.com>
1 parent 04789d8 commit e62d740

File tree

15 files changed

+1226
-13
lines changed

15 files changed

+1226
-13
lines changed

.env.example

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,17 @@ LOG_MAX_SIZE_MB=1
653653
LOG_BACKUP_COUNT=5
654654
LOG_BUFFER_SIZE_MB=1.0
655655

656+
# Correlation ID / Request Tracking
657+
# Enable automatic correlation ID tracking for unified request tracing
658+
# Options: true (default), false
659+
CORRELATION_ID_ENABLED=true
660+
# HTTP header name for correlation ID (default: X-Correlation-ID)
661+
CORRELATION_ID_HEADER=X-Correlation-ID
662+
# Preserve incoming correlation IDs from clients (default: true)
663+
CORRELATION_ID_PRESERVE=true
664+
# Include correlation ID in HTTP response headers (default: true)
665+
CORRELATION_ID_RESPONSE_HEADER=true
666+
656667
# Transport Protocol Configuration
657668
# Options: all (default), sse, streamablehttp, http
658669
# - all: Enable all transport protocols

mcpgateway/auth.py

Lines changed: 62 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,62 @@
2727
from mcpgateway.db import EmailUser, SessionLocal
2828
from mcpgateway.plugins.framework import get_plugin_manager, GlobalContext, HttpAuthResolveUserPayload, HttpHeaderPayload, HttpHookType, PluginViolationError
2929
from mcpgateway.services.team_management_service import TeamManagementService # pylint: disable=import-outside-toplevel
30+
from mcpgateway.utils.correlation_id import get_correlation_id
3031
from mcpgateway.utils.verify_credentials import verify_jwt_token
3132

3233
# Security scheme
33-
bearer_scheme = HTTPBearer(auto_error=False)
34+
security = HTTPBearer(auto_error=False)
35+
36+
37+
def _log_auth_event(
38+
logger: logging.Logger,
39+
message: str,
40+
level: int = logging.INFO,
41+
user_id: Optional[str] = None,
42+
auth_method: Optional[str] = None,
43+
auth_success: bool = False,
44+
security_event: Optional[str] = None,
45+
security_severity: str = "low",
46+
**extra_context
47+
) -> None:
48+
"""Log authentication event with structured context and request_id.
49+
50+
This helper creates structured log records that include request_id from the
51+
correlation ID context, enabling end-to-end tracing of authentication flows.
52+
53+
Args:
54+
logger: Logger instance to use
55+
message: Log message
56+
level: Log level (default: INFO)
57+
user_id: User identifier
58+
auth_method: Authentication method used (jwt, api_token, etc.)
59+
auth_success: Whether authentication succeeded
60+
security_event: Type of security event (authentication, authorization, etc.)
61+
security_severity: Severity level (low, medium, high, critical)
62+
**extra_context: Additional context fields
63+
"""
64+
# Get request_id from correlation ID context
65+
request_id = get_correlation_id()
66+
67+
# Build structured log record
68+
extra = {
69+
'request_id': request_id,
70+
'entity_type': 'auth',
71+
'auth_success': auth_success,
72+
'security_event': security_event or 'authentication',
73+
'security_severity': security_severity,
74+
}
75+
76+
if user_id:
77+
extra['user_id'] = user_id
78+
if auth_method:
79+
extra['auth_method'] = auth_method
80+
81+
# Add any additional context
82+
extra.update(extra_context)
83+
84+
# Log with structured context
85+
logger.log(level, message, extra=extra)
3486

3587

3688
def get_db() -> Generator[Session, Never, None]:
@@ -169,12 +221,15 @@ async def get_current_user(
169221
if request and hasattr(request, "headers"):
170222
headers = dict(request.headers)
171223

172-
# Get request ID from request state (set by middleware) or generate new one
173-
request_id = None
174-
if request and hasattr(request, "state") and hasattr(request.state, "request_id"):
175-
request_id = request.state.request_id
176-
else:
177-
request_id = uuid.uuid4().hex
224+
# Get request ID from correlation ID context (set by CorrelationIDMiddleware)
225+
request_id = get_correlation_id()
226+
if not request_id:
227+
# Fallback chain for safety
228+
if request and hasattr(request, "state") and hasattr(request.state, "request_id"):
229+
request_id = request.state.request_id
230+
else:
231+
request_id = uuid.uuid4().hex
232+
logger.debug(f"Generated fallback request ID in get_current_user: {request_id}")
178233

179234
# Create global context
180235
global_context = GlobalContext(

mcpgateway/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -747,6 +747,12 @@ def _parse_allowed_origins(cls, v: Any) -> Set[str]:
747747
# Enable span events
748748
observability_events_enabled: bool = Field(default=True, description="Enable event logging within spans")
749749

750+
# Correlation ID Settings
751+
correlation_id_enabled: bool = Field(default=True, description="Enable automatic correlation ID tracking for requests")
752+
correlation_id_header: str = Field(default="X-Correlation-ID", description="HTTP header name for correlation ID")
753+
correlation_id_preserve: bool = Field(default=True, description="Preserve correlation IDs from incoming requests")
754+
correlation_id_response_header: bool = Field(default=True, description="Include correlation ID in response headers")
755+
750756
@field_validator("log_level", mode="before")
751757
@classmethod
752758
def validate_log_level(cls, v: str) -> str:

mcpgateway/main.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
from mcpgateway.db import refresh_slugs_on_startup, SessionLocal
7171
from mcpgateway.db import Tool as DbTool
7272
from mcpgateway.handlers.sampling import SamplingHandler
73+
from mcpgateway.middleware.correlation_id import CorrelationIDMiddleware
7374
from mcpgateway.middleware.http_auth_middleware import HttpAuthMiddleware
7475
from mcpgateway.middleware.protocol_version import MCPProtocolVersionMiddleware
7576
from mcpgateway.middleware.rbac import get_current_user_with_permissions, require_permission
@@ -1065,6 +1066,13 @@ async def _call_streamable_http(self, scope, receive, send):
10651066
# Add HTTP authentication hook middleware for plugins (before auth dependencies)
10661067
if plugin_manager:
10671068
app.add_middleware(HttpAuthMiddleware, plugin_manager=plugin_manager)
1069+
logger.info("🔌 HTTP authentication hooks enabled for plugins")
1070+
1071+
# Add correlation ID middleware if enabled
1072+
# Note: Registered AFTER HttpAuthMiddleware so it executes FIRST (middleware runs in LIFO order)
1073+
if settings.correlation_id_enabled:
1074+
app.add_middleware(CorrelationIDMiddleware)
1075+
logger.info(f"✅ Correlation ID tracking enabled (header: {settings.correlation_id_header})")
10681076

10691077
# Add custom DocsAuthMiddleware
10701078
app.add_middleware(DocsAuthMiddleware)
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# -*- coding: utf-8 -*-
2+
"""Location: ./mcpgateway/middleware/correlation_id.py
3+
Copyright 2025
4+
SPDX-License-Identifier: Apache-2.0
5+
Authors: MCP Gateway Contributors
6+
7+
Correlation ID (Request ID) Middleware.
8+
9+
This middleware handles X-Correlation-ID HTTP headers and maps them to the internal
10+
request_id used throughout the system for unified request tracing.
11+
12+
Key concept: HTTP X-Correlation-ID header → Internal request_id field (single ID for entire request flow)
13+
14+
The middleware automatically extracts or generates request IDs for every HTTP request,
15+
stores them in context variables for async-safe propagation across services, and
16+
injects them back into response headers for client-side correlation.
17+
18+
This enables end-to-end tracing: HTTP → Middleware → Services → Plugins → Logs (all with same request_id)
19+
"""
20+
21+
# Standard
22+
import logging
23+
from typing import Callable
24+
25+
# Third-Party
26+
from fastapi import Request, Response
27+
from starlette.middleware.base import BaseHTTPMiddleware
28+
29+
# First-Party
30+
from mcpgateway.config import settings
31+
from mcpgateway.utils.correlation_id import (
32+
clear_correlation_id,
33+
extract_correlation_id_from_headers,
34+
generate_correlation_id,
35+
set_correlation_id,
36+
)
37+
38+
logger = logging.getLogger(__name__)
39+
40+
41+
class CorrelationIDMiddleware(BaseHTTPMiddleware):
42+
"""Middleware for automatic request ID (correlation ID) handling.
43+
44+
This middleware:
45+
1. Extracts request ID from X-Correlation-ID header in incoming requests
46+
2. Generates a new UUID if no correlation ID is present
47+
3. Stores the ID in context variables for the request lifecycle (used as request_id throughout system)
48+
4. Injects the request ID into X-Correlation-ID response header
49+
5. Cleans up context after request completion
50+
51+
The request ID extracted/generated here becomes the unified request_id used in:
52+
- All log entries (request_id field)
53+
- GlobalContext.request_id (when plugins execute)
54+
- Service method calls for tracing
55+
- Database queries for request tracking
56+
57+
Configuration is controlled via settings:
58+
- correlation_id_enabled: Enable/disable the middleware
59+
- correlation_id_header: Header name to use (default: X-Correlation-ID)
60+
- correlation_id_preserve: Whether to preserve incoming IDs (default: True)
61+
- correlation_id_response_header: Whether to add ID to responses (default: True)
62+
"""
63+
64+
def __init__(self, app):
65+
"""Initialize the correlation ID (request ID) middleware.
66+
67+
Args:
68+
app: The FastAPI application instance
69+
"""
70+
super().__init__(app)
71+
self.header_name = getattr(settings, 'correlation_id_header', 'X-Correlation-ID')
72+
self.preserve_incoming = getattr(settings, 'correlation_id_preserve', True)
73+
self.add_to_response = getattr(settings, 'correlation_id_response_header', True)
74+
75+
async def dispatch(self, request: Request, call_next: Callable) -> Response:
76+
"""Process the request and manage request ID (correlation ID) lifecycle.
77+
78+
Extracts or generates a request ID, stores it in context variables for use throughout
79+
the request lifecycle (becomes request_id in logs, services, plugins), and injects
80+
it back into the X-Correlation-ID response header.
81+
82+
Args:
83+
request: The incoming HTTP request
84+
call_next: The next middleware or route handler
85+
86+
Returns:
87+
Response: The HTTP response with correlation ID header added
88+
"""
89+
# Extract correlation ID from incoming request headers
90+
correlation_id = None
91+
if self.preserve_incoming:
92+
correlation_id = extract_correlation_id_from_headers(
93+
dict(request.headers),
94+
self.header_name
95+
)
96+
97+
# Generate new correlation ID if none was provided
98+
if not correlation_id:
99+
correlation_id = generate_correlation_id()
100+
logger.debug(f"Generated new correlation ID: {correlation_id}")
101+
else:
102+
logger.debug(f"Using client-provided correlation ID: {correlation_id}")
103+
104+
# Store correlation ID in context variable for this request
105+
# This makes it available to all downstream code (auth, services, plugins, logs)
106+
set_correlation_id(correlation_id)
107+
108+
try:
109+
# Process the request
110+
response = await call_next(request)
111+
112+
# Add correlation ID to response headers if enabled
113+
if self.add_to_response:
114+
response.headers[self.header_name] = correlation_id
115+
116+
return response
117+
118+
finally:
119+
# Clean up context after request completes
120+
# Note: ContextVar automatically cleans up, but explicit cleanup is good practice
121+
clear_correlation_id()

mcpgateway/middleware/http_auth_middleware.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
# First-Party
1919
from mcpgateway.plugins.framework import GlobalContext, HttpHeaderPayload, HttpHookType, HttpPostRequestPayload, HttpPreRequestPayload, PluginManager
20+
from mcpgateway.utils.correlation_id import generate_correlation_id, get_correlation_id
2021

2122
logger = logging.getLogger(__name__)
2223

@@ -60,9 +61,14 @@ async def dispatch(self, request: Request, call_next):
6061
if not self.plugin_manager:
6162
return await call_next(request)
6263

63-
# Generate request ID for tracing and store in request state
64-
# This ensures all hooks and downstream code see the same request ID
65-
request_id = uuid.uuid4().hex
64+
# Use correlation ID from CorrelationIDMiddleware if available
65+
# This ensures all hooks and downstream code see the same unified request ID
66+
request_id = get_correlation_id()
67+
if not request_id:
68+
# Fallback if correlation ID middleware is disabled
69+
request_id = generate_correlation_id()
70+
logger.debug(f"Correlation ID not found, generated fallback: {request_id}")
71+
6672
request.state.request_id = request_id
6773

6874
# Create global context for hooks

mcpgateway/middleware/request_logging_middleware.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
# First-Party
2525
from mcpgateway.services.logging_service import LoggingService
26+
from mcpgateway.utils.correlation_id import get_correlation_id
2627

2728
# Initialize logging service first
2829
logging_service = LoggingService()
@@ -171,12 +172,16 @@ async def dispatch(self, request: Request, call_next: Callable):
171172
# Mask sensitive headers
172173
masked_headers = mask_sensitive_headers(dict(request.headers))
173174

175+
# Get correlation ID for request tracking
176+
request_id = get_correlation_id()
177+
174178
logger.log(
175179
log_level,
176180
f"📩 Incoming request: {request.method} {request.url.path}\n"
177181
f"Query params: {dict(request.query_params)}\n"
178182
f"Headers: {masked_headers}\n"
179183
f"Body: {payload_str}{'... [truncated]' if truncated else ''}",
184+
extra={"request_id": request_id},
180185
)
181186

182187
except Exception as e:

mcpgateway/observability.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,24 @@ def create_span(name: str, attributes: Optional[Dict[str, Any]] = None) -> Any:
393393
# Return a no-op context manager if tracing is not configured or available
394394
return nullcontext()
395395

396+
# Auto-inject correlation ID into all spans for request tracing
397+
try:
398+
# Import here to avoid circular dependency
399+
from mcpgateway.utils.correlation_id import get_correlation_id
400+
401+
correlation_id = get_correlation_id()
402+
if correlation_id:
403+
if attributes is None:
404+
attributes = {}
405+
# Add correlation ID if not already present
406+
if "correlation_id" not in attributes:
407+
attributes["correlation_id"] = correlation_id
408+
if "request_id" not in attributes:
409+
attributes["request_id"] = correlation_id # Alias for compatibility
410+
except ImportError:
411+
# Correlation ID module not available, continue without it
412+
pass
413+
396414
# Start span and return the context manager
397415
span_context = _TRACER.start_as_current_span(name)
398416

mcpgateway/services/a2a_service.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from mcpgateway.services.logging_service import LoggingService
2929
from mcpgateway.services.team_management_service import TeamManagementService
3030
from mcpgateway.services.tool_service import ToolService
31+
from mcpgateway.utils.correlation_id import get_correlation_id
3132
from mcpgateway.utils.create_slug import slugify
3233
from mcpgateway.utils.services_auth import encode_auth # ,decode_auth
3334

0 commit comments

Comments
 (0)