diff --git a/.gitignore b/.gitignore index f9133b2c..a90fd6b8 100644 --- a/.gitignore +++ b/.gitignore @@ -133,3 +133,8 @@ artifacts/ ui/node_modules/ ui/ui-kit/node_modules/ ui/test-results/ + +# SSL certificates (generated locally) +monitoring/nginx/certs/ +config.yaml +eval_reports diff --git a/Dockerfile b/Dockerfile index b184694a..76071e07 100644 --- a/Dockerfile +++ b/Dockerfile @@ -47,6 +47,9 @@ RUN --mount=type=cache,target=/root/.cache/uv \ # This is the final image. It will be much smaller. FROM python:3.12-slim +# Copy uv from the official image for runtime use (needed by entrypoint) +COPY --from=ghcr.io/astral-sh/uv:latest /uv /bin/uv + WORKDIR /app # Install ONLY runtime system dependencies diff --git a/config.yaml.example b/config.yaml.example new file mode 100644 index 00000000..4f9e2c4f --- /dev/null +++ b/config.yaml.example @@ -0,0 +1,107 @@ +# Redis SRE Agent Configuration +# Copy this file to config.yaml and customize for your environment. +# +# Settings can be loaded from (priority order): +# 1. Environment variables (highest priority) +# 2. .env file +# 3. config.yaml (this file) +# 4. Default values (lowest priority) +# +# Set SRE_AGENT_CONFIG environment variable to use a custom path. + +# Application settings +# debug: false +# log_level: INFO + +# Server settings +# host: "0.0.0.0" +# port: 8000 + +# MCP (Model Context Protocol) servers configuration +# This is the primary use case for YAML config - complex nested structures +mcp_servers: + # Memory server for long-term agent memory + redis-memory-server: + command: uv + args: + - tool + - run + - --from + - agent-memory-server + - agent-memory + - mcp + env: + REDIS_URL: redis://localhost:6399 + tools: + get_current_datetime: + description: | + Get the current date and time. Use this when you need to + record timestamps for Redis instance events or incidents. + + {original} + create_long_term_memories: + description: | + Save long-term memories about Redis instances. Use this to + record: past incidents and their resolutions, configuration + changes, performance baselines, known issues, maintenance + history, and lessons learned. Always include the instance_id + in the memory text for future retrieval. + + {original} + search_long_term_memory: + description: | + Search saved memories about Redis instances. ALWAYS use this + before troubleshooting a Redis instance to recall past issues, + solutions, and context. Search by instance_id, error patterns, + or symptoms. + + {original} + get_long_term_memory: + description: | + Retrieve a specific memory by ID. Use this to get full details + of a memory found via search. + + {original} + edit_long_term_memory: + description: | + Update an existing memory. Use this to add new information to + a past incident record, update resolution status, or correct + outdated information. + + {original} + delete_long_term_memories: + description: | + Delete memories that are no longer relevant. Use sparingly - + prefer editing to add context rather than deleting. + + {original} + + # GitHub MCP server for repository operations + # Option 1: Local Docker (requires Docker to be running) + github: + command: docker + args: + - run + - -i + - --rm + - -e + - GITHUB_PERSONAL_ACCESS_TOKEN + - ghcr.io/github/github-mcp-server + env: + # Set your GitHub Personal Access Token here or via environment variable + GITHUB_PERSONAL_ACCESS_TOKEN: ${GITHUB_PERSONAL_ACCESS_TOKEN} + + # Option 2: Remote GitHub MCP server (recommended, no Docker needed) + # Uncomment the following and comment out the local Docker option above: + # github: + # url: "https://api.githubcopilot.com/mcp/" + # headers: + # Authorization: "Bearer ${GITHUB_PERSONAL_ACCESS_TOKEN}" + # # transport: streamable_http # default, uses Streamable HTTP protocol + +# Tool providers configuration (fully qualified class paths) +# tool_providers: +# - redis_sre_agent.tools.metrics.prometheus.provider.PrometheusToolProvider +# - redis_sre_agent.tools.diagnostics.redis_command.provider.RedisCommandToolProvider +# - redis_sre_agent.tools.logs.loki.provider.LokiToolProvider +# - redis_sre_agent.tools.host_telemetry.provider.HostTelemetryToolProvider diff --git a/docker-compose.test.yml b/docker-compose.test.yml index 6268f5f5..e5cc9dc3 100644 --- a/docker-compose.test.yml +++ b/docker-compose.test.yml @@ -87,7 +87,7 @@ services: context: . dockerfile: Dockerfile ports: - - "8000:8000" + - "8080:8000" environment: - REDIS_URL=redis://redis-demo:6379/0 # Internal container port stays 6379 - PROMETHEUS_URL=http://prometheus:9090 diff --git a/docker-compose.yml b/docker-compose.yml index 20f913be..999e1c7d 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -9,10 +9,12 @@ services: - ./monitoring/redis.conf:/usr/local/etc/redis/redis.conf command: redis-server /usr/local/etc/redis/redis.conf healthcheck: - test: ["CMD", "redis-cli", "ping"] - interval: 10s + # Wait for Redis to finish loading before marking healthy + test: ["CMD-SHELL", "redis-cli ping | grep -q PONG && redis-cli INFO persistence | grep -q 'loading:0'"] + interval: 5s timeout: 5s - retries: 3 + retries: 10 + start_period: 10s networks: - sre-network @@ -273,7 +275,7 @@ services: context: . dockerfile: Dockerfile ports: - - "8000:8000" + - "8080:8000" environment: - REDIS_URL=redis://redis:6379/0 # Internal container port stays 6379 - TOOLS_PROMETHEUS_URL=http://prometheus:9090 @@ -327,6 +329,71 @@ services: networks: - sre-network + # GitHub MCP Server - Exposes GitHub tools via MCP + # This runs the GitHub MCP server behind an SSE/HTTP proxy so the sre-worker + # can connect to it without needing Docker-in-Docker permissions. + github-mcp: + image: ghcr.io/sparfenyuk/mcp-proxy:latest + ports: + - "8082:8082" + environment: + - GITHUB_PERSONAL_ACCESS_TOKEN=${GITHUB_PERSONAL_ACCESS_TOKEN} + command: > + --pass-environment + --port=8082 + --host=0.0.0.0 + docker run -i --rm -e GITHUB_PERSONAL_ACCESS_TOKEN ghcr.io/github/github-mcp-server + volumes: + - /var/run/docker.sock:/var/run/docker.sock + networks: + - sre-network + profiles: + - mcp # Start with: docker compose --profile mcp up + + # SRE Agent MCP Server - Exposes agent capabilities via Model Context Protocol + # Connect Claude to this via: Settings > Connectors > Add Custom Connector + # HTTP: http://localhost:8081/mcp + # HTTPS: https://localhost:8450/mcp (requires running scripts/generate-mcp-certs.sh first) + sre-mcp: + build: + context: . + dockerfile: Dockerfile + ports: + - "8081:8081" + environment: + - REDIS_URL=redis://redis:6379/0 + - REDIS_SRE_MASTER_KEY=${REDIS_SRE_MASTER_KEY} + - TOOLS_PROMETHEUS_URL=http://prometheus:9090 + - TOOLS_LOKI_URL=http://loki:3100 + depends_on: + redis: + condition: service_healthy + volumes: + - .env:/app/.env + - ./redis_sre_agent:/app/redis_sre_agent + command: uv run redis-sre-agent mcp serve --transport http --host 0.0.0.0 --port 8081 + networks: + - sre-network + profiles: + - mcp # Start with: docker compose --profile mcp up + - ssl # Or with SSL: docker compose --profile ssl up + + # MCP SSL Proxy - HTTPS termination for MCP server + # Run scripts/generate-mcp-certs.sh first to generate self-signed certs + sre-mcp-ssl: + image: nginx:alpine + ports: + - "8450:443" + volumes: + - ./monitoring/nginx/mcp-ssl.conf:/etc/nginx/conf.d/default.conf:ro + - ./monitoring/nginx/certs:/etc/nginx/certs:ro + depends_on: + - sre-mcp + networks: + - sre-network + profiles: + - ssl # Only start with: docker compose --profile ssl up + # SRE Agent UI sre-ui: build: diff --git a/docs/concepts/core.md b/docs/concepts/core.md index 70654503..abcb5533 100644 --- a/docs/concepts/core.md +++ b/docs/concepts/core.md @@ -27,7 +27,7 @@ This section explains the core ideas behind Redis SRE Agent and how pieces fit t When you create a task, the API creates or reuses a thread to store the execution history. You can: - Poll the task for status: `GET /api/v1/tasks/{task_id}` - Read the thread for results: `GET /api/v1/threads/{thread_id}` - - Stream updates via WebSocket: `ws://localhost:8000/api/v1/ws/tasks/{thread_id}` + - Stream updates via WebSocket: `ws://localhost:8080/api/v1/ws/tasks/{thread_id}` (Docker Compose) or port 8000 (local) - **Jobs** - Ad-hoc jobs: On-demand via CLI or API. Each run creates a task and streams results to a thread. diff --git a/docs/how-to/api.md b/docs/how-to/api.md index 62de4a33..98c39eff 100644 --- a/docs/how-to/api.md +++ b/docs/how-to/api.md @@ -6,6 +6,8 @@ This guide shows how to use the HTTP API end-to-end: check health, add an instan - Services running (Docker Compose or local uvicorn + worker) - If you enabled auth in your environment, include your API key header as needed +**Port Note**: Docker Compose exposes the API on port **8080**, while local uvicorn uses port **8000**. Examples below use port 8080 (Docker Compose). Replace with 8000 if running locally. + ### 1) Start services (choose one) - Docker Compose ```bash @@ -26,20 +28,21 @@ uv run redis-sre-agent worker --concurrency 4 ### 2) Health and readiness ```bash # Root health (fast) -curl -fsS http://localhost:8000/ +# Use port 8080 for Docker Compose, port 8000 for local uvicorn +curl -fsS http://localhost:8080/ # Detailed health (Redis, vector index, workers) -curl -fsS http://localhost:8000/api/v1/health | jq +curl -fsS http://localhost:8080/api/v1/health | jq # Prometheus metrics (scrape this) -curl -fsS http://localhost:8000/api/v1/metrics | head -n 20 +curl -fsS http://localhost:8080/api/v1/metrics | head -n 20 ``` ### 3) Manage Redis instances Create the instance the agent will triage, then verify a connection. ```bash # Create instance -curl -fsS -X POST http://localhost:8000/api/v1/instances \ +curl -fsS -X POST http://localhost:8080/api/v1/instances \ -H 'Content-Type: application/json' \ -d '{ "name": "prod-cache", @@ -50,14 +53,14 @@ curl -fsS -X POST http://localhost:8000/api/v1/instances \ }' | jq # List & inspect -curl -fsS http://localhost:8000/api/v1/instances | jq -curl -fsS http://localhost:8000/api/v1/instances/ | jq +curl -fsS http://localhost:8080/api/v1/instances | jq +curl -fsS http://localhost:8080/api/v1/instances/ | jq # Test connection (by ID) -curl -fsS -X POST http://localhost:8000/api/v1/instances//test-connection | jq +curl -fsS -X POST http://localhost:8080/api/v1/instances//test-connection | jq # Test a raw URL (without saving) -curl -fsS -X POST http://localhost:8000/api/v1/instances/test-connection-url \ +curl -fsS -X POST http://localhost:8080/api/v1/instances/test-connection-url \ -H 'Content-Type: application/json' \ -d '{"connection_url": "redis://host:6379/0"}' | jq ``` @@ -69,14 +72,16 @@ curl -fsS -X POST http://localhost:8000/api/v1/instances/test-connection-url \ ### 4) Triage with tasks and threads Simplest: create a task with your question. The API will create a thread if you omit `thread_id`. + +> **Note**: Triage performs comprehensive analysis (metrics, logs, knowledge base, multi-topic recommendations) and typically takes **2-10 minutes** to complete. Poll the task status or use WebSocket for real-time updates. ```bash # Create a task (no instance) -curl -fsS -X POST http://localhost:8000/api/v1/tasks \ +curl -fsS -X POST http://localhost:8080/api/v1/tasks \ -H 'Content-Type: application/json' \ -d '{"message": "Explain high memory usage signals in Redis"}' | jq # Create a task (target a specific instance) -curl -fsS -X POST http://localhost:8000/api/v1/tasks \ +curl -fsS -X POST http://localhost:8080/api/v1/tasks \ -H 'Content-Type: application/json' \ -d '{ "message": "Check memory pressure and slow ops", @@ -86,15 +91,15 @@ curl -fsS -X POST http://localhost:8000/api/v1/tasks \ Poll task or inspect the thread: ```bash # Poll task status -curl -fsS http://localhost:8000/api/v1/tasks/ | jq +curl -fsS http://localhost:8080/api/v1/tasks/ | jq # Get the thread state (messages, updates, result) -curl -fsS http://localhost:8000/api/v1/threads/ | jq +curl -fsS http://localhost:8080/api/v1/threads/ | jq ``` Real-time updates via WebSocket: ```bash # Requires a thread_id; use any ws client (wscat, websocat) -wscat -c ws://localhost:8000/api/v1/ws/tasks/ +wscat -c ws://localhost:8080/api/v1/ws/tasks/ # You will receive an initial_state event and subsequent progress updates ``` @@ -103,12 +108,12 @@ wscat -c ws://localhost:8000/api/v1/ws/tasks/ Alternative flow: create a thread first, then submit a task on that thread. ```bash # Create thread -curl -fsS -X POST http://localhost:8000/api/v1/threads \ +curl -fsS -X POST http://localhost:8080/api/v1/threads \ -H 'Content-Type: application/json' \ -d '{"user_id": "u1", "subject": "Prod triage"}' | jq # Submit a task to that thread -curl -fsS -X POST http://localhost:8000/api/v1/tasks \ +curl -fsS -X POST http://localhost:8080/api/v1/tasks \ -H 'Content-Type: application/json' \ -d '{ "thread_id": "", @@ -121,20 +126,20 @@ curl -fsS -X POST http://localhost:8000/api/v1/tasks \ Run an ingestion job, then search to confirm content is available. ```bash # Start pipeline job (ingest existing artifacts or run full if configured) -curl -fsS -X POST http://localhost:8000/api/v1/knowledge/ingest/pipeline \ +curl -fsS -X POST http://localhost:8080/api/v1/knowledge/ingest/pipeline \ -H 'Content-Type: application/json' \ -d '{"operation": "ingest", "artifacts_path": "./artifacts"}' | jq # List jobs & check individual job status -curl -fsS http://localhost:8000/api/v1/knowledge/jobs | jq -curl -fsS http://localhost:8000/api/v1/knowledge/jobs/ | jq +curl -fsS http://localhost:8080/api/v1/knowledge/jobs | jq +curl -fsS http://localhost:8080/api/v1/knowledge/jobs/ | jq # Search knowledge -curl -fsS 'http://localhost:8000/api/v1/knowledge/search?query=redis%20eviction%20policy' | jq +curl -fsS 'http://localhost:8080/api/v1/knowledge/search?query=redis%20eviction%20policy' | jq ``` Optional single-document ingestion: ```bash -curl -fsS -X POST http://localhost:8000/api/v1/knowledge/ingest/document \ +curl -fsS -X POST http://localhost:8080/api/v1/knowledge/ingest/document \ -H 'Content-Type: application/json' \ -d '{ "title": "Redis memory troubleshooting", @@ -148,7 +153,7 @@ curl -fsS -X POST http://localhost:8000/api/v1/knowledge/ingest/document \ Create a schedule to run instructions periodically, optionally bound to an instance. ```bash # Create schedule (daily) -curl -fsS -X POST http://localhost:8000/api/v1/schedules/ \ +curl -fsS -X POST http://localhost:8080/api/v1/schedules/ \ -H 'Content-Type: application/json' \ -d '{ "name": "daily-triage", @@ -161,20 +166,20 @@ curl -fsS -X POST http://localhost:8000/api/v1/schedules/ \ }' | jq # List/get -curl -fsS http://localhost:8000/api/v1/schedules/ | jq -curl -fsS http://localhost:8000/api/v1/schedules/ | jq +curl -fsS http://localhost:8080/api/v1/schedules/ | jq +curl -fsS http://localhost:8080/api/v1/schedules/ | jq # Trigger now (manual run) -curl -fsS -X POST http://localhost:8000/api/v1/schedules//trigger | jq +curl -fsS -X POST http://localhost:8080/api/v1/schedules//trigger | jq # View runs for a schedule -curl -fsS http://localhost:8000/api/v1/schedules//runs | jq +curl -fsS http://localhost:8080/api/v1/schedules//runs | jq ``` ### 7) Tasks, threads, and streaming - Tasks: `GET /api/v1/tasks/{task_id}` - Threads: `GET /api/v1/threads`, `GET /api/v1/threads/{thread_id}` -- WebSocket: `ws://localhost:8000/api/v1/ws/tasks/{thread_id}` +- WebSocket: `ws://localhost:8080/api/v1/ws/tasks/{thread_id}` ### 8) Observability - Prometheus scrape: `GET /api/v1/metrics` diff --git a/docs/how-to/cli.md b/docs/how-to/cli.md index 3bd4c7a7..cd1c6366 100644 --- a/docs/how-to/cli.md +++ b/docs/how-to/cli.md @@ -15,7 +15,7 @@ docker compose up -d \ prometheus grafana \ loki promtail ``` - - API: http://localhost:8000 + - API: http://localhost:8080 (Docker Compose exposes port 8080) - Local processes (no Docker) ```bash # API @@ -122,4 +122,4 @@ uv run redis-sre-agent thread sources ### Tips - Use the Docker stack to get Prometheus/Loki; set TOOLS_PROMETHEUS_URL and TOOLS_LOKI_URL so the agent can fetch metrics/logs. - Prefer `docker compose exec -T sre-agent uv run ...` inside containers when running in Docker (uses in-cluster addresses). -- Health endpoints: `curl http://localhost:8000/` and `/api/v1/health` to verify API and worker availability. +- Health endpoints: `curl http://localhost:8080/` (Docker Compose) or `http://localhost:8000/` (local uvicorn) and `/api/v1/health` to verify API and worker availability. diff --git a/docs/how-to/configuration.md b/docs/how-to/configuration.md index 531cee24..215a3fad 100644 --- a/docs/how-to/configuration.md +++ b/docs/how-to/configuration.md @@ -6,9 +6,68 @@ This guide explains how the Redis SRE Agent is configured, what the required and Configuration values are loaded from these sources (highest precedence first): -- Environment variables (recommended for prod) -- `.env` file (loaded automatically in dev if present) -- Code defaults in `redis_sre_agent/core/config.py` +1. Environment variables (recommended for prod) +2. `.env` file (loaded automatically in dev if present) +3. **YAML config file** (for complex nested configurations like MCP servers) +4. Code defaults in `redis_sre_agent/core/config.py` + +### YAML configuration + +For complex nested settings like MCP server configurations, you can use a YAML config file. This is particularly useful for configuring multiple MCP servers with tool descriptions. + +**Config file discovery order:** + +1. Path specified in `SRE_AGENT_CONFIG` environment variable +2. `config.yaml` in the current working directory +3. `config.yml` in the current working directory +4. `sre_agent_config.yaml` in the current working directory +5. `sre_agent_config.yml` in the current working directory + +**Example `config.yaml`:** + +```yaml +# Application settings +debug: false +log_level: INFO + +# MCP (Model Context Protocol) servers configuration +mcp_servers: + # Memory server for long-term agent memory + redis-memory-server: + command: uv + args: + - tool + - run + - --from + - agent-memory-server + - agent-memory + - mcp + env: + REDIS_URL: redis://localhost:6399 + tools: + search_long_term_memory: + description: | + Search saved memories about Redis instances. ALWAYS use this + before troubleshooting to recall past issues and solutions. + {original} + + # GitHub MCP server (remote) - uses GitHub's hosted MCP endpoint + # Requires a GitHub Personal Access Token with appropriate permissions + # Uses Streamable HTTP transport (default for URL-based connections) + github: + url: "https://api.githubcopilot.com/mcp/" + headers: + Authorization: "Bearer ${GITHUB_PERSONAL_ACCESS_TOKEN}" + # transport: streamable_http # default, can also be 'sse' for legacy servers +``` + +See `config.yaml.example` for a complete example with all available options. + +**Using a custom config path:** + +```bash +export SRE_AGENT_CONFIG=/path/to/my-config.yaml +``` ### Required diff --git a/docs/how-to/local-dev.md b/docs/how-to/local-dev.md index a9b1034c..2af4b395 100644 --- a/docs/how-to/local-dev.md +++ b/docs/how-to/local-dev.md @@ -52,6 +52,7 @@ docker compose up -d \ # Logs docker compose logs -f sre-agent ``` +**Note**: Docker Compose exposes the API on port **8080** (http://localhost:8080), while local uvicorn uses port 8000. ### 5) Create a demo instance (optional) ```bash diff --git a/docs/how-to/scheduling-flows.md b/docs/how-to/scheduling-flows.md index 92fa826f..594c07da 100644 --- a/docs/how-to/scheduling-flows.md +++ b/docs/how-to/scheduling-flows.md @@ -24,7 +24,7 @@ uv run redis-sre-agent schedule run-now ### 2) Create a schedule (API) ```bash -curl -X POST http://localhost:8000/api/v1/schedules \ +curl -X POST http://localhost:8080/api/v1/schedules \ -H "Content-Type: application/json" \ -d '{ "name": "redis-health", @@ -38,20 +38,20 @@ curl -X POST http://localhost:8000/api/v1/schedules \ List schedules: ```bash -curl http://localhost:8000/api/v1/schedules/ +curl http://localhost:8080/api/v1/schedules/ ``` Get a schedule: ```bash -curl http://localhost:8000/api/v1/schedules/{schedule_id} +curl http://localhost:8080/api/v1/schedules/{schedule_id} ``` Trigger a run immediately: ```bash -curl -X POST http://localhost:8000/api/v1/schedules/{schedule_id}/trigger +curl -X POST http://localhost:8080/api/v1/schedules/{schedule_id}/trigger ``` List recent runs: ```bash -curl http://localhost:8000/api/v1/schedules/{schedule_id}/runs +curl http://localhost:8080/api/v1/schedules/{schedule_id}/runs ``` diff --git a/docs/operations/observability.md b/docs/operations/observability.md index 185871bc..b6c14881 100644 --- a/docs/operations/observability.md +++ b/docs/operations/observability.md @@ -11,13 +11,13 @@ The docker-compose stack includes Prometheus, Grafana, Loki, and Tempo for local ### Quick health check Fast endpoint for load balancers (no external dependencies): ```bash -curl -fsS http://localhost:8000/ +curl -fsS http://localhost:8080/ ``` ### Detailed health check Checks Redis connectivity, vector index, and worker availability: ```bash -curl -fsS http://localhost:8000/api/v1/health | jq +curl -fsS http://localhost:8080/api/v1/health | jq ``` Returns status and component details. Status may be `degraded` if workers aren't running. @@ -43,7 +43,7 @@ The agent exposes Prometheus metrics at `/api/v1/metrics` for scraping. ### Scrape the API ```bash -curl -fsS http://localhost:8000/api/v1/metrics | head -n 30 +curl -fsS http://localhost:8080/api/v1/metrics | head -n 30 ``` ### Prometheus configuration diff --git a/docs/quickstarts/local.md b/docs/quickstarts/local.md index f7dc85ce..0ebf741e 100644 --- a/docs/quickstarts/local.md +++ b/docs/quickstarts/local.md @@ -23,16 +23,16 @@ docker compose up -d \ sre-agent sre-worker sre-ui ``` Notes: -- API: http://localhost:8000 +- API: http://localhost:8080 - Grafana: http://localhost:3001 (admin/admin) - Experimental UI: http://localhost:3002 (proxied to API) ### 3) Check status ```bash # API root health -curl http://localhost:8000/ +curl http://localhost:8080/ # Detailed health (Redis, Docket/worker availability, etc.) -curl http://localhost:8000/api/v1/health +curl http://localhost:8080/api/v1/health # Prometheus curl http://localhost:9090/-/ready ``` diff --git a/docs/reference/cli.md b/docs/reference/cli.md index b0ae488c..07a234bd 100644 --- a/docs/reference/cli.md +++ b/docs/reference/cli.md @@ -68,6 +68,53 @@ Generated from the Click command tree. - runbook evaluate — Evaluate existing runbooks in the source documents directory. - runbook generate — Generate a new Redis SRE runbook for the specified topic. - query — Execute an agent query. + + Supports conversation threads for multi-turn interactions. Use --thread-id + to continue an existing conversation, or omit it to start a new one. + +  + The agent is automatically selected based on the query, or use --agent: + - knowledge: General Redis questions (no instance needed) + - chat: Quick questions with a Redis instance + - triage: Full health checks and diagnostics + - auto: Let the router decide (default) - worker — Start the background worker. +- mcp — MCP server commands - expose agent capabilities via Model Context Protocol. +- mcp list-tools — List available MCP tools. +- mcp serve — Start the MCP server. + + The MCP server exposes the Redis SRE Agent's capabilities to other + MCP-compatible AI agents. + +  + Available tools: + - triage: Start a Redis troubleshooting session + - get_task_status: Check if a triage task is complete + - get_thread: Get the full results from a triage + - knowledge_search: Search Redis documentation and runbooks + - list_instances: List configured Redis instances + - create_instance: Register a new Redis instance + +  + Examples: + # Run in stdio mode (for Claude Desktop local config) + redis-sre-agent mcp serve + +  + # Run in HTTP mode (for Claude remote connector - RECOMMENDED) + redis-sre-agent mcp serve --transport http --port 8081 + # Then add in Claude: Settings > Connectors > Add Custom Connector + # URL: http://your-host:8081/mcp + +  + # Run in SSE mode (legacy, for older clients) + redis-sre-agent mcp serve --transport sse --port 8081 +- index — RediSearch index management commands. +- index list — List all SRE agent indices and their status. +- index recreate — Drop and recreate RediSearch indices. + + This is useful when the schema has changed (e.g., new fields added). + WARNING: This will delete all indexed data. The underlying Redis keys + remain, but you'll need to re-index documents. See How-to guides for examples. diff --git a/docs/reference/configuration.md b/docs/reference/configuration.md index 07f47ed6..9088f1eb 100644 --- a/docs/reference/configuration.md +++ b/docs/reference/configuration.md @@ -2,10 +2,34 @@ Key environment variables and pointers. For step-by-step setup, see: how-to/configuration.md -- OPENAI_API_KEY: LLM access -- REDIS_SRE_MASTER_KEY: 32-byte base64 master key for envelope encryption -- TOOLS_PROMETHEUS_URL, TOOLS_LOKI_URL: Provider endpoints -- REDIS_URL: Agent storage Redis URL (for local/dev) +### Environment Variables + +- `OPENAI_API_KEY`: LLM access (required) +- `REDIS_SRE_MASTER_KEY`: 32-byte base64 master key for envelope encryption +- `TOOLS_PROMETHEUS_URL`, `TOOLS_LOKI_URL`: Provider endpoints +- `REDIS_URL`: Agent storage Redis URL (for local/dev) +- `SRE_AGENT_CONFIG`: Path to YAML config file (optional) + +### YAML Configuration + +For complex nested settings, use a YAML config file (`config.yaml`): + +```yaml +mcp_servers: + server-name: + command: string # Command to run (e.g., "npx", "docker", "uv") + args: [string] # Command arguments + env: {key: value} # Environment variables + url: string # Optional: URL for HTTP-based servers + headers: {key: value} # Optional: Headers for HTTP transport (e.g., Authorization) + transport: string # Optional: 'streamable_http' (default) or 'sse' + tools: # Optional: Tool-specific configurations + tool-name: + description: string # Override tool description ({original} for default) + capability: string # Tool capability category +``` + +See `config.yaml.example` for a complete example. ### See also diff --git a/monitoring/nginx/mcp-ssl.conf b/monitoring/nginx/mcp-ssl.conf new file mode 100644 index 00000000..8d23c217 --- /dev/null +++ b/monitoring/nginx/mcp-ssl.conf @@ -0,0 +1,33 @@ +server { + listen 443 ssl; + server_name localhost; + + ssl_certificate /etc/nginx/certs/server.crt; + ssl_certificate_key /etc/nginx/certs/server.key; + + ssl_protocols TLSv1.2 TLSv1.3; + ssl_ciphers HIGH:!aNULL:!MD5; + + # Use Docker's internal DNS resolver for dynamic resolution + resolver 127.0.0.11 valid=30s ipv6=off; + + location / { + # Use variable to force runtime DNS resolution + set $upstream http://sre-mcp:8081; + proxy_pass $upstream; + + proxy_http_version 1.1; + proxy_set_header Upgrade $http_upgrade; + proxy_set_header Connection "upgrade"; + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; + proxy_set_header X-Forwarded-Proto $scheme; + + # SSE/streaming support + proxy_buffering off; + proxy_cache off; + proxy_read_timeout 86400s; + proxy_send_timeout 86400s; + } +} diff --git a/pyproject.toml b/pyproject.toml index 9bb954d4..8904f651 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,8 @@ dependencies = [ "opentelemetry-instrumentation-httpx>=0.57b0", "opentelemetry-instrumentation-aiohttp-client>=0.57b0", "opentelemetry-instrumentation-openai>=0.47.5", + "mcp>=1.23.3", + "nltk>=3.9.1", ] [dependency-groups] diff --git a/redis_sre_agent/agent/__init__.py b/redis_sre_agent/agent/__init__.py index adc00945..e23a19fb 100644 --- a/redis_sre_agent/agent/__init__.py +++ b/redis_sre_agent/agent/__init__.py @@ -1,5 +1,6 @@ """SRE Agent module.""" +from .chat_agent import ChatAgent, get_chat_agent from .langgraph_agent import SRELangGraphAgent, get_sre_agent -__all__ = ["SRELangGraphAgent", "get_sre_agent"] +__all__ = ["SRELangGraphAgent", "get_sre_agent", "ChatAgent", "get_chat_agent"] diff --git a/redis_sre_agent/agent/chat_agent.py b/redis_sre_agent/agent/chat_agent.py new file mode 100644 index 00000000..0257aefd --- /dev/null +++ b/redis_sre_agent/agent/chat_agent.py @@ -0,0 +1,490 @@ +""" +Lightweight Chat Agent for fast Redis instance interaction. + +This agent is designed for quick Q&A when a Redis instance is available +but the user doesn't need a full health check or triage. It has access +to all Redis tools but uses a simpler workflow without deep research +or safety-evaluation chains. +""" + +import json +import logging +from typing import Any, Awaitable, Callable, Dict, List, Optional, TypedDict + +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage, ToolMessage +from langchain_core.tools import StructuredTool +from langchain_openai import ChatOpenAI +from langgraph.checkpoint.memory import MemorySaver +from langgraph.graph import END, StateGraph +from langgraph.prebuilt import ToolNode as LGToolNode +from opentelemetry import trace + +from redis_sre_agent.core.config import settings +from redis_sre_agent.core.instances import RedisInstance +from redis_sre_agent.core.progress import ( + CallbackEmitter, + NullEmitter, + ProgressEmitter, +) +from redis_sre_agent.tools.manager import ToolManager +from redis_sre_agent.tools.models import ToolCapability + +from .helpers import build_result_envelope + +logger = logging.getLogger(__name__) +tracer = trace.get_tracer(__name__) + + +CHAT_SYSTEM_PROMPT = """You are a Redis SRE agent. A user is asking about a specific Redis deployment. +You have access to the full toolset needed to inspect the deployment and answer questions about how Redis behaves in this context. + +## Your Approach +- Respond quickly and directly to the user's question +- Use tools to gather the specific information needed +- Don't perform exhaustive diagnostics unless asked +- Focus on answering what was asked, not a full health assessment + +## Tool Usage - BATCH YOUR CALLS +**CRITICAL: Call multiple tools in a single response whenever possible.** + +When you need to gather information, request ALL relevant tools at once: +- ❌ WRONG: Call one tool, wait, call another, wait... +- ✅ CORRECT: Call get_detailed_redis_diagnostics, get_cluster_info, and search_knowledge_base together in one turn + +Think about what information you'll need and request it all at once. This is much faster. + +## Guidelines +- Call tools as needed to answer the question +- Keep responses concise and actionable +- Cite specific data from tool results +- If the user wants a comprehensive health check, suggest they ask for a "full triage" instead + +## Redis Enterprise / Redis Cloud Notes +- For managed Redis (Enterprise or Cloud), INFO output can be misleading +- Use the Admin REST API tools for accurate configuration details +- Don't suggest CONFIG SET for managed deployments +""" + + +class ChatAgentState(TypedDict): + """State for the chat agent.""" + + messages: List[BaseMessage] + session_id: str + user_id: str + current_tool_calls: List[Dict[str, Any]] + iteration_count: int + max_iterations: int + # Accumulated tool result envelopes for context management + signals_envelopes: List[Dict[str, Any]] + + +class ChatAgent: + """Lightweight LangGraph-based agent for quick Redis Q&A. + + This agent has access to all Redis tools but uses a simpler workflow + optimized for fast, targeted responses rather than comprehensive triage. + """ + + # Threshold for summarizing tool outputs (chars) + ENVELOPE_SUMMARY_THRESHOLD = 500 + + def __init__( + self, + redis_instance: Optional[RedisInstance] = None, + progress_emitter: Optional[ProgressEmitter] = None, + progress_callback: Optional[Callable[[str, str], Awaitable[None]]] = None, + exclude_mcp_categories: Optional[List["ToolCapability"]] = None, + ): + """Initialize the Chat agent. + + Args: + redis_instance: Optional Redis instance for context + progress_emitter: Emitter for progress/notification updates + progress_callback: DEPRECATED - use progress_emitter instead + exclude_mcp_categories: Optional list of MCP tool capability categories to exclude. + Use this to filter out specific types of MCP tools. Common categories: + METRICS, LOGS, TICKETS, REPOS, TRACES, DIAGNOSTICS, KNOWLEDGE, UTILITIES. + """ + self.settings = settings + self.redis_instance = redis_instance + self.exclude_mcp_categories = exclude_mcp_categories + + # Handle emitter (prefer progress_emitter, fall back to callback wrapper) + if progress_emitter is not None: + self._emitter = progress_emitter + elif progress_callback is not None: + self._emitter = CallbackEmitter(progress_callback) + else: + self._emitter = NullEmitter() + + self.llm = ChatOpenAI( + model=self.settings.openai_model, + openai_api_key=self.settings.openai_api_key, + ) + self.mini_llm = ChatOpenAI( + model=self.settings.openai_model_mini, + openai_api_key=self.settings.openai_api_key, + ) + + logger.info( + f"Chat agent initialized (instance: {redis_instance.name if redis_instance else 'none'})" + ) + + def _build_expand_evidence_tool( + self, + original_envelopes: List[Dict[str, Any]], + ) -> Dict[str, Any]: + """Build a tool that allows the LLM to retrieve full tool output details. + + When we summarize tool outputs, the LLM only sees condensed versions. + This tool lets the LLM request the full original output for any tool_key + if it needs more detail. + + Args: + original_envelopes: The original (unsummarized) envelopes + + Returns: + A dict with name, description, func, and parameters for creating a tool + """ + originals_by_key = {e.get("tool_key"): e for e in original_envelopes} + available_keys = list(originals_by_key.keys()) + + def expand_evidence(tool_key: str) -> Dict[str, Any]: + """Retrieve the full, unsummarized output from a previous tool call.""" + if tool_key not in originals_by_key: + return { + "status": "error", + "error": f"Unknown tool_key: {tool_key}. Available: {available_keys}", + } + original = originals_by_key[tool_key] + return { + "status": "success", + "tool_key": tool_key, + "name": original.get("name"), + "full_data": original.get("data"), + } + + return { + "name": "expand_evidence", + "description": ( + "Retrieve the full, unsummarized output from a previous tool call. " + "Use this when the summary doesn't have enough detail. " + f"Available tool_keys: {available_keys}" + ), + "func": expand_evidence, + "parameters": { + "type": "object", + "properties": { + "tool_key": { + "type": "string", + "description": "The tool_key from a summarized evidence item", + } + }, + "required": ["tool_key"], + }, + } + + def _summarize_envelope_sync(self, env: Dict[str, Any]) -> Dict[str, Any]: + """Synchronously truncate large envelope data (simple fallback). + + For chat agent, we use simple truncation rather than LLM summarization + to keep things fast. + """ + data_str = json.dumps(env.get("data", {}), default=str) + if len(data_str) <= self.ENVELOPE_SUMMARY_THRESHOLD: + return env + + # Truncate large data + return { + "tool_key": env.get("tool_key"), + "name": env.get("name"), + "description": env.get("description"), + "args": env.get("args"), + "status": env.get("status"), + "data": { + "summary": data_str[: self.ENVELOPE_SUMMARY_THRESHOLD] + "...", + "note": "Data truncated. Use expand_evidence tool to get full output.", + }, + } + + def _build_workflow( + self, + tool_mgr: ToolManager, + llm_with_tools: ChatOpenAI, + adapters: List[Any], + emitter: Optional[ProgressEmitter] = None, + ) -> StateGraph: + """Build the LangGraph workflow for chat interactions. + + Args: + tool_mgr: ToolManager instance for resolving tool calls + llm_with_tools: LLM instance with tools bound + adapters: List of tool adapters for the ToolNode + emitter: Optional progress emitter for status updates + """ + tooldefs_by_name = {t.name: t for t in tool_mgr.get_tools()} + + # We'll dynamically add expand_evidence tool when envelopes are available + # For now, track state needed for dynamic tool injection + expand_tool_added = {"value": False} + current_adapters = list(adapters) + + async def agent_node(state: ChatAgentState) -> Dict[str, Any]: + """Main agent node - invokes LLM with tools.""" + messages = state["messages"] + iteration_count = state.get("iteration_count", 0) + envelopes = state.get("signals_envelopes") or [] + + # If we have envelopes and haven't added expand_evidence yet, add it + nonlocal current_adapters, expand_tool_added + if envelopes and not expand_tool_added["value"]: + expand_spec = self._build_expand_evidence_tool(envelopes) + expand_tool = StructuredTool.from_function( + func=expand_spec["func"], + name=expand_spec["name"], + description=expand_spec["description"], + ) + current_adapters = list(adapters) + [expand_tool] + expand_tool_added["value"] = True + # Rebind tools to LLM with expand_evidence + bound_llm = self.llm.bind_tools(current_adapters) + else: + bound_llm = llm_with_tools + + with tracer.start_as_current_span("chat_agent_node"): + response = await bound_llm.ainvoke(messages) + + new_messages = list(messages) + [response] + return { + "messages": new_messages, + "iteration_count": iteration_count + 1, + "current_tool_calls": response.tool_calls + if hasattr(response, "tool_calls") + else [], + } + + async def tool_node(state: ChatAgentState) -> Dict[str, Any]: + """Execute tool calls from the agent.""" + messages = state["messages"] + envelopes = list(state.get("signals_envelopes") or []) + + # Get pending tool calls from the last AI message + last_msg = messages[-1] if messages else None + tool_calls = [] + if isinstance(last_msg, AIMessage) and hasattr(last_msg, "tool_calls"): + tool_calls = last_msg.tool_calls or [] + + # Emit progress updates for each tool call + if emitter and tool_calls: + for tc in tool_calls: + tool_name = ( + tc.get("name") if isinstance(tc, dict) else getattr(tc, "name", None) + ) + tool_args = ( + tc.get("args") if isinstance(tc, dict) else getattr(tc, "args", {}) + ) or {} + if tool_name: + # Try to get provider-supplied status message + status_msg = tool_mgr.get_status_update(tool_name, tool_args) + if status_msg: + await emitter.emit(status_msg, "tool_call") + else: + # Default status message + await emitter.emit(f"Executing tool: {tool_name}", "tool_call") + + with tracer.start_as_current_span("chat_tool_node"): + nonlocal current_adapters + lg_tool_node = LGToolNode(current_adapters) + out = await lg_tool_node.ainvoke({"messages": messages}) + out_messages = out.get("messages", []) + new_tool_messages = [m for m in out_messages if isinstance(m, ToolMessage)] + + # Build envelopes for each tool call result + for idx, tc in enumerate(tool_calls): + tool_name = ( + tc.get("name") if isinstance(tc, dict) else getattr(tc, "name", None) + ) + tool_args = ( + tc.get("args") if isinstance(tc, dict) else getattr(tc, "args", {}) + ) or {} + + # Skip expand_evidence calls - they don't need envelope tracking + if tool_name == "expand_evidence": + continue + + tm = new_tool_messages[idx] if idx < len(new_tool_messages) else None + env_dict = build_result_envelope( + tool_name or f"tool_{idx + 1}", tool_args, tm, tooldefs_by_name + ) + # Summarize if large + env_dict = self._summarize_envelope_sync(env_dict) + envelopes.append(env_dict) + + return { + "messages": list(messages) + new_tool_messages, + "current_tool_calls": [], + "signals_envelopes": envelopes, + } + + def should_continue(state: ChatAgentState) -> str: + """Decide whether to continue with tools or end.""" + messages = state["messages"] + iteration_count = state.get("iteration_count", 0) + max_iterations = state.get("max_iterations", 10) + + if iteration_count >= max_iterations: + logger.warning(f"Chat agent reached max iterations ({max_iterations})") + return END + + if messages and isinstance(messages[-1], AIMessage) and messages[-1].tool_calls: + return "tools" + + if state.get("current_tool_calls"): + return "tools" + + return END + + workflow = StateGraph(ChatAgentState) + workflow.add_node("agent", agent_node) + workflow.add_node("tools", tool_node) + workflow.set_entry_point("agent") + workflow.add_conditional_edges("agent", should_continue, {"tools": "tools", END: END}) + workflow.add_edge("tools", "agent") + + return workflow + + async def process_query( + self, + query: str, + session_id: str, + user_id: str, + max_iterations: int = 10, + context: Optional[Dict[str, Any]] = None, + progress_emitter: Optional[ProgressEmitter] = None, + progress_callback: Optional[Callable[[str, str], Awaitable[None]]] = None, + conversation_history: Optional[List[BaseMessage]] = None, + ) -> str: + """Process a query with quick tool access. + + Args: + query: User's question + session_id: Session identifier + user_id: User identifier + max_iterations: Maximum agent iterations (default 10) + context: Additional context (e.g., instance_id) + progress_emitter: Emitter for progress/notification updates + progress_callback: DEPRECATED - use progress_emitter instead + conversation_history: Optional previous messages for context + + Returns: + Agent's response as a string + """ + logger.info(f"Chat agent processing query for user {user_id}") + + # Use provided emitter, or fall back to instance emitter + if progress_emitter is not None: + emitter = progress_emitter + elif progress_callback is not None: + emitter = CallbackEmitter(progress_callback) + else: + emitter = self._emitter + + # Create ToolManager with Redis instance for full tool access + async with ToolManager( + redis_instance=self.redis_instance, + exclude_mcp_categories=self.exclude_mcp_categories, + ) as tool_mgr: + tools = tool_mgr.get_tools() + logger.info(f"Chat agent loaded {len(tools)} tools") + + from .helpers import build_adapters_for_tooldefs as _build_adapters + + adapters = await _build_adapters(tool_mgr, tools) + llm_with_tools = self.llm.bind_tools(adapters) + + workflow = self._build_workflow(tool_mgr, llm_with_tools, adapters, emitter) + + checkpointer = MemorySaver() + app = workflow.compile(checkpointer=checkpointer) + + # Build initial messages with instance context + initial_messages: List[BaseMessage] = [SystemMessage(content=CHAT_SYSTEM_PROMPT)] + + # Add instance context to the query if available + enhanced_query = query + if self.redis_instance: + repo_context = "" + if self.redis_instance.repo_url: + repo_context = f"""- Repository URL: {self.redis_instance.repo_url} + +If you have GitHub tools available, you can search the repository for code, configuration, or documentation related to this Redis instance. +""" + instance_context = f""" +INSTANCE CONTEXT: This query is about Redis instance: +- Instance Name: {self.redis_instance.name} +- Environment: {self.redis_instance.environment} +- Usage: {self.redis_instance.usage} +- Instance Type: {self.redis_instance.instance_type} +{repo_context} +Your diagnostic tools are PRE-CONFIGURED for this instance. + +User Query: {query}""" + enhanced_query = instance_context + + if conversation_history: + initial_messages.extend(conversation_history) + + initial_messages.append(HumanMessage(content=enhanced_query)) + + initial_state: ChatAgentState = { + "messages": initial_messages, + "session_id": session_id, + "user_id": user_id, + "current_tool_calls": [], + "iteration_count": 0, + "max_iterations": max_iterations, + "signals_envelopes": [], # Track tool outputs for expand_evidence + } + + thread_config = {"configurable": {"thread_id": session_id}} + + try: + await emitter.emit("Chat agent processing your question...", "agent_start") + + final_state = await app.ainvoke(initial_state, config=thread_config) + + messages = final_state.get("messages", []) + if messages: + last_message = messages[-1] + if isinstance(last_message, AIMessage): + return last_message.content + return str(last_message.content) + + return "I couldn't process that query. Please try rephrasing." + + except Exception as e: + logger.exception(f"Chat agent error: {e}") + return f"Error processing query: {e}" + + +# Singleton cache keyed by instance name +_chat_agents: Dict[str, ChatAgent] = {} + + +def get_chat_agent(redis_instance: Optional[RedisInstance] = None) -> ChatAgent: + """Get or create a chat agent, optionally for a specific Redis instance. + + Args: + redis_instance: Optional Redis instance for context + + Returns: + ChatAgent instance + """ + global _chat_agents + key = redis_instance.name if redis_instance else "__no_instance__" + + if key not in _chat_agents: + _chat_agents[key] = ChatAgent(redis_instance=redis_instance) + + return _chat_agents[key] diff --git a/redis_sre_agent/agent/knowledge_agent.py b/redis_sre_agent/agent/knowledge_agent.py index 60a5e21e..a21c8221 100644 --- a/redis_sre_agent/agent/knowledge_agent.py +++ b/redis_sre_agent/agent/knowledge_agent.py @@ -18,6 +18,11 @@ from opentelemetry import trace from redis_sre_agent.core.config import settings +from redis_sre_agent.core.progress import ( + CallbackEmitter, + NullEmitter, + ProgressEmitter, +) from redis_sre_agent.tools.manager import ToolManager logger = logging.getLogger(__name__) @@ -80,10 +85,26 @@ class KnowledgeOnlyAgent: It's designed for general Q&A when no Redis instance is specified. """ - def __init__(self, progress_callback: Optional[Callable[[str, str], Awaitable[None]]] = None): - """Initialize the Knowledge-only SRE agent.""" + def __init__( + self, + progress_emitter: Optional[ProgressEmitter] = None, + progress_callback: Optional[Callable[[str, str], Awaitable[None]]] = None, + ): + """Initialize the Knowledge-only SRE agent. + + Args: + progress_emitter: Emitter for progress/notification updates + progress_callback: DEPRECATED - use progress_emitter instead + """ self.settings = settings - self.progress_callback = progress_callback + + # Handle emitter (prefer progress_emitter, fall back to callback wrapper) + if progress_emitter is not None: + self._emitter = progress_emitter + elif progress_callback is not None: + self._emitter = CallbackEmitter(progress_callback) + else: + self._emitter = NullEmitter() # LLM optimized for knowledge tasks self.llm = ChatOpenAI( @@ -97,11 +118,15 @@ def __init__(self, progress_callback: Optional[Callable[[str, str], Awaitable[No logger.info("Knowledge-only agent initialized (tools loaded per-query)") - def _build_workflow(self, tool_mgr: ToolManager, llm_with_tools: ChatOpenAI) -> StateGraph: + def _build_workflow( + self, tool_mgr: ToolManager, llm_with_tools: ChatOpenAI, emitter: ProgressEmitter + ) -> StateGraph: """Build the LangGraph workflow for knowledge-only queries. Args: tool_mgr: ToolManager instance with knowledge tools loaded + llm_with_tools: LLM with tools bound + emitter: Emitter for progress notifications """ async def agent_node(state: KnowledgeAgentState) -> KnowledgeAgentState: @@ -268,11 +293,11 @@ async def safe_tool_node(state: KnowledgeAgentState) -> KnowledgeAgentState: "source": doc.get("source"), } ) - if fragments and self.progress_callback: - await self.progress_callback( - f"Found {len(fragments)} knowledge fragments", # message - "knowledge_sources", # update_type - {"fragments": fragments}, # metadata + if fragments: + await emitter.emit( + f"Found {len(fragments)} knowledge fragments", + "knowledge_sources", + {"fragments": fragments}, ) except Exception: # Don't let telemetry failures break tool handling @@ -378,7 +403,8 @@ async def process_query( user_id: str, max_iterations: int = 5, context: Optional[Dict[str, Any]] = None, - progress_callback=None, + progress_emitter: Optional[ProgressEmitter] = None, + progress_callback: Optional[Callable[[str, str], Awaitable[None]]] = None, conversation_history: Optional[List[BaseMessage]] = None, ) -> str: """ @@ -390,7 +416,8 @@ async def process_query( user_id: User identifier max_iterations: Maximum number of agent iterations context: Additional context (currently ignored for knowledge-only agent) - progress_callback: Optional callback for progress updates + progress_emitter: Emitter for progress/notification updates + progress_callback: DEPRECATED - use progress_emitter instead conversation_history: Optional list of previous messages for context Returns: @@ -398,9 +425,13 @@ async def process_query( """ logger.info(f"Processing knowledge query for user {user_id}") - # Set progress callback for this query - if progress_callback: - self.progress_callback = progress_callback + # Use provided emitter, or fall back to instance emitter + if progress_emitter is not None: + emitter = progress_emitter + elif progress_callback is not None: + emitter = CallbackEmitter(progress_callback) + else: + emitter = self._emitter # Create ToolManager with Redis instance-independent tools async with ToolManager(redis_instance=None) as tool_mgr: @@ -414,7 +445,7 @@ async def process_query( llm_with_tools = self.llm.bind_tools(adapters) # Build workflow with tools and bound LLM - workflow = self._build_workflow(tool_mgr, llm_with_tools) + workflow = self._build_workflow(tool_mgr, llm_with_tools, emitter) # Create initial state with conversation history initial_messages = [] @@ -447,11 +478,10 @@ async def process_query( } try: - # Progress callback for start - if self.progress_callback: - await self.progress_callback( - "Knowledge agent starting to process your query...", "agent_start" - ) + # Emit start notification + await emitter.emit( + "Knowledge agent starting to process your query...", "agent_start" + ) # Run the workflow (with recursion limit to match settings) final_state = await app.ainvoke(initial_state, config=thread_config) @@ -468,11 +498,10 @@ async def process_query( else: response = "I apologize, but I wasn't able to process your query. Please try asking a more specific question about SRE practices or troubleshooting." - # Progress callback for completion - if self.progress_callback: - await self.progress_callback( - "Knowledge agent has completed processing your query.", "agent_complete" - ) + # Emit completion notification + await emitter.emit( + "Knowledge agent has completed processing your query.", "agent_complete" + ) logger.info(f"Knowledge query completed for user {user_id}") return response @@ -481,10 +510,7 @@ async def process_query( logger.error(f"Knowledge agent processing failed: {e}") error_response = f"I encountered an error while processing your knowledge query: {str(e)}. Please try asking a more specific question about SRE practices, troubleshooting methodologies, or system reliability concepts." - if self.progress_callback: - await self.progress_callback( - f"Knowledge agent encountered an error: {str(e)}", "agent_error" - ) + await emitter.emit(f"Knowledge agent encountered an error: {str(e)}", "agent_error") return error_response diff --git a/redis_sre_agent/agent/langgraph_agent.py b/redis_sre_agent/agent/langgraph_agent.py index 36e2c739..b79f5f43 100644 --- a/redis_sre_agent/agent/langgraph_agent.py +++ b/redis_sre_agent/agent/langgraph_agent.py @@ -29,6 +29,7 @@ get_instances, save_instances, ) +from ..core.progress import CallbackEmitter, NullEmitter, ProgressEmitter from ..tools.manager import ToolManager from .helpers import build_adapters_for_tooldefs as _build_adapters from .helpers import log_preflight_messages @@ -313,10 +314,28 @@ class SREToolCall(BaseModel): class SRELangGraphAgent: """LangGraph-based SRE Agent with multi-turn conversation and tool calling.""" - def __init__(self, progress_callback=None): - """Initialize the SRE LangGraph agent.""" + def __init__( + self, + progress_callback=None, + progress_emitter: Optional[ProgressEmitter] = None, + ): + """Initialize the SRE LangGraph agent. + + Args: + progress_callback: Deprecated. Legacy callback for progress updates. + Use progress_emitter instead. + progress_emitter: ProgressEmitter instance for emitting status updates. + If not provided but progress_callback is, wraps callback + in a CallbackEmitter for backward compatibility. + """ self.settings = settings - self.progress_callback = progress_callback + # Support both new emitter and legacy callback + if progress_emitter is not None: + self._progress_emitter: ProgressEmitter = progress_emitter + elif progress_callback is not None: + self._progress_emitter = CallbackEmitter(progress_callback) + else: + self._progress_emitter = NullEmitter() # LLM with both reasoning and function calling capabilities self.llm = ChatOpenAI( model=self.settings.openai_model, @@ -520,6 +539,196 @@ async def _compose_final_markdown( return content + async def _summarize_envelopes_for_reasoning( + self, + envelopes: List[Dict[str, Any]], + max_data_chars: int = 500, + ) -> List[Dict[str, Any]]: + """Summarize tool output envelopes to reduce context size for reasoning. + + For envelopes with large data payloads, uses the mini LLM to extract + key findings. Small payloads are kept as-is. + + Args: + envelopes: List of ResultEnvelope dicts from tool executions + max_data_chars: Threshold above which to summarize (default 500 chars) + + Returns: + List of summarized envelope dicts with condensed data + """ + if not envelopes: + return [] + + summarized = [] + to_summarize = [] + to_summarize_indices = [] + + # Identify which envelopes need summarization + for i, env in enumerate(envelopes): + data = env.get("data", {}) + data_str = json.dumps(data, default=str) if data else "" + + if len(data_str) > max_data_chars: + to_summarize.append(env) + to_summarize_indices.append(i) + else: + summarized.append((i, env)) + + # Batch summarize large envelopes + if to_summarize: + logger.info( + f"Reasoning: summarizing {len(to_summarize)} envelopes " + f"(>{max_data_chars} chars each)" + ) + + # Build batch prompt for efficiency + batch_prompt = ( + "You are summarizing tool outputs for an SRE agent. " + "For each tool result below, extract ONLY the key findings in 2-3 sentences. " + "Focus on: errors, warnings, anomalies, key metrics, and actionable insights. " + "Preserve exact numbers, error messages, and metric values. " + "Return a JSON array with one summary object per input.\n\n" + ) + + for j, env in enumerate(to_summarize): + tool_name = env.get("name", "tool") + data = env.get("data", {}) + batch_prompt += f"--- Tool {j + 1}: {tool_name} ---\n" + batch_prompt += json.dumps(data, default=str)[:2000] # Cap individual items + batch_prompt += "\n\n" + + batch_prompt += ( + 'Return JSON array format: [{"summary": "key findings..."}, {"summary": "..."}]' + ) + + try: + summary_response = await self._ainvoke_memo( + "envelope_summarizer", + self.mini_llm, + [HumanMessage(content=batch_prompt)], + ) + content = summary_response.content or "" + + # Parse summaries from response + summaries = [] + try: + # Try to extract JSON array from response + import re + + json_match = re.search(r"\[[\s\S]*\]", content) + if json_match: + summaries = json.loads(json_match.group()) + except Exception: + pass + + # Apply summaries to envelopes + for j, (orig_idx, env) in enumerate(zip(to_summarize_indices, to_summarize)): + summary_text = ( + summaries[j].get("summary", "") + if j < len(summaries) and isinstance(summaries[j], dict) + else "" + ) + if not summary_text: + # Fallback: truncate data + data_str = json.dumps(env.get("data", {}), default=str) + summary_text = data_str[:max_data_chars] + "..." + + condensed_env = { + "tool_key": env.get("tool_key"), + "name": env.get("name"), + "description": env.get("description"), + "args": env.get("args"), + "status": env.get("status"), + "data": {"summary": summary_text}, + } + summarized.append((orig_idx, condensed_env)) + + except Exception as e: + logger.warning(f"Envelope summarization failed, using truncation: {e}") + # Fallback: truncate all large envelopes + for orig_idx, env in zip(to_summarize_indices, to_summarize): + data_str = json.dumps(env.get("data", {}), default=str) + condensed_env = { + "tool_key": env.get("tool_key"), + "name": env.get("name"), + "description": env.get("description"), + "args": env.get("args"), + "status": env.get("status"), + "data": {"truncated": data_str[:max_data_chars] + "..."}, + } + summarized.append((orig_idx, condensed_env)) + + # Sort by original index to preserve order + summarized.sort(key=lambda x: x[0]) + return [env for _, env in summarized] + + def _build_expand_evidence_tool( + self, + original_envelopes: List[Dict[str, Any]], + ) -> Dict[str, Any]: + """Build a tool that allows the LLM to retrieve full tool output details. + + When we summarize tool outputs, the LLM only sees condensed versions. + This tool lets the LLM request the full original output for any tool_key + if it needs more detail. + + Args: + original_envelopes: The original (unsummarized) envelopes + + Returns: + A LangChain-compatible tool dict that can be bound to an LLM + """ + # Build lookup from tool_key to original envelope + originals_by_key = {e.get("tool_key"): e for e in original_envelopes} + available_keys = list(originals_by_key.keys()) + + def expand_evidence(tool_key: str) -> Dict[str, Any]: + """Retrieve the full, unsummarized output from a previous tool call. + + Use this when you need more detail than the summary provides. + Only call this for tool_keys that appear in the evidence summaries. + + Args: + tool_key: The tool_key from a summarized evidence item + + Returns: + The full original tool output with all details + """ + if tool_key not in originals_by_key: + return { + "status": "error", + "error": f"Unknown tool_key: {tool_key}. Available keys: {available_keys}", + } + original = originals_by_key[tool_key] + return { + "status": "success", + "tool_key": tool_key, + "name": original.get("name"), + "description": original.get("description"), + "full_data": original.get("data"), + } + + # Return as a LangChain tool-compatible format + return { + "name": "expand_evidence", + "description": ( + "Retrieve the full, unsummarized output from a previous tool call. " + "Use this when the summary doesn't have enough detail for your analysis. " + f"Available tool_keys: {available_keys}" + ), + "func": expand_evidence, + "parameters": { + "type": "object", + "properties": { + "tool_key": { + "type": "string", + "description": "The tool_key from a summarized evidence item", + } + }, + "required": ["tool_key"], + }, + } + def _build_workflow( self, tool_mgr: ToolManager, target_instance: Optional[Any] = None ) -> StateGraph: @@ -785,12 +994,12 @@ async def tool_node(state: AgentState) -> AgentState: try: tool_name = tc.get("name") tool_args = tc.get("args") or {} - if self.progress_callback and tool_name: + if tool_name: status_msg = tool_mgr.get_status_update( tool_name, tool_args ) or self._generate_tool_reflection(tool_name, tool_args) if status_msg: - await self.progress_callback(status_msg, "agent_reflection") + await self._progress_emitter.emit(status_msg, "agent_reflection") except Exception: pass @@ -839,8 +1048,7 @@ async def tool_node(state: AgentState) -> AgentState: # Knowledge fragments progress (best-effort) try: if ( - self.progress_callback - and isinstance(data_obj, dict) + isinstance(data_obj, dict) and isinstance(tool_name, str) and tool_name.startswith("knowledge_") and tool_name.endswith("_search") @@ -863,7 +1071,7 @@ async def tool_node(state: AgentState) -> AgentState: f"Failed to build fragment from knowledge search result: {e}" ) if fragments: - await self.progress_callback( + await self._progress_emitter.emit( "Retrieved knowledge fragments", "knowledge_sources", {"fragments": fragments}, @@ -912,9 +1120,16 @@ def _parse_tool_json_blocks(tool_msg_text: str) -> Optional[dict]: except Exception: return None - # New path: topic extraction with structured output based on full envelopes + # New path: topic extraction with structured output based on summarized envelopes envelopes = state.get("signals_envelopes") or [] logger.info(f"Reasoning: envelopes captured={len(envelopes)}") + + # Summarize large envelopes to reduce context size + summarized_envelopes = await self._summarize_envelopes_for_reasoning( + envelopes, max_data_chars=500 + ) + logger.info(f"Reasoning: envelopes after summarization={len(summarized_envelopes)}") + topics: List[Dict[str, Any]] = [] try: from .models import TopicsList @@ -927,13 +1142,13 @@ def _parse_tool_json_blocks(tool_msg_text: str) -> Optional[dict]: "name": target_instance.name, } preface = ( - "About this JSON: signals from upstream tool calls (each has a tool description, args, and raw JSON results).\n" + "About this JSON: summarized signals from upstream tool calls (each has a tool description, args, and key findings).\n" "Use only these as evidence. Return a list of topics with evidence_keys referencing tool_key.\n" "For EACH topic, include: id, title, category, scope, narrative, evidence_keys, and severity.\n" "severity MUST be one of: critical | high | medium | low, based on operational risk/impact/urgency.\n" "Order the topics by severity (critical->low)." ) - payload = json.dumps(envelopes, default=str) + payload = json.dumps(summarized_envelopes, default=str) human = HumanMessage( content=( preface @@ -975,6 +1190,8 @@ def _sev_score(t: dict) -> int: # If we have extracted topics, run dynamic per-topic recommendation workers if topics: + from langchain_core.tools import StructuredTool + from .subgraphs.recommendation_worker import build_recommendation_worker rec_tasks = [] @@ -988,32 +1205,50 @@ def _sev_score(t: dict) -> int: # Use all knowledge tools for the mini knowledge agent; no op-level filtering. knowledge_tools = tool_mgr.get_tools_for_capability(_ToolCap.KNOWLEDGE) knowledge_adapters = await _build_adapters(tool_mgr, knowledge_tools) - if knowledge_adapters: - knowledge_llm = self.mini_llm.bind_tools(knowledge_adapters) - if knowledge_adapters: + # Build expand_evidence tool so LLM can retrieve full details if needed + # This gives the LLM access to original (unsummarized) tool outputs + expand_tool_spec = self._build_expand_evidence_tool(envelopes) + expand_tool = StructuredTool.from_function( + func=expand_tool_spec["func"], + name=expand_tool_spec["name"], + description=expand_tool_spec["description"], + ) + # Add expand_evidence to the available tools + all_adapters = list(knowledge_adapters) + [expand_tool] + + if all_adapters: + knowledge_llm = self.mini_llm.bind_tools(all_adapters) + + if all_adapters: logger.info( - f"Reasoning: knowledge adapters available={len(knowledge_adapters)}; topics to run={len(topics)}" + f"Reasoning: knowledge adapters available={len(all_adapters)} " + f"(includes expand_evidence tool); topics to run={len(topics)}" ) worker = build_recommendation_worker( knowledge_llm, - knowledge_adapters, + all_adapters, max_tool_steps=self.settings.max_tool_calls_per_stage, memoize=self._ainvoke_memo, ) - env_by_key = { - e.get("tool_key"): e for e in (state.get("signals_envelopes") or []) - } + # Use summarized envelopes for recommendation workers + # LLM can call expand_evidence to get full details if needed + env_by_key = {e.get("tool_key"): e for e in summarized_envelopes} for t in topics: ev_keys = [k for k in (t.get("evidence_keys") or []) if isinstance(k, str)] ev = [env_by_key[k] for k in ev_keys if k in env_by_key] inp = { "messages": [ SystemMessage( - content="You will research and then synthesize recommendations for the given topic." + content=( + "You will research and then synthesize recommendations for the given topic. " + "The evidence provided contains summaries of tool outputs. " + "If you need more detail from any evidence item, use the expand_evidence tool " + "with the tool_key to retrieve the full original output." + ) ), HumanMessage( - content=f"Topic: {json.dumps(t, default=str)}\nInstance: {json.dumps(instance_ctx, default=str)}\nEvidence: {json.dumps(ev, default=str)}" + content=f"Topic: {json.dumps(t, default=str)}\nInstance: {json.dumps(instance_ctx, default=str)}\nEvidence (summaries): {json.dumps(ev, default=str)}" ), ], "budget": int(self.settings.max_tool_calls_per_stage), @@ -1178,6 +1413,7 @@ async def _process_query( context: Optional[Dict[str, Any]] = None, progress_callback: Optional[Callable[[str, str], Awaitable[None]]] = None, conversation_history: Optional[List[BaseMessage]] = None, + progress_emitter: Optional[ProgressEmitter] = None, ) -> str: """Process a single SRE query through the LangGraph workflow. @@ -1187,15 +1423,19 @@ async def _process_query( user_id: User identifier max_iterations: Maximum number of workflow iterations context: Additional context including instance_id if specified + progress_callback: Deprecated. Use progress_emitter instead. + progress_emitter: ProgressEmitter for status updates during this query. Returns: Agent's response as a string """ logger.info(f"Processing SRE query for user {user_id}, session {session_id}") - # Set progress callback for this query - if progress_callback: - self.progress_callback = progress_callback + # Set progress emitter for this query (prefer emitter over callback) + if progress_emitter is not None: + self._progress_emitter = progress_emitter + elif progress_callback is not None: + self._progress_emitter = CallbackEmitter(progress_callback) # Determine target Redis instance from context target_instance = None @@ -1214,6 +1454,12 @@ async def _process_query( f"Found target instance: {target_instance.name} ({target_instance.connection_url})" ) # Add instance context to the query + repo_context = "" + if target_instance.repo_url: + repo_context = f"""- Repository URL: {target_instance.repo_url} + +If you have repository tools available (e.g., GitHub MCP), you can use them to access code, configuration files, or documentation related to this instance. +""" enhanced_query = f"""User Query: {query} IMPORTANT CONTEXT: This query is specifically about Redis instance: @@ -1222,7 +1468,7 @@ async def _process_query( - Connection URL: {target_instance.connection_url} - Environment: {target_instance.environment} - Usage: {target_instance.usage} - +{repo_context} Your diagnostic tools are PRE-CONFIGURED for this instance. You do NOT need to specify redis_url or instance details - they are already set. Just call the tools directly. SAFETY REQUIREMENT: You MUST verify you can connect to and gather data from this specific Redis instance before making any recommendations. If you cannot get basic metrics like maxmemory, connected_clients, or keyspace info, you lack sufficient information to make recommendations. @@ -1311,6 +1557,12 @@ async def _process_query( f"Auto-detected single Redis instance: {target_instance.name} ({redis_url})" ) + repo_context = "" + if target_instance.repo_url: + repo_context = f"""- Repository URL: {target_instance.repo_url} + +If you have repository tools available (e.g., GitHub MCP), you can use them to access code, configuration files, or documentation related to this instance. +""" enhanced_query = f"""User Query: {query} AUTO-DETECTED CONTEXT: Since no specific Redis instance was mentioned, I am analyzing the available Redis instance: @@ -1319,7 +1571,7 @@ async def _process_query( - Port: {port} - Environment: {target_instance.environment} - Usage: {target_instance.usage} - +{repo_context} When using Redis diagnostic tools, use this Redis URL: {redis_url} SAFETY REQUIREMENT: You MUST verify you can connect to and gather data from this Redis instance before making any recommendations. If you cannot get basic metrics like maxmemory, connected_clients, or keyspace info, you lack sufficient information to make recommendations.""" @@ -1675,8 +1927,20 @@ async def process_query( context: Optional[Dict[str, Any]] = None, progress_callback: Optional[Callable[[str, str], Awaitable[None]]] = None, conversation_history: Optional[List[BaseMessage]] = None, + progress_emitter: Optional[ProgressEmitter] = None, ) -> str: - """Process a query once, then attach Safety and Fact-Checking notes.""" + """Process a query once, then attach Safety and Fact-Checking notes. + + Args: + query: User's SRE question or request + session_id: Session identifier for conversation context + user_id: User identifier + max_iterations: Maximum number of workflow iterations + context: Additional context including instance_id if specified + progress_callback: Deprecated. Use progress_emitter instead. + conversation_history: Optional list of previous messages for context + progress_emitter: ProgressEmitter for status updates during this query. + """ # Initialize in-run caches (LLM memo; tool cache is per-ToolManager context) self._begin_run_cache() try: @@ -1689,6 +1953,7 @@ async def process_query( context, progress_callback, conversation_history, + progress_emitter, ) # Skip correction if this message isn't about Redis @@ -1771,6 +2036,7 @@ async def process_query( "version": inst.version, "memory": inst.memory, "connections": inst.connections, + "repo_url": inst.repo_url, } except Exception: instance_ctx = {} diff --git a/redis_sre_agent/agent/prompts.py b/redis_sre_agent/agent/prompts.py index 6d75ab15..9027fd89 100644 --- a/redis_sre_agent/agent/prompts.py +++ b/redis_sre_agent/agent/prompts.py @@ -16,6 +16,27 @@ 3. **Search your knowledge** when you need specific troubleshooting steps 4. **Give them a clear plan** - actionable steps they can take right now +## Tool Usage - BATCH YOUR CALLS + +**CRITICAL: Call multiple tools in a single response whenever possible.** + +When you need to gather information, request ALL relevant tools at once rather than one at a time: + +❌ **WRONG** (sequential - slow): +``` +Turn 1: Call get_detailed_redis_diagnostics +Turn 2: Call get_cluster_info +Turn 3: Call list_nodes +Turn 4: Call search_knowledge_base +``` + +✅ **CORRECT** (parallel - fast): +``` +Turn 1: Call get_detailed_redis_diagnostics, get_cluster_info, list_nodes, search_knowledge_base all together +``` + +Think about what information you'll need upfront and request it all in one turn. This significantly speeds up analysis. + ## Writing Style Write like you're updating a colleague on what you found. Use natural language: diff --git a/redis_sre_agent/agent/router.py b/redis_sre_agent/agent/router.py index b27fa585..c2890143 100644 --- a/redis_sre_agent/agent/router.py +++ b/redis_sre_agent/agent/router.py @@ -20,8 +20,12 @@ class AgentType(Enum): """Types of available agents.""" - REDIS_FOCUSED = "redis_focused" - KNOWLEDGE_ONLY = "knowledge_only" + REDIS_TRIAGE = "redis_triage" # Full triage/health check agent + REDIS_CHAT = "redis_chat" # Lightweight chat agent for quick Q&A + KNOWLEDGE_ONLY = "knowledge_only" # No instance, general knowledge + + # Keep old value for backward compatibility + REDIS_FOCUSED = "redis_triage" # Alias for REDIS_TRIAGE async def route_to_appropriate_agent( @@ -32,6 +36,11 @@ async def route_to_appropriate_agent( """ Route a query to the appropriate agent using a fast LLM categorization. + Routing logic: + - No Redis instance: KNOWLEDGE_ONLY (general knowledge questions) + - Has Redis instance + asks for full/comprehensive health check or triage: REDIS_TRIAGE + - Has Redis instance + quick question: REDIS_CHAT (fast diagnostic loop) + Args: query: The user's query text context: Additional context including instance_id, priority, etc. @@ -42,48 +51,89 @@ async def route_to_appropriate_agent( """ logger.info(f"Routing query: {query[:100]}...") - # 1. Check for explicit Redis instance context - if context and context.get("instance_id"): - logger.info("Query has explicit Redis instance context - routing to Redis-focused agent") - return AgentType.REDIS_FOCUSED + has_instance = context and context.get("instance_id") + + # 1. No instance context - route to knowledge agent + if not has_instance: + # Use LLM to decide if query needs instance access or is knowledge-only + try: + llm = ChatOpenAI( + model=settings.openai_model_nano, + api_key=settings.openai_api_key, + timeout=10.0, + temperature=0, + ) + + system_prompt = """You are a query categorization system for a Redis SRE agent. + +Categorize if this query requires access to a live Redis instance or is just seeking general knowledge. + +1. NEEDS_INSTANCE: Queries that require access to a specific Redis instance for diagnostics, monitoring, or troubleshooting. + Examples: "Check my Redis memory", "Why is Redis slow?", "Show me the slowlog" + +2. KNOWLEDGE_ONLY: Queries seeking general knowledge, best practices, or guidance. + Examples: "What are Redis best practices?", "How does Redis replication work?" + +Respond with ONLY one word: either "NEEDS_INSTANCE" or "KNOWLEDGE_ONLY".""" + + messages = [ + SystemMessage(content=system_prompt), + HumanMessage(content=f"Categorize this query: {query}"), + ] + + response = await llm.ainvoke(messages) + category = response.content.strip().upper() - # 2. Check user preferences + if "NEEDS_INSTANCE" in category: + logger.info("Query needs instance but none provided - routing to KNOWLEDGE_ONLY") + else: + logger.info("LLM categorized query as KNOWLEDGE_ONLY") + + return AgentType.KNOWLEDGE_ONLY + + except Exception as e: + logger.error(f"Error during LLM routing: {e}, defaulting to KNOWLEDGE_ONLY") + return AgentType.KNOWLEDGE_ONLY + + # 2. Has instance - decide between triage (full) and chat (quick) + # Check user preferences first if user_preferences and user_preferences.get("preferred_agent"): preferred = user_preferences["preferred_agent"] if preferred in [agent.value for agent in AgentType]: logger.info(f"Using user preference: {preferred}") return AgentType(preferred) - # 3. Use fast LLM to categorize the query + # 3. Use LLM to categorize triage vs chat try: llm = ChatOpenAI( model=settings.openai_model_nano, api_key=settings.openai_api_key, - timeout=10.0, # Fast timeout for categorization - temperature=0, # Deterministic categorization + timeout=10.0, + temperature=0, ) system_prompt = """You are a query categorization system for a Redis SRE agent. -Your task is to categorize user queries into one of two categories: +The user has a Redis instance available. Determine what kind of agent should handle their query: -1. REDIS_FOCUSED: Queries that require access to a specific Redis instance for diagnostics, monitoring, or troubleshooting. +1. TRIAGE: Full health check, comprehensive diagnostics, or in-depth analysis. + Trigger words: "full health check", "triage", "comprehensive", "full analysis", "complete diagnostic", "thorough check", "audit" Examples: - - "Check the memory usage of my Redis instance" - - "Why is Redis slow?" - - "Show me the slowlog" - - "What's the current connection count?" - - "Diagnose performance issues" + - "Run a full health check on my Redis" + - "I need a comprehensive triage of this instance" + - "Do a complete diagnostic" + - "Give me a thorough analysis" -2. KNOWLEDGE_ONLY: Queries seeking general knowledge, best practices, or guidance that don't require instance access. +2. CHAT: Quick questions, specific lookups, or targeted queries. Examples: - - "What are Redis best practices?" - - "How does Redis replication work?" - - "Explain Redis persistence options" - - "What is an SRE runbook?" - - "How should I configure Redis for high availability?" + - "What do you know about this instance?" + - "Check the memory usage" + - "Show me the slowlog" + - "How many connections are there?" + - "What's the current ops/sec?" + - "Is replication working?" -Respond with ONLY one word: either "REDIS_FOCUSED" or "KNOWLEDGE_ONLY".""" +Respond with ONLY one word: either "TRIAGE" or "CHAT".""" messages = [ SystemMessage(content=system_prompt), @@ -93,18 +143,13 @@ async def route_to_appropriate_agent( response = await llm.ainvoke(messages) category = response.content.strip().upper() - if "REDIS_FOCUSED" in category: - logger.info("LLM categorized query as REDIS_FOCUSED") - return AgentType.REDIS_FOCUSED - elif "KNOWLEDGE_ONLY" in category: - logger.info("LLM categorized query as KNOWLEDGE_ONLY") - return AgentType.KNOWLEDGE_ONLY + if "TRIAGE" in category: + logger.info("LLM categorized query as REDIS_TRIAGE (full health check)") + return AgentType.REDIS_TRIAGE else: - logger.warning( - f"LLM returned unexpected category: {category}, defaulting to KNOWLEDGE_ONLY" - ) - return AgentType.KNOWLEDGE_ONLY + logger.info("LLM categorized query as REDIS_CHAT (quick Q&A)") + return AgentType.REDIS_CHAT except Exception as e: - logger.error(f"Error during LLM routing: {e}, defaulting to KNOWLEDGE_ONLY") - return AgentType.KNOWLEDGE_ONLY + logger.error(f"Error during LLM routing: {e}, defaulting to REDIS_CHAT") + return AgentType.REDIS_CHAT diff --git a/redis_sre_agent/api/schemas.py b/redis_sre_agent/api/schemas.py index 4bf5a89c..f68296bf 100644 --- a/redis_sre_agent/api/schemas.py +++ b/redis_sre_agent/api/schemas.py @@ -95,6 +95,9 @@ class TaskResponse(BaseModel): updates: List[Dict[str, Any]] = Field(default_factory=list) result: Optional[Dict[str, Any]] = None error_message: Optional[str] = None + subject: Optional[str] = None + created_at: Optional[str] = None + updated_at: Optional[str] = None # Thread schemas @@ -126,6 +129,12 @@ class ThreadAppendMessagesRequest(BaseModel): class ThreadResponse(BaseModel): + """Response model for thread data. + + Updates, result, and error_message are fetched from the latest task + associated with this thread to support real-time UI updates. + """ + thread_id: str user_id: Optional[str] = None priority: int = 0 @@ -133,10 +142,11 @@ class ThreadResponse(BaseModel): subject: Optional[str] = None context: Optional[Dict[str, Any]] = None messages: List[Message] = Field(default_factory=list) - # New fields to expose full thread state for UI streaming - updates: List[Dict[str, Any]] = Field(default_factory=list) - result: Optional[Dict[str, Any]] = None - error_message: Optional[str] = None metadata: Optional[Dict[str, Any]] = None created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + # Task-level fields for real-time updates + updates: List[Dict[str, Any]] = Field(default_factory=list) + result: Optional[Dict[str, Any]] = None + error_message: Optional[str] = None + status: Optional[str] = None diff --git a/redis_sre_agent/api/tasks.py b/redis_sre_agent/api/tasks.py index 5c8c6f2a..e3664abb 100644 --- a/redis_sre_agent/api/tasks.py +++ b/redis_sre_agent/api/tasks.py @@ -67,4 +67,7 @@ async def get_task(task_id: str) -> TaskResponse: updates=[u.model_dump() for u in state.updates], result=state.result, error_message=state.error_message, + subject=state.metadata.subject if state.metadata else None, + created_at=state.metadata.created_at if state.metadata else None, + updated_at=state.metadata.updated_at if state.metadata else None, ) diff --git a/redis_sre_agent/api/threads.py b/redis_sre_agent/api/threads.py index 6cbc4b22..d56ed926 100644 --- a/redis_sre_agent/api/threads.py +++ b/redis_sre_agent/api/threads.py @@ -76,7 +76,7 @@ async def create_thread(req: ThreadCreateRequest) -> ThreadResponse: thread_id = await tm.create_thread( user_id=req.user_id, session_id=req.session_id, - initial_context=req.context or {"messages": []}, + initial_context=req.context or {}, tags=req.tags or [], ) if req.subject: @@ -91,12 +91,18 @@ async def create_thread(req: ThreadCreateRequest) -> ThreadResponse: await tm.append_messages(thread_id, [m.model_dump() for m in req.messages]) state = await tm.get_thread(thread_id) - messages = state.context.get("messages", []) if state else [] - # Return created thread state (messages + context) + if not state: + raise HTTPException(status_code=500, detail="Failed to retrieve created thread") + + # Convert Message objects to API schema + messages = [ + Message(role=m.role, content=m.content, metadata=m.metadata) for m in state.messages + ] + return ThreadResponse( thread_id=thread_id, - messages=[Message(**m) for m in messages] if messages else [], - context=state.context if state else {}, + messages=messages, + context=state.context, ) except Exception as e: logger.error(f"Failed to create thread: {e}") @@ -111,9 +117,6 @@ async def get_thread(thread_id: str) -> ThreadResponse: if not state: raise HTTPException(status_code=404, detail="Thread not found") - # Extract messages from context if present - messages = state.context.get("messages", []) - # Build metadata dict compatible with UI expectations try: metadata = state.metadata.model_dump() @@ -123,18 +126,50 @@ async def get_thread(thread_id: str) -> ThreadResponse: except Exception: metadata = None + # Convert Message objects to API schema + messages = [ + Message(role=m.role, content=m.content, metadata=m.metadata) for m in state.messages + ] + + # Fetch the latest task's updates, result, and status for real-time UI display + updates = [] + result = None + error_message = None + task_status = None + + try: + from redis_sre_agent.core.keys import RedisKeys + from redis_sre_agent.core.tasks import TaskManager + + task_manager = TaskManager(redis_client=rc) + # Get the latest task for this thread + latest_task_ids = await rc.zrevrange(RedisKeys.thread_tasks_index(thread_id), 0, 0) + if latest_task_ids: + latest_task_id = latest_task_ids[0] + if isinstance(latest_task_id, bytes): + latest_task_id = latest_task_id.decode() + task_state = await task_manager.get_task_state(latest_task_id) + if task_state: + updates = [u.model_dump() for u in (task_state.updates or [])] + result = task_state.result + error_message = task_state.error_message + task_status = task_state.status + except Exception as e: + logger.warning(f"Failed to fetch task updates for thread {thread_id}: {e}") + return ThreadResponse( thread_id=thread_id, user_id=(metadata.get("user_id") if metadata else None), priority=(metadata.get("priority", 0) if metadata else 0), tags=(metadata.get("tags", []) if metadata else []), subject=(metadata.get("subject") if metadata else None), - messages=[Message(**m) for m in messages] if messages else [], + messages=messages, context=state.context, - updates=[u.model_dump() for u in state.updates] if state.updates else [], - result=state.result, - error_message=state.error_message, metadata=metadata, + updates=updates, + result=result, + error_message=error_message, + status=task_status, ) diff --git a/redis_sre_agent/api/websockets.py b/redis_sre_agent/api/websockets.py index 40e94dfd..2b7aedc4 100644 --- a/redis_sre_agent/api/websockets.py +++ b/redis_sre_agent/api/websockets.py @@ -215,13 +215,35 @@ async def websocket_task_status(websocket: WebSocket, thread_id: str): if len(_active_connections[thread_id]) == 1: await _stream_manager.start_consumer(thread_id) - # Send current thread state immediately (no thread status) + # Get the latest task for this thread to send updates/result/error + from redis_sre_agent.core.tasks import TaskManager + + task_manager = TaskManager(redis_client=redis_client) + latest_task_ids = await redis_client.zrevrange( + RedisKeys.thread_tasks_index(thread_id), 0, 0 + ) + + updates = [] + result = None + error_message = None + + if latest_task_ids: + latest_task_id = latest_task_ids[0] + if isinstance(latest_task_id, bytes): + latest_task_id = latest_task_id.decode() + task_state = await task_manager.get_task_state(latest_task_id) + if task_state: + updates = task_state.updates[-10:] if task_state.updates else [] + result = task_state.result + error_message = task_state.error_message + + # Send current state immediately initial_event = InitialStateEvent( update_type="initial_state", thread_id=thread_id, - updates=thread_state.updates[-10:], # Last 10 updates - result=thread_state.result, - error_message=thread_state.error_message, + updates=updates, + result=result, + error_message=error_message, timestamp=datetime.now(timezone.utc).isoformat(), ) await websocket.send_text(initial_event.model_dump_json()) diff --git a/redis_sre_agent/cli/index.py b/redis_sre_agent/cli/index.py new file mode 100644 index 00000000..60d33b0c --- /dev/null +++ b/redis_sre_agent/cli/index.py @@ -0,0 +1,156 @@ +"""Index management CLI commands.""" + +from __future__ import annotations + +import asyncio +import json as _json + +import click +from rich.console import Console +from rich.table import Table + + +@click.group() +def index(): + """RediSearch index management commands.""" + pass + + +@index.command("list") +@click.option("--json", "as_json", is_flag=True, help="Output JSON") +def index_list(as_json: bool): + """List all SRE agent indices and their status.""" + + async def _run(): + from redis_sre_agent.core.redis import ( + SRE_INSTANCES_INDEX, + SRE_KNOWLEDGE_INDEX, + SRE_SCHEDULES_INDEX, + SRE_TASKS_INDEX, + SRE_THREADS_INDEX, + get_instances_index, + get_knowledge_index, + get_schedules_index, + get_tasks_index, + get_threads_index, + ) + + console = Console() + indices = [ + ("knowledge", SRE_KNOWLEDGE_INDEX, get_knowledge_index), + ("schedules", SRE_SCHEDULES_INDEX, get_schedules_index), + ("threads", SRE_THREADS_INDEX, get_threads_index), + ("tasks", SRE_TASKS_INDEX, get_tasks_index), + ("instances", SRE_INSTANCES_INDEX, get_instances_index), + ] + + results = [] + for name, index_name, get_fn in indices: + try: + idx = await get_fn() + exists = await idx.exists() + info = {} + if exists: + try: + # Get index info to show field count + client = idx._redis_client + raw_info = await client.execute_command("FT.INFO", index_name) + # Parse the flat list into a dict + info_dict = {} + for i in range(0, len(raw_info), 2): + key = raw_info[i] + if isinstance(key, bytes): + key = key.decode() + info_dict[key] = raw_info[i + 1] + num_docs = info_dict.get("num_docs", 0) + if isinstance(num_docs, bytes): + num_docs = num_docs.decode() + info["num_docs"] = int(num_docs) + except Exception: + info["num_docs"] = "?" + + results.append( + { + "name": name, + "index_name": index_name, + "exists": exists, + "num_docs": info.get("num_docs", 0) if exists else 0, + } + ) + except Exception as e: + results.append( + { + "name": name, + "index_name": index_name, + "exists": False, + "error": str(e), + } + ) + + if as_json: + print(_json.dumps(results, indent=2)) + return + + table = Table(title="RediSearch Indices") + table.add_column("Name", no_wrap=True) + table.add_column("Index Name", no_wrap=True) + table.add_column("Exists", no_wrap=True) + table.add_column("Documents", no_wrap=True) + + for r in results: + exists_str = "✅" if r["exists"] else "❌" + docs = str(r.get("num_docs", 0)) if r["exists"] else "-" + if r.get("error"): + docs = f"Error: {r['error']}" + table.add_row(r["name"], r["index_name"], exists_str, docs) + + console.print(table) + + asyncio.run(_run()) + + +@index.command("recreate") +@click.option( + "--index-name", + type=click.Choice(["knowledge", "schedules", "threads", "tasks", "instances", "all"]), + default="all", + help="Which index to recreate (default: all)", +) +@click.option("-y", "--yes", is_flag=True, help="Skip confirmation prompt") +@click.option("--json", "as_json", is_flag=True, help="Output JSON") +def index_recreate(index_name: str, yes: bool, as_json: bool): + """Drop and recreate RediSearch indices. + + This is useful when the schema has changed (e.g., new fields added). + WARNING: This will delete all indexed data. The underlying Redis keys + remain, but you'll need to re-index documents. + """ + + async def _run(): + from redis_sre_agent.core.redis import recreate_indices + + console = Console() + + if not yes and not as_json: + console.print( + "[yellow]Warning:[/yellow] This will drop and recreate indices. " + "Indexed data will need to be re-ingested." + ) + if not click.confirm("Continue?"): + console.print("Aborted.") + return + + result = await recreate_indices(index_name if index_name != "all" else None) + + if as_json: + print(_json.dumps(result, indent=2)) + return + + if result.get("success"): + console.print("[green]✅ Successfully recreated indices[/green]") + for idx_name, status in result.get("indices", {}).items(): + console.print(f" - {idx_name}: {status}") + else: + console.print(f"[red]❌ Failed to recreate indices: {result.get('error')}[/red]") + + asyncio.run(_run()) diff --git a/redis_sre_agent/cli/knowledge.py b/redis_sre_agent/cli/knowledge.py index dabb2d2c..2b53df0a 100644 --- a/redis_sre_agent/cli/knowledge.py +++ b/redis_sre_agent/cli/knowledge.py @@ -25,8 +25,9 @@ def knowledge(): @knowledge.command("search") @click.argument("query", nargs=-1) -@click.option("--limit", "-l", default=5, help="Number of results to return") @click.option("--category", "-c", type=str, help="Filter by category") +@click.option("--limit", "-l", default=10, help="Number of results to return") +@click.option("--offset", "-o", default=0, help="Offset for pagination") @click.option("--distance-threshold", "-d", type=float, help="Cosine distance threshold") @click.option( "--hybrid-search", @@ -35,11 +36,14 @@ def knowledge(): default=False, help="Use hybrid search (vector + full-text)", ) +@click.option("--version", "-v", type=str, default="latest", help="Redis version filter") def knowledge_search( - limit: int, category: Optional[str], + limit: int, + offset: int, distance_threshold: Optional[float], - hybrid_search: bool = False, + hybrid_search: bool, + version: Optional[str], query: str = "*", ): """Search the knowledge base (query helpers group).""" @@ -48,6 +52,7 @@ async def _run(): kwargs = { "query": " ".join(query), "limit": limit, + "offset": offset, "distance_threshold": distance_threshold, "hybrid_search": hybrid_search, } @@ -58,6 +63,9 @@ async def _run(): click.echo(f"📂 Category filter: {category}") if distance_threshold: click.echo(f"📏 Distance threshold: {distance_threshold}") + if version: + kwargs["version"] = version + click.echo(f"🔢 Version filter: {version}") click.echo(f"🔢 Limit: {limit}") result = await search_knowledge_base_helper(**kwargs) @@ -73,6 +81,7 @@ async def _run(): click.echo(f"Title: {doc.get('title', 'Unknown')}") click.echo(f"Source: {doc.get('source', 'Unknown')}") click.echo(f"Category: {doc.get('category', 'general')}") + click.echo(f"Version: {doc.get('version', 'None')}") content = doc.get("content", "") if len(content) > 1000: content = content[:1000] + "..." diff --git a/redis_sre_agent/cli/main.py b/redis_sre_agent/cli/main.py index 303da357..6d377e5c 100644 --- a/redis_sre_agent/cli/main.py +++ b/redis_sre_agent/cli/main.py @@ -15,6 +15,8 @@ "runbook": "redis_sre_agent.cli.runbook:runbook", "query": "redis_sre_agent.cli.query:query", "worker": "redis_sre_agent.cli.worker:worker", + "mcp": "redis_sre_agent.cli.mcp:mcp", + "index": "redis_sre_agent.cli.index:index", } diff --git a/redis_sre_agent/cli/mcp.py b/redis_sre_agent/cli/mcp.py new file mode 100644 index 00000000..285bbfad --- /dev/null +++ b/redis_sre_agent/cli/mcp.py @@ -0,0 +1,87 @@ +"""MCP server CLI commands.""" + +import click + + +@click.group() +def mcp(): + """MCP server commands - expose agent capabilities via Model Context Protocol.""" + pass + + +@mcp.command("serve") +@click.option( + "--transport", + type=click.Choice(["stdio", "http", "sse"]), + default="stdio", + help="Transport mode: stdio (local), http (remote/recommended), or sse (legacy)", +) +@click.option( + "--host", + default="0.0.0.0", + help="Host to bind to (http/sse mode only)", +) +@click.option( + "--port", + default=8081, + type=int, + help="Port to bind to (http/sse mode only)", +) +def serve(transport: str, host: str, port: int): + """Start the MCP server. + + The MCP server exposes the Redis SRE Agent's capabilities to other + MCP-compatible AI agents. + + \b + Available tools: + - triage: Start a Redis troubleshooting session + - get_task_status: Check if a triage task is complete + - get_thread: Get the full results from a triage + - knowledge_search: Search Redis documentation and runbooks + - list_instances: List configured Redis instances + - create_instance: Register a new Redis instance + + \b + Examples: + # Run in stdio mode (for Claude Desktop local config) + redis-sre-agent mcp serve + + \b + # Run in HTTP mode (for Claude remote connector - RECOMMENDED) + redis-sre-agent mcp serve --transport http --port 8081 + # Then add in Claude: Settings > Connectors > Add Custom Connector + # URL: http://your-host:8081/mcp + + \b + # Run in SSE mode (legacy, for older clients) + redis-sre-agent mcp serve --transport sse --port 8081 + """ + from redis_sre_agent.mcp_server.server import run_http, run_sse, run_stdio + + if transport == "stdio": + # Don't print anything to stdout in stdio mode - it corrupts the JSON-RPC stream + run_stdio() + elif transport == "http": + click.echo(f"Starting MCP server in HTTP mode on {host}:{port}...") + click.echo(f"MCP endpoint: http://{host}:{port}/mcp") + click.echo("Add this URL as a Custom Connector in Claude settings.") + run_http(host=host, port=port) + else: + click.echo(f"Starting MCP server in SSE mode on {host}:{port}...") + run_sse(host=host, port=port) + + +@mcp.command("list-tools") +def list_tools(): + """List available MCP tools.""" + from redis_sre_agent.mcp_server.server import mcp as mcp_server + + click.echo("Available MCP tools:\n") + for tool in mcp_server._tool_manager._tools.values(): + click.echo(f" {tool.name}") + if tool.description: + # Get first line of description + first_line = tool.description.split("\n")[0].strip() + click.echo(f" {first_line}") + click.echo() diff --git a/redis_sre_agent/cli/query.py b/redis_sre_agent/cli/query.py index 0cbefce4..c04675f0 100644 --- a/redis_sre_agent/cli/query.py +++ b/redis_sre_agent/cli/query.py @@ -6,54 +6,172 @@ from typing import Optional import click +from langchain_core.messages import AIMessage, HumanMessage +from rich.console import Console +from rich.markdown import Markdown +from redis_sre_agent.agent.chat_agent import get_chat_agent from redis_sre_agent.agent.knowledge_agent import get_knowledge_agent from redis_sre_agent.agent.langgraph_agent import get_sre_agent +from redis_sre_agent.agent.router import AgentType, route_to_appropriate_agent from redis_sre_agent.core.config import settings from redis_sre_agent.core.instances import get_instance_by_id +from redis_sre_agent.core.redis import get_redis_client +from redis_sre_agent.core.threads import ThreadManager @click.command() @click.argument("query") @click.option("--redis-instance-id", "-r", help="Redis instance ID to investigate") -def query(query: str, redis_instance_id: Optional[str]): - """Execute an agent query.""" +@click.option("--thread-id", "-t", help="Thread ID to continue an existing conversation") +@click.option( + "--agent", + "-a", + type=click.Choice(["auto", "triage", "chat", "knowledge"], case_sensitive=False), + default="auto", + help="Agent to use (default: auto-select based on query)", +) +def query(query: str, redis_instance_id: Optional[str], thread_id: Optional[str], agent: str): + """Execute an agent query. + + Supports conversation threads for multi-turn interactions. Use --thread-id + to continue an existing conversation, or omit it to start a new one. + + \b + The agent is automatically selected based on the query, or use --agent: + - knowledge: General Redis questions (no instance needed) + - chat: Quick questions with a Redis instance + - triage: Full health checks and diagnostics + - auto: Let the router decide (default) + """ async def _query(): + console = Console() + redis_client = get_redis_client() + thread_manager = ThreadManager(redis_client=redis_client) + + # Resolve instance if provided + instance = None if redis_instance_id: instance = await get_instance_by_id(redis_instance_id) if not instance: - click.echo(f"❌ Instance not found: {redis_instance_id}") + console.print(f"[red]❌ Instance not found: {redis_instance_id}[/red]") exit(1) + + # Get or create thread + active_thread_id = thread_id + conversation_history = [] + + if thread_id: + # Continue existing thread + thread = await thread_manager.get_thread(thread_id) + if not thread: + console.print(f"[red]❌ Thread not found: {thread_id}[/red]") + exit(1) + + console.print(f"[dim]📎 Continuing thread: {thread_id}[/dim]") + + # Load conversation history + for msg in thread.messages: + if msg.role == "user": + conversation_history.append(HumanMessage(content=msg.content)) + elif msg.role == "assistant": + conversation_history.append(AIMessage(content=msg.content)) + + # Use instance from thread context if not provided + if not instance and thread.context.get("instance_id"): + instance = await get_instance_by_id(thread.context["instance_id"]) + if instance: + console.print(f"[dim]🔗 Using instance from thread: {instance.name}[/dim]") + else: - instance = None + # Create new thread + initial_context = {} + if instance: + initial_context["instance_id"] = instance.id - click.echo(f"🔍 Query: {query}") + active_thread_id = await thread_manager.create_thread( + user_id="cli_user", + session_id="cli", + initial_context=initial_context, + tags=["cli"], + ) + await thread_manager.update_thread_subject(active_thread_id, query) + console.print(f"[dim]📎 Created thread: {active_thread_id}[/dim]") + + console.print(f"[bold]🔍 Query:[/bold] {query}") if instance: - click.echo(f"🔗 Redis instance: {instance.name}") - agent = get_sre_agent() + console.print(f"[dim]🔗 Redis instance: {instance.name}[/dim]") + + # Build context for routing + routing_context = {"instance_id": instance.id} if instance else None + + # Map CLI agent choice to AgentType + agent_choice_map = { + "triage": AgentType.REDIS_TRIAGE, + "chat": AgentType.REDIS_CHAT, + "knowledge": AgentType.KNOWLEDGE_ONLY, + } + + # Determine which agent to use + if agent != "auto": + agent_type = agent_choice_map[agent.lower()] + agent_label = agent.capitalize() + console.print(f"[dim]🔧 Agent: {agent_label} (selected)[/dim]") + else: + agent_type = await route_to_appropriate_agent( + query=query, + context=routing_context, + ) + agent_label = { + AgentType.REDIS_TRIAGE: "Triage", + AgentType.REDIS_CHAT: "Chat", + AgentType.KNOWLEDGE_ONLY: "Knowledge", + }.get(agent_type, agent_type.value) + console.print(f"[dim]🔧 Agent: {agent_label}[/dim]") + + # Get the appropriate agent instance + if agent_type == AgentType.REDIS_TRIAGE: + selected_agent = get_sre_agent() + elif agent_type == AgentType.REDIS_CHAT: + selected_agent = get_chat_agent(redis_instance=instance) else: - agent = get_knowledge_agent() + selected_agent = get_knowledge_agent() try: context = {"instance_id": instance.id} if instance else None - response = await agent.process_query( + + # Run the agent + response = await selected_agent.process_query( query, session_id="cli", user_id="cli_user", max_iterations=settings.max_iterations, context=context, + conversation_history=conversation_history if conversation_history else None, ) - from rich.console import Console - from rich.markdown import Markdown + # Save messages to thread + await thread_manager.append_messages( + active_thread_id, + [ + {"role": "user", "content": query}, + {"role": "assistant", "content": str(response)}, + ], + ) - console = Console() - console.print("\n✅ Response:\n") + console.print("\n[bold green]✅ Response:[/bold green]\n") console.print(Markdown(str(response))) + + # Show thread ID for follow-up queries + console.print("\n[dim]💡 To continue this conversation:[/dim]") + console.print( + f'[dim] redis-sre-agent query --thread-id {active_thread_id} "your follow-up"[/dim]' + ) + except Exception as e: - click.echo(f"❌ Error: {e}") + console.print(f"[red]❌ Error: {e}[/red]") exit(1) asyncio.run(_query()) diff --git a/redis_sre_agent/cli/threads.py b/redis_sre_agent/cli/threads.py index d61759c7..9d65c350 100644 --- a/redis_sre_agent/cli/threads.py +++ b/redis_sre_agent/cli/threads.py @@ -153,26 +153,22 @@ async def _get(): table.add_row("Tags", ", ".join(meta.tags or []) or "-") table.add_row("Instance", ctx.get("instance_name") or ctx.get("instance_id") or "-") table.add_row("Priority", str(meta.priority)) + table.add_row("Messages", str(len(state.messages))) console.print(table) - # Updates - if state.updates: - ut = Table(title="Updates") - ut.add_column("Time", no_wrap=True) - ut.add_column("Type", no_wrap=True) - ut.add_column("Message") - for u in state.updates[:20]: - ut.add_row(u.timestamp or "-", u.update_type or "-", u.message or "-") - console.print(ut) - - # Result - if state.result: - rt = Table(title="Result") - rt.add_column("Key", no_wrap=True) - rt.add_column("Value") - for k, v in (state.result or {}).items(): - rt.add_row(str(k), str(v)) - console.print(rt) + # Messages (conversation history) + if state.messages: + mt = Table(title="Messages (Conversation)") + mt.add_column("#", no_wrap=True) + mt.add_column("Role", no_wrap=True) + mt.add_column("Content") + for i, m in enumerate(state.messages, 1): + # Truncate long messages for display + content = m.content + if len(content) > 200: + content = content[:197] + "..." + mt.add_row(str(i), m.role, content) + console.print(mt) asyncio.run(_get()) @@ -185,7 +181,12 @@ def thread_sources(thread_id: str, task_id: str | None, as_json: bool): """List knowledge fragments retrieved for a thread (optionally a specific turn).""" async def _run(): - tm = ThreadManager(redis_client=get_redis_client()) + from redis_sre_agent.core.tasks import TaskManager + + client = get_redis_client() + tm = ThreadManager(redis_client=client) + task_manager = TaskManager(redis_client=client) + state = await tm.get_thread(thread_id) if not state: payload = {"error": "Thread not found", "thread_id": thread_id} @@ -195,30 +196,44 @@ async def _run(): click.echo(f"❌ Thread not found: {thread_id}") return - # Collect knowledge_sources updates + # Get tasks for this thread and collect knowledge_sources updates from them items = [] - for u in state.updates or []: - try: - if (u.update_type or "") != "knowledge_sources": - continue - md = u.metadata or {} - if task_id and (md.get("task_id") != task_id): - continue - for frag in md.get("fragments") or []: - items.append( - { - "timestamp": u.timestamp, - "task_id": md.get("task_id"), - "id": frag.get("id"), - "document_hash": frag.get("document_hash"), - "chunk_index": frag.get("chunk_index"), - "title": frag.get("title"), - "source": frag.get("source"), - } - ) - except Exception: + + # Get all tasks for this thread + from redis_sre_agent.core.keys import RedisKeys + + task_ids = await client.zrange(RedisKeys.thread_tasks_index(thread_id), 0, -1) + + for tid in task_ids: + if isinstance(tid, bytes): + tid = tid.decode() + if task_id and tid != task_id: + continue + + task_state = await task_manager.get_task_state(tid) + if not task_state: continue + for u in task_state.updates or []: + try: + if (u.update_type or "") != "knowledge_sources": + continue + md = u.metadata or {} + for frag in md.get("fragments") or []: + items.append( + { + "timestamp": u.timestamp, + "task_id": tid, + "id": frag.get("id"), + "document_hash": frag.get("document_hash"), + "chunk_index": frag.get("chunk_index"), + "title": frag.get("title"), + "source": frag.get("source"), + } + ) + except Exception: + continue + if as_json: print( json.dumps( diff --git a/redis_sre_agent/cli/worker.py b/redis_sre_agent/cli/worker.py index 1c6ac58b..f9b04da7 100644 --- a/redis_sre_agent/cli/worker.py +++ b/redis_sre_agent/cli/worker.py @@ -86,6 +86,19 @@ async def _worker(): except Exception as _e: logger.warning(f"Failed to start Prometheus metrics server in worker: {_e}") + # Initialize Redis infrastructure (creates indices if they don't exist) + try: + from redis_sre_agent.core.redis import create_indices + + indices_created = await create_indices() + if indices_created: + logger.info("✅ Redis indices initialized") + else: + logger.warning("⚠️ Failed to create some Redis indices") + except Exception as e: + logger.error(f"Failed to initialize Redis indices: {e}") + # Continue anyway - some functionality may still work + try: # Register tasks first (support both sync and async implementations) reg = register_sre_tasks() @@ -109,7 +122,7 @@ async def _worker(): try: asyncio.run(_worker()) except KeyboardInterrupt: - click.echo("\n\ud83d\udc4b SRE worker stopped by user") + click.echo("\nSRE worker stopped by user") except Exception as e: - click.echo(f"\ud83d\udca5 Unexpected worker error: {e}") + click.echo(f"Unexpected worker error: {e}") raise diff --git a/redis_sre_agent/core/config.py b/redis_sre_agent/core/config.py index 87300e14..9dee27bb 100644 --- a/redis_sre_agent/core/config.py +++ b/redis_sre_agent/core/config.py @@ -1,27 +1,153 @@ """Configuration management using Pydantic Settings.""" -from typing import TYPE_CHECKING, List, Optional +import os +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union + +from dotenv import load_dotenv +from pydantic import BaseModel, Field, SecretStr +from pydantic_settings import ( + BaseSettings, + PydanticBaseSettingsSource, + SettingsConfigDict, + YamlConfigSettingsSource, +) -from pydantic import Field, SecretStr -from pydantic_settings import BaseSettings, SettingsConfigDict +from redis_sre_agent.tools.models import ToolCapability if TYPE_CHECKING: pass -# Load environment variables from .env file if it exists -# In Docker/production, environment variables are set directly -from pathlib import Path -from dotenv import load_dotenv +class MCPToolConfig(BaseModel): + """Configuration for a specific tool exposed by an MCP server. + + This allows overriding or constraining how the agent sees and uses + a specific MCP tool. + + Example: + MCPToolConfig( + capability=ToolCapability.LOGS, + description="Use this tool when searching for memories..." + ) + """ + + capability: Optional[ToolCapability] = Field( + default=None, + description="The capability category for this tool (e.g., LOGS, METRICS). " + "If not specified, defaults to UTILITIES.", + ) + description: Optional[str] = Field( + default=None, + description="Alternative description for this tool. " + "If provided, the agent sees this instead of the MCP server's description.", + ) + + +class MCPServerConfig(BaseModel): + """Configuration for an MCP server. + + This follows the standard MCP configuration format used by Claude, VS Code, + and other tools, with additional fields for tool constraints. + + Example: + MCPServerConfig( + command="npx", + args=["-y", "@modelcontextprotocol/server-memory"], + tools={ + "search_memories": MCPToolConfig(capability=ToolCapability.LOGS), + "create_memories": MCPToolConfig(description="Use this tool when..."), + } + ) + """ + + # Standard MCP configuration fields + command: Optional[str] = Field( + default=None, + description="Command to launch the MCP server (for stdio transport).", + ) + args: Optional[List[str]] = Field( + default=None, + description="Arguments to pass to the MCP server command.", + ) + env: Optional[Dict[str, str]] = Field( + default=None, + description="Environment variables to set when launching the server.", + ) + + # URL-based transport (alternative to command-based) + url: Optional[str] = Field( + default=None, + description="URL for SSE or HTTP-based MCP transport.", + ) + # Headers for HTTP/SSE transport (e.g., Authorization) + headers: Optional[Dict[str, str]] = Field( + default=None, + description="Headers to send with HTTP/SSE requests (e.g., Authorization).", + ) + + # Transport type for URL-based connections + transport: Optional[str] = Field( + default=None, + description="Transport type for URL-based connections: 'sse' for Server-Sent Events " + "(legacy), 'streamable_http' for Streamable HTTP (recommended for modern servers like " + "GitHub's remote MCP). If not specified, defaults to 'streamable_http' for better " + "compatibility with modern MCP servers.", + ) + + # Tool constraints - if provided, only these tools are exposed to the agent + tools: Optional[Dict[str, MCPToolConfig]] = Field( + default=None, + description="Optional mapping of tool names to their configurations. " + "If provided, only these tools are exposed to the agent from the MCP server. " + "Each tool can have a custom capability and/or description override.", + ) + + +# Load environment variables from .env file if it exists +# In Docker/production, environment variables are set directly ENV_FILE_OPT: str | None = None TWENTY_MINUTES_IN_SECONDS = 1200 # Only load .env if it exists (for local development) +# In Docker/production, environment variables are set directly. +# We check existence before calling load_dotenv to avoid FileNotFoundError. _env_path = Path(".env") -if _env_path.exists(): +if _env_path.is_file(): load_dotenv(dotenv_path=_env_path) ENV_FILE_OPT = str(_env_path) +else: + # Try loading from default locations without erroring if not found + load_dotenv() + + +# Default config file paths (checked in order) +# SRE_AGENT_CONFIG environment variable takes precedence if set +DEFAULT_CONFIG_PATHS = [ + "config.yaml", + "config.yml", + "sre_agent_config.yaml", + "sre_agent_config.yml", +] + + +def _get_yaml_config_path() -> str | list[str] | None: + """Get the YAML config file path to use. + + Returns: + - The path from SRE_AGENT_CONFIG env var if set + - Or the list of default paths to check + - Or None if SRE_AGENT_CONFIG is set to a nonexistent file + """ + config_path = os.environ.get("SRE_AGENT_CONFIG") + + if config_path: + # If explicitly specified, use it (pydantic will handle missing files) + return config_path + + # Return list of default paths - pydantic-settings will check each in order + return DEFAULT_CONFIG_PATHS class Settings(BaseSettings): @@ -29,6 +155,18 @@ class Settings(BaseSettings): Loads settings from environment variables. In local development, these can be provided via a .env file. In Docker/production, they should be set directly. + + Configuration can also be loaded from YAML files. The following paths are checked + (first match wins): + - Path specified in SRE_AGENT_CONFIG environment variable + - config.yaml, config.yml, sre_agent_config.yaml, sre_agent_config.yml + + Priority (highest to lowest): + 1. Values passed to Settings() constructor + 2. Environment variables + 3. .env file + 4. YAML config file + 5. Default values """ model_config = SettingsConfigDict( @@ -38,6 +176,8 @@ class Settings(BaseSettings): extra="ignore", # Don't error if .env file is missing (Docker/production use env vars directly) env_ignore_empty=True, + # Note: yaml_file is set dynamically in settings_customise_sources + # to support SRE_AGENT_CONFIG env var being set after module import ) # Application @@ -70,6 +210,10 @@ class Settings(BaseSettings): default="text-embedding-3-small", description="OpenAI embedding model" ) vector_dim: int = Field(default=1536, description="Vector dimensions") + embeddings_cache_ttl: Optional[int] = Field( + default=86400 * 7, # 7 days + description="TTL in seconds for cached embeddings. None means no expiration.", + ) # Docket Task Queue task_queue_name: str = Field(default="sre_agent_tasks", description="Task queue name") @@ -136,6 +280,107 @@ class Settings(BaseSettings): "Example: redis_sre_agent.tools.metrics.prometheus.PrometheusToolProvider", ) + # MCP Server Configuration + # Uses "uv tool run" (equivalent to uvx) to auto-install the package from PyPI. + # Override via MCP_SERVERS environment variable (JSON) if needed. + mcp_servers: Dict[str, Union[MCPServerConfig, Dict[str, Any]]] = Field( + default_factory=lambda: { + "redis-memory-server": { + "command": "uv", + "args": [ + "tool", + "run", + "--from", + "agent-memory-server", + "agent-memory", + "mcp", + ], + "env": {"REDIS_URL": "redis://localhost:6399"}, + # Only include specific tools, with context-aware descriptions. + # Use {original} to include the tool's original description. + "tools": { + "get_current_datetime": { + "description": ( + "Get the current date and time. Use this when you need to " + "record timestamps for Redis instance events or incidents.\n\n" + "{original}" + ), + }, + "create_long_term_memories": { + "description": ( + "Save long-term memories about Redis instances. Use this to " + "record: past incidents and their resolutions, configuration " + "changes, performance baselines, known issues, maintenance " + "history, and lessons learned. Always include the instance_id " + "in the memory text for future retrieval.\n\n{original}" + ), + }, + "search_long_term_memory": { + "description": ( + "Search saved memories about Redis instances. ALWAYS use this " + "before troubleshooting a Redis instance to recall past issues, " + "solutions, and context. Search by instance_id, error patterns, " + "or symptoms.\n\n{original}" + ), + }, + "get_long_term_memory": { + "description": ( + "Retrieve a specific memory by ID. Use this to get full details " + "of a memory found via search.\n\n{original}" + ), + }, + "edit_long_term_memory": { + "description": ( + "Update an existing memory. Use this to add new information to " + "a past incident record, update resolution status, or correct " + "outdated information.\n\n{original}" + ), + }, + "delete_long_term_memories": { + "description": ( + "Delete memories that are no longer relevant. Use sparingly - " + "prefer editing to add context rather than deleting.\n\n{original}" + ), + }, + }, + } + }, + description="MCP (Model Context Protocol) servers to connect to. " + "Each key is the server name, and the value is the server configuration. " + "Example: {'memory': {'command': 'npx', 'args': ['-y', '@modelcontextprotocol/server-memory'], " + "'tools': {'search_memories': {'capability': 'logs'}}}}", + ) + + @classmethod + def settings_customise_sources( + cls, + settings_cls: Type[BaseSettings], + init_settings: PydanticBaseSettingsSource, + env_settings: PydanticBaseSettingsSource, + dotenv_settings: PydanticBaseSettingsSource, + file_secret_settings: PydanticBaseSettingsSource, + ) -> Tuple[PydanticBaseSettingsSource, ...]: + """Customize settings sources to include YAML config file. + + Priority (highest to lowest): + 1. init_settings (passed to Settings()) + 2. env_settings (environment variables) + 3. dotenv_settings (.env file) + 4. yaml_settings (config.yaml file) + 5. file_secret_settings (Docker secrets) + """ + # Use the built-in YamlConfigSettingsSource from pydantic-settings + # Get the yaml_file path dynamically to respect SRE_AGENT_CONFIG env var + # set after module import + yaml_file = _get_yaml_config_path() + return ( + init_settings, + env_settings, + dotenv_settings, + YamlConfigSettingsSource(settings_cls, yaml_file=yaml_file), + file_secret_settings, + ) + # Global settings instance settings = Settings() diff --git a/redis_sre_agent/core/docket_tasks.py b/redis_sre_agent/core/docket_tasks.py index f65098bf..b6700756 100644 --- a/redis_sre_agent/core/docket_tasks.py +++ b/redis_sre_agent/core/docket_tasks.py @@ -9,22 +9,24 @@ from ulid import ULID from redis_sre_agent.agent import get_sre_agent +from redis_sre_agent.agent.chat_agent import get_chat_agent from redis_sre_agent.agent.knowledge_agent import get_knowledge_agent from redis_sre_agent.agent.langgraph_agent import ( _extract_instance_details_from_message, ) from redis_sre_agent.agent.router import AgentType, route_to_appropriate_agent from redis_sre_agent.core.config import settings -from redis_sre_agent.core.instances import create_instance +from redis_sre_agent.core.instances import create_instance, get_instance_by_id from redis_sre_agent.core.knowledge_helpers import ( ingest_sre_document_helper, search_knowledge_base_helper, ) +from redis_sre_agent.core.progress import TaskEmitter from redis_sre_agent.core.redis import ( get_redis_client, ) from redis_sre_agent.core.tasks import TaskManager, TaskStatus -from redis_sre_agent.core.threads import ThreadManager +from redis_sre_agent.core.threads import Message, ThreadManager logger = logging.getLogger(__name__) @@ -130,6 +132,187 @@ async def ingest_sre_document( raise +@sre_task +async def process_chat_turn( + query: str, + task_id: str, + thread_id: str, + instance_id: Optional[str] = None, + user_id: Optional[str] = None, + exclude_mcp_categories: Optional[List[str]] = None, + retry: Retry = Retry(attempts=2, delay=timedelta(seconds=2)), +) -> Dict[str, Any]: + """ + Process a chat query using the ChatAgent (background task). + + This runs the lightweight ChatAgent for quick Q&A about Redis instances. + Notifications are emitted to the task, and the result is stored on both + the task and the thread. + + Args: + query: User's question + task_id: Task ID for notifications and result storage + thread_id: Thread ID for conversation context and result storage + instance_id: Optional Redis instance ID + user_id: Optional user ID for tracking + exclude_mcp_categories: Optional list of MCP tool category names to exclude. + Valid values: "metrics", "logs", "tickets", "repos", "traces", + "diagnostics", "knowledge", "utilities". + retry: Retry configuration + + Returns: + Dictionary with the chat response + """ + from redis_sre_agent.agent.chat_agent import ChatAgent + from redis_sre_agent.tools.models import ToolCapability + + logger.info(f"Processing chat turn for task {task_id}") + + redis_client = get_redis_client() + task_manager = TaskManager(redis_client=redis_client) + thread_manager = ThreadManager(redis_client=redis_client) + + # Mark task as in progress + await task_manager.update_task_status(task_id, TaskStatus.IN_PROGRESS) + + # Convert string category names to ToolCapability enums + mcp_categories: Optional[List[ToolCapability]] = None + if exclude_mcp_categories: + mcp_categories = [] + for cat_name in exclude_mcp_categories: + try: + mcp_categories.append(ToolCapability(cat_name.lower())) + except ValueError: + logger.warning(f"Unknown MCP category to exclude: {cat_name}") + + try: + # Create task emitter for notifications + emitter = TaskEmitter(task_manager=task_manager, task_id=task_id) + + # Get Redis instance if specified + redis_instance = None + if instance_id: + redis_instance = await get_instance_by_id(instance_id) + if not redis_instance: + raise ValueError(f"Instance not found: {instance_id}") + + # Run chat agent + agent = ChatAgent( + redis_instance=redis_instance, + progress_emitter=emitter, + exclude_mcp_categories=mcp_categories, + ) + response = await agent.process_query( + query=query, + session_id=thread_id, + user_id=user_id or "mcp-user", + progress_emitter=emitter, + ) + + # Store result on task + result = { + "response": response, + "instance_id": instance_id, + } + await task_manager.set_task_result(task_id, result) + await task_manager.update_task_status(task_id, TaskStatus.DONE) + + # Add response to thread as assistant message + await thread_manager.append_messages( + thread_id, + [ + { + "role": "assistant", + "content": response, + "metadata": {"task_id": task_id, "agent": "chat"}, + } + ], + ) + + return result + + except Exception as e: + logger.error(f"Chat turn failed: {e}") + await task_manager.set_task_error(task_id, str(e)) + raise + + +@sre_task +async def process_knowledge_query( + query: str, + task_id: str, + thread_id: str, + user_id: Optional[str] = None, + retry: Retry = Retry(attempts=2, delay=timedelta(seconds=2)), +) -> Dict[str, Any]: + """ + Process a knowledge query using the KnowledgeOnlyAgent (background task). + + This runs the KnowledgeOnlyAgent for general SRE knowledge questions. + Notifications are emitted to the task, and the result is stored on both + the task and the thread. + + Args: + query: User's question about SRE practices or Redis + task_id: Task ID for notifications and result storage + thread_id: Thread ID for conversation context and result storage + user_id: Optional user ID for tracking + retry: Retry configuration + + Returns: + Dictionary with the knowledge agent response + """ + from redis_sre_agent.agent.knowledge_agent import KnowledgeOnlyAgent + + logger.info(f"Processing knowledge query for task {task_id}") + + redis_client = get_redis_client() + task_manager = TaskManager(redis_client=redis_client) + thread_manager = ThreadManager(redis_client=redis_client) + + # Mark task as in progress + await task_manager.update_task_status(task_id, TaskStatus.IN_PROGRESS) + + try: + # Create task emitter for notifications + emitter = TaskEmitter(task_manager=task_manager, task_id=task_id) + + # Run knowledge agent + agent = KnowledgeOnlyAgent(progress_emitter=emitter) + response = await agent.process_query( + query=query, + session_id=thread_id, + user_id=user_id or "mcp-user", + progress_emitter=emitter, + ) + + # Store result on task + result = { + "response": response, + } + await task_manager.set_task_result(task_id, result) + await task_manager.update_task_status(task_id, TaskStatus.DONE) + + # Add response to thread as assistant message + await thread_manager.append_messages( + thread_id, + [ + { + "role": "assistant", + "content": response, + "metadata": {"task_id": task_id, "agent": "knowledge"}, + } + ], + ) + + return result + + except Exception as e: + logger.error(f"Knowledge query failed: {e}") + await task_manager.set_task_error(task_id, str(e)) + raise + + @sre_task async def scheduler_task( global_limit="scheduler", # Need a sentinel value for concurrency limit argument @@ -474,15 +657,32 @@ async def process_agent_turn( logger.info(f"Routing query to {agent_type.value} agent") - # Import and initialize the appropriate agent - if agent_type == AgentType.REDIS_FOCUSED: + # Import and initialize the appropriate agent based on routing decision + # REDIS_TRIAGE = full triage agent (heavy, comprehensive) + # REDIS_CHAT = lightweight chat agent (fast, targeted) + # KNOWLEDGE_ONLY = knowledge agent (no instance needed) + if agent_type == AgentType.REDIS_TRIAGE: agent = get_sre_agent() + elif agent_type == AgentType.REDIS_CHAT: + # Get the target instance for the chat agent + target_instance = ( + await get_instance_by_id(active_instance_id) if active_instance_id else None + ) + agent = get_chat_agent(redis_instance=target_instance) else: agent = get_knowledge_agent() - # Prepare the conversation state with thread context - messages = thread.context.get("messages", []) - logger.debug(f"Loaded {len(messages)} messages from thread context") + # Prepare the conversation state with thread messages + # Convert Message objects to dicts for agent processing + messages = [ + { + "role": m.role, + "content": m.content, + **({"metadata": m.metadata} if m.metadata else {}), + } + for m in thread.messages + ] + logger.debug(f"Loaded {len(messages)} messages from thread") conversation_state = { "messages": messages, @@ -492,11 +692,12 @@ async def process_agent_turn( logger.debug(f"conversation_state messages type: {type(conversation_state['messages'])}") # Add the new user message + user_msg_timestamp = datetime.now(timezone.utc).isoformat() conversation_state["messages"].append( { "role": "user", "content": message, - "timestamp": datetime.now(timezone.utc).isoformat(), + "timestamp": user_msg_timestamp, } ) @@ -508,7 +709,7 @@ async def process_agent_turn( { "role": "user", "content": message, - "timestamp": conversation_state["messages"][-1]["timestamp"], + "metadata": {"timestamp": user_msg_timestamp}, } ], ) @@ -517,21 +718,12 @@ async def process_agent_turn( # Agent will post its own reflections as it works - # Create a progress callback for the agent - async def progress_callback( - update_message: str, - update_type: str = "progress", - metadata: Optional[Dict[str, Any]] = None, - ): - # Include task_id in thread-level metadata for easier grouping - md = dict(metadata or {}) - md.setdefault("task_id", task_id) - await thread_manager.add_thread_update(thread_id, update_message, update_type, md) - try: - await task_manager.add_task_update(task_id, update_message, update_type, metadata) - except Exception: - # Best-effort: do not fail the turn if per-task update logging fails - pass + # Create a task emitter for agent notifications + # Notifications go to the task only; the final result goes to both task and thread + progress_emitter = TaskEmitter( + task_manager=task_manager, + task_id=task_id, + ) # Run the appropriate agent if agent_type == AgentType.KNOWLEDGE_ONLY: @@ -559,7 +751,7 @@ async def progress_callback( session_id=thread.metadata.session_id or thread_id, max_iterations=_k_max_iters, context=None, - progress_callback=progress_callback, + progress_emitter=progress_emitter, conversation_history=lc_history if lc_history else None, ) @@ -567,10 +759,41 @@ async def progress_callback( "response": response_text, "metadata": {"agent_type": "knowledge_only"}, } + elif agent_type == AgentType.REDIS_CHAT: + # Use lightweight chat agent with process_query interface + await thread_manager.add_thread_update( + thread_id, "Processing query with chat agent", "agent_processing" + ) + + # Convert conversation history to LangChain messages + lc_history = [] + for msg in conversation_state["messages"][:-1]: # Exclude the latest message + if msg["role"] == "user": + lc_history.append(HumanMessage(content=msg["content"])) + elif msg["role"] == "assistant": + lc_history.append(AIMessage(content=msg["content"])) + + # Chat agent uses a reasonable iteration cap for quick responses + _chat_max_iters = min(int(settings.max_iterations or 15), 10) + + response_text = await agent.process_query( + query=message, + user_id=thread.metadata.user_id or "unknown", + session_id=thread.metadata.session_id or thread_id, + max_iterations=_chat_max_iters, + context=routing_context, + progress_emitter=progress_emitter, + conversation_history=lc_history if lc_history else None, + ) + + agent_response = { + "response": response_text, + "metadata": {"agent_type": "redis_chat"}, + } else: - # Use Redis-focused agent with full conversation state + # Use full Redis triage agent with full conversation state agent_response = await run_agent_with_progress( - agent, conversation_state, progress_callback, thread + agent, conversation_state, progress_emitter, thread ) # Add agent response to conversation @@ -593,48 +816,54 @@ async def progress_callback( ] # Persist agent reflections/status updates for this turn as chat messages + # Note: Updates are now stored on TaskState, not Thread try: - fresh_state = await thread_manager.get_thread(thread_id) - updates = list(fresh_state.updates or []) - # Keep only updates from this task/turn and relevant types - relevant_types = {"agent_reflection", "agent_processing", "agent_start"} - turn_updates = [ - u - for u in updates - if (u.metadata or {}).get("task_id") == task_id - and u.update_type in relevant_types - and u.message - ] - # Order chronologically - turn_updates.sort(key=lambda u: u.timestamp) - reflection_messages = [ - { - "role": "assistant", - "content": u.message, - "timestamp": u.timestamp, - "metadata": {"update_type": u.update_type, **(u.metadata or {})}, - } - for u in turn_updates - ] - if reflection_messages: - # Insert reflections before the final assistant message for this turn - if clean_messages: - final_msg = clean_messages[-1] - base_msgs = clean_messages[:-1] - # Deduplicate by content - seen = set(m.get("content") for m in base_msgs) - merged = ( - base_msgs - + [m for m in reflection_messages if m["content"] not in seen] - + [final_msg] - ) - clean_messages = merged - else: - clean_messages = reflection_messages + task_state = await task_manager.get_task_state(task_id) + if task_state and task_state.updates: + # Keep only relevant types of updates + relevant_types = {"agent_reflection", "agent_processing", "agent_start"} + turn_updates = [ + u for u in task_state.updates if u.update_type in relevant_types and u.message + ] + # Order chronologically + turn_updates.sort(key=lambda u: u.timestamp) + reflection_messages = [ + { + "role": "assistant", + "content": u.message, + "timestamp": u.timestamp, + "metadata": {"update_type": u.update_type, **(u.metadata or {})}, + } + for u in turn_updates + ] + if reflection_messages: + # Insert reflections before the final assistant message for this turn + if clean_messages: + final_msg = clean_messages[-1] + base_msgs = clean_messages[:-1] + # Deduplicate by content + seen = set(m.get("content") for m in base_msgs) + merged = ( + base_msgs + + [m for m in reflection_messages if m["content"] not in seen] + + [final_msg] + ) + clean_messages = merged + else: + clean_messages = reflection_messages except Exception as e: logger.warning(f"Failed to merge reflection updates into transcript: {e}") - thread.context["messages"] = clean_messages + # Convert clean_messages dicts to Message objects for thread storage + thread.messages = [ + Message( + role=m.get("role", "user"), + content=m.get("content", ""), + metadata={k: v for k, v in m.items() if k not in ("role", "content")} or None, + ) + for m in clean_messages + if m.get("content") + ] thread.context["last_updated"] = datetime.now(timezone.utc).isoformat() # If the subject is empty/placeholder, set an optimistic subject from original_query or first user message @@ -651,13 +880,9 @@ async def progress_callback( candidate = oq.strip() else: # Find the first user message content - for m in clean_messages: - if ( - isinstance(m, dict) - and m.get("role") == "user" - and (m.get("content") or "").strip() - ): - candidate = m.get("content").strip() + for m in thread.messages: + if m.role == "user" and m.content.strip(): + candidate = m.content.strip() break if candidate: # Normalize to a single line and cap length @@ -668,13 +893,13 @@ async def progress_callback( except Exception as e: logger.warning(f"Failed to set optimistic subject for thread {thread_id}: {e}") - # Save the updated context to Redis + # Save the updated thread state to Redis await thread_manager._save_thread_state(thread) logger.info( - f"Saved conversation history: {len(clean_messages)} user/assistant messages (filtered from {len(conversation_state['messages'])} total)" + f"Saved conversation history: {len(thread.messages)} user/assistant messages (filtered from {len(conversation_state['messages'])} total)" ) - # Set the final result + # Set the final result on the task (not the thread - results belong on tasks) result = { "response": agent_response.get("response", ""), "metadata": agent_response.get("metadata", {}), @@ -685,9 +910,12 @@ async def progress_callback( await task_manager.set_task_result(task_id, result) await task_manager.update_task_status(task_id, TaskStatus.DONE) - await thread_manager.set_thread_result(thread_id, result) - await thread_manager.add_thread_update( - thread_id, f"Task {task_id} completed successfully", "turn_complete" + + # Publish completion to stream for WebSocket updates (deprecated methods but still publish) + await thread_manager._publish_stream_update( + thread_id, + "turn_complete", + {"task_id": task_id, "message": "Task completed successfully"}, ) # End root span if present @@ -728,17 +956,17 @@ async def progress_callback( async def run_agent_with_progress( - agent, conversation_state: Dict[str, Any], progress_callback, thread_state=None + agent, conversation_state: Dict[str, Any], progress_emitter, thread_state=None ): """ Run the LangGraph agent with progress updates. - This creates a new agent instance with progress callback support and runs it. + This creates a new agent instance with progress emitter support and runs it. Args: agent: The agent instance (currently unused, kept for compatibility) conversation_state: Dictionary containing messages and thread_id - progress_callback: Async callback function for progress updates + progress_emitter: ProgressEmitter instance for progress updates thread_state: Optional thread state object containing metadata and context """ try: @@ -749,10 +977,10 @@ async def run_agent_with_progress( if not messages: raise ValueError("No messages in conversation") - # Create a new agent instance with progress callback + # Create a new agent instance with progress emitter from redis_sre_agent.agent.langgraph_agent import SRELangGraphAgent - progress_agent = SRELangGraphAgent(progress_callback=progress_callback) + progress_agent = SRELangGraphAgent(progress_emitter=progress_emitter) # Convert conversation messages to LangChain format # We only store user/assistant messages, tool messages are internal to LangGraph @@ -802,7 +1030,7 @@ async def run_agent_with_progress( user_id=thread_state.metadata.user_id if thread_state else "system", max_iterations=settings.max_iterations, context=agent_context, - progress_callback=progress_callback, + progress_emitter=progress_emitter, conversation_history=lc_messages[:-1] if lc_messages else None, # Exclude the latest message (it's added in process_query) @@ -810,7 +1038,7 @@ async def run_agent_with_progress( # Create a mock final state for compatibility - await progress_callback("Agent workflow completed", "agent_complete") + await progress_emitter.emit("Agent workflow completed", "agent_complete") # The response is already the final agent response agent_response = response @@ -825,7 +1053,7 @@ async def run_agent_with_progress( } except Exception as e: - await progress_callback(f"Agent error: {str(e)}", "error") + await progress_emitter.emit(f"Agent error: {str(e)}", "error") raise diff --git a/redis_sre_agent/core/keys.py b/redis_sre_agent/core/keys.py index b2ddedc3..10d5e810 100644 --- a/redis_sre_agent/core/keys.py +++ b/redis_sre_agent/core/keys.py @@ -23,9 +23,17 @@ def thread_status(thread_id: str) -> str: @staticmethod def thread_updates(thread_id: str) -> str: - """Key for thread updates list.""" + """Key for thread updates list. + + DEPRECATED: Use task_updates() instead. Progress updates belong on tasks. + """ return f"sre:thread:{thread_id}:updates" + @staticmethod + def thread_messages(thread_id: str) -> str: + """Key for thread messages list (conversation history).""" + return f"sre:thread:{thread_id}:messages" + @staticmethod def thread_context(thread_id: str) -> str: """Key for thread context (conversation history, etc.).""" @@ -166,9 +174,10 @@ def all_thread_keys(thread_id: str) -> dict[str, str]: """ return { "status": RedisKeys.thread_status(thread_id), - "updates": RedisKeys.thread_updates(thread_id), + "messages": RedisKeys.thread_messages(thread_id), + "updates": RedisKeys.thread_updates(thread_id), # DEPRECATED "context": RedisKeys.thread_context(thread_id), "metadata": RedisKeys.thread_metadata(thread_id), - "result": RedisKeys.thread_result(thread_id), - "error": RedisKeys.thread_error(thread_id), + "result": RedisKeys.thread_result(thread_id), # DEPRECATED + "error": RedisKeys.thread_error(thread_id), # DEPRECATED } diff --git a/redis_sre_agent/core/knowledge_helpers.py b/redis_sre_agent/core/knowledge_helpers.py index a9678af1..719fbb87 100644 --- a/redis_sre_agent/core/knowledge_helpers.py +++ b/redis_sre_agent/core/knowledge_helpers.py @@ -26,8 +26,10 @@ async def search_knowledge_base_helper( query: str, category: Optional[str] = None, limit: int = 10, + offset: int = 0, distance_threshold: Optional[float] = 0.5, hybrid_search: bool = False, + version: Optional[str] = "latest", ) -> Dict[str, Any]: """Search the SRE knowledge base. @@ -37,18 +39,24 @@ async def search_knowledge_base_helper( Behavior: - Default: distance_threshold=0.5 (filters by cosine distance) - Explicit None: disables threshold (pure KNN, return top-k regardless of distance) + - Default version: "latest" (filters to unversioned/latest docs) + - Explicit version: Filter to specific version (e.g., "7.8", "7.4") + - version=None: Return all versions (no version filtering) Args: query: Search query text category: Optional category filter (incident, maintenance, monitoring, etc.) limit: Maximum number of results + offset: Number of results to skip (for pagination) distance_threshold: Cosine distance cutoff; None disables threshold hybrid_search: Whether to use hybrid search (vector + full-text) + version: Version filter - "latest" (default), specific version like "7.8", + or None to return all versions Returns: Dictionary with search results including task_id, query, results, etc. """ - logger.info(f"Searching SRE knowledge: '{query}'") + logger.info(f"Searching SRE knowledge: '{query}' (version={version}, offset={offset})") index = await get_knowledge_index() return_fields = [ "id", @@ -59,8 +67,18 @@ async def search_knowledge_base_helper( "source", "category", "severity", + "version", ] + # Build version filter expression if version is specified + from redisvl.query.filter import Tag + + filter_expr = None + if version is not None: + # Filter by specific version (e.g., "latest", "7.8", "7.4") + filter_expr = Tag("version") == version + logger.debug(f"Applying version filter: {version}") + # Always use vector search (tests rely on embedding being used) vectorizer = get_vectorizer() @@ -73,6 +91,10 @@ async def search_knowledge_base_helper( query_vector = vectors[0] if vectors else [] + # We need to fetch more results if there's an offset, then slice + # This is because RedisVL vector queries don't support offset directly + fetch_limit = limit + offset + if hybrid_search: logger.info(f"Using hybrid search (vector + full-text) for query: {query}") query_obj = HybridQuery( @@ -80,8 +102,9 @@ async def search_knowledge_base_helper( vector_field_name="vector", text_field_name="content", text=query, - num_results=limit, + num_results=fetch_limit, return_fields=return_fields, + filter_expression=filter_expr, ) else: # Build pure vector query @@ -92,7 +115,7 @@ async def search_knowledge_base_helper( vector=query_vector, vector_field_name="vector", return_fields=return_fields, - num_results=limit, + num_results=fetch_limit, distance_threshold=effective_threshold, ) else: @@ -100,26 +123,37 @@ async def search_knowledge_base_helper( vector=query_vector, vector_field_name="vector", return_fields=return_fields, - num_results=limit, + num_results=fetch_limit, ) + if filter_expr is not None: + query_obj.set_filter(filter_expr) - # Perform vector search (no category filter) + # Perform vector search _t2 = time.monotonic() with tracer.start_as_current_span("knowledge.index.query") as _span: _span.set_attribute("limit", int(limit)) + _span.set_attribute("offset", int(offset)) _span.set_attribute("hybrid_search", bool(hybrid_search)) + _span.set_attribute("version", version or "all") _span.set_attribute( "distance_threshold", float(distance_threshold) if distance_threshold is not None else -1.0, ) - results = await index.query(query_obj) + all_results = await index.query(query_obj) _t3 = time.monotonic() + # Apply offset by slicing results + results = all_results[offset:] if offset > 0 else all_results + search_result = { "query": query, "category": category, + "version": version, + "offset": offset, + "limit": limit, "timestamp": datetime.now(timezone.utc).isoformat(), "results_count": len(results), + "total_fetched": len(all_results), "results": [ { "id": doc.get("id", ""), @@ -130,6 +164,7 @@ async def search_knowledge_base_helper( "content": doc.get("content", ""), "source": doc.get("source", ""), "category": doc.get("category", ""), + "version": doc.get("version", "latest"), # RedisVL returns distance when return_score=True (default). Some versions # expose it as 'score' and others as 'vector_distance' or 'distance'. # Normalize to float. diff --git a/redis_sre_agent/core/progress.py b/redis_sre_agent/core/progress.py new file mode 100644 index 00000000..10d08b70 --- /dev/null +++ b/redis_sre_agent/core/progress.py @@ -0,0 +1,445 @@ +"""Progress emission abstraction for agent status updates. + +This module provides a ProgressEmitter protocol that abstracts how progress/status +updates (notifications) are emitted during agent execution. Different implementations +can send updates to different destinations: + +- TaskEmitter: Persists notifications to a Task in Redis. Clients poll the task + for status and notifications. This is the primary implementation for + both REST and MCP paths. +- MCPEmitter: Sends MCP protocol progress notifications (for synchronous MCP tools) +- CompositeEmitter: Combines multiple emitters for simultaneous delivery +- NullEmitter: No-op emitter for testing or batch jobs +- LoggingEmitter: Logs updates for debugging + +Architecture: + - Notifications (tool reflections, progress) → Task updates (via TaskEmitter) + - Final result → Task result AND Thread message (handled by docket_tasks) + - Clients (REST or MCP) poll get_task_status() for notifications and status + +Example: + # Docket worker path (REST and MCP both use this) + emitter = TaskEmitter(task_manager, task_id) + agent = SRELangGraphAgent(progress_emitter=emitter) +""" + +from __future__ import annotations + +import asyncio +import logging +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Protocol, runtime_checkable + +if TYPE_CHECKING: + from redis_sre_agent.core.tasks import TaskManager + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Progress Counter (for MCP's monotonically increasing requirement) +# --------------------------------------------------------------------------- + + +class ProgressCounter(ABC): + """Abstract counter for generating monotonically increasing progress values.""" + + @abstractmethod + async def next(self) -> int: + """Get the next progress value. Must always return a value > previous.""" + ... + + +class LocalProgressCounter(ProgressCounter): + """Thread-safe monotonic counter for single-process scenarios. + + Uses an asyncio.Lock to ensure concurrent calls always get increasing values. + """ + + def __init__(self, start: int = 0): + self._value = start + self._lock = asyncio.Lock() + + async def next(self) -> int: + async with self._lock: + self._value += 1 + return self._value + + +# --------------------------------------------------------------------------- +# ProgressEmitter Protocol and Implementations +# --------------------------------------------------------------------------- + + +@runtime_checkable +class ProgressEmitter(Protocol): + """Protocol for emitting progress/status updates during agent execution. + + Implementations of this protocol handle where and how progress updates + are delivered (Redis persistence, MCP notifications, logging, etc.). + """ + + async def emit( + self, + message: str, + update_type: str = "progress", + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + """Emit a progress update. + + Args: + message: Human-readable status message + update_type: Category of update (e.g., "progress", "agent_reflection", + "knowledge_sources", "tool_call") + metadata: Optional additional data (e.g., fragments, tool args) + """ + ... + + +class NullEmitter: + """No-op emitter that discards all updates. Useful for testing or batch jobs.""" + + async def emit( + self, + message: str, + update_type: str = "progress", + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + pass + + +class LoggingEmitter: + """Emitter that logs updates. Useful for debugging.""" + + def __init__(self, logger_name: str = __name__, level: int = logging.INFO): + self._logger = logging.getLogger(logger_name) + self._level = level + + async def emit( + self, + message: str, + update_type: str = "progress", + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + self._logger.log(self._level, f"[{update_type}] {message}") + + +class CLIEmitter: + """Emitter that prints notifications to the terminal for CLI usage. + + Formats output with colors/symbols based on update_type for better + readability in terminal environments. + """ + + # ANSI color codes + COLORS = { + "reset": "\033[0m", + "dim": "\033[2m", + "bold": "\033[1m", + "blue": "\033[34m", + "green": "\033[32m", + "yellow": "\033[33m", + "cyan": "\033[36m", + "magenta": "\033[35m", + } + + # Symbols and colors for different update types + TYPE_STYLES = { + "agent_start": ("🚀", "green"), + "agent_complete": ("✅", "green"), + "agent_error": ("❌", "yellow"), + "agent_reflection": ("💭", "cyan"), + "agent_processing": ("⚙️ ", "blue"), + "tool_call": ("🔧", "magenta"), + "knowledge_sources": ("📚", "blue"), + "progress": ("→", "dim"), + "instance_context": ("🔗", "cyan"), + "instance_created": ("➕", "green"), + "instance_error": ("⚠️ ", "yellow"), + "task_start": ("📋", "blue"), + "error": ("❌", "yellow"), + } + + def __init__(self, use_colors: bool = True, file=None): + """Initialize CLI emitter. + + Args: + use_colors: Whether to use ANSI colors (disable for non-TTY output) + file: Output file (defaults to sys.stderr) + """ + import sys + + self._use_colors = use_colors and (file or sys.stderr).isatty() + self._file = file or sys.stderr + + def _colorize(self, text: str, color: str) -> str: + """Apply ANSI color to text if colors are enabled.""" + if not self._use_colors or color not in self.COLORS: + return text + return f"{self.COLORS[color]}{text}{self.COLORS['reset']}" + + async def emit( + self, + message: str, + update_type: str = "progress", + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + """Print notification to terminal.""" + symbol, color = self.TYPE_STYLES.get(update_type, ("•", "dim")) + formatted = f"{symbol} {self._colorize(message, color)}" + print(formatted, file=self._file, flush=True) + + +class TaskEmitter: + """Emitter that persists notifications to a Task in Redis. + + Notifications (tool reflections, progress updates) are stored on the Task, + not the Thread. Clients (REST or MCP) can poll the Task for notifications + and status updates. + + The Thread is only updated with the final result (as a message), which is + handled separately by the task completion logic, not by this emitter. + """ + + def __init__( + self, + task_manager: "TaskManager", + task_id: str, + ): + self._task_manager = task_manager + self._task_id = task_id + + @property + def task_id(self) -> str: + """Return the task ID this emitter is writing to.""" + return self._task_id + + async def emit( + self, + message: str, + update_type: str = "progress", + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + """Emit notification to task storage.""" + try: + await self._task_manager.add_task_update(self._task_id, message, update_type, metadata) + except Exception as e: + # Best-effort: don't fail the agent if notification logging fails + logger.warning(f"Failed to emit task notification: {e}") + + +class MCPEmitter: + """Emitter that sends MCP protocol progress notifications. + + This implementation is used when the agent is invoked via MCP, sending + real-time progress updates to the MCP client (e.g., Claude Desktop). + + The MCP spec requires progress values to be monotonically increasing, + so this emitter uses a ProgressCounter to generate sequence numbers. + + IMPORTANT: For MCP progress to work, the agent must run synchronously + within the MCP tool call - not in a background worker like Docket. + + Example using FastMCP Context: + from fastmcp import Context + from redis_sre_agent.core.progress import MCPEmitter + + @mcp.tool + async def triage_sync(query: str, ctx: Context) -> Dict[str, Any]: + emitter = MCPEmitter.from_fastmcp_context(ctx) + agent = SRELangGraphAgent(progress_emitter=emitter) + response = await agent.process_query(...) + return {"response": response} + """ + + def __init__( + self, + send_progress: Any, # Callable to send MCP progress notification + counter: Optional[ProgressCounter] = None, + ): + """Initialize MCP emitter. + + Args: + send_progress: Async callable that sends MCP notifications. + Signature: (progress: float, total: float | None) -> None + counter: Optional custom counter; defaults to LocalProgressCounter + """ + self._send_progress = send_progress + self._counter = counter or LocalProgressCounter() + + @classmethod + def from_fastmcp_context(cls, ctx: Any) -> "MCPEmitter": + """Create an MCPEmitter from a FastMCP Context object. + + Args: + ctx: FastMCP Context object (from tool function parameter) + + Returns: + MCPEmitter configured to use the context's report_progress method + """ + return cls(send_progress=ctx.report_progress) + + async def emit( + self, + message: str, + update_type: str = "progress", + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + """Emit progress via MCP notification. + + Note: MCP progress notifications don't have a message field in + report_progress, but we log the message and use the counter for + the progress value. Clients will see increasing progress numbers. + """ + try: + progress = await self._counter.next() + # FastMCP's report_progress takes (progress, total) + # We use indeterminate progress (no total) since we don't know + # how many updates there will be + await self._send_progress(progress=progress, total=None) + # Also log the message for debugging + logger.debug(f"MCP progress {progress}: [{update_type}] {message}") + except Exception as e: + # Don't fail the agent if MCP notification fails + logger.warning(f"Failed to send MCP progress notification: {e}") + + +class CompositeEmitter: + """Emitter that forwards updates to multiple child emitters. + + Useful when you want updates delivered to multiple destinations, + e.g., both MCP notifications and Redis persistence for debugging. + """ + + def __init__(self, emitters: List[ProgressEmitter]): + self._emitters = emitters + + async def emit( + self, + message: str, + update_type: str = "progress", + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + """Emit to all child emitters concurrently.""" + if not self._emitters: + return + + await asyncio.gather( + *[e.emit(message, update_type, metadata) for e in self._emitters], + return_exceptions=True, # Don't fail if one emitter fails + ) + + +class CallbackEmitter: + """Emitter that wraps a legacy callback function. + + Provides backward compatibility for code that still uses the old + progress_callback signature: async def callback(message, update_type, metadata) + """ + + def __init__(self, callback): + """Initialize with a legacy callback. + + Args: + callback: Async callable with signature (str, str, Optional[Dict]) -> None + """ + self._callback = callback + + async def emit( + self, + message: str, + update_type: str = "progress", + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + """Forward to the legacy callback.""" + if self._callback: + try: + await self._callback(message, update_type, metadata) + except TypeError: + # Some callbacks may not accept metadata + await self._callback(message, update_type) + + +# --------------------------------------------------------------------------- +# Emitter Factory - context-aware emitter creation +# --------------------------------------------------------------------------- + + +def create_emitter( + *, + task_id: Optional[str] = None, + task_manager: Optional["TaskManager"] = None, + cli: bool = False, + cli_colors: bool = True, + additional_emitters: Optional[List[ProgressEmitter]] = None, +) -> ProgressEmitter: + """Create the appropriate emitter based on context. + + This factory function returns the right emitter for the execution context: + - If task_id/task_manager provided: TaskEmitter (writes to task) + - If cli=True: CLIEmitter (prints to terminal) + - Can combine multiple emitters via CompositeEmitter + + Args: + task_id: Task ID to emit notifications to (requires task_manager) + task_manager: TaskManager instance for persisting to Redis + cli: Whether to emit to CLI (terminal output) + cli_colors: Whether to use colors in CLI output + additional_emitters: Extra emitters to include + + Returns: + ProgressEmitter instance (may be composite if multiple destinations) + + Examples: + # Task context (REST API, MCP via Docket) + emitter = create_emitter(task_id=task_id, task_manager=task_manager) + + # CLI context + emitter = create_emitter(cli=True) + + # Both task and CLI (debugging) + emitter = create_emitter(task_id=task_id, task_manager=task_manager, cli=True) + """ + emitters: List[ProgressEmitter] = [] + + # Add task emitter if in task context + if task_id and task_manager: + emitters.append(TaskEmitter(task_manager=task_manager, task_id=task_id)) + + # Add CLI emitter if requested + if cli: + emitters.append(CLIEmitter(use_colors=cli_colors)) + + # Add any additional emitters + if additional_emitters: + emitters.extend(additional_emitters) + + # Return appropriate emitter + if not emitters: + return NullEmitter() + elif len(emitters) == 1: + return emitters[0] + else: + return CompositeEmitter(emitters) + + +async def create_emitter_for_task( + task_id: str, + redis_client=None, +) -> ProgressEmitter: + """Convenience function to create a TaskEmitter for a given task_id. + + This is useful when you have a task_id but not a TaskManager instance. + It creates the TaskManager internally. + + Args: + task_id: The task ID to emit notifications to + redis_client: Optional Redis client (uses default if not provided) + + Returns: + TaskEmitter configured for the given task + """ + from redis_sre_agent.core.tasks import TaskManager + + task_manager = TaskManager(redis_client=redis_client) + return TaskEmitter(task_manager=task_manager, task_id=task_id) diff --git a/redis_sre_agent/core/redis.py b/redis_sre_agent/core/redis.py index eb6afa0f..58580de4 100644 --- a/redis_sre_agent/core/redis.py +++ b/redis_sre_agent/core/redis.py @@ -69,6 +69,10 @@ "name": "product_label_tags", "type": "tag", }, + { + "name": "version", + "type": "tag", + }, { "name": "created_at", "type": "numeric", @@ -211,6 +215,12 @@ def get_vectorizer() -> OpenAITextVectorizer: """Get OpenAI vectorizer with Redis-backed embeddings cache. Returns the native vectorizer; callers should use aembed/aembed_many. + + The embeddings cache uses a stable key namespace ("sre_embeddings_cache") + so that embeddings are shared across vectorizer instances. Cache keys + include the model name, so different models won't conflict. + + TTL is configurable via settings.embeddings_cache_ttl (default: 7 days). """ # Build Redis URL with password if needed (ensure cache can auth) redis_url = settings.redis_url.get_secret_value() @@ -219,7 +229,13 @@ def get_vectorizer() -> OpenAITextVectorizer: redis_url = redis_url.replace("redis://", f"redis://:{redis_password}@") # Name the cache to keep a stable key namespace - cache = EmbeddingsCache(name="sre_embeddings_cache", redis_url=redis_url) + # TTL prevents stale embeddings if model changes + cache = EmbeddingsCache( + name="sre_embeddings_cache", + redis_url=redis_url, + ttl=settings.embeddings_cache_ttl, + ) + logger.debug(f"Vectorizer created with embeddings cache (ttl={settings.embeddings_cache_ttl}s)") return OpenAITextVectorizer( model=settings.embedding_model, @@ -407,6 +423,58 @@ async def create_indices() -> bool: return False +async def recreate_indices(index_name: str | None = None) -> dict: + """Drop and recreate RediSearch indices. + + This is useful when the schema has changed (e.g., new fields added). + + Args: + index_name: Specific index to recreate ('knowledge', 'schedules', 'threads', + 'tasks', 'instances'), or None to recreate all. + + Returns: + Dictionary with success status and details for each index. + """ + result = {"success": True, "indices": {}} + + index_configs = [ + ("knowledge", SRE_KNOWLEDGE_INDEX, get_knowledge_index), + ("schedules", SRE_SCHEDULES_INDEX, get_schedules_index), + ("threads", SRE_THREADS_INDEX, get_threads_index), + ("tasks", SRE_TASKS_INDEX, get_tasks_index), + ("instances", SRE_INSTANCES_INDEX, get_instances_index), + ] + + for name, idx_name, get_fn in index_configs: + # Skip if a specific index was requested and this isn't it + if index_name and name != index_name: + continue + + try: + idx = await get_fn() + + # Drop index if it exists + if await idx.exists(): + try: + # Use FT.DROPINDEX to drop without deleting documents + await idx._redis_client.execute_command("FT.DROPINDEX", idx_name) + logger.info(f"Dropped index: {idx_name}") + except Exception as drop_err: + logger.warning(f"Could not drop index {idx_name}: {drop_err}") + + # Recreate with current schema + await idx.create() + logger.info(f"Created index: {idx_name}") + result["indices"][name] = "recreated" + + except Exception as e: + logger.error(f"Failed to recreate index {name}: {e}") + result["indices"][name] = f"error: {e}" + result["success"] = False + + return result + + async def initialize_redis() -> dict: """Initialize Redis infrastructure and return status.""" status = {} diff --git a/redis_sre_agent/core/task_events.py b/redis_sre_agent/core/task_events.py index 3f515795..03f304b9 100644 --- a/redis_sre_agent/core/task_events.py +++ b/redis_sre_agent/core/task_events.py @@ -10,7 +10,7 @@ from pydantic import BaseModel, ConfigDict, Field -from .threads import ThreadUpdate +from .tasks import TaskUpdate class TaskStreamEvent(BaseModel): @@ -28,8 +28,11 @@ class TaskStreamEvent(BaseModel): class InitialStateEvent(TaskStreamEvent): - """Initial snapshot event sent upon WebSocket connection.""" + """Initial snapshot event sent upon WebSocket connection. - updates: List[ThreadUpdate] = Field(default_factory=list) + Updates, result, and error_message come from the latest Task, not the Thread. + """ + + updates: List[TaskUpdate] = Field(default_factory=list) result: Optional[Dict[str, Any]] = None error_message: Optional[str] = None diff --git a/redis_sre_agent/core/tasks.py b/redis_sre_agent/core/tasks.py index 45572f50..f617dcee 100644 --- a/redis_sre_agent/core/tasks.py +++ b/redis_sre_agent/core/tasks.py @@ -228,11 +228,23 @@ async def get_task_state(self, task_id: str) -> Optional[TaskState]: except Exception: result = None - md = await self._redis.hgetall(RedisKeys.task_metadata(task_id)) + md_raw = await self._redis.hgetall(RedisKeys.task_metadata(task_id)) + # Decode byte keys/values from hgetall when decode_responses=False + md: Dict[str, Any] = {} + if isinstance(md_raw, dict): + for k, v in md_raw.items(): + key = k.decode("utf-8") if isinstance(k, bytes) else k + val = v.decode("utf-8") if isinstance(v, bytes) else v + md[key] = val + # thread_id stored in metadata for convenience - thread_id = md.get("thread_id") if isinstance(md, dict) else None - if isinstance(thread_id, bytes): - thread_id = thread_id.decode("utf-8") + thread_id = md.get("thread_id") + + # Handle error_message - decode if bytes + error_raw = await self._redis.get(RedisKeys.task_error(task_id)) + error_message = None + if error_raw: + error_message = error_raw.decode("utf-8") if isinstance(error_raw, bytes) else error_raw return TaskState( task_id=task_id, @@ -242,13 +254,12 @@ async def get_task_state(self, task_id: str) -> Optional[TaskState]: ), updates=updates, result=result, - error_message=(await self._redis.get(RedisKeys.task_error(task_id))) or None, + error_message=error_message, metadata=TaskMetadata( - created_at=(md.get("created_at") if isinstance(md, dict) else None) - or datetime.now(timezone.utc).isoformat(), - updated_at=(md.get("updated_at") if isinstance(md, dict) else None), - user_id=(md.get("user_id") if isinstance(md, dict) else None), - subject=(md.get("subject") if isinstance(md, dict) else None), + created_at=md.get("created_at") or datetime.now(timezone.utc).isoformat(), + updated_at=md.get("updated_at"), + user_id=md.get("user_id"), + subject=md.get("subject"), ), ) diff --git a/redis_sre_agent/core/threads.py b/redis_sre_agent/core/threads.py index ed0ab88e..1e12377f 100644 --- a/redis_sre_agent/core/threads.py +++ b/redis_sre_agent/core/threads.py @@ -20,7 +20,11 @@ class ThreadUpdate(BaseModel): - """Individual progress update within a thread.""" + """Individual progress update within a thread. + + DEPRECATED: Progress updates should be stored on TaskState, not Thread. + This class is kept for backward compatibility when reading old data. + """ timestamp: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) message: str @@ -28,6 +32,14 @@ class ThreadUpdate(BaseModel): metadata: Optional[Dict[str, Any]] = None +class Message(BaseModel): + """A single message in a thread conversation.""" + + role: str = Field(default="user", description="Message role: user|assistant|system") + content: str + metadata: Optional[Dict[str, Any]] = None + + class ThreadMetadata(BaseModel): """Thread metadata and configuration.""" @@ -41,14 +53,21 @@ class ThreadMetadata(BaseModel): class Thread(BaseModel): - """Complete thread state representation.""" + """Complete thread state representation. + + A Thread represents a conversation. It contains: + - messages: The conversation history (user, assistant, system messages) + - context: Additional context data (instance_id, original_query, etc.) + - metadata: Thread metadata (created_at, user_id, tags, etc.) + + Note: result, error_message, and progress updates belong on TaskState, + not Thread. Tasks represent individual agent turns within a thread. + """ thread_id: str = Field(default_factory=lambda: str(ULID())) - updates: List[ThreadUpdate] = Field(default_factory=list) + messages: List[Message] = Field(default_factory=list) context: Dict[str, Any] = Field(default_factory=dict) metadata: ThreadMetadata = Field(default_factory=ThreadMetadata) - result: Optional[Dict[str, Any]] = None - error_message: Optional[str] = None class ThreadManager: @@ -386,21 +405,19 @@ async def get_thread(self, thread_id: str) -> Optional[Thread]: if not await client.exists(keys["metadata"]): return None - # Load all thread data - updates_data = await client.lrange(keys["updates"], 0, -1) + # Load thread data + messages_data = await client.lrange(keys["messages"], 0, -1) context_data = await client.hgetall(keys["context"]) metadata_data = await client.hgetall(keys["metadata"]) - result_data = await client.get(keys["result"]) - error_data = await client.get(keys["error"]) - # Parse updates - updates = [] - for update_json in updates_data: + # Parse messages from dedicated list (FIFO order via RPUSH) + messages: List[Message] = [] + for msg_json in messages_data: try: - update_dict = json.loads(update_json) - updates.append(ThreadUpdate(**update_dict)) + msg_dict = json.loads(msg_json) + messages.append(Message(**msg_dict)) except (json.JSONDecodeError, Exception) as e: - logger.warning(f"Failed to parse update: {e}") + logger.warning(f"Failed to parse message: {e}") # Parse metadata metadata = ThreadMetadata() @@ -434,23 +451,25 @@ async def get_thread(self, thread_id: str) -> Optional[Thread]: # Fallback: just decode bytes to strings context = {k.decode(): v.decode() for k, v in context_data.items()} - # Parse result and error - result = None - if result_data: - try: - result = json.loads(result_data) - except json.JSONDecodeError: - result = {"raw": result_data.decode()} - - error_message = error_data.decode() if error_data else None + # BACKWARD COMPATIBILITY: If no messages in dedicated list, check context["messages"] + if not messages and isinstance(context.get("messages"), list): + for m in context["messages"]: + if isinstance(m, dict) and m.get("content"): + messages.append( + Message( + role=m.get("role", "user"), + content=m.get("content", ""), + metadata=m.get("metadata"), + ) + ) + # Remove messages from context since they're now in the messages field + context.pop("messages", None) return Thread( thread_id=thread_id, - updates=updates, + messages=messages, context=context, metadata=metadata, - result=result, - error_message=error_message, ) except Exception as e: @@ -464,26 +483,23 @@ async def add_thread_update( update_type: str = "progress", metadata: Optional[Dict[str, Any]] = None, ) -> bool: - """Add a progress update to the thread.""" - try: - client = await self._get_client() - keys = self._get_thread_keys(thread_id) - - update = ThreadUpdate(message=message, update_type=update_type, metadata=metadata) + """Add a progress update to the thread. - # Add to updates list - update_json = update.model_dump_json() - await client.lpush(keys["updates"], update_json) - - # Keep only last 100 updates - await client.ltrim(keys["updates"], 0, 99) + DEPRECATED: Progress updates should be stored on TaskState via TaskManager. + This method now only publishes to the stream for WebSocket updates. + """ + import warnings - # Update metadata timestamp - await client.hset( - keys["metadata"], "updated_at", datetime.now(timezone.utc).isoformat() - ) + warnings.warn( + "add_thread_update is deprecated. Use TaskManager.add_task_update instead.", + DeprecationWarning, + stacklevel=2, + ) - # Publish update to stream + try: + # Only publish to stream for real-time WebSocket updates + # Don't store on thread - updates belong on tasks + update = ThreadUpdate(message=message, update_type=update_type, metadata=metadata) await self._publish_stream_update( thread_id, "thread_update", @@ -495,43 +511,38 @@ async def add_thread_update( }, ) - # Update search index - await self._upsert_thread_search_doc(thread_id) - - logger.debug(f"Added update to thread {thread_id}: {message}") + logger.debug(f"Published update for thread {thread_id}: {message}") return True except Exception as e: - logger.error(f"Failed to add update to thread {thread_id}: {e}") + logger.error(f"Failed to publish update for thread {thread_id}: {e}") return False async def set_thread_result(self, thread_id: str, result: Dict[str, Any]) -> bool: - """Set the final result for a thread.""" - try: - client = await self._get_client() - keys = self._get_thread_keys(thread_id) + """Set the final result for a thread. - result_json = json.dumps(result) - await client.set(keys["result"], result_json) + DEPRECATED: Results should be stored on TaskState via TaskManager. + This method now only publishes to the stream for WebSocket updates. + """ + import warnings - # Update metadata timestamp - await client.hset( - keys["metadata"], "updated_at", datetime.now(timezone.utc).isoformat() - ) + warnings.warn( + "set_thread_result is deprecated. Use TaskManager.set_task_result instead.", + DeprecationWarning, + stacklevel=2, + ) - # Publish result to stream + try: + # Only publish to stream for real-time WebSocket updates await self._publish_stream_update( thread_id, "result_set", {"result": result, "message": "Task result available"} ) - # Update search index - await self._upsert_thread_search_doc(thread_id) - - logger.info(f"Set result for thread {thread_id}") + logger.info(f"Published result for thread {thread_id}") return True except Exception as e: - logger.error(f"Failed to set result for thread {thread_id}: {e}") + logger.error(f"Failed to publish result for thread {thread_id}: {e}") return False async def _publish_stream_update( @@ -550,19 +561,21 @@ async def _publish_stream_update( return False async def set_thread_error(self, thread_id: str, error_message: str) -> bool: - """Set error message and mark thread as failed.""" - try: - client = await self._get_client() - keys = self._get_thread_keys(thread_id) + """Set error message for a thread. - await client.set(keys["error"], error_message) + DEPRECATED: Errors should be stored on TaskState via TaskManager. + This method is now a no-op but kept for backward compatibility. + """ + import warnings - logger.error(f"Set error for thread {thread_id}: {error_message}") - return True + warnings.warn( + "set_thread_error is deprecated. Use TaskManager.set_task_error instead.", + DeprecationWarning, + stacklevel=2, + ) - except Exception as e: - logger.error(f"Failed to set error for thread {thread_id}: {e}") - return False + logger.warning(f"set_thread_error called (deprecated) for thread {thread_id}") + return True async def update_thread_context( self, thread_id: str, context_updates: Dict[str, Any], merge: bool = True @@ -633,33 +646,45 @@ async def update_thread_context( return False async def append_messages(self, thread_id: str, messages: List[Dict[str, Any]]) -> bool: - """Append messages to a thread's message list in context. + """Append messages to thread's message list. - This treats context["messages"] as a JSON-serializable list of {role, content, ...} dicts. + Messages are stored in a dedicated Redis list (RPUSH for FIFO order). + Each message should have {role, content, metadata?}. """ try: - # Load existing messages from thread state - state = await self.get_thread(thread_id) - existing = [] - if state and isinstance(state.context.get("messages"), list): - existing = state.context.get("messages") + client = await self._get_client() + keys = self._get_thread_keys(thread_id) - # Append new messages, minimal validation + # Append each message to the list (RPUSH for chronological order) for m in messages or []: if not isinstance(m, dict): continue - role = m.get("role") content = m.get("content") if not content: continue - if role not in ("user", "assistant", "system", None): + + role = m.get("role", "user") + if role not in ("user", "assistant", "system"): role = "user" - existing.append( - {k: v for k, v in m.items() if k in ("role", "content", "metadata") or True} + + msg = Message( + role=role, + content=content, + metadata=m.get("metadata"), ) + await client.rpush(keys["messages"], msg.model_dump_json()) + + # Update metadata timestamp + await client.hset( + keys["metadata"], "updated_at", datetime.now(timezone.utc).isoformat() + ) + + # Update search index + await self._upsert_thread_search_doc(thread_id) + + logger.debug(f"Appended {len(messages)} messages to thread {thread_id}") + return True - # Save back to context - return await self.update_thread_context(thread_id, {"messages": existing}, merge=True) except Exception as e: logger.error(f"Failed to append messages for thread {thread_id}: {e}") return False @@ -671,11 +696,20 @@ async def _save_thread_state(self, thread_state: Thread) -> bool: keys = self._get_thread_keys(thread_state.thread_id) async with client.pipeline(transaction=True) as pipe: - # Set context as hash + # Save messages to dedicated list (clear and rebuild for atomicity) + if thread_state.messages: + pipe.delete(keys["messages"]) + for msg in thread_state.messages: + pipe.rpush(keys["messages"], msg.model_dump_json()) + + # Set context as hash (excluding messages which are now separate) if thread_state.context: # Filter out None values and serialize complex objects as JSON clean_context = {} for k, v in thread_state.context.items(): + # Skip 'messages' key - messages are stored separately + if k == "messages": + continue if v is None: clean_context[k] = "" elif isinstance(v, (dict, list)): @@ -697,18 +731,6 @@ async def _save_thread_state(self, thread_state: Thread) -> bool: } pipe.hset(keys["metadata"], mapping=clean_metadata) - # Set result if exists - if thread_state.result: - pipe.set(keys["result"], json.dumps(thread_state.result)) - - # Set error if exists - if thread_state.error_message: - pipe.set(keys["error"], thread_state.error_message) - - # Add updates - for update in thread_state.updates: - pipe.lpush(keys["updates"], update.model_dump_json()) - # Set TTL (24 hours for thread data) for key in keys.values(): pipe.expire(key, 86400) diff --git a/redis_sre_agent/mcp_server/__init__.py b/redis_sre_agent/mcp_server/__init__.py new file mode 100644 index 00000000..f1ecf705 --- /dev/null +++ b/redis_sre_agent/mcp_server/__init__.py @@ -0,0 +1,24 @@ +"""MCP server for redis-sre-agent. + +This module exposes the agent's capabilities as an MCP server, allowing +other agents to use the Redis SRE Agent's tools via the Model Context Protocol. + +Exposed tools (all prefixed with redis_sre_): + +Task-based tools (require polling redis_sre_get_task_status): +- redis_sre_deep_triage: Comprehensive Redis issue analysis (2-10 min) +- redis_sre_general_chat: Quick Q&A with full toolset including external MCP tools +- redis_sre_database_chat: Redis-focused chat with selective MCP tool exclusion +- redis_sre_knowledge_query: Ask the Knowledge Agent a question + +Utility tools (return immediately): +- redis_sre_knowledge_search: Direct search of knowledge base docs +- redis_sre_list_instances: List configured Redis instances +- redis_sre_create_instance: Create a new Redis instance configuration +- redis_sre_get_task_status: Check task progress, notifications, and results +- redis_sre_get_thread: Get full conversation history and results +""" + +from redis_sre_agent.mcp_server.server import mcp + +__all__ = ["mcp"] diff --git a/redis_sre_agent/mcp_server/server.py b/redis_sre_agent/mcp_server/server.py new file mode 100644 index 00000000..c95a40af --- /dev/null +++ b/redis_sre_agent/mcp_server/server.py @@ -0,0 +1,935 @@ +"""MCP server implementation for redis-sre-agent. + +This module creates an MCP server using FastMCP that exposes the agent's +capabilities to other MCP clients. The server runs in stdio mode and +proxies requests to the running Redis SRE Agent HTTP API. + +This allows Claude to connect to an already-running agent via: +1. Start agent: docker compose up -d (API on port 8080) +2. Claude spawns this MCP server, which calls the HTTP API +""" + +import logging +from typing import Any, Dict, List, Optional + +from mcp.server.fastmcp import FastMCP + +logger = logging.getLogger(__name__) + +# Create the MCP server instance +mcp = FastMCP( + name="redis-sre-agent", + instructions="""Redis SRE Agent - An AI-powered Redis troubleshooting and operations assistant. + +## Task-Based Architecture + +This agent uses a **task-based workflow**. Most tools create a **Task** that runs in +the background. You MUST watch each task for: + +1. **Status changes**: queued → in_progress → done/failed +2. **Notifications**: Real-time updates showing what the agent is doing +3. **Final result**: The response when status="done" + +## Tools That Create Tasks (require polling) + +| Tool | Purpose | Typical Duration | +|------|---------|------------------| +| `redis_sre_deep_triage()` | Deep analysis of Redis issues | 2-10 minutes | +| `redis_sre_general_chat()` | Quick Q&A with full toolset (including external MCP tools) | 10-30 seconds | +| `redis_sre_database_chat()` | Redis-focused chat (no external MCP tools) | 10-30 seconds | +| `redis_sre_knowledge_query()` | Answer questions using knowledge base | 10-30 seconds | + +**Note**: Deep triage performs comprehensive analysis including metrics collection, log analysis, +knowledge base searches, and multi-topic recommendation synthesis. Complex queries or +instances with many data sources may take longer. + +After calling any of these, you MUST: +1. Get the `task_id` from the response +2. Poll `redis_sre_get_task_status(task_id)` until status is "done" or "failed" +3. Read the `result` field when done + +## Utility Tools (return immediately) + +| Tool | Purpose | +|------|---------| +| `redis_sre_knowledge_search()` | Direct search of docs (raw results) | +| `redis_sre_list_instances()` | List available Redis instances | +| `redis_sre_get_task_status()` | Check task progress | +| `redis_sre_get_thread()` | Get conversation history | + +## Standard Workflow + +``` +1. Call redis_sre_deep_triage(), redis_sre_general_chat(), or redis_sre_knowledge_query() + → Returns: task_id, thread_id, status="queued" + +2. Poll redis_sre_get_task_status(task_id) every 5 seconds + → status: "queued" → "in_progress" → "done" + → updates: Array of notifications (grows over time) + → result: Final answer (when status="done") + +3. When status="done", read result.response +``` + +## Example + +``` +# Step 1: Create task +response = redis_sre_deep_triage(query="High memory usage on prod-redis") +task_id = response.task_id + +# Step 2: Poll for completion +while True: + status = redis_sre_get_task_status(task_id) + if status.status == "done": + print(status.result.response) # The answer! + break + elif status.status == "failed": + print(status.error_message) + break + # Show progress to user + for update in status.updates: + print(update.message) + sleep(5) +``` + +## Tips + +- **Always poll redis_sre_get_task_status()** - results are on the task, not returned directly +- Use `redis_sre_knowledge_search()` for quick doc lookups (no polling needed) +- Use `redis_sre_list_instances()` to find instance IDs before calling other tools +- Check the `updates` array to show users what the agent is doing""", +) + + +@mcp.tool() +async def redis_sre_deep_triage( + query: str, + instance_id: Optional[str] = None, + user_id: Optional[str] = None, +) -> Dict[str, Any]: + """Create a deep triage task to analyze a Redis issue comprehensively. + + This creates a **Task** that runs in the background. You MUST watch the task + for status changes, notifications, and the final result. + + ## What This Tool Does + + Creates a deep analysis task that: + - Performs comprehensive multi-topic analysis (memory, connections, performance, etc.) + - Uses knowledge base, metrics, logs, traces, and diagnostics tools + - Synthesizes findings into actionable recommendations + - Emits notifications as it works (visible via redis_sre_get_task_status) + - Stores the final result on the task when complete + + ## How to Use the Task + + 1. **Call this tool** → Returns `task_id` (and `thread_id`) + 2. **Watch the task** → Poll `redis_sre_get_task_status(task_id)` every 5-10 seconds + - `status`: "queued" → "in_progress" → "done" or "failed" + - `updates`: Array of notifications showing what the agent is doing + - `result`: Final analysis (present when status="done") + 3. **Read the result** → When status="done", the `result` field has the response + + The task typically takes 2-10 minutes depending on complexity. + + Args: + query: The issue to analyze (e.g., "High memory usage on production Redis") + instance_id: Optional Redis instance ID (use redis_sre_list_instances to find IDs) + user_id: Optional user ID for tracking + + Returns: + task_id: Watch this task for status, notifications, and result + thread_id: Conversation thread (for multi-turn follow-ups) + status: Initial status (usually "queued") + """ + from docket import Docket + + from redis_sre_agent.core.docket_tasks import get_redis_url, process_agent_turn + from redis_sre_agent.core.redis import get_redis_client + from redis_sre_agent.core.tasks import create_task + + logger.info(f"MCP deep_triage request: {query[:100]}...") + + try: + redis_client = get_redis_client() + context: Dict[str, Any] = {} + if instance_id: + context["instance_id"] = instance_id + if user_id: + context["user_id"] = user_id + + result = await create_task( + message=query, + context=context, + redis_client=redis_client, + ) + + # Submit to Docket for processing (this is what the API does) + async with Docket(url=await get_redis_url(), name="sre_docket") as docket: + task_func = docket.add(process_agent_turn) + await task_func( + thread_id=result["thread_id"], + message=query, + context=context, + task_id=result["task_id"], + ) + + return { + "thread_id": result["thread_id"], + "task_id": result["task_id"], + "status": result["status"].value + if hasattr(result["status"], "value") + else str(result["status"]), + "message": result.get("message", "Triage queued for processing"), + } + + except Exception as e: + logger.error(f"Triage failed: {e}") + return { + "error": str(e), + "status": "failed", + "message": f"Failed to start triage: {e}", + } + + +@mcp.tool() +async def redis_sre_general_chat( + query: str, + instance_id: Optional[str] = None, + user_id: Optional[str] = None, +) -> Dict[str, Any]: + """Create a chat task for Redis Q&A with full tool access. + + This creates a **Task** that runs the chat agent with access to ALL tools including: + - Redis instance tools (INFO, SLOWLOG, CONFIG, CLIENT, etc.) + - Knowledge base tools (search documentation, runbooks) + - Utility tools (time conversion, formatting) + - External MCP tools (GitHub, Slack, Prometheus, Loki, etc. if configured) + + Use this for: + - Questions that may require external data (metrics, logs, tickets) + - Operations that span multiple systems + - Quick status checks with full observability context + + For Redis-only questions without external integrations, use redis_sre_database_chat(). + For complex issues requiring deep analysis, use redis_sre_deep_triage(). + + ## How to Use the Task + + 1. **Call this tool** → Returns `task_id` (and `thread_id`) + 2. **Watch the task** → Poll `redis_sre_get_task_status(task_id)` every 2-5 seconds + - Chat is faster than triage (typically 10-30 seconds) + - `status`: "queued" → "in_progress" → "done" or "failed" + - `updates`: Notifications showing what the agent is doing + - `result`: The answer (present when status="done") + + Args: + query: Your question (e.g., "What's the current memory usage?") + instance_id: Optional Redis instance ID (use redis_sre_list_instances to find IDs) + user_id: Optional user ID for tracking + + Returns: + task_id: Watch this task for status, notifications, and result + thread_id: Conversation thread (for follow-up questions) + status: Initial status (usually "queued") + """ + from docket import Docket + + from redis_sre_agent.core.docket_tasks import get_redis_url, process_chat_turn + from redis_sre_agent.core.redis import get_redis_client + from redis_sre_agent.core.tasks import create_task + + logger.info(f"MCP general_chat request: {query[:100]}...") + + try: + redis_client = get_redis_client() + context: Dict[str, Any] = {"agent_type": "chat"} + if instance_id: + context["instance_id"] = instance_id + if user_id: + context["user_id"] = user_id + + result = await create_task( + message=query, + context=context, + redis_client=redis_client, + ) + + # Submit to Docket for processing + async with Docket(url=await get_redis_url(), name="sre_docket") as docket: + task_func = docket.add(process_chat_turn) + await task_func( + query=query, + task_id=result["task_id"], + thread_id=result["thread_id"], + instance_id=instance_id, + user_id=user_id, + ) + + return { + "thread_id": result["thread_id"], + "task_id": result["task_id"], + "status": result["status"].value + if hasattr(result["status"], "value") + else str(result["status"]), + "message": "Chat task queued for processing", + } + + except Exception as e: + logger.error(f"Chat failed: {e}") + return { + "error": str(e), + "status": "failed", + "message": f"Failed to start chat: {e}", + } + + +@mcp.tool() +async def redis_sre_database_chat( + query: str, + instance_id: Optional[str] = None, + user_id: Optional[str] = None, + exclude_mcp_categories: Optional[List[str]] = None, +) -> Dict[str, Any]: + """Create a Redis-focused chat task with selective MCP tool access. + + Similar to redis_sre_general_chat(), but allows excluding specific categories of + MCP tools. By default, excludes all external MCP tools for focused Redis diagnostics. + + The agent has access to: + - Redis instance tools (INFO, SLOWLOG, CONFIG, CLIENT, etc.) + - Knowledge base tools (search documentation, runbooks) + - Utility tools (time conversion, formatting) + - MCP tools NOT in the excluded categories + + Use this when: + - You want focused Redis instance diagnostics without external integrations + - You need a lighter-weight agent that won't call out to certain MCP servers + - You want selective access to MCP tools (e.g., allow metrics but not tickets) + + ## Exclude Categories + + You can exclude specific MCP tool categories: + - "metrics": Prometheus, Grafana, etc. + - "logs": Loki, log aggregators, etc. + - "tickets": Jira, GitHub Issues, etc. + - "repos": GitHub, GitLab, etc. + - "traces": Jaeger, distributed tracing, etc. + - "diagnostics": External diagnostic tools + - "knowledge": External knowledge bases + - "utilities": External utility tools + + Pass None or empty list to include all MCP tools (same as redis_sre_general_chat). + Pass ["all"] to exclude all MCP tools. + + ## How to Use the Task + + 1. **Call this tool** → Returns `task_id` (and `thread_id`) + 2. **Watch the task** → Poll `redis_sre_get_task_status(task_id)` every 2-5 seconds + - `status`: "queued" → "in_progress" → "done" or "failed" + - `updates`: Notifications showing what the agent is doing + - `result`: The answer (present when status="done") + + Args: + query: Your question (e.g., "What's the current memory usage?") + instance_id: Optional Redis instance ID (use redis_sre_list_instances to find IDs) + user_id: Optional user ID for tracking + exclude_mcp_categories: Categories to exclude. Pass ["all"] to exclude all MCP tools. + Default: ["all"] (excludes all MCP tools for focused Redis chat) + + Returns: + task_id: Watch this task for status, notifications, and result + thread_id: Conversation thread (for follow-up questions) + status: Initial status (usually "queued") + """ + from docket import Docket + + from redis_sre_agent.core.docket_tasks import get_redis_url, process_chat_turn + from redis_sre_agent.core.redis import get_redis_client + from redis_sre_agent.core.tasks import create_task + from redis_sre_agent.tools.models import ToolCapability + + logger.info(f"MCP database_chat request: {query[:100]}...") + + # Default to excluding all MCP categories for focused Redis chat + if exclude_mcp_categories is None: + exclude_mcp_categories = ["all"] + + # Convert "all" to list of all categories + if "all" in exclude_mcp_categories: + exclude_mcp_categories = [cap.value for cap in ToolCapability] + + try: + redis_client = get_redis_client() + context: Dict[str, Any] = { + "agent_type": "chat", + "exclude_mcp_categories": exclude_mcp_categories, + } + if instance_id: + context["instance_id"] = instance_id + if user_id: + context["user_id"] = user_id + + result = await create_task( + message=query, + context=context, + redis_client=redis_client, + ) + + # Submit to Docket for processing with category exclusions + async with Docket(url=await get_redis_url(), name="sre_docket") as docket: + task_func = docket.add(process_chat_turn) + await task_func( + query=query, + task_id=result["task_id"], + thread_id=result["thread_id"], + instance_id=instance_id, + user_id=user_id, + exclude_mcp_categories=exclude_mcp_categories, + ) + + return { + "thread_id": result["thread_id"], + "task_id": result["task_id"], + "status": result["status"].value + if hasattr(result["status"], "value") + else str(result["status"]), + "message": f"Database chat task queued (excluded categories: {exclude_mcp_categories})", + } + + except Exception as e: + logger.error(f"Database chat failed: {e}") + return { + "error": str(e), + "status": "failed", + "message": f"Failed to start database chat: {e}", + } + + +@mcp.tool() +async def redis_sre_knowledge_search( + query: str, + limit: int = 10, + offset: int = 0, + category: Optional[str] = None, + version: Optional[str] = "latest", +) -> Dict[str, Any]: + """Search the Redis SRE knowledge base (returns raw results). + + This is a **direct search** that returns raw knowledge base results immediately. + Use this when you want to browse documentation or get specific content. + + For questions that need interpretation/reasoning, use `redis_sre_knowledge_query()` + instead, which creates a task that uses the Knowledge Agent to analyze and answer. + + Args: + query: Search query (e.g., "redis memory eviction policies") + limit: Maximum number of results (1-50, default 10) + offset: Number of results to skip for pagination (default 0) + category: Optional filter by category ('incident', 'maintenance', 'monitoring', etc.) + version: Redis documentation version filter. Defaults to "latest". + + Returns: + results: Array of matching documents with title, content, source, etc. + (Returns immediately - no task polling needed) + """ + from redis_sre_agent.core.knowledge_helpers import search_knowledge_base_helper + + logger.info(f"MCP knowledge_search: {query[:100]}... (version={version}, offset={offset})") + + try: + limit = max(1, min(50, limit)) + offset = max(0, offset) + kwargs: Dict[str, Any] = { + "query": query, + "limit": limit, + "offset": offset, + "version": version, + } + if category: + kwargs["category"] = category + + result = await search_knowledge_base_helper(**kwargs) + + results = [] + for item in result.get("results", []): + results.append( + { + "title": item.get("title", "Untitled"), + "content": item.get("content", ""), + "source": item.get("source"), + "category": item.get("category"), + "version": item.get("version", "latest"), + "score": item.get("score"), + } + ) + + return { + "query": query, + "version": version, + "offset": offset, + "limit": limit, + "results": results, + "total_results": len(results), + "has_more": len(results) == limit, # Hint for pagination + } + + except Exception as e: + logger.error(f"Knowledge search failed: {e}") + return { + "error": str(e), + "query": query, + "results": [], + "total_results": 0, + } + + +@mcp.tool() +async def redis_sre_knowledge_query( + query: str, + user_id: Optional[str] = None, +) -> Dict[str, Any]: + """Create a task to answer a question using the Knowledge Agent. + + This creates a **Task** that uses the Knowledge Agent to answer questions + about SRE practices, Redis best practices, and troubleshooting guidance. + The agent searches the knowledge base and synthesizes an answer. + + Use this for questions that need reasoning/interpretation. + Use `redis_sre_knowledge_search()` for direct document search. + + ## How to Use the Task + + 1. **Call this tool** → Returns `task_id` (and `thread_id`) + 2. **Watch the task** → Poll `redis_sre_get_task_status(task_id)` every 2-5 seconds + - `status`: "queued" → "in_progress" → "done" or "failed" + - `updates`: Notifications showing knowledge sources being searched + - `result`: The synthesized answer (present when status="done") + + Args: + query: Your question (e.g., "What are Redis memory eviction policies?") + user_id: Optional user ID for tracking + + Returns: + task_id: Watch this task for status, notifications, and result + thread_id: Conversation thread (for follow-up questions) + status: Initial status (usually "queued") + """ + from docket import Docket + + from redis_sre_agent.core.docket_tasks import get_redis_url, process_knowledge_query + from redis_sre_agent.core.redis import get_redis_client + from redis_sre_agent.core.tasks import create_task + + logger.info(f"MCP knowledge_query: {query[:100]}...") + + try: + redis_client = get_redis_client() + context: Dict[str, Any] = {"agent_type": "knowledge"} + if user_id: + context["user_id"] = user_id + + result = await create_task( + message=query, + context=context, + redis_client=redis_client, + ) + + # Submit to Docket for processing + async with Docket(url=await get_redis_url(), name="sre_docket") as docket: + task_func = docket.add(process_knowledge_query) + await task_func( + query=query, + task_id=result["task_id"], + thread_id=result["thread_id"], + user_id=user_id, + ) + + return { + "thread_id": result["thread_id"], + "task_id": result["task_id"], + "status": result["status"].value + if hasattr(result["status"], "value") + else str(result["status"]), + "message": "Knowledge query task queued for processing", + } + + except Exception as e: + logger.error(f"Knowledge query failed: {e}") + return { + "error": str(e), + "status": "failed", + "message": f"Failed to start knowledge query: {e}", + } + + +@mcp.tool() +async def redis_sre_get_thread(thread_id: str) -> Dict[str, Any]: + """Get the full conversation and results from a triage or chat thread. + + Call this AFTER redis_sre_get_task_status() shows status="done" to retrieve the + complete analysis. The thread contains: + + - All messages exchanged (user query, assistant responses) + - Tool calls made by the agent (metrics queries, log searches, etc.) + - The final result with findings and recommendations + + Workflow: + 1. redis_sre_deep_triage() or redis_sre_*_chat() → get thread_id and task_id + 2. redis_sre_get_task_status(task_id) → poll until status="done" + 3. redis_sre_get_thread(thread_id) → get full results (this tool) + + Args: + thread_id: The thread_id returned from the triage or chat tool + + Returns: + messages: List of conversation messages with role and content + result: Final analysis result (findings, recommendations, etc.) + updates: Progress updates that occurred during execution + error_message: Error details if the triage failed + """ + from redis_sre_agent.core.redis import get_redis_client + from redis_sre_agent.core.threads import ThreadManager + + logger.info(f"MCP get_thread: {thread_id}") + + try: + redis_client = get_redis_client() + tm = ThreadManager(redis_client=redis_client) + thread = await tm.get_thread(thread_id) + + if not thread: + return { + "error": f"Thread {thread_id} not found", + "thread_id": thread_id, + } + + # Format messages from thread.messages + formatted_messages = [] + for msg in thread.messages: + formatted_msg = { + "role": msg.role, + "content": msg.content, + } + # Include metadata if present + if msg.metadata: + formatted_msg["metadata"] = msg.metadata + formatted_messages.append(formatted_msg) + + # Get the latest task for updates/result/error + from redis_sre_agent.core.keys import RedisKeys + from redis_sre_agent.core.tasks import TaskManager + + task_manager = TaskManager(redis_client=redis_client) + latest_task_ids = await redis_client.zrevrange( + RedisKeys.thread_tasks_index(thread_id), 0, 0 + ) + + result = None + error_message = None + updates = [] + + if latest_task_ids: + latest_task_id = latest_task_ids[0] + if isinstance(latest_task_id, bytes): + latest_task_id = latest_task_id.decode() + task_state = await task_manager.get_task_state(latest_task_id) + if task_state: + result = task_state.result + error_message = task_state.error_message + updates = [u.model_dump() for u in task_state.updates] if task_state.updates else [] + + return { + "thread_id": thread_id, + "messages": formatted_messages, + "message_count": len(formatted_messages), + "result": result, + "error_message": error_message, + "updates": updates, + } + + except Exception as e: + logger.error(f"Get thread failed: {e}") + return { + "error": str(e), + "thread_id": thread_id, + } + + +@mcp.tool() +async def redis_sre_get_task_status(task_id: str) -> Dict[str, Any]: + """Watch a task for status, notifications, and result. + + After calling any task-based tool (redis_sre_deep_triage, redis_sre_*_chat, etc.), + poll this tool to watch your task. Check THREE things: + + ## 1. Status (is it done?) + + - "queued": Waiting to start + - "in_progress": Agent is working + - "done": Complete! Check the `result` field + - "failed": Error occurred - check `error_message` + + ## 2. Updates/Notifications (what is the agent doing?) + + The `updates` array shows real-time notifications: + ``` + updates: [ + {"timestamp": "...", "message": "Querying Redis INFO...", "type": "tool_call"}, + {"timestamp": "...", "message": "Memory usage is 85%...", "type": "agent_reflection"}, + {"timestamp": "...", "message": "Checking slow log...", "type": "tool_call"}, + ] + ``` + + This array grows as the agent works. Each entry shows what the agent + is doing or thinking. Use this to provide feedback to users. + + ## 3. Result (the final answer) + + When status="done", the `result` field contains: + ``` + result: { + "response": "Based on my analysis, the high memory...", + "metadata": {...} + } + ``` + + ## Polling Pattern + + Poll every 5-10 seconds until status is "done" or "failed": + - Show updates to user as they arrive + - When done, extract the result + + Args: + task_id: The task_id returned from triage or chat tools + + Returns: + status: Current status (queued/in_progress/done/failed) + updates: Array of notifications from the agent (grows over time) + result: Final response (only present when status="done") + error_message: Error details (only present when status="failed") + thread_id: For multi-turn follow-ups via redis_sre_get_thread() + """ + from redis_sre_agent.core.tasks import get_task_by_id + + logger.info(f"MCP get_task_status: {task_id}") + + try: + task = await get_task_by_id(task_id=task_id) + metadata = task.get("metadata", {}) or {} + + return { + "task_id": task_id, + "thread_id": task.get("thread_id"), + "status": task.get("status"), + "subject": metadata.get("subject"), + "created_at": metadata.get("created_at"), + "updated_at": metadata.get("updated_at"), + "updates": task.get("updates", []), + "result": task.get("result"), + "error_message": task.get("error_message"), + } + + except ValueError as e: + return { + "error": str(e), + "task_id": task_id, + "status": "not_found", + } + except Exception as e: + logger.error(f"Get task status failed: {e}") + return { + "error": str(e), + "task_id": task_id, + } + + +@mcp.tool() +async def redis_sre_list_instances() -> Dict[str, Any]: + """List all configured Redis instances. + + Returns a list of all Redis instances that have been configured + in the SRE agent. Sensitive information like connection URLs and + passwords are masked. + + Use this to find instance IDs before calling other tools like + redis_sre_deep_triage() or redis_sre_general_chat(). + + Returns: + Dictionary with list of instance information + """ + from redis_sre_agent.core.instances import get_instances + + logger.info("MCP list_instances request") + + try: + instances = await get_instances() + + instance_list = [] + for inst in instances: + instance_list.append( + { + "id": inst.id, + "name": inst.name, + "environment": inst.environment, + "usage": inst.usage, + "description": inst.description, + "instance_type": inst.instance_type, + "repo_url": inst.repo_url, + "status": getattr(inst, "status", None), + } + ) + + return { + "instances": instance_list, + "total": len(instance_list), + } + + except Exception as e: + logger.error(f"List instances failed: {e}") + return { + "error": str(e), + "instances": [], + "total": 0, + } + + +@mcp.tool() +async def redis_sre_create_instance( + name: str, + connection_url: str, + environment: str, + usage: str, + description: str, + repo_url: Optional[str] = None, + user_id: Optional[str] = None, +) -> Dict[str, Any]: + """Create a new Redis instance configuration. + + Registers a new Redis instance with the SRE agent. The instance can + then be used for triage, monitoring, and diagnostics via tools like + redis_sre_deep_triage() and redis_sre_general_chat(). + + Args: + name: Unique name for the instance + connection_url: Redis connection URL (redis://host:port or rediss://...) + environment: Environment type (development, staging, production, test) + usage: Usage type (cache, analytics, session, queue, custom) + description: Description of what this Redis instance is used for + repo_url: Optional GitHub repository URL associated with this instance + user_id: Optional user ID of who is creating this instance + + Returns: + Dictionary with the created instance ID and status + """ + from datetime import datetime + + from redis_sre_agent.core.instances import ( + RedisInstance, + get_instances, + save_instances, + ) + + logger.info(f"MCP create_instance: {name}") + + valid_envs = ["development", "staging", "production", "test"] + if environment.lower() not in valid_envs: + return { + "error": f"Invalid environment. Must be one of: {', '.join(valid_envs)}", + "status": "failed", + } + + valid_usages = ["cache", "analytics", "session", "queue", "custom"] + if usage.lower() not in valid_usages: + return { + "error": f"Invalid usage. Must be one of: {', '.join(valid_usages)}", + "status": "failed", + } + + try: + instances = await get_instances() + + if any(inst.name == name for inst in instances): + return { + "error": f"Instance with name '{name}' already exists", + "status": "failed", + } + + instance_id = f"redis-{environment.lower()}-{int(datetime.now().timestamp())}" + new_instance = RedisInstance( + id=instance_id, + name=name, + connection_url=connection_url, + environment=environment.lower(), + usage=usage.lower(), + description=description, + repo_url=repo_url, + instance_type="unknown", # Will be auto-detected on first connection + ) + + instances.append(new_instance) + if not await save_instances(instances): + return {"error": "Failed to save instance", "status": "failed"} + + logger.info(f"Created Redis instance: {name} ({instance_id})") + return { + "id": instance_id, + "name": name, + "repo_url": repo_url, + "status": "created", + "message": f"Successfully created instance '{name}'", + } + + except Exception as e: + logger.error(f"Create instance failed: {e}") + return {"error": str(e), "status": "failed"} + + +# ============================================================================ +# Server runners +# ============================================================================ + + +def run_stdio(): + """Run the MCP server in stdio mode.""" + mcp.run(transport="stdio") + + +def run_sse(host: str = "127.0.0.1", port: int = 8080): + """Run the MCP server in SSE mode (legacy, use HTTP instead).""" + mcp.run(transport="sse", host=host, port=port) + + +def run_http(host: str = "0.0.0.0", port: int = 8081): + """Run the MCP server in HTTP mode (Streamable HTTP). + + This is the recommended transport for remote access. Claude can connect + to this server via Settings > Connectors > Add Custom Connector with + the URL: http://:/mcp + + Args: + host: Host to bind to (default 0.0.0.0 for external access) + port: Port to listen on (default 8081) + """ + import asyncio + + mcp.settings.host = host + mcp.settings.port = port + asyncio.run(mcp.run_streamable_http_async()) + + +def get_http_app(): + """Get the ASGI app for the MCP server. + + Use this when deploying with uvicorn or other ASGI servers: + uvicorn redis_sre_agent.mcp_server.server:app --host 0.0.0.0 --port 8081 + + The MCP endpoint will be available at /mcp + """ + return mcp.streamable_http_app() + + +# ASGI app for uvicorn deployment +# Usage: uvicorn redis_sre_agent.mcp_server.server:app --host 0.0.0.0 --port 8081 +app = mcp.streamable_http_app() diff --git a/redis_sre_agent/pipelines/ingestion/deduplication.py b/redis_sre_agent/pipelines/ingestion/deduplication.py index c9f43c84..71e9ff12 100644 --- a/redis_sre_agent/pipelines/ingestion/deduplication.py +++ b/redis_sre_agent/pipelines/ingestion/deduplication.py @@ -316,6 +316,7 @@ async def replace_document_chunks(self, chunks: List[Dict[str, Any]], vectorizer "category": chunk["category"], "doc_type": chunk["doc_type"], "severity": chunk["severity"], + "version": chunk.get("version", "latest"), "chunk_index": chunk["chunk_index"], "vector": all_embeddings[i], "created_at": datetime.now(timezone.utc).timestamp(), diff --git a/redis_sre_agent/pipelines/ingestion/processor.py b/redis_sre_agent/pipelines/ingestion/processor.py index e10fd2c1..9de9a0c1 100644 --- a/redis_sre_agent/pipelines/ingestion/processor.py +++ b/redis_sre_agent/pipelines/ingestion/processor.py @@ -175,6 +175,9 @@ def _create_chunk( # Generate deterministic ID based on document hash and chunk index chunk_id = f"{document.content_hash}_{chunk_index}" + # Extract version from metadata, default to "latest" + version = document.metadata.get("version", "latest") + return { "id": chunk_id, "document_hash": document.content_hash, @@ -184,6 +187,7 @@ def _create_chunk( "category": document.category.value, "doc_type": document.doc_type.value, "severity": document.severity.value, + "version": version, "chunk_index": chunk_index, "metadata": { **document.metadata, diff --git a/redis_sre_agent/pipelines/scraper/redis_docs.py b/redis_sre_agent/pipelines/scraper/redis_docs.py index baaeeb62..e0a7b7cc 100644 --- a/redis_sre_agent/pipelines/scraper/redis_docs.py +++ b/redis_sre_agent/pipelines/scraper/redis_docs.py @@ -50,6 +50,31 @@ def _is_versioned_url(self, url: str) -> bool: except Exception: return False + def _extract_version_from_url(self, url: str) -> str: + """Extract version from URL path. + + Examples: + /rs/7.8/clusters/... -> "7.8" + /rs/7.4/clusters/... -> "7.4" + /rs/clusters/... -> "latest" + /latest/operate/... -> "latest" + + Returns: + Version string (e.g., "7.8", "7.4") or "latest" for unversioned docs. + """ + import re + from urllib.parse import urlparse + + try: + path = urlparse(url).path + # Match version patterns like /7.8/, /7.4/, /6.2/ + match = re.search(r"/(\d+\.\d+)/", path) + if match: + return match.group(1) + return "latest" + except Exception: + return "latest" + def get_source_name(self) -> str: return "redis_documentation" @@ -244,6 +269,12 @@ async def _scrape_section( # Extract main content main_content = await self._extract_page_content(soup, section_url) if main_content: + # Extract version from URL and add to metadata + version = self._extract_version_from_url(section_url) + metadata = { + **main_content["metadata"], + "version": version, + } doc = ScrapedDocument( title=main_content["title"], content=main_content["content"], @@ -251,7 +282,7 @@ async def _scrape_section( category=category, doc_type=doc_type, severity=severity, - metadata=main_content["metadata"], + metadata=metadata, ) documents.append(doc) diff --git a/redis_sre_agent/tools/knowledge/knowledge_base.py b/redis_sre_agent/tools/knowledge/knowledge_base.py index e96dd1d6..e63ccf9b 100644 --- a/redis_sre_agent/tools/knowledge/knowledge_base.py +++ b/redis_sre_agent/tools/knowledge/knowledge_base.py @@ -52,7 +52,8 @@ def create_tool_schemas(self) -> List[ToolDefinition]: "runbooks, Redis documentation, troubleshooting guides, and SRE procedures. " "Use this to find solutions to problems, understand Redis features, or get " "guidance on SRE best practices. Always cite the source document and title " - "when using information from search results." + "when using information from search results. By default, returns only the " + "latest version of documentation to avoid duplicates." ), capability=ToolCapability.KNOWLEDGE, parameters={ @@ -67,7 +68,23 @@ def create_tool_schemas(self) -> List[ToolDefinition]: "description": "Maximum number of results to return (default: 10)", "default": 10, "minimum": 1, - "maximum": 20, + "maximum": 50, + }, + "offset": { + "type": "integer", + "description": "Number of results to skip for pagination (default: 0)", + "default": 0, + "minimum": 0, + }, + "version": { + "type": "string", + "description": ( + "Redis documentation version filter. Defaults to 'latest' which " + "returns only the most current documentation. Available versions: " + "'latest' (default, recommended), '7.8', '7.4', '7.2'. " + "Set to null to return all versions (may include duplicates)." + ), + "default": "latest", }, "distance_threshold": { "type": "number", @@ -198,6 +215,8 @@ async def search( self, query: str, limit: int = 10, + offset: int = 0, + version: Optional[str] = "latest", distance_threshold: Optional[float] = None, ) -> Dict[str, Any]: """Search the knowledge base. @@ -205,17 +224,22 @@ async def search( Args: query: Search query limit: Maximum number of results + offset: Number of results to skip for pagination + version: Version filter - "latest" (default), specific version like "7.8", + or None to return all versions distance_threshold: Optional cosine distance threshold. If provided, overrides the backend default. Returns: Search results with relevant knowledge base content """ logger.info( - f"Knowledge base search: {query} (limit={limit}, distance_threshold={distance_threshold})" + f"Knowledge base search: {query} (limit={limit}, offset={offset}, version={version})" ) kwargs = { "query": query, "limit": limit, + "offset": offset, + "version": version, "distance_threshold": distance_threshold, } # OTel: instrument knowledge search without leaking raw query @@ -232,6 +256,8 @@ async def search( "query.len": len(query or ""), "query.sha1": _qhash, "limit": int(limit), + "offset": int(offset), + "version": version or "all", "distance_threshold.set": distance_threshold is not None, }, ): diff --git a/redis_sre_agent/tools/manager.py b/redis_sre_agent/tools/manager.py index 663b5642..fa3cf07d 100644 --- a/redis_sre_agent/tools/manager.py +++ b/redis_sre_agent/tools/manager.py @@ -53,13 +53,23 @@ class ToolManager: "redis_sre_agent.tools.utilities.provider.UtilitiesToolProvider", ] - def __init__(self, redis_instance: Optional[RedisInstance] = None): + def __init__( + self, + redis_instance: Optional[RedisInstance] = None, + exclude_mcp_categories: Optional[List[ToolCapability]] = None, + ): """Initialize tool manager. Args: redis_instance: Optional Redis instance to scope tools to + exclude_mcp_categories: Optional list of MCP tool categories to exclude. + Use [ToolCapability.UTILITIES] to exclude utility-only MCP tools, + or pass all capabilities to exclude all MCP tools. + Common categories: METRICS, LOGS, TICKETS, REPOS, TRACES, + DIAGNOSTICS, KNOWLEDGE, UTILITIES. """ self.redis_instance = redis_instance + self.exclude_mcp_categories = exclude_mcp_categories # Track loaded provider class paths to avoid duplicates self._loaded_provider_paths: set[str] = set() @@ -158,6 +168,10 @@ async def __aenter__(self) -> "ToolManager": else: logger.info("No redis_instance provided - loading only instance-independent providers") + # Load MCP servers (these are always-on and don't require redis_instance) + # Pass excluded categories to filter which MCP tools are loaded + await self._load_mcp_providers() + logger.info( f"ToolManager initialized with {len(self._tools)} tools " f"from {len(set(self._routing_table.values()))} providers" @@ -210,6 +224,91 @@ async def _load_provider(self, provider_path: str, always_on: bool = False) -> N logger.exception(f"Failed to load provider {provider_path}") # Don't fail entire manager if one provider fails + async def _load_mcp_providers(self) -> None: + """Load MCP tool providers based on configured mcp_servers. + + This method iterates through the mcp_servers configuration and creates + an MCPToolProvider for each configured server. Tools are filtered based + on exclude_mcp_categories if specified. + """ + from redis_sre_agent.core.config import MCPServerConfig, settings + + if not settings.mcp_servers: + return + + # Build set of excluded capabilities for fast lookup + excluded_caps = set(self.exclude_mcp_categories or []) + if excluded_caps: + logger.info( + f"MCP tools with these categories will be excluded: {[c.value for c in excluded_caps]}" + ) + + for server_name, server_config in settings.mcp_servers.items(): + try: + # Convert dict to MCPServerConfig if needed + if isinstance(server_config, dict): + server_config = MCPServerConfig.model_validate(server_config) + + # Skip if already loaded (use a synthetic path for tracking) + mcp_provider_path = f"mcp:{server_name}" + if mcp_provider_path in self._loaded_provider_paths: + logger.debug(f"MCP provider already loaded, skipping: {server_name}") + continue + + # Import and create the MCP provider + from redis_sre_agent.tools.mcp.provider import MCPToolProvider + + provider = MCPToolProvider( + server_name=server_name, + server_config=server_config, + redis_instance=None, # MCP providers don't use redis_instance + ) + + # Enter the provider's async context + provider = await self._stack.enter_async_context(provider) + + # Set back-reference + try: + setattr(provider, "_manager", self) + except Exception: + pass + + # Register tools, filtering by excluded categories + tools = provider.tools() + included_count = 0 + excluded_count = 0 + for tool in tools: + name = tool.metadata.name + if not name: + continue + # Skip tools whose capability is in the excluded list + if tool.metadata.capability in excluded_caps: + excluded_count += 1 + logger.debug( + f"Excluding MCP tool '{name}' (capability: {tool.metadata.capability.value})" + ) + continue + self._routing_table[name] = provider + self._tools.append(tool) + self._tool_by_name[name] = tool + included_count += 1 + + # Track provider + self._providers.append(provider) + self._loaded_provider_paths.add(mcp_provider_path) + + if excluded_count > 0: + logger.info( + f"Loaded MCP provider '{server_name}': {included_count} tools included, " + f"{excluded_count} excluded by category filter" + ) + else: + logger.info(f"Loaded MCP provider '{server_name}' with {included_count} tools") + + except Exception: + logger.exception(f"Failed to load MCP provider '{server_name}'") + # Don't fail entire manager if one MCP provider fails + @classmethod def _get_provider_class(cls, provider_path: str) -> type: """Get provider class from path, with caching. diff --git a/redis_sre_agent/tools/mcp/__init__.py b/redis_sre_agent/tools/mcp/__init__.py new file mode 100644 index 00000000..e59862e3 --- /dev/null +++ b/redis_sre_agent/tools/mcp/__init__.py @@ -0,0 +1,9 @@ +"""MCP (Model Context Protocol) tool provider integration. + +This module provides dynamic tool providers that connect to MCP servers +and expose their tools to the agent. +""" + +from redis_sre_agent.tools.mcp.provider import MCPToolProvider + +__all__ = ["MCPToolProvider"] diff --git a/redis_sre_agent/tools/mcp/provider.py b/redis_sre_agent/tools/mcp/provider.py new file mode 100644 index 00000000..e740cfab --- /dev/null +++ b/redis_sre_agent/tools/mcp/provider.py @@ -0,0 +1,390 @@ +"""MCP (Model Context Protocol) tool provider. + +This module provides a dynamic tool provider that connects to an MCP server +and exposes its tools to the agent. It supports tool filtering and description +overrides based on the MCPServerConfig. +""" + +import logging +import os +from contextlib import AsyncExitStack +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from mcp import ClientSession, StdioServerParameters +from mcp import types as mcp_types +from mcp.client.sse import sse_client +from mcp.client.stdio import stdio_client +from mcp.client.streamable_http import streamablehttp_client + +from redis_sre_agent.core.config import MCPServerConfig, MCPToolConfig +from redis_sre_agent.tools.models import Tool, ToolCapability, ToolDefinition, ToolMetadata +from redis_sre_agent.tools.protocols import ToolProvider + +if TYPE_CHECKING: + from redis_sre_agent.core.instances import RedisInstance + +logger = logging.getLogger(__name__) + + +class MCPToolProvider(ToolProvider): + """Dynamic tool provider that connects to an MCP server. + + This provider: + 1. Connects to an MCP server using the configured transport (stdio or HTTP) + 2. Discovers available tools from the server + 3. Optionally filters tools based on the config's `tools` mapping + 4. Applies capability and description overrides from the config + 5. Exposes the tools to the agent + + Example: + config = MCPServerConfig( + command="npx", + args=["-y", "@modelcontextprotocol/server-memory"], + tools={ + "search_memories": MCPToolConfig(capability=ToolCapability.LOGS), + } + ) + provider = MCPToolProvider( + server_name="memory", + server_config=config, + ) + async with provider: + tools = provider.tools() + """ + + # Default capability for MCP tools if not specified + DEFAULT_CAPABILITY = ToolCapability.UTILITIES + + def __init__( + self, + server_name: str, + server_config: MCPServerConfig, + redis_instance: Optional["RedisInstance"] = None, + ): + """Initialize the MCP tool provider. + + Args: + server_name: Name of the MCP server (used in tool naming) + server_config: Configuration for the MCP server + redis_instance: Optional Redis instance (not typically used by MCP) + """ + super().__init__(redis_instance=redis_instance) + self._server_name = server_name + self._server_config = server_config + self._session: Optional[ClientSession] = None + self._exit_stack: Optional[AsyncExitStack] = None + self._mcp_tools: List[mcp_types.Tool] = [] + self._tool_cache: List[Tool] = [] + + @property + def provider_name(self) -> str: + """Return the provider name based on the server name.""" + return f"mcp_{self._server_name}" + + async def __aenter__(self) -> "MCPToolProvider": + """Enter async context and connect to the MCP server.""" + await self._connect() + return self + + async def __aexit__(self, *args) -> None: + """Exit async context and disconnect from the MCP server.""" + await self._disconnect() + + async def _connect(self) -> None: + """Connect to the MCP server and discover tools. + + This method initializes the MCP client based on the transport type + (stdio command or HTTP URL) and fetches the available tools. + """ + try: + logger.info( + f"Connecting to MCP server '{self._server_name}' " + f"(command={self._server_config.command}, url={self._server_config.url})" + ) + + self._exit_stack = AsyncExitStack() + await self._exit_stack.__aenter__() + + # Determine transport type and connect + if self._server_config.command: + # Stdio transport - spawn a subprocess + # Merge parent environment with config-specified env so that + # env vars like OPENAI_API_KEY are inherited by the subprocess + merged_env = {**os.environ, **(self._server_config.env or {})} + server_params = StdioServerParameters( + command=self._server_config.command, + args=self._server_config.args or [], + env=merged_env, + ) + read_stream, write_stream = await self._exit_stack.enter_async_context( + stdio_client(server_params) + ) + elif self._server_config.url: + # URL-based transport (SSE or Streamable HTTP) + # Expand environment variables in headers (e.g., ${GITHUB_TOKEN}) + headers = None + if self._server_config.headers: + headers = {} + for key, value in self._server_config.headers.items(): + # Expand ${VAR} patterns from environment + expanded_value = os.path.expandvars(value) + headers[key] = expanded_value + + # Determine transport type - default to streamable_http for modern servers + transport_type = (self._server_config.transport or "streamable_http").lower() + + if transport_type == "sse": + # Legacy SSE transport + logger.info(f"Using SSE transport for '{self._server_name}'") + read_stream, write_stream = await self._exit_stack.enter_async_context( + sse_client(self._server_config.url, headers=headers) + ) + else: + # Streamable HTTP transport (default, works with GitHub remote MCP, etc.) + logger.info(f"Using Streamable HTTP transport for '{self._server_name}'") + ( + read_stream, + write_stream, + _get_session_id, + ) = await self._exit_stack.enter_async_context( + streamablehttp_client(self._server_config.url, headers=headers) + ) + else: + raise ValueError( + f"MCP server '{self._server_name}' must have either 'command' or 'url' configured" + ) + + # Create and initialize the session + self._session = await self._exit_stack.enter_async_context( + ClientSession(read_stream, write_stream) + ) + await self._session.initialize() + + # Discover tools from the server + tools_result = await self._session.list_tools() + self._mcp_tools = tools_result.tools + self._tool_cache = [] + + logger.info( + f"MCP server '{self._server_name}' connected with {len(self._mcp_tools)} tools: " + f"{[t.name for t in self._mcp_tools]}" + ) + + except Exception as e: + logger.error(f"Failed to connect to MCP server '{self._server_name}': {e}") + # Clean up on failure + if self._exit_stack: + await self._exit_stack.aclose() + self._exit_stack = None + raise + + async def _disconnect(self) -> None: + """Disconnect from the MCP server.""" + try: + if self._exit_stack: + logger.info(f"Disconnecting from MCP server '{self._server_name}'") + await self._exit_stack.aclose() + self._exit_stack = None + self._session = None + except Exception as e: + logger.warning(f"Error disconnecting from MCP server '{self._server_name}': {e}") + + def _get_tool_config(self, tool_name: str) -> Optional[MCPToolConfig]: + """Get the configuration for a specific tool, if any.""" + if self._server_config.tools: + return self._server_config.tools.get(tool_name) + return None + + def _should_include_tool(self, tool_name: str) -> bool: + """Check if a tool should be included based on the config. + + If `tools` is specified in the config, only those tools are included. + If `tools` is None, all tools from the server are included. + """ + if self._server_config.tools is None: + return True + return tool_name in self._server_config.tools + + def _get_capability(self, tool_name: str) -> ToolCapability: + """Get the capability for a tool, with config override support.""" + config = self._get_tool_config(tool_name) + if config and config.capability: + return config.capability + return self.DEFAULT_CAPABILITY + + def _get_description(self, tool_name: str, mcp_description: str) -> str: + """Get the description for a tool, with config override/template support. + + If the config provides a description, it can use {original} as a placeholder + for the MCP tool's original description. This allows adding context while + preserving the original tool documentation. + + Examples: + - No override: uses original MCP description + - Override without placeholder: "Custom description" -> replaces entirely + - Override with placeholder: "Context. {original}" -> prepends context + + Args: + tool_name: Name of the MCP tool + mcp_description: Original description from the MCP server + + Returns: + Final description (original, override, or templated) + """ + config = self._get_tool_config(tool_name) + if config and config.description: + # Support templating: {original} gets replaced with the MCP description + if "{original}" in config.description: + return config.description.replace("{original}", mcp_description) + return config.description + return mcp_description + + def create_tool_schemas(self) -> List[ToolDefinition]: + """Create tool schemas from the MCP server's tools. + + This method transforms MCP tool definitions into ToolDefinition objects, + applying any configured filters, capability overrides, and description + overrides. + """ + schemas: List[ToolDefinition] = [] + + for mcp_tool in self._mcp_tools: + tool_name = mcp_tool.name + if not tool_name: + continue + + # Check if tool should be included + if not self._should_include_tool(tool_name): + continue + + # Get description (with potential override) + mcp_description = mcp_tool.description or f"MCP tool: {tool_name}" + description = self._get_description(tool_name, mcp_description) + + # Get capability (with potential override) + capability = self._get_capability(tool_name) + + # Build parameters schema from MCP tool input schema + input_schema = mcp_tool.inputSchema or {} + parameters = { + "type": "object", + "properties": input_schema.get("properties", {}), + "required": input_schema.get("required", []), + } + + schema = ToolDefinition( + name=self._make_tool_name(tool_name), + description=description, + capability=capability, + parameters=parameters, + ) + schemas.append(schema) + + return schemas + + def tools(self) -> List[Tool]: + """Return the concrete tools exposed by this provider. + + This caches the tools list to avoid rebuilding on every call. + """ + if self._tool_cache: + return self._tool_cache + + schemas = self.create_tool_schemas() + tools: List[Tool] = [] + + for schema in schemas: + # Extract the original MCP tool name from our tool name + mcp_tool_name = self.resolve_operation(schema.name, {}) or "" + + meta = ToolMetadata( + name=schema.name, + description=schema.description, + capability=schema.capability, + provider_name=self.provider_name, + requires_instance=False, # MCP tools typically don't require Redis instance + ) + + # Create the invoke closure that calls the MCP server + async def _invoke( + args: Dict[str, Any], + _tool_name: str = mcp_tool_name, + ) -> Any: + return await self._call_mcp_tool(_tool_name, args) + + tools.append(Tool(metadata=meta, definition=schema, invoke=_invoke)) + + self._tool_cache = tools + return tools + + async def _call_mcp_tool(self, tool_name: str, args: Dict[str, Any]) -> Any: + """Call an MCP tool on the server. + + Args: + tool_name: The original MCP tool name (without provider prefix) + args: Arguments to pass to the tool + + Returns: + The tool's result from the MCP server + """ + if not self._session: + return { + "status": "error", + "error": f"MCP server '{self._server_name}' is not connected", + } + + try: + logger.info(f"Calling MCP tool '{tool_name}' with args: {args}") + result = await self._session.call_tool(tool_name, arguments=args) + + # Check for errors + if result.isError: + error_text = "" + for content in result.content: + if isinstance(content, mcp_types.TextContent): + error_text += content.text + return { + "status": "error", + "error": error_text or "Tool execution failed", + } + + # Extract the result content + response: Dict[str, Any] = {"status": "success"} + + # If there's structured content, use it + if result.structuredContent: + response["data"] = result.structuredContent + + # Also extract text content for compatibility + text_parts = [] + for content in result.content: + if isinstance(content, mcp_types.TextContent): + text_parts.append(content.text) + elif isinstance(content, mcp_types.ImageContent): + response.setdefault("images", []).append( + { + "mimeType": content.mimeType, + "data": content.data, + } + ) + elif isinstance(content, mcp_types.EmbeddedResource): + resource = content.resource + if isinstance(resource, mcp_types.TextResourceContents): + response.setdefault("resources", []).append( + { + "uri": str(resource.uri), + "text": resource.text, + } + ) + + if text_parts: + response["text"] = "\n".join(text_parts) + + return response + + except Exception as e: + logger.error(f"Error calling MCP tool '{tool_name}': {e}") + return { + "status": "error", + "error": str(e), + } diff --git a/scripts/generate-mcp-certs.sh b/scripts/generate-mcp-certs.sh new file mode 100755 index 00000000..fafc8955 --- /dev/null +++ b/scripts/generate-mcp-certs.sh @@ -0,0 +1,19 @@ +#!/bin/bash +# Generate self-signed certificates for the MCP server + +CERT_DIR="monitoring/nginx/certs" +mkdir -p "$CERT_DIR" + +# Generate self-signed certificate valid for 365 days +openssl req -x509 -nodes -days 365 -newkey rsa:2048 \ + -keyout "$CERT_DIR/server.key" \ + -out "$CERT_DIR/server.crt" \ + -subj "/CN=localhost/O=Redis SRE Agent/C=US" \ + -addext "subjectAltName=DNS:localhost,DNS:sre-mcp,IP:127.0.0.1" + +echo "Certificates generated in $CERT_DIR/" +echo " - server.crt (certificate)" +echo " - server.key (private key)" +echo "" +echo "To trust this cert on macOS:" +echo " sudo security add-trusted-cert -d -r trustRoot -k /Library/Keychains/System.keychain $CERT_DIR/server.crt" diff --git a/tests/conftest.py b/tests/conftest.py index 0d2e645f..8c483a11 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,8 +3,6 @@ """ import os -import subprocess -import time from typing import Any, Dict, List from unittest.mock import AsyncMock, Mock, patch @@ -67,34 +65,9 @@ def pytest_configure(config): os.environ["OPENAI_INTEGRATION_TESTS"] = "true" os.environ["AGENT_BEHAVIOR_TESTS"] = "true" os.environ["INTEGRATION_TESTS"] = "true" # Needed for redis_container fixture - - # If running full suite and INTEGRATION_TESTS requested, ensure docker compose is up - if os.environ.get("INTEGRATION_TESTS") and not os.environ.get("CI"): - try: - # Start only infra services to avoid building app images during tests - subprocess.run( - [ - "docker", - "compose", - "-f", - "docker-compose.yml", - "-f", - "docker-compose.test.yml", - "up", - "-d", - "redis", - "redis-exporter", - "prometheus", - "node-exporter", - "grafana", - ], - check=False, - ) - # Give services a moment to start - time.sleep(3) - except Exception: - # Non-fatal; testcontainers fallback will still work - pass + # Note: We intentionally do NOT start docker-compose here. + # Integration tests use testcontainers via the redis_container fixture, + # which manages Redis lifecycle automatically with docker-compose.integration.yml. def pytest_collection_modifyitems(config, items): diff --git a/tests/integration/tools/diagnostics/redis_command/test_redis_cli_integration.py b/tests/integration/tools/diagnostics/redis_command/test_redis_cli_integration.py index 115afac6..7316198c 100644 --- a/tests/integration/tools/diagnostics/redis_command/test_redis_cli_integration.py +++ b/tests/integration/tools/diagnostics/redis_command/test_redis_cli_integration.py @@ -1,22 +1,6 @@ """Integration tests for Redis Command Diagnostics provider with ToolManager.""" import pytest -from testcontainers.redis import RedisContainer - - -@pytest.fixture(scope="module") -def redis_container(): - """Start a Redis container for testing.""" - with RedisContainer("redis:8.2.1") as redis: - yield redis - - -@pytest.fixture -def redis_url(redis_container): - """Get Redis connection URL from container.""" - host = redis_container.get_container_host_ip() - port = redis_container.get_exposed_port(6379) - return f"redis://{host}:{port}" @pytest.mark.asyncio diff --git a/tests/integration/tools/diagnostics/redis_command/test_redis_command_provider.py b/tests/integration/tools/diagnostics/redis_command/test_redis_command_provider.py index 79d5643e..5d141b15 100644 --- a/tests/integration/tools/diagnostics/redis_command/test_redis_command_provider.py +++ b/tests/integration/tools/diagnostics/redis_command/test_redis_command_provider.py @@ -3,7 +3,6 @@ from unittest.mock import AsyncMock, patch import pytest -from testcontainers.redis import RedisContainer from redis_sre_agent.tools.diagnostics.redis_command import ( RedisCommandToolProvider, @@ -11,21 +10,6 @@ from redis_sre_agent.tools.protocols import ToolCapability -@pytest.fixture(scope="module") -def redis_container(): - """Start a Redis container for testing.""" - with RedisContainer("redis:8.2.1") as redis: - yield redis - - -@pytest.fixture -def redis_url(redis_container): - """Get Redis connection URL from container.""" - host = redis_container.get_container_host_ip() - port = redis_container.get_exposed_port(6379) - return f"redis://{host}:{port}" - - @pytest.mark.asyncio async def test_provider_initialization(redis_url): """Test that provider initializes correctly.""" diff --git a/tests/unit/agent/test_chat_agent.py b/tests/unit/agent/test_chat_agent.py new file mode 100644 index 00000000..5d6f1989 --- /dev/null +++ b/tests/unit/agent/test_chat_agent.py @@ -0,0 +1,266 @@ +"""Unit tests for the lightweight Chat Agent.""" + +from unittest.mock import MagicMock, patch + +from redis_sre_agent.agent.chat_agent import ( + CHAT_SYSTEM_PROMPT, + ChatAgent, + ChatAgentState, + get_chat_agent, +) +from redis_sre_agent.core.instances import RedisInstance +from redis_sre_agent.core.progress import ( + CallbackEmitter, + NullEmitter, +) + + +class TestChatAgentInitialization: + """Test ChatAgent initialization.""" + + @patch("redis_sre_agent.agent.chat_agent.ChatOpenAI") + def test_agent_initializes_without_instance(self, mock_chat_openai): + """Test that ChatAgent initializes correctly without a Redis instance.""" + mock_llm = MagicMock() + mock_chat_openai.return_value = mock_llm + + agent = ChatAgent() + + assert agent.llm is mock_llm + assert agent.mini_llm is mock_llm # Both use the same mock + assert agent.redis_instance is None + # Should have NullEmitter by default + assert isinstance(agent._emitter, NullEmitter) + # Now creates 2 LLM instances (llm and mini_llm) + assert mock_chat_openai.call_count == 2 + + @patch("redis_sre_agent.agent.chat_agent.ChatOpenAI") + def test_agent_initializes_with_instance(self, mock_chat_openai): + """Test that ChatAgent initializes correctly with a Redis instance.""" + mock_llm = MagicMock() + mock_chat_openai.return_value = mock_llm + + instance = RedisInstance( + id="test-id", + name="test-instance", + connection_url="redis://localhost:6379", + environment="development", + usage="cache", + description="Test instance", + instance_type="oss_single", + ) + + agent = ChatAgent(redis_instance=instance) + + assert agent.llm is mock_llm + assert agent.redis_instance is instance + assert agent.redis_instance.name == "test-instance" + + @patch("redis_sre_agent.agent.chat_agent.ChatOpenAI") + def test_agent_initializes_with_progress_emitter(self, mock_chat_openai): + """Test that ChatAgent accepts a progress_emitter.""" + mock_llm = MagicMock() + mock_chat_openai.return_value = mock_llm + + emitter = NullEmitter() + agent = ChatAgent(progress_emitter=emitter) + + assert agent._emitter is emitter + + @patch("redis_sre_agent.agent.chat_agent.ChatOpenAI") + def test_agent_initializes_with_progress_callback_deprecated(self, mock_chat_openai): + """Test that ChatAgent still accepts deprecated progress_callback.""" + mock_llm = MagicMock() + mock_chat_openai.return_value = mock_llm + + async def my_callback(msg, type): + pass + + agent = ChatAgent(progress_callback=my_callback) + + # Should wrap callback in CallbackEmitter + assert isinstance(agent._emitter, CallbackEmitter) + + @patch("redis_sre_agent.agent.chat_agent.ChatOpenAI") + def test_progress_emitter_takes_precedence_over_callback(self, mock_chat_openai): + """Test that progress_emitter takes precedence over progress_callback.""" + mock_llm = MagicMock() + mock_chat_openai.return_value = mock_llm + + emitter = NullEmitter() + + async def my_callback(msg, type): + pass + + agent = ChatAgent(progress_emitter=emitter, progress_callback=my_callback) + + # Should use the emitter, not the callback + assert agent._emitter is emitter + + @patch("redis_sre_agent.agent.chat_agent.ChatOpenAI") + def test_agent_no_temperature_parameter(self, mock_chat_openai): + """Test that ChatAgent doesn't use temperature parameter (reasoning models).""" + mock_llm = MagicMock() + mock_chat_openai.return_value = mock_llm + + ChatAgent() + + call_args = mock_chat_openai.call_args + assert "temperature" not in call_args.kwargs + + +class TestChatAgentSingleton: + """Test get_chat_agent singleton behavior.""" + + def test_get_chat_agent_without_instance(self): + """Test get_chat_agent returns agent without instance.""" + with patch("redis_sre_agent.agent.chat_agent.ChatAgent") as mock_agent_class: + mock_instance = MagicMock() + mock_agent_class.return_value = mock_instance + + # Clear cache + from redis_sre_agent.agent import chat_agent + + chat_agent._chat_agents.clear() + + agent = get_chat_agent() + + assert agent is mock_instance + mock_agent_class.assert_called_once_with(redis_instance=None) + + def test_get_chat_agent_caches_by_instance_name(self): + """Test get_chat_agent caches agents by instance name.""" + with patch("redis_sre_agent.agent.chat_agent.ChatAgent") as mock_agent_class: + mock_agent1 = MagicMock() + mock_agent2 = MagicMock() + mock_agent_class.side_effect = [mock_agent1, mock_agent2] + + # Clear cache + from redis_sre_agent.agent import chat_agent + + chat_agent._chat_agents.clear() + + instance1 = RedisInstance( + id="id-1", + name="instance-1", + connection_url="redis://localhost:6379", + environment="development", + usage="cache", + description="Test instance 1", + instance_type="oss_single", + ) + instance2 = RedisInstance( + id="id-2", + name="instance-2", + connection_url="redis://localhost:6380", + environment="development", + usage="cache", + description="Test instance 2", + instance_type="oss_single", + ) + + agent1 = get_chat_agent(redis_instance=instance1) + agent1_again = get_chat_agent(redis_instance=instance1) + agent2 = get_chat_agent(redis_instance=instance2) + + # Same instance name should return cached agent + assert agent1 is agent1_again + # Different instance name should return new agent + assert agent1 is not agent2 + assert mock_agent_class.call_count == 2 + + +class TestChatAgentSystemPrompt: + """Test the chat agent system prompt.""" + + def test_system_prompt_is_concise(self): + """Test that the system prompt is focused and concise.""" + assert "Redis SRE agent" in CHAT_SYSTEM_PROMPT + assert "quick" in CHAT_SYSTEM_PROMPT.lower() or "fast" in CHAT_SYSTEM_PROMPT.lower() + # Should mention full triage as alternative + assert "triage" in CHAT_SYSTEM_PROMPT.lower() + + def test_system_prompt_mentions_tools(self): + """Test that the system prompt mentions tool usage.""" + assert "tool" in CHAT_SYSTEM_PROMPT.lower() + + def test_system_prompt_warns_about_managed_redis(self): + """Test that the system prompt has Redis Enterprise/Cloud notes.""" + assert "Enterprise" in CHAT_SYSTEM_PROMPT or "Cloud" in CHAT_SYSTEM_PROMPT + assert "INFO" in CHAT_SYSTEM_PROMPT + + +class TestChatAgentState: + """Test the ChatAgentState TypedDict.""" + + def test_state_has_required_fields(self): + """Test that ChatAgentState has all required fields.""" + state: ChatAgentState = { + "messages": [], + "session_id": "test-session", + "user_id": "test-user", + "current_tool_calls": [], + "iteration_count": 0, + "max_iterations": 10, + "signals_envelopes": [], + } + + assert "messages" in state + assert "session_id" in state + assert "user_id" in state + assert "current_tool_calls" in state + assert "iteration_count" in state + assert "max_iterations" in state + assert "signals_envelopes" in state + + +class TestChatAgentWorkflowBuild: + """Test the _build_workflow method and emitter parameter.""" + + @patch("redis_sre_agent.agent.chat_agent.ChatOpenAI") + def test_build_workflow_accepts_emitter(self, mock_chat_openai): + """Test that _build_workflow accepts an emitter parameter.""" + mock_llm = MagicMock() + mock_chat_openai.return_value = mock_llm + + agent = ChatAgent() + + # Create a mock tool manager + mock_tool_mgr = MagicMock() + mock_tool_mgr.get_tools.return_value = [] + mock_tool_mgr.get_status_update.return_value = None + + # Create a mock emitter + emitter = NullEmitter() + + # Should not raise - emitter is now accepted + workflow = agent._build_workflow( + tool_mgr=mock_tool_mgr, + llm_with_tools=mock_llm, + adapters=[], + emitter=emitter, + ) + + assert workflow is not None + + @patch("redis_sre_agent.agent.chat_agent.ChatOpenAI") + def test_build_workflow_works_without_emitter(self, mock_chat_openai): + """Test that _build_workflow works when emitter is None.""" + mock_llm = MagicMock() + mock_chat_openai.return_value = mock_llm + + agent = ChatAgent() + + # Create a mock tool manager + mock_tool_mgr = MagicMock() + mock_tool_mgr.get_tools.return_value = [] + + # Should not raise when emitter is None + workflow = agent._build_workflow( + tool_mgr=mock_tool_mgr, + llm_with_tools=mock_llm, + adapters=[], + emitter=None, + ) + + assert workflow is not None diff --git a/tests/unit/agent/test_envelope_summarization.py b/tests/unit/agent/test_envelope_summarization.py new file mode 100644 index 00000000..542dde86 --- /dev/null +++ b/tests/unit/agent/test_envelope_summarization.py @@ -0,0 +1,221 @@ +"""Tests for envelope summarization and expand_evidence tool in the reasoning phase.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from redis_sre_agent.agent.langgraph_agent import SRELangGraphAgent + + +class TestEnvelopeSummarization: + """Test the _summarize_envelopes_for_reasoning method.""" + + @pytest.fixture + def agent(self): + """Create agent instance with mocked LLM.""" + with patch("redis_sre_agent.agent.langgraph_agent.ChatOpenAI"): + agent = SRELangGraphAgent() + # Mock the mini_llm + agent.mini_llm = MagicMock() + agent._llm_cache = {} + agent._run_cache_active = False + return agent + + @pytest.mark.asyncio + async def test_empty_envelopes_returns_empty(self, agent): + """Test that empty input returns empty output.""" + result = await agent._summarize_envelopes_for_reasoning([]) + assert result == [] + + @pytest.mark.asyncio + async def test_small_envelopes_unchanged(self, agent): + """Test that small envelopes are not summarized.""" + small_envelope = { + "tool_key": "test_tool", + "name": "test", + "description": "A test tool", + "args": {"param": "value"}, + "status": "success", + "data": {"result": "small data"}, # Well under 500 chars + } + + result = await agent._summarize_envelopes_for_reasoning([small_envelope]) + + assert len(result) == 1 + assert result[0]["data"] == {"result": "small data"} + + @pytest.mark.asyncio + async def test_large_envelopes_summarized(self, agent): + """Test that large envelopes are summarized via LLM.""" + # Create a large envelope (>500 chars in data) + large_data = {"metrics": "x" * 1000, "logs": "y" * 1000} + large_envelope = { + "tool_key": "redis_info", + "name": "info", + "description": "Get Redis INFO", + "args": {}, + "status": "success", + "data": large_data, + } + + # Mock LLM response + mock_response = MagicMock() + mock_response.content = '[{"summary": "Key finding: metrics show high load"}]' + agent.mini_llm.ainvoke = AsyncMock(return_value=mock_response) + + result = await agent._summarize_envelopes_for_reasoning([large_envelope]) + + assert len(result) == 1 + assert "summary" in result[0]["data"] + assert "high load" in result[0]["data"]["summary"] + # Original large data should be replaced + assert result[0]["data"] != large_data + + @pytest.mark.asyncio + async def test_mixed_envelopes_partial_summarization(self, agent): + """Test that only large envelopes are summarized.""" + small_envelope = { + "tool_key": "small_tool", + "name": "small", + "description": "Small tool", + "args": {}, + "status": "success", + "data": {"value": 42}, + } + large_envelope = { + "tool_key": "large_tool", + "name": "large", + "description": "Large tool", + "args": {}, + "status": "success", + "data": {"content": "x" * 1000}, + } + + # Mock LLM response for large envelope + mock_response = MagicMock() + mock_response.content = '[{"summary": "Large content summarized"}]' + agent.mini_llm.ainvoke = AsyncMock(return_value=mock_response) + + result = await agent._summarize_envelopes_for_reasoning([small_envelope, large_envelope]) + + assert len(result) == 2 + # Small envelope unchanged + assert result[0]["data"] == {"value": 42} + # Large envelope summarized + assert "summary" in result[1]["data"] + + @pytest.mark.asyncio + async def test_order_preserved(self, agent): + """Test that envelope order is preserved after summarization.""" + envelopes = [ + { + "tool_key": f"tool_{i}", + "name": f"t{i}", + "args": {}, + "status": "success", + "data": {"id": i, "content": "x" * (100 if i % 2 == 0 else 1000)}, + } + for i in range(5) + ] + + mock_response = MagicMock() + mock_response.content = '[{"summary": "s1"}, {"summary": "s2"}]' + agent.mini_llm.ainvoke = AsyncMock(return_value=mock_response) + + result = await agent._summarize_envelopes_for_reasoning(envelopes) + + # Check order by tool_key + assert [r["tool_key"] for r in result] == [f"tool_{i}" for i in range(5)] + + @pytest.mark.asyncio + async def test_llm_failure_fallback_truncation(self, agent): + """Test that LLM failure falls back to truncation.""" + large_envelope = { + "tool_key": "test", + "name": "test", + "description": "Test", + "args": {}, + "status": "success", + "data": {"content": "x" * 1000}, + } + + # Mock LLM to raise exception + agent.mini_llm.ainvoke = AsyncMock(side_effect=Exception("LLM error")) + + result = await agent._summarize_envelopes_for_reasoning([large_envelope]) + + assert len(result) == 1 + assert "truncated" in result[0]["data"] + assert result[0]["data"]["truncated"].endswith("...") + + +class TestExpandEvidenceTool: + """Test the expand_evidence tool for retrieving full tool outputs.""" + + @pytest.fixture + def agent(self): + """Create agent instance with mocked LLM.""" + with patch("redis_sre_agent.agent.langgraph_agent.ChatOpenAI"): + agent = SRELangGraphAgent() + return agent + + def test_expand_evidence_returns_full_data(self, agent): + """Test that expand_evidence returns the full original data.""" + envelopes = [ + { + "tool_key": "redis_info_123", + "name": "info", + "description": "Get Redis INFO", + "args": {"section": "all"}, + "status": "success", + "data": {"memory": "large data here", "clients": 100}, + }, + { + "tool_key": "slowlog_456", + "name": "slowlog", + "description": "Get slow queries", + "args": {}, + "status": "success", + "data": {"queries": ["query1", "query2"]}, + }, + ] + + tool_spec = agent._build_expand_evidence_tool(envelopes) + func = tool_spec["func"] + + # Call expand_evidence for first tool + result = func("redis_info_123") + assert result["status"] == "success" + assert result["tool_key"] == "redis_info_123" + assert result["full_data"] == {"memory": "large data here", "clients": 100} + + # Call for second tool + result = func("slowlog_456") + assert result["status"] == "success" + assert result["full_data"] == {"queries": ["query1", "query2"]} + + def test_expand_evidence_unknown_key(self, agent): + """Test that expand_evidence returns error for unknown tool_key.""" + envelopes = [ + {"tool_key": "known_key", "name": "test", "data": {"x": 1}}, + ] + + tool_spec = agent._build_expand_evidence_tool(envelopes) + func = tool_spec["func"] + + result = func("unknown_key") + assert result["status"] == "error" + assert "Unknown tool_key" in result["error"] + assert "known_key" in result["error"] # Should list available keys + + def test_expand_evidence_tool_schema(self, agent): + """Test that expand_evidence tool has correct schema.""" + envelopes = [{"tool_key": "test_key", "name": "test", "data": {}}] + + tool_spec = agent._build_expand_evidence_tool(envelopes) + + assert tool_spec["name"] == "expand_evidence" + assert "full" in tool_spec["description"].lower() + assert "test_key" in tool_spec["description"] # Lists available keys + assert tool_spec["parameters"]["properties"]["tool_key"]["type"] == "string" + assert "tool_key" in tool_spec["parameters"]["required"] diff --git a/tests/unit/agent/test_router.py b/tests/unit/agent/test_router.py new file mode 100644 index 00000000..a29ad715 --- /dev/null +++ b/tests/unit/agent/test_router.py @@ -0,0 +1,151 @@ +"""Unit tests for the agent router.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from redis_sre_agent.agent.router import AgentType, route_to_appropriate_agent + + +class TestAgentTypeEnum: + """Test the AgentType enum.""" + + def test_agent_types_exist(self): + """Test that all expected agent types exist.""" + assert AgentType.REDIS_TRIAGE.value == "redis_triage" + assert AgentType.REDIS_CHAT.value == "redis_chat" + assert AgentType.KNOWLEDGE_ONLY.value == "knowledge_only" + + def test_redis_focused_is_alias_for_triage(self): + """Test that REDIS_FOCUSED is an alias for REDIS_TRIAGE.""" + # In Python enums, same value = same member + assert AgentType.REDIS_FOCUSED is AgentType.REDIS_TRIAGE + assert AgentType.REDIS_FOCUSED.value == "redis_triage" + + +@pytest.mark.asyncio +class TestRouteToAppropriateAgent: + """Test the route_to_appropriate_agent function.""" + + async def test_no_instance_routes_to_knowledge(self): + """Test that queries without instance context route to knowledge agent.""" + with patch("redis_sre_agent.agent.router.ChatOpenAI") as mock_chat: + mock_llm = MagicMock() + mock_response = MagicMock() + mock_response.content = "KNOWLEDGE_ONLY" + mock_llm.ainvoke = AsyncMock(return_value=mock_response) + mock_chat.return_value = mock_llm + + result = await route_to_appropriate_agent( + query="What are Redis best practices?", + context=None, + ) + + assert result == AgentType.KNOWLEDGE_ONLY + + async def test_instance_with_triage_request_routes_to_triage(self): + """Test that triage requests with instance route to triage agent.""" + with patch("redis_sre_agent.agent.router.ChatOpenAI") as mock_chat: + mock_llm = MagicMock() + mock_response = MagicMock() + mock_response.content = "TRIAGE" + mock_llm.ainvoke = AsyncMock(return_value=mock_response) + mock_chat.return_value = mock_llm + + result = await route_to_appropriate_agent( + query="Run a full health check on my Redis", + context={"instance_id": "test-instance"}, + ) + + assert result == AgentType.REDIS_TRIAGE + + async def test_instance_with_quick_question_routes_to_chat(self): + """Test that quick questions with instance route to chat agent.""" + with patch("redis_sre_agent.agent.router.ChatOpenAI") as mock_chat: + mock_llm = MagicMock() + mock_response = MagicMock() + mock_response.content = "CHAT" + mock_llm.ainvoke = AsyncMock(return_value=mock_response) + mock_chat.return_value = mock_llm + + result = await route_to_appropriate_agent( + query="What's the memory usage?", + context={"instance_id": "test-instance"}, + ) + + assert result == AgentType.REDIS_CHAT + + async def test_llm_error_with_instance_defaults_to_chat(self): + """Test that LLM errors with instance default to chat agent.""" + with patch("redis_sre_agent.agent.router.ChatOpenAI") as mock_chat: + mock_llm = MagicMock() + mock_llm.ainvoke = AsyncMock(side_effect=Exception("LLM error")) + mock_chat.return_value = mock_llm + + result = await route_to_appropriate_agent( + query="Check something", + context={"instance_id": "test-instance"}, + ) + + assert result == AgentType.REDIS_CHAT + + async def test_llm_error_without_instance_defaults_to_knowledge(self): + """Test that LLM errors without instance default to knowledge agent.""" + with patch("redis_sre_agent.agent.router.ChatOpenAI") as mock_chat: + mock_llm = MagicMock() + mock_llm.ainvoke = AsyncMock(side_effect=Exception("LLM error")) + mock_chat.return_value = mock_llm + + result = await route_to_appropriate_agent( + query="What is Redis?", + context=None, + ) + + assert result == AgentType.KNOWLEDGE_ONLY + + async def test_user_preference_respected(self): + """Test that user preferences are respected when instance exists.""" + with patch("redis_sre_agent.agent.router.ChatOpenAI") as mock_chat: + # LLM should not be called when preference is set + mock_chat.return_value = MagicMock() + + result = await route_to_appropriate_agent( + query="Some query", + context={"instance_id": "test-instance"}, + user_preferences={"preferred_agent": "redis_triage"}, + ) + + assert result == AgentType.REDIS_TRIAGE + + async def test_comprehensive_triggers_triage(self): + """Test that 'comprehensive' keyword triggers triage routing.""" + with patch("redis_sre_agent.agent.router.ChatOpenAI") as mock_chat: + mock_llm = MagicMock() + mock_response = MagicMock() + mock_response.content = "TRIAGE" + mock_llm.ainvoke = AsyncMock(return_value=mock_response) + mock_chat.return_value = mock_llm + + result = await route_to_appropriate_agent( + query="Give me a comprehensive analysis", + context={"instance_id": "test-instance"}, + ) + + assert result == AgentType.REDIS_TRIAGE + + async def test_unexpected_llm_response_defaults_to_chat(self): + """Test that unexpected LLM responses default to chat when instance exists.""" + with patch("redis_sre_agent.agent.router.ChatOpenAI") as mock_chat: + mock_llm = MagicMock() + mock_response = MagicMock() + mock_response.content = "UNEXPECTED_VALUE" + mock_llm.ainvoke = AsyncMock(return_value=mock_response) + mock_chat.return_value = mock_llm + + result = await route_to_appropriate_agent( + query="Some query", + context={"instance_id": "test-instance"}, + ) + + # Should default to CHAT when unexpected value with instance + assert result == AgentType.REDIS_CHAT diff --git a/tests/unit/api/test_tasks_api.py b/tests/unit/api/test_tasks_api.py index ae66fa83..3c122eba 100644 --- a/tests/unit/api/test_tasks_api.py +++ b/tests/unit/api/test_tasks_api.py @@ -55,7 +55,12 @@ def test_create_task_success(self, client): def test_get_task_success(self, client): """GET /api/v1/tasks/{task_id} returns 200 with state.""" - # Minimal TaskState-like object + # Minimal TaskState-like object with metadata + class Metadata: + subject = "Test subject" + created_at = "2024-01-01T00:00:00Z" + updated_at = "2024-01-01T00:01:00Z" + class S: task_id = "t1" thread_id = "th1" @@ -63,6 +68,7 @@ class S: updates = [] result = None error_message = None + metadata = Metadata() mock_tm = MagicMock() mock_tm.get_task_state = AsyncMock(return_value=S()) @@ -72,3 +78,6 @@ class S: data = resp.json() assert data["task_id"] == "t1" assert data["thread_id"] == "th1" + assert data["subject"] == "Test subject" + assert data["created_at"] == "2024-01-01T00:00:00Z" + assert data["updated_at"] == "2024-01-01T00:01:00Z" diff --git a/tests/unit/api/test_threads_api.py b/tests/unit/api/test_threads_api.py index 5b961f77..68d0ba00 100644 --- a/tests/unit/api/test_threads_api.py +++ b/tests/unit/api/test_threads_api.py @@ -75,19 +75,18 @@ def test_update_thread_success(self, client): def test_get_thread_success(self, client): """GET /api/v1/threads/{id} returns 200 with messages and metadata.""" + from redis_sre_agent.core.threads import Message, Thread, ThreadMetadata - # Minimal ThreadState-like object - class State: - context = {"messages": [{"role": "user", "content": "hi"}]} - action_items = [] - updates = [] - result = None - error_message = None - metadata = MagicMock() - metadata.model_dump = lambda: {"user_id": "u"} + # Create a proper Thread object matching the model + mock_thread = Thread( + thread_id="th1", + messages=[Message(role="user", content="hi")], + context={}, + metadata=ThreadMetadata(user_id="u"), + ) mock_tm = MagicMock() - mock_tm.get_thread = AsyncMock(return_value=State()) + mock_tm.get_thread = AsyncMock(return_value=mock_thread) with patch("redis_sre_agent.api.threads.ThreadManager", return_value=mock_tm): resp = client.get("/api/v1/threads/th1") assert resp.status_code == 200 diff --git a/tests/unit/api/test_websockets.py b/tests/unit/api/test_websockets.py index 6705feb3..e1985e11 100644 --- a/tests/unit/api/test_websockets.py +++ b/tests/unit/api/test_websockets.py @@ -201,6 +201,8 @@ async def test_websocket_connection_success(self, test_client): patch("redis_sre_agent.api.websockets._stream_manager") as mock_stream_manager, ): mock_redis = AsyncMock() + # Mock Redis operations that the websocket endpoint uses + mock_redis.zrevrange = AsyncMock(return_value=[]) # No latest task mock_get_redis.return_value = mock_redis mock_manager = AsyncMock() @@ -216,8 +218,8 @@ async def test_websocket_connection_success(self, test_client): assert data["update_type"] == "initial_state" assert data["thread_id"] == thread_id - assert len(data["updates"]) == 2 - assert data["updates"][0]["message"] == "Processing..." # Most recent first + # With no task, updates should be empty + assert data["updates"] == [] # Verify stream consumer was started mock_stream_manager.start_consumer.assert_called_once_with(thread_id) diff --git a/tests/unit/cli/test_cli_index.py b/tests/unit/cli/test_cli_index.py new file mode 100644 index 00000000..437cb00e --- /dev/null +++ b/tests/unit/cli/test_cli_index.py @@ -0,0 +1,178 @@ +"""Tests for the `index` CLI commands.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from click.testing import CliRunner + +from redis_sre_agent.cli.index import index + + +@pytest.fixture +def cli_runner(): + """Click CLI test runner.""" + return CliRunner() + + +class TestIndexListCLI: + """Test index list CLI command.""" + + def test_list_help_shows_options(self, cli_runner): + """Test that list command shows expected options in help.""" + result = cli_runner.invoke(index, ["list", "--help"]) + + assert result.exit_code == 0 + assert "--json" in result.output + + def test_list_displays_indices(self, cli_runner): + """Test that list command displays indices.""" + mock_index = MagicMock() + mock_index.exists = AsyncMock(return_value=True) + mock_index._redis_client = MagicMock() + mock_index._redis_client.execute_command = AsyncMock(return_value=[b"num_docs", b"100"]) + + with ( + patch( + "redis_sre_agent.core.redis.get_knowledge_index", + new_callable=AsyncMock, + return_value=mock_index, + ), + patch( + "redis_sre_agent.core.redis.get_schedules_index", + new_callable=AsyncMock, + return_value=mock_index, + ), + patch( + "redis_sre_agent.core.redis.get_threads_index", + new_callable=AsyncMock, + return_value=mock_index, + ), + patch( + "redis_sre_agent.core.redis.get_tasks_index", + new_callable=AsyncMock, + return_value=mock_index, + ), + patch( + "redis_sre_agent.core.redis.get_instances_index", + new_callable=AsyncMock, + return_value=mock_index, + ), + ): + result = cli_runner.invoke(index, ["list"]) + + assert result.exit_code == 0 + # Should show table with indices + assert "knowledge" in result.output or "RediSearch" in result.output + + def test_list_json_output(self, cli_runner): + """Test that --json flag outputs JSON.""" + mock_index = MagicMock() + mock_index.exists = AsyncMock(return_value=True) + mock_index._redis_client = MagicMock() + mock_index._redis_client.execute_command = AsyncMock(return_value=[b"num_docs", b"50"]) + + with ( + patch( + "redis_sre_agent.core.redis.get_knowledge_index", + new_callable=AsyncMock, + return_value=mock_index, + ), + patch( + "redis_sre_agent.core.redis.get_schedules_index", + new_callable=AsyncMock, + return_value=mock_index, + ), + patch( + "redis_sre_agent.core.redis.get_threads_index", + new_callable=AsyncMock, + return_value=mock_index, + ), + patch( + "redis_sre_agent.core.redis.get_tasks_index", + new_callable=AsyncMock, + return_value=mock_index, + ), + patch( + "redis_sre_agent.core.redis.get_instances_index", + new_callable=AsyncMock, + return_value=mock_index, + ), + ): + result = cli_runner.invoke(index, ["list", "--json"]) + + assert result.exit_code == 0 + import json + + output_data = json.loads(result.output) + assert isinstance(output_data, list) + assert len(output_data) == 5 # 5 indices + + +class TestIndexRecreateCLI: + """Test index recreate CLI command.""" + + def test_recreate_help_shows_options(self, cli_runner): + """Test that recreate command shows expected options in help.""" + result = cli_runner.invoke(index, ["recreate", "--help"]) + + assert result.exit_code == 0 + assert "--index-name" in result.output + assert "--yes" in result.output + assert "-y" in result.output + assert "--json" in result.output + assert "knowledge" in result.output + assert "schedules" in result.output + assert "all" in result.output + + def test_recreate_requires_confirmation(self, cli_runner): + """Test that recreate requires confirmation without -y.""" + result = cli_runner.invoke(index, ["recreate"], input="n\n") + + assert result.exit_code == 0 + assert "Aborted" in result.output + + def test_recreate_with_yes_flag(self, cli_runner): + """Test that -y flag skips confirmation.""" + mock_result = {"success": True, "indices": {"knowledge": "recreated"}} + + with patch( + "redis_sre_agent.core.redis.recreate_indices", + new_callable=AsyncMock, + return_value=mock_result, + ) as mock_recreate: + result = cli_runner.invoke(index, ["recreate", "-y"]) + + assert result.exit_code == 0 + mock_recreate.assert_called_once_with(None) # None means all + assert "Successfully" in result.output or "✅" in result.output + + def test_recreate_specific_index(self, cli_runner): + """Test recreating a specific index.""" + mock_result = {"success": True, "indices": {"knowledge": "recreated"}} + + with patch( + "redis_sre_agent.core.redis.recreate_indices", + new_callable=AsyncMock, + return_value=mock_result, + ) as mock_recreate: + result = cli_runner.invoke(index, ["recreate", "--index-name", "knowledge", "-y"]) + + assert result.exit_code == 0 + mock_recreate.assert_called_once_with("knowledge") + + def test_recreate_json_output(self, cli_runner): + """Test that --json flag outputs JSON.""" + mock_result = {"success": True, "indices": {"knowledge": "recreated"}} + + with patch( + "redis_sre_agent.core.redis.recreate_indices", + new_callable=AsyncMock, + return_value=mock_result, + ): + result = cli_runner.invoke(index, ["recreate", "--json"]) + + assert result.exit_code == 0 + import json + + output_data = json.loads(result.output) + assert output_data["success"] is True diff --git a/tests/unit/cli/test_cli_knowledge.py b/tests/unit/cli/test_cli_knowledge.py new file mode 100644 index 00000000..b17b7211 --- /dev/null +++ b/tests/unit/cli/test_cli_knowledge.py @@ -0,0 +1,383 @@ +"""Tests for the `knowledge` CLI commands.""" + +from unittest.mock import AsyncMock, patch + +import pytest +from click.testing import CliRunner + +from redis_sre_agent.cli.knowledge import knowledge + + +@pytest.fixture +def cli_runner(): + """Click CLI test runner.""" + return CliRunner() + + +class TestKnowledgeSearchCLI: + """Test knowledge search CLI command.""" + + def test_search_help_shows_offset_option(self, cli_runner): + """Test that --offset option is visible in help.""" + result = cli_runner.invoke(knowledge, ["search", "--help"]) + + assert result.exit_code == 0 + assert "--offset" in result.output + assert "-o" in result.output + + def test_search_help_shows_version_option(self, cli_runner): + """Test that --version option is visible in help.""" + result = cli_runner.invoke(knowledge, ["search", "--help"]) + + assert result.exit_code == 0 + assert "--version" in result.output + assert "-v" in result.output + + def test_search_passes_offset_to_helper(self, cli_runner): + """Test that offset parameter is passed to search helper.""" + mock_result = { + "query": "redis memory", + "results_count": 1, + "results": [ + { + "title": "Redis Memory Guide", + "content": "Redis uses memory...", + "source": "docs", + "category": "documentation", + "version": "latest", + } + ], + } + + with patch( + "redis_sre_agent.cli.knowledge.search_knowledge_base_helper", + new_callable=AsyncMock, + ) as mock_search: + mock_search.return_value = mock_result + + result = cli_runner.invoke(knowledge, ["search", "redis", "memory", "--offset", "5"]) + + assert result.exit_code == 0, result.output + mock_search.assert_called_once() + call_kwargs = mock_search.call_args.kwargs + assert call_kwargs["offset"] == 5 + + def test_search_passes_version_to_helper(self, cli_runner): + """Test that version parameter is passed to search helper.""" + mock_result = { + "query": "redis clustering", + "results_count": 1, + "results": [ + { + "title": "Clustering Guide", + "content": "How to set up clustering...", + "source": "docs", + "category": "documentation", + "version": "7.8", + } + ], + } + + with patch( + "redis_sre_agent.cli.knowledge.search_knowledge_base_helper", + new_callable=AsyncMock, + ) as mock_search: + mock_search.return_value = mock_result + + result = cli_runner.invoke( + knowledge, ["search", "redis", "clustering", "--version", "7.8"] + ) + + assert result.exit_code == 0, result.output + mock_search.assert_called_once() + call_kwargs = mock_search.call_args.kwargs + assert call_kwargs["version"] == "7.8" + + def test_search_default_version_is_latest(self, cli_runner): + """Test that version defaults to 'latest' when not specified.""" + mock_result = { + "query": "test", + "results_count": 0, + "results": [], + } + + with patch( + "redis_sre_agent.cli.knowledge.search_knowledge_base_helper", + new_callable=AsyncMock, + ) as mock_search: + mock_search.return_value = mock_result + + result = cli_runner.invoke(knowledge, ["search", "test"]) + + assert result.exit_code == 0, result.output + mock_search.assert_called_once() + call_kwargs = mock_search.call_args.kwargs + assert call_kwargs["version"] == "latest" + + def test_search_default_offset_is_zero(self, cli_runner): + """Test that offset defaults to 0 when not specified.""" + mock_result = { + "query": "test", + "results_count": 0, + "results": [], + } + + with patch( + "redis_sre_agent.cli.knowledge.search_knowledge_base_helper", + new_callable=AsyncMock, + ) as mock_search: + mock_search.return_value = mock_result + + result = cli_runner.invoke(knowledge, ["search", "test"]) + + assert result.exit_code == 0, result.output + mock_search.assert_called_once() + call_kwargs = mock_search.call_args.kwargs + assert call_kwargs["offset"] == 0 + + def test_search_with_all_options(self, cli_runner): + """Test search with offset, version, and other options combined.""" + mock_result = { + "query": "redis performance", + "results_count": 2, + "results": [ + { + "title": "Perf Guide", + "content": "Performance tips...", + "source": "docs", + "category": "performance", + "version": "7.4", + } + ], + } + + with patch( + "redis_sre_agent.cli.knowledge.search_knowledge_base_helper", + new_callable=AsyncMock, + ) as mock_search: + mock_search.return_value = mock_result + + result = cli_runner.invoke( + knowledge, + [ + "search", + "redis", + "performance", + "--offset", + "10", + "--version", + "7.4", + "--limit", + "5", + "--category", + "performance", + ], + ) + + assert result.exit_code == 0, result.output + mock_search.assert_called_once() + call_kwargs = mock_search.call_args.kwargs + assert call_kwargs["offset"] == 10 + assert call_kwargs["version"] == "7.4" + assert call_kwargs["limit"] == 5 + assert call_kwargs["category"] == "performance" + + +class TestKnowledgeFragmentsCLI: + """Test knowledge fragments CLI command.""" + + def test_fragments_help_shows_options(self, cli_runner): + """Test that fragments command shows expected options in help.""" + result = cli_runner.invoke(knowledge, ["fragments", "--help"]) + + assert result.exit_code == 0 + assert "DOCUMENT_HASH" in result.output + assert "--json" in result.output + assert "--include-metadata" in result.output + assert "--no-metadata" in result.output + + def test_fragments_passes_document_hash(self, cli_runner): + """Test that document_hash is passed to helper.""" + mock_result = { + "title": "Test Doc", + "source": "test", + "category": "general", + "fragments_count": 2, + "fragments": [ + {"chunk_index": 0, "content": "First chunk"}, + {"chunk_index": 1, "content": "Second chunk"}, + ], + } + + with patch( + "redis_sre_agent.cli.knowledge.get_all_document_fragments", + new_callable=AsyncMock, + ) as mock_get: + mock_get.return_value = mock_result + + result = cli_runner.invoke(knowledge, ["fragments", "abc123"]) + + assert result.exit_code == 0, result.output + mock_get.assert_called_once_with("abc123", include_metadata=True) + + def test_fragments_with_no_metadata(self, cli_runner): + """Test that --no-metadata flag is passed correctly.""" + mock_result = { + "fragments_count": 1, + "fragments": [{"chunk_index": 0, "content": "Content"}], + } + + with patch( + "redis_sre_agent.cli.knowledge.get_all_document_fragments", + new_callable=AsyncMock, + ) as mock_get: + mock_get.return_value = mock_result + + result = cli_runner.invoke(knowledge, ["fragments", "abc123", "--no-metadata"]) + + assert result.exit_code == 0, result.output + mock_get.assert_called_once_with("abc123", include_metadata=False) + + def test_fragments_json_output(self, cli_runner): + """Test that --json flag outputs JSON.""" + mock_result = { + "title": "Test Doc", + "fragments_count": 1, + "fragments": [{"chunk_index": 0, "content": "Content"}], + } + + with patch( + "redis_sre_agent.cli.knowledge.get_all_document_fragments", + new_callable=AsyncMock, + ) as mock_get: + mock_get.return_value = mock_result + + result = cli_runner.invoke(knowledge, ["fragments", "abc123", "--json"]) + + assert result.exit_code == 0, result.output + # JSON output should be parseable + import json + + output_data = json.loads(result.output) + assert output_data["title"] == "Test Doc" + + def test_fragments_handles_error(self, cli_runner): + """Test that errors are handled gracefully.""" + with patch( + "redis_sre_agent.cli.knowledge.get_all_document_fragments", + new_callable=AsyncMock, + ) as mock_get: + mock_get.side_effect = Exception("Document not found") + + result = cli_runner.invoke(knowledge, ["fragments", "nonexistent"]) + + assert result.exit_code == 0 # CLI doesn't exit with error code + assert "Error" in result.output or "error" in result.output + + +class TestKnowledgeRelatedCLI: + """Test knowledge related CLI command.""" + + def test_related_help_shows_options(self, cli_runner): + """Test that related command shows expected options in help.""" + result = cli_runner.invoke(knowledge, ["related", "--help"]) + + assert result.exit_code == 0 + assert "DOCUMENT_HASH" in result.output + assert "--chunk-index" in result.output + assert "--window" in result.output + assert "--json" in result.output + + def test_related_requires_chunk_index(self, cli_runner): + """Test that --chunk-index is required.""" + result = cli_runner.invoke(knowledge, ["related", "abc123"]) + + assert result.exit_code != 0 + assert "chunk-index" in result.output.lower() or "required" in result.output.lower() + + def test_related_passes_parameters(self, cli_runner): + """Test that parameters are passed to helper.""" + mock_result = { + "title": "Test Doc", + "source": "test", + "category": "general", + "target_chunk_index": 5, + "context_window": 2, + "related_fragments_count": 3, + "related_fragments": [ + {"chunk_index": 4, "content": "Before", "is_target_chunk": False}, + {"chunk_index": 5, "content": "Target", "is_target_chunk": True}, + {"chunk_index": 6, "content": "After", "is_target_chunk": False}, + ], + } + + with patch( + "redis_sre_agent.cli.knowledge.get_related_document_fragments", + new_callable=AsyncMock, + ) as mock_get: + mock_get.return_value = mock_result + + result = cli_runner.invoke(knowledge, ["related", "abc123", "--chunk-index", "5"]) + + assert result.exit_code == 0, result.output + mock_get.assert_called_once_with("abc123", current_chunk_index=5, context_window=2) + + def test_related_with_custom_window(self, cli_runner): + """Test that --window parameter is passed correctly.""" + mock_result = { + "target_chunk_index": 5, + "context_window": 4, + "related_fragments_count": 0, + "related_fragments": [], + } + + with patch( + "redis_sre_agent.cli.knowledge.get_related_document_fragments", + new_callable=AsyncMock, + ) as mock_get: + mock_get.return_value = mock_result + + result = cli_runner.invoke( + knowledge, ["related", "abc123", "--chunk-index", "5", "--window", "4"] + ) + + assert result.exit_code == 0, result.output + mock_get.assert_called_once_with("abc123", current_chunk_index=5, context_window=4) + + def test_related_json_output(self, cli_runner): + """Test that --json flag outputs JSON.""" + mock_result = { + "title": "Test Doc", + "target_chunk_index": 5, + "related_fragments_count": 1, + "related_fragments": [{"chunk_index": 5, "content": "Target"}], + } + + with patch( + "redis_sre_agent.cli.knowledge.get_related_document_fragments", + new_callable=AsyncMock, + ) as mock_get: + mock_get.return_value = mock_result + + result = cli_runner.invoke( + knowledge, ["related", "abc123", "--chunk-index", "5", "--json"] + ) + + assert result.exit_code == 0, result.output + import json + + output_data = json.loads(result.output) + assert output_data["target_chunk_index"] == 5 + + def test_related_handles_error(self, cli_runner): + """Test that errors are handled gracefully.""" + with patch( + "redis_sre_agent.cli.knowledge.get_related_document_fragments", + new_callable=AsyncMock, + ) as mock_get: + mock_get.side_effect = Exception("Document not found") + + result = cli_runner.invoke(knowledge, ["related", "nonexistent", "--chunk-index", "0"]) + + assert result.exit_code == 0 # CLI doesn't exit with error code + assert "Error" in result.output or "error" in result.output diff --git a/tests/unit/cli/test_cli_mcp.py b/tests/unit/cli/test_cli_mcp.py new file mode 100644 index 00000000..0d998892 --- /dev/null +++ b/tests/unit/cli/test_cli_mcp.py @@ -0,0 +1,96 @@ +"""Unit tests for MCP CLI commands.""" + +from unittest.mock import MagicMock, patch + +import pytest +from click.testing import CliRunner + +from redis_sre_agent.cli.mcp import mcp + + +@pytest.fixture +def cli_runner(): + """Create a CLI runner for testing.""" + return CliRunner() + + +class TestMCPServeCLI: + """Tests for the mcp serve command.""" + + def test_serve_help_shows_options(self, cli_runner): + """Test that serve help shows all options.""" + result = cli_runner.invoke(mcp, ["serve", "--help"]) + + assert result.exit_code == 0 + assert "--transport" in result.output + assert "--host" in result.output + assert "--port" in result.output + assert "stdio" in result.output + assert "http" in result.output + assert "sse" in result.output + + def test_serve_default_transport_is_stdio(self, cli_runner): + """Test that default transport is stdio.""" + with patch("redis_sre_agent.mcp_server.server.run_stdio") as mock_run: + cli_runner.invoke(mcp, ["serve"]) + + # stdio mode doesn't print anything + mock_run.assert_called_once() + + def test_serve_http_mode(self, cli_runner): + """Test serve in HTTP mode.""" + with patch("redis_sre_agent.mcp_server.server.run_http") as mock_run: + result = cli_runner.invoke(mcp, ["serve", "--transport", "http"]) + + assert result.exit_code == 0 + mock_run.assert_called_once_with(host="0.0.0.0", port=8081) + assert "HTTP mode" in result.output + + def test_serve_sse_mode(self, cli_runner): + """Test serve in SSE mode.""" + with patch("redis_sre_agent.mcp_server.server.run_sse") as mock_run: + result = cli_runner.invoke(mcp, ["serve", "--transport", "sse"]) + + assert result.exit_code == 0 + mock_run.assert_called_once_with(host="0.0.0.0", port=8081) + assert "SSE mode" in result.output + + def test_serve_custom_host_and_port(self, cli_runner): + """Test serve with custom host and port.""" + with patch("redis_sre_agent.mcp_server.server.run_http") as mock_run: + result = cli_runner.invoke( + mcp, ["serve", "--transport", "http", "--host", "127.0.0.1", "--port", "9000"] + ) + + assert result.exit_code == 0 + mock_run.assert_called_once_with(host="127.0.0.1", port=9000) + + +class TestMCPListToolsCLI: + """Tests for the mcp list-tools command.""" + + def test_list_tools_help(self, cli_runner): + """Test that list-tools help is available.""" + result = cli_runner.invoke(mcp, ["list-tools", "--help"]) + + assert result.exit_code == 0 + assert "List available MCP tools" in result.output + + def test_list_tools_displays_tools(self, cli_runner): + """Test that list-tools displays available tools.""" + mock_tool = MagicMock() + mock_tool.name = "test_tool" + mock_tool.description = "A test tool for testing" + + mock_mcp_server = MagicMock() + mock_mcp_server._tool_manager._tools = {"test_tool": mock_tool} + + # Patch at the import location inside the function + with patch( + "redis_sre_agent.mcp_server.server.mcp", + mock_mcp_server, + ): + result = cli_runner.invoke(mcp, ["list-tools"]) + + assert result.exit_code == 0 + assert "Available MCP tools" in result.output diff --git a/tests/unit/cli/test_cli_query.py b/tests/unit/cli/test_cli_query.py index 78ba605d..22aebc0c 100644 --- a/tests/unit/cli/test_cli_query.py +++ b/tests/unit/cli/test_cli_query.py @@ -2,21 +2,45 @@ from unittest.mock import AsyncMock, MagicMock, patch +import pytest from click.testing import CliRunner from redis_sre_agent.cli.query import query -def test_query_cli_help_shows_instance_option(): +@pytest.fixture +def mock_thread_manager(): + """Create a mock ThreadManager that doesn't require Redis.""" + mock_tm = MagicMock() + mock_tm.create_thread = AsyncMock(return_value="test-thread-id") + mock_tm.get_thread = AsyncMock(return_value=None) + mock_tm.update_thread_subject = AsyncMock() + mock_tm.append_messages = AsyncMock() + return mock_tm + + +@pytest.fixture +def mock_redis_client(): + """Create a mock Redis client.""" + return MagicMock() + + +def test_query_cli_help_shows_options(): runner = CliRunner() result = runner.invoke(query, ["--help"]) assert result.exit_code == 0 assert "--redis-instance-id" in result.output assert "-r" in result.output + assert "--agent" in result.output + assert "-a" in result.output + assert "auto" in result.output + assert "triage" in result.output + assert "chat" in result.output + assert "knowledge" in result.output -def test_query_without_instance_uses_knowledge_agent(): +def test_query_without_instance_uses_knowledge_agent(mock_thread_manager, mock_redis_client): runner = CliRunner() mock_agent = MagicMock() @@ -31,6 +55,8 @@ def test_query_without_instance_uses_knowledge_agent(): "redis_sre_agent.cli.query.get_instance_by_id", new=AsyncMock(), ) as mock_get_instance, + patch("redis_sre_agent.cli.query.get_redis_client", return_value=mock_redis_client), + patch("redis_sre_agent.cli.query.ThreadManager", return_value=mock_thread_manager), ): result = runner.invoke(query, ["What is Redis SRE?"]) @@ -41,19 +67,27 @@ def test_query_without_instance_uses_knowledge_agent(): mock_agent.process_query.assert_awaited_once() -def test_query_with_instance_uses_sre_agent_and_passes_instance_context(): +def test_query_with_instance_uses_sre_agent_and_passes_instance_context( + mock_thread_manager, mock_redis_client +): runner = CliRunner() class DummyInstance: def __init__(self, id: str, name: str): # noqa: A003 - keep click-style arg name self.id = id self.name = name + self.instance_type = "oss_single" # Required by ChatAgent system prompt + self.connection_url = "redis://localhost:6379" + self.environment = "development" + self.usage = "cache" instance = DummyInstance("redis-prod-123", "Haink Production") mock_sre_agent = MagicMock() mock_sre_agent.process_query = AsyncMock(return_value="ok") + from redis_sre_agent.agent.router import AgentType + with ( patch( "redis_sre_agent.cli.query.get_instance_by_id", @@ -63,6 +97,12 @@ def __init__(self, id: str, name: str): # noqa: A003 - keep click-style arg nam "redis_sre_agent.cli.query.get_sre_agent", return_value=mock_sre_agent ) as mock_get_sre, patch("redis_sre_agent.cli.query.get_knowledge_agent") as mock_get_knowledge, + patch( + "redis_sre_agent.cli.query.route_to_appropriate_agent", + new=AsyncMock(return_value=AgentType.REDIS_TRIAGE), + ), + patch("redis_sre_agent.cli.query.get_redis_client", return_value=mock_redis_client), + patch("redis_sre_agent.cli.query.ThreadManager", return_value=mock_thread_manager), ): # Use -r / --redis-instance-id option to select instance result = runner.invoke( @@ -91,7 +131,9 @@ def __init__(self, id: str, name: str): # noqa: A003 - keep click-style arg nam assert kwargs.get("context") == {"instance_id": instance.id} -def test_query_with_unknown_instance_exits_with_error_and_skips_agents(): +def test_query_with_unknown_instance_exits_with_error_and_skips_agents( + mock_thread_manager, mock_redis_client +): """If -r is provided but the instance does not exist, CLI should error and exit. This directly tests the new existence-check logic in redis_sre_agent.cli.query. @@ -113,6 +155,8 @@ def test_query_with_unknown_instance_exits_with_error_and_skips_agents(): "redis_sre_agent.cli.query.get_knowledge_agent", return_value=mock_agent ) as mock_get_knowledge, patch("redis_sre_agent.cli.query.get_sre_agent") as mock_get_sre, + patch("redis_sre_agent.cli.query.get_redis_client", return_value=mock_redis_client), + patch("redis_sre_agent.cli.query.ThreadManager", return_value=mock_thread_manager), ): result = runner.invoke( query, @@ -134,3 +178,171 @@ def test_query_with_unknown_instance_exits_with_error_and_skips_agents(): mock_get_knowledge.assert_not_called() mock_get_sre.assert_not_called() mock_agent.process_query.assert_not_awaited() + + +def test_query_with_agent_triage_forces_triage_agent(mock_thread_manager, mock_redis_client): + """Test that --agent triage forces use of the triage agent.""" + runner = CliRunner() + + mock_agent = MagicMock() + mock_agent.process_query = AsyncMock(return_value="triage result") + + with ( + patch("redis_sre_agent.cli.query.get_sre_agent", return_value=mock_agent) as mock_get_sre, + patch("redis_sre_agent.cli.query.get_knowledge_agent") as mock_get_knowledge, + patch("redis_sre_agent.cli.query.get_chat_agent") as mock_get_chat, + patch("redis_sre_agent.cli.query.route_to_appropriate_agent") as mock_router, + patch("redis_sre_agent.cli.query.get_redis_client", return_value=mock_redis_client), + patch("redis_sre_agent.cli.query.ThreadManager", return_value=mock_thread_manager), + ): + result = runner.invoke(query, ["--agent", "triage", "Check my Redis health"]) + + assert result.exit_code == 0, result.output + assert "Triage (selected)" in result.output + + # Triage agent should be used + mock_get_sre.assert_called_once() + mock_get_knowledge.assert_not_called() + mock_get_chat.assert_not_called() + + # Router should NOT be called when agent is explicitly specified + mock_router.assert_not_called() + + mock_agent.process_query.assert_awaited_once() + + +def test_query_with_agent_knowledge_forces_knowledge_agent(mock_thread_manager, mock_redis_client): + """Test that --agent knowledge forces use of the knowledge agent.""" + runner = CliRunner() + + mock_agent = MagicMock() + mock_agent.process_query = AsyncMock(return_value="knowledge result") + + with ( + patch( + "redis_sre_agent.cli.query.get_knowledge_agent", return_value=mock_agent + ) as mock_get_knowledge, + patch("redis_sre_agent.cli.query.get_sre_agent") as mock_get_sre, + patch("redis_sre_agent.cli.query.get_chat_agent") as mock_get_chat, + patch("redis_sre_agent.cli.query.route_to_appropriate_agent") as mock_router, + patch("redis_sre_agent.cli.query.get_redis_client", return_value=mock_redis_client), + patch("redis_sre_agent.cli.query.ThreadManager", return_value=mock_thread_manager), + ): + result = runner.invoke(query, ["-a", "knowledge", "What is Redis replication?"]) + + assert result.exit_code == 0, result.output + assert "Knowledge (selected)" in result.output + + # Knowledge agent should be used + mock_get_knowledge.assert_called_once() + mock_get_sre.assert_not_called() + mock_get_chat.assert_not_called() + + # Router should NOT be called + mock_router.assert_not_called() + + mock_agent.process_query.assert_awaited_once() + + +def test_query_with_agent_chat_forces_chat_agent(mock_thread_manager, mock_redis_client): + """Test that --agent chat forces use of the chat agent.""" + runner = CliRunner() + + class DummyInstance: + def __init__(self): + self.id = "test-instance" + self.name = "Test Instance" + self.instance_type = "oss_single" + self.connection_url = "redis://localhost:6379" + self.environment = "development" + self.usage = "cache" + + instance = DummyInstance() + + mock_agent = MagicMock() + mock_agent.process_query = AsyncMock(return_value="chat result") + + with ( + patch("redis_sre_agent.cli.query.get_chat_agent", return_value=mock_agent) as mock_get_chat, + patch("redis_sre_agent.cli.query.get_sre_agent") as mock_get_sre, + patch("redis_sre_agent.cli.query.get_knowledge_agent") as mock_get_knowledge, + patch( + "redis_sre_agent.cli.query.get_instance_by_id", + new=AsyncMock(return_value=instance), + ), + patch("redis_sre_agent.cli.query.route_to_appropriate_agent") as mock_router, + patch("redis_sre_agent.cli.query.get_redis_client", return_value=mock_redis_client), + patch("redis_sre_agent.cli.query.ThreadManager", return_value=mock_thread_manager), + ): + result = runner.invoke(query, ["--agent", "chat", "-r", "test-instance", "Quick question"]) + + assert result.exit_code == 0, result.output + assert "Chat (selected)" in result.output + + # Chat agent should be used + mock_get_chat.assert_called_once() + mock_get_sre.assert_not_called() + mock_get_knowledge.assert_not_called() + + # Router should NOT be called + mock_router.assert_not_called() + + mock_agent.process_query.assert_awaited_once() + + +def test_query_with_agent_auto_uses_router(mock_thread_manager, mock_redis_client): + """Test that --agent auto (default) uses the router to select agent.""" + runner = CliRunner() + + from redis_sre_agent.agent.router import AgentType + + mock_agent = MagicMock() + mock_agent.process_query = AsyncMock(return_value="routed result") + + with ( + patch( + "redis_sre_agent.cli.query.get_knowledge_agent", return_value=mock_agent + ) as mock_get_knowledge, + patch("redis_sre_agent.cli.query.get_sre_agent") as mock_get_sre, + patch( + "redis_sre_agent.cli.query.route_to_appropriate_agent", + new=AsyncMock(return_value=AgentType.KNOWLEDGE_ONLY), + ) as mock_router, + patch("redis_sre_agent.cli.query.get_redis_client", return_value=mock_redis_client), + patch("redis_sre_agent.cli.query.ThreadManager", return_value=mock_thread_manager), + ): + # Default is auto, so router should be called + result = runner.invoke(query, ["What is Redis?"]) + + assert result.exit_code == 0, result.output + # Should show "Knowledge" without "(selected)" since it was auto-routed + assert "Agent: Knowledge" in result.output + assert "(selected)" not in result.output + + # Router should be called + mock_router.assert_awaited_once() + + mock_get_knowledge.assert_called_once() + mock_get_sre.assert_not_called() + + +def test_query_agent_option_is_case_insensitive(mock_thread_manager, mock_redis_client): + """Test that --agent option accepts different cases.""" + runner = CliRunner() + + mock_agent = MagicMock() + mock_agent.process_query = AsyncMock(return_value="result") + + with ( + patch("redis_sre_agent.cli.query.get_knowledge_agent", return_value=mock_agent), + patch("redis_sre_agent.cli.query.route_to_appropriate_agent"), + patch("redis_sre_agent.cli.query.get_redis_client", return_value=mock_redis_client), + patch("redis_sre_agent.cli.query.ThreadManager", return_value=mock_thread_manager), + ): + # Test uppercase + result = runner.invoke(query, ["--agent", "KNOWLEDGE", "test query"]) + assert result.exit_code == 0, result.output + + # Test mixed case + result = runner.invoke(query, ["--agent", "Knowledge", "test query"]) + assert result.exit_code == 0, result.output diff --git a/tests/unit/cli/test_cli_runbook.py b/tests/unit/cli/test_cli_runbook.py new file mode 100644 index 00000000..88b9807e --- /dev/null +++ b/tests/unit/cli/test_cli_runbook.py @@ -0,0 +1,60 @@ +"""Unit tests for runbook CLI commands.""" + +import pytest +from click.testing import CliRunner + +from redis_sre_agent.cli.runbook import runbook + + +@pytest.fixture +def cli_runner(): + """Create a CLI runner for testing.""" + return CliRunner() + + +class TestRunbookGenerateCLI: + """Tests for the runbook generate command.""" + + def test_generate_help_shows_options(self, cli_runner): + """Test that generate help shows all options.""" + result = cli_runner.invoke(runbook, ["generate", "--help"]) + + assert result.exit_code == 0 + assert "--severity" in result.output or "-s" in result.output + assert "--category" in result.output or "-c" in result.output + assert "--output-file" in result.output or "-o" in result.output + assert "--requirements" in result.output or "-r" in result.output + assert "--max-iterations" in result.output + assert "--auto-save" in result.output + assert "critical" in result.output + assert "warning" in result.output + assert "info" in result.output + + def test_generate_requires_topic_and_description(self, cli_runner): + """Test that generate requires topic and scenario_description.""" + result = cli_runner.invoke(runbook, ["generate"]) + + assert result.exit_code != 0 + assert "Missing argument" in result.output or "Usage:" in result.output + + +class TestRunbookEvaluateCLI: + """Tests for the runbook evaluate command.""" + + def test_evaluate_help_shows_options(self, cli_runner): + """Test that evaluate help shows all options.""" + result = cli_runner.invoke(runbook, ["evaluate", "--help"]) + + assert result.exit_code == 0 + assert "--input-dir" in result.output or "-i" in result.output + assert "--output-file" in result.output or "-o" in result.output + # Default value may not be shown in help, just check the option exists + assert "Directory containing runbook" in result.output + + def test_evaluate_with_nonexistent_dir(self, cli_runner): + """Test evaluate with non-existent directory.""" + result = cli_runner.invoke(runbook, ["evaluate", "--input-dir", "/nonexistent/path"]) + + assert result.exit_code != 0 + # Click should report the path doesn't exist + assert "does not exist" in result.output or "Error" in result.output diff --git a/tests/unit/cli/test_cli_schedules.py b/tests/unit/cli/test_cli_schedules.py new file mode 100644 index 00000000..a77ae145 --- /dev/null +++ b/tests/unit/cli/test_cli_schedules.py @@ -0,0 +1,154 @@ +"""Unit tests for schedules CLI commands.""" + +from unittest.mock import AsyncMock, patch + +import pytest +from click.testing import CliRunner + +from redis_sre_agent.cli.schedules import schedule + + +@pytest.fixture +def cli_runner(): + """Create a CLI runner for testing.""" + return CliRunner() + + +class TestScheduleListCLI: + """Tests for the schedule list command.""" + + def test_list_help_shows_options(self, cli_runner): + """Test that list help shows all options.""" + result = cli_runner.invoke(schedule, ["list", "--help"]) + + assert result.exit_code == 0 + assert "--json" in result.output + assert "--tz" in result.output + assert "--limit" in result.output or "-l" in result.output + + def test_list_displays_schedules(self, cli_runner): + """Test that list displays schedules.""" + mock_schedules = [ + { + "id": "sched-1", + "name": "Test Schedule", + "enabled": True, + "interval_type": "hours", + "interval_value": 1, + "next_run": "2024-01-01T00:00:00Z", + "last_run": "2023-12-31T23:00:00Z", + } + ] + + with patch( + "redis_sre_agent.core.schedules.list_schedules", + new_callable=AsyncMock, + return_value=mock_schedules, + ): + result = cli_runner.invoke(schedule, ["list"]) + + assert result.exit_code == 0 + # Should show table with schedules + assert "Test Schedule" in result.output or "Schedules" in result.output + + def test_list_json_output(self, cli_runner): + """Test that --json flag outputs JSON.""" + mock_schedules = [ + { + "id": "sched-1", + "name": "Test Schedule", + "enabled": True, + "interval_type": "hours", + "interval_value": 1, + } + ] + + with patch( + "redis_sre_agent.core.schedules.list_schedules", + new_callable=AsyncMock, + return_value=mock_schedules, + ): + result = cli_runner.invoke(schedule, ["list", "--json"]) + + assert result.exit_code == 0 + import json + + output_data = json.loads(result.output) + assert isinstance(output_data, list) + assert len(output_data) == 1 + assert output_data[0]["name"] == "Test Schedule" + + def test_list_empty_schedules(self, cli_runner): + """Test list with no schedules.""" + with patch( + "redis_sre_agent.core.schedules.list_schedules", + new_callable=AsyncMock, + return_value=[], + ): + result = cli_runner.invoke(schedule, ["list"]) + + assert result.exit_code == 0 + assert "No schedules found" in result.output + + +class TestScheduleGetCLI: + """Tests for the schedule get command.""" + + def test_get_help_shows_options(self, cli_runner): + """Test that get help shows options.""" + result = cli_runner.invoke(schedule, ["get", "--help"]) + + assert result.exit_code == 0 + assert "SCHEDULE_ID" in result.output or "schedule_id" in result.output.lower() + + +class TestScheduleCreateCLI: + """Tests for the schedule create command.""" + + def test_create_help_shows_options(self, cli_runner): + """Test that create help shows all options.""" + result = cli_runner.invoke(schedule, ["create", "--help"]) + + assert result.exit_code == 0 + assert "--name" in result.output + assert "--instance" in result.output or "instance" in result.output.lower() + + +class TestScheduleEnableDisableCLI: + """Tests for schedule enable/disable commands.""" + + def test_enable_help(self, cli_runner): + """Test that enable help is available.""" + result = cli_runner.invoke(schedule, ["enable", "--help"]) + + assert result.exit_code == 0 + assert "SCHEDULE_ID" in result.output or "schedule_id" in result.output.lower() + + def test_disable_help(self, cli_runner): + """Test that disable help is available.""" + result = cli_runner.invoke(schedule, ["disable", "--help"]) + + assert result.exit_code == 0 + assert "SCHEDULE_ID" in result.output or "schedule_id" in result.output.lower() + + +class TestScheduleDeleteCLI: + """Tests for the schedule delete command.""" + + def test_delete_help(self, cli_runner): + """Test that delete help is available.""" + result = cli_runner.invoke(schedule, ["delete", "--help"]) + + assert result.exit_code == 0 + assert "SCHEDULE_ID" in result.output or "schedule_id" in result.output.lower() + + +class TestScheduleRunNowCLI: + """Tests for the schedule run-now command.""" + + def test_run_now_help(self, cli_runner): + """Test that run-now help is available.""" + result = cli_runner.invoke(schedule, ["run-now", "--help"]) + + assert result.exit_code == 0 + assert "SCHEDULE_ID" in result.output or "schedule_id" in result.output.lower() diff --git a/tests/unit/cli/test_cli_tasks.py b/tests/unit/cli/test_cli_tasks.py new file mode 100644 index 00000000..bc8ab483 --- /dev/null +++ b/tests/unit/cli/test_cli_tasks.py @@ -0,0 +1,129 @@ +"""Unit tests for tasks CLI commands.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from click.testing import CliRunner + +from redis_sre_agent.cli.tasks import task + + +@pytest.fixture +def cli_runner(): + """Create a CLI runner for testing.""" + return CliRunner() + + +class TestTaskListCLI: + """Tests for the task list command.""" + + def test_list_help_shows_options(self, cli_runner): + """Test that list help shows all options.""" + result = cli_runner.invoke(task, ["list", "--help"]) + + assert result.exit_code == 0 + assert "--user-id" in result.output + assert "--status" in result.output + assert "--all" in result.output + assert "--limit" in result.output or "-l" in result.output + assert "--tz" in result.output + # Status choices + assert "queued" in result.output + assert "in_progress" in result.output + assert "done" in result.output + assert "failed" in result.output + assert "cancelled" in result.output + + def test_list_displays_tasks(self, cli_runner): + """Test that list displays tasks.""" + mock_tasks = [ + { + "task_id": "task-1", + "status": "in_progress", + "created_at": "2024-01-01T00:00:00Z", + "user_id": "user-1", + } + ] + + mock_redis = MagicMock() + mock_redis.get = AsyncMock(return_value=None) + + with ( + patch( + "redis_sre_agent.core.tasks.list_tasks", + new_callable=AsyncMock, + return_value=mock_tasks, + ), + patch( + "redis_sre_agent.core.redis.get_redis_client", + return_value=mock_redis, + ), + ): + result = cli_runner.invoke(task, ["list"]) + + assert result.exit_code == 0 + + def test_list_empty_tasks(self, cli_runner): + """Test list with no tasks.""" + with patch( + "redis_sre_agent.core.tasks.list_tasks", + new_callable=AsyncMock, + return_value=[], + ): + result = cli_runner.invoke(task, ["list"]) + + assert result.exit_code == 0 + assert "No tasks found" in result.output + + def test_list_with_status_filter(self, cli_runner): + """Test list with status filter.""" + mock_tasks = [ + { + "task_id": "task-1", + "status": "done", + "created_at": "2024-01-01T00:00:00Z", + } + ] + + mock_redis = MagicMock() + mock_redis.get = AsyncMock(return_value=None) + + with ( + patch( + "redis_sre_agent.core.tasks.list_tasks", + new_callable=AsyncMock, + return_value=mock_tasks, + ) as mock_list, + patch( + "redis_sre_agent.core.redis.get_redis_client", + return_value=mock_redis, + ), + ): + result = cli_runner.invoke(task, ["list", "--status", "done"]) + + assert result.exit_code == 0 + # Verify the status filter was passed + mock_list.assert_called_once() + + +class TestTaskGetCLI: + """Tests for the task get command.""" + + def test_get_help_shows_options(self, cli_runner): + """Test that get help shows options.""" + result = cli_runner.invoke(task, ["get", "--help"]) + + assert result.exit_code == 0 + assert "TASK_ID" in result.output or "task_id" in result.output.lower() + + +class TestTaskPurgeCLI: + """Tests for the task purge command.""" + + def test_purge_help_shows_options(self, cli_runner): + """Test that purge help shows options.""" + result = cli_runner.invoke(task, ["purge", "--help"]) + + assert result.exit_code == 0 + # Should have options for purging + assert "--" in result.output or "purge" in result.output.lower() diff --git a/tests/unit/cli/test_cli_thread_sources.py b/tests/unit/cli/test_cli_thread_sources.py index 64efb852..546a2959 100644 --- a/tests/unit/cli/test_cli_thread_sources.py +++ b/tests/unit/cli/test_cli_thread_sources.py @@ -4,19 +4,30 @@ from click.testing import CliRunner from redis_sre_agent.cli.main import main as cli_main +from redis_sre_agent.core.tasks import TaskState, TaskStatus, TaskUpdate from redis_sre_agent.core.threads import ( Thread, ThreadMetadata, - ThreadUpdate, ) -def _make_state_with_sources(thread_id: str = "thread-1") -> Thread: - update = ThreadUpdate( +def _make_thread(thread_id: str = "thread-1") -> Thread: + """Create a minimal thread (updates are now on TaskState, not Thread).""" + return Thread( + thread_id=thread_id, + messages=[], + context={}, + metadata=ThreadMetadata(), + ) + + +def _make_task_with_sources(task_id: str = "task-abc", thread_id: str = "thread-1") -> TaskState: + """Create a task with knowledge_sources updates.""" + update = TaskUpdate( message="Found 1 knowledge fragments", update_type="knowledge_sources", metadata={ - "task_id": "task-abc", + "task_id": task_id, "fragments": [ { "id": "frag-1", @@ -28,25 +39,39 @@ def _make_state_with_sources(thread_id: str = "thread-1") -> Thread: ], }, ) - return Thread( + return TaskState( + task_id=task_id, thread_id=thread_id, + status=TaskStatus.DONE, updates=[update], - context={}, - metadata=ThreadMetadata(), - result=None, - error_message=None, ) def test_thread_sources_cli_json_output(monkeypatch): runner = CliRunner() - async def fake_get_thread_state(_self, thread_id: str): # noqa: ARG001 - return _make_state_with_sources(thread_id) + async def fake_get_thread(_self, thread_id: str): # noqa: ARG001 + return _make_thread(thread_id) + + async def fake_get_task_state(_self, task_id: str): # noqa: ARG001 + return _make_task_with_sources(task_id) - with patch( - "redis_sre_agent.core.threads.ThreadManager.get_thread", - new=fake_get_thread_state, + async def fake_zrange(_self, _key, _start, _end): + return [b"task-abc"] + + with ( + patch( + "redis_sre_agent.core.threads.ThreadManager.get_thread", + new=fake_get_thread, + ), + patch( + "redis_sre_agent.core.tasks.TaskManager.get_task_state", + new=fake_get_task_state, + ), + patch( + "redis.asyncio.Redis.zrange", + new=fake_zrange, + ), ): result = runner.invoke(cli_main, ["thread", "sources", "thread-1", "--json"]) @@ -66,12 +91,28 @@ async def fake_get_thread_state(_self, thread_id: str): # noqa: ARG001 def test_thread_sources_cli_human_output(monkeypatch): runner = CliRunner() - async def fake_get_thread_state(_self, thread_id: str): # noqa: ARG001 - return _make_state_with_sources(thread_id) + async def fake_get_thread(_self, thread_id: str): # noqa: ARG001 + return _make_thread(thread_id) + + async def fake_get_task_state(_self, task_id: str): # noqa: ARG001 + return _make_task_with_sources(task_id) + + async def fake_zrange(_self, _key, _start, _end): + return [b"task-abc"] - with patch( - "redis_sre_agent.core.threads.ThreadManager.get_thread", - new=fake_get_thread_state, + with ( + patch( + "redis_sre_agent.core.threads.ThreadManager.get_thread", + new=fake_get_thread, + ), + patch( + "redis_sre_agent.core.tasks.TaskManager.get_task_state", + new=fake_get_task_state, + ), + patch( + "redis.asyncio.Redis.zrange", + new=fake_zrange, + ), ): result = runner.invoke(cli_main, ["thread", "sources", "thread-1"]) # table output diff --git a/tests/unit/cli/test_cli_worker.py b/tests/unit/cli/test_cli_worker.py new file mode 100644 index 00000000..70dfc17d --- /dev/null +++ b/tests/unit/cli/test_cli_worker.py @@ -0,0 +1,49 @@ +"""Unit tests for worker CLI command.""" + +from unittest.mock import MagicMock, patch + +import pytest +from click.testing import CliRunner + +from redis_sre_agent.cli.worker import worker + + +@pytest.fixture +def cli_runner(): + """Create a CLI runner for testing.""" + return CliRunner() + + +class TestWorkerCLI: + """Tests for the worker command.""" + + def test_worker_help_shows_options(self, cli_runner): + """Test that worker help shows all options.""" + result = cli_runner.invoke(worker, ["--help"]) + + assert result.exit_code == 0 + assert "--concurrency" in result.output or "-c" in result.output + assert "Number of concurrent tasks" in result.output + + def test_worker_concurrency_option_exists(self, cli_runner): + """Test that concurrency option exists.""" + result = cli_runner.invoke(worker, ["--help"]) + + assert result.exit_code == 0 + # Verify the option exists + assert "--concurrency" in result.output or "-c" in result.output + assert "INTEGER" in result.output + + def test_worker_requires_redis_url(self, cli_runner): + """Test that worker requires Redis URL.""" + mock_settings = MagicMock() + mock_settings.redis_url = None + + with patch( + "redis_sre_agent.cli.worker.settings", + mock_settings, + ): + result = cli_runner.invoke(worker) + + # Should fail without Redis URL + assert result.exit_code != 0 or "Redis URL not configured" in result.output diff --git a/tests/unit/core/test_config.py b/tests/unit/core/test_config.py index c0c9540d..de14b855 100644 --- a/tests/unit/core/test_config.py +++ b/tests/unit/core/test_config.py @@ -1,10 +1,12 @@ """Unit tests for configuration management.""" import os +import tempfile from typing import Optional from unittest.mock import patch import pytest +import yaml # Import Settings in tests with mocked environment @@ -293,6 +295,135 @@ def test_extra_ignore_behavior(self): assert not hasattr(settings, "unknown_field") +class TestMCPConfiguration: + """Test MCP server configuration models.""" + + def test_mcp_tool_config_defaults(self): + """Test MCPToolConfig default values.""" + from redis_sre_agent.core.config import MCPToolConfig + + config = MCPToolConfig() + assert config.capability is None + assert config.description is None + + def test_mcp_tool_config_with_capability(self): + """Test MCPToolConfig with capability set.""" + from redis_sre_agent.core.config import MCPToolConfig + from redis_sre_agent.tools.models import ToolCapability + + config = MCPToolConfig(capability=ToolCapability.LOGS) + assert config.capability == ToolCapability.LOGS + assert config.description is None + + def test_mcp_tool_config_with_description(self): + """Test MCPToolConfig with description override.""" + from redis_sre_agent.core.config import MCPToolConfig + + config = MCPToolConfig(description="Use this tool when searching for memories...") + assert config.capability is None + assert config.description == "Use this tool when searching for memories..." + + def test_mcp_tool_config_with_both(self): + """Test MCPToolConfig with both capability and description.""" + from redis_sre_agent.core.config import MCPToolConfig + from redis_sre_agent.tools.models import ToolCapability + + config = MCPToolConfig( + capability=ToolCapability.METRICS, + description="Custom description for the tool", + ) + assert config.capability == ToolCapability.METRICS + assert config.description == "Custom description for the tool" + + def test_mcp_server_config_command_based(self): + """Test MCPServerConfig for command-based (stdio) transport.""" + from redis_sre_agent.core.config import MCPServerConfig + + config = MCPServerConfig( + command="npx", + args=["-y", "@modelcontextprotocol/server-memory"], + env={"DEBUG": "true"}, + ) + assert config.command == "npx" + assert config.args == ["-y", "@modelcontextprotocol/server-memory"] + assert config.env == {"DEBUG": "true"} + assert config.url is None + assert config.tools is None + + def test_mcp_server_config_url_based(self): + """Test MCPServerConfig for URL-based transport.""" + from redis_sre_agent.core.config import MCPServerConfig + + config = MCPServerConfig(url="http://localhost:3000/mcp") + assert config.command is None + assert config.args is None + assert config.url == "http://localhost:3000/mcp" + # Default transport should be None (provider defaults to streamable_http) + assert config.transport is None + + def test_mcp_server_config_url_with_transport(self): + """Test MCPServerConfig with explicit transport type.""" + from redis_sre_agent.core.config import MCPServerConfig + + # Test with streamable_http transport (for GitHub remote MCP) + config = MCPServerConfig( + url="https://api.githubcopilot.com/mcp/", + headers={"Authorization": "Bearer test-token"}, + transport="streamable_http", + ) + assert config.url == "https://api.githubcopilot.com/mcp/" + assert config.headers == {"Authorization": "Bearer test-token"} + assert config.transport == "streamable_http" + + # Test with legacy SSE transport + config_sse = MCPServerConfig( + url="http://localhost:3000/mcp", + transport="sse", + ) + assert config_sse.transport == "sse" + + def test_mcp_server_config_with_tool_constraints(self): + """Test MCPServerConfig with tool constraints.""" + from redis_sre_agent.core.config import MCPServerConfig, MCPToolConfig + from redis_sre_agent.tools.models import ToolCapability + + config = MCPServerConfig( + command="npx", + args=["-y", "@modelcontextprotocol/server-memory"], + tools={ + "search_memories": MCPToolConfig(capability=ToolCapability.LOGS), + "create_memories": MCPToolConfig(description="Use this tool when..."), + }, + ) + assert config.tools is not None + assert len(config.tools) == 2 + assert config.tools["search_memories"].capability == ToolCapability.LOGS + assert config.tools["create_memories"].description == "Use this tool when..." + + def test_settings_mcp_servers_default_empty(self): + """Test that mcp_servers defaults to empty dict.""" + from redis_sre_agent.core.config import settings + + # Default should be an empty dict + assert isinstance(settings.mcp_servers, dict) + + def test_mcp_server_config_from_dict(self): + """Test that MCPServerConfig can be created from a dict (for env var parsing).""" + from redis_sre_agent.core.config import MCPServerConfig + + config_dict = { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-memory"], + "tools": { + "search_memories": {"capability": "logs"}, + }, + } + config = MCPServerConfig.model_validate(config_dict) + assert config.command == "npx" + assert config.args == ["-y", "@modelcontextprotocol/server-memory"] + # Note: tools with string capability will need special handling in the provider + + class TestSettingsValidation: """Test settings validation logic.""" @@ -344,3 +475,260 @@ def test_positive_integer_fields(self): assert settings.task_timeout == 1 assert settings.max_iterations == 1 assert settings.tool_timeout == 1 + + +class TestYamlConfigLoading: + """Test YAML configuration file loading.""" + + def test_yaml_config_loads_mcp_servers(self): + """Test that MCP servers can be loaded from YAML config.""" + yaml_content = { + "mcp_servers": { + "test-server": { + "command": "echo", + "args": ["hello"], + }, + "github": { + "command": "docker", + "args": ["run", "-i", "ghcr.io/github/github-mcp-server"], + "env": {"GITHUB_TOKEN": "test-token"}, + }, + } + } + + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + yaml.dump(yaml_content, f) + config_path = f.name + + try: + with patch.dict( + os.environ, + {"SRE_AGENT_CONFIG": config_path, "OPENAI_API_KEY": "test-key"}, + clear=True, + ): + from redis_sre_agent.core.config import MCPServerConfig, Settings + + settings = Settings() + + assert "test-server" in settings.mcp_servers + assert "github" in settings.mcp_servers + # Values may be MCPServerConfig objects or dicts depending on validation + test_server = settings.mcp_servers["test-server"] + if isinstance(test_server, MCPServerConfig): + assert test_server.command == "echo" + else: + assert test_server["command"] == "echo" + + github_server = settings.mcp_servers["github"] + if isinstance(github_server, MCPServerConfig): + assert github_server.env["GITHUB_TOKEN"] == "test-token" + else: + assert github_server["env"]["GITHUB_TOKEN"] == "test-token" + finally: + os.unlink(config_path) + + def test_yaml_config_with_tool_descriptions(self): + """Test that tool descriptions in YAML are properly loaded.""" + yaml_content = { + "mcp_servers": { + "memory-server": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-memory"], + "tools": { + "search_memories": { + "description": "Search for memories about Redis instances.", + }, + }, + } + } + } + + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + yaml.dump(yaml_content, f) + config_path = f.name + + try: + with patch.dict( + os.environ, + {"SRE_AGENT_CONFIG": config_path, "OPENAI_API_KEY": "test-key"}, + clear=True, + ): + from redis_sre_agent.core.config import MCPServerConfig, Settings + + settings = Settings() + + assert "memory-server" in settings.mcp_servers + server = settings.mcp_servers["memory-server"] + if isinstance(server, MCPServerConfig): + assert server.tools is not None + assert "search_memories" in server.tools + else: + tools = server["tools"] + assert "search_memories" in tools + finally: + os.unlink(config_path) + + def test_env_vars_override_yaml_config(self): + """Test that environment variables take precedence over YAML config.""" + yaml_content = { + "debug": False, + "log_level": "WARNING", + } + + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + yaml.dump(yaml_content, f) + config_path = f.name + + try: + with patch.dict( + os.environ, + { + "SRE_AGENT_CONFIG": config_path, + "OPENAI_API_KEY": "test-key", + "DEBUG": "true", # Override YAML value + "LOG_LEVEL": "DEBUG", # Override YAML value + }, + clear=True, + ): + from redis_sre_agent.core.config import Settings + + settings = Settings() + + # Env vars should win + assert settings.debug is True + assert settings.log_level == "DEBUG" + finally: + os.unlink(config_path) + + def test_yaml_config_source_class(self): + """Test YamlConfigSettingsSource directly with pydantic-settings built-in source.""" + from pydantic_settings import YamlConfigSettingsSource + + from redis_sre_agent.core.config import Settings + + yaml_content = { + "debug": True, + "log_level": "DEBUG", + "mcp_servers": {"test": {"command": "echo"}}, + } + + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + yaml.dump(yaml_content, f) + config_path = f.name + + try: + # Use the built-in YamlConfigSettingsSource with explicit yaml_file + source = YamlConfigSettingsSource(Settings, yaml_file=config_path) + data = source() + + assert data["debug"] is True + assert data["log_level"] == "DEBUG" + assert "mcp_servers" in data + finally: + os.unlink(config_path) + + def test_default_config_paths_are_checked(self): + """Test that default config paths are checked when SRE_AGENT_CONFIG is not set.""" + from redis_sre_agent.core.config import DEFAULT_CONFIG_PATHS + + # Verify the default paths exist in the module + assert "config.yaml" in DEFAULT_CONFIG_PATHS + assert "config.yml" in DEFAULT_CONFIG_PATHS + assert "sre_agent_config.yaml" in DEFAULT_CONFIG_PATHS + + def test_yaml_with_simple_settings(self): + """Test loading simple settings from YAML. + + Note: We test app_name and debug which don't have values in the + workspace's config.yaml or .env files. + """ + yaml_content = { + "app_name": "test-app-from-yaml", + "debug": True, + "recursion_limit": 200, # Use a field not in .env + } + + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + yaml.dump(yaml_content, f) + config_path = f.name + + try: + with patch.dict( + os.environ, + {"SRE_AGENT_CONFIG": config_path, "OPENAI_API_KEY": "test-key"}, + clear=True, + ): + from redis_sre_agent.core.config import Settings + + settings = Settings() + + # These values should come from our test YAML + assert settings.app_name == "test-app-from-yaml" + assert settings.debug is True + assert settings.recursion_limit == 200 # Should override default of 100 + finally: + os.unlink(config_path) + + def test_yaml_with_list_settings(self): + """Test loading list settings from YAML. + + Note: We use tool_providers which can be overridden from YAML, + but allowed_hosts may be set in workspace's .env file. + """ + yaml_content = { + "tool_providers": [ + "custom.provider.MyProvider", + "another.provider.AnotherProvider", + ], + } + + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + yaml.dump(yaml_content, f) + config_path = f.name + + try: + with patch.dict( + os.environ, + {"SRE_AGENT_CONFIG": config_path, "OPENAI_API_KEY": "test-key"}, + clear=True, + ): + from redis_sre_agent.core.config import Settings + + settings = Settings() + + # Tool providers should be exactly what we specified in YAML + assert len(settings.tool_providers) == 2 + assert "custom.provider.MyProvider" in settings.tool_providers + assert "another.provider.AnotherProvider" in settings.tool_providers + finally: + os.unlink(config_path) + + def test_yaml_source_returns_empty_for_missing_config(self): + """Test that YamlConfigSettingsSource returns empty dict for missing config.""" + from redis_sre_agent.core.config import Settings, YamlConfigSettingsSource + + with patch.dict(os.environ, {"SRE_AGENT_CONFIG": "/nonexistent/config.yaml"}, clear=True): + source = YamlConfigSettingsSource(Settings) + data = source() + + # Should return empty dict, not error + assert data == {} + + def test_yaml_source_returns_empty_for_invalid_yaml(self): + """Test that YamlConfigSettingsSource returns empty dict for invalid YAML.""" + from redis_sre_agent.core.config import Settings, YamlConfigSettingsSource + + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + # Write invalid YAML + f.write("invalid: yaml: content: [[[") + config_path = f.name + + try: + with patch.dict(os.environ, {"SRE_AGENT_CONFIG": config_path}, clear=True): + source = YamlConfigSettingsSource(Settings) + data = source() + + # Should return empty dict, not error + assert data == {} + finally: + os.unlink(config_path) diff --git a/tests/unit/core/test_progress.py b/tests/unit/core/test_progress.py new file mode 100644 index 00000000..42ba05d3 --- /dev/null +++ b/tests/unit/core/test_progress.py @@ -0,0 +1,288 @@ +"""Unit tests for the progress emission system.""" + +import asyncio +import logging +from io import StringIO +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from redis_sre_agent.core.progress import ( + CallbackEmitter, + CLIEmitter, + CompositeEmitter, + LocalProgressCounter, + LoggingEmitter, + NullEmitter, + ProgressEmitter, + TaskEmitter, + create_emitter, +) + + +class TestProgressEmitterProtocol: + """Test the ProgressEmitter protocol.""" + + def test_null_emitter_is_progress_emitter(self): + """NullEmitter should satisfy the ProgressEmitter protocol.""" + emitter = NullEmitter() + assert isinstance(emitter, ProgressEmitter) + + def test_logging_emitter_is_progress_emitter(self): + """LoggingEmitter should satisfy the ProgressEmitter protocol.""" + emitter = LoggingEmitter() + assert isinstance(emitter, ProgressEmitter) + + def test_cli_emitter_is_progress_emitter(self): + """CLIEmitter should satisfy the ProgressEmitter protocol.""" + emitter = CLIEmitter() + assert isinstance(emitter, ProgressEmitter) + + +class TestLocalProgressCounter: + """Test the LocalProgressCounter.""" + + @pytest.mark.asyncio + async def test_counter_starts_at_one(self): + """Counter should start at 1.""" + counter = LocalProgressCounter() + value = await counter.next() + assert value == 1 + + @pytest.mark.asyncio + async def test_counter_increments(self): + """Counter should increment on each call.""" + counter = LocalProgressCounter() + assert await counter.next() == 1 + assert await counter.next() == 2 + assert await counter.next() == 3 + + @pytest.mark.asyncio + async def test_counter_thread_safety(self): + """Counter should be thread-safe with asyncio.Lock.""" + counter = LocalProgressCounter() + results = [] + + async def increment(): + for _ in range(10): + results.append(await counter.next()) + + # Run multiple concurrent incrementers + await asyncio.gather(increment(), increment(), increment()) + + # Should have 30 unique, sequential values + assert len(results) == 30 + assert sorted(results) == list(range(1, 31)) + + +class TestNullEmitter: + """Test the NullEmitter.""" + + @pytest.mark.asyncio + async def test_emit_does_nothing(self): + """NullEmitter.emit should not raise and do nothing.""" + emitter = NullEmitter() + # Should not raise + await emitter.emit("test message", "progress", {"key": "value"}) + + +class TestLoggingEmitter: + """Test the LoggingEmitter.""" + + @pytest.mark.asyncio + async def test_emit_logs_message(self, caplog): + """LoggingEmitter should log messages.""" + emitter = LoggingEmitter(level=logging.INFO) + + with caplog.at_level(logging.INFO): + await emitter.emit("Test message", "tool_call") + + assert "[tool_call] Test message" in caplog.text + + +class TestCLIEmitter: + """Test the CLIEmitter.""" + + @pytest.mark.asyncio + async def test_emit_prints_to_file(self): + """CLIEmitter should print to the specified file.""" + output = StringIO() + emitter = CLIEmitter(use_colors=False, file=output) + + await emitter.emit("Test message", "progress") + + output.seek(0) + result = output.read() + assert "Test message" in result + + @pytest.mark.asyncio + async def test_emit_with_different_types(self): + """CLIEmitter should use different symbols for different types.""" + output = StringIO() + emitter = CLIEmitter(use_colors=False, file=output) + + await emitter.emit("Starting", "agent_start") + await emitter.emit("Tool", "tool_call") + await emitter.emit("Done", "agent_complete") + + output.seek(0) + result = output.read() + assert "🚀" in result # agent_start + assert "🔧" in result # tool_call + assert "✅" in result # agent_complete + + def test_colorize_disabled(self): + """Colors should be disabled when use_colors=False.""" + output = StringIO() + emitter = CLIEmitter(use_colors=False, file=output) + + result = emitter._colorize("test", "blue") + assert result == "test" # No ANSI codes + + +class TestTaskEmitter: + """Test the TaskEmitter.""" + + @pytest.mark.asyncio + async def test_emit_calls_task_manager(self): + """TaskEmitter should call task_manager.add_task_update.""" + mock_task_manager = MagicMock() + mock_task_manager.add_task_update = AsyncMock() + + emitter = TaskEmitter(task_manager=mock_task_manager, task_id="task-123") + + await emitter.emit("Progress update", "progress", {"key": "value"}) + + mock_task_manager.add_task_update.assert_called_once_with( + "task-123", "Progress update", "progress", {"key": "value"} + ) + + def test_task_id_property(self): + """TaskEmitter should expose task_id property.""" + mock_task_manager = MagicMock() + emitter = TaskEmitter(task_manager=mock_task_manager, task_id="task-456") + + assert emitter.task_id == "task-456" + + @pytest.mark.asyncio + async def test_emit_handles_errors_gracefully(self): + """TaskEmitter should not raise if task_manager fails.""" + mock_task_manager = MagicMock() + mock_task_manager.add_task_update = AsyncMock(side_effect=Exception("Redis error")) + + emitter = TaskEmitter(task_manager=mock_task_manager, task_id="task-123") + + # Should not raise + await emitter.emit("Progress update", "progress") + + +class TestCompositeEmitter: + """Test the CompositeEmitter.""" + + @pytest.mark.asyncio + async def test_emit_calls_all_emitters(self): + """CompositeEmitter should call emit on all child emitters.""" + emitter1 = MagicMock() + emitter1.emit = AsyncMock() + emitter2 = MagicMock() + emitter2.emit = AsyncMock() + + composite = CompositeEmitter([emitter1, emitter2]) + + await composite.emit("Test message", "progress", {"key": "value"}) + + emitter1.emit.assert_called_once_with("Test message", "progress", {"key": "value"}) + emitter2.emit.assert_called_once_with("Test message", "progress", {"key": "value"}) + + @pytest.mark.asyncio + async def test_emit_continues_on_error(self): + """CompositeEmitter should continue even if one emitter fails.""" + emitter1 = MagicMock() + emitter1.emit = AsyncMock(side_effect=Exception("Failed")) + emitter2 = MagicMock() + emitter2.emit = AsyncMock() + + composite = CompositeEmitter([emitter1, emitter2]) + + # Should not raise + await composite.emit("Test message", "progress") + + # Second emitter should still be called + emitter2.emit.assert_called_once() + + +class TestCallbackEmitter: + """Test the CallbackEmitter for backward compatibility.""" + + @pytest.mark.asyncio + async def test_emit_calls_callback(self): + """CallbackEmitter should forward to callback.""" + callback = AsyncMock() + emitter = CallbackEmitter(callback) + + await emitter.emit("Test message", "progress", {"key": "value"}) + + callback.assert_called_once_with("Test message", "progress", {"key": "value"}) + + @pytest.mark.asyncio + async def test_emit_handles_callback_without_metadata(self): + """CallbackEmitter should handle callbacks that don't accept metadata.""" + + async def simple_callback(msg, update_type): + pass + + emitter = CallbackEmitter(simple_callback) + + # Should not raise (falls back to 2-arg call) + await emitter.emit("Test message", "progress", {"key": "value"}) + + @pytest.mark.asyncio + async def test_emit_with_none_callback(self): + """CallbackEmitter should handle None callback gracefully.""" + emitter = CallbackEmitter(None) + + # Should not raise + await emitter.emit("Test message", "progress") + + +class TestCreateEmitterFactory: + """Test the create_emitter factory function.""" + + def test_returns_null_emitter_when_no_args(self): + """create_emitter with no args should return NullEmitter.""" + emitter = create_emitter() + assert isinstance(emitter, NullEmitter) + + def test_returns_cli_emitter_when_cli_true(self): + """create_emitter with cli=True should return CLIEmitter.""" + emitter = create_emitter(cli=True) + assert isinstance(emitter, CLIEmitter) + + def test_returns_task_emitter_when_task_args(self): + """create_emitter with task args should return TaskEmitter.""" + mock_task_manager = MagicMock() + emitter = create_emitter(task_id="task-123", task_manager=mock_task_manager) + assert isinstance(emitter, TaskEmitter) + + def test_returns_composite_when_multiple(self): + """create_emitter with multiple destinations should return CompositeEmitter.""" + mock_task_manager = MagicMock() + emitter = create_emitter( + task_id="task-123", + task_manager=mock_task_manager, + cli=True, + ) + assert isinstance(emitter, CompositeEmitter) + + def test_returns_single_emitter_when_one_destination(self): + """create_emitter should not wrap single emitter in CompositeEmitter.""" + emitter = create_emitter(cli=True) + # Should be CLIEmitter directly, not CompositeEmitter([CLIEmitter]) + assert isinstance(emitter, CLIEmitter) + assert not isinstance(emitter, CompositeEmitter) + + def test_includes_additional_emitters(self): + """create_emitter should include additional_emitters.""" + extra = NullEmitter() + emitter = create_emitter(cli=True, additional_emitters=[extra]) + assert isinstance(emitter, CompositeEmitter) diff --git a/tests/unit/core/test_tasks.py b/tests/unit/core/test_tasks.py index 6cac9b77..7eca9d66 100644 --- a/tests/unit/core/test_tasks.py +++ b/tests/unit/core/test_tasks.py @@ -17,7 +17,7 @@ class TestSRETaskCollection: def test_sre_task_collection_populated(self): """Test that SRE task collection contains expected tasks.""" - assert len(SRE_TASK_COLLECTION) == 4 + assert len(SRE_TASK_COLLECTION) == 6 task_names = [task.__name__ for task in SRE_TASK_COLLECTION] expected_tasks = [ @@ -25,6 +25,8 @@ def test_sre_task_collection_populated(self): "ingest_sre_document", "scheduler_task", "process_agent_turn", + "process_chat_turn", # New: MCP chat task + "process_knowledge_query", # New: MCP knowledge query task ] for expected_task in expected_tasks: diff --git a/tests/unit/core/test_thread_management.py b/tests/unit/core/test_thread_management.py index 88d33f1a..43329a8d 100644 --- a/tests/unit/core/test_thread_management.py +++ b/tests/unit/core/test_thread_management.py @@ -84,16 +84,11 @@ async def test_get_thread_state_success(self, thread_manager): """Test successful thread state retrieval.""" # Mock Redis data thread_manager._redis_client.exists.return_value = True - thread_manager._redis_client.get.side_effect = [ - None, # result - None, # error - ] thread_manager._redis_client.lrange.return_value = [ json.dumps( { - "timestamp": "2023-01-01T00:00:00Z", - "message": "Test update", - "update_type": "progress", + "role": "user", + "content": "Test message", "metadata": None, } ) @@ -113,30 +108,43 @@ async def test_get_thread_state_success(self, thread_manager): state = await thread_manager.get_thread("test_thread") assert state is not None - assert len(state.updates) == 1 - assert state.updates[0].message == "Test update" + assert len(state.messages) == 1 + assert state.messages[0].content == "Test message" + assert state.messages[0].role == "user" assert state.metadata.user_id == "test_user" @pytest.mark.asyncio - async def test_add_thread_update(self, thread_manager): - """Test adding thread updates.""" - result = await thread_manager.add_thread_update( - "test_thread", "Test progress message", "progress", {"tool": "test_tool"} - ) + async def test_add_thread_update_deprecated(self, thread_manager): + """Test that add_thread_update is deprecated but still works (publishes to stream).""" + import warnings + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + result = await thread_manager.add_thread_update( + "test_thread", "Test progress message", "progress", {"tool": "test_tool"} + ) + # Should have a deprecation warning + assert len(w) >= 1 + assert issubclass(w[0].category, DeprecationWarning) + assert "deprecated" in str(w[0].message).lower() assert result is True - thread_manager._redis_client.lpush.assert_called() - thread_manager._redis_client.ltrim.assert_called() @pytest.mark.asyncio - async def test_set_thread_result(self, thread_manager): - """Test setting thread result.""" + async def test_set_thread_result_deprecated(self, thread_manager): + """Test that set_thread_result is deprecated but still works (publishes to stream).""" + import warnings + result_data = {"response": "Test response", "metadata": {}} - result = await thread_manager.set_thread_result("test_thread", result_data) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + result = await thread_manager.set_thread_result("test_thread", result_data) + # Should have a deprecation warning + assert len(w) >= 1 + assert issubclass(w[0].category, DeprecationWarning) assert result is True - thread_manager._redis_client.set.assert_called() @pytest.mark.asyncio @pytest.mark.asyncio @@ -169,15 +177,18 @@ async def test_process_agent_turn_success(self): mock_get_redis.return_value = mock_redis # Mock thread manager + mock_manager = AsyncMock() mock_manager_class.return_value = mock_manager mock_manager.get_thread.return_value = Thread( thread_id="test_thread", - context={"messages": []}, + messages=[], + context={}, metadata=ThreadMetadata(), ) mock_manager.add_thread_update.return_value = True - mock_manager.set_thread_result.return_value = True + mock_manager._publish_stream_update.return_value = True + mock_manager._save_thread_state.return_value = True # Mock routing to use Redis-focused agent (not knowledge-only) from redis_sre_agent.agent.router import AgentType @@ -213,9 +224,8 @@ async def mock_route_func(*args, **kwargs): assert result["response"] == "Test response from agent" assert result["metadata"]["iterations"] == 2 - # Verify manager calls - mock_manager.add_thread_update.assert_called() - mock_manager.set_thread_result.assert_called() + # Verify thread manager saved state + mock_manager._save_thread_state.assert_called() @pytest.mark.asyncio async def test_process_agent_turn_thread_not_found(self): @@ -312,18 +322,20 @@ def test_thread_update_creation(self): assert update.timestamp is not None def test_thread_state_creation(self): - """Test ThreadState model creation.""" + """Test Thread model creation.""" + from redis_sre_agent.core.threads import Message + state = Thread( thread_id="test_thread", context={"query": "test"}, - updates=[ThreadUpdate(message="Test update")], + messages=[Message(role="user", content="Test message")], ) assert state.thread_id == "test_thread" assert state.context["query"] == "test" - assert len(state.updates) == 1 - assert state.result is None - assert state.error_message is None + assert len(state.messages) == 1 + assert state.messages[0].content == "Test message" + assert state.messages[0].role == "user" def test_thread_metadata_defaults(self): """Test ThreadMetadata default values.""" diff --git a/tests/unit/mcp_server/__init__.py b/tests/unit/mcp_server/__init__.py new file mode 100644 index 00000000..9c3e61b1 --- /dev/null +++ b/tests/unit/mcp_server/__init__.py @@ -0,0 +1 @@ +"""Tests for MCP server module.""" diff --git a/tests/unit/mcp_server/test_mcp_server.py b/tests/unit/mcp_server/test_mcp_server.py new file mode 100644 index 00000000..c628075b --- /dev/null +++ b/tests/unit/mcp_server/test_mcp_server.py @@ -0,0 +1,576 @@ +"""Tests for MCP server tools.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from redis_sre_agent.mcp_server.server import ( + mcp, + redis_sre_create_instance, + redis_sre_database_chat, + redis_sre_deep_triage, + redis_sre_general_chat, + redis_sre_get_task_status, + redis_sre_get_thread, + redis_sre_knowledge_query, + redis_sre_knowledge_search, + redis_sre_list_instances, +) + + +class TestMCPServerSetup: + """Test MCP server configuration.""" + + def test_mcp_server_name(self): + """Test that the MCP server has correct name.""" + assert mcp.name == "redis-sre-agent" + + def test_mcp_server_has_instructions(self): + """Test that the MCP server has instructions.""" + assert mcp.instructions is not None + assert "Redis SRE Agent" in mcp.instructions + + def test_mcp_server_has_tools(self): + """Test that all expected tools are registered.""" + tool_names = [t.name for t in mcp._tool_manager._tools.values()] + assert "redis_sre_deep_triage" in tool_names + assert "redis_sre_general_chat" in tool_names + assert "redis_sre_database_chat" in tool_names + assert "redis_sre_knowledge_search" in tool_names + assert "redis_sre_knowledge_query" in tool_names + assert "redis_sre_get_thread" in tool_names + assert "redis_sre_get_task_status" in tool_names + assert "redis_sre_list_instances" in tool_names + assert "redis_sre_create_instance" in tool_names + + +class TestDeepTriageTool: + """Test the redis_sre_deep_triage MCP tool.""" + + @pytest.mark.asyncio + async def test_deep_triage_success(self): + """Test successful deep triage request.""" + mock_result = { + "thread_id": "thread-123", + "task_id": "task-456", + "status": "queued", + "message": "Task created", + } + + with ( + patch("redis_sre_agent.core.redis.get_redis_client"), + patch("redis_sre_agent.core.tasks.create_task", new_callable=AsyncMock) as mock_create, + ): + mock_create.return_value = mock_result + + result = await redis_sre_deep_triage( + query="High memory usage on Redis", + instance_id="redis-prod-1", + user_id="user-123", + ) + + assert result["thread_id"] == "thread-123" + assert result["task_id"] == "task-456" + assert "status" in result + mock_create.assert_called_once() + + @pytest.mark.asyncio + async def test_deep_triage_error_handling(self): + """Test deep triage error handling.""" + with ( + patch("redis_sre_agent.core.redis.get_redis_client"), + patch("redis_sre_agent.core.tasks.create_task", new_callable=AsyncMock) as mock_create, + ): + mock_create.side_effect = Exception("Redis connection failed") + + result = await redis_sre_deep_triage(query="Test query") + + assert result["status"] == "failed" + assert "error" in result + + +class TestGeneralChatTool: + """Test the redis_sre_general_chat MCP tool. + + Note: The chat tool creates a task and returns task_id/thread_id + instead of running synchronously. This matches the triage pattern. + """ + + @pytest.mark.asyncio + async def test_general_chat_creates_task(self): + """Test that general_chat creates a task and returns task_id.""" + mock_result = { + "thread_id": "thread-123", + "task_id": "task-456", + "status": "queued", + "message": "Task created", + } + + with ( + patch("redis_sre_agent.core.redis.get_redis_client"), + patch("redis_sre_agent.core.tasks.create_task", new_callable=AsyncMock) as mock_create, + ): + mock_create.return_value = mock_result + + result = await redis_sre_general_chat(query="What's the memory usage?") + + assert result["thread_id"] == "thread-123" + assert result["task_id"] == "task-456" + assert "status" in result + mock_create.assert_called_once() + + @pytest.mark.asyncio + async def test_general_chat_with_instance_id(self): + """Test general_chat with a specific instance includes it in context.""" + mock_result = { + "thread_id": "thread-123", + "task_id": "task-456", + "status": "queued", + } + + with ( + patch("redis_sre_agent.core.redis.get_redis_client"), + patch("redis_sre_agent.core.tasks.create_task", new_callable=AsyncMock) as mock_create, + ): + mock_create.return_value = mock_result + + result = await redis_sre_general_chat(query="Check status", instance_id="redis-prod-1") + + assert result["task_id"] == "task-456" + # Verify instance_id was passed in context + call_kwargs = mock_create.call_args.kwargs + assert call_kwargs["context"]["instance_id"] == "redis-prod-1" + + @pytest.mark.asyncio + async def test_general_chat_error_handling(self): + """Test general_chat error handling.""" + with ( + patch("redis_sre_agent.core.redis.get_redis_client"), + patch("redis_sre_agent.core.tasks.create_task", new_callable=AsyncMock) as mock_create, + ): + mock_create.side_effect = Exception("Redis connection failed") + + result = await redis_sre_general_chat(query="Test query") + + assert result["status"] == "failed" + assert "error" in result + + +class TestDatabaseChatTool: + """Test the redis_sre_database_chat MCP tool with category exclusion.""" + + @pytest.mark.asyncio + async def test_database_chat_excludes_all_mcp_by_default(self): + """Test that database_chat excludes all MCP categories by default.""" + mock_result = { + "thread_id": "thread-123", + "task_id": "task-456", + "status": "queued", + } + + with ( + patch("redis_sre_agent.core.redis.get_redis_client"), + patch("redis_sre_agent.core.tasks.create_task", new_callable=AsyncMock) as mock_create, + ): + mock_create.return_value = mock_result + + result = await redis_sre_database_chat(query="What's the memory usage?") + + assert result["task_id"] == "task-456" + # Verify that exclude_mcp_categories is set in context + call_kwargs = mock_create.call_args.kwargs + assert "exclude_mcp_categories" in call_kwargs["context"] + + @pytest.mark.asyncio + async def test_database_chat_with_selective_exclusion(self): + """Test database_chat with selective category exclusion.""" + mock_result = { + "thread_id": "thread-123", + "task_id": "task-456", + "status": "queued", + } + + with ( + patch("redis_sre_agent.core.redis.get_redis_client"), + patch("redis_sre_agent.core.tasks.create_task", new_callable=AsyncMock) as mock_create, + ): + mock_create.return_value = mock_result + + # Only exclude tickets and repos + result = await redis_sre_database_chat( + query="Check status", + exclude_mcp_categories=["tickets", "repos"], + ) + + assert result["task_id"] == "task-456" + call_kwargs = mock_create.call_args.kwargs + assert call_kwargs["context"]["exclude_mcp_categories"] == ["tickets", "repos"] + + +class TestKnowledgeSearchTool: + """Test the redis_sre_knowledge_search MCP tool.""" + + @pytest.mark.asyncio + async def test_knowledge_search_success(self): + """Test successful knowledge search.""" + mock_result = { + "results": [ + { + "title": "Redis Memory Management", + "content": "Redis uses memory...", + "source": "docs", + "category": "documentation", + } + ] + } + + with patch( + "redis_sre_agent.core.knowledge_helpers.search_knowledge_base_helper", + new_callable=AsyncMock, + ) as mock_search: + mock_search.return_value = mock_result + + result = await redis_sre_knowledge_search(query="memory management", limit=5) + + assert result["query"] == "memory management" + assert len(result["results"]) == 1 + assert result["results"][0]["title"] == "Redis Memory Management" + mock_search.assert_called_once() + + @pytest.mark.asyncio + async def test_knowledge_search_limit_clamped(self): + """Test that limit is clamped to valid range.""" + with patch( + "redis_sre_agent.core.knowledge_helpers.search_knowledge_base_helper", + new_callable=AsyncMock, + ) as mock_search: + mock_search.return_value = {"results": []} + + # Test with too high limit (max is 50) + await redis_sre_knowledge_search(query="test", limit=100) + call_args = mock_search.call_args + assert call_args.kwargs["limit"] == 50 + + # Test with too low limit + await redis_sre_knowledge_search(query="test", limit=0) + call_args = mock_search.call_args + assert call_args.kwargs["limit"] == 1 + + @pytest.mark.asyncio + async def test_knowledge_search_error_handling(self): + """Test knowledge search error handling.""" + with patch( + "redis_sre_agent.core.knowledge_helpers.search_knowledge_base_helper", + new_callable=AsyncMock, + ) as mock_search: + mock_search.side_effect = Exception("Search failed") + + result = await redis_sre_knowledge_search(query="test") + + assert "error" in result + assert result["results"] == [] + assert result["total_results"] == 0 + + +class TestKnowledgeQueryTool: + """Test the redis_sre_knowledge_query MCP tool. + + The knowledge_query tool creates a task that uses the KnowledgeOnlyAgent + to answer questions about SRE practices and Redis. + """ + + @pytest.mark.asyncio + async def test_knowledge_query_creates_task(self): + """Test that knowledge_query creates a task and returns task_id.""" + mock_result = { + "thread_id": "thread-123", + "task_id": "task-456", + "status": "queued", + "message": "Task created", + } + + with ( + patch("redis_sre_agent.core.redis.get_redis_client"), + patch("redis_sre_agent.core.tasks.create_task", new_callable=AsyncMock) as mock_create, + ): + mock_create.return_value = mock_result + + result = await redis_sre_knowledge_query(query="What are Redis eviction policies?") + + assert result["thread_id"] == "thread-123" + assert result["task_id"] == "task-456" + assert "status" in result + mock_create.assert_called_once() + # Verify agent_type is set in context + call_kwargs = mock_create.call_args.kwargs + assert call_kwargs["context"]["agent_type"] == "knowledge" + + @pytest.mark.asyncio + async def test_knowledge_query_error_handling(self): + """Test knowledge_query error handling.""" + with ( + patch("redis_sre_agent.core.redis.get_redis_client"), + patch("redis_sre_agent.core.tasks.create_task", new_callable=AsyncMock) as mock_create, + ): + mock_create.side_effect = Exception("Redis connection failed") + + result = await redis_sre_knowledge_query(query="Test query") + + assert result["status"] == "failed" + assert "error" in result + + +class TestListInstancesTool: + """Test the redis_sre_list_instances MCP tool.""" + + @pytest.mark.asyncio + async def test_list_instances_success(self): + """Test successful instance listing.""" + mock_instance = MagicMock() + mock_instance.id = "redis-prod-1" + mock_instance.name = "Production Redis" + mock_instance.environment = "production" + mock_instance.usage = "cache" + mock_instance.description = "Main cache" + mock_instance.instance_type = "redis_cloud" + mock_instance.repo_url = "https://github.com/example/repo" + + with patch( + "redis_sre_agent.core.instances.get_instances", + new_callable=AsyncMock, + ) as mock_get: + mock_get.return_value = [mock_instance] + + result = await redis_sre_list_instances() + + assert result["total"] == 1 + assert result["instances"][0]["id"] == "redis-prod-1" + assert result["instances"][0]["name"] == "Production Redis" + assert result["instances"][0]["repo_url"] == "https://github.com/example/repo" + + @pytest.mark.asyncio + async def test_list_instances_empty(self): + """Test empty instance list.""" + with patch( + "redis_sre_agent.core.instances.get_instances", + new_callable=AsyncMock, + ) as mock_get: + mock_get.return_value = [] + + result = await redis_sre_list_instances() + + assert result["total"] == 0 + assert result["instances"] == [] + + @pytest.mark.asyncio + async def test_list_instances_error(self): + """Test list instances error handling.""" + with patch( + "redis_sre_agent.core.instances.get_instances", + new_callable=AsyncMock, + ) as mock_get: + mock_get.side_effect = Exception("Connection failed") + + result = await redis_sre_list_instances() + + assert "error" in result + assert result["instances"] == [] + + +class TestCreateInstanceTool: + """Test the redis_sre_create_instance MCP tool.""" + + @pytest.mark.asyncio + async def test_create_instance_success(self): + """Test successful instance creation.""" + with ( + patch( + "redis_sre_agent.core.instances.get_instances", + new_callable=AsyncMock, + ) as mock_get, + patch( + "redis_sre_agent.core.instances.save_instances", + new_callable=AsyncMock, + ) as mock_save, + ): + mock_get.return_value = [] + mock_save.return_value = True + + result = await redis_sre_create_instance( + name="test-redis", + connection_url="redis://localhost:6379", + environment="development", + usage="cache", + description="Test instance", + ) + + assert result["status"] == "created" + assert result["name"] == "test-redis" + assert "id" in result + + @pytest.mark.asyncio + async def test_create_instance_invalid_environment(self): + """Test create instance with invalid environment.""" + result = await redis_sre_create_instance( + name="test-redis", + connection_url="redis://localhost:6379", + environment="invalid", + usage="cache", + description="Test", + ) + + assert result["status"] == "failed" + assert "error" in result + assert "environment" in result["error"].lower() + + @pytest.mark.asyncio + async def test_create_instance_invalid_usage(self): + """Test create instance with invalid usage.""" + result = await redis_sre_create_instance( + name="test-redis", + connection_url="redis://localhost:6379", + environment="development", + usage="invalid", + description="Test", + ) + + assert result["status"] == "failed" + assert "error" in result + assert "usage" in result["error"].lower() + + @pytest.mark.asyncio + async def test_create_instance_duplicate_name(self): + """Test create instance with duplicate name.""" + from unittest.mock import MagicMock + + existing = MagicMock() + existing.name = "test-redis" + + with patch( + "redis_sre_agent.core.instances.get_instances", + new_callable=AsyncMock, + ) as mock_get: + mock_get.return_value = [existing] + + result = await redis_sre_create_instance( + name="test-redis", + connection_url="redis://localhost:6379", + environment="development", + usage="cache", + description="Test", + ) + + assert result["status"] == "failed" + assert "already exists" in result["error"] + + +class TestGetThreadTool: + """Test the redis_sre_get_thread MCP tool.""" + + @pytest.mark.asyncio + async def test_get_thread_success(self): + """Test successful thread retrieval.""" + from redis_sre_agent.core.threads import Message, Thread, ThreadMetadata + + # Create a proper Thread object with messages + mock_thread = Thread( + thread_id="thread-123", + messages=[ + Message(role="user", content="Check memory"), + Message(role="assistant", content="Analyzing..."), + ], + context={}, + metadata=ThreadMetadata(), + ) + + mock_redis = AsyncMock() + mock_redis.zrevrange = AsyncMock(return_value=[]) # No tasks + + with ( + patch("redis_sre_agent.core.redis.get_redis_client", return_value=mock_redis), + patch( + "redis_sre_agent.core.threads.ThreadManager.get_thread", + new_callable=AsyncMock, + ) as mock_get, + ): + mock_get.return_value = mock_thread + + result = await redis_sre_get_thread(thread_id="thread-123") + + assert result["thread_id"] == "thread-123" + assert result["message_count"] == 2 + assert result["messages"][0]["role"] == "user" + + @pytest.mark.asyncio + async def test_get_thread_not_found(self): + """Test thread not found.""" + with ( + patch("redis_sre_agent.core.redis.get_redis_client"), + patch( + "redis_sre_agent.core.threads.ThreadManager.get_thread", + new_callable=AsyncMock, + ) as mock_get, + ): + mock_get.return_value = None + + result = await redis_sre_get_thread(thread_id="nonexistent") + + assert "error" in result + assert "not found" in result["error"] + + +class TestGetTaskStatusTool: + """Test the redis_sre_get_task_status MCP tool.""" + + @pytest.mark.asyncio + async def test_get_task_status_success(self): + """Test successful task status retrieval.""" + # Mock returns data in the format that get_task_by_id actually returns + mock_task = { + "task_id": "task-123", + "thread_id": "thread-456", + "status": "done", + "updates": [ + {"timestamp": "2024-01-01T00:00:30Z", "message": "Processing", "type": "progress"} + ], + "result": {"summary": "Complete"}, + "error_message": None, + "metadata": { + "subject": "Health check", + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-01T00:01:00Z", + "user_id": None, + }, + "context": {}, + } + + with patch( + "redis_sre_agent.core.tasks.get_task_by_id", + new_callable=AsyncMock, + ) as mock_get: + mock_get.return_value = mock_task + + result = await redis_sre_get_task_status(task_id="task-123") + + assert result["task_id"] == "task-123" + assert result["status"] == "done" + assert result["thread_id"] == "thread-456" + assert result["subject"] == "Health check" + assert result["created_at"] == "2024-01-01T00:00:00Z" + assert result["updated_at"] == "2024-01-01T00:01:00Z" + assert result["updates"] == mock_task["updates"] + assert result["result"] == {"summary": "Complete"} + + @pytest.mark.asyncio + async def test_get_task_status_not_found(self): + """Test task not found.""" + with patch( + "redis_sre_agent.core.tasks.get_task_by_id", + new_callable=AsyncMock, + ) as mock_get: + mock_get.side_effect = ValueError("Task task-999 not found") + + result = await redis_sre_get_task_status(task_id="task-999") + + assert result["status"] == "not_found" + assert "error" in result diff --git a/tests/unit/tools/mcp_provider/__init__.py b/tests/unit/tools/mcp_provider/__init__.py new file mode 100644 index 00000000..35969d59 --- /dev/null +++ b/tests/unit/tools/mcp_provider/__init__.py @@ -0,0 +1 @@ +"""Tests for MCP tool provider.""" diff --git a/tests/unit/tools/mcp_provider/test_mcp_provider.py b/tests/unit/tools/mcp_provider/test_mcp_provider.py new file mode 100644 index 00000000..67284e20 --- /dev/null +++ b/tests/unit/tools/mcp_provider/test_mcp_provider.py @@ -0,0 +1,176 @@ +"""Unit tests for MCP tool provider.""" + +import pytest + +from redis_sre_agent.core.config import MCPServerConfig, MCPToolConfig +from redis_sre_agent.tools.mcp.provider import MCPToolProvider +from redis_sre_agent.tools.models import ToolCapability + + +class TestMCPToolProvider: + """Test MCPToolProvider functionality.""" + + def test_provider_name(self): + """Test that provider name is based on server name.""" + config = MCPServerConfig(command="test") + provider = MCPToolProvider(server_name="memory", server_config=config) + assert provider.provider_name == "mcp_memory" + + def test_provider_name_with_special_chars(self): + """Test provider name with various server names.""" + config = MCPServerConfig(command="test") + + provider = MCPToolProvider(server_name="my_server", server_config=config) + assert provider.provider_name == "mcp_my_server" + + provider = MCPToolProvider(server_name="test123", server_config=config) + assert provider.provider_name == "mcp_test123" + + def test_should_include_tool_no_filter(self): + """Test that all tools are included when no filter is specified.""" + config = MCPServerConfig(command="test") + provider = MCPToolProvider(server_name="test", server_config=config) + + assert provider._should_include_tool("any_tool") is True + assert provider._should_include_tool("another_tool") is True + + def test_should_include_tool_with_filter(self): + """Test that only specified tools are included when filter is set.""" + config = MCPServerConfig( + command="test", + tools={ + "allowed_tool": MCPToolConfig(), + "another_allowed": MCPToolConfig(), + }, + ) + provider = MCPToolProvider(server_name="test", server_config=config) + + assert provider._should_include_tool("allowed_tool") is True + assert provider._should_include_tool("another_allowed") is True + assert provider._should_include_tool("not_allowed") is False + + def test_get_capability_default(self): + """Test that default capability is UTILITIES.""" + config = MCPServerConfig(command="test") + provider = MCPToolProvider(server_name="test", server_config=config) + + assert provider._get_capability("any_tool") == ToolCapability.UTILITIES + + def test_get_capability_with_override(self): + """Test that capability override is respected.""" + config = MCPServerConfig( + command="test", + tools={ + "search_tool": MCPToolConfig(capability=ToolCapability.LOGS), + "metrics_tool": MCPToolConfig(capability=ToolCapability.METRICS), + "no_override": MCPToolConfig(), + }, + ) + provider = MCPToolProvider(server_name="test", server_config=config) + + assert provider._get_capability("search_tool") == ToolCapability.LOGS + assert provider._get_capability("metrics_tool") == ToolCapability.METRICS + assert provider._get_capability("no_override") == ToolCapability.UTILITIES + assert provider._get_capability("unknown_tool") == ToolCapability.UTILITIES + + def test_get_description_default(self): + """Test that MCP description is used by default.""" + config = MCPServerConfig(command="test") + provider = MCPToolProvider(server_name="test", server_config=config) + + mcp_desc = "Original MCP description" + assert provider._get_description("any_tool", mcp_desc) == mcp_desc + + def test_get_description_with_override(self): + """Test that description override is respected.""" + config = MCPServerConfig( + command="test", + tools={ + "custom_tool": MCPToolConfig(description="Custom description"), + "no_override": MCPToolConfig(), + }, + ) + provider = MCPToolProvider(server_name="test", server_config=config) + + assert provider._get_description("custom_tool", "MCP desc") == "Custom description" + assert provider._get_description("no_override", "MCP desc") == "MCP desc" + assert provider._get_description("unknown", "MCP desc") == "MCP desc" + + def test_get_description_with_original_template(self): + """Test that {original} placeholder is replaced with MCP description.""" + config = MCPServerConfig( + command="test", + tools={ + "templated_tool": MCPToolConfig(description="Custom context. {original}"), + "prepended": MCPToolConfig( + description="WARNING: Use carefully. {original} See docs for details." + ), + }, + ) + provider = MCPToolProvider(server_name="test", server_config=config) + + # Template should replace {original} with the MCP description + assert ( + provider._get_description("templated_tool", "Original MCP description") + == "Custom context. Original MCP description" + ) + assert ( + provider._get_description("prepended", "Search for files.") + == "WARNING: Use carefully. Search for files. See docs for details." + ) + + def test_get_tool_config(self): + """Test getting tool config.""" + tool_config = MCPToolConfig( + capability=ToolCapability.LOGS, + description="Test description", + ) + config = MCPServerConfig( + command="test", + tools={"my_tool": tool_config}, + ) + provider = MCPToolProvider(server_name="test", server_config=config) + + assert provider._get_tool_config("my_tool") == tool_config + assert provider._get_tool_config("unknown") is None + + def test_get_tool_config_no_tools_defined(self): + """Test getting tool config when no tools are defined.""" + config = MCPServerConfig(command="test") + provider = MCPToolProvider(server_name="test", server_config=config) + + assert provider._get_tool_config("any_tool") is None + + +class TestMCPToolProviderAsync: + """Test async functionality of MCPToolProvider.""" + + @pytest.mark.asyncio + async def test_tools_returns_empty_list_without_connection(self): + """Test that tools() returns empty list when not connected.""" + config = MCPServerConfig(command="test") + provider = MCPToolProvider(server_name="test", server_config=config) + + # Without connecting, tools should be empty + tools = provider.tools() + assert tools == [] + + @pytest.mark.asyncio + async def test_create_tool_schemas_empty_without_connection(self): + """Test that create_tool_schemas returns empty when not connected.""" + config = MCPServerConfig(command="test") + provider = MCPToolProvider(server_name="test", server_config=config) + + # Without connecting, schemas should be empty + schemas = provider.create_tool_schemas() + assert schemas == [] + + @pytest.mark.asyncio + async def test_call_mcp_tool_not_connected(self): + """Test that _call_mcp_tool returns error when not connected.""" + config = MCPServerConfig(command="test") + provider = MCPToolProvider(server_name="test", server_config=config) + + result = await provider._call_mcp_tool("some_tool", {"arg": "value"}) + assert result["status"] == "error" + assert "not connected" in result["error"] diff --git a/tests/unit/tools/test_tool_manager_protocols.py b/tests/unit/tools/test_tool_manager_protocols.py index e7a12f8d..5be4c08f 100644 --- a/tests/unit/tools/test_tool_manager_protocols.py +++ b/tests/unit/tools/test_tool_manager_protocols.py @@ -33,9 +33,12 @@ async def test_protocol_selection_for_utilities_subset(): tools = mgr.get_tools_for_capability(ToolCapability.UTILITIES) assert tools, "Expected utilities tools for the allowed set" - # Ensure all returned tools are utilities_* and that the allowed subset is present + # Collect ops from utilities_* tools only (MCP tools may also have UTILITIES capability) ops_seen = set() for t in tools: + # Skip MCP tools which have a different naming convention (mcp_servername_hash_toolname) + if t.name.startswith("mcp_"): + continue assert t.name.startswith("utilities_"), f"Unexpected provider prefix: {t.name}" parts = t.name.split("_", 2) op = parts[2] if len(parts) >= 3 else parts[-1] diff --git a/ui/Dockerfile b/ui/Dockerfile index 2bb561a2..ed6c5cb1 100644 --- a/ui/Dockerfile +++ b/ui/Dockerfile @@ -9,13 +9,20 @@ FROM base AS development # Copy package files COPY package*.json ./ +COPY ui-kit/package*.json ./ui-kit/ -# Install all dependencies (including dev dependencies) -RUN npm ci +# Install root dependencies, skipping postinstall (we'll build ui-kit after copying source) +RUN npm ci --ignore-scripts + +# Install ui-kit dependencies +RUN npm --prefix ./ui-kit ci --ignore-scripts # Copy source code COPY . . +# Build ui-kit now that source files are present +RUN npm --prefix ./ui-kit run build + # Expose port EXPOSE 3000 @@ -27,15 +34,19 @@ FROM base AS build # Copy package files COPY package*.json ./ +COPY ui-kit/package*.json ./ui-kit/ + +# Install root dependencies, skipping postinstall +RUN npm ci --ignore-scripts -# Install dependencies -RUN npm ci +# Install ui-kit dependencies +RUN npm --prefix ./ui-kit ci --ignore-scripts # Copy source code COPY . . -# Build the application -RUN npm run build +# Build ui-kit first, then the main app +RUN npm --prefix ./ui-kit run build && npm run build # Production stage FROM nginx:alpine AS production diff --git a/ui/e2e/schedules.spec.ts b/ui/e2e/schedules.spec.ts new file mode 100644 index 00000000..441e6f5a --- /dev/null +++ b/ui/e2e/schedules.spec.ts @@ -0,0 +1,164 @@ +import { test, expect } from '@playwright/test'; + +// E2E tests for schedule creation/update functionality. +// Validates that the form correctly sends interval_type and interval_value to the backend. +// +// NOTE: These tests require: +// 1. Backend API running on port 8000 +// 2. Frontend dev server (npm run dev) running on port 3000 (or 3002 via docker) +// +// The tests validate the critical fix that the schedule form sends interval_type +// and interval_value instead of cron_expression. + +const API_BASE = 'http://localhost:8000/api/v1'; +const uniqueSuffix = () => `${Date.now()}`; + +test.describe('Schedules API payload validation', () => { + // This test validates the API contract directly without relying on the UI + // loading correctly - useful for CI environments where UI tests may be flaky + test('schedule API accepts interval_type and interval_value', async ({ request }) => { + const scheduleName = `E2E API Test ${uniqueSuffix()}`; + + // Create a schedule using the correct payload format + const createResponse = await request.post(`${API_BASE}/schedules/`, { + data: { + name: scheduleName, + interval_type: 'days', + interval_value: 1, + instructions: 'E2E test instructions', + enabled: true, + }, + }); + + expect(createResponse.ok()).toBeTruthy(); + const createdSchedule = await createResponse.json(); + expect(createdSchedule).toHaveProperty('id'); + expect(createdSchedule).toHaveProperty('name', scheduleName); + expect(createdSchedule).toHaveProperty('interval_type', 'days'); + expect(createdSchedule).toHaveProperty('interval_value', 1); + + // Cleanup + const deleteResponse = await request.delete(`${API_BASE}/schedules/${createdSchedule.id}`); + expect(deleteResponse.ok()).toBeTruthy(); + }); + + test('schedule API rejects payload with cron_expression but no interval fields', async ({ request }) => { + const scheduleName = `E2E Invalid Test ${uniqueSuffix()}`; + + // This payload matches what the bug was producing - cron_expression without interval fields + const createResponse = await request.post(`${API_BASE}/schedules/`, { + data: { + name: scheduleName, + cron_expression: '*/1 * * * *', // This was the bug - sending cron instead of interval + instructions: 'E2E test instructions', + enabled: true, + }, + }); + + // The API should reject this payload because interval_type and interval_value are required + expect(createResponse.ok()).toBeFalsy(); + expect(createResponse.status()).toBe(422); // Validation error + }); + + test('schedule update API accepts interval_type and interval_value', async ({ request }) => { + const scheduleName = `E2E Update API Test ${uniqueSuffix()}`; + + // First create a schedule + const createResponse = await request.post(`${API_BASE}/schedules/`, { + data: { + name: scheduleName, + interval_type: 'hours', + interval_value: 2, + instructions: 'Initial instructions', + enabled: true, + }, + }); + + expect(createResponse.ok()).toBeTruthy(); + const createdSchedule = await createResponse.json(); + + try { + // Update the schedule with new interval values + const updateResponse = await request.put(`${API_BASE}/schedules/${createdSchedule.id}`, { + data: { + name: scheduleName, + interval_type: 'days', + interval_value: 7, + instructions: 'Updated instructions', + enabled: true, + }, + }); + + expect(updateResponse.ok()).toBeTruthy(); + const updatedSchedule = await updateResponse.json(); + expect(updatedSchedule).toHaveProperty('interval_type', 'days'); + expect(updatedSchedule).toHaveProperty('interval_value', 7); + } finally { + // Cleanup + await request.delete(`${API_BASE}/schedules/${createdSchedule.id}`); + } + }); +}); + +test.describe('Schedules UI form', () => { + test.skip('create schedule form sends correct payload', async ({ page }) => { + // NOTE: This test is skipped because it requires the UI to load correctly, + // which depends on proper frontend/backend connectivity in the test environment. + // The API tests above validate the same functionality at the API level. + // + // To run this test locally: + // 1. Start the backend: uv run uvicorn redis_sre_agent.api.app:app --port 8000 + // 2. Start the frontend: cd ui && npm run dev + // 3. Run: cd ui && npm run e2e -- --grep "create schedule form" + + const scheduleName = `E2E UI Schedule ${uniqueSuffix()}`; + let scheduleId: string | undefined; + + await page.goto('/schedules'); + + // Wait for the page to load + await expect(page.getByRole('heading', { name: 'Schedules' })).toBeVisible({ timeout: 15_000 }); + + // Click Create Schedule button + await page.getByRole('button', { name: 'Create Schedule' }).first().click(); + + // Wait for modal + await expect(page.getByText('Create New Schedule')).toBeVisible(); + + // Fill form + await page.getByPlaceholder('e.g., Daily Health Check').fill(scheduleName); + await page.locator('select[name="interval_type"]').first().selectOption('days'); + await page.getByPlaceholder('e.g., 30').first().fill('1'); + await page.getByPlaceholder('Instructions for the agent to execute...').first().fill('E2E test'); + + // Intercept API request + const requestPromise = page.waitForRequest((req) => + req.url().includes('/api/v1/schedules') && req.method() === 'POST' + ); + + // Submit + await page.locator('form').getByRole('button', { name: 'Create Schedule' }).click(); + + // Validate payload + const request = await requestPromise; + const postData = request.postDataJSON(); + expect(postData).toHaveProperty('interval_type', 'days'); + expect(postData).toHaveProperty('interval_value', 1); + expect(postData).not.toHaveProperty('cron_expression'); + + // Get schedule ID for cleanup + const response = await page.waitForResponse((res) => + res.url().includes('/api/v1/schedules') && res.request().method() === 'POST' + ); + + if (response.ok()) { + const data = await response.json(); + scheduleId = data.id; + } + + // Cleanup + if (scheduleId) { + await page.request.delete(`${API_BASE}/schedules/${scheduleId}`); + } + }); +}); diff --git a/ui/e2e/support/cleanup.mjs b/ui/e2e/support/cleanup.mjs index e0625d76..5096ff92 100644 --- a/ui/e2e/support/cleanup.mjs +++ b/ui/e2e/support/cleanup.mjs @@ -1,6 +1,6 @@ const base = process.env.API_BASE_URL || 'http://localhost:8000/api/v1'; -const E2E_PATTERNS = [/^e2e\b/i, /^e2e\s+hello/i, /^e2e\s+streaming/i, /^e2e\s+persistence/i]; +const E2E_PATTERNS = [/^e2e\b/i, /^e2e\s+hello/i, /^e2e\s+streaming/i, /^e2e\s+persistence/i, /^e2e\s+schedule/i, /^e2e\s+update\s+test/i]; const withinHours = (iso, hours = 72) => { try { return Date.now() - new Date(iso).getTime() < hours * 3600 * 1000; } catch { return false; } @@ -21,7 +21,37 @@ async function deleteThread(id) { if (!res.ok) throw new Error(`Failed to delete ${id}: ${res.status} ${res.statusText}`); } +async function listSchedules() { + const res = await fetch(`${base}/schedules/`); + if (!res.ok) throw new Error(`Failed to list schedules: ${res.status} ${res.statusText}`); + return res.json(); +} + +async function deleteSchedule(id) { + const res = await fetch(`${base}/schedules/${id}`, { method: 'DELETE' }); + if (!res.ok) throw new Error(`Failed to delete schedule ${id}: ${res.status} ${res.statusText}`); +} + +async function cleanupSchedules() { + try { + const schedules = await listSchedules(); + let deleted = 0; + for (const s of schedules) { + const name = s.name || ''; + const recent = withinHours(s.updated_at || s.created_at || ''); + if (matchesE2E(name) || (name.toLowerCase().startsWith('e2e ') && recent)) { + try { await deleteSchedule(s.id); deleted++; } + catch (e) { console.warn(`Could not delete schedule ${s.id}: ${e.message}`); } + } + } + console.log(`[global cleanup] Deleted ${deleted} E2E schedules`); + } catch (e) { + console.warn(`[global cleanup] Schedule cleanup failed: ${e.message}`); + } +} + export default async function cleanup() { + // Clean up threads try { const threads = await listThreads(1000); let deleted = 0; @@ -35,6 +65,9 @@ export default async function cleanup() { } console.log(`[global cleanup] Deleted ${deleted} E2E threads`); } catch (e) { - console.warn(`[global cleanup] Failed: ${e.message}`); + console.warn(`[global cleanup] Thread cleanup failed: ${e.message}`); } + + // Clean up schedules + await cleanupSchedules(); } diff --git a/ui/package.json b/ui/package.json index fb04a039..8af05e2b 100644 --- a/ui/package.json +++ b/ui/package.json @@ -12,7 +12,7 @@ "e2e:ui": "playwright test --ui", "format": "npx prettier --write \"src/**/*.{ts,tsx,js,jsx,css,md}\"", "format:check": "npx prettier --check \"src/**/*.{ts,tsx,js,jsx,css,md}\"", - "postinstall": "npm --prefix ./ui-kit ci && npm --prefix ./ui-kit run build", + "postinstall": "npm --prefix ./ui-kit ci --ignore-scripts && npm --prefix ./ui-kit run build", "preview": "vite preview" }, "dependencies": { diff --git a/ui/scripts/cleanup-e2e.mjs b/ui/scripts/cleanup-e2e.mjs index 27309609..e4d2d311 100644 --- a/ui/scripts/cleanup-e2e.mjs +++ b/ui/scripts/cleanup-e2e.mjs @@ -1,4 +1,4 @@ -// Cleanup script for E2E-created threads in the Redis SRE Agent backend +// Cleanup script for E2E-created threads and schedules in the Redis SRE Agent backend // Usage: // API_BASE_URL=http://localhost:8000/api/v1 node scripts/cleanup-e2e.mjs // Defaults to localhost if API_BASE_URL not set @@ -10,6 +10,8 @@ const E2E_PATTERNS = [ /^e2e\s+hello/i, /^e2e\s+streaming/i, /^e2e\s+persistence/i, + /^e2e\s+schedule/i, + /^e2e\s+update\s+test/i, ]; const withinHours = (iso, hours = 24) => { @@ -40,7 +42,19 @@ async function deleteThread(id) { if (!res.ok) throw new Error(`Failed to delete ${id}: ${res.status} ${res.statusText}`); } +async function listSchedules() { + const res = await fetch(`${base}/schedules/`); + if (!res.ok) throw new Error(`Failed to list schedules: ${res.status} ${res.statusText}`); + return res.json(); +} + +async function deleteSchedule(id) { + const res = await fetch(`${base}/schedules/${id}`, { method: 'DELETE' }); + if (!res.ok) throw new Error(`Failed to delete schedule ${id}: ${res.status} ${res.statusText}`); +} + (async () => { + // Clean up threads try { const threads = await listThreads(1000); let deleted = 0; @@ -57,9 +71,30 @@ async function deleteThread(id) { } } } - console.log(`Cleanup complete. Deleted ${deleted} threads.`); + console.log(`Thread cleanup complete. Deleted ${deleted} threads.`); + } catch (e) { + console.error(`Thread cleanup failed: ${e.message}`); + } + + // Clean up schedules + try { + const schedules = await listSchedules(); + let deleted = 0; + for (const s of schedules) { + const name = s.name || ''; + const recent = withinHours(s.updated_at || s.created_at || '', 72); + if (matchesE2E(name) || (name.toLowerCase().startsWith('e2e ') && recent)) { + try { + await deleteSchedule(s.id); + deleted++; + console.log(`Deleted E2E schedule: ${s.id} (${name})`); + } catch (e) { + console.warn(`Could not delete schedule ${s.id}: ${e.message}`); + } + } + } + console.log(`Schedule cleanup complete. Deleted ${deleted} schedules.`); } catch (e) { - console.error(`Cleanup failed: ${e.message}`); - process.exitCode = 1; + console.error(`Schedule cleanup failed: ${e.message}`); } })(); diff --git a/ui/src/components/TaskMonitor.tsx b/ui/src/components/TaskMonitor.tsx index 43a900dc..7d64786c 100644 --- a/ui/src/components/TaskMonitor.tsx +++ b/ui/src/components/TaskMonitor.tsx @@ -53,10 +53,12 @@ const TaskMonitor: React.FC = ({ const refetchTimeoutRef = useRef(null); const mapThreadToMessages = (threadData: any): ChatMessage[] => { - // Start with any persisted transcript messages - const baseMsgs: any[] = Array.isArray(threadData?.context?.messages) - ? threadData.context.messages - : []; + // Messages are now at threadData.messages (top-level), with fallback to context.messages for old data + const baseMsgs: any[] = Array.isArray(threadData?.messages) + ? threadData.messages + : Array.isArray(threadData?.context?.messages) + ? threadData.context.messages + : []; const out: ChatMessage[] = []; @@ -79,17 +81,18 @@ const TaskMonitor: React.FC = ({ baseMsgs.forEach((msg: any, index: number) => { if (!msg || !msg.content) return; out.push({ - id: `${msg.role}-${index}-${msg.timestamp || index}`, + id: `${msg.role}-${index}-${msg.metadata?.timestamp || index}`, role: msg.role, content: msg.content, timestamp: - msg.timestamp || + msg.metadata?.timestamp || threadData?.metadata?.updated_at || new Date().toISOString(), }); }); - // Merge in live updates as assistant/user bubbles even when context.messages exists + // Merge in live updates as assistant/user bubbles even when messages exist + // Updates now come from the latest task, not the thread directly // This ensures reflections and interim responses are visible during the turn. const seen = new Set(out.map((m) => `${m.role}::${m.content}`)); const updates = Array.isArray(threadData?.updates) diff --git a/ui/src/pages/Dashboard.tsx b/ui/src/pages/Dashboard.tsx index 4801c8c7..5392fc06 100644 --- a/ui/src/pages/Dashboard.tsx +++ b/ui/src/pages/Dashboard.tsx @@ -67,10 +67,8 @@ const Dashboard = () => { const threadsPromise = sreAgentApi.listThreads(undefined, 10, 0); const instancesPromise = sreAgentApi.listInstances(); - const knowledgePromise = fetch("/api/v1/knowledge/stats").then((res) => - res.json(), - ); - const healthPromise = fetch("/api/v1/health").then((res) => res.json()); + const knowledgePromise = sreAgentApi.getKnowledgeStats(); + const healthPromise = sreAgentApi.getSystemHealth(); const [threadsRes, instancesRes, knowledgeRes, healthRes] = await Promise.allSettled([ diff --git a/ui/src/pages/Knowledge.tsx b/ui/src/pages/Knowledge.tsx index 25719958..d8e2b1b0 100644 --- a/ui/src/pages/Knowledge.tsx +++ b/ui/src/pages/Knowledge.tsx @@ -1,5 +1,6 @@ import { useState, useEffect } from "react"; import { Card, CardHeader, CardContent, Button } from "@radar/ui-kit"; +import { sreAgentApi } from "../services/sreAgentApi"; interface KnowledgeStats { total_documents: number; @@ -102,40 +103,16 @@ const Knowledge = () => { console.log("Loading knowledge data..."); - // Load real knowledge base data - const [statsResponse, jobsResponse] = await Promise.all([ - fetch("/api/v1/knowledge/stats"), - fetch("/api/v1/knowledge/jobs"), - ]); - - console.log("Response status:", { - stats: statsResponse.status, - jobs: jobsResponse.status, - }); - - if (!statsResponse.ok || !jobsResponse.ok) { - const errorDetails = { - stats: statsResponse.ok - ? "OK" - : `${statsResponse.status} ${statsResponse.statusText}`, - jobs: jobsResponse.ok - ? "OK" - : `${jobsResponse.status} ${jobsResponse.statusText}`, - }; - throw new Error( - `Failed to load knowledge data: ${JSON.stringify(errorDetails)}`, - ); - } - + // Load real knowledge base data using the API service const [statsData, jobsData] = await Promise.all([ - statsResponse.json(), - jobsResponse.json(), + sreAgentApi.getKnowledgeStats(), + sreAgentApi.getKnowledgeJobs(), ]); console.log("Data loaded successfully:", { statsData, jobsData }); - setStats(statsData); - setIngestionJobs(jobsData.jobs || []); + setStats(statsData as KnowledgeStats); + setIngestionJobs((jobsData as any).jobs || []); } catch (err) { console.error("Error loading knowledge data:", err); setError(err instanceof Error ? err.message : "Unknown error occurred"); @@ -151,23 +128,14 @@ const Knowledge = () => { } try { - const response = await fetch("/api/v1/knowledge/ingest/document", { - method: "POST", - headers: { "Content-Type": "application/json" }, - body: JSON.stringify({ - title: "User Added Content", - content: ingestionText, - source: "web_ui", - category: "general", - severity: "info", - }), - }); - - if (!response.ok) { - throw new Error("Failed to ingest document"); - } - - const result = await response.json(); + const result = await sreAgentApi.ingestDocument( + "User Added Content", + ingestionText, + "general", + "runbook", + "info", + ); + console.log("Ingestion result:", result); setShowIngestionForm(false); @@ -182,7 +150,7 @@ const Knowledge = () => { const searchKnowledgeBase = async ( query?: string, - thresholdOverride?: number, + _thresholdOverride?: number, ) => { const queryToUse = query || searchQuery; if (!queryToUse.trim()) { @@ -194,30 +162,20 @@ const Knowledge = () => { setIsSearching(true); setError(null); - const thresholdToUse = - typeof thresholdOverride === "number" - ? thresholdOverride - : distanceThreshold; - const params = new URLSearchParams({ - query: queryToUse, - limit: "10", - distance_threshold: String(thresholdToUse), - }); - - if (searchCategory) { - params.append("category", searchCategory); - } - - const response = await fetch(`/api/v1/knowledge/search?${params}`); - - if (!response.ok) { - throw new Error("Failed to search knowledge base"); - } + const result = await sreAgentApi.searchKnowledge( + queryToUse, + 10, + searchCategory || undefined, + ); - const result: SearchResponse = await response.json(); console.log("Search result:", result); - setSearchResults(result.results || []); + setSearchResults( + (result.results || []).map((r) => ({ + ...r, + severity: "info", // Default severity since API doesn't return it + })), + ); setExpandedResults(new Set()); // Clear expanded state on new search } catch (err) { console.error("Search error:", err); diff --git a/ui/src/pages/Schedules.tsx b/ui/src/pages/Schedules.tsx index 67eadc8a..73881ab2 100644 --- a/ui/src/pages/Schedules.tsx +++ b/ui/src/pages/Schedules.tsx @@ -65,20 +65,16 @@ const Schedules = () => { const loadData = async () => { try { setError(null); - const schedulesPromise = fetch("/api/v1/schedules/"); - const instancesPromise = sreAgentApi.listInstances(); - const [schedulesRes, instancesRes] = await Promise.allSettled([ - schedulesPromise, - instancesPromise, + sreAgentApi.listSchedules(), + sreAgentApi.listInstances(), ]); - if (schedulesRes.status !== "fulfilled" || !schedulesRes.value.ok) { + if (schedulesRes.status !== "fulfilled") { throw new Error("Failed to load schedules"); } - const schedulesData = await schedulesRes.value.json(); - setSchedules(schedulesData); + setSchedules(schedulesRes.value); if (instancesRes.status === "fulfilled") { // Map API instances to minimal shape used by this page @@ -105,26 +101,15 @@ const Schedules = () => { setError(null); const scheduleData = { name: formData.get("name") as string, - description: (formData.get("description") as string) || undefined, interval_type: formData.get("interval_type") as string, - interval_value: parseInt(formData.get("interval_value") as string), + interval_value: parseInt(formData.get("interval_value") as string, 10), redis_instance_id: (formData.get("redis_instance_id") as string) || undefined, instructions: formData.get("instructions") as string, enabled: formData.get("enabled") === "on", }; - const response = await fetch("/api/v1/schedules/", { - method: "POST", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify(scheduleData), - }); - - if (!response.ok) { - throw new Error("Failed to create schedule"); - } + await sreAgentApi.createSchedule(scheduleData); await loadData(); setShowCreateForm(false); @@ -143,26 +128,15 @@ const Schedules = () => { setError(null); const updateData = { name: formData.get("name") as string, - description: (formData.get("description") as string) || undefined, interval_type: formData.get("interval_type") as string, - interval_value: parseInt(formData.get("interval_value") as string), + interval_value: parseInt(formData.get("interval_value") as string, 10), redis_instance_id: (formData.get("redis_instance_id") as string) || undefined, instructions: formData.get("instructions") as string, enabled: formData.get("enabled") === "on", }; - const response = await fetch(`/api/v1/schedules/${scheduleId}`, { - method: "PUT", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify(updateData), - }); - - if (!response.ok) { - throw new Error("Failed to update schedule"); - } + await sreAgentApi.updateSchedule(scheduleId, updateData); await loadData(); setEditingSchedule(null); @@ -187,14 +161,7 @@ const Schedules = () => { try { setError(null); - const response = await fetch(`/api/v1/schedules/${scheduleId}`, { - method: "DELETE", - }); - - if (!response.ok) { - throw new Error("Failed to delete schedule"); - } - + await sreAgentApi.deleteSchedule(scheduleId); await loadData(); } catch (err) { setError( @@ -206,14 +173,7 @@ const Schedules = () => { const handleTriggerSchedule = async (scheduleId: string) => { try { setError(null); - const response = await fetch(`/api/v1/schedules/${scheduleId}/trigger`, { - method: "POST", - }); - - if (!response.ok) { - throw new Error("Failed to trigger schedule"); - } - + await sreAgentApi.triggerSchedule(scheduleId); alert("Schedule triggered successfully!"); } catch (err) { setError( @@ -225,13 +185,7 @@ const Schedules = () => { const handleViewRuns = async (schedule: Schedule) => { try { setError(null); - const response = await fetch(`/api/v1/schedules/${schedule.id}/runs`); - - if (!response.ok) { - throw new Error("Failed to load schedule runs"); - } - - const runs = await response.json(); + const runs = await sreAgentApi.getScheduleRuns(schedule.id); setSelectedScheduleRuns(runs); setShowRunsModal(true); } catch (err) { diff --git a/ui/src/pages/Settings.tsx b/ui/src/pages/Settings.tsx index 111eea54..42486635 100644 --- a/ui/src/pages/Settings.tsx +++ b/ui/src/pages/Settings.tsx @@ -9,6 +9,7 @@ import { ErrorMessage, } from "@radar/ui-kit"; import Instances from "./Instances"; +import { sreAgentApi } from "../services/sreAgentApi"; interface KnowledgeSettings { chunk_size: number; @@ -39,12 +40,8 @@ const KnowledgeSettingsSection = () => { const loadSettings = async () => { try { setError(null); - const response = await fetch("/api/v1/knowledge/settings"); - if (!response.ok) { - throw new Error("Failed to load knowledge settings"); - } - const data = await response.json(); - setSettings(data); + const data = await sreAgentApi.getKnowledgeSettings(); + setSettings(data as KnowledgeSettings); } catch (err) { setError(err instanceof Error ? err.message : "Failed to load settings"); } finally { @@ -64,19 +61,8 @@ const KnowledgeSettingsSection = () => { setIsSaving(true); setError(null); - const response = await fetch("/api/v1/knowledge/settings", { - method: "PUT", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify(pendingSettings), - }); - - if (!response.ok) { - throw new Error("Failed to update settings"); - } - - const updatedSettings = await response.json(); + const updatedSettings = + await sreAgentApi.updateKnowledgeSettings(pendingSettings); setSettings(updatedSettings); setShowConfirmDialog(false); setPendingSettings(null); @@ -99,15 +85,7 @@ const KnowledgeSettingsSection = () => { setIsSaving(true); setError(null); - const response = await fetch("/api/v1/knowledge/settings/reset", { - method: "POST", - }); - - if (!response.ok) { - throw new Error("Failed to reset settings"); - } - - const defaultSettings = await response.json(); + const defaultSettings = await sreAgentApi.resetKnowledgeSettings(); setSettings(defaultSettings); } catch (err) { setError(err instanceof Error ? err.message : "Failed to reset settings"); diff --git a/ui/src/services/sreAgentApi.ts b/ui/src/services/sreAgentApi.ts index 183963ea..e082c291 100644 --- a/ui/src/services/sreAgentApi.ts +++ b/ui/src/services/sreAgentApi.ts @@ -21,6 +21,13 @@ export interface TaskStatusResponse { | "done" | "failed" | "cancelled"; + // Messages are now at top level (conversation history) + messages: Array<{ + role: string; + content: string; + metadata?: Record; + }>; + // Updates come from the latest task (progress updates, not conversation) updates: TaskUpdate[]; result?: Record; error_message?: string; @@ -306,16 +313,28 @@ class SREAgentAPI { } const thread = await response.json(); - // Derive a task-like status from thread data - const status = thread?.error_message - ? "failed" - : thread?.result - ? "completed" - : "in_progress"; + + // Messages are now at top level; fall back to context.messages for old data + const messages = Array.isArray(thread.messages) + ? thread.messages + : Array.isArray(thread?.context?.messages) + ? thread.context.messages + : []; + + // Derive status: if we have messages, likely completed; otherwise in_progress + // Note: updates/result/error_message come from latest task, not thread + const hasResponse = messages.some((m: any) => m.role === "assistant"); + const status = hasResponse ? "completed" : "in_progress"; return { thread_id: thread.thread_id, status, + messages: messages.map((m: any) => ({ + role: m.role, + content: m.content, + metadata: m.metadata, + })), + // Updates may come from the API if backend provides them from latest task updates: Array.isArray(thread.updates) ? thread.updates.map((u: any) => ({ timestamp: u.timestamp, @@ -497,11 +516,20 @@ class SREAgentAPI { }; } - // Unified transcript helper: prefer context.messages; fallback to updates + // Unified transcript helper: prefer top-level messages; fallback to context.messages and updates async getTranscript(threadId: string): Promise { const status = await this.getTaskStatus(threadId); - // Preferred: context.messages contains the entire transcript + // Preferred: top-level messages contains the entire transcript + if (status.messages && status.messages.length > 0) { + return status.messages.map((msg: any) => ({ + role: msg.role, + content: msg.content, + timestamp: msg.metadata?.timestamp || status.metadata.updated_at, + })) as ChatMessage[]; + } + + // Fallback for old data: context.messages const ctxMsgs = Array.isArray(status?.context?.messages) ? status.context.messages : []; @@ -509,11 +537,11 @@ class SREAgentAPI { return ctxMsgs.map((msg: any) => ({ role: msg.role, content: msg.content, - timestamp: msg.timestamp, + timestamp: msg.timestamp || status.metadata.updated_at, })) as ChatMessage[]; } - // Fallback: reconstruct from updates and metadata + // Last resort: reconstruct from updates and metadata const messages: ChatMessage[] = []; const initial = (status.context as any)?.original_query || status.metadata.subject; @@ -861,6 +889,238 @@ class SREAgentAPI { return response.json(); } + + // Knowledge Base Methods + async getKnowledgeStats(): Promise<{ + total_documents: number; + total_chunks: number; + last_ingestion: string | null; + }> { + const response = await fetch(`${this.tasksBaseUrl}/knowledge/stats`); + if (!response.ok) { + throw new Error(`Failed to get knowledge stats: ${response.statusText}`); + } + return response.json(); + } + + async getKnowledgeJobs(): Promise { + const response = await fetch(`${this.tasksBaseUrl}/knowledge/jobs`); + if (!response.ok) { + throw new Error(`Failed to get knowledge jobs: ${response.statusText}`); + } + return response.json(); + } + + async searchKnowledge( + query: string, + limit: number = 10, + category?: string, + ): Promise<{ + query: string; + results: Array<{ + id: string; + title: string; + content: string; + source: string; + category: string; + score: number; + }>; + total_results: number; + }> { + const params = new URLSearchParams(); + params.append("query", query); + params.append("limit", String(limit)); + if (category) { + params.append("category", category); + } + + const response = await fetch( + `${this.tasksBaseUrl}/knowledge/search?${params}`, + ); + if (!response.ok) { + throw new Error( + `Failed to search knowledge base: ${response.statusText}`, + ); + } + return response.json(); + } + + async ingestDocument( + title: string, + content: string, + category: string = "general", + docType: string = "runbook", + severity: string = "info", + ): Promise<{ message: string; document_id?: string }> { + const response = await fetch( + `${this.tasksBaseUrl}/knowledge/ingest/document`, + { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + title, + content, + category, + doc_type: docType, + severity, + }), + }, + ); + + if (!response.ok) { + throw new Error(`Failed to ingest document: ${response.statusText}`); + } + return response.json(); + } + + // System Health Methods + async getSystemHealth(): Promise<{ + status: string; + components: Record; + version?: string; + }> { + const response = await fetch(`${this.tasksBaseUrl}/health`); + if (!response.ok) { + throw new Error(`Failed to get system health: ${response.statusText}`); + } + return response.json(); + } + + // Schedule Methods + async listSchedules(): Promise { + const response = await fetch(`${this.tasksBaseUrl}/schedules/`); + if (!response.ok) { + throw new Error(`Failed to list schedules: ${response.statusText}`); + } + return response.json(); + } + + async createSchedule(scheduleData: { + name: string; + interval_type: string; + interval_value: number; + redis_instance_id?: string; + instructions: string; + enabled: boolean; + }): Promise { + const response = await fetch(`${this.tasksBaseUrl}/schedules/`, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify(scheduleData), + }); + if (!response.ok) { + throw new Error(`Failed to create schedule: ${response.statusText}`); + } + return response.json(); + } + + async updateSchedule( + scheduleId: string, + updateData: { + name?: string; + interval_type?: string; + interval_value?: number; + redis_instance_id?: string; + instructions?: string; + enabled?: boolean; + }, + ): Promise { + const response = await fetch( + `${this.tasksBaseUrl}/schedules/${scheduleId}`, + { + method: "PUT", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify(updateData), + }, + ); + if (!response.ok) { + throw new Error(`Failed to update schedule: ${response.statusText}`); + } + return response.json(); + } + + async deleteSchedule(scheduleId: string): Promise { + const response = await fetch( + `${this.tasksBaseUrl}/schedules/${scheduleId}`, + { + method: "DELETE", + }, + ); + if (!response.ok) { + throw new Error(`Failed to delete schedule: ${response.statusText}`); + } + } + + async triggerSchedule(scheduleId: string): Promise { + const response = await fetch( + `${this.tasksBaseUrl}/schedules/${scheduleId}/trigger`, + { method: "POST" }, + ); + if (!response.ok) { + throw new Error(`Failed to trigger schedule: ${response.statusText}`); + } + return response.json(); + } + + async getScheduleRuns(scheduleId: string): Promise { + const response = await fetch( + `${this.tasksBaseUrl}/schedules/${scheduleId}/runs`, + ); + if (!response.ok) { + throw new Error(`Failed to get schedule runs: ${response.statusText}`); + } + return response.json(); + } + + // Knowledge Settings Methods + async getKnowledgeSettings(): Promise<{ + chunk_size: number; + chunk_overlap: number; + splitting_strategy: string; + embedding_model: string; + }> { + const response = await fetch(`${this.tasksBaseUrl}/knowledge/settings`); + if (!response.ok) { + throw new Error( + `Failed to get knowledge settings: ${response.statusText}`, + ); + } + return response.json(); + } + + async updateKnowledgeSettings(settings: { + chunk_size?: number; + chunk_overlap?: number; + splitting_strategy?: string; + embedding_model?: string; + }): Promise { + const response = await fetch(`${this.tasksBaseUrl}/knowledge/settings`, { + method: "PUT", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify(settings), + }); + if (!response.ok) { + throw new Error( + `Failed to update knowledge settings: ${response.statusText}`, + ); + } + return response.json(); + } + + async resetKnowledgeSettings(): Promise { + const response = await fetch( + `${this.tasksBaseUrl}/knowledge/settings/reset`, + { + method: "POST", + }, + ); + if (!response.ok) { + throw new Error( + `Failed to reset knowledge settings: ${response.statusText}`, + ); + } + return response.json(); + } } // Export singleton instance diff --git a/ui/ui-kit/package.json b/ui/ui-kit/package.json index 60ee2b5e..1a49d032 100644 --- a/ui/ui-kit/package.json +++ b/ui/ui-kit/package.json @@ -25,8 +25,8 @@ "type": "module", "exports": { ".": { - "import": "./dist/index.js", - "types": "./dist/index.d.ts" + "types": "./dist/index.d.ts", + "import": "./dist/index.js" }, "./styles": "./dist/styles.css" }, @@ -46,9 +46,6 @@ "format:check": "prettier --check \"src/**/*.{ts,tsx}\"", "lint": "eslint src --ext .ts,.tsx", "lint:fix": "eslint src --ext .ts,.tsx --fix", - "pre-commit": "pre-commit run --all-files", - "pre-commit:install": "pre-commit install --hook-type pre-commit --hook-type pre-push", - "prepare": "pre-commit install", "prepublishOnly": "npm run clean && npm run build", "storybook": "storybook dev -p 6006", "test": "vitest run", diff --git a/uv.lock b/uv.lock index 39ea1f0f..441e9e73 100644 --- a/uv.lock +++ b/uv.lock @@ -1118,6 +1118,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" }, ] +[[package]] +name = "httpx-sse" +version = "0.4.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0f/4c/751061ffa58615a32c31b2d82e8482be8dd4a89154f003147acee90f2be9/httpx_sse-0.4.3.tar.gz", hash = "sha256:9b1ed0127459a66014aec3c56bebd93da3c1bc8bb6618c8082039a44889a755d", size = 15943, upload-time = "2025-10-10T21:48:22.271Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d2/fd/6668e5aec43ab844de6fc74927e155a3b37bf40d7c3790e49fc0406b6578/httpx_sse-0.4.3-py3-none-any.whl", hash = "sha256:0ac1c9fe3c0afad2e0ebb25a934a59f4c7823b60792691f779fad2c5568830fc", size = 8960, upload-time = "2025-10-10T21:48:21.158Z" }, +] + [[package]] name = "huggingface-hub" version = "0.34.4" @@ -1810,6 +1819,31 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/27/1a/1f68f9ba0c207934b35b86a8ca3aad8395a3d6dd7921c0686e23853ff5a9/mccabe-0.7.0-py2.py3-none-any.whl", hash = "sha256:6c2d30ab6be0e4a46919781807b4f0d834ebdd6c6e3dca0bda5a15f863427b6e", size = 7350, upload-time = "2022-01-24T01:14:49.62Z" }, ] +[[package]] +name = "mcp" +version = "1.23.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "httpx" }, + { name = "httpx-sse" }, + { name = "jsonschema" }, + { name = "pydantic" }, + { name = "pydantic-settings" }, + { name = "pyjwt", extra = ["crypto"] }, + { name = "python-multipart" }, + { name = "pywin32", marker = "sys_platform == 'win32'" }, + { name = "sse-starlette" }, + { name = "starlette" }, + { name = "typing-extensions" }, + { name = "typing-inspection" }, + { name = "uvicorn", marker = "sys_platform != 'emscripten'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a7/a4/d06a303f45997e266f2c228081abe299bbcba216cb806128e2e49095d25f/mcp-1.23.3.tar.gz", hash = "sha256:b3b0da2cc949950ce1259c7bfc1b081905a51916fcd7c8182125b85e70825201", size = 600697, upload-time = "2025-12-09T16:04:37.351Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/32/c6/13c1a26b47b3f3a3b480783001ada4268917c9f42d78a079c336da2e75e5/mcp-1.23.3-py3-none-any.whl", hash = "sha256:32768af4b46a1b4f7df34e2bfdf5c6011e7b63d7f1b0e321d0fdef4cd6082031", size = 231570, upload-time = "2025-12-09T16:04:35.56Z" }, +] + [[package]] name = "mdurl" version = "0.1.2" @@ -3220,6 +3254,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" }, ] +[[package]] +name = "pyjwt" +version = "2.10.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e7/46/bd74733ff231675599650d3e47f361794b22ef3e3770998dda30d3b63726/pyjwt-2.10.1.tar.gz", hash = "sha256:3cc5772eb20009233caf06e9d8a0577824723b44e6648ee0a2aedb6cf9381953", size = 87785, upload-time = "2024-11-28T03:43:29.933Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/61/ad/689f02752eeec26aed679477e80e632ef1b682313be70793d798c1d5fc8f/PyJWT-2.10.1-py3-none-any.whl", hash = "sha256:dcdd193e30abefd5debf142f9adfcdd2b58004e644f25406ffaebd50bd98dacb", size = 22997, upload-time = "2024-11-28T03:43:27.893Z" }, +] + +[package.optional-dependencies] +crypto = [ + { name = "cryptography" }, +] + [[package]] name = "pylint" version = "4.0.4" @@ -3471,7 +3519,9 @@ dependencies = [ { name = "langgraph" }, { name = "langgraph-checkpoint-redis" }, { name = "markdownify" }, + { name = "mcp" }, { name = "nbformat" }, + { name = "nltk" }, { name = "openai" }, { name = "opentelemetry-api" }, { name = "opentelemetry-exporter-otlp-proto-http" }, @@ -3531,7 +3581,9 @@ requires-dist = [ { name = "langgraph", specifier = ">=0.2.0" }, { name = "langgraph-checkpoint-redis", specifier = ">=0.1.0" }, { name = "markdownify", specifier = ">=0.11.6" }, + { name = "mcp", specifier = ">=1.23.3" }, { name = "nbformat", specifier = ">=5.9.0" }, + { name = "nltk", specifier = ">=3.9.1" }, { name = "openai", specifier = ">=1.0.0" }, { name = "opentelemetry-api", specifier = ">=1.21.0" }, { name = "opentelemetry-exporter-otlp-proto-http", specifier = ">=1.34.0" }, @@ -4206,6 +4258,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b8/d9/13bdde6521f322861fab67473cec4b1cc8999f3871953531cf61945fad92/sqlalchemy-2.0.43-py3-none-any.whl", hash = "sha256:1681c21dd2ccee222c2fe0bef671d1aef7c504087c9c4e800371cfcc8ac966fc", size = 1924759, upload-time = "2025-08-11T15:39:53.024Z" }, ] +[[package]] +name = "sse-starlette" +version = "3.0.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/db/3c/fa6517610dc641262b77cc7bf994ecd17465812c1b0585fe33e11be758ab/sse_starlette-3.0.3.tar.gz", hash = "sha256:88cfb08747e16200ea990c8ca876b03910a23b547ab3bd764c0d8eb81019b971", size = 21943, upload-time = "2025-10-30T18:44:20.117Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/23/a0/984525d19ca5c8a6c33911a0c164b11490dd0f90ff7fd689f704f84e9a11/sse_starlette-3.0.3-py3-none-any.whl", hash = "sha256:af5bf5a6f3933df1d9c7f8539633dc8444ca6a97ab2e2a7cd3b6e431ac03a431", size = 11765, upload-time = "2025-10-30T18:44:18.834Z" }, +] + [[package]] name = "starlette" version = "0.47.2"