Skip to content
Open
14 changes: 10 additions & 4 deletions examples/internlm2_agent_web_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

import streamlit as st

from lagent.actions import ActionExecutor, ArxivSearch, IPythonInterpreter
from lagent.actions import PPT, ActionExecutor, ArxivSearch, BINGMap, GoogleScholar, IPythonInterpreter
# from lagent.actions.agentlego_wrapper import AgentLegoToolkit
from lagent.agents.internlm2_agent import INTERPRETER_CN, META_CN, PLUGIN_CN, Internlm2Agent, Internlm2Protocol
from lagent.llms.lmdepoly_wrapper import LMDeployClient
from lagent.llms.meta_template import INTERNLM2_META as META
Expand All @@ -23,6 +24,11 @@ def init_state(self):

action_list = [
ArxivSearch(),
PPT(),
BINGMap(key='Your api key' # noqa
),
GoogleScholar(api_key='Your api key' # noqa
)
]
st.session_state['plugin_map'] = {
action.name: action
Expand Down Expand Up @@ -104,7 +110,7 @@ def setup_sidebar(self):
actions=[IPythonInterpreter()])
else:
st.session_state['chatbot']._interpreter_executor = None
st.session_state['chatbot']._protocol._meta_template = meta_prompt
st.session_state['chatbot']._protocol.meta_prompt = meta_prompt
st.session_state['chatbot']._protocol.plugin_prompt = plugin_prompt
st.session_state[
'chatbot']._protocol.interpreter_prompt = da_prompt
Expand Down Expand Up @@ -141,8 +147,8 @@ def initialize_chatbot(self, model, plugin_action):
plugin='<|plugin|>', interpreter='<|interpreter|>'),
belong='assistant',
end='<|action_end|>\n',
), ),
max_turn=7)
)),
max_turn=15)

def render_user(self, prompt: str):
with st.chat_message('user'):
Expand Down
48 changes: 48 additions & 0 deletions lagent/actions/agentlego_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from typing import Optional

# from agentlego.parsers import DefaultParser
from agentlego.tools.remote import RemoteTool

from lagent import BaseAction
from lagent.actions.parser import JsonParser


class AgentLegoToolkit(BaseAction):

def __init__(self,
name: str,
url: Optional[str] = None,
text: Optional[str] = None,
spec_dict: Optional[dict] = None,
parser=JsonParser,
enable: bool = True):

if url is not None:
spec = dict(url=url)
elif text is not None:
spec = dict(text=text)
else:
assert spec_dict is not None
spec = dict(spec_dict=spec_dict)
if url is not None and not url.endswith('.json'):
api_list = [RemoteTool.from_url(url).to_lagent()]
else:
api_list = [
api.to_lagent() for api in RemoteTool.from_openapi(**spec)
]
api_desc = []
for api in api_list:
api_desc.append(api.description)
if len(api_list) > 1:
tool_description = dict(name=name, api_list=api_desc)
for func in api_list:
setattr(self, func.name, func.run)
else:
tool_description = api_desc[0]
setattr(self, 'run', api_list[0].run)
super().__init__(
description=tool_description, parser=parser, enable=enable)

@property
def is_toolkit(self):
return 'api_list' in self.description
4 changes: 2 additions & 2 deletions lagent/actions/builtin_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def run(self, err_msg: Optional[str] = None) -> ActionReturn:
action_return = ActionReturn(
url=None,
args=dict(text=err_msg),
errmsg=err_msg or self._err_msg,
errmsg=str(err_msg) or self._err_msg,
type=self.name,
valid=ActionValidCode.INVALID,
state=ActionStatusCode.API_ERROR)
Expand Down Expand Up @@ -76,7 +76,7 @@ def run(self, err_msg: Optional[str] = None) -> ActionReturn:
url=None,
args=dict(text=err_msg),
type=self.name,
errmsg=err_msg or self._err_msg,
errmsg=str(err_msg) or self._err_msg,
valid=ActionValidCode.INVALID,
state=ActionStatusCode.API_ERROR)
return action_return
Expand Down
30 changes: 20 additions & 10 deletions lagent/agents/internlm2_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ def format_plugin(message):
else:
new_message.append(
dict(role=message['role'], content=new_content))

return new_message

