diff --git a/pyproject.toml b/pyproject.toml index 589e24f5e..63ad62b3a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,7 +70,7 @@ dependencies = [ "pydantic", "numpy>=2", # pinned to avoid incompatibilities "hf-xet>=1.1.8", # pinned to avoid failing test suite - # Prettiness + "scipy>=1.7.0", # for sparse matrix handling specific to scicode benchmark "typer>=0.20.0", "termcolor==2.3.0", "pytablewriter", @@ -90,6 +90,7 @@ dependencies = [ "httpx>=0.27.2", "latex2sympy2_extended==1.0.6", "langcodes", + "h5py", # for handling h5 files e.g. scicode benchmark ] [project.optional-dependencies] diff --git a/src/lighteval/tasks/tasks/scicode/__init__.py b/src/lighteval/tasks/tasks/scicode/__init__.py new file mode 100644 index 000000000..a6faba3f7 --- /dev/null +++ b/src/lighteval/tasks/tasks/scicode/__init__.py @@ -0,0 +1,10 @@ +"""SciCode benchmark implementation for Lighteval. + +Based on the original SciCode implementation: +https://github.com/scicode-bench/SciCode/blob/main/eval/inspect_ai/scicode.py +""" + +from lighteval.tasks.tasks.scicode.main import TASKS_TABLE, scicode + + +__all__ = ["scicode", "TASKS_TABLE"] diff --git a/src/lighteval/tasks/tasks/scicode/main.py b/src/lighteval/tasks/tasks/scicode/main.py new file mode 100644 index 000000000..e47f266b7 --- /dev/null +++ b/src/lighteval/tasks/tasks/scicode/main.py @@ -0,0 +1,106 @@ +""" +name: +SciCode + +dataset: +SciCode1/SciCode + +abstract: +SciCode is a challenging benchmark designed to evaluate the capabilities of language models (LMs) +in generating code for solving realistic scientific research problems. It has a diverse coverage of +16 subdomains from 6 domains: Physics, Math, Material Science, Biology, and Chemistry. Unlike previous +benchmarks that consist of exam-like question-answer pairs, SciCode is converted from real research problems. +SciCode problems naturally factorize into multiple subproblems, each involving knowledge recall, reasoning, +and code synthesis. In total, SciCode contains 338 subproblems decomposed from 80 challenging main problems. + +languages: +english + +tags: +code-generation, scientific-computing + +paper: +https://arxiv.org/abs/2407.13168 + +starred: +true +""" + +from typing import Any + +from inspect_ai.dataset import Sample + +from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc +from lighteval.tasks.tasks.scicode.prompts import prepare_scicode_prompt +from lighteval.tasks.tasks.scicode.scorer import scicode_scorer +from lighteval.tasks.tasks.scicode.solver import scicode_solver +from lighteval.tasks.tasks.scicode.utils import _extract_first_step_metadata + + +def scicode_prompt(line: dict[str, Any], task_name: str = "scicode") -> Doc: + """Convert dataset record to Doc for evaluation. + + For multi-step evaluation, this returns the first step's prompt. + The solver will handle subsequent steps. + """ + step_metadata = _extract_first_step_metadata(line) + step_data = step_metadata["step_data"] + query = prepare_scicode_prompt(step_data, line, with_background=False) + + return Doc( + task_name=task_name, + query=query, + choices=[""], + gold_index=0, + specific={ + "test_cases": step_metadata["test_cases"], + "function_header": step_metadata["function_header"], + "fn_name": step_metadata["fn_name"], + "step_number": step_metadata["step_number"], + "problem_id": line.get("problem_id"), + "required_dependencies": line.get("required_dependencies", ""), + }, + ) + + +def record_to_sample(record: dict[str, Any]) -> Sample: + """Convert dataset record to inspect-ai Sample object. + + Includes ALL sub_steps in metadata for multi-step processing. + """ + step_metadata = _extract_first_step_metadata(record) + step_data = step_metadata["step_data"] + + metadata = dict(record) + metadata.update( + { + "test_cases": step_metadata["test_cases"], + "function_header": step_metadata["function_header"], + "fn_name": step_metadata["fn_name"], + "step_number": step_metadata["step_number"], + } + ) + + prompt = prepare_scicode_prompt(step_data, record, with_background=False) + + return Sample(input=prompt, metadata=metadata) + + +scicode = LightevalTaskConfig( + name="scicode", + prompt_function=scicode_prompt, + sample_fields=record_to_sample, + solver=scicode_solver(with_background=False), + scorer=scicode_scorer(), + hf_repo="SciCode1/SciCode", + hf_subset="default", + hf_avail_splits=["test", "validation"], + evaluation_splits=["test"], + generation_size=32768, + metrics=[], # Metrics are defined in the scorer decorator for inspect_ai + stop_sequence=[], # no stop sequence, will use EOS token + version=0, +) + +TASKS_TABLE = [scicode] diff --git a/src/lighteval/tasks/tasks/scicode/parse.py b/src/lighteval/tasks/tasks/scicode/parse.py new file mode 100644 index 000000000..69a38fc02 --- /dev/null +++ b/src/lighteval/tasks/tasks/scicode/parse.py @@ -0,0 +1,133 @@ +"""Parsing utilities for SciCode. + +Based on original implementation: +https://github.com/scicode-bench/SciCode +""" + +import ast +import re +from pathlib import Path + +import h5py +import scipy.sparse + + +def extract_function_name(function_header: str) -> str: + """Extract function or class name from function header.""" + pattern = r"\bdef\s+(\w+)\s*\(" + match = re.search(pattern, function_header) + if match: + return match.group(1) + + pattern = r"\bclass\s+(\w+)\s*[\(:]" + match = re.search(pattern, function_header) + if match: + return match.group(1) + + raise ValueError(f"Function name or class name not found in: {function_header}") + + +def get_function_from_code(code_string: str, function_name: str) -> str: + """Extract specific function/class from code using AST.""" + if code_string is None: + return "" + try: + tree = ast.parse(code_string) + for node in ast.walk(tree): + if isinstance(node, (ast.FunctionDef, ast.ClassDef)) and node.name == function_name: + return ast.unparse(node) + except Exception: + return code_string + return None + + +def _process_hdf5_sparse_matrix(group: h5py.Group): + """Process an h5py Group containing sparse matrix data.""" + data = group["data"][()] + shape = tuple(group["shape"][()]) + if "row" in group and "col" in group: + row = group["row"][()] + col = group["col"][()] + return scipy.sparse.coo_matrix((data, (row, col)), shape=shape) + elif "blocksize" in group: + indices = group["indices"][()] + indptr = group["indptr"][()] + blocksize = tuple(group["blocksize"][()]) + return scipy.sparse.bsr_matrix((data, indices, indptr), shape=shape, blocksize=blocksize) + else: + indices = group["indices"][()] + indptr = group["indptr"][()] + return scipy.sparse.csr_matrix((data, indices, indptr), shape=shape) + + +def _process_hdf5_list(group: h5py.Group) -> list: + """Process an h5py Group containing list data.""" + result_list = [] + for key in group.keys(): + result_list.append(group[key][()]) + return result_list + + +def _process_hdf5_dict(group: h5py.Group) -> dict: + """Process an h5py Group into a dictionary.""" + result_dict = {} + for key, obj in group.items(): + if isinstance(obj, h5py.Group): + if "sparse_matrix" in obj: + result_dict[key] = _process_hdf5_sparse_matrix(obj["sparse_matrix"]) + else: + result_dict[key] = _process_hdf5_datagroup(obj) + elif isinstance(obj, h5py.Dataset): + if isinstance(obj[()], bytes): + result_dict[key] = obj[()].decode("utf-8", errors="strict") + else: + try: + tmp = float(key) + result_dict[tmp] = obj[()] + except ValueError: + result_dict[key] = obj[()] + return result_dict + + +def _process_hdf5_datagroup(group: h5py.Group): + """Process an h5py Group, handling special cases (list, sparse_matrix) or dict.""" + if "list" in group: + return _process_hdf5_list(group["list"]) + elif "sparse_matrix" in group: + return _process_hdf5_sparse_matrix(group["sparse_matrix"]) + else: + return _process_hdf5_dict(group) + + +def extract_targets(step_id: str, num_tests: int, h5py_file: str | Path) -> tuple: + """Extract target values from h5py file for a given step.""" + if isinstance(step_id, tuple): + step_id = ".".join(str(x) for x in step_id) + elif not isinstance(step_id, str): + step_id = str(step_id) + + with h5py.File(h5py_file, "r") as f: + if step_id not in f: + raise ValueError(f"Step {step_id} not found in h5py file") + targets = [] + for i in range(1, num_tests + 1): + group_path = f"{step_id}/test{i}" + + try: + if group_path not in f: + continue + + group = f[group_path] + + if "var1" in group: + var1 = group["var1"] + if isinstance(var1, h5py.Dataset): + target = var1[()] + targets.append(target) + elif isinstance(var1, h5py.Group): + target = _process_hdf5_datagroup(var1) + targets.append(target) + except Exception: + raise + + return tuple(targets) diff --git a/src/lighteval/tasks/tasks/scicode/prompts.py b/src/lighteval/tasks/tasks/scicode/prompts.py new file mode 100644 index 000000000..a045ab9cc --- /dev/null +++ b/src/lighteval/tasks/tasks/scicode/prompts.py @@ -0,0 +1,141 @@ +"""Prompt templates and generation for SciCode.""" + +from typing import Any + + +SCICODE_PROMPT_TEMPLATE = """PROBLEM DESCRIPTION: +You will be provided with the main description of the problem, previous steps, and the next step. Your task will be to generate the disciplinary knowledge necessary for solving the next step and then develop a Python solution focused on this step. + +PREVIOUS STEPS DESCRIPTION: +{previous_steps_str} + +NEXT STEP - PROBLEM DESCRIPTION AND FUNCTION HEADER: +This part will describe the next step in the problem-solving process. First, provide the necessary scientific background knowledge as a comment at the beginning of your response, starting with 'Background: '. Then, a function header will be provided, and your task is to develop the Python code for this next step based on the provided description and function header. + +{next_step_str} + +DEPENDENCIES: +Use only the following dependencies in your solution. Do not include these dependencies at the beginning of your code. +{dependencies} + +RESPONSE GUIDELINES: +1. Start with the scientific background required for the next step, formatted as a comment. +2. Then write the complete and executable Python program for the next step in a single block. +3. Your response should focus exclusively on implementing the solution for the next step, adhering closely to the specified function header and the context provided by the initial steps. +4. DO NOT include previous function code, example usage or test code in your response. +5. Ensure your response is in the format of ```python``` and includes the necessary background as a comment at the top. + +Example: +```python +# Background: [Here, insert the necessary scientific knowledge required for the next step.] + +[Insert the Python code here based on the provided function header and dependencies.] +```""".strip() + + +def prepare_scicode_prompt( + step_data: dict[str, Any], problem_data: dict[str, Any], with_background: bool = False +) -> str: + """Prepare prompt for the first SciCode sub-step (no previous steps). + + This function is used for initial prompt generation before the solver runs. + For subsequent steps with previous context, use generate_prompt_with_steps() instead. + """ + next_step_parts = [step_data["step_description_prompt"]] + + if with_background and step_data.get("step_background"): + next_step_parts.append(step_data["step_background"]) + + next_step_parts.append(step_data["function_header"]) + + if step_data.get("return_line"): + next_step_parts.append(step_data["return_line"]) + + next_step_str = "\n\n".join(next_step_parts) + dependencies = problem_data.get("required_dependencies", "") + previous_steps_str = "" + + prompt = SCICODE_PROMPT_TEMPLATE.format( + previous_steps_str=previous_steps_str, + next_step_str=next_step_str, + dependencies=dependencies, + ) + + return prompt + + +def process_problem_code(prob_data: dict[str, Any], num_steps: int) -> str: + """Extract function header and return line for a given step.""" + header_docstring = prob_data["sub_steps"][num_steps - 1]["function_header"] + return_str = prob_data["sub_steps"][num_steps - 1].get("return_line", "") + if return_str: + return f"{header_docstring}\n\n{return_str}" + return header_docstring + + +def process_problem_steps( + problem_data: dict[str, Any], + num_steps: int, + previous_llm_code: list[str | None], + with_background: bool = False, +) -> tuple[str, str, str]: + """Process problem data and return previous steps and next steps. + + Returns: + tuple: (previous_steps_str, next_step_str, previous_code_str) + """ + output_lines = [] + next_step = [] + previous_code = [] + + for i in range(num_steps - 1): + step_desc = problem_data["sub_steps"][i]["step_description_prompt"] + if with_background and problem_data["sub_steps"][i].get("step_background"): + step_desc += "\n" + problem_data["sub_steps"][i]["step_background"] + output_lines.append(step_desc) + + if previous_llm_code[i] is not None: + output_lines.append(previous_llm_code[i]) + previous_code.append(previous_llm_code[i]) + output_lines.append("------") + + # Next step + step_desc = problem_data["sub_steps"][num_steps - 1]["step_description_prompt"] + if with_background and problem_data["sub_steps"][num_steps - 1].get("step_background"): + step_desc += "\n" + problem_data["sub_steps"][num_steps - 1]["step_background"] + next_step.append(step_desc) + next_step.append(process_problem_code(problem_data, num_steps)) + + output_str = "\n\n".join(output_lines[:-1]) # Remove the last "------" + next_step_str = "\n\n".join(next_step) + previous_code_str = "\n".join(previous_code) + + return output_str, next_step_str, previous_code_str + + +def generate_prompt_with_steps( + prob_data: dict[str, Any], + num_steps: int, + previous_llm_code: list[str | None], + prompt_template: str = SCICODE_PROMPT_TEMPLATE, + with_background: bool = False, +) -> tuple[str, str]: + """Generate prompt for step N with previous steps context. + + Returns: + tuple: (prompt, previous_code_str) + """ + problem_steps_str, next_step_str, previous_code_str = process_problem_steps( + prob_data, num_steps, previous_llm_code, with_background + ) + dependencies = prob_data.get("required_dependencies", "") + + prompt = prompt_template.format( + previous_steps_str=problem_steps_str, + next_step_str=next_step_str, + dependencies=dependencies, + ) + + previous_code_with_deps = f"{dependencies}\n{previous_code_str}\n" if previous_code_str else f"{dependencies}\n" + + return prompt, previous_code_with_deps diff --git a/src/lighteval/tasks/tasks/scicode/scorer.py b/src/lighteval/tasks/tasks/scicode/scorer.py new file mode 100644 index 000000000..464fefed9 --- /dev/null +++ b/src/lighteval/tasks/tasks/scicode/scorer.py @@ -0,0 +1,247 @@ +"""Scorer and metrics for SciCode evaluation. + +Based on original implementation: +https://github.com/scicode-bench/SciCode +""" + +import platform +import resource +import shutil +import subprocess +import tempfile +import uuid +from pathlib import Path + +from inspect_ai.scorer import Metric, Score, Target, mean, metric, scorer +from inspect_ai.solver import TaskState + +from lighteval.tasks.tasks.lcb.codegen_metrics import extract_code +from lighteval.tasks.tasks.scicode.parse import extract_targets +from lighteval.tasks.tasks.scicode.solver import should_skip_step +from lighteval.tasks.tasks.scicode.utils import get_h5py_file_path + + +@metric +def sub_problem_correctness() -> Metric: + """Metric to compute sub-problem correctness rate.""" + + def metric_fn(scores: list[Score]) -> int | float: + total_correct = 0 + total_steps = 0 + for score in scores: + total_correct += score.value["Total Correct"] + total_steps += score.value["Total Steps"] + return total_correct / total_steps if total_steps > 0 else 0.0 + + return metric_fn + + +def run_script(script_path: Path) -> int: + """Run test script and return exit code. + + 0 = pass, 1 = fail, 2 = timeout + + Note: Resource limits are applied to restrict memory usage (4GB max). + """ + maximum_memory_bytes = 4 * 1024 * 1024 * 1024 # 4GB + + def set_resource_limits(): + """Set resource limits in the child process before execution.""" + resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes)) + resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes)) + if platform.system() != "Darwin": + resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes)) + + preexec_fn = set_resource_limits if platform.system() != "Windows" else None + + process = None + try: + process = subprocess.Popen( + ["python", str(script_path)], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + preexec_fn=preexec_fn, + ) + stdout, stderr = process.communicate(timeout=1800) # 30 minutes like original + if process.returncode == 0: + return 0 + return 1 + except subprocess.TimeoutExpired: + if process is not None: + process.kill() + process.wait() + return 2 + except Exception: + return 1 + + +def _get_initial_score(sub_steps: list) -> Score | None: + """Check for early return conditions and return appropriate Score if needed. + + Returns Score for early returns, or None if processing should continue. + """ + if not sub_steps: + return Score( + value={ + "Problem Correctness": 0, + "Total Correct": 0, + "Total Steps": 0, + }, + explanation="No sub-steps found in metadata", + ) + return None + + +def _get_h5py_file_or_error_score(sub_steps: list) -> tuple[Path, Score | None]: + """Get h5py file path or return error Score if it fails. + + Returns tuple of (h5py_file, error_score) where error_score is None on success. + """ + try: + h5py_file = get_h5py_file_path() + return h5py_file, None + except Exception as e: + error_score = Score( + value={ + "Problem Correctness": 0, + "Total Correct": 0, + "Total Steps": len(sub_steps), + }, + explanation=f"Failed to get h5py file: {e}", + ) + return None, error_score + + +def _get_code_content(step_id: str, generated_code_by_step: dict, state: TaskState) -> str | None: + """Get code content for a step, with fallback to state output.""" + code_content = generated_code_by_step.get(step_id, "") + if not code_content: + response = state.output.completion if hasattr(state, "output") else "" + code_content = extract_code(response) if response else "" + return code_content if code_content else None + + +def _write_test_script(test_script: Path, code_content: str, targets: tuple, test_cases: list[str]) -> None: + """Write test script file with imports, code, targets, and test cases.""" + with open(test_script, "w", encoding="utf-8") as f: + f.write("import numpy as np\n") + f.write("from numpy import array\n\n") + f.write(code_content) + f.write("\n\n") + f.write("targets = (\n") + for target in targets: + if hasattr(target, "tolist"): + f.write(f" np.array({target.tolist()}),\n") + elif hasattr(target, "__iter__") and not isinstance(target, str): + f.write(f" {repr(target)},\n") + else: + f.write(f" {repr(target)},\n") + f.write(")\n\n") + for i in range(len(test_cases)): + f.write(f"target = targets[{i}]\n\n") + f.write(test_cases[i]) + f.write("\n\n") + + +def _execute_and_aggregate_tests(sub_steps: list, problem_id: str, tmp_dir: Path) -> tuple[int, int]: + """Execute test scripts and aggregate results. + + Returns tuple of (total_correct, total_steps). + """ + total_correct = 0 + total_steps = 0 + + for idx in range(len(sub_steps)): + if should_skip_step(problem_id, idx): + continue + + step_data = sub_steps[idx] + step_id = step_data.get("step_number") + + if not step_id: + continue + + total_steps += 1 + script_path = tmp_dir / f"{step_id}.py" + + if not script_path.exists(): + continue + + ret = run_script(script_path) + if ret == 0: + total_correct += 1 + + return total_correct, total_steps + + +@scorer( + metrics=[ + {"Problem Correctness": [mean()]}, + sub_problem_correctness(), + ] +) +def scicode_scorer(): + """Scorer for SciCode evaluation using inspect-ai. + + Implements full multi-step test execution with h5py file support. + """ + + async def score(state: TaskState, target: Target) -> Score: + metadata = state.metadata + sub_steps = metadata.get("sub_steps", []) + problem_id = metadata.get("problem_id") + generated_code_by_step = metadata.get("generated_code_by_step", {}) + + initial_score = _get_initial_score(sub_steps) + if initial_score is not None: + return initial_score + + h5py_file, error_score = _get_h5py_file_or_error_score(sub_steps) + if error_score is not None: + return error_score + + tmp_dir: Path | None = None + try: + tmp_dir = Path(tempfile.mkdtemp(prefix=f"scicode_test_{uuid.uuid4().hex}_")) + for idx in range(len(sub_steps)): + if should_skip_step(problem_id, idx): + continue + + step_data = sub_steps[idx] + step_id = step_data.get("step_number") + test_cases = step_data.get("test_cases", []) + + if not step_id or not test_cases: + continue + + code_content = _get_code_content(step_id, generated_code_by_step, state) + if not code_content: + continue + + try: + targets = extract_targets(step_id, len(test_cases), h5py_file) + except Exception: + continue + + test_script = tmp_dir / f"{step_id}.py" + _write_test_script(test_script, code_content, targets, test_cases) + + total_correct, total_steps = _execute_and_aggregate_tests(sub_steps, problem_id, tmp_dir) + + problem_correct = 1 if total_correct == total_steps and total_steps > 0 else 0 + + return Score( + value={ + "Problem Correctness": problem_correct, + "Total Correct": total_correct, + "Total Steps": total_steps, + }, + explanation=f"Tested {total_steps} steps, {total_correct} passed", + ) + + finally: + if tmp_dir is not None and tmp_dir.exists(): + shutil.rmtree(tmp_dir) + + return score diff --git a/src/lighteval/tasks/tasks/scicode/solver.py b/src/lighteval/tasks/tasks/scicode/solver.py new file mode 100644 index 000000000..aa2227a90 --- /dev/null +++ b/src/lighteval/tasks/tasks/scicode/solver.py @@ -0,0 +1,80 @@ +"""Multi-step solver for SciCode. + +Based on original implementation: +https://github.com/scicode-bench/SciCode +""" + +import copy +from typing import Any + +from inspect_ai.solver import Generate, TaskState, solver + +from lighteval.tasks.tasks.scicode.prompts import SCICODE_PROMPT_TEMPLATE, generate_prompt_with_steps +from lighteval.tasks.tasks.scicode.utils import extract_python_script + + +def should_skip_step(problem_id: str, step_idx: int) -> bool: + """Check if a step should be skipped based on special cases. + + Special cases from original implementation: + - Problem 13, step 6 (idx 5) + - Problem 62, step 1 (idx 0) + - Problem 76, step 3 (idx 2) + """ + return ( + (problem_id == "13" and step_idx == 5) + or (problem_id == "62" and step_idx == 0) + or (problem_id == "76" and step_idx == 2) + ) + + +@solver +def scicode_solver(**params: Any): + """Custom solver that processes all sub-steps sequentially.""" + + async def solve(state: TaskState, generate_fn: Generate) -> TaskState: + sub_steps = state.metadata.get("sub_steps", []) + problem_id = state.metadata.get("problem_id") + with_background = params.get("with_background", False) + + if not sub_steps: + return state + + if "generated_code_by_step" not in state.metadata: + state.metadata["generated_code_by_step"] = {} + + tot_steps = len(sub_steps) + previous_llm_code: list[str | None] = [None] * tot_steps + + for idx in range(len(sub_steps)): + if should_skip_step(problem_id, idx): + continue + + num_steps = idx + 1 + + prompt, previous_code_str = generate_prompt_with_steps( + prob_data=state.metadata, + num_steps=num_steps, + previous_llm_code=previous_llm_code, + prompt_template=SCICODE_PROMPT_TEMPLATE, + with_background=with_background, + ) + + try: + state.user_prompt.text = prompt + state_copy = copy.deepcopy(state) + result = await generate_fn(state=state_copy) + response_from_llm = result.output.completion + except Exception: + return state + + extracted_code = extract_python_script(response_from_llm) + step_id = sub_steps[idx].get("step_number") + + if step_id: + state.metadata["generated_code_by_step"][step_id] = extracted_code + previous_llm_code[idx] = extracted_code + + return state + + return solve diff --git a/src/lighteval/tasks/tasks/scicode/utils.py b/src/lighteval/tasks/tasks/scicode/utils.py new file mode 100644 index 000000000..7970a15f2 --- /dev/null +++ b/src/lighteval/tasks/tasks/scicode/utils.py @@ -0,0 +1,75 @@ +"""Utility functions for SciCode. + +Based on original implementation: +https://github.com/scicode-bench/SciCode +""" + +import re +from pathlib import Path +from typing import Any + + +def extract_python_script(response: str) -> str: + """Extract Python code from markdown code blocks.""" + if "```" in response: + if "```python" in response: + python_script = response.split("```python")[1].split("```")[0] + else: + python_script = response.split("```")[1].split("```")[0] + else: + python_script = response + + python_script = re.sub(r"^\s*(import .*|from .*\s+import\s+.*)", "", python_script, flags=re.MULTILINE) + return python_script + + +def _extract_first_step_metadata(record: dict[str, Any]) -> dict[str, Any]: + """Extract and validate metadata from the first step of a record.""" + if not record.get("sub_steps") or len(record["sub_steps"]) == 0: + raise ValueError("No sub-steps found in problem data") + + step_data = record["sub_steps"][0] + function_header = step_data.get("function_header", "") + + # Import here to avoid circular dependency + from lighteval.tasks.tasks.scicode.parse import extract_function_name + + fn_name = extract_function_name(function_header) if function_header else None + + return { + "step_data": step_data, + "test_cases": step_data.get("test_cases", []), + "function_header": function_header, + "fn_name": fn_name, + "step_number": step_data.get("step_number"), + } + + +def get_h5py_file_path() -> Path: + """Get path to test_data.h5, downloading from HuggingFace if necessary. + + Note: Currently hosted at akshathmangudi/scicode-files. + Once official hosting is available, this will be updated to the official repository. + """ + from huggingface_hub import hf_hub_download + + repo_id = "akshathmangudi/scicode-files" + try: + h5py_file = hf_hub_download( + repo_id=repo_id, + filename="test_data.h5", + repo_type="dataset", + ) + return Path(h5py_file) + except Exception as e: + # Fallback: check local path + local_path = Path(__file__).parent / "test_data.h5" + if local_path.exists(): + return local_path + raise FileNotFoundError( + f""" + Could not download test_data.h5 from {repo_id}. + Please ensure it's available or place it at {local_path}. + Error: {str(e)} + """ + ) from e