|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +"""Location: ./mcpgateway/middleware/correlation_id.py |
| 3 | +Copyright 2025 |
| 4 | +SPDX-License-Identifier: Apache-2.0 |
| 5 | +Authors: MCP Gateway Contributors |
| 6 | +
|
| 7 | +Correlation ID (Request ID) Middleware. |
| 8 | +
|
| 9 | +This middleware handles X-Correlation-ID HTTP headers and maps them to the internal |
| 10 | +request_id used throughout the system for unified request tracing. |
| 11 | +
|
| 12 | +Key concept: HTTP X-Correlation-ID header → Internal request_id field (single ID for entire request flow) |
| 13 | +
|
| 14 | +The middleware automatically extracts or generates request IDs for every HTTP request, |
| 15 | +stores them in context variables for async-safe propagation across services, and |
| 16 | +injects them back into response headers for client-side correlation. |
| 17 | +
|
| 18 | +This enables end-to-end tracing: HTTP → Middleware → Services → Plugins → Logs (all with same request_id) |
| 19 | +""" |
| 20 | + |
| 21 | +# Standard |
| 22 | +import logging |
| 23 | +from typing import Callable |
| 24 | + |
| 25 | +# Third-Party |
| 26 | +from fastapi import Request, Response |
| 27 | +from starlette.middleware.base import BaseHTTPMiddleware |
| 28 | + |
| 29 | +# First-Party |
| 30 | +from mcpgateway.config import settings |
| 31 | +from mcpgateway.utils.correlation_id import ( |
| 32 | + clear_correlation_id, |
| 33 | + extract_correlation_id_from_headers, |
| 34 | + generate_correlation_id, |
| 35 | + set_correlation_id, |
| 36 | +) |
| 37 | + |
| 38 | +logger = logging.getLogger(__name__) |
| 39 | + |
| 40 | + |
| 41 | +class CorrelationIDMiddleware(BaseHTTPMiddleware): |
| 42 | + """Middleware for automatic request ID (correlation ID) handling. |
| 43 | +
|
| 44 | + This middleware: |
| 45 | + 1. Extracts request ID from X-Correlation-ID header in incoming requests |
| 46 | + 2. Generates a new UUID if no correlation ID is present |
| 47 | + 3. Stores the ID in context variables for the request lifecycle (used as request_id throughout system) |
| 48 | + 4. Injects the request ID into X-Correlation-ID response header |
| 49 | + 5. Cleans up context after request completion |
| 50 | +
|
| 51 | + The request ID extracted/generated here becomes the unified request_id used in: |
| 52 | + - All log entries (request_id field) |
| 53 | + - GlobalContext.request_id (when plugins execute) |
| 54 | + - Service method calls for tracing |
| 55 | + - Database queries for request tracking |
| 56 | +
|
| 57 | + Configuration is controlled via settings: |
| 58 | + - correlation_id_enabled: Enable/disable the middleware |
| 59 | + - correlation_id_header: Header name to use (default: X-Correlation-ID) |
| 60 | + - correlation_id_preserve: Whether to preserve incoming IDs (default: True) |
| 61 | + - correlation_id_response_header: Whether to add ID to responses (default: True) |
| 62 | + """ |
| 63 | + |
| 64 | + def __init__(self, app): |
| 65 | + """Initialize the correlation ID (request ID) middleware. |
| 66 | +
|
| 67 | + Args: |
| 68 | + app: The FastAPI application instance |
| 69 | + """ |
| 70 | + super().__init__(app) |
| 71 | + self.header_name = getattr(settings, 'correlation_id_header', 'X-Correlation-ID') |
| 72 | + self.preserve_incoming = getattr(settings, 'correlation_id_preserve', True) |
| 73 | + self.add_to_response = getattr(settings, 'correlation_id_response_header', True) |
| 74 | + |
| 75 | + async def dispatch(self, request: Request, call_next: Callable) -> Response: |
| 76 | + """Process the request and manage request ID (correlation ID) lifecycle. |
| 77 | +
|
| 78 | + Extracts or generates a request ID, stores it in context variables for use throughout |
| 79 | + the request lifecycle (becomes request_id in logs, services, plugins), and injects |
| 80 | + it back into the X-Correlation-ID response header. |
| 81 | +
|
| 82 | + Args: |
| 83 | + request: The incoming HTTP request |
| 84 | + call_next: The next middleware or route handler |
| 85 | +
|
| 86 | + Returns: |
| 87 | + Response: The HTTP response with correlation ID header added |
| 88 | + """ |
| 89 | + # Extract correlation ID from incoming request headers |
| 90 | + correlation_id = None |
| 91 | + if self.preserve_incoming: |
| 92 | + correlation_id = extract_correlation_id_from_headers( |
| 93 | + dict(request.headers), |
| 94 | + self.header_name |
| 95 | + ) |
| 96 | + |
| 97 | + # Generate new correlation ID if none was provided |
| 98 | + if not correlation_id: |
| 99 | + correlation_id = generate_correlation_id() |
| 100 | + logger.debug(f"Generated new correlation ID: {correlation_id}") |
| 101 | + else: |
| 102 | + logger.debug(f"Using client-provided correlation ID: {correlation_id}") |
| 103 | + |
| 104 | + # Store correlation ID in context variable for this request |
| 105 | + # This makes it available to all downstream code (auth, services, plugins, logs) |
| 106 | + set_correlation_id(correlation_id) |
| 107 | + |
| 108 | + try: |
| 109 | + # Process the request |
| 110 | + response = await call_next(request) |
| 111 | + |
| 112 | + # Add correlation ID to response headers if enabled |
| 113 | + if self.add_to_response: |
| 114 | + response.headers[self.header_name] = correlation_id |
| 115 | + |
| 116 | + return response |
| 117 | + |
| 118 | + finally: |
| 119 | + # Clean up context after request completes |
| 120 | + # Note: ContextVar automatically cleans up, but explicit cleanup is good practice |
| 121 | + clear_correlation_id() |
0 commit comments