Skip to content

Commit 1d9ea6a

Browse files
committed
feat: batch o1 and deepseek-chat via concurrency
1 parent 4920808 commit 1d9ea6a

File tree

3 files changed

+77
-37
lines changed

3 files changed

+77
-37
lines changed
Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import signal
21
import time
32

43
import openai
@@ -14,53 +13,38 @@ def make_request(
1413
n: int = 1,
1514
**kwargs
1615
) -> ChatCompletion:
17-
system_msg = "You are a helpful assistant good at coding."
18-
if (
19-
kwargs.get("response_format", None)
20-
and kwargs["response_format"]["type"] == "json_object"
21-
):
22-
system_msg = "You are a helpful assistant designed to output JSON."
23-
16+
kwargs["top_p"] = 0.95
17+
kwargs["max_completion_tokens"] = max_tokens
18+
if model.startswith("o1-"): # pop top-p and max_completion_tokens
19+
kwargs.pop("top_p")
20+
kwargs.pop("max_completion_tokens")
21+
2422
return client.chat.completions.create(
2523
model=model,
2624
messages=[
27-
{"role": "system", "content": system_msg},
2825
{"role": "user", "content": message},
2926
],
30-
max_tokens=max_tokens,
3127
temperature=temperature,
3228
n=n,
3329
**kwargs
3430
)
3531

3632

37-
def handler(signum, frame):
38-
# swallow signum and frame
39-
raise Exception("end of time")
40-
41-
4233
def make_auto_request(*args, **kwargs) -> ChatCompletion:
4334
ret = None
4435
while ret is None:
4536
try:
46-
signal.signal(signal.SIGALRM, handler)
47-
signal.alarm(100)
4837
ret = make_request(*args, **kwargs)
49-
signal.alarm(0)
5038
except openai.RateLimitError:
5139
print("Rate limit exceeded. Waiting...")
52-
signal.alarm(0)
5340
time.sleep(5)
5441
except openai.APIConnectionError:
5542
print("API connection error. Waiting...")
56-
signal.alarm(0)
5743
time.sleep(5)
5844
except openai.APIError as e:
5945
print(e)
60-
signal.alarm(0)
6146
except Exception as e:
6247
print("Unknown error. Waiting...")
6348
print(e)
64-
signal.alarm(0)
6549
time.sleep(1)
66-
return ret
50+
return ret

bigcodebench/provider/openai.py

Lines changed: 62 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import os
22
from typing import List
3-
from tqdm import tqdm
43

54
import openai
65

7-
from bigcodebench.provider.base import DecoderBase
86
from bigcodebench.gen.util.openai_request import make_auto_request
97
from bigcodebench.provider.utility import make_raw_chat_prompt
8+
from bigcodebench.provider.base import DecoderBase
9+
from bigcodebench.provider.utility import concurrent_call
1010

1111
class OpenAIChatDecoder(DecoderBase):
1212
def __init__(self, name: str, base_url=None, **kwargs) -> None:
@@ -15,34 +15,83 @@ def __init__(self, name: str, base_url=None, **kwargs) -> None:
1515
api_key=os.getenv("OPENAI_API_KEY", "none"), base_url=base_url
1616
)
1717

18+
# def codegen(
19+
# self, prompts: List[str], do_sample: bool = True, num_samples: int = 200
20+
# ) -> List[str]:
21+
# if do_sample:
22+
# assert self.temperature > 0, "Temperature must be positive for sampling"
23+
# all_outputs = []
24+
# for prompt in tqdm(prompts):
25+
# outputs = []
26+
# message = make_raw_chat_prompt(
27+
# task_prompt=prompt,
28+
# subset=self.subset,
29+
# split=self.split,
30+
# instruction_prefix=self.instruction_prefix,
31+
# response_prefix=self.response_prefix,
32+
# tokenizer=None,
33+
# )
34+
# ret = make_auto_request(
35+
# self.client,
36+
# message=message,
37+
# model=self.name,
38+
# max_tokens=self.max_new_tokens,
39+
# temperature=self.temperature,
40+
# n=num_samples,
41+
# )
42+
# for item in ret.choices:
43+
# outputs.append(item.message.content)
44+
# all_outputs.append(outputs)
45+
# return all_outputs
46+
47+
# def is_direct_completion(self) -> bool:
48+
# return False
49+
1850
def codegen(
1951
self, prompts: List[str], do_sample: bool = True, num_samples: int = 200
2052
) -> List[str]:
2153
if do_sample:
2254
assert self.temperature > 0, "Temperature must be positive for sampling"
55+
messages = [make_raw_chat_prompt(
56+
task_prompt=prompt,
57+
subset=self.subset,
58+
split=self.split,
59+
instruction_prefix=self.instruction_prefix,
60+
response_prefix=self.response_prefix,
61+
tokenizer=None,
62+
) for prompt in prompts]
63+
# use concurrency based batching for o1 and deepseek models
64+
if self.name.startswith("o1-") or self.name == "deepseek-chat":
65+
return self._codegen_batch_via_concurrency(messages, num_samples)
66+
67+
return self._codegen_api_batch(messages, num_samples)
68+
69+
def _codegen_api_batch(self, messages: List[str], num_samples: int) -> List[str]:
70+
client = openai.OpenAI(
71+
api_key=os.getenv("OPENAI_API_KEY", "none"), base_url=self.base_url
72+
)
73+
2374
all_outputs = []
24-
for prompt in tqdm(prompts):
25-
outputs = []
26-
message = make_raw_chat_prompt(
27-
task_prompt=prompt,
28-
subset=self.subset,
29-
split=self.split,
30-
instruction_prefix=self.instruction_prefix,
31-
response_prefix=self.response_prefix,
32-
tokenizer=None,
33-
)
75+
for message in messages:
3476
ret = make_auto_request(
35-
self.client,
77+
client,
3678
message=message,
3779
model=self.name,
3880
max_tokens=self.max_new_tokens,
3981
temperature=self.temperature,
4082
n=num_samples,
4183
)
84+
outputs = []
4285
for item in ret.choices:
4386
outputs.append(item.message.content)
4487
all_outputs.append(outputs)
4588
return all_outputs
4689

90+
def _codegen_batch_via_concurrency(self, messages: List[str], num_samples: int) -> List[str]:
91+
batches = concurrent_call(
92+
num_samples, self._codegen_api_batch, messages, num_samples=1
93+
)
94+
return [b[0] for b in batches]
95+
4796
def is_direct_completion(self) -> bool:
4897
return False

bigcodebench/provider/utility.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import List
22
from transformers import AutoTokenizer
3+
from concurrent.futures import ThreadPoolExecutor
34

45
EOS = [
56
"<|endoftext|>",
@@ -64,4 +65,10 @@ def make_raw_chat_prompt(
6465
],
6566
tokenize=False,
6667
).split(_MAGIC_SPLITTER_)[0]
67-
return task_prompt
68+
return task_prompt
69+
70+
71+
def concurrent_call(n, callback, /, *args, **kwargs):
72+
with ThreadPoolExecutor(max_workers=n) as executor:
73+
futures = [executor.submit(callback, *args, **kwargs) for _ in range(n)]
74+
return [future.result() for future in futures]

0 commit comments

Comments
 (0)