Skip to content

Commit eb986d0

Browse files
committed
fix: add direct_complete check
1 parent 92d9664 commit eb986d0

File tree

1 file changed

+49
-28
lines changed

1 file changed

+49
-28
lines changed

analysis/get_results.py

100644100755
Lines changed: 49 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,22 @@
1010
import itertools
1111
import math
1212
from datasets import Dataset
13+
from transformers import AutoTokenizer
14+
15+
16+
def update_model_info(model_info):
17+
for model, info in model_info.items():
18+
if "https://huggingface.co/" in info["link"]:
19+
hf_model = info["link"].split("https://huggingface.co/")[-1]
20+
tokenizer = AutoTokenizer.from_pretrained(hf_model, trust_remote_code=True)
21+
if tokenizer.chat_template is None:
22+
model_info[model]["direct_complete"] = True
23+
else:
24+
model_info[model]["direct_complete"] = False
25+
else:
26+
model_info[model]["direct_complete"] = False
27+
28+
return model_info
1329

1430

1531
def get_results():
@@ -26,16 +42,18 @@ def get_results():
2642
},
2743
"prompted": info["prompted"],
2844
"size": info["size"],
45+
"direct_complete": info["direct_complete"],
2946
}
3047

3148
for model, info in model_info.items():
3249
model = model.replace("/", "--")
50+
hf_model = ""
3351
if "https://huggingface.co/" in info["link"]:
34-
model = info["link"].split("https://huggingface.co/")[-1].replace("/", "--")
52+
hf_model = info["link"].split("https://huggingface.co/")[-1]
53+
model = hf_model.replace("/", "--")
3554
files = glob(f"results/{model}--bigcodebench-*.json")
3655
assert files, f"No files found for results/{model}--bigcodebench-*.json"
3756
for file in files:
38-
# print(file)
3957
_, suffix = os.path.basename(file).split("--bigcodebench-")
4058
status = []
4159
with open("results/"+model+"--bigcodebench-"+suffix, "r") as f:
@@ -57,36 +75,29 @@ def get_results():
5775
mode = ""
5876
if "-sanitized-calibrate" in file:
5977
mode = "-cal"
60-
78+
6179
results[info["name"]][f"pass@1"][f"{task}{mode}"] = round(mean(status)*100,1)
62-
if not info["prompted"]:
80+
if not info["prompted"] or info["direct_complete"]:
6381
results[info["name"]][f"pass@1"][f"{task}-cal"] = round(mean(status)*100,1)
64-
82+
6583
for model, result in results.items():
6684
for task in ["complete"]:
6785
origin = result["pass@1"].pop(task)
6886
assert origin, f"Missing original complete results for {model}"
6987
calibrate = result["pass@1"].pop(f"{task}-cal")
70-
assert calibrate, f"Missing calibrated complete results for {model}"
71-
if calibrate - origin > 1:
72-
results[model]["lazy"] = True
88+
if calibrate:
89+
if calibrate - origin > 1:
90+
results[model]["lazy"] = True
91+
else:
92+
results[model]["lazy"] = False
93+
results[model]["pass@1"][task] = calibrate
7394
else:
7495
results[model]["lazy"] = False
75-
results[model]["pass@1"][task] = calibrate
96+
results[model]["pass@1"][task] = origin
7697
calibrate_instruct = result["pass@1"].pop(f"instruct-cal")
7798
result["pass@1"]["instruct"] = calibrate_instruct
7899
return results
79100

80-
81-
def compute_diff(results):
82-
diffs = []
83-
for model, info in model_info.items():
84-
if not info["prompted"]:
85-
continue
86-
diff = results[info["name"]]["pass@1"]["complete"] - results[info["name"]]["pass@1"]["complete-cal"]
87-
diffs.append(diff)
88-
print("Mean diff:", mean(diffs))
89-
90101

91102
def check_valid(results):
92103
for model, result in results.items():
@@ -104,8 +115,8 @@ def split_gen():
104115
os.makedirs("sanitized_calibrated_samples/instruct", exist_ok=True)
105116
for model, info in model_info.items():
106117
model = model.replace("/", "--")
107-
files = glob(f"clean_results/{model}--bigcodebench-*.jsonl")
108-
if "https://huggingface.co/" in info["link"]:
118+
files = glob(f"results/{model}--bigcodebench-*.jsonl")
119+
if info["link"].startswith("https://huggingface.co/"):
109120
model = info["link"].split("https://huggingface.co/")[-1].replace("/", "--")
110121

