From a2c3dbb664ba7c94d2a003f60e56411acfe93d60 Mon Sep 17 00:00:00 2001 From: braisedpork1964 <497494458@qq.com> Date: Tue, 24 Jun 2025 06:07:06 +0000 Subject: [PATCH 1/5] support MCP clients --- .../run_async_agent_api_model_with_mcp.py | 127 ++++++++++++++ lagent/actions/__init__.py | 2 + lagent/actions/mcp_client.py | 159 ++++++++++++++++++ requirements/runtime.txt | 3 +- 4 files changed, 290 insertions(+), 1 deletion(-) create mode 100644 examples/run_async_agent_api_model_with_mcp.py create mode 100644 lagent/actions/mcp_client.py diff --git a/examples/run_async_agent_api_model_with_mcp.py b/examples/run_async_agent_api_model_with_mcp.py new file mode 100644 index 0000000..029827b --- /dev/null +++ b/examples/run_async_agent_api_model_with_mcp.py @@ -0,0 +1,127 @@ +import asyncio +import json +import os +import time +from typing import List + +from volcenginesdkarkruntime import Ark, AsyncArk + +from lagent.actions import AsyncActionExecutor, AsyncMCPClient +from lagent.agents import AsyncAgentForInternLM +from lagent.agents.aggregator import InternLMToolAggregator +from lagent.agents.stream import get_plugin_prompt +from lagent.llms import AsyncGPTAPI +from lagent.prompts import PluginParser +from lagent.schema import AgentMessage + + +class LcAsyncAPI(AsyncGPTAPI): + def __init__(self, model, api_key=None, max_tokens=4096, max_retries=4, **gen_params): + if api_key is None: + raise ValueError("api_key is required") + self.model = model + self.max_tokens = max_tokens + self.retry = max_retries + self.client = AsyncArk( + api_key=api_key, timeout=900, max_retries=self.retry, base_url='https://ark.cn-beijing.volces.com/api/v3' + ) + super().__init__(**gen_params) + + async def _chat(self, messages: List[dict], **gen_params) -> str: + """Generate completion from a list of templates. + + Args: + messages (List[dict]): a list of prompt dictionaries + gen_params: additional generation configuration + + Returns: + str: The generated string. + """ + assert isinstance(messages, list) + + max_num_retries = 0 + while max_num_retries < self.retry: + try: + response = await self.client.chat.completions.create( + model=self.model, + messages=messages, + max_tokens=self.max_tokens, + temperature=gen_params.get('temperature', 0.8), + top_p=gen_params.get('top_p', 0.95), + stream=True, + ) + reasoning_content = "" + content = "" + async for chunk in response: + if ( + hasattr(chunk.choices[0].delta, 'reasoning_content') + and chunk.choices[0].delta.reasoning_content + ): + reasoning_content += chunk.choices[0].delta.reasoning_content + # print(chunk.choices[0].delta.reasoning_content, end="") + else: + content += chunk.choices[0].delta.content + # print(chunk.choices[0].delta.content, end="") + # response = json.loads(response.json()) + # reasoning_content = response['choices'][0]['message']['reasoning_content'].strip() + # content = response['choices'][0]['message']['content'].strip() + return content if reasoning_content == "" else "" + reasoning_content + "\n" + content + + except Exception as error: + self.logger.error(str(error)) + time.sleep(20) + max_num_retries += 1 + + raise RuntimeError( + 'Calling OpenAI failed after retrying for ' f'{max_num_retries} times. Check the logs for ' 'details.' + ) + + +TEMPLATE = ( + "You have access to the following tools:\n{tool_description}\nPlease provide" + " your thought process when you need to use a tool, followed by the call statement in this format:" + "\n{invocation_format}" +) +llm = dict(type=LcAsyncAPI, model=None, api_key=None, top_p=0.95, temperature=0.6, max_tokens=16384, max_retries=50) +plugin = dict( + type=AsyncMCPClient, + name='PlayWright', + server_type='stdio', + command='npx', + args=["@playwright/mcp@latest", '--isolated', '--no-sandbox'], +) +agent = AsyncAgentForInternLM( + llm, + plugin, + template=TEMPLATE.format( + tool_description=get_plugin_prompt(plugin), + invocation_format='```json\n{"name": {{tool name}}, "parameters": {{keyword arguments}}}\n```\n', + ), + output_format=PluginParser(begin="```json\n", end="\n```\n", validate=lambda x: json.loads(x.rstrip('`'))), + aggregator=InternLMToolAggregator(environment_role='system'), +) +msg = AgentMessage( + sender='user', + content='解释一下MCP中Sampling Flow的工作机制,参考https://modelcontextprotocol.io/docs/concepts/sampling', +) + +# proj_dir = os.path.dirname(os.path.dirname(__file__)) +# executor = AsyncActionExecutor( +# dict( +# type=AsyncMCPClient, +# name='FS', +# server_type='stdio', +# command='npx', +# args=['-y', '@modelcontextprotocol/server-filesystem', os.path.join(proj_dir, 'docs')], +# ) +# ) +# msg = AgentMessage( +# sender='assistant', +# content=dict( +# name='FS.read_file', +# parameters=dict(path=os.path.join(proj_dir, 'docs/en/get_started/install.md')), +# ), +# ) +loop = asyncio.get_event_loop() +res = loop.run_until_complete(agent(msg)) +print(res.content) diff --git a/lagent/actions/__init__.py b/lagent/actions/__init__.py index b75a226..0398d82 100644 --- a/lagent/actions/__init__.py +++ b/lagent/actions/__init__.py @@ -8,6 +8,7 @@ from .ipython_interactive import AsyncIPythonInteractive, IPythonInteractive from .ipython_interpreter import AsyncIPythonInterpreter, IPythonInterpreter from .ipython_manager import IPythonInteractiveManager +from .mcp_client import AsyncMCPClient from .parser import BaseParser, JsonParser, TupleParser from .ppt import PPT, AsyncPPT from .python_interpreter import AsyncPythonInterpreter, PythonInterpreter @@ -39,6 +40,7 @@ 'AsyncPPT', 'WebBrowser', 'AsyncWebBrowser', + 'AsyncMCPClient', 'BaseParser', 'JsonParser', 'TupleParser', diff --git a/lagent/actions/mcp_client.py b/lagent/actions/mcp_client.py new file mode 100644 index 0000000..5994401 --- /dev/null +++ b/lagent/actions/mcp_client.py @@ -0,0 +1,159 @@ +import asyncio +import logging +from contextlib import AsyncExitStack +from typing import Literal, TypeAlias + +from lagent.actions.base_action import BaseAction +from lagent.actions.parser import JsonParser, ParseError +from lagent.schema import ActionReturn, ActionStatusCode + +ServerType: TypeAlias = Literal["stdio", "sse", "http"] + +logger = logging.getLogger(__name__) + + +class AsyncMCPClient(BaseAction): + """Model Context Protocol (MCP) Client for asynchronous communication with MCP servers. + + Args: + name (str): The name of the action. Make sure it is unique among all actions. + server_type (ServerType): The type of MCP server to connect to. Options are "stdio", "sse", or "http". + **server_params: Additional parameters for the server connection, which may include: + - For stdio servers: + - command (str): The command to run the MCP server. + - args (list, optional): Additional arguments for the command. + - env (dict, optional): Environment variables for the command. + - cwd (str, optional): Current working directory for the command. + - For sse servers: + - url (str): The URL of the MCP server. + - headers (dict, optional): Headers to include in the request. + - timeout (int, optional): Timeout for the request. + - sse_read_timeout (int, optional): Timeout for reading SSE events. + - For http servers: + - url (str): The URL of the MCP server. + - headers (dict, optional): Headers to include in the request. + - timeout (int, optional): Timeout for the request. + - sse_read_timeout (int, optional): Timeout for reading SSE events. + - terminate_on_close (bool, optional): Whether to terminate the connection on close. + """ + + def __init__(self, name: str, server_type: ServerType, **server_params): + self._is_toolkit = True + self._session = None + self.server_type = server_type + self.server_params = server_params + self.exit_stack = AsyncExitStack() + # initialze the session + loop = asyncio.get_event_loop() + if loop.is_running(): + fut = asyncio.run_coroutine_threadsafe(self.list_tools(), loop) + tools = fut.result() + else: + tools = loop.run_until_complete(self.list_tools()) + self._api_names = {tool.name for tool in tools} + super().__init__( + description=dict( + name=name, + api_list=[ + { + 'name': tool.name, + 'description': tool.description, + 'parameters': [ + {'name': k, 'type': v['type'].upper(), 'description': v.get('description', '')} + for k, v in tool.inputSchema['properties'].items() + ], + 'required': tool.inputSchema.get('required', []), + } + for tool in tools + ], + ), + parser=JsonParser, + ) + + async def _initialize(self): + """Initialize the MCP client and connect to the server.""" + if self._session is not None: + return + + from mcp import ClientSession, StdioServerParameters + + if self.server_type == "stdio": + from mcp.client.stdio import stdio_client + + logger.info( + f"Connecting to stdio MCP server with command: {self.server_params['command']} " + f"{self.server_params.get('args', [])}" + ) + + client_kwargs = {"command": self.server_params["command"]} + for key in ["args", "env", "cwd"]: + if self.server_params.get(key) is not None: + client_kwargs[key] = self.server_params[key] + server_params = StdioServerParameters(**client_kwargs) + read, write = await self.exit_stack.enter_async_context(stdio_client(server_params)) + elif self.server_type == "sse": + from mcp.client.sse import sse_client + + logger.info(f"Connecting to SSE MCP server at: {self.server_params['url']}") + + client_kwargs = {"url": self.server_params["url"]} + for key in ["headers", "timeout", "sse_read_timeout"]: + if self.server_params.get(key) is not None: + client_kwargs[key] = self.server_params[key] + read, write = await self.exit_stack.enter_async_context(sse_client(**client_kwargs)) + elif self.server_type == "http": + from mcp.client.streamable_http import streamablehttp_client + + logger.info(f"Connecting to StreamableHTTP MCP server at: {self.server_params['url']}") + + client_kwargs = {"url": self.server_params["url"]} + for key in ["headers", "timeout", "sse_read_timeout", "terminate_on_close"]: + if self.server_params.get(key) is not None: + client_kwargs[key] = self.server_params[key] + read, write, _ = await self.exit_stack.enter_async_context(streamablehttp_client(**client_kwargs)) + else: + raise ValueError(f"Unsupported server type: {self.server_type}") + + self._session = await self.exit_stack.enter_async_context(ClientSession(read, write)) + + async def list_tools(self) -> list: + await self._initialize() + return (await self._session.list_tools()).tools + + async def cleanup(self): + await self.exit_stack.aclose() + + def __del__(self): + loop = asyncio.get_event_loop() + if loop.is_running(): + fut = asyncio.run_coroutine_threadsafe(self.cleanup(), loop) + fut.result() + else: + loop.run_until_complete(self.cleanup()) + + async def __call__(self, inputs: str, name: str) -> ActionReturn: + fallback_args = {'inputs': inputs, 'name': name} + if name not in self._api_names: + return ActionReturn( + fallback_args, type=self.name, errmsg=f'invalid API: {name}', state=ActionStatusCode.API_ERROR + ) + try: + inputs = self._parser.parse_inputs(inputs, name) + except ParseError as exc: + return ActionReturn(fallback_args, type=self.name, errmsg=exc.err_msg, state=ActionStatusCode.ARGS_ERROR) + try: + await self._initialize() + outputs = await self._session.call_tool(name, inputs) + outputs = outputs.content[0].text + except Exception as exc: + return ActionReturn(inputs, type=self.name, errmsg=str(exc), state=ActionStatusCode.API_ERROR) + if isinstance(outputs, ActionReturn): + action_return = outputs + if not action_return.args: + action_return.args = inputs + if not action_return.type: + action_return.type = self.name + else: + result = self._parser.parse_outputs(outputs) + action_return = ActionReturn(inputs, type=self.name, result=result) + return action_return diff --git a/requirements/runtime.txt b/requirements/runtime.txt index ac0b85c..14930de 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -12,7 +12,8 @@ jsonschema jupyter==1.0.0 jupyter_client==8.6.2 jupyter_core==5.7.2 -pydantic==2.6.4 +mcp +pydantic requests tenacity termcolor From 5266b908ff968cb4943a05e0692af187f99cef19 Mon Sep 17 00:00:00 2001 From: braisedpork1964 <497494458@qq.com> Date: Tue, 24 Jun 2025 12:22:01 +0000 Subject: [PATCH 2/5] add session management in concurrency --- lagent/actions/mcp_client.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/lagent/actions/mcp_client.py b/lagent/actions/mcp_client.py index 5994401..9586c48 100644 --- a/lagent/actions/mcp_client.py +++ b/lagent/actions/mcp_client.py @@ -39,7 +39,7 @@ class AsyncMCPClient(BaseAction): def __init__(self, name: str, server_type: ServerType, **server_params): self._is_toolkit = True - self._session = None + self._sessions: dict = {} self.server_type = server_type self.server_params = server_params self.exit_stack = AsyncExitStack() @@ -70,10 +70,10 @@ def __init__(self, name: str, server_type: ServerType, **server_params): parser=JsonParser, ) - async def _initialize(self): + async def initialize(self, session_id): """Initialize the MCP client and connect to the server.""" - if self._session is not None: - return + if session_id in self._sessions: + return self._sessions[session_id] from mcp import ClientSession, StdioServerParameters @@ -114,15 +114,18 @@ async def _initialize(self): else: raise ValueError(f"Unsupported server type: {self.server_type}") - self._session = await self.exit_stack.enter_async_context(ClientSession(read, write)) - - async def list_tools(self) -> list: - await self._initialize() - return (await self._session.list_tools()).tools + session = await self.exit_stack.enter_async_context(ClientSession(read, write)) + await session.initialize() + self._sessions[session_id] = session + return session async def cleanup(self): await self.exit_stack.aclose() + async def list_tools(self, session_id=0) -> list: + session = await self.initialize(session_id=session_id) + return (await session.list_tools()).tools + def __del__(self): loop = asyncio.get_event_loop() if loop.is_running(): @@ -142,8 +145,8 @@ async def __call__(self, inputs: str, name: str) -> ActionReturn: except ParseError as exc: return ActionReturn(fallback_args, type=self.name, errmsg=exc.err_msg, state=ActionStatusCode.ARGS_ERROR) try: - await self._initialize() - outputs = await self._session.call_tool(name, inputs) + session = await self.initialize(inputs.pop('session_id', 0)) + outputs = await session.call_tool(name, inputs) outputs = outputs.content[0].text except Exception as exc: return ActionReturn(inputs, type=self.name, errmsg=str(exc), state=ActionStatusCode.API_ERROR) From fdadebb594571690b7c4055c3fb32fc77b2fdb86 Mon Sep 17 00:00:00 2001 From: braisedpork1964 <497494458@qq.com> Date: Wed, 25 Jun 2025 11:28:37 +0000 Subject: [PATCH 3/5] fix acquiring the loop in action initialization --- lagent/actions/mcp_client.py | 37 ++++++++++++++++++++++++++++++++---- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/lagent/actions/mcp_client.py b/lagent/actions/mcp_client.py index 9586c48..9d5683a 100644 --- a/lagent/actions/mcp_client.py +++ b/lagent/actions/mcp_client.py @@ -10,6 +10,32 @@ ServerType: TypeAlias = Literal["stdio", "sse", "http"] logger = logging.getLogger(__name__) +_loop = None + + +def _get_event_loop(): + try: + event_loop = asyncio.get_event_loop() + except Exception: + logger.warning('Can not found event loop in current thread. Create a new event loop.') + event_loop = asyncio.new_event_loop() + asyncio.set_event_loop(event_loop) + + if event_loop.is_running(): + global _loop + if _loop: + return _loop + + from threading import Thread + + def _start_loop(loop): + asyncio.set_event_loop(loop) + loop.run_forever() + + event_loop = asyncio.new_event_loop() + Thread(target=_start_loop, args=(event_loop,), daemon=True).start() + _loop = event_loop + return event_loop class AsyncMCPClient(BaseAction): @@ -37,14 +63,16 @@ class AsyncMCPClient(BaseAction): - terminate_on_close (bool, optional): Whether to terminate the connection on close. """ + is_stateful = True + def __init__(self, name: str, server_type: ServerType, **server_params): self._is_toolkit = True self._sessions: dict = {} self.server_type = server_type self.server_params = server_params self.exit_stack = AsyncExitStack() - # initialze the session - loop = asyncio.get_event_loop() + # get the list of tools from the MCP server + loop = _get_event_loop() if loop.is_running(): fut = asyncio.run_coroutine_threadsafe(self.list_tools(), loop) tools = fut.result() @@ -127,7 +155,7 @@ async def list_tools(self, session_id=0) -> list: return (await session.list_tools()).tools def __del__(self): - loop = asyncio.get_event_loop() + loop = _get_event_loop() if loop.is_running(): fut = asyncio.run_coroutine_threadsafe(self.cleanup(), loop) fut.result() @@ -135,6 +163,7 @@ def __del__(self): loop.run_until_complete(self.cleanup()) async def __call__(self, inputs: str, name: str) -> ActionReturn: + session_id = inputs.pop('session_id', 0) if isinstance(inputs, dict) else 0 fallback_args = {'inputs': inputs, 'name': name} if name not in self._api_names: return ActionReturn( @@ -145,7 +174,7 @@ async def __call__(self, inputs: str, name: str) -> ActionReturn: except ParseError as exc: return ActionReturn(fallback_args, type=self.name, errmsg=exc.err_msg, state=ActionStatusCode.ARGS_ERROR) try: - session = await self.initialize(inputs.pop('session_id', 0)) + session = await self.initialize(session_id) outputs = await session.call_tool(name, inputs) outputs = outputs.content[0].text except Exception as exc: From faa4d40ff75607af8869f38b28686f58dd89e10c Mon Sep 17 00:00:00 2001 From: braisedpork1964 <497494458@qq.com> Date: Fri, 27 Jun 2025 03:27:24 +0000 Subject: [PATCH 4/5] inject session id into stateful action calling --- lagent/actions/base_action.py | 2 + lagent/actions/ipython_interpreter.py | 106 +++++++++++--------------- lagent/hooks/action_preprocessor.py | 55 ++++++++----- 3 files changed, 82 insertions(+), 81 deletions(-) diff --git a/lagent/actions/base_action.py b/lagent/actions/base_action.py index b42036a..ed31b22 100644 --- a/lagent/actions/base_action.py +++ b/lagent/actions/base_action.py @@ -340,6 +340,8 @@ def sub(self, a, b): action = Calculator() """ + is_stateful = False + def __init__( self, description: Optional[dict] = None, diff --git a/lagent/actions/ipython_interpreter.py b/lagent/actions/ipython_interpreter.py index 68e9a0d..a022e4a 100644 --- a/lagent/actions/ipython_interpreter.py +++ b/lagent/actions/ipython_interpreter.py @@ -52,9 +52,7 @@ async def async_run_code( assert iopub_timeout > interrupt_after try: - async def get_iopub_msg_with_death_detection(kc: AsyncKernelClient, - *, - timeout=None): + async def get_iopub_msg_with_death_detection(kc: AsyncKernelClient, *, timeout=None): loop = asyncio.get_running_loop() dead_fut = loop.create_future() @@ -71,8 +69,7 @@ def dead(): km.add_restart_callback(restarting, "restart") km.add_restart_callback(dead, "dead") try: - done, _ = await asyncio.wait( - [dead_fut, msg_task], return_when=asyncio.FIRST_COMPLETED) + done, _ = await asyncio.wait([dead_fut, msg_task], return_when=asyncio.FIRST_COMPLETED) if dead_fut in done: raise KernelDeath() assert msg_task in done @@ -88,13 +85,21 @@ async def send_interrupt(): await km.interrupt_kernel() @retry( - retry=retry_if_result(lambda ret: ret[-1].strip() in [ - 'KeyboardInterrupt', - f"Kernel didn't respond in {wait_for_ready_timeout} seconds", - ] if isinstance(ret, tuple) else False), + retry=retry_if_result( + lambda ret: ( + ret[-1].strip() + in [ + 'KeyboardInterrupt', + f"Kernel didn't respond in {wait_for_ready_timeout} seconds", + ] + if isinstance(ret, tuple) + else False + ) + ), stop=stop_after_attempt(3), wait=wait_fixed(1), - retry_error_callback=lambda state: state.outcome.result()) + retry_error_callback=lambda state: state.outcome.result(), + ) async def run(): execute_result = None error_traceback = None @@ -106,11 +111,9 @@ async def run(): await kc.wait_for_ready(timeout=wait_for_ready_timeout) msg_id = kc.execute(code) while True: - message = await get_iopub_msg_with_death_detection( - kc, timeout=iopub_timeout) + message = await get_iopub_msg_with_death_detection(kc, timeout=iopub_timeout) if logger.isEnabledFor(logging.DEBUG): - logger.debug( - json.dumps(message, indent=2, default=str)) + logger.debug(json.dumps(message, indent=2, default=str)) assert message["parent_header"]["msg_id"] == msg_id msg_type = message["msg_type"] if msg_type == "status": @@ -136,8 +139,7 @@ async def run(): if interrupt_after: run_task = asyncio.create_task(run()) send_interrupt_task = asyncio.create_task(send_interrupt()) - done, _ = await asyncio.wait([run_task, send_interrupt_task], - return_when=asyncio.FIRST_COMPLETED) + done, _ = await asyncio.wait([run_task, send_interrupt_task], return_when=asyncio.FIRST_COMPLETED) if run_task in done: send_interrupt_task.cancel() else: @@ -216,13 +218,10 @@ def reset(self): if not self._initialized: self.initialize() else: - code = "get_ipython().run_line_magic('reset', '-f')\n" + \ - START_CODE.format(self.user_data_dir) + code = "get_ipython().run_line_magic('reset', '-f')\n" + START_CODE.format(self.user_data_dir) self._call(code, None) - def _call(self, - command: str, - timeout: Optional[int] = None) -> Tuple[str, bool]: + def _call(self, command: str, timeout: Optional[int] = None) -> Tuple[str, bool]: self.initialize() command = extract_code(command) @@ -261,16 +260,14 @@ def _inner_call(): text = msg['content']['data'].get('text/plain', '') if 'image/png' in msg['content']['data']: image_b64 = msg['content']['data']['image/png'] - image_url = publish_image_to_local( - image_b64, self.work_dir) + image_url = publish_image_to_local(image_b64, self.work_dir) image_idx += 1 image = '![fig-%03d](%s)' % (image_idx, image_url) elif msg_type == 'display_data': if 'image/png' in msg['content']['data']: image_b64 = msg['content']['data']['image/png'] - image_url = publish_image_to_local( - image_b64, self.work_dir) + image_url = publish_image_to_local(image_b64, self.work_dir) image_idx += 1 image = '![fig-%03d](%s)' % (image_idx, image_url) @@ -281,8 +278,7 @@ def _inner_call(): text = msg['content']['text'] elif msg_type == 'error': succeed = False - text = escape_ansi('\n'.join( - msg['content']['traceback'])) + text = escape_ansi('\n'.join(msg['content']['traceback'])) if 'M6_CODE_INTERPRETER_TIMEOUT' in text: text = f'Timeout. No response after {timeout} seconds.' # noqa except queue.Empty: @@ -349,8 +345,7 @@ def run(self, command: str, timeout: Optional[int] = None) -> ActionReturn: # text=result['text'], image=result.get('image', [])[0]) tool_return.state = ActionStatusCode.SUCCESS else: - tool_return.errmsg = result.get('text', '') if isinstance( - result, dict) else result + tool_return.errmsg = result.get('text', '') if isinstance(result, dict) else result tool_return.state = ActionStatusCode.API_ERROR return tool_return @@ -371,6 +366,7 @@ class AsyncIPythonInterpreter(AsyncActionMixin, IPythonInterpreter): action's inputs and outputs. Defaults to :class:`JsonParser`. """ + is_stateful = True _UNBOUND_KERNEL_CLIENTS = asyncio.Queue() def __init__( @@ -390,8 +386,7 @@ def __init__( c = Config() c.KernelManager.transport = 'ipc' - self._amkm = AsyncMultiKernelManager( - config=c, connection_dir=connection_dir) + self._amkm = AsyncMultiKernelManager(config=c, connection_dir=connection_dir) self._max_kernels = max_kernels self._reuse_kernel = reuse_kernel self._sem = asyncio.Semaphore(startup_rate) @@ -403,25 +398,23 @@ async def initialize(self, session_id: str): if session_id in self._KERNEL_CLIENTS: return self._KERNEL_CLIENTS[session_id] if self._reuse_kernel and not self._UNBOUND_KERNEL_CLIENTS.empty(): - self._KERNEL_CLIENTS[ - session_id] = await self._UNBOUND_KERNEL_CLIENTS.get() + self._KERNEL_CLIENTS[session_id] = await self._UNBOUND_KERNEL_CLIENTS.get() return self._KERNEL_CLIENTS[session_id] async with self._sem: - if self._max_kernels is None or len( - self._KERNEL_CLIENTS - ) + self._UNBOUND_KERNEL_CLIENTS.qsize() < self._max_kernels: + if ( + self._max_kernels is None + or len(self._KERNEL_CLIENTS) + self._UNBOUND_KERNEL_CLIENTS.qsize() < self._max_kernels + ): kernel_id = None try: kernel_id = await self._amkm.start_kernel() kernel = self._amkm.get_kernel(kernel_id) client = kernel.client() _, error_stacktrace, stream_text = await async_run_code( - kernel, - START_CODE.format(self.user_data_dir), - shutdown_kernel=False) + kernel, START_CODE.format(self.user_data_dir), shutdown_kernel=False + ) # check if the output of START_CODE meets expectations - if not (error_stacktrace is None - and stream_text == ''): + if not (error_stacktrace is None and stream_text == ''): raise RuntimeError except Exception as e: print(f'Starting kernel error: {e}') @@ -431,15 +424,11 @@ async def initialize(self, session_id: str): await asyncio.sleep(1) continue if self._max_kernels is None: - self._KERNEL_CLIENTS[session_id] = (kernel_id, kernel, - client) + self._KERNEL_CLIENTS[session_id] = (kernel_id, kernel, client) return kernel_id, kernel, client async with self._lock: - if len(self._KERNEL_CLIENTS - ) + self._UNBOUND_KERNEL_CLIENTS.qsize( - ) < self._max_kernels: - self._KERNEL_CLIENTS[session_id] = (kernel_id, - kernel, client) + if len(self._KERNEL_CLIENTS) + self._UNBOUND_KERNEL_CLIENTS.qsize() < self._max_kernels: + self._KERNEL_CLIENTS[session_id] = (kernel_id, kernel, client) return kernel_id, kernel, client await self._amkm.shutdown_kernel(kernel_id) self._amkm.remove_kernel(kernel_id) @@ -450,8 +439,7 @@ async def reset(self, session_id: str): if session_id not in self._KERNEL_CLIENTS: return _, kernel, _ = self._KERNEL_CLIENTS[session_id] - code = "get_ipython().run_line_magic('reset', '-f')\n" + \ - START_CODE.format(self.user_data_dir) + code = "get_ipython().run_line_magic('reset', '-f')\n" + START_CODE.format(self.user_data_dir) await async_run_code(kernel, code, shutdown_kernel=False) async def shutdown(self, session_id: str): @@ -467,18 +455,15 @@ async def close_session(self, session_id: str): if self._reuse_kernel: if session_id in self._KERNEL_CLIENTS: await self.reset(session_id) - await self._UNBOUND_KERNEL_CLIENTS.put( - self._KERNEL_CLIENTS.pop(session_id)) + await self._UNBOUND_KERNEL_CLIENTS.put(self._KERNEL_CLIENTS.pop(session_id)) else: await self.shutdown(session_id) async def _call(self, command, timeout=None, session_id=None): _, kernel, _ = await self.initialize(str(session_id)) result = await async_run_code( - kernel, - extract_code(command), - interrupt_after=timeout or self.timeout, - shutdown_kernel=False) + kernel, extract_code(command), interrupt_after=timeout or self.timeout, shutdown_kernel=False + ) execute_result, error_stacktrace, stream_text = result if error_stacktrace is not None: ret = re.sub('^-*\n', '', escape_ansi(error_stacktrace)) @@ -492,10 +477,7 @@ async def _call(self, command, timeout=None, session_id=None): return status, ret @tool_api - async def run(self, - command: str, - timeout: Optional[int] = None, - session_id: Optional[str] = None) -> ActionReturn: + async def run(self, command: str, timeout: Optional[int] = None, session_id: Optional[str] = None) -> ActionReturn: r"""When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is disabled. Do not make external web requests or API calls as they will fail. Args: @@ -516,8 +498,7 @@ async def run(self, # text=result['text'], image=result.get('image', [])[0]) tool_return.state = ActionStatusCode.SUCCESS else: - tool_return.errmsg = result.get('text', '') if isinstance( - result, dict) else result + tool_return.errmsg = result.get('text', '') if isinstance(result, dict) else result tool_return.state = ActionStatusCode.API_ERROR return tool_return @@ -549,6 +530,7 @@ def escape_ansi(line): def publish_image_to_local(image_base64: str, work_dir='./work_dir/tmp_dir'): import PIL.Image + image_file = str(uuid.uuid4()) + '.png' local_image_file = os.path.join(work_dir, image_file) diff --git a/lagent/hooks/action_preprocessor.py b/lagent/hooks/action_preprocessor.py index 51083aa..09f6ec6 100644 --- a/lagent/hooks/action_preprocessor.py +++ b/lagent/hooks/action_preprocessor.py @@ -1,3 +1,4 @@ +import inspect from copy import deepcopy from lagent.schema import ActionReturn, ActionStatusCode, FunctionCall @@ -11,17 +12,20 @@ class ActionPreprocessor(Hook): """ def before_action(self, executor, message, session_id): - assert isinstance(message.formatted, FunctionCall) or ( - isinstance(message.formatted, dict) and 'name' in message.content - and 'parameters' in message.formatted) or ( + assert ( + isinstance(message.formatted, FunctionCall) + or ( + isinstance(message.formatted, dict) and 'name' in message.content and 'parameters' in message.formatted + ) + or ( 'action' in message.formatted and 'parameters' in message.formatted['action'] - and 'name' in message.formatted['action']) + and 'name' in message.formatted['action'] + ) + ) if isinstance(message.formatted, dict): - name = message.formatted.get('name', - message.formatted['action']['name']) - parameters = message.formatted.get( - 'parameters', message.formatted['action']['parameters']) + name = message.formatted.get('name', message.formatted['action']['name']) + parameters = message.formatted.get('parameters', message.formatted['action']['parameters']) else: name = message.formatted.name parameters = message.formatted.parameters @@ -48,15 +52,28 @@ def __init__(self, code_parameter: str = 'command'): def before_action(self, executor, message, session_id): message = deepcopy(message) - assert isinstance(message.formatted, dict) and set( - message.formatted).issuperset( - {'tool_type', 'thought', 'action', 'status'}) - if isinstance(message.formatted['action'], str): - # encapsulate code interpreter arguments - action_name = next(iter(executor.actions)) - parameters = {self.code_parameter: message.formatted['action']} - if action_name in ['AsyncIPythonInterpreter']: - parameters['session_id'] = session_id - message.formatted['action'] = dict( - name=action_name, parameters=parameters) + assert isinstance(message.formatted, dict) and set(message.formatted).issuperset( + {'tool_type', 'thought', 'action', 'status'} + ) + if message.formatted['tool_type'] == 'interpreter' and isinstance(message.formatted['action'], str): + for action in executor.actions.values(): + if hasattr(action, 'run') and callable(action.run): + param = inspect.signature(action.run).parameters + if self.code_parameter in param: + # encapsulate code interpreter arguments + message.formatted['action'] = dict( + name=action.name, parameters={self.code_parameter: message.formatted['action']} + ) + break + else: + raise ValueError( + f"Action '{message.formatted['action']}' is not supported by any action in the executor." + ) + tool_call = message.formatted['action'] + if ( + isinstance(tool_call, dict) + and isinstance(tool_call.get('parameters', {}), dict) + and executor.actions[tool_call['name'].split('.')[0]].is_stateful + ): + tool_call['parameters']['session_id'] = session_id return super().before_action(executor, message, session_id) From a656e604d283654c44b10c154d70b3daa4f797f7 Mon Sep 17 00:00:00 2001 From: braisedpork1964 <497494458@qq.com> Date: Tue, 1 Jul 2025 07:31:48 +0000 Subject: [PATCH 5/5] update examples --- .../run_async_agent_api_model_with_mcp.py | 70 +------------------ 1 file changed, 1 insertion(+), 69 deletions(-) diff --git a/examples/run_async_agent_api_model_with_mcp.py b/examples/run_async_agent_api_model_with_mcp.py index 029827b..ff07f21 100644 --- a/examples/run_async_agent_api_model_with_mcp.py +++ b/examples/run_async_agent_api_model_with_mcp.py @@ -1,10 +1,5 @@ import asyncio import json -import os -import time -from typing import List - -from volcenginesdkarkruntime import Ark, AsyncArk from lagent.actions import AsyncActionExecutor, AsyncMCPClient from lagent.agents import AsyncAgentForInternLM @@ -14,75 +9,12 @@ from lagent.prompts import PluginParser from lagent.schema import AgentMessage - -class LcAsyncAPI(AsyncGPTAPI): - def __init__(self, model, api_key=None, max_tokens=4096, max_retries=4, **gen_params): - if api_key is None: - raise ValueError("api_key is required") - self.model = model - self.max_tokens = max_tokens - self.retry = max_retries - self.client = AsyncArk( - api_key=api_key, timeout=900, max_retries=self.retry, base_url='https://ark.cn-beijing.volces.com/api/v3' - ) - super().__init__(**gen_params) - - async def _chat(self, messages: List[dict], **gen_params) -> str: - """Generate completion from a list of templates. - - Args: - messages (List[dict]): a list of prompt dictionaries - gen_params: additional generation configuration - - Returns: - str: The generated string. - """ - assert isinstance(messages, list) - - max_num_retries = 0 - while max_num_retries < self.retry: - try: - response = await self.client.chat.completions.create( - model=self.model, - messages=messages, - max_tokens=self.max_tokens, - temperature=gen_params.get('temperature', 0.8), - top_p=gen_params.get('top_p', 0.95), - stream=True, - ) - reasoning_content = "" - content = "" - async for chunk in response: - if ( - hasattr(chunk.choices[0].delta, 'reasoning_content') - and chunk.choices[0].delta.reasoning_content - ): - reasoning_content += chunk.choices[0].delta.reasoning_content - # print(chunk.choices[0].delta.reasoning_content, end="") - else: - content += chunk.choices[0].delta.content - # print(chunk.choices[0].delta.content, end="") - # response = json.loads(response.json()) - # reasoning_content = response['choices'][0]['message']['reasoning_content'].strip() - # content = response['choices'][0]['message']['content'].strip() - return content if reasoning_content == "" else "" + reasoning_content + "\n" + content - - except Exception as error: - self.logger.error(str(error)) - time.sleep(20) - max_num_retries += 1 - - raise RuntimeError( - 'Calling OpenAI failed after retrying for ' f'{max_num_retries} times. Check the logs for ' 'details.' - ) - - TEMPLATE = ( "You have access to the following tools:\n{tool_description}\nPlease provide" " your thought process when you need to use a tool, followed by the call statement in this format:" "\n{invocation_format}" ) -llm = dict(type=LcAsyncAPI, model=None, api_key=None, top_p=0.95, temperature=0.6, max_tokens=16384, max_retries=50) +llm = dict(type=AsyncGPTAPI, model_type=None, retry=50, key=None, top_p=0.95, temperature=0.6, max_new_tokens=16384) plugin = dict( type=AsyncMCPClient, name='PlayWright',