Skip to content

Commit bbe93d6

Browse files
authored
feat: save pass@k results and use custom tokenizer
Save pass@k result & use custom tokenizer
2 parents 19ca466 + 7d9e4fc commit bbe93d6

File tree

3 files changed

+48
-9
lines changed

3 files changed

+48
-9
lines changed

bigcodebench/evaluate.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,32 @@ def stucking_checker():
277277
if not os.path.isfile(result_path):
278278
with open(result_path, "w") as f:
279279
json.dump(results, f, indent=2)
280+
281+
pass_at_k_path = result_path.replace("_eval_results.json", "_pass_at_k.json")
282+
pass_at_k["model"] = flags.samples.split("/")[-1].replace(".jsonl", "")
283+
pass_at_k["subset"] = flags.subset
284+
285+
def save_pass_at_k():
286+
with open(pass_at_k_path, "w") as f:
287+
json.dump(pass_at_k, f, indent=2)
288+
289+
if os.path.isfile(pass_at_k_path):
290+
saved_pass_at_k = json.load(open(pass_at_k_path, "r"))
291+
# compare saved_pass_at_k with pass_at_k
292+
for k in saved_pass_at_k.keys():
293+
if pass_at_k[k] != saved_pass_at_k[k]:
294+
cprint(f"Warning: {k} is different from the saved one", "yellow")
295+
296+
# ask user whether to save the pass@k
297+
decision = ""
298+
while decision.lower() not in ["y", "n"]:
299+
print(f"Save pass@k to {pass_at_k_path}? [Y/N]")
300+
decision = input()
301+
if decision.lower() == "y":
302+
save_pass_at_k()
303+
304+
else:
305+
save_pass_at_k()
280306

281307

282308
def main():

bigcodebench/generate.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def codegen(
3535

3636
if model.is_direct_completion() and subset == "instruct":
3737
raise Exception("Base model does not support direct completion for instruct tasks")
38-
38+
3939
# create save_path if it doesn't exist, e.g., a/b.jsonl
4040
dirname = os.path.dirname(save_path)
4141
if not os.path.exists(dirname) and dirname != "":
@@ -118,6 +118,8 @@ def main():
118118
parser.add_argument("--base_url", default=None, type=str)
119119
parser.add_argument("--tp", default=1, type=int)
120120
parser.add_argument("--trust_remote_code", action="store_true")
121+
parser.add_argument("--tokenizer_name", default=None, type=str)
122+
121123
args = parser.parse_args()
122124

123125

@@ -145,7 +147,8 @@ def main():
145147
temperature=args.temperature,
146148
base_url=args.base_url,
147149
tp=args.tp,
148-
trust_remote_code=args.trust_remote_code
150+
trust_remote_code=args.trust_remote_code,
151+
tokenizer_name=args.tokenizer_name
149152
)
150153

151154
if not args.save_path:
@@ -161,7 +164,7 @@ def main():
161164
strip_newlines=args.strip_newlines,
162165
n_samples=args.n_samples,
163166
resume=args.resume,
164-
id_range=args.id_range,
167+
id_range=args.id_range
165168
)
166169

167170

bigcodebench/model.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def __init__(
9292
max_new_tokens: int = 1280,
9393
dtype: str = "bfloat16", # default
9494
trust_remote_code: bool = False,
95+
tokenizer_name: str = None,
9596
) -> None:
9697
print("Initializing a decoder model: {} ...".format(name))
9798
self.name = name
@@ -102,6 +103,7 @@ def __init__(
102103
self.max_new_tokens = max_new_tokens
103104
self.dtype = dtype
104105
self.trust_remote_code = trust_remote_code
106+
self.tokenizer_name = tokenizer_name
105107

106108
@abstractmethod
107109
def codegen(
@@ -129,11 +131,13 @@ def __init__(self, name: str, dataset: str, tp: int, **kwargs) -> None:
129131
"dtype": self.dtype,
130132
"trust_remote_code": self.trust_remote_code,
131133
}
132-
133-
self.tokenizer = AutoTokenizer.from_pretrained(self.name, **kwargs)
134+
if self.tokenizer_name is None:
135+
self.tokenizer_name = self.name
136+
137+
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, **kwargs)
134138
if self.tokenizer.chat_template is None:
135139
self.eos += extra_eos_for_direct_completion(dataset)
136-
self.llm = LLM(model=name, max_model_len=2048, **kwargs)
140+
self.llm = LLM(model=name, max_model_len=2048, tokenizer=self.tokenizer_name, **kwargs)
137141

138142
def is_direct_completion(self) -> bool:
139143
return self.tokenizer.chat_template is None
@@ -185,9 +189,12 @@ def __init__(self, name: str, dataset: str, **kwargs):
185189
kwargs["torch_dtype"] = getattr(torch, self.dtype)
186190
self.skip_special_tokens = True
187191

188-
print(f"{kwargs = }")
192+
print(f"{kwargs = }", self.tokenizer_name)
193+
194+
if self.tokenizer_name is None:
195+
self.tokenizer_name = self.name
189196

190-
self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
197+
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, **kwargs)
191198
if self.tokenizer.chat_template is None:
192199
self.eos += extra_eos_for_direct_completion(dataset)
193200

@@ -253,7 +260,7 @@ def __init__(self, name: str, **kwargs):
253260
super().__init__(name=name, **kwargs)
254261
self.eos += ["\n```\n"]
255262
print(f"EOS strings: {self.eos}")
256-
self.tokenizer = AutoTokenizer.from_pretrained(self.name, **kwargs)
263+
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name if self.tokenizer_name else self.name, **kwargs)
257264

258265
def codegen(
259266
self, prompt: str, do_sample: bool = True, num_samples: int = 200
@@ -486,6 +493,7 @@ def make_model(
486493
tp=1,
487494
base_url=None,
488495
trust_remote_code=False,
496+
tokenizer_name=None,
489497
):
490498
if backend == "vllm":
491499
return GeneralVllmDecoder(
@@ -495,6 +503,7 @@ def make_model(
495503
dataset=dataset,
496504
tp=tp,
497505
trust_remote_code=trust_remote_code,
506+
tokenizer_name=tokenizer_name,
498507
)
499508
elif backend == "hf":
500509
return GenenralHfTorchDecoder(
@@ -503,6 +512,7 @@ def make_model(
503512
temperature=temperature,
504513
dataset=dataset,
505514
trust_remote_code=trust_remote_code,
515+
tokenizer_name=tokenizer_name,
506516
)
507517
elif backend == "openai":
508518
return OpenAIChatDecoder(

0 commit comments

Comments
 (0)