Skip to content

Commit 1a1be72

Browse files
authored
Merge pull request #23 from redis-applied-ai/feat/new-tool-provider-unblocks-mcp
refactor(tools): Refactor tool provider system for MCP support
2 parents 7954236 + 938cec5 commit 1a1be72

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+3853
-3603
lines changed

docs/how-to/tool-providers.md

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ This is an early release with an initial set of built-in providers. More provide
1414
- Config: `TOOLS_PROMETHEUS_URL`, `TOOLS_PROMETHEUS_DISABLE_SSL`
1515
- **Loki logs**: `redis_sre_agent.tools.logs.loki.provider.LokiToolProvider`
1616
- Config: `TOOLS_LOKI_URL`, `TOOLS_LOKI_TENANT_ID`, `TOOLS_LOKI_TIMEOUT`
17-
- **Redis CLI diagnostics**: `redis_sre_agent.tools.diagnostics.redis_cli.provider.RedisCliToolProvider`
18-
- Runs Redis CLI commands against target instances
17+
- **Redis command diagnostics**: `redis_sre_agent.tools.diagnostics.redis_command.provider.RedisCommandToolProvider`
18+
- Runs Redis commands against target instances
1919
- **Host telemetry**: `redis_sre_agent.tools.host_telemetry.provider.HostTelemetryToolProvider`
2020
- System-level metrics and diagnostics
2121

@@ -41,18 +41,22 @@ Implement a ToolProvider subclass that defines tool schemas and resolves calls.
4141
### Minimal skeleton
4242

4343
```python
44-
from typing import Any, Dict, List, Optional
44+
from typing import Any, Dict, List
4545
from redis_sre_agent.tools.protocols import ToolProvider
46-
from redis_sre_agent.tools.tool_definition import ToolDefinition
46+
from redis_sre_agent.tools.models import ToolDefinition, ToolCapability
47+
4748

4849
class MyMetricsProvider(ToolProvider):
49-
provider_name = "my_metrics"
50+
@property
51+
def provider_name(self) -> str:
52+
return "my_metrics"
5053

5154
def create_tool_schemas(self) -> List[ToolDefinition]:
5255
return [
5356
ToolDefinition(
5457
name=self._make_tool_name("query"),
5558
description="Query my metrics backend using a query string.",
59+
capability=ToolCapability.METRICS,
5660
parameters={
5761
"type": "object",
5862
"properties": {"query": {"type": "string"}},
@@ -61,17 +65,13 @@ class MyMetricsProvider(ToolProvider):
6165
)
6266
]
6367

64-
async def resolve_tool_call(self, tool_name: str, args: Dict[str, Any]):
65-
op = self.resolve_operation(tool_name)
66-
if op == "query":
67-
return await self.query(**args)
68-
raise ValueError(f"Unknown operation: {op}")
69-
7068
async def query(self, query: str) -> Dict[str, Any]:
7169
# Implement your backend call
7270
return {"status": "success", "query": query, "data": []}
7371
```
7472

73+
The base class `tools()` method automatically wires tool names to provider methods. When an LLM invokes `my_metrics_{hash}_query`, the framework calls `self.query(**args)` directly. No manual `resolve_tool_call()` implementation is required.
74+
7575
### Register your provider
7676

7777
- Install your package into the same environment as the agent (e.g., `pip install -e /path/to/pkg`)

pyproject.toml

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,26 @@ dependencies = [
5959
"opentelemetry-instrumentation-openai>=0.47.5",
6060
]
6161

62+
[dependency-groups]
63+
dev = [
64+
"pytest>=7.0.0",
65+
"pytest-asyncio>=0.21.0",
66+
"pytest-cov>=4.1.0",
67+
"ruff>=0.3.0",
68+
"black>=23.0.0",
69+
"mypy>=1.8.0",
70+
"testcontainers>=3.7.0",
71+
"pre-commit>=3.6.0",
72+
"safety>=3.0.0",
73+
"bandit>=1.7.0",
74+
# OpenAPI client generator
75+
"openapi-python-client>=0.21.0",
76+
"mkdocs>=1.6.1",
77+
"mkdocs-material>=9.6.22",
78+
"mfcqi>=0.0.4",
79+
]
80+
81+
6282
[project.scripts]
6383
redis-sre-agent = "redis_sre_agent.cli:main"
6484