def format(self,
Expand All @@ -125,9 +124,10 @@ def format(self,
if self.meta_prompt:
formatted.append(dict(role='system', content=self.meta_prompt))
if interpreter_executor and self.interpreter_prompt:
interpreter_info = interpreter_executor.get_actions_info()[0]
interpreter_prompt = self.interpreter_prompt.format(
code_prompt=interpreter_info['description'])
# interpreter_info = interpreter_executor.get_actions_info()[0]
# interpreter_prompt = self.interpreter_prompt.format(
# code_prompt=interpreter_info['description'])
interpreter_prompt = self.interpreter_prompt
formatted.append(
dict(
role='system',
Expand Down Expand Up @@ -169,20 +169,30 @@ def parse(self, message, plugin_executor: ActionExecutor,
action = action.split(self.tool['end'].strip())[0]
return 'plugin', message, action
if self.tool['name_map']['interpreter'] in message:
message, code = message.split(
f"{self.tool['start_token']}"
f"{self.tool['name_map']['interpreter']}")
try:
message, code, *_ = message.split(
f"{self.tool['start_token']}"
f"{self.tool['name_map']['interpreter']}")
# message, code, *_ = message.split(f"{self.tool['start_token']}")
# _, code, *_ = code.split(f"{self.tool['name_map']['interpreter']}")
except ValueError:
message, code, *_ = message.split(
self.tool['name_map']['interpreter'])
tool_start_idx = message.rfind(self.tool['start_token'])
if tool_start_idx != -1:
message = message[:tool_start_idx]
message = message.strip()
code = code.split(self.tool['end'].strip())[0].strip()
return 'interpreter', message, dict(
name=interpreter_executor.action_names()[0],
parameters=dict(command=code))
name='IPythonInterpreter', parameters=dict(
command=code)) if interpreter_executor else None
return None, message.split(self.tool['start_token'])[0], None

def format_response(self, action_return, name) -> dict:
if action_return.state == ActionStatusCode.SUCCESS:
response = action_return.format_result()
else:
response = action_return.errmsg
response = str(action_return.errmsg)
content = self.execute['begin'] + response + self.execute['end']
if self.execute.get('fallback_role'):
return dict(
Expand Down
87 changes: 66 additions & 21 deletions lagent/llms/openai.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import os
import time
import warnings
from concurrent.futures import ThreadPoolExecutor
from logging import getLogger
from threading import Lock
Expand All @@ -10,6 +11,8 @@

from .base_api import BaseAPIModel

warnings.simplefilter('default')

OPENAI_API_BASE = 'https://api.openai.com/v1/chat/completions'


Expand Down Expand Up @@ -45,6 +48,7 @@ def __init__(self,
model_type: str = 'gpt-3.5-turbo',
query_per_second: int = 1,
retry: int = 2,
json_mode: bool = False,
key: Union[str, List[str]] = 'ENV',
org: Optional[Union[str, List[str]]] = None,
meta_template: Optional[Dict] = [
Expand All @@ -53,13 +57,19 @@ def __init__(self,
dict(role='assistant', api_role='assistant')
],
openai_api_base: str = OPENAI_API_BASE,
proxies: Optional[Dict] = None,
**gen_params):
if 'top_k' in gen_params:
warnings.warn('`top_k` parameter is deprecated in OpenAI APIs.',
DeprecationWarning)
gen_params.pop('top_k')
super().__init__(
model_type=model_type,
meta_template=meta_template,
query_per_second=query_per_second,
retry=retry,
**gen_params)
self.gen_params.pop('top_k')
self.logger = getLogger(__name__)

if isinstance(key, str):
Expand All @@ -79,16 +89,8 @@ def __init__(self,
self.org_ctr = 0
self.url = openai_api_base
self.model_type = model_type

# max num token for gpt-3.5-turbo is 4097
context_window = 4096
if '32k' in self.model_type:
context_window = 32768
elif '16k' in self.model_type:
context_window = 16384
elif 'gpt-4' in self.model_type:
context_window = 8192
self.context_window = context_window
self.proxies = proxies
self.json_mode = json_mode

def chat(
self,
Expand Down Expand Up @@ -118,6 +120,27 @@ def chat(
ret = [task.result() for task in tasks]
return ret[0] if isinstance(inputs[0], dict) else ret

def stream_chat(
self,
inputs: List[dict],
**gen_params,
) -> str:
"""Generate responses given the contexts.

Args:
inputs (List[dict]): a list of messages
gen_params: additional generation configuration

Returns:
str: generated string
"""
assert isinstance(inputs, list)
if 'max_tokens' in gen_params:
raise NotImplementedError('unsupported parameter: max_tokens')
gen_params = {**self.gen_params, **gen_params}
gen_params['stream'] = True
yield from self._chat(inputs, **gen_params)

def _chat(self, messages: List[dict], **gen_params) -> str:
"""Generate completion from a list of templates.

Expand All @@ -132,9 +155,7 @@ def _chat(self, messages: List[dict], **gen_params) -> str:
gen_params = gen_params.copy()

# Hold out 100 tokens due to potential errors in tiktoken calculation
max_tokens = min(
gen_params.pop('max_new_tokens'),
self.context_window - len(self.tokenize(str(input))) - 100)
max_tokens = min(gen_params.pop('max_new_tokens'), 4096)
if max_tokens <= 0:
return ''

Expand Down Expand Up @@ -170,27 +191,49 @@ def _chat(self, messages: List[dict], **gen_params) -> str:
header['OpenAI-Organization'] = self.orgs[self.org_ctr]

try:
gen_params_new = gen_params.copy()
data = dict(
model=self.model_type,
messages=messages,
max_tokens=max_tokens,
n=1,
stop=gen_params.pop('stop_words'),
frequency_penalty=gen_params.pop('repetition_penalty'),
**gen_params,
stop=gen_params_new.pop('stop_words'),
frequency_penalty=gen_params_new.pop('repetition_penalty'),
**gen_params_new,
)
if self.json_mode:
data['response_format'] = {'type': 'json_object'}
raw_response = requests.post(
self.url, headers=header, data=json.dumps(data))
self.url,
headers=header,
data=json.dumps(data),
proxies=self.proxies)
if 'stream' not in data or not data['stream']:
response = raw_response.json()
return response['choices'][0]['message']['content'].strip()
else:
resp = ''
for chunk in raw_response.iter_lines(
chunk_size=8192, decode_unicode=False,
delimiter=b'\n'):
if chunk:
decoded = chunk.decode('utf-8')
if decoded == 'data: [DONE]':
return
if decoded[:6] == 'data: ':
decoded = decoded[6:]
response = json.loads(decoded)
choice = response['choices'][0]
if choice['finish_reason'] == 'stop':
return
resp += choice['delta']['content'].strip()
yield resp
except requests.ConnectionError:
print('Got connection error, retrying...')
continue
try:
response = raw_response.json()
except requests.JSONDecodeError:
print('JsonDecode error, got', str(raw_response.content))
continue
try:
return response['choices'][0]['message']['content'].strip()
except KeyError:
if 'error' in response:
if response['error']['code'] == 'rate_limit_exceeded':
Expand All @@ -203,6 +246,8 @@ def _chat(self, messages: List[dict], **gen_params) -> str:

print('Find error message in response: ',
str(response['error']))
except Exception as error:
print(str(error))
max_num_retries += 1

raise RuntimeError('Calling OpenAI failed after retrying for '
Expand Down
7 changes: 6 additions & 1 deletion lagent/schema.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import asdict, dataclass, field
from enum import IntEnum
from typing import List, Optional, Union
from typing import Dict, List, Optional, Union


def enum_dict_factory(inputs):
Expand Down Expand Up @@ -77,12 +77,17 @@ class AgentStatusCode(IntEnum):
CODING = 6 # start python
CODE_END = 7 # end python
CODE_RETURN = 8 # python return
ANSWER_ING = 9 # final answer is in streaming


@dataclass
class AgentReturn:
type: str = ''
content: str = ''
state: Union[AgentStatusCode, int] = AgentStatusCode.END
actions: List[ActionReturn] = field(default_factory=list)
response: str = ''
inner_steps: List = field(default_factory=list)
nodes: Dict = None
adjacency_list: Dict = None
errmsg: Optional[str] = None
1 change: 1 addition & 0 deletions requirements/optional.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
agentlego
google-search-results
lmdeploy>=0.2.3
pillow
Expand Down