Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 9 additions & 11 deletions dreadnode/airt/attack/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@
from dreadnode.airt.target.base import Target
from dreadnode.eval.hooks.base import EvalHook
from dreadnode.meta import Config
from dreadnode.optimization.study import OutputT as Out
from dreadnode.optimization.study import Study
from dreadnode.optimization.trial import CandidateT as In
from dreadnode.task import Task

In = t.TypeVar("In")
Out = t.TypeVar("Out")


class Attack(Study[In, Out]):
"""
A declarative configuration for executing an AIRT attack.

Attack automatically derives its task from the target.
"""

model_config = ConfigDict(arbitrary_types_allowed=True, use_attribute_docstrings=True)
Expand All @@ -23,16 +25,12 @@ class Attack(Study[In, Out]):

tags: list[str] = Config(default_factory=lambda: ["attack"])
"""A list of tags associated with the attack for logging."""

hooks: list[EvalHook] = Field(default_factory=list, exclude=True, repr=False)
"""Hooks to run at various points in the attack lifecycle."""

# Override the task factory as the target will replace it.
task_factory: t.Callable[[In], Task[..., Out]] = Field( # type: ignore[assignment]
default_factory=lambda: None,
repr=False,
init=False,
)

def model_post_init(self, context: t.Any) -> None:
self.task_factory = self.target.task_factory
"""Initialize attack by deriving task from target."""
if self.task is None:
self.task = self.target.task # type: ignore[attr-defined]
super().model_post_init(context)
6 changes: 0 additions & 6 deletions dreadnode/airt/target/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import typing_extensions as te

from dreadnode.meta import Model
from dreadnode.task import Task

In = te.TypeVar("In", default=t.Any)
Out = te.TypeVar("Out", default=t.Any)
Expand All @@ -18,8 +17,3 @@ class Target(Model, abc.ABC, t.Generic[In, Out]):
def name(self) -> str:
"""Returns the name of the target."""
raise NotImplementedError

@abc.abstractmethod
def task_factory(self, input: In) -> Task[..., Out]:
"""Creates a Task that will run the given input against the target."""
raise NotImplementedError
8 changes: 1 addition & 7 deletions dreadnode/airt/target/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from pydantic import ConfigDict

from dreadnode.airt.target.base import In, Out, Target
from dreadnode.airt.target.base import Out, Target
from dreadnode.common_types import Unset
from dreadnode.meta import Config
from dreadnode.task import Task
Expand Down Expand Up @@ -39,9 +39,3 @@ def model_post_init(self, context: t.Any) -> None:

if self.input_param_name is None:
raise ValueError(f"Could not determine input parameter for {self.task!r}")

def task_factory(self, input: In) -> Task[..., Out]:
task = self.task
if self.input_param_name is not None:
task = self.task.configure(**{self.input_param_name: input})
return task.with_(tags=["target"], append=True)
28 changes: 6 additions & 22 deletions dreadnode/airt/target/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,30 +39,14 @@ def generator(self) -> rg.Generator:
def name(self) -> str:
return self.generator.to_identifier(short=True).split("/")[-1]

def task_factory(self, input: DnMessage) -> Task[[], DnMessage]:
@cached_property
def task(self) -> Task[[DnMessage], DnMessage]:
"""
create a task that:
1. Takes dn.Message as input (auto-logged via to_serializable())
2. Converts to rg.Message only for LLM API call
3. Returns dn.Message with full multimodal content (text/images/audio/video)

Args:
input: The dn.Message to send to the LLM

Returns:
Task that executes the LLM call and returns dn.Message
Task for LLM generation.

Raises:
TypeError: If input is not a dn.Message
ValueError: If the message has no content
Message input will come from dataset (injected by Study),
not from task defaults.
"""
if not isinstance(input, DnMessage):
raise TypeError(f"Expected dn.Message, got {type(input).__name__}")

if not input.content:
raise ValueError("Message must have at least one content part")

dn_message = input
params = (
self.params
if isinstance(self.params, rg.GenerateParams)
Expand All @@ -73,7 +57,7 @@ def task_factory(self, input: DnMessage) -> Task[[], DnMessage]:

@task(name=f"target - {self.name}", tags=["target"])
async def generate(
message: DnMessage = dn_message,
message: DnMessage,
params: rg.GenerateParams = params,
) -> DnMessage:
"""Execute LLM generation task."""
Expand Down
12 changes: 7 additions & 5 deletions dreadnode/eval/hooks/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ async def hook(event: "EvalEvent") -> "EvalReaction | None": # noqa: PLR0911
if create_task:
from dreadnode import task as dn_task

task_kwargs = event.task_kwargs
input_data = event.task_kwargs

@dn_task(
name=f"transform - input ({len(transforms)} transforms)",
Expand All @@ -44,11 +44,11 @@ async def hook(event: "EvalEvent") -> "EvalReaction | None": # noqa: PLR0911
log_output=True,
)
async def apply_task(
data: dict[str, t.Any] = task_kwargs, # Use extracted variable
data: dict[str, t.Any],
) -> dict[str, t.Any]:
return await apply_transforms_to_kwargs(data, transforms)

transformed = await apply_task()
transformed = await apply_task(input_data)
return ModifyInput(task_kwargs=transformed)

# Direct application
Expand All @@ -73,10 +73,12 @@ async def apply_task(
log_inputs=True,
log_output=True,
)
async def apply_task(data: t.Any = output_data) -> t.Any: # Use extracted variable
async def apply_task(
data: t.Any,
) -> t.Any:
return await apply_transforms_to_value(data, transforms)

transformed = await apply_task()
transformed = await apply_task(output_data)
return ModifyOutput(output=transformed)

# Direct application
Expand Down
13 changes: 13 additions & 0 deletions dreadnode/optimization/console.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,10 +303,23 @@ def _handle_event(self, event: StudyEvent[t.Any]) -> None: # noqa: PLR0912
self._trials_completed += 1
self._completed_evals += 1
self._total_cost += event.trial.cost

# Check if this trial is the new best (inline check to avoid stale display)
# This handles the case where NewBestTrialFound event comes after rendering
if (
not event.trial.is_probe
and event.trial.status == "finished"
and (self._best_trial is None or event.trial.score > self._best_trial.score)
):
self._best_trial = event.trial
elif isinstance(event, NewBestTrialFound):
self._best_trial = event.trial
elif isinstance(event, StudyEnd):
self._result = event.result
# Update best trial from final result in case some trials completed
# after stop condition but before we received their events
if event.result.best_trial:
self._best_trial = event.result.best_trial

self._progress.update(self._progress_task_id, completed=self._completed_evals)

Expand Down
4 changes: 1 addition & 3 deletions dreadnode/optimization/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,7 @@ def format_study(study: "Study") -> RenderableType:
if isinstance(study, Attack):
details.add_row(Text("Target", justify="right"), repr(study.target))
else:
details.add_row(
Text("Task Factory", justify="right"), get_callable_name(study.task_factory)
)
details.add_row(Text("Task Factory", justify="right"), get_callable_name(study.task))

details.add_row(Text("Search Strategy", justify="right"), study.search_strategy.name)

Expand Down
Loading