Skip to content

Commit 9009723

Browse files
add to inspect (#1065)
* adding lcb to inspect-ai * adding lcb * adds aime * add ifeval french * move prompts to respective eval file * rename mmlu prompt functoin * move prompt function to task files * move prompt function to task files * prompt function names end with _prompt * fix missing * prompt function names end with _prompt * remove LETTER_INDEX * fix tests * revert lcb * add to inspect * add mixeval and musr * Apply suggestion from @NathanHB * Apply style fixes * adding tasks to inspect * adding tasks to inspect * adds tasks that are defined in inspect --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 943c4c3 commit 9009723

25 files changed

+727
-54
lines changed

src/lighteval/main_inspect.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,11 @@ def get_inspect_ai_task(
5353
name = lighteval_task_config.name
5454
sample_fields = lighteval_task_config.sample_fields
5555

56+
if sample_fields is None:
57+
raise ValueError(
58+
f"Task {name} is not supported by inspect_ai yet. You can either define it or use a different backend, `lighteval --help`"
59+
)
60+
5661
dataset_repo = lighteval_task_config.hf_repo
5762
dataset_subset = lighteval_task_config.hf_subset
5863
dataset_split = lighteval_task_config.evaluation_splits[0]
@@ -528,12 +533,36 @@ def bundle(log_dir: str, output_dir: str, overwrite: bool = True, repo_id: str |
528533

529534

530535
if __name__ == "__main__":
531-
task = "lighteval|gsm8k|5,lighteval|gsm8k|1,lighteval|gsm8k|0"
532-
task = "lighteval|agieval|0"
533-
task = "lighteval|hle|0"
534-
task = "lighteval|ifeval|0"
535-
task = "lighteval|gpqa|0"
536-
task = "lighteval|ifbench_test|0"
537-
task = "lighteval|mmlu_pro|0"
536+
tasks = [
537+
"gsm8k",
538+
"agieval",
539+
"hle",
540+
"ifeval",
541+
"gpqa",
542+
"ifbench_test",
543+
"mmlu_pro",
544+
"mixeval",
545+
"aimo",
546+
"anli",
547+
"arc",
548+
"arithmetic",
549+
"asdiv",
550+
"babi_qa",
551+
"bbq",
552+
"bigbench",
553+
"bigbench_hard",
554+
"blimp",
555+
"bold",
556+
"boolq",
557+
"civil_comments",
558+
"commonsenseqa",
559+
"covid_dialog",
560+
"dyck_language",
561+
"math_500",
562+
"musr",
563+
"olympiad_bench",
564+
"simpleqa",
565+
"tiny_benchmarks",
566+
]
538567
model = "hf-inference-providers/meta-llama/Llama-3.1-8B-Instruct:nebius"
539568
eval(models=[model], tasks=task)

src/lighteval/tasks/tasks/aimo.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717
paper:
1818
"""
1919

20-
from lighteval.metrics.metrics import Metrics
20+
from inspect_ai.dataset import Sample
21+
from inspect_ai.solver import generate
22+
23+
from lighteval.metrics.metrics import Metrics, math_scorer
2124
from lighteval.metrics.normalizations import math_normalizer
2225
from lighteval.tasks.lighteval_task import LightevalTaskConfig
2326
from lighteval.tasks.requests import Doc
@@ -32,9 +35,16 @@ def aimo_prompt(line, task_name: str = None):
3235
)
3336

3437

38+
def record_to_sample(record):
39+
return Sample(input=record["problem"], target=str(record["answer"]))
40+
41+
3542
task = LightevalTaskConfig(
3643
name="aimo_progress_prize_1",
3744
prompt_function=aimo_prompt,
45+
sample_fields=record_to_sample,
46+
solver=[generate(cache=True)],
47+
scorer=math_scorer(),
3848
hf_subset="",
3949
hf_repo="lighteval/aimo_progress_prize_1",
4050
hf_avail_splits=["train"],

src/lighteval/tasks/tasks/anli.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@
2222
https://arxiv.org/abs/1910.14599
2323
"""
2424

25+
from string import ascii_uppercase
26+
27+
from inspect_ai.dataset import Sample
28+
from inspect_ai.scorer import choice
29+
from inspect_ai.solver import multiple_choice
30+
2531
from lighteval.metrics.metrics import Metrics
2632
from lighteval.tasks.lighteval_task import LightevalTaskConfig
2733
from lighteval.tasks.requests import Doc
@@ -36,6 +42,12 @@ def anli_prompt(line, task_name: str = None):
3642
)
3743

3844

45+
def record_to_sample(record):
46+
choices = ["True", "Neither", "False"]
47+
query = f"{record['premise']}\nQuestion: {record['hypothesis']}"
48+
return Sample(input=query, target=ascii_uppercase[record["label"]], choices=choices)
49+
50+
3951
anli_r1 = LightevalTaskConfig(
4052
name="anli:r1",
4153
prompt_function=anli_prompt,
@@ -49,6 +61,9 @@ def anli_prompt(line, task_name: str = None):
4961
metrics=[Metrics.loglikelihood_acc],
5062
stop_sequence=["\n"],
5163
version=0,
64+
sample_fields=record_to_sample,
65+
solver=[multiple_choice(cache=True)],
66+
scorer=choice(),
5267
)
5368

5469

@@ -65,6 +80,9 @@ def anli_prompt(line, task_name: str = None):
6580
metrics=[Metrics.loglikelihood_acc],
6681
stop_sequence=["\n"],
6782
version=0,
83+
sample_fields=record_to_sample,
84+
solver=[multiple_choice(cache=True)],
85+
scorer=choice(),
6886
)
6987

7088

@@ -81,6 +99,9 @@ def anli_prompt(line, task_name: str = None):
8199
metrics=[Metrics.loglikelihood_acc],
82100
stop_sequence=["\n"],
83101
version=0,
102+
sample_fields=record_to_sample,
103+
solver=[multiple_choice(cache=True)],
104+
scorer=choice(),
84105
)
85106

