Skip to content

Commit 133b067

Browse files
committed
feat: add customized tokenizer
1 parent bbc070e commit 133b067

File tree

3 files changed

+44
-7
lines changed

3 files changed

+44
-7
lines changed

bigcodebench/evaluate.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,32 @@ def stucking_checker():
281281
if not os.path.isfile(result_path):
282282
with open(result_path, "w") as f:
283283
json.dump(results, f, indent=2)
284+
285+
pass_at_k_path = result_path.replace("_eval_results.json", "_pass_at_k.json")
286+
pass_at_k["model"] = flags.samples.split("/")[-1].replace(".jsonl", "")
287+
pass_at_k["subset"] = flags.subset
288+
289+
def save_pass_at_k():
290+
with open(pass_at_k_path, "w") as f:
291+
json.dump(pass_at_k, f, indent=2)
292+
293+
if os.path.isfile(pass_at_k_path):
294+
saved_pass_at_k = json.load(open(pass_at_k_path, "r"))
295+
# compare saved_pass_at_k with pass_at_k
296+
for k in saved_pass_at_k.keys():
297+
if pass_at_k[k] != saved_pass_at_k[k]:
298+
cprint(f"Warning: {k} is different from the saved one", "yellow")
299+
300+
# ask user whether to save the pass@k
301+
decision = ""
302+
while decision.lower() not in ["y", "n"]:
303+
print(f"Save pass@k to {pass_at_k_path}? [Y/N]")
304+
decision = input()
305+
if decision.lower() == "y":
306+
save_pass_at_k()
307+
308+
else:
309+
save_pass_at_k()
284310

285311

286312
def main():

bigcodebench/generate.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def main():
119119
parser.add_argument("--base_url", default=None, type=str)
120120
parser.add_argument("--tp", default=1, type=int)
121121
parser.add_argument("--trust_remote_code", action="store_true")
122+
parser.add_argument("--tokenizer_name", default=None, type=str)
122123
args = parser.parse_args()
123124

124125
if args.greedy and (args.temperature != 0 or args.bs != 1 or args.n_samples != 1)\
@@ -142,7 +143,8 @@ def main():
142143
temperature=args.temperature,
143144
base_url=args.base_url,
144145
tp=args.tp,
145-
trust_remote_code=args.trust_remote_code
146+
trust_remote_code=args.trust_remote_code,
147+
tokenizer_name=args.tokenizer_name
146148
)
147149

148150
extra = "-" + args.subset if args.subset != "full" else ""

bigcodebench/model.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ def __init__(
9191
max_new_tokens: int = 1280,
9292
dtype: str = "bfloat16", # default
9393
trust_remote_code: bool = False,
94+
tokenizer_name: str = None,
9495
) -> None:
9596
print("Initializing a decoder model: {} ...".format(name))
9697
self.name = name
@@ -101,6 +102,7 @@ def __init__(
101102
self.max_new_tokens = max_new_tokens
102103
self.dtype = dtype
103104
self.trust_remote_code = trust_remote_code
105+
self.tokenizer_name = tokenizer_name
104106

105107
@abstractmethod
106108
def codegen(
@@ -128,8 +130,10 @@ def __init__(self, name: str, dataset: str, tp: int, **kwargs) -> None:
128130
"dtype": self.dtype,
129131
"trust_remote_code": self.trust_remote_code,
130132
}
131-
132-
self.tokenizer = AutoTokenizer.from_pretrained(self.name, legacy=False, **kwargs)
133+
if self.tokenizer_name is None:
134+
self.tokenizer_name = self.name
135+
136+
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, legacy=False, **kwargs)
133137
if self.tokenizer.chat_template is None:
134138
self.eos += extra_eos_for_direct_completion(dataset)
135139
self.llm = LLM(model=name, max_model_len=2048, **kwargs)
@@ -185,9 +189,11 @@ def __init__(self, name: str, dataset: str, **kwargs):
185189
kwargs["torch_dtype"] = getattr(torch, self.dtype)
186190
self.skip_special_tokens = True
187191

188-
print(f"{kwargs = }")
189-
190-
self.tokenizer = AutoTokenizer.from_pretrained(name, legacy=False, **kwargs)
192+
print(f"{kwargs = }", self.tokenizer_name)
193+
if self.tokenizer_name is None:
194+
self.tokenizer_name = self.name
195+
196+
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, legacy=False, **kwargs)
191197
if self.tokenizer.chat_template is None:
192198
self.eos += extra_eos_for_direct_completion(dataset)
193199

@@ -242,7 +248,7 @@ def __init__(self, name: str, **kwargs):
242248
super().__init__(name=name, **kwargs)
243249
self.eos += ["\n```\n"]
244250
print(f"EOS strings: {self.eos}")
245-
self.tokenizer = AutoTokenizer.from_pretrained(self.name, legacy=False, **kwargs)
251+
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name if self.tokenizer_name else self.name, **kwargs)
246252

247253
def codegen(
248254
self, prompt: str, do_sample: bool = True, num_samples: int = 200
@@ -475,6 +481,7 @@ def make_model(
475481
tp=1,
476482
base_url=None,
477483
trust_remote_code=False,
484+
tokenizer_name=None,
478485
):
479486
if backend == "vllm":
480487
return GeneralVllmDecoder(
@@ -484,6 +491,7 @@ def make_model(
484491
dataset=dataset,
485492
tp=tp,
486493
trust_remote_code=trust_remote_code,
494+
tokenizer_name=tokenizer_name,
487495
)
488496
elif backend == "hf":
489497
return GenenralHfTorchDecoder(
@@ -492,6 +500,7 @@ def make_model(
492500
temperature=temperature,
493501
dataset=dataset,
494502
trust_remote_code=trust_remote_code,
503+
tokenizer_name=tokenizer_name,
495504
)
496505
elif backend == "openai":
497506
return OpenAIChatDecoder(

0 commit comments

Comments
 (0)