Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
ProcessorDF,
ProcessorIterable,
)
from .providers import OpenAIProvider, AzureOpenAIProvider, AzureOpenAIBatchProvider
from .stopping_conditions import (
StoppingCondition,
StopOnStep,
Expand All @@ -30,9 +29,6 @@
"BatchProcessorIterable",
"ProcessorDF",
"BatchProcessorDF",
"OpenAIProvider",
"AzureOpenAIProvider",
"AzureOpenAIBatchProvider",
"StoppingCondition",
"StopOnStep",
"StopOnDataModel",
Expand Down
27 changes: 17 additions & 10 deletions agents/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from asyncio import Task, create_task, to_thread
from dataclasses import dataclass, field
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Expand All @@ -23,13 +24,16 @@
Union,
)

from openai.types.chat import ChatCompletionMessageParam, ChatCompletionMessage
from openai.types.chat.chat_completion import ChatCompletion, Choice
if TYPE_CHECKING:
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionMessage
from openai.types.chat.chat_completion import ChatCompletion, Choice

from pydantic import BaseModel, ValidationError
from .observability import Observable

logger = logging.getLogger(__name__)

Message = Union[dict[str, str], ChatCompletionMessageParam]
Message = Union[dict[str, str], "ChatCompletionMessageParam"]

P = TypeVar("P", bound="_Provider")
A = TypeVar("A", bound="_Agent")
Expand Down Expand Up @@ -186,13 +190,13 @@ async def handler(self) -> Dict[str, Union[str, BaseModel]]:
except ValidationError as e:
# Case: Handle pydantic validation errors by passing them back to the
# model to correct
logger.warning("Failed Pydantic Validation.")
logger.debug("Failed Pydantic Validation.")
res = str(e)

return self._construct_return_message(self.id, res)


class _Provider(Generic[A], metaclass=abc.ABCMeta):
class _Provider(Observable, Generic[A], metaclass=abc.ABCMeta):
"""
A LLM Provider which should provide the standard methods for prompting and agent
authenticating, etc.
Expand All @@ -201,11 +205,12 @@ class _Provider(Generic[A], metaclass=abc.ABCMeta):
"The tool_call class specific to this provider that will be used to evaluate any tool calls from the model"
tool_call_wrapper: Type[_ToolCall]
"The method that will be used to call the OpenAI API, e.g. openai.chat.completions.create"
endpoint_fn: Callable[..., Awaitable[ChatCompletion]]
endpoint_fn: Callable[..., Awaitable["ChatCompletion"]]

mode: Literal["chat", "batch"]

def __init__(self, model_name: str, **kwargs):
super().__init__()
pass

@abc.abstractmethod
Expand Down Expand Up @@ -235,11 +240,11 @@ class _StoppingCondition(metaclass=abc.ABCMeta):
"""

@abc.abstractmethod
def __call__(self, cls: "_Agent", response: Choice) -> Optional[Any]:
def __call__(self, cls: "_Agent", response: "Choice") -> Optional[Any]:
raise NotImplementedError()


class _Agent(metaclass=abc.ABCMeta):
class _Agent(Observable, metaclass=abc.ABCMeta):
terminated: bool = False
truncated: bool = False
curr_step: int = 1
Expand All @@ -254,6 +259,7 @@ class _Agent(metaclass=abc.ABCMeta):
callback_output: list
tool_res_payload: List[Message]
provider: _Provider
placeholder: Optional[Any]

def __init__(
self,
Expand All @@ -265,6 +271,7 @@ def __init__(
oai_kwargs: Optional[dict[str, Any]] = None,
**fmt_kwargs,
):
super().__init__()
pass

@abc.abstractmethod
Expand All @@ -277,7 +284,7 @@ async def step(self):
raise NotImplementedError()

@abc.abstractmethod
def _check_stop_condition(self, response: ChatCompletionMessage) -> None:
def _check_stop_condition(self, response: "ChatCompletionMessage") -> None:
"""
Called from within :func:`step()`.
Checks whether our stop condition has been met and handles assignment of answer, if so.
Expand All @@ -298,7 +305,7 @@ def get_next_messages(self) -> List[Message]:
raise NotImplementedError()

@abc.abstractmethod
def _handle_tool_calls(self, response: Choice) -> None:
def _handle_tool_calls(self, response: "Choice") -> None:
raise NotImplementedError()

@property
Expand Down
14 changes: 12 additions & 2 deletions agents/agent/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,16 @@ def __init__(
:param dict[str, any] oai_kwargs: Dict of additional OpenAI arguments to pass thru to chat call
:param fmt_kwargs: Additional named arguments which will be inserted into the :func:`BASE_PROMPT` via fstring
"""
super().__init__(
stopping_condition,
model_name=None,
provider=None,
tools=None,
callbacks=None,
oai_kwargs=None,
**fmt_kwargs,
)

self.fmt_kwargs = fmt_kwargs
self.stopping_condition = stopping_condition
# We default to Azure OpenAI here, but
Expand Down Expand Up @@ -129,7 +139,7 @@ def _check_stop_condition(self, response):
if (answer := self.stopping_condition(self, response)) is not None:
self.answer = answer
self.terminated = True
logger.info("Stopping condition signaled, terminating.")
logger.debug("Stopping condition signaled, terminating.")

async def step(self):
"""
Expand Down Expand Up @@ -248,7 +258,7 @@ async def _handle_tool_calls(self, response):
for payload, result in zip(tool_calls, tool_call_results):
# Log it
toolcall_str = f"{payload.func_name}({str(payload.kwargs)[:100] + '...(trunc)' if len(str(payload.kwargs)) > 100 else str(payload.kwargs)})"
logger.info(f"Got tool call: {toolcall_str}")
logger.debug(f"Got tool call: {toolcall_str}")
self.scratchpad += f"\t=> {toolcall_str}\n"
self.scratchpad += "\t\t"

Expand Down
6 changes: 3 additions & 3 deletions agents/agent/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""

import logging
from typing import Callable, List, Literal, Optional, Any, Type
from typing import Callable, List, Literal, Optional, Any

import pydantic

Expand All @@ -30,7 +30,7 @@ def __init__(
expected_len: Optional[int] = None,
stopping_condition: Optional[_StoppingCondition] = None,
model_name: Optional[str] = None,
provider: Optional[Type[_Provider]] = None,
provider: Optional[_Provider] = None,
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Callable]] = None,
oai_kwargs: Optional[dict[str, Any]] = None,
Expand All @@ -44,7 +44,7 @@ def __init__(
:param int expected_len: Optional length constraint on the response_model (OpenAI API doesn't allow maxItems parameter in schema so this is checked post-hoc in the Pydantic BaseModel)
:param _StoppingCondition stopping_condition: A handler that signals when an Agent has completed the task (optional)
:param str model_name: Name of model to use (or deployment name for AzureOpenAI) (optional if provider is passed)
:param Type[_Provider] provider: Instantiated OpenAI instance to use (optional)
:param _Provider provider: Instantiated OpenAI instance to use (optional)
:param List[dict] tools: List of tools the agent can call via response (optional)
:param List[Callable] callbacks: List of callbacks to evaluate at end of run (optional)
:param dict[str, any] oai_kwargs: Dict of additional OpenAI arguments to pass thru to chat call
Expand Down
Loading