From 4db51f7886aac6f778edab057fced3225c2ccfcf Mon Sep 17 00:00:00 2001 From: zrobertson466920 Date: Wed, 19 Nov 2025 18:02:25 -0800 Subject: [PATCH 1/6] add tvd-mi prompt + parser --- src/lighteval/metrics/utils/judge_utils.py | 40 ++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/src/lighteval/metrics/utils/judge_utils.py b/src/lighteval/metrics/utils/judge_utils.py index cde25fd26..e7cc299fd 100644 --- a/src/lighteval/metrics/utils/judge_utils.py +++ b/src/lighteval/metrics/utils/judge_utils.py @@ -125,3 +125,43 @@ def process_judge_response_simpleqa(response: str) -> float: else: logger.warning(f"Unknown response from judge: {response}") return 0.0 + + +def get_judge_prompt_tvdmi( + question: str, + answer: str, + options: list[str] | None = None, + gold: str | None = None, + **kwargs, +): + # `question` = response_a, `answer` = response_b + TEMPLATE = """ +You are judging whether two responses are likely derived from the same underlying task or item. + +Consider ONLY whether the two responses plausibly come from the same task/source, not whether they are factually correct. + +Response A: +{response_a} + +Response B: +{response_b} + +Grade as: +A: SAME TASK/SOURCE +B: DIFFERENT TASK/SOURCE + +Just return the letter "A" or "B", with no other text. +""".strip() + + content = TEMPLATE.format(response_a=question, response_b=answer) + return [{"role": "user", "content": content}] + + +def process_judge_response_tvdmi(response: str) -> int: + if response == "A": + return 1 # SAME + elif response == "B": + return 0 # DIFFERENT + else: + logger.warning(f"Unknown response from TVD-MI judge: {response!r}") + return 0 From 3cb0dfcb89a96c8bfbe72df066c820fd2d41c1ba Mon Sep 17 00:00:00 2001 From: zrobertson466920 Date: Wed, 19 Nov 2025 18:16:00 -0800 Subject: [PATCH 2/6] implement judgellmtvdmi and aggregator --- src/lighteval/metrics/metrics_corpus.py | 4 +- src/lighteval/metrics/metrics_sample.py | 85 ++++++++++++++++++++++--- 2 files changed, 79 insertions(+), 10 deletions(-) diff --git a/src/lighteval/metrics/metrics_corpus.py b/src/lighteval/metrics/metrics_corpus.py index 92c2c574a..8532e5080 100644 --- a/src/lighteval/metrics/metrics_corpus.py +++ b/src/lighteval/metrics/metrics_corpus.py @@ -107,7 +107,9 @@ def compute_corpus(self, items: list[LogprobCorpusMetricInput]): for i in range(self.num_classes): f1s.append( sklearn.metrics.f1_score( - y_true=[g == i for g in golds], y_pred=[p == i for p in preds], average=self.average + y_true=[g == i for g in golds], + y_pred=[p == i for p in preds], + average=self.average, ) ) return float(np.mean(f1s)) diff --git a/src/lighteval/metrics/metrics_sample.py b/src/lighteval/metrics/metrics_sample.py index d83e64e22..0d98ccd45 100644 --- a/src/lighteval/metrics/metrics_sample.py +++ b/src/lighteval/metrics/metrics_sample.py @@ -51,7 +51,12 @@ remove_braces, remove_braces_and_strip, ) -from lighteval.metrics.utils.judge_utils import get_judge_prompt_simpleqa, process_judge_response_simpleqa +from lighteval.metrics.utils.judge_utils import ( + get_judge_prompt_simpleqa, + get_judge_prompt_tvdmi, + process_judge_response_simpleqa, + process_judge_response_tvdmi, +) from lighteval.metrics.utils.llm_as_judge import JudgeLM from lighteval.models.model_output import ModelResponse from lighteval.tasks.requests import Doc @@ -643,7 +648,10 @@ def compute(self, doc: Doc, model_response: ModelResponse, **kwargs) -> dict[str logger.warning("The first metric computation step might be a bit longer as we need to download the model.") # We only initialize on first compute self.bert_scorer = BERTScorer( - model_type="microsoft/deberta-large-mnli", lang="en", rescale_with_baseline=True, num_layers=9 + model_type="microsoft/deberta-large-mnli", + lang="en", + rescale_with_baseline=True, + num_layers=9, ) golds = as_list(golds) predictions = as_list(predictions) @@ -655,7 +663,11 @@ def compute(self, doc: Doc, model_response: ModelResponse, **kwargs) -> dict[str predictions = [self.normalize_pred(p) for p in predictions] p, r, f = self.bert_scorer.score(predictions, golds) - return {"BERTScore-P": p[0].item(), "BERTScore-R": r[0].item(), "BERTScore-F": f[0].item()} + return { + "BERTScore-P": p[0].item(), + "BERTScore-R": r[0].item(), + "BERTScore-F": f[0].item(), + } class Extractiveness(SampleLevelComputation): @@ -856,7 +868,11 @@ def __init__( metric_types (list[str] | str): Can be one or any of `longest_common_prefix_length`, `edit_distance` or `edit_similarity`. strip_prediction (bool, optional): Whether to strip the prediction. Defaults to True. """ - allowed_values = ["longest_common_prefix_length", "edit_distance", "edit_similarity"] + allowed_values = [ + "longest_common_prefix_length", + "edit_distance", + "edit_similarity", + ] metric_types = as_list(metric_types) if any(metric_type not in allowed_values for metric_type in metric_types): raise ValueError( @@ -864,7 +880,11 @@ def __init__( ) self.metric_types = metric_types self.strip_prediction = strip_prediction - self.sample_aggregations = {"longest_common_prefix_length": max, "edit_distance": min, "edit_similarity": max} + self.sample_aggregations = { + "longest_common_prefix_length": max, + "edit_distance": min, + "edit_similarity": max, + } def compute(self, doc: Doc, model_response: ModelResponse, **kwargs): """Computes all the requested metrics on the golds and prediction. @@ -940,7 +960,13 @@ def edit_similarity(self, s1, s2): class JudgeLLM(SampleLevelComputation): - available_models_openai = ["gpt-3.5-turbo", "gpt-4o", "gpt-4-turbo", "gpt-4", "gpt-4o-2024-08-06"] + available_models_openai = [ + "gpt-3.5-turbo", + "gpt-4o", + "gpt-4-turbo", + "gpt-4", + "gpt-4o-2024-08-06", + ] def __init__( self, @@ -1065,10 +1091,14 @@ def compute(self, model_response: list[ModelResponse], doc: list[Doc], **kwargs) query_context_2 = {"query": questions[1], "context": predictions[0]} score_turn_1, message_turn_1, judgement_turn_1 = self.judge.evaluate_answer( - question=json.dumps(query_context_1, indent=2), answer=predictions[0], gold=golds[0] if golds else None + question=json.dumps(query_context_1, indent=2), + answer=predictions[0], + gold=golds[0] if golds else None, ) score_turn_2, message_turn_2, judgement_turn_2 = self.judge.evaluate_answer( - question=json.dumps(query_context_2, indent=2), answer=predictions[1], gold=golds[1] if golds else None + question=json.dumps(query_context_2, indent=2), + answer=predictions[1], + gold=golds[1] if golds else None, ) return { @@ -1106,6 +1136,43 @@ def compute(self, responses: list[ModelResponse], docs: list[Doc], **kwargs): return metrics +class JudgeLLMTVDMI(JudgeLLM): + def __init__(self): + super().__init__( + judge_model_name="gpt-4o-2024-08-06", + template=get_judge_prompt_tvdmi, + process_judge_response=process_judge_response_tvdmi, + judge_backend="openai", + short_judge_name="gpt4o", + ) + + def compute(self, responses: list[ModelResponse], docs: list[Doc], **kwargs) -> list: + # For TVD-MI, the evaluated model is the judge; the “responses” from + # base models are already baked into docs as response_a / response_b. + questions = [d.response_a for d in docs] + answers = [d.response_b for d in docs] + labels = [int(d.pair_label) for d in docs] + + options = [None] * len(docs) + golds = [None] * len(docs) + + scores, prompts, judge_responses = self.judge.evaluate_answer_batch(questions, answers, options, golds) + + metrics = [] + for i in range(len(docs)): + pred = scores[i] # already 0/1 from process_judge_response_tvdmi + metrics.append( + { + "label": labels[i], + "pred": pred, + f"user_prompt_{self.short_judge_name}": prompts[i], + f"judgement_{self.short_judge_name}": judge_responses[i], + } + ) + + return metrics + + class SamplingMetric: """All sampling metrics we have defined below use the same set of normalization parameters and same behavior for the default sample_scoring_function. This class just holds the normalization and applies it to all samples passed to preprocess, then uses the default sample function if not provided. @@ -1115,7 +1182,7 @@ def __init__( self, normalize: Callable | str | None = None, strip_strings: bool = False, - sample_scoring_function: Callable[[Doc, ModelResponse], float] | str | None = None, + sample_scoring_function: (Callable[[Doc, ModelResponse], float] | str | None) = None, ): if isinstance(normalize, str): import lighteval.metrics.normalizations From bfa6e654a9115a20f9fe758078daee7fbb59dd82 Mon Sep 17 00:00:00 2001 From: zrobertson466920 Date: Wed, 19 Nov 2025 18:24:47 -0800 Subject: [PATCH 3/6] corpus aggregator + register metric + sanity test --- src/lighteval/metrics/metrics.py | 197 ++++++++++++++++++++---- src/lighteval/metrics/metrics_corpus.py | 19 +++ 2 files changed, 182 insertions(+), 34 deletions(-) diff --git a/src/lighteval/metrics/metrics.py b/src/lighteval/metrics/metrics.py index 82cfbb706..2c6e24b99 100644 --- a/src/lighteval/metrics/metrics.py +++ b/src/lighteval/metrics/metrics.py @@ -35,6 +35,7 @@ CorpusLevelF1Score, CorpusLevelPerplexityMetric, CorpusLevelTranslationMetric, + CorpusLevelTVDMI, MatthewsCorrCoef, ) from lighteval.metrics.metrics_sample import ( @@ -51,13 +52,18 @@ Faithfulness, GPassAtK, JudgeLLMSimpleQA, + JudgeLLMTVDMI, LoglikelihoodAcc, MajAtN, PassAtK, Recall, StringDistance, ) -from lighteval.metrics.normalizations import bigbench_normalizer, remove_braces, remove_braces_and_strip +from lighteval.metrics.normalizations import ( + bigbench_normalizer, + remove_braces, + remove_braces_and_strip, +) from lighteval.metrics.sample_preparator import ( GenerativePreparator, LoglikelihoodPreparator, @@ -84,7 +90,10 @@ @scorer(metrics=[accuracy()]) def math_scorer(): gold_extraction_target = (ExprExtractionConfig(),) - pred_extraction_target = (ExprExtractionConfig(), LatexExtractionConfig(boxed_match_priority=0)) + pred_extraction_target = ( + ExprExtractionConfig(), + LatexExtractionConfig(boxed_match_priority=0), + ) language = Language.ENGLISH fallback_mode = "first_match" extraction_mode = "first_match" @@ -95,10 +104,18 @@ def math_scorer(): async def score(state: TaskState, target: Target): extracted_predictions = extract_target_from_pred( - state.output.completion, pred_extraction_regexes, fallback_mode, extraction_mode, timeout_seconds + state.output.completion, + pred_extraction_regexes, + fallback_mode, + extraction_mode, + timeout_seconds, ) extracted_gold = extract_target_from_pred( - target.text, gold_extraction_regexes, fallback_mode, extraction_mode, timeout_seconds + target.text, + gold_extraction_regexes, + fallback_mode, + extraction_mode, + timeout_seconds, ) return Score( # Correct or Incorrect, used by inspect-ai backend @@ -128,10 +145,18 @@ def multichoice_scorer(): async def score(state: TaskState, target: Target): extracted_predictions = extract_target_from_pred( - state.output.completion, pred_extraction_regexes, fallback_mode, extraction_mode, timeout_seconds + state.output.completion, + pred_extraction_regexes, + fallback_mode, + extraction_mode, + timeout_seconds, ) extracted_gold = extract_target_from_pred( - target.text, gold_extraction_regexes, fallback_mode, extraction_mode, timeout_seconds + target.text, + gold_extraction_regexes, + fallback_mode, + extraction_mode, + timeout_seconds, ) return Score( # Correct or Incorrect, used by inspect-ai backend @@ -163,8 +188,14 @@ class Metrics(Enum): sample_level_fn=AvgAtN( sample_scoring_function=MultilingualExtractiveMatchMetric( language=Language.ENGLISH, - gold_extraction_target=[ExprExtractionConfig(), LatexExtractionConfig()], - pred_extraction_target=[ExprExtractionConfig(), LatexExtractionConfig()], + gold_extraction_target=[ + ExprExtractionConfig(), + LatexExtractionConfig(), + ], + pred_extraction_target=[ + ExprExtractionConfig(), + LatexExtractionConfig(), + ], precision=6, ), ), @@ -176,8 +207,16 @@ class Metrics(Enum): metric_name=["BERTScore-P", "BERTScore-R", "BERTScore-F"], sample_level_fn=BertScore(normalize_gold=remove_braces, normalize_pred=remove_braces_and_strip), category=SamplingMethod.GENERATIVE, - corpus_level_fn={"BERTScore-P": np.mean, "BERTScore-R": np.mean, "BERTScore-F": np.mean}, - higher_is_better={"BERTScore-P": True, "BERTScore-R": True, "BERTScore-F": True}, + corpus_level_fn={ + "BERTScore-P": np.mean, + "BERTScore-R": np.mean, + "BERTScore-F": np.mean, + }, + higher_is_better={ + "BERTScore-P": True, + "BERTScore-R": True, + "BERTScore-F": True, + }, ) bits_per_byte = CorpusLevelMetric( metric_name="bits_per_byte", @@ -237,13 +276,30 @@ class Metrics(Enum): higher_is_better=True, ) copyright = SampleLevelMetricGrouping( - metric_name=["longest_common_prefix_length", "edit_distance", "edit_similarity"], + metric_name=[ + "longest_common_prefix_length", + "edit_distance", + "edit_similarity", + ], sample_level_fn=StringDistance( - metric_types=["longest_common_prefix_length", "edit_distance", "edit_similarity"], strip_prediction=True + metric_types=[ + "longest_common_prefix_length", + "edit_distance", + "edit_similarity", + ], + strip_prediction=True, ), category=SamplingMethod.GENERATIVE, - corpus_level_fn={"longest_common_prefix_length": max, "edit_distance": min, "edit_similarity": max}, - higher_is_better={"longest_common_prefix_length": True, "edit_distance": False, "edit_similarity": True}, + corpus_level_fn={ + "longest_common_prefix_length": max, + "edit_distance": min, + "edit_similarity": max, + }, + higher_is_better={ + "longest_common_prefix_length": True, + "edit_distance": False, + "edit_similarity": True, + }, ) drop = SampleLevelMetricGrouping( metric_name=["em", "f1"], @@ -267,7 +323,10 @@ class Metrics(Enum): precision=5, gold_extraction_target=(ExprExtractionConfig(),), # Match boxed first before trying other regexes - pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig(boxed_match_priority=0)), + pred_extraction_target=( + ExprExtractionConfig(), + LatexExtractionConfig(boxed_match_priority=0), + ), aggregation_function=max, ), category=SamplingMethod.GENERATIVE, @@ -275,9 +334,15 @@ class Metrics(Enum): higher_is_better=True, ) extractiveness = SampleLevelMetricGrouping( - metric_name=["summarization_coverage", "summarization_density", "summarization_compression"], + metric_name=[ + "summarization_coverage", + "summarization_density", + "summarization_compression", + ], sample_level_fn=Extractiveness( - normalize_input=remove_braces, normalize_pred=remove_braces_and_strip, input_column="text" + normalize_input=remove_braces, + normalize_pred=remove_braces_and_strip, + input_column="text", ), category=SamplingMethod.GENERATIVE, corpus_level_fn={ @@ -292,9 +357,16 @@ class Metrics(Enum): }, ) extractiveness_de = SampleLevelMetricGrouping( - metric_name=["summarization_coverage", "summarization_density", "summarization_compression"], + metric_name=[ + "summarization_coverage", + "summarization_density", + "summarization_compression", + ], sample_level_fn=Extractiveness( - normalize_input=remove_braces, normalize_pred=remove_braces_and_strip, input_column="text", language="de" + normalize_input=remove_braces, + normalize_pred=remove_braces_and_strip, + input_column="text", + language="de", ), category=SamplingMethod.GENERATIVE, corpus_level_fn={ @@ -309,9 +381,16 @@ class Metrics(Enum): }, ) extractiveness_fr = SampleLevelMetricGrouping( - metric_name=["summarization_coverage", "summarization_density", "summarization_compression"], + metric_name=[ + "summarization_coverage", + "summarization_density", + "summarization_compression", + ], sample_level_fn=Extractiveness( - normalize_input=remove_braces, normalize_pred=remove_braces_and_strip, input_column="text", language="fr" + normalize_input=remove_braces, + normalize_pred=remove_braces_and_strip, + input_column="text", + language="fr", ), category=SamplingMethod.GENERATIVE, corpus_level_fn={ @@ -326,9 +405,16 @@ class Metrics(Enum): }, ) extractiveness_it = SampleLevelMetricGrouping( - metric_name=["summarization_coverage", "summarization_density", "summarization_compression"], + metric_name=[ + "summarization_coverage", + "summarization_density", + "summarization_compression", + ], sample_level_fn=Extractiveness( - normalize_input=remove_braces, normalize_pred=remove_braces_and_strip, input_column="text", language="it" + normalize_input=remove_braces, + normalize_pred=remove_braces_and_strip, + input_column="text", + language="it", ), category=SamplingMethod.GENERATIVE, corpus_level_fn={ @@ -366,7 +452,9 @@ class Metrics(Enum): faithfulness = SampleLevelMetric( metric_name="summac", sample_level_fn=Faithfulness( - normalize_input=remove_braces, normalize_pred=remove_braces_and_strip, input_column="text" + normalize_input=remove_braces, + normalize_pred=remove_braces_and_strip, + input_column="text", ), category=SamplingMethod.GENERATIVE, corpus_level_fn=np.mean, @@ -390,7 +478,10 @@ class Metrics(Enum): precision=5, gold_extraction_target=(ExprExtractionConfig(),), # Match boxed first before trying other regexes - pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig(boxed_match_priority=0)), + pred_extraction_target=( + ExprExtractionConfig(), + LatexExtractionConfig(boxed_match_priority=0), + ), aggregation_function=max, ), ), @@ -409,7 +500,10 @@ class Metrics(Enum): precision=5, gold_extraction_target=(LatexExtractionConfig(),), # Match boxed first before trying other regexes - pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig(boxed_match_priority=0)), + pred_extraction_target=( + ExprExtractionConfig(), + LatexExtractionConfig(boxed_match_priority=0), + ), aggregation_function=max, ), ), @@ -473,8 +567,14 @@ class Metrics(Enum): # Extracting mathematical expressions and latex expressions sample_scoring_function=MultilingualExtractiveMatchMetric( language=Language.ENGLISH, - gold_extraction_target=[ExprExtractionConfig(), LatexExtractionConfig()], - pred_extraction_target=[ExprExtractionConfig(), LatexExtractionConfig()], + gold_extraction_target=[ + ExprExtractionConfig(), + LatexExtractionConfig(), + ], + pred_extraction_target=[ + ExprExtractionConfig(), + LatexExtractionConfig(), + ], precision=6, ), ), @@ -519,8 +619,18 @@ class Metrics(Enum): normalize_pred=bigbench_normalizer, ), category=SamplingMethod.GENERATIVE, - corpus_level_fn={"rouge1": np.mean, "rouge2": np.mean, "rougeL": np.mean, "rougeLsum": np.mean}, - higher_is_better={"rouge1": True, "rouge2": True, "rougeL": True, "rougeLsum": True}, + corpus_level_fn={ + "rouge1": np.mean, + "rouge2": np.mean, + "rougeL": np.mean, + "rougeLsum": np.mean, + }, + higher_is_better={ + "rouge1": True, + "rouge2": True, + "rougeL": True, + "rougeLsum": True, + }, ) rouge1 = SampleLevelMetric( metric_name="rouge1", @@ -593,10 +703,16 @@ class Metrics(Enum): sample_level_fn=MultilingualExtractiveMatchMetric( language=Language.ENGLISH, gold_extraction_target=[ - IndicesExtractionConfig(prefix_for_extraction="NativeLetters", try_extract_without_anchor=True) + IndicesExtractionConfig( + prefix_for_extraction="NativeLetters", + try_extract_without_anchor=True, + ) ], pred_extraction_target=[ - IndicesExtractionConfig(prefix_for_extraction="NativeLetters", try_extract_without_anchor=True) + IndicesExtractionConfig( + prefix_for_extraction="NativeLetters", + try_extract_without_anchor=True, + ) ], precision=6, ), @@ -610,10 +726,16 @@ class Metrics(Enum): sample_scoring_function=MultilingualExtractiveMatchMetric( language=Language.ENGLISH, gold_extraction_target=[ - IndicesExtractionConfig(prefix_for_extraction="NativeLetters", try_extract_without_anchor=True) + IndicesExtractionConfig( + prefix_for_extraction="NativeLetters", + try_extract_without_anchor=True, + ) ], pred_extraction_target=[ - IndicesExtractionConfig(prefix_for_extraction="NativeLetters", try_extract_without_anchor=True) + IndicesExtractionConfig( + prefix_for_extraction="NativeLetters", + try_extract_without_anchor=True, + ) ], precision=6, ), @@ -622,6 +744,13 @@ class Metrics(Enum): corpus_level_fn=np.mean, higher_is_better=True, ) + tvd_mi = CorpusLevelMetric( + metric_name="tvd_mi", + sample_level_fn=JudgeLLMTVDMI(), + category=SamplingMethod.GENERATIVE, + corpus_level_fn=CorpusLevelTVDMI(), + higher_is_better=True, + ) def __str__(self): return self.name.replace("_at_", "@") diff --git a/src/lighteval/metrics/metrics_corpus.py b/src/lighteval/metrics/metrics_corpus.py index 8532e5080..2cd630b54 100644 --- a/src/lighteval/metrics/metrics_corpus.py +++ b/src/lighteval/metrics/metrics_corpus.py @@ -192,3 +192,22 @@ def compute_corpus(self, items: list[PerplexityCorpusMetricInput]): return math.exp(-sum(logprobs) / sum(weights)) if self.metric_type == "bits_per_byte": return -sum(logprobs) / sum(weights) * 1 / math.log(2) + + +class CorpusLevelTVDMI: + def __call__(self, items): + # items: list of dicts returned by JudgeLLMTVDMI.compute + labels = np.array([it["label"] for it in items]) + preds = np.array([it["pred"] for it in items]) + + pos = labels == 1 + neg = ~pos + + if pos.sum() == 0 or neg.sum() == 0: + return {"tvd_mi": float("nan")} + + tpr = (preds[pos] == 1).mean() + tnr = (preds[neg] == 0).mean() + tvd_mi = tpr + tnr - 1.0 + + return {"tvd_mi": float(tvd_mi)} From 81ec63c049869d2eaa9939912e7feb1f15845e2d Mon Sep 17 00:00:00 2001 From: zrobertson466920 Date: Thu, 20 Nov 2025 15:58:57 -0800 Subject: [PATCH 4/6] add unit-testing + response normalization --- src/lighteval/metrics/utils/judge_utils.py | 14 ++- tests/unit/metrics/test_tvd_mi.py | 140 +++++++++++++++++++++ 2 files changed, 150 insertions(+), 4 deletions(-) create mode 100644 tests/unit/metrics/test_tvd_mi.py diff --git a/src/lighteval/metrics/utils/judge_utils.py b/src/lighteval/metrics/utils/judge_utils.py index e7cc299fd..b78d00dc0 100644 --- a/src/lighteval/metrics/utils/judge_utils.py +++ b/src/lighteval/metrics/utils/judge_utils.py @@ -158,10 +158,16 @@ def get_judge_prompt_tvdmi( def process_judge_response_tvdmi(response: str) -> int: - if response == "A": - return 1 # SAME - elif response == "B": - return 0 # DIFFERENT + # Normalize + if response is None: + return 0 + + cleaned = response.strip().lower() + + if cleaned == "a": + return 1 + elif cleaned == "b": + return 0 else: logger.warning(f"Unknown response from TVD-MI judge: {response!r}") return 0 diff --git a/tests/unit/metrics/test_tvd_mi.py b/tests/unit/metrics/test_tvd_mi.py new file mode 100644 index 000000000..a948c462c --- /dev/null +++ b/tests/unit/metrics/test_tvd_mi.py @@ -0,0 +1,140 @@ +import math +from dataclasses import dataclass + +import pytest + +from lighteval.metrics.metrics_corpus import CorpusLevelTVDMI +from lighteval.metrics.metrics_sample import JudgeLLMTVDMI +from lighteval.metrics.utils.judge_utils import ( + get_judge_prompt_tvdmi, + process_judge_response_tvdmi, +) + + +def test_get_judge_prompt_tvdmi_injects_responses(): + question = "Resp A" + answer = "Resp B" + + messages = get_judge_prompt_tvdmi(question=question, answer=answer, options=None, gold=None) + + # Should be a single chat message + assert isinstance(messages, list) + assert len(messages) == 1 + msg = messages[0] + assert msg["role"] == "user" + + content = msg["content"] + # Basic structure checks + assert "Response A:" in content + assert "Response B:" in content + assert "Resp A" in content + assert "Resp B" in content + # Should mention A/B grading + assert "A:" in content + assert "B:" in content + + +def test_process_judge_response_tvdmi_maps_A_B(): + assert process_judge_response_tvdmi("A") == 1 + assert process_judge_response_tvdmi("B") == 0 + # Robust to case/whitespace + assert process_judge_response_tvdmi(" a \n") == 1 + assert process_judge_response_tvdmi(" b\t") == 0 + + +def test_process_judge_response_tvdmi_unknown_falls_back_to_0(caplog): + with caplog.at_level("WARNING"): + out = process_judge_response_tvdmi("weird") + assert out == 0 + # Optional: check that we actually logged something + assert any("TVD-MI judge" in rec.message for rec in caplog.records) + + +def test_corpus_level_tvdmi_perfect_critic(): + # Always correct on both positive and negative + items = [ + {"label": 1, "pred": 1}, + {"label": 1, "pred": 1}, + {"label": 0, "pred": 0}, + {"label": 0, "pred": 0}, + ] + + result = CorpusLevelTVDMI()(items) + assert "tvd_mi" in result + assert result["tvd_mi"] == pytest.approx(1.0) + + +def test_corpus_level_tvdmi_random_critic(): + # 50% TPR, 50% TNR → TVD-MI = 0 + items = [ + {"label": 1, "pred": 1}, + {"label": 1, "pred": 0}, + {"label": 0, "pred": 0}, + {"label": 0, "pred": 1}, + ] + + result = CorpusLevelTVDMI()(items) + assert result["tvd_mi"] == pytest.approx(0.0) + + +def test_corpus_level_tvdmi_missing_class_returns_nan(): + # No negatives → TVD-MI undefined + items = [ + {"label": 1, "pred": 1}, + {"label": 1, "pred": 0}, + ] + + result = CorpusLevelTVDMI()(items) + assert math.isnan(result["tvd_mi"]) + + +@dataclass +class FakeDoc: + response_a: str + response_b: str + pair_label: int + + +def test_judge_tvdmi_compute(monkeypatch): + judge = JudgeLLMTVDMI() + + # Two examples: one positive, one negative + docs = [ + FakeDoc("A1", "A2", 1), + FakeDoc("B1", "B2", 0), + ] + + # Fake judge backend: we want to check what arguments it receives, + # and return deterministic scores/prompts/responses. + def fake_evaluate_answer_batch(questions, answers, options, golds, **kwargs): + # Input wiring checks + assert questions == ["A1", "B1"] + assert answers == ["A2", "B2"] + assert options == [None, None] + assert golds == [None, None] + + scores = [1, 0] # predict SAME for first, DIFFERENT for second + prompts = ["prompt-0", "prompt-1"] + responses = ["A", "B"] # raw judge outputs + return scores, prompts, responses + + # Attach a fake .judge with our method + class FakeInnerJudge: + def evaluate_answer_batch(self, *args, **kwargs): + return fake_evaluate_answer_batch(*args, **kwargs) + + monkeypatch.setattr(judge, "judge", FakeInnerJudge()) + + metrics = judge.compute(responses=[], docs=docs) + + assert len(metrics) == 2 + + # Check labels and preds propagated correctly + assert metrics[0]["label"] == 1 + assert metrics[0]["pred"] == 1 + assert metrics[1]["label"] == 0 + assert metrics[1]["pred"] == 0 + + # Check extra fields exist (names match your short_judge_name) + assert any(k.startswith("user_prompt_") for k in metrics[0].keys()) + assert any(k.startswith("judgement_") for k in metrics[0].keys()) From 0febc81496db3f924c5ac1ada14d1f920593a773 Mon Sep 17 00:00:00 2001 From: zrobertson466920 Date: Fri, 21 Nov 2025 15:37:42 -0800 Subject: [PATCH 5/6] Document tvd_mi metric --- docs/source/metric-list.mdx | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/metric-list.mdx b/docs/source/metric-list.mdx index 06d3dd069..66b862ffa 100644 --- a/docs/source/metric-list.mdx +++ b/docs/source/metric-list.mdx @@ -61,3 +61,4 @@ These metrics need the model to generate an output. They are therefore slower. - `llm_judge_llama_3_405b`: Can be used for any generative task, the model will be scored by a Llama 3.405B model using the HuggingFace API. - `llm_judge_multi_turn_gpt3p5`: Can be used for any generative task, the model will be scored by a GPT3.5 model using the OpenAI API. It is used for multiturn tasks like mt-bench. - `llm_judge_multi_turn_llama_3_405b`: Can be used for any generative task, the model will be scored by a Llama 3.405B model using the HuggingFace API. It is used for multiturn tasks like mt-bench. +- `tvd_mi`: Corpus-level LLM-as-a-judge metric that estimates a lower bound on total variation mutual information using paired responses. It assumes each example has two responses and a binary label indicating whether they are from the same underlying item (`1`) or from different items (`0`), and computes `TPR + TNR - 1` from the judge’s binary decisions. From 5a0c3dfe9f6674defcf46d86ab8a4e9e26ce27aa Mon Sep 17 00:00:00 2001 From: zrobertson466920 Date: Mon, 24 Nov 2025 13:29:17 -0800 Subject: [PATCH 6/6] Add inspect implementation for tvd-mi metric --- src/lighteval/metrics/metrics.py | 40 +++++++++++++++++++++++++++ tests/unit/metrics/test_tvd_mi.py | 45 +++++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+) diff --git a/src/lighteval/metrics/metrics.py b/src/lighteval/metrics/metrics.py index 2c6e24b99..ae3470465 100644 --- a/src/lighteval/metrics/metrics.py +++ b/src/lighteval/metrics/metrics.py @@ -58,6 +58,7 @@ PassAtK, Recall, StringDistance, + process_judge_response_tvdmi, ) from lighteval.metrics.normalizations import ( bigbench_normalizer, @@ -168,6 +169,45 @@ async def score(state: TaskState, target: Target): return score +@scorer(metrics=[accuracy(), stderr()]) +def tvd_mi_scorer(): + """ + Inspect-compatible scorer for TVD-MI pair classification. + """ + + def _normalize_gold(label: str) -> int | None: + if not isinstance(label, str): + return None + s = label.strip().upper() + if s in {"A", "SAME", "POSITIVE", "1"}: + return 1 + if s in {"B", "DIFFERENT", "NEGATIVE", "0"}: + return 0 + return None + + async def score(state: TaskState, target: Target) -> Score: + raw_pred = state.output.completion + # Interpretation mapping logic + pred_val = process_judge_response_tvdmi(raw_pred) + + try: + pred_label = int(pred_val) + except Exception: + pred_label = None + + gold_label = _normalize_gold(str(target.text)) + + correct = pred_label is not None and gold_label is not None and pred_label == gold_label + # Correct or Incorrect, used by inspect-ai backend + return Score( + value="C" if correct else "I", + explanation=raw_pred, + answer=str(pred_label), + ) + + return score + + class Metrics(Enum): acc_golds_likelihood = SampleLevelMetric( # todo: we need a better name for this! metric_name="acc", diff --git a/tests/unit/metrics/test_tvd_mi.py b/tests/unit/metrics/test_tvd_mi.py index a948c462c..f2bb5e088 100644 --- a/tests/unit/metrics/test_tvd_mi.py +++ b/tests/unit/metrics/test_tvd_mi.py @@ -1,8 +1,10 @@ +import asyncio import math from dataclasses import dataclass import pytest +from lighteval.metrics.metrics import tvd_mi_scorer from lighteval.metrics.metrics_corpus import CorpusLevelTVDMI from lighteval.metrics.metrics_sample import JudgeLLMTVDMI from lighteval.metrics.utils.judge_utils import ( @@ -138,3 +140,46 @@ def evaluate_answer_batch(self, *args, **kwargs): # Check extra fields exist (names match your short_judge_name) assert any(k.startswith("user_prompt_") for k in metrics[0].keys()) assert any(k.startswith("judgement_") for k in metrics[0].keys()) + + +# ---- Inspect-compatible scorer tests ---- + + +class _DummyOutput: + def __init__(self, completion: str): + self.completion = completion + + +class _DummyState: + def __init__(self, completion: str): + self.output = _DummyOutput(completion) + + +class _DummyTarget: + def __init__(self, text: str): + self.text = text + + +def test_tvd_mi_scorer_matches_label_A(): + """The inspect-ai tvd_mi_scorer should mark matching 'A' labels as correct.""" + scorer_fn = tvd_mi_scorer() + + state = _DummyState("A") # model/judge output + target = _DummyTarget("A") # gold label + + score = asyncio.run(scorer_fn(state, target)) + + assert score.value == "C" + assert score.answer == "1" # normalized positive class + + +def test_tvd_mi_scorer_mismatch_is_incorrect(): + """Mismatched A/B labels should be scored as incorrect.""" + scorer_fn = tvd_mi_scorer() + + state = _DummyState("B") # model says DIFFERENT + target = _DummyTarget("A") # gold SAME + + score = asyncio.run(scorer_fn(state, target)) + + assert score.value == "I"