Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "CompextAI"
version = "0.0.14"
version = "0.0.28"
authors = [
{ name="burnerlee", email="avi.aviral140@gmail.com" },
]
Expand Down
4 changes: 4 additions & 0 deletions src/compextAI/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from compextAI.default_tools import HumanInTheLoop
from compextAI.tools import register_tool

register_tool(HumanInTheLoop)
13 changes: 13 additions & 0 deletions src/compextAI/default_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from pydantic import BaseModel, Field
from typing import Type

class HumanInTheLoopInputSchema(BaseModel):
message: str = Field(description="The message to ask the human for input")

class HumanInTheLoop:
name = "human_in_the_loop"
description = "Use this tool to ask the human for input"
input_schema: Type[BaseModel] = HumanInTheLoopInputSchema

def _run(self, input:dict):
pass
173 changes: 170 additions & 3 deletions src/compextAI/execution.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from compextAI.api.api import APIClient
from compextAI.messages import Message
from compextAI.threads import ThreadExecutionResponse
from compextAI.tools import get_tool
import time
import queue
import json

class ThreadExecutionStatus:
status: str
Expand Down Expand Up @@ -34,8 +38,59 @@ def get_thread_execution_response(client:APIClient, thread_execution_id:str) ->
class ExecuteMessagesResponse:
thread_execution_id: str

def __init__(self, thread_execution_id:str):
def __init__(self, thread_execution_id:str, thread_execution_param_id:str, messages:list[Message], system_prompt:str, append_assistant_response:bool, metadata:dict):
self.thread_execution_id = thread_execution_id
self.thread_execution_param_id = thread_execution_param_id
self.messages = messages
self.system_prompt = system_prompt
self.append_assistant_response = append_assistant_response
self.metadata = metadata

def poll_thread_execution(self, client:APIClient) -> any:
while True:
try:
thread_run_status = get_thread_execution_status(
client=client,
thread_execution_id=self.thread_execution_id
).status
except Exception as e:
print(e)
raise Exception("failed to get thread execution status")
if thread_run_status == "completed":
break
elif thread_run_status == "failed":
raise Exception("Thread run failed")
elif thread_run_status == "in_progress":
print("thread run in progress")
time.sleep(3)
else:
raise Exception(f"Unknown thread run status: {thread_run_status}")

return get_thread_execution_response(
client=client,
thread_execution_id=self.thread_execution_id
)

def get_stop_reason(self, client:APIClient) -> str:
response = self.poll_thread_execution(client)
return response['response']['choices'][0]['finish_reason']

class Tool:
name: str
description: str
input_schema: dict

def __init__(self, name:str, description:str, input_schema:dict):
self.name = name
self.description = description
self.input_schema = input_schema

def to_dict(self) -> dict:
return {
"name": self.name,
"description": self.description,
"input_schema": self.input_schema
}

def execute_messages(client:APIClient, thread_execution_param_id:str, messages:list[Message],system_prompt:str="", append_assistant_response:bool=True, metadata:dict={}) -> ThreadExecutionResponse:
thread_id = "compext_thread_null"
Expand All @@ -44,7 +99,7 @@ def execute_messages(client:APIClient, thread_execution_param_id:str, messages:l
"append_assistant_response": append_assistant_response,
"thread_execution_system_prompt": system_prompt,
"messages": [message.to_dict() for message in messages],
"metadata": metadata
"metadata": metadata,
})

