Skip to content

Commit 4b7e3af

Browse files
authored
Merge branch 'main' into hard
2 parents bbc070e + 5b67995 commit 4b7e3af

File tree

3 files changed

+41
-4
lines changed

3 files changed

+41
-4
lines changed

bigcodebench/evaluate.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,32 @@ def stucking_checker():
281281
if not os.path.isfile(result_path):
282282
with open(result_path, "w") as f:
283283
json.dump(results, f, indent=2)
284+
285+
pass_at_k_path = result_path.replace("_eval_results.json", "_pass_at_k.json")
286+
pass_at_k["model"] = flags.samples.split("/")[-1].replace(".jsonl", "")
287+
pass_at_k["subset"] = flags.subset
288+
289+
def save_pass_at_k():
290+
with open(pass_at_k_path, "w") as f:
291+
json.dump(pass_at_k, f, indent=2)
292+
293+
if os.path.isfile(pass_at_k_path):
294+
saved_pass_at_k = json.load(open(pass_at_k_path, "r"))
295+
# compare saved_pass_at_k with pass_at_k
296+
for k in saved_pass_at_k.keys():
297+
if pass_at_k[k] != saved_pass_at_k[k]:
298+
cprint(f"Warning: {k} is different from the saved one", "yellow")
299+
300+
# ask user whether to save the pass@k
301+
decision = ""
302+
while decision.lower() not in ["y", "n"]:
303+
print(f"Save pass@k to {pass_at_k_path}? [Y/N]")
304+
decision = input()
305+
if decision.lower() == "y":
306+
save_pass_at_k()
307+
308+
else:
309+
save_pass_at_k()
284310

285311

286312
def main():

bigcodebench/generate.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def codegen(
3636

3737
if model.is_direct_completion() and split == "instruct":
3838
raise Exception("Base model does not support direct completion for instruct tasks")
39-
39+
4040
# create save_path if it doesn't exist, e.g., a/b.jsonl
4141
dirname = os.path.dirname(save_path)
4242
if not os.path.exists(dirname) and dirname != "":
@@ -119,6 +119,8 @@ def main():
119119
parser.add_argument("--base_url", default=None, type=str)
120120
parser.add_argument("--tp", default=1, type=int)
121121
parser.add_argument("--trust_remote_code", action="store_true")
122+
parser.add_argument("--tokenizer_name", default=None, type=str)
123+
122124
args = parser.parse_args()
123125

124126
if args.greedy and (args.temperature != 0 or args.bs != 1 or args.n_samples != 1)\
@@ -142,7 +144,8 @@ def main():
142144
temperature=args.temperature,
143145
base_url=args.base_url,
144146
tp=args.tp,
145-
trust_remote_code=args.trust_remote_code
147+
trust_remote_code=args.trust_remote_code,
148+
tokenizer_name=args.tokenizer_name
146149
)
147150

148151
extra = "-" + args.subset if args.subset != "full" else ""
@@ -160,7 +163,7 @@ def main():
160163
strip_newlines=args.strip_newlines,
161164
n_samples=args.n_samples,
162165
resume=args.resume,
163-
id_range=args.id_range,
166+
id_range=args.id_range
164167
)
165168

166169

bigcodebench/model.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ def __init__(
9191
max_new_tokens: int = 1280,
9292
dtype: str = "bfloat16", # default
9393
trust_remote_code: bool = False,
94+
tokenizer_name: str = None,
9495
) -> None:
9596
print("Initializing a decoder model: {} ...".format(name))
9697
self.name = name
@@ -101,6 +102,7 @@ def __init__(
101102
self.max_new_tokens = max_new_tokens
102103
self.dtype = dtype
103104
self.trust_remote_code = trust_remote_code
105+
self.tokenizer_name = tokenizer_name
104106

105107
@abstractmethod
106108
def codegen(
@@ -185,7 +187,10 @@ def __init__(self, name: str, dataset: str, **kwargs):
185187
kwargs["torch_dtype"] = getattr(torch, self.dtype)
186188
self.skip_special_tokens = True
187189

188-
print(f"{kwargs = }")
190+
print(f"{kwargs = }", self.tokenizer_name)
191+
192+
if self.tokenizer_name is None:
193+
self.tokenizer_name = self.name
189194

190195
self.tokenizer = AutoTokenizer.from_pretrained(name, legacy=False, **kwargs)
191196
if self.tokenizer.chat_template is None:
@@ -475,6 +480,7 @@ def make_model(
475480
tp=1,
476481
base_url=None,
477482
trust_remote_code=False,
483+
tokenizer_name=None,
478484
):
479485
if backend == "vllm":
480486
return GeneralVllmDecoder(
@@ -484,6 +490,7 @@ def make_model(
484490
dataset=dataset,
485491
tp=tp,
486492
trust_remote_code=trust_remote_code,
493+
tokenizer_name=tokenizer_name,
487494
)
488495
elif backend == "hf":
489496
return GenenralHfTorchDecoder(
@@ -492,6 +499,7 @@ def make_model(
492499
temperature=temperature,
493500
dataset=dataset,
494501
trust_remote_code=trust_remote_code,
502+
tokenizer_name=tokenizer_name,
495503
)
496504
elif backend == "openai":
497505
return OpenAIChatDecoder(

0 commit comments

Comments
 (0)