diff --git a/.coveragerc b/.coveragerc index 5e84e1a2c..fe8b64305 100644 --- a/.coveragerc +++ b/.coveragerc @@ -3,6 +3,9 @@ branch = True omit = mcpgateway/alembic/* */__init__.py + mcpgateway/tools/builder/common.py + mcpgateway/tools/builder/dagger_deploy.py + mcpgateway/tools/builder/python_deploy.py [report] exclude_lines = diff --git a/.env.example b/.env.example index 693715cf2..8afdf1313 100644 --- a/.env.example +++ b/.env.example @@ -537,7 +537,7 @@ SECURITY_HEADERS_ENABLED=true # null or none: Completely removes iframe restrictions (no headers sent) # ALLOW-FROM uri: Allows specific domain (deprecated, use CSP instead) # ALLOW-ALL uri: Allows all (*, http, https) -# +# # Both X-Frame-Options header and CSP frame-ancestors directive are automatically synced. # Modern browsers prioritize CSP frame-ancestors over X-Frame-Options. X_FRAME_OPTIONS=DENY @@ -659,6 +659,17 @@ LOG_MAX_SIZE_MB=1 LOG_BACKUP_COUNT=5 LOG_BUFFER_SIZE_MB=1.0 +# Correlation ID / Request Tracking +# Enable automatic correlation ID tracking for unified request tracing +# Options: true (default), false +CORRELATION_ID_ENABLED=true +# HTTP header name for correlation ID (default: X-Correlation-ID) +CORRELATION_ID_HEADER=X-Correlation-ID +# Preserve incoming correlation IDs from clients (default: true) +CORRELATION_ID_PRESERVE=true +# Include correlation ID in HTTP response headers (default: true) +CORRELATION_ID_RESPONSE_HEADER=true + # Transport Protocol Configuration # Options: all (default), sse, streamablehttp, http # - all: Enable all transport protocols @@ -1193,6 +1204,16 @@ PAGINATION_INCLUDE_LINKS=true # Enable TLS for gRPC connections by default # MCPGATEWAY_GRPC_TLS_ENABLED=false +##################################### +# Security Event Logging +##################################### + +# Enable security event logging (authentication attempts, authorization failures, etc.) +# Options: true (default), false +# When enabled, the AuthContextMiddleware will log all authentication attempts to the database +# This is INDEPENDENT of observability settings - security logging is critical for audit trails +# SECURITY_LOGGING_ENABLED=true + ##################################### # Observability Settings ##################################### diff --git a/MANIFEST.in b/MANIFEST.in index 6090fb781..c1e0a957a 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -200,3 +200,12 @@ exclude plugins/external/llmguard/Containerfile exclude plugins/external/llmguard/MANIFEST.in exclude plugins/external/llmguard/pyproject.toml exclude plugins/external/llmguard/run-server.sh + +# Exclude cedar + +exclude plugins/external/cedar/.dockerignore +exclude plugins/external/cedar/.env.template +exclude plugins/external/cedar/.ruff.toml +exclude plugins/external/cedar/Containerfile +exclude plugins/external/cedar/MANIFEST.in +exclude plugins/external/cedar/pyproject.toml diff --git a/Makefile b/Makefile index fb6a7c460..76503fa83 100644 --- a/Makefile +++ b/Makefile @@ -183,6 +183,8 @@ check-env-dev: # help: ▶️ SERVE # help: serve - Run production Gunicorn server on :4444 # help: certs - Generate self-signed TLS cert & key in ./certs (won't overwrite) +# help: certs-passphrase - Generate self-signed cert with passphrase-protected key +# help: certs-remove-passphrase - Remove passphrase from encrypted key # help: certs-jwt - Generate JWT RSA keys in ./certs/jwt/ (idempotent) # help: certs-jwt-ecdsa - Generate JWT ECDSA keys in ./certs/jwt/ (idempotent) # help: certs-all - Generate both TLS certs and JWT keys (combo target) @@ -225,6 +227,45 @@ certs: ## Generate ./certs/cert.pem & ./certs/key.pem fi chmod 640 certs/key.pem +certs-passphrase: ## Generate self-signed cert with passphrase-protected key + @if [ -f certs/cert.pem ] && [ -f certs/key-encrypted.pem ]; then \ + echo "🔏 Existing passphrase-protected certificates found - skipping."; \ + else \ + echo "🔏 Generating passphrase-protected certificate (1 year)..."; \ + mkdir -p certs; \ + read -sp "Enter passphrase for private key: " PASSPHRASE; echo; \ + read -sp "Confirm passphrase: " PASSPHRASE2; echo; \ + if [ "$$PASSPHRASE" != "$$PASSPHRASE2" ]; then \ + echo "❌ Passphrases do not match!"; \ + exit 1; \ + fi; \ + openssl req -x509 -newkey rsa:4096 -sha256 -days 365 \ + -keyout certs/key-encrypted.pem -out certs/cert.pem \ + -subj "/CN=localhost" \ + -addext "subjectAltName=DNS:localhost,IP:127.0.0.1" \ + -passout pass:"$$PASSPHRASE"; \ + echo "✅ Passphrase-protected certificate created"; \ + echo "📁 Certificate: ./certs/cert.pem"; \ + echo "📁 Encrypted Key: ./certs/key-encrypted.pem"; \ + echo ""; \ + echo "💡 To use this certificate:"; \ + echo " 1. Set KEY_FILE_PASSWORD environment variable"; \ + echo " 2. Run: KEY_FILE_PASSWORD='your-passphrase' SSL=true CERT_FILE=certs/cert.pem KEY_FILE=certs/key-encrypted.pem make serve-ssl"; \ + fi + @chmod 600 certs/key-encrypted.pem + +certs-remove-passphrase: ## Remove passphrase from encrypted key (creates key.pem from key-encrypted.pem) + @if [ ! -f certs/key-encrypted.pem ]; then \ + echo "❌ No encrypted key found at certs/key-encrypted.pem"; \ + echo "💡 Generate one with: make certs-passphrase"; \ + exit 1; \ + fi + @echo "🔓 Removing passphrase from private key..." + @openssl rsa -in certs/key-encrypted.pem -out certs/key.pem + @chmod 600 certs/key.pem + @echo "✅ Passphrase removed - unencrypted key saved to certs/key.pem" + @echo "⚠️ Keep this file secure! It contains your unencrypted private key." + certs-jwt: ## Generate JWT RSA keys in ./certs/jwt/ (idempotent) @if [ -f certs/jwt/private.pem ] && [ -f certs/jwt/public.pem ]; then \ echo "🔐 Existing JWT RSA keys found in ./certs/jwt - skipping generation."; \ diff --git a/README.md b/README.md index fae6a2b60..ff0390d3b 100644 --- a/README.md +++ b/README.md @@ -1619,7 +1619,7 @@ ContextForge implements **OAuth 2.0 Dynamic Client Registration (RFC 7591)** and > > **iframe Embedding**: The gateway controls iframe embedding through both `X-Frame-Options` header and CSP `frame-ancestors` directive (both are automatically synced). Options: > - `X_FRAME_OPTIONS=DENY` (default): Blocks all iframe embedding -> - `X_FRAME_OPTIONS=SAMEORIGIN`: Allows embedding from same domain only +> - `X_FRAME_OPTIONS=SAMEORIGIN`: Allows embedding from same domain only > - `X_FRAME_OPTIONS="ALLOW-ALL"`: Allows embedding from all sources (sets `frame-ancestors * file: http: https:`) > - `X_FRAME_OPTIONS=null` or `none`: Completely removes iframe restrictions (no headers sent) > diff --git a/docker-compose.yml b/docker-compose.yml index c47e70a9e..4e94f7374 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -95,9 +95,11 @@ services: - SECURITY_HEADERS_ENABLED=true - CORS_ALLOW_CREDENTIALS=true - SECURE_COOKIES=false + ## Uncomment to enable HTTPS # - SSL=true # - CERT_FILE=/app/certs/cert.pem # - KEY_FILE=/app/certs/key.pem + # - KEY_FILE_PASSWORD=${KEY_FILE_PASSWORD} # Optional: Set in .env for passphrase-protected keys # Uncomment to enable plugins - PLUGINS_ENABLED=true # Uncomment to enable catalog @@ -128,16 +130,32 @@ services: # condition: service_completed_successfully healthcheck: + ## Uncomment for HTTP healthcheck test: ["CMD", "python3", "-c", "import urllib.request; import json; resp = urllib.request.urlopen('http://localhost:4444/health', timeout=5); data = json.loads(resp.read()); exit(0 if data.get('status') == 'healthy' else 1)"] - #test: ["CMD", "curl", "-f", "https://localhost:4444/health"] + ## Uncomment for HTTPS healthcheck + # test: ["CMD", "curl", "-f", "https://localhost:4444/health"] + ## Uncomment to skip SSL validation (self-signed certs) + # test: ["CMD", "curl", "-fk", "https://localhost:4444/health"] interval: 30s timeout: 10s retries: 5 start_period: 30s + # ────────────────────────────────────────────────────────────────────── + # Volume Mounts + # ────────────────────────────────────────────────────────────────────── + # Uncomment to mount catalog configuration and SSL certificates # volumes: # - ./mcp-catalog.yml:/app/mcp-catalog.yml:ro # mount catalog configuration - # - ./certs:/app/certs:ro # mount certs folder read-only (includes both SSL and JWT keys) + # - ./certs:/app/mcpgateway/certs:ro # mount certs folder read-only (includes both SSL and JWT keys) + # + # SSL/TLS Certificate Setup: + # 1. Generate certificates: + # - Without passphrase: make certs + # - With passphrase: make certs-passphrase + # 2. Uncomment the volumes mount above + # 3. Set SSL environment variables + # 4. If using passphrase-protected key, set KEY_FILE_PASSWORD in .env file # # For JWT asymmetric keys: # 1. Generate keys: make certs-jwt diff --git a/docs/docs/architecture/plugins.md b/docs/docs/architecture/plugins.md index b34b8ffab..d2ba0da3d 100644 --- a/docs/docs/architecture/plugins.md +++ b/docs/docs/architecture/plugins.md @@ -103,7 +103,7 @@ flowchart TB subgraph "External Services" AI["AI Safety Services
(LlamaGuard, OpenAI)"] - Security["Security Services
(Vault, OPA)"] + Security["Security Services
(Vault, OPA, Cedar)"] end Client --request--> GW @@ -141,7 +141,7 @@ The framework supports two distinct plugin deployment patterns: - Standalone MCP servers implementing plugin logic - Can be written in any language (Python, TypeScript, Go, Rust, etc.) - Communicate via MCP protocol (Streamable HTTP, STDIO, SSE) -- Examples: OPA filter, LlamaGuard, OpenAI Moderation, custom AI services +- Examples: OPA filter, Cedar Policy Plugin (RBAC), LlamaGuard, OpenAI Moderation, custom AI services ### Plugin Configuration Schema @@ -1721,6 +1721,7 @@ FEDERATION_POST_SYNC = "federation_post_sync" # Post-federation processing #### Current Integrations - ✅ **Open Policy Agent (OPA):** Policy-as-code enforcement engine +- ✅ **Cedar Policy Plugin:** Policy-as-code enforcement engine, RBAC - ✅ **LlamaGuard:** Content safety classification and filtering - ✅ **OpenAI Moderation API:** Commercial content moderation - ✅ **Custom MCP Servers:** Any language, any protocol diff --git a/docs/docs/architecture/roadmap.md b/docs/docs/architecture/roadmap.md index b099bd446..afee9508b 100644 --- a/docs/docs/architecture/roadmap.md +++ b/docs/docs/architecture/roadmap.md @@ -8,13 +8,13 @@ | Release | Due Date | Completion | Status | Description | | ------- | ----------- | ---------- | ------------ | ----------- | -| 1.6.0 | 02 Jun 2026 | 0 % | Open | New MCP Servers and Agents | -| 1.5.0 | 05 May 2026 | 0 % | Open | Documentation, Technical Debt, Bugfixes | -| 1.4.0 | 07 Apr 2026 | 0 % | Open | Technical Debt and Quality | -| 1.3.0 | 03 Mar 2026 | 0 % | Open | Catalog Improvements, A2A Improvements, MCP Standard Review and Sync, Technical Debt | -| 1.2.0 | 03 Feb 2026 | 0 % | Open | Release 1.2.0 - Catalog Enhancements, Ratings, experience and UI | -| 1.1.0 | 06 Jan 2026 | 0 % | Open | Post-GA Testing, Bugfixing, Documentation, Performance and Scale | -| 1.0.0 | 02 Dec 2025 | 4 % | Open | Release 1.0 General Availability & Release Candidate Hardening - stable & audited | +| 1.3.0 | 26 May 2026 | 0 % | Open | New MCP Servers and Agents | +| 1.2.0 | 28 Apr 2026 | 0 % | Open | Documentation, Technical Debt, Bugfixes | +| 1.1.0 | 31 Mar 2026 | 0 % | Open | Technical Debt and Quality | +| 1.0.0-GA | 24 Feb 2026 | 0 % | Open | Catalog Improvements, A2A Improvements, MCP Standard Review and Sync, Technical Debt | +| 1.0.0-RC1 | 03 Feb 2026 | 0 % | Open | Release Candidate 1 - Catalog Enhancements, Ratings, experience and UI | +| 1.0.0-BETA-2 | 20 Jan 2026 | 1 % | Open | Testing, Bugfixing, Documentation, Performance and Scale | +| 1.0.0-BETA-1 | 16 Dec 2025 | 12 % | Open | Release 1.0.0-BETA-1 | | 0.9.0 | 04 Nov 2025 | 36 % | Open | Interoperability, marketplaces & advanced connectivity | | 0.8.0 | 07 Oct 2025 | 100 % | **Closed** | Enterprise Security & Policy Guardrails | | 0.7.0 | 16 Sep 2025 | 100 % | **Closed** | Multitenancy and RBAC (Private/Team/Global catalogs), Extended Connectivity, Core Observability & Starter Agents (OpenAI and A2A) | @@ -27,10 +27,10 @@ --- -## Release 1.6.0 +## Release 1.3.0 -!!! warning "Release 1.6.0 - In Progress (0%)" - **Due:** 02 Jun 2026 | **Status:** Open +!!! warning "Release 1.3.0 - In Progress (0%)" + **Due:** 26 May 2026 | **Status:** Open New MCP Servers and Agents ???+ info "✨ Features - Remaining (1)" @@ -40,28 +40,28 @@ --- -## Release 1.5.0 +## Release 1.2.0 -!!! warning "Release 1.5.0 - In Progress (0%)" - **Due:** 05 May 2026 | **Status:** Open +!!! warning "Release 1.2.0 - In Progress (0%)" + **Due:** 28 Apr 2026 | **Status:** Open Documentation, Technical Debt, Bugfixes --- -## Release 1.4.0 +## Release 1.1.0 -!!! warning "Release 1.4.0 - In Progress (0%)" - **Due:** 07 Apr 2026 | **Status:** Open +!!! warning "Release 1.1.0 - In Progress (0%)" + **Due:** 31 Mar 2026 | **Status:** Open Technical Debt and Quality --- -## Release 1.3.0 +## Release 1.0.0-GA -!!! warning "Release 1.3.0 - In Progress (0%)" - **Due:** 03 Mar 2026 | **Status:** Open +!!! warning "Release 1.0.0-GA - In Progress (0%)" + **Due:** 24 Feb 2026 | **Status:** Open Catalog Improvements, A2A Improvements, MCP Standard Review and Sync, Technical Debt ???+ info "✨ Features - Remaining (1)" @@ -71,11 +71,11 @@ --- -## Release 1.2.0 +## Release 1.0.0-RC1 -!!! warning "Release 1.2.0 - In Progress (0%)" +!!! warning "Release 1.0.0-RC1 - In Progress (0%)" **Due:** 03 Feb 2026 | **Status:** Open - Release 1.2.0 - Catalog Enhancements, Ratings, experience and UI + Release Candidate 1 - Catalog Enhancements, Ratings, experience and UI ???+ info "✨ Features - Remaining (3)" @@ -86,11 +86,11 @@ --- -## Release 1.1.0 +## Release 1.0.0-BETA-2 -!!! warning "Release 1.1.0 - In Progress (0%)" - **Due:** 06 Jan 2026 | **Status:** Open - Post-GA Testing, Bugfixing, Documentation, Performance and Scale +!!! warning "Release 1.0.0-BETA-2 - In Progress (1%)" + **Due:** 20 Jan 2026 | **Status:** Open + Testing, Bugfixing, Documentation, Performance and Scale ???+ info "✨ Features - Remaining (38)" @@ -144,11 +144,11 @@ --- -## Release 1.0.0 +## Release 1.0.0-BETA-1 -!!! warning "Release 1.0.0 - In Progress (4%)" - **Due:** 02 Dec 2025 | **Status:** Open - Release 1.0 General Availability & Release Candidate Hardening - stable & audited +!!! warning "Release 1.0.0-BETA-1 - In Progress (12%)" + **Due:** 16 Dec 2025 | **Status:** Open + Release 1.0.0-BETA-1 ???+ info "📋 Epics - Remaining (12)" diff --git a/docs/docs/deployment/.pages b/docs/docs/deployment/.pages index f7e568e00..2e093a6ad 100644 --- a/docs/docs/deployment/.pages +++ b/docs/docs/deployment/.pages @@ -14,3 +14,4 @@ nav: - azure.md - fly-io.md - proxy-auth.md + - cforge-gateway.md diff --git a/docs/docs/deployment/cforge-gateway.md b/docs/docs/deployment/cforge-gateway.md new file mode 100644 index 000000000..12085caf9 --- /dev/null +++ b/docs/docs/deployment/cforge-gateway.md @@ -0,0 +1,2099 @@ +# cforge gateway - Deployment Tool + +## Overview + +The `cforge gateway` command is a powerful deployment tool for MCP Gateway and its external plugins. It provides a unified, declarative way to build, configure, and deploy the complete MCP stack from a single YAML configuration file. + +--- + +## Quick Start + +### Installation + +The `cforge` CLI is installed with the MCP Gateway package: + +```bash +pip install -e . +``` + +Verify installation: + +```bash +cforge --help +cforge gateway --help +``` + +### Basic Workflow + +```bash +# 1. Validate your configuration +cforge gateway validate examples/deployment-configs/deploy-compose.yaml + +# 2. Build containers (if building from source) +cforge gateway build examples/deployment-configs/deploy-compose.yaml + +# 3. Generate mTLS certificates (if needed) +cforge gateway certs examples/deployment-configs/deploy-compose.yaml + +# 4. Deploy the stack +cforge gateway deploy examples/deployment-configs/deploy-compose.yaml + +# 5. Verify deployment health +cforge gateway verify examples/deployment-configs/deploy-compose.yaml + +# 6. (Optional) Tear down +cforge gateway destroy examples/deployment-configs/deploy-compose.yaml +``` + +--- + +## Simple Configuration Example + +The `cforge gateway` tool uses **custom YAML configuration files** to describe your deployment. These are **not** standard Docker Compose or Kubernetes manifests - instead, `cforge` reads these configuration files and generates the actual deployment manifests for your target environment. + +Here's a minimal example configuration that demonstrates the key components: + +```yaml +deployment: + type: compose # Target: 'compose' or 'kubernetes' + project_name: mcp-stack-test + +gateway: + image: mcpgateway/mcpgateway:latest # Use pre-built image + port: 4444 + host_port: 4444 # Expose on localhost:4444 + + env_vars: + LOG_LEVEL: DEBUG + MCPGATEWAY_UI_ENABLED: "true" + AUTH_REQUIRED: "false" # Simplified for testing + + mtls_enabled: false # Disable mTLS for simple setup + +plugins: + - name: OPAPluginFilter + repo: https://github.com/terylt/mcp-context-forge.git + ref: feat/use_mtls_plugins # Git branch/tag/commit + context: plugins/external/opa # Build context path + containerfile: Containerfile + + expose_port: true + mtls_enabled: false + + plugin_overrides: + priority: 10 + mode: "enforce" + description: "OPA policy enforcement" + +certificates: + auto_generate: true # Auto-generate certs if needed +``` + +**Key sections explained:** + +- **deployment**: Specifies the target environment (Docker Compose or Kubernetes) and basic settings +- **gateway**: Defines the MCP Gateway configuration - can use a pre-built image or build from a Git repository +- **plugins**: Array of external plugins to deploy. Each plugin can be built from source or use pre-built images +- **certificates**: mTLS certificate configuration (auto-generated by default) + +**How it works:** + +When you run `cforge gateway deploy `, the tool: +1. Reads your custom configuration YAML +2. Builds container images (if building from source) +3. Generates mTLS certificates (if needed) +4. **Generates actual deployment files**: + - For `type: compose` → `deploy/docker-compose.yaml` + - For `type: kubernetes` → `deploy/manifests/*.yaml` (Deployment, Service, ConfigMap, etc.) +5. Deploys the generated manifests to your target environment + +**Additional example configurations** are available in `examples/deployment-configs/`: +- `deploy-compose.yaml` - Docker Compose without mTLS +- `deploy-compose.mtls.yaml` - Docker Compose with mTLS +- `deploy-k8s.yaml` - Kubernetes with pre-built images +- `deploy-k8s-cert-manager.yaml` - Kubernetes with cert-manager integration +- More examples for OpenShift, registry integration, and advanced scenarios + +See the [Example Configurations](#example-configurations) section below for detailed examples with full explanations. + +--- + +## Commands + +### `cforge gateway validate` + +Validates the deployment configuration file without making any changes. + +```bash +cforge gateway validate +``` + +**Example:** +```bash +cforge gateway validate deploy.yaml +``` + +**Output:** +- ✅ Configuration syntax validation +- ✅ Plugin name uniqueness check +- ✅ Required field verification +- ✅ Build configuration validation (image XOR repo) + +--- + +### `cforge gateway build` + +Builds container images for gateway and/or plugins from source repositories. + +```bash +cforge gateway build [OPTIONS] +``` + +**Options:** + +| Option | Description | Default | +|--------|-------------|---------| +| `--plugins-only` | Only build plugin containers, skip gateway | `false` | +| `--plugin NAME`, `-p NAME` | Build specific plugin(s) only (can specify multiple) | All plugins | +| `--no-cache` | Disable Docker build cache | `false` | +| `--copy-env-templates` | Copy `.env.template` files from plugin repos | `true` | + +**Examples:** +```bash +# Build everything +cforge gateway build deploy.yaml + +# Build only plugins +cforge gateway build deploy.yaml --plugins-only + +# Build specific plugin +cforge gateway build deploy.yaml --plugin OPAPluginFilter + +# Build multiple plugins with no cache +cforge gateway build deploy.yaml --plugin OPAPluginFilter --plugin LLMGuardPlugin --no-cache +``` + +**What it does:** +1. Clones Git repositories (if `repo` specified) +2. Checks out specified branch/tag/commit (`ref`) +3. Builds Docker images from `containerfile` in `context` directory +4. Tags images appropriately for deployment +5. Copies `.env.template` files to `deploy/env/` for customization + +--- + +### `cforge gateway certs` + +Generates mTLS certificate hierarchy for secure gateway ↔ plugin communication. + +```bash +cforge gateway certs +``` + +**Example:** +```bash +cforge gateway certs deploy.yaml +``` + +**What it generates:** +``` +certs/mcp/ +├── ca/ +│ ├── ca.crt # Root CA certificate +│ └── ca.key # Root CA private key +├── gateway/ +│ ├── client.crt # Gateway client certificate +│ ├── client.key # Gateway client private key +│ └── ca.crt # CA cert (for verification) +└── plugins/ + ├── PluginName1/ + │ ├── server.crt # Plugin server certificate + │ ├── server.key # Plugin server private key + │ └── ca.crt # CA cert (for verification) + └── PluginName2/ + ├── server.crt + ├── server.key + └── ca.crt +``` + +**Certificate Properties:** +- Validity: Configurable (default: 825 days) +- CN for gateway: `mcp-gateway` +- CN for plugins: `mcp-plugin-{PluginName}` +- SANs: `{PluginName}, mcp-plugin-{PluginName}, localhost` + +--- + +### `cforge gateway deploy` + +Deploys the complete MCP stack to the target environment. + +```bash +cforge gateway deploy [OPTIONS] +``` + +**Options:** + +| Option | Description | Default | +|--------|-------------|---------| +| `--output-dir DIR`, `-o DIR` | Custom output directory for manifests | `deploy/` | +| `--dry-run` | Generate manifests without deploying | `false` | +| `--skip-build` | Skip container build step | `false` | +| `--skip-certs` | Skip certificate generation | `false` | + +**Examples:** +```bash +# Full deployment +cforge gateway deploy deploy.yaml + +# Dry-run (generate manifests only) +cforge gateway deploy deploy.yaml --dry-run + +# Deploy with existing images and certs +cforge gateway deploy deploy.yaml --skip-build --skip-certs + +# Custom output directory +cforge gateway deploy deploy.yaml --output-dir ./my-deployment +``` + +**Deployment Process:** +1. **Validate** configuration +2. **Build** containers (unless `--skip-build`) +3. **Generate certificates** (unless `--skip-certs` or already exist) +4. **Generate manifests** (Kubernetes or Docker Compose) +5. **Apply** to target environment: + - **Kubernetes**: `kubectl apply -f` + - **Docker Compose**: `docker-compose up -d` + +**Generated Files:** +``` +deploy/ +├── env/ # Environment files +│ ├── .env.gateway +│ ├── .env.PluginName1 +│ └── .env.PluginName2 +├── manifests/ # Kubernetes OR +│ ├── namespace.yaml +│ ├── configmaps.yaml +│ ├── secrets.yaml +│ ├── gateway-deployment.yaml +│ ├── gateway-service.yaml +│ ├── plugin-deployments.yaml +│ └── plugin-services.yaml +└── docker-compose.yaml # Docker Compose +``` + +--- + +### `cforge gateway verify` + +Verifies that the deployed stack is healthy and running. + +```bash +cforge gateway verify [OPTIONS] +``` + +**Options:** + +| Option | Description | Default | +|--------|-------------|---------| +| `--wait` | Wait for deployment to be ready | `true` | +| `--timeout SECONDS` | Wait timeout in seconds | `300` | + +**Examples:** +```bash +# Verify deployment (wait up to 5 minutes) +cforge gateway verify deploy.yaml + +# Quick check without waiting +cforge gateway verify deploy.yaml --no-wait + +# Custom timeout +cforge gateway verify deploy.yaml --timeout 600 +``` + +**Checks:** +- Container/pod readiness +- Health endpoint responses +- Service connectivity +- mTLS handshake (if enabled) + +--- + +### `cforge gateway destroy` + +Tears down the deployed MCP stack. + +```bash +cforge gateway destroy [OPTIONS] +``` + +**Options:** + +| Option | Description | Default | +|--------|-------------|---------| +| `--force` | Skip confirmation prompt | `false` | + +**Examples:** +```bash +# Destroy with confirmation +cforge gateway destroy deploy.yaml + +# Force destroy without prompt +cforge gateway destroy deploy.yaml --force +``` + +**What it removes:** +- **Kubernetes**: Deletes all resources in namespace +- **Docker Compose**: Stops and removes containers, networks, volumes + +⚠️ **Note:** This does NOT delete generated certificates or build artifacts. To clean those: +```bash +rm -rf certs/ deploy/ +``` + +--- + +### `cforge gateway generate` + +Generates deployment manifests without deploying them. + +```bash +cforge gateway generate [OPTIONS] +``` + +**Options:** + +| Option | Description | Default | +|--------|-------------|---------| +| `--output DIR`, `-o DIR` | Output directory for manifests | `deploy/` | + +**Examples:** +```bash +# Generate manifests +cforge gateway generate deploy.yaml + +# Custom output directory +cforge gateway generate deploy.yaml --output ./manifests +``` + +**Use cases:** +- GitOps workflows (commit generated manifests) +- Manual review before deployment +- Integration with external deployment tools +- CI/CD pipeline artifact generation + +--- + +### `cforge gateway version` + +Shows version and runtime information. + +```bash +cforge gateway version +``` + +**Output:** +``` +┌─ Version Info ─────────────────┐ +│ MCP Deploy │ +│ Version: 1.0.0 │ +│ Mode: dagger │ +│ Environment: local │ +└────────────────────────────────┘ +``` + +--- + +## Global Options + +These options apply to all commands: + +| Option | Description | Default | +|--------|-------------|---------| +| `--dagger` | Enable Dagger mode (auto-downloads CLI if needed) | `false` (uses plain Python) | +| `--verbose`, `-v` | Verbose output | `false` | + +**Examples:** +```bash +# Use plain Python mode (default) +cforge gateway deploy deploy.yaml + +# Enable Dagger mode for optimized builds +cforge gateway --dagger deploy deploy.yaml + +# Verbose mode +cforge gateway -v build deploy.yaml + +# Combine options +cforge gateway --dagger -v deploy deploy.yaml +``` + +--- + +## Configuration Reference + +### Deployment Configuration + +Top-level deployment settings: + +```yaml +deployment: + type: kubernetes | compose # Required: Deployment target + project_name: my-project # Docker Compose only + namespace: mcp-gateway # Kubernetes only + container_engine: podman | docker # Container runtime (auto-detected if not specified) + + # OpenShift-specific configuration (optional) + openshift: + create_routes: true # Create OpenShift Route resources + domain: apps-crc.testing # OpenShift apps domain (auto-detected if omitted) + tls_termination: edge # TLS termination mode: edge, passthrough, or reencrypt +``` + +| Field | Type | Required | Description | Default | +|-------|------|----------|-------------|---------| +| `type` | string | ✅ | Deployment type: `kubernetes` or `compose` | - | +| `project_name` | string | ❌ | Docker Compose project name | - | +| `namespace` | string | ❌ | Kubernetes namespace | - | +| `container_engine` | string | ❌ | Container runtime: `docker` or `podman` | Auto-detected | +| `openshift` | object | ❌ | OpenShift-specific configuration (see below) | - | + +#### OpenShift Configuration + +OpenShift Routes provide native external access to services, with built-in TLS termination and integration with OpenShift's router/HAProxy infrastructure. + +| Field | Type | Required | Description | Default | +|-------|------|----------|-------------|---------| +| `create_routes` | boolean | ❌ | Create OpenShift Route resources for external access | `false` | +| `domain` | string | ❌ | OpenShift apps domain for route hostnames | Auto-detected from cluster | +| `tls_termination` | string | ❌ | TLS termination mode: `edge`, `passthrough`, or `reencrypt` | `edge` | + +**Example:** +```yaml +deployment: + type: kubernetes + namespace: mcp-gateway-test + openshift: + create_routes: true + domain: apps-crc.testing + tls_termination: edge +``` + +When `create_routes: true`, the tool generates an OpenShift Route for the gateway: +- **Host**: `mcpgateway-admin-{namespace}.{domain}` +- **Path**: `/` +- **TLS**: Edge termination (default) +- **Target**: Gateway service on HTTP port + +**Access the gateway:** +```bash +# OpenShift Local (CRC) example +https://mcpgateway-admin-mcp-gateway-test.apps-crc.testing +``` + +**Domain auto-detection:** +If `domain` is not specified, the tool attempts to auto-detect the OpenShift apps domain from the cluster: +```bash +kubectl get ingresses.config.openshift.io cluster -o jsonpath='{.spec.domain}' +``` + +If auto-detection fails, it defaults to `apps-crc.testing` (OpenShift Local). + +--- + +### Gateway Configuration + +Gateway server settings: + +```yaml +gateway: + # Build Configuration (choose ONE) + image: mcpgateway/mcpgateway:latest # Pre-built image + # OR + repo: https://github.com/org/repo.git # Build from source + ref: main # Git branch/tag/commit + context: . # Build context directory + containerfile: Containerfile # Dockerfile path + target: production # Multi-stage build target + + # Runtime Configuration + port: 4444 # Internal port + host_port: 4444 # Host port mapping (compose only) + + # mTLS Client Configuration (gateway → plugins) + mtls_enabled: true # Enable mTLS + mtls_verify: true # Verify server certs + mtls_check_hostname: false # Verify hostname + + # Container Registry Configuration (optional) + registry: + enabled: true # Enable registry push + url: registry.example.com # Registry URL + namespace: myproject # Registry namespace/org + push: true # Push after build + image_pull_policy: IfNotPresent # Kubernetes imagePullPolicy + + # Environment Variables + env_vars: + LOG_LEVEL: INFO + MCPGATEWAY_UI_ENABLED: "true" + AUTH_REQUIRED: "true" + # ... (see full reference below) + + # Kubernetes-specific + replicas: 2 # Number of replicas + service_type: ClusterIP # Service type + service_port: 4444 # Service port + memory_request: 256Mi # Memory request + memory_limit: 512Mi # Memory limit + cpu_request: 100m # CPU request + cpu_limit: 500m # CPU limit + image_pull_policy: IfNotPresent # Image pull policy +``` + +**Build Configuration Fields:** + +| Field | Type | Required | Description | Default | +|-------|------|----------|-------------|---------| +| `image` | string | ❌* | Pre-built Docker image | - | +| `repo` | string | ❌* | Git repository URL | - | +| `ref` | string | ❌ | Git branch/tag/commit | `main` | +| `context` | string | ❌ | Build context subdirectory | `.` | +| `containerfile` | string | ❌ | Containerfile/Dockerfile path | `Containerfile` | +| `target` | string | ❌ | Multi-stage build target | - | + +\* **Either `image` OR `repo` must be specified** + +**Runtime Configuration Fields:** + +| Field | Type | Required | Description | Default | +|-------|------|----------|-------------|---------| +| `port` | integer | ❌ | Internal container port | `4444` | +| `host_port` | integer | ❌ | Host port mapping (compose only) | - | +| `env_vars` | object | ❌ | Environment variables | `{}` | +| `mtls_enabled` | boolean | ❌ | Enable mTLS client | `true` | +| `mtls_verify` | boolean | ❌ | Verify server certificates | `true` | +| `mtls_check_hostname` | boolean | ❌ | Verify hostname in cert | `false` | +| `registry` | object | ❌ | Container registry configuration | - | + +**Container Registry Configuration Fields:** + +| Field | Type | Required | Description | Default | +|-------|------|----------|-------------|---------| +| `enabled` | boolean | ❌ | Enable registry integration | `false` | +| `url` | string | ❌* | Registry URL (e.g., `docker.io`, `quay.io`, OpenShift registry) | - | +| `namespace` | string | ❌* | Registry namespace/organization/project | - | +| `push` | boolean | ❌ | Push image to registry after build | `true` | +| `image_pull_policy` | string | ❌ | Kubernetes imagePullPolicy (`Always`, `IfNotPresent`, `Never`) | `IfNotPresent` | + +\* Required when `enabled: true` + +**Kubernetes-specific Fields:** + +| Field | Type | Required | Description | Default | +|-------|------|----------|-------------|---------| +| `replicas` | integer | ❌ | Number of pod replicas | `1` | +| `service_type` | string | ❌ | Service type (ClusterIP, NodePort, LoadBalancer) | `ClusterIP` | +| `service_port` | integer | ❌ | Service port | `4444` | +| `memory_request` | string | ❌ | Memory request | `256Mi` | +| `memory_limit` | string | ❌ | Memory limit | `512Mi` | +| `cpu_request` | string | ❌ | CPU request | `100m` | +| `cpu_limit` | string | ❌ | CPU limit | `500m` | +| `image_pull_policy` | string | ❌ | Image pull policy | `IfNotPresent` | + +--- + +### Plugin Configuration + +External plugin settings (array of plugin objects): + +```yaml +plugins: + - name: MyPlugin # Required: Unique plugin name + + # Build Configuration (choose ONE) + image: myorg/myplugin:latest # Pre-built image + # OR + repo: https://github.com/org/repo.git # Build from source + ref: main + context: plugins/myplugin + containerfile: Containerfile + target: builder + + # Runtime Configuration + port: 8000 # Internal port + expose_port: true # Expose on host (compose only) + + # mTLS Server Configuration (plugin server) + mtls_enabled: true # Enable mTLS server + + # Container Registry Configuration (optional) + registry: + enabled: true # Enable registry push + url: registry.example.com # Registry URL + namespace: myproject # Registry namespace/org + push: true # Push after build + image_pull_policy: IfNotPresent # Kubernetes imagePullPolicy + + # Environment Variables + env_vars: + LOG_LEVEL: DEBUG + CUSTOM_SETTING: value + + # Plugin Manager Overrides (client-side) + plugin_overrides: + priority: 10 + mode: enforce + description: "My custom plugin" + tags: ["security", "filter"] + + # Kubernetes-specific + replicas: 1 + service_type: ClusterIP + service_port: 8000 + memory_request: 128Mi + memory_limit: 256Mi + cpu_request: 50m + cpu_limit: 200m + image_pull_policy: IfNotPresent +``` + +**Required Fields:** + +| Field | Type | Description | +|-------|------|-------------| +| `name` | string | Unique plugin identifier (used for cert CN, service names, etc.) | + +**Build Configuration:** Same as Gateway (see above) + +**Runtime Configuration:** + +| Field | Type | Required | Description | Default | +|-------|------|----------|-------------|---------| +| `port` | integer | ❌ | Internal container port | `8000` | +| `expose_port` | boolean | ❌ | Expose port on host (compose only) | `false` | +| `env_vars` | object | ❌ | Environment variables | `{}` | +| `mtls_enabled` | boolean | ❌ | Enable mTLS server | `true` | +| `registry` | object | ❌ | Container registry configuration (same fields as gateway) | - | +| `plugin_overrides` | object | ❌ | Plugin manager config overrides | `{}` | + +**Plugin Overrides:** + +| Field | Type | Description | Default | +|-------|------|-------------|---------| +| `priority` | integer | Plugin execution priority (lower = earlier) | - | +| `mode` | string | `enforce`, `monitor`, or `dry-run` | - | +| `description` | string | Plugin description | - | +| `tags` | array | Plugin tags for categorization | - | +| `hooks` | array | Enabled hooks: `prompt_pre_fetch`, `tool_pre_invoke`, etc. | All hooks | + +**Kubernetes-specific:** Same as Gateway (see above) + +--- + +### Certificate Configuration + +mTLS certificate generation settings: + +```yaml +certificates: + # Local certificate generation (default) + validity_days: 825 # Certificate validity period + auto_generate: true # Auto-generate if missing + ca_path: ./certs/mcp/ca # CA certificate directory + gateway_path: ./certs/mcp/gateway # Gateway cert directory + plugins_path: ./certs/mcp/plugins # Plugins cert directory + + # OR use cert-manager (Kubernetes only) + use_cert_manager: true # Use cert-manager for certificates + cert_manager_issuer: mcp-ca-issuer # Issuer/ClusterIssuer name + cert_manager_kind: Issuer # Issuer or ClusterIssuer +``` + +| Field | Type | Required | Description | Default | +|-------|------|----------|-------------|---------| +| `validity_days` | integer | ❌ | Certificate validity in days | `825` | +| `auto_generate` | boolean | ❌ | Auto-generate certificates locally if missing | `true` | +| `ca_path` | string | ❌ | CA certificate directory (local mode) | `./certs/mcp/ca` | +| `gateway_path` | string | ❌ | Gateway client cert directory (local mode) | `./certs/mcp/gateway` | +| `plugins_path` | string | ❌ | Plugin server certs base directory (local mode) | `./certs/mcp/plugins` | +| `use_cert_manager` | boolean | ❌ | Use cert-manager for certificate management (Kubernetes only) | `false` | +| `cert_manager_issuer` | string | ❌ | cert-manager Issuer/ClusterIssuer name | `mcp-ca-issuer` | +| `cert_manager_kind` | string | ❌ | cert-manager issuer kind: `Issuer` or `ClusterIssuer` | `Issuer` | + +#### cert-manager Integration (Kubernetes Only) + +[cert-manager](https://cert-manager.io) is a Kubernetes-native certificate management controller that automates certificate issuance and renewal. + +**Benefits:** +- ✅ **Automatic Renewal**: Certificates renewed before expiry (default: at 2/3 of lifetime) +- ✅ **Native Kubernetes**: Certificates defined as Kubernetes Custom Resources +- ✅ **Simplified Operations**: No manual certificate generation or rotation +- ✅ **GitOps Friendly**: Certificate definitions version-controlled + +**Prerequisites:** +1. Install cert-manager in your cluster: + ```bash + kubectl apply -f https://github.com/cert-manager/cert-manager/releases/download/v1.13.0/cert-manager.yaml + ``` + +2. Create namespace and CA Issuer (one-time setup): + ```bash + # Create namespace first + kubectl create namespace mcp-gateway-test + + # Apply CA Issuer + kubectl apply -f examples/deployment-configs/cert-manager-issuer-example.yaml + ``` + +**Configuration:** +```yaml +certificates: + use_cert_manager: true + cert_manager_issuer: mcp-ca-issuer + cert_manager_kind: Issuer + validity_days: 825 +``` + +When `use_cert_manager: true`: +- Local certificate generation is skipped +- cert-manager Certificate CRDs are generated for gateway and plugins +- cert-manager automatically creates Kubernetes TLS secrets +- Certificates are auto-renewed before expiry + +**Important**: The cert-manager Issuer and CA certificate are long-lived infrastructure. When you destroy your MCP deployment, the Issuer remains (by design) for reuse across deployments. + +--- + +### Infrastructure Services + +PostgreSQL and Redis are **automatically deployed** with the MCP Gateway stack using hardcoded defaults: + +**PostgreSQL (always deployed):** +- Image: `postgres:17` +- Database: `mcp` +- User: `postgres` +- Password: `mysecretpassword` (override with `POSTGRES_PASSWORD` env var) +- Port: `5432` +- Kubernetes: Uses 10Gi PVC + +**Redis (always deployed):** +- Image: `redis:latest` +- Port: `6379` + +**Connection strings (auto-configured):** +```bash +DATABASE_URL=postgresql://postgres:${POSTGRES_PASSWORD}@postgres:5432/mcp +REDIS_URL=redis://redis:6379/0 +``` + +These services are included in all deployments and cannot currently be disabled or customized via the deployment YAML. To customize PostgreSQL password: + +```bash +# Set before deploying +export POSTGRES_PASSWORD=your-secure-password +cforge gateway deploy deploy.yaml +``` + +--- + +## Example Configurations + +### Example 1: Docker Compose (No mTLS) + +**File:** `examples/deployment-configs/deploy-compose.yaml` + +Simple local deployment for development and testing: + +```yaml +deployment: + type: compose + project_name: mcp-stack-test + +gateway: + image: mcpgateway/mcpgateway:latest + port: 4444 + host_port: 4444 + + env_vars: + LOG_LEVEL: DEBUG + MCPGATEWAY_UI_ENABLED: "true" + AUTH_REQUIRED: "false" + + mtls_enabled: false + +plugins: + - name: OPAPluginFilter + repo: https://github.com/terylt/mcp-context-forge.git + ref: feat/use_mtls_plugins + context: plugins/external/opa + + expose_port: true + mtls_enabled: false + + plugin_overrides: + priority: 10 + mode: "enforce" + +certificates: + auto_generate: true +``` + +**Use case:** Quick local testing without security overhead + +**Deploy:** +```bash +cforge gateway deploy examples/deployment-configs/deploy-compose.yaml +``` + +**Access:** +- Gateway: http://localhost:4444 +- Admin UI: http://localhost:4444/admin +- Plugin (exposed): http://localhost:8000 + +--- + +### Example 2: Docker Compose (With mTLS) + +**File:** `examples/deployment-configs/deploy-compose.mtls.yaml` + +Secure local deployment with mutual TLS: + +```yaml +deployment: + type: compose + project_name: mcp-stack-test + +gateway: + image: mcpgateway/mcpgateway:latest + port: 4444 + host_port: 4444 + + mtls_enabled: true # ← Enable mTLS client + mtls_verify: true + mtls_check_hostname: false # Don't verify hostname for localhost + +plugins: + - name: OPAPluginFilter + repo: https://github.com/terylt/mcp-context-forge.git + ref: feat/use_mtls_plugins + context: plugins/external/opa + + mtls_enabled: true # ← Enable mTLS server + + plugin_overrides: + priority: 10 + mode: "enforce" + +certificates: + validity_days: 825 + auto_generate: true # Auto-generate mTLS certs +``` + +**Use case:** Local testing with production-like security + +**Deploy:** +```bash +# Certificates are auto-generated during deploy +cforge gateway deploy examples/deployment-configs/deploy-compose.mtls.yaml +``` + +**How mTLS works:** +1. `cforge gateway certs` generates CA + gateway client cert + plugin server certs +2. Gateway connects to plugins using client certificate +3. Plugins verify gateway's client certificate against CA +4. All communication is encrypted and mutually authenticated + +--- + +### Example 3: Kubernetes (Pre-built Images) + +**File:** `examples/deployment-configs/deploy-k8s.yaml` + +Production-ready Kubernetes deployment using pre-built images: + +```yaml +deployment: + type: kubernetes + namespace: mcp-gateway-prod + +gateway: + image: mcpgateway/mcpgateway:latest + image_pull_policy: IfNotPresent + + replicas: 2 # High availability + service_type: LoadBalancer + service_port: 4444 + + memory_request: 256Mi + memory_limit: 512Mi + cpu_request: 100m + cpu_limit: 500m + + mtls_enabled: true + +plugins: + - name: OPAPluginFilter + image: mcpgateway-opapluginfilter:latest + image_pull_policy: IfNotPresent + + replicas: 2 + service_type: ClusterIP + + memory_request: 128Mi + memory_limit: 256Mi + cpu_request: 50m + cpu_limit: 200m + + mtls_enabled: true + + plugin_overrides: + priority: 10 + mode: "enforce" + +infrastructure: + postgres: + enabled: true + storage_size: 20Gi + storage_class: fast-ssd + redis: + enabled: true + +certificates: + auto_generate: true +``` + +**Use case:** Production deployment with HA and resource limits + +**Deploy:** +```bash +# Deploy to Kubernetes +cforge gateway deploy examples/deployment-configs/deploy-k8s.yaml + +# Verify +kubectl get all -n mcp-gateway-prod + +# Check logs +kubectl logs -n mcp-gateway-prod -l app=mcp-gateway +``` + +--- + +### Example 4: Kubernetes (Build from Source) + +Building plugins from Git repositories in Kubernetes: + +```yaml +deployment: + type: kubernetes + namespace: mcp-gateway-dev + +gateway: + image: mcpgateway/mcpgateway:latest + +plugins: + - name: OPAPluginFilter + # Build from source + repo: https://github.com/terylt/mcp-context-forge.git + ref: feat/use_mtls_plugins + context: plugins/external/opa + containerfile: Containerfile + + # Push to registry (configure with env vars) + # See DOCKER_REGISTRY in deploy process + + replicas: 1 + mtls_enabled: true + +certificates: + auto_generate: true +``` + +**Deploy:** +```bash +# Build locally and push to registry +export DOCKER_REGISTRY=myregistry.io/myorg +cforge gateway build deploy-k8s-build.yaml + +# Deploy to Kubernetes +cforge gateway deploy deploy-k8s-build.yaml --skip-build +``` + +--- + +### Example 5: Kubernetes with cert-manager + +**File:** `examples/deployment-configs/deploy-k8s-cert-manager.yaml` + +Production deployment using cert-manager for automated certificate management: + +```yaml +deployment: + type: kubernetes + namespace: mcp-gateway-test + +gateway: + image: mcpgateway/mcpgateway:latest + image_pull_policy: IfNotPresent + + port: 4444 + service_type: ClusterIP + service_port: 4444 + + replicas: 1 + memory_request: 256Mi + memory_limit: 512Mi + cpu_request: 100m + cpu_limit: 500m + + env_vars: + LOG_LEVEL: DEBUG + MCPGATEWAY_UI_ENABLED: "true" + + mtls_enabled: true + mtls_verify: true + mtls_check_hostname: false + +plugins: + - name: OPAPluginFilter + image: mcpgateway-opapluginfilter:latest + image_pull_policy: IfNotPresent + + port: 8000 + service_type: ClusterIP + + replicas: 1 + memory_request: 128Mi + memory_limit: 256Mi + + mtls_enabled: true + + plugin_overrides: + priority: 10 + mode: "enforce" + +# cert-manager configuration +certificates: + # Use cert-manager for automatic certificate management + use_cert_manager: true + + # Reference the Issuer created in prerequisites + cert_manager_issuer: mcp-ca-issuer + cert_manager_kind: Issuer + + # Certificate validity (auto-renewed at 2/3 of lifetime) + validity_days: 825 + + # Local paths not used when use_cert_manager=true + auto_generate: false +``` + +**Prerequisites:** + +1. Install cert-manager: + ```bash + kubectl apply -f https://github.com/cert-manager/cert-manager/releases/download/v1.13.0/cert-manager.yaml + ``` + +2. Create namespace and CA Issuer (one-time setup): + ```bash + # Create namespace first + kubectl create namespace mcp-gateway-test + + # Apply CA Issuer + kubectl apply -f examples/deployment-configs/cert-manager-issuer-example.yaml + ``` + +**Deploy:** +```bash +# Deploy (no need to generate certificates manually) +cforge gateway deploy examples/deployment-configs/deploy-k8s-cert-manager.yaml + +# Verify cert-manager created certificates +kubectl get certificates -n mcp-gateway-test +kubectl get secrets -n mcp-gateway-test | grep mcp- +``` + +**How it works:** +1. `cforge gateway deploy` skips local certificate generation +2. Generates cert-manager Certificate CRDs for gateway and plugins +3. Applies Certificate CRDs to Kubernetes +4. cert-manager automatically creates TLS secrets +5. Pods use the secrets created by cert-manager +6. cert-manager auto-renews certificates before expiry + +**Certificate lifecycle:** +- **Creation**: cert-manager generates certificates when CRDs are applied +- **Renewal**: Automatic renewal at 2/3 of lifetime (550 days for 825-day cert) +- **Deletion**: Certificates deleted when stack is destroyed, Issuer remains + +--- + +## mTLS Configuration Guide + +### Understanding mTLS in MCP Gateway + +**mTLS (Mutual TLS)** provides: +- **Encryption**: All gateway ↔ plugin traffic is encrypted +- **Authentication**: Both parties prove their identity +- **Authorization**: Only trusted certificates can communicate + +### Certificate Hierarchy + +``` +CA (Root Certificate Authority) +├── Gateway Client Certificate +│ └── Used by gateway to connect to plugins +└── Plugin Server Certificates (one per plugin) + └── Used by plugins to authenticate gateway +``` + +### Enabling mTLS + +**In your configuration:** + +```yaml +gateway: + mtls_enabled: true # Enable mTLS client + mtls_verify: true # Verify server certificates + mtls_check_hostname: false # Skip hostname verification (for localhost/IPs) + +plugins: + - name: MyPlugin + mtls_enabled: true # Enable mTLS server +``` + +### Certificate Generation + +**Automatic (recommended):** +```yaml +certificates: + auto_generate: true # Auto-generate during deploy + validity_days: 825 # ~2.3 years +``` + +**Manual:** +```bash +# Generate certificates explicitly +cforge gateway certs deploy.yaml + +# Certificates are created in: +# - certs/mcp/ca/ (CA) +# - certs/mcp/gateway/ (gateway client cert) +# - certs/mcp/plugins/*/ (plugin server certs) +``` + +### Environment Variables + +The deployment tool automatically sets these environment variables: + +**Gateway (client):** +```bash +PLUGINS_CLIENT_MTLS_CERTFILE=/certs/gateway/client.crt +PLUGINS_CLIENT_MTLS_KEYFILE=/certs/gateway/client.key +PLUGINS_CLIENT_MTLS_CA_BUNDLE=/certs/gateway/ca.crt +PLUGINS_CLIENT_MTLS_VERIFY=true +PLUGINS_CLIENT_MTLS_CHECK_HOSTNAME=false +``` + +**Plugin (server):** +```bash +PLUGINS_SERVER_SSL_CERTFILE=/certs/server.crt +PLUGINS_SERVER_SSL_KEYFILE=/certs/server.key +PLUGINS_SERVER_SSL_CA_CERTS=/certs/ca.crt +PLUGINS_SERVER_SSL_CERT_REQS=2 # CERT_REQUIRED +``` + +### Troubleshooting mTLS + +**Problem: Certificate verification fails** + +Check certificate validity: +```bash +openssl x509 -in certs/mcp/gateway/client.crt -noout -dates +openssl x509 -in certs/mcp/plugins/MyPlugin/server.crt -noout -dates +``` + +**Problem: Hostname mismatch errors** + +Solution: Set `mtls_check_hostname: false` in gateway config, or use service DNS names + +**Problem: Connection refused** + +- Verify plugin has `mtls_enabled: true` +- Check plugin logs for certificate errors +- Ensure certificates are mounted correctly + +**Problem: Expired certificates** + +Regenerate: +```bash +rm -rf certs/ +cforge gateway certs deploy.yaml +``` + +Then redeploy to distribute new certificates. + +--- + +## Container Registry Integration + +### Overview + +The container registry feature allows you to build images locally and automatically push them to container registries (Docker Hub, Quay.io, OpenShift internal registry, private registries, etc.). This is essential for: + +✅ **Kubernetes/OpenShift deployments** - Avoid ImagePullBackOff errors +✅ **Team collaboration** - Share images across developers and environments +✅ **CI/CD pipelines** - Build once, deploy everywhere +✅ **Production deployments** - Use trusted registry sources + +### How It Works + +1. **Build**: Images are built locally using docker/podman +2. **Tag**: Images are automatically tagged with the registry path +3. **Push**: Images are pushed to the registry (if `push: true`) +4. **Deploy**: Kubernetes manifests reference the registry images + +### Configuration + +Add a `registry` section to your gateway and/or plugin configurations: + +```yaml +gateway: + repo: https://github.com/yourorg/yourrepo.git + + # Container registry configuration + registry: + enabled: true # Enable registry integration + url: registry.example.com # Registry URL + namespace: myproject # Registry namespace/org/project + push: true # Push after build (default: true) + image_pull_policy: IfNotPresent # Kubernetes imagePullPolicy +``` + +**Configuration Fields:** + +| Field | Required | Description | Example | +|-------|----------|-------------|---------| +| `enabled` | Yes | Enable registry push | `true` | +| `url` | Yes* | Registry URL | `docker.io`, `quay.io`, `registry.mycompany.com` | +| `namespace` | Yes* | Registry namespace/organization/project | `myusername`, `myorg`, `mcp-gateway-test` | +| `push` | No | Push image after build | `true` (default) | +| `image_pull_policy` | No | Kubernetes imagePullPolicy | `IfNotPresent` (default) | + +\* Required when `enabled: true` + +### Common Registry Examples + +#### Docker Hub + +```yaml +registry: + enabled: true + url: docker.io + namespace: myusername + push: true + image_pull_policy: IfNotPresent +``` + +**Authentication:** +```bash +docker login +``` + +#### Quay.io + +```yaml +registry: + enabled: true + url: quay.io + namespace: myorganization + push: true + image_pull_policy: IfNotPresent +``` + +**Authentication:** +```bash +podman login quay.io +``` + +#### OpenShift Internal Registry + +```yaml +registry: + enabled: true + url: default-route-openshift-image-registry.apps-crc.testing + namespace: mcp-gateway-test + push: true + image_pull_policy: Always +``` + +**Authentication:** +```bash +# OpenShift Local (CRC) +podman login $(oc registry info) -u $(oc whoami) -p $(oc whoami -t) + +# OpenShift on cloud +oc registry login +``` + +#### Private Registry + +```yaml +registry: + enabled: true + url: registry.mycompany.com + namespace: devteam + push: true + image_pull_policy: IfNotPresent +``` + +**Authentication:** +```bash +podman login registry.mycompany.com -u myusername +``` + +### Image Naming + +When registry is enabled, images are automatically tagged with the full registry path: + +**Local tag (without registry):** +``` +mcpgateway-gateway:latest +mcpgateway-opapluginfilter:latest +``` + +**Registry tag (with registry enabled):** +``` +registry.example.com/myproject/mcpgateway-gateway:latest +registry.example.com/myproject/mcpgateway-opapluginfilter:latest +``` + +### Image Pull Policies + +Choose the appropriate policy for your use case: + +| Policy | Description | Best For | +|--------|-------------|----------| +| `Always` | Pull image every time pod starts | Development, testing latest changes | +| `IfNotPresent` | Pull only if image doesn't exist locally | Production, stable releases | +| `Never` | Never pull, only use local images | Air-gapped environments | + +### Workflow Example + +#### OpenShift Local Deployment + +```bash +# 1. Authenticate to OpenShift registry +podman login $(oc registry info) -u $(oc whoami) -p $(oc whoami -t) + +# 2. Build and push images +cforge gateway deploy examples/deployment-configs/deploy-openshift-local-registry.yaml + +# The tool will: +# - Build images locally +# - Tag with registry paths +# - Push to OpenShift internal registry +# - Generate manifests with registry image references +# - Deploy to cluster + +# 3. Verify images were pushed +oc get imagestreams -n mcp-gateway-test + +# Output: +# NAME IMAGE REPOSITORY +# mcpgateway-gateway default-route-.../mcp-gateway-test/mcpgateway-gateway +# mcpgateway-opapluginfilter default-route-.../mcp-gateway-test/mcpgateway-opapluginfilter +``` + +#### CI/CD Pipeline Example + +```bash +# In your CI/CD pipeline: + +# 1. Authenticate to registry +echo "$REGISTRY_PASSWORD" | docker login $REGISTRY_URL -u $REGISTRY_USER --password-stdin + +# 2. Build and push +cforge gateway build deploy-prod.yaml + +# 3. Images are automatically pushed to registry + +# 4. Deploy to Kubernetes (manifests already reference registry images) +cforge gateway deploy deploy-prod.yaml --skip-build --skip-certs +``` + +### Per-Component Configuration + +Each component (gateway and plugins) can have different registry settings: + +```yaml +gateway: + repo: https://github.com/myorg/gateway.git + registry: + enabled: true + url: quay.io + namespace: myorg + push: true + +plugins: + - name: MyPlugin + repo: https://github.com/myorg/plugin.git + registry: + enabled: true + url: docker.io # Different registry + namespace: myusername # Different namespace + push: true + + - name: InternalPlugin + repo: https://github.com/myorg/internal-plugin.git + # No registry - use local image only + registry: + enabled: false +``` + +This allows you to: +- Push gateway to one registry, plugins to another +- Skip registry push for some components +- Use different namespaces per component +- Mix local and registry images + +### Tag-Only Mode + +To tag images without pushing (useful for testing): + +```yaml +registry: + enabled: true + url: registry.example.com + namespace: myproject + push: false # Tag but don't push +``` + +**Use cases:** +- Test registry configuration before pushing +- Generate manifests with registry paths for GitOps +- Manual push workflow + +### Troubleshooting + +#### Authentication Errors + +**Error:** `Failed to push to registry: unauthorized` + +**Solution:** Authenticate to the registry before building: +```bash +# Docker Hub +docker login + +# Quay.io +podman login quay.io + +# Private registry +podman login registry.mycompany.com -u myusername + +# OpenShift +podman login $(oc registry info) -u $(oc whoami) -p $(oc whoami -t) +``` + +#### ImagePullBackOff in Kubernetes + +**Error:** Pods show `ImagePullBackOff` status + +**Possible causes:** +1. Image doesn't exist in registry (push failed) +2. Registry authentication not configured in Kubernetes +3. Network connectivity issues +4. Wrong image path/tag + +**Solutions:** + +**1. Verify image exists:** +```bash +# OpenShift +oc get imagestreams -n mcp-gateway-test + +# Docker Hub/Quay +podman search your-registry.com/namespace/image-name +``` + +**2. Configure Kubernetes pull secrets:** +```bash +# Create docker-registry secret +kubectl create secret docker-registry regcred \ + --docker-server=registry.example.com \ + --docker-username=myusername \ + --docker-password=mypassword \ + --docker-email=myemail@example.com \ + -n mcp-gateway-test + +# Update deployment to use secret (manual step, or add to template) +``` + +**3. For OpenShift, grant pull permissions:** +```bash +# Allow default service account to pull from namespace +oc policy add-role-to-user system:image-puller \ + system:serviceaccount:mcp-gateway-test:default \ + -n mcp-gateway-test +``` + +#### Push Failed: Too Large + +**Error:** `image push failed: blob upload exceeds max size` + +**Solution:** Some registries have size limits. Options: +1. Use multi-stage builds to reduce image size +2. Switch to a registry with larger limits +3. Split into smaller images + +#### Podman Trying HTTP Instead of HTTPS (OpenShift/CRC) + +**Error:** `pinging container registry ...: Get "http://...: dial tcp 127.0.0.1:80: connection refused` + +**Cause:** Podman doesn't know the registry uses HTTPS and defaults to HTTP on port 80. + +**Solution:** Configure podman to use HTTPS for the registry: + +```bash +# SSH into podman machine and configure registries.conf +podman machine ssh -- "sudo bash -c ' +if ! grep -q \"default-route-openshift-image-registry.apps-crc.testing\" /etc/containers/registries.conf 2>/dev/null; then + echo \"\" >> /etc/containers/registries.conf + echo \"[[registry]]\" >> /etc/containers/registries.conf + echo \"location = \\\"default-route-openshift-image-registry.apps-crc.testing\\\"\" >> /etc/containers/registries.conf + echo \"insecure = true\" >> /etc/containers/registries.conf + echo \"Registry configuration added\" +else + echo \"Registry already configured\" +fi +'" + +# Restart podman machine +podman machine restart + +# Wait for restart +sleep 10 + +# Verify you can now push +podman push default-route-openshift-image-registry.apps-crc.testing/namespace/image:tag +``` + +**Alternative solution:** Use the internal registry service name instead of the route: + +```yaml +registry: + url: image-registry.openshift-image-registry.svc:5000 + namespace: mcp-gateway-test +``` + +This bypasses the external route and connects directly to the internal service (HTTPS on port 5000). + +#### Registry URL Format + +**Correct formats:** +```yaml +url: docker.io # Docker Hub +url: quay.io # Quay.io +url: gcr.io # Google Container Registry +url: registry.mycompany.com # Private registry +url: default-route-openshift-image-registry.apps-crc.testing # OpenShift +``` + +**Incorrect formats:** +```yaml +url: https://docker.io # No protocol +url: docker.io/myusername # No namespace in URL +url: registry:5000 # Include port in URL, not namespace +``` + +### Best Practices + +✅ **DO:** +- Authenticate to registry before building +- Use specific version tags in production (not `:latest`) +- Test registry configuration with `push: false` first +- Use `image_pull_policy: Always` for development +- Use `image_pull_policy: IfNotPresent` for production +- Organize images by namespace/project + +❌ **DON'T:** +- Commit registry credentials to Git +- Use `latest` tag in production +- Mix local and registry images without testing +- Skip authentication step +- Use `push: true` for testing without verifying first + +### Example Configurations + +Full examples available in: +- `examples/deployment-configs/deploy-openshift-local.yaml` - Registry config commented +- `examples/deployment-configs/deploy-openshift-local-registry.yaml` - Full registry setup + +--- + +## Deployment Modes + +### Plain Python Mode (Default) + +**What is it?** +Pure Python implementation using standard tools (`docker`, `kubectl`, `git`, etc.). This is the **default mode** to avoid automatic downloads. + +**When to use:** +- ✅ Default choice (no surprises) +- ✅ Environments without Dagger support +- ✅ Air-gapped networks +- ✅ Simple deployments +- ✅ Debugging/troubleshooting + +**Requirements:** +- Python 3.11+ +- Docker CLI +- `kubectl` (for Kubernetes deployments) +- `git` (for building from source) + +**Usage:** +```bash +# Plain Python mode (default, no flag needed) +cforge gateway deploy deploy.yaml +``` + +**Characteristics:** +- Sequential builds +- Standard caching +- No external dependencies beyond Docker/kubectl + +--- + +### Dagger Mode (Opt-in) + +**What is Dagger?** +Dagger is a programmable CI/CD engine that runs pipelines in containers. It provides: +- **Reproducible builds**: Same results everywhere +- **Parallel execution**: Faster builds +- **Intelligent caching**: Only rebuild what changed +- **Cross-platform**: Works on any system with Docker + +**When to use:** +- ✅ Local development (fastest builds) +- ✅ CI/CD pipelines (GitHub Actions, GitLab CI, etc.) +- ✅ Team environments (consistent results) +- ✅ When you want optimized build performance + +**Requirements:** +- Docker or compatible container runtime +- `dagger-io` Python package (optional, installed separately) +- **Note**: First use will auto-download the Dagger CLI (~100MB) + +**Enable:** +```bash +# Install dagger-io package first +pip install dagger-io + +# Use Dagger mode (opt-in with --dagger flag) +cforge gateway --dagger deploy deploy.yaml +``` + +**Performance benefits:** +- 2-3x faster builds with caching +- Parallel plugin builds +- Efficient layer reuse + +**Important**: Using `--dagger` will automatically download the Dagger CLI binary on first use if not already present. Use plain Python mode if you want to avoid automatic downloads + +--- + +## CI/CD Integration + +### GitHub Actions + +```yaml +name: Deploy MCP Gateway + +on: + push: + branches: [main] + +jobs: + deploy: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Install cforge + run: pip install -e . + + - name: Validate configuration + run: cforge gateway validate deploy/deploy-prod.yaml + + - name: Build containers + run: cforge gateway build deploy/deploy-prod.yaml + env: + DOCKER_REGISTRY: ${{ secrets.DOCKER_REGISTRY }} + + - name: Generate certificates + run: cforge gateway certs deploy/deploy-prod.yaml + + - name: Deploy to Kubernetes + run: cforge gateway deploy deploy/deploy-prod.yaml --skip-build + env: + KUBECONFIG: ${{ secrets.KUBECONFIG }} + + - name: Verify deployment + run: cforge gateway verify deploy/deploy-prod.yaml +``` + +--- + +### GitLab CI + +```yaml +stages: + - validate + - build + - deploy + +variables: + CONFIG_FILE: deploy/deploy-prod.yaml + +validate: + stage: validate + script: + - pip install -e . + - cforge gateway validate $CONFIG_FILE + +build: + stage: build + script: + - pip install -e . + - cforge gateway build $CONFIG_FILE + artifacts: + paths: + - deploy/ + +deploy: + stage: deploy + script: + - pip install -e . + - cforge gateway deploy $CONFIG_FILE --skip-build + environment: + name: production + only: + - main +``` + +--- + +## Best Practices + +### Configuration Management + +✅ **DO:** +- Version control your `deploy.yaml` +- Use Git tags/branches for plugin versions (`ref: v1.2.3`) +- Separate configs for dev/staging/prod +- Document custom `env_vars` in comments + +❌ **DON'T:** +- Hardcode secrets in YAML (use environment files) +- Use `ref: main` in production (pin versions) +- Commit generated certificates to Git + +### Environment Variables + +✅ **DO:** +```bash +# Review and customize .env files after build +cforge gateway build deploy.yaml +# Edit deploy/env/.env.gateway +# Edit deploy/env/.env.PluginName +cforge gateway deploy deploy.yaml --skip-build +``` + +❌ **DON'T:** +```bash +# Deploy without reviewing environment +cforge gateway deploy deploy.yaml # May use default/insecure values +``` + +### Certificate Management + +✅ **DO:** +- Let `cforge` auto-generate certificates +- Rotate certificates before expiry +- Use separate CAs for dev/staging/prod +- Backup CA private key securely + +❌ **DON'T:** +- Share certificates between environments +- Commit CA private key to Git +- Use expired certificates + +### Resource Limits + +✅ **DO:** +```yaml +gateway: + memory_request: 256Mi + memory_limit: 512Mi # 2x request for burst capacity + cpu_request: 100m + cpu_limit: 500m # Allow bursting +``` + +❌ **DON'T:** +```yaml +gateway: + # Missing resource limits = unbounded usage + # OR + memory_limit: 256Mi # Too tight, may OOM +``` + +### High Availability + +✅ **DO:** +```yaml +gateway: + replicas: 2 # Multiple replicas + service_type: LoadBalancer + +plugins: + - name: CriticalPlugin + replicas: 2 # HA for critical plugins +``` + +❌ **DON'T:** +```yaml +gateway: + replicas: 1 # Single point of failure in production +``` + +--- + +## Troubleshooting + +### Build Issues + +**Problem: Git clone fails** +``` +Error: Failed to clone repository +``` + +**Solution:** +- Check `repo` URL is correct +- Verify Git credentials/SSH keys +- Ensure network connectivity +- For private repos, configure Git auth + +--- + +**Problem: Docker build fails** +``` +Error: Build failed for plugin MyPlugin +``` + +**Solution:** +1. Check `context` and `containerfile` paths +2. Verify Containerfile syntax +3. Review plugin repository structure +4. Try building manually: + ```bash + git clone + cd + docker build -f . + ``` + +--- + +### Deployment Issues + +**Problem: Pod/container fails to start** +``` +Error: CrashLoopBackOff +``` + +**Solution:** +1. Check logs: + ```bash + # Kubernetes + kubectl logs -n + + # Docker Compose + docker-compose -f deploy/docker-compose.yaml logs + ``` +2. Verify environment variables in `deploy/env/` +3. Check resource limits (may be too low) +4. Verify image was built/pulled correctly + +--- + +**Problem: mTLS connection fails** +``` +Error: SSL certificate verification failed +``` + +**Solution:** +1. Regenerate certificates: + ```bash + rm -rf certs/ + cforge gateway certs deploy.yaml + ``` +2. Redeploy to distribute new certs: + ```bash + cforge gateway deploy deploy.yaml --skip-build --skip-certs + ``` +3. Check certificate expiry: + ```bash + openssl x509 -in certs/mcp/gateway/client.crt -noout -dates + ``` + +--- + +### Verification Issues + +**Problem: Deployment verification timeout** +``` +Error: Verification failed: timeout waiting for deployment +``` + +**Solution:** +1. Increase timeout: + ```bash + cforge gateway verify deploy.yaml --timeout 600 + ``` +2. Check pod/container status manually +3. Review resource availability (CPU/memory) +4. Check for image pull errors + +--- + +## FAQ + +**Q: Can I use pre-built images instead of building from source?** + +A: Yes! Just specify `image` instead of `repo`: +```yaml +plugins: + - name: MyPlugin + image: myorg/myplugin:v1.0.0 +``` + +--- + +**Q: How do I update a plugin to a new version?** + +A: Update the `ref` and redeploy: +```yaml +plugins: + - name: MyPlugin + repo: https://github.com/org/repo.git + ref: v2.0.0 # ← Update version +``` + +Then: +```bash +cforge gateway build deploy.yaml --plugin MyPlugin --no-cache +cforge gateway deploy deploy.yaml --skip-certs +``` + +--- + +**Q: Can I deploy only the gateway without plugins?** + +A: Yes, just omit the `plugins` section or use an empty array: +```yaml +plugins: [] +``` + +--- + +**Q: How do I add custom environment variables?** + +A: Two ways: + +**1. In YAML (committed to Git):** +```yaml +gateway: + env_vars: + CUSTOM_VAR: value +``` + +**2. In .env file (not committed):** +```bash +# deploy/env/.env.gateway +CUSTOM_VAR=value +``` + +--- + +**Q: Can I use cforge in a CI/CD pipeline?** + +A: Absolutely! See [CI/CD Integration](#cicd-integration) section above. + +--- + +**Q: How do I switch between Dagger and plain Python modes?** + +A: +```bash +# Plain Python mode (default) +cforge gateway deploy deploy.yaml + +# Dagger mode (opt-in, requires dagger-io package) +cforge gateway --dagger deploy deploy.yaml +``` + +**Note**: Dagger mode requires installing the `dagger-io` package and will auto-download the Dagger CLI (~100MB) on first use + +--- + +**Q: Where are the generated manifests stored?** + +A: Default: `deploy/` directory +- `deploy/docker-compose.yaml` (Compose mode) +- `deploy/manifests/` (Kubernetes mode) + +Custom location: +```bash +cforge gateway deploy deploy.yaml --output-dir ./my-deploy +``` + +--- + +**Q: How do I access the gateway after deployment?** + +A: +- **Docker Compose**: `http://localhost:` (default: 4444) +- **Kubernetes LoadBalancer**: Get external IP: + ```bash + kubectl get svc -n mcp-gateway + ``` +- **Kubernetes ClusterIP**: Port-forward: + ```bash + kubectl port-forward -n svc/mcp-gateway 4444:4444 + ``` + +--- + +## Additional Resources + +- **Main Documentation**: [ContextForge Documentation](/) +- **Plugin Development**: [Plugin Framework Guide](/plugins/framework) +- **mTLS Setup**: [mTLS Configuration Guide](/using/plugins/mtls) +- **Example Configs**: [`examples/deployment-configs/`](https://github.com/terylt/mcp-context-forge/tree/main/examples/deployment-configs) +- **Source Code**: [`mcpgateway/tools/builder/`](https://github.com/terylt/mcp-context-forge/tree/main/mcpgateway/tools/builder) + +--- + +## Getting Help + +If you encounter issues: + +1. **Check logs**: Review detailed error messages +2. **Validate config**: Run `cforge gateway validate deploy.yaml` +3. **Dry-run**: Test with `cforge gateway deploy deploy.yaml --dry-run` +4. **Verbose mode**: Use `cforge gateway -v ` for detailed output +5. **Debug mode**: Set `export MCP_DEBUG=1` for stack traces +6. **GitHub Issues**: [Report bugs and request features](https://github.com/terylt/mcp-context-forge/issues) + +--- diff --git a/docs/docs/deployment/container.md b/docs/docs/deployment/container.md index 775aeb430..8342e4680 100644 --- a/docs/docs/deployment/container.md +++ b/docs/docs/deployment/container.md @@ -31,12 +31,12 @@ docker logs mcpgateway You can now access the UI at [http://localhost:4444/admin](http://localhost:4444/admin) ### Multi-architecture containers -Note: the container build process creates container images for 'amd64', 'arm64' and 's390x' architectures. The version `ghcr.io/ibm/mcp-context-forge:VERSION` +Note: the container build process creates container images for 'amd64', 'arm64' and 's390x' architectures. The version `ghcr.io/ibm/mcp-context-forge:VERSION` not points to a manifest so that if all commands will pull the correct image for the architecture being used (whether that be locally or on Kubernetes or OpenShift). If the specific image is needed for one architecture on a different architecture use the appropriate arguments for your given container execution tool: -With docker run: +With docker run: ``` docker run [... all your options...] --platform linux/arm64 ghcr.io/ibm/mcp-context-forge:VERSION ``` diff --git a/docs/docs/manage/mtls.md b/docs/docs/manage/mtls.md index e02ed047f..e69de29bb 100644 --- a/docs/docs/manage/mtls.md +++ b/docs/docs/manage/mtls.md @@ -1,943 +0,0 @@ -# mTLS (Mutual TLS) Configuration - -Configure mutual TLS authentication for MCP Gateway to enable certificate-based client authentication and enhanced security. - -## Overview - -Mutual TLS (mTLS) provides bidirectional authentication between clients and servers using X.509 certificates. While native mTLS support is in development ([#568](https://github.com/IBM/mcp-context-forge/issues/568)), MCP Gateway can leverage reverse proxies for production-ready mTLS today. - -## Current Status - -- **Native mTLS**: 🚧 In Progress - tracked in [#568](https://github.com/IBM/mcp-context-forge/issues/568) -- **Proxy-based mTLS**: ✅ Available - using Nginx, Caddy, or other reverse proxies -- **Container Support**: ✅ Ready - lightweight containers support proxy deployment - -## Architecture - -```mermaid -sequenceDiagram - participant Client - participant Proxy as Reverse Proxy
(Nginx/Caddy) - participant Gateway as MCP Gateway - participant MCP as MCP Server - - Client->>Proxy: TLS Handshake
+ Client Certificate - Proxy->>Proxy: Verify Client Cert - Proxy->>Gateway: HTTP + X-SSL Headers - Gateway->>Gateway: Extract User from Headers - Gateway->>MCP: Forward Request - MCP-->>Gateway: Response - Gateway-->>Proxy: Response - Proxy-->>Client: TLS Response -``` - -## Quick Start - -### Option 1: Docker Compose with Nginx mTLS - -1. **Generate certificates** (for testing): - -```bash -# Create certificates directory -mkdir -p certs/mtls - -# Generate CA certificate -openssl req -x509 -newkey rsa:4096 -days 365 -nodes \ - -keyout certs/mtls/ca.key -out certs/mtls/ca.crt \ - -subj "/C=US/ST=State/L=City/O=MCP-CA/CN=MCP Root CA" - -# Generate server certificate -openssl req -newkey rsa:4096 -nodes \ - -keyout certs/mtls/server.key -out certs/mtls/server.csr \ - -subj "/CN=gateway.local" - -openssl x509 -req -in certs/mtls/server.csr \ - -CA certs/mtls/ca.crt -CAkey certs/mtls/ca.key \ - -CAcreateserial -out certs/mtls/server.crt -days 365 - -# Generate client certificate -openssl req -newkey rsa:4096 -nodes \ - -keyout certs/mtls/client.key -out certs/mtls/client.csr \ - -subj "/CN=admin@example.com" - -openssl x509 -req -in certs/mtls/client.csr \ - -CA certs/mtls/ca.crt -CAkey certs/mtls/ca.key \ - -CAcreateserial -out certs/mtls/client.crt -days 365 - -# Create client bundle for testing -cat certs/mtls/client.crt certs/mtls/client.key > certs/mtls/client.pem -``` - -2. **Create Nginx configuration** (`nginx-mtls.conf`): - -```nginx -events { - worker_connections 1024; -} - -http { - upstream mcp_gateway { - server gateway:4444; - } - - server { - listen 443 ssl; - server_name gateway.local; - - # Server certificates - ssl_certificate /etc/nginx/certs/server.crt; - ssl_certificate_key /etc/nginx/certs/server.key; - - # mTLS client verification - ssl_client_certificate /etc/nginx/certs/ca.crt; - ssl_verify_client on; - ssl_verify_depth 2; - - # Strong TLS settings - ssl_protocols TLSv1.2 TLSv1.3; - ssl_ciphers HIGH:!aNULL:!MD5; - ssl_prefer_server_ciphers on; - - location / { - proxy_pass http://mcp_gateway; - proxy_http_version 1.1; - - # Pass client certificate info to MCP Gateway - proxy_set_header X-SSL-Client-Cert $ssl_client_escaped_cert; - proxy_set_header X-SSL-Client-S-DN $ssl_client_s_dn; - proxy_set_header X-SSL-Client-S-DN-CN $ssl_client_s_dn_cn; - proxy_set_header X-SSL-Client-Verify $ssl_client_verify; - proxy_set_header X-Authenticated-User $ssl_client_s_dn_cn; - - # Standard proxy headers - proxy_set_header Host $host; - proxy_set_header X-Real-IP $remote_addr; - proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; - proxy_set_header X-Forwarded-Proto $scheme; - } - - # WebSocket support - location /ws { - proxy_pass http://mcp_gateway; - proxy_http_version 1.1; - proxy_set_header Upgrade $http_upgrade; - proxy_set_header Connection "upgrade"; - proxy_set_header X-SSL-Client-S-DN-CN $ssl_client_s_dn_cn; - proxy_set_header X-Authenticated-User $ssl_client_s_dn_cn; - } - - # SSE support - location ~ ^/servers/.*/sse$ { - proxy_pass http://mcp_gateway; - proxy_http_version 1.1; - proxy_set_header X-SSL-Client-S-DN-CN $ssl_client_s_dn_cn; - proxy_set_header X-Authenticated-User $ssl_client_s_dn_cn; - proxy_set_header Connection ""; - proxy_buffering off; - proxy_cache off; - } - } -} -``` - -3. **Create Docker Compose file** (`docker-compose-mtls.yml`): - -```yaml -version: '3.8' - -services: - nginx-mtls: - image: nginx:alpine - ports: - - "443:443" - volumes: - - ./nginx-mtls.conf:/etc/nginx/nginx.conf:ro - - ./certs/mtls:/etc/nginx/certs:ro - networks: - - mcpnet - depends_on: - - gateway - - gateway: - image: ghcr.io/ibm/mcp-context-forge:latest - environment: - - HOST=0.0.0.0 - - PORT=4444 - - DATABASE_URL=sqlite:////app/data/mcp.db - - # Disable JWT auth and trust proxy headers - - MCP_CLIENT_AUTH_ENABLED=false - - TRUST_PROXY_AUTH=true - - PROXY_USER_HEADER=X-SSL-Client-S-DN-CN - - # Keep admin UI protected - - AUTH_REQUIRED=true - - BASIC_AUTH_USER=admin - - BASIC_AUTH_PASSWORD=changeme - - # Enable admin features - - MCPGATEWAY_UI_ENABLED=true - - MCPGATEWAY_ADMIN_API_ENABLED=true - networks: - - mcpnet - volumes: - - ./data:/app/data # persists SQLite database at /app/data/mcp.db - -networks: - mcpnet: - driver: bridge -``` -> 💾 Run `mkdir -p data` before `docker-compose up` so the SQLite database survives restarts. - - -4. **Test the connection**: - -```bash -# Start the services -docker-compose -f docker-compose-mtls.yml up -d - -# Test with client certificate -curl --cert certs/mtls/client.pem \ - --cacert certs/mtls/ca.crt \ - https://localhost/health - -# Test without certificate (should fail) -curl https://localhost/health -# Error: SSL certificate problem -``` - -### Option 2: Caddy with mTLS - -1. **Create Caddyfile** (`Caddyfile.mtls`): - -```caddyfile -{ - # Global options - debug -} - -gateway.local { - # Enable mTLS - tls { - client_auth { - mode require_and_verify - trusted_ca_cert_file /etc/caddy/certs/ca.crt - } - } - - # Reverse proxy to MCP Gateway - reverse_proxy gateway:4444 { - # Pass certificate info as headers - header_up X-SSL-Client-Cert {http.request.tls.client.certificate_pem_escaped} - header_up X-SSL-Client-S-DN {http.request.tls.client.subject} - header_up X-SSL-Client-S-DN-CN {http.request.tls.client.subject_cn} - header_up X-Authenticated-User {http.request.tls.client.subject_cn} - - # WebSocket support - @websocket { - header Connection *Upgrade* - header Upgrade websocket - } - transport http { - versions 1.1 - } - } -} -``` - -2. **Docker Compose with Caddy**: - -```yaml -version: '3.8' - -services: - caddy-mtls: - image: caddy:alpine - ports: - - "443:443" - volumes: - - ./Caddyfile.mtls:/etc/caddy/Caddyfile:ro - - ./certs/mtls:/etc/caddy/certs:ro - - caddy_data:/data - - caddy_config:/config - networks: - - mcpnet - depends_on: - - gateway - - gateway: - # Same configuration as Nginx example - image: ghcr.io/ibm/mcp-context-forge:latest - environment: - - MCP_CLIENT_AUTH_ENABLED=false - - TRUST_PROXY_AUTH=true - - PROXY_USER_HEADER=X-SSL-Client-S-DN-CN - # ... rest of config ... - networks: - - mcpnet - -volumes: - caddy_data: - caddy_config: - -networks: - mcpnet: - driver: bridge -``` - -## Production Configuration - -### Enterprise PKI Integration - -For production deployments, integrate with your enterprise PKI: - -```nginx -# nginx.conf - Enterprise PKI -server { - listen 443 ssl; - - # Server certificates from enterprise CA - ssl_certificate /etc/pki/tls/certs/gateway.crt; - ssl_certificate_key /etc/pki/tls/private/gateway.key; - - # Client CA chain - ssl_client_certificate /etc/pki/tls/certs/enterprise-ca-chain.crt; - ssl_verify_client on; - ssl_verify_depth 3; - - # CRL verification - ssl_crl /etc/pki/tls/crl/enterprise.crl; - - # OCSP stapling - ssl_stapling on; - ssl_stapling_verify on; - ssl_trusted_certificate /etc/pki/tls/certs/enterprise-ca-chain.crt; - - location / { - proxy_pass http://mcp-gateway:4444; - - # Extract user from certificate DN - if ($ssl_client_s_dn ~ /CN=([^\/]+)/) { - set $cert_cn $1; - } - proxy_set_header X-Authenticated-User $cert_cn; - - # Extract organization - if ($ssl_client_s_dn ~ /O=([^\/]+)/) { - set $cert_org $1; - } - proxy_set_header X-User-Organization $cert_org; - } -} -``` - -### Kubernetes Deployment Options - -### Option 1: Helm Chart with TLS Ingress - -The MCP Gateway Helm chart (`charts/mcp-stack`) includes built-in TLS support via Ingress: - -```bash -# Install with TLS enabled -helm install mcp-gateway ./charts/mcp-stack \ - --set mcpContextForge.ingress.enabled=true \ - --set mcpContextForge.ingress.host=gateway.example.com \ - --set mcpContextForge.ingress.tls.enabled=true \ - --set mcpContextForge.ingress.tls.secretName=gateway-tls \ - --set mcpContextForge.ingress.annotations."cert-manager\.io/cluster-issuer"=letsencrypt-prod \ - --set mcpContextForge.ingress.annotations."nginx.ingress.kubernetes.io/auth-tls-secret"=mcp-system/gateway-client-ca \ - --set mcpContextForge.ingress.annotations."nginx.ingress.kubernetes.io/auth-tls-verify-client"=on \ - --set mcpContextForge.ingress.annotations."nginx.ingress.kubernetes.io/auth-tls-verify-depth"="2" \ - --set mcpContextForge.ingress.annotations."nginx.ingress.kubernetes.io/auth-tls-pass-certificate-to-upstream"="true" -``` - - -> ℹ️ The configuration snippet that forwards the client CN is easier to maintain in `values.yaml`; the one-liner above focuses on core flags. - -Or configure via `values.yaml`: - -```yaml -# charts/mcp-stack/values.yaml excerpt -mcpContextForge: - ingress: - enabled: true - className: nginx - host: gateway.example.com - annotations: - cert-manager.io/cluster-issuer: letsencrypt-prod - nginx.ingress.kubernetes.io/auth-tls-secret: mcp-system/gateway-client-ca - nginx.ingress.kubernetes.io/auth-tls-verify-client: "on" - nginx.ingress.kubernetes.io/auth-tls-verify-depth: "2" - nginx.ingress.kubernetes.io/auth-tls-pass-certificate-to-upstream: "true" - nginx.ingress.kubernetes.io/configuration-snippet: | - proxy_set_header X-SSL-Client-S-DN $ssl_client_s_dn; - proxy_set_header X-SSL-Client-S-DN-CN $ssl_client_s_dn_cn; - proxy_set_header X-Authenticated-User $ssl_client_s_dn_cn; - tls: - enabled: true - secretName: gateway-tls # cert-manager will generate this - - secret: - MCP_CLIENT_AUTH_ENABLED: "false" - TRUST_PROXY_AUTH: "true" - PROXY_USER_HEADER: X-SSL-Client-S-DN-CN -``` - -Create the `gateway-client-ca` secret in the same namespace as the release so the Ingress controller can validate client certificates. For example: - -```bash -kubectl create secret generic gateway-client-ca \ - --from-file=ca.crt=certs/mtls/ca.crt \ - --namespace mcp-system -``` - -### Option 2: Kubernetes with Istio mTLS - -Deploy MCP Gateway with automatic mTLS in Istio service mesh: - -```yaml -# gateway-deployment.yaml -apiVersion: apps/v1 -kind: Deployment -metadata: - name: mcp-gateway - namespace: mcp-system -spec: - template: - metadata: - labels: - app: mcp-gateway - annotations: - sidecar.istio.io/inject: "true" - spec: - containers: - - name: mcp-gateway - image: ghcr.io/ibm/mcp-context-forge:latest - env: - - name: MCP_CLIENT_AUTH_ENABLED - value: "false" - - name: TRUST_PROXY_AUTH - value: "true" - - name: PROXY_USER_HEADER - value: "X-SSL-Client-S-DN-CN" ---- -# peer-authentication.yaml -apiVersion: security.istio.io/v1beta1 -kind: PeerAuthentication -metadata: - name: mcp-gateway-mtls - namespace: mcp-system -spec: - selector: - matchLabels: - app: mcp-gateway - mtls: - mode: STRICT -``` - -Istio does not add `X-SSL-Client-S-DN-CN` automatically. Use an `EnvoyFilter` to extract the client certificate common name and forward it as the header referenced by `PROXY_USER_HEADER`: - -```yaml -# envoy-filter-client-cn.yaml -apiVersion: networking.istio.io/v1alpha3 -kind: EnvoyFilter -metadata: - name: append-client-cn-header - namespace: mcp-system -spec: - workloadSelector: - labels: - app: mcp-gateway - configPatches: - - applyTo: HTTP_FILTER - match: - context: SIDECAR_INBOUND - listener: - portNumber: 4444 - filterChain: - filter: - name: envoy.filters.network.http_connection_manager - patch: - operation: INSERT_BEFORE - value: - name: envoy.filters.http.lua - typed_config: - "@type": type.googleapis.com/envoy.extensions.filters.http.lua.v3.Lua - inlineCode: | - function envoy_on_request(handle) - local ssl = handle:streamInfo():downstreamSslConnection() - if ssl ~= nil and ssl:peerCertificatePresented() then - local subject = ssl:subjectPeerCertificate() - if subject ~= nil then - local cn = subject:match("CN=([^,/]+)") - if cn ~= nil then - handle:headers():replace("X-SSL-Client-S-DN-CN", cn) - end - end - end - end - function envoy_on_response(handle) - end -``` - -The filter runs in the sidecar and ensures the gateway receives the client's common name rather than the full certificate payload. - -### HAProxy with mTLS - -```haproxy -# haproxy.cfg -global - ssl-default-bind-options ssl-min-ver TLSv1.2 - tune.ssl.default-dh-param 2048 - -frontend mcp_gateway_mtls - bind *:443 ssl crt /etc/haproxy/certs/server.pem ca-file /etc/haproxy/certs/ca.crt verify required - - # Extract certificate information - http-request set-header X-SSL-Client-Cert %[ssl_c_der,base64] - http-request set-header X-SSL-Client-S-DN %[ssl_c_s_dn] - http-request set-header X-SSL-Client-S-DN-CN %[ssl_c_s_dn(CN)] - http-request set-header X-Authenticated-User %[ssl_c_s_dn(CN)] - - default_backend mcp_gateway_backend - -backend mcp_gateway_backend - server gateway gateway:4444 check -``` - -## Certificate Management - -### Certificate Generation Scripts - -Create a script for certificate management (`generate-certs.sh`): - -```bash -#!/bin/bash -set -e - -CERT_DIR="${CERT_DIR:-./certs/mtls}" -CA_DAYS="${CA_DAYS:-3650}" -CERT_DAYS="${CERT_DAYS:-365}" -KEY_SIZE="${KEY_SIZE:-4096}" - -mkdir -p "$CERT_DIR" - -# Generate CA if it doesn't exist -if [ ! -f "$CERT_DIR/ca.crt" ]; then - echo "Generating CA certificate..." - openssl req -x509 -newkey rsa:$KEY_SIZE -days $CA_DAYS -nodes \ - -keyout "$CERT_DIR/ca.key" -out "$CERT_DIR/ca.crt" \ - -subj "/C=US/ST=State/L=City/O=Organization/CN=MCP CA" - echo "CA certificate generated." -fi - -# Function to generate certificates -generate_cert() { - local name=$1 - local cn=$2 - - if [ -f "$CERT_DIR/${name}.crt" ]; then - echo "Certificate for $name already exists, skipping..." - return - fi - - echo "Generating certificate for $name (CN=$cn)..." - - # Generate private key and CSR - openssl req -newkey rsa:$KEY_SIZE -nodes \ - -keyout "$CERT_DIR/${name}.key" -out "$CERT_DIR/${name}.csr" \ - -subj "/CN=$cn" - - # Sign with CA - openssl x509 -req -in "$CERT_DIR/${name}.csr" \ - -CA "$CERT_DIR/ca.crt" -CAkey "$CERT_DIR/ca.key" \ - -CAcreateserial -out "$CERT_DIR/${name}.crt" -days $CERT_DAYS \ - -extfile <(echo "subjectAltName=DNS:$cn") - - # Create bundle - cat "$CERT_DIR/${name}.crt" "$CERT_DIR/${name}.key" > "$CERT_DIR/${name}.pem" - - # Clean up CSR - rm "$CERT_DIR/${name}.csr" - - echo "Certificate for $name generated." -} - -# Generate server certificate -generate_cert "server" "gateway.local" - -# Generate client certificates -generate_cert "admin" "admin@example.com" -generate_cert "user1" "user1@example.com" -generate_cert "service-account" "mcp-service@example.com" - -echo "All certificates generated in $CERT_DIR" -``` - -### Certificate Rotation - -Implement automatic certificate rotation: - -```yaml -# kubernetes CronJob for cert rotation -apiVersion: batch/v1 -kind: CronJob -metadata: - name: cert-rotation - namespace: mcp-system -spec: - schedule: "0 2 * * *" # Daily at 2 AM - jobTemplate: - spec: - template: - spec: - serviceAccountName: cert-rotation - containers: - - name: cert-rotator - image: bitnami/kubectl:1.30 - command: - - /bin/sh - - -c - - | - set -euo pipefail - SECRET_NAME=${CERT_SECRET:-gateway-tls} - CERT_NAME=${CERT_NAME:-gateway-tls-cert} - NAMESPACE=${TARGET_NAMESPACE:-mcp-system} - TLS_CERT=$(kubectl get secret "$SECRET_NAME" -n "$NAMESPACE" -o jsonpath='{.data.tls\.crt}') - if [ -z "$TLS_CERT" ]; then - echo "TLS secret $SECRET_NAME missing or empty" - exit 1 - fi - echo "$TLS_CERT" | base64 -d > /tmp/current.crt - if openssl x509 -checkend 604800 -noout -in /tmp/current.crt; then - echo "Certificate valid for more than 7 days" - else - echo "Certificate expiring soon, requesting renewal" - kubectl cert-manager renew "$CERT_NAME" -n "$NAMESPACE" || echo "Install the kubectl-cert_manager plugin inside the job image to enable automatic renewal" - fi - env: - - name: CERT_SECRET - value: gateway-tls - - name: CERT_NAME - value: gateway-tls-cert - - name: TARGET_NAMESPACE - value: mcp-system - volumeMounts: - - name: tmp - mountPath: /tmp - restartPolicy: OnFailure - volumes: - - name: tmp - emptyDir: {} -``` - -Create a `ServiceAccount`, `Role`, and `RoleBinding` that grant `get` access to the TLS secret and `update` access to the related `Certificate` resource so the job can request renewals. - - -> 🔧 Install the [`kubectl-cert_manager` plugin](https://cert-manager.io/docs/reference/kubectl-plugin/) or swap the command for `cmctl renew` if you prefer Jetstack's CLI image, and ensure your job image bundles both `kubectl` and `openssl`. - -## mTLS for External MCP Plugins - -External plugins that use the `STREAMABLEHTTP` transport now support mutual TLS directly from the gateway. This is optional—if you skip the configuration below, the gateway continues to call plugins exactly as before. Enabling mTLS lets you restrict remote plugin servers so they only accept connections from gateways presenting a trusted client certificate. - -### 1. Issue Certificates for the Remote Plugin - -Reuse the same CA you generated earlier or provision a dedicated one. Create a **server** certificate for the remote plugin endpoint and a **client** certificate for the MCP Gateway: - -```bash -# Server cert for the remote plugin (served by your reverse proxy/mcp server) -openssl req -newkey rsa:4096 -nodes \ - -keyout certs/plugins/remote.key -out certs/plugins/remote.csr \ - -subj "/CN=plugins.internal.example.com" - -openssl x509 -req -in certs/plugins/remote.csr \ - -CA certs/mtls/ca.crt -CAkey certs/mtls/ca.key \ - -CAcreateserial -out certs/plugins/remote.crt -days 365 \ - -extfile <(echo "subjectAltName=DNS:plugins.internal.example.com") - -# Client cert for the gateway -openssl req -newkey rsa:4096 -nodes \ - -keyout certs/plugins/gateway-client.key -out certs/plugins/gateway-client.csr \ - -subj "/CN=mcpgateway" - -openssl x509 -req -in certs/plugins/gateway-client.csr \ - -CA certs/mtls/ca.crt -CAkey certs/mtls/ca.key \ - -CAcreateserial -out certs/plugins/gateway-client.crt -days 365 - -cat certs/plugins/gateway-client.crt certs/plugins/gateway-client.key > certs/plugins/gateway-client.pem -``` - -### 2. Protect the Remote Plugin with mTLS - -Front the remote MCP plugin with a reverse proxy (Nginx, Caddy, Envoy, etc.) that enforces client certificate verification using the CA above. Example Nginx snippet: - -```nginx -server { - listen 9443 ssl; - server_name plugins.internal.example.com; - - ssl_certificate /etc/ssl/private/remote.crt; - ssl_certificate_key /etc/ssl/private/remote.key; - ssl_client_certificate /etc/ssl/private/ca.crt; - ssl_verify_client on; - - location /mcp { - proxy_pass http://plugin-runtime:8000/mcp; - proxy_set_header Host $host; - proxy_set_header X-Forwarded-Proto https; - } -} -``` - -### 3. Mount Certificates into the Gateway - -Expose the CA bundle and gateway client certificate to the gateway container: - -```yaml -# docker-compose override - gateway: - volumes: - - ./certs/plugins:/app/certs/plugins:ro - -# Kubernetes deployment (snippet) -volumeMounts: - - name: plugin-mtls - mountPath: /app/certs/plugins - readOnly: true -volumes: - - name: plugin-mtls - secret: - secretName: gateway-plugin-mtls -``` - -### 4. Configure the Plugin Entry - -Use the new `mcp.tls` block in `plugins/config.yaml` (or the Admin UI) to point the gateway at the certificates. Example external plugin definition: - -```yaml -plugins: - - name: "LlamaGuardSafety" - kind: "external" - hooks: ["prompt_pre_fetch", "tool_pre_invoke"] - mode: "enforce" - priority: 20 - mcp: - proto: STREAMABLEHTTP - url: https://plugins.internal.example.com:9443/mcp - tls: - ca_bundle: /app/certs/plugins/ca.crt - client_cert: /app/certs/plugins/gateway-client.pem - client_key: /app/certs/plugins/gateway-client.key # optional if PEM already bundles key - verify: true - check_hostname: true - - config: - policy: strict -``` - -**Key behavior** -- `verify` controls whether the gateway validates the remote server certificate. Leave `true` in production; set `false` only for local debugging. -- `ca_bundle` may point to a custom CA chain; omit it if the remote certificate chains to a system-trusted CA. -- `client_cert` must reference the gateway certificate. Provide `client_key` only when the key is stored separately. -- `check_hostname` defaults to `true`. Set it to `false` for scenarios where the certificate CN does not match the URL (not recommended outside testing). - -Restart the gateway after updating the config so the external plugin client reloads with the TLS settings. Watch the logs for `Connected to plugin MCP (http) server` to confirm a successful handshake; TLS errors will surface as plugin initialization failures. - -> 💡 **Tip:** You can set gateway-wide defaults via `PLUGINS_MTLS_CA_BUNDLE`, -> `PLUGINS_MTLS_CLIENT_CERT`, `PLUGINS_MTLS_CLIENT_KEY`, and the other -> `PLUGINS_MTLS_*` environment variables. Any plugin without an explicit -> `tls` block will inherit these values automatically. - - -## Security Best Practices - -### 1. Certificate Validation - -```nginx -# Strict certificate validation -ssl_verify_client on; -ssl_verify_depth 2; - -# Check certificate validity -ssl_session_cache shared:SSL:10m; -ssl_session_timeout 10m; - -# Enable OCSP stapling -ssl_stapling on; -ssl_stapling_verify on; -resolver 8.8.8.8 8.8.4.4 valid=300s; -resolver_timeout 5s; -``` - -### 2. Certificate Pinning - -```python -# MCP Gateway plugin for cert pinning -class CertificatePinningPlugin: - def __init__(self): - self.pinned_certs = { - "admin@example.com": "sha256:HASH...", - "service@example.com": "sha256:HASH..." - } - - async def on_request(self, request): - cert_header = request.headers.get("X-SSL-Client-Cert") - if cert_header: - cert_hash = self.calculate_hash(cert_header) - user = request.headers.get("X-Authenticated-User") - - if user in self.pinned_certs: - if self.pinned_certs[user] != cert_hash: - raise SecurityException("Certificate pin mismatch") -``` - -### 3. Audit Logging - -Configure comprehensive audit logging for mTLS connections: - -```nginx -# nginx.conf - Audit logging -log_format mtls_audit '$remote_addr - $ssl_client_s_dn [$time_local] ' - '"$request" $status $body_bytes_sent ' - '"$http_user_agent" cert_verify:$ssl_client_verify'; - -access_log /var/log/nginx/mtls-audit.log mtls_audit; -``` - -### 4. Rate Limiting by Certificate - -```nginx -# Rate limit by certificate CN -limit_req_zone $ssl_client_s_dn_cn zone=cert_limit:10m rate=10r/s; - -location / { - limit_req zone=cert_limit burst=20 nodelay; - proxy_pass http://mcp-gateway; -} -``` - -## Monitoring & Troubleshooting - -### Health Checks - -```bash -# Check mTLS connectivity -openssl s_client -connect gateway.local:443 \ - -cert certs/mtls/client.crt \ - -key certs/mtls/client.key \ - -CAfile certs/mtls/ca.crt \ - -showcerts - -# Verify certificate -openssl x509 -in certs/mtls/client.crt -text -noout - -# Test with curl -curl -v --cert certs/mtls/client.pem \ - --cacert certs/mtls/ca.crt \ - https://gateway.local/health -``` - -### Common Issues - -| Issue | Cause | Solution | -|-------|-------|----------| -| `SSL certificate verify error` | Missing/invalid client cert | Ensure client cert is valid and signed by CA | -| `400 No required SSL certificate` | mTLS not configured | Check `ssl_verify_client on` in proxy | -| `X-Authenticated-User missing` | Header not passed | Verify proxy_set_header configuration | -| `Connection refused` | Service not running | Check docker-compose logs | -| `Certificate expired` | Cert past validity | Regenerate certificates | - -### Debug Logging - -Enable debug logging in your reverse proxy: - -```nginx -# nginx.conf -error_log /var/log/nginx/error.log debug; - -# Log SSL handshake details -ssl_session_cache shared:SSL:10m; -ssl_session_timeout 10m; -``` - -## Migration Path - -### From JWT to mTLS - -1. **Phase 1**: Deploy proxy with mTLS alongside existing JWT auth -2. **Phase 2**: Run dual-mode (both JWT and mTLS accepted) -3. **Phase 3**: Migrate all clients to certificates -4. **Phase 4**: Disable JWT, enforce mTLS only - -```yaml -# Dual-mode configuration -environment: - # Accept both methods during migration - - MCP_CLIENT_AUTH_ENABLED=true # Keep JWT active - - TRUST_PROXY_AUTH=true # Also trust proxy - - PROXY_USER_HEADER=X-SSL-Client-S-DN-CN -``` - -## Helm Chart Configuration - -The MCP Gateway Helm chart in `charts/mcp-stack/` provides extensive configuration options for TLS and security: - -### Key Security Settings in values.yaml - -```yaml -mcpContextForge: - # JWT Configuration - supports both HMAC and RSA - secret: - JWT_ALGORITHM: HS256 # or RS256 for asymmetric - JWT_SECRET_KEY: my-test-key # for HMAC algorithms - # For RSA/ECDSA, mount keys and set: - # JWT_PUBLIC_KEY_PATH: /app/certs/jwt/public.pem - # JWT_PRIVATE_KEY_PATH: /app/certs/jwt/private.pem - - # Security Headers (enabled by default) - config: - SECURITY_HEADERS_ENABLED: "true" - X_FRAME_OPTIONS: DENY - HSTS_ENABLED: "true" - HSTS_MAX_AGE: "31536000" - SECURE_COOKIES: "true" - - # Ingress with TLS - ingress: - enabled: true - tls: - enabled: true - secretName: gateway-tls -``` - -### Deploying with Helm and mTLS - -```bash -# Create namespace -kubectl create namespace mcp-gateway - -# Install with custom TLS settings -helm install mcp-gateway ./charts/mcp-stack \ - --namespace mcp-gateway \ - --set mcpContextForge.ingress.tls.enabled=true \ - --set mcpContextForge.secret.JWT_ALGORITHM=RS256 \ - --values custom-values.yaml -``` - -## Future Native mTLS Support - -When native mTLS support lands ([#568](https://github.com/IBM/mcp-context-forge/issues/568)), expect: - -- Direct TLS termination in MCP Gateway -- Certificate-based authorization policies -- Integration with enterprise PKI systems -- Built-in certificate validation and revocation checking -- Automatic certificate rotation -- Per-service certificate management - -## Related Documentation - -- [Proxy Authentication](./proxy.md) - Configuring proxy-based authentication -- [Security Features](../architecture/security-features.md) - Overall security architecture -- [Deployment Guide](../deployment/index.md) - Production deployment options -- [Authentication Overview](./securing.md) - All authentication methods diff --git a/docs/docs/using/plugins/plugins.md b/docs/docs/using/plugins/plugins.md index 660afdce3..36e57c2ce 100644 --- a/docs/docs/using/plugins/plugins.md +++ b/docs/docs/using/plugins/plugins.md @@ -104,6 +104,7 @@ Plugins for enforcing custom policies and business rules. | Plugin | Type | Description | |--------|------|-------------| | [OPA Plugin](https://github.com/IBM/mcp-context-forge/tree/main/plugins/external/opa) | External | Enforces Rego policies on tool invocations via an OPA server. Allows selective policy application per tool with context injection and customizable policy endpoints | +| [Cedar (RBAC) Plugin](https://github.com/IBM/mcp-context-forge/tree/main/plugins/external/cedar) | External | Enforces RBAC-based policies on MCP servers using Cedar (leveraging the cedarpy library) or a custom DSL, for local evaluation with flexible configuration and output redaction. | ## Plugin Types diff --git a/examples/deployment-configs/cert-manager-issuer-example.yaml b/examples/deployment-configs/cert-manager-issuer-example.yaml new file mode 100644 index 000000000..5b96aae91 --- /dev/null +++ b/examples/deployment-configs/cert-manager-issuer-example.yaml @@ -0,0 +1,58 @@ +# cert-manager CA Issuer Setup (APPLY ONCE) +# This example shows how to set up a self-signed CA using cert-manager +# for issuing mTLS certificates to the MCP Gateway and plugins. +# +# Prerequisites: +# - cert-manager must be installed in your cluster +# Install: kubectl apply -f https://github.com/cert-manager/cert-manager/releases/download/v1.13.0/cert-manager.yaml +# +# Usage: +# 1. Create namespace: kubectl create namespace mcp-gateway-test +# 2. Apply this file ONCE: kubectl apply -f cert-manager-issuer-example.yaml +# 3. Deploy stack with use_cert_manager: true in mcp-stack.yaml +# +# NOTE: This creates long-lived infrastructure (CA + Issuer). +# Do NOT delete this when tearing down your MCP stack deployment. +# The CA certificate will be reused across deployments. +# +--- +# Self-signed Issuer (used to create the CA certificate) +apiVersion: cert-manager.io/v1 +kind: Issuer +metadata: + name: mcp-selfsigned-issuer + namespace: mcp-gateway-test +spec: + selfSigned: {} + +--- +# CA Certificate (root of trust for all mTLS certificates) +apiVersion: cert-manager.io/v1 +kind: Certificate +metadata: + name: mcp-ca-certificate + namespace: mcp-gateway-test +spec: + isCA: true + commonName: mcp-ca + secretName: mcp-ca-secret + duration: 19800h # 825 days (≈ 2.25 years) + renewBefore: 13200h # Renew at 2/3 of lifetime + privateKey: + algorithm: RSA + size: 4096 + issuerRef: + name: mcp-selfsigned-issuer + kind: Issuer + +--- +# CA Issuer (used to sign gateway and plugin certificates) +# This is what your mcp-stack.yaml references via cert_manager_issuer +apiVersion: cert-manager.io/v1 +kind: Issuer +metadata: + name: mcp-ca-issuer + namespace: mcp-gateway-test +spec: + ca: + secretName: mcp-ca-secret diff --git a/examples/deployment-configs/deploy-compose.mtls.yaml b/examples/deployment-configs/deploy-compose.mtls.yaml new file mode 100644 index 000000000..c77dddc01 --- /dev/null +++ b/examples/deployment-configs/deploy-compose.mtls.yaml @@ -0,0 +1,99 @@ +# MCP Stack - Local Docker Compose Test Configuration +# This config deploys MCP Gateway + external plugins locally with mTLS + +deployment: + type: compose + project_name: mcp-stack-test + +# MCP Gateway configuration +gateway: + # Use local gateway image (build first with: make container-build) + image: mcpgateway/mcpgateway:latest + + port: 4444 + host_port: 4444 # Expose on localhost:4444 + + # Environment configuration + # env_file will auto-detect deploy/env/.env.gateway if not specified + env_vars: + LOG_LEVEL: DEBUG + HOST: 0.0.0.0 + PORT: 4444 + + # Enable features + MCPGATEWAY_UI_ENABLED: "true" + MCPGATEWAY_ADMIN_API_ENABLED: "true" + MCPGATEWAY_A2A_ENABLED: "true" + + # Auth + AUTH_REQUIRED: "false" # Disabled for easy testing + + # Federation + MCPGATEWAY_ENABLE_FEDERATION: "false" + + # mTLS client configuration (gateway connects to plugins) + mtls_enabled: true + mtls_verify: true # Verify server certificates (default: true) + mtls_check_hostname: false # Don't verify hostname (default: false for compose) + + # Note: plugins-config.yaml is auto-generated from the plugins section below + # No need to specify config_file anymore! + +# External plugins +plugins: + # OPA Plugin Filter + - name: OPAPluginFilter + + # Build from GitHub repository + repo: https://github.com/terylt/mcp-context-forge.git + ref: feat/use_mtls_plugins + context: plugins/external/opa + containerfile: Containerfile + + # Defaults: port=8000, host_port auto-assigned (8000, 8001, ...) + expose_port: true # Expose for testing + + # env_file will auto-detect deploy/env/.env.OPAPluginFilter if not specified + env_vars: + LOG_LEVEL: DEBUG + + # OPA-specific settings + OPA_POLICY_PATH: /app/policies + + # mTLS server configuration + mtls_enabled: true + + # Plugin manager overrides (client-side configuration) + plugin_overrides: + priority: 10 + mode: "enforce" + description: "OPA policy enforcement for tool and resource filtering" + tags: ["security", "policy", "opa"] + + # LLMGuard Plugin (content filtering) + #- name: LLMGuardPlugin + + # Build from GitHub repository + # repo: https://github.com/terylt/mcp-context-forge.git + # ref: feat/use_mtls_plugins + # context: plugins/external/llmguard + # containerfile: Containerfile + # target: builder # Build only the 'builder' stage (multi-stage build) + + # Defaults: port=8000, host_port auto-assigned (8000, 8001, ...) + # port: 8001 + # expose_port: true + + # env_file will auto-detect deploy/env/.env.LLMGuardPlugin if not specified + # env_vars: + # LOG_LEVEL: DEBUG + + # mtls_enabled: true + +# mTLS Certificate configuration +certificates: + validity_days: 825 + auto_generate: true + ca_path: ./certs/mcp/ca + gateway_path: ./certs/mcp/gateway + plugins_path: ./certs/mcp/plugins diff --git a/examples/deployment-configs/deploy-compose.yaml b/examples/deployment-configs/deploy-compose.yaml new file mode 100644 index 000000000..800700f9a --- /dev/null +++ b/examples/deployment-configs/deploy-compose.yaml @@ -0,0 +1,96 @@ +# MCP Stack - Local Docker Compose Test Configuration +# This config deploys MCP Gateway + external plugins locally with mTLS + +deployment: + type: compose + project_name: mcp-stack-test + +# MCP Gateway configuration +gateway: + # Use local gateway image (build first with: make container-build) + image: mcpgateway/mcpgateway:latest + + port: 4444 + host_port: 4444 # Expose on localhost:4444 + + # Environment configuration + # env_file will auto-detect deploy/env/.env.gateway if not specified + env_vars: + LOG_LEVEL: DEBUG + HOST: 0.0.0.0 + PORT: 4444 + # Enable features + MCPGATEWAY_UI_ENABLED: "true" + MCPGATEWAY_ADMIN_API_ENABLED: "true" + MCPGATEWAY_A2A_ENABLED: "true" + + # Auth + AUTH_REQUIRED: "false" # Disabled for easy testing + + # Federation + MCPGATEWAY_ENABLE_FEDERATION: "false" + + # mTLS client configuration (gateway connects to plugins) + mtls_enabled: false + + # Note: plugins-config.yaml is auto-generated from the plugins section below + # No need to specify config_file anymore! + +# External plugins +plugins: + # OPA Plugin Filter + - name: OPAPluginFilter + + # Build from GitHub repository + repo: https://github.com/terylt/mcp-context-forge.git + ref: feat/use_mtls_plugins + context: plugins/external/opa + containerfile: Containerfile + + # Defaults: port=8000, host_port auto-assigned (8000, 8001, ...) + expose_port: true # Expose for testing + + # env_file will auto-detect deploy/env/.env.OPAPluginFilter if not specified + env_vars: + LOG_LEVEL: DEBUG + + # OPA-specific settings + OPA_POLICY_PATH: /app/policies + + # mTLS server configuration + mtls_enabled: false + + # Plugin manager overrides (client-side configuration) + plugin_overrides: + priority: 10 + mode: "enforce" + description: "OPA policy enforcement for tool and resource filtering" + tags: ["security", "policy", "opa"] + + # LLMGuard Plugin (content filtering) + #- name: LLMGuardPlugin + + # # Build from GitHub repository + # repo: https://github.com/terylt/mcp-context-forge.git + # ref: feat/use_mtls_plugins + # context: plugins/external/llmguard + # containerfile: Containerfile + # target: builder # Build only the 'builder' stage (multi-stage build) + + # Defaults: port=8000, host_port auto-assigned (8000, 8001, ...) + # port: 8001 + # expose_port: true + + # env_file will auto-detect deploy/env/.env.LLMGuardPlugin if not specified + # env_vars: + # LOG_LEVEL: DEBUG + + # mtls_enabled: false + +# mTLS Certificate configuration +certificates: + validity_days: 825 + auto_generate: true + ca_path: ./certs/mcp/ca + gateway_path: ./certs/mcp/gateway + plugins_path: ./certs/mcp/plugins diff --git a/examples/deployment-configs/deploy-k8s-cert-manager.yaml b/examples/deployment-configs/deploy-k8s-cert-manager.yaml new file mode 100644 index 000000000..d59c7bc57 --- /dev/null +++ b/examples/deployment-configs/deploy-k8s-cert-manager.yaml @@ -0,0 +1,100 @@ +# MCP Stack - Kubernetes Configuration with cert-manager +# This config uses cert-manager for automatic certificate management +# +# Prerequisites: +# 1. Install cert-manager in your cluster +# 2. Apply cert-manager-issuer-example.yaml to create the CA Issuer +# 3. Deploy this config + +deployment: + type: kubernetes + namespace: mcp-gateway-test + +# MCP Gateway configuration +gateway: + # Use pre-built gateway image + image: mcpgateway/mcpgateway:latest + image_pull_policy: IfNotPresent + + port: 4444 + + # Service configuration + service_type: ClusterIP + service_port: 4444 + + # Resource limits + replicas: 1 + memory_request: 256Mi + memory_limit: 512Mi + cpu_request: 100m + cpu_limit: 500m + + # Environment configuration + env_vars: + LOG_LEVEL: DEBUG + HOST: 0.0.0.0 + PORT: 4444 + MCPGATEWAY_UI_ENABLED: "true" + MCPGATEWAY_ADMIN_API_ENABLED: "true" + MCPGATEWAY_A2A_ENABLED: "true" + AUTH_REQUIRED: "false" + MCPGATEWAY_ENABLE_FEDERATION: "false" + + # mTLS client configuration (gateway connects to plugins) + mtls_enabled: true + mtls_verify: true + mtls_check_hostname: false + +# External plugins +plugins: + # OPA Plugin Filter + - name: OPAPluginFilter + + # Use pre-built image for faster testing + image: mcpgateway-opapluginfilter:latest + image_pull_policy: IfNotPresent + + port: 8000 + + # Service configuration + service_type: ClusterIP + service_port: 8000 + + # Resource limits + replicas: 1 + memory_request: 128Mi + memory_limit: 256Mi + cpu_request: 50m + cpu_limit: 200m + + env_vars: + LOG_LEVEL: DEBUG + OPA_POLICY_PATH: /app/policies + + mtls_enabled: true + + # Plugin manager overrides + plugin_overrides: + priority: 10 + mode: "enforce" + description: "OPA policy enforcement" + tags: ["security", "policy", "opa"] + +# cert-manager Certificate configuration +certificates: + # Use cert-manager for automatic certificate management + use_cert_manager: true + + # cert-manager issuer reference (must exist in namespace) + cert_manager_issuer: mcp-ca-issuer + cert_manager_kind: Issuer # or ClusterIssuer + + # Certificate validity (cert-manager will auto-renew at 2/3 of lifetime) + validity_days: 825 # ≈ 2.25 years + + # Local paths not used when use_cert_manager=true + # (included for backward compatibility if switching back) + auto_generate: false + ca_path: ./certs/mcp/ca + gateway_path: ./certs/mcp/gateway + plugins_path: ./certs/mcp/plugins diff --git a/examples/deployment-configs/deploy-k8s.mtls.yaml b/examples/deployment-configs/deploy-k8s.mtls.yaml new file mode 100644 index 000000000..32e653406 --- /dev/null +++ b/examples/deployment-configs/deploy-k8s.mtls.yaml @@ -0,0 +1,84 @@ +# MCP Stack - Kubernetes Test Configuration +# Simple test config using pre-built images + +deployment: + type: kubernetes + namespace: mcp-gateway-test + +# MCP Gateway configuration +gateway: + # Use pre-built gateway image + image: mcpgateway/mcpgateway:latest + image_pull_policy: IfNotPresent + + port: 4444 + + # Service configuration + service_type: ClusterIP + service_port: 4444 + + # Resource limits + replicas: 1 + memory_request: 256Mi + memory_limit: 512Mi + cpu_request: 100m + cpu_limit: 500m + + # Environment configuration + env_vars: + LOG_LEVEL: DEBUG + HOST: 0.0.0.0 + PORT: 4444 + MCPGATEWAY_UI_ENABLED: "true" + MCPGATEWAY_ADMIN_API_ENABLED: "true" + MCPGATEWAY_A2A_ENABLED: "true" + AUTH_REQUIRED: "false" + MCPGATEWAY_ENABLE_FEDERATION: "false" + + # mTLS client configuration (gateway connects to plugins) + mtls_enabled: true + mtls_verify: true # Verify server certificates (default: true) + mtls_check_hostname: false # Don't verify hostname (default: false for compose) + +# External plugins +plugins: + # OPA Plugin Filter + - name: OPAPluginFilter + + # Use pre-built image for faster testing + image: mcpgateway-opapluginfilter:latest + image_pull_policy: IfNotPresent + + port: 8000 + + # Service configuration + service_type: ClusterIP + service_port: 8000 + + # Resource limits + replicas: 1 + memory_request: 128Mi + memory_limit: 256Mi + cpu_request: 50m + cpu_limit: 200m + + env_vars: + LOG_LEVEL: DEBUG + OPA_POLICY_PATH: /app/policies + + mtls_enabled: true + + # Plugin manager overrides + plugin_overrides: + priority: 10 + mode: "enforce" + description: "OPA policy enforcement" + tags: ["security", "policy", "opa"] + +# mTLS Certificate configuration +certificates: + validity_days: 825 + auto_generate: true + ca_path: ./certs/mcp/ca + gateway_path: ./certs/mcp/gateway + plugins_path: ./certs/mcp/plugins diff --git a/examples/deployment-configs/deploy-k8s.yaml b/examples/deployment-configs/deploy-k8s.yaml new file mode 100644 index 000000000..518e61bd4 --- /dev/null +++ b/examples/deployment-configs/deploy-k8s.yaml @@ -0,0 +1,82 @@ +# MCP Stack - Kubernetes Test Configuration +# Simple test config using pre-built images + +deployment: + type: kubernetes + namespace: mcp-gateway-test + +# MCP Gateway configuration +gateway: + # Use pre-built gateway image + image: mcpgateway/mcpgateway:latest + image_pull_policy: IfNotPresent + + port: 4444 + + # Service configuration + service_type: ClusterIP + service_port: 4444 + + # Resource limits + replicas: 1 + memory_request: 256Mi + memory_limit: 512Mi + cpu_request: 100m + cpu_limit: 500m + + # Environment configuration + env_vars: + LOG_LEVEL: DEBUG + HOST: 0.0.0.0 + PORT: 4444 + MCPGATEWAY_UI_ENABLED: "true" + MCPGATEWAY_ADMIN_API_ENABLED: "true" + MCPGATEWAY_A2A_ENABLED: "true" + AUTH_REQUIRED: "false" + MCPGATEWAY_ENABLE_FEDERATION: "false" + + # mTLS disabled for simplicity + mtls_enabled: false + +# External plugins +plugins: + # OPA Plugin Filter + - name: OPAPluginFilter + + # Use pre-built image for faster testing + image: mcpgateway-opapluginfilter:latest + image_pull_policy: IfNotPresent + + port: 8000 + + # Service configuration + service_type: ClusterIP + service_port: 8000 + + # Resource limits + replicas: 1 + memory_request: 128Mi + memory_limit: 256Mi + cpu_request: 50m + cpu_limit: 200m + + env_vars: + LOG_LEVEL: DEBUG + OPA_POLICY_PATH: /app/policies + + mtls_enabled: false + + # Plugin manager overrides + plugin_overrides: + priority: 10 + mode: "enforce" + description: "OPA policy enforcement" + tags: ["security", "policy", "opa"] + +# mTLS Certificate configuration +certificates: + validity_days: 825 + auto_generate: true + ca_path: ./certs/mcp/ca + gateway_path: ./certs/mcp/gateway + plugins_path: ./certs/mcp/plugins diff --git a/examples/deployment-configs/deploy-openshift-local-registry.yaml b/examples/deployment-configs/deploy-openshift-local-registry.yaml new file mode 100644 index 000000000..95e1e8f52 --- /dev/null +++ b/examples/deployment-configs/deploy-openshift-local-registry.yaml @@ -0,0 +1,146 @@ +# MCP Stack - OpenShift Local with Registry Push +# Build from source and push to OpenShift internal registry +# +# This example demonstrates how to build images locally and push them to +# OpenShift's internal registry. This is useful for: +# - Testing images in a production-like environment +# - Avoiding ImagePullBackOff errors when deploying to OpenShift +# - Sharing images across multiple namespaces +# +# Prerequisites: +# 1. Install cert-manager in your cluster +# 2. Apply cert-manager-issuer-example.yaml to create the CA Issuer +# 3. Authenticate to OpenShift internal registry: +# podman login $(oc registry info) -u $(oc whoami) -p $(oc whoami -t) +# 4. Deploy this config + +deployment: + type: kubernetes + namespace: mcp-gateway-test + container_engine: podman + openshift: + create_routes: true + domain: apps-crc.testing # Optional, auto-detected if omitted + tls_termination: edge + +# MCP Gateway configuration +gateway: + # Build gateway from current repository + repo: https://github.com/terylt/mcp-context-forge.git + ref: feat/configurable_plugin_deployment + context: . + containerfile: Containerfile + image: mcpgateway-gateway:latest + + port: 4444 + + # Service configuration + service_type: ClusterIP + service_port: 4444 + + # Resource limits + replicas: 1 + memory_request: 256Mi + memory_limit: 512Mi + cpu_request: 100m + cpu_limit: 500m + + # Environment configuration + env_vars: + LOG_LEVEL: DEBUG + HOST: 0.0.0.0 + PORT: 4444 + MCPGATEWAY_UI_ENABLED: "true" + MCPGATEWAY_ADMIN_API_ENABLED: "true" + MCPGATEWAY_A2A_ENABLED: "true" + AUTH_REQUIRED: "false" + MCPGATEWAY_ENABLE_FEDERATION: "false" + + # mTLS client configuration (gateway connects to plugins) + mtls_enabled: true + mtls_verify: true + mtls_check_hostname: false + + # Container registry configuration + # Build locally, then tag and push to OpenShift internal registry + registry: + enabled: true + # OpenShift internal registry URL (get with: oc registry info) + url: default-route-openshift-image-registry.apps-crc.testing + # Namespace where images will be pushed (must have push permissions) + namespace: mcp-gateway-test + # Push image after build + push: true + # imagePullPolicy for Kubernetes pods + image_pull_policy: Always + +# External plugins +plugins: + # OPA Plugin Filter - build from source and push to registry + - name: OPAPluginFilter + + # Build from repository + repo: https://github.com/terylt/mcp-context-forge.git + ref: feat/use_mtls_plugins + context: plugins/external/opa + containerfile: Containerfile + image: mcpgateway-opapluginfilter:latest + + port: 8000 + + # Service configuration + service_type: ClusterIP + service_port: 8000 + + # Resource limits + replicas: 1 + memory_request: 128Mi + memory_limit: 256Mi + cpu_request: 50m + cpu_limit: 200m + + env_vars: + LOG_LEVEL: DEBUG + OPA_POLICY_PATH: /app/policies + + mtls_enabled: true + + # Container registry configuration + # Push plugin image to same registry as gateway + registry: + enabled: true + url: default-route-openshift-image-registry.apps-crc.testing + namespace: mcp-gateway-test + push: true + image_pull_policy: Always + + # Plugin manager overrides + plugin_overrides: + priority: 10 + mode: "enforce" + description: "OPA policy enforcement" + tags: ["security", "policy", "opa"] + +# Infrastructure services +infrastructure: + postgres: + enabled: true + image: quay.io/sclorg/postgresql-15-c9s:latest + user: mcpuser # Use non-'postgres' username for Red Hat images + database: mcp + password: mysecretpassword + +# cert-manager Certificate configuration +certificates: + # Use cert-manager for automatic certificate management + use_cert_manager: true + + # cert-manager issuer reference (must exist in namespace) + cert_manager_issuer: mcp-ca-issuer + cert_manager_kind: Issuer # or ClusterIssuer + + # Certificate validity (cert-manager will auto-renew at 2/3 of lifetime) + validity_days: 825 # ≈ 2.25 years + + # Local paths not used when use_cert_manager=true + auto_generate: false diff --git a/examples/deployment-configs/deploy-openshift-local.yaml b/examples/deployment-configs/deploy-openshift-local.yaml new file mode 100644 index 000000000..8256478ba --- /dev/null +++ b/examples/deployment-configs/deploy-openshift-local.yaml @@ -0,0 +1,131 @@ +# MCP Stack - OpenShift Local Configuration with cert-manager +# Build from source for local development +# +# Prerequisites: +# 1. Install cert-manager in your cluster +# 2. Apply cert-manager-issuer-example.yaml to create the CA Issuer +# 3. (Optional) Authenticate to OpenShift internal registry: +# podman login $(oc registry info) -u $(oc whoami) -p $(oc whoami -t) +# 4. Deploy this config + +deployment: + type: kubernetes + namespace: mcp-gateway-test + +# MCP Gateway configuration +gateway: + # Build gateway from current repository + repo: https://github.com/terylt/mcp-context-forge.git + ref: feat/configurable_plugin_deployment + context: . + containerfile: Containerfile + image: mcpgateway-gateway:latest + + port: 4444 + + # Service configuration + service_type: ClusterIP + service_port: 4444 + + # Resource limits + replicas: 1 + memory_request: 256Mi + memory_limit: 512Mi + cpu_request: 100m + cpu_limit: 500m + + # Environment configuration + env_vars: + LOG_LEVEL: DEBUG + HOST: 0.0.0.0 + PORT: 4444 + MCPGATEWAY_UI_ENABLED: "true" + MCPGATEWAY_ADMIN_API_ENABLED: "true" + MCPGATEWAY_A2A_ENABLED: "true" + AUTH_REQUIRED: "false" + MCPGATEWAY_ENABLE_FEDERATION: "false" + + # mTLS client configuration (gateway connects to plugins) + mtls_enabled: true + mtls_verify: true + mtls_check_hostname: false + + # Container registry configuration (optional) + # Uncomment to push images to OpenShift internal registry + # registry: + # enabled: true + # url: default-route-openshift-image-registry.apps-crc.testing + # namespace: mcp-gateway-test + # push: true + # image_pull_policy: Always + +# External plugins +plugins: + # OPA Plugin Filter - build from source + - name: OPAPluginFilter + + # Build from repository + repo: https://github.com/terylt/mcp-context-forge.git + ref: feat/configurable_plugin_deployment + context: plugins/external/opa + containerfile: Containerfile + image: mcpgateway-opapluginfilter:latest + + port: 8000 + + # Service configuration + service_type: ClusterIP + service_port: 8000 + + # Resource limits + replicas: 1 + memory_request: 128Mi + memory_limit: 256Mi + cpu_request: 50m + cpu_limit: 200m + + env_vars: + LOG_LEVEL: DEBUG + OPA_POLICY_PATH: /app/policies + + mtls_enabled: true + + # Container registry configuration (optional) + # Uncomment to push images to OpenShift internal registry + # registry: + # enabled: true + # url: default-route-openshift-image-registry.apps-crc.testing + # namespace: mcp-gateway-test + # push: true + # image_pull_policy: Always + + # Plugin manager overrides + plugin_overrides: + priority: 10 + mode: "enforce" + description: "OPA policy enforcement" + tags: ["security", "policy", "opa"] + +# Infrastructure services +infrastructure: + postgres: + enabled: true + image: quay.io/sclorg/postgresql-15-c9s:latest + user: mcpuser # Use non-'postgres' username for Red Hat images + database: mcp + password: mysecretpassword + +# cert-manager Certificate configuration +certificates: + # Use cert-manager for automatic certificate management + use_cert_manager: true + + # cert-manager issuer reference (must exist in namespace) + cert_manager_issuer: mcp-ca-issuer + cert_manager_kind: Issuer # or ClusterIssuer + + # Certificate validity (cert-manager will auto-renew at 2/3 of lifetime) + validity_days: 825 # ≈ 2.25 years + + # Local paths not used when use_cert_manager=true + auto_generate: false diff --git a/gunicorn.config.py b/gunicorn.config.py index f57ef1766..df888da42 100644 --- a/gunicorn.config.py +++ b/gunicorn.config.py @@ -14,6 +14,9 @@ Reference: https://stackoverflow.com/questions/10855197/frequent-worker-timeout """ +# Standard +import os + # First-Party # Import Pydantic Settings singleton from mcpgateway.config import settings @@ -46,13 +49,58 @@ # accesslog = '/tmp/gunicorn-accesslog' # access_log_format = '%(h)s %(l)s %(u)s %(t)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s"' +# SSL/TLS Configuration +# Note: certfile and keyfile are set via command-line arguments in run-gunicorn.sh +# If a passphrase is provided via SSL_KEY_PASSWORD environment variable, +# the key will be decrypted by the SSL key manager before Gunicorn starts. # certfile = 'certs/cert.pem' # keyfile = 'certs/key.pem' # ca-certs = '/etc/ca_bundle.crt' +# Global variable to store the prepared key file path +_prepared_key_file = None + # server hooks +def on_starting(server): + """Called just before the master process is initialized. + + This is where we handle passphrase-protected SSL keys by decrypting + them to a temporary file before Gunicorn workers start. + """ + global _prepared_key_file + + # Check if SSL is enabled via environment variable (set by run-gunicorn.sh) + # and a passphrase is provided + ssl_enabled = os.environ.get("SSL", "false").lower() == "true" + ssl_key_password = os.environ.get("SSL_KEY_PASSWORD") + + if ssl_enabled and ssl_key_password: + try: + from mcpgateway.utils.ssl_key_manager import prepare_ssl_key + + # Get the key file path from environment (set by run-gunicorn.sh) + key_file = os.environ.get("KEY_FILE", "certs/key.pem") + + server.log.info(f"Preparing passphrase-protected SSL key: {key_file}") + + # Decrypt the key and get the temporary file path + _prepared_key_file = prepare_ssl_key(key_file, ssl_key_password) + + server.log.info(f"SSL key prepared successfully: {_prepared_key_file}") + + # Update the keyfile setting to use the decrypted temporary file + # This is a bit of a hack, but Gunicorn doesn't provide a better way + # to modify the keyfile after it's been set via command line + if hasattr(server, 'cfg'): + server.cfg.set('keyfile', _prepared_key_file) + + except Exception as e: + server.log.error(f"Failed to prepare SSL key: {e}") + raise + + def when_ready(server): server.log.info("Server is ready. Spawning workers") diff --git a/mcpgateway/admin.py b/mcpgateway/admin.py index e45604d54..123bee450 100644 --- a/mcpgateway/admin.py +++ b/mcpgateway/admin.py @@ -47,7 +47,7 @@ from pydantic import SecretStr, ValidationError from pydantic_core import ValidationError as CoreValidationError from sqlalchemy import and_, case, cast, desc, func, or_, select, String -from sqlalchemy.exc import IntegrityError +from sqlalchemy.exc import IntegrityError, InvalidRequestError, OperationalError from sqlalchemy.orm import joinedload, Session from sqlalchemy.sql.functions import coalesce from starlette.datastructures import UploadFile as StarletteUploadFile @@ -105,6 +105,7 @@ ) from mcpgateway.services.a2a_service import A2AAgentError, A2AAgentNameConflictError, A2AAgentNotFoundError, A2AAgentService from mcpgateway.services.argon2_service import Argon2PasswordService +from mcpgateway.services.audit_trail_service import get_audit_trail_service from mcpgateway.services.catalog_service import catalog_service from mcpgateway.services.email_auth_service import AuthenticationError, EmailAuthService, PasswordValidationError from mcpgateway.services.encryption_service import get_encryption_service @@ -120,6 +121,7 @@ from mcpgateway.services.resource_service import ResourceNotFoundError, ResourceService, ResourceURIConflictError from mcpgateway.services.root_service import RootService from mcpgateway.services.server_service import ServerError, ServerNameConflictError, ServerNotFoundError, ServerService +from mcpgateway.services.structured_logger import get_structured_logger from mcpgateway.services.tag_service import TagService from mcpgateway.services.team_management_service import TeamManagementService from mcpgateway.services.tool_service import ToolError, ToolNameConflictError, ToolNotFoundError, ToolService @@ -8079,6 +8081,87 @@ async def admin_delete_gateway(gateway_id: str, request: Request, db: Session = return RedirectResponse(f"{root_path}/admin#gateways", status_code=303) +@admin_router.get("/resources/test/{resource_uri:path}") +async def admin_test_resource(resource_uri: str, db: Session = Depends(get_db), user=Depends(get_current_user_with_permissions)) -> Dict[str, Any]: + """ + Test reading a resource by its URI for the admin UI. + + Args: + resource_uri: The full resource URI (may include encoded characters). + db: Database session dependency. + user: Authenticated user with proper permissions. + + Returns: + A dictionary containing the resolved resource content. + + Raises: + HTTPException: If the resource is not found. + Exception: For unexpected errors. + + Examples: + >>> import asyncio + >>> from unittest.mock import AsyncMock, MagicMock + >>> from mcpgateway.services.resource_service import ResourceNotFoundError + >>> from fastapi import HTTPException + + >>> mock_db = MagicMock() + >>> mock_user = {"email": "test_user"} + >>> test_uri = "resource://example/demo" + + >>> # --- Mock successful content read --- + >>> original_read_resource = resource_service.read_resource + >>> resource_service.read_resource = AsyncMock(return_value={"hello": "world"}) + + >>> async def test_success(): + ... result = await admin_test_resource(test_uri, mock_db, mock_user) + ... return result["content"] == {"hello": "world"} + + >>> asyncio.run(test_success()) + True + + >>> # --- Mock resource not found --- + >>> resource_service.read_resource = AsyncMock( + ... side_effect=ResourceNotFoundError("Not found") + ... ) + + >>> async def test_not_found(): + ... try: + ... await admin_test_resource("resource://missing", mock_db, mock_user) + ... return False + ... except HTTPException as e: + ... return e.status_code == 404 and "Not found" in e.detail + + >>> asyncio.run(test_not_found()) + True + + >>> # --- Mock unexpected exception --- + >>> resource_service.read_resource = AsyncMock(side_effect=Exception("Boom")) + + >>> async def test_error(): + ... try: + ... await admin_test_resource(test_uri, mock_db, mock_user) + ... return False + ... except Exception as e: + ... return str(e) == "Boom" + + >>> asyncio.run(test_error()) + True + + >>> # Restore original method + >>> resource_service.read_resource = original_read_resource + """ + LOGGER.debug(f"User {get_user_email(user)} requested details for resource ID {resource_uri}") + + try: + resource_content = await resource_service.read_resource(db, resource_uri=resource_uri) + return {"content": resource_content} + except ResourceNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) + except Exception as e: + LOGGER.error(f"Error getting resource for {resource_uri}: {e}") + raise e + + @admin_router.get("/resources/{resource_id}") async def admin_get_resource(resource_id: str, db: Session = Depends(get_db), user=Depends(get_current_user_with_permissions)) -> Dict[str, Any]: """Get resource details for the admin UI. @@ -8089,7 +8172,7 @@ async def admin_get_resource(resource_id: str, db: Session = Depends(get_db), us user: Authenticated user. Returns: - A dictionary containing resource details and its content. + A dictionary containing resource details. Raises: HTTPException: If the resource is not found. @@ -8098,77 +8181,79 @@ async def admin_get_resource(resource_id: str, db: Session = Depends(get_db), us Examples: >>> import asyncio >>> from unittest.mock import AsyncMock, MagicMock - >>> from mcpgateway.schemas import ResourceRead, ResourceMetrics, ResourceContent + >>> from mcpgateway.schemas import ResourceRead, ResourceMetrics >>> from datetime import datetime, timezone - >>> from mcpgateway.services.resource_service import ResourceNotFoundError # Added import + >>> from mcpgateway.services.resource_service import ResourceNotFoundError >>> from fastapi import HTTPException >>> >>> mock_db = MagicMock() - >>> mock_user = {"email": "test_user", "db": mock_db} + >>> mock_user = {"email": "test_user"} + >>> resource_id = "1" >>> resource_uri = "test://resource/get" - >>> resource_id = "ca627760127d409080fdefc309147e08" >>> >>> # Mock resource data >>> mock_resource = ResourceRead( ... id=resource_id, uri=resource_uri, name="Get Resource", description="Test", ... mime_type="text/plain", size=10, created_at=datetime.now(timezone.utc), - ... updated_at=datetime.now(timezone.utc), enabled=True, metrics=ResourceMetrics( + ... updated_at=datetime.now(timezone.utc), is_active=True,enabled=True, + ... metrics=ResourceMetrics( ... total_executions=0, successful_executions=0, failed_executions=0, - ... failure_rate=0.0, min_response_time=0.0, max_response_time=0.0, avg_response_time=0.0, - ... last_execution_time=None + ... failure_rate=0.0, min_response_time=0.0, max_response_time=0.0, + ... avg_response_time=0.0, last_execution_time=None ... ), ... tags=[] ... ) - >>> mock_content = ResourceContent(id=str(resource_id), type="resource", uri=resource_uri, mime_type="text/plain", text="Hello content") >>> - >>> # Mock service methods + >>> # Mock service call >>> original_get_resource_by_id = resource_service.get_resource_by_id - >>> original_read_resource = resource_service.read_resource >>> resource_service.get_resource_by_id = AsyncMock(return_value=mock_resource) - >>> resource_service.read_resource = AsyncMock(return_value=mock_content) >>> - >>> # Test successful retrieval - >>> async def test_admin_get_resource_success(): + >>> # Test: successful retrieval + >>> async def test_success(): ... result = await admin_get_resource(resource_id, mock_db, mock_user) - ... return isinstance(result, dict) and result['resource']['id'] == resource_id and result['content'].text == "Hello content" # Corrected to .text + ... return result["resource"]["id"] == resource_id >>> - >>> asyncio.run(test_admin_get_resource_success()) + >>> asyncio.run(test_success()) True >>> - >>> # Test resource not found - >>> resource_service.get_resource_by_id = AsyncMock(side_effect=ResourceNotFoundError("Resource not found")) - >>> async def test_admin_get_resource_not_found(): + >>> # Test: resource not found + >>> resource_service.get_resource_by_id = AsyncMock( + ... side_effect=ResourceNotFoundError("Resource not found") + ... ) + >>> + >>> async def test_not_found(): ... try: ... await admin_get_resource("39334ce0ed2644d79ede8913a66930c9", mock_db, mock_user) ... return False ... except HTTPException as e: ... return e.status_code == 404 and "Resource not found" in e.detail >>> - >>> asyncio.run(test_admin_get_resource_not_found()) + >>> asyncio.run(test_not_found()) True >>> - >>> # Test exception during content read (resource found but content fails) - >>> resource_service.get_resource_by_id = AsyncMock(return_value=mock_resource) # Resource found - >>> resource_service.read_resource = AsyncMock(side_effect=Exception("Content read error")) - >>> async def test_admin_get_resource_content_error(): + >>> # Test: unexpected exception + >>> resource_service.get_resource_by_id = AsyncMock( + ... side_effect=Exception("Unexpected error") + ... ) + >>> + >>> async def test_exception(): ... try: ... await admin_get_resource(resource_id, mock_db, mock_user) ... return False ... except Exception as e: - ... return str(e) == "Content read error" + ... return str(e) == "Unexpected error" >>> - >>> asyncio.run(test_admin_get_resource_content_error()) + >>> asyncio.run(test_exception()) True >>> - >>> # Restore original methods + >>> # Restore original method >>> resource_service.get_resource_by_id = original_get_resource_by_id - >>> resource_service.read_resource = original_read_resource """ LOGGER.debug(f"User {get_user_email(user)} requested details for resource ID {resource_id}") try: resource = await resource_service.get_resource_by_id(db, resource_id) - content = await resource_service.read_resource(db, resource_id) - return {"resource": resource.model_dump(by_alias=True), "content": content} + # content = await resource_service.read_resource(db, resource_id=resource_id) + return {"resource": resource.model_dump(by_alias=True)} # , "content": None} except ResourceNotFoundError as e: raise HTTPException(status_code=404, detail=str(e)) except Exception as e: @@ -8286,6 +8371,17 @@ async def admin_add_resource(request: Request, db: Session = Depends(get_db), us status_code=200, ) except Exception as ex: + # Roll back only when a transaction is active to avoid sqlite3 "no transaction" errors. + try: + active_transaction = db.get_transaction() if hasattr(db, "get_transaction") else None + if db.is_active and active_transaction is not None: + db.rollback() + except (InvalidRequestError, OperationalError) as rollback_error: + LOGGER.warning( + "Rollback failed (ignoring for SQLite compatibility): %s", + rollback_error, + ) + if isinstance(ex, ValidationError): LOGGER.error(f"ValidationError in admin_add_resource: {ErrorFormatter.format_validation_error(ex)}") return JSONResponse(content=ErrorFormatter.format_validation_error(ex), status_code=422) @@ -8504,7 +8600,11 @@ async def admin_delete_resource(resource_id: str, request: Request, db: Session LOGGER.debug(f"User {get_user_email(user)} is deleting resource ID {resource_id}") error_message = None try: - await resource_service.delete_resource(user["db"] if isinstance(user, dict) else db, resource_id) + await resource_service.delete_resource( + user["db"] if isinstance(user, dict) else db, + resource_id, + user_email=user_email, + ) except PermissionError as e: LOGGER.warning(f"Permission denied for user {user_email} deleting resource {resource_id}: {e}") error_message = str(e) @@ -9593,7 +9693,7 @@ async def admin_test_gateway(request: GatewayTestRequest, team_id: Optional[str] >>> async def test_admin_test_gateway(): ... with patch('mcpgateway.admin.ResilientHttpClient') as mock_client_class: ... mock_client_class.return_value = MockClient() - ... response = await admin_test_gateway(mock_request, mock_user) + ... response = await admin_test_gateway(mock_request, None, mock_user, mock_db) ... return isinstance(response, GatewayTestResponse) and response.status_code == 200 >>> >>> result = asyncio.run(test_admin_test_gateway()) @@ -9619,7 +9719,7 @@ async def admin_test_gateway(request: GatewayTestRequest, team_id: Optional[str] >>> async def test_admin_test_gateway_text_response(): ... with patch('mcpgateway.admin.ResilientHttpClient') as mock_client_class: ... mock_client_class.return_value = MockClientTextOnly() - ... response = await admin_test_gateway(mock_request, mock_user) + ... response = await admin_test_gateway(mock_request, None, mock_user, mock_db) ... return isinstance(response, GatewayTestResponse) and response.body.get("details") == "plain text response" >>> >>> asyncio.run(test_admin_test_gateway_text_response()) @@ -9637,7 +9737,7 @@ async def admin_test_gateway(request: GatewayTestRequest, team_id: Optional[str] >>> async def test_admin_test_gateway_network_error(): ... with patch('mcpgateway.admin.ResilientHttpClient') as mock_client_class: ... mock_client_class.return_value = MockClientError() - ... response = await admin_test_gateway(mock_request, mock_user) + ... response = await admin_test_gateway(mock_request, None, mock_user, mock_db) ... return response.status_code == 502 and "Network error" in str(response.body) >>> >>> asyncio.run(test_admin_test_gateway_network_error()) @@ -9655,7 +9755,7 @@ async def admin_test_gateway(request: GatewayTestRequest, team_id: Optional[str] >>> async def test_admin_test_gateway_post(): ... with patch('mcpgateway.admin.ResilientHttpClient') as mock_client_class: ... mock_client_class.return_value = MockClient() - ... response = await admin_test_gateway(mock_request_post, mock_user) + ... response = await admin_test_gateway(mock_request_post, None, mock_user, mock_db) ... return isinstance(response, GatewayTestResponse) and response.status_code == 200 >>> >>> asyncio.run(test_admin_test_gateway_post()) @@ -9673,7 +9773,7 @@ async def admin_test_gateway(request: GatewayTestRequest, team_id: Optional[str] >>> async def test_admin_test_gateway_trailing_slash(): ... with patch('mcpgateway.admin.ResilientHttpClient') as mock_client_class: ... mock_client_class.return_value = MockClient() - ... response = await admin_test_gateway(mock_request_trailing, mock_user) + ... response = await admin_test_gateway(mock_request_trailing, None, mock_user, mock_db) ... return isinstance(response, GatewayTestResponse) and response.status_code == 200 >>> >>> asyncio.run(test_admin_test_gateway_trailing_slash()) @@ -9763,11 +9863,56 @@ async def admin_test_gateway(request: GatewayTestRequest, team_id: Optional[str] except json.JSONDecodeError: response_body = {"details": response.text} + # Structured logging: Log successful gateway test + structured_logger = get_structured_logger("gateway_service") + structured_logger.log( + level="INFO", + message=f"Gateway test completed: {request.base_url}", + event_type="gateway_tested", + component="gateway_service", + user_email=get_user_email(user), + team_id=team_id, + resource_type="gateway", + resource_id=gateway.id if gateway else None, + custom_fields={ + "gateway_name": gateway.name if gateway else None, + "gateway_url": str(request.base_url), + "test_method": request.method, + "test_path": request.path, + "status_code": response.status_code, + "latency_ms": latency_ms, + }, + db=db, + ) + return GatewayTestResponse(status_code=response.status_code, latency_ms=latency_ms, body=response_body) except httpx.RequestError as e: LOGGER.warning(f"Gateway test failed: {e}") latency_ms = int((time.monotonic() - start_time) * 1000) + + # Structured logging: Log failed gateway test + structured_logger = get_structured_logger("gateway_service") + structured_logger.log( + level="ERROR", + message=f"Gateway test failed: {request.base_url}", + event_type="gateway_test_failed", + component="gateway_service", + user_email=get_user_email(user), + team_id=team_id, + resource_type="gateway", + resource_id=gateway.id if gateway else None, + error=e, + custom_fields={ + "gateway_name": gateway.name if gateway else None, + "gateway_url": str(request.base_url), + "test_method": request.method, + "test_path": request.path, + "latency_ms": latency_ms, + }, + db=db, + ) + return GatewayTestResponse(status_code=502, latency_ms=latency_ms, body={"error": "Request failed", "details": str(e)}) @@ -11813,6 +11958,7 @@ async def admin_test_a2a_agent( return JSONResponse(content={"success": False, "error": "A2A features are disabled"}, status_code=403) try: + user_email = get_user_email(user) # Get the agent by ID agent = await a2a_service.get_agent(db, agent_id) @@ -11828,7 +11974,14 @@ async def admin_test_a2a_agent( test_params = {"message": "Hello from MCP Gateway Admin UI test!", "test": True, "timestamp": int(time.time())} # Invoke the agent - result = await a2a_service.invoke_agent(db, agent.name, test_params, "admin_test") + result = await a2a_service.invoke_agent( + db, + agent.name, + test_params, + "admin_test", + user_email=user_email, + user_id=user_email, + ) return JSONResponse(content={"success": True, "result": result, "agent_name": agent.name, "test_timestamp": time.time()}) @@ -12453,6 +12606,7 @@ async def list_plugins( HTTPException: If there's an error retrieving plugins """ LOGGER.debug(f"User {get_user_email(user)} requested plugin list") + structured_logger = get_structured_logger() try: # Get plugin service @@ -12473,10 +12627,35 @@ async def list_plugins( enabled_count = sum(1 for p in plugins if p["status"] == "enabled") disabled_count = sum(1 for p in plugins if p["status"] == "disabled") + # Log plugin marketplace browsing activity + structured_logger.info( + "User browsed plugin marketplace", + user_id=str(user.id), + user_email=get_user_email(user), + component="plugin_marketplace", + category="business_logic", + resource_type="plugin_list", + resource_action="browse", + custom_fields={ + "search_query": search, + "filter_mode": mode, + "filter_hook": hook, + "filter_tag": tag, + "results_count": len(plugins), + "enabled_count": enabled_count, + "disabled_count": disabled_count, + "has_filters": any([search, mode, hook, tag]), + }, + db=db, + ) + return PluginListResponse(plugins=plugins, total=len(plugins), enabled_count=enabled_count, disabled_count=disabled_count) except Exception as e: LOGGER.error(f"Error listing plugins: {e}") + structured_logger.error( + "Failed to list plugins in marketplace", user_id=str(user.id), user_email=get_user_email(user), error=e, component="plugin_marketplace", category="business_logic", db=db + ) raise HTTPException(status_code=500, detail=str(e)) @@ -12496,6 +12675,7 @@ async def get_plugin_stats(request: Request, db: Session = Depends(get_db), user HTTPException: If there's an error getting plugin statistics """ LOGGER.debug(f"User {get_user_email(user)} requested plugin statistics") + structured_logger = get_structured_logger() try: # Get plugin service @@ -12509,10 +12689,33 @@ async def get_plugin_stats(request: Request, db: Session = Depends(get_db), user # Get statistics stats = plugin_service.get_plugin_statistics() + # Log marketplace analytics access + structured_logger.info( + "User accessed plugin marketplace statistics", + user_id=str(user.id), + user_email=get_user_email(user), + component="plugin_marketplace", + category="business_logic", + resource_type="plugin_stats", + resource_action="view", + custom_fields={ + "total_plugins": stats.get("total_plugins", 0), + "enabled_plugins": stats.get("enabled_plugins", 0), + "disabled_plugins": stats.get("disabled_plugins", 0), + "hooks_count": len(stats.get("plugins_by_hook", {})), + "tags_count": len(stats.get("plugins_by_tag", {})), + "authors_count": len(stats.get("plugins_by_author", {})), + }, + db=db, + ) + return PluginStatsResponse(**stats) except Exception as e: LOGGER.error(f"Error getting plugin statistics: {e}") + structured_logger.error( + "Failed to get plugin marketplace statistics", user_id=str(user.id), user_email=get_user_email(user), error=e, component="plugin_marketplace", category="business_logic", db=db + ) raise HTTPException(status_code=500, detail=str(e)) @@ -12533,6 +12736,8 @@ async def get_plugin_details(name: str, request: Request, db: Session = Depends( HTTPException: If plugin not found """ LOGGER.debug(f"User {get_user_email(user)} requested details for plugin {name}") + structured_logger = get_structured_logger() + audit_service = get_audit_trail_service() try: # Get plugin service @@ -12547,14 +12752,53 @@ async def get_plugin_details(name: str, request: Request, db: Session = Depends( plugin = plugin_service.get_plugin_by_name(name) if not plugin: + structured_logger.warning( + f"Plugin '{name}' not found in marketplace", + user_id=str(user.id), + user_email=get_user_email(user), + component="plugin_marketplace", + category="business_logic", + custom_fields={"plugin_name": name, "action": "view_details"}, + db=db, + ) raise HTTPException(status_code=404, detail=f"Plugin '{name}' not found") + # Log plugin view activity + structured_logger.info( + f"User viewed plugin details: '{name}'", + user_id=str(user.id), + user_email=get_user_email(user), + component="plugin_marketplace", + category="business_logic", + resource_type="plugin", + resource_id=name, + resource_action="view_details", + custom_fields={ + "plugin_name": name, + "plugin_version": plugin.get("version"), + "plugin_author": plugin.get("author"), + "plugin_status": plugin.get("status"), + "plugin_mode": plugin.get("mode"), + "plugin_hooks": plugin.get("hooks", []), + "plugin_tags": plugin.get("tags", []), + }, + db=db, + ) + + # Create audit trail for plugin access + audit_service.log_audit( + user_id=str(user.id), user_email=get_user_email(user), resource_type="plugin", resource_id=name, action="view", description=f"Viewed plugin '{name}' details in marketplace", db=db + ) + return PluginDetail(**plugin) except HTTPException: raise except Exception as e: LOGGER.error(f"Error getting plugin details: {e}") + structured_logger.error( + f"Failed to get plugin details: '{name}'", user_id=str(user.id), user_email=get_user_email(user), error=e, component="plugin_marketplace", category="business_logic", db=db + ) raise HTTPException(status_code=500, detail=str(e)) diff --git a/mcpgateway/alembic/versions/356a2d4eed6f_uuid_change_for_prompt_and_resources.py b/mcpgateway/alembic/versions/356a2d4eed6f_uuid_change_for_prompt_and_resources.py index 325096243..b616a0892 100644 --- a/mcpgateway/alembic/versions/356a2d4eed6f_uuid_change_for_prompt_and_resources.py +++ b/mcpgateway/alembic/versions/356a2d4eed6f_uuid_change_for_prompt_and_resources.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- """UUID Change for Prompt and Resources Revision ID: 356a2d4eed6f @@ -25,6 +26,7 @@ def upgrade() -> None: """Upgrade schema.""" conn = op.get_bind() + dialect = conn.dialect.name if hasattr(conn, "dialect") else None # 1) Add temporary id_new column to prompts and populate with uuid.hex op.add_column("prompts", sa.Column("id_new", sa.String(36), nullable=True)) @@ -35,6 +37,8 @@ def upgrade() -> None: conn.execute(text("UPDATE prompts SET id_new = :new WHERE id = :old"), {"new": new_id, "old": old_id}) # 2) Create new prompts table (temporary) with varchar(36) id + prompts_pk_name = "pk_prompts" if dialect == "sqlite" else "pk_prompts_tmp" + prompts_uq_name = "uq_team_owner_name_prompt" if dialect == "sqlite" else "uq_team_owner_name_prompt_tmp" op.create_table( "prompts_tmp", sa.Column("id", sa.String(36), primary_key=True, nullable=False), @@ -61,8 +65,8 @@ def upgrade() -> None: sa.Column("team_id", sa.String(36), nullable=True), sa.Column("owner_email", sa.String(255), nullable=True), sa.Column("visibility", sa.String(20), nullable=False, server_default="public"), - sa.UniqueConstraint("team_id", "owner_email", "name", name="uq_team_owner_name_prompt"), - sa.PrimaryKeyConstraint("id", name="pk_prompts"), + sa.UniqueConstraint("team_id", "owner_email", "name", name=prompts_uq_name), + sa.PrimaryKeyConstraint("id", name=prompts_pk_name), ) # 3) Copy data from prompts into prompts_tmp using id_new as id @@ -184,6 +188,7 @@ def upgrade() -> None: conn.execute(ins) # 4) Create new prompt_metrics table with prompt_id varchar(36) + prompt_metrics_pk_name = "pk_prompt_metrics" if dialect == "sqlite" else "pk_prompt_metrics_tmp" op.create_table( "prompt_metrics_tmp", sa.Column("id", sa.Integer, primary_key=True, nullable=False), @@ -193,7 +198,7 @@ def upgrade() -> None: sa.Column("is_success", sa.Boolean, nullable=False), sa.Column("error_message", sa.Text, nullable=True), sa.ForeignKeyConstraint(["prompt_id"], ["prompts_tmp.id"], name="fk_prompt_metrics_prompt_id"), - sa.PrimaryKeyConstraint("id", name="pk_prompt_metrics"), + sa.PrimaryKeyConstraint("id", name=prompt_metrics_pk_name), ) # 5) Copy prompt_metrics mapping old integer prompt_id -> new uuid via join @@ -204,11 +209,12 @@ def upgrade() -> None: ) # 6) Create new server_prompt_association table with prompt_id varchar(36) + server_prompt_assoc_pk = "pk_server_prompt_assoc" if dialect == "sqlite" else "pk_server_prompt_assoc_tmp" op.create_table( "server_prompt_association_tmp", sa.Column("server_id", sa.String(36), nullable=False), sa.Column("prompt_id", sa.String(36), nullable=False), - sa.PrimaryKeyConstraint("server_id", "prompt_id", name="pk_server_prompt_assoc"), + sa.PrimaryKeyConstraint("server_id", "prompt_id", name=server_prompt_assoc_pk), sa.ForeignKeyConstraint(["server_id"], ["servers.id"], name="fk_server_prompt_server_id"), sa.ForeignKeyConstraint(["prompt_id"], ["prompts_tmp.id"], name="fk_server_prompt_prompt_id"), ) @@ -216,7 +222,12 @@ def upgrade() -> None: conn.execute(text("INSERT INTO server_prompt_association_tmp (server_id, prompt_id) SELECT spa.server_id, p.id_new FROM server_prompt_association spa JOIN prompts p ON spa.prompt_id = p.id")) # Update observability spans that reference prompts: remap integer prompt IDs -> new uuid - conn.execute(text("UPDATE observability_spans SET resource_id = p.id_new FROM prompts p WHERE observability_spans.resource_type = 'prompts' AND observability_spans.resource_id = p.id")) + # PostgreSQL requires explicit cast when comparing varchar to int; other DBs (SQLite/MySQL) are permissive. + dialect = conn.dialect.name if hasattr(conn, "dialect") else None + if dialect == "postgresql": + conn.execute(text("UPDATE observability_spans SET resource_id = p.id_new FROM prompts p WHERE observability_spans.resource_type = 'prompts' AND observability_spans.resource_id = p.id::text")) + else: + conn.execute(text("UPDATE observability_spans SET resource_id = p.id_new FROM prompts p WHERE observability_spans.resource_type = 'prompts' AND observability_spans.resource_id = p.id")) # 7) Drop old tables and rename tmp tables into place op.drop_table("prompt_metrics") @@ -226,6 +237,27 @@ def upgrade() -> None: op.rename_table("prompts_tmp", "prompts") op.rename_table("prompt_metrics_tmp", "prompt_metrics") op.rename_table("server_prompt_association_tmp", "server_prompt_association") + # For SQLite we cannot ALTER constraints directly; skip constraint renames there. + if dialect != "sqlite": + # Drop dependent foreign keys first to allow primary key rename/recreation + op.drop_constraint("fk_prompt_metrics_prompt_id", "prompt_metrics", type_="foreignkey") + op.drop_constraint("fk_server_prompt_prompt_id", "server_prompt_association", type_="foreignkey") + + # Restore original constraint names for prompts and dependent tables + op.drop_constraint("pk_prompts_tmp", "prompts", type_="primary") + op.create_primary_key("pk_prompts", "prompts", ["id"]) + op.drop_constraint("uq_team_owner_name_prompt_tmp", "prompts", type_="unique") + op.create_unique_constraint("uq_team_owner_name_prompt", "prompts", ["team_id", "owner_email", "name"]) + + op.drop_constraint("pk_prompt_metrics_tmp", "prompt_metrics", type_="primary") + op.create_primary_key("pk_prompt_metrics", "prompt_metrics", ["id"]) + + op.drop_constraint("pk_server_prompt_assoc_tmp", "server_prompt_association", type_="primary") + op.create_primary_key("pk_server_prompt_assoc", "server_prompt_association", ["server_id", "prompt_id"]) + + # Recreate foreign keys referencing the new primary key name + op.create_foreign_key("fk_prompt_metrics_prompt_id", "prompt_metrics", "prompts", ["prompt_id"], ["id"]) + op.create_foreign_key("fk_server_prompt_prompt_id", "server_prompt_association", "prompts", ["prompt_id"], ["id"]) # ----------------------------- # Resources -> change id to VARCHAR(32) and remap FKs @@ -239,6 +271,8 @@ def upgrade() -> None: conn.execute(text("UPDATE resources SET id_new = :new WHERE id = :old"), {"new": new_id, "old": old_id}) # Create resources_tmp with varchar(32) id + resources_pk_name = "pk_resources" if dialect == "sqlite" else "pk_resources_tmp" + resources_uq_name = "uq_team_owner_uri_resource" if dialect == "sqlite" else "uq_team_owner_uri_resource_tmp" op.create_table( "resources_tmp", sa.Column("id", sa.String(36), primary_key=True, nullable=False), @@ -269,8 +303,8 @@ def upgrade() -> None: sa.Column("team_id", sa.String(36), nullable=True), sa.Column("owner_email", sa.String(255), nullable=True), sa.Column("visibility", sa.String(20), nullable=False, server_default="public"), - sa.UniqueConstraint("team_id", "owner_email", "uri", name="uq_team_owner_uri_resource"), - sa.PrimaryKeyConstraint("id", name="pk_resources"), + sa.UniqueConstraint("team_id", "owner_email", "uri", name=resources_uq_name), + sa.PrimaryKeyConstraint("id", name=resources_pk_name), ) # Copy data into resources_tmp using id_new via SQLAlchemy Core @@ -405,6 +439,7 @@ def upgrade() -> None: conn.execute(ins_res) # resource_metrics_tmp with resource_id varchar(32) + resource_metrics_pk = "pk_resource_metrics" if dialect == "sqlite" else "pk_resource_metrics_tmp" op.create_table( "resource_metrics_tmp", sa.Column("id", sa.Integer, primary_key=True, nullable=False), @@ -414,7 +449,7 @@ def upgrade() -> None: sa.Column("is_success", sa.Boolean, nullable=False), sa.Column("error_message", sa.Text, nullable=True), sa.ForeignKeyConstraint(["resource_id"], ["resources_tmp.id"], name="fk_resource_metrics_resource_id"), - sa.PrimaryKeyConstraint("id", name="pk_resource_metrics"), + sa.PrimaryKeyConstraint("id", name=resource_metrics_pk), ) # copy resource_metrics mapping old int->new uuid @@ -425,11 +460,12 @@ def upgrade() -> None: ) # server_resource_association_tmp + server_resource_assoc_pk = "pk_server_resource_assoc" if dialect == "sqlite" else "pk_server_resource_assoc_tmp" op.create_table( "server_resource_association_tmp", sa.Column("server_id", sa.String(36), nullable=False), sa.Column("resource_id", sa.String(36), nullable=False), - sa.PrimaryKeyConstraint("server_id", "resource_id", name="pk_server_resource_assoc"), + sa.PrimaryKeyConstraint("server_id", "resource_id", name=server_resource_assoc_pk), sa.ForeignKeyConstraint(["server_id"], ["servers.id"], name="fk_server_resource_server_id"), sa.ForeignKeyConstraint(["resource_id"], ["resources_tmp.id"], name="fk_server_resource_resource_id"), ) @@ -439,7 +475,14 @@ def upgrade() -> None: ) # Update observability spans that reference resources: remap integer resource IDs -> new uuid - conn.execute(text("UPDATE observability_spans SET resource_id = r.id_new FROM resources r WHERE observability_spans.resource_type = 'resources' AND observability_spans.resource_id = r.id")) + # Cast for PostgreSQL to avoid varchar = integer operator error + dialect = conn.dialect.name if hasattr(conn, "dialect") else None + if dialect == "postgresql": + conn.execute( + text("UPDATE observability_spans SET resource_id = r.id_new FROM resources r WHERE observability_spans.resource_type = 'resources' AND observability_spans.resource_id = r.id::text") + ) + else: + conn.execute(text("UPDATE observability_spans SET resource_id = r.id_new FROM resources r WHERE observability_spans.resource_type = 'resources' AND observability_spans.resource_id = r.id")) # resource_subscriptions_tmp op.create_table( @@ -468,6 +511,29 @@ def upgrade() -> None: op.rename_table("resource_metrics_tmp", "resource_metrics") op.rename_table("server_resource_association_tmp", "server_resource_association") op.rename_table("resource_subscriptions_tmp", "resource_subscriptions") + # For SQLite we cannot ALTER constraints directly; skip constraint renames there. + if dialect != "sqlite": + # Drop dependent foreign keys first to allow primary key rename/recreation + op.drop_constraint("fk_resource_metrics_resource_id", "resource_metrics", type_="foreignkey") + op.drop_constraint("fk_server_resource_resource_id", "server_resource_association", type_="foreignkey") + op.drop_constraint("fk_resource_subscriptions_resource_id", "resource_subscriptions", type_="foreignkey") + + # Restore original constraint names for resources and dependent tables + op.drop_constraint("pk_resources_tmp", "resources", type_="primary") + op.create_primary_key("pk_resources", "resources", ["id"]) + op.drop_constraint("uq_team_owner_uri_resource_tmp", "resources", type_="unique") + op.create_unique_constraint("uq_team_owner_uri_resource", "resources", ["team_id", "owner_email", "uri"]) + + op.drop_constraint("pk_resource_metrics_tmp", "resource_metrics", type_="primary") + op.create_primary_key("pk_resource_metrics", "resource_metrics", ["id"]) + + op.drop_constraint("pk_server_resource_assoc_tmp", "server_resource_association", type_="primary") + op.create_primary_key("pk_server_resource_assoc", "server_resource_association", ["server_id", "resource_id"]) + + # Recreate foreign keys referencing restored primary key + op.create_foreign_key("fk_resource_metrics_resource_id", "resource_metrics", "resources", ["resource_id"], ["id"]) + op.create_foreign_key("fk_server_resource_resource_id", "server_resource_association", "resources", ["resource_id"], ["id"]) + op.create_foreign_key("fk_resource_subscriptions_resource_id", "resource_subscriptions", "resources", ["resource_id"], ["id"]) with op.batch_alter_table("servers") as batch_op: batch_op.alter_column( @@ -482,9 +548,14 @@ def upgrade() -> None: def downgrade() -> None: """Downgrade schema.""" conn = op.get_bind() + dialect = conn.dialect.name if hasattr(conn, "dialect") else None # Best-effort: rebuild integer prompt ids and remap dependent FK columns. # 1) Create old-style prompts table with integer id (autoincrement) + # If a previous partial downgrade left these tables behind, drop them first + conn.execute(text("DROP TABLE IF EXISTS prompts_old")) + prompts_old_pk = "pk_prompts" if dialect == "sqlite" else "pk_prompts_old" + prompts_old_uq = "uq_team_owner_name_prompt" if dialect == "sqlite" else "uq_team_owner_name_prompt_old" op.create_table( "prompts_old", sa.Column("id", sa.Integer, primary_key=True, autoincrement=True, nullable=False), @@ -511,8 +582,8 @@ def downgrade() -> None: sa.Column("team_id", sa.String(36), nullable=True), sa.Column("owner_email", sa.String(255), nullable=True), sa.Column("visibility", sa.String(20), nullable=False, server_default="public"), - sa.UniqueConstraint("team_id", "owner_email", "name", name="uq_team_owner_name_prompt"), - sa.PrimaryKeyConstraint("id", name="pk_prompts"), + sa.UniqueConstraint("team_id", "owner_email", "name", name=prompts_old_uq), + sa.PrimaryKeyConstraint("id", name=prompts_old_pk), ) # 2) Insert rows from current prompts into prompts_old letting id autoincrement. @@ -549,6 +620,8 @@ def downgrade() -> None: mapping[row[0]] = row[4] # 4) Recreate prompt_metrics_old and remap prompt_id + conn.execute(text("DROP TABLE IF EXISTS prompt_metrics_old")) + prompt_metrics_old_pk = "pk_prompt_metrics" if dialect == "sqlite" else "pk_prompt_metric_old" op.create_table( "prompt_metrics_old", sa.Column("id", sa.Integer, primary_key=True, nullable=False), @@ -558,7 +631,7 @@ def downgrade() -> None: sa.Column("is_success", sa.Boolean, nullable=False), sa.Column("error_message", sa.Text, nullable=True), sa.ForeignKeyConstraint(["prompt_id"], ["prompts_old.id"], name="fk_prompt_metrics_prompt_id"), - sa.PrimaryKeyConstraint("id", name="pk_prompt_metric"), + sa.PrimaryKeyConstraint("id", name=prompt_metrics_old_pk), ) # Copy metrics mapping prompt_id via Python mapping @@ -575,11 +648,13 @@ def downgrade() -> None: ) # 5) Recreate server_prompt_association_old and remap prompt_id + conn.execute(text("DROP TABLE IF EXISTS server_prompt_association_old")) + server_prompt_assoc_old_pk = "pk_server_prompt_assoc" if dialect == "sqlite" else "pk_server_prompt_assoc_old" op.create_table( "server_prompt_association_old", sa.Column("server_id", sa.String(36), nullable=False), sa.Column("prompt_id", sa.Integer, nullable=False), - sa.PrimaryKeyConstraint("server_id", "prompt_id", name="pk_server_prompt_assoc"), + sa.PrimaryKeyConstraint("server_id", "prompt_id", name=server_prompt_assoc_old_pk), sa.ForeignKeyConstraint(["server_id"], ["servers.id"], name="fk_server_prompt_server_id"), sa.ForeignKeyConstraint(["prompt_id"], ["prompts_old.id"], name="fk_server_prompt_prompt_id"), ) @@ -609,10 +684,33 @@ def downgrade() -> None: op.rename_table("prompt_metrics_old", "prompt_metrics") op.rename_table("server_prompt_association_old", "server_prompt_association") + # For SQLite we cannot ALTER constraints directly; skip those steps there. + if dialect != "sqlite": + # Drop dependent foreign keys first to allow primary key rename/recreation + op.drop_constraint("fk_prompt_metrics_prompt_id", "prompt_metrics", type_="foreignkey") + op.drop_constraint("fk_server_prompt_prompt_id", "server_prompt_association", type_="foreignkey") + + # Restore original constraint names after renaming old tables back + op.drop_constraint("pk_prompts_old", "prompts", type_="primary") + op.create_primary_key("pk_prompts", "prompts", ["id"]) + op.drop_constraint("uq_team_owner_name_prompt_old", "prompts", type_="unique") + op.create_unique_constraint("uq_team_owner_name_prompt", "prompts", ["team_id", "owner_email", "name"]) + + op.drop_constraint("pk_prompt_metric_old", "prompt_metrics", type_="primary") + op.create_primary_key("pk_prompt_metrics", "prompt_metrics", ["id"]) + + op.drop_constraint("pk_server_prompt_assoc_old", "server_prompt_association", type_="primary") + op.create_primary_key("pk_server_prompt_assoc", "server_prompt_association", ["server_id", "prompt_id"]) + + # Recreate foreign keys referencing the new primary key name + op.create_foreign_key("fk_prompt_metrics_prompt_id", "prompt_metrics", "prompts", ["prompt_id"], ["id"]) + op.create_foreign_key("fk_server_prompt_prompt_id", "server_prompt_association", "prompts", ["prompt_id"], ["id"]) + # ============================= # Resources downgrade: rebuild integer ids and remap FKs # ============================= # 1) Create old-style resources table with integer id (autoincrement) + conn.execute(text("DROP TABLE IF EXISTS resources_old")) op.create_table( "resources_old", sa.Column("id", sa.Integer, primary_key=True, autoincrement=True, nullable=False), @@ -643,8 +741,8 @@ def downgrade() -> None: sa.Column("team_id", sa.String(36), nullable=True), sa.Column("owner_email", sa.String(255), nullable=True), sa.Column("visibility", sa.String(20), nullable=False, server_default="public"), - sa.UniqueConstraint("team_id", "owner_email", "uri", name="uq_team_owner_uri_resource"), - sa.PrimaryKeyConstraint("id", name="pk_resources"), + sa.UniqueConstraint("team_id", "owner_email", "uri", name="uq_team_owner_uri_resource_old"), + sa.PrimaryKeyConstraint("id", name="pk_resources_old"), ) # 2) Insert rows from current resources into resources_old letting id autoincrement. @@ -680,6 +778,7 @@ def downgrade() -> None: mapping_res[row[0]] = row[4] # 4) Recreate resource_metrics_old and remap resource_id + conn.execute(text("DROP TABLE IF EXISTS resource_metrics_old")) op.create_table( "resource_metrics_old", sa.Column("id", sa.Integer, primary_key=True, nullable=False), @@ -705,11 +804,12 @@ def downgrade() -> None: ) # 5) Recreate server_resource_association_old and remap resource_id + conn.execute(text("DROP TABLE IF EXISTS server_resource_association_old")) op.create_table( "server_resource_association_old", sa.Column("server_id", sa.String(36), nullable=False), sa.Column("resource_id", sa.Integer, nullable=False), - sa.PrimaryKeyConstraint("server_id", "resource_id", name="pk_server_resource_assoc"), + sa.PrimaryKeyConstraint("server_id", "resource_id", name="pk_server_resource_assoc_old"), sa.ForeignKeyConstraint(["server_id"], ["servers.id"], name="fk_server_resource_server_id"), sa.ForeignKeyConstraint(["resource_id"], ["resources_old.id"], name="fk_server_resource_resource_id"), ) @@ -722,6 +822,7 @@ def downgrade() -> None: conn.execute(text("INSERT INTO server_resource_association_old (server_id, resource_id) VALUES (:sid, :rid)"), {"sid": server_id, "rid": int_id}) # 6) Recreate resource_subscriptions_old and remap resource_id + conn.execute(text("DROP TABLE IF EXISTS resource_subscriptions_old")) op.create_table( "resource_subscriptions_old", sa.Column("id", sa.Integer, primary_key=True, nullable=False), @@ -760,6 +861,29 @@ def downgrade() -> None: op.rename_table("resource_metrics_old", "resource_metrics") op.rename_table("server_resource_association_old", "server_resource_association") op.rename_table("resource_subscriptions_old", "resource_subscriptions") + # For SQLite we cannot ALTER constraints directly; skip those steps there. + if dialect != "sqlite": + # Drop dependent foreign keys first to allow primary key rename/recreation + op.drop_constraint("fk_resource_metrics_resource_id", "resource_metrics", type_="foreignkey") + op.drop_constraint("fk_server_resource_resource_id", "server_resource_association", type_="foreignkey") + op.drop_constraint("fk_resource_subscriptions_resource_id", "resource_subscriptions", type_="foreignkey") + + # Restore original constraint names for resources after downgrade + op.drop_constraint("pk_resources_old", "resources", type_="primary") + op.create_primary_key("pk_resources", "resources", ["id"]) + op.drop_constraint("uq_team_owner_uri_resource_old", "resources", type_="unique") + op.create_unique_constraint("uq_team_owner_uri_resource", "resources", ["team_id", "owner_email", "uri"]) + + op.drop_constraint("pk_resource_metrics_old", "resource_metrics", type_="primary") + op.create_primary_key("pk_resource_metrics", "resource_metrics", ["id"]) + + op.drop_constraint("pk_server_resource_assoc_old", "server_resource_association", type_="primary") + op.create_primary_key("pk_server_resource_assoc", "server_resource_association", ["server_id", "resource_id"]) + + # Recreate foreign keys to point to restored primary key + op.create_foreign_key("fk_resource_metrics_resource_id", "resource_metrics", "resources", ["resource_id"], ["id"]) + op.create_foreign_key("fk_server_resource_resource_id", "server_resource_association", "resources", ["resource_id"], ["id"]) + op.create_foreign_key("fk_resource_subscriptions_resource_id", "resource_subscriptions", "resources", ["resource_id"], ["id"]) with op.batch_alter_table("servers") as batch_op: batch_op.alter_column( "enabled", diff --git a/mcpgateway/alembic/versions/9e028ecf59c4_tag_records_changes_list_str_to_list_.py b/mcpgateway/alembic/versions/9e028ecf59c4_tag_records_changes_list_str_to_list_.py index 61ba1ed7c..481f303f5 100644 --- a/mcpgateway/alembic/versions/9e028ecf59c4_tag_records_changes_list_str_to_list_.py +++ b/mcpgateway/alembic/versions/9e028ecf59c4_tag_records_changes_list_str_to_list_.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- """tag records changes list[str] to list[Dict[str,str]] Revision ID: 9e028ecf59c4 diff --git a/mcpgateway/alembic/versions/k5e6f7g8h9i0_add_structured_logging_tables.py b/mcpgateway/alembic/versions/k5e6f7g8h9i0_add_structured_logging_tables.py new file mode 100644 index 000000000..a83afbd28 --- /dev/null +++ b/mcpgateway/alembic/versions/k5e6f7g8h9i0_add_structured_logging_tables.py @@ -0,0 +1,216 @@ +# -*- coding: utf-8 -*- +"""Add structured logging tables + +Revision ID: k5e6f7g8h9i0 +Revises: 356a2d4eed6f +Create Date: 2025-01-15 12:00:00.000000 + +""" + +# Third-Party +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = "k5e6f7g8h9i0" +down_revision = "356a2d4eed6f" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + """Add structured logging tables.""" + # Create structured_log_entries table + op.create_table( + "structured_log_entries", + sa.Column("id", sa.String(36), nullable=False), + sa.Column("timestamp", sa.DateTime(timezone=True), nullable=False), + sa.Column("correlation_id", sa.String(64), nullable=True), + sa.Column("request_id", sa.String(64), nullable=True), + sa.Column("level", sa.String(20), nullable=False), + sa.Column("component", sa.String(100), nullable=False), + sa.Column("message", sa.Text(), nullable=False), + sa.Column("logger", sa.String(255), nullable=True), + sa.Column("user_id", sa.String(255), nullable=True), + sa.Column("user_email", sa.String(255), nullable=True), + sa.Column("client_ip", sa.String(45), nullable=True), + sa.Column("user_agent", sa.Text(), nullable=True), + sa.Column("request_path", sa.String(500), nullable=True), + sa.Column("request_method", sa.String(10), nullable=True), + sa.Column("duration_ms", sa.Float(), nullable=True), + sa.Column("operation_type", sa.String(100), nullable=True), + sa.Column("is_security_event", sa.Boolean(), nullable=False, server_default=sa.false()), + sa.Column("security_severity", sa.String(20), nullable=True), + sa.Column("threat_indicators", sa.JSON(), nullable=True), + sa.Column("context", sa.JSON(), nullable=True), + sa.Column("error_details", sa.JSON(), nullable=True), + sa.Column("performance_metrics", sa.JSON(), nullable=True), + sa.Column("hostname", sa.String(255), nullable=False), + sa.Column("process_id", sa.Integer(), nullable=False), + sa.Column("thread_id", sa.Integer(), nullable=True), + sa.Column("version", sa.String(50), nullable=False), + sa.Column("environment", sa.String(50), nullable=False, server_default="production"), + sa.Column("trace_id", sa.String(32), nullable=True), + sa.Column("span_id", sa.String(16), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + + # Create indexes for structured_log_entries + op.create_index("ix_structured_log_entries_timestamp", "structured_log_entries", ["timestamp"], unique=False) + op.create_index("ix_structured_log_entries_level", "structured_log_entries", ["level"], unique=False) + op.create_index("ix_structured_log_entries_component", "structured_log_entries", ["component"], unique=False) + op.create_index("ix_structured_log_entries_correlation_id", "structured_log_entries", ["correlation_id"], unique=False) + op.create_index("ix_structured_log_entries_request_id", "structured_log_entries", ["request_id"], unique=False) + op.create_index("ix_structured_log_entries_user_id", "structured_log_entries", ["user_id"], unique=False) + op.create_index("ix_structured_log_entries_user_email", "structured_log_entries", ["user_email"], unique=False) + op.create_index("ix_structured_log_entries_operation_type", "structured_log_entries", ["operation_type"], unique=False) + op.create_index("ix_structured_log_entries_is_security_event", "structured_log_entries", ["is_security_event"], unique=False) + op.create_index("ix_structured_log_entries_security_severity", "structured_log_entries", ["security_severity"], unique=False) + op.create_index("ix_structured_log_entries_trace_id", "structured_log_entries", ["trace_id"], unique=False) + + # Composite indexes matching db.py + op.create_index("idx_log_correlation_time", "structured_log_entries", ["correlation_id", "timestamp"], unique=False) + op.create_index("idx_log_user_time", "structured_log_entries", ["user_id", "timestamp"], unique=False) + op.create_index("idx_log_level_time", "structured_log_entries", ["level", "timestamp"], unique=False) + op.create_index("idx_log_component_time", "structured_log_entries", ["component", "timestamp"], unique=False) + op.create_index("idx_log_security", "structured_log_entries", ["is_security_event", "security_severity", "timestamp"], unique=False) + op.create_index("idx_log_operation", "structured_log_entries", ["operation_type", "timestamp"], unique=False) + op.create_index("idx_log_trace", "structured_log_entries", ["trace_id", "timestamp"], unique=False) + + # Create performance_metrics table + op.create_table( + "performance_metrics", + sa.Column("id", sa.String(36), nullable=False), + sa.Column("timestamp", sa.DateTime(timezone=True), nullable=False), + sa.Column("operation_type", sa.String(100), nullable=False), + sa.Column("component", sa.String(100), nullable=False), + sa.Column("request_count", sa.Integer(), nullable=False, server_default="0"), + sa.Column("error_count", sa.Integer(), nullable=False, server_default="0"), + sa.Column("error_rate", sa.Float(), nullable=False, server_default="0.0"), + sa.Column("avg_duration_ms", sa.Float(), nullable=False), + sa.Column("min_duration_ms", sa.Float(), nullable=False), + sa.Column("max_duration_ms", sa.Float(), nullable=False), + sa.Column("p50_duration_ms", sa.Float(), nullable=False), + sa.Column("p95_duration_ms", sa.Float(), nullable=False), + sa.Column("p99_duration_ms", sa.Float(), nullable=False), + sa.Column("window_start", sa.DateTime(timezone=True), nullable=False), + sa.Column("window_end", sa.DateTime(timezone=True), nullable=False), + sa.Column("window_duration_seconds", sa.Integer(), nullable=False), + sa.Column("metric_metadata", sa.JSON(), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + + # Create indexes for performance_metrics + op.create_index("ix_performance_metrics_timestamp", "performance_metrics", ["timestamp"], unique=False) + op.create_index("ix_performance_metrics_component", "performance_metrics", ["component"], unique=False) + op.create_index("ix_performance_metrics_operation_type", "performance_metrics", ["operation_type"], unique=False) + op.create_index("ix_performance_metrics_window_start", "performance_metrics", ["window_start"], unique=False) + op.create_index("idx_perf_operation_time", "performance_metrics", ["operation_type", "window_start"], unique=False) + op.create_index("idx_perf_component_time", "performance_metrics", ["component", "window_start"], unique=False) + op.create_index("idx_perf_window", "performance_metrics", ["window_start", "window_end"], unique=False) + + # Create security_events table + op.create_table( + "security_events", + sa.Column("id", sa.String(36), nullable=False), + sa.Column("timestamp", sa.DateTime(timezone=True), nullable=False), + sa.Column("detected_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("correlation_id", sa.String(64), nullable=True), + sa.Column("log_entry_id", sa.String(36), nullable=True), + sa.Column("event_type", sa.String(100), nullable=False), + sa.Column("severity", sa.String(20), nullable=False), + sa.Column("category", sa.String(50), nullable=False), + sa.Column("user_id", sa.String(255), nullable=True), + sa.Column("user_email", sa.String(255), nullable=True), + sa.Column("client_ip", sa.String(45), nullable=False), + sa.Column("user_agent", sa.Text(), nullable=True), + sa.Column("description", sa.Text(), nullable=False), + sa.Column("action_taken", sa.String(100), nullable=True), + sa.Column("threat_score", sa.Float(), nullable=False, server_default="0.0"), + sa.Column("threat_indicators", sa.JSON(), nullable=False, server_default=sa.text("'{}'")), + sa.Column("failed_attempts_count", sa.Integer(), nullable=False, server_default="0"), + sa.Column("resolved", sa.Boolean(), nullable=False, server_default=sa.false()), + sa.Column("resolved_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("resolved_by", sa.String(255), nullable=True), + sa.Column("resolution_notes", sa.Text(), nullable=True), + sa.Column("alert_sent", sa.Boolean(), nullable=False, server_default=sa.false()), + sa.Column("alert_sent_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("alert_recipients", sa.JSON(), nullable=True), + sa.Column("context", sa.JSON(), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["log_entry_id"], ["structured_log_entries.id"]), + ) + + # Create indexes for security_events + op.create_index("ix_security_events_timestamp", "security_events", ["timestamp"], unique=False) + op.create_index("ix_security_events_detected_at", "security_events", ["detected_at"], unique=False) + op.create_index("ix_security_events_correlation_id", "security_events", ["correlation_id"], unique=False) + op.create_index("ix_security_events_event_type", "security_events", ["event_type"], unique=False) + op.create_index("ix_security_events_severity", "security_events", ["severity"], unique=False) + op.create_index("ix_security_events_category", "security_events", ["category"], unique=False) + op.create_index("ix_security_events_user_id", "security_events", ["user_id"], unique=False) + op.create_index("ix_security_events_user_email", "security_events", ["user_email"], unique=False) + op.create_index("ix_security_events_client_ip", "security_events", ["client_ip"], unique=False) + op.create_index("ix_security_events_log_entry_id", "security_events", ["log_entry_id"], unique=False) + op.create_index("ix_security_events_resolved", "security_events", ["resolved"], unique=False) + op.create_index("idx_security_type_time", "security_events", ["event_type", "timestamp"], unique=False) + op.create_index("idx_security_severity_time", "security_events", ["severity", "timestamp"], unique=False) + op.create_index("idx_security_user_time", "security_events", ["user_id", "timestamp"], unique=False) + op.create_index("idx_security_ip_time", "security_events", ["client_ip", "timestamp"], unique=False) + op.create_index("idx_security_unresolved", "security_events", ["resolved", "severity", "timestamp"], unique=False) + + # Create audit_trails table + op.create_table( + "audit_trails", + sa.Column("id", sa.String(36), nullable=False), + sa.Column("timestamp", sa.DateTime(timezone=True), nullable=False), + sa.Column("correlation_id", sa.String(64), nullable=True), + sa.Column("request_id", sa.String(64), nullable=True), + sa.Column("action", sa.String(100), nullable=False), + sa.Column("resource_type", sa.String(100), nullable=False), + sa.Column("resource_id", sa.String(255), nullable=False), + sa.Column("resource_name", sa.String(500), nullable=True), + sa.Column("user_id", sa.String(255), nullable=False), + sa.Column("user_email", sa.String(255), nullable=True), + sa.Column("team_id", sa.String(36), nullable=True), + sa.Column("client_ip", sa.String(45), nullable=True), + sa.Column("user_agent", sa.Text(), nullable=True), + sa.Column("request_path", sa.String(500), nullable=True), + sa.Column("request_method", sa.String(10), nullable=True), + sa.Column("old_values", sa.JSON(), nullable=True), + sa.Column("new_values", sa.JSON(), nullable=True), + sa.Column("changes", sa.JSON(), nullable=True), + sa.Column("data_classification", sa.String(50), nullable=True), + sa.Column("requires_review", sa.Boolean(), nullable=False, server_default=sa.false()), + sa.Column("success", sa.Boolean(), nullable=False), + sa.Column("error_message", sa.Text(), nullable=True), + sa.Column("context", sa.JSON(), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + + # Create indexes for audit_trails + op.create_index("ix_audit_trails_timestamp", "audit_trails", ["timestamp"], unique=False) + op.create_index("ix_audit_trails_correlation_id", "audit_trails", ["correlation_id"], unique=False) + op.create_index("ix_audit_trails_request_id", "audit_trails", ["request_id"], unique=False) + op.create_index("ix_audit_trails_action", "audit_trails", ["action"], unique=False) + op.create_index("ix_audit_trails_resource_type", "audit_trails", ["resource_type"], unique=False) + op.create_index("ix_audit_trails_resource_id", "audit_trails", ["resource_id"], unique=False) + op.create_index("ix_audit_trails_user_id", "audit_trails", ["user_id"], unique=False) + op.create_index("ix_audit_trails_user_email", "audit_trails", ["user_email"], unique=False) + op.create_index("ix_audit_trails_team_id", "audit_trails", ["team_id"], unique=False) + op.create_index("ix_audit_trails_data_classification", "audit_trails", ["data_classification"], unique=False) + op.create_index("ix_audit_trails_requires_review", "audit_trails", ["requires_review"], unique=False) + op.create_index("ix_audit_trails_success", "audit_trails", ["success"], unique=False) + op.create_index("idx_audit_action_time", "audit_trails", ["action", "timestamp"], unique=False) + op.create_index("idx_audit_resource_time", "audit_trails", ["resource_type", "resource_id", "timestamp"], unique=False) + op.create_index("idx_audit_user_time", "audit_trails", ["user_id", "timestamp"], unique=False) + op.create_index("idx_audit_classification", "audit_trails", ["data_classification", "timestamp"], unique=False) + op.create_index("idx_audit_review", "audit_trails", ["requires_review", "timestamp"], unique=False) + + +def downgrade() -> None: + """Remove structured logging tables.""" + op.drop_table("audit_trails") + op.drop_table("security_events") + op.drop_table("performance_metrics") + op.drop_table("structured_log_entries") diff --git a/mcpgateway/auth.py b/mcpgateway/auth.py index ea633ee5d..c0300a124 100644 --- a/mcpgateway/auth.py +++ b/mcpgateway/auth.py @@ -26,11 +26,63 @@ from mcpgateway.config import settings from mcpgateway.db import EmailUser, SessionLocal from mcpgateway.plugins.framework import get_plugin_manager, GlobalContext, HttpAuthResolveUserPayload, HttpHeaderPayload, HttpHookType, PluginViolationError -from mcpgateway.services.team_management_service import TeamManagementService +from mcpgateway.services.team_management_service import TeamManagementService # pylint: disable=import-outside-toplevel +from mcpgateway.utils.correlation_id import get_correlation_id from mcpgateway.utils.verify_credentials import verify_jwt_token # Security scheme -bearer_scheme = HTTPBearer(auto_error=False) +security = HTTPBearer(auto_error=False) + + +def _log_auth_event( + logger: logging.Logger, + message: str, + level: int = logging.INFO, + user_id: Optional[str] = None, + auth_method: Optional[str] = None, + auth_success: bool = False, + security_event: Optional[str] = None, + security_severity: str = "low", + **extra_context, +) -> None: + """Log authentication event with structured context and request_id. + + This helper creates structured log records that include request_id from the + correlation ID context, enabling end-to-end tracing of authentication flows. + + Args: + logger: Logger instance to use + message: Log message + level: Log level (default: INFO) + user_id: User identifier + auth_method: Authentication method used (jwt, api_token, etc.) + auth_success: Whether authentication succeeded + security_event: Type of security event (authentication, authorization, etc.) + security_severity: Severity level (low, medium, high, critical) + **extra_context: Additional context fields + """ + # Get request_id from correlation ID context + request_id = get_correlation_id() + + # Build structured log record + extra = { + "request_id": request_id, + "entity_type": "auth", + "auth_success": auth_success, + "security_event": security_event or "authentication", + "security_severity": security_severity, + } + + if user_id: + extra["user_id"] = user_id + if auth_method: + extra["auth_method"] = auth_method + + # Add any additional context + extra.update(extra_context) + + # Log with structured context + logger.log(level, message, extra=extra) def get_db() -> Generator[Session, Never, None]: @@ -119,7 +171,7 @@ async def get_team_from_token(payload: Dict[str, Any], db: Session) -> Optional[ async def get_current_user( - credentials: Optional[HTTPAuthorizationCredentials] = Depends(bearer_scheme), + credentials: Optional[HTTPAuthorizationCredentials] = Depends(security), db: Session = Depends(get_db), request: Optional[object] = None, ) -> EmailUser: @@ -169,10 +221,15 @@ async def get_current_user( if request and hasattr(request, "headers"): headers = dict(request.headers) - # Get request ID from request state (set by middleware) or generate new one - request_id = getattr(request.state, "request_id", None) if request else None + # Get request ID from correlation ID context (set by CorrelationIDMiddleware) + request_id = get_correlation_id() if not request_id: - request_id = uuid.uuid4().hex + # Fallback chain for safety + if request and hasattr(request, "state") and hasattr(request.state, "request_id"): + request_id = request.state.request_id + else: + request_id = uuid.uuid4().hex + logger.debug(f"Generated fallback request ID in get_current_user: {request_id}") # Get plugin contexts from request state if available global_context = getattr(request.state, "plugin_global_context", None) if request else None diff --git a/mcpgateway/common/models.py b/mcpgateway/common/models.py index ee6e87478..5df0c5e2a 100644 --- a/mcpgateway/common/models.py +++ b/mcpgateway/common/models.py @@ -705,6 +705,7 @@ class ResourceTemplate(BaseModelWithConfigDict): """A template for constructing resource URIs (MCP spec-compliant). Attributes: + id (Optional[str]): Unique identifier for resource uri_template (str): The URI template string. name (str): The unique name of the template. description (Optional[str]): A description of the template. @@ -717,7 +718,7 @@ class ResourceTemplate(BaseModelWithConfigDict): # ✅ DB field name: uri_template # ✅ API (JSON) alias: - id: Optional[int] = None + id: Optional[str] = None uri_template: str = Field(..., alias="uriTemplate") name: str description: Optional[str] = None diff --git a/mcpgateway/config.py b/mcpgateway/config.py index 9d3017876..49b367484 100644 --- a/mcpgateway/config.py +++ b/mcpgateway/config.py @@ -776,6 +776,51 @@ def _parse_allowed_origins(cls, v: Any) -> Set[str]: # Enable span events observability_events_enabled: bool = Field(default=True, description="Enable event logging within spans") + # Correlation ID Settings + correlation_id_enabled: bool = Field(default=True, description="Enable automatic correlation ID tracking for requests") + correlation_id_header: str = Field(default="X-Correlation-ID", description="HTTP header name for correlation ID") + correlation_id_preserve: bool = Field(default=True, description="Preserve correlation IDs from incoming requests") + correlation_id_response_header: bool = Field(default=True, description="Include correlation ID in response headers") + + # Structured Logging Configuration + structured_logging_enabled: bool = Field(default=True, description="Enable structured JSON logging with database persistence") + structured_logging_database_enabled: bool = Field(default=True, description="Persist structured logs to database") + structured_logging_external_enabled: bool = Field(default=False, description="Send logs to external systems") + + # Performance Tracking Configuration + performance_tracking_enabled: bool = Field(default=True, description="Enable performance tracking and metrics") + performance_threshold_database_query_ms: float = Field(default=100.0, description="Alert threshold for database queries (ms)") + performance_threshold_tool_invocation_ms: float = Field(default=2000.0, description="Alert threshold for tool invocations (ms)") + performance_threshold_resource_read_ms: float = Field(default=1000.0, description="Alert threshold for resource reads (ms)") + performance_threshold_http_request_ms: float = Field(default=500.0, description="Alert threshold for HTTP requests (ms)") + performance_degradation_multiplier: float = Field(default=1.5, description="Alert if performance degrades by this multiplier vs baseline") + + # Security Logging Configuration + security_logging_enabled: bool = Field(default=True, description="Enable security event logging") + security_failed_auth_threshold: int = Field(default=5, description="Failed auth attempts before high severity alert") + security_threat_score_alert: float = Field(default=0.7, description="Threat score threshold for alerts (0.0-1.0)") + security_rate_limit_window_minutes: int = Field(default=5, description="Time window for rate limit checks (minutes)") + + # Metrics Aggregation Configuration + metrics_aggregation_enabled: bool = Field(default=True, description="Enable automatic log aggregation into performance metrics") + metrics_aggregation_backfill_hours: int = Field(default=6, ge=0, le=168, description="Hours of structured logs to backfill into performance metrics on startup") + metrics_aggregation_window_minutes: int = Field(default=5, description="Time window for metrics aggregation (minutes)") + metrics_aggregation_auto_start: bool = Field(default=False, description="Automatically run the log aggregation loop on application startup") + + # Log Search Configuration + log_search_max_results: int = Field(default=1000, description="Maximum results per log search query") + log_retention_days: int = Field(default=30, description="Number of days to retain logs in database") + + # External Log Integration Configuration + elasticsearch_enabled: bool = Field(default=False, description="Send logs to Elasticsearch") + elasticsearch_url: Optional[str] = Field(default=None, description="Elasticsearch cluster URL") + elasticsearch_index_prefix: str = Field(default="mcpgateway-logs", description="Elasticsearch index prefix") + syslog_enabled: bool = Field(default=False, description="Send logs to syslog") + syslog_host: Optional[str] = Field(default=None, description="Syslog server host") + syslog_port: int = Field(default=514, description="Syslog server port") + webhook_logging_enabled: bool = Field(default=False, description="Send logs to webhook endpoints") + webhook_logging_urls: List[str] = Field(default_factory=list, description="Webhook URLs for log delivery") + @field_validator("log_level", mode="before") @classmethod def validate_log_level(cls, v: str) -> str: diff --git a/mcpgateway/db.py b/mcpgateway/db.py index f5289951d..f548c8af4 100644 --- a/mcpgateway/db.py +++ b/mcpgateway/db.py @@ -3797,6 +3797,252 @@ def init_db(): raise Exception(f"Failed to initialize database: {str(e)}") +# ============================================================================ +# Structured Logging Models +# ============================================================================ + + +class StructuredLogEntry(Base): + """Structured log entry for comprehensive logging and analysis. + + Stores all log entries with correlation IDs, performance metrics, + and security context for advanced search and analytics. + """ + + __tablename__ = "structured_log_entries" + + # Primary key + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: uuid.uuid4().hex) + + # Timestamps + timestamp: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, index=True, default=utc_now) + + # Correlation and request tracking + correlation_id: Mapped[Optional[str]] = mapped_column(String(64), index=True, nullable=True) + request_id: Mapped[Optional[str]] = mapped_column(String(64), index=True, nullable=True) + + # Log metadata + level: Mapped[str] = mapped_column(String(20), nullable=False, index=True) # DEBUG, INFO, WARNING, ERROR, CRITICAL + component: Mapped[str] = mapped_column(String(100), nullable=False, index=True) + message: Mapped[str] = mapped_column(Text, nullable=False) + logger: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + + # User and request context + user_id: Mapped[Optional[str]] = mapped_column(String(255), index=True, nullable=True) + user_email: Mapped[Optional[str]] = mapped_column(String(255), index=True, nullable=True) + client_ip: Mapped[Optional[str]] = mapped_column(String(45), nullable=True) # IPv6 max length + user_agent: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + request_path: Mapped[Optional[str]] = mapped_column(String(500), nullable=True) + request_method: Mapped[Optional[str]] = mapped_column(String(10), nullable=True) + + # Performance data + duration_ms: Mapped[Optional[float]] = mapped_column(Float, nullable=True) + operation_type: Mapped[Optional[str]] = mapped_column(String(100), index=True, nullable=True) + + # Security context + is_security_event: Mapped[bool] = mapped_column(Boolean, default=False, index=True, nullable=False) + security_severity: Mapped[Optional[str]] = mapped_column(String(20), index=True, nullable=True) # LOW, MEDIUM, HIGH, CRITICAL + threat_indicators: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) + + # Structured context data + context: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) + error_details: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) + performance_metrics: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) + + # System information + hostname: Mapped[str] = mapped_column(String(255), nullable=False) + process_id: Mapped[int] = mapped_column(Integer, nullable=False) + thread_id: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) + version: Mapped[str] = mapped_column(String(50), nullable=False) + environment: Mapped[str] = mapped_column(String(50), nullable=False, default="production") + + # OpenTelemetry trace context + trace_id: Mapped[Optional[str]] = mapped_column(String(32), index=True, nullable=True) + span_id: Mapped[Optional[str]] = mapped_column(String(16), nullable=True) + + # Indexes for performance + __table_args__ = ( + Index("idx_log_correlation_time", "correlation_id", "timestamp"), + Index("idx_log_user_time", "user_id", "timestamp"), + Index("idx_log_level_time", "level", "timestamp"), + Index("idx_log_component_time", "component", "timestamp"), + Index("idx_log_security", "is_security_event", "security_severity", "timestamp"), + Index("idx_log_operation", "operation_type", "timestamp"), + Index("idx_log_trace", "trace_id", "timestamp"), + ) + + +class PerformanceMetric(Base): + """Aggregated performance metrics from log analysis. + + Stores time-windowed aggregations of operation performance + for analytics and trend analysis. + """ + + __tablename__ = "performance_metrics" + + # Primary key + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: uuid.uuid4().hex) + + # Timestamp + timestamp: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, index=True, default=utc_now) + + # Metric identification + operation_type: Mapped[str] = mapped_column(String(100), nullable=False, index=True) + component: Mapped[str] = mapped_column(String(100), nullable=False, index=True) + + # Aggregated metrics + request_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + error_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + error_rate: Mapped[float] = mapped_column(Float, nullable=False, default=0.0) + + # Duration metrics (in milliseconds) + avg_duration_ms: Mapped[float] = mapped_column(Float, nullable=False) + min_duration_ms: Mapped[float] = mapped_column(Float, nullable=False) + max_duration_ms: Mapped[float] = mapped_column(Float, nullable=False) + p50_duration_ms: Mapped[float] = mapped_column(Float, nullable=False) + p95_duration_ms: Mapped[float] = mapped_column(Float, nullable=False) + p99_duration_ms: Mapped[float] = mapped_column(Float, nullable=False) + + # Time window + window_start: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, index=True) + window_end: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) + window_duration_seconds: Mapped[int] = mapped_column(Integer, nullable=False) + + # Additional context + metric_metadata: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) + + __table_args__ = ( + Index("idx_perf_operation_time", "operation_type", "window_start"), + Index("idx_perf_component_time", "component", "window_start"), + Index("idx_perf_window", "window_start", "window_end"), + ) + + +class SecurityEvent(Base): + """Security event logging for threat detection and audit trails. + + Specialized table for security events with enhanced context + and threat analysis capabilities. + """ + + __tablename__ = "security_events" + + # Primary key + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: uuid.uuid4().hex) + + # Timestamps + timestamp: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, index=True, default=utc_now) + detected_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=utc_now) + + # Correlation tracking + correlation_id: Mapped[Optional[str]] = mapped_column(String(64), index=True, nullable=True) + log_entry_id: Mapped[Optional[str]] = mapped_column(String(36), ForeignKey("structured_log_entries.id"), index=True, nullable=True) + + # Event classification + event_type: Mapped[str] = mapped_column(String(100), nullable=False, index=True) # auth_failure, suspicious_activity, rate_limit, etc. + severity: Mapped[str] = mapped_column(String(20), nullable=False, index=True) # LOW, MEDIUM, HIGH, CRITICAL + category: Mapped[str] = mapped_column(String(50), nullable=False, index=True) # authentication, authorization, data_access, etc. + + # User and request context + user_id: Mapped[Optional[str]] = mapped_column(String(255), index=True, nullable=True) + user_email: Mapped[Optional[str]] = mapped_column(String(255), index=True, nullable=True) + client_ip: Mapped[str] = mapped_column(String(45), nullable=False, index=True) + user_agent: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + + # Event details + description: Mapped[str] = mapped_column(Text, nullable=False) + action_taken: Mapped[Optional[str]] = mapped_column(String(100), nullable=True) # blocked, allowed, flagged, etc. + + # Threat analysis + threat_score: Mapped[float] = mapped_column(Float, nullable=False, default=0.0) # 0.0-1.0 + threat_indicators: Mapped[Dict[str, Any]] = mapped_column(JSON, nullable=False, default=dict) + failed_attempts_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + + # Resolution tracking + resolved: Mapped[bool] = mapped_column(Boolean, default=False, index=True, nullable=False) + resolved_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) + resolved_by: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + resolution_notes: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + + # Alert tracking + alert_sent: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) + alert_sent_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) + alert_recipients: Mapped[Optional[List[str]]] = mapped_column(JSON, nullable=True) + + # Additional context + context: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) + + __table_args__ = ( + Index("idx_security_type_time", "event_type", "timestamp"), + Index("idx_security_severity_time", "severity", "timestamp"), + Index("idx_security_user_time", "user_id", "timestamp"), + Index("idx_security_ip_time", "client_ip", "timestamp"), + Index("idx_security_unresolved", "resolved", "severity", "timestamp"), + ) + + +class AuditTrail(Base): + """Comprehensive audit trail for data access and changes. + + Tracks all significant system changes and data access for + compliance and security auditing. + """ + + __tablename__ = "audit_trails" + + # Primary key + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: uuid.uuid4().hex) + + # Timestamps + timestamp: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, index=True, default=utc_now) + + # Correlation tracking + correlation_id: Mapped[Optional[str]] = mapped_column(String(64), index=True, nullable=True) + request_id: Mapped[Optional[str]] = mapped_column(String(64), index=True, nullable=True) + + # Action details + action: Mapped[str] = mapped_column(String(100), nullable=False, index=True) # create, read, update, delete, execute, etc. + resource_type: Mapped[str] = mapped_column(String(100), nullable=False, index=True) # tool, resource, prompt, user, etc. + resource_id: Mapped[Optional[str]] = mapped_column(String(255), index=True, nullable=True) + resource_name: Mapped[Optional[str]] = mapped_column(String(500), nullable=True) + + # User context + user_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True) + user_email: Mapped[Optional[str]] = mapped_column(String(255), index=True, nullable=True) + team_id: Mapped[Optional[str]] = mapped_column(String(36), index=True, nullable=True) + + # Request context + client_ip: Mapped[Optional[str]] = mapped_column(String(45), nullable=True) + user_agent: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + request_path: Mapped[Optional[str]] = mapped_column(String(500), nullable=True) + request_method: Mapped[Optional[str]] = mapped_column(String(10), nullable=True) + + # Change tracking + old_values: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) + new_values: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) + changes: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) + + # Data classification + data_classification: Mapped[Optional[str]] = mapped_column(String(50), index=True, nullable=True) # public, internal, confidential, restricted + requires_review: Mapped[bool] = mapped_column(Boolean, default=False, index=True, nullable=False) + + # Result + success: Mapped[bool] = mapped_column(Boolean, nullable=False, index=True) + error_message: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + + # Additional context + context: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) + + __table_args__ = ( + Index("idx_audit_action_time", "action", "timestamp"), + Index("idx_audit_resource_time", "resource_type", "resource_id", "timestamp"), + Index("idx_audit_user_time", "user_id", "timestamp"), + Index("idx_audit_classification", "data_classification", "timestamp"), + Index("idx_audit_review", "requires_review", "timestamp"), + ) + + if __name__ == "__main__": # Wait for database to be ready before initializing wait_for_db_ready(max_tries=int(settings.db_max_retries), interval=int(settings.db_retry_interval_ms) / 1000, sync=True) # Converting ms to s diff --git a/mcpgateway/main.py b/mcpgateway/main.py index 6eb6e0f2d..1b0e39eae 100644 --- a/mcpgateway/main.py +++ b/mcpgateway/main.py @@ -27,7 +27,7 @@ # Standard import asyncio -from contextlib import asynccontextmanager +from contextlib import asynccontextmanager, suppress from datetime import datetime import json import os as _os # local alias to avoid collisions @@ -70,6 +70,7 @@ from mcpgateway.db import refresh_slugs_on_startup, SessionLocal from mcpgateway.db import Tool as DbTool from mcpgateway.handlers.sampling import SamplingHandler +from mcpgateway.middleware.correlation_id import CorrelationIDMiddleware from mcpgateway.middleware.http_auth_middleware import HttpAuthMiddleware from mcpgateway.middleware.protocol_version import MCPProtocolVersionMiddleware from mcpgateway.middleware.rbac import get_current_user_with_permissions, require_permission @@ -112,6 +113,7 @@ from mcpgateway.services.import_service import ConflictStrategy, ImportConflictError from mcpgateway.services.import_service import ImportError as ImportServiceError from mcpgateway.services.import_service import ImportService, ImportValidationError +from mcpgateway.services.log_aggregator import get_log_aggregator from mcpgateway.services.logging_service import LoggingService from mcpgateway.services.metrics import setup_metrics from mcpgateway.services.prompt_service import PromptError, PromptNameConflictError, PromptNotFoundError, PromptService @@ -406,6 +408,10 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[None]: Exception: Any unhandled error that occurs during service initialisation or shutdown is re-raised to the caller. """ + aggregation_stop_event: Optional[asyncio.Event] = None + aggregation_loop_task: Optional[asyncio.Task] = None + aggregation_backfill_task: Optional[asyncio.Task] = None + # Initialize logging service FIRST to ensure all logging goes to dual output await logging_service.initialize() logger.info("Starting MCP Gateway services") @@ -461,6 +467,54 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[None]: # Reconfigure uvicorn loggers after startup to capture access logs in dual output logging_service.configure_uvicorn_after_startup() + if settings.metrics_aggregation_enabled and settings.metrics_aggregation_auto_start: + aggregation_stop_event = asyncio.Event() + log_aggregator = get_log_aggregator() + + async def run_log_backfill() -> None: + """Backfill log aggregation metrics for configured hours.""" + hours = getattr(settings, "metrics_aggregation_backfill_hours", 0) + if hours <= 0: + return + try: + await asyncio.to_thread(log_aggregator.backfill, hours) + logger.info("Log aggregation backfill completed for last %s hour(s)", hours) + except Exception as backfill_error: # pragma: no cover - defensive logging + logger.warning("Log aggregation backfill failed: %s", backfill_error) + + async def run_log_aggregation_loop() -> None: + """Run continuous log aggregation at configured intervals. + + Raises: + asyncio.CancelledError: When aggregation is stopped + """ + interval_seconds = max(1, int(settings.metrics_aggregation_window_minutes)) * 60 + logger.info( + "Starting log aggregation loop (window=%s min)", + log_aggregator.aggregation_window_minutes, + ) + try: + while not aggregation_stop_event.is_set(): + try: + await asyncio.to_thread(log_aggregator.aggregate_all_components) + except Exception as agg_error: # pragma: no cover - defensive logging + logger.warning("Log aggregation loop iteration failed: %s", agg_error) + + try: + await asyncio.wait_for(aggregation_stop_event.wait(), timeout=interval_seconds) + except asyncio.TimeoutError: + continue + except asyncio.CancelledError: + logger.debug("Log aggregation loop cancelled") + raise + finally: + logger.info("Log aggregation loop stopped") + + aggregation_backfill_task = asyncio.create_task(run_log_backfill()) + aggregation_loop_task = asyncio.create_task(run_log_aggregation_loop()) + elif settings.metrics_aggregation_enabled: + logger.info("Metrics aggregation auto-start disabled; performance metrics will be generated on-demand when requested.") + yield except Exception as e: logger.error(f"Error during startup: {str(e)}") @@ -474,6 +528,14 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[None]: raise SystemExit(1) raise finally: + if aggregation_stop_event is not None: + aggregation_stop_event.set() + for task in (aggregation_backfill_task, aggregation_loop_task): + if task: + task.cancel() + with suppress(asyncio.CancelledError): + await task + # Shutdown plugin manager if plugin_manager: try: @@ -1169,6 +1231,15 @@ async def _call_streamable_http(self, scope, receive, send): # Add HTTP authentication hook middleware for plugins (before auth dependencies) if plugin_manager: app.add_middleware(HttpAuthMiddleware, plugin_manager=plugin_manager) + logger.info("🔌 HTTP authentication hooks enabled for plugins") + +# Add request logging middleware FIRST (always enabled for gateway boundary logging) +# IMPORTANT: Must be registered BEFORE CorrelationIDMiddleware so it executes AFTER correlation ID is set +# Gateway boundary logging (request_started/completed) runs regardless of log_requests setting +# Detailed payload logging only runs if log_detailed_requests=True +app.add_middleware( + RequestLoggingMiddleware, enable_gateway_logging=True, log_detailed_requests=settings.log_requests, log_level=settings.log_level, max_body_size=settings.log_max_size_mb * 1024 * 1024 +) # Convert MB to bytes # Add custom DocsAuthMiddleware app.add_middleware(DocsAuthMiddleware) @@ -1176,13 +1247,27 @@ async def _call_streamable_http(self, scope, receive, send): # Trust all proxies (or lock down with a list of host patterns) app.add_middleware(ProxyHeadersMiddleware, trusted_hosts="*") -# Add request logging middleware if enabled -if settings.log_requests: - app.add_middleware(RequestLoggingMiddleware, log_requests=settings.log_requests, log_level=settings.log_level, max_body_size=settings.log_max_size_mb * 1024 * 1024) # Convert MB to bytes +# Add correlation ID middleware if enabled +# Note: Registered AFTER RequestLoggingMiddleware so correlation ID is available when RequestLoggingMiddleware executes +if settings.correlation_id_enabled: + app.add_middleware(CorrelationIDMiddleware) + logger.info(f"✅ Correlation ID tracking enabled (header: {settings.correlation_id_header})") + +# Add authentication context middleware if security logging is enabled +# This middleware extracts user context and logs security events (authentication attempts) +# Note: This is independent of observability - security logging is always important +if settings.security_logging_enabled: + # First-Party + from mcpgateway.middleware.auth_middleware import AuthContextMiddleware + + app.add_middleware(AuthContextMiddleware) + logger.info("🔐 Authentication context middleware enabled - logging security events") +else: + logger.info("🔐 Security event logging disabled") # Add observability middleware if enabled # Note: Middleware runs in REVERSE order (last added runs first) -# We add ObservabilityMiddleware first so it wraps AuthContextMiddleware +# If AuthContextMiddleware is already registered, ObservabilityMiddleware wraps it # Execution order will be: AuthContext -> Observability -> Request Handler if settings.observability_enabled: # First-Party @@ -1190,13 +1275,6 @@ async def _call_streamable_http(self, scope, receive, send): app.add_middleware(ObservabilityMiddleware, enabled=True) logger.info("🔍 Observability middleware enabled - tracing all HTTP requests") - - # Add authentication context middleware (runs BEFORE observability in execution) - # First-Party - from mcpgateway.middleware.auth_middleware import AuthContextMiddleware - - app.add_middleware(AuthContextMiddleware) - logger.info("🔐 Authentication context middleware enabled - extracting user info for observability") else: logger.info("🔍 Observability middleware disabled") @@ -2402,7 +2480,20 @@ async def invoke_a2a_agent( logger.debug(f"User {user} is invoking A2A agent '{agent_name}' with type '{interaction_type}'") if a2a_service is None: raise HTTPException(status_code=503, detail="A2A service not available") - return await a2a_service.invoke_agent(db, agent_name, parameters, interaction_type) + user_email = get_user_email(user) + user_id = None + if isinstance(user, dict): + user_id = str(user.get("id") or user.get("sub") or user_email) + else: + user_id = str(user) + return await a2a_service.invoke_agent( + db, + agent_name, + parameters, + interaction_type, + user_id=user_id, + user_email=user_email, + ) except A2AAgentNotFoundError as e: raise HTTPException(status_code=404, detail=str(e)) except A2AAgentError as e: @@ -4980,6 +5071,19 @@ async def cleanup_import_statuses(max_age_hours: int = 24, user=Depends(get_curr app.include_router(tag_router) app.include_router(export_import_router) +# Include log search router if structured logging is enabled +if getattr(settings, "structured_logging_enabled", True): + try: + # First-Party + from mcpgateway.routers.log_search import router as log_search_router + + app.include_router(log_search_router) + logger.info("Log search router included - structured logging enabled") + except ImportError as e: + logger.warning(f"Failed to import log search router: {e}") +else: + logger.info("Log search router not included - structured logging disabled") + # Conditionally include observability router if enabled if settings.observability_enabled: # First-Party diff --git a/mcpgateway/middleware/auth_middleware.py b/mcpgateway/middleware/auth_middleware.py index a8868ccbe..1c2dc7a6c 100644 --- a/mcpgateway/middleware/auth_middleware.py +++ b/mcpgateway/middleware/auth_middleware.py @@ -28,8 +28,10 @@ # First-Party from mcpgateway.auth import get_current_user from mcpgateway.db import SessionLocal +from mcpgateway.services.security_logger import get_security_logger logger = logging.getLogger(__name__) +security_logger = get_security_logger() class AuthContextMiddleware(BaseHTTPMiddleware): @@ -85,14 +87,47 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response: credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials=token) user = await get_current_user(credentials, db) + # Eagerly access user attributes before session closes to prevent DetachedInstanceError + # This forces SQLAlchemy to load the data while the session is still active + # Note: EmailUser uses 'email' as primary key, not 'id' + user_email = user.email + user_id = user_email # For EmailUser, email IS the ID + + # Expunge the user from the session so it can be used after session closes + # This makes the object detached but with all attributes already loaded + db.expunge(user) + # Store user in request state for downstream use request.state.user = user - logger.info(f"✓ Authenticated user for observability: {user.email}") + logger.info(f"✓ Authenticated user: {user_email if user_email else user_id}") + + # Log successful authentication + security_logger.log_authentication_attempt( + user_id=user_id, + user_email=user_email, + auth_method="bearer_token", + success=True, + client_ip=request.client.host if request.client else "unknown", + user_agent=request.headers.get("user-agent"), + db=db, + ) except Exception as e: # Silently fail - let route handlers enforce auth if needed logger.info(f"✗ Auth context extraction failed (continuing as anonymous): {e}") + # Log failed authentication attempt + security_logger.log_authentication_attempt( + user_id="unknown", + user_email=None, + auth_method="bearer_token", + success=False, + client_ip=request.client.host if request.client else "unknown", + user_agent=request.headers.get("user-agent"), + failure_reason=str(e), + db=db if db else None, + ) + finally: # Always close database session if db: diff --git a/mcpgateway/middleware/correlation_id.py b/mcpgateway/middleware/correlation_id.py new file mode 100644 index 000000000..7d9a31193 --- /dev/null +++ b/mcpgateway/middleware/correlation_id.py @@ -0,0 +1,118 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/middleware/correlation_id.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: MCP Gateway Contributors + +Correlation ID (Request ID) Middleware. + +This middleware handles X-Correlation-ID HTTP headers and maps them to the internal +request_id used throughout the system for unified request tracing. + +Key concept: HTTP X-Correlation-ID header → Internal request_id field (single ID for entire request flow) + +The middleware automatically extracts or generates request IDs for every HTTP request, +stores them in context variables for async-safe propagation across services, and +injects them back into response headers for client-side correlation. + +This enables end-to-end tracing: HTTP → Middleware → Services → Plugins → Logs (all with same request_id) +""" + +# Standard +import logging +from typing import Callable + +# Third-Party +from fastapi import Request, Response +from starlette.middleware.base import BaseHTTPMiddleware + +# First-Party +from mcpgateway.config import settings +from mcpgateway.utils.correlation_id import ( + clear_correlation_id, + extract_correlation_id_from_headers, + generate_correlation_id, + set_correlation_id, +) + +logger = logging.getLogger(__name__) + + +class CorrelationIDMiddleware(BaseHTTPMiddleware): + """Middleware for automatic request ID (correlation ID) handling. + + This middleware: + 1. Extracts request ID from X-Correlation-ID header in incoming requests + 2. Generates a new UUID if no correlation ID is present + 3. Stores the ID in context variables for the request lifecycle (used as request_id throughout system) + 4. Injects the request ID into X-Correlation-ID response header + 5. Cleans up context after request completion + + The request ID extracted/generated here becomes the unified request_id used in: + - All log entries (request_id field) + - GlobalContext.request_id (when plugins execute) + - Service method calls for tracing + - Database queries for request tracking + + Configuration is controlled via settings: + - correlation_id_enabled: Enable/disable the middleware + - correlation_id_header: Header name to use (default: X-Correlation-ID) + - correlation_id_preserve: Whether to preserve incoming IDs (default: True) + - correlation_id_response_header: Whether to add ID to responses (default: True) + """ + + def __init__(self, app): + """Initialize the correlation ID (request ID) middleware. + + Args: + app: The FastAPI application instance + """ + super().__init__(app) + self.header_name = getattr(settings, "correlation_id_header", "X-Correlation-ID") + self.preserve_incoming = getattr(settings, "correlation_id_preserve", True) + self.add_to_response = getattr(settings, "correlation_id_response_header", True) + + async def dispatch(self, request: Request, call_next: Callable) -> Response: + """Process the request and manage request ID (correlation ID) lifecycle. + + Extracts or generates a request ID, stores it in context variables for use throughout + the request lifecycle (becomes request_id in logs, services, plugins), and injects + it back into the X-Correlation-ID response header. + + Args: + request: The incoming HTTP request + call_next: The next middleware or route handler + + Returns: + Response: The HTTP response with correlation ID header added + """ + # Extract correlation ID from incoming request headers + correlation_id = None + if self.preserve_incoming: + correlation_id = extract_correlation_id_from_headers(dict(request.headers), self.header_name) + + # Generate new correlation ID if none was provided + if not correlation_id: + correlation_id = generate_correlation_id() + logger.debug(f"Generated new correlation ID: {correlation_id}") + else: + logger.debug(f"Using client-provided correlation ID: {correlation_id}") + + # Store correlation ID in context variable for this request + # This makes it available to all downstream code (auth, services, plugins, logs) + set_correlation_id(correlation_id) + + try: + # Process the request + response = await call_next(request) + + # Add correlation ID to response headers if enabled + if self.add_to_response: + response.headers[self.header_name] = correlation_id + + return response + + finally: + # Clean up context after request completes + # Note: ContextVar automatically cleans up, but explicit cleanup is good practice + clear_correlation_id() diff --git a/mcpgateway/middleware/http_auth_middleware.py b/mcpgateway/middleware/http_auth_middleware.py index 84058641f..8b73ffacd 100644 --- a/mcpgateway/middleware/http_auth_middleware.py +++ b/mcpgateway/middleware/http_auth_middleware.py @@ -8,7 +8,6 @@ # Standard import logging -import uuid # Third-Party from fastapi import Request @@ -17,6 +16,7 @@ # First-Party from mcpgateway.plugins.framework import GlobalContext, HttpHeaderPayload, HttpHookType, HttpPostRequestPayload, HttpPreRequestPayload, PluginManager +from mcpgateway.utils.correlation_id import generate_correlation_id, get_correlation_id logger = logging.getLogger(__name__) @@ -60,9 +60,14 @@ async def dispatch(self, request: Request, call_next): if not self.plugin_manager: return await call_next(request) - # Generate request ID for tracing and store in request state - # This ensures all hooks and downstream code see the same request ID - request_id = uuid.uuid4().hex + # Use correlation ID from CorrelationIDMiddleware if available + # This ensures all hooks and downstream code see the same unified request ID + request_id = get_correlation_id() + if not request_id: + # Fallback if correlation ID middleware is disabled + request_id = generate_correlation_id() + logger.debug(f"Correlation ID not found, generated fallback: {request_id}") + request.state.request_id = request_id # Create global context for hooks diff --git a/mcpgateway/middleware/request_logging_middleware.py b/mcpgateway/middleware/request_logging_middleware.py index db286b20f..f241197ab 100644 --- a/mcpgateway/middleware/request_logging_middleware.py +++ b/mcpgateway/middleware/request_logging_middleware.py @@ -15,19 +15,29 @@ # Standard import json import logging +import time from typing import Callable # Third-Party -from fastapi import Request, Response +from fastapi.security import HTTPAuthorizationCredentials from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import Response # First-Party +from mcpgateway.auth import get_current_user +from mcpgateway.db import SessionLocal from mcpgateway.services.logging_service import LoggingService +from mcpgateway.services.structured_logger import get_structured_logger +from mcpgateway.utils.correlation_id import get_correlation_id # Initialize logging service first logging_service = LoggingService() logger = logging_service.get_logger(__name__) +# Initialize structured logger for gateway boundary logging +structured_logger = get_structured_logger("http_gateway") + SENSITIVE_KEYS = {"password", "secret", "token", "apikey", "access_token", "refresh_token", "client_secret", "authorization", "jwt_token"} @@ -106,20 +116,67 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware): masking sensitive information like passwords, tokens, and authorization headers. """ - def __init__(self, app, log_requests: bool = True, log_level: str = "DEBUG", max_body_size: int = 4096): + def __init__(self, app, enable_gateway_logging: bool = True, log_detailed_requests: bool = False, log_level: str = "DEBUG", max_body_size: int = 4096): """Initialize the request logging middleware. Args: app: The FastAPI application instance - log_requests: Whether to enable request logging + enable_gateway_logging: Whether to enable gateway boundary logging (request_started/completed) + log_detailed_requests: Whether to enable detailed request/response payload logging log_level: The log level for requests (not used, logs at INFO) max_body_size: Maximum request body size to log in bytes """ super().__init__(app) - self.log_requests = log_requests + self.enable_gateway_logging = enable_gateway_logging + self.log_detailed_requests = log_detailed_requests self.log_level = log_level.upper() self.max_body_size = max_body_size # Expected to be in bytes + async def _resolve_user_identity(self, request: Request): + """Best-effort extraction of user identity for request logs. + + Args: + request: The incoming HTTP request + + Returns: + Tuple[Optional[str], Optional[str]]: User ID and email + """ + # Prefer context injected by upstream middleware + if hasattr(request.state, "user") and request.state.user is not None: + raw_user_id = getattr(request.state.user, "id", None) + user_email = getattr(request.state.user, "email", None) + return (str(raw_user_id) if raw_user_id is not None else None, user_email) + + # Fallback: try to authenticate using cookies/headers (matches AuthContextMiddleware) + token = None + if request.cookies: + token = request.cookies.get("jwt_token") or request.cookies.get("access_token") or request.cookies.get("token") + + if not token: + auth_header = request.headers.get("authorization") + if auth_header and auth_header.startswith("Bearer "): + token = auth_header.replace("Bearer ", "") + + if not token: + return (None, None) + + db = None + try: + db = SessionLocal() + credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials=token) + user = await get_current_user(credentials, db) + raw_user_id = getattr(user, "id", None) + user_email = getattr(user, "email", None) + return (str(raw_user_id) if raw_user_id is not None else None, user_email) + except Exception: + return (None, None) + finally: + if db: + try: + db.close() + except Exception: # nosec B110 - Silently handle db.close() failures during cleanup + pass + async def dispatch(self, request: Request, call_next: Callable): """Process incoming request and log details with sensitive data masked. @@ -129,10 +186,74 @@ async def dispatch(self, request: Request, call_next: Callable): Returns: Response: The HTTP response from downstream handlers + + Raises: + Exception: Any exception from downstream handlers is re-raised """ - # Skip logging if disabled - if not self.log_requests: - return await call_next(request) + # Track start time for total duration + start_time = time.time() + + # Get correlation ID and request metadata for boundary logging + correlation_id = get_correlation_id() + path = request.url.path + method = request.method + user_agent = request.headers.get("user-agent", "unknown") + client_ip = request.client.host if request.client else "unknown" + user_id, user_email = await self._resolve_user_identity(request) + + # Skip boundary logging for health checks and static assets + skip_paths = ["/health", "/healthz", "/static", "/favicon.ico"] + should_log_boundary = self.enable_gateway_logging and not any(path.startswith(skip_path) for skip_path in skip_paths) + + # Log gateway request started + if should_log_boundary: + try: + structured_logger.log( + level="INFO", + message=f"Request started: {method} {path}", + component="http_gateway", + correlation_id=correlation_id, + user_email=user_email, + user_id=user_id, + operation_type="http_request", + request_method=method, + request_path=path, + user_agent=user_agent, + client_ip=client_ip, + metadata={"event": "request_started", "query_params": str(request.query_params) if request.query_params else None}, + ) + except Exception as e: + logger.warning(f"Failed to log request start: {e}") + + # Skip detailed logging if disabled + if not self.log_detailed_requests: + response = await call_next(request) + + # Still log request completed even if detailed logging is disabled + if should_log_boundary: + duration_ms = (time.time() - start_time) * 1000 + try: + log_level = "ERROR" if response.status_code >= 500 else "WARNING" if response.status_code >= 400 else "INFO" + structured_logger.log( + level=log_level, + message=f"Request completed: {method} {path} - {response.status_code}", + component="http_gateway", + correlation_id=correlation_id, + user_email=user_email, + user_id=user_id, + operation_type="http_request", + request_method=method, + request_path=path, + response_status_code=response.status_code, + user_agent=user_agent, + client_ip=client_ip, + duration_ms=duration_ms, + metadata={"event": "request_completed", "response_time_category": "fast" if duration_ms < 100 else "normal" if duration_ms < 1000 else "slow"}, + ) + except Exception as e: + logger.warning(f"Failed to log request completion: {e}") + + return response # Always log at INFO level for request payloads to ensure visibility log_level = logging.INFO @@ -171,13 +292,28 @@ async def dispatch(self, request: Request, call_next: Callable): # Mask sensitive headers masked_headers = mask_sensitive_headers(dict(request.headers)) - logger.log( - log_level, - f"📩 Incoming request: {request.method} {request.url.path}\n" - f"Query params: {dict(request.query_params)}\n" - f"Headers: {masked_headers}\n" - f"Body: {payload_str}{'... [truncated]' if truncated else ''}", - ) + # Get correlation ID for request tracking + request_id = get_correlation_id() + + # Try to log with extra parameter, fall back to without if not supported + try: + logger.log( + log_level, + f"📩 Incoming request: {request.method} {request.url.path}\n" + f"Query params: {dict(request.query_params)}\n" + f"Headers: {masked_headers}\n" + f"Body: {payload_str}{'... [truncated]' if truncated else ''}", + extra={"request_id": request_id}, + ) + except TypeError: + # Fall back for test loggers that don't accept extra parameter + logger.log( + log_level, + f"📩 Incoming request: {request.method} {request.url.path}\n" + f"Query params: {dict(request.query_params)}\n" + f"Headers: {masked_headers}\n" + f"Body: {payload_str}{'... [truncated]' if truncated else ''}", + ) except Exception as e: logger.warning(f"Failed to log request body: {e}") @@ -195,5 +331,80 @@ async def receive(): new_scope = request.scope.copy() new_request = Request(new_scope, receive=receive) - response: Response = await call_next(new_request) + # Process request + try: + response: Response = await call_next(new_request) + status_code = response.status_code + except Exception as e: + duration_ms = (time.time() - start_time) * 1000 + + # Log request failed + if should_log_boundary: + try: + structured_logger.log( + level="ERROR", + message=f"Request failed: {method} {path}", + component="gateway", + correlation_id=correlation_id, + user_email=user_email, + user_id=user_id, + operation_type="http_request", + request_method=method, + request_path=path, + user_agent=user_agent, + client_ip=client_ip, + duration_ms=duration_ms, + error=e, + metadata={"event": "request_failed"}, + ) + except Exception as log_error: + logger.warning(f"Failed to log request failure: {log_error}") + + raise + + # Calculate total duration + duration_ms = (time.time() - start_time) * 1000 + + # Log gateway request completed + if should_log_boundary: + try: + log_level = "ERROR" if status_code >= 500 else "WARNING" if status_code >= 400 else "INFO" + + structured_logger.log( + level=log_level, + message=f"Request completed: {method} {path} - {status_code}", + component="gateway", + correlation_id=correlation_id, + user_email=user_email, + user_id=user_id, + operation_type="http_request", + request_method=method, + request_path=path, + response_status_code=status_code, + user_agent=user_agent, + client_ip=client_ip, + duration_ms=duration_ms, + metadata={"event": "request_completed", "response_time_category": self._categorize_response_time(duration_ms)}, + ) + except Exception as e: + logger.warning(f"Failed to log request completion: {e}") + return response + + @staticmethod + def _categorize_response_time(duration_ms: float) -> str: + """Categorize response time for analytics. + + Args: + duration_ms: Response time in milliseconds + + Returns: + Category string + """ + if duration_ms < 100: + return "fast" + if duration_ms < 500: + return "normal" + if duration_ms < 2000: + return "slow" + return "very_slow" diff --git a/mcpgateway/observability.py b/mcpgateway/observability.py index 6714dd392..25b28ef5c 100644 --- a/mcpgateway/observability.py +++ b/mcpgateway/observability.py @@ -15,7 +15,7 @@ import os from typing import Any, Callable, cast, Dict, Optional -# Try to import OpenTelemetry core components - make them truly optional +# Third-Party - Try to import OpenTelemetry core components - make them truly optional OTEL_AVAILABLE = False try: # Third-Party @@ -93,6 +93,9 @@ class _ConsoleSpanExporterStub: # pragma: no cover - test patch replaces this # Shimming is a non-critical, best-effort step for tests; log and continue. logging.getLogger(__name__).debug("Skipping OpenTelemetry shim setup: %s", exc) +# First-Party +from mcpgateway.utils.correlation_id import get_correlation_id # noqa: E402 # pylint: disable=wrong-import-position + # Try to import optional exporters try: OTLP_SPAN_EXPORTER = getattr(_im("opentelemetry.exporter.otlp.proto.grpc.trace_exporter"), "OTLPSpanExporter") @@ -440,6 +443,21 @@ def create_span(name: str, attributes: Optional[Dict[str, Any]] = None) -> Any: # Return a no-op context manager if tracing is not configured or available return nullcontext() + # Auto-inject correlation ID into all spans for request tracing + try: + correlation_id = get_correlation_id() + if correlation_id: + if attributes is None: + attributes = {} + # Add correlation ID if not already present + if "correlation_id" not in attributes: + attributes["correlation_id"] = correlation_id + if "request_id" not in attributes: + attributes["request_id"] = correlation_id # Alias for compatibility + except Exception as exc: + # Correlation ID not available or error getting it, continue without it + logger.debug("Failed to add correlation_id to span: %s", exc) + # Start span and return the context manager span_context = _TRACER.start_as_current_span(name) diff --git a/mcpgateway/plugins/framework/external/mcp/server/runtime.py b/mcpgateway/plugins/framework/external/mcp/server/runtime.py index d13835a81..8ba31618d 100644 --- a/mcpgateway/plugins/framework/external/mcp/server/runtime.py +++ b/mcpgateway/plugins/framework/external/mcp/server/runtime.py @@ -1,3 +1,4 @@ +#!/usr/bin/env python3 # -*- coding: utf-8 -*- """Location: ./mcpgateway/plugins/framework/external/mcp/server/runtime.py Copyright 2025 @@ -12,6 +13,46 @@ - Reads configuration from PLUGINS_SERVER_* environment variables or uses configurations the plugin config.yaml - Implements all plugin hook tools (get_plugin_configs, tool_pre_invoke, etc.) + +Examples: + Create an SSL-capable FastMCP server: + + >>> from mcpgateway.plugins.framework.models import MCPServerConfig + >>> config = MCPServerConfig(host="localhost", port=8000) + >>> server = SSLCapableFastMCP(server_config=config, name="TestServer") + >>> server.settings.host + 'localhost' + >>> server.settings.port + 8000 + + Check SSL configuration returns empty dict when TLS is not configured: + + >>> from mcpgateway.plugins.framework.models import MCPServerConfig + >>> config = MCPServerConfig(host="127.0.0.1", port=8000, tls=None) + >>> server = SSLCapableFastMCP(server_config=config, name="NoTLSServer") + >>> ssl_config = server._get_ssl_config() + >>> ssl_config + {} + + Verify server configuration is accessible: + + >>> from mcpgateway.plugins.framework.models import MCPServerConfig + >>> config = MCPServerConfig(host="localhost", port=9000) + >>> server = SSLCapableFastMCP(server_config=config, name="ConfigTest") + >>> server.server_config.host + 'localhost' + >>> server.server_config.port + 9000 + + Settings are properly passed to FastMCP: + + >>> from mcpgateway.plugins.framework.models import MCPServerConfig + >>> config = MCPServerConfig(host="0.0.0.0", port=8080) + >>> server = SSLCapableFastMCP(server_config=config, name="SettingsTest") + >>> server.settings.host + '0.0.0.0' + >>> server.settings.port + 8080 """ # Standard @@ -45,6 +86,15 @@ async def get_plugin_configs() -> list[dict]: Raises: RuntimeError: If plugin server not initialized. + + Examples: + Function raises RuntimeError when server is not initialized: + + >>> import asyncio + >>> asyncio.run(get_plugin_configs()) # doctest: +SKIP + Traceback (most recent call last): + ... + RuntimeError: Plugin server not initialized """ if not SERVER: raise RuntimeError("Plugin server not initialized") @@ -62,6 +112,13 @@ async def get_plugin_config(name: str) -> dict: Raises: RuntimeError: If plugin server not initialized. + + Examples: + Function returns empty dict when result is None: + + >>> result = None + >>> result if result is not None else {} + {} """ if not SERVER: raise RuntimeError("Plugin server not initialized") @@ -85,6 +142,15 @@ async def invoke_hook(hook_type: str, plugin_name: str, payload: Dict[str, Any], Raises: RuntimeError: If plugin server not initialized. + + Examples: + Function raises RuntimeError when server is not initialized: + + >>> import asyncio + >>> asyncio.run(invoke_hook("hook", "plugin", {}, {})) # doctest: +SKIP + Traceback (most recent call last): + ... + RuntimeError: Plugin server not initialized """ if not SERVER: raise RuntimeError("Plugin server not initialized") @@ -92,7 +158,19 @@ async def invoke_hook(hook_type: str, plugin_name: str, payload: Dict[str, Any], class SSLCapableFastMCP(FastMCP): - """FastMCP server with SSL/TLS support using MCPServerConfig.""" + """FastMCP server with SSL/TLS support using MCPServerConfig. + + Examples: + Create an SSL-capable FastMCP server: + + >>> from mcpgateway.plugins.framework.models import MCPServerConfig + >>> config = MCPServerConfig(host="127.0.0.1", port=8000) + >>> server = SSLCapableFastMCP(server_config=config, name="TestServer") + >>> server.settings.host + '127.0.0.1' + >>> server.settings.port + 8000 + """ def __init__(self, server_config: MCPServerConfig, *args, **kwargs): """Initialize an SSL capable Fast MCP server. @@ -101,6 +179,15 @@ def __init__(self, server_config: MCPServerConfig, *args, **kwargs): server_config: the MCP server configuration including mTLS information. *args: Additional positional arguments passed to FastMCP. **kwargs: Additional keyword arguments passed to FastMCP. + + Examples: + >>> from mcpgateway.plugins.framework.models import MCPServerConfig + >>> config = MCPServerConfig(host="0.0.0.0", port=9000) + >>> server = SSLCapableFastMCP(server_config=config, name="PluginServer") + >>> server.server_config.host + '0.0.0.0' + >>> server.server_config.port + 9000 """ # Load server config from environment @@ -118,6 +205,14 @@ def _get_ssl_config(self) -> dict: Returns: Dictionary of SSL configuration parameters for uvicorn. + + Examples: + >>> from mcpgateway.plugins.framework.models import MCPServerConfig + >>> config = MCPServerConfig(host="127.0.0.1", port=8000, tls=None) + >>> server = SSLCapableFastMCP(server_config=config, name="TestServer") + >>> ssl_config = server._get_ssl_config() + >>> ssl_config + {} """ ssl_config = {} @@ -155,6 +250,18 @@ async def _start_health_check_server(self, health_port: int) -> None: Args: health_port: Port number for the health check server. + + Examples: + Health check endpoint returns expected JSON response: + + >>> import asyncio + >>> from starlette.responses import JSONResponse + >>> from starlette.requests import Request + >>> async def health_check(_request: Request): + ... return JSONResponse({"status": "healthy"}) + >>> response = asyncio.run(health_check(None)) + >>> response.status_code + 200 """ # Third-Party from starlette.applications import Starlette # pylint: disable=import-outside-toplevel @@ -184,7 +291,19 @@ async def health_check(_request: Request): await server.serve() async def run_streamable_http_async(self) -> None: - """Run the server using StreamableHTTP transport with optional SSL/TLS.""" + """Run the server using StreamableHTTP transport with optional SSL/TLS. + + Examples: + Server uses configured host and port: + + >>> from mcpgateway.plugins.framework.models import MCPServerConfig + >>> config = MCPServerConfig(host="0.0.0.0", port=9000) + >>> server = SSLCapableFastMCP(server_config=config, name="HTTPServer") + >>> server.settings.host + '0.0.0.0' + >>> server.settings.port + 9000 + """ starlette_app = self.streamable_http_app() # Add health check endpoint to main app @@ -247,6 +366,20 @@ async def run() -> None: Raises: Exception: If plugin server initialization or execution fails. + + Examples: + SERVER module variable starts as None: + + >>> SERVER is None + True + + FastMCP server names are defined as constants: + + >>> from mcpgateway.plugins.framework.constants import MCP_SERVER_NAME + >>> isinstance(MCP_SERVER_NAME, str) + True + >>> len(MCP_SERVER_NAME) > 0 + True """ global SERVER # pylint: disable=global-statement diff --git a/mcpgateway/plugins/framework/external/mcp/server/server.py b/mcpgateway/plugins/framework/external/mcp/server/server.py index a6e283f1f..ace45e331 100644 --- a/mcpgateway/plugins/framework/external/mcp/server/server.py +++ b/mcpgateway/plugins/framework/external/mcp/server/server.py @@ -5,6 +5,60 @@ Authors: Fred Araujo, Teryl Taylor Module that contains plugin MCP server code to serve external plugins. + +Examples: + Create an external plugin server with a configuration file: + + >>> server = ExternalPluginServer(config_path="./tests/unit/mcpgateway/plugins/fixtures/configs/valid_single_plugin.yaml") + >>> server is not None + True + >>> isinstance(server._config_path, str) + True + + Get server configuration with defaults: + + >>> server = ExternalPluginServer(config_path="./tests/unit/mcpgateway/plugins/fixtures/configs/valid_single_plugin.yaml") + >>> config = server.get_server_config() + >>> config.host == '127.0.0.1' + True + >>> config.port == 8000 + True + + Verify plugin manager is initialized: + + >>> server = ExternalPluginServer(config_path="./tests/unit/mcpgateway/plugins/fixtures/configs/valid_single_plugin.yaml") + >>> server._plugin_manager is not None + True + >>> server._config is not None + True + + Multiple servers can be created: + + >>> server1 = ExternalPluginServer(config_path="./tests/unit/mcpgateway/plugins/fixtures/configs/valid_single_plugin.yaml") + >>> server2 = ExternalPluginServer(config_path="./tests/unit/mcpgateway/plugins/fixtures/configs/valid_multiple_plugins_filter.yaml") + >>> server1._config_path != server2._config_path + True + + Configuration is loaded from file: + + >>> import asyncio + >>> server = ExternalPluginServer(config_path="./tests/unit/mcpgateway/plugins/fixtures/configs/valid_single_plugin.yaml") + >>> plugins = asyncio.run(server.get_plugin_configs()) + >>> isinstance(plugins, list) + True + >>> len(plugins) >= 1 + True + + Server configuration defaults are sensible: + + >>> server = ExternalPluginServer(config_path="./tests/unit/mcpgateway/plugins/fixtures/configs/valid_single_plugin.yaml") + >>> config = server.get_server_config() + >>> isinstance(config.host, str) + True + >>> isinstance(config.port, int) + True + >>> config.port > 0 + True """ # Standard @@ -58,6 +112,21 @@ async def get_plugin_configs(self) -> list[dict]: >>> plugins = asyncio.run(server.get_plugin_configs()) >>> len(plugins) > 0 True + + Returns empty list when no plugins configured: + + >>> server = ExternalPluginServer(config_path="./tests/unit/mcpgateway/plugins/fixtures/configs/valid_single_plugin.yaml") + >>> server._config.plugins = None + >>> plugins = asyncio.run(server.get_plugin_configs()) + >>> plugins + [] + + Each plugin config is a dictionary: + + >>> server = ExternalPluginServer(config_path="./tests/unit/mcpgateway/plugins/fixtures/configs/valid_single_plugin.yaml") + >>> plugins = asyncio.run(server.get_plugin_configs()) + >>> all(isinstance(p, dict) for p in plugins) + True """ plugins: list[dict] = [] if self._config.plugins: @@ -82,6 +151,21 @@ async def get_plugin_config(self, name: str) -> dict | None: True >>> c["name"] == "DenyListPlugin" True + + Returns None when plugin not found: + + >>> server = ExternalPluginServer(config_path="./tests/unit/mcpgateway/plugins/fixtures/configs/valid_single_plugin.yaml") + >>> c = asyncio.run(server.get_plugin_config(name="NonExistentPlugin")) + >>> c is None + True + + Case-insensitive plugin name lookup: + + >>> server = ExternalPluginServer(config_path="./tests/unit/mcpgateway/plugins/fixtures/configs/valid_single_plugin.yaml") + >>> c1 = asyncio.run(server.get_plugin_config(name="ReplaceBadWordsPlugin")) + >>> c2 = asyncio.run(server.get_plugin_config(name="replacebadwordsplugin")) + >>> c1 == c2 + True """ if self._config.plugins: for plug in self._config.plugins: @@ -145,12 +229,28 @@ async def initialize(self) -> bool: Returns: A boolean indicating the intialization status of the server. + + Examples: + >>> import asyncio + >>> server = ExternalPluginServer(config_path="./tests/unit/mcpgateway/plugins/fixtures/configs/valid_single_plugin.yaml") + >>> result = asyncio.run(server.initialize()) + >>> result + True + >>> asyncio.run(server.shutdown()) """ await self._plugin_manager.initialize() return self._plugin_manager.initialized async def shutdown(self) -> None: - """Shutdown the plugin server.""" + """Shutdown the plugin server. + + Examples: + >>> import asyncio + >>> server = ExternalPluginServer(config_path="./tests/unit/mcpgateway/plugins/fixtures/configs/valid_single_plugin.yaml") + >>> asyncio.run(server.initialize()) + True + >>> asyncio.run(server.shutdown()) + """ if self._plugin_manager.initialized: await self._plugin_manager.shutdown() @@ -159,5 +259,15 @@ def get_server_config(self) -> MCPServerConfig: Returns: A server configuration including host, port, and TLS information. + + Examples: + >>> server = ExternalPluginServer(config_path="./tests/unit/mcpgateway/plugins/fixtures/configs/valid_single_plugin.yaml") + >>> config = server.get_server_config() + >>> isinstance(config, MCPServerConfig) + True + >>> config.host + '127.0.0.1' + >>> config.port + 8000 """ return self._config.server_settings or MCPServerConfig.from_env() or MCPServerConfig() diff --git a/mcpgateway/plugins/framework/external/mcp/tls_utils.py b/mcpgateway/plugins/framework/external/mcp/tls_utils.py index 91b04cfb0..c56da52c1 100644 --- a/mcpgateway/plugins/framework/external/mcp/tls_utils.py +++ b/mcpgateway/plugins/framework/external/mcp/tls_utils.py @@ -9,6 +9,57 @@ This module provides utilities for creating and configuring SSL contexts for secure communication with external MCP plugin servers. It implements the certificate validation logic that is tested in test_client_certificate_validation.py. + +Examples: + Create a basic SSL context with default settings: + + >>> from mcpgateway.plugins.framework.models import MCPClientTLSConfig + >>> import ssl + >>> config = MCPClientTLSConfig() + >>> ctx = create_ssl_context(config, "ExamplePlugin") + >>> ctx.verify_mode == ssl.CERT_REQUIRED + True + + Create an SSL context with hostname verification disabled: + + >>> config = MCPClientTLSConfig(verify=True, check_hostname=False) + >>> ctx = create_ssl_context(config, "NoHostnamePlugin") + >>> ctx.verify_mode == ssl.CERT_REQUIRED + True + >>> ctx.check_hostname + False + + Verify that TLS version is enforced: + + >>> config = MCPClientTLSConfig(verify=True) + >>> ctx = create_ssl_context(config, "VersionTestPlugin") + >>> ctx.minimum_version >= ssl.TLSVersion.TLSv1_2 + True + + All SSL contexts have TLS 1.2 minimum: + + >>> config1 = MCPClientTLSConfig(verify=True) + >>> config2 = MCPClientTLSConfig(verify=False) + >>> ctx1 = create_ssl_context(config1, "Plugin1") + >>> ctx2 = create_ssl_context(config2, "Plugin2") + >>> ctx1.minimum_version == ctx2.minimum_version + True + >>> ctx1.minimum_version.name + 'TLSv1_2' + + Verify mode differs based on configuration: + + >>> config_secure = MCPClientTLSConfig(verify=True) + >>> config_insecure = MCPClientTLSConfig(verify=False) + >>> ctx_secure = create_ssl_context(config_secure, "SecureP") + >>> ctx_insecure = create_ssl_context(config_insecure, "InsecureP") + >>> ctx_secure.verify_mode != ctx_insecure.verify_mode + True + >>> import ssl + >>> ctx_secure.verify_mode == ssl.CERT_REQUIRED + True + >>> ctx_insecure.verify_mode == ssl.CERT_NONE + True """ # Standard @@ -57,16 +108,75 @@ def create_ssl_context(tls_config: MCPClientTLSConfig, plugin_name: str) -> ssl. Raises: PluginError: If SSL context configuration fails - Example: - >>> tls_config = MCPClientTLSConfig( # doctest: +SKIP - ... ca_bundle="/path/to/ca.crt", - ... certfile="/path/to/client.crt", - ... keyfile="/path/to/client.key", - ... verify=True, - ... check_hostname=True - ... ) - >>> ssl_context = create_ssl_context(tls_config, "MyPlugin") # doctest: +SKIP - >>> # Use ssl_context with httpx or other SSL connections + Examples: + Create SSL context with verification enabled (default secure mode): + + >>> from mcpgateway.plugins.framework.models import MCPClientTLSConfig + >>> tls_config = MCPClientTLSConfig(verify=True) + >>> ssl_context = create_ssl_context(tls_config, "TestPlugin") + >>> ssl_context.verify_mode == 2 # ssl.CERT_REQUIRED + True + >>> ssl_context.check_hostname + True + + Create SSL context with verification disabled (development/testing): + + >>> tls_config = MCPClientTLSConfig(verify=False, check_hostname=False) + >>> ssl_context = create_ssl_context(tls_config, "DevPlugin") + >>> ssl_context.verify_mode == 0 # ssl.CERT_NONE + True + >>> ssl_context.check_hostname + False + + Verify TLS 1.2 minimum version enforcement: + + >>> tls_config = MCPClientTLSConfig(verify=True) + >>> ssl_context = create_ssl_context(tls_config, "SecurePlugin") + >>> ssl_context.minimum_version.name + 'TLSv1_2' + + Mixed security settings (verify enabled, hostname check disabled): + + >>> tls_config = MCPClientTLSConfig(verify=True, check_hostname=False) + >>> ssl_context = create_ssl_context(tls_config, "MixedPlugin") + >>> ssl_context.verify_mode == 2 # ssl.CERT_REQUIRED + True + >>> ssl_context.check_hostname + False + + Default configuration is secure: + + >>> tls_config = MCPClientTLSConfig() + >>> ssl_context = create_ssl_context(tls_config, "DefaultPlugin") + >>> ssl_context.verify_mode == 2 # ssl.CERT_REQUIRED + True + >>> ssl_context.check_hostname + True + >>> ssl_context.minimum_version.name + 'TLSv1_2' + + Test error handling with invalid certificate file: + + >>> import tempfile + >>> import os + >>> tmp_dir = tempfile.mkdtemp() + >>> bad_cert = os.path.join(tmp_dir, "bad.pem") + >>> with open(bad_cert, 'w') as f: + ... _ = f.write("INVALID CERT") + >>> tls_config = MCPClientTLSConfig(certfile=bad_cert, keyfile=bad_cert, verify=False) + >>> try: + ... ssl_context = create_ssl_context(tls_config, "BadCertPlugin") + ... except PluginError as e: + ... "Failed to configure SSL context" in e.error.message + True + + Verify logging occurs for different configurations: + + >>> import logging + >>> tls_config = MCPClientTLSConfig(verify=False) + >>> ssl_context = create_ssl_context(tls_config, "LogTestPlugin") + >>> ssl_context is not None + True """ try: # Create SSL context with secure defaults @@ -86,7 +196,7 @@ def create_ssl_context(tls_config: MCPClientTLSConfig, plugin_name: str) -> ssl. # Disable certificate verification (not recommended for production) logger.warning(f"Certificate verification disabled for plugin '{plugin_name}'. This is not recommended for production use.") ssl_context.check_hostname = False - ssl_context.verify_mode = ssl.CERT_NONE # noqa: DUO122 + ssl_context.verify_mode = ssl.CERT_NONE # nosec B502 # noqa: DUO122 else: # Enable strict certificate verification (production mode) # Load CA certificate bundle for server certificate validation diff --git a/mcpgateway/routers/log_search.py b/mcpgateway/routers/log_search.py new file mode 100644 index 000000000..5023ea614 --- /dev/null +++ b/mcpgateway/routers/log_search.py @@ -0,0 +1,754 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/routers/log_search.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 + +Log Search API Router. + +This module provides REST API endpoints for searching and analyzing structured logs, +security events, audit trails, and performance metrics. +""" + +# Standard +from datetime import datetime, timedelta, timezone +import logging +from typing import Any, Dict, List, Optional, Tuple + +# Third-Party +from fastapi import APIRouter, Depends, HTTPException, Query +from pydantic import BaseModel, Field +from sqlalchemy import and_, delete, desc, or_, select +from sqlalchemy.orm import Session +from sqlalchemy.sql import func as sa_func + +# First-Party +from mcpgateway.config import settings +from mcpgateway.db import ( + AuditTrail, + get_db, + PerformanceMetric, + SecurityEvent, + StructuredLogEntry, +) +from mcpgateway.middleware.rbac import get_current_user_with_permissions, require_permission +from mcpgateway.services.log_aggregator import get_log_aggregator + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/logs", tags=["logs"]) + +MIN_PERFORMANCE_RANGE_HOURS = 5.0 / 60.0 +_DEFAULT_AGGREGATION_KEY = "5m" +_AGGREGATION_LEVELS: Dict[str, Dict[str, Any]] = { + "5m": {"minutes": 5, "label": "5-minute windows"}, + "24h": {"minutes": 24 * 60, "label": "24-hour windows"}, +} + + +def _align_to_window(dt: datetime, window_minutes: int) -> datetime: + """Align a datetime down to the nearest aggregation window boundary. + + Args: + dt: Datetime to align + window_minutes: Aggregation window size in minutes + + Returns: + datetime: Aligned datetime at window boundary + """ + timestamp = dt.astimezone(timezone.utc) + total_minutes = int(timestamp.timestamp() // 60) + aligned_minutes = (total_minutes // window_minutes) * window_minutes + return datetime.fromtimestamp(aligned_minutes * 60, tz=timezone.utc) + + +def _deduplicate_metrics(metrics: List[PerformanceMetric]) -> List[PerformanceMetric]: + """Ensure a single metric per component/operation/window. + + Args: + metrics: List of performance metrics to deduplicate + + Returns: + List[PerformanceMetric]: Deduplicated metrics sorted by window_start + """ + if not metrics: + return [] + + deduped: Dict[Tuple[str, str, datetime], PerformanceMetric] = {} + for metric in metrics: + component = metric.component or "" + operation = metric.operation_type or "" + key = (component, operation, metric.window_start) + existing = deduped.get(key) + if existing is None or metric.timestamp > existing.timestamp: + deduped[key] = metric + + return sorted(deduped.values(), key=lambda m: m.window_start, reverse=True) + + +def _aggregate_custom_windows( + aggregator, + window_minutes: int, + db: Session, +) -> None: + """Aggregate metrics using custom window duration. + + Args: + aggregator: Log aggregator instance + window_minutes: Window size in minutes + db: Database session + """ + window_delta = timedelta(minutes=window_minutes) + window_duration_seconds = window_minutes * 60 + + sample_row = db.execute( + select(PerformanceMetric.window_start, PerformanceMetric.window_end) + .where(PerformanceMetric.window_duration_seconds == window_duration_seconds) + .order_by(desc(PerformanceMetric.window_start)) + .limit(1) + ).first() + + needs_rebuild = False + if sample_row: + sample_start, sample_end = sample_row + if sample_start is not None and sample_end is not None: + start_utc = sample_start if sample_start.tzinfo else sample_start.replace(tzinfo=timezone.utc) + end_utc = sample_end if sample_end.tzinfo else sample_end.replace(tzinfo=timezone.utc) + duration = int((end_utc - start_utc).total_seconds()) + if duration != window_duration_seconds: + needs_rebuild = True + aligned_start = _align_to_window(start_utc, window_minutes) + if aligned_start != start_utc: + needs_rebuild = True + + if needs_rebuild: + db.execute(delete(PerformanceMetric).where(PerformanceMetric.window_duration_seconds == window_duration_seconds)) + db.commit() + sample_row = None + + max_existing = None + if not needs_rebuild: + max_existing = db.execute(select(sa_func.max(PerformanceMetric.window_start)).where(PerformanceMetric.window_duration_seconds == window_duration_seconds)).scalar() + + if max_existing: + current_start = max_existing if max_existing.tzinfo else max_existing.replace(tzinfo=timezone.utc) + current_start = current_start + window_delta + else: + earliest_log = db.execute(select(sa_func.min(StructuredLogEntry.timestamp))).scalar() + if not earliest_log: + return + if earliest_log.tzinfo is None: + earliest_log = earliest_log.replace(tzinfo=timezone.utc) + current_start = _align_to_window(earliest_log, window_minutes) + + reference_end = datetime.now(timezone.utc) + + while current_start < reference_end: + current_end = current_start + window_delta + aggregator.aggregate_all_components( + window_start=current_start, + window_end=current_end, + db=db, + ) + current_start = current_end + + +# Request/Response Models +class LogSearchRequest(BaseModel): + """Log search request parameters.""" + + search_text: Optional[str] = Field(None, description="Text search query") + level: Optional[List[str]] = Field(None, description="Log levels to filter") + component: Optional[List[str]] = Field(None, description="Components to filter") + category: Optional[List[str]] = Field(None, description="Categories to filter") + correlation_id: Optional[str] = Field(None, description="Correlation ID to filter") + user_id: Optional[str] = Field(None, description="User ID to filter") + start_time: Optional[datetime] = Field(None, description="Start timestamp") + end_time: Optional[datetime] = Field(None, description="End timestamp") + min_duration_ms: Optional[float] = Field(None, description="Minimum duration") + max_duration_ms: Optional[float] = Field(None, description="Maximum duration") + has_error: Optional[bool] = Field(None, description="Filter for errors") + limit: int = Field(100, ge=1, le=1000, description="Maximum results") + offset: int = Field(0, ge=0, description="Result offset") + sort_by: str = Field("timestamp", description="Field to sort by") + sort_order: str = Field("desc", description="Sort order (asc/desc)") + + +class LogEntry(BaseModel): + """Log entry response model.""" + + id: str + timestamp: datetime + level: str + component: str + message: str + correlation_id: Optional[str] = None + user_id: Optional[str] = None + user_email: Optional[str] = None + duration_ms: Optional[float] = None + operation_type: Optional[str] = None + request_path: Optional[str] = None + request_method: Optional[str] = None + is_security_event: bool = False + error_details: Optional[Dict[str, Any]] = None + + class Config: + """Pydantic configuration.""" + + from_attributes = True + + +class LogSearchResponse(BaseModel): + """Log search response.""" + + total: int + results: List[LogEntry] + + +class CorrelationTraceRequest(BaseModel): + """Correlation trace request.""" + + correlation_id: str + + +class CorrelationTraceResponse(BaseModel): + """Correlation trace response with all related logs.""" + + correlation_id: str + total_duration_ms: Optional[float] + log_count: int + error_count: int + logs: List[LogEntry] + security_events: List[Dict[str, Any]] + audit_trails: List[Dict[str, Any]] + performance_metrics: Optional[Dict[str, Any]] + + +class SecurityEventResponse(BaseModel): + """Security event response model.""" + + id: str + timestamp: datetime + event_type: str + severity: str + category: str + user_id: Optional[str] + client_ip: str + description: str + threat_score: float + action_taken: Optional[str] + resolved: bool + + class Config: + """Pydantic configuration.""" + + from_attributes = True + + +class AuditTrailResponse(BaseModel): + """Audit trail response model.""" + + id: str + timestamp: datetime + correlation_id: Optional[str] = None + action: str + resource_type: str + resource_id: Optional[str] + resource_name: Optional[str] = None + user_id: str + user_email: Optional[str] = None + success: bool + requires_review: bool + data_classification: Optional[str] + + class Config: + """Pydantic configuration.""" + + from_attributes = True + + +class PerformanceMetricResponse(BaseModel): + """Performance metric response model.""" + + id: str + timestamp: datetime + component: str + operation_type: str + window_start: datetime + window_end: datetime + request_count: int + error_count: int + error_rate: float + avg_duration_ms: float + min_duration_ms: float + max_duration_ms: float + p50_duration_ms: float + p95_duration_ms: float + p99_duration_ms: float + + class Config: + """Pydantic configuration.""" + + from_attributes = True + + +# API Endpoints +@router.post("/search", response_model=LogSearchResponse) +@require_permission("logs:read") +async def search_logs(request: LogSearchRequest, user=Depends(get_current_user_with_permissions), db: Session = Depends(get_db)) -> LogSearchResponse: + """Search structured logs with filters and pagination. + + Args: + request: Search parameters + user: Current authenticated user + db: Database session + + Returns: + Search results with pagination + + Raises: + HTTPException: On database or validation errors + """ + try: + # Build base query + stmt = select(StructuredLogEntry) + + # Apply filters + conditions = [] + + if request.search_text: + conditions.append(or_(StructuredLogEntry.message.ilike(f"%{request.search_text}%"), StructuredLogEntry.component.ilike(f"%{request.search_text}%"))) + + if request.level: + conditions.append(StructuredLogEntry.level.in_(request.level)) + + if request.component: + conditions.append(StructuredLogEntry.component.in_(request.component)) + + # Note: category field doesn't exist in StructuredLogEntry + # if request.category: + # conditions.append(StructuredLogEntry.category.in_(request.category)) + + if request.correlation_id: + conditions.append(StructuredLogEntry.correlation_id == request.correlation_id) + + if request.user_id: + conditions.append(StructuredLogEntry.user_id == request.user_id) + + if request.start_time: + conditions.append(StructuredLogEntry.timestamp >= request.start_time) + + if request.end_time: + conditions.append(StructuredLogEntry.timestamp <= request.end_time) + + if request.min_duration_ms is not None: + conditions.append(StructuredLogEntry.duration_ms >= request.min_duration_ms) + + if request.max_duration_ms is not None: + conditions.append(StructuredLogEntry.duration_ms <= request.max_duration_ms) + + if request.has_error is not None: + if request.has_error: + conditions.append(StructuredLogEntry.error_details.isnot(None)) + else: + conditions.append(StructuredLogEntry.error_details.is_(None)) + + if conditions: + stmt = stmt.where(and_(*conditions)) + + # Get total count + count_stmt = select(sa_func.count()).select_from(stmt.subquery()) + total = db.execute(count_stmt).scalar() or 0 + + # Apply sorting + sort_column = getattr(StructuredLogEntry, request.sort_by, StructuredLogEntry.timestamp) + if request.sort_order == "desc": + stmt = stmt.order_by(desc(sort_column)) + else: + stmt = stmt.order_by(sort_column) + + # Apply pagination + stmt = stmt.limit(request.limit).offset(request.offset) + + # Execute query + results = db.execute(stmt).scalars().all() + + # Convert to response models + log_entries = [ + LogEntry( + id=str(log.id), + timestamp=log.timestamp, + level=log.level, + component=log.component, + message=log.message, + correlation_id=log.correlation_id, + user_id=log.user_id, + user_email=log.user_email, + duration_ms=log.duration_ms, + operation_type=log.operation_type, + request_path=log.request_path, + request_method=log.request_method, + is_security_event=log.is_security_event, + error_details=log.error_details, + ) + for log in results + ] + + return LogSearchResponse(total=total, results=log_entries) + + except Exception as e: + logger.error(f"Log search failed: {e}") + raise HTTPException(status_code=500, detail="Log search failed") + + +@router.get("/trace/{correlation_id}", response_model=CorrelationTraceResponse) +@require_permission("logs:read") +async def trace_correlation_id(correlation_id: str, user=Depends(get_current_user_with_permissions), db: Session = Depends(get_db)) -> CorrelationTraceResponse: + """Get all logs and events for a correlation ID. + + Args: + correlation_id: Correlation ID to trace + user: Current authenticated user + db: Database session + + Returns: + Complete trace of all related logs and events + + Raises: + HTTPException: On database or validation errors + """ + try: + # Get structured logs + log_stmt = select(StructuredLogEntry).where(StructuredLogEntry.correlation_id == correlation_id).order_by(StructuredLogEntry.timestamp) + + logs = db.execute(log_stmt).scalars().all() + + # Get security events + security_stmt = select(SecurityEvent).where(SecurityEvent.correlation_id == correlation_id).order_by(SecurityEvent.timestamp) + + security_events = db.execute(security_stmt).scalars().all() + + # Get audit trails + audit_stmt = select(AuditTrail).where(AuditTrail.correlation_id == correlation_id).order_by(AuditTrail.timestamp) + + audit_trails = db.execute(audit_stmt).scalars().all() + + # Calculate metrics + durations = [log.duration_ms for log in logs if log.duration_ms is not None] + total_duration = sum(durations) if durations else None + error_count = sum(1 for log in logs if log.error_details) + + # Get performance metrics (if any aggregations exist) + perf_metrics = None + if logs: + component = logs[0].component + operation = logs[0].operation_type + if component and operation: + perf_stmt = ( + select(PerformanceMetric) + .where(and_(PerformanceMetric.component == component, PerformanceMetric.operation_type == operation)) + .order_by(desc(PerformanceMetric.window_start)) + .limit(1) + ) + + perf = db.execute(perf_stmt).scalar_one_or_none() + if perf: + perf_metrics = { + "avg_duration_ms": perf.avg_duration_ms, + "p95_duration_ms": perf.p95_duration_ms, + "p99_duration_ms": perf.p99_duration_ms, + "error_rate": perf.error_rate, + } + + return CorrelationTraceResponse( + correlation_id=correlation_id, + total_duration_ms=total_duration, + log_count=len(logs), + error_count=error_count, + logs=[ + LogEntry( + id=str(log.id), + timestamp=log.timestamp, + level=log.level, + component=log.component, + message=log.message, + correlation_id=log.correlation_id, + user_id=log.user_id, + user_email=log.user_email, + duration_ms=log.duration_ms, + operation_type=log.operation_type, + request_path=log.request_path, + request_method=log.request_method, + is_security_event=log.is_security_event, + error_details=log.error_details, + ) + for log in logs + ], + security_events=[ + { + "id": str(event.id), + "timestamp": event.timestamp.isoformat(), + "event_type": event.event_type, + "severity": event.severity, + "description": event.description, + "threat_score": event.threat_score, + } + for event in security_events + ], + audit_trails=[ + { + "id": str(audit.id), + "timestamp": audit.timestamp.isoformat(), + "action": audit.action, + "resource_type": audit.resource_type, + "resource_id": audit.resource_id, + "success": audit.success, + } + for audit in audit_trails + ], + performance_metrics=perf_metrics, + ) + + except Exception as e: + logger.error(f"Correlation trace failed: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"Correlation trace failed: {str(e)}") + + +@router.get("/security-events", response_model=List[SecurityEventResponse]) +@require_permission("security:read") +async def get_security_events( + severity: Optional[List[str]] = Query(None), + event_type: Optional[List[str]] = Query(None), + resolved: Optional[bool] = Query(None), + start_time: Optional[datetime] = Query(None), + end_time: Optional[datetime] = Query(None), + limit: int = Query(100, ge=1, le=1000), + offset: int = Query(0, ge=0), + user=Depends(get_current_user_with_permissions), + db: Session = Depends(get_db), +) -> List[SecurityEventResponse]: + """Get security events with filters. + + Args: + severity: Filter by severity levels + event_type: Filter by event types + resolved: Filter by resolution status + start_time: Start timestamp + end_time: End timestamp + limit: Maximum results + offset: Result offset + user: Current authenticated user + db: Database session + + Returns: + List of security events + + Raises: + HTTPException: On database or validation errors + """ + try: + stmt = select(SecurityEvent) + + conditions = [] + if severity: + conditions.append(SecurityEvent.severity.in_(severity)) + if event_type: + conditions.append(SecurityEvent.event_type.in_(event_type)) + if resolved is not None: + conditions.append(SecurityEvent.resolved == resolved) + if start_time: + conditions.append(SecurityEvent.timestamp >= start_time) + if end_time: + conditions.append(SecurityEvent.timestamp <= end_time) + + if conditions: + stmt = stmt.where(and_(*conditions)) + + stmt = stmt.order_by(desc(SecurityEvent.timestamp)).limit(limit).offset(offset) + + events = db.execute(stmt).scalars().all() + + return [ + SecurityEventResponse( + id=str(event.id), + timestamp=event.timestamp, + event_type=event.event_type, + severity=event.severity, + category=event.category, + user_id=event.user_id, + client_ip=event.client_ip, + description=event.description, + threat_score=event.threat_score, + action_taken=event.action_taken, + resolved=event.resolved, + ) + for event in events + ] + + except Exception as e: + logger.error(f"Security events query failed: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"Security events query failed: {str(e)}") + + +@router.get("/audit-trails", response_model=List[AuditTrailResponse]) +@require_permission("audit:read") +async def get_audit_trails( + action: Optional[List[str]] = Query(None), + resource_type: Optional[List[str]] = Query(None), + user_id: Optional[str] = Query(None), + requires_review: Optional[bool] = Query(None), + start_time: Optional[datetime] = Query(None), + end_time: Optional[datetime] = Query(None), + limit: int = Query(100, ge=1, le=1000), + offset: int = Query(0, ge=0), + user=Depends(get_current_user_with_permissions), + db: Session = Depends(get_db), +) -> List[AuditTrailResponse]: + """Get audit trails with filters. + + Args: + action: Filter by actions + resource_type: Filter by resource types + user_id: Filter by user ID + requires_review: Filter by review requirement + start_time: Start timestamp + end_time: End timestamp + limit: Maximum results + offset: Result offset + user: Current authenticated user + db: Database session + + Returns: + List of audit trail entries + + Raises: + HTTPException: On database or validation errors + """ + try: + stmt = select(AuditTrail) + + conditions = [] + if action: + conditions.append(AuditTrail.action.in_(action)) + if resource_type: + conditions.append(AuditTrail.resource_type.in_(resource_type)) + if user_id: + conditions.append(AuditTrail.user_id == user_id) + if requires_review is not None: + conditions.append(AuditTrail.requires_review == requires_review) + if start_time: + conditions.append(AuditTrail.timestamp >= start_time) + if end_time: + conditions.append(AuditTrail.timestamp <= end_time) + + if conditions: + stmt = stmt.where(and_(*conditions)) + + stmt = stmt.order_by(desc(AuditTrail.timestamp)).limit(limit).offset(offset) + + trails = db.execute(stmt).scalars().all() + + return [ + AuditTrailResponse( + id=str(trail.id), + timestamp=trail.timestamp, + correlation_id=trail.correlation_id, + action=trail.action, + resource_type=trail.resource_type, + resource_id=trail.resource_id, + resource_name=trail.resource_name, + user_id=trail.user_id, + user_email=trail.user_email, + success=trail.success, + requires_review=trail.requires_review, + data_classification=trail.data_classification, + ) + for trail in trails + ] + + except Exception as e: + logger.error(f"Audit trails query failed: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"Audit trails query failed: {str(e)}") + + +@router.get("/performance-metrics", response_model=List[PerformanceMetricResponse]) +@require_permission("metrics:read") +async def get_performance_metrics( + component: Optional[str] = Query(None), + operation: Optional[str] = Query(None), + hours: float = Query(24.0, ge=MIN_PERFORMANCE_RANGE_HOURS, le=1000.0, description="Historical window to display"), + aggregation: str = Query(_DEFAULT_AGGREGATION_KEY, regex="^(5m|24h)$", description="Aggregation level for metrics"), + user=Depends(get_current_user_with_permissions), + db: Session = Depends(get_db), +) -> List[PerformanceMetricResponse]: + """Get performance metrics. + + Args: + component: Filter by component + operation: Filter by operation + aggregation: Aggregation level (5m, 1h, 1d, 7d) + hours: Hours of history + user: Current authenticated user + db: Database session + + Returns: + List of performance metrics + + Raises: + HTTPException: On database or validation errors + """ + try: + aggregation_config = _AGGREGATION_LEVELS.get(aggregation, _AGGREGATION_LEVELS[_DEFAULT_AGGREGATION_KEY]) + window_minutes = aggregation_config["minutes"] + window_duration_seconds = window_minutes * 60 + + if settings.metrics_aggregation_enabled: + try: + aggregator = get_log_aggregator() + if aggregation == "5m": + aggregator.backfill(hours=hours, db=db) + else: + _aggregate_custom_windows( + aggregator=aggregator, + window_minutes=window_minutes, + db=db, + ) + except Exception as agg_error: # pragma: no cover - defensive logging + logger.warning("On-demand metrics aggregation failed: %s", agg_error) + + stmt = select(PerformanceMetric).where(PerformanceMetric.window_duration_seconds == window_duration_seconds) + + if component: + stmt = stmt.where(PerformanceMetric.component == component) + if operation: + stmt = stmt.where(PerformanceMetric.operation_type == operation) + + stmt = stmt.order_by(desc(PerformanceMetric.window_start), desc(PerformanceMetric.timestamp)) + + metrics = db.execute(stmt).scalars().all() + + metrics = _deduplicate_metrics(metrics) + + return [ + PerformanceMetricResponse( + id=str(metric.id), + timestamp=metric.timestamp, + component=metric.component, + operation_type=metric.operation_type, + window_start=metric.window_start, + window_end=metric.window_end, + request_count=metric.request_count, + error_count=metric.error_count, + error_rate=metric.error_rate, + avg_duration_ms=metric.avg_duration_ms, + min_duration_ms=metric.min_duration_ms, + max_duration_ms=metric.max_duration_ms, + p50_duration_ms=metric.p50_duration_ms, + p95_duration_ms=metric.p95_duration_ms, + p99_duration_ms=metric.p99_duration_ms, + ) + for metric in metrics + ] + + except Exception as e: + logger.error(f"Performance metrics query failed: {e}") + raise HTTPException(status_code=500, detail="Performance metrics query failed") diff --git a/mcpgateway/services/a2a_service.py b/mcpgateway/services/a2a_service.py index 33f3468d0..6fa2b5774 100644 --- a/mcpgateway/services/a2a_service.py +++ b/mcpgateway/services/a2a_service.py @@ -26,8 +26,10 @@ from mcpgateway.db import A2AAgentMetric, EmailTeam from mcpgateway.schemas import A2AAgentCreate, A2AAgentMetrics, A2AAgentRead, A2AAgentUpdate from mcpgateway.services.logging_service import LoggingService +from mcpgateway.services.structured_logger import get_structured_logger from mcpgateway.services.team_management_service import TeamManagementService from mcpgateway.services.tool_service import ToolService +from mcpgateway.utils.correlation_id import get_correlation_id from mcpgateway.utils.create_slug import slugify from mcpgateway.utils.services_auth import encode_auth # ,decode_auth @@ -35,6 +37,9 @@ logging_service = LoggingService() logger = logging_service.get_logger(__name__) +# Initialize structured logger for A2A lifecycle tracking +structured_logger = get_structured_logger("a2a_service") + class A2AAgentError(Exception): """Base class for A2A agent-related errors. @@ -279,6 +284,25 @@ async def register_agent( ) logger.info(f"Registered new A2A agent: {new_agent.name} (ID: {new_agent.id})") + + # Log A2A agent registration for lifecycle tracking + structured_logger.info( + f"A2A agent '{new_agent.name}' registered successfully", + user_id=created_by, + user_email=owner_email, + team_id=team_id, + resource_type="a2a_agent", + resource_id=str(new_agent.id), + resource_action="create", + custom_fields={ + "agent_name": new_agent.name, + "agent_type": new_agent.agent_type, + "protocol_version": new_agent.protocol_version, + "visibility": visibility, + "endpoint_url": new_agent.endpoint_url, + }, + ) + return self._db_to_schema(db=db, db_agent=new_agent) except A2AAgentNameConflictError as ie: @@ -716,6 +740,21 @@ async def toggle_agent_status(self, db: Session, agent_id: str, activate: bool, status = "activated" if activate else "deactivated" logger.info(f"A2A agent {status}: {agent.name} (ID: {agent.id})") + structured_logger.log( + level="INFO", + message=f"A2A agent {status}", + event_type="a2a_agent_status_changed", + component="a2a_service", + user_email=user_email, + resource_type="a2a_agent", + resource_id=str(agent.id), + custom_fields={ + "agent_name": agent.name, + "enabled": agent.enabled, + "reachable": agent.reachable, + }, + ) + return self._db_to_schema(db=db, db_agent=agent) async def delete_agent(self, db: Session, agent_id: str, user_email: Optional[str] = None) -> None: @@ -751,11 +790,31 @@ async def delete_agent(self, db: Session, agent_id: str, user_email: Optional[st db.commit() logger.info(f"Deleted A2A agent: {agent_name} (ID: {agent_id})") + + structured_logger.log( + level="INFO", + message="A2A agent deleted", + event_type="a2a_agent_deleted", + component="a2a_service", + user_email=user_email, + resource_type="a2a_agent", + resource_id=str(agent_id), + custom_fields={"agent_name": agent_name}, + ) except PermissionError: db.rollback() raise - async def invoke_agent(self, db: Session, agent_name: str, parameters: Dict[str, Any], interaction_type: str = "query") -> Dict[str, Any]: + async def invoke_agent( + self, + db: Session, + agent_name: str, + parameters: Dict[str, Any], + interaction_type: str = "query", + *, + user_id: Optional[str] = None, + user_email: Optional[str] = None, + ) -> Dict[str, Any]: """Invoke an A2A agent. Args: @@ -763,6 +822,8 @@ async def invoke_agent(self, db: Session, agent_name: str, parameters: Dict[str, agent_name: Name of the agent to invoke. parameters: Parameters for the interaction. interaction_type: Type of interaction. + user_id: Identifier of the user initiating the call. + user_email: Email of the user initiating the call. Returns: Agent response. @@ -803,13 +864,64 @@ async def invoke_agent(self, db: Session, agent_name: str, parameters: Dict[str, if token_value: headers["Authorization"] = f"Bearer {token_value}" + # Add correlation ID to outbound headers for distributed tracing + correlation_id = get_correlation_id() + if correlation_id: + headers["X-Correlation-ID"] = correlation_id + + # Log A2A external call start + call_start_time = datetime.now(timezone.utc) + structured_logger.log( + level="INFO", + message=f"A2A external call started: {agent_name}", + component="a2a_service", + user_id=user_id, + user_email=user_email, + correlation_id=correlation_id, + metadata={ + "event": "a2a_call_started", + "agent_name": agent_name, + "agent_id": agent.id, + "endpoint_url": agent.endpoint_url, + "interaction_type": interaction_type, + "protocol_version": agent.protocol_version, + }, + ) + http_response = await client.post(agent.endpoint_url, json=request_data, headers=headers) + call_duration_ms = (datetime.now(timezone.utc) - call_start_time).total_seconds() * 1000 if http_response.status_code == 200: response = http_response.json() success = True + + # Log successful A2A call + structured_logger.log( + level="INFO", + message=f"A2A external call completed: {agent_name}", + component="a2a_service", + user_id=user_id, + user_email=user_email, + correlation_id=correlation_id, + duration_ms=call_duration_ms, + metadata={"event": "a2a_call_completed", "agent_name": agent_name, "agent_id": agent.id, "status_code": http_response.status_code, "success": True}, + ) else: error_message = f"HTTP {http_response.status_code}: {http_response.text}" + + # Log failed A2A call + structured_logger.log( + level="ERROR", + message=f"A2A external call failed: {agent_name}", + component="a2a_service", + user_id=user_id, + user_email=user_email, + correlation_id=correlation_id, + duration_ms=call_duration_ms, + error_details={"error_type": "A2AHTTPError", "error_message": error_message}, + metadata={"event": "a2a_call_failed", "agent_name": agent_name, "agent_id": agent.id, "status_code": http_response.status_code}, + ) + raise A2AAgentError(error_message) except Exception as e: diff --git a/mcpgateway/services/audit_trail_service.py b/mcpgateway/services/audit_trail_service.py new file mode 100644 index 000000000..3d9023bfe --- /dev/null +++ b/mcpgateway/services/audit_trail_service.py @@ -0,0 +1,445 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/services/audit_trail_service.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 + +Audit Trail Service. + +This module provides audit trail management for CRUD operations, +data access tracking, and compliance logging. +""" + +# Standard +from datetime import datetime, timezone +from enum import Enum +import logging +from typing import Any, Dict, Optional + +# Third-Party +from sqlalchemy import select +from sqlalchemy.orm import Session + +# First-Party +from mcpgateway.db import AuditTrail, SessionLocal +from mcpgateway.utils.correlation_id import get_or_generate_correlation_id + +logger = logging.getLogger(__name__) + + +class AuditAction(str, Enum): + """Audit trail action types.""" + + CREATE = "CREATE" + READ = "READ" + UPDATE = "UPDATE" + DELETE = "DELETE" + EXECUTE = "EXECUTE" + ACCESS = "ACCESS" + EXPORT = "EXPORT" + IMPORT = "IMPORT" + + +class DataClassification(str, Enum): + """Data classification levels.""" + + PUBLIC = "public" + INTERNAL = "internal" + CONFIDENTIAL = "confidential" + RESTRICTED = "restricted" + + +REVIEW_REQUIRED_ACTIONS = { + "delete_server", + "delete_tool", + "delete_resource", + "delete_gateway", + "update_sensitive_config", + "bulk_delete", +} + + +class AuditTrailService: + """Service for managing audit trails and compliance logging. + + Provides comprehensive audit trail management with data classification, + change tracking, and compliance reporting capabilities. + """ + + def __init__(self): + """Initialize audit trail service.""" + + def log_action( # pylint: disable=too-many-positional-arguments + self, + action: str, + resource_type: str, + resource_id: str, + user_id: str, + user_email: Optional[str] = None, + team_id: Optional[str] = None, + resource_name: Optional[str] = None, + client_ip: Optional[str] = None, + user_agent: Optional[str] = None, + request_path: Optional[str] = None, + request_method: Optional[str] = None, + old_values: Optional[Dict[str, Any]] = None, + new_values: Optional[Dict[str, Any]] = None, + changes: Optional[Dict[str, Any]] = None, + data_classification: Optional[str] = None, + requires_review: Optional[bool] = None, + success: bool = True, + error_message: Optional[str] = None, + context: Optional[Dict[str, Any]] = None, + details: Optional[Dict[str, Any]] = None, + metadata: Optional[Dict[str, Any]] = None, + db: Optional[Session] = None, + ) -> Optional[AuditTrail]: + """Log an audit trail entry. + + Args: + action: Action performed (CREATE, READ, UPDATE, DELETE, etc.) + resource_type: Type of resource (tool, server, prompt, etc.) + resource_id: ID of the resource + user_id: User who performed the action + user_email: User's email address + team_id: Team ID if applicable + resource_name: Name of the resource + client_ip: Client IP address + user_agent: Client user agent + request_path: HTTP request path + request_method: HTTP request method + old_values: Previous values before change + new_values: New values after change + changes: Specific changes made + data_classification: Data classification level + requires_review: Whether this action requires review (None = auto) + success: Whether the action succeeded + error_message: Error message if failed + context: Additional context + details: Extra key/value payload (stored under context.details) + metadata: Extra metadata payload (stored under context.metadata) + db: Optional database session + + Returns: + Created AuditTrail entry or None if logging disabled + """ + correlation_id = get_or_generate_correlation_id() + + # Use provided session or create new one + close_db = False + if db is None: + db = SessionLocal() + close_db = True + + try: + context_payload: Dict[str, Any] = dict(context) if context else {} + if details: + context_payload["details"] = details + if metadata: + context_payload["metadata"] = metadata + context_value = context_payload if context_payload else None + + requires_review_flag = self._determine_requires_review( + action=action, + data_classification=data_classification, + requires_review_param=requires_review, + ) + + # Create audit trail entry + audit_entry = AuditTrail( + timestamp=datetime.now(timezone.utc), + correlation_id=correlation_id, + action=action, + resource_type=resource_type, + resource_id=resource_id, + resource_name=resource_name, + user_id=user_id, + user_email=user_email, + team_id=team_id, + client_ip=client_ip, + user_agent=user_agent, + request_path=request_path, + request_method=request_method, + old_values=old_values, + new_values=new_values, + changes=changes, + data_classification=data_classification, + requires_review=requires_review_flag, + success=success, + error_message=error_message, + context=context_value, + ) + + db.add(audit_entry) + db.commit() + db.refresh(audit_entry) + + logger.debug( + f"Audit trail logged: {action} {resource_type}/{resource_id} by {user_id}", + extra={"correlation_id": correlation_id, "action": action, "resource_type": resource_type, "resource_id": resource_id, "user_id": user_id, "success": success}, + ) + + return audit_entry + + except Exception as e: + logger.error(f"Failed to log audit trail: {e}", exc_info=True, extra={"correlation_id": correlation_id, "action": action, "resource_type": resource_type, "resource_id": resource_id}) + if close_db: + db.rollback() + return None + + finally: + if close_db: + db.close() + + def _determine_requires_review( + self, + action: Optional[str], + data_classification: Optional[str], + requires_review_param: Optional[bool], + ) -> bool: + """Resolve whether an audit entry should require review. + + Args: + action: Action being performed + data_classification: Data classification level + requires_review_param: Explicit review requirement + + Returns: + bool: Whether the audit entry requires review + """ + if requires_review_param is not None: + return requires_review_param + + if data_classification in {DataClassification.CONFIDENTIAL.value, DataClassification.RESTRICTED.value}: + return True + + normalized_action = (action or "").lower() + if normalized_action in REVIEW_REQUIRED_ACTIONS: + return True + + return False + + def log_crud_operation( + self, + operation: str, + resource_type: str, + resource_id: str, + user_id: str, + user_email: Optional[str] = None, + team_id: Optional[str] = None, + resource_name: Optional[str] = None, + old_values: Optional[Dict[str, Any]] = None, + new_values: Optional[Dict[str, Any]] = None, + success: bool = True, + error_message: Optional[str] = None, + db: Optional[Session] = None, + **kwargs, + ) -> Optional[AuditTrail]: + """Log a CRUD operation with change tracking. + + Args: + operation: CRUD operation (CREATE, READ, UPDATE, DELETE) + resource_type: Type of resource + resource_id: ID of the resource + user_id: User who performed the operation + user_email: User's email + team_id: Team ID if applicable + resource_name: Name of the resource + old_values: Previous values (for UPDATE/DELETE) + new_values: New values (for CREATE/UPDATE) + success: Whether the operation succeeded + error_message: Error message if failed + db: Optional database session + **kwargs: Additional arguments passed to log_action + + Returns: + Created AuditTrail entry + """ + # Calculate changes for UPDATE operations + changes = None + if operation == "UPDATE" and old_values and new_values: + changes = {} + for key in set(old_values.keys()) | set(new_values.keys()): + old_val = old_values.get(key) + new_val = new_values.get(key) + if old_val != new_val: + changes[key] = {"old": old_val, "new": new_val} + + # Determine data classification based on resource type + data_classification = None + if resource_type in ["user", "team", "token", "credential"]: + data_classification = DataClassification.CONFIDENTIAL.value + elif resource_type in ["tool", "server", "prompt", "resource"]: + data_classification = DataClassification.INTERNAL.value + + # Determine if review is required + requires_review = False + if data_classification == DataClassification.CONFIDENTIAL.value: + requires_review = True + if operation == "DELETE" and resource_type in ["tool", "server", "gateway"]: + requires_review = True + + return self.log_action( + action=operation, + resource_type=resource_type, + resource_id=resource_id, + user_id=user_id, + user_email=user_email, + team_id=team_id, + resource_name=resource_name, + old_values=old_values, + new_values=new_values, + changes=changes, + data_classification=data_classification, + requires_review=requires_review, + success=success, + error_message=error_message, + db=db, + **kwargs, + ) + + def log_data_access( + self, + resource_type: str, + resource_id: str, + user_id: str, + access_type: str = "READ", + user_email: Optional[str] = None, + team_id: Optional[str] = None, + resource_name: Optional[str] = None, + data_classification: Optional[str] = None, + db: Optional[Session] = None, + **kwargs, + ) -> Optional[AuditTrail]: + """Log data access for compliance tracking. + + Args: + resource_type: Type of resource accessed + resource_id: ID of the resource + user_id: User who accessed the data + access_type: Type of access (READ, EXPORT, etc.) + user_email: User's email + team_id: Team ID if applicable + resource_name: Name of the resource + data_classification: Data classification level + db: Optional database session + **kwargs: Additional arguments passed to log_action + + Returns: + Created AuditTrail entry + """ + requires_review = data_classification in [DataClassification.CONFIDENTIAL.value, DataClassification.RESTRICTED.value] + + return self.log_action( + action=access_type, + resource_type=resource_type, + resource_id=resource_id, + user_id=user_id, + user_email=user_email, + team_id=team_id, + resource_name=resource_name, + data_classification=data_classification, + requires_review=requires_review, + success=True, + db=db, + **kwargs, + ) + + def log_audit( + self, user_id: str, resource_type: str, resource_id: str, action: str, user_email: Optional[str] = None, description: Optional[str] = None, db: Optional[Session] = None, **kwargs + ) -> Optional[AuditTrail]: + """Convenience method for simple audit logging. + + Args: + user_id: User who performed the action + resource_type: Type of resource + resource_id: ID of the resource + action: Action performed + user_email: User's email + description: Description of the action + db: Optional database session + **kwargs: Additional arguments passed to log_action + + Returns: + Created AuditTrail entry + """ + # Build context if description provided + context = kwargs.pop("context", {}) + if description: + context["description"] = description + + return self.log_action(action=action, resource_type=resource_type, resource_id=resource_id, user_id=user_id, user_email=user_email, context=context if context else None, db=db, **kwargs) + + def get_audit_trail( + self, + resource_type: Optional[str] = None, + resource_id: Optional[str] = None, + user_id: Optional[str] = None, + action: Optional[str] = None, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, + limit: int = 100, + offset: int = 0, + db: Optional[Session] = None, + ) -> list[AuditTrail]: + """Query audit trail entries. + + Args: + resource_type: Filter by resource type + resource_id: Filter by resource ID + user_id: Filter by user ID + action: Filter by action + start_time: Filter by start time + end_time: Filter by end time + limit: Maximum number of results + offset: Offset for pagination + db: Optional database session + + Returns: + List of AuditTrail entries + """ + close_db = False + if db is None: + db = SessionLocal() + close_db = True + + try: + query = select(AuditTrail) + + if resource_type: + query = query.where(AuditTrail.resource_type == resource_type) + if resource_id: + query = query.where(AuditTrail.resource_id == resource_id) + if user_id: + query = query.where(AuditTrail.user_id == user_id) + if action: + query = query.where(AuditTrail.action == action) + if start_time: + query = query.where(AuditTrail.timestamp >= start_time) + if end_time: + query = query.where(AuditTrail.timestamp <= end_time) + + query = query.order_by(AuditTrail.timestamp.desc()) + query = query.limit(limit).offset(offset) + + result = db.execute(query) + return list(result.scalars().all()) + + finally: + if close_db: + db.close() + + +# Singleton instance +_audit_trail_service: Optional[AuditTrailService] = None + + +def get_audit_trail_service() -> AuditTrailService: + """Get or create the singleton audit trail service instance. + + Returns: + AuditTrailService instance + """ + global _audit_trail_service # pylint: disable=global-statement + if _audit_trail_service is None: + _audit_trail_service = AuditTrailService() + return _audit_trail_service diff --git a/mcpgateway/services/export_service.py b/mcpgateway/services/export_service.py index d5806dd59..78a5a5763 100644 --- a/mcpgateway/services/export_service.py +++ b/mcpgateway/services/export_service.py @@ -399,7 +399,7 @@ async def _export_servers(self, db: Session, tags: Optional[List[str]], include_ "websocket_endpoint": f"{root_path}/servers/{server.id}/ws", "jsonrpc_endpoint": f"{root_path}/servers/{server.id}/jsonrpc", "capabilities": {"tools": {"list_changed": True}, "prompts": {"list_changed": True}}, - "is_active": server.is_active, + "is_active": getattr(server, "enabled", getattr(server, "is_active", False)), "tags": server.tags or [], } @@ -469,7 +469,7 @@ async def _export_resources(self, db: Session, tags: Optional[List[str]], includ "description": resource.description, "mime_type": resource.mime_type, "tags": resource.tags or [], - "is_active": resource.is_active, + "is_active": getattr(resource, "enabled", getattr(resource, "is_active", False)), "last_modified": resource.updated_at.isoformat() if resource.updated_at else None, } diff --git a/mcpgateway/services/gateway_service.py b/mcpgateway/services/gateway_service.py index 9df0ae8c2..9d47fea1e 100644 --- a/mcpgateway/services/gateway_service.py +++ b/mcpgateway/services/gateway_service.py @@ -80,11 +80,13 @@ from mcpgateway.db import Tool as DbTool from mcpgateway.observability import create_span from mcpgateway.schemas import GatewayCreate, GatewayRead, GatewayUpdate, PromptCreate, ResourceCreate, ToolCreate -from mcpgateway.services.event_service import EventService # logging.getLogger("httpx").setLevel(logging.WARNING) # Disables httpx logs for regular health checks +from mcpgateway.services.audit_trail_service import get_audit_trail_service +from mcpgateway.services.event_service import EventService from mcpgateway.services.logging_service import LoggingService from mcpgateway.services.oauth_manager import OAuthManager +from mcpgateway.services.structured_logger import get_structured_logger from mcpgateway.services.team_management_service import TeamManagementService from mcpgateway.services.tool_service import ToolService from mcpgateway.utils.create_slug import slugify @@ -98,6 +100,10 @@ logging_service = LoggingService() logger = logging_service.get_logger(__name__) +# Initialize structured logger and audit trail for gateway operations +structured_logger = get_structured_logger("gateway_service") +audit_trail = get_audit_trail_service() + GW_FAILURE_THRESHOLD = settings.unhealthy_threshold GW_HEALTH_CHECK_INTERVAL = settings.health_check_interval @@ -809,6 +815,54 @@ async def register_gateway( # Notify subscribers await self._notify_gateway_added(db_gateway) + logger.info(f"Registered gateway: {gateway.name}") + + # Structured logging: Audit trail for gateway creation + audit_trail.log_action( + user_id=created_by or "system", + action="create_gateway", + resource_type="gateway", + resource_id=str(db_gateway.id), + resource_name=db_gateway.name, + user_email=owner_email, + team_id=team_id, + client_ip=created_from_ip, + user_agent=created_user_agent, + new_values={ + "name": db_gateway.name, + "url": db_gateway.url, + "visibility": visibility, + "transport": db_gateway.transport, + "tools_count": len(tools), + "resources_count": len(db_resources), + "prompts_count": len(db_prompts), + }, + context={ + "created_via": created_via, + }, + db=db, + ) + + # Structured logging: Log successful gateway creation + structured_logger.log( + level="INFO", + message="Gateway created successfully", + event_type="gateway_created", + component="gateway_service", + user_id=created_by, + user_email=owner_email, + team_id=team_id, + resource_type="gateway", + resource_id=str(db_gateway.id), + custom_fields={ + "gateway_name": db_gateway.name, + "gateway_url": normalized_url, + "visibility": visibility, + "transport": db_gateway.transport, + }, + db=db, + ) + # Add team name for response db_gateway.team = self._get_team_name(db, db_gateway.team_id) return GatewayRead.model_validate(self._prepare_gateway_for_read(db_gateway)).masked() @@ -816,31 +870,101 @@ async def register_gateway( if TYPE_CHECKING: ge: ExceptionGroup[GatewayConnectionError] logger.error(f"GatewayConnectionError in group: {ge.exceptions}") + + structured_logger.log( + level="ERROR", + message="Gateway creation failed due to connection error", + event_type="gateway_creation_failed", + component="gateway_service", + user_id=created_by, + user_email=owner_email, + error=ge.exceptions[0], + custom_fields={"gateway_name": gateway.name, "gateway_url": str(gateway.url)}, + db=db, + ) raise ge.exceptions[0] except* GatewayNameConflictError as gnce: # pragma: no mutate if TYPE_CHECKING: gnce: ExceptionGroup[GatewayNameConflictError] logger.error(f"GatewayNameConflictError in group: {gnce.exceptions}") + + structured_logger.log( + level="WARNING", + message="Gateway creation failed due to name conflict", + event_type="gateway_name_conflict", + component="gateway_service", + user_id=created_by, + user_email=owner_email, + custom_fields={"gateway_name": gateway.name, "visibility": visibility}, + db=db, + ) raise gnce.exceptions[0] except* GatewayDuplicateConflictError as guce: # pragma: no mutate if TYPE_CHECKING: guce: ExceptionGroup[GatewayDuplicateConflictError] logger.error(f"GatewayDuplicateConflictError in group: {guce.exceptions}") + + structured_logger.log( + level="WARNING", + message="Gateway creation failed due to duplicate", + event_type="gateway_duplicate_conflict", + component="gateway_service", + user_id=created_by, + user_email=owner_email, + custom_fields={"gateway_name": gateway.name}, + db=db, + ) raise guce.exceptions[0] except* ValueError as ve: # pragma: no mutate if TYPE_CHECKING: ve: ExceptionGroup[ValueError] logger.error(f"ValueErrors in group: {ve.exceptions}") + + structured_logger.log( + level="ERROR", + message="Gateway creation failed due to validation error", + event_type="gateway_creation_failed", + component="gateway_service", + user_id=created_by, + user_email=owner_email, + error=ve.exceptions[0], + custom_fields={"gateway_name": gateway.name}, + db=db, + ) raise ve.exceptions[0] except* RuntimeError as re: # pragma: no mutate if TYPE_CHECKING: re: ExceptionGroup[RuntimeError] logger.error(f"RuntimeErrors in group: {re.exceptions}") + + structured_logger.log( + level="ERROR", + message="Gateway creation failed due to runtime error", + event_type="gateway_creation_failed", + component="gateway_service", + user_id=created_by, + user_email=owner_email, + error=re.exceptions[0], + custom_fields={"gateway_name": gateway.name}, + db=db, + ) raise re.exceptions[0] except* IntegrityError as ie: # pragma: no mutate if TYPE_CHECKING: ie: ExceptionGroup[IntegrityError] logger.error(f"IntegrityErrors in group: {ie.exceptions}") + + structured_logger.log( + level="ERROR", + message="Gateway creation failed due to database integrity error", + event_type="gateway_creation_failed", + component="gateway_service", + user_id=created_by, + user_email=owner_email, + error=ie.exceptions[0], + custom_fields={"gateway_name": gateway.name}, + db=db, + ) raise ie.exceptions[0] except* BaseException as other: # catches every other sub-exception # pragma: no mutate if TYPE_CHECKING: @@ -1461,6 +1585,47 @@ async def update_gateway( await self._notify_gateway_updated(gateway) logger.info(f"Updated gateway: {gateway.name}") + + # Structured logging: Audit trail for gateway update + audit_trail.log_action( + user_id=user_email or modified_by or "system", + action="update_gateway", + resource_type="gateway", + resource_id=str(gateway.id), + resource_name=gateway.name, + user_email=user_email, + team_id=gateway.team_id, + client_ip=modified_from_ip, + user_agent=modified_user_agent, + new_values={ + "name": gateway.name, + "url": gateway.url, + "version": gateway.version, + }, + context={ + "modified_via": modified_via, + }, + db=db, + ) + + # Structured logging: Log successful gateway update + structured_logger.log( + level="INFO", + message="Gateway updated successfully", + event_type="gateway_updated", + component="gateway_service", + user_id=modified_by, + user_email=user_email, + team_id=gateway.team_id, + resource_type="gateway", + resource_id=str(gateway.id), + custom_fields={ + "gateway_name": gateway.name, + "version": gateway.version, + }, + db=db, + ) + gateway.team = self._get_team_name(db, getattr(gateway, "team_id", None)) return GatewayRead.model_validate(self._prepare_gateway_for_read(gateway)) @@ -1468,18 +1633,78 @@ async def update_gateway( return None except GatewayNameConflictError as ge: logger.error(f"GatewayNameConflictError in group: {ge}") + + structured_logger.log( + level="WARNING", + message="Gateway update failed due to name conflict", + event_type="gateway_name_conflict", + component="gateway_service", + user_email=user_email, + resource_type="gateway", + resource_id=gateway_id, + error=ge, + db=db, + ) raise ge except GatewayNotFoundError as gnfe: logger.error(f"GatewayNotFoundError: {gnfe}") + + structured_logger.log( + level="ERROR", + message="Gateway update failed - gateway not found", + event_type="gateway_not_found", + component="gateway_service", + user_email=user_email, + resource_type="gateway", + resource_id=gateway_id, + error=gnfe, + db=db, + ) raise gnfe except IntegrityError as ie: logger.error(f"IntegrityErrors in group: {ie}") + + structured_logger.log( + level="ERROR", + message="Gateway update failed due to database integrity error", + event_type="gateway_update_failed", + component="gateway_service", + user_email=user_email, + resource_type="gateway", + resource_id=gateway_id, + error=ie, + db=db, + ) raise ie - except PermissionError: + except PermissionError as pe: db.rollback() + + structured_logger.log( + level="WARNING", + message="Gateway update failed due to permission error", + event_type="gateway_update_permission_denied", + component="gateway_service", + user_email=user_email, + resource_type="gateway", + resource_id=gateway_id, + error=pe, + db=db, + ) raise except Exception as e: db.rollback() + + structured_logger.log( + level="ERROR", + message="Gateway update failed", + event_type="gateway_update_failed", + component="gateway_service", + user_email=user_email, + resource_type="gateway", + resource_id=gateway_id, + error=e, + db=db, + ) raise GatewayError(f"Failed to update gateway: {str(e)}") async def get_gateway(self, db: Session, gateway_id: str, include_inactive: bool = True) -> GatewayRead: @@ -1542,6 +1767,24 @@ async def get_gateway(self, db: Session, gateway_id: str, include_inactive: bool if gateway.enabled or include_inactive: gateway.team = self._get_team_name(db, getattr(gateway, "team_id", None)) + + # Structured logging: Log gateway view + structured_logger.log( + level="INFO", + message="Gateway retrieved successfully", + event_type="gateway_viewed", + component="gateway_service", + team_id=getattr(gateway, "team_id", None), + resource_type="gateway", + resource_id=str(gateway.id), + custom_fields={ + "gateway_name": gateway.name, + "gateway_url": gateway.url, + "include_inactive": include_inactive, + }, + db=db, + ) + return GatewayRead.model_validate(self._prepare_gateway_for_read(gateway)).masked() raise GatewayNotFoundError(f"Gateway not found: {gateway_id}") @@ -1689,13 +1932,76 @@ async def toggle_gateway_status(self, db: Session, gateway_id: str, activate: bo logger.info(f"Gateway status: {gateway.name} - {'enabled' if activate else 'disabled'} and {'accessible' if reachable else 'inaccessible'}") + # Structured logging: Audit trail for gateway status toggle + audit_trail.log_action( + user_id=user_email or "system", + action="toggle_gateway_status", + resource_type="gateway", + resource_id=str(gateway.id), + resource_name=gateway.name, + user_email=user_email, + team_id=gateway.team_id, + new_values={ + "enabled": gateway.enabled, + "reachable": gateway.reachable, + }, + context={ + "action": "activate" if activate else "deactivate", + "only_update_reachable": only_update_reachable, + }, + db=db, + ) + + # Structured logging: Log successful gateway status toggle + structured_logger.log( + level="INFO", + message=f"Gateway {'activated' if activate else 'deactivated'} successfully", + event_type="gateway_status_toggled", + component="gateway_service", + user_email=user_email, + team_id=gateway.team_id, + resource_type="gateway", + resource_id=str(gateway.id), + custom_fields={ + "gateway_name": gateway.name, + "enabled": gateway.enabled, + "reachable": gateway.reachable, + }, + db=db, + ) + gateway.team = self._get_team_name(db, getattr(gateway, "team_id", None)) return GatewayRead.model_validate(self._prepare_gateway_for_read(gateway)).masked() except PermissionError as e: + # Structured logging: Log permission error + structured_logger.log( + level="WARNING", + message="Gateway status toggle failed due to permission error", + event_type="gateway_toggle_permission_denied", + component="gateway_service", + user_email=user_email, + resource_type="gateway", + resource_id=gateway_id, + error=e, + db=db, + ) raise e except Exception as e: db.rollback() + + # Structured logging: Log generic gateway status toggle failure + structured_logger.log( + level="ERROR", + message="Gateway status toggle failed", + event_type="gateway_toggle_failed", + component="gateway_service", + user_email=user_email, + resource_type="gateway", + resource_id=gateway_id, + error=e, + db=db, + ) raise GatewayError(f"Failed to toggle gateway status: {str(e)}") async def _notify_gateway_updated(self, gateway: DbGateway) -> None: @@ -1765,6 +2071,8 @@ async def delete_gateway(self, db: Session, gateway_id: str, user_email: Optiona # Store gateway info for notification before deletion gateway_info = {"id": gateway.id, "name": gateway.name, "url": gateway.url} + gateway_name = gateway.name + gateway_team_id = gateway.team_id # Hard delete gateway db.delete(gateway) @@ -1778,11 +2086,70 @@ async def delete_gateway(self, db: Session, gateway_id: str, user_email: Optiona logger.info(f"Permanently deleted gateway: {gateway.name}") - except PermissionError: + # Structured logging: Audit trail for gateway deletion + audit_trail.log_action( + user_id=user_email or "system", + action="delete_gateway", + resource_type="gateway", + resource_id=str(gateway_info["id"]), + resource_name=gateway_name, + user_email=user_email, + team_id=gateway_team_id, + old_values={ + "name": gateway_name, + "url": gateway_info["url"], + }, + db=db, + ) + + # Structured logging: Log successful gateway deletion + structured_logger.log( + level="INFO", + message="Gateway deleted successfully", + event_type="gateway_deleted", + component="gateway_service", + user_email=user_email, + team_id=gateway_team_id, + resource_type="gateway", + resource_id=str(gateway_info["id"]), + custom_fields={ + "gateway_name": gateway_name, + "gateway_url": gateway_info["url"], + }, + db=db, + ) + + except PermissionError as pe: db.rollback() + + # Structured logging: Log permission error + structured_logger.log( + level="WARNING", + message="Gateway deletion failed due to permission error", + event_type="gateway_delete_permission_denied", + component="gateway_service", + user_email=user_email, + resource_type="gateway", + resource_id=gateway_id, + error=pe, + db=db, + ) raise except Exception as e: db.rollback() + + # Structured logging: Log generic gateway deletion failure + structured_logger.log( + level="ERROR", + message="Gateway deletion failed", + event_type="gateway_deletion_failed", + component="gateway_service", + user_email=user_email, + resource_type="gateway", + resource_id=gateway_id, + error=e, + db=db, + ) raise GatewayError(f"Failed to delete gateway: {str(e)}") async def forward_request( @@ -3354,26 +3721,26 @@ def get_httpx_client_factory( except Exception as e: logger.warning(f"Failed to fetch resources: {e}") - # resource template URI - try: - response_templates = await session.list_resource_templates() - raw_resources_templates = response_templates.resourceTemplates - resource_templates = [] - for resource_template in raw_resources_templates: - resource_template_data = resource_template.model_dump(by_alias=True, exclude_none=True) + # resource template URI + try: + response_templates = await session.list_resource_templates() + raw_resources_templates = response_templates.resourceTemplates + resource_templates = [] + for resource_template in raw_resources_templates: + resource_template_data = resource_template.model_dump(by_alias=True, exclude_none=True) - if "uriTemplate" in resource_template_data: # and hasattr(resource_template_data["uriTemplate"], "unicode_string"): - resource_template_data["uri_template"] = str(resource_template_data["uriTemplate"]) - resource_template_data["uri"] = str(resource_template_data["uriTemplate"]) + if "uriTemplate" in resource_template_data: # and hasattr(resource_template_data["uriTemplate"], "unicode_string"): + resource_template_data["uri_template"] = str(resource_template_data["uriTemplate"]) + resource_template_data["uri"] = str(resource_template_data["uriTemplate"]) - if "content" not in resource_template_data: - resource_template_data["content"] = "" + if "content" not in resource_template_data: + resource_template_data["content"] = "" - resources.append(ResourceCreate.model_validate(resource_template_data)) - resource_templates.append(ResourceCreate.model_validate(resource_template_data)) - logger.info(f"Fetched {len(raw_resources_templates)} resource templates from gateway") - except Exception as ei: - logger.warning(f"Failed to fetch resource templates: {ei}") + resources.append(ResourceCreate.model_validate(resource_template_data)) + resource_templates.append(ResourceCreate.model_validate(resource_template_data)) + logger.info(f"Fetched {len(raw_resources_templates)} resource templates from gateway") + except Exception as ei: + logger.warning(f"Failed to fetch resource templates: {ei}") # Fetch prompts if supported prompts = [] diff --git a/mcpgateway/services/log_aggregator.py b/mcpgateway/services/log_aggregator.py new file mode 100644 index 000000000..2d7f0f293 --- /dev/null +++ b/mcpgateway/services/log_aggregator.py @@ -0,0 +1,526 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/services/log_aggregator.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 + +Log Aggregation Service. + +This module provides aggregation of performance metrics from structured logs +into time-windowed statistics for analysis and monitoring. +""" + +# Standard +from datetime import datetime, timedelta, timezone +import logging +import math +import statistics +from typing import Any, Dict, List, Optional, Tuple + +# Third-Party +from sqlalchemy import and_, select +from sqlalchemy.orm import Session + +# First-Party +from mcpgateway.config import settings +from mcpgateway.db import PerformanceMetric, SessionLocal, StructuredLogEntry + +logger = logging.getLogger(__name__) + + +class LogAggregator: + """Aggregates structured logs into performance metrics.""" + + def __init__(self): + """Initialize log aggregator.""" + self.aggregation_window_minutes = getattr(settings, "metrics_aggregation_window_minutes", 5) + self.enabled = getattr(settings, "metrics_aggregation_enabled", True) + + def aggregate_performance_metrics( + self, component: Optional[str], operation_type: Optional[str], window_start: Optional[datetime] = None, window_end: Optional[datetime] = None, db: Optional[Session] = None + ) -> Optional[PerformanceMetric]: + """Aggregate performance metrics for a component and operation. + + Args: + component: Component name + operation_type: Operation name + window_start: Start of aggregation window (defaults to N minutes ago) + window_end: End of aggregation window (defaults to now) + db: Optional database session + + Returns: + Created PerformanceMetric or None if no data + """ + if not self.enabled: + return None + if not component or not operation_type: + return None + + window_start, window_end = self._resolve_window_bounds(window_start, window_end) + + should_close = False + if db is None: + db = SessionLocal() + should_close = True + + try: + # Query structured logs for this component/operation in time window + stmt = select(StructuredLogEntry).where( + and_( + StructuredLogEntry.component == component, + StructuredLogEntry.operation_type == operation_type, + StructuredLogEntry.timestamp >= window_start, + StructuredLogEntry.timestamp < window_end, + StructuredLogEntry.duration_ms.isnot(None), + ) + ) + + results = db.execute(stmt).scalars().all() + + if not results: + return None + + # Extract durations + durations = sorted(r.duration_ms for r in results if r.duration_ms is not None) + + if not durations: + return None + + # Calculate statistics + count = len(durations) + avg_duration = statistics.fmean(durations) if hasattr(statistics, "fmean") else statistics.mean(durations) + min_duration = durations[0] + max_duration = durations[-1] + + # Calculate percentiles + p50 = self._percentile(durations, 0.50) + p95 = self._percentile(durations, 0.95) + p99 = self._percentile(durations, 0.99) + + # Count errors + error_count = self._calculate_error_count(results) + error_rate = error_count / count if count > 0 else 0.0 + + metric = self._upsert_metric( + component=component, + operation_type=operation_type, + window_start=window_start, + window_end=window_end, + request_count=count, + error_count=error_count, + error_rate=error_rate, + avg_duration_ms=avg_duration, + min_duration_ms=min_duration, + max_duration_ms=max_duration, + p50_duration_ms=p50, + p95_duration_ms=p95, + p99_duration_ms=p99, + metric_metadata={ + "sample_size": count, + "generated_at": datetime.now(timezone.utc).isoformat(), + }, + db=db, + ) + + logger.info(f"Aggregated performance metrics for {component}.{operation_type}: " f"{count} requests, {avg_duration:.2f}ms avg, {error_rate:.2%} error rate") + + return metric + + except Exception as e: + logger.error(f"Failed to aggregate performance metrics: {e}") + if db: + db.rollback() + return None + + finally: + if should_close: + db.close() + + def aggregate_all_components(self, window_start: Optional[datetime] = None, window_end: Optional[datetime] = None, db: Optional[Session] = None) -> List[PerformanceMetric]: + """Aggregate metrics for all components and operations. + + Args: + window_start: Start of aggregation window + window_end: End of aggregation window + db: Optional database session + + Returns: + List of created PerformanceMetric records + """ + if not self.enabled: + return [] + + should_close = False + if db is None: + db = SessionLocal() + should_close = True + + try: + window_start, window_end = self._resolve_window_bounds(window_start, window_end) + + stmt = ( + select(StructuredLogEntry.component, StructuredLogEntry.operation_type) + .where( + and_( + StructuredLogEntry.timestamp >= window_start, + StructuredLogEntry.timestamp < window_end, + StructuredLogEntry.duration_ms.isnot(None), + StructuredLogEntry.operation_type.isnot(None), + ) + ) + .distinct() + ) + + pairs = db.execute(stmt).all() + + metrics = [] + for component, operation in pairs: + if component and operation: + metric = self.aggregate_performance_metrics(component=component, operation_type=operation, window_start=window_start, window_end=window_end, db=db) + if metric: + metrics.append(metric) + + return metrics + + finally: + if should_close: + db.close() + + def get_recent_metrics(self, component: Optional[str] = None, operation: Optional[str] = None, hours: int = 24, db: Optional[Session] = None) -> List[PerformanceMetric]: + """Get recent performance metrics. + + Args: + component: Optional component filter + operation: Optional operation filter + hours: Hours of history to retrieve + db: Optional database session + + Returns: + List of PerformanceMetric records + """ + should_close = False + if db is None: + db = SessionLocal() + should_close = True + + try: + since = datetime.now(timezone.utc) - timedelta(hours=hours) + + stmt = select(PerformanceMetric).where(PerformanceMetric.window_start >= since) + + if component: + stmt = stmt.where(PerformanceMetric.component == component) + if operation: + stmt = stmt.where(PerformanceMetric.operation_type == operation) + + stmt = stmt.order_by(PerformanceMetric.window_start.desc()) + + return db.execute(stmt).scalars().all() + + finally: + if should_close: + db.close() + + def get_degradation_alerts(self, threshold_multiplier: float = 1.5, hours: int = 24, db: Optional[Session] = None) -> List[Dict[str, Any]]: + """Identify performance degradations by comparing recent vs baseline. + + Args: + threshold_multiplier: Alert if recent is X times slower than baseline + hours: Hours of recent data to check + db: Optional database session + + Returns: + List of degradation alerts with details + """ + should_close = False + if db is None: + db = SessionLocal() + should_close = True + + try: + recent_cutoff = datetime.now(timezone.utc) - timedelta(hours=hours) + baseline_cutoff = recent_cutoff - timedelta(hours=hours * 2) + + # Get unique component/operation pairs + stmt = select(PerformanceMetric.component, PerformanceMetric.operation_type).distinct() + + pairs = db.execute(stmt).all() + + alerts = [] + for component, operation in pairs: + # Get recent metrics + recent_stmt = select(PerformanceMetric).where( + and_(PerformanceMetric.component == component, PerformanceMetric.operation_type == operation, PerformanceMetric.window_start >= recent_cutoff) + ) + recent_metrics = db.execute(recent_stmt).scalars().all() + + # Get baseline metrics + baseline_stmt = select(PerformanceMetric).where( + and_( + PerformanceMetric.component == component, + PerformanceMetric.operation_type == operation, + PerformanceMetric.window_start >= baseline_cutoff, + PerformanceMetric.window_start < recent_cutoff, + ) + ) + baseline_metrics = db.execute(baseline_stmt).scalars().all() + + if not recent_metrics or not baseline_metrics: + continue + + recent_avg = statistics.mean([m.avg_duration_ms for m in recent_metrics]) + baseline_avg = statistics.mean([m.avg_duration_ms for m in baseline_metrics]) + + if recent_avg > baseline_avg * threshold_multiplier: + alerts.append( + { + "component": component, + "operation": operation, + "recent_avg_ms": recent_avg, + "baseline_avg_ms": baseline_avg, + "degradation_ratio": recent_avg / baseline_avg, + "recent_error_rate": statistics.mean([m.error_rate for m in recent_metrics]), + "baseline_error_rate": statistics.mean([m.error_rate for m in baseline_metrics]), + } + ) + + return alerts + + finally: + if should_close: + db.close() + + def backfill(self, hours: float, db: Optional[Session] = None) -> int: + """Backfill metrics for a historical time range. + + Args: + hours: Number of hours of history to aggregate (supports fractional hours) + db: Optional shared database session + + Returns: + Count of performance metric windows processed + """ + if not self.enabled or hours <= 0: + return 0 + + window_minutes = self.aggregation_window_minutes + window_delta = timedelta(minutes=window_minutes) + total_windows = max(1, math.ceil((hours * 60) / window_minutes)) + + should_close = False + if db is None: + db = SessionLocal() + should_close = True + + try: + _, latest_end = self._resolve_window_bounds(None, None) + current_start = latest_end - (window_delta * total_windows) + processed = 0 + + while current_start < latest_end: + current_end = current_start + window_delta + created = self.aggregate_all_components( + window_start=current_start, + window_end=current_end, + db=db, + ) + if created: + processed += 1 + current_start = current_end + + return processed + + finally: + if should_close: + db.close() + + @staticmethod + def _percentile(sorted_values: List[float], percentile: float) -> float: + """Calculate percentile from sorted values. + + Args: + sorted_values: Sorted list of values + percentile: Percentile to calculate (0.0 to 1.0) + + Returns: + float: Calculated percentile value + """ + if not sorted_values: + return 0.0 + + if len(sorted_values) == 1: + return float(sorted_values[0]) + + k = (len(sorted_values) - 1) * percentile + f = math.floor(k) + c = math.ceil(k) + + if f == c: + return float(sorted_values[int(k)]) + + d0 = sorted_values[f] * (c - k) + d1 = sorted_values[c] * (k - f) + return float(d0 + d1) + + @staticmethod + def _calculate_error_count(entries: List[StructuredLogEntry]) -> int: + """Calculate error occurrences for a batch of log entries. + + Args: + entries: List of log entries to analyze + + Returns: + int: Count of error entries + """ + error_levels = {"ERROR", "CRITICAL"} + return sum(1 for entry in entries if (entry.level and entry.level.upper() in error_levels) or entry.error_details) + + def _resolve_window_bounds( + self, + window_start: Optional[datetime], + window_end: Optional[datetime], + ) -> Tuple[datetime, datetime]: + """Resolve and normalize aggregation window bounds. + + Args: + window_start: Start of window or None to calculate + window_end: End of window or None for current time + + Returns: + Tuple[datetime, datetime]: Resolved window start and end + """ + window_delta = timedelta(minutes=self.aggregation_window_minutes) + + if window_start is not None and window_end is not None: + resolved_start = window_start.astimezone(timezone.utc) + resolved_end = window_end.astimezone(timezone.utc) + if resolved_end <= resolved_start: + resolved_end = resolved_start + window_delta + return resolved_start, resolved_end + + if window_end is None: + reference = datetime.now(timezone.utc) + else: + reference = window_end.astimezone(timezone.utc) + + reference = reference.replace(second=0, microsecond=0) + minutes_offset = reference.minute % self.aggregation_window_minutes + if window_end is None and minutes_offset: + reference = reference - timedelta(minutes=minutes_offset) + + resolved_end = reference if window_end is None else reference + + if window_start is None: + resolved_start = resolved_end - window_delta + else: + resolved_start = window_start.astimezone(timezone.utc) + + if resolved_end <= resolved_start: + resolved_start = resolved_end - window_delta + + return resolved_start, resolved_end + + def _upsert_metric( + self, + component: str, + operation_type: str, + window_start: datetime, + window_end: datetime, + request_count: int, + error_count: int, + error_rate: float, + avg_duration_ms: float, + min_duration_ms: float, + max_duration_ms: float, + p50_duration_ms: float, + p95_duration_ms: float, + p99_duration_ms: float, + metric_metadata: Optional[Dict[str, Any]], + db: Session, + ) -> PerformanceMetric: + """Create or update a performance metric window. + + Args: + component: Component name + operation_type: Operation type + window_start: Window start time + window_end: Window end time + request_count: Total request count + error_count: Total error count + error_rate: Error rate (0.0-1.0) + avg_duration_ms: Average duration in milliseconds + min_duration_ms: Minimum duration in milliseconds + max_duration_ms: Maximum duration in milliseconds + p50_duration_ms: 50th percentile duration + p95_duration_ms: 95th percentile duration + p99_duration_ms: 99th percentile duration + metric_metadata: Additional metadata + db: Database session + + Returns: + PerformanceMetric: Created or updated metric + """ + + existing_stmt = select(PerformanceMetric).where( + and_( + PerformanceMetric.component == component, + PerformanceMetric.operation_type == operation_type, + PerformanceMetric.window_start == window_start, + PerformanceMetric.window_end == window_end, + ) + ) + + existing_metrics = db.execute(existing_stmt).scalars().all() + metric = existing_metrics[0] if existing_metrics else None + + if len(existing_metrics) > 1: + logger.warning( + "Found %s duplicate performance metric rows for %s.%s window %s-%s; pruning extras", + len(existing_metrics), + component, + operation_type, + window_start.isoformat(), + window_end.isoformat(), + ) + for duplicate in existing_metrics[1:]: + db.delete(duplicate) + + if metric is None: + metric = PerformanceMetric( + component=component, + operation_type=operation_type, + window_start=window_start, + window_end=window_end, + window_duration_seconds=int((window_end - window_start).total_seconds()), + ) + db.add(metric) + + metric.request_count = request_count + metric.error_count = error_count + metric.error_rate = error_rate + metric.avg_duration_ms = avg_duration_ms + metric.min_duration_ms = min_duration_ms + metric.max_duration_ms = max_duration_ms + metric.p50_duration_ms = p50_duration_ms + metric.p95_duration_ms = p95_duration_ms + metric.p99_duration_ms = p99_duration_ms + metric.metric_metadata = metric_metadata + + db.commit() + db.refresh(metric) + return metric + + +# Global log aggregator instance +_log_aggregator: Optional[LogAggregator] = None + + +def get_log_aggregator() -> LogAggregator: + """Get or create the global log aggregator instance. + + Returns: + Global LogAggregator instance + """ + global _log_aggregator # pylint: disable=global-statement + if _log_aggregator is None: + _log_aggregator = LogAggregator() + return _log_aggregator diff --git a/mcpgateway/services/logging_service.py b/mcpgateway/services/logging_service.py index 4f21111c0..f18f826f9 100644 --- a/mcpgateway/services/logging_service.py +++ b/mcpgateway/services/logging_service.py @@ -16,6 +16,7 @@ import logging from logging.handlers import RotatingFileHandler import os +import socket from typing import Any, AsyncGenerator, Dict, List, NotRequired, Optional, TextIO, TypedDict # Third-Party @@ -25,10 +26,18 @@ from mcpgateway.common.models import LogLevel from mcpgateway.config import settings from mcpgateway.services.log_storage_service import LogStorageService +from mcpgateway.utils.correlation_id import get_correlation_id + +# Optional OpenTelemetry support (Third-Party) +try: + # Third-Party + from opentelemetry import trace # type: ignore[import-untyped] +except ImportError: + trace = None # type: ignore[assignment] AnyioClosedResourceError: Optional[type] # pylint: disable=invalid-name try: - # Optional import; only used for filtering a known benign upstream error + # Optional import; only used for filtering a known benign upstream error (Third-Party) # Third-Party from anyio import ClosedResourceError as AnyioClosedResourceError # pylint: disable=invalid-name except Exception: # pragma: no cover - environment without anyio @@ -38,8 +47,52 @@ # Create a text formatter text_formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") -# Create a JSON formatter -json_formatter = jsonlogger.JsonFormatter("%(asctime)s %(name)s %(levelname)s %(message)s") + +class CorrelationIdJsonFormatter(jsonlogger.JsonFormatter): + """JSON formatter that includes correlation ID and OpenTelemetry trace context.""" + + def add_fields(self, log_record: dict, record: logging.LogRecord, message_dict: dict) -> None: # pylint: disable=arguments-renamed + """Add custom fields to the log record. + + Args: + log_record: The dictionary that will be logged as JSON + record: The original LogRecord + message_dict: Additional message fields + + """ + super().add_fields(log_record, record, message_dict) + + # Add timestamp in ISO 8601 format with 'Z' suffix for UTC + dt = datetime.fromtimestamp(record.created, tz=timezone.utc) + log_record["@timestamp"] = dt.isoformat().replace("+00:00", "Z") + + # Add hostname and process ID for log aggregation + log_record["hostname"] = socket.gethostname() + log_record["process_id"] = os.getpid() + + # Add correlation ID from context + correlation_id = get_correlation_id() + if correlation_id: + log_record["request_id"] = correlation_id + + # Add OpenTelemetry trace context if available + if trace is not None: + try: + span = trace.get_current_span() + if span and span.is_recording(): + span_context = span.get_span_context() + if span_context.is_valid: + # Format trace_id and span_id as hex strings + log_record["trace_id"] = format(span_context.trace_id, "032x") + log_record["span_id"] = format(span_context.span_id, "016x") + log_record["trace_flags"] = format(span_context.trace_flags, "02x") + except Exception: # nosec B110 - intentionally catching all exceptions for optional tracing + # Error accessing span context, continue without trace fields + pass + + +# Create a JSON formatter with correlation ID support +json_formatter = CorrelationIdJsonFormatter("%(asctime)s %(name)s %(levelname)s %(message)s") # Note: Don't use basicConfig here as it conflicts with our custom dual logging setup # The LoggingService.initialize() method will properly configure all handlers diff --git a/mcpgateway/services/performance_tracker.py b/mcpgateway/services/performance_tracker.py new file mode 100644 index 000000000..dcf813979 --- /dev/null +++ b/mcpgateway/services/performance_tracker.py @@ -0,0 +1,304 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/services/performance_tracker.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 + +Performance Tracking Service. + +This module provides performance tracking and analytics for all operations +across the MCP Gateway, enabling identification of bottlenecks and +optimization opportunities. +""" + +# Standard +from collections import defaultdict +from contextlib import contextmanager +import logging +import statistics +import time +from typing import Any, Dict, Generator, List, Optional + +# First-Party +from mcpgateway.config import settings +from mcpgateway.utils.correlation_id import get_correlation_id + +logger = logging.getLogger(__name__) + + +class PerformanceTracker: + """Tracks and analyzes performance metrics across requests. + + Provides context managers for tracking operation timing, + aggregation of metrics, and threshold-based alerting. + """ + + def __init__(self): + """Initialize performance tracker.""" + self.operation_timings: Dict[str, List[float]] = defaultdict(list) + + # Performance thresholds (seconds) from settings or defaults + self.performance_thresholds = { + "database_query": getattr(settings, "perf_threshold_database_query", 0.1), + "tool_invocation": getattr(settings, "perf_threshold_tool_invocation", 2.0), + "authentication": getattr(settings, "perf_threshold_authentication", 0.5), + "cache_operation": getattr(settings, "perf_threshold_cache_operation", 0.01), + "a2a_task": getattr(settings, "perf_threshold_a2a_task", 5.0), + "request_total": getattr(settings, "perf_threshold_request_total", 10.0), + "resource_fetch": getattr(settings, "perf_threshold_resource_fetch", 1.0), + "prompt_processing": getattr(settings, "perf_threshold_prompt_processing", 0.5), + } + + # Max buffer size per operation type + self.max_samples = getattr(settings, "perf_max_samples_per_operation", 1000) + + @contextmanager + def track_operation(self, operation_name: str, component: Optional[str] = None, log_slow: bool = True, extra_context: Optional[Dict[str, Any]] = None) -> Generator[None, None, None]: + """Context manager to track operation performance. + + Args: + operation_name: Name of the operation being tracked + component: Component/module name for context + log_slow: Whether to log operations exceeding thresholds + extra_context: Additional context to include in logs + + Yields: + None + + Raises: + Exception: Any exception from the tracked operation is re-raised + + Example: + >>> tracker = PerformanceTracker() + >>> with tracker.track_operation("database_query", component="tool_service"): + ... # Perform database operation + ... pass + """ + start_time = time.time() + correlation_id = get_correlation_id() + error_occurred = False + + try: + yield + except Exception: + error_occurred = True + raise + finally: + duration = time.time() - start_time + + # Record timing + self.operation_timings[operation_name].append(duration) + + # Limit buffer size + if len(self.operation_timings[operation_name]) > self.max_samples: + self.operation_timings[operation_name].pop(0) + + # Check threshold and log if needed + threshold = self.performance_thresholds.get(operation_name, float("inf")) + threshold_exceeded = duration > threshold + + if log_slow and threshold_exceeded: + context = { + "operation": operation_name, + "duration_ms": duration * 1000, + "threshold_ms": threshold * 1000, + "exceeded_by_ms": (duration - threshold) * 1000, + "component": component, + "correlation_id": correlation_id, + "error_occurred": error_occurred, + } + if extra_context: + context.update(extra_context) + + logger.warning(f"Slow operation detected: {operation_name} took {duration*1000:.2f}ms " f"(threshold: {threshold*1000:.2f}ms)", extra=context) + + def record_timing(self, operation_name: str, duration: float, component: Optional[str] = None, extra_context: Optional[Dict[str, Any]] = None) -> None: + """Manually record a timing measurement. + + Args: + operation_name: Name of the operation + duration: Duration in seconds + component: Component/module name + extra_context: Additional context + """ + self.operation_timings[operation_name].append(duration) + + # Limit buffer size + if len(self.operation_timings[operation_name]) > self.max_samples: + self.operation_timings[operation_name].pop(0) + + # Check threshold + threshold = self.performance_thresholds.get(operation_name, float("inf")) + if duration > threshold: + context = { + "operation": operation_name, + "duration_ms": duration * 1000, + "threshold_ms": threshold * 1000, + "component": component, + "correlation_id": get_correlation_id(), + } + if extra_context: + context.update(extra_context) + + logger.warning(f"Slow operation: {operation_name} took {duration*1000:.2f}ms", extra=context) + + def get_performance_summary(self, operation_name: Optional[str] = None, min_samples: int = 1) -> Dict[str, Any]: + """Get performance summary for analytics. + + Args: + operation_name: Specific operation to summarize (None for all) + min_samples: Minimum samples required to include in summary + + Returns: + Dictionary containing performance statistics + + Example: + >>> tracker = PerformanceTracker() + >>> summary = tracker.get_performance_summary() + >>> isinstance(summary, dict) + True + """ + summary = {} + + operations = {operation_name: self.operation_timings[operation_name]} if operation_name and operation_name in self.operation_timings else self.operation_timings + + for op_name, timings in operations.items(): + if len(timings) < min_samples: + continue + + # Calculate percentiles + sorted_timings = sorted(timings) + count = len(sorted_timings) + + def percentile(p: float, *, sorted_vals=sorted_timings, n=count) -> float: + """Calculate percentile value. + + Args: + p: Percentile to calculate (0.0 to 1.0) + sorted_vals: Sorted list of values + n: Number of values + + Returns: + float: Calculated percentile value + """ + k = (n - 1) * p + f = int(k) + c = k - f + if f + 1 < n: + return sorted_vals[f] * (1 - c) + sorted_vals[f + 1] * c + return sorted_vals[f] + + summary[op_name] = { + "count": count, + "avg_duration_ms": statistics.mean(timings) * 1000, + "min_duration_ms": min(timings) * 1000, + "max_duration_ms": max(timings) * 1000, + "p50_duration_ms": percentile(0.5) * 1000, + "p95_duration_ms": percentile(0.95) * 1000, + "p99_duration_ms": percentile(0.99) * 1000, + "threshold_ms": self.performance_thresholds.get(op_name, float("inf")) * 1000, + "threshold_violations": sum(1 for t in timings if t > self.performance_thresholds.get(op_name, float("inf"))), + "violation_rate": sum(1 for t in timings if t > self.performance_thresholds.get(op_name, float("inf"))) / count, + } + + return summary + + def get_operation_stats(self, operation_name: str) -> Optional[Dict[str, Any]]: + """Get statistics for a specific operation. + + Args: + operation_name: Name of the operation + + Returns: + Statistics dictionary or None if no data + """ + if operation_name not in self.operation_timings: + return None + + timings = self.operation_timings[operation_name] + if not timings: + return None + + return { + "operation": operation_name, + "sample_count": len(timings), + "avg_duration_ms": statistics.mean(timings) * 1000, + "min_duration_ms": min(timings) * 1000, + "max_duration_ms": max(timings) * 1000, + "total_time_ms": sum(timings) * 1000, + "threshold_ms": self.performance_thresholds.get(operation_name, float("inf")) * 1000, + } + + def clear_stats(self, operation_name: Optional[str] = None) -> None: + """Clear performance statistics. + + Args: + operation_name: Specific operation to clear (None for all) + """ + if operation_name: + if operation_name in self.operation_timings: + self.operation_timings[operation_name].clear() + else: + self.operation_timings.clear() + + def set_threshold(self, operation_name: str, threshold_seconds: float) -> None: + """Set or update performance threshold for an operation. + + Args: + operation_name: Name of the operation + threshold_seconds: Threshold in seconds + """ + self.performance_thresholds[operation_name] = threshold_seconds + + def check_performance_degradation(self, operation_name: str, baseline_multiplier: float = 2.0) -> Dict[str, Any]: + """Check if performance has degraded compared to baseline. + + Args: + operation_name: Name of the operation to check + baseline_multiplier: Multiplier for degradation detection + + Returns: + Dictionary with degradation analysis + """ + if operation_name not in self.operation_timings: + return {"degraded": False, "reason": "no_data"} + + timings = self.operation_timings[operation_name] + if len(timings) < 10: + return {"degraded": False, "reason": "insufficient_samples"} + + # Compare recent timings to overall average + recent_count = min(10, len(timings)) + recent_timings = timings[-recent_count:] + historical_timings = timings[:-recent_count] if len(timings) > recent_count else timings + + if not historical_timings: + return {"degraded": False, "reason": "insufficient_historical_data"} + + recent_avg = statistics.mean(recent_timings) + historical_avg = statistics.mean(historical_timings) + + degraded = recent_avg > (historical_avg * baseline_multiplier) + + return { + "degraded": degraded, + "recent_avg_ms": recent_avg * 1000, + "historical_avg_ms": historical_avg * 1000, + "multiplier": recent_avg / historical_avg if historical_avg > 0 else 0, + "threshold_multiplier": baseline_multiplier, + } + + +# Global performance tracker instance +_performance_tracker: Optional[PerformanceTracker] = None + + +def get_performance_tracker() -> PerformanceTracker: + """Get or create the global performance tracker instance. + + Returns: + Global PerformanceTracker instance + """ + global _performance_tracker # pylint: disable=global-statement + if _performance_tracker is None: + _performance_tracker = PerformanceTracker() + return _performance_tracker diff --git a/mcpgateway/services/prompt_service.py b/mcpgateway/services/prompt_service.py index dac07f888..cd3841563 100644 --- a/mcpgateway/services/prompt_service.py +++ b/mcpgateway/services/prompt_service.py @@ -37,9 +37,11 @@ from mcpgateway.observability import create_span from mcpgateway.plugins.framework import GlobalContext, PluginContextTable, PluginManager, PromptHookType, PromptPosthookPayload, PromptPrehookPayload from mcpgateway.schemas import PromptCreate, PromptRead, PromptUpdate, TopPerformer +from mcpgateway.services.audit_trail_service import get_audit_trail_service from mcpgateway.services.event_service import EventService from mcpgateway.services.logging_service import LoggingService from mcpgateway.services.observability_service import current_trace_id, ObservabilityService +from mcpgateway.services.structured_logger import get_structured_logger from mcpgateway.utils.metrics_common import build_top_performers from mcpgateway.utils.pagination import decode_cursor, encode_cursor from mcpgateway.utils.sqlalchemy_modifier import json_contains_expr @@ -48,6 +50,10 @@ logging_service = LoggingService() logger = logging_service.get_logger(__name__) +# Initialize structured logger and audit trail for prompt operations +structured_logger = get_structured_logger("prompt_service") +audit_trail = get_audit_trail_service() + class PromptError(Exception): """Base class for prompt-related errors.""" @@ -401,18 +407,95 @@ async def register_prompt( await self._notify_prompt_added(db_prompt) logger.info(f"Registered prompt: {prompt.name}") + + # Structured logging: Audit trail for prompt creation + audit_trail.log_action( + user_id=created_by or "system", + action="create_prompt", + resource_type="prompt", + resource_id=str(db_prompt.id), + resource_name=db_prompt.name, + user_email=owner_email, + team_id=team_id, + client_ip=created_from_ip, + user_agent=created_user_agent, + new_values={ + "name": db_prompt.name, + "visibility": visibility, + }, + context={ + "created_via": created_via, + "import_batch_id": import_batch_id, + "federation_source": federation_source, + }, + db=db, + ) + + # Structured logging: Log successful prompt creation + structured_logger.log( + level="INFO", + message="Prompt created successfully", + event_type="prompt_created", + component="prompt_service", + user_id=created_by, + user_email=owner_email, + team_id=team_id, + resource_type="prompt", + resource_id=str(db_prompt.id), + custom_fields={ + "prompt_name": db_prompt.name, + "visibility": visibility, + }, + db=db, + ) + db_prompt.team = self._get_team_name(db, db_prompt.team_id) prompt_dict = self._convert_db_prompt(db_prompt) return PromptRead.model_validate(prompt_dict) except IntegrityError as ie: logger.error(f"IntegrityErrors in group: {ie}") + + structured_logger.log( + level="ERROR", + message="Prompt creation failed due to database integrity error", + event_type="prompt_creation_failed", + component="prompt_service", + user_id=created_by, + user_email=owner_email, + error=ie, + custom_fields={"prompt_name": prompt.name}, + db=db, + ) raise ie except PromptNameConflictError as se: db.rollback() + + structured_logger.log( + level="WARNING", + message="Prompt creation failed due to name conflict", + event_type="prompt_name_conflict", + component="prompt_service", + user_id=created_by, + user_email=owner_email, + custom_fields={"prompt_name": prompt.name, "visibility": visibility}, + db=db, + ) raise se except Exception as e: db.rollback() + + structured_logger.log( + level="ERROR", + message="Prompt creation failed", + event_type="prompt_creation_failed", + component="prompt_service", + user_id=created_by, + user_email=owner_email, + error=e, + custom_fields={"prompt_name": prompt.name}, + db=db, + ) raise PromptError(f"Failed to register prompt: {str(e)}") async def list_prompts(self, db: Session, include_inactive: bool = False, cursor: Optional[str] = None, tags: Optional[List[str]] = None) -> tuple[List[PromptRead], Optional[str]]: @@ -826,6 +909,43 @@ async def get_prompt( # Use modified payload if provided result = post_result.modified_payload.result if post_result.modified_payload else result + arguments_supplied = bool(arguments) + + audit_trail.log_action( + user_id=user or "anonymous", + action="view_prompt", + resource_type="prompt", + resource_id=str(prompt.id), + resource_name=prompt.name, + team_id=prompt.team_id, + context={ + "tenant_id": tenant_id, + "server_id": server_id, + "arguments_provided": arguments_supplied, + "request_id": request_id, + }, + db=db, + ) + + structured_logger.log( + level="INFO", + message="Prompt retrieved successfully", + event_type="prompt_viewed", + component="prompt_service", + user_id=user, + team_id=prompt.team_id, + resource_type="prompt", + resource_id=str(prompt.id), + request_id=request_id, + custom_fields={ + "prompt_name": prompt.name, + "arguments_provided": arguments_supplied, + "tenant_id": tenant_id, + "server_id": server_id, + }, + db=db, + ) + # Set success attributes on span if span: span.set_attribute("success", True) @@ -990,26 +1110,117 @@ async def update_prompt( db.refresh(prompt) await self._notify_prompt_updated(prompt) + + # Structured logging: Audit trail for prompt update + audit_trail.log_action( + user_id=user_email or modified_by or "system", + action="update_prompt", + resource_type="prompt", + resource_id=str(prompt.id), + resource_name=prompt.name, + user_email=user_email, + team_id=prompt.team_id, + client_ip=modified_from_ip, + user_agent=modified_user_agent, + new_values={"name": prompt.name, "version": prompt.version}, + context={"modified_via": modified_via}, + db=db, + ) + + structured_logger.log( + level="INFO", + message="Prompt updated successfully", + event_type="prompt_updated", + component="prompt_service", + user_id=modified_by, + user_email=user_email, + team_id=prompt.team_id, + resource_type="prompt", + resource_id=str(prompt.id), + custom_fields={"prompt_name": prompt.name, "version": prompt.version}, + db=db, + ) + prompt.team = self._get_team_name(db, prompt.team_id) return PromptRead.model_validate(self._convert_db_prompt(prompt)) - except PermissionError: + except PermissionError as pe: db.rollback() + + structured_logger.log( + level="WARNING", + message="Prompt update failed due to permission error", + event_type="prompt_update_permission_denied", + component="prompt_service", + user_email=user_email, + resource_type="prompt", + resource_id=str(prompt_id), + error=pe, + db=db, + ) raise except IntegrityError as ie: db.rollback() logger.error(f"IntegrityErrors in group: {ie}") + + structured_logger.log( + level="ERROR", + message="Prompt update failed due to database integrity error", + event_type="prompt_update_failed", + component="prompt_service", + user_email=user_email, + resource_type="prompt", + resource_id=str(prompt_id), + error=ie, + db=db, + ) raise ie except PromptNotFoundError as e: db.rollback() logger.error(f"Prompt not found: {e}") + + structured_logger.log( + level="ERROR", + message="Prompt update failed - prompt not found", + event_type="prompt_not_found", + component="prompt_service", + user_email=user_email, + resource_type="prompt", + resource_id=str(prompt_id), + error=e, + db=db, + ) raise e except PromptNameConflictError as pnce: db.rollback() logger.error(f"Prompt name conflict: {pnce}") + + structured_logger.log( + level="WARNING", + message="Prompt update failed due to name conflict", + event_type="prompt_name_conflict", + component="prompt_service", + user_email=user_email, + resource_type="prompt", + resource_id=str(prompt_id), + error=pnce, + db=db, + ) raise pnce except Exception as e: db.rollback() + + structured_logger.log( + level="ERROR", + message="Prompt update failed", + event_type="prompt_update_failed", + component="prompt_service", + user_email=user_email, + resource_type="prompt", + resource_id=str(prompt_id), + error=e, + db=db, + ) raise PromptError(f"Failed to update prompt: {str(e)}") async def toggle_prompt_status(self, db: Session, prompt_id: int, activate: bool, user_email: Optional[str] = None) -> PromptRead: @@ -1071,12 +1282,63 @@ async def toggle_prompt_status(self, db: Session, prompt_id: int, activate: bool else: await self._notify_prompt_deactivated(prompt) logger.info(f"Prompt {prompt.name} {'activated' if activate else 'deactivated'}") + + # Structured logging: Audit trail for prompt status toggle + audit_trail.log_action( + user_id=user_email or "system", + action="toggle_prompt_status", + resource_type="prompt", + resource_id=str(prompt.id), + resource_name=prompt.name, + user_email=user_email, + team_id=prompt.team_id, + new_values={"enabled": prompt.enabled}, + context={"action": "activate" if activate else "deactivate"}, + db=db, + ) + + structured_logger.log( + level="INFO", + message=f"Prompt {'activated' if activate else 'deactivated'} successfully", + event_type="prompt_status_toggled", + component="prompt_service", + user_email=user_email, + team_id=prompt.team_id, + resource_type="prompt", + resource_id=str(prompt.id), + custom_fields={"prompt_name": prompt.name, "enabled": prompt.enabled}, + db=db, + ) + prompt.team = self._get_team_name(db, prompt.team_id) return PromptRead.model_validate(self._convert_db_prompt(prompt)) except PermissionError as e: + structured_logger.log( + level="WARNING", + message="Prompt status toggle failed due to permission error", + event_type="prompt_toggle_permission_denied", + component="prompt_service", + user_email=user_email, + resource_type="prompt", + resource_id=str(prompt_id), + error=e, + db=db, + ) raise e except Exception as e: db.rollback() + + structured_logger.log( + level="ERROR", + message="Prompt status toggle failed", + event_type="prompt_toggle_failed", + component="prompt_service", + user_email=user_email, + resource_type="prompt", + resource_id=str(prompt_id), + error=e, + db=db, + ) raise PromptError(f"Failed to toggle prompt status: {str(e)}") # Get prompt details for admin ui @@ -1113,7 +1375,35 @@ async def get_prompt_details(self, db: Session, prompt_id: Union[int, str], incl raise PromptNotFoundError(f"Prompt not found: {prompt_id}") # Return the fully converted prompt including metrics prompt.team = self._get_team_name(db, prompt.team_id) - return self._convert_db_prompt(prompt) + prompt_data = self._convert_db_prompt(prompt) + + audit_trail.log_action( + user_id="system", + action="view_prompt_details", + resource_type="prompt", + resource_id=str(prompt.id), + resource_name=prompt.name, + team_id=prompt.team_id, + context={"include_inactive": include_inactive}, + db=db, + ) + + structured_logger.log( + level="INFO", + message="Prompt details retrieved", + event_type="prompt_details_viewed", + component="prompt_service", + resource_type="prompt", + resource_id=str(prompt.id), + team_id=prompt.team_id, + custom_fields={ + "prompt_name": prompt.name, + "include_inactive": include_inactive, + }, + db=db, + ) + + return prompt_data async def delete_prompt(self, db: Session, prompt_id: Union[int, str], user_email: Optional[str] = None) -> None: """ @@ -1161,17 +1451,85 @@ async def delete_prompt(self, db: Session, prompt_id: Union[int, str], user_emai raise PermissionError("Only the owner can delete this prompt") prompt_info = {"id": prompt.id, "name": prompt.name} + prompt_name = prompt.name + prompt_team_id = prompt.team_id + db.delete(prompt) db.commit() await self._notify_prompt_deleted(prompt_info) logger.info(f"Deleted prompt: {prompt_info['name']}") - except PermissionError: + + # Structured logging: Audit trail for prompt deletion + audit_trail.log_action( + user_id=user_email or "system", + action="delete_prompt", + resource_type="prompt", + resource_id=str(prompt_info["id"]), + resource_name=prompt_name, + user_email=user_email, + team_id=prompt_team_id, + old_values={"name": prompt_name}, + db=db, + ) + + # Structured logging: Log successful prompt deletion + structured_logger.log( + level="INFO", + message="Prompt deleted successfully", + event_type="prompt_deleted", + component="prompt_service", + user_email=user_email, + team_id=prompt_team_id, + resource_type="prompt", + resource_id=str(prompt_info["id"]), + custom_fields={"prompt_name": prompt_name}, + db=db, + ) + except PermissionError as pe: db.rollback() + + # Structured logging: Log permission error + structured_logger.log( + level="WARNING", + message="Prompt deletion failed due to permission error", + event_type="prompt_delete_permission_denied", + component="prompt_service", + user_email=user_email, + resource_type="prompt", + resource_id=str(prompt_id), + error=pe, + db=db, + ) raise except Exception as e: db.rollback() if isinstance(e, PromptNotFoundError): + # Structured logging: Log not found error + structured_logger.log( + level="ERROR", + message="Prompt deletion failed - prompt not found", + event_type="prompt_not_found", + component="prompt_service", + user_email=user_email, + resource_type="prompt", + resource_id=str(prompt_id), + error=e, + db=db, + ) raise e + + # Structured logging: Log generic prompt deletion failure + structured_logger.log( + level="ERROR", + message="Prompt deletion failed", + event_type="prompt_deletion_failed", + component="prompt_service", + user_email=user_email, + resource_type="prompt", + resource_id=str(prompt_id), + error=e, + db=db, + ) raise PromptError(f"Failed to delete prompt: {str(e)}") async def subscribe_events(self) -> AsyncGenerator[Dict[str, Any], None]: diff --git a/mcpgateway/services/resource_service.py b/mcpgateway/services/resource_service.py index ced8901ba..1b8136a51 100644 --- a/mcpgateway/services/resource_service.py +++ b/mcpgateway/services/resource_service.py @@ -51,10 +51,12 @@ from mcpgateway.db import server_resource_association from mcpgateway.observability import create_span from mcpgateway.schemas import ResourceCreate, ResourceMetrics, ResourceRead, ResourceSubscription, ResourceUpdate, TopPerformer +from mcpgateway.services.audit_trail_service import get_audit_trail_service from mcpgateway.services.event_service import EventService from mcpgateway.services.logging_service import LoggingService from mcpgateway.services.oauth_manager import OAuthManager from mcpgateway.services.observability_service import current_trace_id, ObservabilityService +from mcpgateway.services.structured_logger import get_structured_logger from mcpgateway.utils.metrics_common import build_top_performers from mcpgateway.utils.pagination import decode_cursor, encode_cursor from mcpgateway.utils.services_auth import decode_auth @@ -74,6 +76,10 @@ logging_service = LoggingService() logger = logging_service.get_logger(__name__) +# Initialize structured logger and audit trail for resource operations +structured_logger = get_structured_logger("resource_service") +audit_trail = get_audit_trail_service() + class ResourceError(Exception): """Base class for resource-related errors.""" @@ -240,6 +246,17 @@ def _convert_resource_to_read(self, resource: DbResource) -> ResourceRead: resource_dict.pop("_sa_instance_state", None) resource_dict.pop("metrics", None) + # Ensure required base fields are present even if SQLAlchemy hasn't loaded them into __dict__ yet + resource_dict["id"] = getattr(resource, "id", resource_dict.get("id")) + resource_dict["uri"] = getattr(resource, "uri", resource_dict.get("uri")) + resource_dict["name"] = getattr(resource, "name", resource_dict.get("name")) + resource_dict["description"] = getattr(resource, "description", resource_dict.get("description")) + resource_dict["mime_type"] = getattr(resource, "mime_type", resource_dict.get("mime_type")) + resource_dict["size"] = getattr(resource, "size", resource_dict.get("size")) + resource_dict["created_at"] = getattr(resource, "created_at", resource_dict.get("created_at")) + resource_dict["updated_at"] = getattr(resource, "updated_at", resource_dict.get("updated_at")) + resource_dict["is_active"] = getattr(resource, "is_active", resource_dict.get("is_active")) + # Compute aggregated metrics from the resource's metrics list. total = len(resource.metrics) if hasattr(resource, "metrics") and resource.metrics is not None else 0 successful = sum(1 for m in resource.metrics if m.is_success) if total > 0 else 0 @@ -397,16 +414,106 @@ async def register_resource( await self._notify_resource_added(db_resource) logger.info(f"Registered resource: {resource.uri}") + + # Structured logging: Audit trail for resource creation + audit_trail.log_action( + user_id=created_by or "system", + action="create_resource", + resource_type="resource", + resource_id=str(db_resource.id), + resource_name=db_resource.name, + user_email=owner_email, + team_id=team_id, + client_ip=created_from_ip, + user_agent=created_user_agent, + new_values={ + "uri": db_resource.uri, + "name": db_resource.name, + "visibility": visibility, + "mime_type": db_resource.mime_type, + }, + context={ + "created_via": created_via, + "import_batch_id": import_batch_id, + "federation_source": federation_source, + }, + db=db, + ) + + # Structured logging: Log successful resource creation + structured_logger.log( + level="INFO", + message="Resource created successfully", + event_type="resource_created", + component="resource_service", + user_id=created_by, + user_email=owner_email, + team_id=team_id, + resource_type="resource", + resource_id=str(db_resource.id), + custom_fields={ + "resource_uri": db_resource.uri, + "resource_name": db_resource.name, + "visibility": visibility, + }, + db=db, + ) + db_resource.team = self._get_team_name(db, db_resource.team_id) return self._convert_resource_to_read(db_resource) except IntegrityError as ie: logger.error(f"IntegrityErrors in group: {ie}") + + # Structured logging: Log database integrity error + structured_logger.log( + level="ERROR", + message="Resource creation failed due to database integrity error", + event_type="resource_creation_failed", + component="resource_service", + user_id=created_by, + user_email=owner_email, + error=ie, + custom_fields={ + "resource_uri": resource.uri, + }, + db=db, + ) raise ie except ResourceURIConflictError as rce: logger.error(f"ResourceURIConflictError in group: {resource.uri}") + + # Structured logging: Log URI conflict error + structured_logger.log( + level="WARNING", + message="Resource creation failed due to URI conflict", + event_type="resource_uri_conflict", + component="resource_service", + user_id=created_by, + user_email=owner_email, + custom_fields={ + "resource_uri": resource.uri, + "visibility": visibility, + }, + db=db, + ) raise rce except Exception as e: db.rollback() + + # Structured logging: Log generic resource creation failure + structured_logger.log( + level="ERROR", + message="Resource creation failed", + event_type="resource_creation_failed", + component="resource_service", + user_id=created_by, + user_email=owner_email, + error=e, + custom_fields={ + "resource_uri": resource.uri, + }, + db=db, + ) raise ResourceError(f"Failed to register resource: {str(e)}") async def list_resources(self, db: Session, include_inactive: bool = False, cursor: Optional[str] = None, tags: Optional[List[str]] = None) -> tuple[List[ResourceRead], Optional[str]]: @@ -814,7 +921,6 @@ async def invoke_resource(self, db: Session, resource_id: str, resource_uri: str 'using template: /template' """ - uri = None if resource_uri and resource_template_uri: uri = resource_template_uri @@ -1464,12 +1570,72 @@ async def toggle_resource_status(self, db: Session, resource_id: int, activate: logger.info(f"Resource {resource.uri} {'activated' if activate else 'deactivated'}") + # Structured logging: Audit trail for resource status toggle + audit_trail.log_action( + user_id=user_email or "system", + action="toggle_resource_status", + resource_type="resource", + resource_id=str(resource.id), + resource_name=resource.name, + user_email=user_email, + team_id=resource.team_id, + new_values={ + "enabled": resource.enabled, + }, + context={ + "action": "activate" if activate else "deactivate", + }, + db=db, + ) + + # Structured logging: Log successful resource status toggle + structured_logger.log( + level="INFO", + message=f"Resource {'activated' if activate else 'deactivated'} successfully", + event_type="resource_status_toggled", + component="resource_service", + user_email=user_email, + team_id=resource.team_id, + resource_type="resource", + resource_id=str(resource.id), + custom_fields={ + "resource_uri": resource.uri, + "enabled": resource.enabled, + }, + db=db, + ) + resource.team = self._get_team_name(db, resource.team_id) return self._convert_resource_to_read(resource) except PermissionError as e: + # Structured logging: Log permission error + structured_logger.log( + level="WARNING", + message="Resource status toggle failed due to permission error", + event_type="resource_toggle_permission_denied", + component="resource_service", + user_email=user_email, + resource_type="resource", + resource_id=str(resource_id), + error=e, + db=db, + ) raise e except Exception as e: db.rollback() + + # Structured logging: Log generic resource status toggle failure + structured_logger.log( + level="ERROR", + message="Resource status toggle failed", + event_type="resource_toggle_failed", + component="resource_service", + user_email=user_email, + resource_type="resource", + resource_id=str(resource_id), + error=e, + db=db, + ) raise ResourceError(f"Failed to toggle resource status: {str(e)}") async def subscribe_resource(self, db: Session, subscription: ResourceSubscription) -> None: @@ -1685,21 +1851,138 @@ async def update_resource( await self._notify_resource_updated(resource) logger.info(f"Updated resource: {resource.uri}") + + # Structured logging: Audit trail for resource update + changes = [] + if resource_update.uri: + changes.append(f"uri: {resource_update.uri}") + if resource_update.visibility: + changes.append(f"visibility: {resource_update.visibility}") + if resource_update.description: + changes.append("description updated") + + audit_trail.log_action( + user_id=user_email or modified_by or "system", + action="update_resource", + resource_type="resource", + resource_id=str(resource.id), + resource_name=resource.name, + user_email=user_email, + team_id=resource.team_id, + client_ip=modified_from_ip, + user_agent=modified_user_agent, + new_values={ + "uri": resource.uri, + "name": resource.name, + "version": resource.version, + }, + context={ + "modified_via": modified_via, + "changes": ", ".join(changes) if changes else "metadata only", + }, + db=db, + ) + + # Structured logging: Log successful resource update + structured_logger.log( + level="INFO", + message="Resource updated successfully", + event_type="resource_updated", + component="resource_service", + user_id=modified_by, + user_email=user_email, + team_id=resource.team_id, + resource_type="resource", + resource_id=str(resource.id), + custom_fields={ + "resource_uri": resource.uri, + "version": resource.version, + }, + db=db, + ) + return self._convert_resource_to_read(resource) - except PermissionError: + except PermissionError as pe: db.rollback() + + # Structured logging: Log permission error + structured_logger.log( + level="WARNING", + message="Resource update failed due to permission error", + event_type="resource_update_permission_denied", + component="resource_service", + user_email=user_email, + resource_type="resource", + resource_id=str(resource_id), + error=pe, + db=db, + ) raise except IntegrityError as ie: db.rollback() logger.error(f"IntegrityErrors in group: {ie}") + + # Structured logging: Log database integrity error + structured_logger.log( + level="ERROR", + message="Resource update failed due to database integrity error", + event_type="resource_update_failed", + component="resource_service", + user_id=modified_by, + user_email=user_email, + resource_type="resource", + resource_id=str(resource_id), + error=ie, + db=db, + ) raise ie except ResourceURIConflictError as pe: logger.error(f"Resource URI conflict: {pe}") + + # Structured logging: Log URI conflict error + structured_logger.log( + level="WARNING", + message="Resource update failed due to URI conflict", + event_type="resource_uri_conflict", + component="resource_service", + user_id=modified_by, + user_email=user_email, + resource_type="resource", + resource_id=str(resource_id), + error=pe, + db=db, + ) raise pe except Exception as e: db.rollback() if isinstance(e, ResourceNotFoundError): + # Structured logging: Log not found error + structured_logger.log( + level="ERROR", + message="Resource update failed - resource not found", + event_type="resource_not_found", + component="resource_service", + user_email=user_email, + resource_type="resource", + resource_id=str(resource_id), + error=e, + db=db, + ) raise e + + # Structured logging: Log generic resource update failure + structured_logger.log( + level="ERROR", + message="Resource update failed", + event_type="resource_update_failed", + component="resource_service", + user_id=modified_by, + user_email=user_email, + resource_type="resource", + resource_id=str(resource_id), + error=e, + db=db, + ) raise ResourceError(f"Failed to update resource: {str(e)}") async def delete_resource(self, db: Session, resource_id: Union[int, str], user_email: Optional[str] = None) -> None: @@ -1758,6 +2041,10 @@ async def delete_resource(self, db: Session, resource_id: Union[int, str], user_ db.execute(delete(DbSubscription).where(DbSubscription.resource_id == resource.id)) # Hard delete the resource. + resource_uri = resource.uri + resource_name = resource.name + resource_team_id = resource.team_id + db.delete(resource) db.commit() @@ -1766,14 +2053,84 @@ async def delete_resource(self, db: Session, resource_id: Union[int, str], user_ logger.info(f"Permanently deleted resource: {resource.uri}") - except PermissionError: + # Structured logging: Audit trail for resource deletion + audit_trail.log_action( + user_id=user_email or "system", + action="delete_resource", + resource_type="resource", + resource_id=str(resource_info["id"]), + resource_name=resource_name, + user_email=user_email, + team_id=resource_team_id, + old_values={ + "uri": resource_uri, + "name": resource_name, + }, + db=db, + ) + + # Structured logging: Log successful resource deletion + structured_logger.log( + level="INFO", + message="Resource deleted successfully", + event_type="resource_deleted", + component="resource_service", + user_email=user_email, + team_id=resource_team_id, + resource_type="resource", + resource_id=str(resource_info["id"]), + custom_fields={ + "resource_uri": resource_uri, + }, + db=db, + ) + + except PermissionError as pe: db.rollback() + + # Structured logging: Log permission error + structured_logger.log( + level="WARNING", + message="Resource deletion failed due to permission error", + event_type="resource_delete_permission_denied", + component="resource_service", + user_email=user_email, + resource_type="resource", + resource_id=str(resource_id), + error=pe, + db=db, + ) raise - except ResourceNotFoundError: + except ResourceNotFoundError as rnfe: # ResourceNotFoundError is re-raised to be handled in the endpoint. + # Structured logging: Log not found error + structured_logger.log( + level="ERROR", + message="Resource deletion failed - resource not found", + event_type="resource_not_found", + component="resource_service", + user_email=user_email, + resource_type="resource", + resource_id=str(resource_id), + error=rnfe, + db=db, + ) raise except Exception as e: db.rollback() + + # Structured logging: Log generic resource deletion failure + structured_logger.log( + level="ERROR", + message="Resource deletion failed", + event_type="resource_deletion_failed", + component="resource_service", + user_email=user_email, + resource_type="resource", + resource_id=str(resource_id), + error=e, + db=db, + ) raise ResourceError(f"Failed to delete resource: {str(e)}") async def get_resource_by_id(self, db: Session, resource_id: str, include_inactive: bool = False) -> ResourceRead: @@ -1820,7 +2177,24 @@ async def get_resource_by_id(self, db: Session, resource_id: str, include_inacti raise ResourceNotFoundError(f"Resource not found: {resource_id}") - return self._convert_resource_to_read(resource) + resource_read = self._convert_resource_to_read(resource) + + structured_logger.log( + level="INFO", + message="Resource retrieved successfully", + event_type="resource_viewed", + component="resource_service", + team_id=getattr(resource, "team_id", None), + resource_type="resource", + resource_id=str(resource.id), + custom_fields={ + "resource_uri": resource.uri, + "include_inactive": include_inactive, + }, + db=db, + ) + + return resource_read async def _notify_resource_activated(self, resource: DbResource) -> None: """ @@ -1969,8 +2343,10 @@ async def _read_template_resource(self, db: Session, uri: str, include_inactive: # # Handle binary template raise NotImplementedError("Binary resource templates not yet supported") + except ResourceNotFoundError: + raise except Exception as e: - raise ResourceError(f"Failed to process template: {str(e)}") + raise ResourceError(f"Failed to process template: {str(e)}") from e def _build_regex(self, template: str) -> re.Pattern: """ diff --git a/mcpgateway/services/security_logger.py b/mcpgateway/services/security_logger.py new file mode 100644 index 000000000..1b2470691 --- /dev/null +++ b/mcpgateway/services/security_logger.py @@ -0,0 +1,597 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/services/security_logger.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 + +Security Logger Service. + +This module provides specialized logging for security events, threat detection, +and audit trail management with automated threat analysis and alerting. +""" + +# Standard +from datetime import datetime, timedelta, timezone +from enum import Enum +import logging +from typing import Any, Dict, Optional + +# Third-Party +from sqlalchemy import func, select +from sqlalchemy.orm import Session + +# First-Party +from mcpgateway.config import settings +from mcpgateway.db import AuditTrail, SecurityEvent, SessionLocal +from mcpgateway.utils.correlation_id import get_correlation_id + +logger = logging.getLogger(__name__) + + +class SecuritySeverity(str, Enum): + """Security event severity levels.""" + + LOW = "LOW" + MEDIUM = "MEDIUM" + HIGH = "HIGH" + CRITICAL = "CRITICAL" + + +class SecurityEventType(str, Enum): + """Types of security events.""" + + AUTHENTICATION_FAILURE = "authentication_failure" + AUTHENTICATION_SUCCESS = "authentication_success" + AUTHORIZATION_FAILURE = "authorization_failure" + SUSPICIOUS_ACTIVITY = "suspicious_activity" + RATE_LIMIT_EXCEEDED = "rate_limit_exceeded" + BRUTE_FORCE_ATTEMPT = "brute_force_attempt" + TOKEN_MANIPULATION = "token_manipulation" # nosec B105 - Not a password, security event type constant + DATA_EXFILTRATION = "data_exfiltration" + PRIVILEGE_ESCALATION = "privilege_escalation" + INJECTION_ATTEMPT = "injection_attempt" + ANOMALOUS_BEHAVIOR = "anomalous_behavior" + + +class SecurityLogger: + """Specialized logger for security events and audit trails. + + Provides threat detection, security event logging, and audit trail + management with automated analysis and alerting capabilities. + """ + + def __init__(self): + """Initialize security logger.""" + self.failed_auth_threshold = getattr(settings, "security_failed_auth_threshold", 5) + self.threat_score_alert_threshold = getattr(settings, "security_threat_score_alert", 0.7) + self.rate_limit_window_minutes = getattr(settings, "security_rate_limit_window", 5) + + def log_authentication_attempt( + self, + user_id: str, + user_email: Optional[str], + auth_method: str, + success: bool, + client_ip: str, + user_agent: Optional[str] = None, + failure_reason: Optional[str] = None, + additional_context: Optional[Dict[str, Any]] = None, + db: Optional[Session] = None, + ) -> Optional[SecurityEvent]: + """Log authentication attempts with security analysis. + + Args: + user_id: User identifier + user_email: User email address + auth_method: Authentication method used + success: Whether authentication succeeded + client_ip: Client IP address + user_agent: Client user agent + failure_reason: Reason for failure if applicable + additional_context: Additional event context + db: Optional database session + + Returns: + Created SecurityEvent or None if logging disabled + """ + correlation_id = get_correlation_id() + + # Count recent failed attempts + failed_attempts = self._count_recent_failures(user_id=user_id, client_ip=client_ip, db=db) + + # Calculate threat score + threat_score = self._calculate_auth_threat_score(success=success, failed_attempts=failed_attempts, auth_method=auth_method) + + # Determine severity + if not success: + if failed_attempts >= self.failed_auth_threshold: + severity = SecuritySeverity.HIGH + elif failed_attempts >= 3: + severity = SecuritySeverity.MEDIUM + else: + severity = SecuritySeverity.LOW + else: + severity = SecuritySeverity.LOW + + # Build event description + description = f"Authentication {'successful' if success else 'failed'} for user {user_id}" + if not success and failure_reason: + description += f": {failure_reason}" + + # Build context + context = {"auth_method": auth_method, "failed_attempts_recent": failed_attempts, "user_agent": user_agent, **(additional_context or {})} + + # Create security event + event = self._create_security_event( + event_type=SecurityEventType.AUTHENTICATION_SUCCESS if success else SecurityEventType.AUTHENTICATION_FAILURE, + severity=severity, + category="authentication", + user_id=user_id, + user_email=user_email, + client_ip=client_ip, + user_agent=user_agent, + description=description, + threat_score=threat_score, + failed_attempts_count=failed_attempts, + context=context, + action_taken="allowed" if success else "denied", + correlation_id=correlation_id, + db=db, + ) + + # Log to standard logger as well + log_level = logging.WARNING if not success else logging.INFO + logger.log( + log_level, + f"Authentication attempt: {description}", + extra={ + "security_event": True, + "event_type": event.event_type if event else None, + "severity": severity.value, + "threat_score": threat_score, + "correlation_id": correlation_id, + }, + ) + + return event + + def log_data_access( # pylint: disable=too-many-positional-arguments + self, + action: str, + resource_type: str, + resource_id: str, + resource_name: Optional[str], + user_id: str, + user_email: Optional[str], + team_id: Optional[str], + client_ip: Optional[str], + user_agent: Optional[str], + success: bool, + data_classification: Optional[str] = None, + old_values: Optional[Dict[str, Any]] = None, + new_values: Optional[Dict[str, Any]] = None, + error_message: Optional[str] = None, + additional_context: Optional[Dict[str, Any]] = None, + db: Optional[Session] = None, + ) -> Optional[AuditTrail]: + """Log data access for audit trails. + + Args: + action: Action performed (create, read, update, delete, execute) + resource_type: Type of resource accessed + resource_id: Resource identifier + resource_name: Resource name + user_id: User performing the action + user_email: User email + team_id: Team context + client_ip: Client IP address + user_agent: Client user agent + success: Whether action succeeded + data_classification: Data sensitivity classification + old_values: Previous values (for updates) + new_values: New values (for updates/creates) + error_message: Error message if failed + additional_context: Additional context + db: Optional database session + + Returns: + Created AuditTrail entry or None + """ + correlation_id = get_correlation_id() + + # Determine if audit requires review + requires_review = self._requires_audit_review(action=action, resource_type=resource_type, data_classification=data_classification, success=success) + + # Calculate changes + changes = None + if old_values and new_values: + changes = {k: {"old": old_values.get(k), "new": new_values.get(k)} for k in set(old_values.keys()) | set(new_values.keys()) if old_values.get(k) != new_values.get(k)} + + # Create audit trail + audit = self._create_audit_trail( + action=action, + resource_type=resource_type, + resource_id=resource_id, + resource_name=resource_name, + user_id=user_id, + user_email=user_email, + team_id=team_id, + client_ip=client_ip, + user_agent=user_agent, + success=success, + old_values=old_values, + new_values=new_values, + changes=changes, + data_classification=data_classification, + requires_review=requires_review, + error_message=error_message, + context=additional_context, + correlation_id=correlation_id, + db=db, + ) + + # Log sensitive data access as security event + if data_classification in ["confidential", "restricted", "sensitive"]: + self._create_security_event( + event_type="data_access", + severity=SecuritySeverity.MEDIUM if success else SecuritySeverity.HIGH, + category="data_access", + user_id=user_id, + user_email=user_email, + client_ip=client_ip or "unknown", + user_agent=user_agent, + description=f"Access to {data_classification} {resource_type}: {resource_name or resource_id}", + threat_score=0.3 if success else 0.6, + context={ + "action": action, + "resource_type": resource_type, + "resource_id": resource_id, + "data_classification": data_classification, + }, + correlation_id=correlation_id, + db=db, + ) + + return audit + + def log_suspicious_activity( + self, + activity_type: str, + description: str, + user_id: Optional[str], + user_email: Optional[str], + client_ip: str, + user_agent: Optional[str], + threat_score: float, + severity: SecuritySeverity, + threat_indicators: Dict[str, Any], + action_taken: str, + additional_context: Optional[Dict[str, Any]] = None, + db: Optional[Session] = None, + ) -> Optional[SecurityEvent]: + """Log suspicious activity with threat analysis. + + Args: + activity_type: Type of suspicious activity + description: Event description + user_id: User identifier (if known) + user_email: User email (if known) + client_ip: Client IP address + user_agent: Client user agent + threat_score: Calculated threat score (0.0-1.0) + severity: Event severity + threat_indicators: Dictionary of threat indicators + action_taken: Action taken in response + additional_context: Additional context + db: Optional database session + + Returns: + Created SecurityEvent or None + """ + correlation_id = get_correlation_id() + + event = self._create_security_event( + event_type=SecurityEventType.SUSPICIOUS_ACTIVITY, + severity=severity, + category="suspicious_activity", + user_id=user_id, + user_email=user_email, + client_ip=client_ip, + user_agent=user_agent, + description=description, + threat_score=threat_score, + threat_indicators=threat_indicators, + action_taken=action_taken, + context=additional_context, + correlation_id=correlation_id, + db=db, + ) + + logger.warning( + f"Suspicious activity detected: {description}", + extra={ + "security_event": True, + "activity_type": activity_type, + "severity": severity.value, + "threat_score": threat_score, + "action_taken": action_taken, + "correlation_id": correlation_id, + }, + ) + + return event + + def _count_recent_failures(self, user_id: Optional[str] = None, client_ip: Optional[str] = None, minutes: Optional[int] = None, db: Optional[Session] = None) -> int: + """Count recent authentication failures. + + Args: + user_id: User identifier + client_ip: Client IP address + minutes: Time window in minutes + db: Optional database session + + Returns: + Count of recent failures + """ + if not user_id and not client_ip: + return 0 + + window_minutes = minutes or self.rate_limit_window_minutes + since = datetime.now(timezone.utc) - timedelta(minutes=window_minutes) + + should_close = False + if db is None: + db = SessionLocal() + should_close = True + + try: + stmt = select(func.count(SecurityEvent.id)).where(SecurityEvent.event_type == SecurityEventType.AUTHENTICATION_FAILURE, SecurityEvent.timestamp >= since) # pylint: disable=not-callable + + if user_id: + stmt = stmt.where(SecurityEvent.user_id == user_id) + if client_ip: + stmt = stmt.where(SecurityEvent.client_ip == client_ip) + + result = db.execute(stmt).scalar() + return result or 0 + + finally: + if should_close: + db.close() + + def _calculate_auth_threat_score(self, success: bool, failed_attempts: int, auth_method: str) -> float: # pylint: disable=unused-argument + """Calculate threat score for authentication attempt. + + Args: + success: Whether authentication succeeded + failed_attempts: Count of recent failures + auth_method: Authentication method used + + Returns: + Threat score from 0.0 to 1.0 + """ + if success: + return 0.0 + + # Base score for failure + score = 0.3 + + # Increase based on failed attempts + if failed_attempts >= 10: + score += 0.5 + elif failed_attempts >= 5: + score += 0.3 + elif failed_attempts >= 3: + score += 0.2 + + # Cap at 1.0 + return min(score, 1.0) + + def _requires_audit_review(self, action: str, resource_type: str, data_classification: Optional[str], success: bool) -> bool: + """Determine if audit entry requires manual review. + + Args: + action: Action performed + resource_type: Resource type + data_classification: Data classification + success: Whether action succeeded + + Returns: + True if review required + """ + # Failed actions on sensitive data require review + if not success and data_classification in ["confidential", "restricted"]: + return True + + # Deletions of sensitive data require review + if action == "delete" and data_classification in ["confidential", "restricted"]: + return True + + # Privilege modifications require review + if resource_type in ["role", "permission", "team_member"]: + return True + + return False + + def _create_security_event( + self, + event_type: str, + severity: SecuritySeverity, + category: str, + client_ip: str, + description: str, + threat_score: float, + user_id: Optional[str] = None, + user_email: Optional[str] = None, + user_agent: Optional[str] = None, + action_taken: Optional[str] = None, + failed_attempts_count: int = 0, + threat_indicators: Optional[Dict[str, Any]] = None, + context: Optional[Dict[str, Any]] = None, + correlation_id: Optional[str] = None, + db: Optional[Session] = None, + ) -> Optional[SecurityEvent]: + """Create a security event record. + + Args: + event_type: Type of security event + severity: Event severity + category: Event category + client_ip: Client IP address + description: Event description + threat_score: Threat score (0.0-1.0) + user_id: User identifier + user_email: User email + user_agent: User agent string + action_taken: Action taken + failed_attempts_count: Failed attempts count + threat_indicators: Threat indicators + context: Additional context + correlation_id: Correlation ID + db: Optional database session + + Returns: + Created SecurityEvent or None + """ + should_close = False + if db is None: + db = SessionLocal() + should_close = True + + try: + event = SecurityEvent( + event_type=event_type, + severity=severity.value, + category=category, + user_id=user_id, + user_email=user_email, + client_ip=client_ip, + user_agent=user_agent, + description=description, + action_taken=action_taken, + threat_score=threat_score, + threat_indicators=threat_indicators or {}, + failed_attempts_count=failed_attempts_count, + context=context, + correlation_id=correlation_id, + ) + + db.add(event) + db.commit() + db.refresh(event) + + return event + + except Exception as e: + logger.error(f"Failed to create security event: {e}") + db.rollback() + return None + + finally: + if should_close: + db.close() + + def _create_audit_trail( # pylint: disable=too-many-positional-arguments + self, + action: str, + resource_type: str, + user_id: str, + success: bool, + resource_id: Optional[str] = None, + resource_name: Optional[str] = None, + user_email: Optional[str] = None, + team_id: Optional[str] = None, + client_ip: Optional[str] = None, + user_agent: Optional[str] = None, + old_values: Optional[Dict[str, Any]] = None, + new_values: Optional[Dict[str, Any]] = None, + changes: Optional[Dict[str, Any]] = None, + data_classification: Optional[str] = None, + requires_review: bool = False, + error_message: Optional[str] = None, + context: Optional[Dict[str, Any]] = None, + correlation_id: Optional[str] = None, + db: Optional[Session] = None, + ) -> Optional[AuditTrail]: + """Create an audit trail record. + + Args: + action: Action performed + resource_type: Resource type + user_id: User performing action + success: Whether action succeeded + resource_id: Resource identifier + resource_name: Resource name + user_email: User email + team_id: Team context + client_ip: Client IP + user_agent: User agent + old_values: Previous values + new_values: New values + changes: Calculated changes + data_classification: Data classification + requires_review: Whether manual review needed + error_message: Error message if failed + context: Additional context + correlation_id: Correlation ID + db: Optional database session + + Returns: + Created AuditTrail or None + """ + should_close = False + if db is None: + db = SessionLocal() + should_close = True + + try: + audit = AuditTrail( + action=action, + resource_type=resource_type, + resource_id=resource_id, + resource_name=resource_name, + user_id=user_id, + user_email=user_email, + team_id=team_id, + client_ip=client_ip, + user_agent=user_agent, + old_values=old_values, + new_values=new_values, + changes=changes, + data_classification=data_classification, + requires_review=requires_review, + success=success, + error_message=error_message, + context=context, + correlation_id=correlation_id, + ) + + db.add(audit) + db.commit() + db.refresh(audit) + + return audit + + except Exception as e: + logger.error(f"Failed to create audit trail: {e}") + db.rollback() + return None + + finally: + if should_close: + db.close() + + +# Global security logger instance +_security_logger: Optional[SecurityLogger] = None + + +def get_security_logger() -> SecurityLogger: + """Get or create the global security logger instance. + + Returns: + Global SecurityLogger instance + """ + global _security_logger # pylint: disable=global-statement + if _security_logger is None: + _security_logger = SecurityLogger() + return _security_logger diff --git a/mcpgateway/services/server_service.py b/mcpgateway/services/server_service.py index e7f8aae4d..01f20b304 100644 --- a/mcpgateway/services/server_service.py +++ b/mcpgateway/services/server_service.py @@ -33,7 +33,10 @@ from mcpgateway.db import ServerMetric from mcpgateway.db import Tool as DbTool from mcpgateway.schemas import ServerCreate, ServerMetrics, ServerRead, ServerUpdate, TopPerformer +from mcpgateway.services.audit_trail_service import get_audit_trail_service from mcpgateway.services.logging_service import LoggingService +from mcpgateway.services.performance_tracker import get_performance_tracker +from mcpgateway.services.structured_logger import get_structured_logger from mcpgateway.services.team_management_service import TeamManagementService from mcpgateway.utils.metrics_common import build_top_performers from mcpgateway.utils.sqlalchemy_modifier import json_contains_expr @@ -130,6 +133,9 @@ def __init__(self) -> None: """ self._event_subscribers: List[asyncio.Queue] = [] self._http_client = httpx.AsyncClient(timeout=settings.federation_timeout, verify=not settings.skip_ssl_verify) + self._structured_logger = get_structured_logger("server_service") + self._audit_trail = get_audit_trail_service() + self._performance_tracker = get_performance_tracker() async def initialize(self) -> None: """Initialize the server service.""" @@ -394,7 +400,7 @@ async def register_server( Examples: >>> from mcpgateway.services.server_service import ServerService - >>> from unittest.mock import MagicMock, AsyncMock + >>> from unittest.mock import MagicMock, AsyncMock, patch >>> from mcpgateway.schemas import ServerRead >>> service = ServerService() >>> db = MagicMock() @@ -406,6 +412,8 @@ async def register_server( >>> db.refresh = MagicMock() >>> service._notify_server_added = AsyncMock() >>> service._convert_server_to_read = MagicMock(return_value='server_read') + >>> service._structured_logger = MagicMock() # Mock structured logger to prevent database writes + >>> service._audit_trail = MagicMock() # Mock audit trail to prevent database writes >>> ServerRead.model_validate = MagicMock(return_value='server_read') >>> import asyncio >>> asyncio.run(service.register_server(db, server_in)) @@ -549,17 +557,91 @@ async def register_server( logger.debug(f"Server Data: {server_data}") await self._notify_server_added(db_server) logger.info(f"Registered server: {server_in.name}") + + # Structured logging: Audit trail for server creation + self._audit_trail.log_action( + user_id=created_by or "system", + action="create_server", + resource_type="server", + resource_id=db_server.id, + details={ + "server_name": db_server.name, + "visibility": visibility, + "team_id": team_id, + "associated_tools_count": len(db_server.tools), + "associated_resources_count": len(db_server.resources), + "associated_prompts_count": len(db_server.prompts), + "associated_a2a_agents_count": len(db_server.a2a_agents), + }, + metadata={ + "created_from_ip": created_from_ip, + "created_via": created_via, + "created_user_agent": created_user_agent, + }, + ) + + # Structured logging: Log successful server creation + self._structured_logger.log( + level="INFO", + message="Server created successfully", + event_type="server_created", + component="server_service", + server_id=db_server.id, + server_name=db_server.name, + visibility=visibility, + created_by=created_by, + user_email=created_by, + ) + db_server.team = self._get_team_name(db, db_server.team_id) return self._convert_server_to_read(db_server) except IntegrityError as ie: db.rollback() logger.error(f"IntegrityErrors in group: {ie}") + + # Structured logging: Log database integrity error + self._structured_logger.log( + level="ERROR", + message="Server creation failed due to database integrity error", + event_type="server_creation_failed", + component="server_service", + server_name=server_in.name, + error_type="IntegrityError", + error_message=str(ie), + created_by=created_by, + user_email=created_by, + ) raise ie except ServerNameConflictError as se: db.rollback() + + # Structured logging: Log name conflict error + self._structured_logger.log( + level="WARNING", + message="Server creation failed due to name conflict", + event_type="server_name_conflict", + component="server_service", + server_name=server_in.name, + visibility=visibility, + created_by=created_by, + user_email=created_by, + ) raise se except Exception as ex: db.rollback() + + # Structured logging: Log generic server creation failure + self._structured_logger.log( + level="ERROR", + message="Server creation failed", + event_type="server_creation_failed", + component="server_service", + server_name=server_in.name, + error_type=type(ex).__name__, + error_message=str(ex), + created_by=created_by, + user_email=created_by, + ) raise ServerError(f"Failed to register server: {str(ex)}") async def list_servers(self, db: Session, include_inactive: bool = False, tags: Optional[List[str]] = None) -> List[ServerRead]: @@ -731,7 +813,39 @@ async def get_server(self, db: Session, server_id: str) -> ServerRead: } logger.debug(f"Server Data: {server_data}") server.team = self._get_team_name(db, server.team_id) if server else None - return self._convert_server_to_read(server) + server_read = self._convert_server_to_read(server) + + self._structured_logger.log( + level="INFO", + message="Server retrieved successfully", + event_type="server_viewed", + component="server_service", + server_id=server.id, + server_name=server.name, + team_id=getattr(server, "team_id", None), + resource_type="server", + resource_id=server.id, + custom_fields={ + "enabled": server.enabled, + "tool_count": len(getattr(server, "tools", []) or []), + "resource_count": len(getattr(server, "resources", []) or []), + "prompt_count": len(getattr(server, "prompts", []) or []), + }, + db=db, + ) + + self._audit_trail.log_action( + action="view_server", + resource_type="server", + resource_id=server.id, + resource_name=server.name, + user_id="system", + team_id=getattr(server, "team_id", None), + context={"enabled": server.enabled}, + db=db, + ) + + return server_read async def update_server( self, @@ -769,7 +883,7 @@ async def update_server( Examples: >>> from mcpgateway.services.server_service import ServerService - >>> from unittest.mock import MagicMock, AsyncMock + >>> from unittest.mock import MagicMock, AsyncMock, patch >>> from mcpgateway.schemas import ServerRead >>> service = ServerService() >>> db = MagicMock() @@ -783,6 +897,8 @@ async def update_server( >>> db.refresh = MagicMock() >>> db.execute.return_value.scalar_one_or_none.return_value = None >>> service._convert_server_to_read = MagicMock(return_value='server_read') + >>> service._structured_logger = MagicMock() # Mock structured logger to prevent database writes + >>> service._audit_trail = MagicMock() # Mock audit trail to prevent database writes >>> ServerRead.model_validate = MagicMock(return_value='server_read') >>> server_update = MagicMock() >>> server_update.id = None # No UUID change @@ -927,6 +1043,44 @@ async def update_server( await self._notify_server_updated(server) logger.info(f"Updated server: {server.name}") + # Structured logging: Audit trail for server update + changes = [] + if server_update.name: + changes.append(f"name: {server_update.name}") + if server_update.visibility: + changes.append(f"visibility: {server_update.visibility}") + if server_update.team_id: + changes.append(f"team_id: {server_update.team_id}") + + self._audit_trail.log_action( + user_id=user_email or "system", + action="update_server", + resource_type="server", + resource_id=server.id, + details={ + "server_name": server.name, + "changes": ", ".join(changes) if changes else "metadata only", + "version": server.version, + }, + metadata={ + "modified_from_ip": modified_from_ip, + "modified_via": modified_via, + "modified_user_agent": modified_user_agent, + }, + ) + + # Structured logging: Log successful server update + self._structured_logger.log( + level="INFO", + message="Server updated successfully", + event_type="server_updated", + component="server_service", + server_id=server.id, + server_name=server.name, + modified_by=user_email, + user_email=user_email, + ) + # Build a dictionary with associated IDs server_data = { "id": server.id, @@ -946,13 +1100,50 @@ async def update_server( except IntegrityError as ie: db.rollback() logger.error(f"IntegrityErrors in group: {ie}") + + # Structured logging: Log database integrity error + self._structured_logger.log( + level="ERROR", + message="Server update failed due to database integrity error", + event_type="server_update_failed", + component="server_service", + server_id=server_id, + error_type="IntegrityError", + error_message=str(ie), + modified_by=user_email, + user_email=user_email, + ) raise ie except ServerNameConflictError as snce: db.rollback() logger.error(f"Server name conflict: {snce}") + + # Structured logging: Log name conflict error + self._structured_logger.log( + level="WARNING", + message="Server update failed due to name conflict", + event_type="server_name_conflict", + component="server_service", + server_id=server_id, + modified_by=user_email, + user_email=user_email, + ) raise snce except Exception as e: db.rollback() + + # Structured logging: Log generic server update failure + self._structured_logger.log( + level="ERROR", + message="Server update failed", + event_type="server_update_failed", + component="server_service", + server_id=server_id, + error_type=type(e).__name__, + error_message=str(e), + modified_by=user_email, + user_email=user_email, + ) raise ServerError(f"Failed to update server: {str(e)}") async def toggle_server_status(self, db: Session, server_id: str, activate: bool, user_email: Optional[str] = None) -> ServerRead: @@ -974,7 +1165,7 @@ async def toggle_server_status(self, db: Session, server_id: str, activate: bool Examples: >>> from mcpgateway.services.server_service import ServerService - >>> from unittest.mock import MagicMock, AsyncMock + >>> from unittest.mock import MagicMock, AsyncMock, patch >>> from mcpgateway.schemas import ServerRead >>> service = ServerService() >>> db = MagicMock() @@ -985,6 +1176,8 @@ async def toggle_server_status(self, db: Session, server_id: str, activate: bool >>> service._notify_server_activated = AsyncMock() >>> service._notify_server_deactivated = AsyncMock() >>> service._convert_server_to_read = MagicMock(return_value='server_read') + >>> service._structured_logger = MagicMock() # Mock structured logger to prevent database writes + >>> service._audit_trail = MagicMock() # Mock audit trail to prevent database writes >>> ServerRead.model_validate = MagicMock(return_value='server_read') >>> import asyncio >>> asyncio.run(service.toggle_server_status(db, 'server_id', True)) @@ -1014,6 +1207,31 @@ async def toggle_server_status(self, db: Session, server_id: str, activate: bool await self._notify_server_deactivated(server) logger.info(f"Server {server.name} {'activated' if activate else 'deactivated'}") + # Structured logging: Audit trail for server status toggle + self._audit_trail.log_action( + user_id=user_email or "system", + action="activate_server" if activate else "deactivate_server", + resource_type="server", + resource_id=server.id, + details={ + "server_name": server.name, + "new_status": "active" if activate else "inactive", + }, + ) + + # Structured logging: Log server status change + self._structured_logger.log( + level="INFO", + message=f"Server {'activated' if activate else 'deactivated'}", + event_type="server_status_changed", + component="server_service", + server_id=server.id, + server_name=server.name, + new_status="active" if activate else "inactive", + changed_by=user_email, + user_email=user_email, + ) + server_data = { "id": server.id, "name": server.name, @@ -1030,9 +1248,30 @@ async def toggle_server_status(self, db: Session, server_id: str, activate: bool logger.info(f"Server Data: {server_data}") return self._convert_server_to_read(server) except PermissionError as e: + # Structured logging: Log permission error + self._structured_logger.log( + level="WARNING", + message="Server status toggle failed due to insufficient permissions", + event_type="server_status_toggle_permission_denied", + component="server_service", + server_id=server_id, + user_email=user_email, + ) raise e except Exception as e: db.rollback() + + # Structured logging: Log generic server status toggle failure + self._structured_logger.log( + level="ERROR", + message="Server status toggle failed", + event_type="server_status_toggle_failed", + component="server_service", + server_id=server_id, + error_type=type(e).__name__, + error_message=str(e), + user_email=user_email, + ) raise ServerError(f"Failed to toggle server status: {str(e)}") async def delete_server(self, db: Session, server_id: str, user_email: Optional[str] = None) -> None: @@ -1050,7 +1289,7 @@ async def delete_server(self, db: Session, server_id: str, user_email: Optional[ Examples: >>> from mcpgateway.services.server_service import ServerService - >>> from unittest.mock import MagicMock, AsyncMock + >>> from unittest.mock import MagicMock, AsyncMock, patch >>> service = ServerService() >>> db = MagicMock() >>> server = MagicMock() @@ -1058,6 +1297,8 @@ async def delete_server(self, db: Session, server_id: str, user_email: Optional[ >>> db.delete = MagicMock() >>> db.commit = MagicMock() >>> service._notify_server_deleted = AsyncMock() + >>> service._structured_logger = MagicMock() # Mock structured logger to prevent database writes + >>> service._audit_trail = MagicMock() # Mock audit trail to prevent database writes >>> import asyncio >>> asyncio.run(service.delete_server(db, 'server_id', 'user@example.com')) """ @@ -1081,11 +1322,56 @@ async def delete_server(self, db: Session, server_id: str, user_email: Optional[ await self._notify_server_deleted(server_info) logger.info(f"Deleted server: {server_info['name']}") - except PermissionError: + + # Structured logging: Audit trail for server deletion + self._audit_trail.log_action( + user_id=user_email or "system", + action="delete_server", + resource_type="server", + resource_id=server_info["id"], + details={ + "server_name": server_info["name"], + }, + ) + + # Structured logging: Log successful server deletion + self._structured_logger.log( + level="INFO", + message="Server deleted successfully", + event_type="server_deleted", + component="server_service", + server_id=server_info["id"], + server_name=server_info["name"], + deleted_by=user_email, + user_email=user_email, + ) + except PermissionError as pe: db.rollback() - raise + + # Structured logging: Log permission error + self._structured_logger.log( + level="WARNING", + message="Server deletion failed due to insufficient permissions", + event_type="server_deletion_permission_denied", + component="server_service", + server_id=server_id, + user_email=user_email, + ) + raise pe except Exception as e: db.rollback() + + # Structured logging: Log generic server deletion failure + self._structured_logger.log( + level="ERROR", + message="Server deletion failed", + event_type="server_deletion_failed", + component="server_service", + server_id=server_id, + error_type=type(e).__name__, + error_message=str(e), + user_email=user_email, + ) raise ServerError(f"Failed to delete server: {str(e)}") async def _publish_event(self, event: Dict[str, Any]) -> None: diff --git a/mcpgateway/services/structured_logger.py b/mcpgateway/services/structured_logger.py new file mode 100644 index 000000000..0d8a4a599 --- /dev/null +++ b/mcpgateway/services/structured_logger.py @@ -0,0 +1,441 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/services/structured_logger.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 + +Structured Logger Service. + +This module provides comprehensive structured logging with component-based loggers, +automatic enrichment, intelligent routing, and database persistence. +""" + +# Standard +from datetime import datetime, timezone +from enum import Enum +import logging +import os +import socket +import traceback +from typing import Any, Dict, List, Optional, Union + +# Third-Party +from sqlalchemy.orm import Session + +# First-Party +from mcpgateway.config import settings +from mcpgateway.db import SessionLocal, StructuredLogEntry +from mcpgateway.services.performance_tracker import get_performance_tracker +from mcpgateway.utils.correlation_id import get_correlation_id + +logger = logging.getLogger(__name__) + + +class LogLevel(str, Enum): + """Log levels matching Python logging.""" + + DEBUG = "DEBUG" + INFO = "INFO" + WARNING = "WARNING" + ERROR = "ERROR" + CRITICAL = "CRITICAL" + + +class LogCategory(str, Enum): + """Log categories for classification.""" + + APPLICATION = "application" + REQUEST = "request" + SECURITY = "security" + PERFORMANCE = "performance" + DATABASE = "database" + AUTHENTICATION = "authentication" + AUTHORIZATION = "authorization" + EXTERNAL_SERVICE = "external_service" + BUSINESS_LOGIC = "business_logic" + SYSTEM = "system" + + +class LogEnricher: + """Enriches log entries with contextual information.""" + + @staticmethod + def enrich(entry: Dict[str, Any]) -> Dict[str, Any]: + """Enrich log entry with system and context information. + + Args: + entry: Base log entry + + Returns: + Enriched log entry + """ + # Get correlation ID + correlation_id = get_correlation_id() + if correlation_id: + entry["correlation_id"] = correlation_id + + # Add hostname and process info + entry.setdefault("hostname", socket.gethostname()) + entry.setdefault("process_id", os.getpid()) + + # Add timestamp if not present + if "timestamp" not in entry: + entry["timestamp"] = datetime.now(timezone.utc) + + # Add performance metrics if available + try: + perf_tracker = get_performance_tracker() + if correlation_id and perf_tracker and hasattr(perf_tracker, "get_current_operations"): + current_ops = perf_tracker.get_current_operations(correlation_id) # pylint: disable=no-member + if current_ops: + entry["active_operations"] = len(current_ops) + except Exception: # nosec B110 - Graceful degradation if performance tracker unavailable + # Silently skip if performance tracker is unavailable or method doesn't exist + pass + + # Add OpenTelemetry trace context if available + try: + # Third-Party + from opentelemetry import trace # pylint: disable=import-outside-toplevel + + span = trace.get_current_span() + if span and span.get_span_context().is_valid: + ctx = span.get_span_context() + entry["trace_id"] = format(ctx.trace_id, "032x") + entry["span_id"] = format(ctx.span_id, "016x") + except (ImportError, Exception): + pass + + return entry + + +class LogRouter: + """Routes log entries to appropriate destinations.""" + + def __init__(self): + """Initialize log router.""" + self.database_enabled = getattr(settings, "structured_logging_database_enabled", True) + self.external_enabled = getattr(settings, "structured_logging_external_enabled", False) + + def route(self, entry: Dict[str, Any], db: Optional[Session] = None) -> None: + """Route log entry to configured destinations. + + Args: + entry: Log entry to route + db: Optional database session + """ + # Always log to standard Python logger + self._log_to_python_logger(entry) + + # Persist to database if enabled + if self.database_enabled: + self._persist_to_database(entry, db) + + # Send to external systems if enabled + if self.external_enabled: + self._send_to_external(entry) + + def _log_to_python_logger(self, entry: Dict[str, Any]) -> None: + """Log to standard Python logger. + + Args: + entry: Log entry + """ + level_str = entry.get("level", "INFO") + level = getattr(logging, level_str, logging.INFO) + + message = entry.get("message", "") + component = entry.get("component", "") + + log_message = f"[{component}] {message}" if component else message + + # Build extra dict for structured logging + extra = {k: v for k, v in entry.items() if k not in ["message", "level"]} + + logger.log(level, log_message, extra=extra) + + def _persist_to_database(self, entry: Dict[str, Any], db: Optional[Session] = None) -> None: + """Persist log entry to database. + + Args: + entry: Log entry + db: Optional database session + """ + should_close = False + if db is None: + db = SessionLocal() + should_close = True + + try: + # Build error_details JSON from error-related fields + error_details = None + if any([entry.get("error_type"), entry.get("error_message"), entry.get("error_stack_trace"), entry.get("error_context")]): + error_details = { + "error_type": entry.get("error_type"), + "error_message": entry.get("error_message"), + "error_stack_trace": entry.get("error_stack_trace"), + "error_context": entry.get("error_context"), + } + + # Build performance_metrics JSON from performance-related fields + performance_metrics = None + perf_fields = { + "database_query_count": entry.get("database_query_count"), + "database_query_duration_ms": entry.get("database_query_duration_ms"), + "cache_hits": entry.get("cache_hits"), + "cache_misses": entry.get("cache_misses"), + "external_api_calls": entry.get("external_api_calls"), + "external_api_duration_ms": entry.get("external_api_duration_ms"), + "memory_usage_mb": entry.get("memory_usage_mb"), + "cpu_usage_percent": entry.get("cpu_usage_percent"), + } + if any(v is not None for v in perf_fields.values()): + performance_metrics = {k: v for k, v in perf_fields.items() if v is not None} + + # Build threat_indicators JSON from security-related fields + threat_indicators = None + security_fields = { + "security_event_type": entry.get("security_event_type"), + "security_threat_score": entry.get("security_threat_score"), + "security_action_taken": entry.get("security_action_taken"), + } + if any(v is not None for v in security_fields.values()): + threat_indicators = {k: v for k, v in security_fields.items() if v is not None} + + # Build context JSON from remaining fields + context_fields = { + "team_id": entry.get("team_id"), + "request_query": entry.get("request_query"), + "request_headers": entry.get("request_headers"), + "request_body_size": entry.get("request_body_size"), + "response_status_code": entry.get("response_status_code"), + "response_body_size": entry.get("response_body_size"), + "response_headers": entry.get("response_headers"), + "business_event_type": entry.get("business_event_type"), + "business_entity_type": entry.get("business_entity_type"), + "business_entity_id": entry.get("business_entity_id"), + "resource_type": entry.get("resource_type"), + "resource_id": entry.get("resource_id"), + "resource_action": entry.get("resource_action"), + "category": entry.get("category"), + "custom_fields": entry.get("custom_fields"), + "tags": entry.get("tags"), + "metadata": entry.get("metadata"), + } + context = {k: v for k, v in context_fields.items() if v is not None} + + # Determine if this is a security event + is_security_event = entry.get("is_security_event", False) or bool(threat_indicators) + security_severity = entry.get("security_severity") + + log_entry = StructuredLogEntry( + timestamp=entry.get("timestamp", datetime.now(timezone.utc)), + level=entry.get("level", "INFO"), + component=entry.get("component"), + message=entry.get("message", ""), + correlation_id=entry.get("correlation_id"), + request_id=entry.get("request_id"), + trace_id=entry.get("trace_id"), + span_id=entry.get("span_id"), + user_id=entry.get("user_id"), + user_email=entry.get("user_email"), + client_ip=entry.get("client_ip"), + user_agent=entry.get("user_agent"), + request_method=entry.get("request_method"), + request_path=entry.get("request_path"), + duration_ms=entry.get("duration_ms"), + operation_type=entry.get("operation_type"), + is_security_event=is_security_event, + security_severity=security_severity, + threat_indicators=threat_indicators, + context=context if context else None, + error_details=error_details, + performance_metrics=performance_metrics, + hostname=entry.get("hostname"), + process_id=entry.get("process_id"), + thread_id=entry.get("thread_id"), + environment=entry.get("environment", getattr(settings, "environment", "development")), + version=entry.get("version", getattr(settings, "version", "unknown")), + ) + + db.add(log_entry) + db.commit() + + except Exception as e: + logger.error(f"Failed to persist log entry to database: {e}", exc_info=True) + # Also print to console for immediate visibility + print(f"ERROR persisting log to database: {e}") + traceback.print_exc() + if db: + db.rollback() + + finally: + if should_close: + db.close() + + def _send_to_external(self, entry: Dict[str, Any]) -> None: + """Send log entry to external systems. + + Args: + entry: Log entry + """ + # Placeholder for external logging integration + # Will be implemented in log exporters + + +class StructuredLogger: + """Main structured logger with enrichment and routing.""" + + def __init__(self, component: str): + """Initialize structured logger. + + Args: + component: Component name for log entries + """ + self.component = component + self.enricher = LogEnricher() + self.router = LogRouter() + + def log( + self, + level: Union[LogLevel, str], + message: str, + category: Optional[Union[LogCategory, str]] = None, + user_id: Optional[str] = None, + user_email: Optional[str] = None, + team_id: Optional[str] = None, + error: Optional[Exception] = None, + duration_ms: Optional[float] = None, + custom_fields: Optional[Dict[str, Any]] = None, + tags: Optional[List[str]] = None, + db: Optional[Session] = None, + **kwargs: Any, + ) -> None: + """Log a structured message. + + Args: + level: Log level + message: Log message + category: Log category + user_id: User identifier + user_email: User email + team_id: Team identifier + error: Exception object + duration_ms: Operation duration + custom_fields: Additional custom fields + tags: Log tags + db: Optional database session + **kwargs: Additional fields to include + """ + # Build base entry + entry: Dict[str, Any] = { + "level": level.value if isinstance(level, LogLevel) else level, + "component": self.component, + "message": message, + "category": category.value if isinstance(category, LogCategory) and category else category if category else None, + "user_id": user_id, + "user_email": user_email, + "team_id": team_id, + "duration_ms": duration_ms, + "custom_fields": custom_fields, + "tags": tags, + } + + # Add error information if present + if error: + entry["error_type"] = type(error).__name__ + entry["error_message"] = str(error) + entry["error_stack_trace"] = "".join(traceback.format_exception(type(error), error, error.__traceback__)) + + # Add any additional kwargs + entry.update(kwargs) + + # Enrich entry with context + entry = self.enricher.enrich(entry) + + # Route to destinations + self.router.route(entry, db) + + def debug(self, message: str, **kwargs: Any) -> None: + """Log debug message. + + Args: + message: Log message + **kwargs: Additional context fields + """ + self.log(LogLevel.DEBUG, message, **kwargs) + + def info(self, message: str, **kwargs: Any) -> None: + """Log info message. + + Args: + message: Log message + **kwargs: Additional context fields + """ + self.log(LogLevel.INFO, message, **kwargs) + + def warning(self, message: str, **kwargs: Any) -> None: + """Log warning message. + + Args: + message: Log message + **kwargs: Additional context fields + """ + self.log(LogLevel.WARNING, message, **kwargs) + + def error(self, message: str, error: Optional[Exception] = None, **kwargs: Any) -> None: + """Log error message. + + Args: + message: Log message + error: Exception object if available + **kwargs: Additional context fields + """ + self.log(LogLevel.ERROR, message, error=error, **kwargs) + + def critical(self, message: str, error: Optional[Exception] = None, **kwargs: Any) -> None: + """Log critical message. + + Args: + message: Log message + error: Exception object if available + **kwargs: Additional context fields + """ + self.log(LogLevel.CRITICAL, message, error=error, **kwargs) + + +class ComponentLogger: + """Logger factory for component-specific loggers.""" + + _loggers: Dict[str, StructuredLogger] = {} + + @classmethod + def get_logger(cls, component: str) -> StructuredLogger: + """Get or create a logger for a specific component. + + Args: + component: Component name + + Returns: + StructuredLogger instance for the component + """ + if component not in cls._loggers: + cls._loggers[component] = StructuredLogger(component) + return cls._loggers[component] + + @classmethod + def clear_loggers(cls) -> None: + """Clear all cached loggers (useful for testing).""" + cls._loggers.clear() + + +# Global structured logger instance for backward compatibility +def get_structured_logger(component: str = "mcpgateway") -> StructuredLogger: + """Get a structured logger instance. + + Args: + component: Component name + + Returns: + StructuredLogger instance + """ + return ComponentLogger.get_logger(component) diff --git a/mcpgateway/services/token_storage_service.py b/mcpgateway/services/token_storage_service.py index ef6b18e66..c2585843f 100644 --- a/mcpgateway/services/token_storage_service.py +++ b/mcpgateway/services/token_storage_service.py @@ -280,10 +280,10 @@ def _is_token_expired(self, token_record: OAuthToken, threshold_seconds: int = 3 >>> svc._is_token_expired(rec_past, threshold_seconds=0) True >>> svc._is_token_expired(SimpleNamespace(expires_at=None)) - True + False """ if not token_record.expires_at: - return True + return False expires_at = token_record.expires_at if expires_at.tzinfo is None: expires_at = expires_at.replace(tzinfo=timezone.utc) diff --git a/mcpgateway/services/tool_service.py b/mcpgateway/services/tool_service.py index 731b67e0a..5616e0ff4 100644 --- a/mcpgateway/services/tool_service.py +++ b/mcpgateway/services/tool_service.py @@ -63,10 +63,14 @@ ) from mcpgateway.plugins.framework.constants import GATEWAY_METADATA, TOOL_METADATA from mcpgateway.schemas import ToolCreate, ToolRead, ToolUpdate, TopPerformer +from mcpgateway.services.audit_trail_service import get_audit_trail_service from mcpgateway.services.event_service import EventService from mcpgateway.services.logging_service import LoggingService from mcpgateway.services.oauth_manager import OAuthManager +from mcpgateway.services.performance_tracker import get_performance_tracker +from mcpgateway.services.structured_logger import get_structured_logger from mcpgateway.services.team_management_service import TeamManagementService +from mcpgateway.utils.correlation_id import get_correlation_id from mcpgateway.utils.create_slug import slugify from mcpgateway.utils.display_name import generate_display_name from mcpgateway.utils.metrics_common import build_top_performers @@ -81,6 +85,11 @@ logging_service = LoggingService() logger = logging_service.get_logger(__name__) +# Initialize performance tracker, structured logger, and audit trail for tool operations +perf_tracker = get_performance_tracker() +structured_logger = get_structured_logger("tool_service") +audit_trail = get_audit_trail_service() + def extract_using_jq(data, jq_filter=""): """ @@ -710,17 +719,109 @@ async def register_tool( db.commit() db.refresh(db_tool) await self._notify_tool_added(db_tool) + + # Structured logging: Audit trail for tool creation + audit_trail.log_action( + user_id=created_by or "system", + action="create_tool", + resource_type="tool", + resource_id=db_tool.id, + resource_name=db_tool.name, + user_email=owner_email, + team_id=team_id, + client_ip=created_from_ip, + user_agent=created_user_agent, + new_values={ + "name": db_tool.name, + "display_name": db_tool.display_name, + "visibility": visibility, + "integration_type": db_tool.integration_type, + }, + context={ + "created_via": created_via, + "import_batch_id": import_batch_id, + "federation_source": federation_source, + }, + db=db, + ) + + # Structured logging: Log successful tool creation + structured_logger.log( + level="INFO", + message="Tool created successfully", + event_type="tool_created", + component="tool_service", + user_id=created_by, + user_email=owner_email, + team_id=team_id, + resource_type="tool", + resource_id=db_tool.id, + custom_fields={ + "tool_name": db_tool.name, + "visibility": visibility, + "integration_type": db_tool.integration_type, + }, + db=db, + ) + + # Refresh db_tool after logging commits (they expire the session objects) + db.refresh(db_tool) return self._convert_tool_to_read(db_tool) except IntegrityError as ie: db.rollback() logger.error(f"IntegrityError during tool registration: {ie}") + + # Structured logging: Log database integrity error + structured_logger.log( + level="ERROR", + message="Tool creation failed due to database integrity error", + event_type="tool_creation_failed", + component="tool_service", + user_id=created_by, + user_email=owner_email, + error=ie, + custom_fields={ + "tool_name": tool.name, + }, + db=db, + ) raise ie except ToolNameConflictError as tnce: db.rollback() logger.error(f"ToolNameConflictError during tool registration: {tnce}") + + # Structured logging: Log name conflict error + structured_logger.log( + level="WARNING", + message="Tool creation failed due to name conflict", + event_type="tool_name_conflict", + component="tool_service", + user_id=created_by, + user_email=owner_email, + custom_fields={ + "tool_name": tool.name, + "visibility": visibility, + }, + db=db, + ) raise tnce except Exception as e: db.rollback() + + # Structured logging: Log generic tool creation failure + structured_logger.log( + level="ERROR", + message="Tool creation failed", + event_type="tool_creation_failed", + component="tool_service", + user_id=created_by, + user_email=owner_email, + error=e, + custom_fields={ + "tool_name": tool.name, + }, + db=db, + ) raise ToolError(f"Failed to register tool: {str(e)}") async def list_tools( @@ -1009,7 +1110,25 @@ async def get_tool(self, db: Session, tool_id: str) -> ToolRead: if not tool: raise ToolNotFoundError(f"Tool not found: {tool_id}") tool.team = self._get_team_name(db, getattr(tool, "team_id", None)) - return self._convert_tool_to_read(tool) + + tool_read = self._convert_tool_to_read(tool) + + structured_logger.log( + level="INFO", + message="Tool retrieved successfully", + event_type="tool_viewed", + component="tool_service", + team_id=getattr(tool, "team_id", None), + resource_type="tool", + resource_id=str(tool.id), + custom_fields={ + "tool_name": tool.name, + "include_metrics": bool(getattr(tool_read, "metrics", {})), + }, + db=db, + ) + + return tool_read async def delete_tool(self, db: Session, tool_id: str, user_email: Optional[str] = None) -> None: """ @@ -1053,15 +1172,75 @@ async def delete_tool(self, db: Session, tool_id: str, user_email: Optional[str] raise PermissionError("Only the owner can delete this tool") tool_info = {"id": tool.id, "name": tool.name} + tool_name = tool.name + tool_team_id = tool.team_id + db.delete(tool) db.commit() await self._notify_tool_deleted(tool_info) logger.info(f"Permanently deleted tool: {tool_info['name']}") - except PermissionError: + + # Structured logging: Audit trail for tool deletion + audit_trail.log_action( + user_id=user_email or "system", + action="delete_tool", + resource_type="tool", + resource_id=tool_info["id"], + resource_name=tool_name, + user_email=user_email, + team_id=tool_team_id, + old_values={ + "name": tool_name, + }, + db=db, + ) + + # Structured logging: Log successful tool deletion + structured_logger.log( + level="INFO", + message="Tool deleted successfully", + event_type="tool_deleted", + component="tool_service", + user_email=user_email, + team_id=tool_team_id, + resource_type="tool", + resource_id=tool_info["id"], + custom_fields={ + "tool_name": tool_name, + }, + db=db, + ) + except PermissionError as pe: db.rollback() + + # Structured logging: Log permission error + structured_logger.log( + level="WARNING", + message="Tool deletion failed due to permission error", + event_type="tool_delete_permission_denied", + component="tool_service", + user_email=user_email, + resource_type="tool", + resource_id=tool_id, + error=pe, + db=db, + ) raise except Exception as e: db.rollback() + + # Structured logging: Log generic tool deletion failure + structured_logger.log( + level="ERROR", + message="Tool deletion failed", + event_type="tool_deletion_failed", + component="tool_service", + user_email=user_email, + resource_type="tool", + resource_id=tool_id, + error=e, + db=db, + ) raise ToolError(f"Failed to delete tool: {str(e)}") async def toggle_tool_status(self, db: Session, tool_id: str, activate: bool, reachable: bool, user_email: Optional[str] = None) -> ToolRead: @@ -1140,11 +1319,74 @@ async def toggle_tool_status(self, db: Session, tool_id: str, activate: bool, re await self._notify_tool_activated(tool) logger.info(f"Tool: {tool.name} is {'enabled' if activate else 'disabled'}{' and accessible' if reachable else ' but inaccessible'}") + + # Structured logging: Audit trail for tool status toggle + audit_trail.log_action( + user_id=user_email or "system", + action="toggle_tool_status", + resource_type="tool", + resource_id=tool.id, + resource_name=tool.name, + user_email=user_email, + team_id=tool.team_id, + new_values={ + "enabled": tool.enabled, + "reachable": tool.reachable, + }, + context={ + "action": "activate" if activate else "deactivate", + }, + db=db, + ) + + # Structured logging: Log successful tool status toggle + structured_logger.log( + level="INFO", + message=f"Tool {'activated' if activate else 'deactivated'} successfully", + event_type="tool_status_toggled", + component="tool_service", + user_email=user_email, + team_id=tool.team_id, + resource_type="tool", + resource_id=tool.id, + custom_fields={ + "tool_name": tool.name, + "enabled": tool.enabled, + "reachable": tool.reachable, + }, + db=db, + ) + return self._convert_tool_to_read(tool) except PermissionError as e: + # Structured logging: Log permission error + structured_logger.log( + level="WARNING", + message="Tool status toggle failed due to permission error", + event_type="tool_toggle_permission_denied", + component="tool_service", + user_email=user_email, + resource_type="tool", + resource_id=tool_id, + error=e, + db=db, + ) raise e except Exception as e: db.rollback() + + # Structured logging: Log generic tool status toggle failure + structured_logger.log( + level="ERROR", + message="Tool status toggle failed", + event_type="tool_toggle_failed", + component="tool_service", + user_email=user_email, + resource_type="tool", + resource_id=tool_id, + error=e, + db=db, + ) raise ToolError(f"Failed to toggle tool status: {str(e)}") async def invoke_tool( @@ -1182,15 +1424,17 @@ async def invoke_tool( Examples: >>> from mcpgateway.services.tool_service import ToolService - >>> from unittest.mock import MagicMock + >>> from unittest.mock import MagicMock, patch >>> service = ToolService() >>> db = MagicMock() >>> tool = MagicMock() >>> db.execute.return_value.scalar_one_or_none.side_effect = [tool, None] >>> tool.reachable = True >>> import asyncio - >>> result = asyncio.run(service.invoke_tool(db, 'tool_name', {})) - >>> isinstance(result, object) + >>> # Mock structured_logger to prevent database writes during doctest + >>> with patch('mcpgateway.services.tool_service.structured_logger'): + ... result = asyncio.run(service.invoke_tool(db, 'tool_name', {})) + ... isinstance(result, object) True """ # pylint: disable=comparison-with-callable @@ -1224,7 +1468,8 @@ async def invoke_tool( global_context.server_id = gateway_id else: # Create new context (fallback when middleware didn't run) - request_id = uuid.uuid4().hex + # Use correlation ID from context if available, otherwise generate new one + request_id = get_correlation_id() or uuid.uuid4().hex gateway_id = getattr(tool, "gateway_id", "unknown") server_id = gateway_id if isinstance(gateway_id, str) else "unknown" global_context = GlobalContext(request_id=request_id, server_id=server_id, tenant_id=None, user=app_user_email) @@ -1445,12 +1690,58 @@ async def connect_to_sse_server(server_url: str, headers: dict = headers): Returns: ToolResult: Result of tool call + + Raises: + Exception: On connection or communication errors """ - async with sse_client(url=server_url, headers=headers, httpx_client_factory=get_httpx_client_factory) as streams: - async with ClientSession(*streams) as session: - await session.initialize() - tool_call_result = await session.call_tool(tool.original_name, arguments) - return tool_call_result + # Get correlation ID for distributed tracing + correlation_id = get_correlation_id() + + # Add correlation ID to headers + if correlation_id and headers: + headers["X-Correlation-ID"] = correlation_id + + # Log MCP call start + mcp_start_time = time.time() + structured_logger.log( + level="INFO", + message=f"MCP tool call started: {tool.original_name}", + component="tool_service", + correlation_id=correlation_id, + metadata={"event": "mcp_call_started", "tool_name": tool.original_name, "tool_id": tool.id, "server_url": server_url, "transport": "sse"}, + ) + + try: + async with sse_client(url=server_url, headers=headers, httpx_client_factory=get_httpx_client_factory) as streams: + async with ClientSession(*streams) as session: + await session.initialize() + tool_call_result = await session.call_tool(tool.original_name, arguments) + + # Log successful MCP call + mcp_duration_ms = (time.time() - mcp_start_time) * 1000 + structured_logger.log( + level="INFO", + message=f"MCP tool call completed: {tool.original_name}", + component="tool_service", + correlation_id=correlation_id, + duration_ms=mcp_duration_ms, + metadata={"event": "mcp_call_completed", "tool_name": tool.original_name, "tool_id": tool.id, "transport": "sse", "success": True}, + ) + + return tool_call_result + except Exception as e: + # Log failed MCP call + mcp_duration_ms = (time.time() - mcp_start_time) * 1000 + structured_logger.log( + level="ERROR", + message=f"MCP tool call failed: {tool.original_name}", + component="tool_service", + correlation_id=correlation_id, + duration_ms=mcp_duration_ms, + error_details={"error_type": type(e).__name__, "error_message": str(e)}, + metadata={"event": "mcp_call_failed", "tool_name": tool.original_name, "tool_id": tool.id, "transport": "sse"}, + ) + raise async def connect_to_streamablehttp_server(server_url: str, headers: dict = headers): """Connect to an MCP server running with Streamable HTTP transport. @@ -1461,12 +1752,58 @@ async def connect_to_streamablehttp_server(server_url: str, headers: dict = head Returns: ToolResult: Result of tool call + + Raises: + Exception: On connection or communication errors """ - async with streamablehttp_client(url=server_url, headers=headers, httpx_client_factory=get_httpx_client_factory) as (read_stream, write_stream, _get_session_id): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - tool_call_result = await session.call_tool(tool.original_name, arguments) - return tool_call_result + # Get correlation ID for distributed tracing + correlation_id = get_correlation_id() + + # Add correlation ID to headers + if correlation_id and headers: + headers["X-Correlation-ID"] = correlation_id + + # Log MCP call start + mcp_start_time = time.time() + structured_logger.log( + level="INFO", + message=f"MCP tool call started: {tool.original_name}", + component="tool_service", + correlation_id=correlation_id, + metadata={"event": "mcp_call_started", "tool_name": tool.original_name, "tool_id": tool.id, "server_url": server_url, "transport": "streamablehttp"}, + ) + + try: + async with streamablehttp_client(url=server_url, headers=headers, httpx_client_factory=get_httpx_client_factory) as (read_stream, write_stream, _get_session_id): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + tool_call_result = await session.call_tool(tool.original_name, arguments) + + # Log successful MCP call + mcp_duration_ms = (time.time() - mcp_start_time) * 1000 + structured_logger.log( + level="INFO", + message=f"MCP tool call completed: {tool.original_name}", + component="tool_service", + correlation_id=correlation_id, + duration_ms=mcp_duration_ms, + metadata={"event": "mcp_call_completed", "tool_name": tool.original_name, "tool_id": tool.id, "transport": "streamablehttp", "success": True}, + ) + + return tool_call_result + except Exception as e: + # Log failed MCP call + mcp_duration_ms = (time.time() - mcp_start_time) * 1000 + structured_logger.log( + level="ERROR", + message=f"MCP tool call failed: {tool.original_name}", + component="tool_service", + correlation_id=correlation_id, + duration_ms=mcp_duration_ms, + error_details={"error_type": type(e).__name__, "error_message": str(e)}, + metadata={"event": "mcp_call_failed", "tool_name": tool.original_name, "tool_id": tool.id, "transport": "streamablehttp"}, + ) + raise tool_gateway_id = tool.gateway_id tool_gateway = db.execute(select(DbGateway).where(DbGateway.id == tool_gateway_id).where(DbGateway.enabled)).scalar_one_or_none() @@ -1546,12 +1883,44 @@ async def connect_to_streamablehttp_server(server_url: str, headers: dict = head span.set_attribute("error.message", str(e)) raise ToolInvocationError(f"Tool invocation failed: {error_message}") finally: + # Calculate duration + duration_ms = (time.monotonic() - start_time) * 1000 + # Add final span attributes if span: span.set_attribute("success", success) - span.set_attribute("duration.ms", (time.monotonic() - start_time) * 1000) + span.set_attribute("duration.ms", duration_ms) + + # Record tool metric await self._record_tool_metric(db, tool, start_time, success, error_message) + # Log structured message with performance tracking + if success: + structured_logger.info( + f"Tool '{name}' invoked successfully", + user_id=app_user_email, + resource_type="tool", + resource_id=str(tool.id), + resource_action="invoke", + duration_ms=duration_ms, + custom_fields={"tool_name": name, "integration_type": tool.integration_type, "arguments_count": len(arguments) if arguments else 0}, + ) + else: + structured_logger.error( + f"Tool '{name}' invocation failed", + error=Exception(error_message) if error_message else None, + user_id=app_user_email, + resource_type="tool", + resource_id=str(tool.id), + resource_action="invoke", + duration_ms=duration_ms, + custom_fields={"tool_name": name, "integration_type": tool.integration_type, "error_message": error_message}, + ) + + # Track performance with threshold checking + with perf_tracker.track_operation("tool_invocation", name): + pass # Duration already captured above + async def update_tool( self, db: Session, @@ -1696,24 +2065,142 @@ async def update_tool( db.refresh(tool) await self._notify_tool_updated(tool) logger.info(f"Updated tool: {tool.name}") + + # Structured logging: Audit trail for tool update + changes = [] + if tool_update.name: + changes.append(f"name: {tool_update.name}") + if tool_update.visibility: + changes.append(f"visibility: {tool_update.visibility}") + if tool_update.description: + changes.append("description updated") + + audit_trail.log_action( + user_id=user_email or modified_by or "system", + action="update_tool", + resource_type="tool", + resource_id=tool.id, + resource_name=tool.name, + user_email=user_email, + team_id=tool.team_id, + client_ip=modified_from_ip, + user_agent=modified_user_agent, + new_values={ + "name": tool.name, + "display_name": tool.display_name, + "version": tool.version, + }, + context={ + "modified_via": modified_via, + "changes": ", ".join(changes) if changes else "metadata only", + }, + db=db, + ) + + # Structured logging: Log successful tool update + structured_logger.log( + level="INFO", + message="Tool updated successfully", + event_type="tool_updated", + component="tool_service", + user_id=modified_by, + user_email=user_email, + team_id=tool.team_id, + resource_type="tool", + resource_id=tool.id, + custom_fields={ + "tool_name": tool.name, + "version": tool.version, + }, + db=db, + ) + return self._convert_tool_to_read(tool) - except PermissionError: + except PermissionError as pe: db.rollback() + + # Structured logging: Log permission error + structured_logger.log( + level="WARNING", + message="Tool update failed due to permission error", + event_type="tool_update_permission_denied", + component="tool_service", + user_email=user_email, + resource_type="tool", + resource_id=tool_id, + error=pe, + db=db, + ) raise except IntegrityError as ie: db.rollback() logger.error(f"IntegrityError during tool update: {ie}") + + # Structured logging: Log database integrity error + structured_logger.log( + level="ERROR", + message="Tool update failed due to database integrity error", + event_type="tool_update_failed", + component="tool_service", + user_id=modified_by, + user_email=user_email, + resource_type="tool", + resource_id=tool_id, + error=ie, + db=db, + ) raise ie except ToolNotFoundError as tnfe: db.rollback() logger.error(f"Tool not found during update: {tnfe}") + + # Structured logging: Log not found error + structured_logger.log( + level="ERROR", + message="Tool update failed - tool not found", + event_type="tool_not_found", + component="tool_service", + user_email=user_email, + resource_type="tool", + resource_id=tool_id, + error=tnfe, + db=db, + ) raise tnfe except ToolNameConflictError as tnce: db.rollback() logger.error(f"Tool name conflict during update: {tnce}") + + # Structured logging: Log name conflict error + structured_logger.log( + level="WARNING", + message="Tool update failed due to name conflict", + event_type="tool_name_conflict", + component="tool_service", + user_id=modified_by, + user_email=user_email, + resource_type="tool", + resource_id=tool_id, + error=tnce, + db=db, + ) raise tnce except Exception as ex: db.rollback() + + # Structured logging: Log generic tool update failure + structured_logger.log( + level="ERROR", + message="Tool update failed", + event_type="tool_update_failed", + component="tool_service", + user_id=modified_by, + user_email=user_email, + resource_type="tool", + resource_id=tool_id, + error=ex, + db=db, + ) raise ToolError(f"Failed to update tool: {str(ex)}") async def _notify_tool_updated(self, tool: DbTool) -> None: diff --git a/mcpgateway/static/admin.js b/mcpgateway/static/admin.js index 2d5b72563..f3e73d21a 100644 --- a/mcpgateway/static/admin.js +++ b/mcpgateway/static/admin.js @@ -300,6 +300,8 @@ function validateInputName(name, type = "input") { /** * Extracts content from various formats with fallback */ + +/** function extractContent(content, fallback = "") { if (typeof content === "object" && content !== null) { if (content.text !== undefined && content.text !== null) { @@ -314,6 +316,7 @@ function extractContent(content, fallback = "") { } return String(content || fallback); } + */ /** * SECURITY: Validate URL inputs @@ -705,6 +708,8 @@ function closeModal(modalId, clearId = null) { cleanupToolTestModal(); // ADD THIS LINE } else if (modalId === "prompt-test-modal") { cleanupPromptTestModal(); + } else if (modalId === "resource-test-modal") { + cleanupResourceTestModal(); } modal.classList.add("hidden"); @@ -3569,6 +3574,295 @@ function toggleA2AAuthFields(authType) { } } +// -------------------- Resource Testing ------------------ // + +// ----- URI Template Parsing -------------- // +function parseUriTemplate(template) { + const regex = /{([^}]+)}/g; + const fields = []; + let match; + + while ((match = regex.exec(template)) !== null) { + fields.push(match[1]); // capture inside {} + } + return fields; +} + +async function testResource(resourceId) { + try { + console.log(`Testing the resource: ${resourceId}`); + + const response = await fetchWithTimeout( + `${window.ROOT_PATH}/admin/resources/${encodeURIComponent(resourceId)}`, + ); + + if (!response.ok) { + let errorDetail = ""; + try { + const errorJson = await response.json(); + errorDetail = errorJson.detail || ""; + } catch (_) {} + + throw new Error( + `HTTP ${response.status}: ${errorDetail || response.statusText}`, + ); + } + + const data = await response.json(); + const resource = data.resource; + // console.log("Resource JSON:\n", JSON.stringify(resource, null, 2)); + openResourceTestModal(resource); + } catch (error) { + console.error("Error fetching resource details:", error); + const errorMessage = handleFetchError(error, "load resource details"); + showErrorMessage(errorMessage); + } +} + +function openResourceTestModal(resource) { + const title = document.getElementById("resource-test-modal-title"); + const fieldsContainer = document.getElementById( + "resource-test-form-fields", + ); + const resultBox = document.getElementById("resource-test-result"); + + title.textContent = `Test Resource: ${resource.name}`; + + fieldsContainer.innerHTML = ""; + resultBox.textContent = "Fill the fields and click Invoke Resource"; + + // 1️⃣ Build form fields ONLY if uriTemplate exists + if (resource.uriTemplate) { + const fieldNames = parseUriTemplate(resource.uriTemplate); + + fieldNames.forEach((name) => { + const div = document.createElement("div"); + div.className = "space-y-1"; + + div.innerHTML = ` + + + `; + + fieldsContainer.appendChild(div); + }); + } else { + // 2️⃣ If no template → show a simple message + fieldsContainer.innerHTML = ` +
+ This resource has no URI template. + Click "Invoke Resource" to test directly. +
+ `; + } + + window.CurrentResourceUnderTest = resource; + openModal("resource-test-modal"); +} + +async function runResourceTest() { + const resource = window.CurrentResourceUnderTest; + if (!resource) { + return; + } + + let finalUri = ""; + + if (resource.uriTemplate) { + finalUri = resource.uriTemplate; + + const fieldNames = parseUriTemplate(resource.uriTemplate); + fieldNames.forEach((name) => { + const value = document.getElementById( + `resource-field-${name}`, + ).value; + finalUri = finalUri.replace(`{${name}}`, encodeURIComponent(value)); + }); + } else { + finalUri = resource.uri; // direct test + } + + console.log("Final URI:", finalUri); + + const response = await fetchWithTimeout( + `${window.ROOT_PATH}/admin/resources/test/${encodeURIComponent(finalUri)}`, + ); + + const json = await response.json(); + + const resultBox = document.getElementById("resource-test-result"); + resultBox.innerHTML = ""; // clear previous + + const container = document.createElement("div"); + resultBox.appendChild(container); + + // Extract the content text (fallback if missing) + const content = json.content || {}; + let contentStr = content.text || JSON.stringify(content, null, 2); + + // Try to prettify JSON content + try { + const parsed = JSON.parse(contentStr); + contentStr = JSON.stringify(parsed, null, 2); + } catch (_) {} + + // ---- Content Section (same as prompt tester) ---- + const contentSection = document.createElement("div"); + contentSection.className = "mt-4"; + + // Header + const contentHeader = document.createElement("div"); + contentHeader.className = + "flex items-center justify-between cursor-pointer select-none p-2 bg-gray-200 dark:bg-gray-700 rounded"; + contentSection.appendChild(contentHeader); + + // Title + const contentTitle = document.createElement("strong"); + contentTitle.textContent = "Content"; + contentHeader.appendChild(contentTitle); + + // Right controls (arrow/copy/fullscreen/download) + const headerRight = document.createElement("div"); + headerRight.className = "flex items-center space-x-2"; + contentHeader.appendChild(headerRight); + + // Arrow icon + const toggleIcon = document.createElement("span"); + toggleIcon.innerHTML = "▶"; + toggleIcon.className = "transform transition-transform text-xs"; + headerRight.appendChild(toggleIcon); + + // Copy button + const copyBtn = document.createElement("button"); + copyBtn.textContent = "Copy"; + copyBtn.className = + "text-xs px-2 py-1 rounded bg-gray-300 dark:bg-gray-600 hover:bg-gray-400 dark:hover:bg-gray-500"; + headerRight.appendChild(copyBtn); + + // Fullscreen button + const fullscreenBtn = document.createElement("button"); + fullscreenBtn.textContent = "Fullscreen"; + fullscreenBtn.className = + "text-xs px-2 py-1 rounded bg-blue-300 dark:bg-blue-600 hover:bg-blue-400 dark:hover:bg-blue-500"; + headerRight.appendChild(fullscreenBtn); + + // Download button + const downloadBtn = document.createElement("button"); + downloadBtn.textContent = "Download"; + downloadBtn.className = + "text-xs px-2 py-1 rounded bg-green-300 dark:bg-green-600 hover:bg-green-400 dark:hover:bg-green-500"; + headerRight.appendChild(downloadBtn); + + // Collapsible body + const contentBody = document.createElement("div"); + contentBody.className = "hidden mt-2"; + contentSection.appendChild(contentBody); + + // Pre block + const contentPre = document.createElement("pre"); + contentPre.className = + "bg-gray-100 p-2 rounded overflow-auto max-h-80 dark:bg-gray-800 dark:text-gray-100 text-sm whitespace-pre-wrap"; + contentPre.textContent = contentStr; + contentBody.appendChild(contentPre); + + // Auto-collapse if too large + const lineCount = contentStr.split("\n").length; + + if (lineCount > 30) { + contentBody.classList.add("hidden"); + toggleIcon.style.transform = "rotate(0deg)"; + contentTitle.textContent = "Content (Large - Click to expand)"; + } else { + contentBody.classList.remove("hidden"); + toggleIcon.style.transform = "rotate(90deg)"; + } + + // Toggle expand/collapse + contentHeader.onclick = () => { + contentBody.classList.toggle("hidden"); + toggleIcon.style.transform = contentBody.classList.contains("hidden") + ? "rotate(0deg)" + : "rotate(90deg)"; + }; + + // Copy button + copyBtn.onclick = (event) => { + event.stopPropagation(); + navigator.clipboard.writeText(contentStr).then(() => { + copyBtn.textContent = "Copied!"; + setTimeout(() => (copyBtn.textContent = "Copy"), 1200); + }); + }; + + // Fullscreen mode + fullscreenBtn.onclick = (event) => { + event.stopPropagation(); + + const overlay = document.createElement("div"); + overlay.className = + "fixed inset-0 bg-black bg-opacity-70 z-[9999] flex items-center justify-center p-4"; + + const box = document.createElement("div"); + box.className = + "bg-white dark:bg-gray-900 rounded-lg w-full h-full p-4 overflow-auto"; + + const closeBtn = document.createElement("button"); + closeBtn.textContent = "Close"; + closeBtn.className = + "text-xs px-3 py-1 mb-2 rounded bg-red-400 hover:bg-red-500 dark:bg-red-700 dark:hover:bg-red-600"; + + closeBtn.onclick = () => overlay.remove(); + + const fsPre = document.createElement("pre"); + fsPre.className = + "bg-gray-100 p-4 rounded overflow-auto h-full dark:bg-gray-800 dark:text-gray-100 text-sm whitespace-pre-wrap"; + fsPre.textContent = contentStr; + + box.appendChild(closeBtn); + box.appendChild(fsPre); + overlay.appendChild(box); + document.body.appendChild(overlay); + }; + + // Download + downloadBtn.onclick = (event) => { + event.stopPropagation(); + + let blob; + let filename; + + // JSON? + try { + JSON.parse(contentStr); + blob = new Blob([contentStr], { type: "application/json" }); + filename = "resource.json"; + } catch (_) { + blob = new Blob([contentStr], { type: "text/plain" }); + filename = "resource.txt"; + } + + const url = URL.createObjectURL(blob); + const a = document.createElement("a"); + a.href = url; + a.download = filename; + document.body.appendChild(a); + a.click(); + document.body.removeChild(a); + URL.revokeObjectURL(url); + }; + + container.appendChild(contentSection); + + // resultBox.textContent = JSON.stringify(json, null, 2); +} + +// -------------------- Resource Testing ------------------ // + /** * SECURE: View Resource function with safe display */ @@ -3594,7 +3888,9 @@ async function viewResource(resourceId) { const data = await response.json(); const resource = data.resource; - const content = data.content; + + // console.log("Resource JSON:\n", JSON.stringify(resource, null, 2)); + // const content = data.content; const resourceDetailsDiv = safeGetElement("resource-details"); if (resourceDetailsDiv) { @@ -3649,39 +3945,41 @@ async function viewResource(resourceId) { statusStrong.textContent = "Status: "; statusP.appendChild(statusStrong); + const isActive = resource.enabled === true; const statusSpan = document.createElement("span"); statusSpan.className = `px-2 inline-flex text-xs leading-5 font-semibold rounded-full ${ - resource.isActive + isActive ? "bg-green-100 text-green-800" : "bg-red-100 text-red-800" }`; - statusSpan.textContent = resource.isActive ? "Active" : "Inactive"; + statusSpan.textContent = isActive ? "Active" : "Inactive"; + statusP.appendChild(statusSpan); container.appendChild(statusP); // Content display - safely handle different types - const contentDiv = document.createElement("div"); - const contentStrong = document.createElement("strong"); - contentStrong.textContent = "Content:"; - contentDiv.appendChild(contentStrong); + // const contentDiv = document.createElement("div"); + // const contentStrong = document.createElement("strong"); + // contentStrong.textContent = "Content:"; + // contentDiv.appendChild(contentStrong); - const contentPre = document.createElement("pre"); - contentPre.className = - "mt-1 bg-gray-100 p-2 rounded overflow-auto max-h-80 dark:bg-gray-800 dark:text-gray-100"; + // const contentPre = document.createElement("pre"); + // contentPre.className = + // "mt-1 bg-gray-100 p-2 rounded overflow-auto max-h-80 dark:bg-gray-800 dark:text-gray-100"; - // Handle content display - extract actual content from object if needed - let contentStr = extractContent( - content, - resource.description || "No content available", - ); + // // Handle content display - extract actual content from object if needed + // let contentStr = extractContent( + // content, + // resource.description || "No content available", + // ); - if (!contentStr.trim()) { - contentStr = resource.description || "No content available"; - } + // if (!contentStr.trim()) { + // contentStr = resource.description || "No content available"; + // } - contentPre.textContent = contentStr; - contentDiv.appendChild(contentPre); - container.appendChild(contentDiv); + // contentPre.textContent = contentStr; + // contentDiv.appendChild(contentPre); + // container.appendChild(contentDiv); // Metrics display if (resource.metrics) { @@ -3855,21 +4153,29 @@ async function viewResource(resourceId) { /** * SECURE: Edit Resource function with validation */ -async function editResource(resourceUri) { +async function editResource(resourceId) { try { - console.log(`Editing resource: ${resourceUri}`); + console.log(`Editing resource: ${resourceId}`); const response = await fetchWithTimeout( - `${window.ROOT_PATH}/admin/resources/${encodeURIComponent(resourceUri)}`, + `${window.ROOT_PATH}/admin/resources/${encodeURIComponent(resourceId)}`, ); if (!response.ok) { - throw new Error(`HTTP ${response.status}: ${response.statusText}`); + let errorDetail = ""; + try { + const errorJson = await response.json(); + errorDetail = errorJson.detail || ""; + } catch (_) {} + + throw new Error( + `HTTP ${response.status}: ${errorDetail || response.statusText}`, + ); } const data = await response.json(); const resource = data.resource; - const content = data.content; + // const content = data.content; // Ensure hidden inactive flag is preserved const isInactiveCheckedBool = isInactiveChecked("resources"); let hiddenField = safeGetElement("edit-resource-show-inactive"); @@ -3917,7 +4223,7 @@ async function editResource(resourceUri) { // Set form action and populate fields with validation if (editForm) { - editForm.action = `${window.ROOT_PATH}/admin/resources/${encodeURIComponent(resourceUri)}/edit`; + editForm.action = `${window.ROOT_PATH}/admin/resources/${encodeURIComponent(resourceId)}/edit`; } // Validate inputs @@ -3928,7 +4234,7 @@ async function editResource(resourceUri) { const nameField = safeGetElement("edit-resource-name"); const descField = safeGetElement("edit-resource-description"); const mimeField = safeGetElement("edit-resource-mime-type"); - const contentField = safeGetElement("edit-resource-content"); + // const contentField = safeGetElement("edit-resource-content"); if (uriField && uriValidation.valid) { uriField.value = uriValidation.value; @@ -3956,33 +4262,33 @@ async function editResource(resourceUri) { tagsField.value = rawTags.join(", "); } - if (contentField) { - let contentStr = extractContent( - content, - resource.description || "No content available", - ); + // if (contentField) { + // let contentStr = extractContent( + // content, + // resource.description || "No content available", + // ); - if (!contentStr.trim()) { - contentStr = resource.description || "No content available"; - } + // if (!contentStr.trim()) { + // contentStr = resource.description || "No content available"; + // } - contentField.value = contentStr; - } + // contentField.value = contentStr; + // } - // Update CodeMirror editor if it exists - if (window.editResourceContentEditor) { - let contentStr = extractContent( - content, - resource.description || "No content available", - ); + // // Update CodeMirror editor if it exists + // if (window.editResourceContentEditor) { + // let contentStr = extractContent( + // content, + // resource.description || "No content available", + // ); - if (!contentStr.trim()) { - contentStr = resource.description || "No content available"; - } + // if (!contentStr.trim()) { + // contentStr = resource.description || "No content available"; + // } - window.editResourceContentEditor.setValue(contentStr); - window.editResourceContentEditor.refresh(); - } + // window.editResourceContentEditor.setValue(contentStr); + // window.editResourceContentEditor.refresh(); + // } openModal("resource-edit-modal"); @@ -6267,6 +6573,14 @@ function showTab(tabName) { initializeLLMChat(); } + if (tabName === "logs") { + // Load structured logs when tab is first opened + const logsTbody = safeGetElement("logs-tbody"); + if (logsTbody && logsTbody.children.length === 0) { + searchStructuredLogs(); + } + } + if (tabName === "teams") { // Load Teams list if not already loaded const teamsList = safeGetElement("teams-list"); @@ -9277,19 +9591,22 @@ async function loadTools() { console.log("Loading tools..."); try { if (toolBody !== null) { - toolBody.innerHTML = ` + toolBody.innerHTML = ` Loading tools... `; - const response = await fetch(`${window.ROOT_PATH}/tools`, { + const response = await fetch(`${window.ROOT_PATH}/admin/tools`, { method: "GET", }); if (!response.ok) { throw new Error("Failed to load tools"); } - const tools = await response.json(); // 👈 expect JSON array + let tools = await response.json(); // 👈 expect JSON array + if ("data" in tools) { + tools = tools.data; + } console.log("Fetched tools:", tools); // document.getElementById("temp_lable").innerText = `Loaded ${tools.length} tools`; @@ -11688,6 +12005,42 @@ async function runPromptTest() { } } +/** + * Clean up resource test modal state + */ +function cleanupResourceTestModal() { + try { + // Clear stored state + window.CurrentResourceUnderTest = null; + + // Reset form fields container + const fieldsContainer = safeGetElement("resource-test-form-fields"); + if (fieldsContainer) { + fieldsContainer.innerHTML = ""; + } + + // Reset result box + const resultBox = safeGetElement("resource-test-result"); + if (resultBox) { + resultBox.innerHTML = ` +
+ Fill the fields and click Invoke Resource +
+ `; + } + + // Hide loading if exists + const loading = safeGetElement("resource-test-loading"); + if (loading) { + loading.classList.add("hidden"); + } + + console.log("✓ Resource test modal cleaned up"); + } catch (err) { + console.error("Error cleaning up resource test modal:", err); + } +} + /** * Clean up prompt test modal state */ @@ -15452,6 +15805,8 @@ window.editTool = editTool; window.testTool = testTool; window.validateTool = validateTool; window.viewResource = viewResource; +window.runResourceTest = runResourceTest; +window.testResource = testResource; window.editResource = editResource; window.viewPrompt = viewPrompt; window.editPrompt = editPrompt; @@ -18159,11 +18514,29 @@ async function getAuthToken() { if (!token) { token = localStorage.getItem("auth_token"); } - console.log("MY TOKEN GENERATED:", token); - return token || ""; } +/** + * Fetch helper that always includes auth context. + * Ensures HTTP-only cookies are sent even when JS cannot read them. + */ +async function fetchWithAuth(url, options = {}) { + const opts = { ...options }; + // Always send same-origin cookies unless caller overrides explicitly + opts.credentials = options.credentials || "same-origin"; + + // Clone headers to avoid mutating caller-provided object + const headers = new Headers(options.headers || {}); + const token = await getAuthToken(); + if (token) { + headers.set("Authorization", `Bearer ${token}`); + } + opts.headers = headers; + + return fetch(url, opts); +} + // Expose token management functions to global scope window.loadTokensList = loadTokensList; window.setupCreateTokenForm = setupCreateTokenForm; @@ -23632,89 +24005,460 @@ function updateEntityStatus(type, data) { updateEntityActionButtons(actionCell, type, data.id, isEnabled); } } +// ============================================================================ +// Structured Logging UI Functions +// ============================================================================ -/** - * Generates the HTML for the status badge (Active/Inactive/Offline) - */ -function generateStatusBadgeHtml(enabled, reachable, typeLabel) { - const label = typeLabel - ? typeLabel.charAt(0).toUpperCase() + typeLabel.slice(1) - : "Item"; +// Current log search state +let currentLogPage = 0; +const currentLogLimit = 50; +// eslint-disable-next-line no-unused-vars +let currentLogFilters = {}; +const PERFORMANCE_HISTORY_HOURS = 24; +const PERFORMANCE_AGGREGATION_OPTIONS = { + "5m": { label: "5-minute aggregation", query: "5m" }, + "24h": { label: "24-hour aggregation", query: "24h" }, +}; +let currentPerformanceAggregationKey = "5m"; - if (!enabled) { - // CASE 1: Inactive (Manually disabled) -> RED - return ` -
- - Inactive - - - -
`; - } else if (!reachable) { - // CASE 2: Offline (Enabled but Unreachable/Health Check Failed) -> YELLOW - return ` -
- - Offline - - - -
`; - } else { - // CASE 3: Active (Enabled and Reachable) -> GREEN - return ` -
- - Active - - - -
`; - } +function getPerformanceAggregationConfig( + rangeKey = currentPerformanceAggregationKey, +) { + return ( + PERFORMANCE_AGGREGATION_OPTIONS[rangeKey] || + PERFORMANCE_AGGREGATION_OPTIONS["5m"] + ); } -/** - * Dynamically updates the action buttons (Activate/Deactivate) inside the table cell - */ +function getPerformanceAggregationLabel( + rangeKey = currentPerformanceAggregationKey, +) { + return getPerformanceAggregationConfig(rangeKey).label; +} -function updateEntityActionButtons(cell, type, id, isEnabled) { - // We look for the form that toggles activation inside the cell - const form = cell.querySelector('form[action*="/toggle"]'); - if (!form) { - return; - } +function getPerformanceAggregationQuery( + rangeKey = currentPerformanceAggregationKey, +) { + return getPerformanceAggregationConfig(rangeKey).query; +} - // The HTML structure for the button - // Ensure we are flipping the button state correctly based on isEnabled +function syncPerformanceAggregationSelect() { + const select = document.getElementById("performance-aggregation-select"); + if (select && select.value !== currentPerformanceAggregationKey) { + select.value = currentPerformanceAggregationKey; + } +} - if (isEnabled) { - // If Enabled -> Show Deactivate Button - form.innerHTML = ` - - - `; +function setPerformanceAggregationVisibility(shouldShow) { + const controls = document.getElementById( + "performance-aggregation-controls", + ); + if (!controls) { + return; + } + if (shouldShow) { + controls.classList.remove("hidden"); } else { - // If Disabled -> Show Activate Button - form.innerHTML = ` - - - `; + controls.classList.add("hidden"); } } -// CRITICAL DEBUG AND FIX FOR MCP SERVERS SEARCH -console.log("🔧 LOADING MCP SERVERS SEARCH DEBUG FUNCTIONS..."); +function setLogFiltersVisibility(shouldShow) { + const filters = document.getElementById("log-filters"); + if (!filters) { + return; + } + if (shouldShow) { + filters.classList.remove("hidden"); + } else { + filters.classList.add("hidden"); + } +} -// Emergency fix function for MCP Servers search -window.emergencyFixMCPSearch = function () { - console.log("🚨 EMERGENCY FIX: Attempting to fix MCP Servers search..."); +function handlePerformanceAggregationChange(event) { + const selectedKey = event?.target?.value; + if (selectedKey && PERFORMANCE_AGGREGATION_OPTIONS[selectedKey]) { + showPerformanceMetrics(selectedKey); + } +} - // Find the search input +/** + * Search structured logs with filters + */ +async function searchStructuredLogs() { + setPerformanceAggregationVisibility(false); + setLogFiltersVisibility(true); + const levelFilter = document.getElementById("log-level-filter")?.value; + const componentFilter = document.getElementById( + "log-component-filter", + )?.value; + const searchQuery = document.getElementById("log-search")?.value; + + // Restore default log table headers (in case we're coming from performance metrics view) + restoreLogTableHeaders(); + + // Build search request + const searchRequest = { + limit: currentLogLimit, + offset: currentLogPage * currentLogLimit, + sort_by: "timestamp", + sort_order: "desc", + }; + + // Only add filters if they have actual values (not empty strings) + if (searchQuery && searchQuery.trim() !== "") { + const trimmedSearch = searchQuery.trim(); + // Check if search is a correlation ID (32 hex chars or UUID format) or text search + const correlationIdPattern = + /^([0-9a-f]{32}|[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})$/i; + if (correlationIdPattern.test(trimmedSearch)) { + searchRequest.correlation_id = trimmedSearch; + } else { + searchRequest.search_text = trimmedSearch; + } + } + if (levelFilter && levelFilter !== "") { + searchRequest.level = [levelFilter]; + } + if (componentFilter && componentFilter !== "") { + searchRequest.component = [componentFilter]; + } + + // Store filters for pagination + currentLogFilters = searchRequest; + + try { + const response = await fetchWithAuth( + `${getRootPath()}/api/logs/search`, + { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify(searchRequest), + }, + ); + + if (!response.ok) { + const errorText = await response.text(); + console.error("API Error Response:", errorText); + throw new Error( + `Failed to search logs: ${response.statusText} - ${errorText}`, + ); + } + + const data = await response.json(); + displayLogResults(data); + } catch (error) { + console.error("Error searching logs:", error); + showToast("Failed to search logs: " + error.message, "error"); + document.getElementById("logs-tbody").innerHTML = ` + + ❌ Error: ${escapeHtml(error.message)} + + `; + } +} + +/** + * Display log search results + */ +function displayLogResults(data) { + const tbody = document.getElementById("logs-tbody"); + const logCount = document.getElementById("log-count"); + const logStats = document.getElementById("log-stats"); + const prevButton = document.getElementById("prev-page"); + const nextButton = document.getElementById("next-page"); + + // Ensure default headers are shown for log view + restoreLogTableHeaders(); + + if (!data.results || data.results.length === 0) { + tbody.innerHTML = ` + + 📭 No logs found matching your criteria + + `; + logCount.textContent = "0 logs"; + logStats.innerHTML = 'No results'; + return; + } + + // Update stats + logCount.textContent = `${data.total.toLocaleString()} logs`; + const start = currentLogPage * currentLogLimit + 1; + const end = Math.min(start + data.results.length - 1, data.total); + logStats.innerHTML = ` + + Showing ${start}-${end} of ${data.total.toLocaleString()} logs + + `; + + // Update pagination buttons + prevButton.disabled = currentLogPage === 0; + nextButton.disabled = end >= data.total; + + // Render log entries + tbody.innerHTML = data.results + .map((log) => { + const levelClass = getLogLevelClass(log.level); + const durationDisplay = log.duration_ms + ? `${log.duration_ms.toFixed(2)}ms` + : "-"; + const correlationId = log.correlation_id || "-"; + const userDisplay = log.user_email || log.user_id || "-"; + + return ` + + + ${formatTimestamp(log.timestamp)} + + + + ${log.level} + + + + ${escapeHtml(log.component || "-")} + + + ${escapeHtml(truncateText(log.message, 80))} + ${log.error_details ? '⚠️' : ""} + + + ${escapeHtml(userDisplay)} + + + ${durationDisplay} + + + ${ + correlationId !== "-" + ? ` + + ` + : "-" + } + + + `; + }) + .join(""); +} + +/** + * Get CSS class for log level badge + */ +function getLogLevelClass(level) { + const classes = { + DEBUG: "bg-gray-200 text-gray-800 dark:bg-gray-600 dark:text-gray-200", + INFO: "bg-blue-200 text-blue-800 dark:bg-blue-800 dark:text-blue-200", + WARNING: + "bg-yellow-200 text-yellow-800 dark:bg-yellow-800 dark:text-yellow-200", + ERROR: "bg-red-200 text-red-800 dark:bg-red-800 dark:text-red-200", + CRITICAL: + "bg-purple-200 text-purple-800 dark:bg-purple-800 dark:text-purple-200", + }; + return classes[level] || classes.INFO; +} + +/** + * Format timestamp for display + */ +function formatTimestamp(timestamp) { + const date = new Date(timestamp); + return date.toLocaleString("en-US", { + month: "short", + day: "numeric", + hour: "2-digit", + minute: "2-digit", + second: "2-digit", + }); +} + +/** + * Truncate text with ellipsis + */ +function truncateText(text, maxLength) { + if (!text) { + return ""; + } + return text.length > maxLength + ? text.substring(0, maxLength) + "..." + : text; +} + +/** + * Show detailed log entry (future enhancement - modal) + */ +function showLogDetails(logId, correlationId) { + if (correlationId) { + showCorrelationTrace(correlationId); + } else { + console.log("Log details:", logId); + showToast("Full log details view coming soon", "info"); + } +} + +/** + * Restore default log table headers + */ +function restoreLogTableHeaders() { + const thead = document.getElementById("logs-thead"); + if (thead) { + thead.innerHTML = ` + + + Time + + + Level + + + Component + + + Message + + + User + + + Duration + + + Correlation ID + + + `; + } +} + +/** + * Trace all logs for a correlation ID + */ +async function showCorrelationTrace(correlationId) { + setPerformanceAggregationVisibility(false); + setLogFiltersVisibility(true); + if (!correlationId) { + const searchInput = document.getElementById("log-search"); + correlationId = prompt( + "Enter Correlation ID to trace:", + searchInput?.value || "", + ); + if (!correlationId) { + return; + } + } + + try { + const response = await fetchWithAuth( + `${getRootPath()}/api/logs/trace/${encodeURIComponent(correlationId)}`, + { + method: "GET", + }, + ); + + if (!response.ok) { + throw new Error(`Failed to fetch trace: ${response.statusText}`); + } + + const trace = await response.json(); + displayCorrelationTrace(trace); + } catch (error) { + console.error("Error fetching correlation trace:", error); + showToast( + "Failed to fetch correlation trace: " + error.message, + "error", + ); + } +} + +/** + * Generates the HTML for the status badge (Active/Inactive/Offline) + */ +function generateStatusBadgeHtml(enabled, reachable, typeLabel) { + const label = typeLabel + ? typeLabel.charAt(0).toUpperCase() + typeLabel.slice(1) + : "Item"; + + if (!enabled) { + // CASE 1: Inactive (Manually disabled) -> RED + return ` +
+ + Inactive + + + +
`; + } else if (!reachable) { + // CASE 2: Offline (Enabled but Unreachable/Health Check Failed) -> YELLOW + return ` +
+ + Offline + + + +
`; + } else { + // CASE 3: Active (Enabled and Reachable) -> GREEN + return ` +
+ + Active + + + +
`; + } +} + +/** + * Dynamically updates the action buttons (Activate/Deactivate) inside the table cell + */ +function updateEntityActionButtons(cell, type, id, isEnabled) { + // We look for the form that toggles activation inside the cell + const form = cell.querySelector('form[action*="/toggle"]'); + if (!form) { + return; + } + + // The HTML structure for the button + // Ensure we are flipping the button state correctly based on isEnabled + + if (isEnabled) { + // If Enabled -> Show Deactivate Button + form.innerHTML = ` + + + `; + } else { + // If Disabled -> Show Activate Button + form.innerHTML = ` + + + `; + } +} + +// CRITICAL DEBUG AND FIX FOR MCP SERVERS SEARCH +console.log("🔧 LOADING MCP SERVERS SEARCH DEBUG FUNCTIONS..."); + +// Emergency fix function for MCP Servers search +window.emergencyFixMCPSearch = function () { + console.log("🚨 EMERGENCY FIX: Attempting to fix MCP Servers search..."); + + // Find the search input const searchInput = document.getElementById("gateways-search-input"); if (!searchInput) { console.error("❌ Cannot find gateways-search-input element"); @@ -23790,3 +24534,719 @@ console.log("🔧 MCP SERVERS SEARCH DEBUG FUNCTIONS LOADED!"); console.log("💡 Use: window.emergencyFixMCPSearch() to fix search"); console.log("💡 Use: window.testMCPSearchManually('github') to test search"); console.log("💡 Use: window.debugMCPSearchState() to check current state"); + +/** + * Display correlation trace results + */ +function displayCorrelationTrace(trace) { + const tbody = document.getElementById("logs-tbody"); + const thead = document.getElementById("logs-thead"); + const logCount = document.getElementById("log-count"); + const logStats = document.getElementById("log-stats"); + + // Calculate total events + const totalEvents = + (trace.logs?.length || 0) + + (trace.security_events?.length || 0) + + (trace.audit_trails?.length || 0); + + // Update table headers for trace view + if (thead) { + thead.innerHTML = ` + + + Time + + + Event Type + + + Component + + + Message/Description + + + User + + + Duration + + + Status/Severity + + + `; + } + + // Update stats + logCount.textContent = `${totalEvents} events`; + logStats.innerHTML = ` +
+
+ Correlation ID:
+ ${escapeHtml(trace.correlation_id)} +
+
+ Logs: ${trace.log_count || 0} +
+
+ Security: ${trace.security_events?.length || 0} +
+
+ Audit: ${trace.audit_trails?.length || 0} +
+
+ Duration: ${trace.total_duration_ms ? trace.total_duration_ms.toFixed(2) + "ms" : "N/A"} +
+
+ `; + + if (totalEvents === 0) { + tbody.innerHTML = ` + + 📭 No events found for this correlation ID + + `; + return; + } + + // Combine all events into a unified timeline + const allEvents = []; + + // Add logs + (trace.logs || []).forEach((log) => { + const levelClass = getLogLevelClass(log.level); + allEvents.push({ + timestamp: new Date(log.timestamp), + html: ` + + + ${formatTimestamp(log.timestamp)} + + + + 📝 Log + + + + ${escapeHtml(log.component || "-")} + + + ${escapeHtml(log.message)} + ${log.error_details ? `
⚠️ ${escapeHtml(log.error_details.error_message || JSON.stringify(log.error_details))}` : ""} + + + ${escapeHtml(log.user_email || log.user_id || "-")} + + + ${log.duration_ms ? log.duration_ms.toFixed(2) + "ms" : "-"} + + + + ${log.level} + + + + `, + }); + }); + + // Add security events + (trace.security_events || []).forEach((event) => { + const severityClass = getSeverityClass(event.severity); + const threatScore = event.threat_score + ? (event.threat_score * 100).toFixed(0) + : 0; + allEvents.push({ + timestamp: new Date(event.timestamp), + html: ` + + + ${formatTimestamp(event.timestamp)} + + + + 🛡️ Security + + + + ${escapeHtml(event.event_type || "-")} + + + ${escapeHtml(event.description || "-")} + + + ${escapeHtml(event.user_email || event.user_id || "-")} + + + - + + +
+ + ${event.severity} + +
+ Threat: +
+
+
+ ${threatScore}% +
+
+ + + `, + }); + }); + + // Add audit trails + (trace.audit_trails || []).forEach((audit) => { + const actionBadgeColors = { + create: "bg-green-200 text-green-800", + update: "bg-blue-200 text-blue-800", + delete: "bg-red-200 text-red-800", + read: "bg-gray-200 text-gray-800", + }; + const actionBadge = + actionBadgeColors[audit.action?.toLowerCase()] || + "bg-purple-200 text-purple-800"; + const statusIcon = audit.success ? "✓" : "✗"; + const statusClass = audit.success ? "text-green-600" : "text-red-600"; + const statusBg = audit.success + ? "bg-green-100 dark:bg-green-900" + : "bg-red-100 dark:bg-red-900"; + + allEvents.push({ + timestamp: new Date(audit.timestamp), + html: ` + + + ${formatTimestamp(audit.timestamp)} + + + + 📋 ${audit.action?.toUpperCase()} + + + + ${escapeHtml(audit.resource_type || "-")} + + + ${audit.action}: ${audit.resource_type} + ${escapeHtml(audit.resource_id || "-")} + + + ${escapeHtml(audit.user_email || audit.user_id || "-")} + + + - + + + + ${statusIcon} ${audit.success ? "Success" : "Failed"} + + + + `, + }); + }); + + // Sort all events chronologically + allEvents.sort((a, b) => a.timestamp - b.timestamp); + + // Render sorted events + tbody.innerHTML = allEvents.map((event) => event.html).join(""); +} + +/** + * Show security events + */ +async function showSecurityEvents() { + setPerformanceAggregationVisibility(false); + setLogFiltersVisibility(false); + try { + const response = await fetchWithAuth( + `${getRootPath()}/api/logs/security-events?limit=50&resolved=false`, + { + method: "GET", + }, + ); + + if (!response.ok) { + throw new Error( + `Failed to fetch security events: ${response.statusText}`, + ); + } + + const events = await response.json(); + displaySecurityEvents(events); + } catch (error) { + console.error("Error fetching security events:", error); + showToast("Failed to fetch security events: " + error.message, "error"); + } +} + +/** + * Display security events + */ +function displaySecurityEvents(events) { + const tbody = document.getElementById("logs-tbody"); + const thead = document.getElementById("logs-thead"); + const logCount = document.getElementById("log-count"); + const logStats = document.getElementById("log-stats"); + + // Update table headers for security events + if (thead) { + thead.innerHTML = ` + + + Time + + + Severity + + + Event Type + + + Description + + + User/Source + + + Threat Score + + + Correlation ID + + + `; + } + + logCount.textContent = `${events.length} security events`; + logStats.innerHTML = ` + + 🛡️ Unresolved Security Events + + `; + + if (events.length === 0) { + tbody.innerHTML = ` + + ✅ No unresolved security events + + `; + return; + } + + tbody.innerHTML = events + .map((event) => { + const severityClass = getSeverityClass(event.severity); + const threatScore = (event.threat_score * 100).toFixed(0); + + return ` + + + ${formatTimestamp(event.timestamp)} + + + + ${event.severity} + + + + ${escapeHtml(event.event_type)} + + + ${escapeHtml(event.description)} + + + ${escapeHtml(event.user_email || event.user_id || "-")} + + +
+
+
+
+ ${threatScore}% +
+ + + ${ + event.correlation_id + ? ` + + ` + : "-" + } + + + `; + }) + .join(""); +} + +/** + * Get CSS class for severity badge + */ +function getSeverityClass(severity) { + const classes = { + LOW: "bg-blue-200 text-blue-800 dark:bg-blue-800 dark:text-blue-200", + MEDIUM: "bg-yellow-200 text-yellow-800 dark:bg-yellow-800 dark:text-yellow-200", + HIGH: "bg-orange-200 text-orange-800 dark:bg-orange-800 dark:text-orange-200", + CRITICAL: "bg-red-200 text-red-800 dark:bg-red-800 dark:text-red-200", + }; + return classes[severity] || classes.MEDIUM; +} + +/** + * Show audit trail + */ +async function showAuditTrail() { + setPerformanceAggregationVisibility(false); + setLogFiltersVisibility(false); + try { + const response = await fetchWithAuth( + `${getRootPath()}/api/logs/audit-trails?limit=50&requires_review=true`, + { + method: "GET", + }, + ); + + if (!response.ok) { + throw new Error( + `Failed to fetch audit trails: ${response.statusText}`, + ); + } + + const trails = await response.json(); + displayAuditTrail(trails); + } catch (error) { + console.error("Error fetching audit trails:", error); + showToast("Failed to fetch audit trails: " + error.message, "error"); + } +} + +/** + * Display audit trail entries + */ +function displayAuditTrail(trails) { + const tbody = document.getElementById("logs-tbody"); + const thead = document.getElementById("logs-thead"); + const logCount = document.getElementById("log-count"); + const logStats = document.getElementById("log-stats"); + + // Update table headers for audit trail + if (thead) { + thead.innerHTML = ` + + + Time + + + Action + + + Resource Type + + + Resource + + + User + + + Status + + + Correlation ID + + + `; + } + + logCount.textContent = `${trails.length} audit entries`; + logStats.innerHTML = ` + + 📝 Audit Trail Entries Requiring Review + + `; + + if (trails.length === 0) { + tbody.innerHTML = ` + + ✅ No audit entries require review + + `; + return; + } + + tbody.innerHTML = trails + .map((trail) => { + const actionClass = trail.success + ? "text-green-600" + : "text-red-600"; + const actionIcon = trail.success ? "✓" : "✗"; + + // Determine action badge color + const actionBadgeColors = { + create: "bg-green-200 text-green-800 dark:bg-green-800 dark:text-green-200", + update: "bg-blue-200 text-blue-800 dark:bg-blue-800 dark:text-blue-200", + delete: "bg-red-200 text-red-800 dark:bg-red-800 dark:text-red-200", + read: "bg-gray-200 text-gray-800 dark:bg-gray-600 dark:text-gray-200", + activate: + "bg-teal-200 text-teal-800 dark:bg-teal-800 dark:text-teal-200", + deactivate: + "bg-orange-200 text-orange-800 dark:bg-orange-800 dark:text-orange-200", + }; + const actionBadge = + actionBadgeColors[trail.action.toLowerCase()] || + "bg-purple-200 text-purple-800 dark:bg-purple-800 dark:text-purple-200"; + + // Format resource name with ID + const resourceName = + trail.resource_name || trail.resource_id || "-"; + const resourceDisplay = ` +
${escapeHtml(resourceName)}
+ ${trail.resource_id && trail.resource_name ? `
UUID: ${escapeHtml(trail.resource_id)}
` : ""} + ${trail.data_classification ? `
🔒 ${escapeHtml(trail.data_classification)}
` : ""} + `; + + return ` + + + ${formatTimestamp(trail.timestamp)} + + + + ${trail.action.toUpperCase()} + + + + ${escapeHtml(trail.resource_type || "-")} + + + ${resourceDisplay} + + + ${escapeHtml(trail.user_email || trail.user_id || "-")} + + + ${actionIcon} ${trail.success ? "Success" : "Failed"} + + + ${ + trail.correlation_id + ? ` + + ` + : "-" + } + + + `; + }) + .join(""); +} + +/** + * Show performance metrics + */ +async function showPerformanceMetrics(rangeKey) { + if (rangeKey && PERFORMANCE_AGGREGATION_OPTIONS[rangeKey]) { + currentPerformanceAggregationKey = rangeKey; + } else { + const select = document.getElementById( + "performance-aggregation-select", + ); + if (select?.value && PERFORMANCE_AGGREGATION_OPTIONS[select.value]) { + currentPerformanceAggregationKey = select.value; + } + } + + syncPerformanceAggregationSelect(); + setPerformanceAggregationVisibility(true); + setLogFiltersVisibility(false); + const hoursParam = encodeURIComponent(PERFORMANCE_HISTORY_HOURS.toString()); + const aggregationParam = encodeURIComponent( + getPerformanceAggregationQuery(), + ); + + try { + const response = await fetchWithAuth( + `${getRootPath()}/api/logs/performance-metrics?hours=${hoursParam}&aggregation=${aggregationParam}`, + { + method: "GET", + }, + ); + + if (!response.ok) { + throw new Error( + `Failed to fetch performance metrics: ${response.statusText}`, + ); + } + + const metrics = await response.json(); + displayPerformanceMetrics(metrics); + } catch (error) { + console.error("Error fetching performance metrics:", error); + showToast( + "Failed to fetch performance metrics: " + error.message, + "error", + ); + } +} + +/** + * Display performance metrics + */ +function displayPerformanceMetrics(metrics) { + const tbody = document.getElementById("logs-tbody"); + const thead = document.getElementById("logs-thead"); + const logCount = document.getElementById("log-count"); + const logStats = document.getElementById("log-stats"); + const aggregationLabel = getPerformanceAggregationLabel(); + + // Update table headers for performance metrics + if (thead) { + thead.innerHTML = ` + + + Time + + + Component + + + Operation + + + Avg Duration + + + Requests + + + Error Rate + + + P99 Duration + + + `; + } + + logCount.textContent = `${metrics.length} metrics`; + logStats.innerHTML = ` + + ⚡ Performance Metrics (${aggregationLabel}) + + `; + + if (metrics.length === 0) { + tbody.innerHTML = ` + + 📊 No performance metrics available for ${aggregationLabel.toLowerCase()} + + `; + return; + } + + tbody.innerHTML = metrics + .map((metric) => { + const errorRatePercent = (metric.error_rate * 100).toFixed(2); + const errorClass = + metric.error_rate > 0.1 ? "text-red-600" : "text-green-600"; + + return ` + + + ${formatTimestamp(metric.window_start)} + + + ${escapeHtml(metric.component || "-")} + + + ${escapeHtml(metric.operation_type || "-")} + + +
+
Avg: ${metric.avg_duration_ms.toFixed(2)}ms
+
P95: ${metric.p95_duration_ms.toFixed(2)}ms
+
+ + + ${metric.request_count.toLocaleString()} requests + + + ${errorRatePercent}% + ${metric.error_rate > 0.1 ? "⚠️" : ""} + + +
+ P99: ${metric.p99_duration_ms.toFixed(2)}ms +
+ + + `; + }) + .join(""); +} + +/** + * Navigate to previous log page + */ +function previousLogPage() { + if (currentLogPage > 0) { + currentLogPage--; + searchStructuredLogs(); + } +} + +/** + * Navigate to next log page + */ +function nextLogPage() { + currentLogPage++; + searchStructuredLogs(); +} + +/** + * Get root path for API calls + */ +function getRootPath() { + return window.ROOT_PATH || ""; +} + +/** + * Show toast notification + */ +function showToast(message, type = "info") { + // Check if showMessage function exists (from existing admin.js) + if (typeof showMessage === "function") { + // eslint-disable-next-line no-undef + showMessage(message, type === "error" ? "danger" : type); + } else { + console.log(`[${type.toUpperCase()}] ${message}`); + } +} + +// Make functions globally available for HTML onclick handlers +window.searchStructuredLogs = searchStructuredLogs; +window.showCorrelationTrace = showCorrelationTrace; +window.showSecurityEvents = showSecurityEvents; +window.showAuditTrail = showAuditTrail; +window.showPerformanceMetrics = showPerformanceMetrics; +window.handlePerformanceAggregationChange = handlePerformanceAggregationChange; +window.previousLogPage = previousLogPage; +window.nextLogPage = nextLogPage; +window.showLogDetails = showLogDetails; diff --git a/mcpgateway/templates/admin.html b/mcpgateway/templates/admin.html index 96027f5fc..b54474f16 100644 --- a/mcpgateway/templates/admin.html +++ b/mcpgateway/templates/admin.html @@ -590,11 +590,65 @@

- System Logs + 📋 System Logs

+ +
+ + + + + +
+ + + + -
+
Component
Search / Correlation ID
- -
- - - - - -
-
- Loading stats... + Loading...
@@ -694,28 +717,43 @@ - + + + + Virtual MCP Servers
- +
@@ -2577,7 +2615,7 @@

MCP Tools

Clear
- +
@@ -3575,7 +3613,7 @@

MCP Resources

Clear
- + @@ -3795,7 +3833,7 @@

MCP Prompts

Clear - + @@ -4046,7 +4084,7 @@

Clear - + @@ -5463,7 +5501,7 @@

Clear - + @@ -7595,6 +7633,97 @@

+ + + + + + + + +
class="mt-1 px-1.5 block w-full rounded-md border border-gray-300 shadow-sm focus:border-indigo-500 focus:ring-indigo-500 dark:bg-gray-900 dark:placeholder-gray-300 dark:text-gray-300" />
+ + + + + + -->
@@ -10989,131 +11123,22 @@

Breakdown by Type:

let currentLogPage = 0; const logsPerPage = 100; - async function refreshLogs() { - const level = document.getElementById("log-level-filter").value; - const entityType = document.getElementById("log-entity-filter").value; - const search = document.getElementById("log-search").value; - - const params = new URLSearchParams({ - limit: logsPerPage, - offset: currentLogPage * logsPerPage, - order: "desc", - }); - - if (level) params.append("level", level); - if (entityType) params.append("entity_type", entityType); - if (search) params.append("search", search); - - try { - const headers = {}; - const token = localStorage.getItem("token"); - if (token) { - headers["Authorization"] = `Bearer ${token}`; - } - - const response = await fetch( - `${window.ROOT_PATH || ""}/admin/logs?${params}`, - { - headers: headers, - credentials: "same-origin", - }, - ); - - if (!response.ok) throw new Error(`HTTP ${response.status}`); - - const data = await response.json(); - displayLogs(data.logs); - updateLogStats(data.stats); - updateLogCount(data.total); - } catch (error) { - console.error("Error fetching logs:", error); - showErrorMessage("Failed to fetch logs"); - } - } - - function displayLogs(logs) { - const tbody = document.getElementById("logs-tbody"); - tbody.innerHTML = ""; - - logs.forEach((log) => { - const row = document.createElement("tr"); - row.className = "hover:bg-gray-50 dark:hover:bg-gray-700"; - - const timestamp = new Date(log.timestamp).toLocaleString(); - const levelClass = getLevelClass(log.level); - const entity = log.entity_name || log.entity_id || "-"; - - row.innerHTML = ` -

- - - - `; - - tbody.appendChild(row); - }); - } - - function getLevelClass(level) { - switch (level) { - case "debug": - return "bg-gray-100 text-gray-800"; - case "info": - return "bg-blue-100 text-blue-800"; - case "warning": - return "bg-yellow-100 text-yellow-800"; - case "error": - return "bg-red-100 text-red-800"; - case "critical": - return "bg-red-600 text-white"; - default: - return "bg-gray-100 text-gray-800"; - } - } - - function updateLogStats(stats) { - if (!stats) return; - - const statsDiv = document.getElementById("log-stats"); - const levelDist = stats.level_distribution || {}; - const entityDist = stats.entity_distribution || {}; - - let html = ` -
- Buffer: ${stats.usage_percent || 0}% (${stats.buffer_size_mb || 0}/${stats.max_size_mb || 0} MB) - Total: ${stats.total_logs || 0} logs - `; - - if (Object.keys(levelDist).length > 0) { - html += "Levels: "; - for (const [level, count] of Object.entries(levelDist)) { - html += `${level}(${count}) `; - } - html += ""; - } - - html += "
"; - statsDiv.innerHTML = html; - } - - function updateLogCount(total) { - document.getElementById("log-count").textContent = `${total} logs`; - - // Update pagination buttons - document.getElementById("prev-page").disabled = currentLogPage === 0; - document.getElementById("next-page").disabled = - (currentLogPage + 1) * logsPerPage >= total; - } + // Main search function for structured logs + // Note: Structured logging functions are defined in admin.js which loads below: + // - searchStructuredLogs() - Search logs with filters + // - displayLogResults() - Display log table + // - showCorrelationTrace() - Show correlation trace modal + // - displayCorrelationTrace() - Display trace results + // - Helper functions: getLevelClass(), formatDuration(), getDurationClass() + // + // Keeping all structured logging UI logic centralized in admin.js to avoid + // duplication and maintenance issues. admin.js loads last and provides the + // definitive implementations. + + // Note: showSecurityEvents, showAuditTrail, showPerformanceMetrics, + // updateLogStats, and updateLogCount are also defined in admin.js which + // is loaded below and overrides any inline definitions. + // Keeping functions centralized in admin.js to avoid duplication and maintenance issues. function toggleLogStream() { const button = document.getElementById("stream-toggle"); @@ -11207,7 +11232,7 @@

Breakdown by Type:

try { // Use the same auth approach as other admin endpoints const headers = {}; - const token = localStorage.getItem("token"); + const token = getAuthToken(); if (token) { headers["Authorization"] = `Bearer ${token}`; } @@ -11258,7 +11283,7 @@

Breakdown by Type:

async function showLogFiles() { try { const headers = {}; - const token = localStorage.getItem("token"); + const token = getAuthToken(); if (token) { headers["Authorization"] = `Bearer ${token}`; } @@ -11326,7 +11351,7 @@

Available Log Files

async function downloadLogFile(filename) { try { const headers = {}; - const token = localStorage.getItem("token"); + const token = getAuthToken(); if (token) { headers["Authorization"] = `Bearer ${token}`; } @@ -11387,7 +11412,7 @@

Available Log Files

document.addEventListener("DOMContentLoaded", () => { const logFilters = [ "log-level-filter", - "log-entity-filter", + "log-component-filter", "log-search", ]; logFilters.forEach((id) => { diff --git a/mcpgateway/templates/resources_partial.html b/mcpgateway/templates/resources_partial.html index 209df6880..9306b029e 100644 --- a/mcpgateway/templates/resources_partial.html +++ b/mcpgateway/templates/resources_partial.html @@ -30,6 +30,7 @@
Time Level - Entity + Component Message + User + + Duration + + Correlation ID +
- ${timestamp} - - - ${log.level} - - - ${log.entity_type ? `${log.entity_type}: ${entity}` : entity} - - ${escapeHtml(log.message)} - {{ 'Active' if resource.enabled else 'Inactive' }}
+
diff --git a/mcpgateway/toolops/README.md b/mcpgateway/toolops/README.md index 61bdc31e3..4142525ed 100644 --- a/mcpgateway/toolops/README.md +++ b/mcpgateway/toolops/README.md @@ -1,14 +1,15 @@ ### Starting MCP context forge from git repo * Use `make venv` to create virtual environment (tested with python 3.12) -* Install MCP-CF and toolops dependencies using `make install install-dev install-toolops`. Please check if all the packages are installed in the created virtual environment. +* Install MCP-CF and dependencies using `make install install-dev` +* Install toolops and other dependencies using `uv pip install .'[toolops,grpc]'`.Please check if all the packages are installed in the created virtual environment. * `uvicorn mcpgateway.main:app --host 0.0.0.0 --port 4444 --workers 2 --env-file .env` will start Context forge UI and APIs at http://localhost:4444/docs and toolops API endpoints will be shown. ### Important NOTE: * Please provide all configurations such as LLM provider, api keys etc., in `.env` file. And you need to set `TOOLOPS_ENABLED=true` for enabling toolops functionality` * While selecting LLM model , please use the model that supports instruction following (IF) text generation tasks and tool-calling capabilities for executing tools in chat mode. For example `granite4:micro` , `llama-3-3-70b-instruct` etc., * Toolops depends on `agent life cycle toolkit(ALTK)` which is specified in `pyproject.toml` required packages, to install ALTK please set-up github public key SSH if required. -* For toolops developement (Caution) : Only if required to re-install of latest version of `agent life cycle toolkit(ALTK)` from git repo in case of fixes/updates please use pip install via git ssh url. +* For toolops developement (Caution) : Only if required to re-install of latest version of `agent life cycle toolkit(ALTK)` from git repo in case of fixes/updates please use pip install via git ssh url. ### Testing toolops requires MCP server running to set up MCP server using OAPI specification ``` @@ -17,4 +18,4 @@ python3 -m mcpgateway.translate \ --expose-sse \ --expose-streamable-http \ --port 9000 -``` \ No newline at end of file +``` diff --git a/mcpgateway/tools/builder/__init__.py b/mcpgateway/tools/builder/__init__.py new file mode 100644 index 000000000..ec309d8bd --- /dev/null +++ b/mcpgateway/tools/builder/__init__.py @@ -0,0 +1,8 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/tools/builder/__init__.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Builder Package. +""" diff --git a/mcpgateway/tools/builder/cli.py b/mcpgateway/tools/builder/cli.py new file mode 100644 index 000000000..0fdfaebfd --- /dev/null +++ b/mcpgateway/tools/builder/cli.py @@ -0,0 +1,337 @@ +# -*- coding: utf-8 -*- +""" +Location: ./mcpgateway/tools/builder/cli.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +MCP Stack Deployment Tool - Hybrid Dagger/Python Implementation + +This script can run in two modes: +1. Plain Python mode (default) - No external dependencies +2. Dagger mode (opt-in) - Requires dagger-io package, auto-downloads CLI + +Usage: + # Local execution (plain Python mode) + cforge deploy deploy.yaml + + # Use Dagger mode for optimization (requires dagger-io, auto-downloads CLI) + cforge --dagger deploy deploy.yaml + + # Inside container + docker run -v $PWD:/workspace mcpgateway/mcp-builder:latest deploy deploy.yaml + +Features: + - Validates deploy.yaml configuration + - Builds plugin containers from git repos + - Generates mTLS certificates + - Deploys to Kubernetes or Docker Compose + - Integrates with CI/CD vault secrets + +Examples: + >>> # Test that IN_CONTAINER detection works + >>> import os + >>> isinstance(IN_CONTAINER, bool) + True + + >>> # Test that BUILDER_DIR is a Path + >>> from pathlib import Path + >>> isinstance(BUILDER_DIR, Path) + True + + >>> # Test IMPL_MODE is set + >>> isinstance(IMPL_MODE, str) + True +""" + +# Standard +import asyncio +import os +from pathlib import Path +import sys +from typing import Optional + +# Third-Party +from rich.console import Console +from rich.panel import Panel +import typer +from typing_extensions import Annotated + +# First-Party +from mcpgateway.tools.builder.factory import DeployFactory + +app = typer.Typer( + help="Command line tools for deploying the gateway and plugins via a config file.", +) + +console = Console() + +deployer = None + +IN_CONTAINER = os.path.exists("/.dockerenv") or os.environ.get("CONTAINER") == "true" +BUILDER_DIR = Path(__file__).parent / "builder" +IMPL_MODE = "plain" + + +@app.callback() +def cli( + ctx: typer.Context, + dagger: Annotated[bool, typer.Option("--dagger", help="Use Dagger mode (requires dagger-io package)")] = False, + verbose: Annotated[bool, typer.Option("--verbose", "-v", help="Verbose output")] = False, +): + """MCP Stack deployment tool + + Deploys MCP Gateway + external plugins from a single YAML configuration. + + By default, uses plain Python mode. Use --dagger to enable Dagger optimization. + + Args: + ctx: Typer context object + dagger: Enable Dagger mode (requires dagger-io package and auto-downloads CLI) + verbose: Enable verbose output + """ + ctx.ensure_object(dict) + ctx.obj["verbose"] = verbose + ctx.obj["dagger"] = dagger + + if ctx.invoked_subcommand != "version": + # Show execution mode - default to Python, opt-in to Dagger + mode = "dagger" if dagger else "python" + ctx.obj["deployer"], ctx.obj["mode"] = DeployFactory.create_deployer(mode, verbose) + mode_color = "green" if ctx.obj["mode"] == "dagger" else "yellow" + env_text = "container" if IN_CONTAINER else "local" + + if verbose: + console.print(Panel(f"[bold]Mode:[/bold] [{mode_color}]{ctx.obj['mode']}[/{mode_color}]\n" f"[bold]Environment:[/bold] {env_text}\n", title="MCP Deploy", border_style=mode_color)) + + +@app.command() +def validate(ctx: typer.Context, config_file: Annotated[Path, typer.Argument(help="The deployment configuration file.")]): + """Validate mcp-stack.yaml configuration + + Args: + ctx: Typer context object + config_file: Path to the deployment configuration file + """ + impl = ctx.obj["deployer"] + + try: + impl.validate(config_file) + console.print("[green]✓ Configuration valid[/green]") + except Exception as e: + console.print(f"[red]✗ Validation failed: {e}[/red]") + sys.exit(1) + + +@app.command() +def build( + ctx: typer.Context, + config_file: Annotated[Path, typer.Argument(help="The deployment configuration file")], + plugins_only: Annotated[bool, typer.Option("--plugins-only", help="Only build plugin containers")] = False, + plugin: Annotated[Optional[list[str]], typer.Option("--plugin", "-p", help="Build specific plugin(s)")] = None, + no_cache: Annotated[bool, typer.Option("--no-cache", help="Disable build cache")] = False, + copy_env_templates: Annotated[bool, typer.Option("--copy-env-templates", help="Copy .env.template files from plugin repos")] = True, +): + """Build containers + + Args: + ctx: Typer context object + config_file: Path to the deployment configuration file + plugins_only: Only build plugin containers, skip gateway + plugin: List of specific plugin names to build + no_cache: Disable build cache + copy_env_templates: Copy .env.template files from plugin repos + """ + impl = ctx.obj["deployer"] + + try: + asyncio.run(impl.build(config_file, plugins_only=plugins_only, specific_plugins=list(plugin) if plugin else None, no_cache=no_cache, copy_env_templates=copy_env_templates)) + console.print("[green]✓ Build complete[/green]") + + if copy_env_templates: + console.print("[yellow]⚠ IMPORTANT: Review .env files in deploy/env/ before deploying![/yellow]") + console.print("[yellow] Update any required configuration values.[/yellow]") + except Exception as e: + console.print(f"[red]✗ Build failed: {e}[/red]") + sys.exit(1) + + +@app.command() +def certs(ctx: typer.Context, config_file: Annotated[Path, typer.Argument(help="The deployment configuration file")]): + """Generate mTLS certificates + + Args: + ctx: Typer context object + config_file: Path to the deployment configuration file + """ + impl = ctx.obj["deployer"] + + try: + asyncio.run(impl.generate_certificates(config_file)) + console.print("[green]✓ Certificates generated[/green]") + except Exception as e: + console.print(f"[red]✗ Certificate generation failed: {e}[/red]") + sys.exit(1) + + +@app.command() +def deploy( + ctx: typer.Context, + config_file: Annotated[Path, typer.Argument(help="The deployment configuration file")], + output_dir: Annotated[Optional[Path], typer.Option("--output-dir", "-o", help="The deployment configuration file")] = None, + dry_run: Annotated[bool, typer.Option("--dry-run", help="Generate manifests without deploying")] = False, + skip_build: Annotated[bool, typer.Option("--skip-build", help="Skip building containers")] = False, + skip_certs: Annotated[bool, typer.Option("--skip-certs", help="Skip certificate generation")] = False, +): + """Deploy MCP stack + + Args: + ctx: Typer context object + config_file: Path to the deployment configuration file + output_dir: Custom output directory for manifests + dry_run: Generate manifests without deploying + skip_build: Skip building containers + skip_certs: Skip certificate generation + """ + impl = ctx.obj["deployer"] + + try: + asyncio.run(impl.deploy(config_file, dry_run=dry_run, skip_build=skip_build, skip_certs=skip_certs, output_dir=output_dir)) + if dry_run: + console.print("[yellow]✓ Dry-run complete (no changes made)[/yellow]") + else: + console.print("[green]✓ Deployment complete[/green]") + except Exception as e: + console.print(f"[red]✗ Deployment failed: {e}[/red]") + sys.exit(1) + + +@app.command() +def verify( + ctx: typer.Context, + config_file: Annotated[Path, typer.Argument(help="The deployment configuration file")], + wait: Annotated[bool, typer.Option("--wait", help="Wait for deployment to be ready")] = True, + timeout: Annotated[int, typer.Option("--timeout", help="Wait timeout in seconds")] = 300, +): + """Verify deployment health + + Args: + ctx: Typer context object + config_file: Path to the deployment configuration file + wait: Wait for deployment to be ready + timeout: Wait timeout in seconds + """ + impl = ctx.obj["deployer"] + + try: + asyncio.run(impl.verify(config_file, wait=wait, timeout=timeout)) + console.print("[green]✓ Deployment healthy[/green]") + except Exception as e: + console.print(f"[red]✗ Verification failed: {e}[/red]") + sys.exit(1) + + +@app.command() +def destroy( + ctx: typer.Context, + config_file: Annotated[Path, typer.Argument(help="The deployment configuration file")], + force: Annotated[bool, typer.Option("--force", help="Force destruction without confirmation")] = False, +): + """Destroy deployed MCP stack + + Args: + ctx: Typer context object + config_file: Path to the deployment configuration file + force: Force destruction without confirmation + """ + impl = ctx.obj["deployer"] + + if not force: + if not typer.confirm("Are you sure you want to destroy the deployment?"): + console.print("[yellow]Aborted[/yellow]") + return + + try: + asyncio.run(impl.destroy(config_file)) + console.print("[green]✓ Deployment destroyed[/green]") + except Exception as e: + console.print(f"[red]✗ Destruction failed: {e}[/red]") + sys.exit(1) + + +@app.command() +def version(): + """Show version information + + Examples: + >>> # Test that version function exists + >>> callable(version) + True + + >>> # Test that it accesses module constants + >>> IMPL_MODE in ['plain', 'dagger'] + True + """ + console.print( + Panel(f"[bold]MCP Deploy[/bold]\n" f"Version: 1.0.0\n" f"Mode: {IMPL_MODE}\n" f"Environment: {'container' if IN_CONTAINER else 'local'}\n", title="Version Info", border_style="blue") + ) + + +@app.command() +def generate( + ctx: typer.Context, + config_file: Annotated[Path, typer.Argument(help="The deployment configuration file")], + output: Annotated[Optional[Path], typer.Option("--output", "-o", help="Output directory for manifests")] = None, +): + """Generate deployment manifests (k8s or compose) + + Args: + ctx: Typer context object + config_file: Path to the deployment configuration file + output: Output directory for manifests + """ + impl = ctx.obj["deployer"] + + try: + manifests_dir = impl.generate_manifests(config_file, output_dir=output) + console.print(f"[green]✓ Manifests generated: {manifests_dir}[/green]") + except Exception as e: + console.print(f"[red]✗ Manifest generation failed: {e}[/red]") + sys.exit(1) + + +def main(): + """Main entry point + + Raises: + Exception: Any unhandled exception from subcommands (re-raised in debug mode) + + Examples: + >>> # Test that main function exists and is callable + >>> callable(main) + True + + >>> # Test that app is a Typer instance + >>> import typer + >>> isinstance(app, typer.Typer) + True + + >>> # Test that console is available + >>> from rich.console import Console + >>> isinstance(console, Console) + True + """ + try: + app(obj={}) + except KeyboardInterrupt: + console.print("\n[yellow]Interrupted by user[/yellow]") + sys.exit(130) + except Exception as e: + console.print(f"[red]Fatal error: {e}[/red]") + if os.environ.get("MCP_DEBUG"): + raise + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/mcpgateway/tools/builder/common.py b/mcpgateway/tools/builder/common.py new file mode 100644 index 000000000..940652d6d --- /dev/null +++ b/mcpgateway/tools/builder/common.py @@ -0,0 +1,1268 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/tools/builder/common.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Common utilities shared between Dagger and plain Python implementations. + +This module contains shared functionality to avoid code duplication between +the Dagger-based (dagger_module.py) and plain Python (plain_deploy.py) +implementations of the MCP Stack deployment system. + +Shared functions: +- load_config: Load and parse YAML configuration file +- generate_plugin_config: Generate plugins-config.yaml for gateway from mcp-stack.yaml +- generate_kubernetes_manifests: Generate Kubernetes deployment manifests +- generate_compose_manifests: Generate Docker Compose manifest +- copy_env_template: Copy .env.template from plugin repo to env.d/ directory +- handle_registry_operations: Tag and push images to container registry +- get_docker_compose_command: Detect available docker compose command +- run_compose: Run docker compose with error handling +- deploy_compose: Deploy using docker compose up -d +- verify_compose: Verify deployment with docker compose ps +- destroy_compose: Destroy deployment with docker compose down -v +- deploy_kubernetes: Deploy to Kubernetes using kubectl +- verify_kubernetes: Verify Kubernetes deployment health +- destroy_kubernetes: Destroy Kubernetes deployment with kubectl delete +""" + +# Standard +import base64 +import os +from pathlib import Path +import shutil +import subprocess # nosec B404 +from typing import List + +# Third-Party +from jinja2 import Environment, FileSystemLoader +from rich.console import Console +import yaml + +# First-Party +from mcpgateway.tools.builder.schema import MCPStackConfig + +console = Console() + + +def get_deploy_dir() -> Path: + """Get deployment directory from environment variable or default. + + Checks MCP_DEPLOY_DIR environment variable, defaults to './deploy'. + + Returns: + Path to deployment directory + + Examples: + >>> # Test with default value (when MCP_DEPLOY_DIR is not set) + >>> import os + >>> old_value = os.environ.pop("MCP_DEPLOY_DIR", None) + >>> result = get_deploy_dir() + >>> isinstance(result, Path) + True + >>> str(result) + 'deploy' + + >>> # Test with custom environment variable + >>> os.environ["MCP_DEPLOY_DIR"] = "/custom/deploy" + >>> result = get_deploy_dir() + >>> str(result) + '/custom/deploy' + + >>> # Cleanup: restore original value + >>> if old_value is not None: + ... os.environ["MCP_DEPLOY_DIR"] = old_value + ... else: + ... _ = os.environ.pop("MCP_DEPLOY_DIR", None) + """ + deploy_dir = os.environ.get("MCP_DEPLOY_DIR", "./deploy") + return Path(deploy_dir) + + +def load_config(config_file: str) -> MCPStackConfig: + """Load and parse YAML configuration file into validated Pydantic model. + + Args: + config_file: Path to mcp-stack.yaml configuration file + + Returns: + Validated MCPStackConfig Pydantic model + + Raises: + FileNotFoundError: If configuration file doesn't exist + ValidationError: If configuration validation fails + + Examples: + >>> # Test with non-existent file + >>> try: + ... load_config("/nonexistent/path/config.yaml") + ... except FileNotFoundError as e: + ... "Configuration file not found" in str(e) + True + + >>> # Test that function returns MCPStackConfig type + >>> from mcpgateway.tools.builder.schema import MCPStackConfig + >>> # Actual file loading would require a real file: + >>> # config = load_config("mcp-stack.yaml") + >>> # assert isinstance(config, MCPStackConfig) + """ + config_path = Path(config_file) + if not config_path.exists(): + raise FileNotFoundError(f"Configuration file not found: {config_file}") + + with open(config_path, encoding="utf-8") as f: + config_dict = yaml.safe_load(f) + + # Validate and return Pydantic model + return MCPStackConfig.model_validate(config_dict) + + +def generate_plugin_config(config: MCPStackConfig, output_dir: Path, verbose: bool = False) -> Path: + """Generate plugin config.yaml for gateway from mcp-stack.yaml. + + This function is shared between Dagger and plain Python implementations + to avoid code duplication. + + Args: + config: Validated MCPStackConfig Pydantic model + output_dir: Output directory for generated config + verbose: Print verbose output + + Returns: + Path to generated plugins-config.yaml file + + Raises: + FileNotFoundError: If template directory not found + + Examples: + >>> from pathlib import Path + >>> from mcpgateway.tools.builder.schema import MCPStackConfig, DeploymentConfig, GatewayConfig + >>> import tempfile + >>> # Test with minimal config + >>> with tempfile.TemporaryDirectory() as tmpdir: + ... output = Path(tmpdir) + ... config = MCPStackConfig( + ... deployment=DeploymentConfig(type="compose"), + ... gateway=GatewayConfig(image="test:latest"), + ... plugins=[] + ... ) + ... result = generate_plugin_config(config, output, verbose=False) + ... result.name + 'plugins-config.yaml' + + >>> # Test return type + >>> # result_path = generate_plugin_config(config, output_dir) + >>> # isinstance(result_path, Path) + >>> # True + """ + + deployment_type = config.deployment.type + plugins = config.plugins + + # Load template + template_dir = Path(__file__).parent / "templates" + if not template_dir.exists(): + raise FileNotFoundError(f"Template directory not found: {template_dir}") + + # YAML files should not use HTML autoescape + env = Environment(loader=FileSystemLoader(str(template_dir)), autoescape=False) # nosec B701 + template = env.get_template("plugins-config.yaml.j2") + + # Prepare plugin data with computed URLs + plugin_data = [] + for plugin in plugins: + plugin_name = plugin.name + port = plugin.port or 8000 + + # Determine URL based on deployment type + if deployment_type == "compose": + # Use container hostname (lowercase) + hostname = plugin_name.lower() + # Use HTTPS if mTLS is enabled + protocol = "https" if plugin.mtls_enabled else "http" + url = f"{protocol}://{hostname}:{port}/mcp" + else: # kubernetes + # Use Kubernetes service DNS + namespace = config.deployment.namespace or "mcp-gateway" + service_name = f"mcp-plugin-{plugin_name.lower()}" + protocol = "https" if plugin.mtls_enabled else "http" + url = f"{protocol}://{service_name}.{namespace}.svc:{port}/mcp" + + # Build plugin entry with computed URL + plugin_entry = { + "name": plugin_name, + "port": port, + "url": url, + } + + # Merge plugin_overrides (client-side config only, excludes 'config') + # Allowed client-side fields that plugin manager uses + if plugin.plugin_overrides: + overrides = plugin.plugin_overrides + allowed_fields = ["priority", "mode", "description", "version", "author", "hooks", "tags", "conditions"] + for field in allowed_fields: + if field in overrides: + plugin_entry[field] = overrides[field] + + plugin_data.append(plugin_entry) + + # Render template + rendered = template.render(plugins=plugin_data) + + # Write config file + config_path = output_dir / "plugins-config.yaml" + config_path.write_text(rendered) + + if verbose: + print(f"✓ Plugin config generated: {config_path}") + + return config_path + + +def generate_kubernetes_manifests(config: MCPStackConfig, output_dir: Path, verbose: bool = False) -> None: + """Generate Kubernetes manifests from configuration. + + Args: + config: Validated MCPStackConfig Pydantic model + output_dir: Output directory for manifests + verbose: Print verbose output + + Raises: + FileNotFoundError: If template directory not found + + Examples: + >>> from pathlib import Path + >>> import inspect + >>> # Test function signature + >>> sig = inspect.signature(generate_kubernetes_manifests) + >>> list(sig.parameters.keys()) + ['config', 'output_dir', 'verbose'] + + >>> # Test that verbose parameter has default + >>> sig.parameters['verbose'].default + False + + >>> # Actual usage requires valid config and templates: + >>> # from mcpgateway.tools.builder.schema import MCPStackConfig + >>> # generate_kubernetes_manifests(config, Path("./output")) + """ + + # Load templates + template_dir = Path(__file__).parent / "templates" / "kubernetes" + if not template_dir.exists(): + raise FileNotFoundError(f"Template directory not found: {template_dir}") + + # Auto-detect and assign env files if not specified + _auto_detect_env_files(config, output_dir, verbose=verbose) + + env = Environment(loader=FileSystemLoader(str(template_dir)), autoescape=True) # nosec B701 + + # Generate namespace + namespace = config.deployment.namespace or "mcp-gateway" + + # Generate mTLS certificate resources if enabled + gateway_mtls = config.gateway.mtls_enabled if config.gateway.mtls_enabled is not None else True + cert_config = config.certificates + use_cert_manager = cert_config.use_cert_manager if cert_config else False + + if gateway_mtls: + if use_cert_manager: + # Generate cert-manager Certificate CRDs + cert_manager_template = env.get_template("cert-manager-certificates.yaml.j2") + + # Calculate duration and renewBefore in hours + validity_days = cert_config.validity_days or 825 + duration_hours = validity_days * 24 + # Renew at 2/3 of lifetime (cert-manager default) + renew_before_hours = int(duration_hours * 2 / 3) + + # Prepare certificate data + cert_data = { + "namespace": namespace, + "gateway_name": "mcpgateway", + "issuer_name": cert_config.cert_manager_issuer or "mcp-ca-issuer", + "issuer_kind": cert_config.cert_manager_kind or "Issuer", + "duration": duration_hours, + "renew_before": renew_before_hours, + "plugins": [], + } + + # Add plugins with mTLS enabled + for plugin in config.plugins: + if plugin.mtls_enabled if plugin.mtls_enabled is not None else True: + cert_data["plugins"].append({"name": f"mcp-plugin-{plugin.name.lower()}"}) + + # Generate cert-manager certificates manifest + cert_manager_manifest = cert_manager_template.render(**cert_data) + (output_dir / "cert-manager-certificates.yaml").write_text(cert_manager_manifest) + if verbose: + print(" ✓ cert-manager Certificate CRDs manifest generated") + + else: + # Generate traditional certificate secrets (backward compatibility) + cert_secrets_template = env.get_template("cert-secrets.yaml.j2") + + # Prepare certificate data + cert_data = {"namespace": namespace, "gateway_name": "mcpgateway", "plugins": []} + + # Read and encode CA certificate + ca_cert_path = Path("certs/mcp/ca/ca.crt") + if ca_cert_path.exists(): + cert_data["ca_cert_b64"] = base64.b64encode(ca_cert_path.read_bytes()).decode("utf-8") + else: + if verbose: + print(f"[yellow]Warning: CA certificate not found at {ca_cert_path}[/yellow]") + + # Read and encode gateway certificates + gateway_cert_path = Path("certs/mcp/gateway/client.crt") + gateway_key_path = Path("certs/mcp/gateway/client.key") + if gateway_cert_path.exists() and gateway_key_path.exists(): + cert_data["gateway_cert_b64"] = base64.b64encode(gateway_cert_path.read_bytes()).decode("utf-8") + cert_data["gateway_key_b64"] = base64.b64encode(gateway_key_path.read_bytes()).decode("utf-8") + else: + if verbose: + print("[yellow]Warning: Gateway certificates not found[/yellow]") + + # Read and encode plugin certificates + for plugin in config.plugins: + if plugin.mtls_enabled if plugin.mtls_enabled is not None else True: + plugin_name = plugin.name + plugin_cert_path = Path(f"certs/mcp/plugins/{plugin_name}/server.crt") + plugin_key_path = Path(f"certs/mcp/plugins/{plugin_name}/server.key") + + if plugin_cert_path.exists() and plugin_key_path.exists(): + cert_data["plugins"].append( + { + "name": f"mcp-plugin-{plugin_name.lower()}", + "cert_b64": base64.b64encode(plugin_cert_path.read_bytes()).decode("utf-8"), + "key_b64": base64.b64encode(plugin_key_path.read_bytes()).decode("utf-8"), + } + ) + else: + if verbose: + print(f"[yellow]Warning: Plugin {plugin_name} certificates not found[/yellow]") + + # Generate certificate secrets manifest + if "ca_cert_b64" in cert_data: + cert_secrets_manifest = cert_secrets_template.render(**cert_data) + (output_dir / "cert-secrets.yaml").write_text(cert_secrets_manifest) + if verbose: + print(" ✓ mTLS certificate secrets manifest generated") + + # Generate infrastructure manifests (postgres, redis) if enabled + infrastructure = config.infrastructure + + # PostgreSQL + if infrastructure and infrastructure.postgres and infrastructure.postgres.enabled: + postgres_config = infrastructure.postgres + postgres_template = env.get_template("postgres.yaml.j2") + postgres_manifest = postgres_template.render( + namespace=namespace, + image=postgres_config.image or "quay.io/sclorg/postgresql-15-c9s:latest", + database=postgres_config.database or "mcp", + user=postgres_config.user or "postgres", + password=postgres_config.password or "mysecretpassword", + storage_size=postgres_config.storage_size or "10Gi", + storage_class=postgres_config.storage_class, + ) + (output_dir / "postgres-deployment.yaml").write_text(postgres_manifest) + if verbose: + print(" ✓ PostgreSQL deployment manifest generated") + + # Redis + if infrastructure and infrastructure.redis and infrastructure.redis.enabled: + redis_config = infrastructure.redis + redis_template = env.get_template("redis.yaml.j2") + redis_manifest = redis_template.render(namespace=namespace, image=redis_config.image or "redis:latest") + (output_dir / "redis-deployment.yaml").write_text(redis_manifest) + if verbose: + print(" ✓ Redis deployment manifest generated") + + # Generate plugins ConfigMap if plugins are configured + if config.plugins and len(config.plugins) > 0: + configmap_template = env.get_template("plugins-configmap.yaml.j2") + # Read the generated plugins-config.yaml file + plugins_config_path = output_dir / "plugins-config.yaml" + if plugins_config_path.exists(): + plugins_config_content = plugins_config_path.read_text() + configmap_manifest = configmap_template.render(namespace=namespace, plugins_config=plugins_config_content) + (output_dir / "plugins-configmap.yaml").write_text(configmap_manifest) + if verbose: + print(" ✓ Plugins ConfigMap manifest generated") + + # Generate gateway deployment + gateway_template = env.get_template("deployment.yaml.j2") + # Convert Pydantic model to dict for template rendering + gateway_dict = config.gateway.model_dump(exclude_none=True) + gateway_dict["name"] = "mcpgateway" + gateway_dict["namespace"] = namespace + gateway_dict["has_plugins"] = config.plugins and len(config.plugins) > 0 + + # Update image to use full registry path if registry is enabled + if config.gateway.registry and config.gateway.registry.enabled: + base_image_name = config.gateway.image.split(":")[0].split("/")[-1] + image_version = config.gateway.image.split(":")[-1] if ":" in config.gateway.image else "latest" + gateway_dict["image"] = f"{config.gateway.registry.url}/{config.gateway.registry.namespace}/{base_image_name}:{image_version}" + # Set imagePullPolicy from registry config + if config.gateway.registry.image_pull_policy: + gateway_dict["image_pull_policy"] = config.gateway.registry.image_pull_policy + + # Add DATABASE_URL and REDIS_URL to gateway environment if infrastructure is enabled + if "env_vars" not in gateway_dict: + gateway_dict["env_vars"] = {} + + # Enable plugins if any are configured + if config.plugins and len(config.plugins) > 0: + gateway_dict["env_vars"]["PLUGINS_ENABLED"] = "true" + gateway_dict["env_vars"]["PLUGIN_CONFIG_FILE"] = "/app/config/plugins.yaml" + + # Add init containers to wait for infrastructure services + init_containers = [] + + if infrastructure and infrastructure.postgres and infrastructure.postgres.enabled: + postgres = infrastructure.postgres + db_user = postgres.user or "postgres" + db_password = postgres.password or "mysecretpassword" + db_name = postgres.database or "mcp" + gateway_dict["env_vars"]["DATABASE_URL"] = f"postgresql://{db_user}:{db_password}@postgres:5432/{db_name}" + + # Add init container to wait for PostgreSQL + init_containers.append({"name": "wait-for-postgres", "image": "busybox:1.36", "command": ["sh", "-c", "until nc -z postgres 5432; do echo waiting for postgres; sleep 2; done"]}) + + if infrastructure and infrastructure.redis and infrastructure.redis.enabled: + gateway_dict["env_vars"]["REDIS_URL"] = "redis://redis:6379/0" + + # Add init container to wait for Redis + init_containers.append({"name": "wait-for-redis", "image": "busybox:1.36", "command": ["sh", "-c", "until nc -z redis 6379; do echo waiting for redis; sleep 2; done"]}) + + # Add init containers to wait for plugins to be ready + if config.plugins and len(config.plugins) > 0: + for plugin in config.plugins: + plugin_service_name = f"mcp-plugin-{plugin.name.lower()}" + plugin_port = plugin.port or 8000 + # Wait for plugin service to be available + init_containers.append( + { + "name": f"wait-for-{plugin.name.lower()}", + "image": "busybox:1.36", + "command": ["sh", "-c", f"until nc -z {plugin_service_name} {plugin_port}; do echo waiting for {plugin_service_name}; sleep 2; done"], + } + ) + + if init_containers: + gateway_dict["init_containers"] = init_containers + + gateway_manifest = gateway_template.render(**gateway_dict) + (output_dir / "gateway-deployment.yaml").write_text(gateway_manifest) + + # Generate OpenShift Route if configured + if config.deployment.openshift and config.deployment.openshift.create_routes: + route_template = env.get_template("route.yaml.j2") + openshift_config = config.deployment.openshift + + # Auto-detect OpenShift apps domain if not specified + openshift_domain = openshift_config.domain + if not openshift_domain: + try: + # Try to get domain from OpenShift cluster info + result = subprocess.run( + ["kubectl", "get", "ingresses.config.openshift.io", "cluster", "-o", "jsonpath={.spec.domain}"], capture_output=True, text=True, check=False + ) # nosec B603, B607 + if result.returncode == 0 and result.stdout.strip(): + openshift_domain = result.stdout.strip() + if verbose: + console.print(f"[dim]Auto-detected OpenShift domain: {openshift_domain}[/dim]") + else: + # Fallback to common OpenShift Local domain + openshift_domain = "apps-crc.testing" + if verbose: + console.print(f"[yellow]Could not auto-detect OpenShift domain, using default: {openshift_domain}[/yellow]") + except Exception: + # Fallback to common OpenShift Local domain + openshift_domain = "apps-crc.testing" + if verbose: + console.print(f"[yellow]Could not auto-detect OpenShift domain, using default: {openshift_domain}[/yellow]") + + route_manifest = route_template.render(namespace=namespace, openshift_domain=openshift_domain, tls_termination=openshift_config.tls_termination) + (output_dir / "gateway-route.yaml").write_text(route_manifest) + if verbose: + print(" ✓ OpenShift Route manifest generated") + + # Generate plugin deployments + for plugin in config.plugins: + # Convert Pydantic model to dict for template rendering + plugin_dict = plugin.model_dump(exclude_none=True) + plugin_dict["name"] = f"mcp-plugin-{plugin.name.lower()}" + plugin_dict["namespace"] = namespace + + # Update image to use full registry path if registry is enabled + if plugin.registry and plugin.registry.enabled: + base_image_name = plugin.image.split(":")[0].split("/")[-1] + image_version = plugin.image.split(":")[-1] if ":" in plugin.image else "latest" + plugin_dict["image"] = f"{plugin.registry.url}/{plugin.registry.namespace}/{base_image_name}:{image_version}" + # Set imagePullPolicy from registry config + if plugin.registry.image_pull_policy: + plugin_dict["image_pull_policy"] = plugin.registry.image_pull_policy + + plugin_manifest = gateway_template.render(**plugin_dict) + (output_dir / f"plugin-{plugin.name.lower()}-deployment.yaml").write_text(plugin_manifest) + + if verbose: + print(f"✓ Kubernetes manifests generated in {output_dir}") + + +def generate_compose_manifests(config: MCPStackConfig, output_dir: Path, verbose: bool = False) -> None: + """Generate Docker Compose manifest from configuration. + + Args: + config: Validated MCPStackConfig Pydantic model + output_dir: Output directory for manifests + verbose: Print verbose output + + Raises: + FileNotFoundError: If template directory not found + + Examples: + >>> from pathlib import Path + >>> import inspect + >>> # Test function signature + >>> sig = inspect.signature(generate_compose_manifests) + >>> list(sig.parameters.keys()) + ['config', 'output_dir', 'verbose'] + + >>> # Test default parameters + >>> sig.parameters['verbose'].default + False + + >>> # Actual execution requires templates and config: + >>> # from mcpgateway.tools.builder.schema import MCPStackConfig + >>> # generate_compose_manifests(config, Path("./output")) + """ + + # Load templates + template_dir = Path(__file__).parent / "templates" / "compose" + if not template_dir.exists(): + raise FileNotFoundError(f"Template directory not found: {template_dir}") + + # Auto-detect and assign env files if not specified + _auto_detect_env_files(config, output_dir, verbose=verbose) + + # Auto-assign host_ports if expose_port is true but host_port not specified + next_host_port = 8000 + for plugin in config.plugins: + # Port defaults are handled by Pydantic defaults in schema + + # Auto-assign host_port if expose_port is true + if plugin.expose_port and not plugin.host_port: + plugin.host_port = next_host_port # type: ignore + next_host_port += 1 + + # Compute relative certificate paths (from output_dir to project root certs/) + # Certificates are at: ./certs/mcp/... + # Output dir is at: ./deploy/manifests/ + # So relative path is: ../../certs/mcp/... + certs_base = Path.cwd() / "certs" + certs_rel_base = os.path.relpath(certs_base, output_dir) + + # Add computed cert paths to context for template + cert_paths = { + "certs_base": certs_rel_base, + "gateway_cert_dir": os.path.join(certs_rel_base, "mcp/gateway"), + "ca_cert_file": os.path.join(certs_rel_base, "mcp/ca/ca.crt"), + "plugins_cert_base": os.path.join(certs_rel_base, "mcp/plugins"), + } + + env = Environment(loader=FileSystemLoader(str(template_dir)), autoescape=True) # nosec B701 + + # Generate compose file + compose_template = env.get_template("docker-compose.yaml.j2") + # Convert Pydantic model to dict for template rendering + config_dict = config.model_dump(exclude_none=True) + compose_manifest = compose_template.render(**config_dict, cert_paths=cert_paths) + (output_dir / "docker-compose.yaml").write_text(compose_manifest) + + if verbose: + print(f"✓ Compose manifest generated in {output_dir}") + + +def _auto_detect_env_files(config: MCPStackConfig, output_dir: Path, verbose: bool = False) -> None: + """Auto-detect and assign env files if not explicitly specified. + + If env_file is not specified in the config, check if {deploy_dir}/env/.env.{name} + exists and use it. Warn the user when auto-detection is used. + + Args: + config: MCPStackConfig Pydantic model (modified in-place via attribute assignment) + output_dir: Output directory where manifests will be generated (for relative paths) + verbose: Print verbose output + + Examples: + >>> from pathlib import Path + >>> from mcpgateway.tools.builder.schema import MCPStackConfig, DeploymentConfig, GatewayConfig + >>> import tempfile + >>> # Test function modifies config in place + >>> with tempfile.TemporaryDirectory() as tmpdir: + ... output = Path(tmpdir) + ... config = MCPStackConfig( + ... deployment=DeploymentConfig(type="compose"), + ... gateway=GatewayConfig(image="test:latest"), + ... plugins=[] + ... ) + ... # Function modifies config if env files exist + ... _auto_detect_env_files(config, output, verbose=False) + ... # Config object is modified in place + ... isinstance(config, MCPStackConfig) + True + + >>> # Test function signature + >>> import inspect + >>> sig = inspect.signature(_auto_detect_env_files) + >>> 'verbose' in sig.parameters + True + """ + deploy_dir = get_deploy_dir() + env_dir = deploy_dir / "env" + + # Check gateway - since we need to modify the model, we access env_file directly + # Note: Pydantic models allow attribute assignment after creation + if not hasattr(config.gateway, "env_file") or not config.gateway.env_file: + gateway_env = env_dir / ".env.gateway" + if gateway_env.exists(): + # Make path relative to output_dir (where docker-compose.yaml will be) + relative_path = os.path.relpath(gateway_env, output_dir) + config.gateway.env_file = relative_path # type: ignore + print(f"⚠ Auto-detected env file: {gateway_env}") + if verbose: + print(" (Gateway env_file not specified in config)") + + # Check plugins + for plugin in config.plugins: + plugin_name = plugin.name + if not hasattr(plugin, "env_file") or not plugin.env_file: + plugin_env = env_dir / f".env.{plugin_name}" + if plugin_env.exists(): + # Make path relative to output_dir (where docker-compose.yaml will be) + relative_path = os.path.relpath(plugin_env, output_dir) + plugin.env_file = relative_path # type: ignore + print(f"⚠ Auto-detected env file: {plugin_env}") + if verbose: + print(f" (Plugin {plugin_name} env_file not specified in config)") + + +def copy_env_template(plugin_name: str, plugin_build_dir: Path, verbose: bool = False) -> None: + """Copy .env.template from plugin repo to {deploy_dir}/env/ directory. + + Uses MCP_DEPLOY_DIR environment variable if set, defaults to './deploy'. + This function is shared between Dagger and plain Python implementations. + + Args: + plugin_name: Name of the plugin + plugin_build_dir: Path to plugin build directory (contains .env.template) + verbose: Print verbose output + + Examples: + >>> from pathlib import Path + >>> import tempfile + >>> import os + >>> # Test with non-existent template (should return early) + >>> with tempfile.TemporaryDirectory() as tmpdir: + ... build_dir = Path(tmpdir) + ... # No .env.template exists, function returns early + ... copy_env_template("test-plugin", build_dir, verbose=False) + + >>> # Test directory creation + >>> with tempfile.TemporaryDirectory() as tmpdir: + ... os.environ["MCP_DEPLOY_DIR"] = tmpdir + ... build_dir = Path(tmpdir) / "build" + ... build_dir.mkdir() + ... template = build_dir / ".env.template" + ... _ = template.write_text("TEST=value") + ... copy_env_template("test", build_dir, verbose=False) + ... env_file = Path(tmpdir) / "env" / ".env.test" + ... env_file.exists() + True + + >>> # Cleanup + >>> _ = os.environ.pop("MCP_DEPLOY_DIR", None) + """ + # Create {deploy_dir}/env directory if it doesn't exist + deploy_dir = get_deploy_dir() + env_dir = deploy_dir / "env" + env_dir.mkdir(parents=True, exist_ok=True) + + # Look for .env.template in plugin build directory + template_file = plugin_build_dir / ".env.template" + if not template_file.exists(): + if verbose: + print(f"No .env.template found in {plugin_name}") + return + + # Target file path + target_file = env_dir / f".env.{plugin_name}" + + # Only copy if target doesn't exist (don't overwrite user edits) + if target_file.exists(): + if verbose: + print(f"⚠ {target_file} already exists, skipping") + return + + # Copy template + shutil.copy2(template_file, target_file) + if verbose: + print(f"✓ Copied .env.template -> {target_file}") + + +def handle_registry_operations(component, component_name: str, image_tag: str, container_runtime: str, verbose: bool = False) -> str: + """Handle registry tagging and pushing for a built component. + + This function is shared between Dagger and plain Python implementations. + It tags the locally built image with the registry path and optionally pushes it. + + Args: + component: BuildableConfig component (GatewayConfig or PluginConfig) + component_name: Name of the component (gateway or plugin name) + image_tag: Current local image tag + container_runtime: Container runtime to use ("docker" or "podman") + verbose: Print verbose output + + Returns: + Final image tag (registry path if registry enabled, otherwise original tag) + + Raises: + TypeError: If component is not a BuildableConfig instance + ValueError: If registry enabled but missing required configuration + subprocess.CalledProcessError: If tag or push command fails + + Examples: + >>> from mcpgateway.tools.builder.schema import GatewayConfig, RegistryConfig + >>> # Test with registry disabled (returns original tag) + >>> gateway = GatewayConfig(image="test:latest") + >>> result = handle_registry_operations(gateway, "gateway", "test:latest", "docker") + >>> result + 'test:latest' + + >>> # Test type checking - wrong type raises TypeError + >>> try: + ... handle_registry_operations("not a config", "test", "tag:latest", "docker") + ... except TypeError as e: + ... "BuildableConfig" in str(e) + True + + >>> # Test validation error - registry enabled but missing config + >>> from mcpgateway.tools.builder.schema import GatewayConfig, RegistryConfig + >>> gateway_bad = GatewayConfig( + ... image="test:latest", + ... registry=RegistryConfig(enabled=True, url="docker.io") # missing namespace + ... ) + >>> try: + ... handle_registry_operations(gateway_bad, "gateway", "test:latest", "docker") + ... except ValueError as e: + ... "missing" in str(e) and "namespace" in str(e) + True + + >>> # Test validation error - missing URL + >>> gateway_bad2 = GatewayConfig( + ... image="test:latest", + ... registry=RegistryConfig(enabled=True, namespace="myns") # missing url + ... ) + >>> try: + ... handle_registry_operations(gateway_bad2, "gateway", "test:latest", "docker") + ... except ValueError as e: + ... "missing" in str(e) and "url" in str(e) + True + + >>> # Test function signature + >>> import inspect + >>> sig = inspect.signature(handle_registry_operations) + >>> list(sig.parameters.keys()) + ['component', 'component_name', 'image_tag', 'container_runtime', 'verbose'] + + >>> # Test return type + >>> sig.return_annotation + + """ + # First-Party + from mcpgateway.tools.builder.schema import BuildableConfig + + # Type check for better error messages + if not isinstance(component, BuildableConfig): + raise TypeError(f"Component must be a BuildableConfig instance, got {type(component)}") + + # Check if registry is enabled + if not component.registry or not component.registry.enabled: + return image_tag + + registry_config = component.registry + + # Validate registry configuration + if not registry_config.url or not registry_config.namespace: + raise ValueError(f"Registry enabled for {component_name} but missing 'url' or 'namespace' configuration") + + # Construct registry image path + # Format: {registry_url}/{namespace}/{image_name}:{tag} + base_image_name = image_tag.split(":")[0].split("/")[-1] # Extract base name (e.g., "mcpgateway-gateway") + image_version = image_tag.split(":")[-1] if ":" in image_tag else "latest" # Extract tag + registry_image = f"{registry_config.url}/{registry_config.namespace}/{base_image_name}:{image_version}" + + # Tag image for registry + if verbose: + console.print(f"[dim]Tagging {image_tag} as {registry_image}[/dim]") + tag_cmd = [container_runtime, "tag", image_tag, registry_image] + result = subprocess.run(tag_cmd, capture_output=True, text=True, check=True) # nosec B603, B607 + if result.stdout and verbose: + console.print(result.stdout) + + # Push to registry if enabled + if registry_config.push: + if verbose: + console.print(f"[blue]Pushing {registry_image} to registry...[/blue]") + + # Build push command with TLS options + push_cmd = [container_runtime, "push"] + + # For podman, add --tls-verify=false for registries with self-signed certs + # This is common for OpenShift internal registries and local development + if container_runtime == "podman": + push_cmd.append("--tls-verify=false") + + push_cmd.append(registry_image) + + try: + result = subprocess.run(push_cmd, capture_output=True, text=True, check=True) # nosec B603, B607 + if result.stdout and verbose: + console.print(result.stdout) + console.print(f"[green]✓ Pushed to registry: {registry_image}[/green]") + except subprocess.CalledProcessError as e: + console.print(f"[red]✗ Failed to push to registry: {e}[/red]") + if e.stderr: + console.print(f"[red]Error output: {e.stderr}[/red]") + console.print("[yellow]Tip: Authenticate to the registry first:[/yellow]") + console.print(f" {container_runtime} login {registry_config.url}") + raise + + # Update component image reference to use registry path for manifests + component.image = registry_image + + return registry_image + + +# Docker Compose Utilities + + +def get_docker_compose_command() -> List[str]: + """Detect and return available docker compose command. + + Tries to detect docker compose plugin first, then falls back to + standalone docker-compose command. + + Returns: + Command to use: ["docker", "compose"] or ["docker-compose"] + + Raises: + RuntimeError: If neither command is available + + Examples: + >>> # Test that function returns a list + >>> try: + ... cmd = get_docker_compose_command() + ... isinstance(cmd, list) + ... except RuntimeError: + ... # Docker compose not installed in test environment + ... True + True + + >>> # Test that it returns valid command formats + >>> try: + ... cmd = get_docker_compose_command() + ... # Should be either ["docker", "compose"] or ["docker-compose"] + ... cmd in [["docker", "compose"], ["docker-compose"]] + ... except RuntimeError: + ... # Docker compose not installed + ... True + True + + >>> # Test error case (requires mocking, shown for documentation) + >>> # from unittest.mock import patch + >>> # with patch('shutil.which', return_value=None): + >>> # try: + >>> # get_docker_compose_command() + >>> # except RuntimeError as e: + >>> # "Docker Compose not found" in str(e) + >>> # True + """ + # Try docker compose (new plugin) first + if shutil.which("docker"): + try: + subprocess.run(["docker", "compose", "version"], capture_output=True, check=True) # nosec B603, B607 + return ["docker", "compose"] + except (subprocess.CalledProcessError, FileNotFoundError): + pass + + # Fall back to standalone docker-compose + if shutil.which("docker-compose"): + return ["docker-compose"] + + raise RuntimeError("Docker Compose not found. Install docker compose plugin or docker-compose.") + + +def run_compose(compose_file: Path, args: List[str], verbose: bool = False, check: bool = True) -> subprocess.CompletedProcess: + """Run docker compose command with given arguments. + + Args: + compose_file: Path to docker-compose.yaml + args: Arguments to pass to compose (e.g., ["up", "-d"]) + verbose: Print verbose output + check: Raise exception on non-zero exit code + + Returns: + CompletedProcess instance + + Raises: + FileNotFoundError: If compose_file doesn't exist + RuntimeError: If docker compose command fails (when check=True) + + Examples: + >>> from pathlib import Path + >>> import tempfile + >>> # Test with non-existent file + >>> try: + ... run_compose(Path("/nonexistent/docker-compose.yaml"), ["ps"]) + ... except FileNotFoundError as e: + ... "Compose file not found" in str(e) + True + + >>> # Test that args are properly formatted + >>> args = ["up", "-d"] + >>> isinstance(args, list) + True + >>> all(isinstance(arg, str) for arg in args) + True + + >>> # Real execution would require docker compose installed: + >>> # with tempfile.NamedTemporaryFile(suffix=".yaml") as f: + >>> # result = run_compose(Path(f.name), ["--version"], check=False) + >>> # isinstance(result, subprocess.CompletedProcess) + """ + if not compose_file.exists(): + raise FileNotFoundError(f"Compose file not found: {compose_file}") + + compose_cmd = get_docker_compose_command() + full_cmd = compose_cmd + ["-f", str(compose_file)] + args + + if verbose: + console.print(f"[dim]Running: {' '.join(full_cmd)}[/dim]") + + try: + result = subprocess.run(full_cmd, capture_output=True, text=True, check=check) # nosec B603, B607 + return result + except subprocess.CalledProcessError as e: + console.print("\n[red bold]Docker Compose command failed:[/red bold]") + if e.stdout: + console.print(f"[yellow]Output:[/yellow]\n{e.stdout}") + if e.stderr: + console.print(f"[red]Error:[/red]\n{e.stderr}") + raise RuntimeError(f"Docker Compose failed with exit code {e.returncode}") from e + + +def deploy_compose(compose_file: Path, verbose: bool = False) -> None: + """Deploy using docker compose up -d. + + Args: + compose_file: Path to docker-compose.yaml + verbose: Print verbose output + + Raises: + RuntimeError: If deployment fails + + Examples: + >>> from pathlib import Path + >>> # Test that function signature is correct + >>> import inspect + >>> sig = inspect.signature(deploy_compose) + >>> 'compose_file' in sig.parameters + True + >>> 'verbose' in sig.parameters + True + + >>> # Test with non-existent file (would fail at run_compose) + >>> # deploy_compose(Path("/nonexistent.yaml")) # Raises FileNotFoundError + """ + result = run_compose(compose_file, ["up", "-d"], verbose=verbose) + if result.stdout and verbose: + console.print(result.stdout) + console.print("[green]✓ Deployed with Docker Compose[/green]") + + +def verify_compose(compose_file: Path, verbose: bool = False) -> str: + """Verify Docker Compose deployment with ps command. + + Args: + compose_file: Path to docker-compose.yaml + verbose: Print verbose output + + Returns: + Output from docker compose ps command + + Examples: + >>> from pathlib import Path + >>> # Test return type + >>> import inspect + >>> sig = inspect.signature(verify_compose) + >>> sig.return_annotation + + + >>> # Test parameters + >>> list(sig.parameters.keys()) + ['compose_file', 'verbose'] + + >>> # Actual execution requires docker compose: + >>> # output = verify_compose(Path("docker-compose.yaml")) + >>> # isinstance(output, str) + """ + result = run_compose(compose_file, ["ps"], verbose=verbose, check=False) + return result.stdout + + +def destroy_compose(compose_file: Path, verbose: bool = False) -> None: + """Destroy Docker Compose deployment with down -v. + + Args: + compose_file: Path to docker-compose.yaml + verbose: Print verbose output + + Raises: + RuntimeError: If destruction fails + + Examples: + >>> from pathlib import Path + >>> # Test with non-existent file (graceful handling) + >>> destroy_compose(Path("/nonexistent/docker-compose.yaml"), verbose=False) + Compose file not found: /nonexistent/docker-compose.yaml + Nothing to destroy + + >>> # Test function signature + >>> import inspect + >>> sig = inspect.signature(destroy_compose) + >>> 'verbose' in sig.parameters + True + """ + if not compose_file.exists(): + console.print(f"[yellow]Compose file not found: {compose_file}[/yellow]") + console.print("[yellow]Nothing to destroy[/yellow]") + return + + result = run_compose(compose_file, ["down", "-v"], verbose=verbose) + if result.stdout and verbose: + console.print(result.stdout) + console.print("[green]✓ Destroyed Docker Compose deployment[/green]") + + +# Kubernetes kubectl utilities + + +def deploy_kubernetes(manifests_dir: Path, verbose: bool = False) -> None: + """Deploy to Kubernetes using kubectl. + + Applies manifests in correct order: + 1. Deployments (creates namespaces) + 2. Certificate resources (secrets or cert-manager CRDs) + 3. ConfigMaps (plugins configuration) + 4. Infrastructure (PostgreSQL, Redis) + 5. OpenShift Routes (if configured) + + Excludes plugins-config.yaml (not a Kubernetes resource). + + Args: + manifests_dir: Path to directory containing Kubernetes manifests + verbose: Print verbose output + + Raises: + RuntimeError: If kubectl not found or deployment fails + + Examples: + >>> from pathlib import Path + >>> import shutil + >>> # Test that function checks for kubectl + >>> if not shutil.which("kubectl"): + ... # Would raise RuntimeError + ... print("kubectl not found") + ... else: + ... print("kubectl available") + kubectl... + + >>> # Test function signature + >>> import inspect + >>> sig = inspect.signature(deploy_kubernetes) + >>> list(sig.parameters.keys()) + ['manifests_dir', 'verbose'] + """ + if not shutil.which("kubectl"): + raise RuntimeError("kubectl not found. Cannot deploy to Kubernetes.") + + # Get all manifest files, excluding plugins-config.yaml (not a Kubernetes resource) + all_manifests = sorted(manifests_dir.glob("*.yaml")) + all_manifests = [m for m in all_manifests if m.name != "plugins-config.yaml"] + + # Identify different types of manifests + cert_secrets = manifests_dir / "cert-secrets.yaml" + cert_manager_certs = manifests_dir / "cert-manager-certificates.yaml" + postgres_deploy = manifests_dir / "postgres-deployment.yaml" + redis_deploy = manifests_dir / "redis-deployment.yaml" + plugins_configmap = manifests_dir / "plugins-configmap.yaml" + + # 1. Apply all deployments first (creates namespaces) + deployment_files = [m for m in all_manifests if m.name.endswith("-deployment.yaml") and m not in [cert_secrets, postgres_deploy, redis_deploy]] + + # Apply deployment files (this creates the namespace) + for manifest in deployment_files: + result = subprocess.run(["kubectl", "apply", "-f", str(manifest)], capture_output=True, text=True, check=False) # nosec B603, B607 + if result.stdout and verbose: + console.print(result.stdout) + if result.returncode != 0: + raise RuntimeError(f"kubectl apply failed: {result.stderr}") + + # 2. Apply certificate resources (now namespace exists) + # Check for both cert-secrets.yaml (local mode) and cert-manager-certificates.yaml (cert-manager mode) + if cert_manager_certs.exists(): + result = subprocess.run(["kubectl", "apply", "-f", str(cert_manager_certs)], capture_output=True, text=True, check=False) # nosec B603, B607 + if result.stdout and verbose: + console.print(result.stdout) + if result.returncode != 0: + raise RuntimeError(f"kubectl apply failed: {result.stderr}") + elif cert_secrets.exists(): + result = subprocess.run(["kubectl", "apply", "-f", str(cert_secrets)], capture_output=True, text=True, check=False) # nosec B603, B607 + if result.stdout and verbose: + console.print(result.stdout) + if result.returncode != 0: + raise RuntimeError(f"kubectl apply failed: {result.stderr}") + + # 3. Apply ConfigMaps (needed by deployments) + if plugins_configmap.exists(): + result = subprocess.run(["kubectl", "apply", "-f", str(plugins_configmap)], capture_output=True, text=True, check=False) # nosec B603, B607 + if result.stdout and verbose: + console.print(result.stdout) + if result.returncode != 0: + raise RuntimeError(f"kubectl apply failed: {result.stderr}") + + # 4. Apply infrastructure + for infra_file in [postgres_deploy, redis_deploy]: + if infra_file.exists(): + result = subprocess.run(["kubectl", "apply", "-f", str(infra_file)], capture_output=True, text=True, check=False) # nosec B603, B607 + if result.stdout and verbose: + console.print(result.stdout) + if result.returncode != 0: + raise RuntimeError(f"kubectl apply failed: {result.stderr}") + + # 5. Apply OpenShift Routes (if configured) + gateway_route = manifests_dir / "gateway-route.yaml" + if gateway_route.exists(): + result = subprocess.run(["kubectl", "apply", "-f", str(gateway_route)], capture_output=True, text=True, check=False) # nosec B603, B607 + if result.stdout and verbose: + console.print(result.stdout) + if result.returncode != 0: + # Don't fail on Route errors (may not be on OpenShift) + if verbose: + console.print(f"[yellow]Warning: Could not apply Route (may not be on OpenShift): {result.stderr}[/yellow]") + + console.print("[green]✓ Deployed to Kubernetes[/green]") + + +def verify_kubernetes(namespace: str, wait: bool = False, timeout: int = 300, verbose: bool = False) -> str: + """Verify Kubernetes deployment health. + + Args: + namespace: Kubernetes namespace to check + wait: Wait for pods to be ready + timeout: Wait timeout in seconds + verbose: Print verbose output + + Returns: + String output from kubectl get pods + + Raises: + RuntimeError: If kubectl not found or verification fails + + Examples: + >>> # Test function signature and return type + >>> import inspect + >>> sig = inspect.signature(verify_kubernetes) + >>> sig.return_annotation + + + >>> # Test parameters + >>> params = list(sig.parameters.keys()) + >>> 'namespace' in params and 'wait' in params and 'timeout' in params + True + + >>> # Test default timeout value + >>> sig.parameters['timeout'].default + 300 + """ + if not shutil.which("kubectl"): + raise RuntimeError("kubectl not found. Cannot verify Kubernetes deployment.") + + # Get pod status + result = subprocess.run(["kubectl", "get", "pods", "-n", namespace], capture_output=True, text=True, check=False) # nosec B603, B607 + output = result.stdout if result.stdout else "" + if result.returncode != 0: + raise RuntimeError(f"kubectl get pods failed: {result.stderr}") + + # Wait for pods if requested + if wait: + result = subprocess.run(["kubectl", "wait", "--for=condition=Ready", "pod", "--all", "-n", namespace, f"--timeout={timeout}s"], capture_output=True, text=True, check=False) # nosec B603, B607 + if result.stdout and verbose: + console.print(result.stdout) + if result.returncode != 0: + raise RuntimeError(f"kubectl wait failed: {result.stderr}") + + return output + + +def destroy_kubernetes(manifests_dir: Path, verbose: bool = False) -> None: + """Destroy Kubernetes deployment. + + Args: + manifests_dir: Path to directory containing Kubernetes manifests + verbose: Print verbose output + + Raises: + RuntimeError: If kubectl not found or destruction fails + + Examples: + >>> from pathlib import Path + >>> # Test with non-existent directory (graceful handling) + >>> import shutil + >>> if shutil.which("kubectl"): + ... destroy_kubernetes(Path("/nonexistent/manifests"), verbose=False) + ... else: + ... print("kubectl not available") + Manifests directory not found: /nonexistent/manifests + Nothing to destroy + + >>> # Test function signature + >>> import inspect + >>> sig = inspect.signature(destroy_kubernetes) + >>> list(sig.parameters.keys()) + ['manifests_dir', 'verbose'] + """ + if not shutil.which("kubectl"): + raise RuntimeError("kubectl not found. Cannot destroy Kubernetes deployment.") + + if not manifests_dir.exists(): + console.print(f"[yellow]Manifests directory not found: {manifests_dir}[/yellow]") + console.print("[yellow]Nothing to destroy[/yellow]") + return + + # Delete all manifests except plugins-config.yaml + all_manifests = sorted(manifests_dir.glob("*.yaml")) + all_manifests = [m for m in all_manifests if m.name != "plugins-config.yaml"] + + for manifest in all_manifests: + result = subprocess.run(["kubectl", "delete", "-f", str(manifest), "--ignore-not-found=true"], capture_output=True, text=True, check=False) # nosec B603, B607 + if result.stdout and verbose: + console.print(result.stdout) + if result.returncode != 0 and "NotFound" not in result.stderr: + console.print(f"[yellow]Warning: {result.stderr}[/yellow]") + + console.print("[green]✓ Destroyed Kubernetes deployment[/green]") diff --git a/mcpgateway/tools/builder/dagger_deploy.py b/mcpgateway/tools/builder/dagger_deploy.py new file mode 100644 index 000000000..81367625d --- /dev/null +++ b/mcpgateway/tools/builder/dagger_deploy.py @@ -0,0 +1,557 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/tools/builder/dagger_deploy.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Dagger-based MCP Stack Deployment Module + +This module provides optimized build and deployment using Dagger. + +Features: +- Automatic caching and parallelization +- Content-addressable storage +- Efficient multi-stage builds +- Built-in layer caching +""" + +# Standard +from pathlib import Path +from typing import List, Optional + +try: + # Third-Party + import dagger + from dagger import dag + + DAGGER_AVAILABLE = True +except ImportError: + DAGGER_AVAILABLE = False + dagger = None # type: ignore + dag = None # type: ignore + +# Third-Party +from rich.console import Console +from rich.progress import Progress, SpinnerColumn, TextColumn + +# First-Party +from mcpgateway.tools.builder.common import ( + deploy_compose, + deploy_kubernetes, + destroy_compose, + destroy_kubernetes, + generate_compose_manifests, + generate_kubernetes_manifests, + generate_plugin_config, + get_deploy_dir, + handle_registry_operations, + load_config, + verify_compose, + verify_kubernetes, +) +from mcpgateway.tools.builder.common import copy_env_template as copy_template +from mcpgateway.tools.builder.pipeline import CICDModule +from mcpgateway.tools.builder.schema import BuildableConfig, MCPStackConfig + +console = Console() + + +class MCPStackDagger(CICDModule): + """Dagger-based implementation of MCP Stack deployment.""" + + def __init__(self, verbose: bool = False): + """Initialize MCPStackDagger instance. + + Args: + verbose: Enable verbose output + + Raises: + ImportError: If dagger is not installed + """ + if not DAGGER_AVAILABLE: + raise ImportError("Dagger is not installed. Install with: pip install dagger-io\n" "Alternatively, use the plain Python deployer with --deployer=python") + super().__init__(verbose) + + async def build(self, config_file: str, plugins_only: bool = False, specific_plugins: Optional[List[str]] = None, no_cache: bool = False, copy_env_templates: bool = False) -> None: + """Build gateway and plugin containers using Dagger. + + Args: + config_file: Path to mcp-stack.yaml + plugins_only: Only build plugins, skip gateway + specific_plugins: List of specific plugin names to build + no_cache: Disable Dagger cache + copy_env_templates: Copy .env.template files from cloned repos + + Raises: + Exception: If build fails for any component + """ + config = load_config(config_file) + + async with dagger.connection(dagger.Config(workdir=str(Path.cwd()))): + # Build gateway (unless plugins_only=True) + if not plugins_only: + gateway = config.gateway + if gateway.repo: + with Progress(SpinnerColumn(), TextColumn("[progress.description]{task.description}"), console=self.console) as progress: + task = progress.add_task("Building gateway...", total=None) + try: + await self._build_component_with_dagger(gateway, "gateway", no_cache=no_cache) + progress.update(task, completed=1, description="[green]✓ Built gateway[/green]") + except Exception as e: + progress.update(task, completed=1, description="[red]✗ Failed gateway[/red]") + # Print full error after progress bar closes + self.console.print("\n[red bold]Gateway build failed:[/red bold]") + self.console.print(f"[red]{type(e).__name__}: {str(e)}[/red]") + if self.verbose: + # Standard + import traceback + + self.console.print(f"[dim]{traceback.format_exc()}[/dim]") + raise + elif self.verbose: + self.console.print("[dim]Skipping gateway build (using pre-built image)[/dim]") + + # Build plugins + plugins = config.plugins + + if specific_plugins: + plugins = [p for p in plugins if p.name in specific_plugins] + + if not plugins: + self.console.print("[yellow]No plugins to build[/yellow]") + return + + with Progress(SpinnerColumn(), TextColumn("[progress.description]{task.description}"), console=self.console) as progress: + + for plugin in plugins: + plugin_name = plugin.name + + # Skip if pre-built image specified + if plugin.image and not plugin.repo: + task = progress.add_task(f"Skipping {plugin_name} (using pre-built image)", total=1) + progress.update(task, completed=1) + continue + + task = progress.add_task(f"Building {plugin_name}...", total=None) + + try: + await self._build_component_with_dagger(plugin, plugin_name, no_cache=no_cache, copy_env_templates=copy_env_templates) + progress.update(task, completed=1, description=f"[green]✓ Built {plugin_name}[/green]") + except Exception as e: + progress.update(task, completed=1, description=f"[red]✗ Failed {plugin_name}[/red]") + # Print full error after progress bar closes + self.console.print(f"\n[red bold]Plugin '{plugin_name}' build failed:[/red bold]") + self.console.print(f"[red]{type(e).__name__}: {str(e)}[/red]") + if self.verbose: + # Standard + import traceback + + self.console.print(f"[dim]{traceback.format_exc()}[/dim]") + raise + + async def generate_certificates(self, config_file: str) -> None: + """Generate mTLS certificates for plugins. + + Supports two modes: + 1. Local generation (use_cert_manager=false): Uses Dagger to generate certificates locally + 2. cert-manager (use_cert_manager=true): Skips local generation, cert-manager will create certificates + + Args: + config_file: Path to mcp-stack.yaml + + Raises: + dagger.ExecError: If certificate generation command fails (when using local generation) + dagger.QueryError: If Dagger query fails (when using local generation) + """ + config = load_config(config_file) + + # Check if using cert-manager + cert_config = config.certificates + use_cert_manager = cert_config.use_cert_manager if cert_config else False + validity_days = cert_config.validity_days if cert_config else 825 + + if use_cert_manager: + # Skip local generation - cert-manager will handle certificate creation + if self.verbose: + self.console.print("[blue]Using cert-manager for certificate management[/blue]") + self.console.print("[dim]Skipping local certificate generation (cert-manager will create certificates)[/dim]") + return + + # Local certificate generation (backward compatibility) + if self.verbose: + self.console.print("[blue]Generating mTLS certificates locally...[/blue]") + + # Use Dagger container to run certificate generation + async with dagger.connection(dagger.Config(workdir=str(Path.cwd()))): + # Mount current directory + source = dag.host().directory(".") + try: + # Use Alpine with openssl + container = ( + dag.container() + .from_("alpine:latest") + .with_exec(["apk", "add", "--no-cache", "openssl", "python3", "py3-pip", "make", "bash"]) + .with_mounted_directory("/workspace", source) + .with_workdir("/workspace") + # .with_exec(["python3", "-m", "venv", ".venv"]) + # .with_exec(["sh", "-c", "source .venv/bin/activate && pip install pyyaml"]) + # .with_exec(["pip", "install", "pyyaml"]) + ) + + # Generate CA + container = container.with_exec(["sh", "-c", f"make certs-mcp-ca MCP_CERT_DAYS={validity_days}"]) + + # Generate gateway cert + container = container.with_exec(["sh", "-c", f"make certs-mcp-gateway MCP_CERT_DAYS={validity_days}"]) + + # Generate plugin certificates + plugins = config.plugins + for plugin in plugins: + plugin_name = plugin.name + container = container.with_exec(["sh", "-c", f"make certs-mcp-plugin PLUGIN_NAME={plugin_name} MCP_CERT_DAYS={validity_days}"]) + + # Export certificates back to host + output = container.directory("/workspace/certs") + await output.export("./certs") + except dagger.ExecError as e: + self.console.print(f"Dagger Exec Error: {e.message}") + self.console.print(f"Exit Code: {e.exit_code}") + self.console.print(f"Stderr: {e.stderr}") + raise + except dagger.QueryError as e: + self.console.print(f"Dagger Query Error: {e.errors}") + self.console.print(f"Debug Query: {e.debug_query()}") + raise + except Exception as e: + self.console.print(f"An unexpected error occurred: {e}") + raise + + if self.verbose: + self.console.print("[green]✓ Certificates generated locally[/green]") + + async def deploy(self, config_file: str, dry_run: bool = False, skip_build: bool = False, skip_certs: bool = False, output_dir: Optional[str] = None) -> None: + """Deploy MCP stack. + + Args: + config_file: Path to mcp-stack.yaml + dry_run: Generate manifests without deploying + skip_build: Skip building containers + skip_certs: Skip certificate generation + output_dir: Output directory for manifests (default: ./deploy) + + Raises: + ValueError: If unsupported deployment type specified + dagger.ExecError: If deployment command fails + dagger.QueryError: If Dagger query fails + """ + config = load_config(config_file) + + # Build containers + if not skip_build: + await self.build(config_file) + + # Generate certificates (only if mTLS is enabled) + gateway_mtls = config.gateway.mtls_enabled if config.gateway.mtls_enabled is not None else True + plugin_mtls = any((p.mtls_enabled if p.mtls_enabled is not None else True) for p in config.plugins) + mtls_needed = gateway_mtls or plugin_mtls + + if not skip_certs and mtls_needed: + await self.generate_certificates(config_file) + elif not skip_certs and not mtls_needed: + if self.verbose: + self.console.print("[dim]Skipping certificate generation (mTLS disabled)[/dim]") + + # Generate manifests + manifests_dir = self.generate_manifests(config_file, output_dir=output_dir) + + if dry_run: + self.console.print(f"[yellow]Dry-run: Manifests generated in {manifests_dir}[/yellow]") + return + + # Apply deployment + deployment_type = config.deployment.type + + async with dagger.connection(dagger.Config(workdir=str(Path.cwd()))): + try: + if deployment_type == "kubernetes": + await self._deploy_kubernetes(manifests_dir) + elif deployment_type == "compose": + await self._deploy_compose(manifests_dir) + else: + raise ValueError(f"Unsupported deployment type: {deployment_type}") + except dagger.ExecError as e: + self.console.print(f"Dagger Exec Error: {e.message}") + self.console.print(f"Exit Code: {e.exit_code}") + self.console.print(f"Stderr: {e.stderr}") + raise + except dagger.QueryError as e: + self.console.print(f"Dagger Query Error: {e.errors}") + self.console.print(f"Debug Query: {e.debug_query()}") + raise + except Exception as e: + # Extract detailed error from Dagger exception + error_msg = str(e) + self.console.print("\n[red bold]Deployment failed:[/red bold]") + self.console.print(f"[red]{error_msg}[/red]") + + # Check if it's a compose-specific error and try to provide more context + if "compose" in error_msg.lower() and self.verbose: + self.console.print("\n[yellow]Hint:[/yellow] Check the generated docker-compose.yaml:") + self.console.print(f"[dim] {manifests_dir}/docker-compose.yaml[/dim]") + self.console.print("[yellow]Try running manually:[/yellow]") + self.console.print(f"[dim] cd {manifests_dir} && docker compose up[/dim]") + + raise + + async def verify(self, config_file: str, wait: bool = False, timeout: int = 300) -> None: + """Verify deployment health. + + Args: + config_file: Path to mcp-stack.yaml + wait: Wait for deployment to be ready + timeout: Wait timeout in seconds + """ + config = load_config(config_file) + deployment_type = config.deployment.type + + if self.verbose: + self.console.print("[blue]Verifying deployment...[/blue]") + + async with dagger.connection(dagger.Config(workdir=str(Path.cwd()))): + if deployment_type == "kubernetes": + await self._verify_kubernetes(config, wait=wait, timeout=timeout) + elif deployment_type == "compose": + await self._verify_compose(config, wait=wait, timeout=timeout) + + async def destroy(self, config_file: str) -> None: + """Destroy deployed MCP stack. + + Args: + config_file: Path to mcp-stack.yaml + """ + config = load_config(config_file) + deployment_type = config.deployment.type + + if self.verbose: + self.console.print("[blue]Destroying deployment...[/blue]") + + async with dagger.connection(dagger.Config(workdir=str(Path.cwd()))): + if deployment_type == "kubernetes": + await self._destroy_kubernetes(config) + elif deployment_type == "compose": + await self._destroy_compose(config) + + def generate_manifests(self, config_file: str, output_dir: Optional[str] = None) -> Path: + """Generate deployment manifests. + + Args: + config_file: Path to mcp-stack.yaml + output_dir: Output directory for manifests + + Returns: + Path to generated manifests directory + + Raises: + ValueError: If unsupported deployment type specified + """ + config = load_config(config_file) + deployment_type = config.deployment.type + + if output_dir is None: + deploy_dir = get_deploy_dir() + # Separate subdirectories for kubernetes and compose + manifests_path = deploy_dir / "manifests" / deployment_type + else: + manifests_path = Path(output_dir) + + manifests_path.mkdir(parents=True, exist_ok=True) + + # Store output dir for later use + self._last_output_dir = manifests_path + + # Generate plugin config.yaml for gateway (shared function) + generate_plugin_config(config, manifests_path, verbose=self.verbose) + + if deployment_type == "kubernetes": + generate_kubernetes_manifests(config, manifests_path, verbose=self.verbose) + elif deployment_type == "compose": + generate_compose_manifests(config, manifests_path, verbose=self.verbose) + else: + raise ValueError(f"Unsupported deployment type: {deployment_type}") + + return manifests_path + + # Private helper methods + + async def _build_component_with_dagger(self, component: BuildableConfig, component_name: str, no_cache: bool = False, copy_env_templates: bool = False) -> None: + """Build a component (gateway or plugin) container using Dagger. + + Args: + component: Component configuration (GatewayConfig or PluginConfig) + component_name: Name of the component (gateway or plugin name) + no_cache: Disable cache + copy_env_templates: Copy .env.template from repo if it exists + + Raises: + ValueError: If component has no repo field + Exception: If build or export fails + """ + repo = component.repo + + if not repo: + raise ValueError(f"Component '{component_name}' has no 'repo' field") + + # Clone repository to local directory for env template access + git_ref = component.ref or "main" + clone_dir = Path(f"./build/{component_name}") + + # For Dagger, we still need local clone if copying env templates + if copy_env_templates: + # Standard + import subprocess # nosec B404 + + clone_dir.mkdir(parents=True, exist_ok=True) + + if (clone_dir / ".git").exists(): + subprocess.run(["git", "fetch", "origin", git_ref], cwd=clone_dir, check=True, capture_output=True) # nosec B603, B607 + # Checkout what we just fetched (FETCH_HEAD) + subprocess.run(["git", "checkout", "FETCH_HEAD"], cwd=clone_dir, check=True, capture_output=True) # nosec B603, B607 + else: + subprocess.run(["git", "clone", "--branch", git_ref, "--depth", "1", repo, str(clone_dir)], check=True, capture_output=True) # nosec B603, B607 + + # Determine build context + build_context = component.context or "." + build_dir = clone_dir / build_context + + # Copy env template using shared function + copy_template(component_name, build_dir, verbose=self.verbose) + + # Use Dagger for the actual build + source = dag.git(repo).branch(git_ref).tree() + + # If component has context subdirectory, navigate to it + build_context = component.context or "." + if build_context != ".": + source = source.directory(build_context) + + # Detect Containerfile/Dockerfile + containerfile = component.containerfile or "Containerfile" + + # Build container - determine image tag + if component.image: + # Use explicitly specified image name + image_tag = component.image + else: + # Generate default image name based on component type + image_tag = f"mcpgateway-{component_name.lower()}:latest" + + # Build with optional target stage for multi-stage builds + build_kwargs = {"dockerfile": containerfile} + if component.target: + build_kwargs["target"] = component.target + + # Use docker_build on the directory + container = source.docker_build(**build_kwargs) + + # Export image to Docker daemon (always export, Dagger handles caching) + # Workaround for dagger-io 0.19.0 bug: export_image returns None instead of Void + # The export actually works, but beartype complains about the return type + try: + await container.export_image(image_tag) + except Exception as e: + # Ignore beartype validation error - the export actually succeeds + if "BeartypeCallHintReturnViolation" not in str(type(e)): + raise + + # Handle registry operations (tag and push if enabled) + # Note: Dagger exports to local docker/podman, so we need to detect which runtime to use + # Standard + import shutil + + container_runtime = "docker" if shutil.which("docker") else "podman" + image_tag = handle_registry_operations(component, component_name, image_tag, container_runtime, verbose=self.verbose) + + if self.verbose: + self.console.print(f"[green]✓ Built {component_name} -> {image_tag}[/green]") + + async def _deploy_kubernetes(self, manifests_dir: Path) -> None: + """Deploy to Kubernetes using kubectl. + + Uses shared deploy_kubernetes() from common.py to avoid code duplication. + + Args: + manifests_dir: Path to directory containing Kubernetes manifests + """ + deploy_kubernetes(manifests_dir, verbose=self.verbose) + + async def _deploy_compose(self, manifests_dir: Path) -> None: + """Deploy using Docker Compose. + + Uses shared deploy_compose() from common.py to avoid code duplication. + + Args: + manifests_dir: Path to directory containing compose manifest + """ + compose_file = manifests_dir / "docker-compose.yaml" + deploy_compose(compose_file, verbose=self.verbose) + + async def _verify_kubernetes(self, config: MCPStackConfig, wait: bool = False, timeout: int = 300) -> None: + """Verify Kubernetes deployment health. + + Uses shared verify_kubernetes() from common.py to avoid code duplication. + + Args: + config: Parsed configuration Pydantic model + wait: Wait for pods to be ready + timeout: Wait timeout in seconds + """ + namespace = config.deployment.namespace or "mcp-gateway" + output = verify_kubernetes(namespace, wait=wait, timeout=timeout, verbose=self.verbose) + self.console.print(output) + + async def _verify_compose(self, config: MCPStackConfig, wait: bool = False, timeout: int = 300) -> None: + """Verify Docker Compose deployment health. + + Uses shared verify_compose() from common.py to avoid code duplication. + + Args: + config: Parsed configuration Pydantic model + wait: Wait for containers to be ready + timeout: Wait timeout in seconds + """ + _ = config, wait, timeout # Reserved for future use + # Use the same manifests directory as generate_manifests + deploy_dir = get_deploy_dir() + output_dir = getattr(self, "_last_output_dir", deploy_dir / "manifests" / "compose") + compose_file = output_dir / "docker-compose.yaml" + output = verify_compose(compose_file, verbose=self.verbose) + self.console.print(output) + + async def _destroy_kubernetes(self, config: MCPStackConfig) -> None: + """Destroy Kubernetes deployment. + + Uses shared destroy_kubernetes() from common.py to avoid code duplication. + + Args: + config: Parsed configuration Pydantic model + """ + _ = config # Reserved for future use (namespace, labels, etc.) + # Use the same manifests directory as generate_manifests + deploy_dir = get_deploy_dir() + manifests_dir = getattr(self, "_last_output_dir", deploy_dir / "manifests" / "kubernetes") + destroy_kubernetes(manifests_dir, verbose=self.verbose) + + async def _destroy_compose(self, config: MCPStackConfig) -> None: + """Destroy Docker Compose deployment. + + Uses shared destroy_compose() from common.py to avoid code duplication. + + Args: + config: Parsed configuration Pydantic model + """ + _ = config # Reserved for future use (project name, networks, etc.) + # Use the same manifests directory as generate_manifests + deploy_dir = get_deploy_dir() + output_dir = getattr(self, "_last_output_dir", deploy_dir / "manifests" / "compose") + compose_file = output_dir / "docker-compose.yaml" + destroy_compose(compose_file, verbose=self.verbose) diff --git a/mcpgateway/tools/builder/factory.py b/mcpgateway/tools/builder/factory.py new file mode 100644 index 000000000..1353bd733 --- /dev/null +++ b/mcpgateway/tools/builder/factory.py @@ -0,0 +1,146 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/tools/builder/factory.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Factory for creating MCP Stack deployment implementations. + +This module provides a factory pattern for creating the appropriate deployment +implementation (Dagger or Plain Python) based on availability and user preference. + +The factory handles graceful fallback from Dagger to Python if dependencies are +unavailable, ensuring the deployment system works in various environments. + +Example: + >>> deployer, mode = DeployFactory.create_deployer("dagger", verbose=False) + ⚠ Dagger not installed. Using plain python. + >>> # Validate configuration (output varies by config) + >>> # deployer.validate("mcp-stack.yaml") +""" + +# Standard +from enum import Enum + +# Third-Party +from rich.console import Console + +# First-Party +from mcpgateway.tools.builder.pipeline import CICDModule + + +class CICDTypes(str, Enum): + """Deployment implementation types. + + Attributes: + DAGGER: Dagger-based implementation (optimal performance) + PYTHON: Plain Python implementation (fallback, no dependencies) + + Examples: + >>> # Test enum values + >>> CICDTypes.DAGGER.value + 'dagger' + >>> CICDTypes.PYTHON.value + 'python' + + >>> # Test enum comparison + >>> CICDTypes.DAGGER == "dagger" + True + >>> CICDTypes.PYTHON == "python" + True + + >>> # Test enum membership + >>> "dagger" in [t.value for t in CICDTypes] + True + >>> "python" in [t.value for t in CICDTypes] + True + + >>> # Test enum iteration + >>> types = list(CICDTypes) + >>> len(types) + 2 + >>> CICDTypes.DAGGER in types + True + """ + + DAGGER = "dagger" + PYTHON = "python" + + +console = Console() + + +class DeployFactory: + """Factory for creating MCP Stack deployment implementations. + + This factory implements the Strategy pattern, allowing dynamic selection + between Dagger and Python implementations based on availability. + """ + + @staticmethod + def create_deployer(deployer: str, verbose: bool = False) -> tuple[CICDModule, CICDTypes]: + """Create a deployment implementation instance. + + Attempts to load the requested deployer type with automatic fallback + to Python implementation if dependencies are missing. + + Args: + deployer: Deployment type to create ("dagger" or "python") + verbose: Enable verbose logging during creation + + Returns: + tuple: (deployer_instance, actual_type) + - deployer_instance: Instance of MCPStackDagger or MCPStackPython + - actual_type: CICDTypes enum indicating which implementation was loaded + + Raises: + RuntimeError: If no implementation can be loaded (critical failure) + + Example: + >>> # Try to load Dagger, fall back to Python if unavailable + >>> deployer, mode = DeployFactory.create_deployer("dagger", verbose=False) + ⚠ Dagger not installed. Using plain python. + >>> if mode == CICDTypes.DAGGER: + ... print("Using optimized Dagger implementation") + ... else: + ... print("Using fallback Python implementation") + Using fallback Python implementation + """ + # Attempt to load Dagger implementation first if requested + if deployer == "dagger": + try: + # First-Party + from mcpgateway.tools.builder.dagger_deploy import DAGGER_AVAILABLE, MCPStackDagger + + # Check if dagger is actually available (not just the module) + if not DAGGER_AVAILABLE: + raise ImportError("Dagger SDK not installed") + + if verbose: + console.print("[green]✓ Dagger module loaded[/green]") + + return (MCPStackDagger(verbose), CICDTypes.DAGGER) + + except ImportError: + # Dagger dependencies not available, fall back to Python + console.print("[yellow]⚠ Dagger not installed. Using plain python.[/yellow]") + + # Load plain Python implementation (fallback or explicitly requested) + try: + # First-Party + from mcpgateway.tools.builder.python_deploy import MCPStackPython + + if verbose and deployer != "dagger": + console.print("[blue]Using plain Python implementation[/blue]") + + return (MCPStackPython(verbose), CICDTypes.PYTHON) + + except ImportError as e: + # Critical failure - neither implementation can be loaded + console.print("[red]✗ ERROR: Cannot import deployment modules[/red]") + console.print(f"[red] Details: {e}[/red]") + console.print("[yellow] Make sure you're running from the project root[/yellow]") + console.print("[yellow] and PYTHONPATH is set correctly[/yellow]") + + # This should never be reached if PYTHONPATH is set correctly + raise RuntimeError(f"Unable to load deployer of type '{deployer}'. ") diff --git a/mcpgateway/tools/builder/pipeline.py b/mcpgateway/tools/builder/pipeline.py new file mode 100644 index 000000000..e7fcd098c --- /dev/null +++ b/mcpgateway/tools/builder/pipeline.py @@ -0,0 +1,367 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/tools/builder/pipeline.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Abstract base class for MCP Stack deployment implementations. + +This module defines the CICDModule interface that all deployment implementations +must implement. It provides a common API for building, deploying, and managing +MCP Gateway stacks with external plugin servers. + +The base class implements shared functionality (validation) while requiring +subclasses to implement deployment-specific logic (build, deploy, etc.). + +Design Pattern: + Strategy Pattern - Different implementations (Dagger vs Python) can be + swapped transparently via the DeployFactory. + +Example: + >>> from mcpgateway.tools.builder.factory import DeployFactory + >>> deployer, mode = DeployFactory.create_deployer("dagger", verbose=False) + ⚠ Dagger not installed. Using plain python. + >>> # Validate configuration (output varies by config) + >>> # deployer.validate("mcp-stack.yaml") + >>> # Async methods must be called with await (see method examples below) +""" + +# Standard +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Optional + +# Third-Party +from pydantic import ValidationError +from rich.console import Console +import yaml + +# First-Party +from mcpgateway.tools.builder.schema import MCPStackConfig + +# Shared console instance for consistent output formatting +console = Console() + + +class CICDModule(ABC): + """Abstract base class for MCP Stack deployment implementations. + + This class defines the interface that all deployment implementations must + implement. It provides common initialization and validation logic while + deferring implementation-specific details to subclasses. + + Attributes: + verbose (bool): Enable verbose output during operations + console (Console): Rich console for formatted output + + Implementations: + - MCPStackDagger: High-performance implementation using Dagger SDK + - MCPStackPython: Fallback implementation using plain Python + Docker/Podman + + Examples: + >>> # Test that CICDModule is abstract + >>> from abc import ABC + >>> issubclass(CICDModule, ABC) + True + + >>> # Test initialization with defaults + >>> class TestDeployer(CICDModule): + ... async def build(self, config_file: str, **kwargs) -> None: + ... pass + ... async def generate_certificates(self, config_file: str) -> None: + ... pass + ... async def deploy(self, config_file: str, **kwargs) -> None: + ... pass + ... async def verify(self, config_file: str, **kwargs) -> None: + ... pass + ... async def destroy(self, config_file: str) -> None: + ... pass + ... def generate_manifests(self, config_file: str, **kwargs) -> Path: + ... return Path(".") + >>> deployer = TestDeployer() + >>> deployer.verbose + False + + >>> # Test initialization with verbose=True + >>> verbose_deployer = TestDeployer(verbose=True) + >>> verbose_deployer.verbose + True + + >>> # Test that console is available + >>> hasattr(deployer, 'console') + True + """ + + def __init__(self, verbose: bool = False): + """Initialize the deployment module. + + Args: + verbose: Enable verbose output during all operations + + Examples: + >>> # Cannot instantiate abstract class directly + >>> try: + ... CICDModule() + ... except TypeError as e: + ... "abstract" in str(e).lower() + True + """ + self.verbose = verbose + self.console = console + + def validate(self, config_file: str) -> None: + """Validate mcp-stack.yaml configuration using Pydantic schemas. + + This method provides comprehensive validation of the MCP stack configuration + using Pydantic models defined in schema.py. It validates: + - Required sections (deployment, gateway, plugins) + - Deployment type (kubernetes or compose) + - Gateway image specification + - Plugin configurations (name, repo/image, etc.) + - Custom business rules (unique names, valid combinations) + + Args: + config_file: Path to mcp-stack.yaml configuration file + + Raises: + ValueError: If configuration is invalid, with formatted error details + ValidationError: If Pydantic schema validation fails + FileNotFoundError: If config_file does not exist + + Examples: + >>> import tempfile + >>> import yaml + >>> from pathlib import Path + >>> # Create a test deployer + >>> class TestDeployer(CICDModule): + ... async def build(self, config_file: str, **kwargs) -> None: + ... pass + ... async def generate_certificates(self, config_file: str) -> None: + ... pass + ... async def deploy(self, config_file: str, **kwargs) -> None: + ... pass + ... async def verify(self, config_file: str, **kwargs) -> None: + ... pass + ... async def destroy(self, config_file: str) -> None: + ... pass + ... def generate_manifests(self, config_file: str, **kwargs) -> Path: + ... return Path(".") + >>> deployer = TestDeployer(verbose=False) + + >>> # Test with valid minimal config + >>> with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + ... config = { + ... 'deployment': {'type': 'compose'}, + ... 'gateway': {'image': 'test:latest'}, + ... 'plugins': [] + ... } + ... yaml.dump(config, f) + ... config_path = f.name + >>> deployer.validate(config_path) + >>> import os + >>> os.unlink(config_path) + + >>> # Test with missing file + >>> try: + ... deployer.validate("/nonexistent/config.yaml") + ... except FileNotFoundError as e: + ... "config.yaml" in str(e) + True + + >>> # Test with invalid config (missing required fields) + >>> with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + ... bad_config = {'deployment': {'type': 'compose'}} + ... yaml.dump(bad_config, f) + ... bad_path = f.name + >>> try: + ... deployer.validate(bad_path) + ... except ValueError as e: + ... "validation failed" in str(e).lower() + True + >>> os.unlink(bad_path) + """ + if self.verbose: + self.console.print(f"[blue]Validating {config_file}...[/blue]") + + # Load YAML configuration + with open(config_file, "r") as f: + config_dict = yaml.safe_load(f) + + # Validate using Pydantic schema + try: + # Local + + MCPStackConfig(**config_dict) + except ValidationError as e: + # Format validation errors for better readability + error_msg = "Configuration validation failed:\n" + for error in e.errors(): + # Join the error location path (e.g., plugins -> 0 -> name) + loc = " -> ".join(str(x) for x in error["loc"]) + error_msg += f" • {loc}: {error['msg']}\n" + raise ValueError(error_msg) from e + + if self.verbose: + self.console.print("[green]✓ Configuration valid[/green]") + + @abstractmethod + async def build(self, config_file: str, plugins_only: bool = False, specific_plugins: Optional[list[str]] = None, no_cache: bool = False, copy_env_templates: bool = False) -> None: + """Build container images for plugins and/or gateway. + + Subclasses must implement this to build Docker/Podman images from + Git repositories or use pre-built images. + + Args: + config_file: Path to mcp-stack.yaml + plugins_only: Only build plugins, skip gateway + specific_plugins: List of specific plugin names to build (optional) + no_cache: Disable build cache for fresh builds + copy_env_templates: Copy .env.template files from cloned repos + + Raises: + RuntimeError: If build fails + ValueError: If plugin configuration is invalid + + Example: + # await deployer.build("mcp-stack.yaml", plugins_only=True) + # ✓ Built OPAPluginFilter + # ✓ Built LLMGuardPlugin + """ + + @abstractmethod + async def generate_certificates(self, config_file: str) -> None: + """Generate mTLS certificates for gateway and plugins. + + Creates a certificate authority (CA) and issues certificates for: + - Gateway (client certificates for connecting to plugins) + - Each plugin (server certificates for accepting connections) + + Certificates are stored in the paths defined in the config's + certificates section (default: ./certs/mcp/). + + Args: + config_file: Path to mcp-stack.yaml + + Raises: + RuntimeError: If certificate generation fails + FileNotFoundError: If required tools (openssl) are not available + + Example: + # await deployer.generate_certificates("mcp-stack.yaml") + # ✓ Certificates generated + """ + + @abstractmethod + async def deploy(self, config_file: str, dry_run: bool = False, skip_build: bool = False, skip_certs: bool = False) -> None: + """Deploy the MCP stack to Kubernetes or Docker Compose. + + This is the main deployment method that orchestrates: + 1. Building containers (unless skip_build=True) + 2. Generating mTLS certificates (unless skip_certs=True or mTLS disabled) + 3. Generating manifests (Kubernetes YAML or docker-compose.yaml) + 4. Applying the deployment (unless dry_run=True) + + Args: + config_file: Path to mcp-stack.yaml + dry_run: Generate manifests without actually deploying + skip_build: Skip building containers (use existing images) + skip_certs: Skip certificate generation (use existing certs) + + Raises: + RuntimeError: If deployment fails at any stage + ValueError: If configuration is invalid + + Example: + # Full deployment + # await deployer.deploy("mcp-stack.yaml") + # ✓ Build complete + # ✓ Certificates generated + # ✓ Deployment complete + + # Dry run (generate manifests only) + # await deployer.deploy("mcp-stack.yaml", dry_run=True) + # ✓ Dry-run complete (no changes made) + """ + + @abstractmethod + async def verify(self, config_file: str, wait: bool = False, timeout: int = 300) -> None: + """Verify deployment health and readiness. + + Checks that all deployed services are healthy and ready: + - Kubernetes: Checks pod status, optionally waits for Ready + - Docker Compose: Checks container status + + Args: + config_file: Path to mcp-stack.yaml + wait: Wait for deployment to become ready + timeout: Maximum time to wait in seconds (default: 300) + + Raises: + RuntimeError: If verification fails or timeout is reached + TimeoutError: If wait=True and deployment doesn't become ready + + Example: + # Quick health check + # await deployer.verify("mcp-stack.yaml") + # NAME READY STATUS RESTARTS AGE + # mcpgateway-xxx 1/1 Running 0 2m + # mcp-plugin-opa-xxx 1/1 Running 0 2m + + # Wait for ready state + # await deployer.verify("mcp-stack.yaml", wait=True, timeout=600) + # ✓ Deployment healthy + """ + + @abstractmethod + async def destroy(self, config_file: str) -> None: + """Destroy the deployed MCP stack. + + Removes all deployed resources: + - Kubernetes: Deletes all resources in the namespace + - Docker Compose: Stops and removes containers, networks, volumes + + WARNING: This is destructive and cannot be undone! + + Args: + config_file: Path to mcp-stack.yaml + + Raises: + RuntimeError: If destruction fails + + Example: + # await deployer.destroy("mcp-stack.yaml") + # ✓ Deployment destroyed + """ + + @abstractmethod + def generate_manifests(self, config_file: str, output_dir: Optional[str] = None) -> Path: + """Generate deployment manifests (Kubernetes YAML or docker-compose.yaml). + + Creates deployment manifests based on configuration: + - Kubernetes: Generates Deployment, Service, ConfigMap, Secret YAML files + - Docker Compose: Generates docker-compose.yaml with all services + + Also generates: + - plugins-config.yaml: Plugin manager configuration for gateway + - Environment files: .env files for each service + + Args: + config_file: Path to mcp-stack.yaml + output_dir: Output directory for manifests (default: ./deploy/manifests) + + Returns: + Path: Directory containing generated manifests + + Raises: + ValueError: If configuration is invalid + OSError: If output directory cannot be created + + Example: + # manifests_path = deployer.generate_manifests("mcp-stack.yaml") + # print(f"Manifests generated in: {manifests_path}") + # Manifests generated in: /path/to/deploy/manifests + + # Custom output directory + # deployer.generate_manifests("mcp-stack.yaml", output_dir="./my-manifests") + # ✓ Manifests generated: ./my-manifests + """ diff --git a/mcpgateway/tools/builder/python_deploy.py b/mcpgateway/tools/builder/python_deploy.py new file mode 100644 index 000000000..a07dc938d --- /dev/null +++ b/mcpgateway/tools/builder/python_deploy.py @@ -0,0 +1,603 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/tools/builder/python_deploy.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Plain Python MCP Stack Deployment Module + +This module provides deployment functionality using only standard Python +and system commands (docker/podman, kubectl, docker-compose). + +This is the fallback implementation when Dagger is not available. +""" + +# Standard +from pathlib import Path +import shutil +import subprocess # nosec B404 +from typing import List, Optional + +# Third-Party +from rich.console import Console +from rich.progress import Progress, SpinnerColumn, TextColumn + +# First-Party +from mcpgateway.tools.builder.common import ( + deploy_compose, + deploy_kubernetes, + destroy_compose, + destroy_kubernetes, + generate_compose_manifests, + generate_kubernetes_manifests, + generate_plugin_config, + get_deploy_dir, + handle_registry_operations, + load_config, + verify_compose, + verify_kubernetes, +) +from mcpgateway.tools.builder.common import copy_env_template as copy_template +from mcpgateway.tools.builder.pipeline import CICDModule +from mcpgateway.tools.builder.schema import BuildableConfig, MCPStackConfig + +console = Console() + + +class MCPStackPython(CICDModule): + """Plain Python implementation of MCP Stack deployment. + + This implementation uses standard Python and system commands (docker/podman, + kubectl, docker-compose) without requiring additional dependencies like Dagger. + + Examples: + >>> # Test class instantiation + >>> deployer = MCPStackPython(verbose=False) + >>> deployer.verbose + False + + >>> # Test with verbose mode + >>> deployer_verbose = MCPStackPython(verbose=True) + >>> deployer_verbose.verbose + True + + >>> # Test that console is available + >>> hasattr(deployer, 'console') + True + + >>> # Test that it's a CICDModule subclass + >>> from mcpgateway.tools.builder.pipeline import CICDModule + >>> isinstance(deployer, CICDModule) + True + """ + + async def build(self, config_file: str, plugins_only: bool = False, specific_plugins: Optional[List[str]] = None, no_cache: bool = False, copy_env_templates: bool = False) -> None: + """Build gateway and plugin containers using docker/podman. + + Args: + config_file: Path to mcp-stack.yaml + plugins_only: Only build plugins, skip gateway + specific_plugins: List of specific plugin names to build + no_cache: Disable build cache + copy_env_templates: Copy .env.template files from cloned repos + + Raises: + Exception: If build fails for any component + """ + config = load_config(config_file) + + # Build gateway (unless plugins_only=True) + if not plugins_only: + gateway = config.gateway + if gateway.repo: + with Progress(SpinnerColumn(), TextColumn("[progress.description]{task.description}"), console=self.console) as progress: + task = progress.add_task("Building gateway...", total=None) + try: + self._build_component(gateway, config, "gateway", no_cache=no_cache) + progress.update(task, completed=1, description="[green]✓ Built gateway[/green]") + except Exception as e: + progress.update(task, completed=1, description="[red]✗ Failed gateway[/red]") + # Print full error after progress bar closes + self.console.print("\n[red bold]Gateway build failed:[/red bold]") + self.console.print(f"[red]{type(e).__name__}: {str(e)}[/red]") + if self.verbose: + # Standard + import traceback + + self.console.print(f"[dim]{traceback.format_exc()}[/dim]") + raise + elif self.verbose: + self.console.print("[dim]Skipping gateway build (using pre-built image)[/dim]") + + # Build plugins + plugins = config.plugins + + if specific_plugins: + plugins = [p for p in plugins if p.name in specific_plugins] + + if not plugins: + self.console.print("[yellow]No plugins to build[/yellow]") + return + + with Progress(SpinnerColumn(), TextColumn("[progress.description]{task.description}"), console=self.console) as progress: + + for plugin in plugins: + plugin_name = plugin.name + + # Skip if pre-built image specified + if plugin.image and not plugin.repo: + task = progress.add_task(f"Skipping {plugin_name} (using pre-built image)", total=1) + progress.update(task, completed=1) + continue + + task = progress.add_task(f"Building {plugin_name}...", total=None) + + try: + self._build_component(plugin, config, plugin_name, no_cache=no_cache, copy_env_templates=copy_env_templates) + progress.update(task, completed=1, description=f"[green]✓ Built {plugin_name}[/green]") + except Exception as e: + progress.update(task, completed=1, description=f"[red]✗ Failed {plugin_name}[/red]") + # Print full error after progress bar closes + self.console.print(f"\n[red bold]Plugin '{plugin_name}' build failed:[/red bold]") + self.console.print(f"[red]{type(e).__name__}: {str(e)}[/red]") + if self.verbose: + # Standard + import traceback + + self.console.print(f"[dim]{traceback.format_exc()}[/dim]") + raise + + async def generate_certificates(self, config_file: str) -> None: + """Generate mTLS certificates for plugins. + + Supports two modes: + 1. Local generation (use_cert_manager=false): Uses Makefile to generate certificates locally + 2. cert-manager (use_cert_manager=true): Skips local generation, cert-manager will create certificates + + Args: + config_file: Path to mcp-stack.yaml + + Raises: + RuntimeError: If make command not found (when using local generation) + """ + config = load_config(config_file) + + # Check if using cert-manager + cert_config = config.certificates + use_cert_manager = cert_config.use_cert_manager if cert_config else False + validity_days = cert_config.validity_days if cert_config else 825 + + if use_cert_manager: + # Skip local generation - cert-manager will handle certificate creation + if self.verbose: + self.console.print("[blue]Using cert-manager for certificate management[/blue]") + self.console.print("[dim]Skipping local certificate generation (cert-manager will create certificates)[/dim]") + return + + # Local certificate generation (backward compatibility) + if self.verbose: + self.console.print("[blue]Generating mTLS certificates locally...[/blue]") + + # Check if make is available + if not shutil.which("make"): + raise RuntimeError("'make' command not found. Cannot generate certificates.") + + # Generate CA + self._run_command(["make", "certs-mcp-ca", f"MCP_CERT_DAYS={validity_days}"]) + + # Generate gateway cert + self._run_command(["make", "certs-mcp-gateway", f"MCP_CERT_DAYS={validity_days}"]) + + # Generate plugin certificates + plugins = config.plugins + for plugin in plugins: + plugin_name = plugin.name + self._run_command(["make", "certs-mcp-plugin", f"PLUGIN_NAME={plugin_name}", f"MCP_CERT_DAYS={validity_days}"]) + + if self.verbose: + self.console.print("[green]✓ Certificates generated locally[/green]") + + async def deploy(self, config_file: str, dry_run: bool = False, skip_build: bool = False, skip_certs: bool = False, output_dir: Optional[str] = None) -> None: + """Deploy MCP stack. + + Args: + config_file: Path to mcp-stack.yaml + dry_run: Generate manifests without deploying + skip_build: Skip building containers + skip_certs: Skip certificate generation + output_dir: Output directory for manifests (default: ./deploy) + + Raises: + ValueError: If unsupported deployment type specified + """ + config = load_config(config_file) + + # Build containers + if not skip_build: + await self.build(config_file) + + # Generate certificates (only if mTLS is enabled) + gateway_mtls = config.gateway.mtls_enabled if config.gateway.mtls_enabled is not None else True + plugin_mtls = any((p.mtls_enabled if p.mtls_enabled is not None else True) for p in config.plugins) + mtls_needed = gateway_mtls or plugin_mtls + + if not skip_certs and mtls_needed: + await self.generate_certificates(config_file) + elif not skip_certs and not mtls_needed: + if self.verbose: + self.console.print("[dim]Skipping certificate generation (mTLS disabled)[/dim]") + + # Generate manifests + manifests_dir = self.generate_manifests(config_file, output_dir=output_dir) + + if dry_run: + self.console.print(f"[yellow]Dry-run: Manifests generated in {manifests_dir}[/yellow]") + return + + # Apply deployment + deployment_type = config.deployment.type + + if deployment_type == "kubernetes": + self._deploy_kubernetes(manifests_dir) + elif deployment_type == "compose": + self._deploy_compose(manifests_dir) + else: + raise ValueError(f"Unsupported deployment type: {deployment_type}") + + async def verify(self, config_file: str, wait: bool = False, timeout: int = 300) -> None: + """Verify deployment health. + + Args: + config_file: Path to mcp-stack.yaml + wait: Wait for deployment to be ready + timeout: Wait timeout in seconds + """ + config = load_config(config_file) + deployment_type = config.deployment.type + + if self.verbose: + self.console.print("[blue]Verifying deployment...[/blue]") + + if deployment_type == "kubernetes": + self._verify_kubernetes(config, wait=wait, timeout=timeout) + elif deployment_type == "compose": + self._verify_compose(config, wait=wait, timeout=timeout) + + async def destroy(self, config_file: str) -> None: + """Destroy deployed MCP stack. + + Args: + config_file: Path to mcp-stack.yaml + """ + config = load_config(config_file) + deployment_type = config.deployment.type + + if self.verbose: + self.console.print("[blue]Destroying deployment...[/blue]") + + if deployment_type == "kubernetes": + self._destroy_kubernetes(config) + elif deployment_type == "compose": + self._destroy_compose(config) + + def generate_manifests(self, config_file: str, output_dir: Optional[str] = None) -> Path: + """Generate deployment manifests. + + Args: + config_file: Path to mcp-stack.yaml + output_dir: Output directory for manifests + + Returns: + Path to generated manifests directory + + Raises: + ValueError: If unsupported deployment type specified + + Examples: + >>> import tempfile + >>> import yaml + >>> from pathlib import Path + >>> deployer = MCPStackPython(verbose=False) + + >>> # Test method signature and return type + >>> import inspect + >>> sig = inspect.signature(deployer.generate_manifests) + >>> 'config_file' in sig.parameters + True + >>> 'output_dir' in sig.parameters + True + >>> sig.return_annotation + + + >>> # Test that method exists and is callable + >>> callable(deployer.generate_manifests) + True + """ + config = load_config(config_file) + deployment_type = config.deployment.type + + if output_dir is None: + deploy_dir = get_deploy_dir() + # Separate subdirectories for kubernetes and compose + output_dir = deploy_dir / "manifests" / deployment_type + else: + output_dir = Path(output_dir) + + output_dir.mkdir(parents=True, exist_ok=True) + + # Store output dir for later use + self._last_output_dir = output_dir + + # Generate plugin config.yaml for gateway (shared function) + generate_plugin_config(config, output_dir, verbose=self.verbose) + + if deployment_type == "kubernetes": + generate_kubernetes_manifests(config, output_dir, verbose=self.verbose) + elif deployment_type == "compose": + generate_compose_manifests(config, output_dir, verbose=self.verbose) + else: + raise ValueError(f"Unsupported deployment type: {deployment_type}") + + return output_dir + + # Private helper methods + + def _detect_container_engine(self, config: MCPStackConfig) -> str: + """Detect available container engine (docker or podman). + + Supports both engine names ("docker", "podman") and full paths ("/opt/podman/bin/podman"). + + Args: + config: MCP Stack configuration containing deployment settings + + Returns: + Name or full path to available engine + + Raises: + RuntimeError: If no container engine found + + Examples: + >>> from mcpgateway.tools.builder.schema import MCPStackConfig, DeploymentConfig, GatewayConfig + >>> deployer = MCPStackPython(verbose=False) + + >>> # Test with docker specified + >>> config = MCPStackConfig( + ... deployment=DeploymentConfig(type="compose", container_engine="docker"), + ... gateway=GatewayConfig(image="test:latest"), + ... plugins=[] + ... ) + >>> result = deployer._detect_container_engine(config) + >>> result in ["docker", "podman"] # Returns available engine + True + + >>> # Test that method returns a string + >>> import shutil + >>> if shutil.which("docker") or shutil.which("podman"): + ... config = MCPStackConfig( + ... deployment=DeploymentConfig(type="compose"), + ... gateway=GatewayConfig(image="test:latest"), + ... plugins=[] + ... ) + ... engine = deployer._detect_container_engine(config) + ... isinstance(engine, str) + ... else: + ... True # Skip test if no container engine available + True + """ + if config.deployment.container_engine: + engine = config.deployment.container_engine + + # Check if it's a full path + if "/" in engine: + if Path(engine).exists() and Path(engine).is_file(): + return engine + else: + raise RuntimeError(f"Specified container engine path does not exist: {engine}") + + # Otherwise treat as command name and check PATH + if shutil.which(engine): + return engine + else: + raise RuntimeError(f"Unable to find specified container engine: {engine}") + + # Auto-detect + if shutil.which("docker"): + return "docker" + elif shutil.which("podman"): + return "podman" + else: + raise RuntimeError("No container engine found. Install docker or podman.") + + def _run_command(self, cmd: List[str], cwd: Optional[Path] = None, capture_output: bool = False) -> subprocess.CompletedProcess: + """Run a shell command. + + Args: + cmd: Command and arguments + cwd: Working directory + capture_output: Capture stdout/stderr + + Returns: + CompletedProcess instance + + Raises: + subprocess.CalledProcessError: If command fails + """ + if self.verbose: + self.console.print(f"[dim]Running: {' '.join(cmd)}[/dim]") + + result = subprocess.run(cmd, cwd=cwd, capture_output=capture_output, text=True, check=True) # nosec B603, B607 + + return result + + def _build_component(self, component: BuildableConfig, config: MCPStackConfig, component_name: str, no_cache: bool = False, copy_env_templates: bool = False) -> None: + """Build a component (gateway or plugin) container using docker/podman. + + Args: + component: Component configuration (GatewayConfig or PluginConfig) + config: Overall stack configuration + component_name: Name of the component (gateway or plugin name) + no_cache: Disable cache + copy_env_templates: Copy .env.template from repo if it exists + + Raises: + ValueError: If component has no repo field + FileNotFoundError: If build context or containerfile not found + """ + repo = component.repo + + container_engine = self._detect_container_engine(config) + + if not repo: + raise ValueError(f"Component '{component_name}' has no 'repo' field") + + # Clone repository + git_ref = component.ref or "main" + clone_dir = Path(f"./build/{component_name}") + clone_dir.mkdir(parents=True, exist_ok=True) + + # Clone or update repo + if (clone_dir / ".git").exists(): + if self.verbose: + self.console.print(f"[dim]Updating {component_name} repository...[/dim]") + self._run_command(["git", "fetch", "origin", git_ref], cwd=clone_dir) + # Checkout what we just fetched (FETCH_HEAD) + self._run_command(["git", "checkout", "FETCH_HEAD"], cwd=clone_dir) + else: + if self.verbose: + self.console.print(f"[dim]Cloning {component_name} repository...[/dim]") + self._run_command(["git", "clone", "--branch", git_ref, "--depth", "1", repo, str(clone_dir)]) + + # Determine build context (subdirectory within repo) + build_context = component.context or "." + build_dir = clone_dir / build_context + + if not build_dir.exists(): + raise FileNotFoundError(f"Build context not found: {build_dir}") + + # Detect Containerfile/Dockerfile + containerfile = component.containerfile or "Containerfile" + containerfile_path = build_dir / containerfile + + if not containerfile_path.exists(): + containerfile = "Dockerfile" + containerfile_path = build_dir / containerfile + if not containerfile_path.exists(): + raise FileNotFoundError(f"No Containerfile or Dockerfile found in {build_dir}") + + # Build container - determine image tag + if component.image: + # Use explicitly specified image name + image_tag = component.image + else: + # Generate default image name based on component type + image_tag = f"mcpgateway-{component_name.lower()}:latest" + + build_cmd = [container_engine, "build", "-f", containerfile, "-t", image_tag] + + if no_cache: + build_cmd.append("--no-cache") + + # Add target stage if specified (for multi-stage builds) + if component.target: + build_cmd.extend(["--target", component.target]) + + # For Docker, add --load to ensure image is loaded into daemon + # (needed for buildx/docker-container driver) + if container_engine == "docker": + build_cmd.append("--load") + + build_cmd.append(".") + + self._run_command(build_cmd, cwd=build_dir) + + # Handle registry operations (tag and push if enabled) + image_tag = handle_registry_operations(component, component_name, image_tag, container_engine, verbose=self.verbose) + + # Copy .env.template if requested and exists + if copy_env_templates: + copy_template(component_name, build_dir, verbose=self.verbose) + + if self.verbose: + self.console.print(f"[green]✓ Built {component_name} -> {image_tag}[/green]") + + def _deploy_kubernetes(self, manifests_dir: Path) -> None: + """Deploy to Kubernetes using kubectl. + + Uses shared deploy_kubernetes() from common.py to avoid code duplication. + + Args: + manifests_dir: Path to directory containing Kubernetes manifests + """ + deploy_kubernetes(manifests_dir, verbose=self.verbose) + + def _deploy_compose(self, manifests_dir: Path) -> None: + """Deploy using Docker Compose. + + Uses shared deploy_compose() from common.py to avoid code duplication. + + Args: + manifests_dir: Path to directory containing compose manifest + """ + compose_file = manifests_dir / "docker-compose.yaml" + deploy_compose(compose_file, verbose=self.verbose) + + def _verify_kubernetes(self, config: MCPStackConfig, wait: bool = False, timeout: int = 300) -> None: + """Verify Kubernetes deployment health. + + Uses shared verify_kubernetes() from common.py to avoid code duplication. + + Args: + config: Parsed configuration Pydantic model + wait: Wait for pods to be ready + timeout: Wait timeout in seconds + """ + namespace = config.deployment.namespace or "mcp-gateway" + output = verify_kubernetes(namespace, wait=wait, timeout=timeout, verbose=self.verbose) + self.console.print(output) + + def _verify_compose(self, config: MCPStackConfig, wait: bool = False, timeout: int = 300) -> None: + """Verify Docker Compose deployment health. + + Uses shared verify_compose() from common.py to avoid code duplication. + + Args: + config: Parsed configuration Pydantic model + wait: Wait for containers to be ready + timeout: Wait timeout in seconds + """ + _ = config, wait, timeout # Reserved for future use + # Use the same manifests directory as generate_manifests + deploy_dir = get_deploy_dir() + output_dir = getattr(self, "_last_output_dir", deploy_dir / "manifests" / "compose") + compose_file = output_dir / "docker-compose.yaml" + output = verify_compose(compose_file, verbose=self.verbose) + self.console.print(output) + + def _destroy_kubernetes(self, config: MCPStackConfig) -> None: + """Destroy Kubernetes deployment. + + Uses shared destroy_kubernetes() from common.py to avoid code duplication. + + Args: + config: Parsed configuration Pydantic model + """ + _ = config # Reserved for future use (namespace, labels, etc.) + # Use the same manifests directory as generate_manifests + deploy_dir = get_deploy_dir() + manifests_dir = getattr(self, "_last_output_dir", deploy_dir / "manifests" / "kubernetes") + destroy_kubernetes(manifests_dir, verbose=self.verbose) + + def _destroy_compose(self, config: MCPStackConfig) -> None: + """Destroy Docker Compose deployment. + + Uses shared destroy_compose() from common.py to avoid code duplication. + + Args: + config: Parsed configuration Pydantic model + """ + _ = config # Reserved for future use (project name, networks, etc.) + # Use the same manifests directory as generate_manifests + deploy_dir = get_deploy_dir() + output_dir = getattr(self, "_last_output_dir", deploy_dir / "manifests" / "compose") + compose_file = output_dir / "docker-compose.yaml" + destroy_compose(compose_file, verbose=self.verbose) diff --git a/mcpgateway/tools/builder/schema.py b/mcpgateway/tools/builder/schema.py new file mode 100644 index 000000000..657398a74 --- /dev/null +++ b/mcpgateway/tools/builder/schema.py @@ -0,0 +1,475 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/tools/builder/schema.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Pydantic schemas for MCP Stack configuration validation""" + +# Standard +from typing import Any, Dict, List, Literal, Optional + +# Third-Party +from pydantic import BaseModel, ConfigDict, Field, field_validator + + +class OpenShiftConfig(BaseModel): + """OpenShift-specific configuration. + + Routes are OpenShift's native way of exposing services externally (predates Kubernetes Ingress). + They provide built-in TLS termination and are integrated with OpenShift's router/HAProxy infrastructure. + + Attributes: + create_routes: Create OpenShift Route resources for external access (default: False) + domain: OpenShift apps domain for route hostnames (default: auto-detected from cluster) + tls_termination: TLS termination mode - edge, passthrough, or reencrypt (default: edge) + + Examples: + >>> # Test with default values + >>> config = OpenShiftConfig() + >>> config.create_routes + False + >>> config.tls_termination + 'edge' + + >>> # Test with custom values + >>> config = OpenShiftConfig( + ... create_routes=True, + ... domain="apps.example.com", + ... tls_termination="passthrough" + ... ) + >>> config.create_routes + True + >>> config.domain + 'apps.example.com' + >>> config.tls_termination + 'passthrough' + + >>> # Test valid TLS termination modes + >>> for mode in ["edge", "passthrough", "reencrypt"]: + ... cfg = OpenShiftConfig(tls_termination=mode) + ... cfg.tls_termination == mode + True + True + True + """ + + create_routes: bool = Field(False, description="Create OpenShift Route resources") + domain: Optional[str] = Field(None, description="OpenShift apps domain (e.g., apps-crc.testing)") + tls_termination: Literal["edge", "passthrough", "reencrypt"] = Field("edge", description="TLS termination mode") + + +class DeploymentConfig(BaseModel): + """Deployment configuration + + Examples: + >>> # Test compose deployment + >>> config = DeploymentConfig(type="compose", project_name="test-project") + >>> config.type + 'compose' + >>> config.project_name + 'test-project' + + >>> # Test kubernetes deployment + >>> config = DeploymentConfig(type="kubernetes", namespace="mcp-test") + >>> config.type + 'kubernetes' + >>> config.namespace + 'mcp-test' + + >>> # Test container engine options + >>> config = DeploymentConfig(type="compose", container_engine="podman") + >>> config.container_engine + 'podman' + + >>> # Test with OpenShift config + >>> config = DeploymentConfig( + ... type="kubernetes", + ... namespace="test", + ... openshift=OpenShiftConfig(create_routes=True) + ... ) + >>> config.openshift.create_routes + True + """ + + type: Literal["kubernetes", "compose"] = Field(..., description="Deployment type") + container_engine: Optional[str] = Field(default=None, description="Container engine: 'podman', 'docker', or full path (e.g., '/opt/podman/bin/podman')") + project_name: Optional[str] = Field(None, description="Project name for compose") + namespace: Optional[str] = Field(None, description="Namespace for Kubernetes") + openshift: Optional[OpenShiftConfig] = Field(None, description="OpenShift-specific configuration") + + +class RegistryConfig(BaseModel): + """Container registry configuration. + + Optional configuration for pushing built images to a container registry. + When enabled, images will be tagged with the full registry path and optionally pushed. + + Authentication: + Users must authenticate to the registry before running the build: + - Docker Hub: `docker login` + - Quay.io: `podman login quay.io` + - OpenShift internal: `podman login $(oc registry info) -u $(oc whoami) -p $(oc whoami -t)` + - Private registry: `podman login your-registry.com -u username` + + Attributes: + enabled: Enable registry integration (default: False) + url: Registry URL (e.g., "docker.io", "quay.io", "default-route-openshift-image-registry.apps-crc.testing") + namespace: Registry namespace/organization/project (e.g., "myorg", "mcp-gateway-test") + push: Push image after build (default: True) + image_pull_policy: Kubernetes imagePullPolicy (default: "IfNotPresent") + + Examples: + >>> # Test with defaults (registry disabled) + >>> config = RegistryConfig() + >>> config.enabled + False + >>> config.push + True + >>> config.image_pull_policy + 'IfNotPresent' + + >>> # Test Docker Hub configuration + >>> config = RegistryConfig( + ... enabled=True, + ... url="docker.io", + ... namespace="myusername" + ... ) + >>> config.enabled + True + >>> config.url + 'docker.io' + >>> config.namespace + 'myusername' + + >>> # Test with custom pull policy + >>> config = RegistryConfig( + ... enabled=True, + ... url="quay.io", + ... namespace="myorg", + ... image_pull_policy="Always" + ... ) + >>> config.image_pull_policy + 'Always' + + >>> # Test tag-only mode (no push) + >>> config = RegistryConfig( + ... enabled=True, + ... url="registry.local", + ... namespace="test", + ... push=False + ... ) + >>> config.push + False + """ + + enabled: bool = Field(False, description="Enable registry push") + url: Optional[str] = Field(None, description="Registry URL (e.g., docker.io, quay.io, or internal registry)") + namespace: Optional[str] = Field(None, description="Registry namespace/organization/project") + push: bool = Field(True, description="Push image after build") + image_pull_policy: Optional[str] = Field("IfNotPresent", description="Kubernetes imagePullPolicy (IfNotPresent, Always, Never)") + + +class BuildableConfig(BaseModel): + """Base class for components that can be built from source or use pre-built images. + + This base class provides common configuration for both gateway and plugins, + supporting two build modes: + 1. Pre-built image: Specify only 'image' field + 2. Build from source: Specify 'repo' and optionally 'ref', 'context', 'containerfile', 'target' + + Attributes: + image: Pre-built Docker image name (e.g., "mcpgateway/mcpgateway:latest") + repo: Git repository URL to build from + ref: Git branch/tag/commit to checkout (default: "main") + context: Build context subdirectory within repo (default: ".") + containerfile: Path to Containerfile/Dockerfile (default: "Containerfile") + target: Target stage for multi-stage builds (optional) + host_port: Host port mapping for direct access (optional) + env_vars: Environment variables for container + env_file: Path to environment file (.env) + mtls_enabled: Enable mutual TLS authentication (default: True) + """ + + # Allow attribute assignment after model creation (needed for auto-detection of env_file) + model_config = ConfigDict(validate_assignment=True) + + # Build configuration + image: Optional[str] = Field(None, description="Pre-built Docker image") + repo: Optional[str] = Field(None, description="Git repository URL") + ref: Optional[str] = Field("main", description="Git branch/tag/commit") + context: Optional[str] = Field(".", description="Build context subdirectory") + containerfile: Optional[str] = Field("Containerfile", description="Containerfile path") + target: Optional[str] = Field(None, description="Multi-stage build target") + + # Runtime configuration + host_port: Optional[int] = Field(None, description="Host port mapping") + env_vars: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Environment variables") + env_file: Optional[str] = Field(None, description="Path to environment file (.env)") + mtls_enabled: Optional[bool] = Field(True, description="Enable mTLS") + + # Registry configuration + registry: Optional[RegistryConfig] = Field(None, description="Container registry configuration") + + def model_post_init(self, _: Any) -> None: + """Validate that either image or repo is specified + + Raises: + ValueError: If neither image nor repo is specified + + Examples: + >>> # Test that error is raised when neither image nor repo specified + >>> try: + ... # BuildableConfig can't be instantiated directly, use GatewayConfig + ... from mcpgateway.tools.builder.schema import GatewayConfig + ... GatewayConfig() + ... except ValueError as e: + ... "must specify either 'image' or 'repo'" in str(e) + True + + >>> # Test valid config with image + >>> from mcpgateway.tools.builder.schema import GatewayConfig + >>> config = GatewayConfig(image="mcpgateway:latest") + >>> config.image + 'mcpgateway:latest' + + >>> # Test valid config with repo + >>> from mcpgateway.tools.builder.schema import GatewayConfig + >>> config = GatewayConfig(repo="https://github.com/example/repo") + >>> config.repo + 'https://github.com/example/repo' + """ + if not self.image and not self.repo: + component_type = self.__class__.__name__.replace("Config", "") + raise ValueError(f"{component_type} must specify either 'image' or 'repo'") + + +class GatewayConfig(BuildableConfig): + """Gateway configuration. + + Extends BuildableConfig to support either pre-built gateway images or + building the gateway from source repository. + + Attributes: + port: Gateway internal port (default: 4444) + + Examples: + >>> # Test with pre-built image + >>> config = GatewayConfig(image="mcpgateway:latest") + >>> config.image + 'mcpgateway:latest' + >>> config.port + 4444 + + >>> # Test with custom port + >>> config = GatewayConfig(image="mcpgateway:latest", port=8080) + >>> config.port + 8080 + + >>> # Test with source repository + >>> config = GatewayConfig( + ... repo="https://github.com/example/gateway", + ... ref="v1.0.0" + ... ) + >>> config.repo + 'https://github.com/example/gateway' + >>> config.ref + 'v1.0.0' + + >>> # Test with environment variables + >>> config = GatewayConfig( + ... image="mcpgateway:latest", + ... env_vars={"LOG_LEVEL": "DEBUG", "PORT": "4444"} + ... ) + >>> config.env_vars['LOG_LEVEL'] + 'DEBUG' + + >>> # Test with mTLS enabled + >>> config = GatewayConfig(image="mcpgateway:latest", mtls_enabled=True) + >>> config.mtls_enabled + True + """ + + port: Optional[int] = Field(4444, description="Gateway port") + + +class PluginConfig(BuildableConfig): + """Plugin configuration. + + Extends BuildableConfig to support plugin-specific configuration while + inheriting common build and runtime capabilities. + + Attributes: + name: Unique plugin identifier + port: Plugin internal port (default: 8000) + expose_port: Whether to expose plugin port on host (default: False) + plugin_overrides: Plugin-specific override configuration + """ + + name: str = Field(..., description="Plugin name") + port: Optional[int] = Field(8000, description="Plugin port") + expose_port: Optional[bool] = Field(False, description="Expose port on host") + plugin_overrides: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Plugin overrides") + + @field_validator("name") + @classmethod + def validate_name(cls, v: str) -> str: + """Validate plugin name is non-empty + + Args: + v: Plugin name value to validate + + Returns: + Validated plugin name + + Raises: + ValueError: If plugin name is empty or whitespace only + + Examples: + >>> # Test valid plugin names + >>> PluginConfig.validate_name("my-plugin") + 'my-plugin' + >>> PluginConfig.validate_name("plugin_123") + 'plugin_123' + >>> PluginConfig.validate_name("TestPlugin") + 'TestPlugin' + + >>> # Test empty name raises error + >>> try: + ... PluginConfig.validate_name("") + ... except ValueError as e: + ... "cannot be empty" in str(e) + True + + >>> # Test whitespace-only name raises error + >>> try: + ... PluginConfig.validate_name(" ") + ... except ValueError as e: + ... "cannot be empty" in str(e) + True + """ + if not v or not v.strip(): + raise ValueError("Plugin name cannot be empty") + return v + + +class CertificatesConfig(BaseModel): + """Certificate configuration. + + Supports two modes: + 1. Local certificate generation (use_cert_manager=false, default): + - Certificates generated locally using OpenSSL (via Makefile) + - Deployed to Kubernetes as secrets via kubectl + - Manual rotation required before expiry + + 2. cert-manager integration (use_cert_manager=true, Kubernetes only): + - Certificates managed by cert-manager controller + - Automatic renewal before expiry (default: at 2/3 of lifetime) + - Native Kubernetes Certificate resources + - Requires cert-manager to be installed in cluster + + Attributes: + validity_days: Certificate validity period in days (default: 825 ≈ 2.25 years) + auto_generate: Auto-generate certificates locally (default: True) + use_cert_manager: Use cert-manager for certificate management (default: False, Kubernetes only) + cert_manager_issuer: Name of cert-manager Issuer/ClusterIssuer (default: "mcp-ca-issuer") + cert_manager_kind: Type of issuer - Issuer or ClusterIssuer (default: "Issuer") + ca_path: Path to CA certificates for local generation (default: "./certs/mcp/ca") + gateway_path: Path to gateway certificates for local generation (default: "./certs/mcp/gateway") + plugins_path: Path to plugin certificates for local generation (default: "./certs/mcp/plugins") + """ + + validity_days: Optional[int] = Field(825, description="Certificate validity in days") + auto_generate: Optional[bool] = Field(True, description="Auto-generate certificates locally") + + # cert-manager integration (Kubernetes only) + use_cert_manager: Optional[bool] = Field(False, description="Use cert-manager for certificate management (Kubernetes only)") + cert_manager_issuer: Optional[str] = Field("mcp-ca-issuer", description="cert-manager Issuer/ClusterIssuer name") + cert_manager_kind: Optional[Literal["Issuer", "ClusterIssuer"]] = Field("Issuer", description="cert-manager issuer kind") + + ca_path: Optional[str] = Field("./certs/mcp/ca", description="CA certificate path") + gateway_path: Optional[str] = Field("./certs/mcp/gateway", description="Gateway cert path") + plugins_path: Optional[str] = Field("./certs/mcp/plugins", description="Plugins cert path") + + +class PostgresConfig(BaseModel): + """PostgreSQL database configuration""" + + enabled: Optional[bool] = Field(True, description="Enable PostgreSQL deployment") + image: Optional[str] = Field("quay.io/sclorg/postgresql-15-c9s:latest", description="PostgreSQL image (default is OpenShift-compatible)") + database: Optional[str] = Field("mcp", description="Database name") + user: Optional[str] = Field("postgres", description="Database user") + password: Optional[str] = Field("mysecretpassword", description="Database password") + storage_size: Optional[str] = Field("10Gi", description="Persistent volume size (Kubernetes only)") + storage_class: Optional[str] = Field(None, description="Storage class name (Kubernetes only)") + + +class RedisConfig(BaseModel): + """Redis cache configuration""" + + enabled: Optional[bool] = Field(True, description="Enable Redis deployment") + image: Optional[str] = Field("redis:latest", description="Redis image") + + +class InfrastructureConfig(BaseModel): + """Infrastructure services configuration""" + + postgres: Optional[PostgresConfig] = Field(default_factory=PostgresConfig) + redis: Optional[RedisConfig] = Field(default_factory=RedisConfig) + + +class MCPStackConfig(BaseModel): + """Complete MCP Stack configuration""" + + deployment: DeploymentConfig + gateway: GatewayConfig + plugins: List[PluginConfig] = Field(default_factory=list) + certificates: Optional[CertificatesConfig] = Field(default_factory=CertificatesConfig) + infrastructure: Optional[InfrastructureConfig] = Field(default_factory=InfrastructureConfig) + + @field_validator("plugins") + @classmethod + def validate_plugin_names_unique(cls, v: List[PluginConfig]) -> List[PluginConfig]: + """Ensure plugin names are unique + + Args: + v: List of plugin configurations to validate + + Returns: + Validated list of plugin configurations + + Raises: + ValueError: If duplicate plugin names are found + + Examples: + >>> from mcpgateway.tools.builder.schema import PluginConfig + >>> # Test with unique names (valid) + >>> plugins = [ + ... PluginConfig(name="plugin1", image="img1:latest"), + ... PluginConfig(name="plugin2", image="img2:latest") + ... ] + >>> result = MCPStackConfig.validate_plugin_names_unique(plugins) + >>> len(result) == 2 + True + + >>> # Test with duplicate names (invalid) + >>> try: + ... duplicates = [ + ... PluginConfig(name="duplicate", image="img1:latest"), + ... PluginConfig(name="duplicate", image="img2:latest") + ... ] + ... MCPStackConfig.validate_plugin_names_unique(duplicates) + ... except ValueError as e: + ... "Duplicate plugin names found" in str(e) + True + + >>> # Test with empty list (valid) + >>> empty = MCPStackConfig.validate_plugin_names_unique([]) + >>> len(empty) == 0 + True + """ + names = [p.name for p in v] + if len(names) != len(set(names)): + duplicates = [name for name in names if names.count(name) > 1] + raise ValueError(f"Duplicate plugin names found: {duplicates}") + return v diff --git a/mcpgateway/tools/builder/templates/compose/docker-compose.yaml.j2 b/mcpgateway/tools/builder/templates/compose/docker-compose.yaml.j2 new file mode 100644 index 000000000..aaf2fc04e --- /dev/null +++ b/mcpgateway/tools/builder/templates/compose/docker-compose.yaml.j2 @@ -0,0 +1,198 @@ +# Location: ./mcpgateway/tools/builder/templates/compose/docker-compose.yaml.j2 +# Copyright 2025 +# SPDX-License-Identifier: Apache-2.0 +# Authors: Teryl Taylor +# Docker Compose manifest for MCP Stack +# Generated from mcp-stack.yaml + +version: '3.8' + +networks: + mcp-network: + driver: bridge + +volumes: + gateway-data: + driver: local + pgdata: + driver: local +{% for plugin in plugins %} + {{ plugin.name | lower }}-data: + driver: local +{% endfor %} + +services: + # MCP Gateway + mcpgateway: + image: {{ gateway.image }} + container_name: mcpgateway + hostname: mcpgateway + + {% if gateway.env_file is defined %} + env_file: + - {{ gateway.env_file }} + {% endif %} + + environment: + {% if gateway.env_vars is defined and gateway.env_vars %} + # User-defined environment variables + {% for key, value in gateway.env_vars.items() %} + - {{ key }}={{ value }} + {% endfor %} + {% endif %} + # Database configuration + - DATABASE_URL=postgresql://postgres:$${POSTGRES_PASSWORD:-mysecretpassword}@postgres:5432/mcp + - REDIS_URL=redis://redis:6379/0 + {% if gateway.mtls_enabled | default(true) %} + # mTLS client configuration (gateway connects to external plugins) + - PLUGINS_CLIENT_MTLS_CA_BUNDLE=/app/certs/mcp/ca/ca.crt + - PLUGINS_CLIENT_MTLS_CERTFILE=/app/certs/mcp/gateway/client.crt + - PLUGINS_CLIENT_MTLS_KEYFILE=/app/certs/mcp/gateway/client.key + - PLUGINS_CLIENT_MTLS_VERIFY={{ gateway.mtls_verify | default('true') }} + - PLUGINS_CLIENT_MTLS_CHECK_HOSTNAME={{ gateway.mtls_check_hostname | default('false') }} + {% endif %} + + ports: + - "{{ gateway.host_port | default(4444) }}:{{ gateway.port | default(4444) }}" + + volumes: + - gateway-data:/app/data + {% if gateway.mtls_enabled | default(true) %} + - {{ cert_paths.gateway_cert_dir }}:/app/certs/mcp/gateway:ro + - {{ cert_paths.ca_cert_file }}:/app/certs/mcp/ca/ca.crt:ro + {% endif %} + # Auto-generated plugin configuration + - ./plugins-config.yaml:/app/config/plugins.yaml:ro + + networks: + - mcp-network + + restart: unless-stopped + + healthcheck: + test: ["CMD", "python3", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:{{ gateway.port | default(4444) }}/health').read()"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 40s + + depends_on: + postgres: + condition: service_healthy + redis: + condition: service_started +{% for plugin in plugins %} {{ plugin.name | lower }}: + condition: service_started +{% endfor %} + +{% for plugin in plugins %} + # Plugin: {{ plugin.name }} + {{ plugin.name | lower }}: + image: {{ plugin.image | default('mcpgateway-' + plugin.name | lower + ':latest') }} + container_name: mcp-plugin-{{ plugin.name | lower }} + hostname: {{ plugin.name | lower }} + + {% if plugin.env_file is defined %} + env_file: + - {{ plugin.env_file }} + {% endif %} + + environment: + {% if plugin.env_vars is defined and plugin.env_vars %} + # User-defined environment variables + {% for key, value in plugin.env_vars.items() %} + - {{ key }}={{ value }} + {% endfor %} + {% endif %} + {% if plugin.mtls_enabled | default(true) %} + # mTLS server configuration (plugin accepts gateway connections) + - PLUGINS_TRANSPORT=http + - PLUGINS_SERVER_HOST=0.0.0.0 + - PLUGINS_SERVER_PORT={{ plugin.port | default(8000) }} + - PLUGINS_SERVER_SSL_ENABLED=true + - PLUGINS_SERVER_SSL_KEYFILE=/app/certs/mcp/server.key + - PLUGINS_SERVER_SSL_CERTFILE=/app/certs/mcp/server.crt + - PLUGINS_SERVER_SSL_CA_CERTS=/app/certs/mcp/ca.crt + - PLUGINS_SERVER_SSL_CERT_REQS=2 # CERT_REQUIRED - enforce client certificates + {% endif %} + + {% if plugin.expose_port | default(false) %} + ports: + - "{{ plugin.host_port }}:{{ plugin.port | default(8000) }}" + {% endif %} + + volumes: + - {{ plugin.name | lower }}-data:/app/data + {% if plugin.mtls_enabled | default(true) %} + - {{ cert_paths.plugins_cert_base }}/{{ plugin.name }}:/app/certs/mcp:ro + {% endif %} + + networks: + - mcp-network + + restart: unless-stopped + + healthcheck: + {% if plugin.mtls_enabled | default(true) %} + # When mTLS is enabled, health check uses separate HTTP server on port+1000 + test: ["CMD", "python3", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:{{ (plugin.port | default(8000)) + 1000 }}/health').read()"] + {% else %} + # When mTLS is disabled, health check uses main server + test: ["CMD", "python3", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:{{ plugin.port | default(8000) }}/health').read()"] + {% endif %} + interval: 30s + timeout: 10s + retries: 3 + start_period: 40s + + {% if plugin.depends_on is defined %} + depends_on: + {% for dep in plugin.depends_on %} + - {{ dep }} + {% endfor %} + {% endif %} + +{% endfor %} + # PostgreSQL Database + postgres: + image: postgres:17 + container_name: mcp-postgres + hostname: postgres + + environment: + - POSTGRES_USER=postgres + - POSTGRES_PASSWORD=$${POSTGRES_PASSWORD:-mysecretpassword} + - POSTGRES_DB=mcp + + ports: + - "5432:5432" + + volumes: + - pgdata:/var/lib/postgresql/data + + networks: + - mcp-network + + restart: unless-stopped + + healthcheck: + test: ["CMD-SHELL", "pg_isready -U postgres"] + interval: 30s + timeout: 5s + retries: 5 + start_period: 20s + + # Redis Cache + redis: + image: redis:latest + container_name: mcp-redis + hostname: redis + + ports: + - "6379:6379" + + networks: + - mcp-network + + restart: unless-stopped + diff --git a/mcpgateway/tools/builder/templates/kubernetes/cert-manager-certificates.yaml.j2 b/mcpgateway/tools/builder/templates/kubernetes/cert-manager-certificates.yaml.j2 new file mode 100644 index 000000000..e11963573 --- /dev/null +++ b/mcpgateway/tools/builder/templates/kubernetes/cert-manager-certificates.yaml.j2 @@ -0,0 +1,62 @@ +# Location: ./mcpgateway/tools/builder/templates/kubernetes/cert-manager-certificates.yaml.j2 +# Copyright 2025 +# SPDX-License-Identifier: Apache-2.0 +# Authors: Teryl Taylor +# cert-manager Certificate Resources +# Gateway Certificate +apiVersion: cert-manager.io/v1 +kind: Certificate +metadata: + name: mcp-{{ gateway_name }}-cert + namespace: {{ namespace }} +spec: + secretName: mcp-{{ gateway_name }}-server-cert + duration: {{ duration }}h + renewBefore: {{ renew_before }}h + isCA: false + privateKey: + algorithm: RSA + size: 2048 + usages: + - digital signature + - key encipherment + - server auth + - client auth + dnsNames: + - {{ gateway_name }} + - {{ gateway_name }}.{{ namespace }} + - {{ gateway_name }}.{{ namespace }}.svc + - {{ gateway_name }}.{{ namespace }}.svc.cluster.local + issuerRef: + name: {{ issuer_name }} + kind: {{ issuer_kind }} +{% for plugin in plugins %} +--- +# Plugin {{ plugin.name }} Certificate +apiVersion: cert-manager.io/v1 +kind: Certificate +metadata: + name: mcp-{{ plugin.name }}-cert + namespace: {{ namespace }} +spec: + secretName: mcp-{{ plugin.name }}-server-cert + duration: {{ duration }}h + renewBefore: {{ renew_before }}h + isCA: false + privateKey: + algorithm: RSA + size: 2048 + usages: + - digital signature + - key encipherment + - server auth + - client auth + dnsNames: + - {{ plugin.name }} + - {{ plugin.name }}.{{ namespace }} + - {{ plugin.name }}.{{ namespace }}.svc + - {{ plugin.name }}.{{ namespace }}.svc.cluster.local + issuerRef: + name: {{ issuer_name }} + kind: {{ issuer_kind }} +{% endfor %} diff --git a/mcpgateway/tools/builder/templates/kubernetes/cert-secrets.yaml.j2 b/mcpgateway/tools/builder/templates/kubernetes/cert-secrets.yaml.j2 new file mode 100644 index 000000000..67e5a1e87 --- /dev/null +++ b/mcpgateway/tools/builder/templates/kubernetes/cert-secrets.yaml.j2 @@ -0,0 +1,38 @@ +# Location: ./mcpgateway/tools/builder/templates/kubernetes/cert-secrets.yaml.j2 +# Copyright 2025 +# SPDX-License-Identifier: Apache-2.0 +# Authors: Teryl Taylor +# mTLS Certificate Secrets +# CA Certificate (shared by all components) +apiVersion: v1 +kind: Secret +metadata: + name: mcp-ca-secret + namespace: {{ namespace }} +type: Opaque +data: + ca.crt: {{ ca_cert_b64 }} +--- +# Gateway Client Certificate +apiVersion: v1 +kind: Secret +metadata: + name: mcp-{{ gateway_name }}-server-cert + namespace: {{ namespace }} +type: kubernetes.io/tls +data: + tls.crt: {{ gateway_cert_b64 }} + tls.key: {{ gateway_key_b64 }} +{% for plugin in plugins %} +--- +# Plugin {{ plugin.name }} Server Certificate +apiVersion: v1 +kind: Secret +metadata: + name: mcp-{{ plugin.name }}-server-cert + namespace: {{ namespace }} +type: kubernetes.io/tls +data: + tls.crt: {{ plugin.cert_b64 }} + tls.key: {{ plugin.key_b64 }} +{% endfor %} diff --git a/mcpgateway/tools/builder/templates/kubernetes/deployment.yaml.j2 b/mcpgateway/tools/builder/templates/kubernetes/deployment.yaml.j2 new file mode 100644 index 000000000..843bb5fd4 --- /dev/null +++ b/mcpgateway/tools/builder/templates/kubernetes/deployment.yaml.j2 @@ -0,0 +1,248 @@ +# Location: ./mcpgateway/tools/builder/templates/kubernetes/deployment.yaml.j2 +# Copyright 2025 +# SPDX-License-Identifier: Apache-2.0 +# Authors: Teryl Taylor +# Kubernetes Deployment for {{ name }} +apiVersion: v1 +kind: Namespace +metadata: + name: {{ namespace }} +--- +apiVersion: v1 +kind: Secret +metadata: + name: {{ name }}-env + namespace: {{ namespace }} +type: Opaque +stringData: +{% if env_vars is defined and env_vars %} + # Environment variables + # NOTE: In production, these should come from CI/CD vault secrets +{% for key, value in env_vars.items() %} + {{ key }}: "{{ value }}" +{% endfor %} +{% endif %} +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: {{ name }} + namespace: {{ namespace }} + labels: + app: {{ name }} + component: {% if name == 'mcpgateway' %}gateway{% else %}plugin{% endif %} +spec: + replicas: {{ replicas | default(1) }} + selector: + matchLabels: + app: {{ name }} + template: + metadata: + labels: + app: {{ name }} + component: {% if name == 'mcpgateway' %}gateway{% else %}plugin{% endif %} + spec: + {% if image_pull_secret is defined %} + imagePullSecrets: + - name: {{ image_pull_secret }} + {% endif %} + + {% if init_containers is defined %} + initContainers: + {% for init_container in init_containers %} + - name: {{ init_container.name }} + image: {{ init_container.image }} + command: {{ init_container.command | tojson }} + {% endfor %} + {% endif %} + + containers: + - name: {{ name }} + image: {{ image }} + imagePullPolicy: {{ image_pull_policy | default('IfNotPresent') }} + + ports: + - name: http + containerPort: {{ port | default(8000) }} + protocol: TCP + {% if mtls_enabled | default(true) and name != 'mcpgateway' %} + - name: health + containerPort: 9000 + protocol: TCP + {% endif %} + + env: + {% if mtls_enabled | default(true) %} + {% if name == 'mcpgateway' %} + # mTLS client configuration (gateway connects to plugins) + - name: PLUGINS_CLIENT_MTLS_CA_BUNDLE + value: "/app/certs/ca/ca.crt" + - name: PLUGINS_CLIENT_MTLS_CERTFILE + value: "/app/certs/mcp/tls.crt" + - name: PLUGINS_CLIENT_MTLS_KEYFILE + value: "/app/certs/mcp/tls.key" + - name: PLUGINS_CLIENT_MTLS_VERIFY + value: "{{ mtls_verify | default('true') }}" + - name: PLUGINS_CLIENT_MTLS_CHECK_HOSTNAME + value: "{{ mtls_check_hostname | default('false') }}" + {% else %} + # mTLS server configuration (plugin accepts gateway connections) + - name: PLUGINS_TRANSPORT + value: "http" + - name: PLUGINS_SERVER_HOST + value: "0.0.0.0" + - name: PLUGINS_SERVER_PORT + value: "{{ port | default(8000) }}" + - name: PLUGINS_SERVER_SSL_ENABLED + value: "true" + - name: PLUGINS_SERVER_SSL_KEYFILE + value: "/app/certs/mcp/tls.key" + - name: PLUGINS_SERVER_SSL_CERTFILE + value: "/app/certs/mcp/tls.crt" + - name: PLUGINS_SERVER_SSL_CA_CERTS + value: "/app/certs/ca/ca.crt" + - name: PLUGINS_SERVER_SSL_CERT_REQS + value: "2" # CERT_REQUIRED + {% endif %} + {% endif %} + + envFrom: + - secretRef: + name: {{ name }}-env + + {% if health_check | default(true) %} + livenessProbe: + httpGet: + path: /health + {% if mtls_enabled | default(true) and name != 'mcpgateway' %} + # Plugin with mTLS: use separate health check server on port 9000 + port: health + scheme: HTTP + {% else %} + # Gateway or non-mTLS: health check on main HTTP port + port: http + scheme: HTTP + {% endif %} + initialDelaySeconds: 30 + periodSeconds: 10 + timeoutSeconds: 5 + failureThreshold: 3 + + readinessProbe: + httpGet: + path: /health + {% if mtls_enabled | default(true) and name != 'mcpgateway' %} + # Plugin with mTLS: use separate health check server on port 9000 + port: health + scheme: HTTP + {% else %} + # Gateway or non-mTLS: health check on main HTTP port + port: http + scheme: HTTP + {% endif %} + initialDelaySeconds: 10 + periodSeconds: 5 + timeoutSeconds: 3 + failureThreshold: 3 + {% endif %} + + resources: + requests: + memory: "{{ memory_request | default('256Mi') }}" + cpu: "{{ cpu_request | default('100m') }}" + limits: + memory: "{{ memory_limit | default('512Mi') }}" + cpu: "{{ cpu_limit | default('500m') }}" + + volumeMounts: + {% if mtls_enabled | default(true) %} + - name: server-cert + mountPath: /app/certs/mcp + readOnly: true + - name: ca-cert + mountPath: /app/certs/ca + readOnly: true + {% endif %} + {% if name == 'mcpgateway' and has_plugins | default(false) %} + - name: plugins-config + mountPath: /app/config + readOnly: true + {% endif %} + + {% if volume_mounts is defined %} + {% for mount in volume_mounts %} + - name: {{ mount.name }} + mountPath: {{ mount.path }} + {% if mount.readonly | default(false) %} + readOnly: true + {% endif %} + {% endfor %} + {% endif %} + + securityContext: + runAsNonRoot: true + {% if run_as_user is defined %} + runAsUser: {{ run_as_user }} + {% endif %} + allowPrivilegeEscalation: false + capabilities: + drop: + - ALL + readOnlyRootFilesystem: false + + volumes: + {% if mtls_enabled | default(true) %} + - name: server-cert + secret: + secretName: mcp-{{ name }}-server-cert + defaultMode: 0444 + - name: ca-cert + secret: + secretName: mcp-ca-secret + defaultMode: 0444 + {% endif %} + {% if name == 'mcpgateway' and has_plugins | default(false) %} + - name: plugins-config + configMap: + name: plugins-config + defaultMode: 0444 + {% endif %} + + {% if volumes is defined %} + {% for volume in volumes %} + - name: {{ volume.name }} + {% if volume.type == 'secret' %} + secret: + secretName: {{ volume.secret_name }} + {% if volume.default_mode is defined %} + defaultMode: {{ volume.default_mode }} + {% endif %} + {% elif volume.type == 'configmap' %} + configMap: + name: {{ volume.configmap_name }} + {% elif volume.type == 'persistentVolumeClaim' %} + persistentVolumeClaim: + claimName: {{ volume.claim_name }} + {% endif %} + {% endfor %} + {% endif %} +--- +apiVersion: v1 +kind: Service +metadata: + name: {{ name }} + namespace: {{ namespace }} + labels: + app: {{ name }} +spec: + type: {{ service_type | default('ClusterIP') }} + ports: + - name: http + port: {{ port | default(8000) }} + targetPort: http + protocol: TCP + {% if service_type == 'NodePort' and node_port is defined %} + nodePort: {{ node_port }} + {% endif %} + selector: + app: {{ name }} diff --git a/mcpgateway/tools/builder/templates/kubernetes/plugins-configmap.yaml.j2 b/mcpgateway/tools/builder/templates/kubernetes/plugins-configmap.yaml.j2 new file mode 100644 index 000000000..d517d8459 --- /dev/null +++ b/mcpgateway/tools/builder/templates/kubernetes/plugins-configmap.yaml.j2 @@ -0,0 +1,13 @@ +# Location: ./mcpgateway/tools/builder/templates/kubernetes/plugins-configmap.yaml.j2 +# Copyright 2025 +# SPDX-License-Identifier: Apache-2.0 +# Authors: Teryl Taylor +# ConfigMap for plugins configuration +apiVersion: v1 +kind: ConfigMap +metadata: + name: plugins-config + namespace: {{ namespace }} +data: + plugins.yaml: | +{{ plugins_config | safe | indent(4, first=True) }} diff --git a/mcpgateway/tools/builder/templates/kubernetes/postgres.yaml.j2 b/mcpgateway/tools/builder/templates/kubernetes/postgres.yaml.j2 new file mode 100644 index 000000000..de58a288e --- /dev/null +++ b/mcpgateway/tools/builder/templates/kubernetes/postgres.yaml.j2 @@ -0,0 +1,125 @@ +# Location: ./mcpgateway/tools/builder/templates/kubernetes/postgres.yaml.j2 +# Copyright 2025 +# SPDX-License-Identifier: Apache-2.0 +# Authors: Teryl Taylor +# PostgreSQL Database for MCP Gateway +apiVersion: v1 +kind: PersistentVolumeClaim +metadata: + name: postgres-pvc + namespace: {{ namespace }} +spec: + accessModes: + - ReadWriteOnce + resources: + requests: + storage: {{ storage_size }} + {% if storage_class %} + storageClassName: {{ storage_class }} + {% endif %} +--- +apiVersion: v1 +kind: Secret +metadata: + name: postgres-secret + namespace: {{ namespace }} +type: Opaque +stringData: + # Official PostgreSQL image variables + POSTGRES_USER: {{ user }} + POSTGRES_PASSWORD: {{ password }} + POSTGRES_DB: {{ database }} + # Red Hat/SCL PostgreSQL image variables (OpenShift-compatible) + POSTGRESQL_USER: {{ user }} + POSTGRESQL_PASSWORD: {{ password }} + POSTGRESQL_DATABASE: {{ database }} +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: postgres + namespace: {{ namespace }} + labels: + app: postgres + component: database +spec: + replicas: 1 + selector: + matchLabels: + app: postgres + template: + metadata: + labels: + app: postgres + component: database + spec: + containers: + - name: postgres + image: {{ image }} + imagePullPolicy: IfNotPresent + + ports: + - name: postgres + containerPort: 5432 + protocol: TCP + + envFrom: + - secretRef: + name: postgres-secret + + volumeMounts: + - name: postgres-data + mountPath: /var/lib/postgresql/data + subPath: postgres + + livenessProbe: + exec: + command: + - /bin/sh + - -c + - pg_isready -U {{ user }} + initialDelaySeconds: 30 + periodSeconds: 10 + timeoutSeconds: 5 + failureThreshold: 3 + + readinessProbe: + exec: + command: + - /bin/sh + - -c + - pg_isready -U {{ user }} + initialDelaySeconds: 10 + periodSeconds: 5 + timeoutSeconds: 3 + failureThreshold: 3 + + resources: + requests: + memory: "256Mi" + cpu: "100m" + limits: + memory: "512Mi" + cpu: "500m" + + volumes: + - name: postgres-data + persistentVolumeClaim: + claimName: postgres-pvc +--- +apiVersion: v1 +kind: Service +metadata: + name: postgres + namespace: {{ namespace }} + labels: + app: postgres +spec: + type: ClusterIP + ports: + - name: postgres + port: 5432 + targetPort: postgres + protocol: TCP + selector: + app: postgres diff --git a/mcpgateway/tools/builder/templates/kubernetes/redis.yaml.j2 b/mcpgateway/tools/builder/templates/kubernetes/redis.yaml.j2 new file mode 100644 index 000000000..340e2c71a --- /dev/null +++ b/mcpgateway/tools/builder/templates/kubernetes/redis.yaml.j2 @@ -0,0 +1,76 @@ +# Location: ./mcpgateway/tools/builder/templates/kubernetes/redis.yaml.j2 +# Copyright 2025 +# SPDX-License-Identifier: Apache-2.0 +# Authors: Teryl Taylor +# Redis Cache for MCP Gateway +apiVersion: apps/v1 +kind: Deployment +metadata: + name: redis + namespace: {{ namespace }} + labels: + app: redis + component: cache +spec: + replicas: 1 + selector: + matchLabels: + app: redis + template: + metadata: + labels: + app: redis + component: cache + spec: + containers: + - name: redis + image: {{ image }} + imagePullPolicy: IfNotPresent + + ports: + - name: redis + containerPort: 6379 + protocol: TCP + + livenessProbe: + tcpSocket: + port: redis + initialDelaySeconds: 30 + periodSeconds: 10 + timeoutSeconds: 5 + failureThreshold: 3 + + readinessProbe: + exec: + command: + - redis-cli + - ping + initialDelaySeconds: 10 + periodSeconds: 5 + timeoutSeconds: 3 + failureThreshold: 3 + + resources: + requests: + memory: "128Mi" + cpu: "50m" + limits: + memory: "256Mi" + cpu: "200m" +--- +apiVersion: v1 +kind: Service +metadata: + name: redis + namespace: {{ namespace }} + labels: + app: redis +spec: + type: ClusterIP + ports: + - name: redis + port: 6379 + targetPort: redis + protocol: TCP + selector: + app: redis diff --git a/mcpgateway/tools/builder/templates/kubernetes/route.yaml.j2 b/mcpgateway/tools/builder/templates/kubernetes/route.yaml.j2 new file mode 100644 index 000000000..815ace26d --- /dev/null +++ b/mcpgateway/tools/builder/templates/kubernetes/route.yaml.j2 @@ -0,0 +1,25 @@ +# Location: ./mcpgateway/tools/builder/templates/kubernetes/route.yaml.j2 +# Copyright 2025 +# SPDX-License-Identifier: Apache-2.0 +# Authors: Teryl Taylor +# OpenShift Route for external access to MCP Gateway +apiVersion: route.openshift.io/v1 +kind: Route +metadata: + name: mcpgateway-admin + namespace: {{ namespace }} + labels: + app: mcpgateway + component: gateway +spec: + host: mcpgateway-admin-{{ namespace }}.{{ openshift_domain }} + path: / + to: + kind: Service + name: mcpgateway + weight: 100 + port: + targetPort: http + tls: + termination: {{ tls_termination }} + wildcardPolicy: None diff --git a/mcpgateway/tools/builder/templates/plugins-config.yaml.j2 b/mcpgateway/tools/builder/templates/plugins-config.yaml.j2 new file mode 100644 index 000000000..a8221873a --- /dev/null +++ b/mcpgateway/tools/builder/templates/plugins-config.yaml.j2 @@ -0,0 +1,49 @@ +# Location: ./mcpgateway/tools/builder/templates/compose/plugins-config.yaml.j2 +# Copyright 2025 +# SPDX-License-Identifier: Apache-2.0 +# Authors: Teryl Taylor +# Plugin configuration for MCP Gateway +# Auto-generated from mcp-stack.yaml + +# Global plugin settings +plugin_settings: + parallel_execution_within_band: true + plugin_timeout: 120 + fail_on_plugin_error: false + enable_plugin_api: true + plugin_health_check_interval: 60 + +# External plugin connections +plugins: +{% for plugin in plugins -%} +- name: {{ plugin.name }} + kind: external +{%- if plugin.description %} + description: "{{ plugin.description }}" +{%- endif %} +{%- if plugin.version %} + version: "{{ plugin.version }}" +{%- endif %} +{%- if plugin.author %} + author: "{{ plugin.author }}" +{%- endif %} +{%- if plugin.hooks %} + hooks: {{ plugin.hooks }} +{%- endif %} +{%- if plugin.tags %} + tags: {{ plugin.tags }} +{%- endif %} +{%- if plugin.mode %} + mode: "{{ plugin.mode }}" +{%- endif %} +{%- if plugin.priority %} + priority: {{ plugin.priority }} +{%- endif %} +{%- if plugin.conditions %} + conditions: {{ plugin.conditions }} +{%- endif %} + mcp: + proto: STREAMABLEHTTP + url: {{ plugin.url }} + +{% endfor %} diff --git a/mcpgateway/tools/cli.py b/mcpgateway/tools/cli.py new file mode 100644 index 000000000..d7a869c77 --- /dev/null +++ b/mcpgateway/tools/cli.py @@ -0,0 +1,57 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/tools/cli.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +cforge CLI ─ command line tools for building and deploying the +MCP Gateway and its plugins. + +This module is exposed as a **console-script** via: + + [project.scripts] + cforge = "mcpgateway.tools.cli:main" + +so that a user can simply type `cforge ...` to use the CLI. + +Features +───────── +* plugin: + - bootstrap: Creates a new plugin project from template │ + - install: Installs plugins into a Python environment │ + - package: Builds an MCP server to serve plugins as tools +* gateway: + - Validates deploy.yaml configuration + - Builds plugin containers from git repos + - Generates mTLS certificates + - Deploys to Kubernetes or Docker Compose + - Integrates with CI/CD vault secrets + + +Typical usage +───────────── +```console +$ cforge --help +``` +""" + +# Third-Party +import typer + +# First-Party +import mcpgateway.plugins.tools.cli as plugins +import mcpgateway.tools.builder.cli as builder + +app = typer.Typer(help="Command line tools for building, deploying, and interacting with the ContextForge MCP Gateway") + +app.add_typer(plugins.app, name="plugin", help="Manage the plugin lifecycle") +app.add_typer(builder.app, name="gateway", help="Manage the building and deployment of the gateway") + + +def main() -> None: # noqa: D401 - imperative mood is fine here + """Entry point for the *cforge* console script.""" + app(obj={}) + + +if __name__ == "__main__": + main() diff --git a/mcpgateway/translate_grpc.py b/mcpgateway/translate_grpc.py index f1aa9a4ba..58b80ab8f 100644 --- a/mcpgateway/translate_grpc.py +++ b/mcpgateway/translate_grpc.py @@ -173,7 +173,7 @@ async def _discover_service_details(self, stub, service_name: str) -> None: # Add to pool (ignore if already exists) try: self._pool.Add(file_desc_proto) - except Exception as e: # noqa: B110 + except Exception as e: # pylint: disable=broad-except # Descriptor already in pool, safe to skip logger.debug(f"Descriptor already in pool: {e}") diff --git a/mcpgateway/utils/correlation_id.py b/mcpgateway/utils/correlation_id.py new file mode 100644 index 000000000..6701405e3 --- /dev/null +++ b/mcpgateway/utils/correlation_id.py @@ -0,0 +1,175 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/utils/correlation_id.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: MCP Gateway Contributors + +Correlation ID (Request ID) Utilities. + +This module provides async-safe utilities for managing correlation IDs (also known as +request IDs) throughout the request lifecycle using Python's contextvars. + +The correlation ID is a unique identifier that tracks a single request as it flows +through all components of the system (HTTP → Middleware → Services → Plugins → Logs). + +Key concepts: +- ContextVar provides per-request isolation in async environments +- Correlation IDs can be client-provided (X-Correlation-ID header) or auto-generated +- The same ID is used as request_id throughout logs, services, and plugin contexts +- Thread-safe and async-safe (no cross-contamination between concurrent requests) +""" + +# Standard +from contextvars import ContextVar +import logging +from typing import Dict, Optional +import uuid + +logger = logging.getLogger(__name__) + +# Context variable for storing correlation ID (request ID) per-request +# This is async-safe and provides automatic isolation between concurrent requests +_correlation_id_context: ContextVar[Optional[str]] = ContextVar("correlation_id", default=None) + + +def get_correlation_id() -> Optional[str]: + """Get the current correlation ID (request ID) from context. + + Returns the correlation ID for the current async task/request. Each request + has its own isolated context, so concurrent requests won't interfere. + + Returns: + Optional[str]: The correlation ID if set, None otherwise + """ + return _correlation_id_context.get() + + +def set_correlation_id(correlation_id: str) -> None: + """Set the correlation ID (request ID) for the current context. + + Stores the correlation ID in a context variable that's automatically isolated + per async task. This ID will be used as request_id throughout the system. + + Args: + correlation_id: The correlation ID to set (typically a UUID or client-provided ID) + """ + _correlation_id_context.set(correlation_id) + + +def clear_correlation_id() -> None: + """Clear the correlation ID (request ID) from the current context. + + Should be called at the end of request processing to clean up context. + In practice, FastAPI middleware automatically handles context cleanup. + + Note: This is optional as ContextVar automatically cleans up when the + async task completes. + """ + _correlation_id_context.set(None) + + +def generate_correlation_id() -> str: + """Generate a new correlation ID (UUID4 hex format). + + Creates a new random UUID suitable for use as a correlation ID. + Uses UUID4 which provides 122 bits of randomness. + + Returns: + str: A new UUID in hex format (32 characters, no hyphens) + """ + return uuid.uuid4().hex + + +def extract_correlation_id_from_headers(headers: Dict[str, str], header_name: str = "X-Correlation-ID") -> Optional[str]: + """Extract correlation ID from HTTP headers. + + Searches for the correlation ID header (case-insensitive) and returns its value. + Validates that the value is non-empty after stripping whitespace. + + Args: + headers: Dictionary of HTTP headers + header_name: Name of the correlation ID header (default: X-Correlation-ID) + + Returns: + Optional[str]: The correlation ID if found and valid, None otherwise + + Example: + >>> headers = {"X-Correlation-ID": "abc-123"} + >>> extract_correlation_id_from_headers(headers) + 'abc-123' + + >>> headers = {"x-correlation-id": "def-456"} # Case insensitive + >>> extract_correlation_id_from_headers(headers) + 'def-456' + """ + # Headers can be accessed case-insensitively in FastAPI/Starlette + for key, value in headers.items(): + if key.lower() == header_name.lower(): + correlation_id = value.strip() + if correlation_id: + return correlation_id + return None + + +def get_or_generate_correlation_id() -> str: + """Get the current correlation ID or generate a new one if not set. + + This is a convenience function that ensures you always have a correlation ID. + If the current context doesn't have a correlation ID, it generates and sets + a new one. + + Returns: + str: The correlation ID (either existing or newly generated) + + Example: + >>> # First call generates new ID + >>> id1 = get_or_generate_correlation_id() + >>> # Second call returns same ID + >>> id2 = get_or_generate_correlation_id() + >>> assert id1 == id2 + """ + correlation_id = get_correlation_id() + if not correlation_id: + correlation_id = generate_correlation_id() + set_correlation_id(correlation_id) + return correlation_id + + +def validate_correlation_id(correlation_id: Optional[str], max_length: int = 255) -> bool: + """Validate a correlation ID for safety and length. + + Checks that the correlation ID is: + - Non-empty after stripping whitespace + - Within the maximum length limit + - Contains only safe characters (alphanumeric, hyphens, underscores) + + Args: + correlation_id: The correlation ID to validate + max_length: Maximum allowed length (default: 255) + + Returns: + bool: True if valid, False otherwise + + Example: + >>> validate_correlation_id("abc-123") + True + >>> validate_correlation_id("abc 123") # Spaces not allowed + False + >>> validate_correlation_id("a" * 300) # Too long + False + """ + if not correlation_id or not correlation_id.strip(): + return False + + correlation_id = correlation_id.strip() + + if len(correlation_id) > max_length: + logger.warning(f"Correlation ID too long: {len(correlation_id)} > {max_length}") + return False + + # Allow alphanumeric, hyphens, and underscores only + if not all(c.isalnum() or c in ("-", "_") for c in correlation_id): + logger.warning(f"Correlation ID contains invalid characters: {correlation_id}") + return False + + return True diff --git a/mcpgateway/utils/retry_manager.py b/mcpgateway/utils/retry_manager.py index c8cb8283f..3e721167e 100644 --- a/mcpgateway/utils/retry_manager.py +++ b/mcpgateway/utils/retry_manager.py @@ -301,7 +301,7 @@ async def _sleep_with_jitter(self, base: float, jitter_range: float): True """ # random.uniform() is safe here as jitter is only used for retry timing, not security - delay = base + random.uniform(0, jitter_range) # noqa: DUO102 # nosec B311 + delay = base + random.uniform(0, jitter_range) # nosec B311 # noqa: DUO102 # Ensure delay doesn't exceed the max allowed delay = min(delay, self.max_delay) await asyncio.sleep(delay) diff --git a/mcpgateway/utils/ssl_key_manager.py b/mcpgateway/utils/ssl_key_manager.py new file mode 100644 index 000000000..8c4fa4533 --- /dev/null +++ b/mcpgateway/utils/ssl_key_manager.py @@ -0,0 +1,174 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/utils/ssl_key_manager.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Keval Mahajan + +SSL key management utilities for handling passphrase-protected keys. + +This module provides utilities for managing SSL private keys, including support +for passphrase-protected keys. It handles decryption and secure temporary file +management for use with Gunicorn and other servers that don't natively support +passphrase-protected keys. +""" + +# Standard +import atexit +import logging +import os +from pathlib import Path +import tempfile +from typing import Optional + +# Third-Party +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.serialization import load_pem_private_key + +logger = logging.getLogger(__name__) + + +class SSLKeyManager: + """Manages SSL private keys, including passphrase-protected keys. + + This class handles the decryption of passphrase-protected private keys + and creates temporary unencrypted key files for use with servers that + don't support passphrase-protected keys directly (like Gunicorn). + + The temporary files are created with secure permissions (0o600) and are + automatically cleaned up on process exit. + + Examples: + >>> manager = SSLKeyManager() + >>> key_path = manager.prepare_key_file("certs/key.pem") # doctest: +SKIP + >>> # Use key_path with Gunicorn + >>> manager.cleanup() # doctest: +SKIP + """ + + def __init__(self): + """Initialize the SSL key manager.""" + self._temp_key_file: Optional[Path] = None + + def prepare_key_file( + self, + key_file: str | Path, + password: Optional[str] = None, + ) -> str: + """Prepare a key file for use with Gunicorn. + + If the key is passphrase-protected, decrypt it and write to a + temporary file with secure permissions. Otherwise, return the + original path. + + Args: + key_file: Path to the private key file + password: Optional passphrase for encrypted key + + Returns: + Path to the usable key file (original or temporary) + + Raises: + FileNotFoundError: If the key file doesn't exist + ValueError: If decryption fails (wrong passphrase, invalid key, etc.) + + Examples: + >>> manager = SSLKeyManager() + >>> # Unencrypted key - returns original path + >>> path = manager.prepare_key_file("certs/key.pem") # doctest: +SKIP + >>> # Encrypted key - returns temporary decrypted path + >>> path = manager.prepare_key_file("certs/key-enc.pem", "secret") # doctest: +SKIP + """ + key_path = Path(key_file) + + if not key_path.exists(): + raise FileNotFoundError(f"Key file not found: {key_file}") + + # If no password, use the key as-is + if not password: + logger.info(f"Using unencrypted key file: {key_file}") + return str(key_path) + + # Decrypt the key and write to temporary file + logger.info("Decrypting passphrase-protected key...") + + try: + # Read and decrypt the key + with open(key_path, "rb") as f: + key_data = f.read() + + private_key = load_pem_private_key( + key_data, + password=password.encode() if password else None, + ) + + # Serialize to unencrypted PEM + unencrypted_pem = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ) + + # Write to temporary file with secure permissions + fd, temp_path = tempfile.mkstemp(suffix=".pem", prefix="ssl_key_") + self._temp_key_file = Path(temp_path) + + # Set restrictive permissions (owner read/write only) + os.chmod(temp_path, 0o600) + + # Write the decrypted key + with os.fdopen(fd, "wb") as f: + f.write(unencrypted_pem) + + logger.info(f"Decrypted key written to temporary file: {temp_path}") + + # Register cleanup on exit + atexit.register(self.cleanup) + + return temp_path + + except Exception as e: + logger.error(f"Failed to decrypt key: {e}") + self.cleanup() + raise ValueError("Failed to decrypt private key. Check that the passphrase is correct.") from e + + def cleanup(self): + """Remove temporary key file if it exists. + + This method is automatically called on process exit via atexit, + but can also be called manually for explicit cleanup. + """ + if self._temp_key_file and self._temp_key_file.exists(): + try: + self._temp_key_file.unlink() + logger.info(f"Cleaned up temporary key file: {self._temp_key_file}") + except Exception as e: + logger.warning(f"Failed to clean up temporary key file: {e}") + finally: + self._temp_key_file = None + + +# Global instance for convenience +_key_manager = SSLKeyManager() + + +def prepare_ssl_key(key_file: str, password: Optional[str] = None) -> str: + """Prepare an SSL key file for use with Gunicorn. + + This is a convenience function that uses the global key manager instance. + + Args: + key_file: Path to the private key file + password: Optional passphrase for encrypted key + + Returns: + Path to the usable key file (original or temporary) + + Raises: + FileNotFoundError: If the key file doesn't exist + ValueError: If decryption fails + + Examples: + >>> from mcpgateway.utils.ssl_key_manager import prepare_ssl_key + >>> key_path = prepare_ssl_key("certs/key.pem") # doctest: +SKIP + >>> key_path = prepare_ssl_key("certs/key-enc.pem", "secret") # doctest: +SKIP + """ + return _key_manager.prepare_key_file(key_file, password) diff --git a/mcpgateway/utils/validate_signature.py b/mcpgateway/utils/validate_signature.py index 35216d5c4..647b37fb7 100755 --- a/mcpgateway/utils/validate_signature.py +++ b/mcpgateway/utils/validate_signature.py @@ -115,6 +115,14 @@ def validate_signature(data: bytes, signature: bytes | str, public_key_pem: str) >>> # Test invalid signature >>> validate_signature(b"wrong data", signature, public_pem) False + >>> + >>> # Test with string data (gets encoded) + >>> validate_signature("test message", signature, public_pem) + True + >>> + >>> # Test invalid hex signature format + >>> validate_signature(data, "not-valid-hex", public_pem) + False """ if isinstance(data, str): data = data.encode() @@ -182,6 +190,20 @@ def resign_data( >>> new_sig = resign_data(data, old_public_pem, "", new_private_pem) >>> isinstance(new_sig, str) True + >>> + >>> # Test re-signing with valid old signature + >>> old_sig = old_private.sign(data) + >>> new_sig2 = resign_data(data, old_public_pem, old_sig, new_private_pem) + >>> isinstance(new_sig2, str) + True + >>> new_sig2 != old_sig.hex() # New signature should be different + True + >>> + >>> # Test with invalid old signature (should return None) + >>> bad_sig = b"invalid signature bytes" + >>> result = resign_data(data, old_public_pem, bad_sig, new_private_pem) + >>> result is None + True """ # Handle first-time signing (no old signature) if not old_signature: diff --git a/plugins/config.yaml b/plugins/config.yaml index 7c821daf6..e36d45fe1 100644 --- a/plugins/config.yaml +++ b/plugins/config.yaml @@ -755,7 +755,7 @@ plugins: hooks: ["tool_pre_invoke"] tags: ["security", "vault", "OAUTH2"] # mode: "permissive" - mode: "disabled" + mode: "permissive" priority: 10 conditions: - prompts: [] diff --git a/plugins/external/cedar/.dockerignore b/plugins/external/cedar/.dockerignore new file mode 100644 index 000000000..e9a71f900 --- /dev/null +++ b/plugins/external/cedar/.dockerignore @@ -0,0 +1,363 @@ +# syntax=docker/dockerfile:1 +#---------------------------------------------------------------------- +# Docker Build Context Optimization +# +# This .dockerignore file excludes unnecessary files from the Docker +# build context to improve build performance and security. +#---------------------------------------------------------------------- + +#---------------------------------------------------------------------- +# 1. Development and source directories (not needed in production) +#---------------------------------------------------------------------- +agent_runtimes/ +charts/ +deployment/ +docs/ +deployment/k8s/ +mcp-servers/ +tests/ +test/ +attic/ +*.md +.benchmarks/ + +# Development environment directories +.devcontainer/ +.github/ +.vscode/ +.idea/ + +#---------------------------------------------------------------------- +# 2. Version control +#---------------------------------------------------------------------- +.git/ +.gitignore +.gitattributes +.gitmodules + +#---------------------------------------------------------------------- +# 3. Python build artifacts and caches +#---------------------------------------------------------------------- +# Byte-compiled files +__pycache__/ +*.py[cod] +*.pyc +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST +.wily/ + +# PyInstaller +*.manifest +*.spec + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ +.pytype/ + +# Cython debug symbols +cython_debug/ + +#---------------------------------------------------------------------- +# 4. Virtual environments +#---------------------------------------------------------------------- +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ +.python37/ +.python39/ +.python-version + +# PDM +pdm.lock +.pdm.toml +.pdm-python + +#---------------------------------------------------------------------- +# 5. Package managers and dependencies +#---------------------------------------------------------------------- +# Node.js +node_modules/ +npm-debug.log* +yarn-debug.log* +yarn-error.log* +.npm +.yarn + +# pip +pip-log.txt +pip-delete-this-directory.txt + +#---------------------------------------------------------------------- +# 6. Docker and container files (avoid recursive copies) +#---------------------------------------------------------------------- +Dockerfile +Dockerfile.* +Containerfile +Containerfile.* +docker-compose.yml +docker-compose.*.yml +podman-compose*.yaml +.dockerignore + +#---------------------------------------------------------------------- +# 7. IDE and editor files +#---------------------------------------------------------------------- +# JetBrains +.idea/ +*.iml +*.iws +*.ipr + +# VSCode +.vscode/ +*.code-workspace + +# Vim +*.swp +*.swo +*~ + +# Emacs +*~ +\#*\# +.\#* + +# macOS +.DS_Store +.AppleDouble +.LSOverride + +#---------------------------------------------------------------------- +# 8. Build tools and CI/CD configurations +#---------------------------------------------------------------------- +# Testing configurations +.coveragerc +.pylintrc +.flake8 +pytest.ini +tox.ini +.pytest.ini + +# Linting and formatting +.hadolint.yaml +.pre-commit-config.yaml +.pycodestyle +.pyre_configuration +.pyspelling.yaml +.ruff.toml +.shellcheckrc + +# Build configurations +Makefile +setup.cfg +pyproject.toml.bak +MANIFEST.in + +# CI/CD +.travis.* +.gitlab-ci.yml +.circleci/ +.github/ +azure-pipelines.yml +Jenkinsfile + +# Code quality +sonar-code.properties +sonar-project.properties +.scannerwork/ +whitesource.config +.whitesource + +# Other tools +.bumpversion.cfg +.editorconfig +mypy.ini + +#---------------------------------------------------------------------- +# 9. Application runtime files (should not be in image) +#---------------------------------------------------------------------- +# Databases +*.db +*.sqlite +*.sqlite3 +mcp.db +db.sqlite3 + +# Logs +*.log +logs/ +log/ + +# Certificates and secrets +certs/ +*.pem +*.key +*.crt +*.csr +.env +.env.* + +# Generated files +public/ +static/ +media/ + +# Application instances +instance/ +local_settings.py + +#---------------------------------------------------------------------- +# 10. Framework-specific files +#---------------------------------------------------------------------- +# Django +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal +media/ + +# Flask +instance/ +.webassets-cache + +# Scrapy +.scrapy + +# Sphinx documentation +docs/_build/ +docs/build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints +*.ipynb + +# IPython +profile_default/ +ipython_config.py + +# celery +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +#---------------------------------------------------------------------- +# 11. Backup and temporary files +#---------------------------------------------------------------------- +*.bak +*.backup +*.tmp +*.temp +*.orig +*.rej +.backup/ +backup/ +tmp/ +temp/ + +#---------------------------------------------------------------------- +# 12. Documentation and miscellaneous +#---------------------------------------------------------------------- +*.md +!README.md +LICENSE +CHANGELOG +AUTHORS +CONTRIBUTORS +TODO +TODO.md +DEVELOPING.md +CONTRIBUTING.md + +# Spelling +.spellcheck-en.txt +*.dic + +# Shell scripts (if not needed in container) +test.sh +scripts/test/ +scripts/dev/ + +#---------------------------------------------------------------------- +# 13. OS-specific files +#---------------------------------------------------------------------- +# Windows +Thumbs.db +ehthumbs.db +Desktop.ini +$RECYCLE.BIN/ + +# Linux +*~ +.fuse_hidden* +.directory +.Trash-* +.nfs* + +#---------------------------------------------------------------------- +# End of .dockerignore +#---------------------------------------------------------------------- diff --git a/plugins/external/cedar/.env.template b/plugins/external/cedar/.env.template new file mode 100644 index 000000000..5dbc57403 --- /dev/null +++ b/plugins/external/cedar/.env.template @@ -0,0 +1,63 @@ +##################################### +# Plugins Settings +##################################### + +# Enable the plugin framework +PLUGINS_ENABLED=false + +# Enable auto-completion for plugins CLI +PLUGINS_CLI_COMPLETION=false + +# default host port to listen on +PLUGINS_SERVER_HOST=0.0.0.0 + +# Set markup mode for plugins CLI +# Valid options: +# rich: use rich markup +# markdown: allow markdown in help strings +# disabled: disable markup +# If unset (commented out), uses "rich" if rich is detected, otherwise disables it. +PLUGINS_CLI_MARKUP_MODE=rich + +# Configuration path for plugin loader +PLUGINS_CONFIG=./resources/plugins/config.yaml + +# Configuration path for chuck mcp runtime +CHUK_MCP_CONFIG_PATH=./resources/runtime/config.yaml + +# Configuration for plugins transport +PLUGINS_TRANSPORT=streamablehttp + +##################################### +# MCP External Plugin Server - mTLS Configuration +##################################### + +# Enable SSL/TLS for external plugin MCP server +# Options: true, false (default) +# When true: Enables HTTPS and optionally mTLS for the plugin MCP server +MCP_SSL_ENABLED=false + +# SSL/TLS Certificate Files +# Path to server private key (required when MCP_SSL_ENABLED=true) +# Generate with: openssl genrsa -out certs/mcp/server.key 2048 +# MCP_SSL_KEYFILE=certs/mcp/server.key + +# Path to server certificate (required when MCP_SSL_ENABLED=true) +# Generate with: openssl req -new -x509 -key certs/mcp/server.key -out certs/mcp/server.crt -days 365 +# MCP_SSL_CERTFILE=certs/mcp/server.crt + +# Optional password for encrypted private key +# MCP_SSL_KEYFILE_PASSWORD= + +# mTLS (Mutual TLS) Configuration +# Client certificate verification mode: +# 0 (CERT_NONE): No client certificate required - standard TLS (default) +# 1 (CERT_OPTIONAL): Client certificate optional - validate if provided +# 2 (CERT_REQUIRED): Client certificate required - full mTLS +# Default: 0 (standard TLS without client verification) +MCP_SSL_CERT_REQS=0 + +# CA certificate bundle for verifying client certificates +# Required when MCP_SSL_CERT_REQS=1 or MCP_SSL_CERT_REQS=2 +# Can be a single CA file or a bundle containing multiple CAs +# MCP_SSL_CA_CERTS=certs/mcp/ca.crt diff --git a/plugins/external/cedar/.ruff.toml b/plugins/external/cedar/.ruff.toml new file mode 100644 index 000000000..443a275df --- /dev/null +++ b/plugins/external/cedar/.ruff.toml @@ -0,0 +1,63 @@ +# Exclude a variety of commonly ignored directories. +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".git-rewrite", + ".hg", + ".ipynb_checkpoints", + ".mypy_cache", + ".nox", + ".pants.d", + ".pyenv", + ".pytest_cache", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + ".vscode", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "site-packages", + "venv", + "docs", + "test" +] + +# 200 line length +line-length = 200 +indent-width = 4 + +# Assume Python 3.11 +target-version = "py311" + +[lint] +# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. +select = ["E4", "E7", "E9", "F"] +ignore = [] + +# Allow fix for all enabled rules (when `--fix`) is provided. +fixable = ["ALL"] +unfixable = [] + +# Allow unused variables when underscore-prefixed. +dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" + +[format] +# Like Black, use double quotes for strings. +quote-style = "double" + +# Like Black, indent with spaces, rather than tabs. +indent-style = "space" + +# Like Black, respect magic trailing commas. +skip-magic-trailing-comma = false + +# Like Black, automatically detect the appropriate line ending. +line-ending = "auto" diff --git a/plugins/external/cedar/Containerfile b/plugins/external/cedar/Containerfile new file mode 100644 index 000000000..d2d5f6748 --- /dev/null +++ b/plugins/external/cedar/Containerfile @@ -0,0 +1,47 @@ +# syntax=docker/dockerfile:1.7 +ARG UBI=python-312-minimal + +FROM registry.access.redhat.com/ubi9/${UBI} AS builder + +ARG PYTHON_VERSION=3.12 + +ARG VERSION +ARG COMMIT_ID +ARG SKILLS_SDK_COMMIT_ID +ARG SKILLS_SDK_VERSION +ARG BUILD_TIME_SKILLS_INSTALL + +ENV APP_HOME=/app + +USER 0 + +# Image pre-requisites +RUN INSTALL_PKGS="git make gcc gcc-c++ python${PYTHON_VERSION}-devel" && \ + microdnf -y --setopt=tsflags=nodocs --setopt=install_weak_deps=0 install $INSTALL_PKGS && \ + microdnf -y clean all --enablerepo='*' + +# Setup alias from HOME to APP_HOME +RUN mkdir -p ${APP_HOME} && \ + chown -R 1001:0 ${APP_HOME} && \ + ln -s ${HOME} ${APP_HOME} && \ + mkdir -p ${HOME}/resources/config && \ + chown -R 1001:0 ${HOME}/resources/config + +USER 1001 + +# Install plugin package +COPY . . +RUN pip install --no-cache-dir uv && python -m uv pip install . + +# Make default cache directory writable +RUN mkdir -p -m 0776 ${HOME}/.cache + +# Update labels +LABEL maintainer="Context Forge MCP Gateway Team" \ + name="mcp/mcppluginserver" \ + version="${VERSION}" \ + url="https://github.com/IBM/mcp-context-forge" \ + description="MCP Plugin Server for the Context Forge MCP Gateway" + +# App entrypoint +ENTRYPOINT ["sh", "-c", "${HOME}/run-server.sh"] diff --git a/plugins/external/cedar/MANIFEST.in b/plugins/external/cedar/MANIFEST.in new file mode 100644 index 000000000..1fb92c60a --- /dev/null +++ b/plugins/external/cedar/MANIFEST.in @@ -0,0 +1,67 @@ +# ────────────────────────────────────────────────────────────── +# MANIFEST.in - source-distribution contents for cedarpolicyplugin +# ────────────────────────────────────────────────────────────── + +# 1️⃣ Core project files that SDists/Wheels should always carry +include LICENSE +include README.md +include pyproject.toml +include Containerfile + +# 2️⃣ Top-level config, examples and helper scripts +include *.py +include *.md +include *.example +include *.lock +include *.properties +include *.toml +include *.yaml +include *.yml +include *.json +include *.sh +include *.txt +recursive-include tests/async *.py +recursive-include tests/async *.yaml + +# 3️⃣ Tooling/lint configuration dot-files (explicit so they're not lost) +include .env.make +include .interrogaterc +include .jshintrc +include whitesource.config +include .darglint +include .dockerignore +include .flake8 +include .htmlhintrc +include .pycodestyle +include .pylintrc +include .whitesource +include .coveragerc +# include .gitignore # purely optional but many projects ship it +include .bumpversion.cfg +include .yamllint +include .editorconfig +include .snyk + +# 4️⃣ Runtime data that lives *inside* the package at import time +recursive-include resources/plugins *.yaml +recursive-include cedarpolicyplugin *.yaml + +# 5️⃣ (Optional) include MKDocs-based docs in the sdist +# graft docs + +# 6️⃣ Never publish caches, compiled or build outputs, deployment, agent_runtimes, etc. +global-exclude __pycache__ *.py[cod] *.so *.dylib +prune build +prune dist +prune .eggs +prune *.egg-info +prune charts +prune k8s +prune .devcontainer +exclude CLAUDE.* +exclude llms-full.txt + +# Exclude deployment, mcp-servers and agent_runtimes +prune deployment +prune mcp-servers +prune agent_runtimes diff --git a/plugins/external/cedar/Makefile b/plugins/external/cedar/Makefile new file mode 100644 index 000000000..a6855e6e3 --- /dev/null +++ b/plugins/external/cedar/Makefile @@ -0,0 +1,449 @@ + +REQUIRED_BUILD_BINS := uv + +SHELL := /bin/bash +.SHELLFLAGS := -eu -o pipefail -c + +# Project variables +PACKAGE_NAME = cedarpolicyplugin +PROJECT_NAME = cedarpolicyplugin +TARGET ?= cedarpolicyplugin + +# Virtual-environment variables +VENVS_DIR ?= $(HOME)/.venv +VENV_DIR ?= $(VENVS_DIR)/$(PROJECT_NAME) + +# ============================================================================= +# Linters +# ============================================================================= + +black: + @echo "🎨 black $(TARGET)..." && $(VENV_DIR)/bin/black -l 200 $(TARGET) + +black-check: + @echo "🎨 black --check $(TARGET)..." && $(VENV_DIR)/bin/black -l 200 --check --diff $(TARGET) + +ruff: + @echo "⚡ ruff $(TARGET)..." && $(VENV_DIR)/bin/ruff check $(TARGET) && $(VENV_DIR)/bin/ruff format $(TARGET) + +ruff-check: + @echo "⚡ ruff check $(TARGET)..." && $(VENV_DIR)/bin/ruff check $(TARGET) + +ruff-fix: + @echo "⚡ ruff check --fix $(TARGET)..." && $(VENV_DIR)/bin/ruff check --fix $(TARGET) + +ruff-format: + @echo "⚡ ruff format $(TARGET)..." && $(VENV_DIR)/bin/ruff format $(TARGET) + +# ============================================================================= +# Container runtime configuration and operations +# ============================================================================= + +# Container resource limits +CONTAINER_MEMORY = 2048m +CONTAINER_CPUS = 2 + +# Auto-detect container runtime if not specified - DEFAULT TO DOCKER +CONTAINER_RUNTIME ?= $(shell command -v docker >/dev/null 2>&1 && echo docker || echo podman) + +# Alternative: Always default to docker unless explicitly overridden +# CONTAINER_RUNTIME ?= docker + +# Container port +CONTAINER_PORT ?= 8000 +CONTAINER_INTERNAL_PORT ?= 8000 + +print-runtime: + @echo Using container runtime: $(CONTAINER_RUNTIME) + +# Base image name (without any prefix) +IMAGE_BASE ?= mcpgateway/$(PROJECT_NAME) +IMAGE_TAG ?= latest + +# Handle runtime-specific image naming +ifeq ($(CONTAINER_RUNTIME),podman) + # Podman adds localhost/ prefix for local builds + IMAGE_LOCAL := localhost/$(IMAGE_BASE):$(IMAGE_TAG) + IMAGE_LOCAL_DEV := localhost/$(IMAGE_BASE)-dev:$(IMAGE_TAG) + IMAGE_PUSH := $(IMAGE_BASE):$(IMAGE_TAG) +else + # Docker doesn't add prefix + IMAGE_LOCAL := $(IMAGE_BASE):$(IMAGE_TAG) + IMAGE_LOCAL_DEV := $(IMAGE_BASE)-dev:$(IMAGE_TAG) + IMAGE_PUSH := $(IMAGE_BASE):$(IMAGE_TAG) +endif + +print-image: + @echo "🐳 Container Runtime: $(CONTAINER_RUNTIME)" + @echo "Using image: $(IMAGE_LOCAL)" + @echo "Development image: $(IMAGE_LOCAL_DEV)" + @echo "Push image: $(IMAGE_PUSH)" + + + +# Function to get the actual image name as it appears in image list +define get_image_name +$(shell $(CONTAINER_RUNTIME) images --format "{{.Repository}}:{{.Tag}}" | grep -E "(localhost/)?$(IMAGE_BASE):$(IMAGE_TAG)" | head -1) +endef + +# Function to normalize image name for operations +define normalize_image +$(if $(findstring localhost/,$(1)),$(1),$(if $(filter podman,$(CONTAINER_RUNTIME)),localhost/$(1),$(1))) +endef + +# Containerfile to use (can be overridden) +#CONTAINER_FILE ?= Containerfile +CONTAINER_FILE ?= $(shell [ -f "Containerfile" ] && echo "Containerfile" || echo "Dockerfile") + +# Define COMMA for the conditional Z flag +COMMA := , + +container-info: + @echo "🐳 Container Runtime Configuration" + @echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + @echo "Runtime: $(CONTAINER_RUNTIME)" + @echo "Base Image: $(IMAGE_BASE)" + @echo "Tag: $(IMAGE_TAG)" + @echo "Local Image: $(IMAGE_LOCAL)" + @echo "Push Image: $(IMAGE_PUSH)" + @echo "Actual Image: $(call get_image_name)" + @echo "Container File: $(CONTAINER_FILE)" + @echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + +# Auto-detect platform based on uname +PLATFORM ?= linux/$(shell uname -m | sed 's/x86_64/amd64/;s/aarch64/arm64/') + +container-build: + @echo "🔨 Building with $(CONTAINER_RUNTIME) for platform $(PLATFORM)..." + $(CONTAINER_RUNTIME) build \ + --platform=$(PLATFORM) \ + -f $(CONTAINER_FILE) \ + --tag $(IMAGE_BASE):$(IMAGE_TAG) \ + . + @echo "✅ Built image: $(call get_image_name)" + $(CONTAINER_RUNTIME) images $(IMAGE_BASE):$(IMAGE_TAG) + +container-run: container-check-image + @echo "🚀 Running with $(CONTAINER_RUNTIME)..." + -$(CONTAINER_RUNTIME) stop $(PROJECT_NAME) 2>/dev/null || true + -$(CONTAINER_RUNTIME) rm $(PROJECT_NAME) 2>/dev/null || true + $(CONTAINER_RUNTIME) run --name $(PROJECT_NAME) \ + --env-file=.env \ + -p $(CONTAINER_PORT):$(CONTAINER_INTERNAL_PORT) \ + --restart=always \ + --memory=$(CONTAINER_MEMORY) --cpus=$(CONTAINER_CPUS) \ + --health-cmd="curl --fail http://localhost:$(CONTAINER_INTERNAL_PORT)/health || exit 1" \ + --health-interval=1m --health-retries=3 \ + --health-start-period=30s --health-timeout=10s \ + -d $(call get_image_name) + @sleep 2 + @echo "✅ Container started" + @echo "🔍 Health check status:" + @$(CONTAINER_RUNTIME) inspect $(PROJECT_NAME) --format='{{.State.Health.Status}}' 2>/dev/null || echo "No health check configured" + +container-run-host: container-check-image + @echo "🚀 Running with $(CONTAINER_RUNTIME)..." + -$(CONTAINER_RUNTIME) stop $(PROJECT_NAME) 2>/dev/null || true + -$(CONTAINER_RUNTIME) rm $(PROJECT_NAME) 2>/dev/null || true + $(CONTAINER_RUNTIME) run --name $(PROJECT_NAME) \ + --env-file=.env \ + --network=host \ + -p $(CONTAINER_PORT):$(CONTAINER_INTERNAL_PORT) \ + --restart=always \ + --memory=$(CONTAINER_MEMORY) --cpus=$(CONTAINER_CPUS) \ + --health-cmd="curl --fail http://localhost:$(CONTAINER_INTERNAL_PORT)/health || exit 1" \ + --health-interval=1m --health-retries=3 \ + --health-start-period=30s --health-timeout=10s \ + -d $(call get_image_name) + @sleep 2 + @echo "✅ Container started" + @echo "🔍 Health check status:" + @$(CONTAINER_RUNTIME) inspect $(PROJECT_NAME) --format='{{.State.Health.Status}}' 2>/dev/null || echo "No health check configured" + +container-push: container-check-image + @echo "📤 Preparing to push image..." + @# For Podman, we need to remove localhost/ prefix for push + @if [ "$(CONTAINER_RUNTIME)" = "podman" ]; then \ + actual_image=$$($(CONTAINER_RUNTIME) images --format "{{.Repository}}:{{.Tag}}" | grep -E "$(IMAGE_BASE):$(IMAGE_TAG)" | head -1); \ + if echo "$$actual_image" | grep -q "^localhost/"; then \ + echo "🏷️ Tagging for push (removing localhost/ prefix)..."; \ + $(CONTAINER_RUNTIME) tag "$$actual_image" $(IMAGE_PUSH); \ + fi; \ + fi + $(CONTAINER_RUNTIME) push $(IMAGE_PUSH) + @echo "✅ Pushed: $(IMAGE_PUSH)" + +container-check-image: + @echo "🔍 Checking for image..." + @if [ "$(CONTAINER_RUNTIME)" = "podman" ]; then \ + if ! $(CONTAINER_RUNTIME) image exists $(IMAGE_LOCAL) 2>/dev/null && \ + ! $(CONTAINER_RUNTIME) image exists $(IMAGE_BASE):$(IMAGE_TAG) 2>/dev/null; then \ + echo "❌ Image not found: $(IMAGE_LOCAL)"; \ + echo "💡 Run 'make container-build' first"; \ + exit 1; \ + fi; \ + else \ + if ! $(CONTAINER_RUNTIME) images -q $(IMAGE_LOCAL) 2>/dev/null | grep -q . && \ + ! $(CONTAINER_RUNTIME) images -q $(IMAGE_BASE):$(IMAGE_TAG) 2>/dev/null | grep -q .; then \ + echo "❌ Image not found: $(IMAGE_LOCAL)"; \ + echo "💡 Run 'make container-build' first"; \ + exit 1; \ + fi; \ + fi + @echo "✅ Image found" + +container-stop: + @echo "🛑 Stopping container..." + -$(CONTAINER_RUNTIME) stop $(PROJECT_NAME) 2>/dev/null || true + -$(CONTAINER_RUNTIME) rm $(PROJECT_NAME) 2>/dev/null || true + @echo "✅ Container stopped and removed" + +container-logs: + @echo "📜 Streaming logs (Ctrl+C to exit)..." + $(CONTAINER_RUNTIME) logs -f $(PROJECT_NAME) + +container-shell: + @echo "🔧 Opening shell in container..." + @if ! $(CONTAINER_RUNTIME) ps -q -f name=$(PROJECT_NAME) | grep -q .; then \ + echo "❌ Container $(PROJECT_NAME) is not running"; \ + echo "💡 Run 'make container-run' first"; \ + exit 1; \ + fi + @$(CONTAINER_RUNTIME) exec -it $(PROJECT_NAME) /bin/bash 2>/dev/null || \ + $(CONTAINER_RUNTIME) exec -it $(PROJECT_NAME) /bin/sh + +container-health: + @echo "🏥 Checking container health..." + @if ! $(CONTAINER_RUNTIME) ps -q -f name=$(PROJECT_NAME) | grep -q .; then \ + echo "❌ Container $(PROJECT_NAME) is not running"; \ + exit 1; \ + fi + @echo "Status: $$($(CONTAINER_RUNTIME) inspect $(PROJECT_NAME) --format='{{.State.Health.Status}}' 2>/dev/null || echo 'No health check')" + @echo "Logs:" + @$(CONTAINER_RUNTIME) inspect $(PROJECT_NAME) --format='{{range .State.Health.Log}}{{.Output}}{{end}}' 2>/dev/null || true + +container-build-multi: + @echo "🔨 Building multi-architecture image..." + @if [ "$(CONTAINER_RUNTIME)" = "docker" ]; then \ + if ! docker buildx inspect $(PROJECT_NAME)-builder >/dev/null 2>&1; then \ + echo "📦 Creating buildx builder..."; \ + docker buildx create --name $(PROJECT_NAME)-builder; \ + fi; \ + docker buildx use $(PROJECT_NAME)-builder; \ + docker buildx build \ + --platform=linux/amd64,linux/arm64 \ + -f $(CONTAINER_FILE) \ + --tag $(IMAGE_BASE):$(IMAGE_TAG) \ + --push \ + .; \ + elif [ "$(CONTAINER_RUNTIME)" = "podman" ]; then \ + echo "📦 Building manifest with Podman..."; \ + $(CONTAINER_RUNTIME) build --platform=linux/amd64,linux/arm64 \ + -f $(CONTAINER_FILE) \ + --manifest $(IMAGE_BASE):$(IMAGE_TAG) \ + .; \ + echo "💡 To push: podman manifest push $(IMAGE_BASE):$(IMAGE_TAG)"; \ + else \ + echo "❌ Multi-arch builds require Docker buildx or Podman"; \ + exit 1; \ + fi + +# Helper targets for debugging image issues +image-list: + @echo "📋 Images matching $(IMAGE_BASE):" + @$(CONTAINER_RUNTIME) images --format "table {{.Repository}}:{{.Tag}}\t{{.ID}}\t{{.Created}}\t{{.Size}}" | \ + grep -E "(IMAGE|$(IMAGE_BASE))" || echo "No matching images found" + +image-clean: + @echo "🧹 Removing all $(IMAGE_BASE) images..." + @$(CONTAINER_RUNTIME) images --format "{{.Repository}}:{{.Tag}}" | \ + grep -E "(localhost/)?$(IMAGE_BASE)" | \ + xargs $(XARGS_FLAGS) $(CONTAINER_RUNTIME) rmi -f 2>/dev/null + @echo "✅ Images cleaned" + +# Fix image naming issues +image-retag: + @echo "🏷️ Retagging images for consistency..." + @if [ "$(CONTAINER_RUNTIME)" = "podman" ]; then \ + if $(CONTAINER_RUNTIME) image exists $(IMAGE_BASE):$(IMAGE_TAG) 2>/dev/null; then \ + $(CONTAINER_RUNTIME) tag $(IMAGE_BASE):$(IMAGE_TAG) $(IMAGE_LOCAL) 2>/dev/null || true; \ + fi; \ + else \ + if $(CONTAINER_RUNTIME) images -q $(IMAGE_LOCAL) 2>/dev/null | grep -q .; then \ + $(CONTAINER_RUNTIME) tag $(IMAGE_LOCAL) $(IMAGE_BASE):$(IMAGE_TAG) 2>/dev/null || true; \ + fi; \ + fi + @echo "✅ Images retagged" # This always shows success + +# Runtime switching helpers +use-docker: + @echo "export CONTAINER_RUNTIME=docker" + @echo "💡 Run: export CONTAINER_RUNTIME=docker" + +use-podman: + @echo "export CONTAINER_RUNTIME=podman" + @echo "💡 Run: export CONTAINER_RUNTIME=podman" + +show-runtime: + @echo "Current runtime: $(CONTAINER_RUNTIME)" + @echo "Detected from: $$(command -v $(CONTAINER_RUNTIME) || echo 'not found')" # Added + @echo "To switch: make use-docker or make use-podman" + + + +# ============================================================================= +# Targets +# ============================================================================= + +.PHONY: venv +venv: + @rm -Rf "$(VENV_DIR)" + @test -d "$(VENVS_DIR)" || mkdir -p "$(VENVS_DIR)" + @python3 -m venv "$(VENV_DIR)" + @/bin/bash -c "source $(VENV_DIR)/bin/activate && python3 -m pip install --upgrade pip setuptools pdm uv" + @echo -e "✅ Virtual env created.\n💡 Enter it with:\n . $(VENV_DIR)/bin/activate\n" + +.PHONY: install +install: venv + $(foreach bin,$(REQUIRED_BUILD_BINS), $(if $(shell command -v $(bin) 2> /dev/null),,$(error Couldn't find `$(bin)`))) + @/bin/bash -c "source $(VENV_DIR)/bin/activate && python3 -m uv pip install ." + +.PHONY: install-dev +install-dev: venv + $(foreach bin,$(REQUIRED_BUILD_BINS), $(if $(shell command -v $(bin) 2> /dev/null),,$(error Couldn't find `$(bin)`))) + @/bin/bash -c "source $(VENV_DIR)/bin/activate && python3 -m uv pip install -e .[dev]" + +.PHONY: install-editable +install-editable: venv + $(foreach bin,$(REQUIRED_BUILD_BINS), $(if $(shell command -v $(bin) 2> /dev/null),,$(error Couldn't find `$(bin)`))) + @/bin/bash -c "source $(VENV_DIR)/bin/activate && python3 -m uv pip install -e .[dev]" + +.PHONY: uninstall +uninstall: + pip uninstall $(PACKAGE_NAME) + +.PHONY: dist +dist: clean ## Build wheel + sdist into ./dist + @test -d "$(VENV_DIR)" || $(MAKE) --no-print-directory venv + @/bin/bash -eu -c "\ + source $(VENV_DIR)/bin/activate && \ + python3 -m pip install --quiet --upgrade pip build && \ + python3 -m build" + @echo '🛠 Wheel & sdist written to ./dist' + +.PHONY: wheel +wheel: ## Build wheel only + @test -d "$(VENV_DIR)" || $(MAKE) --no-print-directory venv + @/bin/bash -eu -c "\ + source $(VENV_DIR)/bin/activate && \ + python3 -m pip install --quiet --upgrade pip build && \ + python3 -m build -w" + @echo '🛠 Wheel written to ./dist' + +.PHONY: sdist +sdist: ## Build source distribution only + @test -d "$(VENV_DIR)" || $(MAKE) --no-print-directory venv + @/bin/bash -eu -c "\ + source $(VENV_DIR)/bin/activate && \ + python3 -m pip install --quiet --upgrade pip build && \ + python3 -m build -s" + @echo '🛠 Source distribution written to ./dist' + +.PHONY: verify +verify: dist ## Build, run metadata & manifest checks + @/bin/bash -c "source $(VENV_DIR)/bin/activate && \ + twine check dist/* && \ + check-manifest && \ + pyroma -d ." + @echo "✅ Package verified - ready to publish." + +.PHONY: lint-fix +lint-fix: + @# Handle file arguments + @target_file="$(word 2,$(MAKECMDGOALS))"; \ + if [ -n "$$target_file" ] && [ "$$target_file" != "" ]; then \ + actual_target="$$target_file"; \ + else \ + actual_target="$(TARGET)"; \ + fi; \ + for target in $$(echo $$actual_target); do \ + if [ ! -e "$$target" ]; then \ + echo "❌ File/directory not found: $$target"; \ + exit 1; \ + fi; \ + done; \ + echo "🔧 Fixing lint issues in $$actual_target..."; \ + $(MAKE) --no-print-directory black TARGET="$$actual_target"; \ + $(MAKE) --no-print-directory ruff-fix TARGET="$$actual_target" + +.PHONY: lint-check +lint-check: + @# Handle file arguments + @target_file="$(word 2,$(MAKECMDGOALS))"; \ + if [ -n "$$target_file" ] && [ "$$target_file" != "" ]; then \ + actual_target="$$target_file"; \ + else \ + actual_target="$(TARGET)"; \ + fi; \ + for target in $$(echo $$actual_target); do \ + if [ ! -e "$$target" ]; then \ + echo "❌ File/directory not found: $$target"; \ + exit 1; \ + fi; \ + done; \ + echo "🔧 Fixing lint issues in $$actual_target..."; \ + $(MAKE) --no-print-directory black-check TARGET="$$actual_target"; \ + $(MAKE) --no-print-directory ruff-check TARGET="$$actual_target" + +.PHONY: lock +lock: + $(foreach bin,$(REQUIRED_BUILD_BINS), $(if $(shell command -v $(bin) 2> /dev/null),,$(error Couldn't find `$(bin)`. Please run `make init`))) + uv lock + +.PHONY: test +test: + pytest tests + +.PHONY: serve +serve: + @echo "Implement me." + +.PHONY: build +build: + @$(MAKE) container-build + +.PHONY: start +start: + @$(MAKE) container-run + +.PHONY: stop +stop: + @$(MAKE) container-stop + +.PHONY: clean +clean: + find . -type f -name '*.py[co]' -delete -o -type d -name __pycache__ -delete + rm -rf *.egg-info .pytest_cache tests/.pytest_cache build dist .ruff_cache .coverage + +.PHONY: help +help: + @echo "This Makefile is offered for convenience." + @echo "" + @echo "The following are the valid targets for this Makefile:" + @echo "...install Install package from sources" + @echo "...install-dev Install package from sources with dev packages" + @echo "...install-editable Install package from sources in editabled mode" + @echo "...uninstall Uninstall package" + @echo "...dist Clean-build wheel *and* sdist into ./dist" + @echo "...wheel Build wheel only" + @echo "...sdist Build source distribution only" + @echo "...verify Build + twine + check-manifest + pyroma (no upload)" + @echo "...serve Start API server locally" + @echo "...build Build API server container image" + @echo "...start Start the API server container" + @echo "...start Stop the API server container" + @echo "...lock Lock dependencies" + @echo "...lint-fix Check and fix lint errors" + @echo "...lint-check Check for lint errors" + @echo "...test Run all tests" + @echo "...clean Remove all artifacts and builds" diff --git a/plugins/external/cedar/README.md b/plugins/external/cedar/README.md new file mode 100644 index 000000000..4d99b6e09 --- /dev/null +++ b/plugins/external/cedar/README.md @@ -0,0 +1,332 @@ +# Cedar RBAC Plugin for MCP Gateway + +> Author: Shriti Priya +> Version: 0.1.0 + +A plugin that evaluates Cedar policies and user‑friendly custom-DSL policies on incoming requests, and then allows or denies those requests using RBAC-based decisions which are enforced in cedar language and using library `cedarpy`. + +## Cedar Language + +Cedar is an open-source language and specification for defining and evaluating permission policies. It allows you to specify who is authorized to perform which actions within your application. +For more details: https://www.cedarpolicy.com/en + +## RBAC + +Role-based access control (RBAC) is an authorization model where permissions are attached to roles (like admin, manager, viewer), and users are assigned to those roles instead of getting permissions directly. This makes access control easier to manage and reason about in larger systems. + +## CedarPolicyPlugin + +This plugin supports two ways of defining policies in the configuration file, controlled by the `policy_lang` parameter. + +### Cedar Mode + +`plugins/external/cedar/resources/config.yaml` + +When `policy_lang` is set to cedar, policies are written in the Cedar language under the policy key, using the following structure: + +```yaml + - id: allow-employee-basic-access + effect: Permit + principal: Role::"employee" + action: + - Action::"get_leave_balance" #tool name + - Action::"request_certificate" + resource: + - Server::"askHR" # mcp-server name + - Agent::"employee_agent" # agent name +``` +1. **id** is a unique string identifier for the policy. +2. **effect** can be either Permit or Forbid and determines whether matching requests are allowed or denied. +3. **principal** specifies who the policy applies to; here it targets the employee role. +4. **action** lists one or more tools that the principal is attempting to invoke. It could also be actions controlling the visibility of output, either to see full output or redacted output based on user role. +5. **resource** lists the servers, agents, prompts and resources that the actions can target. + +### Custom DSL mode + +`plugins/external/cedar/examples/config-dsl.yaml` + +When `policy_lang` is set to `custom_dsl`, policies are written in a compact, human-readable mini-language as a YAML multiline string. This allows non-experts to define role, resource, and action in a single, easy-to-scan block. +following syntax: + + +## Syntax + +Policies use the following basic pattern: + +``` +[role::/] + + +``` + +For example: + +```yaml + [role:hr:server/hr_tool] + update_payroll +``` + +In this example, role is hr, resource is server, and action is hr_tool. The line update_payroll represents the specific operation being authorized for that role–resource–action tuple. + + +## Configuration + +1. **policy_lang**: Specifies the policy language used, `cedar` or `custom_dsl`. +2. **policy_output_keywords**: Defines keywords for output views such as `view_full_output` and `view_redacted_output` which can be used in policies or applications to control the output visibility. +3. **policy_redaction_spec**: Contains a regex pattern for redaction; in this case, the pattern matches currency-like strings (e.g., "$123,456") for potential redaction in the policy output, protecting sensitive information. +4. **policy**: Defines the RBAC policy + +## Installation + +1. In the folder `plugins/external/cedar`, copy `.env.example` to `.env` file. +2. If you are using `policy_lang` to be `cedar`, add the plugin configuration to `plugins/external/cedar/resources/plugins/config.yaml`: + +```yaml +plugins: + - name: "CedarPolicyPlugin" + kind: "cedarpolicyplugin.plugin.CedarPolicyPlugin" + description: "A plugin that does policy decision and enforcement using cedar" + version: "0.1.0" + author: "Shriti Priya" + hooks: ["prompt_pre_fetch", "prompt_post_fetch", "tool_pre_invoke", "tool_post_invoke", "resource_pre_fetch", "resource_post_fetch"] + tags: ["plugin"] + mode: "enforce" # enforce | permissive | disabled + priority: 150 + conditions: + # Apply to specific tools/servers + - server_ids: [] # Apply to all servers + tenant_ids: [] # Apply to all tenants + config: + policy_lang: cedar + policy_output_keywords: + view_full: "view_full_output" + view_redacted: "view_redacted_output" + policy_redaction_spec: + pattern: '"\$\d{1,}(,\d{1,})*"' # provide regex, if none, then replace all + policy: + ### Tool invocation policies ### + - id: allow-employee-basic-access + effect: Permit + principal: Role::"employee" + action: + - Action::"get_leave_balance" #tool name + - Action::"request_certificate" + resource: + - Server::"askHR" # mcp-server name + - Agent::"employee_agent" # agent name + + - id: allow-manager-full-access + effect: Permit + principal: Role::"manager" + action: + - Action::"get_leave_balance" + - Action::"approve_leave" + - Action::"promote_employee" + - Action::"view_performance" + - Action::"view_full_output" + resource: + - Agent::"manager_agent" + - Server::"payroll_tool" + + - id: allow-hr-hr_tool + effect: Permit + principal: Role::"hr" + action: + - Action::"update_payroll" + - Action::"view_performance" + - Action::"view_full_output" + resource: Server::"hr_tool" + + - id: redact-non-manager-views + effect: Permit + principal: Role::"employee" + action: Action::"view_redacted_output" + resource: + - Server::"payroll_tool" + - Agent::"manager_agent" + - Server::"askHR" + + ### Resource invocation policies ### + - id: allow-admin-resources # policy for resources + effect: Permit + principal: Role::"admin" + action: + - Action::"view_full_output" + resource: Resource::""https://example.com/data"" #Resource:: + + - id: allow-employee-redacted-resources # policy for resources + effect: Permit + principal: Role::"employee" + action: + - Action::"view_redacted_output" + resource: Resource::""https://example.com/data"" #Resource:: + + ### Prompt invocation policies ### + - id: allow-admin-prompts # policy for resources + effect: Permit + principal: Role::"admin" + action: + - Action::"view_full_output" + resource: Prompt::"judge_prompts" #Prompt:: + + + - id: allow-employee-redacted-prompts # policy for resources + effect: Permit + principal: Role::"employee" + action: + - Action::"view_redacted_output" + resource: Prompt::"judge_prompts" #Prompt:: + +``` + +#### Tool Invocation Policies + +For the RBAC policy related to `tool_pre_invoke` and `tool_post_invoke` +Example: +```yaml + - id: allow-employee-basic-access + effect: Permit + principal: Role::"employee" + action: + - Action::"get_leave_balance" #tool name + - Action::"request_certificate" + resource: + - Server::"askHR" # mcp-server name + - Agent::"employee_agent" # agent name +``` + +Here, user with role `employee` (**Role**) is only allowed to invoke tool `get_leave_balance` (**Action**) belonging to the MCP server or (**Server**). + +In another policy defined for tools + +```yaml + + - id: allow-hr-hr_tool + effect: Permit + principal: Role::"hr" + action: + - Action::"update_payroll" + - Action::"view_performance" + - Action::"view_full_output" + resource: Server::"hr_tool" + + - id: redact-non-manager-views + effect: Permit + principal: Role::"employee" + action: Action::"view_redacted_output" + resource: + - Server::"payroll_tool" + - Agent::"manager_agent" + - Server::"askHR" +``` + + +The actions like `view_full_output` and `view_redacted_output` has been used. This basically controls the +level of output visibile to the user. In the above policy, user with role `hr` is only allowed to view the output of `update_payroll`. Similary for the second policy, user with role `employee` is only allowed to view redacted output of the tool. + + +#### Prompt Invocation Policies + + +```yaml + + ### Prompt invocation policies ### + - id: allow-admin-prompts # policy for resources + effect: Permit + principal: Role::"admin" + action: + - Action::"view_full_output" + resource: Prompt::"judge_prompts" #Prompt:: + + + - id: allow-employee-redacted-prompts # policy for resources + effect: Permit + principal: Role::"employee" + action: + - Action::"view_redacted_output" + resource: Prompt::"judge_prompts" #Prompt:: +``` + +Here, in the above polcicy, given a prompt template `judge_prompts`, user of role `admin` is only allowed to view full prompt. However, if a user is of role `employee`, then it could only see redacted version of the prompt. + + +#### Resource Invocation Policies + +**NOTE:** Please don't be confused with the word resource in cedar to the word resource in MCP ContextForge. + +```yaml + + - id: allow-admin-resources # policy for resources + effect: Permit + principal: Role::"admin" + action: + - Action::"view_full_output" + resource: Resource::"https://example.com/data" #Resource:: + + - id: allow-employee-redacted-resources # policy for resources + effect: Permit + principal: Role::"employee" + action: + - Action::"view_redacted_output" + resource: Resource::"https://example.com/data" #Resource:: +``` + +Here, `Resource` word used in policy, is if resource hooks are invoked. So, in the above policy, +user with role `admin` is only allowed to view full output of uri `https://example.com/data`. Where, the user is of `employee` role, it can only see the redacted versionaaaaa of the resource output. + + +#### policy_output_keywords + +``` + view_full: "view_full_output" + view_redacted: "view_redacted_output" +``` + +has been provided, so everytime a user defines a policy, if it wants to control the output visibility of +any of the tool, prompt, resource or agent in MCP gateway, it can provide the keyword, it's supposed to use in the policy in `policy_output_keywords`. CedarPolicyPlugin will internally use this mapping to redact or fully display the tool, prompt or resource response in post hooks. + + + + +3. Now, the policy and plugin configurations are defined in `resources/config.yaml` file, next step is build this as an external MCP server. + +* `make venv`: This will create a virtual environment to develop or build your plugin. +* `make install && make install-dev`: To install all the required libraries in the environment. +* `make build`: This will build a docker image named `mcpgateway/cedarpolicyplugin` +* `make start`: This will start the cedarpolicyplugin container. + +This confirms that your container is running fine: +``` +WARNING:mcpgateway.observability:OpenTelemetry not installed. Proceeding with graceful fallbacks. +INFO: Started server process [9] +INFO: Waiting for application startup. +INFO: Application startup complete. +INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit) +INFO: 127.0.0.1:55196 - "GET /health HTTP/1.1" 200 OK + +``` + +4. Now, you can add this external plugin configuration, in `plugins/config.yaml`: +3. The next step is to enable the opa plugin which you can do by adding `PLUGINS_ENABLED=true` and the following blob in `plugins/config.yaml` file. This will indicate that OPA Plugin is running as an external MCP server. + + ```yaml + - name: "CedarPolicyPlugin" + kind: "external" + priority: 10 # adjust the priority + mcp: + proto: STREAMABLEHTTP + url: http://127.0.0.1:8000/mcp + ``` + +## Testing + +There are set of test cases in the `cedar/tests` folder. The file named `test_cedarpolicyplugin.py` file which contains detailed test cases for RBAC policies enforced on tools, prompts and resources. +run `make test` to run all the test cases. + + + +## Difference from OPAPlugin + +The OPA plugin runs an OPA server to enforce policies, whereas the Cedar plugin uses the `cedarpy` library and performs policy enforcement locally without requiring an external service. +OPA plugin requires to know `rego` to define policies by user while the `Cedar` plugin can be defined either in `cedar` or user friendly `custom_dsl` language. +Right now, the cedar plugin enforces RBAC policies and it could be extended to enforce ABAC policies using the same plugin. diff --git a/plugins/external/cedar/cedarpolicyplugin/__init__.py b/plugins/external/cedar/cedarpolicyplugin/__init__.py new file mode 100644 index 000000000..52cdda086 --- /dev/null +++ b/plugins/external/cedar/cedarpolicyplugin/__init__.py @@ -0,0 +1,24 @@ +# -*- coding: utf-8 -*- +"""MCP Gateway CedarPolicyPlugin Plugin - A plugin that does policy decision and enforcement using cedar. + +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Shriti Priya + +""" + +import importlib.metadata + +# Package version +try: + __version__ = importlib.metadata.version("cedarpolicyplugin") +except Exception: + __version__ = "0.1.0" + +__author__ = "Shriti Priya" +__copyright__ = "Copyright 2025" +__license__ = "Apache 2.0" +__description__ = "A plugin that does policy decision and enforcement using cedar" +__url__ = "https://ibm.github.io/mcp-context-forge/" +__download_url__ = "https://github.com/IBM/mcp-context-forge" +__packages__ = ["cedarpolicyplugin"] diff --git a/plugins/external/cedar/cedarpolicyplugin/plugin-manifest.yaml b/plugins/external/cedar/cedarpolicyplugin/plugin-manifest.yaml new file mode 100644 index 000000000..38ec3ceea --- /dev/null +++ b/plugins/external/cedar/cedarpolicyplugin/plugin-manifest.yaml @@ -0,0 +1,9 @@ +description: "A plugin that does policy decision and enforcement using cedar" +author: "Shriti Priya" +version: "0.1.0" +available_hooks: + - "prompt_pre_hook" + - "prompt_post_hook" + - "tool_pre_hook" + - "tool_post_hook" +default_configs: diff --git a/plugins/external/cedar/cedarpolicyplugin/plugin.py b/plugins/external/cedar/cedarpolicyplugin/plugin.py new file mode 100644 index 000000000..dd98f5f1b --- /dev/null +++ b/plugins/external/cedar/cedarpolicyplugin/plugin.py @@ -0,0 +1,678 @@ +# -*- coding: utf-8 -*- +"""A plugin that does policy decision and enforcement using cedar. + +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Shriti Priya + +This module loads configurations for plugins. +""" + +# Standard +from enum import Enum +import re +from typing import Any +from urllib.parse import urlparse + +# Third-Party +from cedarpolicyplugin.schema import CedarConfig, CedarInput +from cedarpy import AuthzResult, Decision, is_authorized + +# First-Party +from mcpgateway.plugins.framework import ( + Plugin, + PluginConfig, + PluginContext, + PluginError, + PluginErrorModel, + PluginViolation, + PromptPosthookPayload, + PromptPosthookResult, + PromptPrehookPayload, + PromptPrehookResult, + ToolPostInvokePayload, + ToolPostInvokeResult, + ToolPreInvokePayload, + ToolPreInvokeResult, +) +from mcpgateway.plugins.framework.hooks.resources import ResourcePostFetchPayload, ResourcePostFetchResult, ResourcePreFetchPayload, ResourcePreFetchResult +from mcpgateway.services.logging_service import LoggingService + +# Initialize logging service first +logging_service = LoggingService() +logger = logging_service.get_logger(__name__) + + +class CedarCodes(str, Enum): + """CedarCodes implementation.""" + + ALLOW_CODE = "ALLOW" + DENIAL_CODE = "DENY" + AUDIT_CODE = "AUDIT" + REQUIRES_HUMAN_APPROVAL_CODE = "REQUIRES_APPROVAL" + + +class CedarResponseTemplates(str, Enum): + """CedarResponseTemplates implementation.""" + + CEDAR_REASON = "Cedar policy denied for {hook_type}" + CEDAR_DESC = "{hook_type} not allowed" + + +class CedarResourceTemplates(str, Enum): + """CedarResourceTemplates implementation.""" + + SERVER = 'Server::"{resource_type}"' + AGENT = 'Agent::"{resource_type}"' + PROMPT = 'Prompt::"{resource_type}"' + RESOURCE = 'Resource::"{resource_type}"' + + +class CedarErrorCodes(str, Enum): + """CedarPolicyPlugin errors""" + + UNSUPPORTED_RESOURCE_TYPE = "Unspecified resource types, accepted resources server, prompt, agent and resource" + UNSPECIFIED_USER_ROLE = "User role is not defined" + UNSPECIFIED_POLICY = "No policy has been provided" + UNSPECIFIED_OUTPUT_ACTION = "Unspecified output action in policy configuration" + UNSPECIFIED_SERVER = "Unspecified server for tool request" + UNSUPPORTED_CONTENT_TYPE = "Unsupported content type" + + +class CedarPolicyPlugin(Plugin): + """A plugin that does policy decision and enforcement using cedar.""" + + def __init__(self, config: PluginConfig): + """Entry init block for plugin. + + Args: + logger: logger that the skill can make use of + config: the skill configuration + """ + super().__init__(config) + self.cedar_config = CedarConfig.model_validate(self._config.config) + self.cedar_context_key = "cedar_policy_context" + self.jwt_info = {} + logger.info(f"CedarPolicyPlugin initialised with configuration {self.cedar_config}") + + def _set_jwt_info(self, user_role_mapping: dict) -> None: + """Sets user role mapping information from jwt tokens + + Args: + info(dict): with user mappings + """ + self.jwt_info["users"] = user_role_mapping + + def _extract_payload_key(self, content: Any = None, key: str = None, result: dict[str, list] = None) -> None: + """Function to extract values of passed in key in the payload recursively based on if the content is of type list, dict + str or pydantic structure. The value is inplace updated in result. + + Args: + content: The content of post hook results. + key: The key for which value needs to be extracted for. + result: A list of all the values for a key. + """ + if isinstance(content, list): + for element in content: + if isinstance(element, dict) and key in element: + self._extract_payload_key(element, key, result) + elif isinstance(content, dict): + if key in content or hasattr(content, key): + result[key].append(content[key]) + elif isinstance(content, str): + result[key].append(content) + elif hasattr(content, key): + result[key].append(getattr(content, key)) + else: + logger.error(f"{CedarErrorCodes.UNSUPPORTED_CONTENT_TYPE.value}: {type(content)}") + raise PluginError(PluginErrorModel(message=CedarErrorCodes.UNSUPPORTED_CONTENT_TYPE.value, plugin_name="CedarPolicyPlugin")) + + def _evaluate_policy(self, request: dict, policy_expr: str) -> str: + """Function that evaluates and enforce cedar policy using is_authorized function in cedarpy library + Args: + request(dict): The request dict consisting of principal, action, resource or context keys. + policy_exp(str): The policy expression to evaluate the request on + + Returns: + decision(str): "Allow" or "Deny" + """ + result: AuthzResult = is_authorized(request, policy_expr, []) + decision = "Allow" if result.decision == Decision.Allow else "Deny" + return decision + + def _yamlpolicy2text(self, policies: list) -> str: + """Function to convert yaml representation of policies to text + Args: + policies(list): A list of cedar policies with dict values consisting of individual policies + + Returns: + cedar_policy_text(str): string representation of policy + """ + cedar_policy_text = "" + for policy in policies: + actions = policy["action"] if isinstance(policy["action"], list) else [policy["action"]] + resources = policy["resource"] if isinstance(policy["resource"], list) else [policy["resource"]] + + for res in resources: + actions_str = ", ".join(actions) + cedar_policy_text += "permit(\n" + cedar_policy_text += f' principal == {policy["principal"]},\n' + cedar_policy_text += f" action in [{actions_str}],\n" + cedar_policy_text += f" resource == {res}\n" + cedar_policy_text += ");\n\n" + + return cedar_policy_text + + def _dsl2cedar(self, policy_string: str) -> str: + """Function to convert custom dsl representation of policies to cedar + Args: + policy_string: string representation of policies + + Returns: + cedar_policy_text(str): string representation of policy + """ + lines = [line.strip() for line in policy_string.splitlines() if line.strip()] + policies = [] + current_role = None + current_actions = [] + resource_category = None + resource_name = None + + pattern = r"\[role:([A-Za-z0-9_]+):(resource|prompt|server|agent)/([^\]]+)\]" + for line in lines: + match = re.match(pattern, line) + if match: + if current_role and resource_category and resource_name and current_actions: + resource_category = resource_category.capitalize() + policies.append( + { + "id": f"allow-{current_role}-{resource_category}", + "effect": "Permit", + "principal": f'Role::"{current_role}"', + "action": [f'Action::"{a}"' for a in current_actions], + "resource": f'{resource_category}::"{resource_name}"', + } + ) + current_role, resource_category, resource_name = match.groups() + current_actions = [] + else: + current_actions.append(line) + if current_role and resource_category and resource_name and current_actions: + resource_category = resource_category.capitalize() + policies.append( + { + "id": f"allow-{current_role}-{resource_category}", + "effect": "Permit", + "principal": f'Role::"{current_role}"', + "action": [f'Action::"{a}"' for a in current_actions], + "resource": f'{resource_category}::"{resource_name}"', + } + ) + + cedar_policy_text = self._yamlpolicy2text(policies) + return cedar_policy_text + + def _preprocess_request(self, user: str, action: str, resource: str, hook_type: str) -> CedarInput: + """Function to pre process request into a format that cedar accepts + Args: + user(str): name of the user + action(str): action requested by the user + resource(str): resource requested by the user + hook_type(str): the hook type on which invocation is made + + Returns: + request(CedarInput): pydantic representation of request as excpected by cedar policy + """ + user_role = "" + if hook_type in ["tool_post_invoke", "tool_pre_invoke"]: + resource_expr = CedarResourceTemplates.SERVER.format(resource_type=resource) + elif hook_type in ["agent_post_invoke", "agent_pre_invoke"]: + resource_expr = CedarResourceTemplates.AGENT.format(resource_type=resource) + elif hook_type in ["resource_post_fetch", "resource_pre_fetch"]: + resource_expr = CedarResourceTemplates.RESOURCE.format(resource_type=resource) + elif hook_type in ["prompt_post_fetch", "prompt_pre_fetch"]: + resource_expr = CedarResourceTemplates.PROMPT.format(resource_type=resource) + else: + logger.error(f"{CedarErrorCodes.UNSUPPORTED_RESOURCE_TYPE.value}: {hook_type}") + raise PluginError(PluginErrorModel(message=CedarErrorCodes.UNSUPPORTED_RESOURCE_TYPE.value, plugin_name="CedarPolicyPlugin")) + + if len(self.jwt_info) > 0 and "users" in self.jwt_info: + user_role = self.jwt_info["users"].get(user) + else: + logger.error(f"{CedarErrorCodes.UNSPECIFIED_USER_ROLE.value}") + raise PluginError(PluginErrorModel(message=CedarErrorCodes.UNSPECIFIED_USER_ROLE.value, plugin_name="CedarPolicyPlugin")) + + principal_expr = f'Role::"{user_role}"' + action_expr = f'Action::"{action}"' + request = CedarInput(principal=principal_expr, action=action_expr, resource=resource_expr, context={}).model_dump() + return request + + def _redact_output(self, payload: str, pattern: str) -> str: + """Function that redacts the output of prompt, tool or resource + NOTE: It's an extremely simple logic for redaction, could be replaced with more advanced + as per need. + Args: + payload(str): payload or output + pattern(str): regex expression to replace + Returns: + redacted_text(str): redacted representation of payload string + """ + redacted_text = "" + if not pattern: + redacted_text = payload + elif pattern == "all": + redacted_text = "[REDACTED]" + else: + redacted_text = re.sub(pattern, "[REDACTED]", payload) + return redacted_text + + async def prompt_pre_fetch(self, payload: PromptPrehookPayload, context: PluginContext) -> PromptPrehookResult: + """The plugin hook run before a prompt is retrieved and rendered. + + Args: + payload: The prompt payload to be analyzed. + context: contextual information about the hook call. + + Returns: + The result of the plugin's analysis, including whether the prompt can proceed. + """ + hook_type = "prompt_pre_fetch" + logger.info(f"Processing {hook_type} for '{payload.args}' with {len(payload.args) if payload.args else 0}") + logger.info(f"Processing context {context}") + + if not payload.args: + return PromptPrehookResult() + + policy = None + user = "" + result_full = None + result_redacted = None + + if self.cedar_config.policy_lang == "cedar": + if self.cedar_config.policy: + policy = self._yamlpolicy2text(self.cedar_config.policy) + else: + logger.error(f"{CedarErrorCodes.UNSPECIFIED_POLICY.value}") + raise PluginError(PluginErrorModel(message=CedarErrorCodes.UNSPECIFIED_POLICY.value, plugin_name="CedarPolicyPlugin")) + if self.cedar_config.policy_lang == "custom_dsl": + if self.cedar_config.policy: + policy = self._dsl2cedar(self.cedar_config.policy) + else: + logger.error(f"{CedarErrorCodes.UNSPECIFIED_POLICY.value}") + raise PluginError(PluginErrorModel(message=CedarErrorCodes.UNSPECIFIED_POLICY.value, plugin_name="CedarPolicyPlugin")) + + if context.global_context.user: + user = context.global_context.user + + if self.cedar_config.policy_output_keywords: + view_full = self.cedar_config.policy_output_keywords.get("view_full", None) + view_redacted = self.cedar_config.policy_output_keywords.get("view_redacted", None) + if not view_full and not view_redacted: + logger.error(f"{CedarErrorCodes.UNSPECIFIED_OUTPUT_ACTION.value}") + raise PluginError(PluginErrorModel(message=CedarErrorCodes.UNSPECIFIED_OUTPUT_ACTION.value, plugin_name="CedarPolicyPlugin")) + if view_full and policy: + request = self._preprocess_request(user, view_full, payload.prompt_id, hook_type) + result_full = self._evaluate_policy(request, policy) + if view_redacted and policy: + request = self._preprocess_request(user, view_redacted, payload.prompt_id, hook_type) + result_redacted = self._evaluate_policy(request, policy) + + if result_full == Decision.Deny.value and result_redacted == Decision.Deny.value: + violation = PluginViolation( + reason=CedarResponseTemplates.CEDAR_REASON.format(hook_type=hook_type), + description=CedarResponseTemplates.CEDAR_DESC.format(hook_type=hook_type), + code=CedarCodes.DENIAL_CODE, + details={}, + ) + return PromptPrehookResult(modified_payload=payload, violation=violation, continue_processing=False) + return PromptPrehookResult(continue_processing=True) + + async def prompt_post_fetch(self, payload: PromptPosthookPayload, context: PluginContext) -> PromptPosthookResult: + """Plugin hook run after a prompt is rendered. + + Args: + payload: The prompt payload to be analyzed. + context: Contextual information about the hook call. + + Returns: + The result of the plugin's analysis, including whether the prompt can proceed. + """ + hook_type = "prompt_post_fetch" + logger.info(f"Processing {hook_type} for '{payload.result}'") + logger.info(f"Processing context {context}") + + if not payload.result: + return PromptPosthookResult() + + policy = None + user = "" + result_full = None + result_redacted = None + + if self.cedar_config.policy_lang == "cedar": + if self.cedar_config.policy: + policy = self._yamlpolicy2text(self.cedar_config.policy) + else: + logger.error(f"{CedarErrorCodes.UNSPECIFIED_POLICY.value}") + raise PluginError(PluginErrorModel(message=CedarErrorCodes.UNSPECIFIED_POLICY.value, plugin_name="CedarPolicyPlugin")) + if self.cedar_config.policy_lang == "custom_dsl": + if self.cedar_config.policy: + policy = self._dsl2cedar(self.cedar_config.policy) + else: + logger.error(f"{CedarErrorCodes.UNSPECIFIED_POLICY.value}") + raise PluginError(PluginErrorModel(message=CedarErrorCodes.UNSPECIFIED_POLICY.value, plugin_name="CedarPolicyPlugin")) + + if context.global_context.user: + user = context.global_context.user + + if self.cedar_config.policy_output_keywords: + view_full = self.cedar_config.policy_output_keywords.get("view_full", None) + view_redacted = self.cedar_config.policy_output_keywords.get("view_redacted", None) + if not view_full and not view_redacted: + logger.error(f"{CedarErrorCodes.UNSPECIFIED_OUTPUT_ACTION.value}") + raise PluginError(PluginErrorModel(message=CedarErrorCodes.UNSPECIFIED_OUTPUT_ACTION.value, plugin_name="CedarPolicyPlugin")) + if view_full and policy: + request = self._preprocess_request(user, view_full, payload.prompt_id, hook_type) + result_full = self._evaluate_policy(request, policy) + if view_redacted and policy: + request = self._preprocess_request(user, view_redacted, payload.prompt_id, hook_type) + result_redacted = self._evaluate_policy(request, policy) + + if result_full == Decision.Allow.value: + return PromptPosthookResult(continue_processing=True) + + if result_redacted == Decision.Allow.value: + if payload.result.messages: + for index, message in enumerate(payload.result.messages): + value = self._redact_output(message.content.text, self.cedar_config.policy_redaction_spec.pattern) + payload.result.messages[index].content.text = value + return PromptPosthookResult(modified_payload=payload, continue_processing=True) + + violation = PluginViolation( + reason=CedarResponseTemplates.CEDAR_REASON.format(hook_type=hook_type), + description=CedarResponseTemplates.CEDAR_DESC.format(hook_type=hook_type), + code=CedarCodes.DENIAL_CODE, + details={}, + ) + return PromptPosthookResult(modified_payload=payload, violation=violation, continue_processing=False) + return PromptPosthookResult(continue_processing=True) + + async def tool_pre_invoke(self, payload: ToolPreInvokePayload, context: PluginContext) -> ToolPreInvokeResult: + """Plugin hook run before a tool is invoked. + + Args: + payload: The tool payload to be analyzed. + context: Contextual information about the hook call. + + Returns: + The result of the plugin's analysis, including whether the tool can proceed. + """ + hook_type = "tool_pre_invoke" + logger.info(f"Processing {hook_type} for '{payload.args}' with {len(payload.args) if payload.args else 0}") + logger.info(f"Processing context {context}") + + if not payload.args: + return ToolPreInvokeResult() + + policy = None + user = "" + server_id = "" + + if self.cedar_config.policy_lang == "cedar": + if self.cedar_config.policy: + policy = self._yamlpolicy2text(self.cedar_config.policy) + else: + logger.error(f"{CedarErrorCodes.UNSPECIFIED_POLICY.value}") + raise PluginError(PluginErrorModel(message=CedarErrorCodes.UNSPECIFIED_POLICY.value, plugin_name="CedarPolicyPlugin")) + if self.cedar_config.policy_lang == "custom_dsl": + if self.cedar_config.policy: + policy = self._dsl2cedar(self.cedar_config.policy) + else: + logger.error(f"{CedarErrorCodes.UNSPECIFIED_POLICY.value}") + raise PluginError(PluginErrorModel(message=CedarErrorCodes.UNSPECIFIED_POLICY.value, plugin_name="CedarPolicyPlugin")) + + if context.global_context.user: + user = context.global_context.user + server_id = context.global_context.server_id + + if server_id: + request = self._preprocess_request(user, payload.name, server_id, hook_type) + else: + logger.error(f"{CedarErrorCodes.UNSPECIFIED_SERVER.value}") + raise PluginError(PluginErrorModel(message=CedarErrorCodes.UNSPECIFIED_SERVER.value, plugin_name="CedarPolicyPlugin")) + + if policy: + decision = self._evaluate_policy(request, policy) + if decision == Decision.Deny.value: + violation = PluginViolation( + reason=CedarResponseTemplates.CEDAR_REASON.format(hook_type=hook_type), + description=CedarResponseTemplates.CEDAR_DESC.format(hook_type=hook_type), + code=CedarCodes.DENIAL_CODE, + details={}, + ) + return ToolPreInvokeResult(modified_payload=payload, violation=violation, continue_processing=False) + return ToolPreInvokeResult(continue_processing=True) + + async def tool_post_invoke(self, payload: ToolPostInvokePayload, context: PluginContext) -> ToolPostInvokeResult: + """Plugin hook run after a tool is invoked. + + Args: + payload: The tool result payload to be analyzed. + context: Contextual information about the hook call. + + Returns: + The result of the plugin's analysis, including whether the tool result should proceed. + """ + + hook_type = "tool_post_invoke" + logger.info(f"Processing {hook_type} for '{payload.result}' with {len(payload.result) if payload.result else 0}") + logger.info(f"Processing context {context}") + + if not payload.result: + return ToolPostInvokeResult() + + policy = None + user = "" + server_id = "" + result_full = None + result_redacted = None + + if self.cedar_config.policy_lang == "cedar": + if self.cedar_config.policy: + policy = self._yamlpolicy2text(self.cedar_config.policy) + else: + logger.error(f"{CedarErrorCodes.UNSPECIFIED_POLICY.value}") + raise PluginError(PluginErrorModel(message=CedarErrorCodes.UNSPECIFIED_POLICY.value, plugin_name="CedarPolicyPlugin")) + if self.cedar_config.policy_lang == "custom_dsl": + if self.cedar_config.policy: + policy = self._dsl2cedar(self.cedar_config.policy) + else: + logger.error(f"{CedarErrorCodes.UNSPECIFIED_POLICY.value}") + raise PluginError(PluginErrorModel(message=CedarErrorCodes.UNSPECIFIED_POLICY.value, plugin_name="CedarPolicyPlugin")) + + if context.global_context.user: + user = context.global_context.user + server_id = context.global_context.server_id + + if self.cedar_config.policy_output_keywords: + view_full = self.cedar_config.policy_output_keywords.get("view_full", None) + view_redacted = self.cedar_config.policy_output_keywords.get("view_redacted", None) + if not view_full and not view_redacted: + logger.error(f"{CedarErrorCodes.UNSPECIFIED_OUTPUT_ACTION.value}") + raise PluginError(PluginErrorModel(message=CedarErrorCodes.UNSPECIFIED_OUTPUT_ACTION.value, plugin_name="CedarPolicyPlugin")) + if view_full and policy: + request = self._preprocess_request(user, view_full, server_id, hook_type) + result_full = self._evaluate_policy(request, policy) + if view_redacted and policy: + request = self._preprocess_request(user, view_redacted, server_id, hook_type) + result_redacted = self._evaluate_policy(request, policy) + + # Evaluate Policy and based on that redact output + if policy: + request = self._preprocess_request(user, payload.name, server_id, hook_type) + result_action = self._evaluate_policy(request, policy) + # Check if full output view is allowed by policy + if result_action == Decision.Allow.value: + if result_full == Decision.Allow.value: + return ToolPostInvokeResult(continue_processing=True) + if result_redacted == Decision.Allow.value: + if payload.result and isinstance(payload.result, dict): + for key in payload.result: + if isinstance(payload.result[key], str): + value = self._redact_output(payload.result[key], self.cedar_config.policy_redaction_spec.pattern) + payload.result[key] = value + elif payload.result and isinstance(payload.result, str): + payload.result = self._redact_output(payload.result, self.cedar_config.policy_redaction_spec.pattern) + return ToolPostInvokeResult(continue_processing=True, modified_payload=payload) + # If none of the redacted or full output views are allowed by policy then deny + else: + violation = PluginViolation( + reason=CedarResponseTemplates.CEDAR_REASON.format(hook_type=hook_type), + description=CedarResponseTemplates.CEDAR_DESC.format(hook_type=hook_type), + code=CedarCodes.DENIAL_CODE, + details={}, + ) + return ToolPostInvokeResult(modified_payload=payload, violation=violation, continue_processing=False) + return ToolPostInvokeResult(continue_processing=True) + + async def resource_pre_fetch(self, payload: ResourcePreFetchPayload, context: PluginContext) -> ResourcePreFetchResult: + """OPA Plugin hook that runs after resource pre fetch. This hook takes in payload and context and further evaluates rego + policies on the input by sending the request to opa server. + + Args: + payload: The resource pre fetch input or payload to be analyzed. + context: Contextual information about the hook call. + + Returns: + The result of the plugin's analysis, including whether the resource input can be passed further. + """ + + hook_type = "resource_pre_fetch" + logger.info(f"Processing {hook_type} for '{payload.uri}'") + logger.info(f"Processing context {context}") + + if not payload.uri: + return ResourcePreFetchResult() + + try: + parsed = urlparse(payload.uri) + except Exception as e: + violation = PluginViolation(reason="Invalid URI", description=f"Could not parse resource URI: {e}", code="INVALID_URI", details={"uri": payload.uri, "error": str(e)}) + return ResourcePreFetchResult(continue_processing=False, violation=violation) + + # Check if URI has a scheme + if not parsed.scheme: + violation = PluginViolation(reason="Invalid URI format", description="URI must have a valid scheme (protocol)", code="INVALID_URI", details={"uri": payload.uri}) + return ResourcePreFetchResult(continue_processing=False, violation=violation) + + policy = None + user = "" + result_full = None + result_redacted = None + + if self.cedar_config.policy_lang == "cedar": + if self.cedar_config.policy: + policy = self._yamlpolicy2text(self.cedar_config.policy) + else: + logger.error(f"{CedarErrorCodes.UNSPECIFIED_POLICY.value}") + raise PluginError(PluginErrorModel(message=CedarErrorCodes.UNSPECIFIED_POLICY.value, plugin_name="CedarPolicyPlugin")) + if self.cedar_config.policy_lang == "custom_dsl": + if self.cedar_config.policy: + policy = self._dsl2cedar(self.cedar_config.policy) + else: + logger.error(f"{CedarErrorCodes.UNSPECIFIED_POLICY.value}") + raise PluginError(PluginErrorModel(message=CedarErrorCodes.UNSPECIFIED_POLICY.value, plugin_name="CedarPolicyPlugin")) + + if context.global_context.user: + user = context.global_context.user + + if self.cedar_config.policy_output_keywords: + view_full = self.cedar_config.policy_output_keywords.get("view_full", None) + view_redacted = self.cedar_config.policy_output_keywords.get("view_redacted", None) + if not view_full and not view_redacted: + logger.error(f"{CedarErrorCodes.UNSPECIFIED_OUTPUT_ACTION.value}") + raise PluginError(PluginErrorModel(message=CedarErrorCodes.UNSPECIFIED_OUTPUT_ACTION.value, plugin_name="CedarPolicyPlugin")) + if view_full and policy: + request = self._preprocess_request(user, view_full, payload.uri, hook_type) + result_full = self._evaluate_policy(request, policy) + if view_redacted and policy: + request = self._preprocess_request(user, view_redacted, payload.uri, hook_type) + result_redacted = self._evaluate_policy(request, policy) + + if result_full == Decision.Deny.value and result_redacted == Decision.Deny.value: + violation = PluginViolation( + reason=CedarResponseTemplates.CEDAR_REASON.format(hook_type=hook_type), + description=CedarResponseTemplates.CEDAR_DESC.format(hook_type=hook_type), + code=CedarCodes.DENIAL_CODE, + details={}, + ) + return ResourcePreFetchResult(modified_payload=payload, violation=violation, continue_processing=False) + return ResourcePreFetchResult(continue_processing=True) + + async def resource_post_fetch(self, payload: ResourcePostFetchPayload, context: PluginContext) -> ResourcePostFetchResult: + """OPA Plugin hook that runs after resource post fetch. This hook takes in payload and context and further evaluates rego + policies on the output by sending the request to opa server. + + Args: + payload: The resource post fetch output or payload to be analyzed. + context: Contextual information about the hook call. + + Returns: + The result of the plugin's analysis, including whether the resource output can be passed further. + """ + hook_type = "resource_post_fetch" + logger.info(f"Processing {hook_type} for '{payload.uri}'") + logger.info(f"Processing context {context}") + + policy = None + user = "" + result_full = None + result_redacted = None + + if self.cedar_config.policy_lang == "cedar": + if self.cedar_config.policy: + policy = self._yamlpolicy2text(self.cedar_config.policy) + else: + logger.error(f"{CedarErrorCodes.UNSPECIFIED_POLICY.value}") + raise PluginError(PluginErrorModel(message=CedarErrorCodes.UNSPECIFIED_POLICY.value, plugin_name="CedarPolicyPlugin")) + if self.cedar_config.policy_lang == "custom_dsl": + if self.cedar_config.policy: + policy = self._dsl2cedar(self.cedar_config.policy) + else: + logger.error(f"{CedarErrorCodes.UNSPECIFIED_POLICY.value}") + raise PluginError(PluginErrorModel(message=CedarErrorCodes.UNSPECIFIED_POLICY.value, plugin_name="CedarPolicyPlugin")) + + if context.global_context.user: + user = context.global_context.user + + if self.cedar_config.policy_output_keywords: + view_full = self.cedar_config.policy_output_keywords.get("view_full", None) + view_redacted = self.cedar_config.policy_output_keywords.get("view_redacted", None) + if not view_full and not view_redacted: + logger.error(f"{CedarErrorCodes.UNSPECIFIED_OUTPUT_ACTION.value}") + raise PluginError(PluginErrorModel(message=CedarErrorCodes.UNSPECIFIED_OUTPUT_ACTION.value, plugin_name="CedarPolicyPlugin")) + if view_full and policy: + request = self._preprocess_request(user, view_full, payload.uri, hook_type) + result_full = self._evaluate_policy(request, policy) + if view_redacted and policy: + request = self._preprocess_request(user, view_redacted, payload.uri, hook_type) + result_redacted = self._evaluate_policy(request, policy) + + if result_full == Decision.Allow.value: + return ResourcePostFetchResult(continue_processing=True) + + if result_redacted == Decision.Allow.value: + if payload.content: + if hasattr(payload.content, "text"): + value = self._redact_output(payload.content.text, self.cedar_config.policy_redaction_spec.pattern) + payload.content.text = value + return ResourcePostFetchResult(modified_payload=payload, continue_processing=True) + + violation = PluginViolation( + reason=CedarResponseTemplates.CEDAR_REASON.format(hook_type=hook_type), + description=CedarResponseTemplates.CEDAR_DESC.format(hook_type=hook_type), + code=CedarCodes.DENIAL_CODE, + details={}, + ) + return ResourcePostFetchResult(modified_payload=payload, violation=violation, continue_processing=False) + return ResourcePostFetchResult(continue_processing=True) diff --git a/plugins/external/cedar/cedarpolicyplugin/schema.py b/plugins/external/cedar/cedarpolicyplugin/schema.py new file mode 100644 index 000000000..9274e7674 --- /dev/null +++ b/plugins/external/cedar/cedarpolicyplugin/schema.py @@ -0,0 +1,57 @@ +# -*- coding: utf-8 -*- +"""A schema file for OPA plugin. + +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Shriti Priya + +This module defines schema for Cedar plugin. +""" + +# Standard +from typing import Any, Optional, Union + +# Third-Party +from pydantic import BaseModel + + +class CedarInput(BaseModel): + """BaseOPAInputKeys + + Attributes: + user (str) : specifying the user + action (str): specifies the action + resource (str): specifies the resource + context (Optional[dict[str, Any]]) : context provided for policy evaluation. + """ + + principal: str = "" + action: str = "" + resource: str = "" + context: Optional[dict[Any, Any]] = None + + +class Redaction(BaseModel): + """Configuration for Redaction + + Attributes: + pattern (str) : pattern detected in output to redact + """ + + pattern: str = "" + + +class CedarConfig(BaseModel): + """Configuration for the Cedar plugin. + + Attributes: + policy_land (str) : cedar or custom_dsl. If policy is represented in cedar mode or custom_dsl mode + policy (Union[list, str]): RBAC policy defined + policy_output_keywords (dict): this is to internally check if certain type of views are allowed for outputs + policy_redaction_spec (Redaction) : pattern or other parameters provided to redact the output + """ + + policy_lang: str = "None" + policy: Union[list, str] = None + policy_output_keywords: Optional[dict] = None + policy_redaction_spec: Optional[Redaction] = None diff --git a/plugins/external/cedar/examples/config-dsl.yaml b/plugins/external/cedar/examples/config-dsl.yaml new file mode 100644 index 000000000..e1584c8af --- /dev/null +++ b/plugins/external/cedar/examples/config-dsl.yaml @@ -0,0 +1,43 @@ +plugins: + - name: "CedarPolicyPlugin" + kind: "cedarpolicyplugin.plugin.CedarPolicyPlugin" + description: "A plugin that does policy decision and enforcement using cedar" + version: "0.1.0" + author: "Shriti Priya" + hooks: ["prompt_pre_fetch", "prompt_post_fetch", "tool_pre_invoke", "tool_post_invoke"] + tags: ["plugin"] + mode: "enforce" # enforce | permissive | disabled + priority: 150 + conditions: + # Apply to specific tools/servers + - server_ids: [] # Apply to all servers + tenant_ids: [] # Apply to all tenants + config: + policy_lang: custom_dsl + policy_output_keywords: + view_full: "view_full_output" + view_redacted: "view_redacted_output" + policy_redaction_spec: + pattern: '"\$\d{1,}(,\d{1,})*"' # provide regex, if none, then replace all + policy: | + [role:hr:server/hr_tool] + update_payroll + + [role:admin:resource/example.com/data] + view_full_output + + [role:admin:prompt/judge_prompts] + view_full_output + + +# Plugin directories to scan +plugin_dirs: + - "cedarpolicyplugin" + +# Global plugin settings +plugin_settings: + parallel_execution_within_band: true + plugin_timeout: 30 + fail_on_plugin_error: false + enable_plugin_api: true + plugin_health_check_interval: 60 diff --git a/plugins/external/cedar/pyproject.toml b/plugins/external/cedar/pyproject.toml new file mode 100644 index 000000000..334583a7a --- /dev/null +++ b/plugins/external/cedar/pyproject.toml @@ -0,0 +1,99 @@ +# ---------------------------------------------------------------- +# 💡 Build system (PEP 517) +# - setuptools ≥ 77 gives SPDX licence support (PEP 639) +# - wheel is needed by most build front-ends +# ---------------------------------------------------------------- +[build-system] +requires = ["setuptools>=77", "wheel"] +build-backend = "setuptools.build_meta" + +# ---------------------------------------------------------------- +# 📦 Core project metadata (PEP 621) +# ---------------------------------------------------------------- +[project] +name = "cedarpolicyplugin" +version = "0.1.0" +description = "A plugin that does policy decision and enforcement using cedar" +keywords = ["MCP","API","gateway","tools", + "agents","agentic ai","model context protocol","multi-agent","fastapi", + "json-rpc","sse","websocket","federation","security","authentication" +] +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Framework :: FastAPI", + "Framework :: AsyncIO", + "Topic :: Internet :: WWW/HTTP :: WSGI :: Application", + "Topic :: Software Development :: Libraries :: Application Frameworks" +] +readme = "README.md" +requires-python = ">=3.11,<3.14" +license = "Apache-2.0" +license-files = ["LICENSE"] + +maintainers = [ + {name = "Shriti Priya", email = "shritip@ibm.com"} +] + +authors = [ + {name = "Shriti Priya", email = "shritip@ibm.com"} +] + +dependencies = [ + "mcp>=1.16.0", + "mcp-contextforge-gateway", + "cedarpy>=4.1.0" +] + +# URLs +[project.urls] +Homepage = "https://ibm.github.io/mcp-context-forge/" +Documentation = "https://ibm.github.io/mcp-context-forge/" +Repository = "https://github.com/IBM/mcp-context-forge" +"Bug Tracker" = "https://github.com/IBM/mcp-context-forge/issues" +Changelog = "https://github.com/IBM/mcp-context-forge/blob/main/CHANGELOG.md" + +[tool.uv.sources] +mcp-contextforge-gateway = { git = "https://github.com/IBM/mcp-context-forge.git", rev = "main" } + +# ---------------------------------------------------------------- +# Optional dependency groups (extras) +# ---------------------------------------------------------------- +[project.optional-dependencies] +dev = [ + "black>=25.1.0", + "pytest>=8.4.1", + "pytest-asyncio>=1.1.0", + "pytest-cov>=6.2.1", + "pytest-dotenv>=0.5.2", + "pytest-env>=1.1.5", + "pytest-examples>=0.0.18", + "pytest-md-report>=0.7.0", + "pytest-rerunfailures>=15.1", + "pytest-trio>=0.8.0", + "pytest-xdist>=3.8.0", + "ruff>=0.12.9", + "unimport>=1.2.1", + "uv>=0.8.11", +] + +# -------------------------------------------------------------------- +# 🔧 setuptools-specific configuration +# -------------------------------------------------------------------- +[tool.setuptools] +include-package-data = true # ensure wheels include the data files + +# Automatic discovery: keep every package that starts with "cedarpolicyplugin" +[tool.setuptools.packages.find] +include = ["cedarpolicyplugin*"] +exclude = ["tests*"] + +## Runtime data files ------------------------------------------------ +[tool.setuptools.package-data] +cedarpolicyplugin = [ + "resources/plugins/config.yaml", +] diff --git a/plugins/external/cedar/resources/plugins/config.yaml b/plugins/external/cedar/resources/plugins/config.yaml new file mode 100644 index 000000000..23d048311 --- /dev/null +++ b/plugins/external/cedar/resources/plugins/config.yaml @@ -0,0 +1,102 @@ +plugins: + - name: "CedarPolicyPlugin" + kind: "cedarpolicyplugin.plugin.CedarPolicyPlugin" + description: "A plugin that does policy decision and enforcement using cedar" + version: "0.1.0" + author: "Shriti Priya" + hooks: ["prompt_pre_fetch", "prompt_post_fetch", "tool_pre_invoke", "tool_post_invoke"] + tags: ["plugin"] + mode: "enforce" # enforce | permissive | disabled + priority: 150 + conditions: + # Apply to specific tools/servers + - server_ids: [] # Apply to all servers + tenant_ids: [] # Apply to all tenants + config: + policy_lang: cedar + policy_output_keywords: + view_full: "view_full_output" + view_redacted: "view_redacted_output" + policy_redaction_spec: + pattern: '"\$\d{1,}(,\d{1,})*"' # provide regex, if none, then replace all + policy: + - id: allow-employee-basic-access + effect: Permit + principal: Role::"employee" + action: + - Action::"get_leave_balance" #tool name + - Action::"request_certificate" + resource: + - Server::"askHR" # mcp-server name + - Agent::"employee_agent" # agent name + + - id: allow-manager-full-access + effect: Permit + principal: Role::"manager" + action: + - Action::"get_leave_balance" + - Action::"approve_leave" + - Action::"promote_employee" + - Action::"view_performance" + - Action::"view_full_output" + resource: + - Agent::"manager_agent" + - Server::"payroll_tool" + + - id: allow-hr-hr_tool + effect: Permit + principal: Role::"hr" + action: + - Action::"update_payroll" + - Action::"view_performance" + - Action::"view_full_output" + resource: Server::"hr_tool" + + - id: redact-non-manager-views + effect: Permit + principal: Role::"employee" + action: Action::"view_redacted_output" + resource: + - Server::"payroll_tool" + - Agent::"manager_agent" + - Server::"askHR" + + - id: allow-admin-resources # policy for resources + effect: Permit + principal: Role::"admin" + action: + - Action::"view_full_output" + resource: Resource::""https://example.com/data"" #Resource:: + + - id: allow-employee-redacted-resources # policy for resources + effect: Permit + principal: Role::"employee" + action: + - Action::"view_redacted_output" + resource: Resource::""https://example.com/data"" #Resource:: + + - id: allow-admin-prompts # policy for resources + effect: Permit + principal: Role::"admin" + action: + - Action::"view_full_output" + resource: Prompts::"judge_prompts" #Prompt:: + + - id: allow-employee-redacted-prompts # policy for resources + effect: Permit + principal: Role::"employee" + action: + - Action::"view_redacted_output" + resource: Prompts::"judge_prompts" #Prompt:: + +# Plugin directories to scan +plugin_dirs: + - "cedarpolicyplugin" + +# Global plugin settings +plugin_settings: + parallel_execution_within_band: true + plugin_timeout: 30 + fail_on_plugin_error: false + enable_plugin_api: true + plugin_health_check_interval: 60 diff --git a/plugins/external/cedar/resources/runtime/config.yaml b/plugins/external/cedar/resources/runtime/config.yaml new file mode 100644 index 000000000..5b26791f5 --- /dev/null +++ b/plugins/external/cedar/resources/runtime/config.yaml @@ -0,0 +1,71 @@ +# config.yaml +host: + name: "cedarpolicyplugin" + log_level: "INFO" + +server: + type: "streamable-http" # "stdio" or "sse" or "streamable-http" + #auth: "bearer" # this line is needed to enable bearer auth + +# Logging configuration - controls all logging behavior +logging: + level: "WARNING" # Changed from INFO to WARNING for quieter default + format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + reset_handlers: true + quiet_libraries: true + + # Specific logger overrides to silence noisy components + loggers: + # Your existing overrides + "chuk_mcp_runtime.proxy": "WARNING" + "chuk_mcp_runtime.proxy.manager": "WARNING" + "chuk_mcp_runtime.proxy.tool_wrapper": "WARNING" + "chuk_tool_processor.mcp.stream_manager": "WARNING" + "chuk_tool_processor.mcp.register": "WARNING" + "chuk_tool_processor.mcp.setup_stdio": "WARNING" + "chuk_mcp_runtime.common.tool_naming": "WARNING" + "chuk_mcp_runtime.common.openai_compatibility": "WARNING" + + # NEW: Add the noisy loggers you're seeing + "chuk_sessions.session_manager": "ERROR" + "chuk_mcp_runtime.session.native": "ERROR" + "chuk_mcp_runtime.tools.artifacts": "ERROR" + "chuk_mcp_runtime.tools.session": "ERROR" + "chuk_artifacts.store": "ERROR" + "chuk_mcp_runtime.entry": "WARNING" # Keep some info but less chatty + "chuk_mcp_runtime.server": "WARNING" # Server start/stop messages + +# optional overrides +sse: + host: "0.0.0.0" + port: 8000 + sse_path: "/sse" + message_path: "/messages/" + health_path: "/health" + log_level: "info" + access_log: true + +streamable-http: + host: "0.0.0.0" + port: 8000 + mcp_path: "/mcp" + stateless: true + json_response: true + health_path: "/health" + log_level: "info" + access_log: true + +proxy: + enabled: false + namespace: "proxy" + openai_compatible: false # ← set to true if you want underscores + +# Session tools (disabled by default - must enable explicitly) +session_tools: + enabled: false # Must explicitly enable + +# Artifact storage (disabled by default - must enable explicitly) +artifacts: + enabled: false # Must explicitly enable + storage_provider: "filesystem" + session_provider: "memory" diff --git a/plugins/external/cedar/run-server.sh b/plugins/external/cedar/run-server.sh new file mode 100755 index 000000000..d73f57de5 --- /dev/null +++ b/plugins/external/cedar/run-server.sh @@ -0,0 +1,43 @@ +#!/usr/bin/env bash +#─────────────────────────────────────────────────────────────────────────────── +# Script : run-server.sh +# Purpose: Launch the MCP Gateway's Plugin API +# +# Description: +# This script launches an API server using +# chuck runtime. +# +# Environment Variables: +# API_SERVER_SCRIPT : Path to the server script (optional, auto-detected) +# PLUGINS_CONFIG_PATH : Path to the plugin config (optional, default: ./resources/plugins/config.yaml) +# CHUK_MCP_CONFIG_PATH : Path to the chuck-mcp-runtime config (optional, default: ./resources/runtime/config.yaml) +# +# Usage: +# ./run-server.sh # Run server +#─────────────────────────────────────────────────────────────────────────────── + +# Exit immediately on error, undefined variable, or pipe failure +set -euo pipefail + +#──────────────────────────────────────────────────────────────────────────────── +# SECTION 1: Script Location Detection +# Determine the absolute path of the API server script +#──────────────────────────────────────────────────────────────────────────────── +if [[ -z "${API_SERVER_SCRIPT:-}" ]]; then + API_SERVER_SCRIPT="$(python -c 'import mcpgateway.plugins.framework.external.mcp.server.runtime as server; print(server.__file__)')" + echo "✓ API server script path auto-detected: ${API_SERVER_SCRIPT}" +else + echo "✓ Using provided API server script path: ${API_SERVER_SCRIPT}" +fi + +#──────────────────────────────────────────────────────────────────────────────── +# SECTION 2: Run the API server +# Run the API server from configuration +#──────────────────────────────────────────────────────────────────────────────── + +PLUGINS_CONFIG_PATH=${PLUGINS_CONFIG_PATH:-./resources/plugins/config.yaml} +CHUK_MCP_CONFIG_PATH=${CHUK_MCP_CONFIG_PATH:-./resources/runtime/config.yaml} + +echo "✓ Using plugin config from: ${PLUGINS_CONFIG_PATH}" +echo "✓ Running API server with config from: ${CHUK_MCP_CONFIG_PATH}" +python ${API_SERVER_SCRIPT} diff --git a/plugins/external/cedar/tests/__init__.py b/plugins/external/cedar/tests/__init__.py new file mode 100644 index 000000000..2e033f69b --- /dev/null +++ b/plugins/external/cedar/tests/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- +"""Location: ./tests/__init__.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Shriti Priya +""" diff --git a/plugins/external/cedar/tests/pytest.ini b/plugins/external/cedar/tests/pytest.ini new file mode 100644 index 000000000..ff60648e6 --- /dev/null +++ b/plugins/external/cedar/tests/pytest.ini @@ -0,0 +1,13 @@ +[pytest] +log_cli = false +log_cli_level = INFO +log_cli_format = %(asctime)s [%(module)s] [%(levelname)s] %(message)s +log_cli_date_format = %Y-%m-%d %H:%M:%S +log_level = INFO +log_format = %(asctime)s [%(module)s] [%(levelname)s] %(message)s +log_date_format = %Y-%m-%d %H:%M:%S +addopts = --cov --cov-report term-missing +env_files = .env +pythonpath = . src +filterwarnings = + ignore::DeprecationWarning:pydantic.* diff --git a/plugins/external/cedar/tests/test_cedarpolicyplugin.py b/plugins/external/cedar/tests/test_cedarpolicyplugin.py new file mode 100644 index 000000000..52e897408 --- /dev/null +++ b/plugins/external/cedar/tests/test_cedarpolicyplugin.py @@ -0,0 +1,642 @@ +# -*- coding: utf-8 -*- +"""Tests for plugin.""" + +# Third-Party +from cedarpolicyplugin.plugin import CedarPolicyPlugin +import pytest + +# First-Party +from mcpgateway.common.models import Message, PromptResult, ResourceContent, Role, TextContent +from mcpgateway.plugins.framework.hooks.prompts import PromptPosthookPayload, PromptPrehookPayload +from mcpgateway.plugins.framework.hooks.resources import ResourcePostFetchPayload, ResourcePreFetchPayload +from mcpgateway.plugins.framework.hooks.tools import ToolPostInvokePayload, ToolPreInvokePayload +from mcpgateway.plugins.framework.models import ( + GlobalContext, + PluginConfig, + PluginContext, +) + + +# This test case is responsible for verifying cedarplugin functionality for post tool hooks in cdear native mode +@pytest.mark.asyncio +async def test_cedarpolicyplugin_post_tool_invoke_rbac(): + """Test plugin for post tool invocation""" + policy_config = [ + { + "id": "allow-employee-basic-access", + "effect": "Permit", + "principal": 'Role::"employee"', + "action": ['Action::"get_leave_balance"', 'Action::"request_certificate"'], + "resource": ['Server::"askHR"', 'Agent::"employee_agent"'], + }, + { + "id": "allow-manager-full-access", + "effect": "Permit", + "principal": 'Role::"manager"', + "action": ['Action::"get_leave_balance"', 'Action::"approve_leave"', 'Action::"promote_employee"', 'Action::"view_performance"', 'Action::"view_full_output"'], + "resource": ['Agent::"manager_agent"', 'Server::"payroll_tool"'], + }, + { + "id": "allow-hr-hr_tool", + "effect": "Permit", + "principal": 'Role::"hr"', + "action": ['Action::"update_payroll"', 'Action::"view_performance"', 'Action::"view_full_output"'], + "resource": ['Server::"hr_tool"'], + }, + { + "id": "redact-non-manager-views", + "effect": "Permit", + "principal": 'Role::"employee"', + "action": ['Action::"view_redacted_output"'], + "resource": ['Server::"payroll_tool"', 'Agent::"manager_agent"', 'Server::"askHR"'], + }, + ] + + policy_output_keywords = {"view_full": "view_full_output", "view_redacted": "view_redacted_output"} + policy_redaction_spec = {"pattern": r"\$\d{1,}(,\d{1,})*"} + config = PluginConfig( + name="test", + kind="cedarpolicyplugin.CedarPolicyPlugin", + hooks=["tool_pre_invoke"], + config={"policy_lang": "cedar", "policy": policy_config, "policy_output_keywords": policy_output_keywords, "policy_redaction_spec": policy_redaction_spec}, + ) + plugin = CedarPolicyPlugin(config) + info = {"alice": "employee", "bob": "manager", "carol": "hr", "robert": "admin"} + plugin._set_jwt_info(info) + requests = [ + {"user": "alice", "action": "get_leave_balance", "resource": "askHR"}, + {"user": "bob", "action": "view_performance", "resource": "payroll_tool"}, + {"user": "carol", "action": "update_payroll", "resource": "hr_tool"}, + {"user": "alice", "action": "update_payroll", "resource": "hr_tool"}, + ] + + redact_count = 0 + allow_count = 0 + deny_count = 0 + for req in requests: + payload = ToolPostInvokePayload(name=req["action"], result={"text": "Alice has a salary of $250,000"}) + context = PluginContext(global_context=GlobalContext(request_id="1", server_id=req["resource"], user=req["user"])) + result = await plugin.tool_post_invoke(payload, context) + if result.modified_payload and "[REDACTED]" in result.modified_payload.result["text"]: + redact_count += 1 + if result.continue_processing: + allow_count += 1 + if not result.continue_processing: + deny_count += 1 + + assert redact_count == 1 + assert allow_count == 3 + assert deny_count == 1 + + +# This test case is responsible for verifying cedarplugin functionality for post tool invocation with policy in custom dsl mode +@pytest.mark.asyncio +async def test_cedarpolicyplugin_post_tool_invoke_custom_dsl_rbac(): + """Test plugin for post tool invocation""" + policy_config = "[role:employee:server/askHR]\nget_leave_balance\nrequest_certificate\n\n\ + [role:employee:agent/employee_agent]\nget_leave_balance\nrequest_certificate\n\n[role:manager:agent/manager_agent]\nget_leave_balance\napprove_leave\npromote_employee\nview_performance\nview_full_output\n\n[role:manager:server/payroll_tool]\ + \nget_leave_balance\napprove_leave\npromote_employee\nview_performance\nview_full_output\n\n[role:hr:server/hr_tool]\nupdate_payroll\nview_performance\nview_full_output\n\n[role:employee:server/payroll_tool]\nview_redacted_output\n\n[role:employee:agent/manager_agent]\nview_redacted_output\n\n\ + [role:employee:server/askHR]\nview_redacted_output" + + policy_output_keywords = {"view_full": "view_full_output", "view_redacted": "view_redacted_output"} + policy_redaction_spec = {"pattern": r"\$\d{1,}(,\d{1,})*"} + config = PluginConfig( + name="test", + kind="cedarpolicyplugin.CedarPolicyPlugin", + hooks=["tool_pre_invoke"], + config={"policy_lang": "custom_dsl", "policy": policy_config, "policy_output_keywords": policy_output_keywords, "policy_redaction_spec": policy_redaction_spec}, + ) + plugin = CedarPolicyPlugin(config) + info = {"alice": "employee", "bob": "manager", "carol": "hr", "robert": "admin"} + plugin._set_jwt_info(info) + requests = [ + {"user": "alice", "action": "get_leave_balance", "resource": "askHR"}, + {"user": "bob", "action": "view_performance", "resource": "payroll_tool"}, + {"user": "carol", "action": "update_payroll", "resource": "hr_tool"}, + {"user": "alice", "action": "update_payroll", "resource": "hr_tool"}, + ] + + redact_count = 0 + allow_count = 0 + deny_count = 0 + for req in requests: + payload = ToolPostInvokePayload(name=req["action"], result={"text": "Alice has a salary of $250,000"}) + context = PluginContext(global_context=GlobalContext(request_id="1", server_id=req["resource"], user=req["user"])) + result = await plugin.tool_post_invoke(payload, context) + if result.modified_payload and "[REDACTED]" in result.modified_payload.result["text"]: + redact_count += 1 + if result.continue_processing: + allow_count += 1 + if not result.continue_processing: + deny_count += 1 + + assert redact_count == 1 + assert allow_count == 3 + assert deny_count == 1 + + +# This test case is responsible for verifying cedarplugin functionality for tool pre invoke in cedar native mode +@pytest.mark.asyncio +async def test_cedarpolicyplugin_pre_tool_invoke_cedar_rbac(): + """Test plugin tool pre invoke hook.""" + policy_config = [ + { + "id": "allow-employee-basic-access", + "effect": "Permit", + "principal": 'Role::"employee"', + "action": ['Action::"get_leave_balance"', 'Action::"request_certificate"'], + "resource": ['Server::"askHR"', 'Agent::"employee_agent"'], + }, + { + "id": "allow-manager-full-access", + "effect": "Permit", + "principal": 'Role::"manager"', + "action": ['Action::"get_leave_balance"', 'Action::"approve_leave"', 'Action::"promote_employee"', 'Action::"view_performance"', 'Action::"view_full_output"'], + "resource": ['Agent::"manager_agent"', 'Server::"payroll_tool"'], + }, + { + "id": "allow-hr-hr_tool", + "effect": "Permit", + "principal": 'Role::"hr"', + "action": ['Action::"update_payroll"', 'Action::"view_performance"', 'Action::"view_full_output"'], + "resource": ['Server::"hr_tool"'], + }, + { + "id": "redact-non-manager-views", + "effect": "Permit", + "principal": 'Role::"employee"', + "action": ['Action::"view_redacted_output"'], + "resource": ['Server::"payroll_tool"', 'Agent::"manager_agent"', 'Server::"askHR"'], + }, + ] + + policy_output_keywords = {"view_full": "view_full_output", "view_redacted": "view_redacted_output"} + policy_redaction_spec = {"pattern": r"\$\d{1,}(,\d{1,})*"} + config = PluginConfig( + name="test", + kind="cedarpolicyplugin.CedarPolicyPlugin", + hooks=["tool_pre_invoke"], + config={"policy_lang": "cedar", "policy": policy_config, "policy_output_keywords": policy_output_keywords, "policy_redaction_spec": policy_redaction_spec}, + ) + plugin = CedarPolicyPlugin(config) + info = {"alice": "employee", "bob": "manager", "carol": "hr", "robert": "admin"} + plugin._set_jwt_info(info) + requests = [ + {"user": "alice", "action": "get_leave_balance", "resource": "askHR"}, + {"user": "bob", "action": "view_performance", "resource": "payroll_tool"}, + {"user": "carol", "action": "update_payroll", "resource": "hr_tool"}, + {"user": "alice", "action": "update_payroll", "resource": "hr_tool"}, + ] + + allow_count = 0 + deny_count = 0 + for req in requests: + payload = ToolPreInvokePayload(name=req["action"], args={"arg1": "sample arg"}) + context = PluginContext(global_context=GlobalContext(request_id="1", server_id=req["resource"], user=req["user"])) + result = await plugin.tool_pre_invoke(payload, context) + if result.continue_processing: + allow_count += 1 + if not result.continue_processing: + deny_count += 1 + + assert allow_count == 3 + assert deny_count == 1 + + +# This test case is responsible for verifying cedarplugin functionality for tool pre invoke in custom dsl mode +@pytest.mark.asyncio +async def test_cedarpolicyplugin_pre_tool_invoke_custom_dsl_rbac(): + """Test plugin tool pre invoke.""" + policy_config = "[role:employee:server/askHR]\nget_leave_balance\nrequest_certificate\n\n[role:employee:agent/employee_agent]\n\ + get_leave_balance\nrequest_certificate\n\n[role:manager:agent/manager_agent]\nget_leave_balance\napprove_leave\npromote_employee\n\ + view_performance\nview_full_output\n\n[role:manager:server/payroll_tool]\nget_leave_balance\napprove_leave\npromote_employee\nview_performance\n\ + view_full_output\n\n[role:hr:server/hr_tool]\nupdate_payroll\nview_performance\nview_full_output\n\n[role:employee:server/payroll_tool]\n\ + view_redacted_output\n\n[role:employee:agent/manager_agent]\nview_redacted_output\n\n[role:employee:server/askHR]\nview_redacted_output" + policy_output_keywords = {"view_full": "view_full_output", "view_redacted": "view_redacted_output"} + policy_redaction_spec = {"pattern": r"\$\d{1,}(,\d{1,})*"} + config = PluginConfig( + name="test", + kind="cedarpolicyplugin.CedarPolicyPlugin", + hooks=["tool_pre_invoke"], + config={"policy_lang": "custom_dsl", "policy": policy_config, "policy_output_keywords": policy_output_keywords, "policy_redaction_spec": policy_redaction_spec}, + ) + plugin = CedarPolicyPlugin(config) + info = {"alice": "employee", "bob": "manager", "carol": "hr", "robert": "admin"} + plugin._set_jwt_info(info) + requests = [ + {"user": "alice", "action": "get_leave_balance", "resource": "askHR"}, + {"user": "bob", "action": "view_performance", "resource": "payroll_tool"}, + {"user": "carol", "action": "update_payroll", "resource": "hr_tool"}, + {"user": "alice", "action": "update_payroll", "resource": "hr_tool"}, + ] + + allow_count = 0 + deny_count = 0 + for req in requests: + payload = ToolPreInvokePayload(name=req["action"], args={"arg1": "sample arg"}) + context = PluginContext(global_context=GlobalContext(request_id="1", server_id=req["resource"], user=req["user"])) + result = await plugin.tool_pre_invoke(payload, context) + if result.continue_processing: + allow_count += 1 + if not result.continue_processing: + deny_count += 1 + + assert allow_count == 3 + assert deny_count == 1 + + +# This test case is responsible for verifying cedarplugin functionality for prompt pre fetch in cedar mode +@pytest.mark.asyncio +async def test_cedarpolicyplugin_prompt_pre_fetch_rbac(): + """Test plugin prompt prefetch hook.""" + policy_config = [ + {"id": "redact-non-admin-views", "effect": "Permit", "principal": 'Role::"employee"', "action": ['Action::"view_redacted_output"'], "resource": 'Prompt::"judge_prompts"'}, + { + "id": "allow-admin-prompts", # policy for resources + "effect": "Permit", + "principal": 'Role::"admin"', + "action": ['Action::"view_full_output"'], + "resource": 'Prompt::"judge_prompts"', # Prompt:: + }, + ] + + policy_output_keywords = {"view_full": "view_full_output", "view_redacted": "view_redacted_output"} + policy_redaction_spec = {"pattern": "all"} + config = PluginConfig( + name="test", + kind="cedarpolicyplugin.CedarPolicyPlugin", + hooks=["tool_pre_invoke"], + config={"policy_lang": "cedar", "policy": policy_config, "policy_output_keywords": policy_output_keywords, "policy_redaction_spec": policy_redaction_spec}, + ) + plugin = CedarPolicyPlugin(config) + info = {"alice": "employee", "bob": "manager", "carol": "hr", "robert": "admin"} + plugin._set_jwt_info(info) + requests = [ + {"user": "alice", "resource": "judge_prompts"}, # allow + {"user": "robert", "resource": "judge_prompts"}, # allow + {"user": "carol", "resource": "judge_prompts"}, # deny + ] + + allow_count = 0 + deny_count = 0 + + for req in requests: + + # Prompt pre hook input + payload = PromptPrehookPayload(prompt_id=req["resource"], args={"text": "You are curseword"}) + context = PluginContext(global_context=GlobalContext(request_id="1", server_id="2", user=req["user"])) + result = await plugin.prompt_pre_fetch(payload, context) + if result.continue_processing: + allow_count += 1 + if not result.continue_processing: + deny_count += 1 + + assert allow_count == 2 + assert deny_count == 1 + + +# This test case is responsible for verifying cedarplugin functionality for prompt pre fetch in custom dsl mode +@pytest.mark.asyncio +async def test_cedarpolicyplugin_prompt_pre_fetch_custom_dsl_rbac(): + """Test plugin prompt prefetch hook.""" + policy_config = "[role:employee:prompt/judge_prompts]\nview_redacted_output\n\n[role:admin:prompt/judge_prompts]\nview_full_output" + + policy_output_keywords = {"view_full": "view_full_output", "view_redacted": "view_redacted_output"} + policy_redaction_spec = {"pattern": "all"} + config = PluginConfig( + name="test", + kind="cedarpolicyplugin.CedarPolicyPlugin", + hooks=["tool_pre_invoke"], + config={"policy_lang": "custom_dsl", "policy": policy_config, "policy_output_keywords": policy_output_keywords, "policy_redaction_spec": policy_redaction_spec}, + ) + plugin = CedarPolicyPlugin(config) + info = {"alice": "employee", "bob": "manager", "carol": "hr", "robert": "admin"} + plugin._set_jwt_info(info) + requests = [ + {"user": "alice", "resource": "judge_prompts"}, # allow + {"user": "robert", "resource": "judge_prompts"}, # allow + {"user": "carol", "resource": "judge_prompts"}, # deny + ] + + allow_count = 0 + deny_count = 0 + + for req in requests: + + # Prompt pre hook input + payload = PromptPrehookPayload(prompt_id=req["resource"], args={"text": "You are curseword"}) + context = PluginContext(global_context=GlobalContext(request_id="1", server_id="2", user=req["user"])) + result = await plugin.prompt_pre_fetch(payload, context) + if result.continue_processing: + allow_count += 1 + if not result.continue_processing: + deny_count += 1 + + assert allow_count == 2 + assert deny_count == 1 + + +# This test case is responsible for verifying cedarplugin functionality for prompt post fetch in cedar native mode +@pytest.mark.asyncio +async def test_cedarpolicyplugin_prompt_post_fetch_cedar_rbac(): + """Test plugin prompt postfetch hook.""" + policy_config = [ + {"id": "redact-non-admin-views", "effect": "Permit", "principal": 'Role::"employee"', "action": ['Action::"view_redacted_output"'], "resource": 'Prompt::"judge_prompts"'}, + { + "id": "allow-admin-prompts", # policy for resources + "effect": "Permit", + "principal": 'Role::"admin"', + "action": ['Action::"view_full_output"'], + "resource": 'Prompt::"judge_prompts"', # Prompt:: + }, + ] + + policy_output_keywords = {"view_full": "view_full_output", "view_redacted": "view_redacted_output"} + policy_redaction_spec = {"pattern": "all"} + config = PluginConfig( + name="test", + kind="cedarpolicyplugin.CedarPolicyPlugin", + hooks=["tool_pre_invoke"], + config={"policy_lang": "cedar", "policy": policy_config, "policy_output_keywords": policy_output_keywords, "policy_redaction_spec": policy_redaction_spec}, + ) + plugin = CedarPolicyPlugin(config) + info = {"alice": "employee", "bob": "manager", "carol": "hr", "robert": "admin"} + plugin._set_jwt_info(info) + requests = [ + {"user": "alice", "resource": "judge_prompts"}, # allow + {"user": "robert", "resource": "judge_prompts"}, # allow + {"user": "carol", "resource": "judge_prompts"}, # deny + ] + + allow_count = 0 + deny_count = 0 + redact_count = 0 + + for req in requests: + + # Prompt post hook output + message = Message(content=TextContent(type="text", text="abc"), role=Role.USER) + prompt_result = PromptResult(messages=[message]) + payload = PromptPosthookPayload(prompt_id=req["resource"], result=prompt_result) + context = PluginContext(global_context=GlobalContext(request_id="1", server_id="2", user=req["user"])) + result = await plugin.prompt_post_fetch(payload, context) + if result.continue_processing: + allow_count += 1 + if result.modified_payload and "[REDACTED]" in result.modified_payload.result.messages[0].content.text: + redact_count += 1 + if not result.continue_processing: + deny_count += 1 + + assert allow_count == 2 + assert deny_count == 1 + assert redact_count == 1 + + +# This test case is responsible for verifying cedarplugin functionality for prompt post fetch in custom dsl mode +@pytest.mark.asyncio +async def test_cedarpolicyplugin_prompt_post_fetch_custom_dsl_rbac(): + """Test plugin prompt postfetch hook.""" + policy_config = "[role:employee:prompt/judge_prompts]\nview_redacted_output\n\n[role:admin:prompt/judge_prompts]\nview_full_output" + + policy_output_keywords = {"view_full": "view_full_output", "view_redacted": "view_redacted_output"} + policy_redaction_spec = {"pattern": "all"} + config = PluginConfig( + name="test", + kind="cedarpolicyplugin.CedarPolicyPlugin", + hooks=["tool_pre_invoke"], + config={"policy_lang": "custom_dsl", "policy": policy_config, "policy_output_keywords": policy_output_keywords, "policy_redaction_spec": policy_redaction_spec}, + ) + plugin = CedarPolicyPlugin(config) + info = {"alice": "employee", "bob": "manager", "carol": "hr", "robert": "admin"} + plugin._set_jwt_info(info) + requests = [ + {"user": "alice", "resource": "judge_prompts"}, # allow + {"user": "robert", "resource": "judge_prompts"}, # allow + {"user": "carol", "resource": "judge_prompts"}, # deny + ] + + allow_count = 0 + deny_count = 0 + redact_count = 0 + + for req in requests: + + # Prompt post hook output + message = Message(content=TextContent(type="text", text="abc"), role=Role.USER) + prompt_result = PromptResult(messages=[message]) + payload = PromptPosthookPayload(prompt_id=req["resource"], result=prompt_result) + context = PluginContext(global_context=GlobalContext(request_id="1", server_id="2", user=req["user"])) + result = await plugin.prompt_post_fetch(payload, context) + if result.continue_processing: + allow_count += 1 + if result.modified_payload and "[REDACTED]" in result.modified_payload.result.messages[0].content.text: + redact_count += 1 + if not result.continue_processing: + deny_count += 1 + + assert allow_count == 2 + assert deny_count == 1 + assert redact_count == 1 + + +# This test case is responsible for verifying cedarplugin functionality for resource pre fetch in cedar native mode +@pytest.mark.asyncio +async def test_cedarpolicyplugin_resource_pre_fetch_cedar_rbac(): + """Test plugin resource prefetch hook.""" + policy_config = [ + { + "id": "redact-non-admin-resource-views", + "effect": "Permit", + "principal": 'Role::"employee"', + "action": ['Action::"view_redacted_output"'], + "resource": 'Resource::"https://example.com/data"', + }, + { + "id": "allow-admin-resources", # policy for resources + "effect": "Permit", + "principal": 'Role::"admin"', + "action": ['Action::"view_full_output"'], + "resource": 'Resource::"https://example.com/data"', + }, + ] + + policy_output_keywords = {"view_full": "view_full_output", "view_redacted": "view_redacted_output"} + policy_redaction_spec = {"pattern": "[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}"} + config = PluginConfig( + name="test", + kind="cedarpolicyplugin.CedarPolicyPlugin", + hooks=["tool_pre_invoke"], + config={"policy_lang": "cedar", "policy": policy_config, "policy_output_keywords": policy_output_keywords, "policy_redaction_spec": policy_redaction_spec}, + ) + plugin = CedarPolicyPlugin(config) + info = {"alice": "employee", "bob": "manager", "carol": "hr", "robert": "admin"} + plugin._set_jwt_info(info) + requests = [ + {"user": "alice", "resource": "https://example.com/data"}, # allow + {"user": "robert", "resource": "https://example.com/data"}, # allow + {"user": "carol", "resource": "https://example.com/data"}, # deny + ] + + allow_count = 0 + deny_count = 0 + + for req in requests: + + # Prompt post hook output + payload = ResourcePreFetchPayload(uri="https://example.com/data", metadata={}) + context = PluginContext(global_context=GlobalContext(request_id="1", server_id="2", user=req["user"])) + result = await plugin.resource_pre_fetch(payload, context) + if result.continue_processing: + allow_count += 1 + if not result.continue_processing: + deny_count += 1 + + assert allow_count == 2 + assert deny_count == 1 + + +# This test case is responsible for verifying cedarplugin functionality for resource pre fetch in custom dsl mode +@pytest.mark.asyncio +async def test_cedarpolicyplugin_resource_pre_fetch_custom_dsl_rbac(): + """Test plugin resource prefetch hook.""" + policy_config = "[role:employee:resource/https://example.com/data]\nview_redacted_output\n\n[role:admin:resource/https://example.com/data]\nview_full_output" + + policy_output_keywords = {"view_full": "view_full_output", "view_redacted": "view_redacted_output"} + policy_redaction_spec = {"pattern": "[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}"} + config = PluginConfig( + name="test", + kind="cedarpolicyplugin.CedarPolicyPlugin", + hooks=["tool_pre_invoke"], + config={"policy_lang": "custom_dsl", "policy": policy_config, "policy_output_keywords": policy_output_keywords, "policy_redaction_spec": policy_redaction_spec}, + ) + plugin = CedarPolicyPlugin(config) + info = {"alice": "employee", "bob": "manager", "carol": "hr", "robert": "admin"} + plugin._set_jwt_info(info) + requests = [ + {"user": "alice", "resource": "https://example.com/data"}, # allow + {"user": "robert", "resource": "https://example.com/data"}, # allow + {"user": "carol", "resource": "https://example.com/data"}, # deny + ] + + allow_count = 0 + deny_count = 0 + + for req in requests: + + # Prompt post hook output + payload = ResourcePreFetchPayload(uri="https://example.com/data", metadata={}) + context = PluginContext(global_context=GlobalContext(request_id="1", server_id="2", user=req["user"])) + result = await plugin.resource_pre_fetch(payload, context) + if result.continue_processing: + allow_count += 1 + if not result.continue_processing: + deny_count += 1 + + assert allow_count == 2 + assert deny_count == 1 + + +# This test case is responsible for verifying cedarplugin functionality for resource post fetch in cedar native mode +@pytest.mark.asyncio +async def test_cedarpolicyplugin_resource_post_fetch_cedar_rbac(): + """Test plugin resource post fetch.""" + policy_config = [ + { + "id": "redact-non-admin-resource-views", + "effect": "Permit", + "principal": 'Role::"employee"', + "action": ['Action::"view_redacted_output"'], + "resource": 'Resource::"https://example.com/data"', + }, + { + "id": "allow-admin-resources", # policy for resources + "effect": "Permit", + "principal": 'Role::"admin"', + "action": ['Action::"view_full_output"'], + "resource": 'Resource::"https://example.com/data"', + }, + ] + + policy_output_keywords = {"view_full": "view_full_output", "view_redacted": "view_redacted_output"} + policy_redaction_spec = {"pattern": "[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}"} + config = PluginConfig( + name="test", + kind="cedarpolicyplugin.CedarPolicyPlugin", + hooks=["tool_pre_invoke"], + config={"policy_lang": "cedar", "policy": policy_config, "policy_output_keywords": policy_output_keywords, "policy_redaction_spec": policy_redaction_spec}, + ) + plugin = CedarPolicyPlugin(config) + info = {"alice": "employee", "bob": "manager", "carol": "hr", "robert": "admin"} + plugin._set_jwt_info(info) + requests = [ + {"user": "alice", "resource": "https://example.com/data"}, # allow + {"user": "robert", "resource": "https://example.com/data"}, # allow + {"user": "carol", "resource": "https://example.com/data"}, # deny + ] + + allow_count = 0 + deny_count = 0 + redact_count = 0 + + for req in requests: + + # Prompt post hook output + content = ResourceContent(type="resource", uri="test://large", text="test://abc@example.com", id="1") + payload = ResourcePostFetchPayload(uri="https://example.com/data", content=content) + context = PluginContext(global_context=GlobalContext(request_id="1", server_id="2", user=req["user"])) + result = await plugin.resource_post_fetch(payload, context) + if result.continue_processing: + allow_count += 1 + if result.modified_payload and "[REDACTED]" in result.modified_payload.content.text: + redact_count += 1 + if not result.continue_processing: + deny_count += 1 + + assert allow_count == 2 + assert deny_count == 1 + assert redact_count == 1 + + +# This test case is responsible for verifying cedarplugin functionality for resource post fetch in custom dsl mode +@pytest.mark.asyncio +async def test_cedarpolicyplugin_resource_post_fetch_custom_dsl_rbac(): + """Test plugin resource postfetch hook.""" + policy_config = "[role:employee:resource/https://example.com/data]\nview_redacted_output\n\n[role:admin:resource/https://example.com/data]\nview_full_output" + policy_output_keywords = {"view_full": "view_full_output", "view_redacted": "view_redacted_output"} + policy_redaction_spec = {"pattern": "[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}"} + config = PluginConfig( + name="test", + kind="cedarpolicyplugin.CedarPolicyPlugin", + hooks=["tool_pre_invoke"], + config={"policy_lang": "custom_dsl", "policy": policy_config, "policy_output_keywords": policy_output_keywords, "policy_redaction_spec": policy_redaction_spec}, + ) + plugin = CedarPolicyPlugin(config) + info = {"alice": "employee", "bob": "manager", "carol": "hr", "robert": "admin"} + plugin._set_jwt_info(info) + requests = [ + {"user": "alice", "resource": "https://example.com/data"}, # allow + {"user": "robert", "resource": "https://example.com/data"}, # allow + {"user": "carol", "resource": "https://example.com/data"}, # deny + ] + + allow_count = 0 + deny_count = 0 + redact_count = 0 + + for req in requests: + + # Prompt post hook output + content = ResourceContent(type="resource", uri="test://large", text="test://abc@example.com", id="1") + payload = ResourcePostFetchPayload(uri="https://example.com/data", content=content) + context = PluginContext(global_context=GlobalContext(request_id="1", server_id="2", user=req["user"])) + result = await plugin.resource_post_fetch(payload, context) + if result.continue_processing: + allow_count += 1 + if result.modified_payload and "[REDACTED]" in result.modified_payload.content.text: + redact_count += 1 + if not result.continue_processing: + deny_count += 1 + + assert allow_count == 2 + assert deny_count == 1 + assert redact_count == 1 diff --git a/plugins/vault/README.md b/plugins/vault/README.md index 7f61bb01b..15f3440a2 100644 --- a/plugins/vault/README.md +++ b/plugins/vault/README.md @@ -222,4 +222,3 @@ curl -s -X POST -H "Authorization: Bearer $MCPGATEWAY_BEARER_TOKEN" \ }' \ http://localhost:4444/tools/invoke ``` - diff --git a/pyproject.toml b/pyproject.toml index c8afa56f6..aa8c35355 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -282,6 +282,7 @@ Changelog = "https://github.com/IBM/mcp-context-forge/blob/main/CHANGELOG.md" [project.scripts] mcpgateway = "mcpgateway.cli:main" mcpplugins = "mcpgateway.plugins.tools.cli:main" +cforge = "mcpgateway.tools.cli:main" # -------------------------------------------------------------------- # 🔧 setuptools-specific configuration @@ -300,6 +301,9 @@ exclude = ["tests*"] # - templates -> Jinja2 templates shipped at runtime [tool.setuptools.package-data] mcpgateway = [ + "tools/builder/templates/*.yaml.j2", + "tools/builder/templates/compose/*.yaml.j2", + "tools/builder/templates/kubernetes/*.yaml.j2", "py.typed", "static/*.css", "static/*.js", @@ -674,7 +678,11 @@ omit = [ "*/test_*.py", "*/__init__.py", "*/alembic/*", - "*/version.py" + "*/version.py", + # Builder deployment files - require external tools (docker, kubectl, templates) + "mcpgateway/tools/builder/common.py", + "mcpgateway/tools/builder/dagger_deploy.py", + "mcpgateway/tools/builder/python_deploy.py" ] # -------------------------------------------------------------------- diff --git a/run-gunicorn.sh b/run-gunicorn.sh index 61addbbb1..e20e3ea3b 100755 --- a/run-gunicorn.sh +++ b/run-gunicorn.sh @@ -278,6 +278,13 @@ echo " Developer Mode: ${GUNICORN_DEV_MODE}" SSL=${SSL:-false} # Enable/disable SSL (default: false) CERT_FILE=${CERT_FILE:-certs/cert.pem} # Path to SSL certificate file KEY_FILE=${KEY_FILE:-certs/key.pem} # Path to SSL private key file +KEY_FILE_PASSWORD=${KEY_FILE_PASSWORD:-} # Optional passphrase for encrypted key +CERT_PASSPHRASE=${CERT_PASSPHRASE:-} # Alternative name for passphrase + +# Use CERT_PASSPHRASE if KEY_FILE_PASSWORD is not set (for compatibility) +if [[ -z "${KEY_FILE_PASSWORD}" && -n "${CERT_PASSPHRASE}" ]]; then + KEY_FILE_PASSWORD="${CERT_PASSPHRASE}" +fi # Verify SSL settings if enabled if [[ "${SSL}" == "true" ]]; then @@ -305,9 +312,22 @@ if [[ "${SSL}" == "true" ]]; then exit 1 fi + # Check if passphrase is provided + if [[ -n "${KEY_FILE_PASSWORD}" ]]; then + echo "🔑 Passphrase-protected key detected" + echo " Note: Key will be decrypted by Python SSL key manager" + # Export for Python to access + export SSL_KEY_PASSWORD="${KEY_FILE_PASSWORD}" + fi + echo "✓ TLS enabled - using:" echo " Certificate: ${CERT_FILE}" echo " Private Key: ${KEY_FILE}" + if [[ -n "${KEY_FILE_PASSWORD}" ]]; then + echo " Passphrase: ******** (protected)" + else + echo " Passphrase: (none)" + fi else echo "🔓 Running without TLS (HTTP only)" fi @@ -381,6 +401,7 @@ fi # Add SSL arguments if enabled if [[ "${SSL}" == "true" ]]; then cmd+=( --certfile "${CERT_FILE}" --keyfile "${KEY_FILE}" ) + # If passphrase is set, it will be available to Python via SSL_KEY_PASSWORD env var fi # Add the application module diff --git a/tests/conftest.py b/tests/conftest.py index 5c813749f..69b3a0e31 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -121,8 +121,20 @@ def app(): import mcpgateway.main as main_mod mp.setattr(main_mod, "SessionLocal", TestSessionLocal, raising=False) - # (patch engine too if your code references it) - mp.setattr(main_mod, "engine", engine, raising=False) + + # Also patch security_logger and auth_middleware's SessionLocal + # First-Party + import mcpgateway.middleware.auth_middleware as auth_middleware_mod + import mcpgateway.services.security_logger as sec_logger_mod + import mcpgateway.services.structured_logger as struct_logger_mod + import mcpgateway.services.audit_trail_service as audit_trail_mod + import mcpgateway.services.log_aggregator as log_aggregator_mod + + mp.setattr(auth_middleware_mod, "SessionLocal", TestSessionLocal, raising=False) + mp.setattr(sec_logger_mod, "SessionLocal", TestSessionLocal, raising=False) + mp.setattr(struct_logger_mod, "SessionLocal", TestSessionLocal, raising=False) + mp.setattr(audit_trail_mod, "SessionLocal", TestSessionLocal, raising=False) + mp.setattr(log_aggregator_mod, "SessionLocal", TestSessionLocal, raising=False) # 4) create schema db_mod.Base.metadata.create_all(bind=engine) @@ -186,8 +198,20 @@ def app_with_temp_db(): import mcpgateway.main as main_mod mp.setattr(main_mod, "SessionLocal", TestSessionLocal, raising=False) - # (patch engine too if your code references it) - mp.setattr(main_mod, "engine", engine, raising=False) + + # Also patch security_logger and auth_middleware's SessionLocal + # First-Party + import mcpgateway.middleware.auth_middleware as auth_middleware_mod + import mcpgateway.services.security_logger as sec_logger_mod + import mcpgateway.services.structured_logger as struct_logger_mod + import mcpgateway.services.audit_trail_service as audit_trail_mod + import mcpgateway.services.log_aggregator as log_aggregator_mod + + mp.setattr(auth_middleware_mod, "SessionLocal", TestSessionLocal, raising=False) + mp.setattr(sec_logger_mod, "SessionLocal", TestSessionLocal, raising=False) + mp.setattr(struct_logger_mod, "SessionLocal", TestSessionLocal, raising=False) + mp.setattr(audit_trail_mod, "SessionLocal", TestSessionLocal, raising=False) + mp.setattr(log_aggregator_mod, "SessionLocal", TestSessionLocal, raising=False) # 4) create schema db_mod.Base.metadata.create_all(bind=engine) diff --git a/tests/e2e/test_main_apis.py b/tests/e2e/test_main_apis.py index bb8c6e29c..ab22b1126 100644 --- a/tests/e2e/test_main_apis.py +++ b/tests/e2e/test_main_apis.py @@ -218,9 +218,18 @@ def mock_get_permission_service(*args, **kwargs): app.dependency_overrides[get_permission_service] = mock_get_permission_service app.dependency_overrides[get_db] = override_get_db + # Mock security_logger to prevent database access issues + mock_sec_logger = MagicMock() + mock_sec_logger.log_authentication_attempt = MagicMock(return_value=None) + mock_sec_logger.log_security_event = MagicMock(return_value=None) + # Patch at the middleware level where security_logger is used + sec_patcher = patch("mcpgateway.middleware.auth_middleware.security_logger", mock_sec_logger) + sec_patcher.start() + yield engine # Cleanup + sec_patcher.stop() app.dependency_overrides.clear() os.close(db_fd) os.unlink(db_path) diff --git a/tests/fuzz/conftest.py b/tests/fuzz/conftest.py index 6b9326b4b..a92cd87ca 100644 --- a/tests/fuzz/conftest.py +++ b/tests/fuzz/conftest.py @@ -7,13 +7,68 @@ Fuzzing test configuration. """ +# Standard +import os +import tempfile + # Third-Party +from _pytest.monkeypatch import MonkeyPatch from hypothesis import HealthCheck, settings, Verbosity import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from sqlalchemy.pool import StaticPool # Mark all tests in this directory as fuzz tests pytestmark = pytest.mark.fuzz + +@pytest.fixture(autouse=True) +def mock_logging_services(monkeypatch): + """Mock logging services to prevent database access during fuzz tests. + + This fixture patches SessionLocal in the db module and all modules that + import it, ensuring they use a test database with all tables created. + """ + # Create a temp database for the fuzz tests + fd, path = tempfile.mkstemp(suffix=".db") + url = f"sqlite:///{path}" + + engine = create_engine(url, connect_args={"check_same_thread": False}, poolclass=StaticPool) + TestSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + + # First-Party + import mcpgateway.db as db_mod + from mcpgateway.db import Base + import mcpgateway.main as main_mod + import mcpgateway.middleware.auth_middleware as auth_middleware_mod + import mcpgateway.services.security_logger as sec_logger_mod + import mcpgateway.services.structured_logger as struct_logger_mod + + # Patch the core db module + monkeypatch.setattr(db_mod, "engine", engine) + monkeypatch.setattr(db_mod, "SessionLocal", TestSessionLocal) + + # Patch main module's SessionLocal (it imports SessionLocal from db) + monkeypatch.setattr(main_mod, "SessionLocal", TestSessionLocal) + + # Patch auth_middleware's SessionLocal + monkeypatch.setattr(auth_middleware_mod, "SessionLocal", TestSessionLocal) + + # Patch security_logger and structured_logger SessionLocal + monkeypatch.setattr(sec_logger_mod, "SessionLocal", TestSessionLocal) + monkeypatch.setattr(struct_logger_mod, "SessionLocal", TestSessionLocal) + + # Create all tables + Base.metadata.create_all(bind=engine) + + yield + + # Cleanup + engine.dispose() + os.close(fd) + os.unlink(path) + # Configure Hypothesis profiles for different environments settings.register_profile("dev", max_examples=100, verbosity=Verbosity.normal, suppress_health_check=[HealthCheck.too_slow]) diff --git a/tests/fuzz/test_schema_validation_fuzz.py b/tests/fuzz/test_schema_validation_fuzz.py index fcd6ea479..4c75f2a62 100644 --- a/tests/fuzz/test_schema_validation_fuzz.py +++ b/tests/fuzz/test_schema_validation_fuzz.py @@ -135,9 +135,9 @@ def test_tool_create_tags_field(self, tags): """Test tags field with various lists.""" try: tool = ToolCreate(name="test", url="http://example.com", tags=tags) - # If validation succeeds, tags should be list of strings + # If validation succeeds, tags should be list of dicts with id/label keys assert isinstance(tool.tags, list) - assert all(isinstance(tag, str) for tag in tool.tags) + assert all(isinstance(tag, dict) and "id" in tag and "label" in tag for tag in tool.tags) except ValidationError: # Expected for invalid tag structures pass diff --git a/tests/fuzz/test_security_fuzz.py b/tests/fuzz/test_security_fuzz.py index 7da56e4c5..b4494d5d9 100644 --- a/tests/fuzz/test_security_fuzz.py +++ b/tests/fuzz/test_security_fuzz.py @@ -99,7 +99,7 @@ def test_integer_overflow_handling(self, large_int): response = client.post("/admin/tools", json=payload, headers={"Authorization": "Basic YWRtaW46Y2hhbmdlbWU="}) - assert response.status_code in [200, 201, 400, 422] + assert response.status_code in [200, 201, 400, 401, 422] def test_path_traversal_resistance(self): """Test resistance to path traversal attacks.""" @@ -330,4 +330,4 @@ def test_rate_limiting_behavior(self): # Should either accept all or start rate limiting # Rate limiting typically returns 429 for status in responses: - assert status in [200, 201, 400, 422, 429, 409] + assert status in [200, 201, 400, 401, 422, 429, 409] diff --git a/tests/security/test_rpc_endpoint_validation.py b/tests/security/test_rpc_endpoint_validation.py index 2ec390eee..71af40285 100644 --- a/tests/security/test_rpc_endpoint_validation.py +++ b/tests/security/test_rpc_endpoint_validation.py @@ -14,6 +14,7 @@ # Standard import logging +from unittest.mock import MagicMock, patch # Third-Party from fastapi.testclient import TestClient @@ -37,9 +38,14 @@ class TestRPCEndpointValidation: """ @pytest.fixture - def client(self): - """Create a test client for the FastAPI app.""" - return TestClient(app) + def client(self, app): + """Create a test client for the FastAPI app with mocked security_logger.""" + # Mock security_logger to prevent database access + mock_sec_logger = MagicMock() + mock_sec_logger.log_authentication_attempt = MagicMock(return_value=None) + mock_sec_logger.log_security_event = MagicMock(return_value=None) + with patch("mcpgateway.middleware.auth_middleware.security_logger", mock_sec_logger): + yield TestClient(app) @pytest.fixture def auth_headers(self): @@ -269,8 +275,14 @@ class TestRPCValidationBypass: """Test various techniques to bypass RPC validation.""" @pytest.fixture - def client(self): - return TestClient(app) + def client(self, app): + """Create a test client for the FastAPI app with mocked security_logger.""" + # Mock security_logger to prevent database access + mock_sec_logger = MagicMock() + mock_sec_logger.log_authentication_attempt = MagicMock(return_value=None) + mock_sec_logger.log_security_event = MagicMock(return_value=None) + with patch("mcpgateway.middleware.auth_middleware.security_logger", mock_sec_logger): + yield TestClient(app) def test_bypass_techniques(self, client): """Test various bypass techniques.""" diff --git a/tests/unit/mcpgateway/middleware/test_auth_middleware.py b/tests/unit/mcpgateway/middleware/test_auth_middleware.py index cf8b85aa3..5882e20af 100644 --- a/tests/unit/mcpgateway/middleware/test_auth_middleware.py +++ b/tests/unit/mcpgateway/middleware/test_auth_middleware.py @@ -103,10 +103,18 @@ async def test_authentication_failure(monkeypatch): request.url.path = "/api/data" request.cookies = {"jwt_token": "bad_token"} request.headers = {} + # Mock request.client for security_logger + request.client = MagicMock() + request.client.host = "127.0.0.1" + + # Mock security_logger to prevent database operations + mock_security_logger = MagicMock() + mock_security_logger.log_authentication_attempt = MagicMock(return_value=None) with patch("mcpgateway.middleware.auth_middleware.SessionLocal", return_value=MagicMock()) as mock_session, \ patch("mcpgateway.middleware.auth_middleware.get_current_user", AsyncMock(side_effect=Exception("Invalid token"))), \ - patch("mcpgateway.middleware.auth_middleware.logger") as mock_logger: + patch("mcpgateway.middleware.auth_middleware.logger") as mock_logger, \ + patch("mcpgateway.middleware.auth_middleware.security_logger", mock_security_logger): response = await middleware.dispatch(request, call_next) call_next.assert_awaited_once_with(request) diff --git a/tests/unit/mcpgateway/middleware/test_correlation_id.py b/tests/unit/mcpgateway/middleware/test_correlation_id.py new file mode 100644 index 000000000..029d482fc --- /dev/null +++ b/tests/unit/mcpgateway/middleware/test_correlation_id.py @@ -0,0 +1,230 @@ +# -*- coding: utf-8 -*- +"""Tests for correlation ID middleware.""" + +import pytest +from unittest.mock import Mock, patch +from fastapi import FastAPI, Request +from fastapi.testclient import TestClient + +from mcpgateway.middleware.correlation_id import CorrelationIDMiddleware +from mcpgateway.utils.correlation_id import get_correlation_id + + +@pytest.fixture +def app(): + """Create a test FastAPI app with correlation ID middleware.""" + test_app = FastAPI() + + # Add the correlation ID middleware + test_app.add_middleware(CorrelationIDMiddleware) + + @test_app.get("/test") + async def test_endpoint(request: Request): + # Get correlation ID from context + correlation_id = get_correlation_id() + return {"correlation_id": correlation_id} + + return test_app + + +@pytest.fixture +def client(app): + """Create a test client.""" + return TestClient(app) + + +def test_middleware_generates_correlation_id_when_not_provided(client): + """Test that middleware generates a correlation ID when not provided by client.""" + response = client.get("/test") + + assert response.status_code == 200 + data = response.json() + + # Should have a correlation ID in response body + assert "correlation_id" in data + assert data["correlation_id"] is not None + assert len(data["correlation_id"]) == 32 # UUID hex format + + # Should have correlation ID in response headers + assert "X-Correlation-ID" in response.headers + assert response.headers["X-Correlation-ID"] == data["correlation_id"] + + +def test_middleware_preserves_client_correlation_id(client): + """Test that middleware preserves correlation ID from client.""" + client_id = "client-provided-id-123" + + response = client.get("/test", headers={"X-Correlation-ID": client_id}) + + assert response.status_code == 200 + data = response.json() + + # Should use the client-provided ID + assert data["correlation_id"] == client_id + + # Should echo it back in response headers + assert response.headers["X-Correlation-ID"] == client_id + + +def test_middleware_case_insensitive_header(client): + """Test that middleware handles case-insensitive headers.""" + client_id = "lowercase-header-id" + + response = client.get("/test", headers={"x-correlation-id": client_id}) + + assert response.status_code == 200 + data = response.json() + + # Should use the client-provided ID regardless of case + assert data["correlation_id"] == client_id + + +def test_middleware_strips_whitespace_from_header(client): + """Test that middleware strips whitespace from correlation ID header.""" + client_id = " whitespace-id " + + response = client.get("/test", headers={"X-Correlation-ID": client_id}) + + assert response.status_code == 200 + data = response.json() + + # Should strip whitespace + assert data["correlation_id"] == "whitespace-id" + + +def test_middleware_clears_correlation_id_after_request(app): + """Test that middleware clears correlation ID after request completes.""" + client = TestClient(app) + + # Make a request + response = client.get("/test") + assert response.status_code == 200 + + # After request completes, correlation ID should be cleared + # (Note: This happens in a different context, so we can't directly test it here, + # but we verify that multiple requests get different IDs) + response2 = client.get("/test") + assert response2.status_code == 200 + + # Two requests without client-provided IDs should have different correlation IDs + assert response.json()["correlation_id"] != response2.json()["correlation_id"] + + +def test_middleware_handles_empty_header(client): + """Test that middleware generates new ID when header is empty.""" + response = client.get("/test", headers={"X-Correlation-ID": ""}) + + assert response.status_code == 200 + data = response.json() + + # Should generate a new ID when header is empty + assert data["correlation_id"] is not None + assert len(data["correlation_id"]) == 32 + + +def test_middleware_with_custom_settings(monkeypatch): + """Test middleware with custom configuration settings.""" + # Create a mock settings object + mock_settings = Mock() + mock_settings.correlation_id_header = "X-Request-ID" + mock_settings.correlation_id_preserve = False + mock_settings.correlation_id_response_header = False + + # Create app with custom settings + app = FastAPI() + + # Patch settings at module level + with patch("mcpgateway.middleware.correlation_id.settings", mock_settings): + app.add_middleware(CorrelationIDMiddleware) + + @app.get("/test") + async def test_endpoint(): + return {"correlation_id": get_correlation_id()} + + client = TestClient(app) + + # Test with custom header name + response = client.get("/test", headers={"X-Request-ID": "custom-id"}) + + assert response.status_code == 200 + + # When preserve=False, should always generate new ID (not use client's) + # When response_header=False, should not include in response headers + assert "X-Request-ID" not in response.headers + + +def test_middleware_integration_with_multiple_requests(client): + """Test middleware properly isolates correlation IDs across multiple requests.""" + ids = [] + + for i in range(5): + response = client.get("/test", headers={"X-Correlation-ID": f"request-{i}"}) + assert response.status_code == 200 + ids.append(response.json()["correlation_id"]) + + # Each request should have its unique correlation ID + assert len(ids) == 5 + assert len(set(ids)) == 5 # All unique + for i, correlation_id in enumerate(ids): + assert correlation_id == f"request-{i}" + + +def test_middleware_context_isolation(): + """Test that correlation ID is properly isolated per request context.""" + app = FastAPI() + app.add_middleware(CorrelationIDMiddleware) + + correlation_ids_seen = [] + + @app.get("/capture") + async def capture_endpoint(): + # Capture the correlation ID during request handling + correlation_id = get_correlation_id() + correlation_ids_seen.append(correlation_id) + return {"captured": correlation_id} + + client = TestClient(app) + + # Make multiple concurrent-like requests + for i in range(3): + response = client.get("/capture", headers={"X-Correlation-ID": f"id-{i}"}) + assert response.status_code == 200 + + # Each request should have captured its own unique ID + assert len(correlation_ids_seen) == 3 + assert correlation_ids_seen[0] == "id-0" + assert correlation_ids_seen[1] == "id-1" + assert correlation_ids_seen[2] == "id-2" + + +def test_middleware_preserves_correlation_id_through_request_lifecycle(): + """Test that correlation ID remains consistent throughout entire request.""" + captured_ids = [] + + app = FastAPI() + + @app.middleware("http") + async def capture_middleware(request: Request, call_next): + # Capture ID at middleware level (after CorrelationIDMiddleware sets it) + captured_ids.append(("middleware", get_correlation_id())) + response = await call_next(request) + return response + + # Add CorrelationIDMiddleware last so it executes first (LIFO) + app.add_middleware(CorrelationIDMiddleware) + + @app.get("/test") + async def test_endpoint(): + # Capture ID at endpoint level + captured_ids.append(("endpoint", get_correlation_id())) + return {"ok": True} + + client = TestClient(app) + response = client.get("/test", headers={"X-Correlation-ID": "consistent-id"}) + + assert response.status_code == 200 + + # Both captures should have the same correlation ID + assert len(captured_ids) == 2 + assert captured_ids[0][1] == "consistent-id" # Middleware capture + assert captured_ids[1][1] == "consistent-id" # Endpoint capture diff --git a/tests/unit/mcpgateway/middleware/test_request_logging_middleware.py b/tests/unit/mcpgateway/middleware/test_request_logging_middleware.py index 30a2a3c26..e905d9716 100644 --- a/tests/unit/mcpgateway/middleware/test_request_logging_middleware.py +++ b/tests/unit/mcpgateway/middleware/test_request_logging_middleware.py @@ -7,6 +7,7 @@ """ import json import pytest +from unittest.mock import MagicMock from fastapi import Request, Response from starlette.datastructures import Headers from starlette.types import Scope @@ -28,7 +29,7 @@ def __init__(self): def isEnabledFor(self, level): return self.enabled - def log(self, level, msg): + def log(self, level, msg, extra=None): self.logged.append((level, msg)) def warning(self, msg): @@ -40,6 +41,15 @@ def dummy_logger(monkeypatch): monkeypatch.setattr("mcpgateway.middleware.request_logging_middleware.logger", logger) return logger + +@pytest.fixture +def mock_structured_logger(monkeypatch): + """Mock the structured_logger to prevent database writes.""" + mock_logger = MagicMock() + mock_logger.log = MagicMock() + monkeypatch.setattr("mcpgateway.middleware.request_logging_middleware.structured_logger", mock_logger) + return mock_logger + @pytest.fixture def dummy_call_next(): async def _call_next(request): @@ -112,8 +122,8 @@ def test_mask_sensitive_headers_non_sensitive(): # --- RequestLoggingMiddleware tests --- @pytest.mark.asyncio -async def test_dispatch_logs_json_body(dummy_logger, dummy_call_next): - middleware = RequestLoggingMiddleware(app=None) +async def test_dispatch_logs_json_body(dummy_logger, mock_structured_logger, dummy_call_next): + middleware = RequestLoggingMiddleware(app=None, enable_gateway_logging=False, log_detailed_requests=True) body = json.dumps({"password": "123", "data": "ok"}).encode() request = make_request(body=body, headers={"Authorization": "Bearer abc"}) response = await middleware.dispatch(request, dummy_call_next) @@ -122,8 +132,8 @@ async def test_dispatch_logs_json_body(dummy_logger, dummy_call_next): assert "******" in dummy_logger.logged[0][1] @pytest.mark.asyncio -async def test_dispatch_logs_non_json_body(dummy_logger, dummy_call_next): - middleware = RequestLoggingMiddleware(app=None) +async def test_dispatch_logs_non_json_body(dummy_logger, mock_structured_logger, dummy_call_next): + middleware = RequestLoggingMiddleware(app=None, enable_gateway_logging=False, log_detailed_requests=True) body = b"token=abc" request = make_request(body=body) response = await middleware.dispatch(request, dummy_call_next) @@ -131,8 +141,8 @@ async def test_dispatch_logs_non_json_body(dummy_logger, dummy_call_next): assert any("" in msg for _, msg in dummy_logger.logged) @pytest.mark.asyncio -async def test_dispatch_large_body_truncated(dummy_logger, dummy_call_next): - middleware = RequestLoggingMiddleware(app=None, max_body_size=10) +async def test_dispatch_large_body_truncated(dummy_logger, mock_structured_logger, dummy_call_next): + middleware = RequestLoggingMiddleware(app=None, enable_gateway_logging=False, log_detailed_requests=True, max_body_size=10) body = b"{" + b"a" * 100 + b"}" request = make_request(body=body) response = await middleware.dispatch(request, dummy_call_next) @@ -140,8 +150,8 @@ async def test_dispatch_large_body_truncated(dummy_logger, dummy_call_next): assert any("[truncated]" in msg for _, msg in dummy_logger.logged) @pytest.mark.asyncio -async def test_dispatch_logging_disabled(dummy_logger, dummy_call_next): - middleware = RequestLoggingMiddleware(app=None, log_requests=False) +async def test_dispatch_logging_disabled(dummy_logger, mock_structured_logger, dummy_call_next): + middleware = RequestLoggingMiddleware(app=None, enable_gateway_logging=False, log_detailed_requests=False) body = b"{}" request = make_request(body=body) response = await middleware.dispatch(request, dummy_call_next) @@ -149,9 +159,9 @@ async def test_dispatch_logging_disabled(dummy_logger, dummy_call_next): assert dummy_logger.logged == [] @pytest.mark.asyncio -async def test_dispatch_logger_disabled(dummy_logger, dummy_call_next): +async def test_dispatch_logger_disabled(dummy_logger, mock_structured_logger, dummy_call_next): dummy_logger.enabled = False - middleware = RequestLoggingMiddleware(app=None) + middleware = RequestLoggingMiddleware(app=None, enable_gateway_logging=False, log_detailed_requests=True) body = b"{}" request = make_request(body=body) response = await middleware.dispatch(request, dummy_call_next) @@ -159,12 +169,12 @@ async def test_dispatch_logger_disabled(dummy_logger, dummy_call_next): assert dummy_logger.logged == [] @pytest.mark.asyncio -async def test_dispatch_exception_handling(dummy_logger, dummy_call_next, monkeypatch): +async def test_dispatch_exception_handling(dummy_logger, mock_structured_logger, dummy_call_next, monkeypatch): async def bad_body(): raise ValueError("fail") request = make_request() monkeypatch.setattr(request, "body", bad_body) - middleware = RequestLoggingMiddleware(app=None) + middleware = RequestLoggingMiddleware(app=None, enable_gateway_logging=False, log_detailed_requests=True) response = await middleware.dispatch(request, dummy_call_next) assert response.status_code == 200 assert any("Failed to log request body" in msg for msg in dummy_logger.warnings) diff --git a/tests/unit/mcpgateway/plugins/framework/external/mcp/server/test_runtime.py b/tests/unit/mcpgateway/plugins/framework/external/mcp/server/test_runtime.py index 1d675a70f..b7c25724b 100644 --- a/tests/unit/mcpgateway/plugins/framework/external/mcp/server/test_runtime.py +++ b/tests/unit/mcpgateway/plugins/framework/external/mcp/server/test_runtime.py @@ -9,6 +9,7 @@ # Standard import asyncio +import json # Third-Party import pytest diff --git a/tests/unit/mcpgateway/plugins/framework/external/mcp/server/test_server.py b/tests/unit/mcpgateway/plugins/framework/external/mcp/server/test_server.py new file mode 100644 index 000000000..0c171d7cc --- /dev/null +++ b/tests/unit/mcpgateway/plugins/framework/external/mcp/server/test_server.py @@ -0,0 +1,421 @@ +# -*- coding: utf-8 -*- +"""Location: ./tests/unit/mcpgateway/plugins/framework/external/mcp/server/test_server.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Fred Araujo + +Comprehensive unit tests for ExternalPluginServer. +""" + +# Standard +import os +from unittest.mock import Mock, patch + +# Third-Party +import pytest + +# First-Party +from mcpgateway.common.models import Message, PromptResult, Role, TextContent +from mcpgateway.plugins.framework import ( + GlobalContext, + PluginContext, + PromptHookType, + PromptPosthookPayload, + PromptPrehookPayload, + ToolHookType, + ToolPreInvokePayload, +) +from mcpgateway.plugins.framework.errors import PluginError +from mcpgateway.plugins.framework.external.mcp.server.server import ExternalPluginServer +from mcpgateway.plugins.framework.models import MCPServerConfig, PluginErrorModel + + +@pytest.fixture +def server_with_plugins(): + """Create a server with valid plugin configuration.""" + return ExternalPluginServer(config_path="./tests/unit/mcpgateway/plugins/fixtures/configs/valid_multiple_plugins_filter.yaml") + + +@pytest.fixture +async def initialized_server(server_with_plugins): + """Create and initialize a server.""" + await server_with_plugins.initialize() + yield server_with_plugins + await server_with_plugins.shutdown() + + +class TestExternalPluginServerInit: + """Tests for ExternalPluginServer initialization.""" + + def test_init_with_config_path(self): + """Test initialization with explicit config path.""" + server = ExternalPluginServer(config_path="./tests/unit/mcpgateway/plugins/fixtures/configs/valid_single_plugin.yaml") + assert server._config_path == "./tests/unit/mcpgateway/plugins/fixtures/configs/valid_single_plugin.yaml" + assert server._config is not None + assert server._plugin_manager is not None + + def test_init_with_env_var(self): + """Test initialization using PLUGINS_CONFIG_PATH environment variable.""" + os.environ["PLUGINS_CONFIG_PATH"] = "./tests/unit/mcpgateway/plugins/fixtures/configs/valid_single_plugin.yaml" + try: + server = ExternalPluginServer() + assert server._config_path == "./tests/unit/mcpgateway/plugins/fixtures/configs/valid_single_plugin.yaml" + assert server._config is not None + finally: + if "PLUGINS_CONFIG_PATH" in os.environ: + del os.environ["PLUGINS_CONFIG_PATH"] + + def test_init_with_default_path(self): + """Test initialization with default config path.""" + # Temporarily remove env var if it exists + env_backup = os.environ.pop("PLUGINS_CONFIG_PATH", None) + try: + with patch("os.path.join", return_value="./resources/plugins/config.yaml"): + with patch("mcpgateway.plugins.framework.loader.config.ConfigLoader.load_config") as mock_load: + mock_load.return_value = Mock(plugins=[], server_settings=None) + server = ExternalPluginServer() + assert "./resources/plugins/config.yaml" in server._config_path + finally: + if env_backup: + os.environ["PLUGINS_CONFIG_PATH"] = env_backup + + def test_init_with_invalid_config(self): + """Test initialization with invalid config path uses defaults or raises error.""" + # ConfigLoader may handle missing files by returning empty config + # This test verifies the server can be instantiated (or raises if validation fails) + try: + server = ExternalPluginServer(config_path="./nonexistent/path/config.yaml") + # If it succeeds, just verify server was created + assert server is not None + except Exception: + # If it raises, that's also acceptable behavior + pass + + +class TestGetPluginConfigs: + """Tests for get_plugin_configs method.""" + + @pytest.mark.asyncio + async def test_get_plugin_configs_multiple(self, server_with_plugins): + """Test getting multiple plugin configurations.""" + configs = await server_with_plugins.get_plugin_configs() + assert isinstance(configs, list) + assert len(configs) > 0 + # Verify each config is a dict with expected keys + for config in configs: + assert isinstance(config, dict) + assert "name" in config + + @pytest.mark.asyncio + async def test_get_plugin_configs_single(self): + """Test getting plugin configs with single plugin.""" + server = ExternalPluginServer(config_path="./tests/unit/mcpgateway/plugins/fixtures/configs/valid_single_plugin.yaml") + configs = await server.get_plugin_configs() + assert len(configs) == 1 + assert configs[0]["name"] == "ReplaceBadWordsPlugin" + + @pytest.mark.asyncio + async def test_get_plugin_configs_empty(self): + """Test getting plugin configs when no plugins configured.""" + server = ExternalPluginServer(config_path="./tests/unit/mcpgateway/plugins/fixtures/configs/valid_single_plugin.yaml") + # Mock empty plugins list + server._config.plugins = None + configs = await server.get_plugin_configs() + assert configs == [] + + +class TestGetPluginConfig: + """Tests for get_plugin_config method.""" + + @pytest.mark.asyncio + async def test_get_plugin_config_found(self, server_with_plugins): + """Test getting a specific plugin config by name.""" + config = await server_with_plugins.get_plugin_config(name="DenyListPlugin") + assert config is not None + assert config["name"] == "DenyListPlugin" + + @pytest.mark.asyncio + async def test_get_plugin_config_case_insensitive(self, server_with_plugins): + """Test that plugin name lookup is case-insensitive.""" + config = await server_with_plugins.get_plugin_config(name="denylistplugin") + assert config is not None + assert config["name"] == "DenyListPlugin" + + @pytest.mark.asyncio + async def test_get_plugin_config_not_found(self, server_with_plugins): + """Test getting a non-existent plugin config returns None.""" + config = await server_with_plugins.get_plugin_config(name="NonExistentPlugin") + assert config is None + + @pytest.mark.asyncio + async def test_get_plugin_config_empty_plugins(self): + """Test getting plugin config when no plugins configured.""" + server = ExternalPluginServer(config_path="./tests/unit/mcpgateway/plugins/fixtures/configs/valid_single_plugin.yaml") + server._config.plugins = None + config = await server.get_plugin_config(name="AnyPlugin") + assert config is None + + +class TestInvokeHook: + """Tests for invoke_hook method.""" + + @pytest.mark.asyncio + async def test_invoke_hook_success(self, initialized_server): + """Test successful hook invocation.""" + payload = PromptPrehookPayload(prompt_id="123", name="test_prompt", args={"user": "This is so innovative"}) + context = PluginContext(global_context=GlobalContext(request_id="1", server_id="2")) + + result = await initialized_server.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, "DenyListPlugin", payload.model_dump(), context.model_dump()) + + assert result is not None + assert "plugin_name" in result + assert result["plugin_name"] == "DenyListPlugin" + assert "result" in result + assert result["result"]["continue_processing"] is False + + @pytest.mark.asyncio + async def test_invoke_hook_with_context_update(self, initialized_server): + """Test that hook invocation includes updated context in response.""" + payload = PromptPrehookPayload(prompt_id="123", name="test_prompt", args={"user": "normal text"}) + context = PluginContext(global_context=GlobalContext(request_id="1", server_id="2")) + + result = await initialized_server.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, "DenyListPlugin", payload.model_dump(), context.model_dump()) + + assert result is not None + assert "plugin_name" in result + # Context may or may not be included depending on whether it was modified + + @pytest.mark.asyncio + async def test_invoke_hook_plugin_error(self, initialized_server): + """Test hook invocation when plugin raises PluginError.""" + with patch("mcpgateway.plugins.framework.manager.PluginManager.invoke_hook_for_plugin") as mock_invoke: + # Simulate a PluginError + error = PluginErrorModel(message="Test error", plugin_name="TestPlugin", code="TEST_ERROR") + mock_invoke.side_effect = PluginError(error=error) + + payload = PromptPrehookPayload(prompt_id="123", args={}) + context = PluginContext(global_context=GlobalContext(request_id="1", server_id="2")) + + result = await initialized_server.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, "DenyListPlugin", payload.model_dump(), context.model_dump()) + + assert result is not None + assert "error" in result + # error is a PluginErrorModel object, not a dict + error_obj = result["error"] + assert isinstance(error_obj, PluginErrorModel) + assert error_obj.message == "Test error" + assert error_obj.plugin_name == "TestPlugin" + + @pytest.mark.asyncio + async def test_invoke_hook_generic_exception(self, initialized_server): + """Test hook invocation when plugin raises generic exception.""" + with patch("mcpgateway.plugins.framework.manager.PluginManager.invoke_hook_for_plugin") as mock_invoke: + # Simulate a generic exception + mock_invoke.side_effect = ValueError("Unexpected error") + + payload = PromptPrehookPayload(prompt_id="123", args={}) + context = PluginContext(global_context=GlobalContext(request_id="1", server_id="2")) + + result = await initialized_server.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, "DenyListPlugin", payload.model_dump(), context.model_dump()) + + assert result is not None + assert "error" in result + assert "Unexpected error" in result["error"]["message"] + assert result["error"]["plugin_name"] == "DenyListPlugin" + + @pytest.mark.asyncio + async def test_invoke_hook_invalid_context(self, initialized_server): + """Test hook invocation with invalid context data returns error.""" + payload = PromptPrehookPayload(prompt_id="123", args={}) + # Invalid context dict + invalid_context = {"invalid": "data"} + + # The method catches exceptions and returns them in the result + result = await initialized_server.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, "DenyListPlugin", payload.model_dump(), invalid_context) + + # Should return an error result instead of raising + assert result is not None + assert "error" in result + + @pytest.mark.asyncio + async def test_invoke_hook_tool_hooks(self, initialized_server): + """Test invoking tool pre/post hooks.""" + # Test tool pre-invoke + payload = ToolPreInvokePayload(name="test_tool", args={"arg": "value"}) + context = PluginContext(global_context=GlobalContext(request_id="1", server_id="2")) + + result = await initialized_server.invoke_hook(ToolHookType.TOOL_PRE_INVOKE, "ReplaceBadWordsPlugin", payload.model_dump(), context.model_dump()) + + assert result is not None + assert "plugin_name" in result + assert result["plugin_name"] == "ReplaceBadWordsPlugin" + + @pytest.mark.asyncio + async def test_invoke_hook_prompt_post_fetch(self, initialized_server): + """Test invoking prompt post-fetch hook.""" + message = Message(content=TextContent(type="text", text="test content"), role=Role.USER) + prompt_result = PromptResult(messages=[message]) + payload = PromptPosthookPayload(prompt_id="123", result=prompt_result) + context = PluginContext(global_context=GlobalContext(request_id="1", server_id="2")) + + result = await initialized_server.invoke_hook(PromptHookType.PROMPT_POST_FETCH, "ReplaceBadWordsPlugin", payload.model_dump(), context.model_dump()) + + assert result is not None + assert "plugin_name" in result + assert result["plugin_name"] == "ReplaceBadWordsPlugin" + + +class TestInitializeShutdown: + """Tests for initialize and shutdown methods.""" + + @pytest.mark.asyncio + async def test_initialize_success(self, server_with_plugins): + """Test successful initialization.""" + result = await server_with_plugins.initialize() + assert result is True + assert server_with_plugins._plugin_manager.initialized is True + await server_with_plugins.shutdown() + + @pytest.mark.asyncio + async def test_initialize_idempotent(self, server_with_plugins): + """Test that multiple initializations are safe.""" + await server_with_plugins.initialize() + await server_with_plugins.initialize() + # Should still return True + assert server_with_plugins._plugin_manager.initialized is True + await server_with_plugins.shutdown() + + @pytest.mark.asyncio + async def test_shutdown_when_initialized(self, initialized_server): + """Test shutdown on initialized server.""" + assert initialized_server._plugin_manager.initialized is True + await initialized_server.shutdown() + assert initialized_server._plugin_manager.initialized is False + + @pytest.mark.asyncio + async def test_shutdown_when_not_initialized(self, server_with_plugins): + """Test shutdown on non-initialized server (should be safe).""" + assert server_with_plugins._plugin_manager.initialized is False + # Should not raise an error + await server_with_plugins.shutdown() + assert server_with_plugins._plugin_manager.initialized is False + + @pytest.mark.asyncio + async def test_shutdown_idempotent(self, initialized_server): + """Test that multiple shutdowns are safe.""" + await initialized_server.shutdown() + # Second shutdown should be safe + await initialized_server.shutdown() + + +class TestGetServerConfig: + """Tests for get_server_config method.""" + + def test_get_server_config_with_settings(self): + """Test getting server config when server_settings is configured.""" + server = ExternalPluginServer(config_path="./tests/unit/mcpgateway/plugins/fixtures/configs/valid_single_plugin.yaml") + + # Mock server settings + expected_config = MCPServerConfig(host="0.0.0.0", port=8080, tls_enabled=False) + server._config.server_settings = expected_config + + config = server.get_server_config() + assert config == expected_config + assert config.host == "0.0.0.0" + assert config.port == 8080 + + def test_get_server_config_from_env(self): + """Test getting server config from environment variables.""" + server = ExternalPluginServer(config_path="./tests/unit/mcpgateway/plugins/fixtures/configs/valid_single_plugin.yaml") + server._config.server_settings = None + + # Set environment variables + os.environ["MCP_SERVER_HOST"] = "127.0.0.1" + os.environ["MCP_SERVER_PORT"] = "9090" + + try: + config = server.get_server_config() + assert config is not None + # Should have loaded from env or defaults + finally: + # Cleanup + os.environ.pop("MCP_SERVER_HOST", None) + os.environ.pop("MCP_SERVER_PORT", None) + + def test_get_server_config_defaults(self): + """Test getting server config with defaults.""" + server = ExternalPluginServer(config_path="./tests/unit/mcpgateway/plugins/fixtures/configs/valid_single_plugin.yaml") + server._config.server_settings = None + + config = server.get_server_config() + assert config is not None + assert isinstance(config, MCPServerConfig) + + def test_get_server_config_with_tls(self, tmp_path): + """Test getting server config with TLS enabled.""" + # First-Party + from mcpgateway.plugins.framework.models import MCPServerTLSConfig + + server = ExternalPluginServer(config_path="./tests/unit/mcpgateway/plugins/fixtures/configs/valid_single_plugin.yaml") + + # Create dummy cert files for validation + cert_file = tmp_path / "cert.pem" + key_file = tmp_path / "key.pem" + cert_file.write_text("cert") + key_file.write_text("key") + + tls_settings = MCPServerTLSConfig(certfile=str(cert_file), keyfile=str(key_file)) + tls_config = MCPServerConfig(host="0.0.0.0", port=8443, tls=tls_settings) + server._config.server_settings = tls_config + + config = server.get_server_config() + assert config.tls is not None + assert config.tls.certfile == str(cert_file) + assert config.tls.keyfile == str(key_file) + + +class TestEdgeCases: + """Tests for edge cases and error conditions.""" + + def test_doctest_example(self): + """Test the doctest example from __init__.""" + server = ExternalPluginServer(config_path="./tests/unit/mcpgateway/plugins/fixtures/configs/valid_multiple_plugins_filter.yaml") + assert server is not None + + @pytest.mark.asyncio + async def test_doctest_get_plugin_configs(self): + """Test the doctest example from get_plugin_configs.""" + server = ExternalPluginServer(config_path="./tests/unit/mcpgateway/plugins/fixtures/configs/valid_multiple_plugins_filter.yaml") + plugins = await server.get_plugin_configs() + assert len(plugins) > 0 + + @pytest.mark.asyncio + async def test_doctest_get_plugin_config(self): + """Test the doctest example from get_plugin_config.""" + server = ExternalPluginServer(config_path="./tests/unit/mcpgateway/plugins/fixtures/configs/valid_multiple_plugins_filter.yaml") + config = await server.get_plugin_config(name="DenyListPlugin") + assert config is not None + assert config["name"] == "DenyListPlugin" + + @pytest.mark.asyncio + async def test_invoke_hook_with_empty_payload(self, initialized_server): + """Test hook invocation with minimal/empty payload.""" + payload = PromptPrehookPayload(prompt_id="123", args={}) + context = PluginContext(global_context=GlobalContext(request_id="1", server_id="2")) + + result = await initialized_server.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, "DenyListPlugin", payload.model_dump(), context.model_dump()) + + assert result is not None + assert "plugin_name" in result + + @pytest.mark.asyncio + async def test_invoke_hook_with_complex_payload(self, initialized_server): + """Test hook invocation with multiple arguments.""" + # PromptPrehookPayload args values must be strings + payload = PromptPrehookPayload(prompt_id="123", args={"user": "test message", "system": "system prompt", "context": "additional context"}) + context = PluginContext(global_context=GlobalContext(request_id="1", server_id="2")) + + result = await initialized_server.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, "DenyListPlugin", payload.model_dump(), context.model_dump()) + + assert result is not None + assert "plugin_name" in result diff --git a/tests/unit/mcpgateway/plugins/framework/external/mcp/test_tls_utils.py b/tests/unit/mcpgateway/plugins/framework/external/mcp/test_tls_utils.py new file mode 100644 index 000000000..751045d33 --- /dev/null +++ b/tests/unit/mcpgateway/plugins/framework/external/mcp/test_tls_utils.py @@ -0,0 +1,369 @@ +# -*- coding: utf-8 -*- +"""Location: ./tests/unit/mcpgateway/plugins/framework/external/mcp/test_tls_utils.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Fred Araujo + +Additional unit tests for TLS utilities to improve code coverage. +""" + +# Standard +import ssl +from unittest.mock import patch + +# Third-Party +import pytest + +# First-Party +from mcpgateway.plugins.framework.errors import PluginError +from mcpgateway.plugins.framework.external.mcp.tls_utils import create_ssl_context +from mcpgateway.plugins.framework.models import MCPClientTLSConfig + + +class TestCreateSSLContextBasicConfig: + """Tests for basic SSL context configuration.""" + + def test_create_ssl_context_minimal_config(self): + """Test creating SSL context with minimal configuration.""" + tls_config = MCPClientTLSConfig(verify=True) + + ssl_context = create_ssl_context(tls_config, "MinimalPlugin") + + assert ssl_context is not None + assert ssl_context.verify_mode == ssl.CERT_REQUIRED + assert ssl_context.check_hostname is True + assert ssl_context.minimum_version == ssl.TLSVersion.TLSv1_2 + + def test_create_ssl_context_verify_disabled(self): + """Test creating SSL context with verification disabled.""" + tls_config = MCPClientTLSConfig(verify=False, check_hostname=False) + + ssl_context = create_ssl_context(tls_config, "InsecurePlugin") + + assert ssl_context is not None + assert ssl_context.verify_mode == ssl.CERT_NONE + assert ssl_context.check_hostname is False + + def test_create_ssl_context_with_ca_bundle(self, tmp_path): + """Test creating SSL context with CA bundle.""" + # Create a temporary CA file + ca_file = tmp_path / "ca.pem" + ca_file.write_text("-----BEGIN CERTIFICATE-----\ntest\n-----END CERTIFICATE-----") + + tls_config = MCPClientTLSConfig(ca_bundle=str(ca_file), verify=True) + + # Will fail to load the invalid cert but we're testing the path is used + with pytest.raises(PluginError): + create_ssl_context(tls_config, "TestPlugin") + + def test_create_ssl_context_hostname_check_disabled(self): + """Test creating SSL context with hostname checking disabled but verify enabled.""" + tls_config = MCPClientTLSConfig(verify=True, check_hostname=False) + + ssl_context = create_ssl_context(tls_config, "NoHostnameCheckPlugin") + + assert ssl_context is not None + assert ssl_context.verify_mode == ssl.CERT_REQUIRED + assert ssl_context.check_hostname is False + + +class TestCreateSSLContextClientCertificates: + """Tests for SSL context with client certificates (mTLS).""" + + def test_create_ssl_context_with_client_cert(self, tmp_path): + """Test creating SSL context with client certificate.""" + cert_file = tmp_path / "client.crt" + key_file = tmp_path / "client.key" + cert_file.write_text("-----BEGIN CERTIFICATE-----\ntest\n-----END CERTIFICATE-----") + key_file.write_text("-----BEGIN PRIVATE KEY-----\ntest\n-----END PRIVATE KEY-----") + + tls_config = MCPClientTLSConfig(certfile=str(cert_file), keyfile=str(key_file), verify=False) + + # Will fail to load the invalid cert but we're testing the path is used + with pytest.raises(PluginError): + create_ssl_context(tls_config, "mTLSPlugin") + + def test_create_ssl_context_with_cert_no_key(self, tmp_path): + """Test creating SSL context with cert but no key (should use same file).""" + cert_file = tmp_path / "combined.pem" + cert_file.write_text("-----BEGIN CERTIFICATE-----\ntest\n-----END CERTIFICATE-----") + + tls_config = MCPClientTLSConfig(certfile=str(cert_file), keyfile=None, verify=False) + + # Will fail to load the invalid cert + with pytest.raises(PluginError): + create_ssl_context(tls_config, "CombinedPEMPlugin") + + def test_create_ssl_context_with_encrypted_key(self, tmp_path): + """Test creating SSL context with encrypted private key.""" + cert_file = tmp_path / "client.crt" + key_file = tmp_path / "client.key" + cert_file.write_text("-----BEGIN CERTIFICATE-----\ntest\n-----END CERTIFICATE-----") + key_file.write_text("-----BEGIN ENCRYPTED PRIVATE KEY-----\ntest\n-----END ENCRYPTED PRIVATE KEY-----") + + tls_config = MCPClientTLSConfig(certfile=str(cert_file), keyfile=str(key_file), keyfile_password="secret123", verify=False) + + # Will fail to load the invalid cert + with pytest.raises(PluginError): + create_ssl_context(tls_config, "EncryptedKeyPlugin") + + +class TestCreateSSLContextSecuritySettings: + """Tests for SSL context security settings.""" + + def test_ssl_context_enforces_tls_1_2_minimum(self): + """Test that SSL context enforces TLS 1.2 as minimum version.""" + tls_config = MCPClientTLSConfig(verify=True) + + ssl_context = create_ssl_context(tls_config, "SecurePlugin") + + assert ssl_context.minimum_version == ssl.TLSVersion.TLSv1_2 + # Ensure weak protocols are not allowed + assert ssl_context.minimum_version > ssl.TLSVersion.TLSv1_1 + + def test_ssl_context_uses_default_context_security(self): + """Test that ssl.create_default_context() security settings are preserved.""" + tls_config = MCPClientTLSConfig(verify=True) + + ssl_context = create_ssl_context(tls_config, "DefaultSecurityPlugin") + + # create_default_context() sets secure defaults + # Verify CERT_REQUIRED is set (from create_default_context) + assert ssl_context.verify_mode == ssl.CERT_REQUIRED + + +class TestCreateSSLContextErrorHandling: + """Tests for error handling in create_ssl_context.""" + + def test_create_ssl_context_invalid_ca_bundle(self, tmp_path): + """Test that invalid CA bundle content raises PluginError.""" + # Create a file with invalid certificate content + ca_file = tmp_path / "invalid_ca.pem" + ca_file.write_text("INVALID CERTIFICATE CONTENT") + + tls_config = MCPClientTLSConfig(ca_bundle=str(ca_file), verify=True) + + with pytest.raises(PluginError) as exc_info: + create_ssl_context(tls_config, "InvalidCAPlugin") + + assert "InvalidCAPlugin" in str(exc_info.value) + assert "Failed to configure SSL context" in str(exc_info.value) + + def test_create_ssl_context_invalid_client_cert(self, tmp_path): + """Test that invalid client certificate content raises PluginError.""" + # Create files with invalid certificate/key content + cert_file = tmp_path / "invalid_cert.pem" + key_file = tmp_path / "invalid_key.pem" + cert_file.write_text("INVALID CERT") + key_file.write_text("INVALID KEY") + + tls_config = MCPClientTLSConfig(certfile=str(cert_file), keyfile=str(key_file), verify=False) + + with pytest.raises(PluginError) as exc_info: + create_ssl_context(tls_config, "InvalidCertPlugin") + + assert "InvalidCertPlugin" in str(exc_info.value) + assert "Failed to configure SSL context" in str(exc_info.value) + + def test_create_ssl_context_exception_includes_plugin_name(self, tmp_path): + """Test that PluginError includes the plugin name in error details.""" + # Create a file with invalid content + ca_file = tmp_path / "bad_ca.pem" + ca_file.write_text("BAD CONTENT") + + tls_config = MCPClientTLSConfig(ca_bundle=str(ca_file), verify=True) + + with pytest.raises(PluginError) as exc_info: + create_ssl_context(tls_config, "MyTestPlugin") + + error = exc_info.value + assert error.error.plugin_name == "MyTestPlugin" + assert "MyTestPlugin" in error.error.message + + def test_create_ssl_context_generic_exception_handling(self): + """Test that any exception during SSL context creation is caught and wrapped.""" + tls_config = MCPClientTLSConfig(verify=True) + + with patch("ssl.create_default_context") as mock_create: + mock_create.side_effect = RuntimeError("SSL initialization failed") + + with pytest.raises(PluginError) as exc_info: + create_ssl_context(tls_config, "FailingPlugin") + + assert "Failed to configure SSL context" in str(exc_info.value) + assert "FailingPlugin" in str(exc_info.value) + + +class TestCreateSSLContextLogging: + """Tests for logging in create_ssl_context.""" + + def test_create_ssl_context_logs_verification_disabled(self): + """Test that disabling verification logs a warning.""" + tls_config = MCPClientTLSConfig(verify=False) + + with patch("mcpgateway.plugins.framework.external.mcp.tls_utils.logger") as mock_logger: + create_ssl_context(tls_config, "InsecurePlugin") + + # Should log warning about disabled verification + assert mock_logger.warning.called + warning_calls = [call for call in mock_logger.warning.call_args_list] + assert any("verification disabled" in str(call).lower() for call in warning_calls) + + def test_create_ssl_context_logs_hostname_check_disabled(self): + """Test that disabling hostname checking logs a warning.""" + tls_config = MCPClientTLSConfig(verify=True, check_hostname=False) + + with patch("mcpgateway.plugins.framework.external.mcp.tls_utils.logger") as mock_logger: + create_ssl_context(tls_config, "NoHostnamePlugin") + + # Should log warning about disabled hostname verification + assert mock_logger.warning.called + warning_calls = [call for call in mock_logger.warning.call_args_list] + assert any("hostname" in str(call).lower() for call in warning_calls) + + def test_create_ssl_context_logs_mtls_enabled(self, tmp_path): + """Test that mTLS configuration is logged.""" + cert_file = tmp_path / "client.crt" + key_file = tmp_path / "client.key" + # Create minimal valid-looking PEM files + cert_file.write_text("-----BEGIN CERTIFICATE-----\ntest\n-----END CERTIFICATE-----") + key_file.write_text("-----BEGIN PRIVATE KEY-----\ntest\n-----END PRIVATE KEY-----") + + tls_config = MCPClientTLSConfig(certfile=str(cert_file), keyfile=str(key_file), verify=False) + + with patch("mcpgateway.plugins.framework.external.mcp.tls_utils.logger"): + # Will fail but we can check if debug logging was attempted + try: + create_ssl_context(tls_config, "mTLSPlugin") + except PluginError: + pass # Expected to fail with invalid cert + + # Should have attempted to log debug message about mTLS + # (even though it failed) + + def test_create_ssl_context_logs_debug_info(self): + """Test that SSL context configuration is logged at debug level.""" + tls_config = MCPClientTLSConfig(verify=True) + + with patch("mcpgateway.plugins.framework.external.mcp.tls_utils.logger") as mock_logger: + create_ssl_context(tls_config, "DebugPlugin") + + # Should log debug message with context details + assert mock_logger.debug.called + + def test_create_ssl_context_logs_error_on_failure(self, tmp_path): + """Test that errors are logged.""" + # Create a file with invalid content + ca_file = tmp_path / "bad.pem" + ca_file.write_text("INVALID") + + tls_config = MCPClientTLSConfig(ca_bundle=str(ca_file), verify=True) + + with patch("mcpgateway.plugins.framework.external.mcp.tls_utils.logger") as mock_logger: + with pytest.raises(PluginError): + create_ssl_context(tls_config, "ErrorPlugin") + + # Should log error + assert mock_logger.error.called + + +class TestCreateSSLContextIntegration: + """Integration tests for create_ssl_context.""" + + def test_create_ssl_context_production_like_config(self): + """Test creating SSL context with production-like configuration.""" + tls_config = MCPClientTLSConfig(verify=True, check_hostname=True) + + ssl_context = create_ssl_context(tls_config, "ProductionPlugin") + + # Verify all security features are enabled + assert ssl_context.verify_mode == ssl.CERT_REQUIRED + assert ssl_context.check_hostname is True + assert ssl_context.minimum_version == ssl.TLSVersion.TLSv1_2 + + def test_create_ssl_context_development_config(self): + """Test creating SSL context with development/testing configuration.""" + tls_config = MCPClientTLSConfig(verify=False, check_hostname=False) + + ssl_context = create_ssl_context(tls_config, "DevPlugin") + + # Verify security is relaxed + assert ssl_context.verify_mode == ssl.CERT_NONE + assert ssl_context.check_hostname is False + + def test_create_ssl_context_mixed_security_config(self): + """Test creating SSL context with mixed security settings.""" + # Verify enabled but hostname check disabled + tls_config = MCPClientTLSConfig(verify=True, check_hostname=False) + + ssl_context = create_ssl_context(tls_config, "MixedPlugin") + + assert ssl_context.verify_mode == ssl.CERT_REQUIRED + assert ssl_context.check_hostname is False + + +class TestCreateSSLContextCompliance: + """Tests for SSL context compliance with security standards.""" + + def test_ssl_context_meets_tls_requirements(self): + """Test that SSL context meets modern TLS requirements.""" + tls_config = MCPClientTLSConfig(verify=True) + + ssl_context = create_ssl_context(tls_config, "CompliancePlugin") + + # Modern security requirements + assert ssl_context.minimum_version >= ssl.TLSVersion.TLSv1_2 + assert ssl_context.verify_mode in [ssl.CERT_REQUIRED, ssl.CERT_OPTIONAL] + + def test_ssl_context_default_is_secure(self): + """Test that default SSL context configuration is secure.""" + tls_config = MCPClientTLSConfig() # All defaults + + ssl_context = create_ssl_context(tls_config, "DefaultPlugin") + + # Defaults should be secure + assert ssl_context.verify_mode == ssl.CERT_REQUIRED + assert ssl_context.check_hostname is True + assert ssl_context.minimum_version == ssl.TLSVersion.TLSv1_2 + + +class TestCreateSSLContextEdgeCases: + """Tests for edge cases in create_ssl_context.""" + + def test_create_ssl_context_empty_plugin_name(self): + """Test creating SSL context with empty plugin name.""" + tls_config = MCPClientTLSConfig(verify=True) + + ssl_context = create_ssl_context(tls_config, "") + + assert ssl_context is not None + + def test_create_ssl_context_special_chars_in_plugin_name(self): + """Test creating SSL context with special characters in plugin name.""" + tls_config = MCPClientTLSConfig(verify=True) + + ssl_context = create_ssl_context(tls_config, "Plugin-Name_123!@#") + + assert ssl_context is not None + + def test_create_ssl_context_unicode_plugin_name(self): + """Test creating SSL context with unicode characters in plugin name.""" + tls_config = MCPClientTLSConfig(verify=True) + + ssl_context = create_ssl_context(tls_config, "プラグイン") + + assert ssl_context is not None + + def test_create_ssl_context_verify_true_hostname_false(self): + """Test the combination of verify=True with check_hostname=False.""" + tls_config = MCPClientTLSConfig(verify=True, check_hostname=False) + + with patch("mcpgateway.plugins.framework.external.mcp.tls_utils.logger") as mock_logger: + ssl_context = create_ssl_context(tls_config, "PartialSecurityPlugin") + + # Should warn about hostname verification being disabled + assert mock_logger.warning.called + # Should still have CERT_REQUIRED + assert ssl_context.verify_mode == ssl.CERT_REQUIRED + # But hostname check should be disabled + assert ssl_context.check_hostname is False diff --git a/tests/unit/mcpgateway/services/test_a2a_service.py b/tests/unit/mcpgateway/services/test_a2a_service.py index 34a2e34b2..0b45fe87c 100644 --- a/tests/unit/mcpgateway/services/test_a2a_service.py +++ b/tests/unit/mcpgateway/services/test_a2a_service.py @@ -21,6 +21,21 @@ from mcpgateway.schemas import A2AAgentCreate, A2AAgentUpdate from mcpgateway.services.a2a_service import A2AAgentError, A2AAgentNameConflictError, A2AAgentNotFoundError, A2AAgentService + +@pytest.fixture(autouse=True) +def mock_logging_services(): + """Mock structured_logger and audit_trail to prevent database writes during tests.""" + with patch("mcpgateway.services.a2a_service.structured_logger") as mock_a2a_logger, \ + patch("mcpgateway.services.tool_service.structured_logger") as mock_tool_logger, \ + patch("mcpgateway.services.tool_service.audit_trail") as mock_tool_audit: + mock_a2a_logger.log = MagicMock(return_value=None) + mock_a2a_logger.info = MagicMock(return_value=None) + mock_tool_logger.log = MagicMock(return_value=None) + mock_tool_logger.info = MagicMock(return_value=None) + mock_tool_audit.log_action = MagicMock(return_value=None) + yield {"structured_logger": mock_a2a_logger, "tool_logger": mock_tool_logger, "tool_audit": mock_tool_audit} + + class TestA2AAgentService: """Test suite for A2A Agent Service.""" diff --git a/tests/unit/mcpgateway/services/test_correlation_id_json_formatter.py b/tests/unit/mcpgateway/services/test_correlation_id_json_formatter.py new file mode 100644 index 000000000..337e23f27 --- /dev/null +++ b/tests/unit/mcpgateway/services/test_correlation_id_json_formatter.py @@ -0,0 +1,307 @@ +# -*- coding: utf-8 -*- +"""Tests for correlation ID JSON formatter.""" + +import json +import logging +from datetime import datetime, timezone +from io import StringIO +from unittest.mock import Mock, patch + +import pytest + +from mcpgateway.services.logging_service import CorrelationIdJsonFormatter +from mcpgateway.utils.correlation_id import set_correlation_id, clear_correlation_id + + +@pytest.fixture +def formatter(): + """Create a test JSON formatter.""" + return CorrelationIdJsonFormatter() + + +@pytest.fixture +def logger_with_formatter(formatter): + """Create a test logger with JSON formatter.""" + logger = logging.getLogger("test_correlation_logger") + logger.setLevel(logging.DEBUG) + logger.handlers.clear() + + # Add string stream handler + stream = StringIO() + handler = logging.StreamHandler(stream) + handler.setFormatter(formatter) + logger.addHandler(handler) + + return logger, stream + + +def test_formatter_includes_correlation_id(logger_with_formatter): + """Test that formatter includes correlation ID in log records.""" + logger, stream = logger_with_formatter + + # Set correlation ID + test_id = "test-correlation-123" + set_correlation_id(test_id) + + # Log a message + logger.info("Test message") + + # Get the logged output + output = stream.getvalue() + log_record = json.loads(output.strip()) + + # Should include correlation ID + assert "request_id" in log_record + assert log_record["request_id"] == test_id + + clear_correlation_id() + + +def test_formatter_without_correlation_id(logger_with_formatter): + """Test formatter when correlation ID is not set.""" + logger, stream = logger_with_formatter + + # Clear any existing correlation ID + clear_correlation_id() + + # Log a message + logger.info("Test message without correlation ID") + + # Get the logged output + output = stream.getvalue() + log_record = json.loads(output.strip()) + + # request_id should not be present + assert "request_id" not in log_record or log_record.get("request_id") is None + + +def test_formatter_includes_standard_fields(logger_with_formatter): + """Test that formatter includes standard log fields.""" + logger, stream = logger_with_formatter + + # Log a message + logger.info("Standard fields test") + + # Get the logged output + output = stream.getvalue() + log_record = json.loads(output.strip()) + + # Check for standard fields + assert "message" in log_record + assert log_record["message"] == "Standard fields test" + assert "@timestamp" in log_record + assert "hostname" in log_record + assert "process_id" in log_record + # Note: levelname is included by the JsonFormatter format string if specified + + +def test_formatter_includes_opentelemetry_trace_context(logger_with_formatter): + """Test that formatter includes OpenTelemetry trace context when available.""" + logger, stream = logger_with_formatter + + # Mock OpenTelemetry span + mock_span_context = Mock() + mock_span_context.trace_id = 0x1234567890abcdef1234567890abcdef + mock_span_context.span_id = 0x1234567890abcdef + mock_span_context.trace_flags = 0x01 + mock_span_context.is_valid = True + + mock_span = Mock() + mock_span.is_recording.return_value = True + mock_span.get_span_context.return_value = mock_span_context + + with patch("mcpgateway.services.logging_service.trace") as mock_trace: + mock_trace.get_current_span.return_value = mock_span + + # Log a message + logger.info("Test with trace context") + + # Get the logged output + output = stream.getvalue() + log_record = json.loads(output.strip()) + + # Should include trace context + assert "trace_id" in log_record + assert "span_id" in log_record + assert "trace_flags" in log_record + + # Verify hex formatting + assert log_record["trace_id"] == "1234567890abcdef1234567890abcdef" + assert log_record["span_id"] == "1234567890abcdef" + assert log_record["trace_flags"] == "01" + + +def test_formatter_handles_missing_opentelemetry(logger_with_formatter): + """Test that formatter gracefully handles missing OpenTelemetry.""" + logger, stream = logger_with_formatter + + # Simulate ImportError for opentelemetry + import sys + with patch.dict(sys.modules, {"opentelemetry.trace": None}): + # Log a message + logger.info("Test without OpenTelemetry") + + # Get the logged output + output = stream.getvalue() + log_record = json.loads(output.strip()) + + # Should not fail, just exclude trace fields + assert "trace_id" not in log_record + assert "span_id" not in log_record + assert "message" in log_record + + +def test_formatter_timestamp_format(logger_with_formatter): + """Test that timestamp is in ISO 8601 format with 'Z' suffix.""" + logger, stream = logger_with_formatter + + # Log a message + logger.info("Timestamp test") + + # Get the logged output + output = stream.getvalue() + log_record = json.loads(output.strip()) + + # Check timestamp format + assert "@timestamp" in log_record + timestamp = log_record["@timestamp"] + + # Should end with 'Z' (Zulu/UTC time) + assert timestamp.endswith("Z") + + # Should be parseable as ISO 8601 + # Remove 'Z' and parse + datetime.fromisoformat(timestamp.replace("Z", "+00:00")) + + +def test_formatter_with_extra_fields(logger_with_formatter): + """Test that formatter includes extra fields from log record.""" + logger, stream = logger_with_formatter + + # Log with extra fields + logger.info("Extra fields test", extra={"user_id": "user-123", "action": "login"}) + + # Get the logged output + output = stream.getvalue() + log_record = json.loads(output.strip()) + + # Should include extra fields + assert log_record.get("user_id") == "user-123" + assert log_record.get("action") == "login" + + +def test_formatter_correlation_id_with_trace_context(logger_with_formatter): + """Test that both correlation ID and trace context coexist.""" + logger, stream = logger_with_formatter + + # Set correlation ID + set_correlation_id("both-test-id") + + # Mock OpenTelemetry span + mock_span_context = Mock() + mock_span_context.trace_id = 0xabcdef + mock_span_context.span_id = 0x123456 + mock_span_context.trace_flags = 0x01 + mock_span_context.is_valid = True + + mock_span = Mock() + mock_span.is_recording.return_value = True + mock_span.get_span_context.return_value = mock_span_context + + with patch("mcpgateway.services.logging_service.trace") as mock_trace: + mock_trace.get_current_span.return_value = mock_span + + # Log a message + logger.info("Test with both IDs") + + # Get the logged output + output = stream.getvalue() + log_record = json.loads(output.strip()) + + # Should include both correlation ID and trace context + assert log_record.get("request_id") == "both-test-id" + assert "trace_id" in log_record + assert "span_id" in log_record + + clear_correlation_id() + + +def test_formatter_multiple_log_entries(logger_with_formatter): + """Test that formatter handles multiple log entries correctly.""" + logger, stream = logger_with_formatter + + # Log multiple messages with different correlation IDs + set_correlation_id("first-id") + logger.info("First message") + + set_correlation_id("second-id") + logger.info("Second message") + + clear_correlation_id() + logger.info("Third message") + + # Get all logged output + output = stream.getvalue() + log_lines = output.strip().split("\n") + + assert len(log_lines) == 3 + + # Parse each line + first_record = json.loads(log_lines[0]) + second_record = json.loads(log_lines[1]) + third_record = json.loads(log_lines[2]) + + # Verify correlation IDs + assert first_record.get("request_id") == "first-id" + assert second_record.get("request_id") == "second-id" + assert "request_id" not in third_record or third_record.get("request_id") is None + + +def test_formatter_process_id_and_hostname(logger_with_formatter): + """Test that formatter includes process ID and hostname.""" + logger, stream = logger_with_formatter + + # Log a message + logger.info("Process info test") + + # Get the logged output + output = stream.getvalue() + log_record = json.loads(output.strip()) + + # Check process_id and hostname + assert "process_id" in log_record + assert isinstance(log_record["process_id"], int) + assert log_record["process_id"] > 0 + + assert "hostname" in log_record + assert isinstance(log_record["hostname"], str) + assert len(log_record["hostname"]) > 0 + + +def test_formatter_handles_invalid_span_context(logger_with_formatter): + """Test that formatter handles invalid span context gracefully.""" + logger, stream = logger_with_formatter + + # Mock span with invalid context + mock_span_context = Mock() + mock_span_context.is_valid = False + + mock_span = Mock() + mock_span.is_recording.return_value = True + mock_span.get_span_context.return_value = mock_span_context + + with patch("mcpgateway.services.logging_service.trace") as mock_trace: + mock_trace.get_current_span.return_value = mock_span + + # Log a message + logger.info("Test with invalid span") + + # Get the logged output + output = stream.getvalue() + log_record = json.loads(output.strip()) + + # Should not include trace context when invalid + assert "trace_id" not in log_record + assert "span_id" not in log_record + # But message should still be logged + assert log_record["message"] == "Test with invalid span" diff --git a/tests/unit/mcpgateway/services/test_export_service.py b/tests/unit/mcpgateway/services/test_export_service.py index 0c60f803b..4f921a140 100644 --- a/tests/unit/mcpgateway/services/test_export_service.py +++ b/tests/unit/mcpgateway/services/test_export_service.py @@ -726,7 +726,7 @@ async def test_export_servers_with_data(export_service, mock_db): mock_server.name = "test_server" mock_server.description = "Test server" mock_server.associated_tools = ["tool1", "tool2"] - mock_server.is_active = True + mock_server.enabled = True mock_server.tags = ["test", "api"] export_service.server_service.list_servers.return_value = [mock_server] @@ -803,7 +803,7 @@ async def test_export_resources_with_data(export_service, mock_db): mock_resource.uri = "file:///workspace/test.txt" mock_resource.description = "Test resource file" mock_resource.mime_type = "text/plain" - mock_resource.is_active = True + mock_resource.enabled = True mock_resource.tags = ["file", "text"] mock_resource.updated_at = datetime.now(timezone.utc) diff --git a/tests/unit/mcpgateway/services/test_gateway_service.py b/tests/unit/mcpgateway/services/test_gateway_service.py index 0a6d4e06f..7e443279b 100644 --- a/tests/unit/mcpgateway/services/test_gateway_service.py +++ b/tests/unit/mcpgateway/services/test_gateway_service.py @@ -67,6 +67,16 @@ def _make_execute_result(*, scalar: _R | None = None, scalars_list: list[_R] | N return result +@pytest.fixture(autouse=True) +def mock_logging_services(): + """Mock audit_trail and structured_logger to prevent database writes during tests.""" + with patch("mcpgateway.services.gateway_service.audit_trail") as mock_audit, \ + patch("mcpgateway.services.gateway_service.structured_logger") as mock_logger: + mock_audit.log_action = MagicMock(return_value=None) + mock_logger.log = MagicMock(return_value=None) + yield {"audit_trail": mock_audit, "structured_logger": mock_logger} + + @pytest.fixture(autouse=True) def _bypass_gatewayread_validation(monkeypatch): """ diff --git a/tests/unit/mcpgateway/services/test_gateway_service_extended.py b/tests/unit/mcpgateway/services/test_gateway_service_extended.py index 55a2cba62..b1196ad8d 100644 --- a/tests/unit/mcpgateway/services/test_gateway_service_extended.py +++ b/tests/unit/mcpgateway/services/test_gateway_service_extended.py @@ -1363,4 +1363,4 @@ async def test_helper_methods_complete_removal_scenario(self): assert len(prompts_to_remove) == 1 assert tools_to_remove[0].original_name == "old_tool" assert resources_to_remove[0].uri == "file:///old.txt" - assert prompts_to_remove[0].name == "old_prompt" \ No newline at end of file + assert prompts_to_remove[0].name == "old_prompt" diff --git a/tests/unit/mcpgateway/services/test_prompt_service.py b/tests/unit/mcpgateway/services/test_prompt_service.py index 8282f5c19..dbbcb806a 100644 --- a/tests/unit/mcpgateway/services/test_prompt_service.py +++ b/tests/unit/mcpgateway/services/test_prompt_service.py @@ -44,6 +44,16 @@ # --------------------------------------------------------------------------- +@pytest.fixture(autouse=True) +def mock_logging_services(): + """Mock audit_trail and structured_logger to prevent database writes during tests.""" + with patch("mcpgateway.services.prompt_service.audit_trail") as mock_audit, \ + patch("mcpgateway.services.prompt_service.structured_logger") as mock_logger: + mock_audit.log_action = MagicMock(return_value=None) + mock_logger.log = MagicMock(return_value=None) + yield {"audit_trail": mock_audit, "structured_logger": mock_logger} + + @pytest.fixture def mock_prompt(): """Create a mock prompt model.""" @@ -225,7 +235,7 @@ async def test_get_prompt_rendered(self, prompt_service, test_db): db_prompt = _build_db_prompt(template="Hello, {{ name }}!") test_db.execute = Mock(return_value=_make_execute_result(scalar=db_prompt)) - pr: PromptResult = await prompt_service.get_prompt(test_db, 1, {"name": "Alice"}) + pr: PromptResult = await prompt_service.get_prompt(test_db, "1", {"name": "Alice"}) assert isinstance(pr, PromptResult) assert len(pr.messages) == 1 @@ -239,7 +249,7 @@ async def test_get_prompt_not_found(self, prompt_service, test_db): test_db.execute = Mock(return_value=_make_execute_result(scalar=None)) with pytest.raises(PromptNotFoundError): - await prompt_service.get_prompt(test_db, 999) + await prompt_service.get_prompt(test_db, "999") @pytest.mark.asyncio async def test_get_prompt_inactive(self, prompt_service, test_db): @@ -251,7 +261,7 @@ async def test_get_prompt_inactive(self, prompt_service, test_db): ] ) with pytest.raises(PromptNotFoundError) as exc_info: - await prompt_service.get_prompt(test_db, 1) + await prompt_service.get_prompt(test_db, "1") assert "inactive" in str(exc_info.value) @pytest.mark.asyncio @@ -260,7 +270,7 @@ async def test_get_prompt_render_error(self, prompt_service, test_db): test_db.execute = Mock(return_value=_make_execute_result(scalar=db_prompt)) db_prompt.validate_arguments.side_effect = Exception("bad args") with pytest.raises(PromptError) as exc_info: - await prompt_service.get_prompt(test_db, 1, {"name": "Alice"}) + await prompt_service.get_prompt(test_db, "1", {"name": "Alice"}) assert "Failed to process prompt" in str(exc_info.value) @pytest.mark.asyncio diff --git a/tests/unit/mcpgateway/services/test_resource_ownership.py b/tests/unit/mcpgateway/services/test_resource_ownership.py index 6c70cb399..c3e4f82d6 100644 --- a/tests/unit/mcpgateway/services/test_resource_ownership.py +++ b/tests/unit/mcpgateway/services/test_resource_ownership.py @@ -26,6 +26,26 @@ from mcpgateway.services.a2a_service import A2AAgentService +@pytest.fixture(autouse=True) +def mock_logging_services(): + """Mock audit_trail and structured_logger to prevent database writes during tests.""" + with patch("mcpgateway.services.gateway_service.audit_trail") as mock_gw_audit, \ + patch("mcpgateway.services.gateway_service.structured_logger") as mock_gw_logger, \ + patch("mcpgateway.services.tool_service.audit_trail") as mock_tool_audit, \ + patch("mcpgateway.services.tool_service.structured_logger") as mock_tool_logger, \ + patch("mcpgateway.services.resource_service.audit_trail") as mock_res_audit, \ + patch("mcpgateway.services.resource_service.structured_logger") as mock_res_logger, \ + patch("mcpgateway.services.prompt_service.audit_trail") as mock_prompt_audit, \ + patch("mcpgateway.services.prompt_service.structured_logger") as mock_prompt_logger, \ + patch("mcpgateway.services.a2a_service.structured_logger") as mock_a2a_logger: + for mock in [mock_gw_audit, mock_tool_audit, mock_res_audit, mock_prompt_audit]: + mock.log_action = MagicMock(return_value=None) + for mock in [mock_gw_logger, mock_tool_logger, mock_res_logger, mock_prompt_logger, mock_a2a_logger]: + mock.log = MagicMock(return_value=None) + mock.info = MagicMock(return_value=None) + yield + + @pytest.fixture def mock_db_session(): """Create a mock database session.""" diff --git a/tests/unit/mcpgateway/services/test_resource_service.py b/tests/unit/mcpgateway/services/test_resource_service.py index 632b35cc1..2c9cb1517 100644 --- a/tests/unit/mcpgateway/services/test_resource_service.py +++ b/tests/unit/mcpgateway/services/test_resource_service.py @@ -37,6 +37,16 @@ # --------------------------------------------------------------------------- # +@pytest.fixture(autouse=True) +def mock_logging_services(): + """Mock audit_trail and structured_logger to prevent database writes during tests.""" + with patch("mcpgateway.services.resource_service.audit_trail") as mock_audit, \ + patch("mcpgateway.services.resource_service.structured_logger") as mock_logger: + mock_audit.log_action = MagicMock(return_value=None) + mock_logger.log = MagicMock(return_value=None) + yield {"audit_trail": mock_audit, "structured_logger": mock_logger} + + @pytest.fixture def resource_service(monkeypatch): """Create a ResourceService instance.""" @@ -1588,7 +1598,7 @@ async def test_read_template_resource_not_found(self): # One template in cache — but it does NOT match URI template_obj = ResourceTemplate( - id=1, + id="1", uriTemplate="file://search/{query}", name="search_template", description="Template for performing a file search", diff --git a/tests/unit/mcpgateway/services/test_resource_service_plugins.py b/tests/unit/mcpgateway/services/test_resource_service_plugins.py index 33856187f..d609c6a15 100644 --- a/tests/unit/mcpgateway/services/test_resource_service_plugins.py +++ b/tests/unit/mcpgateway/services/test_resource_service_plugins.py @@ -177,7 +177,7 @@ async def test_read_resource_with_pre_fetch_hook(self, mock_ssl, resource_servic mock_ctx = MagicMock() mock_ssl.return_value = mock_ctx - + # Mock DB row returned by scalar_one_or_none mock_db_row = MagicMock() mock_db_row.content = fake_resource_content @@ -288,7 +288,7 @@ async def test_read_resource_uri_modified_by_plugin(self, mock_ssl, mock_db, res mock_db_row.content = fake_resource_content mock_db_row.uri = fake_resource_content.uri mock_db_row.uri_template = None - + mock_ctx = MagicMock() mock_ssl.return_value = mock_ctx @@ -616,7 +616,7 @@ async def test_read_resource_no_request_id(self, mock_ssl,resource_service_with_ mock_ctx = MagicMock() mock_ssl.return_value = mock_ctx - + # Setup mock resource mock_resource = MagicMock() mock_resource.content = ResourceContent(type="resource", id="test://resource", uri="test://resource", text="Test") diff --git a/tests/unit/mcpgateway/services/test_server_service.py b/tests/unit/mcpgateway/services/test_server_service.py index 333cae283..27ff9ef0f 100644 --- a/tests/unit/mcpgateway/services/test_server_service.py +++ b/tests/unit/mcpgateway/services/test_server_service.py @@ -516,16 +516,16 @@ async def test_update_server(self, server_service, mock_server, test_db, mock_to side_effect=lambda cls, _id: ( mock_server if (cls, _id) == (DbServer, 1) - else None + else None ) ) # FIX: Configure db.execute to handle both the conflict check and the bulk item fetches mock_db_result = MagicMock() - + # 1. Handle name conflict check: scalar_one_or_none() -> None mock_db_result.scalar_one_or_none.return_value = None - + # 2. Handle bulk fetches: scalars().all() -> lists of items # The code executes bulk queries in this order: Tools -> Resources -> Prompts mock_db_result.scalars.return_value.all.side_effect = [ @@ -533,9 +533,9 @@ async def test_update_server(self, server_service, mock_server, test_db, mock_to [new_resource], # Second call: select(DbResource)... [new_prompt] # Third call: select(DbPrompt)... ] - + test_db.execute = Mock(return_value=mock_db_result) - + test_db.commit = Mock() test_db.refresh = Mock() @@ -553,7 +553,7 @@ async def test_update_server(self, server_service, mock_server, test_db, mock_to mock_tools.__iter__ = Mock(return_value=iter(tool_items)) mock_resources.__iter__ = Mock(return_value=iter(resource_items)) mock_prompts.__iter__ = Mock(return_value=iter(prompt_items)) - + # Capture assignment to the lists (since the new code does server.tools = list(...)) mock_server.tools = tool_items mock_server.resources = resource_items diff --git a/tests/unit/mcpgateway/services/test_tool_service.py b/tests/unit/mcpgateway/services/test_tool_service.py index 5beeeab27..6f6fe4e90 100644 --- a/tests/unit/mcpgateway/services/test_tool_service.py +++ b/tests/unit/mcpgateway/services/test_tool_service.py @@ -36,6 +36,16 @@ from mcpgateway.utils.services_auth import encode_auth +@pytest.fixture(autouse=True) +def mock_logging_services(): + """Mock audit_trail and structured_logger to prevent database writes during tests.""" + with patch("mcpgateway.services.tool_service.audit_trail") as mock_audit, \ + patch("mcpgateway.services.tool_service.structured_logger") as mock_logger: + mock_audit.log_action = MagicMock(return_value=None) + mock_logger.log = MagicMock(return_value=None) + yield {"audit_trail": mock_audit, "structured_logger": mock_logger} + + @pytest.fixture def tool_service(): """Create a tool service instance.""" @@ -290,7 +300,8 @@ async def test_register_tool(self, tool_service, mock_tool, test_db): # Verify DB operations test_db.add.assert_called_once() test_db.commit.assert_called_once() - test_db.refresh.assert_called_once() + # refresh is called twice: once after commit and once after logging commits + assert test_db.refresh.call_count == 2 # Verify result assert result.name == "test-gateway-test-tool" @@ -1875,7 +1886,7 @@ async def test_aggregate_metrics_no_data(self, tool_service): "avg_response_time": None, "last_execution_time": None, } - + # Verify optimization assert mock_db.execute.call_count == 1 @@ -1987,7 +1998,7 @@ async def test_get_top_tools(self, tool_service, test_db): with patch("mcpgateway.services.tool_service.build_top_performers") as mock_build: mock_build.return_value = ["top_performer1", "top_performer2"] - + # Run the method result = await tool_service.get_top_tools(test_db, limit=5) @@ -1996,7 +2007,7 @@ async def test_get_top_tools(self, tool_service, test_db): # Assert build_top_performers was called with the mock results mock_build.assert_called_once_with(mock_results) - + # Verify that the execute method was called once test_db.execute.assert_called_once() diff --git a/tests/unit/mcpgateway/test_admin.py b/tests/unit/mcpgateway/test_admin.py index 21b5791f1..5eb16a540 100644 --- a/tests/unit/mcpgateway/test_admin.py +++ b/tests/unit/mcpgateway/test_admin.py @@ -804,17 +804,19 @@ async def test_admin_list_resources_with_complex_data(self, mock_list_resources, @patch.object(ResourceService, "get_resource_by_id") @patch.object(ResourceService, "read_resource") async def test_admin_get_resource_with_read_error(self, mock_read_resource, mock_get_resource, mock_db): - """Test getting resource when content read fails.""" - # Resource exists + """Test: read_resource should not be called at all.""" + mock_resource = MagicMock() mock_resource.model_dump.return_value = {"id": 1, "uri": "/test/resource"} mock_get_resource.return_value = mock_resource - # But reading content fails mock_read_resource.side_effect = IOError("Cannot read resource content") - with pytest.raises(IOError): - await admin_get_resource("1", mock_db, "test-user") + result = await admin_get_resource("1", mock_db, "test-user") + + assert result["resource"]["id"] == 1 + mock_read_resource.assert_not_called() + @patch.object(ResourceService, "register_resource") async def test_admin_add_resource_with_valid_mime_type(self, mock_register_resource, mock_request, mock_db): @@ -1357,7 +1359,8 @@ async def test_admin_test_gateway_various_methods(self): mock_client_class.return_value = mock_client - result = await admin_test_gateway(request, "test-user") + mock_db = MagicMock() + result = await admin_test_gateway(request, None, "test-user", mock_db) assert result.status_code == 200 mock_client.request.assert_called_once() @@ -1396,7 +1399,8 @@ async def test_admin_test_gateway_url_construction(self): mock_client_class.return_value = mock_client - await admin_test_gateway(request, "test-user") + mock_db = MagicMock() + await admin_test_gateway(request, None, "test-user", mock_db) call_args = mock_client.request.call_args assert call_args[1]["url"] == expected_url @@ -1422,7 +1426,8 @@ async def test_admin_test_gateway_timeout_handling(self): mock_client_class.return_value = mock_client - result = await admin_test_gateway(request, "test-user") + mock_db = MagicMock() + result = await admin_test_gateway(request, None, "test-user", mock_db) assert result.status_code == 502 assert "Request timed out" in str(result.body) @@ -1459,7 +1464,8 @@ async def test_admin_test_gateway_non_json_response(self): mock_client_class.return_value = mock_client - result = await admin_test_gateway(request, "test-user") + mock_db = MagicMock() + result = await admin_test_gateway(request, None, "test-user", mock_db) assert result.status_code == 200 assert result.body["details"] == response_text diff --git a/tests/unit/mcpgateway/test_main.py b/tests/unit/mcpgateway/test_main.py index 94e0beb80..fabf401bc 100644 --- a/tests/unit/mcpgateway/test_main.py +++ b/tests/unit/mcpgateway/test_main.py @@ -193,13 +193,20 @@ def test_client(app): # Patch the auth function used by DocsAuthMiddleware # Standard - from unittest.mock import patch + from unittest.mock import MagicMock, patch # Third-Party from fastapi import HTTPException, status # First-Party + # Mock security_logger to prevent database access + mock_sec_logger = MagicMock() + mock_sec_logger.log_authentication_attempt = MagicMock(return_value=None) + mock_sec_logger.log_security_event = MagicMock(return_value=None) + sec_patcher = patch("mcpgateway.middleware.auth_middleware.security_logger", mock_sec_logger) + sec_patcher.start() + # Create a mock that validates JWT tokens properly async def mock_require_auth_override(auth_header=None, jwt_token=None): # Third-Party @@ -270,6 +277,7 @@ async def mock_check_permission(self, user_email: str, permission: str, resource app.dependency_overrides.pop(get_current_user, None) app.dependency_overrides.pop(get_current_user_with_permissions, None) patcher.stop() # Stop the require_auth_override patch + sec_patcher.stop() # Stop the security_logger patch if hasattr(PermissionService, "_original_check_permission"): PermissionService.check_permission = PermissionService._original_check_permission diff --git a/tests/unit/mcpgateway/test_main_extended.py b/tests/unit/mcpgateway/test_main_extended.py index ceb40f763..2079b692d 100644 --- a/tests/unit/mcpgateway/test_main_extended.py +++ b/tests/unit/mcpgateway/test_main_extended.py @@ -324,7 +324,7 @@ def test_server_toggle_edge_cases(self, test_client, auth_headers): def test_client(app): """Test client with auth override for testing protected endpoints.""" # Standard - from unittest.mock import patch + from unittest.mock import MagicMock, patch # First-Party from mcpgateway.auth import get_current_user @@ -341,6 +341,13 @@ def test_client(app): auth_provider="test", ) + # Mock security_logger to prevent database access + mock_sec_logger = MagicMock() + mock_sec_logger.log_authentication_attempt = MagicMock(return_value=None) + mock_sec_logger.log_security_event = MagicMock(return_value=None) + sec_patcher = patch("mcpgateway.middleware.auth_middleware.security_logger", mock_sec_logger) + sec_patcher.start() + # Mock require_auth_override function def mock_require_auth_override(user: str) -> str: return user @@ -390,6 +397,7 @@ async def mock_check_permission( app.dependency_overrides.pop(get_current_user, None) app.dependency_overrides.pop(get_current_user_with_permissions, None) patcher.stop() # Stop the require_auth_override patch + sec_patcher.stop() # Stop the security_logger patch if hasattr(PermissionService, "_original_check_permission"): PermissionService.check_permission = PermissionService._original_check_permission diff --git a/tests/unit/mcpgateway/test_oauth_manager.py b/tests/unit/mcpgateway/test_oauth_manager.py index 3431119a6..01bc68614 100644 --- a/tests/unit/mcpgateway/test_oauth_manager.py +++ b/tests/unit/mcpgateway/test_oauth_manager.py @@ -2168,7 +2168,7 @@ def test_is_token_expired_no_expires_at(self): result = service._is_token_expired(token_record) - assert result is True + assert result is False def test_is_token_expired_past_expiry(self): """Test _is_token_expired with past expiration.""" diff --git a/tests/unit/mcpgateway/test_translate_stdio_endpoint.py b/tests/unit/mcpgateway/test_translate_stdio_endpoint.py index 708d605ed..a2464b33d 100644 --- a/tests/unit/mcpgateway/test_translate_stdio_endpoint.py +++ b/tests/unit/mcpgateway/test_translate_stdio_endpoint.py @@ -315,7 +315,7 @@ async def test_empty_env_vars(self, echo_script): await endpoint.send("hello world\n") # Wait for response - await asyncio.sleep(0.1) + await asyncio.sleep(0.5) # Check that process was started assert endpoint._proc is not None diff --git a/tests/unit/mcpgateway/tools/__init__.py b/tests/unit/mcpgateway/tools/__init__.py new file mode 100644 index 000000000..eee1aa024 --- /dev/null +++ b/tests/unit/mcpgateway/tools/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- +"""Location: ./tests/unit/mcpgateway/tools/__init__.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor +""" diff --git a/tests/unit/mcpgateway/tools/builder/__init__.py b/tests/unit/mcpgateway/tools/builder/__init__.py new file mode 100644 index 000000000..e63d648ed --- /dev/null +++ b/tests/unit/mcpgateway/tools/builder/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- +"""Location: ./tests/unit/mcpgateway/tools/builder/__init__.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor +""" diff --git a/tests/unit/mcpgateway/tools/builder/test_cli.py b/tests/unit/mcpgateway/tools/builder/test_cli.py new file mode 100644 index 000000000..5328f03c3 --- /dev/null +++ b/tests/unit/mcpgateway/tools/builder/test_cli.py @@ -0,0 +1,509 @@ +# -*- coding: utf-8 -*- +"""Location: ./tests/unit/mcpgateway/tools/builder/test_cli.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Unit tests for builder CLI commands. +""" + +# Standard +import os +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +# Third-Party +import pytest +import typer +from typer.testing import CliRunner + +# First-Party +from mcpgateway.tools.builder.cli import app, main + + +@pytest.fixture +def runner(): + """Create CLI test runner.""" + return CliRunner() + + +@pytest.fixture +def mock_deployer(): + """Create mock deployer instance.""" + deployer = MagicMock() + deployer.validate = MagicMock() + deployer.build = AsyncMock() + deployer.generate_certificates = AsyncMock() + deployer.deploy = AsyncMock() + deployer.verify = AsyncMock() + deployer.destroy = AsyncMock() + deployer.generate_manifests = MagicMock(return_value=Path("/tmp/manifests")) + return deployer + + +class TestCLICallback: + """Test CLI callback initialization.""" + + @patch("mcpgateway.tools.builder.cli.DeployFactory.create_deployer") + def test_cli_callback_default(self, mock_factory, runner): + """Test CLI callback with default options (Python mode by default).""" + mock_deployer = MagicMock() + mock_factory.return_value = (mock_deployer, "python") + + result = runner.invoke(app, ["--help"]) + assert result.exit_code == 0 + + @patch("mcpgateway.tools.builder.cli.DeployFactory.create_deployer") + def test_cli_callback_verbose(self, mock_factory, runner): + """Test CLI callback with verbose flag (Python mode by default).""" + mock_deployer = MagicMock() + mock_factory.return_value = (mock_deployer, "python") + + result = runner.invoke(app, ["--verbose", "--help"]) + assert result.exit_code == 0 + + @patch("mcpgateway.tools.builder.cli.DeployFactory.create_deployer") + def test_cli_callback_with_dagger(self, mock_factory, runner, tmp_path): + """Test CLI callback with --dagger flag (opt-in).""" + mock_deployer = MagicMock() + mock_deployer.validate = MagicMock() + mock_factory.return_value = (mock_deployer, "dagger") + + config_file = tmp_path / "test-config.yaml" + config_file.write_text("deployment:\n type: compose\n") + + # Use validate command which invokes the callback + result = runner.invoke(app, ["--dagger", "validate", str(config_file)]) + assert result.exit_code == 0 + # Verify dagger mode was requested + mock_factory.assert_called_once_with("dagger", False) + + @patch("mcpgateway.tools.builder.cli.DeployFactory.create_deployer") + def test_cli_callback_default_python(self, mock_factory, runner, tmp_path): + """Test CLI callback defaults to Python mode.""" + mock_deployer = MagicMock() + mock_deployer.validate = MagicMock() + mock_factory.return_value = (mock_deployer, "python") + + config_file = tmp_path / "test-config.yaml" + config_file.write_text("deployment:\n type: compose\n") + + # Use validate command without --dagger flag to test default + result = runner.invoke(app, ["validate", str(config_file)]) + assert result.exit_code == 0 + # Verify python mode was requested (default) + mock_factory.assert_called_once_with("python", False) + + +class TestValidateCommand: + """Test validate command.""" + + @patch("mcpgateway.tools.builder.cli.DeployFactory.create_deployer") + def test_validate_success(self, mock_factory, runner, tmp_path, mock_deployer): + """Test successful configuration validation.""" + config_file = tmp_path / "mcp-stack.yaml" + config_file.write_text("deployment:\n type: compose\n") + + mock_factory.return_value = (mock_deployer, "python") + mock_deployer.validate.return_value = None + + result = runner.invoke(app, ["validate", str(config_file)]) + assert result.exit_code == 0 + assert "Configuration valid" in result.stdout + mock_deployer.validate.assert_called_once() + + @patch("mcpgateway.tools.builder.cli.DeployFactory.create_deployer") + def test_validate_failure(self, mock_factory, runner, tmp_path, mock_deployer): + """Test validation failure.""" + config_file = tmp_path / "invalid-config.yaml" + config_file.write_text("invalid: yaml\n") + + mock_factory.return_value = (mock_deployer, "python") + mock_deployer.validate.side_effect = ValueError("Invalid configuration") + + result = runner.invoke(app, ["validate", str(config_file)]) + assert result.exit_code == 1 + assert "Validation failed" in result.stdout + + +class TestBuildCommand: + """Test build command.""" + + @patch("mcpgateway.tools.builder.cli.DeployFactory.create_deployer") + def test_build_success(self, mock_factory, runner, tmp_path, mock_deployer): + """Test successful build.""" + config_file = tmp_path / "mcp-stack.yaml" + config_file.write_text("gateway:\n image: test:latest\n") + + mock_factory.return_value = (mock_deployer, "python") + + result = runner.invoke(app, ["build", str(config_file)]) + assert result.exit_code == 0 + assert "Build complete" in result.stdout + mock_deployer.build.assert_called_once() + + @patch("mcpgateway.tools.builder.cli.DeployFactory.create_deployer") + def test_build_plugins_only(self, mock_factory, runner, tmp_path, mock_deployer): + """Test building only plugins.""" + config_file = tmp_path / "mcp-stack.yaml" + config_file.write_text("plugins:\n - name: TestPlugin\n") + + mock_factory.return_value = (mock_deployer, "python") + + result = runner.invoke(app, ["build", str(config_file), "--plugins-only"]) + assert result.exit_code == 0 + # Verify plugins_only flag was passed + call_kwargs = mock_deployer.build.call_args[1] + assert call_kwargs["plugins_only"] is True + + @patch("mcpgateway.tools.builder.cli.DeployFactory.create_deployer") + def test_build_specific_plugins(self, mock_factory, runner, tmp_path, mock_deployer): + """Test building specific plugins.""" + config_file = tmp_path / "mcp-stack.yaml" + config_file.write_text("plugins:\n - name: Plugin1\n - name: Plugin2\n") + + mock_factory.return_value = (mock_deployer, "python") + + result = runner.invoke( + app, ["build", str(config_file), "--plugin", "Plugin1", "--plugin", "Plugin2"] + ) + assert result.exit_code == 0 + + @patch("mcpgateway.tools.builder.cli.DeployFactory.create_deployer") + def test_build_no_cache(self, mock_factory, runner, tmp_path, mock_deployer): + """Test building with --no-cache flag.""" + config_file = tmp_path / "mcp-stack.yaml" + config_file.write_text("gateway:\n image: test:latest\n") + + mock_factory.return_value = (mock_deployer, "python") + + result = runner.invoke(app, ["build", str(config_file), "--no-cache"]) + assert result.exit_code == 0 + call_kwargs = mock_deployer.build.call_args[1] + assert call_kwargs["no_cache"] is True + + @patch("mcpgateway.tools.builder.cli.DeployFactory.create_deployer") + def test_build_failure(self, mock_factory, runner, tmp_path, mock_deployer): + """Test build failure.""" + config_file = tmp_path / "mcp-stack.yaml" + config_file.write_text("gateway:\n image: test:latest\n") + + mock_factory.return_value = (mock_deployer, "python") + mock_deployer.build.side_effect = RuntimeError("Build failed") + + result = runner.invoke(app, ["build", str(config_file)]) + assert result.exit_code == 1 + assert "Build failed" in result.stdout + + +class TestCertsCommand: + """Test certs command.""" + + @patch("mcpgateway.tools.builder.cli.DeployFactory.create_deployer") + def test_certs_success(self, mock_factory, runner, tmp_path, mock_deployer): + """Test successful certificate generation.""" + config_file = tmp_path / "mcp-stack.yaml" + config_file.write_text("plugins:\n - name: TestPlugin\n") + + mock_factory.return_value = (mock_deployer, "python") + + result = runner.invoke(app, ["certs", str(config_file)]) + assert result.exit_code == 0 + assert "Certificates generated" in result.stdout + mock_deployer.generate_certificates.assert_called_once() + + @patch("mcpgateway.tools.builder.cli.DeployFactory.create_deployer") + def test_certs_failure(self, mock_factory, runner, tmp_path, mock_deployer): + """Test certificate generation failure.""" + config_file = tmp_path / "mcp-stack.yaml" + config_file.write_text("plugins:\n - name: TestPlugin\n") + + mock_factory.return_value = (mock_deployer, "python") + mock_deployer.generate_certificates.side_effect = RuntimeError("Cert generation failed") + + result = runner.invoke(app, ["certs", str(config_file)]) + assert result.exit_code == 1 + assert "Certificate generation failed" in result.stdout + + +class TestDeployCommand: + """Test deploy command.""" + + @patch("mcpgateway.tools.builder.cli.DeployFactory.create_deployer") + def test_deploy_success(self, mock_factory, runner, tmp_path, mock_deployer): + """Test successful deployment.""" + config_file = tmp_path / "mcp-stack.yaml" + config_file.write_text("deployment:\n type: compose\n") + + mock_factory.return_value = (mock_deployer, "python") + + result = runner.invoke(app, ["deploy", str(config_file)]) + assert result.exit_code == 0 + assert "Deployment complete" in result.stdout + mock_deployer.deploy.assert_called_once() + + @patch("mcpgateway.tools.builder.cli.DeployFactory.create_deployer") + def test_deploy_dry_run(self, mock_factory, runner, tmp_path, mock_deployer): + """Test dry-run deployment.""" + config_file = tmp_path / "mcp-stack.yaml" + config_file.write_text("deployment:\n type: compose\n") + + mock_factory.return_value = (mock_deployer, "python") + + result = runner.invoke(app, ["deploy", str(config_file), "--dry-run"]) + assert result.exit_code == 0 + assert "Dry-run complete" in result.stdout + call_kwargs = mock_deployer.deploy.call_args[1] + assert call_kwargs["dry_run"] is True + + @patch("mcpgateway.tools.builder.cli.DeployFactory.create_deployer") + def test_deploy_skip_build(self, mock_factory, runner, tmp_path, mock_deployer): + """Test deployment with --skip-build.""" + config_file = tmp_path / "mcp-stack.yaml" + config_file.write_text("deployment:\n type: compose\n") + + mock_factory.return_value = (mock_deployer, "python") + + result = runner.invoke(app, ["deploy", str(config_file), "--skip-build"]) + assert result.exit_code == 0 + call_kwargs = mock_deployer.deploy.call_args[1] + assert call_kwargs["skip_build"] is True + + @patch("mcpgateway.tools.builder.cli.DeployFactory.create_deployer") + def test_deploy_skip_certs(self, mock_factory, runner, tmp_path, mock_deployer): + """Test deployment with --skip-certs.""" + config_file = tmp_path / "mcp-stack.yaml" + config_file.write_text("deployment:\n type: compose\n") + + mock_factory.return_value = (mock_deployer, "python") + + result = runner.invoke(app, ["deploy", str(config_file), "--skip-certs"]) + assert result.exit_code == 0 + call_kwargs = mock_deployer.deploy.call_args[1] + assert call_kwargs["skip_certs"] is True + + @patch("mcpgateway.tools.builder.cli.DeployFactory.create_deployer") + def test_deploy_custom_output_dir(self, mock_factory, runner, tmp_path, mock_deployer): + """Test deployment with custom output directory.""" + config_file = tmp_path / "mcp-stack.yaml" + config_file.write_text("deployment:\n type: compose\n") + output_dir = tmp_path / "custom-output" + + mock_factory.return_value = (mock_deployer, "python") + + result = runner.invoke(app, ["deploy", str(config_file), "--output-dir", str(output_dir)]) + assert result.exit_code == 0 + + @patch("mcpgateway.tools.builder.cli.DeployFactory.create_deployer") + def test_deploy_failure(self, mock_factory, runner, tmp_path, mock_deployer): + """Test deployment failure.""" + config_file = tmp_path / "mcp-stack.yaml" + config_file.write_text("deployment:\n type: compose\n") + + mock_factory.return_value = (mock_deployer, "python") + mock_deployer.deploy.side_effect = RuntimeError("Deployment failed") + + result = runner.invoke(app, ["deploy", str(config_file)]) + assert result.exit_code == 1 + assert "Deployment failed" in result.stdout + + +class TestVerifyCommand: + """Test verify command.""" + + @patch("mcpgateway.tools.builder.cli.DeployFactory.create_deployer") + def test_verify_success(self, mock_factory, runner, tmp_path, mock_deployer): + """Test successful deployment verification.""" + config_file = tmp_path / "mcp-stack.yaml" + config_file.write_text("deployment:\n type: compose\n") + + mock_factory.return_value = (mock_deployer, "python") + + result = runner.invoke(app, ["verify", str(config_file)]) + assert result.exit_code == 0 + assert "Deployment healthy" in result.stdout + mock_deployer.verify.assert_called_once() + + @patch("mcpgateway.tools.builder.cli.DeployFactory.create_deployer") + def test_verify_with_wait(self, mock_factory, runner, tmp_path, mock_deployer): + """Test verification with default wait behavior (wait=True by default).""" + config_file = tmp_path / "mcp-stack.yaml" + config_file.write_text("deployment:\n type: compose\n") + + mock_factory.return_value = (mock_deployer, "python") + + # Default wait is True, so just run verify without any flags + result = runner.invoke(app, ["verify", str(config_file)]) + assert result.exit_code == 0 + call_kwargs = mock_deployer.verify.call_args[1] + assert call_kwargs["wait"] is True + + @patch("mcpgateway.tools.builder.cli.DeployFactory.create_deployer") + def test_verify_with_timeout(self, mock_factory, runner, tmp_path, mock_deployer): + """Test verification with custom timeout.""" + config_file = tmp_path / "mcp-stack.yaml" + config_file.write_text("deployment:\n type: compose\n") + + mock_factory.return_value = (mock_deployer, "python") + + result = runner.invoke(app, ["verify", str(config_file), "--timeout", "600"]) + assert result.exit_code == 0 + call_kwargs = mock_deployer.verify.call_args[1] + assert call_kwargs["timeout"] == 600 + + @patch("mcpgateway.tools.builder.cli.DeployFactory.create_deployer") + def test_verify_failure(self, mock_factory, runner, tmp_path, mock_deployer): + """Test verification failure.""" + config_file = tmp_path / "mcp-stack.yaml" + config_file.write_text("deployment:\n type: compose\n") + + mock_factory.return_value = (mock_deployer, "python") + mock_deployer.verify.side_effect = RuntimeError("Verification failed") + + result = runner.invoke(app, ["verify", str(config_file)]) + assert result.exit_code == 1 + assert "Verification failed" in result.stdout + + +class TestDestroyCommand: + """Test destroy command.""" + + @patch("mcpgateway.tools.builder.cli.DeployFactory.create_deployer") + def test_destroy_with_force(self, mock_factory, runner, tmp_path, mock_deployer): + """Test destroy with --force flag.""" + config_file = tmp_path / "mcp-stack.yaml" + config_file.write_text("deployment:\n type: compose\n") + + mock_factory.return_value = (mock_deployer, "python") + + result = runner.invoke(app, ["destroy", str(config_file), "--force"]) + assert result.exit_code == 0 + assert "Deployment destroyed" in result.stdout + mock_deployer.destroy.assert_called_once() + + @patch("mcpgateway.tools.builder.cli.DeployFactory.create_deployer") + def test_destroy_with_confirmation(self, mock_factory, runner, tmp_path, mock_deployer): + """Test destroy with user confirmation.""" + config_file = tmp_path / "mcp-stack.yaml" + config_file.write_text("deployment:\n type: compose\n") + + mock_factory.return_value = (mock_deployer, "python") + + # Simulate user confirming "yes" + result = runner.invoke(app, ["destroy", str(config_file)], input="y\n") + assert result.exit_code == 0 + assert "Deployment destroyed" in result.stdout + + @patch("mcpgateway.tools.builder.cli.DeployFactory.create_deployer") + def test_destroy_abort(self, mock_factory, runner, tmp_path, mock_deployer): + """Test aborting destroy command.""" + config_file = tmp_path / "mcp-stack.yaml" + config_file.write_text("deployment:\n type: compose\n") + + mock_factory.return_value = (mock_deployer, "python") + + # Simulate user declining "no" + result = runner.invoke(app, ["destroy", str(config_file)], input="n\n") + assert "Aborted" in result.stdout + mock_deployer.destroy.assert_not_called() + + @patch("mcpgateway.tools.builder.cli.DeployFactory.create_deployer") + def test_destroy_failure(self, mock_factory, runner, tmp_path, mock_deployer): + """Test destroy failure.""" + config_file = tmp_path / "mcp-stack.yaml" + config_file.write_text("deployment:\n type: compose\n") + + mock_factory.return_value = (mock_deployer, "python") + mock_deployer.destroy.side_effect = RuntimeError("Destruction failed") + + result = runner.invoke(app, ["destroy", str(config_file), "--force"]) + assert result.exit_code == 1 + assert "Destruction failed" in result.stdout + + +class TestGenerateCommand: + """Test generate command.""" + + @patch("mcpgateway.tools.builder.cli.DeployFactory.create_deployer") + def test_generate_success(self, mock_factory, runner, tmp_path, mock_deployer): + """Test successful manifest generation.""" + config_file = tmp_path / "mcp-stack.yaml" + config_file.write_text("deployment:\n type: compose\n") + + mock_factory.return_value = (mock_deployer, "python") + + result = runner.invoke(app, ["generate", str(config_file)]) + assert result.exit_code == 0 + assert "Manifests generated" in result.stdout + mock_deployer.generate_manifests.assert_called_once() + + @patch("mcpgateway.tools.builder.cli.DeployFactory.create_deployer") + def test_generate_with_output_dir(self, mock_factory, runner, tmp_path, mock_deployer): + """Test manifest generation with custom output directory.""" + config_file = tmp_path / "mcp-stack.yaml" + config_file.write_text("deployment:\n type: compose\n") + output_dir = tmp_path / "custom-manifests" + + mock_factory.return_value = (mock_deployer, "python") + + result = runner.invoke(app, ["generate", str(config_file), "--output", str(output_dir)]) + assert result.exit_code == 0 + + @patch("mcpgateway.tools.builder.cli.DeployFactory.create_deployer") + def test_generate_failure(self, mock_factory, runner, tmp_path, mock_deployer): + """Test manifest generation failure.""" + config_file = tmp_path / "mcp-stack.yaml" + config_file.write_text("deployment:\n type: compose\n") + + mock_factory.return_value = (mock_deployer, "python") + mock_deployer.generate_manifests.side_effect = ValueError("Generation failed") + + result = runner.invoke(app, ["generate", str(config_file)]) + assert result.exit_code == 1 + assert "Manifest generation failed" in result.stdout + + +class TestVersionCommand: + """Test version command.""" + + def test_version(self, runner): + """Test version command.""" + result = runner.invoke(app, ["version"]) + assert result.exit_code == 0 + assert "MCP Deploy" in result.stdout + assert "Version" in result.stdout + + +class TestMainFunction: + """Test main entry point.""" + + @patch("mcpgateway.tools.builder.cli.app") + def test_main_success(self, mock_app): + """Test successful main execution.""" + mock_app.return_value = None + main() + mock_app.assert_called_once() + + @patch("mcpgateway.tools.builder.cli.app") + def test_main_keyboard_interrupt(self, mock_app): + """Test main with keyboard interrupt.""" + mock_app.side_effect = KeyboardInterrupt() + with pytest.raises(SystemExit) as exc_info: + main() + assert exc_info.value.code == 130 + + @patch("mcpgateway.tools.builder.cli.app") + def test_main_exception_no_debug(self, mock_app): + """Test main with exception (no debug mode).""" + mock_app.side_effect = RuntimeError("Test error") + with pytest.raises(SystemExit) as exc_info: + main() + assert exc_info.value.code == 1 + + @patch("mcpgateway.tools.builder.cli.app") + @patch.dict(os.environ, {"MCP_DEBUG": "1"}) + def test_main_exception_debug_mode(self, mock_app): + """Test main with exception (debug mode enabled).""" + mock_app.side_effect = RuntimeError("Test error") + with pytest.raises(RuntimeError, match="Test error"): + main() diff --git a/tests/unit/mcpgateway/tools/builder/test_common.py b/tests/unit/mcpgateway/tools/builder/test_common.py new file mode 100644 index 000000000..fdfc26036 --- /dev/null +++ b/tests/unit/mcpgateway/tools/builder/test_common.py @@ -0,0 +1,994 @@ +# -*- coding: utf-8 -*- +"""Location: ./tests/unit/mcpgateway/tools/builder/test_common.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Unit tests for builder common utilities. +""" + +# Standard +import os +from pathlib import Path +import shutil +import subprocess +from unittest.mock import MagicMock, Mock, patch +from mcpgateway.tools.builder.schema import MCPStackConfig + +# Third-Party +import pytest +import yaml + +# First-Party +from mcpgateway.tools.builder.common import ( + copy_env_template, + deploy_compose, + deploy_kubernetes, + destroy_compose, + destroy_kubernetes, + generate_compose_manifests, + generate_kubernetes_manifests, + generate_plugin_config, + get_deploy_dir, + get_docker_compose_command, + load_config, + run_compose, + verify_compose, + verify_kubernetes, +) + + +class TestGetDeployDir: + """Test get_deploy_dir function.""" + + def test_default_deploy_dir(self): + """Test default deploy directory.""" + with patch.dict(os.environ, {}, clear=True): + result = get_deploy_dir() + assert result == Path("./deploy") + + def test_custom_deploy_dir(self): + """Test custom deploy directory from environment variable.""" + with patch.dict(os.environ, {"MCP_DEPLOY_DIR": "/custom/deploy"}): + result = get_deploy_dir() + assert result == Path("/custom/deploy") + + +class TestLoadConfig: + """Test load_config function.""" + + def test_load_valid_config(self, tmp_path): + """Test loading valid YAML configuration.""" + config_file = tmp_path / "mcp-stack.yaml" + config_data = { + "deployment": {"type": "compose", "project_name": "test"}, + "gateway": {"image": "mcpgateway:latest"}, + "plugins": [], + } + config_file.write_text(yaml.dump(config_data)) + + result = load_config(str(config_file)) + assert result.deployment.type == "compose" + assert result.gateway.image == "mcpgateway:latest" + + def test_load_nonexistent_config(self): + """Test loading non-existent configuration file.""" + with pytest.raises(FileNotFoundError, match="Configuration file not found"): + load_config("/nonexistent/config.yaml") + + +class TestGeneratePluginConfig: + """Test generate_plugin_config function.""" + + @patch("mcpgateway.tools.builder.common.Environment") + def test_generate_plugin_config_compose(self, mock_env_class, tmp_path): + """Test generating plugin config for Docker Compose deployment.""" + # Setup mock template + mock_template = MagicMock() + mock_template.render.return_value = "plugins:\n - name: TestPlugin\n" + mock_env = MagicMock() + mock_env.get_template.return_value = mock_template + mock_env_class.return_value = mock_env + + # Create fake template directory + template_dir = tmp_path / "templates" + template_dir.mkdir() + + config = MCPStackConfig.model_validate({ + "gateway": {"image": "mcpgateway:latest"}, + "deployment": {"type": "compose"}, + "plugins": [ + {"name": "TestPlugin", "port": 8000, "mtls_enabled": True, "repo": "https://github.com/test/plugin.git"} + ], + }) + + with patch("mcpgateway.tools.builder.common.Path") as mock_path: + mock_path.return_value.__truediv__.return_value = template_dir + output_dir = tmp_path / "output" + output_dir.mkdir() + + result = generate_plugin_config(config, output_dir) + + # Verify template was called + mock_env.get_template.assert_called_once_with("plugins-config.yaml.j2") + assert result == output_dir / "plugins-config.yaml" + + @patch("mcpgateway.tools.builder.common.Environment") + def test_generate_plugin_config_kubernetes(self, mock_env_class, tmp_path): + """Test generating plugin config for Kubernetes deployment.""" + # Setup mock template + mock_template = MagicMock() + mock_template.render.return_value = "plugins:\n - name: TestPlugin\n" + mock_env = MagicMock() + mock_env.get_template.return_value = mock_template + mock_env_class.return_value = mock_env + + # Create fake template directory + template_dir = tmp_path / "templates" + template_dir.mkdir() + + config = MCPStackConfig.model_validate({ + "gateway": {"image": "mcpgateway:latest"}, + "deployment": {"type": "kubernetes", "namespace": "test-ns"}, + "plugins": [ + {"name": "TestPlugin", "port": 8000, "mtls_enabled": False, "repo": "https://github.com/test/plugin1.git"} + ], + }) + + with patch("mcpgateway.tools.builder.common.Path") as mock_path: + mock_path.return_value.__truediv__.return_value = template_dir + output_dir = tmp_path / "output" + output_dir.mkdir() + + result = generate_plugin_config(config, output_dir) + + # Verify template was called + assert mock_env.get_template.called + assert result == output_dir / "plugins-config.yaml" + + @patch("mcpgateway.tools.builder.common.Environment") + def test_generate_plugin_config_with_overrides(self, mock_env_class, tmp_path): + """Test generating plugin config with plugin_overrides.""" + # Setup mock template + mock_template = MagicMock() + mock_template.render.return_value = "plugins:\n - name: TestPlugin\n" + mock_env = MagicMock() + mock_env.get_template.return_value = mock_template + mock_env_class.return_value = mock_env + + # Create fake template directory + template_dir = tmp_path / "templates" + template_dir.mkdir() + + config = MCPStackConfig.model_validate({ + "deployment": {"type": "compose"}, + "gateway": {"image": "mcpgateway:latest"}, + "plugins": [ + { + "name": "TestPlugin", + "port": 8000, + "plugin_overrides": { + "priority": 10, + "mode": "enforce", + "tags": ["security"], + }, + "repo": "https://github.com/test/plugin1.git" + } + ], + }) + + with patch("mcpgateway.tools.builder.common.Path") as mock_path: + mock_path.return_value.__truediv__.return_value = template_dir + output_dir = tmp_path / "output" + output_dir.mkdir() + + result = generate_plugin_config(config, output_dir) + assert result == output_dir / "plugins-config.yaml" + + +class TestCopyEnvTemplate: + """Test copy_env_template function.""" + + def test_copy_env_template_success(self, tmp_path): + """Test successful copying of .env.template.""" + # Create plugin build dir with .env.template + plugin_dir = tmp_path / "plugin" + plugin_dir.mkdir() + template_file = plugin_dir / ".env.template" + template_file.write_text("TEST_VAR=value\n") + + # Setup deploy dir + deploy_dir = tmp_path / "deploy" + + with patch("mcpgateway.tools.builder.common.get_deploy_dir", return_value=deploy_dir): + copy_env_template("TestPlugin", plugin_dir) + + target_file = deploy_dir / "env" / ".env.TestPlugin" + assert target_file.exists() + assert target_file.read_text() == "TEST_VAR=value\n" + + def test_copy_env_template_no_template(self, tmp_path): + """Test when .env.template doesn't exist.""" + plugin_dir = tmp_path / "plugin" + plugin_dir.mkdir() + + deploy_dir = tmp_path / "deploy" + + with patch("mcpgateway.tools.builder.common.get_deploy_dir", return_value=deploy_dir): + # Should not raise error, just skip + copy_env_template("TestPlugin", plugin_dir, verbose=True) + + def test_copy_env_template_target_exists(self, tmp_path): + """Test when target file already exists.""" + # Create plugin build dir with .env.template + plugin_dir = tmp_path / "plugin" + plugin_dir.mkdir() + template_file = plugin_dir / ".env.template" + template_file.write_text("NEW_VAR=newvalue\n") + + # Setup deploy dir with existing target + deploy_dir = tmp_path / "deploy" + deploy_dir.mkdir() + env_dir = deploy_dir / "env" + env_dir.mkdir() + target_file = env_dir / ".env.TestPlugin" + target_file.write_text("OLD_VAR=oldvalue\n") + + with patch("mcpgateway.tools.builder.common.get_deploy_dir", return_value=deploy_dir): + copy_env_template("TestPlugin", plugin_dir) + + # Should not overwrite + assert target_file.read_text() == "OLD_VAR=oldvalue\n" + + +class TestGetDockerComposeCommand: + """Test get_docker_compose_command function.""" + + @patch("mcpgateway.tools.builder.common.shutil.which") + @patch("mcpgateway.tools.builder.common.subprocess.run") + def test_docker_compose_plugin(self, mock_run, mock_which): + """Test detecting docker compose plugin.""" + mock_which.return_value = "/usr/bin/docker" + mock_run.return_value = Mock(returncode=0) + + result = get_docker_compose_command() + assert result == ["docker", "compose"] + + @patch("mcpgateway.tools.builder.common.shutil.which") + @patch("mcpgateway.tools.builder.common.subprocess.run") + def test_docker_compose_standalone(self, mock_run, mock_which): + """Test detecting standalone docker-compose.""" + + def which_side_effect(cmd): + if cmd == "docker": + return "/usr/bin/docker" + elif cmd == "docker-compose": + return "/usr/bin/docker-compose" + return None + + mock_which.side_effect = which_side_effect + mock_run.side_effect = subprocess.CalledProcessError(1, "cmd") + + result = get_docker_compose_command() + assert result == ["docker-compose"] + + @patch("mcpgateway.tools.builder.common.shutil.which") + def test_docker_compose_not_found(self, mock_which): + """Test when docker compose is not available.""" + mock_which.return_value = None + + with pytest.raises(RuntimeError, match="Docker Compose not found"): + get_docker_compose_command() + + +class TestRunCompose: + """Test run_compose function.""" + + @patch("mcpgateway.tools.builder.common.get_docker_compose_command") + @patch("mcpgateway.tools.builder.common.subprocess.run") + def test_run_compose_success(self, mock_run, mock_get_cmd, tmp_path): + """Test successful compose command execution.""" + compose_file = tmp_path / "docker-compose.yaml" + compose_file.write_text("services:\n test: {}\n") + + mock_get_cmd.return_value = ["docker", "compose"] + mock_run.return_value = Mock(returncode=0, stdout="Success", stderr="") + + result = run_compose(compose_file, ["ps"]) + assert result.returncode == 0 + mock_run.assert_called_once() + + @patch("mcpgateway.tools.builder.common.get_docker_compose_command") + def test_run_compose_file_not_found(self, mock_get_cmd, tmp_path): + """Test run_compose with non-existent file.""" + compose_file = tmp_path / "nonexistent.yaml" + mock_get_cmd.return_value = ["docker", "compose"] + + with pytest.raises(FileNotFoundError, match="Compose file not found"): + run_compose(compose_file, ["ps"]) + + @patch("mcpgateway.tools.builder.common.get_docker_compose_command") + @patch("mcpgateway.tools.builder.common.subprocess.run") + def test_run_compose_command_failure(self, mock_run, mock_get_cmd, tmp_path): + """Test run_compose command failure.""" + compose_file = tmp_path / "docker-compose.yaml" + compose_file.write_text("services:\n test: {}\n") + + mock_get_cmd.return_value = ["docker", "compose"] + mock_run.side_effect = subprocess.CalledProcessError( + 1, "cmd", output="", stderr="Error" + ) + + with pytest.raises(RuntimeError, match="Docker Compose failed"): + run_compose(compose_file, ["up", "-d"]) + + +class TestDeployCompose: + """Test deploy_compose function.""" + + @patch("mcpgateway.tools.builder.common.run_compose") + def test_deploy_compose_success(self, mock_run, tmp_path): + """Test successful Docker Compose deployment.""" + compose_file = tmp_path / "docker-compose.yaml" + mock_run.return_value = Mock(stdout="Deployed", stderr="") + + deploy_compose(compose_file) + mock_run.assert_called_once_with(compose_file, ["up", "-d"], verbose=False) + + +class TestVerifyCompose: + """Test verify_compose function.""" + + @patch("mcpgateway.tools.builder.common.run_compose") + def test_verify_compose(self, mock_run, tmp_path): + """Test verifying Docker Compose deployment.""" + compose_file = tmp_path / "docker-compose.yaml" + mock_run.return_value = Mock(stdout="test-service running", stderr="") + + result = verify_compose(compose_file) + assert "test-service running" in result + mock_run.assert_called_once_with(compose_file, ["ps"], verbose=False, check=False) + + +class TestDestroyCompose: + """Test destroy_compose function.""" + + @patch("mcpgateway.tools.builder.common.run_compose") + def test_destroy_compose_success(self, mock_run, tmp_path): + """Test successful Docker Compose destruction.""" + compose_file = tmp_path / "docker-compose.yaml" + compose_file.write_text("services:\n test: {}\n") + mock_run.return_value = Mock(stdout="Removed", stderr="") + + destroy_compose(compose_file) + mock_run.assert_called_once_with(compose_file, ["down", "-v"], verbose=False) + + def test_destroy_compose_file_not_found(self, tmp_path): + """Test destroying with non-existent compose file.""" + compose_file = tmp_path / "nonexistent.yaml" + + # Should not raise error, just print warning + destroy_compose(compose_file) + + +class TestDeployKubernetes: + """Test deploy_kubernetes function.""" + + @patch("mcpgateway.tools.builder.common.shutil.which") + @patch("mcpgateway.tools.builder.common.subprocess.run") + def test_deploy_kubernetes_success(self, mock_run, mock_which, tmp_path): + """Test successful Kubernetes deployment.""" + mock_which.return_value = "/usr/bin/kubectl" + mock_run.return_value = Mock(returncode=0, stdout="created", stderr="") + + manifests_dir = tmp_path / "manifests" + manifests_dir.mkdir() + (manifests_dir / "gateway-deployment.yaml").write_text("apiVersion: v1\n") + (manifests_dir / "plugins-config.yaml").write_text("plugins: []\n") + + deploy_kubernetes(manifests_dir) + assert mock_run.called + + @patch("mcpgateway.tools.builder.common.shutil.which") + def test_deploy_kubernetes_kubectl_not_found(self, mock_which, tmp_path): + """Test deployment when kubectl is not available.""" + mock_which.return_value = None + manifests_dir = tmp_path / "manifests" + + with pytest.raises(RuntimeError, match="kubectl not found"): + deploy_kubernetes(manifests_dir) + + @patch("mcpgateway.tools.builder.common.shutil.which") + @patch("mcpgateway.tools.builder.common.subprocess.run") + def test_deploy_kubernetes_with_certs(self, mock_run, mock_which, tmp_path): + """Test Kubernetes deployment with certificate secrets.""" + mock_which.return_value = "/usr/bin/kubectl" + mock_run.return_value = Mock(returncode=0, stdout="created", stderr="") + + manifests_dir = tmp_path / "manifests" + manifests_dir.mkdir() + (manifests_dir / "gateway-deployment.yaml").write_text("apiVersion: v1\n") + (manifests_dir / "cert-secrets.yaml").write_text("apiVersion: v1\n") + + deploy_kubernetes(manifests_dir) + assert mock_run.called + + +class TestVerifyKubernetes: + """Test verify_kubernetes function.""" + + @patch("mcpgateway.tools.builder.common.shutil.which") + @patch("mcpgateway.tools.builder.common.subprocess.run") + def test_verify_kubernetes_success(self, mock_run, mock_which): + """Test successful Kubernetes verification.""" + mock_which.return_value = "/usr/bin/kubectl" + mock_run.return_value = Mock( + returncode=0, stdout="pod-1 Running\npod-2 Running", stderr="" + ) + + result = verify_kubernetes("test-ns") + assert "Running" in result + mock_run.assert_called_once() + + @patch("mcpgateway.tools.builder.common.shutil.which") + def test_verify_kubernetes_kubectl_not_found(self, mock_which): + """Test verification when kubectl is not available.""" + mock_which.return_value = None + + with pytest.raises(RuntimeError, match="kubectl not found"): + verify_kubernetes("test-ns") + + @patch("mcpgateway.tools.builder.common.shutil.which") + @patch("mcpgateway.tools.builder.common.subprocess.run") + def test_verify_kubernetes_with_wait(self, mock_run, mock_which): + """Test Kubernetes verification with wait.""" + mock_which.return_value = "/usr/bin/kubectl" + mock_run.return_value = Mock(returncode=0, stdout="Ready", stderr="") + + result = verify_kubernetes("test-ns", wait=True, timeout=60) + assert mock_run.call_count >= 1 + + +class TestDestroyKubernetes: + """Test destroy_kubernetes function.""" + + @patch("mcpgateway.tools.builder.common.shutil.which") + @patch("mcpgateway.tools.builder.common.subprocess.run") + def test_destroy_kubernetes_success(self, mock_run, mock_which, tmp_path): + """Test successful Kubernetes destruction.""" + mock_which.return_value = "/usr/bin/kubectl" + mock_run.return_value = Mock(returncode=0, stdout="deleted", stderr="") + + manifests_dir = tmp_path / "manifests" + manifests_dir.mkdir() + (manifests_dir / "gateway-deployment.yaml").write_text("apiVersion: v1\n") + (manifests_dir / "plugins-config.yaml").write_text("plugins: []\n") + + destroy_kubernetes(manifests_dir) + assert mock_run.called + + @patch("mcpgateway.tools.builder.common.shutil.which") + def test_destroy_kubernetes_kubectl_not_found(self, mock_which, tmp_path): + """Test destruction when kubectl is not available.""" + mock_which.return_value = None + manifests_dir = tmp_path / "manifests" + + with pytest.raises(RuntimeError, match="kubectl not found"): + destroy_kubernetes(manifests_dir) + + def test_destroy_kubernetes_dir_not_found(self, tmp_path): + """Test destroying with non-existent manifests directory.""" + manifests_dir = tmp_path / "nonexistent" + + with patch("mcpgateway.tools.builder.common.shutil.which", return_value="/usr/bin/kubectl"): + # Should not raise error, just print warning + destroy_kubernetes(manifests_dir) + + +class TestGenerateKubernetesManifests: + """Test generate_kubernetes_manifests function with real template rendering.""" + + def test_generate_manifests_gateway_only(self, tmp_path): + """Test generating Kubernetes manifests for gateway only.""" + output_dir = tmp_path / "manifests" + output_dir.mkdir() + + config = MCPStackConfig.model_validate({ + "deployment": {"type": "kubernetes", "namespace": "test-ns"}, + "gateway": { + "image": "mcpgateway:latest", + "port": 4444, + "mtls_enabled": False, + }, + "plugins": [], + }) + + generate_kubernetes_manifests(config, output_dir) + + # Verify gateway deployment was created + gateway_file = output_dir / "gateway-deployment.yaml" + assert gateway_file.exists() + + # Parse and validate YAML + with open(gateway_file) as f: + docs = list(yaml.safe_load_all(f)) + + # Should have Deployment and Service + assert len(docs) >= 2 + + # Validate Deployment + deployment = next((d for d in docs if d.get("kind") == "Deployment"), None) + assert deployment is not None + assert deployment["metadata"]["name"] == "mcpgateway" + assert deployment["metadata"]["namespace"] == "test-ns" + assert deployment["spec"]["template"]["spec"]["containers"][0]["image"] == "mcpgateway:latest" + + # Validate Service + service = next((d for d in docs if d.get("kind") == "Service"), None) + assert service is not None + assert service["metadata"]["name"] == "mcpgateway" + assert service["spec"]["ports"][0]["port"] == 4444 + + def test_generate_manifests_with_plugins(self, tmp_path): + """Test generating Kubernetes manifests with plugins.""" + output_dir = tmp_path / "manifests" + output_dir.mkdir() + + config = MCPStackConfig.model_validate({ + "deployment": {"type": "kubernetes", "namespace": "mcp-test"}, + "gateway": { + "image": "mcpgateway:latest", + "port": 4444, + "mtls_enabled": False, + }, + "plugins": [ + { + "name": "TestPlugin", + "image": "test-plugin:v1", + "port": 8000, + "mtls_enabled": False, + }, + { + "name": "AnotherPlugin", + "image": "another-plugin:v2", + "port": 8001, + "mtls_enabled": False, + }, + ], + }) + + generate_kubernetes_manifests(config, output_dir) + + # Verify plugin deployments were created + plugin1_file = output_dir / "plugin-testplugin-deployment.yaml" + plugin2_file = output_dir / "plugin-anotherplugin-deployment.yaml" + + assert plugin1_file.exists() + assert plugin2_file.exists() + + # Parse and validate first plugin + with open(plugin1_file) as f: + docs = list(yaml.safe_load_all(f)) + + deployment = next((d for d in docs if d.get("kind") == "Deployment"), None) + assert deployment is not None + assert deployment["metadata"]["name"] == "mcp-plugin-testplugin" + assert deployment["metadata"]["namespace"] == "mcp-test" + assert deployment["spec"]["template"]["spec"]["containers"][0]["image"] == "test-plugin:v1" + + def test_generate_manifests_with_mtls(self, tmp_path): + """Test generating Kubernetes manifests with mTLS enabled.""" + # Change to tmp_path to ensure we have a valid working directory + original_dir = None + try: + original_dir = os.getcwd() + except (FileNotFoundError, OSError): + pass # Current directory doesn't exist + + os.chdir(tmp_path) + + try: + output_dir = tmp_path / "manifests" + output_dir.mkdir() + + # Create fake certificate files in the actual location where the code looks + certs_dir = Path("certs/mcp") + ca_dir = certs_dir / "ca" + gateway_dir = certs_dir / "gateway" + plugin_dir = certs_dir / "plugins" / "SecurePlugin" + + ca_dir.mkdir(parents=True, exist_ok=True) + gateway_dir.mkdir(parents=True, exist_ok=True) + plugin_dir.mkdir(parents=True, exist_ok=True) + + (ca_dir / "ca.crt").write_bytes(b"fake-ca-cert") + (gateway_dir / "client.crt").write_bytes(b"fake-gateway-cert") + (gateway_dir / "client.key").write_bytes(b"fake-gateway-key") + (plugin_dir / "server.crt").write_bytes(b"fake-plugin-cert") + (plugin_dir / "server.key").write_bytes(b"fake-plugin-key") + + config = MCPStackConfig.model_validate({ + "deployment": {"type": "kubernetes", "namespace": "secure-ns"}, + "gateway": { + "image": "mcpgateway:latest", + "port": 4444, + "mtls_enabled": True, + }, + "plugins": [ + { + "name": "SecurePlugin", + "image": "secure-plugin:v1", + "port": 8000, + "mtls_enabled": True, + } + ], + }) + + generate_kubernetes_manifests(config, output_dir) + finally: + # Clean up created certificate files + if Path("certs").exists(): + shutil.rmtree("certs") + + # Restore original directory if it exists + if original_dir and Path(original_dir).exists(): + os.chdir(original_dir) + + # Verify certificate secrets were created + cert_secrets_file = output_dir / "cert-secrets.yaml" + assert cert_secrets_file.exists() + + # Parse and validate secrets + with open(cert_secrets_file) as f: + docs = list(yaml.safe_load_all(f)) + + # Should have secrets for CA, gateway, and plugin + secrets = [d for d in docs if d.get("kind") == "Secret"] + assert len(secrets) >= 2 # At least gateway and plugin secrets + + def test_generate_manifests_with_infrastructure(self, tmp_path): + """Test generating Kubernetes manifests with PostgreSQL and Redis.""" + output_dir = tmp_path / "manifests" + output_dir.mkdir() + + config = MCPStackConfig.model_validate({ + "deployment": {"type": "kubernetes", "namespace": "infra-ns"}, + "gateway": { + "image": "mcpgateway:latest", + "port": 4444, + "mtls_enabled": False, + }, + "plugins": [], + "infrastructure": { + "postgres": { + "enabled": True, + "image": "postgres:17", + "database": "testdb", + "user": "testuser", + "password": "testpass", + }, + "redis": { + "enabled": True, + "image": "redis:alpine", + }, + }, + }) + + generate_kubernetes_manifests(config, output_dir) + + # Verify infrastructure manifests were created + postgres_file = output_dir / "postgres-deployment.yaml" + redis_file = output_dir / "redis-deployment.yaml" + + assert postgres_file.exists() + assert redis_file.exists() + + # Parse and validate PostgreSQL + with open(postgres_file) as f: + docs = list(yaml.safe_load_all(f)) + + postgres_deployment = next((d for d in docs if d.get("kind") == "Deployment"), None) + assert postgres_deployment is not None + assert postgres_deployment["metadata"]["name"] == "postgres" + assert postgres_deployment["spec"]["template"]["spec"]["containers"][0]["image"] == "postgres:17" + + # Parse and validate Redis + with open(redis_file) as f: + docs = list(yaml.safe_load_all(f)) + + redis_deployment = next((d for d in docs if d.get("kind") == "Deployment"), None) + assert redis_deployment is not None + assert redis_deployment["metadata"]["name"] == "redis" + + # Verify gateway has database environment variables in Secret + gateway_file = output_dir / "gateway-deployment.yaml" + with open(gateway_file) as f: + docs = list(yaml.safe_load_all(f)) + + # Find the Secret containing environment variables + secret = next((d for d in docs if d.get("kind") == "Secret" and d["metadata"]["name"] == "mcpgateway-env"), None) + assert secret is not None + assert "stringData" in secret + + string_data = secret["stringData"] + + # Check DATABASE_URL is set + assert "DATABASE_URL" in string_data + assert "postgresql://" in string_data["DATABASE_URL"] + assert "testuser:testpass" in string_data["DATABASE_URL"] + + # Check REDIS_URL is set + assert "REDIS_URL" in string_data + assert "redis://redis:6379" in string_data["REDIS_URL"] + + # Verify deployment references the Secret via envFrom + gateway_deployment = next((d for d in docs if d.get("kind") == "Deployment"), None) + assert gateway_deployment is not None + env_from = gateway_deployment["spec"]["template"]["spec"]["containers"][0]["envFrom"] + assert any(ref.get("secretRef", {}).get("name") == "mcpgateway-env" for ref in env_from) + + +class TestGenerateComposeManifests: + """Test generate_compose_manifests function with real template rendering.""" + + def test_generate_compose_gateway_only(self, tmp_path): + """Test generating Docker Compose manifest for gateway only.""" + output_dir = tmp_path / "manifests" + output_dir.mkdir() + + config = MCPStackConfig.model_validate({ + "deployment": {"type": "compose", "project_name": "test-mcp"}, + "gateway": { + "image": "mcpgateway:latest", + "port": 4444, + "host_port": 4444, + "mtls_enabled": False, + }, + "plugins": [], + }) + + with patch("mcpgateway.tools.builder.common.Path.cwd", return_value=tmp_path): + generate_compose_manifests(config, output_dir) + + # Verify compose file was created + compose_file = output_dir / "docker-compose.yaml" + assert compose_file.exists() + + # Parse and validate + with open(compose_file) as f: + compose_data = yaml.safe_load(f) + + assert "services" in compose_data + assert "mcpgateway" in compose_data["services"] + + gateway = compose_data["services"]["mcpgateway"] + assert gateway["image"] == "mcpgateway:latest" + assert gateway["ports"] == ["4444:4444"] + + def test_generate_compose_with_plugins(self, tmp_path): + """Test generating Docker Compose manifest with plugins.""" + output_dir = tmp_path / "manifests" + output_dir.mkdir() + + config = MCPStackConfig.model_validate({ + "deployment": {"type": "compose", "project_name": "mcp-stack"}, + "gateway": { + "image": "mcpgateway:latest", + "port": 4444, + "host_port": 4444, + "mtls_enabled": False, + }, + "plugins": [ + { + "name": "Plugin1", + "image": "plugin1:v1", + "port": 8000, + "expose_port": True, + "host_port": 8000, + "mtls_enabled": False, + }, + { + "name": "Plugin2", + "image": "plugin2:v1", + "port": 8001, + "expose_port": False, + "mtls_enabled": False, + }, + ], + }) + + with patch("mcpgateway.tools.builder.common.Path.cwd", return_value=tmp_path): + generate_compose_manifests(config, output_dir) + + # Parse and validate + compose_file = output_dir / "docker-compose.yaml" + with open(compose_file) as f: + compose_data = yaml.safe_load(f) + + # Verify plugins are in services + assert "plugin1" in compose_data["services"] + assert "plugin2" in compose_data["services"] + + plugin1 = compose_data["services"]["plugin1"] + assert plugin1["image"] == "plugin1:v1" + assert "8000:8000" in plugin1["ports"] # Exposed + + plugin2 = compose_data["services"]["plugin2"] + assert plugin2["image"] == "plugin2:v1" + # Plugin2 should not have host port mapping since expose_port is False + + def test_generate_compose_with_mtls(self, tmp_path): + """Test generating Docker Compose manifest with mTLS certificates.""" + output_dir = tmp_path / "manifests" + output_dir.mkdir() + + # Create fake certificate structure + certs_dir = tmp_path / "certs" / "mcp" + ca_dir = certs_dir / "ca" + gateway_dir = certs_dir / "gateway" + plugin_dir = certs_dir / "plugins" / "SecurePlugin" + + ca_dir.mkdir(parents=True) + gateway_dir.mkdir(parents=True) + plugin_dir.mkdir(parents=True) + + (ca_dir / "ca.crt").write_text("fake-ca") + (gateway_dir / "client.crt").write_text("fake-cert") + (gateway_dir / "client.key").write_text("fake-key") + (plugin_dir / "server.crt").write_text("fake-plugin-cert") + (plugin_dir / "server.key").write_text("fake-plugin-key") + + config = MCPStackConfig.model_validate({ + "deployment": {"type": "compose"}, + "gateway": { + "image": "mcpgateway:latest", + "port": 4444, + "host_port": 4444, + "mtls_enabled": True, + }, + "plugins": [ + { + "name": "SecurePlugin", + "image": "secure:v1", + "port": 8000, + "mtls_enabled": True, + } + ], + }) + + with patch("mcpgateway.tools.builder.common.Path.cwd", return_value=tmp_path): + generate_compose_manifests(config, output_dir) + + # Parse and validate + compose_file = output_dir / "docker-compose.yaml" + with open(compose_file) as f: + compose_data = yaml.safe_load(f) + + # Verify gateway has certificate volumes + gateway = compose_data["services"]["mcpgateway"] + assert "volumes" in gateway + # Should have volume mounts for certificates + volumes = gateway["volumes"] + assert any("certs" in str(v) or "ca.crt" in str(v) for v in volumes) + + # Verify plugin has certificate volumes + plugin = compose_data["services"]["secureplugin"] + assert "volumes" in plugin + + def test_generate_compose_with_env_files(self, tmp_path): + """Test generating Docker Compose manifest with environment files.""" + output_dir = tmp_path / "manifests" + output_dir.mkdir() + + # Create env files + deploy_dir = tmp_path / "deploy" + env_dir = deploy_dir / "env" + env_dir.mkdir(parents=True) + (env_dir / ".env.gateway").write_text("GATEWAY_VAR=value1\n") + (env_dir / ".env.TestPlugin").write_text("PLUGIN_VAR=value2\n") + + config = MCPStackConfig.model_validate({ + "deployment": {"type": "compose"}, + "gateway": { + "image": "mcpgateway:latest", + "port": 4444, + "mtls_enabled": False, + }, + "plugins": [ + { + "name": "TestPlugin", + "image": "test:v1", + "port": 8000, + "mtls_enabled": False, + } + ], + }) + + with patch("mcpgateway.tools.builder.common.get_deploy_dir", return_value=deploy_dir): + with patch("mcpgateway.tools.builder.common.Path.cwd", return_value=tmp_path): + generate_compose_manifests(config, output_dir) + + # Parse and validate + compose_file = output_dir / "docker-compose.yaml" + with open(compose_file) as f: + compose_data = yaml.safe_load(f) + + # Verify env_file is set + gateway = compose_data["services"]["mcpgateway"] + assert "env_file" in gateway + + plugin = compose_data["services"]["testplugin"] + assert "env_file" in plugin + + def test_generate_compose_with_infrastructure(self, tmp_path): + """Test generating Docker Compose manifest with PostgreSQL and Redis. + + Note: Currently the template uses hardcoded infrastructure images/config. + Infrastructure customization is not yet implemented for Docker Compose. + """ + output_dir = tmp_path / "manifests" + output_dir.mkdir() + + config = MCPStackConfig.model_validate({ + "deployment": {"type": "compose"}, + "gateway": { + "image": "mcpgateway:latest", + "port": 4444, + "mtls_enabled": False, + }, + "plugins": [], + "infrastructure": { + "postgres": { + "enabled": True, + "image": "postgres:17", + "database": "mcpdb", + "user": "mcpuser", + "password": "secret123", + }, + "redis": { + "enabled": True, + "image": "redis:7-alpine", + }, + }, + }) + + with patch("mcpgateway.tools.builder.common.Path.cwd", return_value=tmp_path): + generate_compose_manifests(config, output_dir) + + # Parse and validate + compose_file = output_dir / "docker-compose.yaml" + with open(compose_file) as f: + compose_data = yaml.safe_load(f) + + # Verify PostgreSQL service exists + # Note: Template uses hardcoded "postgres:17" and "mcp" database + assert "postgres" in compose_data["services"] + postgres = compose_data["services"]["postgres"] + assert postgres["image"] == "postgres:17" # Hardcoded in template + assert "environment" in postgres + + # Verify database name is "mcp" (hardcoded default, not "mcpdb" from config) + env = postgres["environment"] + if isinstance(env, list): + assert any("POSTGRES_DB=mcp" in str(e) for e in env) + else: + assert env["POSTGRES_DB"] == "mcp" + + # Verify Redis service exists + # Note: Template uses hardcoded "redis:latest" + assert "redis" in compose_data["services"] + redis = compose_data["services"]["redis"] + assert redis["image"] == "redis:latest" # Hardcoded in template + + # Verify gateway has database environment variables + gateway = compose_data["services"]["mcpgateway"] + assert "environment" in gateway + env = gateway["environment"] + + # Should have DATABASE_URL with default values + if isinstance(env, list): + db_url = next((e for e in env if "DATABASE_URL" in str(e)), None) + else: + db_url = env.get("DATABASE_URL") + assert db_url is not None + assert "postgresql://" in str(db_url) diff --git a/tests/unit/mcpgateway/tools/builder/test_dagger_deploy.py b/tests/unit/mcpgateway/tools/builder/test_dagger_deploy.py new file mode 100644 index 000000000..bc0f8ee87 --- /dev/null +++ b/tests/unit/mcpgateway/tools/builder/test_dagger_deploy.py @@ -0,0 +1,451 @@ +# -*- coding: utf-8 -*- +"""Location: ./tests/unit/mcpgateway/tools/builder/test_dagger_deploy.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Unit tests for Dagger-based MCP Stack deployment. + +These tests are skipped if Dagger is not installed. +""" + +# Standard +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +# Third-Party +import pytest + +# Check if dagger is available +try: + import dagger + + DAGGER_AVAILABLE = True +except ImportError: + DAGGER_AVAILABLE = False + +# Skip all tests in this module if Dagger is not available +pytestmark = pytest.mark.skipif(not DAGGER_AVAILABLE, reason="Dagger not installed") + +# Conditional import to avoid errors when Dagger is not installed +if DAGGER_AVAILABLE: + # First-Party + from mcpgateway.tools.builder.dagger_deploy import MCPStackDagger +else: + # Create a dummy class to avoid NameError in decorators + MCPStackDagger = type("MCPStackDagger", (), {}) + + +@pytest.fixture +def mock_dagger_connection(tmp_path): + """Fixture to mock Dagger connection and dag.""" + with patch("mcpgateway.tools.builder.dagger_deploy.dagger.connection") as mock_conn: + with patch("mcpgateway.tools.builder.dagger_deploy.dag") as mock_dag: + with patch("mcpgateway.tools.builder.dagger_deploy.Path.cwd") as mock_cwd: + # Mock Path.cwd() to return a valid temporary directory + mock_cwd.return_value = tmp_path + + # Mock the async context manager + mock_conn_ctx = AsyncMock() + mock_conn.return_value = mock_conn_ctx + mock_conn_ctx.__aenter__.return_value = None + mock_conn_ctx.__aexit__.return_value = None + + # Setup dag mocks (use regular Mock for synchronous Dagger API) + mock_git = Mock() + mock_tree = Mock() + mock_container = Mock() + mock_container.export_image = AsyncMock() # Only export_image is async + mock_host = Mock() + mock_dir = Mock() + mock_dir.export = AsyncMock() # export is async + + # Set up the method chain for git operations + mock_dag.git.return_value = mock_git + mock_git.branch.return_value = mock_git + mock_git.tree.return_value = mock_tree + mock_tree.docker_build.return_value = mock_container + + # Set up container operations + mock_dag.container.return_value = mock_container + mock_container.from_.return_value = mock_container + mock_container.with_exec.return_value = mock_container + mock_container.with_mounted_directory.return_value = mock_container + mock_container.with_workdir.return_value = mock_container + mock_container.directory.return_value = mock_dir + + # Set up host operations + mock_dag.host.return_value = mock_host + mock_host.directory.return_value = mock_dir + + yield {"connection": mock_conn, "dag": mock_dag, "container": mock_container} + + +class TestMCPStackDaggerInit: + """Test MCPStackDagger initialization.""" + + def test_init_default(self): + """Test default initialization.""" + stack = MCPStackDagger() + assert stack.verbose is False + + def test_init_verbose(self): + """Test initialization with verbose flag.""" + stack = MCPStackDagger(verbose=True) + assert stack.verbose is True + + +class TestMCPStackDaggerBuild: + """Test MCPStackDagger build method.""" + + @patch("mcpgateway.tools.builder.dagger_deploy.get_deploy_dir") + @patch("mcpgateway.tools.builder.dagger_deploy.load_config") + @pytest.mark.asyncio + async def test_build_gateway_only(self, mock_load, mock_get_deploy, mock_dagger_connection, tmp_path): + """Test building gateway container with Dagger.""" + mock_load.return_value = { + "gateway": {"repo": "https://github.com/test/gateway.git", "ref": "main"}, + "plugins": [], + } + mock_get_deploy.return_value = tmp_path / "deploy" + + stack = MCPStackDagger() + await stack.build("test-config.yaml") + + mock_load.assert_called_once_with("test-config.yaml") + + @patch("mcpgateway.tools.builder.dagger_deploy.get_deploy_dir") + @patch("mcpgateway.tools.builder.dagger_deploy.load_config") + @pytest.mark.asyncio + async def test_build_plugins_only(self, mock_load, mock_get_deploy, mock_dagger_connection, tmp_path): + """Test building only plugins.""" + mock_load.return_value = { + "gateway": {"repo": "https://github.com/test/gateway.git"}, + "plugins": [ + {"name": "Plugin1", "repo": "https://github.com/test/plugin1.git"} + ], + } + mock_get_deploy.return_value = tmp_path / "deploy" + + stack = MCPStackDagger() + await stack.build("test-config.yaml", plugins_only=True) + + mock_load.assert_called_once() + + @patch("mcpgateway.tools.builder.dagger_deploy.get_deploy_dir") + @patch("mcpgateway.tools.builder.dagger_deploy.load_config") + @pytest.mark.asyncio + async def test_build_specific_plugins(self, mock_load, mock_get_deploy, mock_dagger_connection, tmp_path): + """Test building specific plugins only.""" + mock_load.return_value = { + "gateway": {"image": "mcpgateway:latest"}, + "plugins": [ + {"name": "Plugin1", "repo": "https://github.com/test/plugin1.git"}, + {"name": "Plugin2", "repo": "https://github.com/test/plugin2.git"}, + ], + } + mock_get_deploy.return_value = tmp_path / "deploy" + + stack = MCPStackDagger() + await stack.build("test-config.yaml", specific_plugins=["Plugin1"]) + + mock_load.assert_called_once() + + @patch("mcpgateway.tools.builder.dagger_deploy.get_deploy_dir") + @patch("mcpgateway.tools.builder.dagger_deploy.load_config") + @pytest.mark.asyncio + async def test_build_no_plugins(self, mock_load, mock_get_deploy, mock_dagger_connection, tmp_path): + """Test building when no plugins are defined.""" + mock_load.return_value = { + "gateway": {"image": "mcpgateway:latest"}, + "plugins": [], + } + mock_get_deploy.return_value = tmp_path / "deploy" + + stack = MCPStackDagger() + # Should not raise error + await stack.build("test-config.yaml", plugins_only=True) + + +class TestMCPStackDaggerGenerateCertificates: + """Test MCPStackDagger generate_certificates method.""" + + @patch("mcpgateway.tools.builder.dagger_deploy.get_deploy_dir") + @patch("mcpgateway.tools.builder.dagger_deploy.load_config") + @pytest.mark.asyncio + async def test_generate_certificates(self, mock_load, mock_get_deploy, mock_dagger_connection, tmp_path): + """Test certificate generation with Dagger.""" + mock_load.return_value = { + "plugins": [ + {"name": "Plugin1"}, + {"name": "Plugin2"}, + ] + } + mock_get_deploy.return_value = tmp_path / "deploy" + + stack = MCPStackDagger() + await stack.generate_certificates("test-config.yaml") + + mock_load.assert_called_once() + + +class TestMCPStackDaggerDeploy: + """Test MCPStackDagger deploy method.""" + + @patch("mcpgateway.tools.builder.dagger_deploy.get_deploy_dir") + @patch("mcpgateway.tools.builder.dagger_deploy.load_config") + @patch.object(MCPStackDagger, "build") + @patch.object(MCPStackDagger, "generate_certificates") + @patch.object(MCPStackDagger, "generate_manifests") + @patch.object(MCPStackDagger, "_deploy_compose") + @pytest.mark.asyncio + async def test_deploy_compose_full( + self, mock_deploy, mock_gen_manifests, mock_certs, mock_build, mock_load, mock_get_deploy, mock_dagger_connection, tmp_path + ): + """Test full Docker Compose deployment with Dagger.""" + mock_load.return_value = { + "deployment": {"type": "compose", "project_name": "test"}, + "gateway": {"repo": "https://github.com/test/gateway.git", "mtls_enabled": True}, + "plugins": [], + } + mock_gen_manifests.return_value = Path("/tmp/manifests") + mock_get_deploy.return_value = tmp_path / "deploy" + + stack = MCPStackDagger() + await stack.deploy("test-config.yaml") + + mock_build.assert_called_once() + mock_certs.assert_called_once() + mock_gen_manifests.assert_called_once() + mock_deploy.assert_called_once() + + @patch("mcpgateway.tools.builder.dagger_deploy.load_config") + @patch.object(MCPStackDagger, "generate_manifests") + @pytest.mark.asyncio + async def test_deploy_dry_run(self, mock_gen_manifests, mock_load, mock_dagger_connection, tmp_path): + """Test dry-run deployment with Dagger.""" + mock_load.return_value = { + "deployment": {"type": "compose"}, + "gateway": {"image": "mcpgateway:latest"}, + "plugins": [], + } + mock_gen_manifests.return_value = Path("/tmp/manifests") + + stack = MCPStackDagger() + await stack.deploy("test-config.yaml", dry_run=True, skip_build=True, skip_certs=True) + + mock_gen_manifests.assert_called_once() + + @patch("mcpgateway.tools.builder.dagger_deploy.get_deploy_dir") + @patch("mcpgateway.tools.builder.dagger_deploy.load_config") + @patch.object(MCPStackDagger, "generate_manifests") + @patch.object(MCPStackDagger, "_deploy_kubernetes") + @pytest.mark.asyncio + async def test_deploy_kubernetes(self, mock_deploy, mock_gen_manifests, mock_load, mock_get_deploy, mock_dagger_connection, tmp_path): + """Test Kubernetes deployment with Dagger.""" + mock_load.return_value = { + "deployment": {"type": "kubernetes", "namespace": "test-ns"}, + "gateway": {"image": "mcpgateway:latest", "mtls_enabled": False}, + "plugins": [], + } + mock_gen_manifests.return_value = Path("/tmp/manifests") + mock_get_deploy.return_value = tmp_path / "deploy" + + stack = MCPStackDagger() + await stack.deploy("test-config.yaml", skip_build=True, skip_certs=True) + + mock_deploy.assert_called_once() + + +class TestMCPStackDaggerVerify: + """Test MCPStackDagger verify method.""" + + @patch("mcpgateway.tools.builder.dagger_deploy.get_deploy_dir") + @patch("mcpgateway.tools.builder.dagger_deploy.load_config") + @patch.object(MCPStackDagger, "_verify_kubernetes") + @pytest.mark.asyncio + async def test_verify_kubernetes(self, mock_verify_kubernetes, mock_load, mock_get_deploy, mock_dagger_connection, tmp_path): + """Test Kubernetes deployment verification with Dagger.""" + mock_load.return_value = { + "deployment": {"type": "kubernetes", "namespace": "test-ns"} + } + mock_get_deploy.return_value = tmp_path / "deploy" + + stack = MCPStackDagger() + await stack.verify("test-config.yaml") + + mock_verify_kubernetes.assert_called_once() + + @patch("mcpgateway.tools.builder.dagger_deploy.get_deploy_dir") + @patch("mcpgateway.tools.builder.dagger_deploy.load_config") + @patch.object(MCPStackDagger, "_verify_compose") + @pytest.mark.asyncio + async def test_verify_compose(self, mock_verify_compose, mock_load, mock_get_deploy, mock_dagger_connection, tmp_path): + """Test Docker Compose deployment verification with Dagger.""" + mock_load.return_value = {"deployment": {"type": "compose"}} + mock_get_deploy.return_value = tmp_path / "deploy" + + stack = MCPStackDagger() + await stack.verify("test-config.yaml") + + mock_verify_compose.assert_called_once() + + +class TestMCPStackDaggerDestroy: + """Test MCPStackDagger destroy method.""" + + @patch("mcpgateway.tools.builder.dagger_deploy.get_deploy_dir") + @patch("mcpgateway.tools.builder.dagger_deploy.load_config") + @patch.object(MCPStackDagger, "_destroy_kubernetes") + @pytest.mark.asyncio + async def test_destroy_kubernetes(self, mock_destroy_kubernetes, mock_load, mock_get_deploy, mock_dagger_connection, tmp_path): + """Test Kubernetes deployment destruction with Dagger.""" + mock_load.return_value = {"deployment": {"type": "kubernetes"}} + mock_get_deploy.return_value = tmp_path / "deploy" + + stack = MCPStackDagger() + await stack.destroy("test-config.yaml") + + mock_destroy_kubernetes.assert_called_once() + + @patch("mcpgateway.tools.builder.dagger_deploy.get_deploy_dir") + @patch("mcpgateway.tools.builder.dagger_deploy.load_config") + @patch.object(MCPStackDagger, "_destroy_compose") + @pytest.mark.asyncio + async def test_destroy_compose(self, mock_destroy_compose, mock_load, mock_get_deploy, mock_dagger_connection, tmp_path): + """Test Docker Compose deployment destruction with Dagger.""" + mock_load.return_value = {"deployment": {"type": "compose"}} + mock_get_deploy.return_value = tmp_path / "deploy" + + stack = MCPStackDagger() + await stack.destroy("test-config.yaml") + + mock_destroy_compose.assert_called_once() + + +class TestMCPStackDaggerGenerateManifests: + """Test MCPStackDagger generate_manifests method.""" + + @patch("mcpgateway.tools.builder.dagger_deploy.load_config") + @patch("mcpgateway.tools.builder.dagger_deploy.generate_plugin_config") + @patch("mcpgateway.tools.builder.dagger_deploy.generate_kubernetes_manifests") + def test_generate_manifests_kubernetes( + self, mock_k8s_gen, mock_plugin_gen, mock_load, tmp_path + ): + """Test generating Kubernetes manifests with Dagger.""" + mock_load.return_value = { + "deployment": {"type": "kubernetes", "namespace": "test-ns"}, + "gateway": {"image": "mcpgateway:latest"}, + "plugins": [], + } + + stack = MCPStackDagger() + result = stack.generate_manifests("test-config.yaml", output_dir=str(tmp_path)) + + mock_plugin_gen.assert_called_once() + mock_k8s_gen.assert_called_once() + assert result == tmp_path + + @patch("mcpgateway.tools.builder.dagger_deploy.load_config") + @patch("mcpgateway.tools.builder.dagger_deploy.generate_plugin_config") + @patch("mcpgateway.tools.builder.dagger_deploy.generate_compose_manifests") + def test_generate_manifests_compose( + self, mock_compose_gen, mock_plugin_gen, mock_load, tmp_path + ): + """Test generating Docker Compose manifests with Dagger.""" + mock_load.return_value = { + "deployment": {"type": "compose"}, + "gateway": {"image": "mcpgateway:latest"}, + "plugins": [], + } + + stack = MCPStackDagger() + result = stack.generate_manifests("test-config.yaml", output_dir=str(tmp_path)) + + mock_plugin_gen.assert_called_once() + mock_compose_gen.assert_called_once() + assert result == tmp_path + + @patch("mcpgateway.tools.builder.dagger_deploy.get_deploy_dir") + @patch("mcpgateway.tools.builder.dagger_deploy.load_config") + def test_generate_manifests_invalid_type(self, mock_load, mock_get_deploy, tmp_path): + """Test generating manifests with invalid deployment type.""" + mock_load.return_value = { + "deployment": {"type": "invalid"}, + "gateway": {"image": "mcpgateway:latest"}, + } + mock_get_deploy.return_value = tmp_path / "deploy" + + stack = MCPStackDagger() + with pytest.raises(ValueError, match="Unsupported deployment type"): + stack.generate_manifests("test-config.yaml") + + +class TestMCPStackDaggerBuildComponent: + """Test MCPStackDagger _build_component_with_dagger method.""" + + @pytest.mark.asyncio + async def test_build_component_basic(self, mock_dagger_connection, tmp_path): + """Test basic component build with Dagger.""" + component = { + "repo": "https://github.com/test/component.git", + "ref": "main", + "context": ".", + "containerfile": "Containerfile", + "image": "test-component:latest", + } + + stack = MCPStackDagger() + await stack._build_component_with_dagger(component, "test-component") + + # Verify Dagger operations were called (using mocks from fixture) + mock_dag = mock_dagger_connection["dag"] + mock_dag.git.assert_called_once() + + # Get the mock git object + mock_git = mock_dag.git.return_value + mock_git.branch.assert_called_with("main") + + # Get the mock tree object + mock_tree = mock_git.tree.return_value + mock_tree.docker_build.assert_called_once() + + @pytest.mark.asyncio + async def test_build_component_with_target(self, mock_dagger_connection, tmp_path): + """Test component build with multi-stage target.""" + component = { + "repo": "https://github.com/test/component.git", + "ref": "main", + "context": ".", + "image": "test:latest", + "target": "production", + } + + stack = MCPStackDagger() + await stack._build_component_with_dagger(component, "test") + + # Verify docker_build was called with target parameter + mock_dag = mock_dagger_connection["dag"] + mock_git = mock_dag.git.return_value + mock_tree = mock_git.tree.return_value + call_args = mock_tree.docker_build.call_args + assert "target" in call_args[1] or call_args[0] + + @pytest.mark.asyncio + async def test_build_component_with_env_vars(self, mock_dagger_connection, tmp_path): + """Test component build with environment variables.""" + component = { + "repo": "https://github.com/test/component.git", + "ref": "main", + "image": "test:latest", + "env_vars": {"BUILD_ENV": "production", "VERSION": "1.0"}, + } + + stack = MCPStackDagger() + await stack._build_component_with_dagger(component, "test") + + # Verify docker_build was called + mock_dag = mock_dagger_connection["dag"] + mock_git = mock_dag.git.return_value + mock_tree = mock_git.tree.return_value + mock_tree.docker_build.assert_called_once() diff --git a/tests/unit/mcpgateway/tools/builder/test_python_deploy.py b/tests/unit/mcpgateway/tools/builder/test_python_deploy.py new file mode 100644 index 000000000..1f46d1601 --- /dev/null +++ b/tests/unit/mcpgateway/tools/builder/test_python_deploy.py @@ -0,0 +1,294 @@ +# -*- coding: utf-8 -*- +"""Location: ./tests/unit/mcpgateway/tools/builder/test_python_deploy.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Unit tests for plain Python MCP Stack deployment. +""" + +# Standard +from pathlib import Path +import re +import subprocess +from unittest.mock import MagicMock, Mock, patch, call + +# Third-Party +import pytest +from pydantic import ValidationError + +# First-Party +from mcpgateway.tools.builder.python_deploy import MCPStackPython +from mcpgateway.tools.builder.schema import BuildableConfig, MCPStackConfig + + +class TestMCPStackPython: + """Test MCPStackPython deployment class.""" + + @patch("mcpgateway.tools.builder.python_deploy.shutil.which") + @patch("mcpgateway.tools.builder.python_deploy.load_config") + @pytest.mark.asyncio + async def test_build_no_plugins(self, mock_load, mock_which): + """Test building when no plugins are defined.""" + mock_which.return_value = "/usr/bin/docker" + mock_load.return_value = MCPStackConfig.model_validate({ + "deployment": {"type": "compose"}, + "gateway": {"image": "mcpgateway:latest"}, + "plugins": [], + }) + + stack = MCPStackPython() + # Should not raise error + await stack.build("test-config.yaml", plugins_only=True) + + @patch("mcpgateway.tools.builder.python_deploy.shutil.which") + @patch("mcpgateway.tools.builder.python_deploy.load_config") + @patch("mcpgateway.tools.builder.python_deploy.shutil.which", return_value="/usr/bin/make") + @patch.object(MCPStackPython, "_run_command") + @pytest.mark.asyncio + async def test_generate_certificates(self, mock_run, mock_make, mock_load, mock_which_runtime): + """Test certificate generation.""" + mock_which_runtime.return_value = "/usr/bin/docker" + mock_load.return_value = MCPStackConfig.model_validate({ + "gateway": {"image": "mcpgateway:latest"}, + "deployment": {"type": "compose"}, + "plugins": [ + {"name": "Plugin1", "repo": "https://github.com/test/plugin1.git"}, + {"name": "Plugin2", "repo": "https://github.com/test/plugin2.git"}, + ] + }) + + stack = MCPStackPython() + await stack.generate_certificates("test-config.yaml") + + # Should call make commands for CA, gateway, and each plugin + assert mock_run.call_count == 4 # CA + gateway + 2 plugins + + @patch("mcpgateway.tools.builder.python_deploy.shutil.which") + @patch("mcpgateway.tools.builder.python_deploy.load_config") + @patch.object(MCPStackPython, "build") + @patch.object(MCPStackPython, "generate_certificates") + @patch.object(MCPStackPython, "generate_manifests") + @patch.object(MCPStackPython, "_deploy_compose") + @pytest.mark.asyncio + async def test_deploy_compose( + self, mock_deploy, mock_gen_manifests, mock_certs, mock_build, mock_load, mock_which + ): + """Test full compose deployment.""" + mock_which.return_value = "/usr/bin/docker" + mock_load.return_value = MCPStackConfig.model_validate({ + "deployment": {"type": "compose", "project_name": "test"}, + "gateway": {"image": "mcpgateway:latest", "mtls_enabled": True}, + "plugins": [], + }) + mock_gen_manifests.return_value = Path("/tmp/manifests") + + stack = MCPStackPython() + await stack.deploy("test-config.yaml") + + mock_build.assert_called_once() + mock_certs.assert_called_once() + mock_gen_manifests.assert_called_once() + mock_deploy.assert_called_once() + + @patch("mcpgateway.tools.builder.python_deploy.shutil.which") + @patch("mcpgateway.tools.builder.python_deploy.load_config") + @patch.object(MCPStackPython, "build") + @patch.object(MCPStackPython, "generate_manifests") + @pytest.mark.asyncio + async def test_deploy_dry_run(self, mock_gen_manifests, mock_build, mock_load, mock_which): + """Test dry-run deployment.""" + mock_which.return_value = "/usr/bin/docker" + mock_load.return_value = MCPStackConfig.model_validate({ + "deployment": {"type": "compose"}, + "gateway": {"image": "mcpgateway:latest"}, + "plugins": [], + }) + mock_gen_manifests.return_value = Path("/tmp/manifests") + + stack = MCPStackPython() + await stack.deploy("test-config.yaml", dry_run=True, skip_build=True, skip_certs=True) + + mock_gen_manifests.assert_called_once() + # Should not call actual deployment + + @patch("mcpgateway.tools.builder.python_deploy.shutil.which") + @patch("mcpgateway.tools.builder.python_deploy.load_config") + @patch.object(MCPStackPython, "generate_manifests") + @pytest.mark.asyncio + async def test_deploy_skip_certs_mtls_disabled(self, mock_gen_manifests, mock_load, mock_which): + """Test deployment with mTLS disabled.""" + mock_which.return_value = "/usr/bin/docker" + mock_load.return_value = MCPStackConfig.model_validate({ + "deployment": {"type": "compose"}, + "gateway": {"image": "mcpgateway:latest", "mtls_enabled": False}, + "plugins": [], + }) + mock_gen_manifests.return_value = Path("/tmp/manifests") + + stack = MCPStackPython() + with patch.object(stack, "generate_certificates") as mock_certs: + await stack.deploy("test-config.yaml", dry_run=True, skip_build=True) + + # Certificates should not be generated + mock_certs.assert_not_called() + + @patch("mcpgateway.tools.builder.python_deploy.shutil.which") + @patch("mcpgateway.tools.builder.python_deploy.load_config") + @patch.object(MCPStackPython, "_verify_kubernetes") + @pytest.mark.asyncio + async def test_verify_kubernetes(self, mock_verify, mock_load, mock_which): + """Test Kubernetes deployment verification.""" + mock_which.return_value = "/usr/bin/docker" + mock_load.return_value = MCPStackConfig.model_validate({ + "gateway": {"image": "mcpgateway:latest", "mtls_enabled": False}, + "deployment": {"type": "kubernetes", "namespace": "test-ns"} + }) + + stack = MCPStackPython() + await stack.verify("test-config.yaml") + + mock_verify.assert_called_once() + + @patch("mcpgateway.tools.builder.python_deploy.shutil.which") + @patch("mcpgateway.tools.builder.python_deploy.load_config") + @patch.object(MCPStackPython, "_verify_compose") + @pytest.mark.asyncio + async def test_verify_compose(self, mock_verify, mock_load, mock_which): + """Test Docker Compose deployment verification.""" + mock_which.return_value = "/usr/bin/docker" + mock_load.return_value = MCPStackConfig.model_validate({"deployment": {"type": "compose"}, + "gateway": {"image": "mcpgateway:latest", "mtls_enabled": False}, + }) + + stack = MCPStackPython() + await stack.verify("test-config.yaml") + + mock_verify.assert_called_once() + + @patch("mcpgateway.tools.builder.python_deploy.shutil.which") + @patch("mcpgateway.tools.builder.python_deploy.load_config") + @patch.object(MCPStackPython, "_destroy_kubernetes") + @pytest.mark.asyncio + async def test_destroy_kubernetes(self, mock_destroy, mock_load, mock_which): + """Test Kubernetes deployment destruction.""" + mock_which.return_value = "/usr/bin/docker" + mock_load.return_value = MCPStackConfig.model_validate({"deployment": {"type": "kubernetes"}, + "gateway": {"image": "mcpgateway:latest", "mtls_enabled": False}, + }) + + stack = MCPStackPython() + await stack.destroy("test-config.yaml") + + mock_destroy.assert_called_once() + + @patch("mcpgateway.tools.builder.python_deploy.shutil.which") + @patch("mcpgateway.tools.builder.python_deploy.load_config") + @patch.object(MCPStackPython, "_destroy_compose") + @pytest.mark.asyncio + async def test_destroy_compose(self, mock_destroy, mock_load, mock_which): + """Test Docker Compose deployment destruction.""" + mock_which.return_value = "/usr/bin/docker" + mock_load.return_value = MCPStackConfig.model_validate({"deployment": {"type": "compose"}, + "gateway": {"image": "mcpgateway:latest", "mtls_enabled": False}, + }) + + stack = MCPStackPython() + await stack.destroy("test-config.yaml") + + mock_destroy.assert_called_once() + + @patch("mcpgateway.tools.builder.python_deploy.shutil.which") + @patch("mcpgateway.tools.builder.python_deploy.load_config") + @patch("mcpgateway.tools.builder.python_deploy.generate_plugin_config") + @patch("mcpgateway.tools.builder.python_deploy.generate_kubernetes_manifests") + def test_generate_manifests_kubernetes( + self, mock_k8s_gen, mock_plugin_gen, mock_load, mock_which, tmp_path + ): + """Test generating Kubernetes manifests.""" + mock_which.return_value = "/usr/bin/docker" + mock_load.return_value = MCPStackConfig.model_validate({ + "deployment": {"type": "kubernetes", "namespace": "test-ns"}, + "gateway": {"image": "mcpgateway:latest"}, + "plugins": [], + }) + + stack = MCPStackPython() + result = stack.generate_manifests("test-config.yaml", output_dir=str(tmp_path)) + + mock_plugin_gen.assert_called_once() + mock_k8s_gen.assert_called_once() + assert result == tmp_path + + @patch("mcpgateway.tools.builder.python_deploy.shutil.which") + @patch("mcpgateway.tools.builder.python_deploy.load_config") + @patch("mcpgateway.tools.builder.python_deploy.generate_plugin_config") + @patch("mcpgateway.tools.builder.python_deploy.generate_compose_manifests") + def test_generate_manifests_compose( + self, mock_compose_gen, mock_plugin_gen, mock_load, mock_which, tmp_path + ): + """Test generating Docker Compose manifests.""" + mock_which.return_value = "/usr/bin/docker" + mock_load.return_value = MCPStackConfig.model_validate({ + "deployment": {"type": "compose"}, + "gateway": {"image": "mcpgateway:latest"}, + "plugins": [], + }) + + stack = MCPStackPython() + result = stack.generate_manifests("test-config.yaml", output_dir=str(tmp_path)) + + mock_plugin_gen.assert_called_once() + mock_compose_gen.assert_called_once() + assert result == tmp_path + + @patch("mcpgateway.tools.builder.python_deploy.shutil.which") + @patch("mcpgateway.tools.builder.python_deploy.load_config") + @patch("mcpgateway.tools.builder.python_deploy.get_deploy_dir") + def test_generate_manifests_invalid_type(self, mock_get_deploy, mock_load, mock_which, tmp_path): + """Test generating manifests with invalid deployment type.""" + mock_which.return_value = "/usr/bin/docker" + with pytest.raises(ValidationError, match=re.escape("1 validation error for MCPStackConfig\ndeployment.type\n Input should be 'kubernetes' or 'compose' [type=literal_error, input_value='invalid', input_type=str]\n For further information visit https://errors.pydantic.dev/2.12/v/literal_error")): + mock_load.return_value = MCPStackConfig.model_validate({ + "deployment": {"type": "invalid"}, + "gateway": {"image": "mcpgateway:latest"}, + }) + +class TestRunCommand: + """Test _run_command method.""" + + @patch("mcpgateway.tools.builder.python_deploy.shutil.which") + @patch("mcpgateway.tools.builder.python_deploy.subprocess.run") + def test_run_command_success(self, mock_run, mock_which): + """Test successful command execution.""" + mock_which.return_value = "/usr/bin/docker" + mock_run.return_value = Mock(returncode=0, stdout="Success", stderr="") + + stack = MCPStackPython() + result = stack._run_command(["echo", "test"]) + + assert result.returncode == 0 + mock_run.assert_called_once() + + @patch("mcpgateway.tools.builder.python_deploy.shutil.which") + @patch("mcpgateway.tools.builder.python_deploy.subprocess.run") + def test_run_command_failure(self, mock_run, mock_which): + """Test command execution failure.""" + mock_which.return_value = "/usr/bin/docker" + mock_run.side_effect = subprocess.CalledProcessError(1, "cmd") + + stack = MCPStackPython() + with pytest.raises(subprocess.CalledProcessError): + stack._run_command(["false"]) + + @patch("mcpgateway.tools.builder.python_deploy.shutil.which") + @patch("mcpgateway.tools.builder.python_deploy.subprocess.run") + def test_run_command_with_cwd(self, mock_run, mock_which, tmp_path): + """Test command execution with working directory.""" + mock_which.return_value = "/usr/bin/docker" + mock_run.return_value = Mock(returncode=0) + + stack = MCPStackPython() + stack._run_command(["ls"], cwd=tmp_path) + + assert mock_run.call_args[1]["cwd"] == tmp_path diff --git a/tests/unit/mcpgateway/tools/builder/test_schema.py b/tests/unit/mcpgateway/tools/builder/test_schema.py new file mode 100644 index 000000000..63897bdaf --- /dev/null +++ b/tests/unit/mcpgateway/tools/builder/test_schema.py @@ -0,0 +1,330 @@ +# -*- coding: utf-8 -*- +"""Location: ./tests/unit/mcpgateway/tools/builder/test_schema.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Unit tests for builder schema validation (Pydantic models). +""" + +# Third-Party +import pytest +from pydantic import ValidationError + +# First-Party +from mcpgateway.tools.builder.schema import ( + BuildableConfig, + CertificatesConfig, + DeploymentConfig, + GatewayConfig, + InfrastructureConfig, + MCPStackConfig, + PluginConfig, + PostgresConfig, + RedisConfig, +) + + +class TestDeploymentConfig: + """Test DeploymentConfig validation.""" + + def test_valid_kubernetes_deployment(self): + """Test valid Kubernetes deployment configuration.""" + config = DeploymentConfig(type="kubernetes", namespace="test-ns") + assert config.type == "kubernetes" + assert config.namespace == "test-ns" + assert config.project_name is None + + def test_valid_compose_deployment(self): + """Test valid Docker Compose deployment configuration.""" + config = DeploymentConfig(type="compose", project_name="test-project") + assert config.type == "compose" + assert config.project_name == "test-project" + assert config.namespace is None + + def test_invalid_deployment_type(self): + """Test invalid deployment type.""" + with pytest.raises(ValidationError): + DeploymentConfig(type="invalid") + + +class TestGatewayConfig: + """Test GatewayConfig validation.""" + + def test_gateway_with_image(self): + """Test gateway config with pre-built image.""" + config = GatewayConfig(image="mcpgateway:latest", port=4444) + assert config.image == "mcpgateway:latest" + assert config.port == 4444 + assert config.repo is None + + def test_gateway_with_repo(self): + """Test gateway config with repository build.""" + config = GatewayConfig( + repo="https://github.com/org/repo.git", + ref="main", + context=".", + port=4444 + ) + assert config.repo == "https://github.com/org/repo.git" + assert config.ref == "main" + assert config.image is None + + def test_gateway_without_image_or_repo(self): + """Test that gateway requires either image or repo.""" + with pytest.raises(ValueError, match="must specify either 'image' or 'repo'"): + GatewayConfig(port=4444) + + def test_gateway_defaults(self): + """Test gateway default values.""" + config = GatewayConfig(image="test:latest") + assert config.port == 4444 + assert config.mtls_enabled is True + assert config.ref == "main" + assert config.context == "." + assert config.containerfile == "Containerfile" + + +class TestPluginConfig: + """Test PluginConfig validation.""" + + def test_plugin_with_image(self): + """Test plugin config with pre-built image.""" + config = PluginConfig(name="TestPlugin", image="test:latest") + assert config.name == "TestPlugin" + assert config.image == "test:latest" + assert config.repo is None + + def test_plugin_with_repo(self): + """Test plugin config with repository build.""" + config = PluginConfig( + name="TestPlugin", + repo="https://github.com/org/plugin.git", + ref="v1.0.0", + context="plugins/test" + ) + assert config.name == "TestPlugin" + assert config.repo == "https://github.com/org/plugin.git" + assert config.ref == "v1.0.0" + assert config.context == "plugins/test" + + def test_plugin_without_name(self): + """Test that plugin requires name.""" + with pytest.raises(ValidationError): + PluginConfig(image="test:latest") + + def test_plugin_empty_name(self): + """Test that plugin name cannot be empty.""" + with pytest.raises(ValidationError, match="Plugin name cannot be empty"): + PluginConfig(name="", image="test:latest") + + def test_plugin_whitespace_name(self): + """Test that plugin name cannot be whitespace only.""" + with pytest.raises(ValidationError, match="Plugin name cannot be empty"): + PluginConfig(name=" ", image="test:latest") + + def test_plugin_defaults(self): + """Test plugin default values.""" + config = PluginConfig(name="TestPlugin", image="test:latest") + assert config.port == 8000 + assert config.expose_port is False + assert config.mtls_enabled is True + assert config.plugin_overrides == {} + + def test_plugin_overrides(self): + """Test plugin with overrides.""" + config = PluginConfig( + name="TestPlugin", + image="test:latest", + plugin_overrides={ + "priority": 10, + "mode": "enforce", + "tags": ["security", "filter"] + } + ) + assert config.plugin_overrides["priority"] == 10 + assert config.plugin_overrides["mode"] == "enforce" + assert config.plugin_overrides["tags"] == ["security", "filter"] + + +class TestCertificatesConfig: + """Test CertificatesConfig validation.""" + + def test_certificates_defaults(self): + """Test certificates default values.""" + config = CertificatesConfig() + assert config.validity_days == 825 + assert config.auto_generate is True + assert config.ca_path == "./certs/mcp/ca" + assert config.gateway_path == "./certs/mcp/gateway" + assert config.plugins_path == "./certs/mcp/plugins" + + def test_certificates_custom_values(self): + """Test certificates with custom values.""" + config = CertificatesConfig( + validity_days=365, + auto_generate=False, + ca_path="/custom/ca", + gateway_path="/custom/gateway", + plugins_path="/custom/plugins" + ) + assert config.validity_days == 365 + assert config.auto_generate is False + assert config.ca_path == "/custom/ca" + + +class TestInfrastructureConfig: + """Test InfrastructureConfig validation.""" + + def test_postgres_defaults(self): + """Test PostgreSQL default configuration.""" + config = PostgresConfig() + assert config.enabled is True + assert config.image == "quay.io/sclorg/postgresql-15-c9s:latest" + assert config.database == "mcp" + assert config.user == "postgres" + assert config.password == "mysecretpassword" + assert config.storage_size == "10Gi" + + def test_postgres_custom(self): + """Test PostgreSQL custom configuration.""" + config = PostgresConfig( + enabled=True, + image="postgres:16", + database="customdb", + user="customuser", + password="custompass", + storage_size="20Gi", + storage_class="fast-ssd" + ) + assert config.image == "postgres:16" + assert config.database == "customdb" + assert config.storage_class == "fast-ssd" + + def test_redis_defaults(self): + """Test Redis default configuration.""" + config = RedisConfig() + assert config.enabled is True + assert config.image == "redis:latest" + + def test_infrastructure_defaults(self): + """Test infrastructure with default values.""" + config = InfrastructureConfig() + assert config.postgres.enabled is True + assert config.redis.enabled is True + + +class TestMCPStackConfig: + """Test complete MCPStackConfig validation.""" + + def test_minimal_config(self): + """Test minimal valid configuration.""" + config = MCPStackConfig( + deployment=DeploymentConfig(type="compose", project_name="test"), + gateway=GatewayConfig(image="mcpgateway:latest") + ) + assert config.deployment.type == "compose" + assert config.gateway.image == "mcpgateway:latest" + assert config.plugins == [] + + def test_full_config(self): + """Test full configuration with all options.""" + config = MCPStackConfig( + deployment=DeploymentConfig(type="kubernetes", namespace="prod"), + gateway=GatewayConfig( + image="mcpgateway:latest", + port=4444, + mtls_enabled=True + ), + plugins=[ + PluginConfig(name="Plugin1", image="plugin1:latest"), + PluginConfig(name="Plugin2", image="plugin2:latest") + ], + certificates=CertificatesConfig(validity_days=365), + infrastructure=InfrastructureConfig() + ) + assert config.deployment.namespace == "prod" + assert len(config.plugins) == 2 + assert config.certificates.validity_days == 365 + + def test_duplicate_plugin_names(self): + """Test that duplicate plugin names are rejected.""" + with pytest.raises(ValidationError, match="Duplicate plugin names found"): + MCPStackConfig( + deployment=DeploymentConfig(type="compose"), + gateway=GatewayConfig(image="test:latest"), + plugins=[ + PluginConfig(name="DuplicatePlugin", image="plugin1:latest"), + PluginConfig(name="DuplicatePlugin", image="plugin2:latest") + ] + ) + + def test_unique_plugin_names(self): + """Test that unique plugin names are accepted.""" + config = MCPStackConfig( + deployment=DeploymentConfig(type="compose"), + gateway=GatewayConfig(image="test:latest"), + plugins=[ + PluginConfig(name="Plugin1", image="plugin1:latest"), + PluginConfig(name="Plugin2", image="plugin2:latest"), + PluginConfig(name="Plugin3", image="plugin3:latest") + ] + ) + assert len(config.plugins) == 3 + assert [p.name for p in config.plugins] == ["Plugin1", "Plugin2", "Plugin3"] + + def test_config_with_repo_builds(self): + """Test configuration with repository builds.""" + config = MCPStackConfig( + deployment=DeploymentConfig(type="compose"), + gateway=GatewayConfig( + repo="https://github.com/org/gateway.git", + ref="v2.0.0" + ), + plugins=[ + PluginConfig( + name="BuiltPlugin", + repo="https://github.com/org/plugin.git", + ref="main", + context="plugins/src" + ) + ] + ) + assert config.gateway.repo is not None + assert config.gateway.ref == "v2.0.0" + assert config.plugins[0].repo is not None + assert config.plugins[0].context == "plugins/src" + + +class TestBuildableConfig: + """Test BuildableConfig base class validation.""" + + def test_mtls_defaults(self): + """Test mTLS default settings.""" + config = GatewayConfig(image="test:latest") + assert config.mtls_enabled is True + + def test_mtls_disabled(self): + """Test mTLS can be disabled.""" + config = GatewayConfig(image="test:latest", mtls_enabled=False) + assert config.mtls_enabled is False + + def test_env_vars(self): + """Test environment variables.""" + config = GatewayConfig( + image="test:latest", + env_vars={"LOG_LEVEL": "DEBUG", "PORT": "4444"} + ) + assert config.env_vars["LOG_LEVEL"] == "DEBUG" + assert config.env_vars["PORT"] == "4444" + + def test_multi_stage_build(self): + """Test multi-stage build target.""" + config = PluginConfig( + name="TestPlugin", + repo="https://github.com/org/plugin.git", + containerfile="Dockerfile", + target="production" + ) + assert config.containerfile == "Dockerfile" + assert config.target == "production" diff --git a/tests/unit/mcpgateway/utils/test_correlation_id.py b/tests/unit/mcpgateway/utils/test_correlation_id.py new file mode 100644 index 000000000..6b80ae163 --- /dev/null +++ b/tests/unit/mcpgateway/utils/test_correlation_id.py @@ -0,0 +1,216 @@ +# -*- coding: utf-8 -*- +"""Tests for correlation ID utilities.""" + +import asyncio +import pytest +from mcpgateway.utils.correlation_id import ( + clear_correlation_id, + extract_correlation_id_from_headers, + generate_correlation_id, + get_correlation_id, + get_or_generate_correlation_id, + set_correlation_id, + validate_correlation_id, +) + + +def test_generate_correlation_id(): + """Test correlation ID generation.""" + id1 = generate_correlation_id() + id2 = generate_correlation_id() + + assert id1 is not None + assert id2 is not None + assert id1 != id2 + assert len(id1) == 32 # UUID4 hex is 32 characters + assert len(id2) == 32 + + +def test_set_and_get_correlation_id(): + """Test setting and getting correlation ID.""" + test_id = "test-correlation-123" + + set_correlation_id(test_id) + retrieved_id = get_correlation_id() + + assert retrieved_id == test_id + + clear_correlation_id() + + +def test_clear_correlation_id(): + """Test clearing correlation ID.""" + test_id = "test-correlation-456" + + set_correlation_id(test_id) + assert get_correlation_id() == test_id + + clear_correlation_id() + assert get_correlation_id() is None + + +def test_get_correlation_id_returns_none_when_not_set(): + """Test getting correlation ID when not set.""" + clear_correlation_id() + assert get_correlation_id() is None + + +def test_extract_correlation_id_from_headers(): + """Test extracting correlation ID from headers.""" + headers = {"X-Correlation-ID": "header-correlation-789"} + + correlation_id = extract_correlation_id_from_headers(headers) + assert correlation_id == "header-correlation-789" + + +def test_extract_correlation_id_from_headers_case_insensitive(): + """Test case-insensitive header extraction.""" + headers = {"x-correlation-id": "lowercase-id"} + + correlation_id = extract_correlation_id_from_headers(headers) + assert correlation_id == "lowercase-id" + + +def test_extract_correlation_id_from_headers_custom_header(): + """Test extracting from custom header name.""" + headers = {"X-Request-ID": "custom-request-id"} + + correlation_id = extract_correlation_id_from_headers(headers, "X-Request-ID") + assert correlation_id == "custom-request-id" + + +def test_extract_correlation_id_from_headers_not_found(): + """Test when correlation ID header is not present.""" + headers = {"Content-Type": "application/json"} + + correlation_id = extract_correlation_id_from_headers(headers) + assert correlation_id is None + + +def test_extract_correlation_id_from_headers_empty_value(): + """Test when correlation ID header has empty value.""" + headers = {"X-Correlation-ID": " "} + + correlation_id = extract_correlation_id_from_headers(headers) + assert correlation_id is None + + +def test_get_or_generate_correlation_id_when_not_set(): + """Test get_or_generate when ID is not set.""" + clear_correlation_id() + + correlation_id = get_or_generate_correlation_id() + + assert correlation_id is not None + assert len(correlation_id) == 32 + assert get_correlation_id() == correlation_id # Should be stored + + clear_correlation_id() + + +def test_get_or_generate_correlation_id_when_already_set(): + """Test get_or_generate when ID is already set.""" + test_id = "existing-correlation-id" + set_correlation_id(test_id) + + correlation_id = get_or_generate_correlation_id() + + assert correlation_id == test_id + + clear_correlation_id() + + +def test_validate_correlation_id_valid(): + """Test validation of valid correlation IDs.""" + assert validate_correlation_id("abc-123") is True + assert validate_correlation_id("test_id_456") is True + assert validate_correlation_id("UPPER-lower-123_mix") is True + + +def test_validate_correlation_id_invalid(): + """Test validation of invalid correlation IDs.""" + assert validate_correlation_id(None) is False + assert validate_correlation_id("") is False + assert validate_correlation_id(" ") is False + assert validate_correlation_id("id with spaces") is False + assert validate_correlation_id("id@special!chars") is False + + +def test_validate_correlation_id_too_long(): + """Test validation rejects overly long IDs.""" + long_id = "a" * 256 # Default max is 255 + + assert validate_correlation_id(long_id) is False + assert validate_correlation_id(long_id, max_length=300) is True + + +@pytest.mark.asyncio +async def test_correlation_id_isolation_between_async_tasks(): + """Test that correlation IDs are isolated between concurrent async tasks.""" + results = [] + + async def task_with_id(task_id: str): + set_correlation_id(task_id) + await asyncio.sleep(0.01) # Simulate async work + retrieved_id = get_correlation_id() + results.append((task_id, retrieved_id)) + clear_correlation_id() + + # Run multiple tasks concurrently + await asyncio.gather( + task_with_id("task-1"), + task_with_id("task-2"), + task_with_id("task-3"), + ) + + # Each task should have retrieved its own ID + assert len(results) == 3 + for task_id, retrieved_id in results: + assert task_id == retrieved_id + + +@pytest.mark.asyncio +async def test_correlation_id_inheritance_in_nested_tasks(): + """Test that correlation ID is inherited by child async tasks.""" + + async def parent_task(): + set_correlation_id("parent-id") + parent_id = get_correlation_id() + + async def child_task(): + return get_correlation_id() + + child_id = await child_task() + + clear_correlation_id() + return parent_id, child_id + + parent_id, child_id = await parent_task() + + # Child should inherit parent's correlation ID + assert parent_id == "parent-id" + assert child_id == "parent-id" + + +def test_correlation_id_context_isolation(): + """Test that correlation ID is properly isolated per context.""" + clear_correlation_id() + + # Set ID in one context + set_correlation_id("context-1") + assert get_correlation_id() == "context-1" + + # Overwrite with new ID + set_correlation_id("context-2") + assert get_correlation_id() == "context-2" + + clear_correlation_id() + assert get_correlation_id() is None + + +def test_extract_correlation_id_strips_whitespace(): + """Test that extracted correlation ID has whitespace stripped.""" + headers = {"X-Correlation-ID": " trimmed-id "} + + correlation_id = extract_correlation_id_from_headers(headers) + assert correlation_id == "trimmed-id" diff --git a/tests/unit/mcpgateway/utils/test_ssl_key_manager.py b/tests/unit/mcpgateway/utils/test_ssl_key_manager.py new file mode 100644 index 000000000..b1a4291c4 --- /dev/null +++ b/tests/unit/mcpgateway/utils/test_ssl_key_manager.py @@ -0,0 +1,253 @@ +# -*- coding: utf-8 -*- +"""Location: ./tests/unit/mcpgateway/utils/test_ssl_key_manager.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Keval Mahajan + +Unit tests for SSL key manager utility. +""" + +# Standard +import os +from pathlib import Path +import tempfile + +# Third-Party +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa +import pytest + +# First-Party +from mcpgateway.utils.ssl_key_manager import SSLKeyManager, prepare_ssl_key + + +@pytest.fixture +def temp_cert_dir(tmp_path): + """Create a temporary directory for test certificates.""" + cert_dir = tmp_path / "certs" + cert_dir.mkdir() + return cert_dir + + +@pytest.fixture +def unencrypted_key(temp_cert_dir): + """Generate an unencrypted RSA private key for testing.""" + # Generate a test RSA key + private_key = rsa.generate_private_key( + public_exponent=65537, + key_size=2048, + ) + + # Save as unencrypted PEM + key_path = temp_cert_dir / "key.pem" + with open(key_path, "wb") as f: + f.write( + private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ) + ) + + return key_path + + +@pytest.fixture +def encrypted_key(temp_cert_dir): + """Generate a passphrase-protected RSA private key for testing.""" + # Generate a test RSA key + private_key = rsa.generate_private_key( + public_exponent=65537, + key_size=2048, + ) + + # Save as encrypted PEM with passphrase "test123" + key_path = temp_cert_dir / "key-encrypted.pem" + with open(key_path, "wb") as f: + f.write( + private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.BestAvailableEncryption(b"test123"), + ) + ) + + return key_path, "test123" + + +class TestSSLKeyManager: + """Test suite for SSLKeyManager class.""" + + def test_prepare_key_file_unencrypted(self, unencrypted_key): + """Test that unencrypted keys are returned as-is.""" + manager = SSLKeyManager() + + result = manager.prepare_key_file(str(unencrypted_key)) + + # Should return the original path + assert result == str(unencrypted_key) + + # No temporary file should be created + assert manager._temp_key_file is None + + def test_prepare_key_file_encrypted(self, encrypted_key): + """Test that encrypted keys are decrypted to temporary files.""" + key_path, passphrase = encrypted_key + manager = SSLKeyManager() + + result = manager.prepare_key_file(str(key_path), passphrase) + + # Should return a different path (temporary file) + assert result != str(key_path) + + # Temporary file should exist + temp_path = Path(result) + assert temp_path.exists() + + # Temporary file should have restrictive permissions (0o600) + stat_info = os.stat(result) + permissions = stat_info.st_mode & 0o777 + assert permissions == 0o600 + + # Temporary file should be tracked + assert manager._temp_key_file == temp_path + + # Verify the decrypted key is valid + with open(result, "rb") as f: + key_data = f.read() + # Should be able to load without password + from cryptography.hazmat.primitives.serialization import load_pem_private_key + private_key = load_pem_private_key(key_data, password=None) + assert private_key is not None + + # Cleanup + manager.cleanup() + assert not temp_path.exists() + + def test_prepare_key_file_wrong_passphrase(self, encrypted_key): + """Test that wrong passphrase raises ValueError.""" + key_path, _ = encrypted_key + manager = SSLKeyManager() + + with pytest.raises(ValueError, match="Failed to decrypt private key"): + manager.prepare_key_file(str(key_path), "wrong_password") + + # Ensure cleanup was called + assert manager._temp_key_file is None + + def test_prepare_key_file_missing_file(self, temp_cert_dir): + """Test that missing key file raises FileNotFoundError.""" + manager = SSLKeyManager() + missing_path = temp_cert_dir / "nonexistent.pem" + + with pytest.raises(FileNotFoundError, match="Key file not found"): + manager.prepare_key_file(str(missing_path)) + + def test_cleanup_removes_temp_file(self, encrypted_key): + """Test that cleanup removes temporary files.""" + key_path, passphrase = encrypted_key + manager = SSLKeyManager() + + # Create temporary file + temp_path = manager.prepare_key_file(str(key_path), passphrase) + assert Path(temp_path).exists() + + # Cleanup should remove it + manager.cleanup() + assert not Path(temp_path).exists() + assert manager._temp_key_file is None + + def test_cleanup_idempotent(self): + """Test that cleanup can be called multiple times safely.""" + manager = SSLKeyManager() + + # Should not raise even if no temp file exists + manager.cleanup() + manager.cleanup() + + def test_prepare_ssl_key_convenience_function(self, unencrypted_key): + """Test the convenience function prepare_ssl_key.""" + result = prepare_ssl_key(str(unencrypted_key)) + + # Should work the same as the manager method + assert result == str(unencrypted_key) + + def test_prepare_ssl_key_with_passphrase(self, encrypted_key): + """Test convenience function with passphrase.""" + key_path, passphrase = encrypted_key + + result = prepare_ssl_key(str(key_path), passphrase) + + # Should return a temporary file path + assert result != str(key_path) + assert Path(result).exists() + + # Verify it's a valid unencrypted key + with open(result, "rb") as f: + key_data = f.read() + from cryptography.hazmat.primitives.serialization import load_pem_private_key + private_key = load_pem_private_key(key_data, password=None) + assert private_key is not None + + +class TestSSLKeyManagerIntegration: + """Integration tests for SSL key manager.""" + + def test_atexit_cleanup(self, encrypted_key): + """Test that atexit handler is registered for cleanup.""" + import atexit + + key_path, passphrase = encrypted_key + manager = SSLKeyManager() + + # Get initial atexit handlers count + initial_handlers = len(atexit._exithandlers) if hasattr(atexit, '_exithandlers') else 0 + + # Prepare key (should register cleanup) + temp_path = manager.prepare_key_file(str(key_path), passphrase) + + # Verify atexit handler was registered + # Note: This is implementation-dependent and may vary by Python version + if hasattr(atexit, '_exithandlers'): + assert len(atexit._exithandlers) > initial_handlers + + # Manual cleanup for test + manager.cleanup() + + def test_multiple_keys(self, temp_cert_dir): + """Test handling multiple keys (should only track the last one).""" + # Generate two encrypted keys + key1 = rsa.generate_private_key(public_exponent=65537, key_size=2048) + key2 = rsa.generate_private_key(public_exponent=65537, key_size=2048) + + key1_path = temp_cert_dir / "key1.pem" + key2_path = temp_cert_dir / "key2.pem" + + for key, path in [(key1, key1_path), (key2, key2_path)]: + with open(path, "wb") as f: + f.write( + key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.BestAvailableEncryption(b"test"), + ) + ) + + manager = SSLKeyManager() + + # Prepare first key + temp1 = manager.prepare_key_file(str(key1_path), "test") + temp1_path = Path(temp1) + assert temp1_path.exists() + + # Prepare second key (should replace the first) + temp2 = manager.prepare_key_file(str(key2_path), "test") + temp2_path = Path(temp2) + assert temp2_path.exists() + + # Only the second temp file should be tracked + assert manager._temp_key_file == temp2_path + + # Cleanup should only remove the second file + manager.cleanup() + assert not temp2_path.exists() diff --git a/tests/unit/mcpgateway/utils/test_verify_credentials.py b/tests/unit/mcpgateway/utils/test_verify_credentials.py index dabf49f63..942f2f800 100644 --- a/tests/unit/mcpgateway/utils/test_verify_credentials.py +++ b/tests/unit/mcpgateway/utils/test_verify_credentials.py @@ -281,9 +281,16 @@ async def test_require_auth_override_basic_auth_disabled(monkeypatch): @pytest.fixture -def test_client(): - if app is None: - pytest.skip("FastAPI app not importable") +def test_client(app, monkeypatch): + """Create a test client with the properly configured app fixture from conftest.""" + from unittest.mock import MagicMock + + # Patch security_logger at the middleware level where it's imported and called + mock_sec_logger = MagicMock() + mock_sec_logger.log_authentication_attempt = MagicMock(return_value=None) + mock_sec_logger.log_security_event = MagicMock(return_value=None) + monkeypatch.setattr("mcpgateway.middleware.auth_middleware.security_logger", mock_sec_logger) + return TestClient(app)