status_code: int = response["status"]
Expand All @@ -53,4 +108,116 @@ def execute_messages(client:APIClient, thread_execution_param_id:str, messages:l
if status_code != 200:
raise Exception(f"Failed to execute thread, status code: {status_code}, response: {data}")

return ExecuteMessagesResponse(data["identifier"])
return ExecuteMessagesResponse(data["identifier"], thread_execution_param_id, messages, system_prompt, append_assistant_response, metadata)

class ExecuteMessagesWithToolsResponse(ExecuteMessagesResponse):
tools: list[Tool]
messages: list[Message]
human_in_the_loop: bool
human_intervention_handler: callable
def __init__(self, thread_execution_id:str, thread_execution_param_id:str, messages:list[Message], system_prompt:str, append_assistant_response:bool, metadata:dict, tools:list[str], human_in_the_loop:bool=False, human_intervention_handler:callable=None):
super().__init__(thread_execution_id, thread_execution_param_id, messages, system_prompt, append_assistant_response, metadata)
if human_in_the_loop:
if human_intervention_handler is None:
raise Exception("Human intervention handler is required when human_in_the_loop is True")
self.tools = tools
self.human_in_the_loop = human_in_the_loop
self.human_intervention_handler = human_intervention_handler

def poll_until_completion(self, client:APIClient, execution_queue:queue.Queue=None) -> any:
while True:
response = self.poll_thread_execution(client)
if self.get_stop_reason(client) == "tool_calls":
content_msg = response['response']['choices'][0]['message']['content']
tool_calls = response['response']['choices'][0]['message']['tool_calls']
self.messages.append(Message(
**response['response']['choices'][0]['message'],
))
for tool_call in tool_calls:
tool_name = tool_call['function']['name']
tool_input = tool_call['function']['arguments']
tool_use_id = tool_call['id']

if tool_name == "json_tool_call":
self.messages.append(Message(
role="tool" ,
content=tool_call['function']['arguments'],
tool_call_id=tool_use_id
))
continue

if execution_queue:
execution_queue.put({
"type": "tool_use",
"content": {
"tool_name": tool_name,
"tool_input": tool_input,
"tool_use_id": tool_use_id
}
})
tool_input_dict = json.loads(tool_input)
print("tool return", tool_call)
print("tool input", tool_input_dict)
print("tool input type", type(tool_input_dict))
try:
if tool_name == "human_in_the_loop":
tool_result = self.human_intervention_handler(**tool_input_dict)
else:
tool_result = get_tool(tool_name)(**tool_input_dict)
except Exception as e:
print(f"Error executing tool {tool_name}: {e}")
raise Exception(f"Error executing tool {tool_name}: {e}")

# handle tool result
print(f"Tool {tool_name} returned: {tool_result.get_content()}")
if execution_queue:
execution_queue.put({
"type": "tool_result",
"content": {
"tool_use_id": tool_use_id,
"result": tool_result.get_content()
}
})
print("response assistant appending", response['response']['choices'][0]['message'])
self.messages.append(Message(
role="tool" ,
content=tool_result.get_result(),
tool_call_id=tool_use_id
))
# start a new execution with the new messages
new_execution = execute_messages_with_tools(
client=client,
thread_execution_param_id=self.thread_execution_param_id,
messages=self.messages,
system_prompt=self.system_prompt,
append_assistant_response=self.append_assistant_response,
metadata=self.metadata,
tool_list=self.tools
)
self.thread_execution_id = new_execution.thread_execution_id
else:
break

return response


def execute_messages_with_tools(client:APIClient, thread_execution_param_id:str, messages:list[Message],system_prompt:str="", append_assistant_response:bool=True, metadata:dict={}, tool_list:list[str]=[], human_in_the_loop:bool=False, human_intervention_handler:callable=None) -> ExecuteMessagesWithToolsResponse:
if human_in_the_loop:
tool_list.append("human_in_the_loop")
thread_id = "compext_thread_null"
response = client.post(f"/thread/{thread_id}/execute", data={
"thread_execution_param_id": thread_execution_param_id,
"append_assistant_response": append_assistant_response,
"thread_execution_system_prompt": system_prompt,
"messages": [message.to_dict() for message in messages],
"metadata": metadata,
"tools": [get_tool(tool).to_dict() for tool in tool_list]
})

status_code: int = response["status"]
data: dict = response["data"]

if status_code != 200:
raise Exception(f"Failed to execute thread, status code: {status_code}, response: {data}")

return ExecuteMessagesWithToolsResponse(data["identifier"], thread_execution_param_id, messages, system_prompt, append_assistant_response, metadata, tool_list, human_in_the_loop, human_intervention_handler)
34 changes: 30 additions & 4 deletions src/compextAI/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,30 @@ class Message:
metadata: dict
created_at: datetime
updated_at: datetime
tool_call_id: str
tool_calls: any
function_call: any

def __init__(self,content:any, role:str, message_id:str='', thread_id:str='', metadata:dict={}, created_at:datetime=None, updated_at:datetime=None):
def __init__(self,content:any, role:str, message_id:str='', thread_id:str='', metadata:dict={}, created_at:datetime=None, updated_at:datetime=None, **kwargs):
self.tool_call_id = None
self.tool_calls = None
self.function_call = None
self.message_id = message_id
self.thread_id = thread_id
self.content = content
self.role = role
self.metadata = metadata
self.created_at = created_at
self.updated_at = updated_at
if "tool_call_id" in kwargs:
self.tool_call_id = kwargs["tool_call_id"]
if "tool_calls" in kwargs:
self.tool_calls = kwargs["tool_calls"]
if "function_call" in kwargs:
self.function_call = kwargs["function_call"]

def __str__(self):
return f"Message(message_id={self.message_id}, thread_id={self.thread_id}, content={self.content}, role={self.role}, metadata={self.metadata})"
return f"Message(message_id={self.message_id}, thread_id={self.thread_id}, content={self.content}, role={self.role}, metadata={self.metadata}, tool_call_id={self.tool_call_id}, tool_calls={self.tool_calls}, function_call={self.function_call})"

def to_dict(self) -> dict:
return {
Expand All @@ -30,11 +42,25 @@ def to_dict(self) -> dict:
"message_id": self.message_id,
"thread_id": self.thread_id,
"created_at": self.created_at,
"updated_at": self.updated_at
"updated_at": self.updated_at,
"tool_call_id": self.tool_call_id if self.tool_call_id else None,
"tool_calls": self.tool_calls if self.tool_calls else None,
"function_call": self.function_call if self.function_call else None
}

def get_message_object_from_dict(data:dict) -> Message:
return Message(data["content"], data["role"], data["identifier"], data["thread_id"], data["metadata"], data["created_at"], data["updated_at"])
return Message(
content=data["content"],
role=data["role"],
message_id=data["identifier"],
thread_id=data["thread_id"],
metadata=data["metadata"],
created_at=data["created_at"],
updated_at=data["updated_at"],
tool_call_id=data["tool_call_id"] if "tool_call_id" in data else None,
tool_calls=data["tool_calls"] if "tool_calls" in data else None,
function_call=data["function_call"] if "function_call" in data else None
)

def list_all(client:APIClient, thread_id:str) -> list[Message]:
response = client.get(f"/message/thread/{thread_id}")
Expand Down
72 changes: 72 additions & 0 deletions src/compextAI/tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from typing import Callable
from typing import Any, Dict, get_origin, get_args
from pydantic import BaseModel
from typing import Type

ToolRegistry = {}

def get_input_schema_dict(input_schema:Type[BaseModel]) -> dict:
input_schema_dict = input_schema.model_json_schema()
return input_schema_dict

class ToolResult:
result: str
content: str

def __init__(self, result:str, content:str):
self.result = result
self.content = content

def get_result(self):
return self.result

def get_content(self):
return self.content

class Tool:
name:str
description:str
input_schema:dict
func:Callable
def __init__(self, name:str, description:str, input_schema:dict):
self.name = name
self.description = description
self.input_schema = input_schema

def __call__(self, *args, **kwargs) -> ToolResult:
return self.func(self.tool_class(), *args, **kwargs)

def to_dict(self):
return {
"name": self.name,
"description": self.description,
"input_schema": self.input_schema
}
def __str__(self):
return f"Tool(name={self.name}, description={self.description}, input_schema={self.input_schema})"

def register_tool(cls):
if not hasattr(cls, "name"):
raise Exception(f"Tool {cls.__name__} does not have a name")
if not hasattr(cls, "description"):
raise Exception(f"Tool {cls.__name__} does not have a description")
if not hasattr(cls, "input_schema"):
raise Exception(f"Tool {cls.__name__} does not have an input schema")
if not hasattr(cls, "_run"):
raise Exception(f"Tool {cls.__name__} does not have a _run method")

input_schema = get_input_schema_dict(cls.input_schema)
tool_instance = Tool(name=cls.name, description=cls.description, input_schema=input_schema)
tool_instance.func = cls._run
tool_instance.tool_class = cls
ToolRegistry[cls.name] = tool_instance

return cls

def get_tool(name:str) -> Tool:
if name not in ToolRegistry:
raise Exception(f"Tool {name} not found, please register the tool first")
return ToolRegistry[name]

def get_tool_names() -> list[str]:
return ToolRegistry.keys()