86107
TASKS_TABLE = [

src/lighteval/tasks/tasks/arc.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@
2222
https://arxiv.org/abs/1803.05457
2323
"""
2424

25+
from inspect_ai.dataset import Sample
26+
from inspect_ai.scorer import choice
27+
from inspect_ai.solver import multiple_choice
28+
2529
from lighteval.metrics.metrics import Metrics
2630
from lighteval.tasks.lighteval_task import LightevalTaskConfig
2731
from lighteval.tasks.requests import Doc
@@ -36,6 +40,14 @@ def arc_prompt(line, task_name: str = None):
3640
)
3741

3842

43+
def record_to_sample(record):
44+
query = record["question"].strip()
45+
target = record["answerKey"]
46+
choices = record["choices"]["text"]
47+
48+
return Sample(input=query, target=target, choices=choices)
49+
50+
3951
arc_challenge = LightevalTaskConfig(
4052
name="arc:challenge",
4153
prompt_function=arc_prompt,
@@ -51,6 +63,9 @@ def arc_prompt(line, task_name: str = None):
5163
],
5264
stop_sequence=["\n"],
5365
version=0,
66+
sample_fields=record_to_sample,
67+
solver=[multiple_choice(cache=True)],
68+
scorer=choice(),
5469
)
5570

5671
arc_easy = LightevalTaskConfig(
@@ -68,6 +83,9 @@ def arc_prompt(line, task_name: str = None):
6883
],
6984
stop_sequence=["\n"],
7085
version=0,
86+
sample_fields=record_to_sample,
87+
solver=[multiple_choice(cache=True)],
88+
scorer=choice(),
7189
)
7290

7391
TASKS_TABLE = [arc_challenge, arc_easy]

src/lighteval/tasks/tasks/arithmetic.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,25 @@
1919
https://arxiv.org/abs/2005.14165
2020
"""
2121

22-
from lighteval.metrics.metrics import Metrics
22+
from inspect_ai.dataset import Sample
23+
from inspect_ai.solver import generate
24+
25+
from lighteval.metrics.metrics import Metrics, math_scorer
2326
from lighteval.tasks.lighteval_task import LightevalTaskConfig
2427
from lighteval.tasks.requests import Doc
2528

2629

30+
# TODO: convert dataset to parquet
31+
32+
2733
def arithmetic_prompt(line, task_name: str = None):
2834
return Doc(task_name=task_name, query=line["context"], choices=[line["completion"]], gold_index=[0])
2935

3036

37+
def record_to_sample(record):
38+
return Sample(input=record["context"], target=record["completion"])
39+
40+
3141
arithmetic_1dc = LightevalTaskConfig(
3242
name="arithmetic:1dc",
3343
prompt_function=arithmetic_prompt,
@@ -41,6 +51,9 @@ def arithmetic_prompt(line, task_name: str = None):
4151
metrics=[Metrics.exact_match],
4252
stop_sequence=["\n"],
4353
version=0,
54+
sample_fields=record_to_sample,
55+
solver=[generate(cache=True)],
56+
scorer=math_scorer(),
4457
)
4558

4659
arithmetic_2da = LightevalTaskConfig(
@@ -56,6 +69,9 @@ def arithmetic_prompt(line, task_name: str = None):
5669
metrics=[Metrics.exact_match],
5770
stop_sequence=["\n"],
5871
version=0,
72+
sample_fields=record_to_sample,
73+
solver=[generate(cache=True)],
74+
scorer=math_scorer(),
5975
)
6076

6177
arithmetic_2dm = LightevalTaskConfig(
@@ -71,6 +87,9 @@ def arithmetic_prompt(line, task_name: str = None):
7187
metrics=[Metrics.exact_match],
7288
stop_sequence=["\n"],
7389
version=0,
90+
sample_fields=record_to_sample,
91+
solver=[generate(cache=True)],
92+
scorer=math_scorer(),
7493
)
7594

7695
arithmetic_2ds = LightevalTaskConfig(
@@ -86,6 +105,9 @@ def arithmetic_prompt(line, task_name: str = None):
86105
metrics=[Metrics.exact_match],
87106
stop_sequence=["\n"],
88107
version=0,
108+
sample_fields=record_to_sample,
109+
solver=[generate(cache=True)],
110+
scorer=math_scorer(),
89111
)
90112

91113
arithmetic_3da = LightevalTaskConfig(
@@ -101,6 +123,9 @@ def arithmetic_prompt(line, task_name: str = None):
101123
metrics=[Metrics.exact_match],
102124
stop_sequence=["\n"],
103125
version=0,
126+
sample_fields=record_to_sample,
127+
solver=[generate(cache=True)],
128+
scorer=math_scorer(),
104129
)
105130

106131
arithmetic_3ds = LightevalTaskConfig(
@@ -116,6 +141,9 @@ def arithmetic_prompt(line, task_name: str = None):
116141
metrics=[Metrics.exact_match],
117142
stop_sequence=["\n"],
118143
version=0,
144+
sample_fields=record_to_sample,
145+
solver=[generate(cache=True)],
146+
scorer=math_scorer(),
119147
)
120148

121149
arithmetic_4da = LightevalTaskConfig(
@@ -131,6 +159,9 @@ def arithmetic_prompt(line, task_name: str = None):
131159
metrics=[Metrics.exact_match],
132160
stop_sequence=["\n"],
133161
version=0,
162+
sample_fields=record_to_sample,
163+
solver=[generate(cache=True)],
164+
scorer=math_scorer(),
134165
)
135166

136167
arithmetic_4ds = LightevalTaskConfig(
@@ -146,6 +177,9 @@ def arithmetic_prompt(line, task_name: str = None):
146177
metrics=[Metrics.exact_match],
147178
stop_sequence=["\n"],
148179
version=0,
180+
sample_fields=record_to_sample,
181+
solver=[generate(cache=True)],
182+
scorer=math_scorer(),
149183
)
150184

151185
arithmetic_5da = LightevalTaskConfig(
@@ -161,6 +195,9 @@ def arithmetic_prompt(line, task_name: str = None):
161195
metrics=[Metrics.exact_match],
162196
stop_sequence=["\n"],
163197
version=0,
198+
sample_fields=record_to_sample,
199+
solver=[generate(cache=True)],
200+
scorer=math_scorer(),
164201
)
165202

166203
arithmetic_5ds = LightevalTaskConfig(
@@ -176,6 +213,9 @@ def arithmetic_prompt(line, task_name: str = None):
176213
metrics=[Metrics.exact_match],
177214
stop_sequence=["\n"],
178215
version=0,
216+
sample_fields=record_to_sample,
217+
solver=[generate(cache=True)],
218+
scorer=math_scorer(),
179219
)
180220

181221
TASKS_TABLE = [

src/lighteval/tasks/tasks/asdiv.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@
1919
https://arxiv.org/abs/2410.12853
2020
"""
2121

22-
from lighteval.metrics.metrics import Metrics
22+
from inspect_ai.dataset import Sample
23+
from inspect_ai.solver import generate
24+
25+
from lighteval.metrics.metrics import Metrics, math_scorer
2326
from lighteval.tasks.lighteval_task import LightevalTaskConfig
2427
from lighteval.tasks.requests import Doc
2528

@@ -33,6 +36,12 @@ def asdiv_prompt(line, task_name: str = None):
3336
)
3437

3538

39+
def record_to_sample(record):
40+
query = f"{record['body']}\n{record['question']}"
41+
target = record["answer"].split(" (")[0]
42+
return Sample(input=query, target=target)
43+
44+
3645
asdiv = LightevalTaskConfig(
3746
name="asdiv",
3847
prompt_function=asdiv_prompt,
@@ -46,6 +55,9 @@ def asdiv_prompt(line, task_name: str = None):
4655
metrics=[Metrics.exact_match],
4756
stop_sequence=["\n"],
4857
version=0,
58+
sample_fields=record_to_sample,
59+
solver=[generate(cache=True)],
60+
scorer=math_scorer(),
4961
)
5062

5163
TASKS_TABLE = [asdiv]

src/lighteval/tasks/tasks/babi_qa.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626
from lighteval.tasks.requests import Doc
2727

2828

29+
# TODO: clean dataset and convert to inspect-ai
30+
31+
2932
def babi_qa_prompt(line, task_name: str = None):
3033
def process_path(path: str) -> str:
3134
steps = path.split(",")

0 commit comments

Comments
 (0)