File tree Expand file tree Collapse file tree 1 file changed +22
-2
lines changed
Expand file tree Collapse file tree 1 file changed +22
-2
lines changed Original file line number Diff line number Diff line change @@ -281,8 +281,28 @@ def stucking_checker():
281281 pass_at_k_path = result_path .replace ("_eval_results.json" , "_pass_at_k.json" )
282282 pass_at_k ["model" ] = flags .samples .split ("/" )[- 1 ].replace (".jsonl" , "" )
283283 pass_at_k ["subset" ] = flags .subset
284- with open (pass_at_k_path , "w" ) as f :
285- json .dump (pass_at_k , f , indent = 2 )
284+
285+ def save_pass_at_k ():
286+ with open (pass_at_k_path , "w" ) as f :
287+ json .dump (pass_at_k , f , indent = 2 )
288+
289+ if os .path .isfile (pass_at_k_path ):
290+ saved_pass_at_k = json .load (open (pass_at_k_path , "r" ))
291+ # compare saved_pass_at_k with pass_at_k
292+ for k in saved_pass_at_k .keys ():
293+ if pass_at_k [k ] != saved_pass_at_k [k ]:
294+ cprint (f"Warning: { k } is different from the saved one" , "yellow" )
295+
296+ # ask user whether to save the pass@k
297+ decision = ""
298+ while decision .lower () not in ["y" , "n" ]:
299+ print (f"Save pass@k to { pass_at_k_path } ? [Y/N]" )
300+ decision = input ()
301+ if decision .lower () == "y" :
302+ save_pass_at_k ()
303+
304+ else :
305+ save_pass_at_k ()
286306
287307
288308def main ():
You can’t perform that action at this time.
0 commit comments