Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions examples/run_async_agent_api_model_with_mcp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import asyncio
import json

from lagent.actions import AsyncActionExecutor, AsyncMCPClient
from lagent.agents import AsyncAgentForInternLM
from lagent.agents.aggregator import InternLMToolAggregator
from lagent.agents.stream import get_plugin_prompt
from lagent.llms import AsyncGPTAPI
from lagent.prompts import PluginParser
from lagent.schema import AgentMessage

TEMPLATE = (
"You have access to the following tools:\n{tool_description}\nPlease provide"
" your thought process when you need to use a tool, followed by the call statement in this format:"
"\n{invocation_format}"
)
llm = dict(type=AsyncGPTAPI, model_type=None, retry=50, key=None, top_p=0.95, temperature=0.6, max_new_tokens=16384)
plugin = dict(
type=AsyncMCPClient,
name='PlayWright',
server_type='stdio',
command='npx',
args=["@playwright/mcp@latest", '--isolated', '--no-sandbox'],
)
agent = AsyncAgentForInternLM(
llm,
plugin,
template=TEMPLATE.format(
tool_description=get_plugin_prompt(plugin),
invocation_format='```json\n{"name": {{tool name}}, "parameters": {{keyword arguments}}}\n```\n',
),
output_format=PluginParser(begin="```json\n", end="\n```\n", validate=lambda x: json.loads(x.rstrip('`'))),
aggregator=InternLMToolAggregator(environment_role='system'),
)
msg = AgentMessage(
sender='user',
content='解释一下MCP中Sampling Flow的工作机制,参考https://modelcontextprotocol.io/docs/concepts/sampling',
)

