Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 19 additions & 20 deletions py/autoevals/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
import re
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, List, Optional

import chevron
import yaml
Expand Down Expand Up @@ -126,9 +125,9 @@ def build_classification_tools(useCoT, choice_strings):
class OpenAIScorer(ScorerWithPartial):
def __init__(
self,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
client: Optional[Client] = None,
api_key: str | None = None,
base_url: str | None = None,
client: Client | None = None,
) -> None:
self.extra_args = {}
if api_key:
Expand All @@ -142,10 +141,10 @@ def __init__(
class OpenAILLMScorer(OpenAIScorer):
def __init__(
self,
temperature: Optional[float] = None,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
client: Optional[Client] = None,
temperature: float | None = None,
api_key: str | None = None,
base_url: str | None = None,
client: Client | None = None,
) -> None:
super().__init__(
api_key=api_key,
Expand All @@ -159,7 +158,7 @@ class OpenAILLMClassifier(OpenAILLMScorer):
def __init__(
self,
name: str,
messages: List,
messages: list,
model,
choice_scores,
classification_tools,
Expand All @@ -169,7 +168,7 @@ def __init__(
engine=None,
api_key=None,
base_url=None,
client: Optional[Client] = None,
client: Client | None = None,
):
super().__init__(
client=client,
Expand Down Expand Up @@ -264,11 +263,11 @@ def _run_eval_sync(self, output, expected, **kwargs):
@dataclass
class ModelGradedSpec:
prompt: str
choice_scores: Dict[str, float]
model: Optional[str] = None
engine: Optional[str] = None
use_cot: Optional[bool] = None
temperature: Optional[float] = None
choice_scores: dict[str, float]
model: str | None = None
engine: str | None = None
use_cot: bool | None = None
temperature: float | None = None


class LLMClassifier(OpenAILLMClassifier):
Expand Down Expand Up @@ -316,7 +315,7 @@ class LLMClassifier(OpenAILLMClassifier):
**extra_render_args: Additional template variables
"""

_SPEC_FILE_CONTENTS: Dict[str, str] = defaultdict(str)
_SPEC_FILE_CONTENTS: dict[str, str] = defaultdict(str)

def __init__(
self,
Expand All @@ -330,7 +329,7 @@ def __init__(
engine=None,
api_key=None,
base_url=None,
client: Optional[Client] = None,
client: Client | None = None,
**extra_render_args,
):
choice_strings = list(choice_scores.keys())
Expand Down Expand Up @@ -359,11 +358,11 @@ def __init__(
)

@classmethod
def from_spec(cls, name: str, spec: ModelGradedSpec, client: Optional[Client] = None, **kwargs):
def from_spec(cls, name: str, spec: ModelGradedSpec, client: Client | None = None, **kwargs):
return cls(name, spec.prompt, spec.choice_scores, client=client, **kwargs)

@classmethod
def from_spec_file(cls, name: str, path: str, client: Optional[Client] = None, **kwargs):
def from_spec_file(cls, name: str, path: str, client: Client | None = None, **kwargs):
if cls._SPEC_FILE_CONTENTS[name] == "":
with open(path) as f:
cls._SPEC_FILE_CONTENTS[name] = f.read()
Expand All @@ -381,7 +380,7 @@ def __new__(
temperature=None,
api_key=None,
base_url=None,
client: Optional[Client] = None,
client: Client | None = None,
):
kwargs = {}
if model is not None:
Expand Down
4 changes: 1 addition & 3 deletions py/autoevals/moderation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Optional

from autoevals.llm import OpenAIScorer

from .oai import Client, arun_cached_request, run_cached_request
Expand Down Expand Up @@ -50,7 +48,7 @@ def __init__(
threshold=None,
api_key=None,
base_url=None,
client: Optional[Client] = None,
client: Client | None = None,
):
"""Initialize a Moderation scorer.

Expand Down
57 changes: 29 additions & 28 deletions py/autoevals/oai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import textwrap
import time
import warnings
from collections.abc import Callable
from contextvars import ContextVar
from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional, Protocol, Tuple, Type, TypeVar, Union, cast, runtime_checkable
from typing import Any, Optional, Protocol, TypeVar, Union, cast, runtime_checkable

PROXY_URL = "https://api.braintrust.dev/v1/proxy"

Expand Down Expand Up @@ -50,10 +51,10 @@ def moderations(self) -> Moderations: ...
def api_key(self) -> str: ...

@property
def organization(self) -> Optional[str]: ...
def organization(self) -> str | None: ...

@property
def base_url(self) -> Union[str, Any, None]: ...
def base_url(self) -> str | Any | None: ...

class AsyncOpenAI(OpenAI): ...

Expand All @@ -75,18 +76,18 @@ class Moderation(Protocol):
acreate: Callable[..., Any]
create: Callable[..., Any]

api_key: Optional[str]
api_base: Optional[str]
base_url: Optional[str]
api_key: str | None
api_base: str | None
base_url: str | None

class error(Protocol):
class RateLimitError(Exception): ...


_openai_module: Optional[Union[OpenAIV1Module, OpenAIV0Module]] = None
_openai_module: OpenAIV1Module | OpenAIV0Module | None = None


def get_openai_module() -> Union[OpenAIV1Module, OpenAIV0Module]:
def get_openai_module() -> OpenAIV1Module | OpenAIV0Module:
global _openai_module

if _openai_module is not None:
Expand Down Expand Up @@ -150,11 +151,11 @@ def complete(self, **kwargs):
```
"""

openai: Union[OpenAIV0Module, OpenAIV1Module.OpenAI]
openai: OpenAIV0Module | OpenAIV1Module.OpenAI
complete: Callable[..., Any] = None # type: ignore # Set in __post_init__
embed: Callable[..., Any] = None # type: ignore # Set in __post_init__
moderation: Callable[..., Any] = None # type: ignore # Set in __post_init__
RateLimitError: Type[Exception] = None # type: ignore # Set in __post_init__
RateLimitError: type[Exception] = None # type: ignore # Set in __post_init__
is_async: bool = False
_is_wrapped: bool = False

Expand Down Expand Up @@ -199,11 +200,11 @@ def is_wrapped(self) -> bool:

T = TypeVar("T")

_named_wrapper: Optional[Type[Any]] = None
_wrap_openai: Optional[Callable[[Any], Any]] = None
_named_wrapper: type[Any] | None = None
_wrap_openai: Callable[[Any], Any] | None = None


def get_openai_wrappers() -> Tuple[Type[Any], Callable[[Any], Any]]:
def get_openai_wrappers() -> tuple[type[Any], Callable[[Any], Any]]:
global _named_wrapper, _wrap_openai

if _named_wrapper is not None and _wrap_openai is not None:
Expand All @@ -213,7 +214,7 @@ def get_openai_wrappers() -> Tuple[Type[Any], Callable[[Any], Any]]:
from braintrust.oai import NamedWrapper as BraintrustNamedWrapper # type: ignore
from braintrust.oai import wrap_openai # type: ignore

_named_wrapper = cast(Type[Any], BraintrustNamedWrapper)
_named_wrapper = cast(type[Any], BraintrustNamedWrapper)
except ImportError:

class NamedWrapper:
Expand All @@ -237,7 +238,7 @@ def resolve_client(client: Client, is_async: bool = False) -> LLMClient:
return LLMClient(openai=client, is_async=is_async)


def init(client: Optional[Client] = None, is_async: bool = False):
def init(client: Client | None = None, is_async: bool = False):
"""Initialize Autoevals with an optional custom LLM client.

This function sets up the global client context for Autoevals to use. If no client is provided,
Expand All @@ -259,10 +260,10 @@ def init(client: Optional[Client] = None, is_async: bool = False):


def prepare_openai(
client: Optional[Client] = None,
client: Client | None = None,
is_async: bool = False,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
api_key: str | None = None,
base_url: str | None = None,
):
"""Prepares and configures an OpenAI client for use with AutoEval.

Expand Down Expand Up @@ -348,7 +349,7 @@ def prepare_openai(
return LLMClient(openai=openai_obj, is_async=is_async)


def post_process_response(resp: Any) -> Dict[str, Any]:
def post_process_response(resp: Any) -> dict[str, Any]:
# This normalizes against craziness in OpenAI v0 vs. v1
if hasattr(resp, "to_dict"):
# v0
Expand All @@ -358,18 +359,18 @@ def post_process_response(resp: Any) -> Dict[str, Any]:
return resp.dict()


def set_span_purpose(kwargs: Dict[str, Any]) -> None:
def set_span_purpose(kwargs: dict[str, Any]) -> None:
kwargs.setdefault("span_info", {}).setdefault("span_attributes", {})["purpose"] = "scorer"


def run_cached_request(
*,
client: Optional[LLMClient] = None,
client: LLMClient | None = None,
request_type: str = "complete",
api_key: Optional[str] = None,
base_url: Optional[str] = None,
api_key: str | None = None,
base_url: str | None = None,
**kwargs: Any,
) -> Dict[str, Any]:
) -> dict[str, Any]:
wrapper = prepare_openai(client=client, is_async=False, api_key=api_key, base_url=base_url)
if wrapper.is_wrapped:
set_span_purpose(kwargs)
Expand All @@ -393,12 +394,12 @@ def run_cached_request(

async def arun_cached_request(
*,
client: Optional[LLMClient] = None,
client: LLMClient | None = None,
request_type: str = "complete",
api_key: Optional[str] = None,
base_url: Optional[str] = None,
api_key: str | None = None,
base_url: str | None = None,
**kwargs: Any,
) -> Dict[str, Any]:
) -> dict[str, Any]:
wrapper = prepare_openai(client=client, is_async=True, api_key=api_key, base_url=base_url)
if wrapper.is_wrapped:
set_span_purpose(kwargs)
Expand Down
Loading