From 685ed6940e67dc8cd353d09369a4942c25945525 Mon Sep 17 00:00:00 2001 From: Aviral Jain Date: Sun, 5 Jan 2025 11:50:22 +0530 Subject: [PATCH 1/3] add tool calling and handle human input handlers --- pyproject.toml | 2 +- src/compextAI/__init__.py | 4 + src/compextAI/default_tools.py | 13 +++ src/compextAI/execution.py | 161 ++++++++++++++++++++++++++++++++- src/compextAI/tools.py | 58 ++++++++++++ 5 files changed, 234 insertions(+), 4 deletions(-) create mode 100644 src/compextAI/default_tools.py create mode 100644 src/compextAI/tools.py diff --git a/pyproject.toml b/pyproject.toml index 0c394d4..a130072 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "CompextAI" -version = "0.0.14" +version = "0.0.24" authors = [ { name="burnerlee", email="avi.aviral140@gmail.com" }, ] diff --git a/src/compextAI/__init__.py b/src/compextAI/__init__.py index e69de29..25d3b2a 100644 --- a/src/compextAI/__init__.py +++ b/src/compextAI/__init__.py @@ -0,0 +1,4 @@ +from compextAI.default_tools import HumanInTheLoop +from compextAI.tools import register_tool + +register_tool(HumanInTheLoop) diff --git a/src/compextAI/default_tools.py b/src/compextAI/default_tools.py new file mode 100644 index 0000000..d3c63e3 --- /dev/null +++ b/src/compextAI/default_tools.py @@ -0,0 +1,13 @@ +from pydantic import BaseModel, Field +from typing import Type + +class HumanInTheLoopInputSchema(BaseModel): + message: str = Field(description="The message to ask the human for input") + +class HumanInTheLoop: + name = "human_in_the_loop" + description = "Use this tool to ask the human for input" + input_schema: Type[BaseModel] = HumanInTheLoopInputSchema + + def _run(self, input:dict): + pass diff --git a/src/compextAI/execution.py b/src/compextAI/execution.py index cf9b8a9..3e40653 100644 --- a/src/compextAI/execution.py +++ b/src/compextAI/execution.py @@ -1,6 +1,9 @@ from compextAI.api.api import APIClient from compextAI.messages import Message from compextAI.threads import ThreadExecutionResponse +from compextAI.tools import get_tool +import time +import queue class ThreadExecutionStatus: status: str @@ -34,8 +37,55 @@ def get_thread_execution_response(client:APIClient, thread_execution_id:str) -> class ExecuteMessagesResponse: thread_execution_id: str - def __init__(self, thread_execution_id:str): + def __init__(self, thread_execution_id:str, thread_execution_param_id:str, messages:list[Message], system_prompt:str, append_assistant_response:bool, metadata:dict): self.thread_execution_id = thread_execution_id + self.thread_execution_param_id = thread_execution_param_id + self.messages = messages + self.system_prompt = system_prompt + self.append_assistant_response = append_assistant_response + self.metadata = metadata + + def poll_thread_execution(self, client:APIClient) -> any: + while True: + try: + thread_run_status = get_thread_execution_status( + client=client, + thread_execution_id=self.thread_execution_id + ).status + except Exception as e: + print(e) + raise Exception("failed to get thread execution status") + if thread_run_status == "completed": + break + elif thread_run_status == "failed": + raise Exception("Thread run failed") + elif thread_run_status == "in_progress": + print("thread run in progress") + time.sleep(3) + else: + raise Exception(f"Unknown thread run status: {thread_run_status}") + + return get_thread_execution_response( + client=client, + thread_execution_id=self.thread_execution_id + ) + +class Tool: + name: str + description: str + input_schema: dict + + def __init__(self, name:str, description:str, input_schema:dict): + self.name = name + self.description = description + self.input_schema = input_schema + + def to_dict(self) -> dict: + return { + "name": self.name, + "description": self.description, + "input_schema": self.input_schema + } def execute_messages(client:APIClient, thread_execution_param_id:str, messages:list[Message],system_prompt:str="", append_assistant_response:bool=True, metadata:dict={}) -> ThreadExecutionResponse: thread_id = "compext_thread_null" @@ -44,7 +94,7 @@ def execute_messages(client:APIClient, thread_execution_param_id:str, messages:l "append_assistant_response": append_assistant_response, "thread_execution_system_prompt": system_prompt, "messages": [message.to_dict() for message in messages], - "metadata": metadata + "metadata": metadata, }) status_code: int = response["status"] @@ -53,4 +103,109 @@ def execute_messages(client:APIClient, thread_execution_param_id:str, messages:l if status_code != 200: raise Exception(f"Failed to execute thread, status code: {status_code}, response: {data}") - return ExecuteMessagesResponse(data["identifier"]) + return ExecuteMessagesResponse(data["identifier"], thread_execution_param_id, messages, system_prompt, append_assistant_response, metadata) + +class ExecuteMessagesWithToolsResponse(ExecuteMessagesResponse): + tools: list[Tool] + messages: list[Message] + human_in_the_loop: bool + human_intervention_handler: callable + def __init__(self, thread_execution_id:str, thread_execution_param_id:str, messages:list[Message], system_prompt:str, append_assistant_response:bool, metadata:dict, tools:list[str], human_in_the_loop:bool=False, human_intervention_handler:callable=None): + super().__init__(thread_execution_id, thread_execution_param_id, messages, system_prompt, append_assistant_response, metadata) + if human_in_the_loop: + if human_intervention_handler is None: + raise Exception("Human intervention handler is required when human_in_the_loop is True") + self.tools = tools + self.human_in_the_loop = human_in_the_loop + self.human_intervention_handler = human_intervention_handler + + def poll_until_completion(self, client:APIClient, execution_queue:queue.Queue=None) -> any: + while True: + response = self.poll_thread_execution(client) + if response['response']['stop_reason'] == "tool_use": + for msg in response['response']['content']: + if msg['type'] == "tool_use": + tool_name = msg['name'] + tool_input = msg['input'] + tool_use_id = msg['id'] + if execution_queue: + execution_queue.put({ + "type": "tool_use", + "content": { + "tool_name": tool_name, + "tool_input": tool_input, + "tool_use_id": tool_use_id + } + }) + print("tool return", msg) + + try: + if tool_name == "human_in_the_loop": + tool_result = self.human_intervention_handler(**tool_input) + else: + tool_result = get_tool(tool_name)(**tool_input) + except Exception as e: + print(f"Error executing tool {tool_name}: {e}") + raise Exception(f"Error executing tool {tool_name}: {e}") + + # handle tool result + print(f"Tool {tool_name} returned: {tool_result}") + if execution_queue: + execution_queue.put({ + "type": "tool_result", + "content": { + "tool_use_id": tool_use_id, + "result": tool_result + } + }) + self.messages.append(Message( + role="assistant", + content=response['response']['content'], + )) + self.messages.append(Message( + role="user" , + content=[ + { + "type": "tool_result", + "tool_use_id": tool_use_id, + "content": tool_result + } + ] + )) + # start a new execution with the new messages + new_execution = execute_messages_with_tools( + client=client, + thread_execution_param_id=self.thread_execution_param_id, + messages=self.messages, + system_prompt=self.system_prompt, + append_assistant_response=self.append_assistant_response, + metadata=self.metadata, + tool_list=self.tools + ) + self.thread_execution_id = new_execution.thread_execution_id + else: + break + + return response + + +def execute_messages_with_tools(client:APIClient, thread_execution_param_id:str, messages:list[Message],system_prompt:str="", append_assistant_response:bool=True, metadata:dict={}, tool_list:list[str]=[], human_in_the_loop:bool=False, human_intervention_handler:callable=None) -> ExecuteMessagesWithToolsResponse: + if human_in_the_loop: + tool_list.append("human_in_the_loop") + thread_id = "compext_thread_null" + response = client.post(f"/thread/{thread_id}/execute", data={ + "thread_execution_param_id": thread_execution_param_id, + "append_assistant_response": append_assistant_response, + "thread_execution_system_prompt": system_prompt, + "messages": [message.to_dict() for message in messages], + "metadata": metadata, + "tools": [get_tool(tool).to_dict() for tool in tool_list] + }) + + status_code: int = response["status"] + data: dict = response["data"] + + if status_code != 200: + raise Exception(f"Failed to execute thread, status code: {status_code}, response: {data}") + + return ExecuteMessagesWithToolsResponse(data["identifier"], thread_execution_param_id, messages, system_prompt, append_assistant_response, metadata, tool_list, human_in_the_loop, human_intervention_handler) diff --git a/src/compextAI/tools.py b/src/compextAI/tools.py new file mode 100644 index 0000000..b4e9402 --- /dev/null +++ b/src/compextAI/tools.py @@ -0,0 +1,58 @@ +from typing import Callable +from typing import Any, Dict, get_origin, get_args +from pydantic import BaseModel +from typing import Type + +ToolRegistry = {} + +def get_input_schema_dict(input_schema:Type[BaseModel]) -> dict: + input_schema_dict = input_schema.model_json_schema() + return input_schema_dict + +class Tool: + name:str + description:str + input_schema:dict + func:Callable + def __init__(self, name:str, description:str, input_schema:dict): + self.name = name + self.description = description + self.input_schema = input_schema + + def __call__(self, *args, **kwargs): + return self.func(self.tool_class(), *args, **kwargs) + + def to_dict(self): + return { + "name": self.name, + "description": self.description, + "input_schema": self.input_schema + } + def __str__(self): + return f"Tool(name={self.name}, description={self.description}, input_schema={self.input_schema})" + +def register_tool(cls): + if not hasattr(cls, "name"): + raise Exception(f"Tool {cls.__name__} does not have a name") + if not hasattr(cls, "description"): + raise Exception(f"Tool {cls.__name__} does not have a description") + if not hasattr(cls, "input_schema"): + raise Exception(f"Tool {cls.__name__} does not have an input schema") + if not hasattr(cls, "_run"): + raise Exception(f"Tool {cls.__name__} does not have a _run method") + + input_schema = get_input_schema_dict(cls.input_schema) + tool_instance = Tool(name=cls.name, description=cls.description, input_schema=input_schema) + tool_instance.func = cls._run + tool_instance.tool_class = cls + ToolRegistry[cls.name] = tool_instance + + return cls + +def get_tool(name:str) -> Tool: + if name not in ToolRegistry: + raise Exception(f"Tool {name} not found, please register the tool first") + return ToolRegistry[name] + +def get_tool_names() -> list[str]: + return ToolRegistry.keys() From 483e6af7042de042cc750bef09a336a2454a45f7 Mon Sep 17 00:00:00 2001 From: Aviral Jain Date: Thu, 9 Jan 2025 23:15:33 +0530 Subject: [PATCH 2/3] change logic for openai specs --- pyproject.toml | 2 +- src/compextAI/execution.py | 132 ++++++++++++++++++++----------------- src/compextAI/messages.py | 34 ++++++++-- 3 files changed, 103 insertions(+), 65 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a130072..ae4ff70 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "CompextAI" -version = "0.0.24" +version = "0.0.27" authors = [ { name="burnerlee", email="avi.aviral140@gmail.com" }, ] diff --git a/src/compextAI/execution.py b/src/compextAI/execution.py index 3e40653..b48d8d3 100644 --- a/src/compextAI/execution.py +++ b/src/compextAI/execution.py @@ -4,6 +4,7 @@ from compextAI.tools import get_tool import time import queue +import json class ThreadExecutionStatus: status: str @@ -69,6 +70,10 @@ def poll_thread_execution(self, client:APIClient) -> any: client=client, thread_execution_id=self.thread_execution_id ) + + def get_stop_reason(self, client:APIClient) -> str: + response = self.poll_thread_execution(client) + return response['response']['choices'][0]['finish_reason'] class Tool: name: str @@ -122,67 +127,74 @@ def __init__(self, thread_execution_id:str, thread_execution_param_id:str, messa def poll_until_completion(self, client:APIClient, execution_queue:queue.Queue=None) -> any: while True: response = self.poll_thread_execution(client) - if response['response']['stop_reason'] == "tool_use": - for msg in response['response']['content']: - if msg['type'] == "tool_use": - tool_name = msg['name'] - tool_input = msg['input'] - tool_use_id = msg['id'] - if execution_queue: - execution_queue.put({ - "type": "tool_use", - "content": { - "tool_name": tool_name, - "tool_input": tool_input, - "tool_use_id": tool_use_id - } - }) - print("tool return", msg) - - try: - if tool_name == "human_in_the_loop": - tool_result = self.human_intervention_handler(**tool_input) - else: - tool_result = get_tool(tool_name)(**tool_input) - except Exception as e: - print(f"Error executing tool {tool_name}: {e}") - raise Exception(f"Error executing tool {tool_name}: {e}") - - # handle tool result - print(f"Tool {tool_name} returned: {tool_result}") - if execution_queue: - execution_queue.put({ - "type": "tool_result", - "content": { - "tool_use_id": tool_use_id, - "result": tool_result - } - }) - self.messages.append(Message( - role="assistant", - content=response['response']['content'], - )) + if self.get_stop_reason(client) == "tool_calls": + content_msg = response['response']['choices'][0]['message']['content'] + tool_calls = response['response']['choices'][0]['message']['tool_calls'] + self.messages.append(Message( + **response['response']['choices'][0]['message'], + )) + for tool_call in tool_calls: + tool_name = tool_call['function']['name'] + tool_input = tool_call['function']['arguments'] + tool_use_id = tool_call['id'] + + if tool_name == "json_tool_call": self.messages.append(Message( - role="user" , - content=[ - { - "type": "tool_result", - "tool_use_id": tool_use_id, - "content": tool_result - } - ] - )) - # start a new execution with the new messages - new_execution = execute_messages_with_tools( - client=client, - thread_execution_param_id=self.thread_execution_param_id, - messages=self.messages, - system_prompt=self.system_prompt, - append_assistant_response=self.append_assistant_response, - metadata=self.metadata, - tool_list=self.tools - ) - self.thread_execution_id = new_execution.thread_execution_id + role="tool" , + content=tool_call['function']['arguments'], + tool_call_id=tool_use_id + )) + continue + + if execution_queue: + execution_queue.put({ + "type": "tool_use", + "content": { + "tool_name": tool_name, + "tool_input": tool_input, + "tool_use_id": tool_use_id + } + }) + tool_input_dict = json.loads(tool_input) + print("tool return", tool_call) + print("tool input", tool_input_dict) + print("tool input type", type(tool_input_dict)) + try: + if tool_name == "human_in_the_loop": + tool_result = self.human_intervention_handler(**tool_input_dict) + else: + tool_result = get_tool(tool_name)(**tool_input_dict) + except Exception as e: + print(f"Error executing tool {tool_name}: {e}") + raise Exception(f"Error executing tool {tool_name}: {e}") + + # handle tool result + print(f"Tool {tool_name} returned: {tool_result}") + if execution_queue: + execution_queue.put({ + "type": "tool_result", + "content": { + "tool_use_id": tool_use_id, + "result": tool_result + } + }) + print("response assistant appending", response['response']['choices'][0]['message']) + self.messages.append(Message( + role="tool" , + content=tool_result, + tool_call_id=tool_use_id + )) + # start a new execution with the new messages + new_execution = execute_messages_with_tools( + client=client, + thread_execution_param_id=self.thread_execution_param_id, + messages=self.messages, + system_prompt=self.system_prompt, + append_assistant_response=self.append_assistant_response, + metadata=self.metadata, + tool_list=self.tools + ) + self.thread_execution_id = new_execution.thread_execution_id else: break diff --git a/src/compextAI/messages.py b/src/compextAI/messages.py index c4e302a..a0b0b48 100644 --- a/src/compextAI/messages.py +++ b/src/compextAI/messages.py @@ -9,8 +9,14 @@ class Message: metadata: dict created_at: datetime updated_at: datetime + tool_call_id: str + tool_calls: any + function_call: any - def __init__(self,content:any, role:str, message_id:str='', thread_id:str='', metadata:dict={}, created_at:datetime=None, updated_at:datetime=None): + def __init__(self,content:any, role:str, message_id:str='', thread_id:str='', metadata:dict={}, created_at:datetime=None, updated_at:datetime=None, **kwargs): + self.tool_call_id = None + self.tool_calls = None + self.function_call = None self.message_id = message_id self.thread_id = thread_id self.content = content @@ -18,9 +24,15 @@ def __init__(self,content:any, role:str, message_id:str='', thread_id:str='', me self.metadata = metadata self.created_at = created_at self.updated_at = updated_at + if "tool_call_id" in kwargs: + self.tool_call_id = kwargs["tool_call_id"] + if "tool_calls" in kwargs: + self.tool_calls = kwargs["tool_calls"] + if "function_call" in kwargs: + self.function_call = kwargs["function_call"] def __str__(self): - return f"Message(message_id={self.message_id}, thread_id={self.thread_id}, content={self.content}, role={self.role}, metadata={self.metadata})" + return f"Message(message_id={self.message_id}, thread_id={self.thread_id}, content={self.content}, role={self.role}, metadata={self.metadata}, tool_call_id={self.tool_call_id}, tool_calls={self.tool_calls}, function_call={self.function_call})" def to_dict(self) -> dict: return { @@ -30,11 +42,25 @@ def to_dict(self) -> dict: "message_id": self.message_id, "thread_id": self.thread_id, "created_at": self.created_at, - "updated_at": self.updated_at + "updated_at": self.updated_at, + "tool_call_id": self.tool_call_id if self.tool_call_id else None, + "tool_calls": self.tool_calls if self.tool_calls else None, + "function_call": self.function_call if self.function_call else None } def get_message_object_from_dict(data:dict) -> Message: - return Message(data["content"], data["role"], data["identifier"], data["thread_id"], data["metadata"], data["created_at"], data["updated_at"]) + return Message( + content=data["content"], + role=data["role"], + message_id=data["identifier"], + thread_id=data["thread_id"], + metadata=data["metadata"], + created_at=data["created_at"], + updated_at=data["updated_at"], + tool_call_id=data["tool_call_id"] if "tool_call_id" in data else None, + tool_calls=data["tool_calls"] if "tool_calls" in data else None, + function_call=data["function_call"] if "function_call" in data else None + ) def list_all(client:APIClient, thread_id:str) -> list[Message]: response = client.get(f"/message/thread/{thread_id}") From ef99b7a3e00bd53cd045c2a1d800e163e1d4cf4e Mon Sep 17 00:00:00 2001 From: Aviral Jain Date: Thu, 23 Jan 2025 17:12:10 +0530 Subject: [PATCH 3/3] update toolclass --- pyproject.toml | 2 +- src/compextAI/execution.py | 6 +++--- src/compextAI/tools.py | 16 +++++++++++++++- 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ae4ff70..c5c5c60 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "CompextAI" -version = "0.0.27" +version = "0.0.28" authors = [ { name="burnerlee", email="avi.aviral140@gmail.com" }, ] diff --git a/src/compextAI/execution.py b/src/compextAI/execution.py index b48d8d3..c9ef60d 100644 --- a/src/compextAI/execution.py +++ b/src/compextAI/execution.py @@ -169,19 +169,19 @@ def poll_until_completion(self, client:APIClient, execution_queue:queue.Queue=No raise Exception(f"Error executing tool {tool_name}: {e}") # handle tool result - print(f"Tool {tool_name} returned: {tool_result}") + print(f"Tool {tool_name} returned: {tool_result.get_content()}") if execution_queue: execution_queue.put({ "type": "tool_result", "content": { "tool_use_id": tool_use_id, - "result": tool_result + "result": tool_result.get_content() } }) print("response assistant appending", response['response']['choices'][0]['message']) self.messages.append(Message( role="tool" , - content=tool_result, + content=tool_result.get_result(), tool_call_id=tool_use_id )) # start a new execution with the new messages diff --git a/src/compextAI/tools.py b/src/compextAI/tools.py index b4e9402..3cd1b41 100644 --- a/src/compextAI/tools.py +++ b/src/compextAI/tools.py @@ -9,6 +9,20 @@ def get_input_schema_dict(input_schema:Type[BaseModel]) -> dict: input_schema_dict = input_schema.model_json_schema() return input_schema_dict +class ToolResult: + result: str + content: str + + def __init__(self, result:str, content:str): + self.result = result + self.content = content + + def get_result(self): + return self.result + + def get_content(self): + return self.content + class Tool: name:str description:str @@ -19,7 +33,7 @@ def __init__(self, name:str, description:str, input_schema:dict): self.description = description self.input_schema = input_schema - def __call__(self, *args, **kwargs): + def __call__(self, *args, **kwargs) -> ToolResult: return self.func(self.tool_class(), *args, **kwargs) def to_dict(self):