# proj_dir = os.path.dirname(os.path.dirname(__file__))
# executor = AsyncActionExecutor(
# dict(
# type=AsyncMCPClient,
# name='FS',
# server_type='stdio',
# command='npx',
# args=['-y', '@modelcontextprotocol/server-filesystem', os.path.join(proj_dir, 'docs')],
# )
# )
# msg = AgentMessage(
# sender='assistant',
# content=dict(
# name='FS.read_file',
# parameters=dict(path=os.path.join(proj_dir, 'docs/en/get_started/install.md')),
# ),
# )
loop = asyncio.get_event_loop()
res = loop.run_until_complete(agent(msg))
print(res.content)
2 changes: 2 additions & 0 deletions lagent/actions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .ipython_interactive import AsyncIPythonInteractive, IPythonInteractive
from .ipython_interpreter import AsyncIPythonInterpreter, IPythonInterpreter
from .ipython_manager import IPythonInteractiveManager
from .mcp_client import AsyncMCPClient
from .parser import BaseParser, JsonParser, TupleParser
from .ppt import PPT, AsyncPPT
from .python_interpreter import AsyncPythonInterpreter, PythonInterpreter
Expand Down Expand Up @@ -39,6 +40,7 @@
'AsyncPPT',
'WebBrowser',
'AsyncWebBrowser',
'AsyncMCPClient',
'BaseParser',
'JsonParser',
'TupleParser',
Expand Down
2 changes: 2 additions & 0 deletions lagent/actions/base_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,8 @@ def sub(self, a, b):
action = Calculator()
"""

is_stateful = False

def __init__(
self,
description: Optional[dict] = None,
Expand Down
106 changes: 44 additions & 62 deletions lagent/actions/ipython_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,7 @@ async def async_run_code(
assert iopub_timeout > interrupt_after
try:

async def get_iopub_msg_with_death_detection(kc: AsyncKernelClient,
*,
timeout=None):
async def get_iopub_msg_with_death_detection(kc: AsyncKernelClient, *, timeout=None):
loop = asyncio.get_running_loop()
dead_fut = loop.create_future()

Expand All @@ -71,8 +69,7 @@ def dead():
km.add_restart_callback(restarting, "restart")
km.add_restart_callback(dead, "dead")
try:
done, _ = await asyncio.wait(
[dead_fut, msg_task], return_when=asyncio.FIRST_COMPLETED)
done, _ = await asyncio.wait([dead_fut, msg_task], return_when=asyncio.FIRST_COMPLETED)
if dead_fut in done:
raise KernelDeath()
assert msg_task in done
Expand All @@ -88,13 +85,21 @@ async def send_interrupt():
await km.interrupt_kernel()

@retry(
retry=retry_if_result(lambda ret: ret[-1].strip() in [
'KeyboardInterrupt',
f"Kernel didn't respond in {wait_for_ready_timeout} seconds",
] if isinstance(ret, tuple) else False),
retry=retry_if_result(
lambda ret: (
ret[-1].strip()
in [
'KeyboardInterrupt',
f"Kernel didn't respond in {wait_for_ready_timeout} seconds",
]
if isinstance(ret, tuple)
else False
)
),
stop=stop_after_attempt(3),
wait=wait_fixed(1),
retry_error_callback=lambda state: state.outcome.result())
retry_error_callback=lambda state: state.outcome.result(),
)
async def run():
execute_result = None
error_traceback = None
Expand All @@ -106,11 +111,9 @@ async def run():
await kc.wait_for_ready(timeout=wait_for_ready_timeout)
msg_id = kc.execute(code)
while True:
message = await get_iopub_msg_with_death_detection(
kc, timeout=iopub_timeout)
message = await get_iopub_msg_with_death_detection(kc, timeout=iopub_timeout)
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
json.dumps(message, indent=2, default=str))
logger.debug(json.dumps(message, indent=2, default=str))
assert message["parent_header"]["msg_id"] == msg_id
msg_type = message["msg_type"]
if msg_type == "status":
Expand All @@ -136,8 +139,7 @@ async def run():
if interrupt_after:
run_task = asyncio.create_task(run())
send_interrupt_task = asyncio.create_task(send_interrupt())
done, _ = await asyncio.wait([run_task, send_interrupt_task],
return_when=asyncio.FIRST_COMPLETED)
done, _ = await asyncio.wait([run_task, send_interrupt_task], return_when=asyncio.FIRST_COMPLETED)
if run_task in done:
send_interrupt_task.cancel()
else:
Expand Down Expand Up @@ -216,13 +218,10 @@ def reset(self):
if not self._initialized:
self.initialize()
else:
code = "get_ipython().run_line_magic('reset', '-f')\n" + \
START_CODE.format(self.user_data_dir)
code = "get_ipython().run_line_magic('reset', '-f')\n" + START_CODE.format(self.user_data_dir)
self._call(code, None)

def _call(self,
command: str,
timeout: Optional[int] = None) -> Tuple[str, bool]:
def _call(self, command: str, timeout: Optional[int] = None) -> Tuple[str, bool]:
self.initialize()
command = extract_code(command)

Expand Down Expand Up @@ -261,16 +260,14 @@ def _inner_call():
text = msg['content']['data'].get('text/plain', '')
if 'image/png' in msg['content']['data']:
image_b64 = msg['content']['data']['image/png']
image_url = publish_image_to_local(
image_b64, self.work_dir)
image_url = publish_image_to_local(image_b64, self.work_dir)
image_idx += 1
image = '![fig-%03d](%s)' % (image_idx, image_url)

elif msg_type == 'display_data':
if 'image/png' in msg['content']['data']:
image_b64 = msg['content']['data']['image/png']
image_url = publish_image_to_local(
image_b64, self.work_dir)
image_url = publish_image_to_local(image_b64, self.work_dir)
image_idx += 1
image = '![fig-%03d](%s)' % (image_idx, image_url)

Expand All @@ -281,8 +278,7 @@ def _inner_call():
text = msg['content']['text']
elif msg_type == 'error':
succeed = False
text = escape_ansi('\n'.join(
msg['content']['traceback']))
text = escape_ansi('\n'.join(msg['content']['traceback']))
if 'M6_CODE_INTERPRETER_TIMEOUT' in text:
text = f'Timeout. No response after {timeout} seconds.' # noqa
except queue.Empty:
Expand Down Expand Up @@ -349,8 +345,7 @@ def run(self, command: str, timeout: Optional[int] = None) -> ActionReturn:
# text=result['text'], image=result.get('image', [])[0])
tool_return.state = ActionStatusCode.SUCCESS
else:
tool_return.errmsg = result.get('text', '') if isinstance(
result, dict) else result
tool_return.errmsg = result.get('text', '') if isinstance(result, dict) else result
tool_return.state = ActionStatusCode.API_ERROR
return tool_return

Expand All @@ -371,6 +366,7 @@ class AsyncIPythonInterpreter(AsyncActionMixin, IPythonInterpreter):
action's inputs and outputs. Defaults to :class:`JsonParser`.
"""

is_stateful = True
_UNBOUND_KERNEL_CLIENTS = asyncio.Queue()