@@ -77,22 +97,6 @@ include = [
7797

7898
[tool.uv]
7999
default-groups = []
80-
dev-dependencies = [
81-
"pytest>=7.0.0",
82-
"pytest-asyncio>=0.21.0",
83-
"pytest-cov>=4.1.0",
84-
"ruff>=0.3.0",
85-
"black>=23.0.0",
86-
"mypy>=1.8.0",
87-
"testcontainers>=3.7.0",
88-
"pre-commit>=3.6.0",
89-
"safety>=3.0.0",
90-
"bandit>=1.7.0",
91-
# OpenAPI client generator
92-
"openapi-python-client>=0.21.0",
93-
"mkdocs>=1.6.1",
94-
"mkdocs-material>=9.6.22",
95-
]
96100

97101
[tool.pytest.ini_options]
98102
testpaths = ["tests"]

redis_sre_agent/agent/helpers.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -85,23 +85,23 @@ def sanitize_messages_for_llm(msgs: List[Any]) -> List[Any]:
8585
for m in msgs:
8686
if isinstance(m, _AI):
8787
try:
88-
for tc in getattr(m, "tool_calls", []) or []:
88+
for tc in m.tool_calls or []:
8989
if isinstance(tc, dict):
9090
tid = tc.get("id") or tc.get("tool_call_id")
9191
if tid:
9292
seen_tool_ids.add(tid)
9393
except Exception:
9494
pass
9595
clean.append(m)
96-
elif isinstance(m, _TM) or getattr(m, "type", "") == "tool":
97-
tid = getattr(m, "tool_call_id", None)
96+
elif isinstance(m, _TM) or m.type == "tool":
97+
tid = m.tool_call_id
9898
if tid and tid in seen_tool_ids:
9999
clean.append(m)
100100
else:
101101
continue
102102
else:
103103
clean.append(m)
104-
while clean and (isinstance(clean[0], _TM) or getattr(clean[0], "type", "") == "tool"):
104+
while clean and (isinstance(clean[0], _TM) or clean[0].type == "tool"):
105105
clean = clean[1:]
106106
return clean
107107

@@ -122,26 +122,26 @@ def _compact_messages_tail(msgs: List[Any], limit: int = 6) -> List[Dict[str, An
122122
tail = msgs[-limit:] if msgs else []
123123
compact: List[Dict[str, Any]] = []
124124
for m in tail:
125-
role = getattr(m, "type", m.__class__.__name__.lower())
125+
role = m.type if m.type else m.__class__.__name__.lower()
126126
row: Dict[str, Any] = {"role": role}
127127
try:
128-
is_ai = (_AI is not None and isinstance(m, _AI)) or getattr(m, "type", "") in (
128+
is_ai = (_AI is not None and isinstance(m, _AI)) or m.type in (
129129
"ai",
130130
"assistant",
131131
)
132132
if is_ai:
133133
ids: List[str] = []
134-
for tc in getattr(m, "tool_calls", []) or []:
134+
for tc in m.tool_calls or []:
135135
if isinstance(tc, dict):
136136
tid = tc.get("id") or tc.get("tool_call_id")
137137
if tid:
138138
ids.append(tid)
139139
if ids:
140140
row["tool_calls"] = ids
141-
is_tool = (_TM is not None and isinstance(m, _TM)) or getattr(m, "type", "") == "tool"
141+
is_tool = (_TM is not None and isinstance(m, _TM)) or m.type == "tool"
142142
if is_tool:
143-
row["tool_call_id"] = getattr(m, "tool_call_id", None)
144-
name = getattr(m, "name", None)
143+
row["tool_call_id"] = m.tool_call_id
144+
name = m.name
145145
if name:
146146
row["name"] = name
147147
except Exception:
@@ -193,7 +193,7 @@ def build_result_envelope(
193193

194194
from .models import ResultEnvelope
195195

196-
content = getattr(tool_message, "content", None)
196+
content = tool_message.content
197197
data_obj = None
198198
if isinstance(content, str) and content:
199199
try:
@@ -210,9 +210,8 @@ def _extract_operation_from_tool_name(full: str) -> str:
210210
parts = full.split(".")
211211
return parts[-1] if parts else full
212212

213-
description = (
214-
getattr(tooldefs_by_name.get(tool_name), "description", None) if tool_name else None
215-
)
213+
tdef = tooldefs_by_name.get(tool_name) if tool_name else None
214+
description = tdef.description if tdef else None
216215
env = ResultEnvelope(
217216
tool_key=tool_name or "tool",
218217
name=_extract_operation_from_tool_name(tool_name or "tool"),
@@ -224,13 +223,15 @@ def _extract_operation_from_tool_name(full: str) -> str:
224223
return env.model_dump()
225224

226225

227-
async def build_adapters_for_tooldefs(
228-
tool_manager: Any, tooldefs: List[Any]
229-
) -> tuple[list[dict], list[Any]]:
230-
"""Create OpenAI tool schemas and LangChain StructuredTool adapters for ToolDefinitions.
226+
async def build_adapters_for_tooldefs(tool_manager: Any, tooldefs: List[Any]) -> list[Any]:
227+
"""Create LangChain StructuredTool adapters for ToolDefinitions.
231228
232-
Returns (tool_schemas, adapters)
229+
Each adapter wraps :meth:`ToolManager.resolve_tool_call` so that tools can
230+
be executed either via LangGraph's :class:`ToolNode` or directly via the
231+
manager. The same adapters can also be passed to ``ChatOpenAI.bind_tools``
232+
so we do not need to maintain separate OpenAI-specific tool schemas.
233233
"""
234+
234235
try:
235236
from typing import Any as _Any
236237

@@ -241,7 +242,7 @@ async def build_adapters_for_tooldefs(
241242
from pydantic import create_model as _create_model
242243
except Exception:
243244
# Best-effort fallback (should not happen in runtime)
244-
return [], []
245+
return []
245246

246247
def _args_model_from_parameters(tool_name: str, params: dict) -> type[_BaseModel]:
247248
props = (params or {}).get("properties", {}) or {}
@@ -261,20 +262,19 @@ def _args_model_from_parameters(tool_name: str, params: dict) -> type[_BaseModel
261262
pass
262263
return args_model
263264

264-
tool_schemas: list[dict] = [t.to_openai_schema() for t in (tooldefs or [])]
265265
adapters: list[_StructuredTool] = []
266266
for tdef in tooldefs or []:
267267

268268
async def _exec_fn(_name=tdef.name, **kwargs):
269269
return await tool_manager.resolve_tool_call(_name, kwargs or {})
270270

271-
ArgsModel = _args_model_from_parameters(tdef.name, getattr(tdef, "parameters", {}) or {}) # noqa: N806
271+
args_model = _args_model_from_parameters(tdef.name, tdef.parameters or {})
272272
adapters.append(
273273
_StructuredTool.from_function(
274274
coroutine=_exec_fn,
275275
name=tdef.name,
276-
description=getattr(tdef, "description", "") or "",
277-
args_schema=ArgsModel,
276+
description=tdef.description or "",
277+
args_schema=args_model,
278278
)
279279
)
280-
return tool_schemas, adapters
280+
return adapters

0 commit comments

Comments
 (0)