Skip to content

Commit 7c5c3d0

Browse files
committed
refactor(gen): update model provider
1 parent e51bd31 commit 7c5c3d0

File tree

9 files changed

+91
-66
lines changed

9 files changed

+91
-66
lines changed

bigcodebench/provider/__init__.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,20 @@ def make_model(
88
split: str,
99
dataset: str = "bigcodebench",
1010
temperature: float = 0.0,
11+
max_new_tokens: int = 1280,
1112
# instruction model only
12-
instruction_prefix=None,
13-
response_prefix=None,
13+
instruction_prefix: str = None,
14+
response_prefix: str = None,
1415
# vllm only
15-
tp=1,
16-
direct_completion=False,
17-
base_url=None,
18-
trust_remote_code=False,
16+
tp: int = 1,
17+
direct_completion: bool = False,
18+
base_url: str = None,
19+
trust_remote_code: bool = False,
1920
# hf only
20-
attn_implementation="eager",
21+
attn_implementation: str = "eager",
2122
# tokenizer
22-
tokenizer_name=None,
23-
tokenizer_kwargs=None,
23+
tokenizer_name: str = None,
24+
tokenizer_legacy: bool = True,
2425
) -> DecoderBase:
2526
if backend == "vllm":
2627
from bigcodebench.provider.vllm import VllmDecoder
@@ -30,9 +31,10 @@ def make_model(
3031
subset=subset,
3132
split=split,
3233
temperature=temperature,
34+
max_new_tokens=max_new_tokens,
3335
dataset=dataset,
3436
direct_completion=direct_completion,
35-
tensor_parallel_size=tp,
37+
tp=tp,
3638
instruction_prefix=instruction_prefix,
3739
response_prefix=response_prefix,
3840
)
@@ -44,6 +46,7 @@ def make_model(
4446
subset=subset,
4547
split=split,
4648
temperature=temperature,
49+
max_new_tokens=max_new_tokens,
4750
dataset=dataset,
4851
direct_completion=direct_completion,
4952
instruction_prefix=instruction_prefix,
@@ -59,6 +62,7 @@ def make_model(
5962
subset=subset,
6063
split=split,
6164
temperature=temperature,
65+
max_new_tokens=max_new_tokens,
6266
base_url=base_url,
6367
instruction_prefix=instruction_prefix,
6468
response_prefix=response_prefix,
@@ -71,6 +75,7 @@ def make_model(
7175
subset=subset,
7276
split=split,
7377
temperature=temperature,
78+
max_new_tokens=max_new_tokens,
7479
instruction_prefix=instruction_prefix,
7580
response_prefix=response_prefix,
7681
)
@@ -83,6 +88,7 @@ def make_model(
8388
subset=subset,
8489
split=split,
8590
temperature=temperature,
91+
max_new_tokens=max_new_tokens,
8692
instruction_prefix=instruction_prefix,
8793
response_prefix=response_prefix,
8894
)
@@ -95,6 +101,7 @@ def make_model(
95101
subset=subset,
96102
split=split,
97103
temperature=temperature,
104+
max_new_tokens=max_new_tokens,
98105
instruction_prefix=instruction_prefix,
99106
response_prefix=response_prefix,
100107
)

bigcodebench/provider/anthropic.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import os
22
from typing import List
3+
from tqdm import tqdm
34

45
import anthropic
56

6-
from bigcodebench.gen.util import anthropic_request
7+
from bigcodebench.gen.util.anthropic_request import make_auto_request
78
from bigcodebench.provider.base import DecoderBase
89
from bigcodebench.provider.utility import make_raw_chat_prompt
910

@@ -18,15 +19,12 @@ def codegen(
1819
if do_sample:
1920
assert self.temperature > 0, "Temperature must be positive for sampling"
2021

21-
if not do_sample:
22-
assert batch_size == 1, "Sampling only supports batch size of 1"
23-
2422
all_outputs = []
2523
for prompt in tqdm(prompts):
2624
outputs = []
2725

2826
for _ in range(num_samples):
29-
message = anthropic_request.make_auto_request(
27+
ret = make_auto_request(
3028
client=self.client,
3129
model=self.name,
3230
messages=[
@@ -46,9 +44,9 @@ def codegen(
4644
temperature=self.temperature,
4745
stop_sequences=self.eos,
4846
)
49-
outputs.append(message.content[0].text)
47+
outputs.append(ret.content[0].text)
5048
all_outputs.append(outputs)
51-
return outputs
49+
return all_outputs
5250

5351
def is_direct_completion(self) -> bool:
5452
return False

bigcodebench/provider/base.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from abc import ABC, abstractmethod
22
from typing import List
33

4-
from evalplus.provider.utility import EOS
4+
from bigcodebench.provider.utility import EOS
55

66

77
class DecoderBase(ABC):
@@ -11,12 +11,14 @@ def __init__(
1111
subset: str,
1212
split: str,
1313
temperature: float = 0.8,
14-
max_new_tokens: int = 5120,
14+
max_new_tokens: int = 1280,
1515
dtype: str = "bfloat16", # default
1616
direct_completion: bool = False,
1717
trust_remote_code: bool = False,
1818
tokenizer_name: str = None,
1919
tokenizer_legacy: bool = False,
20+
instruction_prefix: str = None,
21+
response_prefix: str = None,
2022
) -> None:
2123
print("Initializing a decoder model: {} ...".format(name))
2224
self.name = name
@@ -31,6 +33,8 @@ def __init__(
3133
self.trust_remote_code = trust_remote_code
3234
self.tokenizer_name = tokenizer_name
3335
self.tokenizer_legacy = tokenizer_legacy
36+
self.instruction_prefix = instruction_prefix
37+
self.response_prefix = response_prefix
3438

3539
@abstractmethod
3640
def codegen(

bigcodebench/provider/google.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import os
22
from typing import List
3+
from tqdm import tqdm
34

45
import google.generativeai as genai
56

6-
77
from bigcodebench.provider.base import DecoderBase
88
from bigcodebench.gen.util.google_request import make_auto_request
99
from bigcodebench.provider.utility import make_raw_chat_prompt
@@ -24,7 +24,7 @@ def codegen(
2424
all_outputs = []
2525

2626
for prompt in tqdm(prompts):
27-
ret_texts = []
27+
outputs = []
2828
message = make_raw_chat_prompt(
2929
task_prompt=prompt,
3030
subset=self.subset,
@@ -33,25 +33,23 @@ def codegen(
3333
response_prefix=self.response_prefix,
3434
tokenizer=None,
3535
)
36-
replies = make_auto_request(
36+
ret = make_auto_request(
3737
self.client,
3838
message,
3939
self.name,
40-
n=batch_size,
40+
n=num_samples,
4141
max_tokens=self.max_new_tokens,
4242
temperature=self.temperature,
4343
)
44-
for candidate in replies.candidates:
44+
for candidate in ret.candidates:
4545
parts = candidate.content.parts
4646
if parts:
47-
ret_texts.append(parts[0].text)
47+
outputs.append(parts[0].text)
4848
else:
4949
print("Empty response!")
50-
ret_texts.append("")
50+
outputs.append("")
5151
print(f"{candidate.safety_ratings = }")
52-
ret_texts.append("")
53-
all_outputs.append(ret_texts + [""] * (batch_size - len(ret_texts)))
54-
52+
all_outputs.append(outputs)
5553
return all_outputs
5654

5755
def is_direct_completion(self) -> bool:

bigcodebench/provider/hf.py

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from stop_sequencer import StopSequencer
55
from transformers import AutoModelForCausalLM, AutoTokenizer
66

7-
from evalplus.provider.base import DecoderBase
8-
from evalplus.provider.utility import (
7+
from bigcodebench.provider.base import DecoderBase
8+
from bigcodebench.provider.utility import (
99
extra_eos_for_direct_completion,
1010
make_raw_chat_prompt,
1111
)
@@ -33,14 +33,17 @@ def __init__(
3333
print(f"{kwargs = }")
3434

3535
self.tokenizer = AutoTokenizer.from_pretrained(name, use_fast=False, legacy=self.tokenizer_legacy)
36+
self.tokenizer.pad_token = self.tokenizer.eos_token
37+
# assume the model is decoder-only
38+
self.tokenizer.padding_side = 'left'
39+
3640
if self.is_direct_completion(): # no chat template
3741
self.eos += extra_eos_for_direct_completion(dataset)
3842
else: # with chat template
3943
self.eos += ["\n```\n"]
4044

4145
print(f"{self.eos = }")
4246
self.model = AutoModelForCausalLM.from_pretrained(name, **kwargs)
43-
self.model = self.model.to(self.device)
4447

4548
def is_direct_completion(self) -> bool:
4649
return self.direct_completion or self.tokenizer.chat_template is None
@@ -61,15 +64,16 @@ def codegen(
6164
)
6265
for prompt in prompts
6366
]
64-
input_tokens = self.tokenizer.encode(prompts, return_tensors="pt").to(
67+
68+
input_tokens = self.tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).to(
6569
self.device
66-
)
70+
)["input_ids"]
71+
6772
kwargs = {}
6873
if do_sample:
6974
kwargs["top_p"] = 0.95
7075
kwargs["temperature"] = self.temperature
71-
72-
outputs = self.model.generate(
76+
ret = self.model.generate(
7377
input_tokens,
7478
max_new_tokens=self.max_new_tokens,
7579
do_sample=do_sample,
@@ -79,17 +83,23 @@ def codegen(
7983
tokenizer=self.tokenizer,
8084
**kwargs,
8185
)
86+
87+
# Reshape ret into a list of lists, each sublist containing num_samples elements
88+
ret_chunks = [ret[i:i + num_samples] for i in range(0, len(ret), num_samples)]
8289

83-
gen_strs = self.tokenizer.batch_decode(
84-
outputs[:, input_tokens.size(-1) :],
85-
skip_special_tokens=self.skip_special_tokens,
86-
)
87-
outputs = []
88-
# removes eos tokens.
89-
for output in gen_strs:
90-
min_index = 10000
91-
for eos in self.eos:
92-
if eos in output:
93-
min_index = min(min_index, output.index(eos))
94-
outputs.append(output[:min_index].replace("\t", " "))
95-
return outputs
90+
all_outputs = []
91+
# Process each chunk in ret_chunks
92+
for i, ret_chunk in enumerate(ret_chunks):
93+
gen_strs = self.tokenizer.batch_decode(
94+
ret_chunk[:, input_tokens[i].size(-1):],
95+
skip_special_tokens=self.skip_special_tokens,
96+
)
97+
outputs = []
98+
for output in gen_strs:
99+
min_index = 10000
100+
for eos in self.eos:
101+
if eos in output:
102+
min_index = min(min_index, output.index(eos))
103+
outputs.append(output[:min_index].replace("\t", " "))
104+
all_outputs.append(outputs)
105+
return all_outputs

bigcodebench/provider/mistral.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
11
import os
22
from typing import List
3+
from tqdm import tqdm
34

4-
import anthropic
5+
from mistralai.client import MistralClient
6+
from mistralai.models.chat_completion import ChatMessage
57

68
from bigcodebench.provider.base import DecoderBase
9+
from bigcodebench.gen.util.mistral_request import make_auto_request
710
from bigcodebench.provider.utility import make_raw_chat_prompt
811

9-
class MistralDecoder(DecoderBase):
12+
class MistralChatDecoder(DecoderBase):
1013
def __init__(self, name: str, **kwargs) -> None:
1114
super().__init__(name, **kwargs)
12-
self.client = mistral.Mistral(api_key=os.getenv("MISTRAL_API_KEY"))
15+
self.client = MistralClient(api_key=os.getenv("MISTRAL_API_KEY"))
1316

1417
def codegen(
15-
self, prompt: str, do_sample: bool = True, num_samples: int = 200
18+
self, prompts: List[str], do_sample: bool = True, num_samples: int = 200
1619
) -> List[str]:
1720
if do_sample:
1821
assert self.temperature > 0, "Temperature must be positive for sampling"
@@ -22,7 +25,7 @@ def codegen(
2225
outputs = []
2326

2427
for _ in range(num_samples):
25-
message = mistral_request.make_auto_request(
28+
ret = make_auto_request(
2629
client=self.client,
2730
model=self.name,
2831
messages=[
@@ -40,9 +43,8 @@ def codegen(
4043
)
4144
],
4245
max_tokens=self.max_new_tokens,
43-
**kwargs,
4446
)
45-
outputs.append(message.content[0].text)
47+
outputs.append(ret.choices[0].message.content)
4648
all_outputs.append(outputs)
4749
return all_outputs
4850

bigcodebench/provider/openai.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import os
22
from typing import List
3+
from tqdm import tqdm
34

45
import openai
56

6-
from evalplus.gen.util import openai_request
7-
from evalplus.provider.base import DecoderBase
7+
from bigcodebench.provider.base import DecoderBase
8+
from bigcodebench.gen.util.openai_request import make_auto_request
89
from bigcodebench.provider.utility import make_raw_chat_prompt
910

1011
class OpenAIChatDecoder(DecoderBase):
@@ -21,6 +22,7 @@ def codegen(
2122
assert self.temperature > 0, "Temperature must be positive for sampling"
2223
all_outputs = []
2324
for prompt in tqdm(prompts):
25+
outputs = []
2426
message = make_raw_chat_prompt(
2527
task_prompt=prompt,
2628
subset=self.subset,
@@ -29,7 +31,7 @@ def codegen(
2931
response_prefix=self.response_prefix,
3032
tokenizer=None,
3133
)
32-
ret = openai_request.make_auto_request(
34+
ret = make_auto_request(
3335
self.client,
3436
message=message,
3537
model=self.name,

0 commit comments

Comments
 (0)