diff --git a/examples/internlm2_agent_web_demo.py b/examples/internlm2_agent_web_demo.py index 61f4dccf..b37e31d4 100644 --- a/examples/internlm2_agent_web_demo.py +++ b/examples/internlm2_agent_web_demo.py @@ -5,7 +5,8 @@ import streamlit as st -from lagent.actions import ActionExecutor, ArxivSearch, IPythonInterpreter +from lagent.actions import PPT, ActionExecutor, ArxivSearch, BINGMap, GoogleScholar, IPythonInterpreter +# from lagent.actions.agentlego_wrapper import AgentLegoToolkit from lagent.agents.internlm2_agent import INTERPRETER_CN, META_CN, PLUGIN_CN, Internlm2Agent, Internlm2Protocol from lagent.llms.lmdepoly_wrapper import LMDeployClient from lagent.llms.meta_template import INTERNLM2_META as META @@ -23,6 +24,11 @@ def init_state(self): action_list = [ ArxivSearch(), + PPT(), + BINGMap(key='Your api key' # noqa + ), + GoogleScholar(api_key='Your api key' # noqa + ) ] st.session_state['plugin_map'] = { action.name: action @@ -104,7 +110,7 @@ def setup_sidebar(self): actions=[IPythonInterpreter()]) else: st.session_state['chatbot']._interpreter_executor = None - st.session_state['chatbot']._protocol._meta_template = meta_prompt + st.session_state['chatbot']._protocol.meta_prompt = meta_prompt st.session_state['chatbot']._protocol.plugin_prompt = plugin_prompt st.session_state[ 'chatbot']._protocol.interpreter_prompt = da_prompt @@ -141,8 +147,8 @@ def initialize_chatbot(self, model, plugin_action): plugin='<|plugin|>', interpreter='<|interpreter|>'), belong='assistant', end='<|action_end|>\n', - ), ), - max_turn=7) + )), + max_turn=15) def render_user(self, prompt: str): with st.chat_message('user'): diff --git a/lagent/actions/agentlego_wrapper.py b/lagent/actions/agentlego_wrapper.py new file mode 100644 index 00000000..e4c98255 --- /dev/null +++ b/lagent/actions/agentlego_wrapper.py @@ -0,0 +1,48 @@ +from typing import Optional + +# from agentlego.parsers import DefaultParser +from agentlego.tools.remote import RemoteTool + +from lagent import BaseAction +from lagent.actions.parser import JsonParser + + +class AgentLegoToolkit(BaseAction): + + def __init__(self, + name: str, + url: Optional[str] = None, + text: Optional[str] = None, + spec_dict: Optional[dict] = None, + parser=JsonParser, + enable: bool = True): + + if url is not None: + spec = dict(url=url) + elif text is not None: + spec = dict(text=text) + else: + assert spec_dict is not None + spec = dict(spec_dict=spec_dict) + if url is not None and not url.endswith('.json'): + api_list = [RemoteTool.from_url(url).to_lagent()] + else: + api_list = [ + api.to_lagent() for api in RemoteTool.from_openapi(**spec) + ] + api_desc = [] + for api in api_list: + api_desc.append(api.description) + if len(api_list) > 1: + tool_description = dict(name=name, api_list=api_desc) + for func in api_list: + setattr(self, func.name, func.run) + else: + tool_description = api_desc[0] + setattr(self, 'run', api_list[0].run) + super().__init__( + description=tool_description, parser=parser, enable=enable) + + @property + def is_toolkit(self): + return 'api_list' in self.description diff --git a/lagent/actions/builtin_actions.py b/lagent/actions/builtin_actions.py index 33702132..f336bdf2 100644 --- a/lagent/actions/builtin_actions.py +++ b/lagent/actions/builtin_actions.py @@ -37,7 +37,7 @@ def run(self, err_msg: Optional[str] = None) -> ActionReturn: action_return = ActionReturn( url=None, args=dict(text=err_msg), - errmsg=err_msg or self._err_msg, + errmsg=str(err_msg) or self._err_msg, type=self.name, valid=ActionValidCode.INVALID, state=ActionStatusCode.API_ERROR) @@ -76,7 +76,7 @@ def run(self, err_msg: Optional[str] = None) -> ActionReturn: url=None, args=dict(text=err_msg), type=self.name, - errmsg=err_msg or self._err_msg, + errmsg=str(err_msg) or self._err_msg, valid=ActionValidCode.INVALID, state=ActionStatusCode.API_ERROR) return action_return diff --git a/lagent/agents/internlm2_agent.py b/lagent/agents/internlm2_agent.py index 6c5f4e08..ccf85a61 100644 --- a/lagent/agents/internlm2_agent.py +++ b/lagent/agents/internlm2_agent.py @@ -113,7 +113,6 @@ def format_plugin(message): else: new_message.append( dict(role=message['role'], content=new_content)) - return new_message def format(self, @@ -125,9 +124,10 @@ def format(self, if self.meta_prompt: formatted.append(dict(role='system', content=self.meta_prompt)) if interpreter_executor and self.interpreter_prompt: - interpreter_info = interpreter_executor.get_actions_info()[0] - interpreter_prompt = self.interpreter_prompt.format( - code_prompt=interpreter_info['description']) + # interpreter_info = interpreter_executor.get_actions_info()[0] + # interpreter_prompt = self.interpreter_prompt.format( + # code_prompt=interpreter_info['description']) + interpreter_prompt = self.interpreter_prompt formatted.append( dict( role='system', @@ -169,20 +169,30 @@ def parse(self, message, plugin_executor: ActionExecutor, action = action.split(self.tool['end'].strip())[0] return 'plugin', message, action if self.tool['name_map']['interpreter'] in message: - message, code = message.split( - f"{self.tool['start_token']}" - f"{self.tool['name_map']['interpreter']}") + try: + message, code, *_ = message.split( + f"{self.tool['start_token']}" + f"{self.tool['name_map']['interpreter']}") + # message, code, *_ = message.split(f"{self.tool['start_token']}") + # _, code, *_ = code.split(f"{self.tool['name_map']['interpreter']}") + except ValueError: + message, code, *_ = message.split( + self.tool['name_map']['interpreter']) + tool_start_idx = message.rfind(self.tool['start_token']) + if tool_start_idx != -1: + message = message[:tool_start_idx] + message = message.strip() code = code.split(self.tool['end'].strip())[0].strip() return 'interpreter', message, dict( - name=interpreter_executor.action_names()[0], - parameters=dict(command=code)) + name='IPythonInterpreter', parameters=dict( + command=code)) if interpreter_executor else None return None, message.split(self.tool['start_token'])[0], None def format_response(self, action_return, name) -> dict: if action_return.state == ActionStatusCode.SUCCESS: response = action_return.format_result() else: - response = action_return.errmsg + response = str(action_return.errmsg) content = self.execute['begin'] + response + self.execute['end'] if self.execute.get('fallback_role'): return dict( diff --git a/lagent/llms/openai.py b/lagent/llms/openai.py index 91b3f236..ffc87ba4 100644 --- a/lagent/llms/openai.py +++ b/lagent/llms/openai.py @@ -1,6 +1,7 @@ import json import os import time +import warnings from concurrent.futures import ThreadPoolExecutor from logging import getLogger from threading import Lock @@ -10,6 +11,8 @@ from .base_api import BaseAPIModel +warnings.simplefilter('default') + OPENAI_API_BASE = 'https://api.openai.com/v1/chat/completions' @@ -45,6 +48,7 @@ def __init__(self, model_type: str = 'gpt-3.5-turbo', query_per_second: int = 1, retry: int = 2, + json_mode: bool = False, key: Union[str, List[str]] = 'ENV', org: Optional[Union[str, List[str]]] = None, meta_template: Optional[Dict] = [ @@ -53,13 +57,19 @@ def __init__(self, dict(role='assistant', api_role='assistant') ], openai_api_base: str = OPENAI_API_BASE, + proxies: Optional[Dict] = None, **gen_params): + if 'top_k' in gen_params: + warnings.warn('`top_k` parameter is deprecated in OpenAI APIs.', + DeprecationWarning) + gen_params.pop('top_k') super().__init__( model_type=model_type, meta_template=meta_template, query_per_second=query_per_second, retry=retry, **gen_params) + self.gen_params.pop('top_k') self.logger = getLogger(__name__) if isinstance(key, str): @@ -79,16 +89,8 @@ def __init__(self, self.org_ctr = 0 self.url = openai_api_base self.model_type = model_type - - # max num token for gpt-3.5-turbo is 4097 - context_window = 4096 - if '32k' in self.model_type: - context_window = 32768 - elif '16k' in self.model_type: - context_window = 16384 - elif 'gpt-4' in self.model_type: - context_window = 8192 - self.context_window = context_window + self.proxies = proxies + self.json_mode = json_mode def chat( self, @@ -118,6 +120,27 @@ def chat( ret = [task.result() for task in tasks] return ret[0] if isinstance(inputs[0], dict) else ret + def stream_chat( + self, + inputs: List[dict], + **gen_params, + ) -> str: + """Generate responses given the contexts. + + Args: + inputs (List[dict]): a list of messages + gen_params: additional generation configuration + + Returns: + str: generated string + """ + assert isinstance(inputs, list) + if 'max_tokens' in gen_params: + raise NotImplementedError('unsupported parameter: max_tokens') + gen_params = {**self.gen_params, **gen_params} + gen_params['stream'] = True + yield from self._chat(inputs, **gen_params) + def _chat(self, messages: List[dict], **gen_params) -> str: """Generate completion from a list of templates. @@ -132,9 +155,7 @@ def _chat(self, messages: List[dict], **gen_params) -> str: gen_params = gen_params.copy() # Hold out 100 tokens due to potential errors in tiktoken calculation - max_tokens = min( - gen_params.pop('max_new_tokens'), - self.context_window - len(self.tokenize(str(input))) - 100) + max_tokens = min(gen_params.pop('max_new_tokens'), 4096) if max_tokens <= 0: return '' @@ -170,27 +191,49 @@ def _chat(self, messages: List[dict], **gen_params) -> str: header['OpenAI-Organization'] = self.orgs[self.org_ctr] try: + gen_params_new = gen_params.copy() data = dict( model=self.model_type, messages=messages, max_tokens=max_tokens, n=1, - stop=gen_params.pop('stop_words'), - frequency_penalty=gen_params.pop('repetition_penalty'), - **gen_params, + stop=gen_params_new.pop('stop_words'), + frequency_penalty=gen_params_new.pop('repetition_penalty'), + **gen_params_new, ) + if self.json_mode: + data['response_format'] = {'type': 'json_object'} raw_response = requests.post( - self.url, headers=header, data=json.dumps(data)) + self.url, + headers=header, + data=json.dumps(data), + proxies=self.proxies) + if 'stream' not in data or not data['stream']: + response = raw_response.json() + return response['choices'][0]['message']['content'].strip() + else: + resp = '' + for chunk in raw_response.iter_lines( + chunk_size=8192, decode_unicode=False, + delimiter=b'\n'): + if chunk: + decoded = chunk.decode('utf-8') + if decoded == 'data: [DONE]': + return + if decoded[:6] == 'data: ': + decoded = decoded[6:] + response = json.loads(decoded) + choice = response['choices'][0] + if choice['finish_reason'] == 'stop': + return + resp += choice['delta']['content'].strip() + yield resp except requests.ConnectionError: print('Got connection error, retrying...') continue - try: - response = raw_response.json() except requests.JSONDecodeError: print('JsonDecode error, got', str(raw_response.content)) continue - try: - return response['choices'][0]['message']['content'].strip() except KeyError: if 'error' in response: if response['error']['code'] == 'rate_limit_exceeded': @@ -203,6 +246,8 @@ def _chat(self, messages: List[dict], **gen_params) -> str: print('Find error message in response: ', str(response['error'])) + except Exception as error: + print(str(error)) max_num_retries += 1 raise RuntimeError('Calling OpenAI failed after retrying for ' diff --git a/lagent/schema.py b/lagent/schema.py index a7f8e0cd..94bcf370 100644 --- a/lagent/schema.py +++ b/lagent/schema.py @@ -1,6 +1,6 @@ from dataclasses import asdict, dataclass, field from enum import IntEnum -from typing import List, Optional, Union +from typing import Dict, List, Optional, Union def enum_dict_factory(inputs): @@ -77,12 +77,17 @@ class AgentStatusCode(IntEnum): CODING = 6 # start python CODE_END = 7 # end python CODE_RETURN = 8 # python return + ANSWER_ING = 9 # final answer is in streaming @dataclass class AgentReturn: + type: str = '' + content: str = '' state: Union[AgentStatusCode, int] = AgentStatusCode.END actions: List[ActionReturn] = field(default_factory=list) response: str = '' inner_steps: List = field(default_factory=list) + nodes: Dict = None + adjacency_list: Dict = None errmsg: Optional[str] = None diff --git a/requirements/optional.txt b/requirements/optional.txt index 0fe2504d..4b576507 100644 --- a/requirements/optional.txt +++ b/requirements/optional.txt @@ -1,3 +1,4 @@ +agentlego google-search-results lmdeploy>=0.2.3 pillow