Skip to content

Commit 7d9e4fc

Browse files
committed
ask user whther to save pass@k
1 parent d1cb2fe commit 7d9e4fc

File tree

1 file changed

+22
-2
lines changed

1 file changed

+22
-2
lines changed

bigcodebench/evaluate.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,8 +281,28 @@ def stucking_checker():
281281
pass_at_k_path = result_path.replace("_eval_results.json", "_pass_at_k.json")
282282
pass_at_k["model"] = flags.samples.split("/")[-1].replace(".jsonl", "")
283283
pass_at_k["subset"] = flags.subset
284-
with open(pass_at_k_path, "w") as f:
285-
json.dump(pass_at_k, f, indent=2)
284+
285+
def save_pass_at_k():
286+
with open(pass_at_k_path, "w") as f:
287+
json.dump(pass_at_k, f, indent=2)
288+
289+
if os.path.isfile(pass_at_k_path):
290+
saved_pass_at_k = json.load(open(pass_at_k_path, "r"))
291+
# compare saved_pass_at_k with pass_at_k
292+
for k in saved_pass_at_k.keys():
293+
if pass_at_k[k] != saved_pass_at_k[k]:
294+
cprint(f"Warning: {k} is different from the saved one", "yellow")
295+
296+
# ask user whether to save the pass@k
297+
decision = ""
298+
while decision.lower() not in ["y", "n"]:
299+
print(f"Save pass@k to {pass_at_k_path}? [Y/N]")
300+
decision = input()
301+
if decision.lower() == "y":
302+
save_pass_at_k()
303+
304+
else:
305+
save_pass_at_k()
286306

287307

288308
def main():

0 commit comments

Comments
 (0)