From 2c6d63852c1d61e3380a305866639138da08ec82 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 13 Nov 2023 20:55:24 -0800 Subject: [PATCH 1/2] (feat) allow setting base url for openai --- py/autoevals/oai.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/py/autoevals/oai.py b/py/autoevals/oai.py index a2789ef..293ecfd 100644 --- a/py/autoevals/oai.py +++ b/py/autoevals/oai.py @@ -55,13 +55,14 @@ def prepare_openai_complete(is_async=False, api_key=None): openai_obj = openai is_v1 = False + base_url = os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1") if hasattr(openai, "OpenAI"): # This is the new v1 API is_v1 = True if is_async: - openai_obj = openai.AsyncOpenAI(api_key=api_key) + openai_obj = openai.AsyncOpenAI(api_key=api_key, base_url=base_url) else: - openai_obj = openai.OpenAI(api_key=api_key) + openai_obj = openai.OpenAI(api_key=api_key, base_url=base_url) try: from braintrust.oai import wrap_openai From 5635603d91e6512943e1d460d67c86a583aff94a Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 13 Nov 2023 21:35:07 -0800 Subject: [PATCH 2/2] (fix) pass base_url as arg --- py/autoevals/llm.py | 9 ++++++++- py/autoevals/oai.py | 12 ++++++------ 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/py/autoevals/llm.py b/py/autoevals/llm.py index c85a8a3..26bebde 100644 --- a/py/autoevals/llm.py +++ b/py/autoevals/llm.py @@ -79,6 +79,7 @@ def __init__( temperature=None, engine=None, api_key=None, + base_url=None, ): self.name = name self.model = model @@ -92,6 +93,8 @@ def __init__( self.extra_args["max_tokens"] = max(max_tokens, 5) if api_key: self.extra_args["api_key"] = api_key + if base_url: + self.extra_args["base_url"] = base_url self.render_args = {} if render_args: @@ -199,6 +202,7 @@ def __init__( temperature=0, engine=None, api_key=None, + base_url=None, ): choice_strings = list(choice_scores.keys()) @@ -220,6 +224,7 @@ def __init__( temperature=temperature, engine=engine, api_key=api_key, + base_url=base_url, render_args={"__choices": choice_strings}, ) @@ -235,7 +240,7 @@ def from_spec_file(cls, name: str, path: str, **kwargs): class SpecFileClassifier(LLMClassifier): - def __new__(cls, model=None, engine=None, use_cot=None, max_tokens=None, temperature=None, api_key=None): + def __new__(cls, model=None, engine=None, use_cot=None, max_tokens=None, temperature=None, api_key=None, base_url=None): kwargs = {} if model is not None: kwargs["model"] = model @@ -249,6 +254,8 @@ def __new__(cls, model=None, engine=None, use_cot=None, max_tokens=None, tempera kwargs["temperature"] = temperature if api_key is not None: kwargs["api_key"] = api_key + if base_url is not None: + kwargs["base_url"] = base_url # convert FooBar to foo_bar template_name = re.sub(r"(?