def __init__(
Expand All @@ -390,8 +386,7 @@ def __init__(

c = Config()
c.KernelManager.transport = 'ipc'
self._amkm = AsyncMultiKernelManager(
config=c, connection_dir=connection_dir)
self._amkm = AsyncMultiKernelManager(config=c, connection_dir=connection_dir)
self._max_kernels = max_kernels
self._reuse_kernel = reuse_kernel
self._sem = asyncio.Semaphore(startup_rate)
Expand All @@ -403,25 +398,23 @@ async def initialize(self, session_id: str):
if session_id in self._KERNEL_CLIENTS:
return self._KERNEL_CLIENTS[session_id]
if self._reuse_kernel and not self._UNBOUND_KERNEL_CLIENTS.empty():
self._KERNEL_CLIENTS[
session_id] = await self._UNBOUND_KERNEL_CLIENTS.get()
self._KERNEL_CLIENTS[session_id] = await self._UNBOUND_KERNEL_CLIENTS.get()
return self._KERNEL_CLIENTS[session_id]
async with self._sem:
if self._max_kernels is None or len(
self._KERNEL_CLIENTS
) + self._UNBOUND_KERNEL_CLIENTS.qsize() < self._max_kernels:
if (
self._max_kernels is None
or len(self._KERNEL_CLIENTS) + self._UNBOUND_KERNEL_CLIENTS.qsize() < self._max_kernels
):
kernel_id = None
try:
kernel_id = await self._amkm.start_kernel()
kernel = self._amkm.get_kernel(kernel_id)
client = kernel.client()
_, error_stacktrace, stream_text = await async_run_code(
kernel,
START_CODE.format(self.user_data_dir),
shutdown_kernel=False)
kernel, START_CODE.format(self.user_data_dir), shutdown_kernel=False
)
# check if the output of START_CODE meets expectations
if not (error_stacktrace is None
and stream_text == ''):
if not (error_stacktrace is None and stream_text == ''):
raise RuntimeError
except Exception as e:
print(f'Starting kernel error: {e}')
Expand All @@ -431,15 +424,11 @@ async def initialize(self, session_id: str):
await asyncio.sleep(1)
continue
if self._max_kernels is None:
self._KERNEL_CLIENTS[session_id] = (kernel_id, kernel,
client)
self._KERNEL_CLIENTS[session_id] = (kernel_id, kernel, client)
return kernel_id, kernel, client
async with self._lock:
if len(self._KERNEL_CLIENTS
) + self._UNBOUND_KERNEL_CLIENTS.qsize(
) < self._max_kernels:
self._KERNEL_CLIENTS[session_id] = (kernel_id,
kernel, client)
if len(self._KERNEL_CLIENTS) + self._UNBOUND_KERNEL_CLIENTS.qsize() < self._max_kernels:
self._KERNEL_CLIENTS[session_id] = (kernel_id, kernel, client)
return kernel_id, kernel, client
await self._amkm.shutdown_kernel(kernel_id)
self._amkm.remove_kernel(kernel_id)
Expand All @@ -450,8 +439,7 @@ async def reset(self, session_id: str):
if session_id not in self._KERNEL_CLIENTS:
return
_, kernel, _ = self._KERNEL_CLIENTS[session_id]
code = "get_ipython().run_line_magic('reset', '-f')\n" + \
START_CODE.format(self.user_data_dir)
code = "get_ipython().run_line_magic('reset', '-f')\n" + START_CODE.format(self.user_data_dir)
await async_run_code(kernel, code, shutdown_kernel=False)

async def shutdown(self, session_id: str):
Expand All @@ -467,18 +455,15 @@ async def close_session(self, session_id: str):
if self._reuse_kernel:
if session_id in self._KERNEL_CLIENTS:
await self.reset(session_id)
await self._UNBOUND_KERNEL_CLIENTS.put(
self._KERNEL_CLIENTS.pop(session_id))
await self._UNBOUND_KERNEL_CLIENTS.put(self._KERNEL_CLIENTS.pop(session_id))
else:
await self.shutdown(session_id)

async def _call(self, command, timeout=None, session_id=None):
_, kernel, _ = await self.initialize(str(session_id))
result = await async_run_code(
kernel,
extract_code(command),
interrupt_after=timeout or self.timeout,
shutdown_kernel=False)
kernel, extract_code(command), interrupt_after=timeout or self.timeout, shutdown_kernel=False
)
execute_result, error_stacktrace, stream_text = result
if error_stacktrace is not None:
ret = re.sub('^-*\n', '', escape_ansi(error_stacktrace))
Expand All @@ -492,10 +477,7 @@ async def _call(self, command, timeout=None, session_id=None):
return status, ret

@tool_api
async def run(self,
command: str,
timeout: Optional[int] = None,
session_id: Optional[str] = None) -> ActionReturn:
async def run(self, command: str, timeout: Optional[int] = None, session_id: Optional[str] = None) -> ActionReturn:
r"""When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is disabled. Do not make external web requests or API calls as they will fail.

Args:
Expand All @@ -516,8 +498,7 @@ async def run(self,
# text=result['text'], image=result.get('image', [])[0])
tool_return.state = ActionStatusCode.SUCCESS
else:
tool_return.errmsg = result.get('text', '') if isinstance(
result, dict) else result
tool_return.errmsg = result.get('text', '') if isinstance(result, dict) else result
tool_return.state = ActionStatusCode.API_ERROR
return tool_return

Expand Down Expand Up @@ -549,6 +530,7 @@ def escape_ansi(line):

def publish_image_to_local(image_base64: str, work_dir='./work_dir/tmp_dir'):
import PIL.Image

image_file = str(uuid.uuid4()) + '.png'
local_image_file = os.path.join(work_dir, image_file)

Expand Down
Loading