3434Result = Tuple [str , List [bool ]]
3535
3636
37- def get_groundtruth (problems , hashcode , check_gt_only , max_as_limit , max_data_limit , max_stack_limit ):
37+ def get_groundtruth (n_workers , problems , hashcode , check_gt_only , max_as_limit , max_data_limit , max_stack_limit ):
3838 cache_file = os .path .join (CACHE_DIR , f"{ hashcode } .pkl" )
3939 if os .path .exists (cache_file ):
4040 if check_gt_only :
@@ -47,16 +47,29 @@ def get_groundtruth(problems, hashcode, check_gt_only, max_as_limit, max_data_li
4747 os .makedirs (CACHE_DIR , exist_ok = True )
4848 print ("\n Asserting the groundtruth..." )
4949 tbegin = time .time ()
50- expected_time = {}
51- for task_id , problem in tqdm (problems .items ()):
52- expected_time [task_id ] = trusted_check (
53- problem ["complete_prompt" ] + "\n " + problem ["canonical_solution" ],
54- problem ["test" ],
55- problem ["task_id" ],
56- max_as_limit ,
57- max_data_limit ,
58- max_stack_limit
59- )
50+
51+ with ProcessPoolExecutor (max_workers = n_workers ) as executor :
52+ futures = []
53+ n_samples = 0
54+ expected_time = dict ()
55+
56+ for problem in problems .values ():
57+ args = (
58+ problem ["complete_prompt" ] + "\n " + problem ["canonical_solution" ],
59+ problem ["test" ],
60+ problem ["task_id" ],
61+ max_as_limit ,
62+ max_data_limit ,
63+ max_stack_limit
64+ )
65+
66+ futures .append (executor .submit (trusted_check , * args ))
67+ n_samples += 1
68+
69+ for future in tqdm (as_completed (futures ), total = n_samples ):
70+ result = future .result ()
71+ expected_time [result ["task_id" ]] = result ["time" ]
72+
6073 print (f"Expected outputs computed in { time .time () - tbegin :.2f} s" )
6174
6275 with open (cache_file , "wb" ) as f :
0 commit comments