From cef0b0f2ffea6ddb5d8aee89f2b552e986650b19 Mon Sep 17 00:00:00 2001 From: Akshath Mangudi Date: Fri, 21 Nov 2025 16:26:40 +0530 Subject: [PATCH 1/4] ready for review --- .../tasks/long_horizon_execution/__init__.py | 0 .../tasks/long_horizon_execution/constants.py | 38 +++ .../tasks/long_horizon_execution/main.py | 38 +++ .../long_horizon_execution/multi_turn.py | 221 ++++++++++++++++++ .../long_horizon_execution/single_turn.py | 128 ++++++++++ .../tasks/long_horizon_execution/utils.py | 206 ++++++++++++++++ 6 files changed, 631 insertions(+) create mode 100644 src/lighteval/tasks/tasks/long_horizon_execution/__init__.py create mode 100644 src/lighteval/tasks/tasks/long_horizon_execution/constants.py create mode 100644 src/lighteval/tasks/tasks/long_horizon_execution/main.py create mode 100644 src/lighteval/tasks/tasks/long_horizon_execution/multi_turn.py create mode 100644 src/lighteval/tasks/tasks/long_horizon_execution/single_turn.py create mode 100644 src/lighteval/tasks/tasks/long_horizon_execution/utils.py diff --git a/src/lighteval/tasks/tasks/long_horizon_execution/__init__.py b/src/lighteval/tasks/tasks/long_horizon_execution/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/lighteval/tasks/tasks/long_horizon_execution/constants.py b/src/lighteval/tasks/tasks/long_horizon_execution/constants.py new file mode 100644 index 000000000..dea40af8a --- /dev/null +++ b/src/lighteval/tasks/tasks/long_horizon_execution/constants.py @@ -0,0 +1,38 @@ +""" +Constants file reused within the Long Horizon Execution task. +""" + +PROMPT_TEMPLATE_SINGLE = """You are an AI assistant. I will provide you with a dictionary and then give you a list of keys. +Your task is to calculate the final cumulative sum after processing all keys in order. +For each key in the list, you need to: +1. Look up the value in the dictionary +2. Add it to the running sum +3. After processing all keys, output the final cumulative sum +Dictionary to use: +{dict_str} +Keys to process in order: +{keys_str} +Your task: Process all keys in order and calculate the final cumulative sum after processing all {num_keys} keys. +IMPORTANT: +- Output your answer as a single integer value inside tags +- Do not include any other text outside the answer tags +- Format: final_sum +- Example: If the final cumulative sum is 42, output: 42 +Your answer:""" + +PROMPT_TEMPLATE_MULTI_START = """You are an AI assistant. I will provide you with a dictionary and then give you keys in groups of {k}. +Your task is to keep a running total (starting from 0) by adding the values associated with the keys I provide. +In each turn, I'll provide {k} keys (comma-separated). +Respond with the current running sum, enclosed in tags. +Dictionary to maintain: +{dict_str} +Ready to start! +**User**: {keys_str} +**Assistant**:""" + +PROMPT_TEMPLATE_MULTI_FOLLOWUP = """Here are the next keys to process: +**User**: {keys_str} +**Assistant**:""" + +CONTEXT_SIZES = [1024, 2048, 4096, 8192, 16384, 32768, 65536] +TURN_COMPLEXITIES = [1, 2, 10] diff --git a/src/lighteval/tasks/tasks/long_horizon_execution/main.py b/src/lighteval/tasks/tasks/long_horizon_execution/main.py new file mode 100644 index 000000000..3e8de0e38 --- /dev/null +++ b/src/lighteval/tasks/tasks/long_horizon_execution/main.py @@ -0,0 +1,38 @@ +""" +name: +Long Horizon Execution +dataset: +arvindh75/Long-Horizon-Execution +abstract: +Evaluation benchmark for long-context execution capabilities of language models. +Tests a model's ability to maintain state and perform cumulative operations over +long sequences of inputs. Supports both single-turn (all inputs at once) and +multi-turn (inputs provided incrementally) evaluation modes. +The task requires models to: +1. Maintain a dictionary mapping keys to values +2. Process a sequence of keys +3. Calculate cumulative sums after each key or group of keys +4. Handle varying context sizes and turn complexities +Single-turn evaluation (Section 3.3): Model outputs only the final cumulative sum +after processing all keys, allowing any aggregation strategy. +Multi-turn evaluation: Model processes keys in batches of K per turn, maintaining +conversation history and outputting cumulative sums incrementally. Evaluates +fractional accuracy (correct turns / total turns). +languages: +english +tags: +long-context, state-tracking, arithmetic, execution +paper: +https://arxiv.org/abs/2509.09677 +starred: +true +""" + +from lighteval.tasks.tasks.long_horizon_execution.multi_turn import create_multi_turn_tasks +from lighteval.tasks.tasks.long_horizon_execution.single_turn import create_single_turn_tasks + + +single_turn_tasks = create_single_turn_tasks() +multi_turn_tasks = create_multi_turn_tasks() + +TASKS_TABLE = single_turn_tasks + multi_turn_tasks diff --git a/src/lighteval/tasks/tasks/long_horizon_execution/multi_turn.py b/src/lighteval/tasks/tasks/long_horizon_execution/multi_turn.py new file mode 100644 index 000000000..ec86a093a --- /dev/null +++ b/src/lighteval/tasks/tasks/long_horizon_execution/multi_turn.py @@ -0,0 +1,221 @@ +""" +Multi-turn implementation of the Long Horizon Execution task. +This implementation matches the multi-turn evaluation approach from the research paper, +where keys are provided in batches of K per turn, and the model maintains conversation +state to output cumulative sums after each turn. +""" + +import functools +import re + +from inspect_ai.dataset import Sample +from inspect_ai.scorer import Score, Target, accuracy, scorer, stderr +from inspect_ai.solver import TaskState, generate + +from lighteval.metrics.metrics import Metrics +from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc +from lighteval.tasks.tasks.long_horizon_execution.constants import ( + CONTEXT_SIZES, + PROMPT_TEMPLATE_MULTI_FOLLOWUP, + TURN_COMPLEXITIES, +) +from lighteval.tasks.tasks.long_horizon_execution.utils import _build_multi_turn_prompts + + +def multi_turn_prompt_function(line, prompt_length=32768, k=1, task_name: str = None): + """ + Prompt function for non-inspect-ai backend for multi-turn evaluation. + Converts dataset record to Doc object. + Note: For multi-turn, this returns the first turn's prompt. + Subsequent turns are handled by the solver. + """ + initial_prompt, _, expected_per_turn, _ = _build_multi_turn_prompts(line, prompt_length=prompt_length, k=k) + + return Doc( + task_name=task_name, + query=initial_prompt, + choices=[str(expected_per_turn[-1])], # Final sum as choice + gold_index=0, + instruction=initial_prompt, + ) + + +def multi_turn_record_to_sample(record, prompt_length=32768, k=1): + """ + Converts dataset record to inspect-ai Sample object for multi-turn evaluation. + Stores all turn information in metadata for the solver to use. + """ + initial_prompt, _, expected_per_turn, metadata = _build_multi_turn_prompts( + record, prompt_length=prompt_length, k=k + ) + + return Sample( + input=initial_prompt, + target=str(expected_per_turn[-1]), + metadata=metadata, + ) + + +def _extract_response_content(response): + """Extract content from model response object.""" + if hasattr(response, "content"): + return response.content + if hasattr(response, "completion"): + return response.completion + return str(response) + + +async def _process_single_turn(state, turn_chunk, config): + """Process a single turn: add user message, get model response, add assistant message.""" + keys_str = ", ".join(turn_chunk) + followup_prompt = PROMPT_TEMPLATE_MULTI_FOLLOWUP.format(keys_str=keys_str) + state.messages.append({"role": "user", "content": followup_prompt}) + + response = await state.model.generate(messages=state.messages, config=config) + turn_response = _extract_response_content(response) + + state.messages.append({"role": "assistant", "content": turn_response}) + return turn_response + + +async def multi_turn_solver(state: TaskState): + """ + Custom solver for multi-turn evaluation. + Loops through turns, calling the model for each turn while maintaining conversation history. + This implements offline evaluation: all turns are called, then evaluation happens. + """ + from inspect_ai.model import GenerateConfig, ModelOutput + + turn_chunks = state.metadata.get("turn_chunks", []) + + if not turn_chunks or len(turn_chunks) == 0: + return state + + # Initialize messages + if not hasattr(state, "messages") or state.messages is None: + state.messages = [] + + if not state.messages: + state.messages.append({"role": "user", "content": state.input}) + + all_turn_outputs = [] + + # Process all turns + if hasattr(state, "model") and state.model is not None: + config = GenerateConfig() + + # Process first turn (already in messages as initial prompt) + response = await state.model.generate(messages=state.messages, config=config) + turn_response = _extract_response_content(response) + all_turn_outputs.append(turn_response) + state.messages.append({"role": "assistant", "content": turn_response}) + + # Process remaining turns + for turn_idx in range(1, len(turn_chunks)): + if not hasattr(state, "model") or state.model is None: + break + turn_response = await _process_single_turn(state, turn_chunks[turn_idx], config) + all_turn_outputs.append(turn_response) + + state.metadata["all_turn_outputs"] = all_turn_outputs + + # Set final output + if all_turn_outputs: + if hasattr(state, "output") and state.output is not None: + state.output.completion = all_turn_outputs[-1] + else: + state.output = ModelOutput(completion=all_turn_outputs[-1]) + + return state + + +@scorer(metrics={"turn_accuracy": [accuracy(), stderr()], "fractional_accuracy": [accuracy(), stderr()]}) +def multi_turn_scorer(): + """ + Scorer for multi-turn Long Horizon Execution task. + Compares predicted cumulative sums at each turn with expected. + Returns fractional accuracy (correct turns / total turns). + """ + + async def score(state: TaskState, target: Target): + # metadata stored by solver + all_turn_outputs = state.metadata.get("all_turn_outputs", []) + expected_per_turn = state.metadata.get("expected_per_turn", []) + + if not all_turn_outputs: + return Score(value=0.0, answer="", explanation="No turn outputs found in state.metadata") + + if len(all_turn_outputs) != len(expected_per_turn): + return Score( + value=0.0, + answer="", + explanation=f"Mismatch: {len(all_turn_outputs)} outputs vs {len(expected_per_turn)} expected turns", + ) + + parsed_outputs = [] + answer_pattern = re.compile(r"(.*?)", re.DOTALL) + + for turn_idx, turn_output in enumerate(all_turn_outputs): + match = answer_pattern.search(turn_output) + if match: + try: + parsed_value = int(match.group(1).strip()) + parsed_outputs.append(parsed_value) + except ValueError: + parsed_outputs.append(None) + else: + parsed_outputs.append(None) + + correct_turns = 0 + turn_results = [] + for turn_idx, (pred, exp) in enumerate(zip(parsed_outputs, expected_per_turn)): + is_correct = (pred is not None) and (pred == exp) + if is_correct: + correct_turns += 1 + turn_results.append({"turn": turn_idx + 1, "predicted": pred, "expected": exp, "correct": is_correct}) + + fractional_accuracy = correct_turns / len(expected_per_turn) if expected_per_turn else 0.0 + + return Score( + value={ + "turn_accuracy": fractional_accuracy, + "fractional_accuracy": fractional_accuracy, + "correct_turns": correct_turns, + "total_turns": len(expected_per_turn), + }, + answer=str(parsed_outputs), + explanation=f"Correct {correct_turns}/{len(expected_per_turn)} turns. Details: {turn_results}", + ) + + return score + + +def create_multi_turn_tasks(): + """ + Creates a list of LightevalTaskConfig objects for multi-turn Long Horizon Execution. + Each task corresponds to a different combination of context size and turn complexity (K). + """ + tasks = [] + + for context_size in CONTEXT_SIZES: + for k in TURN_COMPLEXITIES: + task_name = f"long_horizon_execution:multi:{context_size}:k{k}" + prompt_fn = functools.partial(multi_turn_prompt_function, prompt_length=context_size, k=k) + sample_fn = functools.partial(multi_turn_record_to_sample, prompt_length=context_size, k=k) + + task = LightevalTaskConfig( + name=task_name, + prompt_function=prompt_fn, + sample_fields=sample_fn, + solver=[multi_turn_solver, generate(cache=True)], + scorer=multi_turn_scorer(), + hf_repo="arvindh75/Long-Horizon-Execution", + hf_subset="default", + evaluation_splits=("test",), + generation_size=context_size, + metrics=[Metrics.exact_match], + ) + tasks.append(task) + + return tasks diff --git a/src/lighteval/tasks/tasks/long_horizon_execution/single_turn.py b/src/lighteval/tasks/tasks/long_horizon_execution/single_turn.py new file mode 100644 index 000000000..07d089639 --- /dev/null +++ b/src/lighteval/tasks/tasks/long_horizon_execution/single_turn.py @@ -0,0 +1,128 @@ +""" +Single turn implementation of the Long Horizon Execution task. +""" + +import functools +import re + +from inspect_ai.dataset import Sample +from inspect_ai.scorer import Score, Target, accuracy, scorer, stderr +from inspect_ai.solver import TaskState, generate + +from lighteval.metrics.metrics import Metrics +from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc +from lighteval.tasks.tasks.long_horizon_execution.constants import CONTEXT_SIZES, PROMPT_TEMPLATE_SINGLE +from lighteval.tasks.tasks.long_horizon_execution.utils import _build_prompt_and_target + + +def single_turn_prompt_function(line, prompt_length=32768, task_name: str = None): + """ + Prompt function for single-turn evaluation (non-inspect-ai backend). + Converts dataset record to Doc object. + Returns: + Doc object for evaluation + """ + prompt, target_str, _ = _build_prompt_and_target( + line, prompt_length=prompt_length, prompt_template=PROMPT_TEMPLATE_SINGLE + ) + + return Doc( + task_name=task_name, + query=prompt, + choices=[target_str], # Expected answer as a choice + gold_index=0, + instruction=prompt, + ) + + +def single_turn_record_to_sample(record, prompt_length=32768): + """ + Converts dataset record to inspect-ai Sample object for single-turn evaluation. + Returns: + Sample object for inspect-ai + """ + prompt, target_str, metadata = _build_prompt_and_target( + record, prompt_length=prompt_length, prompt_template=PROMPT_TEMPLATE_SINGLE + ) + + return Sample( + input=prompt, + target=target_str, + metadata=metadata, + ) + + +@scorer(metrics={"accuracy": [accuracy(), stderr()]}) +def single_turn_scorer(): + """ + Scorer for single-turn evaluation. + Compares the model's predicted final sum with the expected final sum (binary score). + Returns: + Scorer function that evaluates single integer responses + """ + + async def score(state: TaskState, target: Target): + response = state.output.completion + + answer_pattern = re.compile(r"(.*?)", re.DOTALL) + match = answer_pattern.search(response) + + if not match: + return Score(value="I", answer="", explanation="No tag found in response.") + + content = match.group(1).strip() + + try: + pred_value = int(content.strip()) + except ValueError: + return Score(value="I", answer=content, explanation=f"Failed to parse integer from: {content}") + + try: + exp_value = int(target.text.strip()) + except (ValueError, AttributeError): + return Score( + value="I", + answer=str(pred_value), + explanation=f"Failed to parse expected target: {target.text}", + ) + + is_correct = pred_value == exp_value + return Score( + value="C" if is_correct else "I", + answer=str(pred_value), + explanation=(f"Expected {exp_value}, Got {pred_value}. Match: {is_correct}"), + ) + + return score + + +def create_single_turn_tasks(): + """ + Create all single-turn task configurations for different context sizes. + Returns: + list[LightevalTaskConfig]: List of task configurations for single-turn evaluation + """ + tasks = [] + + for context_size in CONTEXT_SIZES: + task_name = f"long_horizon_execution:{context_size}" + prompt_fn = functools.partial(single_turn_prompt_function, prompt_length=context_size) + sample_fn = functools.partial(single_turn_record_to_sample, prompt_length=context_size) + + task = LightevalTaskConfig( + name=task_name, + prompt_function=prompt_fn, + sample_fields=sample_fn, + solver=[generate(cache=True)], + scorer=single_turn_scorer(), + hf_repo="arvindh75/Long-Horizon-Execution", + hf_subset="default", + evaluation_splits=("test",), + generation_size=context_size, + metrics=[Metrics.exact_match], + ) + + tasks.append(task) + + return tasks diff --git a/src/lighteval/tasks/tasks/long_horizon_execution/utils.py b/src/lighteval/tasks/tasks/long_horizon_execution/utils.py new file mode 100644 index 000000000..9ea4ca6ab --- /dev/null +++ b/src/lighteval/tasks/tasks/long_horizon_execution/utils.py @@ -0,0 +1,206 @@ +""" +Utility functions for Long Horizon Execution task. +""" + +from lighteval.tasks.tasks.long_horizon_execution.constants import ( + PROMPT_TEMPLATE_MULTI_START, + PROMPT_TEMPLATE_SINGLE, +) + + +def _binary_search_max_items(input_keys, build_prompt_fn, prompt_length, min_items=1): + """ + Generic binary search to find maximum number of items that fit within prompt_length. + Returns: + int: Maximum number of items that fit + """ + # Pre-validate that at least min_items fit within prompt_length + test_prompt = build_prompt_fn(min_items) + if test_prompt is None: + raise ValueError("Cannot build prompt: unable to generate prompt with available items") + + if len(test_prompt) > prompt_length: + item_label = "item" if min_items == 1 else f"{min_items} items" + raise ValueError( + f"Prompt length ({prompt_length} chars) is too small to fit {item_label}. " + f"Minimum required: {len(test_prompt)} chars. " + f"Please increase prompt_length or reduce dataset complexity." + ) + + # Binary search to find maximum n that fits within prompt_length + left, right = min_items, len(input_keys) + max_n = min_items + + while left <= right: + mid = (left + right) // 2 + prompt = build_prompt_fn(mid) + + if prompt is None: + right = mid - 1 + continue + + if len(prompt) <= prompt_length: + max_n = mid + left = mid + 1 + else: + right = mid - 1 + + return max_n + + +def _build_prompt_and_target(record, prompt_length=32768, prompt_template=PROMPT_TEMPLATE_SINGLE): + """ + Helper function to extract common logic for building prompt and target. + Uses binary search to find the maximum number of items that fit within prompt_length. + Processes the record and returns prompt, target, and metadata. + Args: + record: Dictionary with 'input', 'values', and 'output' keys + prompt_length: Maximum character length for the prompt. Defaults to 32768. + prompt_template: Prompt template to use for formatting. Defaults to PROMPT_TEMPLATE_SINGLE. + Returns: + tuple: (prompt: str, target_str: str, metadata: dict) + """ + input_keys = record["input"] + input_values = record["values"] + expected_output = record["output"] + + def build_prompt_for_n(n): + """Build a prompt with the first n items.""" + if n == 0: + return None + keys_n = input_keys[:n] + values_n = input_values[:n] + dictionary_n = dict(zip(keys_n, values_n)) + dict_str = str(dictionary_n) + keys_str = str(keys_n) + return prompt_template.format(dict_str=dict_str, keys_str=keys_str, num_keys=n) + + # Handle empty input case + if len(input_keys) == 0: + raise ValueError("Cannot build prompt: no items available in record") + + max_n = _binary_search_max_items(input_keys, build_prompt_for_n, prompt_length, min_items=1) + + # Use the maximum n that fits + input_keys = input_keys[:max_n] + input_values = input_values[:max_n] + expected_output = expected_output[:max_n] + + dictionary = dict(zip(input_keys, input_values)) + dict_str = str(dictionary) + keys_str = str(input_keys) + prompt = prompt_template.format(dict_str=dict_str, keys_str=keys_str, num_keys=len(input_keys)) + + target_str = str(expected_output[-1]) + + metadata = { + "input_keys": input_keys, + "input_values": input_values, + "expected_output": expected_output, + "dictionary": dictionary, + "num_items": len(input_keys), + } + + return prompt, target_str, metadata + + +def _find_max_items_for_multi_turn(input_keys, input_values, prompt_length, k): + """ + Find maximum number of items that fit within prompt_length for multi-turn evaluation. + Uses binary search to find max items where initial prompt (dict + first K keys) fits. + Returns: + int: Maximum number of items that fit + """ + + def build_initial_prompt_for_n(n): + """Build initial prompt with dictionary and first K keys from n total items.""" + if n == 0: + return None + keys_n = input_keys[:n] + values_n = input_values[:n] + dictionary_n = dict(zip(keys_n, values_n)) + dict_str = str(dictionary_n) + + # First turn has first K keys + first_turn_keys = keys_n[:k] + keys_str = ", ".join(first_turn_keys) + + return PROMPT_TEMPLATE_MULTI_START.format( + dict_str=dict_str, keys_str=keys_str, k=k, num_keys=len(first_turn_keys) + ) + + return _binary_search_max_items(input_keys, build_initial_prompt_for_n, prompt_length, min_items=k) + + +def _chunk_and_calculate_expected(input_keys, input_values, k): + """ + Chunk keys into turns of size K and calculate expected cumulative sums per turn. + Returns: + tuple: (turn_chunks: list, value_chunks: list, expected_per_turn: list) + """ + # Chunk keys into turns of size K + turn_chunks = [] + value_chunks = [] + for i in range(0, len(input_keys), k): + turn_chunks.append(input_keys[i : i + k]) + value_chunks.append(input_values[i : i + k]) + + # Calculate expected cumulative sums for each turn + expected_per_turn = [] + cumulative_sum = 0 + for values in value_chunks: + cumulative_sum += sum(values) + expected_per_turn.append(cumulative_sum) + + return turn_chunks, value_chunks, expected_per_turn + + +def _build_multi_turn_prompts(record, prompt_length=32768, k=1): + """ + Build prompts for multi-turn evaluation. + Uses binary search to find maximum number of items that fit within prompt_length. + Chunks keys into turns of size K. + Args: + record: Dictionary with 'input', 'values', and 'output' keys + prompt_length: Maximum character length for the prompt. Defaults to 32768. + k: Turn complexity (number of keys per turn). Defaults to 1. + Returns: + tuple: (initial_prompt: str, turn_chunks: list, expected_per_turn: list, metadata: dict) + """ + input_keys = record["input"] + input_values = record["values"] + expected_output = record["output"] + + # Handle empty input case + if len(input_keys) == 0: + raise ValueError("Cannot build prompt: no items available in record") + + # Find maximum number of items that fit + max_n = _find_max_items_for_multi_turn(input_keys, input_values, prompt_length, k) + + # Use the maximum n that fits + input_keys = input_keys[:max_n] + input_values = input_values[:max_n] + expected_output = expected_output[:max_n] + + turn_chunks, value_chunks, expected_per_turn = _chunk_and_calculate_expected(input_keys, input_values, k) + + dictionary = dict(zip(input_keys, input_values)) + dict_str = str(dictionary) + + first_turn_keys_str = ", ".join(turn_chunks[0]) + initial_prompt = PROMPT_TEMPLATE_MULTI_START.format( + dict_str=dict_str, keys_str=first_turn_keys_str, k=k, num_keys=len(turn_chunks[0]) + ) + + metadata = { + "turn_chunks": turn_chunks, + "value_chunks": value_chunks, + "expected_per_turn": expected_per_turn, + "dictionary": dictionary, + "k": k, + "num_turns": len(turn_chunks), + "num_items": len(input_keys), + } + + return initial_prompt, turn_chunks, expected_per_turn, metadata From 2c0ceaebf0eb13917b3a4bddc5446623960dbe6c Mon Sep 17 00:00:00 2001 From: Akshath Mangudi Date: Sat, 22 Nov 2025 17:47:58 +0530 Subject: [PATCH 2/4] some fixes --- .../tasks/tasks/long_horizon_execution/main.py | 10 ++++++++-- .../tasks/tasks/long_horizon_execution/single_turn.py | 2 +- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/lighteval/tasks/tasks/long_horizon_execution/main.py b/src/lighteval/tasks/tasks/long_horizon_execution/main.py index 3e8de0e38..9f523aa63 100644 --- a/src/lighteval/tasks/tasks/long_horizon_execution/main.py +++ b/src/lighteval/tasks/tasks/long_horizon_execution/main.py @@ -1,8 +1,10 @@ """ name: Long Horizon Execution + dataset: arvindh75/Long-Horizon-Execution + abstract: Evaluation benchmark for long-context execution capabilities of language models. Tests a model's ability to maintain state and perform cumulative operations over @@ -15,17 +17,21 @@ 4. Handle varying context sizes and turn complexities Single-turn evaluation (Section 3.3): Model outputs only the final cumulative sum after processing all keys, allowing any aggregation strategy. + Multi-turn evaluation: Model processes keys in batches of K per turn, maintaining conversation history and outputting cumulative sums incrementally. Evaluates fractional accuracy (correct turns / total turns). + languages: english + tags: long-context, state-tracking, arithmetic, execution + paper: https://arxiv.org/abs/2509.09677 -starred: -true + +starred: true """ from lighteval.tasks.tasks.long_horizon_execution.multi_turn import create_multi_turn_tasks diff --git a/src/lighteval/tasks/tasks/long_horizon_execution/single_turn.py b/src/lighteval/tasks/tasks/long_horizon_execution/single_turn.py index 07d089639..d7bd508be 100644 --- a/src/lighteval/tasks/tasks/long_horizon_execution/single_turn.py +++ b/src/lighteval/tasks/tasks/long_horizon_execution/single_turn.py @@ -53,7 +53,7 @@ def single_turn_record_to_sample(record, prompt_length=32768): ) -@scorer(metrics={"accuracy": [accuracy(), stderr()]}) +@scorer(metrics=[accuracy(), stderr()]) def single_turn_scorer(): """ Scorer for single-turn evaluation. From 7485c14965028d7ff9b4eb7fc7518c13a1d3d2b4 Mon Sep 17 00:00:00 2001 From: Akshath Mangudi Date: Tue, 25 Nov 2025 18:19:45 +0530 Subject: [PATCH 3/4] addressing comments + fixing multi turn --- .../tasks/long_horizon_execution/__init__.py | 0 .../tasks/long_horizon_execution/constants.py | 8 + .../tasks/long_horizon_execution/main.py | 151 +++++++++++++++++- .../long_horizon_execution/multi_turn.py | 77 +++++---- .../long_horizon_execution/single_turn.py | 128 --------------- 5 files changed, 196 insertions(+), 168 deletions(-) delete mode 100644 src/lighteval/tasks/tasks/long_horizon_execution/__init__.py delete mode 100644 src/lighteval/tasks/tasks/long_horizon_execution/single_turn.py diff --git a/src/lighteval/tasks/tasks/long_horizon_execution/__init__.py b/src/lighteval/tasks/tasks/long_horizon_execution/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/lighteval/tasks/tasks/long_horizon_execution/constants.py b/src/lighteval/tasks/tasks/long_horizon_execution/constants.py index dea40af8a..a7bb7b8ea 100644 --- a/src/lighteval/tasks/tasks/long_horizon_execution/constants.py +++ b/src/lighteval/tasks/tasks/long_horizon_execution/constants.py @@ -4,28 +4,36 @@ PROMPT_TEMPLATE_SINGLE = """You are an AI assistant. I will provide you with a dictionary and then give you a list of keys. Your task is to calculate the final cumulative sum after processing all keys in order. + For each key in the list, you need to: 1. Look up the value in the dictionary 2. Add it to the running sum 3. After processing all keys, output the final cumulative sum + Dictionary to use: {dict_str} + Keys to process in order: {keys_str} + Your task: Process all keys in order and calculate the final cumulative sum after processing all {num_keys} keys. + IMPORTANT: - Output your answer as a single integer value inside tags - Do not include any other text outside the answer tags - Format: final_sum - Example: If the final cumulative sum is 42, output: 42 + Your answer:""" PROMPT_TEMPLATE_MULTI_START = """You are an AI assistant. I will provide you with a dictionary and then give you keys in groups of {k}. Your task is to keep a running total (starting from 0) by adding the values associated with the keys I provide. In each turn, I'll provide {k} keys (comma-separated). Respond with the current running sum, enclosed in tags. + Dictionary to maintain: {dict_str} + Ready to start! **User**: {keys_str} **Assistant**:""" diff --git a/src/lighteval/tasks/tasks/long_horizon_execution/main.py b/src/lighteval/tasks/tasks/long_horizon_execution/main.py index 9f523aa63..3c58c0bb3 100644 --- a/src/lighteval/tasks/tasks/long_horizon_execution/main.py +++ b/src/lighteval/tasks/tasks/long_horizon_execution/main.py @@ -34,8 +34,157 @@ starred: true """ +import functools +import re + +from inspect_ai.dataset import Sample +from inspect_ai.scorer import Score, Target, accuracy, scorer, stderr +from inspect_ai.solver import TaskState, generate + +from lighteval.metrics.metrics import Metrics +from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc +from lighteval.tasks.tasks.long_horizon_execution.constants import CONTEXT_SIZES from lighteval.tasks.tasks.long_horizon_execution.multi_turn import create_multi_turn_tasks -from lighteval.tasks.tasks.long_horizon_execution.single_turn import create_single_turn_tasks +from lighteval.tasks.tasks.long_horizon_execution.utils import _build_prompt_and_target + + +# Single-turn prompt template +PROMPT_TEMPLATE_SINGLE = """You are an AI assistant. I will provide you with a dictionary and then give you a list of keys. +Your task is to calculate the final cumulative sum after processing all keys in order. + +For each key in the list, you need to: +1. Look up the value in the dictionary +2. Add it to the running sum +3. After processing all keys, output the final cumulative sum + +Dictionary to use: +{dict_str} + +Keys to process in order: +{keys_str} + +Your task: Process all keys in order and calculate the final cumulative sum after processing all {num_keys} keys. + +IMPORTANT: +- Output your answer as a single integer value inside tags +- Do not include any other text outside the answer tags +- Format: final_sum +- Example: If the final cumulative sum is 42, output: 42 + +Your answer:""" + + +def single_turn_prompt_function(line, prompt_length=32768, task_name: str = None): + """ + Prompt function for single-turn evaluation (non-inspect-ai backend). + Converts dataset record to Doc object. + Returns: + Doc object for evaluation + """ + prompt, target_str, _ = _build_prompt_and_target( + line, prompt_length=prompt_length, prompt_template=PROMPT_TEMPLATE_SINGLE + ) + + return Doc( + task_name=task_name, + query=prompt, + choices=[target_str], # Expected answer as a choice + gold_index=0, + instruction=prompt, + ) + + +def single_turn_record_to_sample(record, prompt_length=32768): + """ + Converts dataset record to inspect-ai Sample object for single-turn evaluation. + Returns: + Sample object for inspect-ai + """ + prompt, target_str, metadata = _build_prompt_and_target( + record, prompt_length=prompt_length, prompt_template=PROMPT_TEMPLATE_SINGLE + ) + + return Sample( + input=prompt, + target=target_str, + metadata=metadata, + ) + + +@scorer(metrics=[accuracy(), stderr()]) +def single_turn_scorer(): + """ + Scorer for single-turn evaluation. + Compares the model's predicted final sum with the expected final sum (binary score). + Returns: + Scorer function that evaluates single integer responses + """ + + async def score(state: TaskState, target: Target): + response = state.output.completion + + answer_pattern = re.compile(r"(.*?)", re.DOTALL) + match = answer_pattern.search(response) + + if not match: + return Score(value="I", answer="", explanation="No tag found in response.") + + content = match.group(1).strip() + + try: + pred_value = int(content.strip()) + except ValueError: + return Score(value="I", answer=content, explanation=f"Failed to parse integer from: {content}") + + try: + exp_value = int(target.text.strip()) + except (ValueError, AttributeError): + return Score( + value="I", + answer=str(pred_value), + explanation=f"Failed to parse expected target: {target.text}", + ) + + is_correct = pred_value == exp_value + return Score( + value="C" if is_correct else "I", + answer=str(pred_value), + explanation=(f"Expected {exp_value}, Got {pred_value}. Match: {is_correct}"), + ) + + return score + + +def create_single_turn_tasks(): + """ + Create all single-turn task configurations for different context sizes. + Returns: + list[LightevalTaskConfig]: List of task configurations for single-turn evaluation + """ + tasks = [] + + for context_size in CONTEXT_SIZES: + task_name = f"long_horizon_execution_single:{context_size}" + prompt_fn = functools.partial(single_turn_prompt_function, prompt_length=context_size) + sample_fn = functools.partial(single_turn_record_to_sample, prompt_length=context_size) + + task = LightevalTaskConfig( + name=task_name, + prompt_function=prompt_fn, + sample_fields=sample_fn, + solver=[generate(cache=True)], + scorer=single_turn_scorer(), + hf_repo="arvindh75/Long-Horizon-Execution", + hf_subset="default", + evaluation_splits=("test",), + generation_size=context_size, + metrics=[Metrics.exact_match], + ) + + tasks.append(task) + + return tasks single_turn_tasks = create_single_turn_tasks() diff --git a/src/lighteval/tasks/tasks/long_horizon_execution/multi_turn.py b/src/lighteval/tasks/tasks/long_horizon_execution/multi_turn.py index ec86a093a..b7c750fc5 100644 --- a/src/lighteval/tasks/tasks/long_horizon_execution/multi_turn.py +++ b/src/lighteval/tasks/tasks/long_horizon_execution/multi_turn.py @@ -9,8 +9,9 @@ import re from inspect_ai.dataset import Sample +from inspect_ai.model import ChatMessageUser, ModelOutput from inspect_ai.scorer import Score, Target, accuracy, scorer, stderr -from inspect_ai.solver import TaskState, generate +from inspect_ai.solver import Generate, TaskState, generate, solver from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig @@ -66,68 +67,66 @@ def _extract_response_content(response): return str(response) -async def _process_single_turn(state, turn_chunk, config): +async def _process_single_turn(state, turn_chunk, generate): """Process a single turn: add user message, get model response, add assistant message.""" keys_str = ", ".join(turn_chunk) followup_prompt = PROMPT_TEMPLATE_MULTI_FOLLOWUP.format(keys_str=keys_str) - state.messages.append({"role": "user", "content": followup_prompt}) + state.messages.append(ChatMessageUser(content=followup_prompt)) - response = await state.model.generate(messages=state.messages, config=config) - turn_response = _extract_response_content(response) + # generate() takes the state and returns updated state with assistant message added + updated_state = await generate(state) + turn_response = _extract_response_content(updated_state.output.completion if updated_state.output else "") - state.messages.append({"role": "assistant", "content": turn_response}) - return turn_response + return updated_state, turn_response -async def multi_turn_solver(state: TaskState): +@solver +def multi_turn_solver(): """ - Custom solver for multi-turn evaluation. + Solver for multi-turn evaluation. Loops through turns, calling the model for each turn while maintaining conversation history. This implements offline evaluation: all turns are called, then evaluation happens. """ - from inspect_ai.model import GenerateConfig, ModelOutput - turn_chunks = state.metadata.get("turn_chunks", []) + async def solve(state: TaskState, generate: Generate): + turn_chunks = state.metadata.get("turn_chunks", []) - if not turn_chunks or len(turn_chunks) == 0: - return state - - # Initialize messages - if not hasattr(state, "messages") or state.messages is None: - state.messages = [] + if not turn_chunks or len(turn_chunks) == 0: + return state - if not state.messages: - state.messages.append({"role": "user", "content": state.input}) + # Initialize messages + if not hasattr(state, "messages") or state.messages is None: + state.messages = [] - all_turn_outputs = [] + if not state.messages: + state.messages.append(ChatMessageUser(content=state.input)) - # Process all turns - if hasattr(state, "model") and state.model is not None: - config = GenerateConfig() + all_turn_outputs = [] # Process first turn (already in messages as initial prompt) - response = await state.model.generate(messages=state.messages, config=config) - turn_response = _extract_response_content(response) + updated_state = await generate(state) + turn_response = _extract_response_content(updated_state.output.completion if updated_state.output else "") all_turn_outputs.append(turn_response) - state.messages.append({"role": "assistant", "content": turn_response}) + + state = updated_state # Process remaining turns for turn_idx in range(1, len(turn_chunks)): - if not hasattr(state, "model") or state.model is None: - break - turn_response = await _process_single_turn(state, turn_chunks[turn_idx], config) + state, turn_response = await _process_single_turn(state, turn_chunks[turn_idx], generate) all_turn_outputs.append(turn_response) - state.metadata["all_turn_outputs"] = all_turn_outputs + state.metadata["all_turn_outputs"] = all_turn_outputs - # Set final output - if all_turn_outputs: - if hasattr(state, "output") and state.output is not None: - state.output.completion = all_turn_outputs[-1] - else: - state.output = ModelOutput(completion=all_turn_outputs[-1]) + # Set final output + if all_turn_outputs: + if hasattr(state, "output") and state.output is not None: + state.output.completion = all_turn_outputs[-1] + else: + state.output = ModelOutput(completion=all_turn_outputs[-1]) + + return state - return state + return solve @scorer(metrics={"turn_accuracy": [accuracy(), stderr()], "fractional_accuracy": [accuracy(), stderr()]}) @@ -200,7 +199,7 @@ def create_multi_turn_tasks(): for context_size in CONTEXT_SIZES: for k in TURN_COMPLEXITIES: - task_name = f"long_horizon_execution:multi:{context_size}:k{k}" + task_name = f"long_horizon_execution_multi_k{k}:{context_size}" prompt_fn = functools.partial(multi_turn_prompt_function, prompt_length=context_size, k=k) sample_fn = functools.partial(multi_turn_record_to_sample, prompt_length=context_size, k=k) @@ -208,7 +207,7 @@ def create_multi_turn_tasks(): name=task_name, prompt_function=prompt_fn, sample_fields=sample_fn, - solver=[multi_turn_solver, generate(cache=True)], + solver=[multi_turn_solver(), generate(cache=True)], scorer=multi_turn_scorer(), hf_repo="arvindh75/Long-Horizon-Execution", hf_subset="default", diff --git a/src/lighteval/tasks/tasks/long_horizon_execution/single_turn.py b/src/lighteval/tasks/tasks/long_horizon_execution/single_turn.py deleted file mode 100644 index d7bd508be..000000000 --- a/src/lighteval/tasks/tasks/long_horizon_execution/single_turn.py +++ /dev/null @@ -1,128 +0,0 @@ -""" -Single turn implementation of the Long Horizon Execution task. -""" - -import functools -import re - -from inspect_ai.dataset import Sample -from inspect_ai.scorer import Score, Target, accuracy, scorer, stderr -from inspect_ai.solver import TaskState, generate - -from lighteval.metrics.metrics import Metrics -from lighteval.tasks.lighteval_task import LightevalTaskConfig -from lighteval.tasks.requests import Doc -from lighteval.tasks.tasks.long_horizon_execution.constants import CONTEXT_SIZES, PROMPT_TEMPLATE_SINGLE -from lighteval.tasks.tasks.long_horizon_execution.utils import _build_prompt_and_target - - -def single_turn_prompt_function(line, prompt_length=32768, task_name: str = None): - """ - Prompt function for single-turn evaluation (non-inspect-ai backend). - Converts dataset record to Doc object. - Returns: - Doc object for evaluation - """ - prompt, target_str, _ = _build_prompt_and_target( - line, prompt_length=prompt_length, prompt_template=PROMPT_TEMPLATE_SINGLE - ) - - return Doc( - task_name=task_name, - query=prompt, - choices=[target_str], # Expected answer as a choice - gold_index=0, - instruction=prompt, - ) - - -def single_turn_record_to_sample(record, prompt_length=32768): - """ - Converts dataset record to inspect-ai Sample object for single-turn evaluation. - Returns: - Sample object for inspect-ai - """ - prompt, target_str, metadata = _build_prompt_and_target( - record, prompt_length=prompt_length, prompt_template=PROMPT_TEMPLATE_SINGLE - ) - - return Sample( - input=prompt, - target=target_str, - metadata=metadata, - ) - - -@scorer(metrics=[accuracy(), stderr()]) -def single_turn_scorer(): - """ - Scorer for single-turn evaluation. - Compares the model's predicted final sum with the expected final sum (binary score). - Returns: - Scorer function that evaluates single integer responses - """ - - async def score(state: TaskState, target: Target): - response = state.output.completion - - answer_pattern = re.compile(r"(.*?)", re.DOTALL) - match = answer_pattern.search(response) - - if not match: - return Score(value="I", answer="", explanation="No tag found in response.") - - content = match.group(1).strip() - - try: - pred_value = int(content.strip()) - except ValueError: - return Score(value="I", answer=content, explanation=f"Failed to parse integer from: {content}") - - try: - exp_value = int(target.text.strip()) - except (ValueError, AttributeError): - return Score( - value="I", - answer=str(pred_value), - explanation=f"Failed to parse expected target: {target.text}", - ) - - is_correct = pred_value == exp_value - return Score( - value="C" if is_correct else "I", - answer=str(pred_value), - explanation=(f"Expected {exp_value}, Got {pred_value}. Match: {is_correct}"), - ) - - return score - - -def create_single_turn_tasks(): - """ - Create all single-turn task configurations for different context sizes. - Returns: - list[LightevalTaskConfig]: List of task configurations for single-turn evaluation - """ - tasks = [] - - for context_size in CONTEXT_SIZES: - task_name = f"long_horizon_execution:{context_size}" - prompt_fn = functools.partial(single_turn_prompt_function, prompt_length=context_size) - sample_fn = functools.partial(single_turn_record_to_sample, prompt_length=context_size) - - task = LightevalTaskConfig( - name=task_name, - prompt_function=prompt_fn, - sample_fields=sample_fn, - solver=[generate(cache=True)], - scorer=single_turn_scorer(), - hf_repo="arvindh75/Long-Horizon-Execution", - hf_subset="default", - evaluation_splits=("test",), - generation_size=context_size, - metrics=[Metrics.exact_match], - ) - - tasks.append(task) - - return tasks From 3138c7c79b703851c5cd40267cdf467f2067b04b Mon Sep 17 00:00:00 2001 From: Akshath Mangudi Date: Sun, 14 Dec 2025 12:07:47 +0530 Subject: [PATCH 4/4] addressing comments and fixing impl --- .../tasks/long_horizon_execution/constants.py | 2 +- .../tasks/long_horizon_execution/main.py | 28 +------------------ .../long_horizon_execution/multi_turn.py | 27 +++++++++--------- .../tasks/long_horizon_execution/utils.py | 13 ++++----- 4 files changed, 20 insertions(+), 50 deletions(-) diff --git a/src/lighteval/tasks/tasks/long_horizon_execution/constants.py b/src/lighteval/tasks/tasks/long_horizon_execution/constants.py index a7bb7b8ea..2a5c9954f 100644 --- a/src/lighteval/tasks/tasks/long_horizon_execution/constants.py +++ b/src/lighteval/tasks/tasks/long_horizon_execution/constants.py @@ -28,7 +28,7 @@ PROMPT_TEMPLATE_MULTI_START = """You are an AI assistant. I will provide you with a dictionary and then give you keys in groups of {k}. Your task is to keep a running total (starting from 0) by adding the values associated with the keys I provide. -In each turn, I'll provide {k} keys (comma-separated). +In each turn, I'll provide {k} key(s) (comma-separated). Respond with the current running sum, enclosed in tags. Dictionary to maintain: diff --git a/src/lighteval/tasks/tasks/long_horizon_execution/main.py b/src/lighteval/tasks/tasks/long_horizon_execution/main.py index 3c58c0bb3..09686b40c 100644 --- a/src/lighteval/tasks/tasks/long_horizon_execution/main.py +++ b/src/lighteval/tasks/tasks/long_horizon_execution/main.py @@ -44,37 +44,11 @@ from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.requests import Doc -from lighteval.tasks.tasks.long_horizon_execution.constants import CONTEXT_SIZES +from lighteval.tasks.tasks.long_horizon_execution.constants import CONTEXT_SIZES, PROMPT_TEMPLATE_SINGLE from lighteval.tasks.tasks.long_horizon_execution.multi_turn import create_multi_turn_tasks from lighteval.tasks.tasks.long_horizon_execution.utils import _build_prompt_and_target -# Single-turn prompt template -PROMPT_TEMPLATE_SINGLE = """You are an AI assistant. I will provide you with a dictionary and then give you a list of keys. -Your task is to calculate the final cumulative sum after processing all keys in order. - -For each key in the list, you need to: -1. Look up the value in the dictionary -2. Add it to the running sum -3. After processing all keys, output the final cumulative sum - -Dictionary to use: -{dict_str} - -Keys to process in order: -{keys_str} - -Your task: Process all keys in order and calculate the final cumulative sum after processing all {num_keys} keys. - -IMPORTANT: -- Output your answer as a single integer value inside tags -- Do not include any other text outside the answer tags -- Format: final_sum -- Example: If the final cumulative sum is 42, output: 42 - -Your answer:""" - - def single_turn_prompt_function(line, prompt_length=32768, task_name: str = None): """ Prompt function for single-turn evaluation (non-inspect-ai backend). diff --git a/src/lighteval/tasks/tasks/long_horizon_execution/multi_turn.py b/src/lighteval/tasks/tasks/long_horizon_execution/multi_turn.py index b7c750fc5..e34638fd0 100644 --- a/src/lighteval/tasks/tasks/long_horizon_execution/multi_turn.py +++ b/src/lighteval/tasks/tasks/long_horizon_execution/multi_turn.py @@ -67,14 +67,14 @@ def _extract_response_content(response): return str(response) -async def _process_single_turn(state, turn_chunk, generate): +async def _process_single_turn(state, turn_chunk, generate_fn): """Process a single turn: add user message, get model response, add assistant message.""" keys_str = ", ".join(turn_chunk) followup_prompt = PROMPT_TEMPLATE_MULTI_FOLLOWUP.format(keys_str=keys_str) state.messages.append(ChatMessageUser(content=followup_prompt)) - # generate() takes the state and returns updated state with assistant message added - updated_state = await generate(state) + # generate_fn() takes the state and returns updated state with assistant message added + updated_state = await generate_fn(state) turn_response = _extract_response_content(updated_state.output.completion if updated_state.output else "") return updated_state, turn_response @@ -91,7 +91,7 @@ def multi_turn_solver(): async def solve(state: TaskState, generate: Generate): turn_chunks = state.metadata.get("turn_chunks", []) - if not turn_chunks or len(turn_chunks) == 0: + if not turn_chunks: return state # Initialize messages @@ -129,7 +129,7 @@ async def solve(state: TaskState, generate: Generate): return solve -@scorer(metrics={"turn_accuracy": [accuracy(), stderr()], "fractional_accuracy": [accuracy(), stderr()]}) +@scorer(metrics={"fractional_accuracy": [accuracy(), stderr()]}) def multi_turn_scorer(): """ Scorer for multi-turn Long Horizon Execution task. @@ -143,11 +143,15 @@ async def score(state: TaskState, target: Target): expected_per_turn = state.metadata.get("expected_per_turn", []) if not all_turn_outputs: - return Score(value=0.0, answer="", explanation="No turn outputs found in state.metadata") + return Score( + value={"fractional_accuracy": 0.0}, + answer="", + explanation="No turn outputs found in state.metadata", + ) if len(all_turn_outputs) != len(expected_per_turn): return Score( - value=0.0, + value={"fractional_accuracy": 0.0}, answer="", explanation=f"Mismatch: {len(all_turn_outputs)} outputs vs {len(expected_per_turn)} expected turns", ) @@ -155,7 +159,7 @@ async def score(state: TaskState, target: Target): parsed_outputs = [] answer_pattern = re.compile(r"(.*?)", re.DOTALL) - for turn_idx, turn_output in enumerate(all_turn_outputs): + for turn_output in all_turn_outputs: match = answer_pattern.search(turn_output) if match: try: @@ -177,12 +181,7 @@ async def score(state: TaskState, target: Target): fractional_accuracy = correct_turns / len(expected_per_turn) if expected_per_turn else 0.0 return Score( - value={ - "turn_accuracy": fractional_accuracy, - "fractional_accuracy": fractional_accuracy, - "correct_turns": correct_turns, - "total_turns": len(expected_per_turn), - }, + value={"fractional_accuracy": fractional_accuracy}, answer=str(parsed_outputs), explanation=f"Correct {correct_turns}/{len(expected_per_turn)} turns. Details: {turn_results}", ) diff --git a/src/lighteval/tasks/tasks/long_horizon_execution/utils.py b/src/lighteval/tasks/tasks/long_horizon_execution/utils.py index 9ea4ca6ab..f96acda86 100644 --- a/src/lighteval/tasks/tasks/long_horizon_execution/utils.py +++ b/src/lighteval/tasks/tasks/long_horizon_execution/utils.py @@ -126,7 +126,9 @@ def build_initial_prompt_for_n(n): keys_str = ", ".join(first_turn_keys) return PROMPT_TEMPLATE_MULTI_START.format( - dict_str=dict_str, keys_str=keys_str, k=k, num_keys=len(first_turn_keys) + dict_str=dict_str, + keys_str=keys_str, + k=k, ) return _binary_search_max_items(input_keys, build_initial_prompt_for_n, prompt_length, min_items=k) @@ -169,7 +171,6 @@ def _build_multi_turn_prompts(record, prompt_length=32768, k=1): """ input_keys = record["input"] input_values = record["values"] - expected_output = record["output"] # Handle empty input case if len(input_keys) == 0: @@ -181,21 +182,17 @@ def _build_multi_turn_prompts(record, prompt_length=32768, k=1): # Use the maximum n that fits input_keys = input_keys[:max_n] input_values = input_values[:max_n] - expected_output = expected_output[:max_n] - turn_chunks, value_chunks, expected_per_turn = _chunk_and_calculate_expected(input_keys, input_values, k) + turn_chunks, _, expected_per_turn = _chunk_and_calculate_expected(input_keys, input_values, k) dictionary = dict(zip(input_keys, input_values)) dict_str = str(dictionary) first_turn_keys_str = ", ".join(turn_chunks[0]) - initial_prompt = PROMPT_TEMPLATE_MULTI_START.format( - dict_str=dict_str, keys_str=first_turn_keys_str, k=k, num_keys=len(turn_chunks[0]) - ) + initial_prompt = PROMPT_TEMPLATE_MULTI_START.format(dict_str=dict_str, keys_str=first_turn_keys_str, k=k) metadata = { "turn_chunks": turn_chunks, - "value_chunks": value_chunks, "expected_per_turn": expected_per_turn, "dictionary": dictionary, "k": k,