Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion agents/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ async def handler(self) -> Dict[str, Union[str, BaseModel]]:
except ValidationError as e:
# Case: Handle pydantic validation errors by passing them back to the
# model to correct
logger.warning("Failed Pydantic Validation.")
logger.debug("Failed Pydantic Validation.")
res = str(e)

return self._construct_return_message(self.id, res)
Expand Down Expand Up @@ -259,6 +259,7 @@ class _Agent(Observable, metaclass=abc.ABCMeta):
callback_output: list
tool_res_payload: List[Message]
provider: _Provider
placeholder: Optional[Any]

def __init__(
self,
Expand Down
4 changes: 2 additions & 2 deletions agents/agent/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def _check_stop_condition(self, response):
if (answer := self.stopping_condition(self, response)) is not None:
self.answer = answer
self.terminated = True
logger.info("Stopping condition signaled, terminating.")
logger.debug("Stopping condition signaled, terminating.")

async def step(self):
"""
Expand Down Expand Up @@ -258,7 +258,7 @@ async def _handle_tool_calls(self, response):
for payload, result in zip(tool_calls, tool_call_results):
# Log it
toolcall_str = f"{payload.func_name}({str(payload.kwargs)[:100] + '...(trunc)' if len(str(payload.kwargs)) > 100 else str(payload.kwargs)})"
logger.info(f"Got tool call: {toolcall_str}")
logger.debug(f"Got tool call: {toolcall_str}")
self.scratchpad += f"\t=> {toolcall_str}\n"
self.scratchpad += "\t\t"

Expand Down
15 changes: 12 additions & 3 deletions agents/processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ def _spawn_agent(self, batch: DataInput, **kwargs) -> A:
out = self.agent_class(
provider=self.provider, **self.agent_kwargs, batch=batch_str, **kwargs
)
# TODO: Hacky, patching issue where placeholder can't be composed from within agent_handler
# because we only have str-formatted batch at that point rather than the original data
out.placeholder = self._placeholder(batch)

return out

def _load_inqueue(self):
Expand Down Expand Up @@ -238,7 +242,7 @@ async def _worker(self, worker_name: str):
logger.error(
f"[_worker - {worker_name}]: Task {id} failed {self.n_retry} times and will not be retried"
)
agent.answer = self._placeholder(agent.fmt_kwargs["batch"])
agent.answer = agent.placeholder
self.error_tasks += 1
else:
# Send data back to queue to retry processing
Expand Down Expand Up @@ -353,7 +357,12 @@ async def process(self):
try:
while not self.in_q.empty():
for idx, retries_left, agent in self.dequeue(self.in_q):
workers.append(asyncio.create_task(self._agent_handler(agent, idx, retries_left), name=f"batch-worker-{idx}"))
workers.append(
asyncio.create_task(
self._agent_handler(agent, idx, retries_left),
name=f"batch-worker-{idx}",
)
)
# Wait for all agents to complete
await asyncio.wait(workers)
finally:
Expand Down Expand Up @@ -401,7 +410,7 @@ async def _agent_handler(self, agent: Agent, id: int, retry_left: int) -> None:
logger.error(
f"[_agent_handler]: Task {id} failed {self.n_retry} times and will not be retried"
)
agent.answer = self._placeholder(agent.fmt_kwargs["batch"])
agent.answer = agent.placeholder
self.error_tasks += 1
else:
# Send data back to queue to retry processing
Expand Down
191 changes: 83 additions & 108 deletions agents/providers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,19 @@ def register_provider(self, provider: "AzureOpenAIBatchProvider"):
self.lock = asyncio.Semaphore(self.n_workers)
self.task = asyncio.create_task(self._batcher(), name="OpenAIBatchHelper")

def _batch_handler_callback(self, task: asyncio.Task):
"""
Simple callback handler for batch tasks
"""
try:
task.result()
except asyncio.CancelledError:
logger.info("Batch task was cancelled.")
except Exception as e:
logger.warning(f"Batch task resulted in an error: {str(e)}")
finally:
self.batch_tasks.remove(task)

async def _batcher(self):
"""
Batch loop
Expand All @@ -129,54 +142,36 @@ async def _batcher(self):
try:
# Define our batch and start the clock
batch = []
batch_poll_start = asyncio.get_running_loop().time()
# Wait until first query comes in
req = await self.provider.batch_q.get()
batch.append(req)

# Await new messages to load into the batch
# - Until we hit our max batch size, or
# - Until we've waited for the time indicated (default 2s)
while len(batch) < self.batch_size:
time_remaining = self.timeout - (
asyncio.get_running_loop().time() - batch_poll_start
)
if time_remaining <= 0.0:
# We've waited long enough, send what we have in a batch
break
try:
req = await asyncio.wait_for(
self.provider.batch_q.get(), timeout=0.1
self.provider.batch_q.get(), timeout=self.timeout
)
except asyncio.TimeoutError:
continue
break
batch.append(req)
self.provider.batch_q.task_done()

# Wait on the batch tasks
if len(self.batch_tasks):
finished_batches, processing_batches = await asyncio.wait(
self.batch_tasks, timeout=self.timeout
)

# Check for exceptions and remove any completed batches so we don't keep checking
for batch_task in finished_batches:
if (batch_task_err := batch_task.exception()) is not None:
logger.error(
f"[OpenAIBatchAPIHelper]: Batch task failed! {str(batch_task_err)}"
)
# Remove completed batches
self.batch_tasks.remove(batch_task)

# Case: Nothing to submit or less than max items and we're still waiting for
# a semaphore
if len(batch) == 0 or (
len(batch) < self.batch_size and self.lock.locked()
):
continue

self.batch_tasks.append(asyncio.create_task(self._batch_handler(batch)))
# Wait for semaphore to send off batch task
await self.lock.acquire()

batch_task = asyncio.create_task(self._batch_handler(batch))
self.batch_tasks.append(batch_task)

# Batch task should remove itself from the list once it's done
batch_task.add_done_callback(self._batch_handler_callback)

except (asyncio.CancelledError, GeneratorExit):
# If the task was cancelled, we should exit the loop
logger.info("OpenAIBatchHelper closing.")

break

async def _batch_handler(self, batch: List[BatchRequestInput]) -> None:
Expand All @@ -185,47 +180,51 @@ async def _batch_handler(self, batch: List[BatchRequestInput]) -> None:
when finished.
"""
# Create batch file, send to OpenAI and execute
async with self.lock:
try:
batch_file = await self.provider.send_batch(batch)
self.pbar.update(1)
self.pbar.refresh()
batch_task = await self.provider.create_batch_task(
batch_file, timeout=self.api_timeout
try:
batch_file = await self.provider.send_batch(batch)
self.pbar.update(1)
self.pbar.refresh()
batch_task = await self.provider.create_batch_task(
batch_file, timeout=self.api_timeout
)

if batch_task.errors is not None and batch_task.errors.data is not None:
# Batch returned an error. Raise
errors = "\n".join(
f"[{err.code}]: {err.message}" for err in batch_task.errors.data
)
logger.error(f"Batch {batch_task.id} returned an error:\n{errors}")
raise RuntimeError(
f"Batch {batch_task.id} returned an error:\n{errors}"
)

# Get results
results = await self.provider.get_batch_results(batch_task)

if batch_task.errors is not None and batch_task.errors.data is not None:
# Batch returned an error. Raise
errors = "\n".join(
f"[{err.code}]: {err.message}" for err in batch_task.errors.data
)
logger.error(f"Batch {batch_task.id} returned an error:\n{errors}")
raise RuntimeError(
f"Batch {batch_task.id} returned an error:\n{errors}"
)

# Get results
results = await self.provider.get_batch_results(batch_task)

# Write out results to dict for agents to pick up
for result in results:
self.provider.batch_out[result["custom_id"]].set_result(
ChatCompletion.model_validate(result["response"]["body"])
)
except Exception as e:
# propagate the exception to the futures
for batch_item in batch:
fut = self.provider.batch_out[batch_item["custom_id"]]
if not fut.done():
# If the future is not done, set it to an exception
fut.set_exception(e)

# Signal to batcher as well
raise e

finally:
self.pbar.update(-1)
self.pbar.refresh()
# Write out results to dict for agents to pick up
for result in results:
self.provider.batch_out[result["custom_id"]].set_result(
ChatCompletion.model_validate(result["response"]["body"])
)

# Log that we're done
logger.info(f"Batch [{batch_task.id}] completed.")

except Exception as e:
# propagate the exception to the futures
for batch_item in batch:
fut = self.provider.batch_out[batch_item["custom_id"]]
if not fut.done():
# If the future is not done, set it to an exception
fut.set_exception(e)

# Signal to batcher as well
raise e

finally:
self.lock.release()
self.pbar.update(-1)
self.pbar.refresh()


class OpenAIObservable(Observable[CompletionUsage]):
Expand Down Expand Up @@ -368,7 +367,7 @@ async def prompt_agent(
# NOTE: This has to come before the next step of parsing
ag.tool_res_payload.append(out.message.model_dump())

logger.info(f"Received response: {out.message.content}")
logger.debug(f"Received response: {out.message.content}")

if out.finish_reason == "length":
ag.truncated = True
Expand Down Expand Up @@ -467,7 +466,7 @@ async def query_batch_mode(
task = {
"custom_id": task_id,
"method": "POST",
"url": "/v1/chat/completions",
"url": "/chat/completions",
"body": {"model": model, **kwargs, "messages": messages},
}

Expand All @@ -478,32 +477,7 @@ async def query_batch_mode(
await self.batch_q.put(task)

# Await the result
done = False
while not done:
try:
# Poll health of batcher
await asyncio.wait_for(
asyncio.shield(self.batch_handler.task), timeout=0.5
)
except asyncio.TimeoutError:
# No issues
pass
except Exception as e:
# TODO: Write exception class so I can make this halting
logger.error(
"[AzureOpenAIBatchProvider]: BatchAPI Helper task raised an exception before query was complete!\n{str(e)}"
)
raise e

# Poll every 1s for the query result
try:
out = await asyncio.wait_for(
asyncio.shield(self.batch_out[task_id]), timeout=1
)
done = True
except asyncio.TimeoutError:
# Our task isn't done, retry
continue
out = await self.batch_out[task_id]

# remove the future from the output dict
self.batch_out.pop(task_id, None)
Expand Down Expand Up @@ -551,10 +525,15 @@ async def send_batch(
file_name, file_content, mime_type = await asyncio.to_thread(
self._create_batch_file, tasks
)
return await self.llm.files.create(

file = await self.llm.files.create(
file=(file_name, file_content, mime_type), purpose="batch", **kwargs
)

logger.info(f"Created file [{file.id}] with {len(tasks)} queries.")

return file

async def create_batch_task(
self, batch_file: FileObject, timeout: int = 30, **kwargs
) -> Batch:
Expand Down Expand Up @@ -602,21 +581,17 @@ async def get_batch_results(self, batch: Batch) -> List[Dict]:

:return: A list of results from the batch
"""
# TODO: This is kind of silly and I don't know how useful this is
# over just manually writing the increment step
return await self.round_trip_increment(self._get_batch_results)(batch)

async def _get_batch_results(self, batch: Batch) -> List[Dict]:
"""
Technical implementation of the wrapped funtion above
"""
if batch.status != "completed" or batch.output_file_id is None:
raise ValueError("Batch status was not 'completed'! Got: " + batch.status)

result_stream = await self.llm.files.content(batch.output_file_id)
results = await asyncio.to_thread(
self._response_from_bytes, result_stream.content
)
# TODO: Now this diverges from how we do it with a chat endpoint
# but maybe no reason to overcomplicate things.
self.round_trips += 1

return results

@staticmethod
Expand Down