Skip to content

Commit 0684dbc

Browse files
committed
refactor(tools): Refactor tool provider system for MCP support
- Introduce Tool model combining schema, metadata, and invoke callable - Move ToolCapability, ToolMetadata, and ToolDefinition to tools/models.py - Add automatic wiring in ToolProvider.tools() that binds tool names to methods - Update all providers to use create_tool_schemas() + method pattern - Standardize provider_name as @Property across all providers - Rename redis_cli provider to redis_command - Update documentation and tests to match new patterns - Fix pre-existing linting issues (N806, N814, E402)
1 parent 7954236 commit 0684dbc

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+3781
-3460
lines changed

docs/how-to/tool-providers.md

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ This is an early release with an initial set of built-in providers. More provide
1414
- Config: `TOOLS_PROMETHEUS_URL`, `TOOLS_PROMETHEUS_DISABLE_SSL`
1515
- **Loki logs**: `redis_sre_agent.tools.logs.loki.provider.LokiToolProvider`
1616
- Config: `TOOLS_LOKI_URL`, `TOOLS_LOKI_TENANT_ID`, `TOOLS_LOKI_TIMEOUT`
17-
- **Redis CLI diagnostics**: `redis_sre_agent.tools.diagnostics.redis_cli.provider.RedisCliToolProvider`
18-
- Runs Redis CLI commands against target instances
17+
- **Redis command diagnostics**: `redis_sre_agent.tools.diagnostics.redis_command.provider.RedisCommandToolProvider`
18+
- Runs Redis commands against target instances
1919
- **Host telemetry**: `redis_sre_agent.tools.host_telemetry.provider.HostTelemetryToolProvider`
2020
- System-level metrics and diagnostics
2121

@@ -41,18 +41,22 @@ Implement a ToolProvider subclass that defines tool schemas and resolves calls.
4141
### Minimal skeleton
4242

4343
```python
44-
from typing import Any, Dict, List, Optional
44+
from typing import Any, Dict, List
4545
from redis_sre_agent.tools.protocols import ToolProvider
46-
from redis_sre_agent.tools.tool_definition import ToolDefinition
46+
from redis_sre_agent.tools.models import ToolDefinition, ToolCapability
47+
4748

4849
class MyMetricsProvider(ToolProvider):
49-
provider_name = "my_metrics"
50+
@property
51+
def provider_name(self) -> str:
52+
return "my_metrics"
5053

5154
def create_tool_schemas(self) -> List[ToolDefinition]:
5255
return [
5356
ToolDefinition(
5457
name=self._make_tool_name("query"),
5558
description="Query my metrics backend using a query string.",
59+
capability=ToolCapability.METRICS,
5660
parameters={
5761
"type": "object",
5862
"properties": {"query": {"type": "string"}},
@@ -61,17 +65,13 @@ class MyMetricsProvider(ToolProvider):
6165
)
6266
]
6367

64-
async def resolve_tool_call(self, tool_name: str, args: Dict[str, Any]):
65-
op = self.resolve_operation(tool_name)
66-
if op == "query":
67-
return await self.query(**args)
68-
raise ValueError(f"Unknown operation: {op}")
69-
7068
async def query(self, query: str) -> Dict[str, Any]:
7169
# Implement your backend call
7270
return {"status": "success", "query": query, "data": []}
7371
```
7472

73+
The base class `tools()` method automatically wires tool names to provider methods. When an LLM invokes `my_metrics_{hash}_query`, the framework calls `self.query(**args)` directly. No manual `resolve_tool_call()` implementation is required.
74+
7575
### Register your provider
7676

7777
- Install your package into the same environment as the agent (e.g., `pip install -e /path/to/pkg`)

pyproject.toml

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,26 @@ dependencies = [
5959
"opentelemetry-instrumentation-openai>=0.47.5",
6060
]
6161

62+
[dependency-groups]
63+
dev = [
64+
"pytest>=7.0.0",
65+
"pytest-asyncio>=0.21.0",
66+
"pytest-cov>=4.1.0",
67+
"ruff>=0.3.0",
68+
"black>=23.0.0",
69+
"mypy>=1.8.0",
70+
"testcontainers>=3.7.0",
71+
"pre-commit>=3.6.0",
72+
"safety>=3.0.0",
73+
"bandit>=1.7.0",
74+
# OpenAPI client generator
75+
"openapi-python-client>=0.21.0",
76+
"mkdocs>=1.6.1",
77+
"mkdocs-material>=9.6.22",
78+
"mfcqi>=0.0.4",
79+
]
80+
81+
6282
[project.scripts]
6383
redis-sre-agent = "redis_sre_agent.cli:main"
6484

