From a2c3dbb664ba7c94d2a003f60e56411acfe93d60 Mon Sep 17 00:00:00 2001
From: braisedpork1964 <497494458@qq.com>
Date: Tue, 24 Jun 2025 06:07:06 +0000
Subject: [PATCH 1/5] support MCP clients
---
.../run_async_agent_api_model_with_mcp.py | 127 ++++++++++++++
lagent/actions/__init__.py | 2 +
lagent/actions/mcp_client.py | 159 ++++++++++++++++++
requirements/runtime.txt | 3 +-
4 files changed, 290 insertions(+), 1 deletion(-)
create mode 100644 examples/run_async_agent_api_model_with_mcp.py
create mode 100644 lagent/actions/mcp_client.py
diff --git a/examples/run_async_agent_api_model_with_mcp.py b/examples/run_async_agent_api_model_with_mcp.py
new file mode 100644
index 0000000..029827b
--- /dev/null
+++ b/examples/run_async_agent_api_model_with_mcp.py
@@ -0,0 +1,127 @@
+import asyncio
+import json
+import os
+import time
+from typing import List
+
+from volcenginesdkarkruntime import Ark, AsyncArk
+
+from lagent.actions import AsyncActionExecutor, AsyncMCPClient
+from lagent.agents import AsyncAgentForInternLM
+from lagent.agents.aggregator import InternLMToolAggregator
+from lagent.agents.stream import get_plugin_prompt
+from lagent.llms import AsyncGPTAPI
+from lagent.prompts import PluginParser
+from lagent.schema import AgentMessage
+
+
+class LcAsyncAPI(AsyncGPTAPI):
+ def __init__(self, model, api_key=None, max_tokens=4096, max_retries=4, **gen_params):
+ if api_key is None:
+ raise ValueError("api_key is required")
+ self.model = model
+ self.max_tokens = max_tokens
+ self.retry = max_retries
+ self.client = AsyncArk(
+ api_key=api_key, timeout=900, max_retries=self.retry, base_url='https://ark.cn-beijing.volces.com/api/v3'
+ )
+ super().__init__(**gen_params)
+
+ async def _chat(self, messages: List[dict], **gen_params) -> str:
+ """Generate completion from a list of templates.
+
+ Args:
+ messages (List[dict]): a list of prompt dictionaries
+ gen_params: additional generation configuration
+
+ Returns:
+ str: The generated string.
+ """
+ assert isinstance(messages, list)
+
+ max_num_retries = 0
+ while max_num_retries < self.retry:
+ try:
+ response = await self.client.chat.completions.create(
+ model=self.model,
+ messages=messages,
+ max_tokens=self.max_tokens,
+ temperature=gen_params.get('temperature', 0.8),
+ top_p=gen_params.get('top_p', 0.95),
+ stream=True,
+ )
+ reasoning_content = ""
+ content = ""
+ async for chunk in response:
+ if (
+ hasattr(chunk.choices[0].delta, 'reasoning_content')
+ and chunk.choices[0].delta.reasoning_content
+ ):
+ reasoning_content += chunk.choices[0].delta.reasoning_content
+ # print(chunk.choices[0].delta.reasoning_content, end="")
+ else:
+ content += chunk.choices[0].delta.content
+ # print(chunk.choices[0].delta.content, end="")
+ # response = json.loads(response.json())
+ # reasoning_content = response['choices'][0]['message']['reasoning_content'].strip()
+ # content = response['choices'][0]['message']['content'].strip()
+ return content if reasoning_content == "" else "" + reasoning_content + "\n" + content
+
+ except Exception as error:
+ self.logger.error(str(error))
+ time.sleep(20)
+ max_num_retries += 1
+
+ raise RuntimeError(
+ 'Calling OpenAI failed after retrying for ' f'{max_num_retries} times. Check the logs for ' 'details.'
+ )
+
+
+TEMPLATE = (
+ "You have access to the following tools:\n{tool_description}\nPlease provide"
+ " your thought process when you need to use a tool, followed by the call statement in this format:"
+ "\n{invocation_format}"
+)
+llm = dict(type=LcAsyncAPI, model=None, api_key=None, top_p=0.95, temperature=0.6, max_tokens=16384, max_retries=50)
+plugin = dict(
+ type=AsyncMCPClient,
+ name='PlayWright',
+ server_type='stdio',
+ command='npx',
+ args=["@playwright/mcp@latest", '--isolated', '--no-sandbox'],
+)
+agent = AsyncAgentForInternLM(
+ llm,
+ plugin,
+ template=TEMPLATE.format(
+ tool_description=get_plugin_prompt(plugin),
+ invocation_format='```json\n{"name": {{tool name}}, "parameters": {{keyword arguments}}}\n```\n',
+ ),
+ output_format=PluginParser(begin="```json\n", end="\n```\n", validate=lambda x: json.loads(x.rstrip('`'))),
+ aggregator=InternLMToolAggregator(environment_role='system'),
+)
+msg = AgentMessage(
+ sender='user',
+ content='解释一下MCP中Sampling Flow的工作机制,参考https://modelcontextprotocol.io/docs/concepts/sampling',
+)
+
+# proj_dir = os.path.dirname(os.path.dirname(__file__))
+# executor = AsyncActionExecutor(
+# dict(
+# type=AsyncMCPClient,
+# name='FS',
+# server_type='stdio',
+# command='npx',
+# args=['-y', '@modelcontextprotocol/server-filesystem', os.path.join(proj_dir, 'docs')],
+# )
+# )
+# msg = AgentMessage(
+# sender='assistant',
+# content=dict(
+# name='FS.read_file',
+# parameters=dict(path=os.path.join(proj_dir, 'docs/en/get_started/install.md')),
+# ),
+# )
+loop = asyncio.get_event_loop()
+res = loop.run_until_complete(agent(msg))
+print(res.content)
diff --git a/lagent/actions/__init__.py b/lagent/actions/__init__.py
index b75a226..0398d82 100644
--- a/lagent/actions/__init__.py
+++ b/lagent/actions/__init__.py
@@ -8,6 +8,7 @@
from .ipython_interactive import AsyncIPythonInteractive, IPythonInteractive
from .ipython_interpreter import AsyncIPythonInterpreter, IPythonInterpreter
from .ipython_manager import IPythonInteractiveManager
+from .mcp_client import AsyncMCPClient
from .parser import BaseParser, JsonParser, TupleParser
from .ppt import PPT, AsyncPPT
from .python_interpreter import AsyncPythonInterpreter, PythonInterpreter
@@ -39,6 +40,7 @@
'AsyncPPT',
'WebBrowser',
'AsyncWebBrowser',
+ 'AsyncMCPClient',
'BaseParser',
'JsonParser',
'TupleParser',
diff --git a/lagent/actions/mcp_client.py b/lagent/actions/mcp_client.py
new file mode 100644
index 0000000..5994401
--- /dev/null
+++ b/lagent/actions/mcp_client.py
@@ -0,0 +1,159 @@
+import asyncio
+import logging
+from contextlib import AsyncExitStack
+from typing import Literal, TypeAlias
+
+from lagent.actions.base_action import BaseAction
+from lagent.actions.parser import JsonParser, ParseError
+from lagent.schema import ActionReturn, ActionStatusCode
+
+ServerType: TypeAlias = Literal["stdio", "sse", "http"]
+
+logger = logging.getLogger(__name__)
+
+
+class AsyncMCPClient(BaseAction):
+ """Model Context Protocol (MCP) Client for asynchronous communication with MCP servers.
+
+ Args:
+ name (str): The name of the action. Make sure it is unique among all actions.
+ server_type (ServerType): The type of MCP server to connect to. Options are "stdio", "sse", or "http".
+ **server_params: Additional parameters for the server connection, which may include:
+ - For stdio servers:
+ - command (str): The command to run the MCP server.
+ - args (list, optional): Additional arguments for the command.
+ - env (dict, optional): Environment variables for the command.
+ - cwd (str, optional): Current working directory for the command.
+ - For sse servers:
+ - url (str): The URL of the MCP server.
+ - headers (dict, optional): Headers to include in the request.
+ - timeout (int, optional): Timeout for the request.
+ - sse_read_timeout (int, optional): Timeout for reading SSE events.
+ - For http servers:
+ - url (str): The URL of the MCP server.
+ - headers (dict, optional): Headers to include in the request.
+ - timeout (int, optional): Timeout for the request.
+ - sse_read_timeout (int, optional): Timeout for reading SSE events.
+ - terminate_on_close (bool, optional): Whether to terminate the connection on close.
+ """
+
+ def __init__(self, name: str, server_type: ServerType, **server_params):
+ self._is_toolkit = True
+ self._session = None
+ self.server_type = server_type
+ self.server_params = server_params
+ self.exit_stack = AsyncExitStack()
+ # initialze the session
+ loop = asyncio.get_event_loop()
+ if loop.is_running():
+ fut = asyncio.run_coroutine_threadsafe(self.list_tools(), loop)
+ tools = fut.result()
+ else:
+ tools = loop.run_until_complete(self.list_tools())
+ self._api_names = {tool.name for tool in tools}
+ super().__init__(
+ description=dict(
+ name=name,
+ api_list=[
+ {
+ 'name': tool.name,
+ 'description': tool.description,
+ 'parameters': [
+ {'name': k, 'type': v['type'].upper(), 'description': v.get('description', '')}
+ for k, v in tool.inputSchema['properties'].items()
+ ],
+ 'required': tool.inputSchema.get('required', []),
+ }
+ for tool in tools
+ ],
+ ),
+ parser=JsonParser,
+ )
+
+ async def _initialize(self):
+ """Initialize the MCP client and connect to the server."""
+ if self._session is not None:
+ return
+
+ from mcp import ClientSession, StdioServerParameters
+
+ if self.server_type == "stdio":
+ from mcp.client.stdio import stdio_client
+
+ logger.info(
+ f"Connecting to stdio MCP server with command: {self.server_params['command']} "
+ f"{self.server_params.get('args', [])}"
+ )
+
+ client_kwargs = {"command": self.server_params["command"]}
+ for key in ["args", "env", "cwd"]:
+ if self.server_params.get(key) is not None:
+ client_kwargs[key] = self.server_params[key]
+ server_params = StdioServerParameters(**client_kwargs)
+ read, write = await self.exit_stack.enter_async_context(stdio_client(server_params))
+ elif self.server_type == "sse":
+ from mcp.client.sse import sse_client
+
+ logger.info(f"Connecting to SSE MCP server at: {self.server_params['url']}")
+
+ client_kwargs = {"url": self.server_params["url"]}
+ for key in ["headers", "timeout", "sse_read_timeout"]:
+ if self.server_params.get(key) is not None:
+ client_kwargs[key] = self.server_params[key]
+ read, write = await self.exit_stack.enter_async_context(sse_client(**client_kwargs))
+ elif self.server_type == "http":
+ from mcp.client.streamable_http import streamablehttp_client
+
+ logger.info(f"Connecting to StreamableHTTP MCP server at: {self.server_params['url']}")
+
+ client_kwargs = {"url": self.server_params["url"]}
+ for key in ["headers", "timeout", "sse_read_timeout", "terminate_on_close"]:
+ if self.server_params.get(key) is not None:
+ client_kwargs[key] = self.server_params[key]
+ read, write, _ = await self.exit_stack.enter_async_context(streamablehttp_client(**client_kwargs))
+ else:
+ raise ValueError(f"Unsupported server type: {self.server_type}")
+
+ self._session = await self.exit_stack.enter_async_context(ClientSession(read, write))
+
+ async def list_tools(self) -> list:
+ await self._initialize()
+ return (await self._session.list_tools()).tools
+
+ async def cleanup(self):
+ await self.exit_stack.aclose()
+
+ def __del__(self):
+ loop = asyncio.get_event_loop()
+ if loop.is_running():
+ fut = asyncio.run_coroutine_threadsafe(self.cleanup(), loop)
+ fut.result()
+ else:
+ loop.run_until_complete(self.cleanup())
+
+ async def __call__(self, inputs: str, name: str) -> ActionReturn:
+ fallback_args = {'inputs': inputs, 'name': name}
+ if name not in self._api_names:
+ return ActionReturn(
+ fallback_args, type=self.name, errmsg=f'invalid API: {name}', state=ActionStatusCode.API_ERROR
+ )
+ try:
+ inputs = self._parser.parse_inputs(inputs, name)
+ except ParseError as exc:
+ return ActionReturn(fallback_args, type=self.name, errmsg=exc.err_msg, state=ActionStatusCode.ARGS_ERROR)
+ try:
+ await self._initialize()
+ outputs = await self._session.call_tool(name, inputs)
+ outputs = outputs.content[0].text
+ except Exception as exc:
+ return ActionReturn(inputs, type=self.name, errmsg=str(exc), state=ActionStatusCode.API_ERROR)
+ if isinstance(outputs, ActionReturn):
+ action_return = outputs
+ if not action_return.args:
+ action_return.args = inputs
+ if not action_return.type:
+ action_return.type = self.name
+ else:
+ result = self._parser.parse_outputs(outputs)
+ action_return = ActionReturn(inputs, type=self.name, result=result)
+ return action_return
diff --git a/requirements/runtime.txt b/requirements/runtime.txt
index ac0b85c..14930de 100644
--- a/requirements/runtime.txt
+++ b/requirements/runtime.txt
@@ -12,7 +12,8 @@ jsonschema
jupyter==1.0.0
jupyter_client==8.6.2
jupyter_core==5.7.2
-pydantic==2.6.4
+mcp
+pydantic
requests
tenacity
termcolor
From 5266b908ff968cb4943a05e0692af187f99cef19 Mon Sep 17 00:00:00 2001
From: braisedpork1964 <497494458@qq.com>
Date: Tue, 24 Jun 2025 12:22:01 +0000
Subject: [PATCH 2/5] add session management in concurrency
---
lagent/actions/mcp_client.py | 25 ++++++++++++++-----------
1 file changed, 14 insertions(+), 11 deletions(-)
diff --git a/lagent/actions/mcp_client.py b/lagent/actions/mcp_client.py
index 5994401..9586c48 100644
--- a/lagent/actions/mcp_client.py
+++ b/lagent/actions/mcp_client.py
@@ -39,7 +39,7 @@ class AsyncMCPClient(BaseAction):
def __init__(self, name: str, server_type: ServerType, **server_params):
self._is_toolkit = True
- self._session = None
+ self._sessions: dict = {}
self.server_type = server_type
self.server_params = server_params
self.exit_stack = AsyncExitStack()
@@ -70,10 +70,10 @@ def __init__(self, name: str, server_type: ServerType, **server_params):
parser=JsonParser,
)
- async def _initialize(self):
+ async def initialize(self, session_id):
"""Initialize the MCP client and connect to the server."""
- if self._session is not None:
- return
+ if session_id in self._sessions:
+ return self._sessions[session_id]
from mcp import ClientSession, StdioServerParameters
@@ -114,15 +114,18 @@ async def _initialize(self):
else:
raise ValueError(f"Unsupported server type: {self.server_type}")
- self._session = await self.exit_stack.enter_async_context(ClientSession(read, write))
-
- async def list_tools(self) -> list:
- await self._initialize()
- return (await self._session.list_tools()).tools
+ session = await self.exit_stack.enter_async_context(ClientSession(read, write))
+ await session.initialize()
+ self._sessions[session_id] = session
+ return session
async def cleanup(self):
await self.exit_stack.aclose()
+ async def list_tools(self, session_id=0) -> list:
+ session = await self.initialize(session_id=session_id)
+ return (await session.list_tools()).tools
+
def __del__(self):
loop = asyncio.get_event_loop()
if loop.is_running():
@@ -142,8 +145,8 @@ async def __call__(self, inputs: str, name: str) -> ActionReturn:
except ParseError as exc:
return ActionReturn(fallback_args, type=self.name, errmsg=exc.err_msg, state=ActionStatusCode.ARGS_ERROR)
try:
- await self._initialize()
- outputs = await self._session.call_tool(name, inputs)
+ session = await self.initialize(inputs.pop('session_id', 0))
+ outputs = await session.call_tool(name, inputs)
outputs = outputs.content[0].text
except Exception as exc:
return ActionReturn(inputs, type=self.name, errmsg=str(exc), state=ActionStatusCode.API_ERROR)
From fdadebb594571690b7c4055c3fb32fc77b2fdb86 Mon Sep 17 00:00:00 2001
From: braisedpork1964 <497494458@qq.com>
Date: Wed, 25 Jun 2025 11:28:37 +0000
Subject: [PATCH 3/5] fix acquiring the loop in action initialization
---
lagent/actions/mcp_client.py | 37 ++++++++++++++++++++++++++++++++----
1 file changed, 33 insertions(+), 4 deletions(-)
diff --git a/lagent/actions/mcp_client.py b/lagent/actions/mcp_client.py
index 9586c48..9d5683a 100644
--- a/lagent/actions/mcp_client.py
+++ b/lagent/actions/mcp_client.py
@@ -10,6 +10,32 @@
ServerType: TypeAlias = Literal["stdio", "sse", "http"]
logger = logging.getLogger(__name__)
+_loop = None
+
+
+def _get_event_loop():
+ try:
+ event_loop = asyncio.get_event_loop()
+ except Exception:
+ logger.warning('Can not found event loop in current thread. Create a new event loop.')
+ event_loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(event_loop)
+
+ if event_loop.is_running():
+ global _loop
+ if _loop:
+ return _loop
+
+ from threading import Thread
+
+ def _start_loop(loop):
+ asyncio.set_event_loop(loop)
+ loop.run_forever()
+
+ event_loop = asyncio.new_event_loop()
+ Thread(target=_start_loop, args=(event_loop,), daemon=True).start()
+ _loop = event_loop
+ return event_loop
class AsyncMCPClient(BaseAction):
@@ -37,14 +63,16 @@ class AsyncMCPClient(BaseAction):
- terminate_on_close (bool, optional): Whether to terminate the connection on close.
"""
+ is_stateful = True
+
def __init__(self, name: str, server_type: ServerType, **server_params):
self._is_toolkit = True
self._sessions: dict = {}
self.server_type = server_type
self.server_params = server_params
self.exit_stack = AsyncExitStack()
- # initialze the session
- loop = asyncio.get_event_loop()
+ # get the list of tools from the MCP server
+ loop = _get_event_loop()
if loop.is_running():
fut = asyncio.run_coroutine_threadsafe(self.list_tools(), loop)
tools = fut.result()
@@ -127,7 +155,7 @@ async def list_tools(self, session_id=0) -> list:
return (await session.list_tools()).tools
def __del__(self):
- loop = asyncio.get_event_loop()
+ loop = _get_event_loop()
if loop.is_running():
fut = asyncio.run_coroutine_threadsafe(self.cleanup(), loop)
fut.result()
@@ -135,6 +163,7 @@ def __del__(self):
loop.run_until_complete(self.cleanup())
async def __call__(self, inputs: str, name: str) -> ActionReturn:
+ session_id = inputs.pop('session_id', 0) if isinstance(inputs, dict) else 0
fallback_args = {'inputs': inputs, 'name': name}
if name not in self._api_names:
return ActionReturn(
@@ -145,7 +174,7 @@ async def __call__(self, inputs: str, name: str) -> ActionReturn:
except ParseError as exc:
return ActionReturn(fallback_args, type=self.name, errmsg=exc.err_msg, state=ActionStatusCode.ARGS_ERROR)
try:
- session = await self.initialize(inputs.pop('session_id', 0))
+ session = await self.initialize(session_id)
outputs = await session.call_tool(name, inputs)
outputs = outputs.content[0].text
except Exception as exc:
From faa4d40ff75607af8869f38b28686f58dd89e10c Mon Sep 17 00:00:00 2001
From: braisedpork1964 <497494458@qq.com>
Date: Fri, 27 Jun 2025 03:27:24 +0000
Subject: [PATCH 4/5] inject session id into stateful action calling
---
lagent/actions/base_action.py | 2 +
lagent/actions/ipython_interpreter.py | 106 +++++++++++---------------
lagent/hooks/action_preprocessor.py | 55 ++++++++-----
3 files changed, 82 insertions(+), 81 deletions(-)
diff --git a/lagent/actions/base_action.py b/lagent/actions/base_action.py
index b42036a..ed31b22 100644
--- a/lagent/actions/base_action.py
+++ b/lagent/actions/base_action.py
@@ -340,6 +340,8 @@ def sub(self, a, b):
action = Calculator()
"""
+ is_stateful = False
+
def __init__(
self,
description: Optional[dict] = None,
diff --git a/lagent/actions/ipython_interpreter.py b/lagent/actions/ipython_interpreter.py
index 68e9a0d..a022e4a 100644
--- a/lagent/actions/ipython_interpreter.py
+++ b/lagent/actions/ipython_interpreter.py
@@ -52,9 +52,7 @@ async def async_run_code(
assert iopub_timeout > interrupt_after
try:
- async def get_iopub_msg_with_death_detection(kc: AsyncKernelClient,
- *,
- timeout=None):
+ async def get_iopub_msg_with_death_detection(kc: AsyncKernelClient, *, timeout=None):
loop = asyncio.get_running_loop()
dead_fut = loop.create_future()
@@ -71,8 +69,7 @@ def dead():
km.add_restart_callback(restarting, "restart")
km.add_restart_callback(dead, "dead")
try:
- done, _ = await asyncio.wait(
- [dead_fut, msg_task], return_when=asyncio.FIRST_COMPLETED)
+ done, _ = await asyncio.wait([dead_fut, msg_task], return_when=asyncio.FIRST_COMPLETED)
if dead_fut in done:
raise KernelDeath()
assert msg_task in done
@@ -88,13 +85,21 @@ async def send_interrupt():
await km.interrupt_kernel()
@retry(
- retry=retry_if_result(lambda ret: ret[-1].strip() in [
- 'KeyboardInterrupt',
- f"Kernel didn't respond in {wait_for_ready_timeout} seconds",
- ] if isinstance(ret, tuple) else False),
+ retry=retry_if_result(
+ lambda ret: (
+ ret[-1].strip()
+ in [
+ 'KeyboardInterrupt',
+ f"Kernel didn't respond in {wait_for_ready_timeout} seconds",
+ ]
+ if isinstance(ret, tuple)
+ else False
+ )
+ ),
stop=stop_after_attempt(3),
wait=wait_fixed(1),
- retry_error_callback=lambda state: state.outcome.result())
+ retry_error_callback=lambda state: state.outcome.result(),
+ )
async def run():
execute_result = None
error_traceback = None
@@ -106,11 +111,9 @@ async def run():
await kc.wait_for_ready(timeout=wait_for_ready_timeout)
msg_id = kc.execute(code)
while True:
- message = await get_iopub_msg_with_death_detection(
- kc, timeout=iopub_timeout)
+ message = await get_iopub_msg_with_death_detection(kc, timeout=iopub_timeout)
if logger.isEnabledFor(logging.DEBUG):
- logger.debug(
- json.dumps(message, indent=2, default=str))
+ logger.debug(json.dumps(message, indent=2, default=str))
assert message["parent_header"]["msg_id"] == msg_id
msg_type = message["msg_type"]
if msg_type == "status":
@@ -136,8 +139,7 @@ async def run():
if interrupt_after:
run_task = asyncio.create_task(run())
send_interrupt_task = asyncio.create_task(send_interrupt())
- done, _ = await asyncio.wait([run_task, send_interrupt_task],
- return_when=asyncio.FIRST_COMPLETED)
+ done, _ = await asyncio.wait([run_task, send_interrupt_task], return_when=asyncio.FIRST_COMPLETED)
if run_task in done:
send_interrupt_task.cancel()
else:
@@ -216,13 +218,10 @@ def reset(self):
if not self._initialized:
self.initialize()
else:
- code = "get_ipython().run_line_magic('reset', '-f')\n" + \
- START_CODE.format(self.user_data_dir)
+ code = "get_ipython().run_line_magic('reset', '-f')\n" + START_CODE.format(self.user_data_dir)
self._call(code, None)
- def _call(self,
- command: str,
- timeout: Optional[int] = None) -> Tuple[str, bool]:
+ def _call(self, command: str, timeout: Optional[int] = None) -> Tuple[str, bool]:
self.initialize()
command = extract_code(command)
@@ -261,16 +260,14 @@ def _inner_call():
text = msg['content']['data'].get('text/plain', '')
if 'image/png' in msg['content']['data']:
image_b64 = msg['content']['data']['image/png']
- image_url = publish_image_to_local(
- image_b64, self.work_dir)
+ image_url = publish_image_to_local(image_b64, self.work_dir)
image_idx += 1
image = '' % (image_idx, image_url)
elif msg_type == 'display_data':
if 'image/png' in msg['content']['data']:
image_b64 = msg['content']['data']['image/png']
- image_url = publish_image_to_local(
- image_b64, self.work_dir)
+ image_url = publish_image_to_local(image_b64, self.work_dir)
image_idx += 1
image = '' % (image_idx, image_url)
@@ -281,8 +278,7 @@ def _inner_call():
text = msg['content']['text']
elif msg_type == 'error':
succeed = False
- text = escape_ansi('\n'.join(
- msg['content']['traceback']))
+ text = escape_ansi('\n'.join(msg['content']['traceback']))
if 'M6_CODE_INTERPRETER_TIMEOUT' in text:
text = f'Timeout. No response after {timeout} seconds.' # noqa
except queue.Empty:
@@ -349,8 +345,7 @@ def run(self, command: str, timeout: Optional[int] = None) -> ActionReturn:
# text=result['text'], image=result.get('image', [])[0])
tool_return.state = ActionStatusCode.SUCCESS
else:
- tool_return.errmsg = result.get('text', '') if isinstance(
- result, dict) else result
+ tool_return.errmsg = result.get('text', '') if isinstance(result, dict) else result
tool_return.state = ActionStatusCode.API_ERROR
return tool_return
@@ -371,6 +366,7 @@ class AsyncIPythonInterpreter(AsyncActionMixin, IPythonInterpreter):
action's inputs and outputs. Defaults to :class:`JsonParser`.
"""
+ is_stateful = True
_UNBOUND_KERNEL_CLIENTS = asyncio.Queue()
def __init__(
@@ -390,8 +386,7 @@ def __init__(
c = Config()
c.KernelManager.transport = 'ipc'
- self._amkm = AsyncMultiKernelManager(
- config=c, connection_dir=connection_dir)
+ self._amkm = AsyncMultiKernelManager(config=c, connection_dir=connection_dir)
self._max_kernels = max_kernels
self._reuse_kernel = reuse_kernel
self._sem = asyncio.Semaphore(startup_rate)
@@ -403,25 +398,23 @@ async def initialize(self, session_id: str):
if session_id in self._KERNEL_CLIENTS:
return self._KERNEL_CLIENTS[session_id]
if self._reuse_kernel and not self._UNBOUND_KERNEL_CLIENTS.empty():
- self._KERNEL_CLIENTS[
- session_id] = await self._UNBOUND_KERNEL_CLIENTS.get()
+ self._KERNEL_CLIENTS[session_id] = await self._UNBOUND_KERNEL_CLIENTS.get()
return self._KERNEL_CLIENTS[session_id]
async with self._sem:
- if self._max_kernels is None or len(
- self._KERNEL_CLIENTS
- ) + self._UNBOUND_KERNEL_CLIENTS.qsize() < self._max_kernels:
+ if (
+ self._max_kernels is None
+ or len(self._KERNEL_CLIENTS) + self._UNBOUND_KERNEL_CLIENTS.qsize() < self._max_kernels
+ ):
kernel_id = None
try:
kernel_id = await self._amkm.start_kernel()
kernel = self._amkm.get_kernel(kernel_id)
client = kernel.client()
_, error_stacktrace, stream_text = await async_run_code(
- kernel,
- START_CODE.format(self.user_data_dir),
- shutdown_kernel=False)
+ kernel, START_CODE.format(self.user_data_dir), shutdown_kernel=False
+ )
# check if the output of START_CODE meets expectations
- if not (error_stacktrace is None
- and stream_text == ''):
+ if not (error_stacktrace is None and stream_text == ''):
raise RuntimeError
except Exception as e:
print(f'Starting kernel error: {e}')
@@ -431,15 +424,11 @@ async def initialize(self, session_id: str):
await asyncio.sleep(1)
continue
if self._max_kernels is None:
- self._KERNEL_CLIENTS[session_id] = (kernel_id, kernel,
- client)
+ self._KERNEL_CLIENTS[session_id] = (kernel_id, kernel, client)
return kernel_id, kernel, client
async with self._lock:
- if len(self._KERNEL_CLIENTS
- ) + self._UNBOUND_KERNEL_CLIENTS.qsize(
- ) < self._max_kernels:
- self._KERNEL_CLIENTS[session_id] = (kernel_id,
- kernel, client)
+ if len(self._KERNEL_CLIENTS) + self._UNBOUND_KERNEL_CLIENTS.qsize() < self._max_kernels:
+ self._KERNEL_CLIENTS[session_id] = (kernel_id, kernel, client)
return kernel_id, kernel, client
await self._amkm.shutdown_kernel(kernel_id)
self._amkm.remove_kernel(kernel_id)
@@ -450,8 +439,7 @@ async def reset(self, session_id: str):
if session_id not in self._KERNEL_CLIENTS:
return
_, kernel, _ = self._KERNEL_CLIENTS[session_id]
- code = "get_ipython().run_line_magic('reset', '-f')\n" + \
- START_CODE.format(self.user_data_dir)
+ code = "get_ipython().run_line_magic('reset', '-f')\n" + START_CODE.format(self.user_data_dir)
await async_run_code(kernel, code, shutdown_kernel=False)
async def shutdown(self, session_id: str):
@@ -467,18 +455,15 @@ async def close_session(self, session_id: str):
if self._reuse_kernel:
if session_id in self._KERNEL_CLIENTS:
await self.reset(session_id)
- await self._UNBOUND_KERNEL_CLIENTS.put(
- self._KERNEL_CLIENTS.pop(session_id))
+ await self._UNBOUND_KERNEL_CLIENTS.put(self._KERNEL_CLIENTS.pop(session_id))
else:
await self.shutdown(session_id)
async def _call(self, command, timeout=None, session_id=None):
_, kernel, _ = await self.initialize(str(session_id))
result = await async_run_code(
- kernel,
- extract_code(command),
- interrupt_after=timeout or self.timeout,
- shutdown_kernel=False)
+ kernel, extract_code(command), interrupt_after=timeout or self.timeout, shutdown_kernel=False
+ )
execute_result, error_stacktrace, stream_text = result
if error_stacktrace is not None:
ret = re.sub('^-*\n', '', escape_ansi(error_stacktrace))
@@ -492,10 +477,7 @@ async def _call(self, command, timeout=None, session_id=None):
return status, ret
@tool_api
- async def run(self,
- command: str,
- timeout: Optional[int] = None,
- session_id: Optional[str] = None) -> ActionReturn:
+ async def run(self, command: str, timeout: Optional[int] = None, session_id: Optional[str] = None) -> ActionReturn:
r"""When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is disabled. Do not make external web requests or API calls as they will fail.
Args:
@@ -516,8 +498,7 @@ async def run(self,
# text=result['text'], image=result.get('image', [])[0])
tool_return.state = ActionStatusCode.SUCCESS
else:
- tool_return.errmsg = result.get('text', '') if isinstance(
- result, dict) else result
+ tool_return.errmsg = result.get('text', '') if isinstance(result, dict) else result
tool_return.state = ActionStatusCode.API_ERROR
return tool_return
@@ -549,6 +530,7 @@ def escape_ansi(line):
def publish_image_to_local(image_base64: str, work_dir='./work_dir/tmp_dir'):
import PIL.Image
+
image_file = str(uuid.uuid4()) + '.png'
local_image_file = os.path.join(work_dir, image_file)
diff --git a/lagent/hooks/action_preprocessor.py b/lagent/hooks/action_preprocessor.py
index 51083aa..09f6ec6 100644
--- a/lagent/hooks/action_preprocessor.py
+++ b/lagent/hooks/action_preprocessor.py
@@ -1,3 +1,4 @@
+import inspect
from copy import deepcopy
from lagent.schema import ActionReturn, ActionStatusCode, FunctionCall
@@ -11,17 +12,20 @@ class ActionPreprocessor(Hook):
"""
def before_action(self, executor, message, session_id):
- assert isinstance(message.formatted, FunctionCall) or (
- isinstance(message.formatted, dict) and 'name' in message.content
- and 'parameters' in message.formatted) or (
+ assert (
+ isinstance(message.formatted, FunctionCall)
+ or (
+ isinstance(message.formatted, dict) and 'name' in message.content and 'parameters' in message.formatted
+ )
+ or (
'action' in message.formatted
and 'parameters' in message.formatted['action']
- and 'name' in message.formatted['action'])
+ and 'name' in message.formatted['action']
+ )
+ )
if isinstance(message.formatted, dict):
- name = message.formatted.get('name',
- message.formatted['action']['name'])
- parameters = message.formatted.get(
- 'parameters', message.formatted['action']['parameters'])
+ name = message.formatted.get('name', message.formatted['action']['name'])
+ parameters = message.formatted.get('parameters', message.formatted['action']['parameters'])
else:
name = message.formatted.name
parameters = message.formatted.parameters
@@ -48,15 +52,28 @@ def __init__(self, code_parameter: str = 'command'):
def before_action(self, executor, message, session_id):
message = deepcopy(message)
- assert isinstance(message.formatted, dict) and set(
- message.formatted).issuperset(
- {'tool_type', 'thought', 'action', 'status'})
- if isinstance(message.formatted['action'], str):
- # encapsulate code interpreter arguments
- action_name = next(iter(executor.actions))
- parameters = {self.code_parameter: message.formatted['action']}
- if action_name in ['AsyncIPythonInterpreter']:
- parameters['session_id'] = session_id
- message.formatted['action'] = dict(
- name=action_name, parameters=parameters)
+ assert isinstance(message.formatted, dict) and set(message.formatted).issuperset(
+ {'tool_type', 'thought', 'action', 'status'}
+ )
+ if message.formatted['tool_type'] == 'interpreter' and isinstance(message.formatted['action'], str):
+ for action in executor.actions.values():
+ if hasattr(action, 'run') and callable(action.run):
+ param = inspect.signature(action.run).parameters
+ if self.code_parameter in param:
+ # encapsulate code interpreter arguments
+ message.formatted['action'] = dict(
+ name=action.name, parameters={self.code_parameter: message.formatted['action']}
+ )
+ break
+ else:
+ raise ValueError(
+ f"Action '{message.formatted['action']}' is not supported by any action in the executor."
+ )
+ tool_call = message.formatted['action']
+ if (
+ isinstance(tool_call, dict)
+ and isinstance(tool_call.get('parameters', {}), dict)
+ and executor.actions[tool_call['name'].split('.')[0]].is_stateful
+ ):
+ tool_call['parameters']['session_id'] = session_id
return super().before_action(executor, message, session_id)
From a656e604d283654c44b10c154d70b3daa4f797f7 Mon Sep 17 00:00:00 2001
From: braisedpork1964 <497494458@qq.com>
Date: Tue, 1 Jul 2025 07:31:48 +0000
Subject: [PATCH 5/5] update examples
---
.../run_async_agent_api_model_with_mcp.py | 70 +------------------
1 file changed, 1 insertion(+), 69 deletions(-)
diff --git a/examples/run_async_agent_api_model_with_mcp.py b/examples/run_async_agent_api_model_with_mcp.py
index 029827b..ff07f21 100644
--- a/examples/run_async_agent_api_model_with_mcp.py
+++ b/examples/run_async_agent_api_model_with_mcp.py
@@ -1,10 +1,5 @@
import asyncio
import json
-import os
-import time
-from typing import List
-
-from volcenginesdkarkruntime import Ark, AsyncArk
from lagent.actions import AsyncActionExecutor, AsyncMCPClient
from lagent.agents import AsyncAgentForInternLM
@@ -14,75 +9,12 @@
from lagent.prompts import PluginParser
from lagent.schema import AgentMessage
-
-class LcAsyncAPI(AsyncGPTAPI):
- def __init__(self, model, api_key=None, max_tokens=4096, max_retries=4, **gen_params):
- if api_key is None:
- raise ValueError("api_key is required")
- self.model = model
- self.max_tokens = max_tokens
- self.retry = max_retries
- self.client = AsyncArk(
- api_key=api_key, timeout=900, max_retries=self.retry, base_url='https://ark.cn-beijing.volces.com/api/v3'
- )
- super().__init__(**gen_params)
-
- async def _chat(self, messages: List[dict], **gen_params) -> str:
- """Generate completion from a list of templates.
-
- Args:
- messages (List[dict]): a list of prompt dictionaries
- gen_params: additional generation configuration
-
- Returns:
- str: The generated string.
- """
- assert isinstance(messages, list)
-
- max_num_retries = 0
- while max_num_retries < self.retry:
- try:
- response = await self.client.chat.completions.create(
- model=self.model,
- messages=messages,
- max_tokens=self.max_tokens,
- temperature=gen_params.get('temperature', 0.8),
- top_p=gen_params.get('top_p', 0.95),
- stream=True,
- )
- reasoning_content = ""
- content = ""
- async for chunk in response:
- if (
- hasattr(chunk.choices[0].delta, 'reasoning_content')
- and chunk.choices[0].delta.reasoning_content
- ):
- reasoning_content += chunk.choices[0].delta.reasoning_content
- # print(chunk.choices[0].delta.reasoning_content, end="")
- else:
- content += chunk.choices[0].delta.content
- # print(chunk.choices[0].delta.content, end="")
- # response = json.loads(response.json())
- # reasoning_content = response['choices'][0]['message']['reasoning_content'].strip()
- # content = response['choices'][0]['message']['content'].strip()
- return content if reasoning_content == "" else "" + reasoning_content + "\n" + content
-
- except Exception as error:
- self.logger.error(str(error))
- time.sleep(20)
- max_num_retries += 1
-
- raise RuntimeError(
- 'Calling OpenAI failed after retrying for ' f'{max_num_retries} times. Check the logs for ' 'details.'
- )
-
-
TEMPLATE = (
"You have access to the following tools:\n{tool_description}\nPlease provide"
" your thought process when you need to use a tool, followed by the call statement in this format:"
"\n{invocation_format}"
)
-llm = dict(type=LcAsyncAPI, model=None, api_key=None, top_p=0.95, temperature=0.6, max_tokens=16384, max_retries=50)
+llm = dict(type=AsyncGPTAPI, model_type=None, retry=50, key=None, top_p=0.95, temperature=0.6, max_new_tokens=16384)
plugin = dict(
type=AsyncMCPClient,
name='PlayWright',