Skip to content

Commit c3bdeef

Browse files
committed
fix: make trusted check in the thread
1 parent 747cf4e commit c3bdeef

File tree

2 files changed

+25
-12
lines changed

2 files changed

+25
-12
lines changed

bigcodebench/evaluate.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
Result = Tuple[str, List[bool]]
3535

3636

37-
def get_groundtruth(problems, hashcode, check_gt_only, max_as_limit, max_data_limit, max_stack_limit):
37+
def get_groundtruth(n_workers, problems, hashcode, check_gt_only, max_as_limit, max_data_limit, max_stack_limit):
3838
cache_file = os.path.join(CACHE_DIR, f"{hashcode}.pkl")
3939
if os.path.exists(cache_file):
4040
if check_gt_only:
@@ -47,16 +47,29 @@ def get_groundtruth(problems, hashcode, check_gt_only, max_as_limit, max_data_li
4747
os.makedirs(CACHE_DIR, exist_ok=True)
4848
print("\nAsserting the groundtruth...")
4949
tbegin = time.time()
50-
expected_time = {}
51-
for task_id, problem in tqdm(problems.items()):
52-
expected_time[task_id] = trusted_check(
53-
problem["complete_prompt"] + "\n" + problem["canonical_solution"],
54-
problem["test"],
55-
problem["task_id"],
56-
max_as_limit,
57-
max_data_limit,
58-
max_stack_limit
59-
)
50+
51+
with ProcessPoolExecutor(max_workers=n_workers) as executor:
52+
futures = []
53+
n_samples = 0
54+
expected_time = dict()
55+
56+
for problem in problems.values():
57+
args = (
58+
problem["complete_prompt"] + "\n" + problem["canonical_solution"],
59+
problem["test"],
60+
problem["task_id"],
61+
max_as_limit,
62+
max_data_limit,
63+
max_stack_limit
64+
)
65+
66+
futures.append(executor.submit(trusted_check, *args))
67+
n_samples += 1
68+
69+
for future in tqdm(as_completed(futures), total=n_samples):
70+
result = future.result()
71+
expected_time[result["task_id"]] = result["time"]
72+
6073
print(f"Expected outputs computed in {time.time() - tbegin:.2f}s")
6174

6275
with open(cache_file, "wb") as f:

bigcodebench/gen/util/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,4 +114,4 @@ def trusted_check(
114114
else:
115115
times = times.value
116116

117-
return times
117+
return {"task_id": task_id, "time": times}

0 commit comments

Comments
 (0)