@@ -77,22 +97,6 @@ include = [
7797

7898
[tool.uv]
7999
default-groups = []
80-
dev-dependencies = [
81-
"pytest>=7.0.0",
82-
"pytest-asyncio>=0.21.0",
83-
"pytest-cov>=4.1.0",
84-
"ruff>=0.3.0",
85-
"black>=23.0.0",
86-
"mypy>=1.8.0",
87-
"testcontainers>=3.7.0",
88-
"pre-commit>=3.6.0",
89-
"safety>=3.0.0",
90-
"bandit>=1.7.0",
91-
# OpenAPI client generator
92-
"openapi-python-client>=0.21.0",
93-
"mkdocs>=1.6.1",
94-
"mkdocs-material>=9.6.22",
95-
]
96100

97101
[tool.pytest.ini_options]
98102
testpaths = ["tests"]

redis_sre_agent/agent/helpers.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -85,23 +85,23 @@ def sanitize_messages_for_llm(msgs: List[Any]) -> List[Any]:
8585
for m in msgs:
8686
if isinstance(m, _AI):
8787
try:
88-
for tc in getattr(m, "tool_calls", []) or []:
88+
for tc in m.tool_calls or []:
8989
if isinstance(tc, dict):
9090
tid = tc.get("id") or tc.get("tool_call_id")
9191
if tid:
9292
seen_tool_ids.add(tid)
9393
except Exception:
9494
pass
9595
clean.append(m)
96-
elif isinstance(m, _TM) or getattr(m, "type", "") == "tool":
97-
tid = getattr(m, "tool_call_id", None)
96+
elif isinstance(m, _TM) or m.type == "tool":
97+
tid = m.tool_call_id
9898
if tid and tid in seen_tool_ids:
9999
clean.append(m)
100100
else:
101101
continue
102102
else:
103103
clean.append(m)
104-
while clean and (isinstance(clean[0], _TM) or getattr(clean[0], "type", "") == "tool"):
104+
while clean and (isinstance(clean[0], _TM) or clean[0].type == "tool"):
105105
clean = clean[1:]
106106
return clean
107107

@@ -122,26 +122,26 @@ def _compact_messages_tail(msgs: List[Any], limit: int = 6) -> List[Dict[str, An
122122
tail = msgs[-limit:] if msgs else []
123123
compact: List[Dict[str, Any]] = []
124124
for m in tail:
125-
role = getattr(m, "type", m.__class__.__name__.lower())
125+
role = m.type if m.type else m.__class__.__name__.lower()
126126
row: Dict[str, Any] = {"role": role}
127127
try:
128-
is_ai = (_AI is not None and isinstance(m, _AI)) or getattr(m, "type", "") in (
128+
is_ai = (_AI is not None and isinstance(m, _AI)) or m.type in (
129129
"ai",
130130
"assistant",
131131
)
132132
if is_ai:
133133
ids: List[str] = []
134-
for tc in getattr(m, "tool_calls", []) or []:
134+
for tc in m.tool_calls or []:
135135
if isinstance(tc, dict):
136136
tid = tc.get("id") or tc.get("tool_call_id")
137137
if tid:
138138
ids.append(tid)
139139
if ids:
140140
row["tool_calls"] = ids
141-
is_tool = (_TM is not None and isinstance(m, _TM)) or getattr(m, "type", "") == "tool"
141+
is_tool = (_TM is not None and isinstance(m, _TM)) or m.type == "tool"
142142
if is_tool:
143-
row["tool_call_id"] = getattr(m, "tool_call_id", None)
144-
name = getattr(m, "name", None)
143+
row["tool_call_id"] = m.tool_call_id
144+
name = m.name
145145
if name:
146146
row["name"] = name
147147
except Exception:
@@ -193,7 +193,7 @@ def build_result_envelope(
193193

194194
from .models import ResultEnvelope
195195

196-
content = getattr(tool_message, "content", None)
196+
content = tool_message.content
197197
data_obj = None
198198
if isinstance(content, str) and content:
199199
try:
@@ -210,9 +210,8 @@ def _extract_operation_from_tool_name(full: str) -> str:
210210
parts = full.split(".")
211211
return parts[-1] if parts else full
212212

213-
description = (
214-
getattr(tooldefs_by_name.get(tool_name), "description", None) if tool_name else None
215-
)
213+
tdef = tooldefs_by_name.get(tool_name) if tool_name else None
214+
description = tdef.description if tdef else None
216215
env = ResultEnvelope(
217216
tool_key=tool_name or "tool",
218217
name=_extract_operation_from_tool_name(tool_name or "tool"),
@@ -224,13 +223,15 @@ def _extract_operation_from_tool_name(full: str) -> str:
224223
return env.model_dump()
225224

226225

227-
async def build_adapters_for_tooldefs(
228-
tool_manager: Any, tooldefs: List[Any]
229-
) -> tuple[list[dict], list[Any]]:
230-
"""Create OpenAI tool schemas and LangChain StructuredTool adapters for ToolDefinitions.
226+
async def build_adapters_for_tooldefs(tool_manager: Any, tooldefs: List[Any]) -> list[Any]:
227+
"""Create LangChain StructuredTool adapters for ToolDefinitions.
231228
232-
Returns (tool_schemas, adapters)
229+
Each adapter wraps :meth:`ToolManager.resolve_tool_call` so that tools can
230+
be executed either via LangGraph's :class:`ToolNode` or directly via the
231+
manager. The same adapters can also be passed to ``ChatOpenAI.bind_tools``
232+
so we do not need to maintain separate OpenAI-specific tool schemas.
233233
"""
234+
234235
try:
235236
from typing import Any as _Any
236237

@@ -241,7 +242,7 @@ async def build_adapters_for_tooldefs(
241242
from pydantic import create_model as _create_model
242243
except Exception:
243244
# Best-effort fallback (should not happen in runtime)
244-
return [], []
245+
return []
245246

246247
def _args_model_from_parameters(tool_name: str, params: dict) -> type[_BaseModel]:
247248
props = (params or {}).get("properties", {}) or {}
@@ -261,20 +262,19 @@ def _args_model_from_parameters(tool_name: str, params: dict) -> type[_BaseModel
261262
pass
262263
return args_model
263264

264-
tool_schemas: list[dict] = [t.to_openai_schema() for t in (tooldefs or [])]
265265
adapters: list[_StructuredTool] = []
266266
for tdef in tooldefs or []:
267267

268268
async def _exec_fn(_name=tdef.name, **kwargs):
269269
return await tool_manager.resolve_tool_call(_name, kwargs or {})
270270

271-
ArgsModel = _args_model_from_parameters(tdef.name, getattr(tdef, "parameters", {}) or {}) # noqa: N806
271+
args_model = _args_model_from_parameters(tdef.name, tdef.parameters or {})
272272
adapters.append(
273273
_StructuredTool.from_function(
274274
coroutine=_exec_fn,
275275
name=tdef.name,
276-
description=getattr(tdef, "description", "") or "",
277-
args_schema=ArgsModel,
276+
description=tdef.description or "",
277+
args_schema=args_model,
278278
)
279279
)
280-
return tool_schemas, adapters
280+
return adapters

redis_sre_agent/agent/knowledge_agent.py

Lines changed: 23 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -97,18 +97,13 @@ def __init__(self, progress_callback: Optional[Callable[[str, str], Awaitable[No
9797

9898
logger.info("Knowledge-only agent initialized (tools loaded per-query)")
9999

100-
def _build_workflow(self, tool_mgr: ToolManager) -> StateGraph:
100+
def _build_workflow(self, tool_mgr: ToolManager, llm_with_tools: ChatOpenAI) -> StateGraph:
101101
"""Build the LangGraph workflow for knowledge-only queries.
102102
103103
Args:
104104
tool_mgr: ToolManager instance with knowledge tools loaded
105105
"""
106106

107-
# Bind tools to LLM for this workflow
108-
tools = tool_mgr.get_tools()
109-
tool_schemas = [tool.to_openai_schema() for tool in tools]
110-
llm_with_tools = self.llm.bind_tools(tool_schemas)
111-
112107
async def agent_node(state: KnowledgeAgentState) -> KnowledgeAgentState:
113108
"""Main agent node for knowledge queries."""
114109
messages = state["messages"]
@@ -131,25 +126,23 @@ def _sanitize_messages_for_llm(msgs: list[BaseMessage]) -> list[BaseMessage]:
131126
for m in msgs:
132127
if isinstance(m, AIMessage):
133128
try:
134-
for tc in getattr(m, "tool_calls", []) or []:
129+
for tc in m.tool_calls or []:
135130
if isinstance(tc, dict):
136131
tid = tc.get("id") or tc.get("tool_call_id")
137132
if tid:
138133
seen_tool_ids.add(tid)
139134
except Exception:
140135
pass
141136
clean.append(m)
142-
elif isinstance(m, _TM) or getattr(m, "type", "") == "tool":
143-
tid = getattr(m, "tool_call_id", None)
137+
elif isinstance(m, _TM) or m.type == "tool":
138+
tid = m.tool_call_id
144139
if tid and tid in seen_tool_ids:
145140
clean.append(m)
146141
else:
147142
continue
148143
else:
149144
clean.append(m)
150-
while clean and (
151-
isinstance(clean[0], _TM) or getattr(clean[0], "type", "") == "tool"
152-
):
145+
while clean and (isinstance(clean[0], _TM) or clean[0].type == "tool"):
153146
clean = clean[1:]
154147
return clean
155148

@@ -192,8 +185,8 @@ def _sanitize_messages_for_llm(msgs: list[BaseMessage]) -> list[BaseMessage]:
192185
)
193186
# Coerce non-Message responses (e.g., simple mocks) into AIMessage
194187
if not isinstance(response, BaseMessage):
195-
content = getattr(response, "content", None)
196-
tool_calls = getattr(response, "tool_calls", None)
188+
content = response.content
189+
tool_calls = response.tool_calls
197190
response = AIMessage(
198191
content=str(content) if content is not None else "", tool_calls=tool_calls
199192
)
@@ -205,7 +198,7 @@ def _sanitize_messages_for_llm(msgs: list[BaseMessage]) -> list[BaseMessage]:
205198
state["messages"] = messages + [response]
206199

207200
# Store tool calls for potential execution
208-
if hasattr(response, "tool_calls") and response.tool_calls:
201+
if response.tool_calls:
209202
state["current_tool_calls"] = [
210203
{
211204
"id": tc.get("id", ""),
@@ -235,14 +228,14 @@ async def safe_tool_node(state: KnowledgeAgentState) -> KnowledgeAgentState:
235228
last_message = messages[-1] if messages else None
236229

237230
# Verify we have tool calls to execute
238-
if not (hasattr(last_message, "tool_calls") and last_message.tool_calls):
231+
if not (last_message and last_message.tool_calls):
239232
logger.warning("safe_tool_node called without tool_calls in last message")
240233
return state
241234

242235
try:
243236
# Emit provider-supplied status updates before executing tools
244237
try:
245-
pending = getattr(last_message, "tool_calls", []) or []
238+
pending = last_message.tool_calls or []
246239
if self.progress_callback:
247240
for tc in pending:
248241
tool_name = tc.get("name")
@@ -255,9 +248,7 @@ async def safe_tool_node(state: KnowledgeAgentState) -> KnowledgeAgentState:
255248
pass
256249

257250
# Execute tools using ToolManager (wrapped in OTel span)
258-
_tool_names = [
259-
tc.get("name", "") for tc in (getattr(last_message, "tool_calls", []) or [])
260-
]
251+
_tool_names = [tc.get("name", "") for tc in (last_message.tool_calls or [])]
261252
with tracer.start_as_current_span(
262253
"knowledge.tools.execute",
263254
attributes={
@@ -347,7 +338,7 @@ def should_continue(state: KnowledgeAgentState) -> str:
347338
try:
348339
from redis_sre_agent.core.config import settings as _settings
349340

350-
_budget = int(getattr(_settings, "max_tool_calls_per_stage", 3))
341+
_budget = int(_settings.max_tool_calls_per_stage)
351342
except Exception:
352343
_budget = 3
353344
prev_exec = int(state.get("tool_calls_executed", 0) or 0)
@@ -359,11 +350,7 @@ def should_continue(state: KnowledgeAgentState) -> str:
359350
return END
360351

361352
# If the last message has tool calls, execute them
362-
if (
363-
hasattr(last_message, "tool_calls")
364-
and last_message.tool_calls
365-
and len(last_message.tool_calls) > 0
366-
):
353+
if last_message.tool_calls and len(last_message.tool_calls) > 0:
367354
return "tools"
368355

369356
return END
@@ -434,10 +421,17 @@ async def process_query(
434421

435422
# Create ToolManager with Redis instance-independent tools
436423
async with ToolManager(redis_instance=None) as tool_mgr:
437-
logger.info(f"Loaded {len(tool_mgr.get_tools())} usable without Redis instance details")
424+
tools = tool_mgr.get_tools()
425+
logger.info(f"Loaded {len(tools)} tools usable without Redis instance details")
426+
427+
# Build StructuredTool adapters and bind them to the LLM
428+
from .helpers import build_adapters_for_tooldefs as _build_adapters
429+
430+
adapters = await _build_adapters(tool_mgr, tools)
431+
llm_with_tools = self.llm.bind_tools(adapters)
438432

439-
# Build workflow with tools
440-
workflow = self._build_workflow(tool_mgr)
433+
# Build workflow with tools and bound LLM
434+
workflow = self._build_workflow(tool_mgr, llm_with_tools)
441435

442436
# Create initial state with conversation history
443437
initial_messages = []

0 commit comments

Comments
 (0)