Skip to content

Commit 8993dca

Browse files
committed
refactor: update evaluate pipeline
1 parent 7c5c3d0 commit 8993dca

File tree

2 files changed

+21
-15
lines changed

2 files changed

+21
-15
lines changed

bigcodebench/evaluate.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,8 @@ def evaluate(
127127
check_gt_only: bool = False,
128128
no_gt: bool = False,
129129
**model_kwargs,
130-
):
130+
):
131+
131132
if not samples and model_kwargs:
132133
samples = run_codegen(
133134
split=split,
@@ -164,7 +165,7 @@ def evaluate(
164165

165166
else:
166167

167-
pass_k = [int(k.strip()) for k in pass_k.split(',') if k.strip().isdigit()]
168+
pass_k = [int(k) for k in pass_k.split(",")]
168169

169170
if parallel is None:
170171
n_workers = max(1, multiprocessing.cpu_count() // 2)

bigcodebench/generate.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@ def codegen(
1919
model: DecoderBase,
2020
target_path: str,
2121
split: str,
22-
subset="full",
23-
greedy=False,
24-
strip_newlines=False,
25-
n_samples=1,
26-
id_range=None,
27-
resume=True,
28-
batch_size: int=-1,
22+
subset: str,
23+
greedy: bool = False,
24+
strip_newlines: bool = False,
25+
n_samples: int = 1,
26+
id_range: Tuple[int, int] = None,
27+
resume: bool = True,
28+
batch_size: int = -1,
2929
):
3030
with Progress(
3131
TextColumn(f"BigCodeBench--{split.capitalize()} ({subset.capitalize()}) •" + "[progress.percentage]{task.percentage:>3.0f}%"),
@@ -51,12 +51,12 @@ def codegen(
5151
batch_entry_points = []
5252

5353
# Read existing data once if resuming
54-
existing_data = {}
54+
task2nexist = {}
5555
if resume and os.path.exists(target_path):
5656
with open(target_path, "r") as f:
5757
for line in f:
5858
item = json.loads(line)
59-
existing_data[item["task_id"]] = existing_data.get(item["task_id"], 0) + 1
59+
task2nexist[item["task_id"]] = task2nexist.get(item["task_id"], 0) + 1
6060

6161
for id_num, (task_id, task) in enumerate(p.track(dataset.items())):
6262
if id_range is not None:
@@ -69,7 +69,7 @@ def codegen(
6969

7070
p_name = task_id.replace("/", "_")
7171

72-
n_existing = existing_data.get(task_id, 0)
72+
n_existing = task2nexist.get(task_id, 0)
7373
nsamples = n_samples - n_existing
7474

7575
try:
@@ -91,7 +91,7 @@ def codegen(
9191
p.console.print(log)
9292

9393
if (batch_size and len(batch_prompts) == batch_size) or id_num == len(dataset) - 1 or (id_range and id_num == id_range[1] - 1):
94-
if not batch_prompts and id_num == len(dataset) - 1:
94+
if not batch_prompts and (id_num == len(dataset) - 1 or (id_range and id_num == id_range[1] - 1)):
9595
break
9696
outputs = model.codegen(
9797
batch_prompts,
@@ -130,6 +130,7 @@ def run_codegen(
130130
bs: Optional[int] = None,
131131
n_samples: int = 1,
132132
temperature: float = 0.0,
133+
max_new_tokens: int = 1280,
133134
greedy: bool = False,
134135
strip_newlines: bool = False,
135136
direct_completion: bool = False,
@@ -147,7 +148,7 @@ def run_codegen(
147148
temperature = 0
148149
n_samples = 1
149150
greedy = True
150-
print("Greedy decoding ON (--greedy): setting bs=1, n_samples=1, temperature=0")
151+
print("Greedy decoding ON (--greedy): setting n_samples=1, temperature=0")
151152

152153
if id_range is not None:
153154
assert len(id_range) == 2, "id_range must be a list of length 2"
@@ -167,6 +168,7 @@ def run_codegen(
167168
subset=subset,
168169
split=split,
169170
temperature=temperature,
171+
max_new_tokens=max_new_tokens,
170172
instruction_prefix=instruction_prefix,
171173
response_prefix=response_prefix,
172174
base_url=base_url,
@@ -181,7 +183,10 @@ def run_codegen(
181183
identifier = model.replace("/", "--") + f"--bigcodebench{extra}-{split}--{backend}-{temperature}-{n_samples}-sanitized_calibrated.jsonl"
182184

183185
target_path = os.path.join(root, identifier)
184-
186+
187+
if not resume:
188+
os.remove(target_path)
189+
185190
codegen(
186191
model=model_runner,
187192
target_path=target_path,

0 commit comments

Comments
 (0)