111122
for file in files:
@@ -139,16 +150,19 @@ def read_task_perf(task="complete"):
139150

140151
task_perf = {f"BigCodeBench/{task_id}": 0 for task_id in range(1140)}
141152
model = model.replace("/", "--")
142-
if "https://huggingface.co/" in info["link"]:
153+
if info["link"].startswith("https://huggingface.co/"):
143154
model = info["link"].split("https://huggingface.co/")[-1].replace("/", "--")
144155
try:
145-
if info["prompted"]:
146-
file = glob(f"results/{model}--bigcodebench-{task}*-0-1-sanitized-calibrated_eval_results.json")[0]
156+
if info["prompted"] and not info["direct_complete"]:
157+
files = glob(f"results/{model}--bigcodebench-{task}*-0-1-sanitized-calibrated_eval_results.json")
158+
if files:
159+
file = files[0]
160+
else:
161+
file = glob(f"results/{model}--bigcodebench-{task}*-0-1-sanitized_eval_results.json")[0]
147162
else:
148163
file = glob(f"results/{model}--bigcodebench-{task}*-0-1-sanitized_eval_results.json")[0]
149164
except:
150165
continue
151-
raise ValueError(f"Missing results/{model}--bigcodebench-{task}*-0-1-sanitized_eval_results.json")
152166

153167
with open(file, "r") as f:
154168
data = json.load(f)
@@ -260,7 +274,7 @@ def get_solve_rate(data_dict, task="complete"):
260274

261275

262276
def get_hf_ds(results):
263-
hf_dataset = {"model": [], "link": [], "size": [], "type": [], "lazy": [],
277+
hf_dataset = {"model": [], "link": [], "size": [], "type": [], "lazy": [], "direct_complete": [],
264278
"complete": [], "instruct": [], "elo_mle": []}
265279

266280
for model, result in results.items():
@@ -271,7 +285,9 @@ def get_hf_ds(results):
271285
hf_dataset["lazy"].append(result["lazy"])
272286
hf_dataset["complete"].append(result["pass@1"]["complete"])
273287
hf_dataset["instruct"].append(result["pass@1"]["instruct"])
288+
hf_dataset["direct_complete"].append(result["direct_complete"])
274289
hf_dataset["elo_mle"].append(result["elo_mle"])
290+
275291
return Dataset.from_dict(hf_dataset)
276292

277293
def get_bootstrap_scores(df):
@@ -294,6 +310,8 @@ def push_ds(ds, path, local=False):
294310

295311

296312
if __name__ == "__main__":
313+
314+
model_info = update_model_info(model_info)
297315
results = get_results()
298316
complete_data = read_task_perf("complete")
299317
instruct_data = read_task_perf("instruct")
@@ -302,14 +320,17 @@ def push_ds(ds, path, local=False):
302320
push_ds(complete_solve_rate, "bigcode/bigcodebench-complete-solve-rate")
303321
push_ds(instruct_solve_rate, "bigcode/bigcodebench-instruct-solve-rate")
304322

305-
battles = get_winner_df(complete_data, "complete")
323+
task_level = True
324+
no_tie = True
325+
battles = get_winner_df(complete_data, "complete", task_level=task_level, no_tie=no_tie)
306326
elo_mle_bootstrap = get_bootstrap_result(battles, get_elo_mle, 500)
307327
bootstrap_lu_median = elo_mle_bootstrap.median().reset_index().set_axis(["model", "Elo rating"], axis=1)
308328
bootstrap_lu_median["Elo rating"] = (bootstrap_lu_median["Elo rating"] + 0.5).astype(int)
309329
bootstrap_lu_median_dict = bootstrap_lu_median.set_index("model")["Elo rating"].to_dict()
310330
elo = get_bootstrap_scores(elo_mle_bootstrap)
311331
push_ds(elo, "bigcode/bigcodebench-elo")
312-
332+
# push_ds(elo, "bigcode/bigcodebench-elo-model-with-tie")
333+
313334
results = update_elo_rating(results, bootstrap_lu_median_dict)
314335
with open("results.json", "w") as f:
315336
json.dump(results, f, indent=4)

0 commit comments

Comments
 (0)