diff --git a/docs/BEHAVIORAL_PREDICTION_ANALYSIS.md b/docs/BEHAVIORAL_PREDICTION_ANALYSIS.md index 6143b80..99a48f2 100644 --- a/docs/BEHAVIORAL_PREDICTION_ANALYSIS.md +++ b/docs/BEHAVIORAL_PREDICTION_ANALYSIS.md @@ -240,6 +240,26 @@ python scripts/lasso_with_ablations.py \ Test whether top-N receptors are *sufficient* to maintain model performance: +--- + +## 5.6 Control-Subtracted (ΔPER) Runs + +To fit LASSO on control-subtracted targets (ΔPER = opto − control), use the CLI: + +``` +python scripts/run_lasso_behavioral_prediction.py \ + --door_cache door_cache \ + --behavior_csv /path/to/reaction_rates_summary_unordered.csv \ + --condition opto_hex \ + --subtract_control \ + --missing_control_policy skip \ + --output_dir outputs/lasso_behavioral_prediction +``` + +Use `--control_condition` to override the default opto→control mapping, and +`--also_run_raw` to generate a side-by-side comparison summary CSV. +If a condition lacks a matched control, the CLI logs a warning and falls back to raw mode. + ```bash conda activate DoOR python scripts/lasso_with_focus_mode.py \ diff --git a/scripts/run_lasso_behavioral_prediction.py b/scripts/run_lasso_behavioral_prediction.py new file mode 100644 index 0000000..a965e23 --- /dev/null +++ b/scripts/run_lasso_behavioral_prediction.py @@ -0,0 +1,239 @@ +#!/usr/bin/env python3 +""" +Run LASSO behavioral prediction with optional control subtraction. + +Example: + python scripts/run_lasso_behavioral_prediction.py \ + --door_cache door_cache \ + --behavior_csv /path/to/reaction_rates_summary_unordered.csv \ + --condition opto_hex \ + --output_dir outputs/lasso_behavioral_prediction \ + --prediction_mode test_odorant \ + --cv_folds 5 \ + --lambda_range 0.0001,0.001,0.01,0.1,1.0 +""" + +from __future__ import annotations + +import argparse +import logging +from pathlib import Path +import sys +from typing import List, Optional + +# Add src to path for repo-local runs (matches other scripts in this repo). +sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src")) + +from door_toolkit.pathways import LassoBehavioralPredictor + +logger = logging.getLogger(__name__) + + +def _parse_conditions(values: List[str]) -> List[str]: + conditions: List[str] = [] + seen = set() + for value in values: + for token in value.split(","): + token = token.strip() + if not token or token in seen: + continue + conditions.append(token) + seen.add(token) + return conditions + + +def _parse_lambda_range(value: Optional[str]) -> Optional[List[float]]: + if value is None: + return None + tokens = [token.strip() for token in value.split(",") if token.strip()] + if not tokens: + raise ValueError("lambda_range cannot be empty.") + try: + return [float(token) for token in tokens] + except ValueError as exc: + raise ValueError(f"Invalid lambda_range value: {value}") from exc + + +def _run_condition( + predictor: LassoBehavioralPredictor, + *, + condition_name: str, + mode_label: str, + subtract_control: bool, + control_condition: Optional[str], + missing_control_policy: str, + prediction_mode: str, + lambda_range: Optional[List[float]], + cv_folds: int, + output_dir: Path, +): + logger.info("Running %s mode for condition '%s'", mode_label, condition_name) + condition_dir = output_dir / condition_name + condition_dir.mkdir(parents=True, exist_ok=True) + results = predictor.fit_behavior( + condition_name=condition_name, + lambda_range=lambda_range, + cv_folds=cv_folds, + prediction_mode=prediction_mode, + subtract_control=subtract_control, + control_condition=control_condition, + missing_control_policy=missing_control_policy, + ) + + prefix = f"{mode_label}" + results.plot_predictions(save_to=str(condition_dir / f"{prefix}_predictions.png")) + results.plot_receptors(save_to=str(condition_dir / f"{prefix}_receptors.png")) + results.export_csv(str(condition_dir / f"{prefix}_results.csv")) + results.export_json(str(condition_dir / f"{prefix}_model.json")) + + return results + + +def _is_missing_control_error(exc: Exception) -> bool: + message = str(exc).lower() + if "no matched control mapping" in message: + return True + if "control" in message and "not found in behavioral data" in message: + return True + return False + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Run LASSO behavioral prediction with optional control subtraction.", + ) + + parser.add_argument("--door_cache", required=True, help="Path to DoOR cache directory.") + parser.add_argument("--behavior_csv", required=True, help="Path to behavioral matrix CSV.") + parser.add_argument( + "--condition", + required=True, + action="append", + help="Condition name(s). Repeat flag or pass comma-separated list.", + ) + parser.add_argument("--output_dir", required=True, help="Directory to write outputs.") + + parser.add_argument("--subtract_control", action="store_true", help="Fit ΔPER = opto - control.") + parser.add_argument( + "--also_run_raw", + action="store_true", + help="If set, run raw PER alongside ΔPER and write a comparison summary.", + ) + parser.add_argument( + "--control_condition", + default=None, + help="Optional control dataset override (applies to all conditions).", + ) + parser.add_argument( + "--missing_control_policy", + choices=["skip", "zero", "error"], + default="skip", + help="How to handle missing control values.", + ) + + parser.add_argument( + "--prediction_mode", + choices=["test_odorant", "trained_odorant", "interaction"], + default="test_odorant", + help="Feature extraction mode.", + ) + parser.add_argument("--cv_folds", type=int, default=5, help="Number of CV folds.") + parser.add_argument( + "--lambda_range", + default=None, + help="Comma-separated lambda values (e.g., 0.0001,0.001,0.01).", + ) + + return parser.parse_args() + + +def main() -> None: + logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") + args = _parse_args() + + conditions = _parse_conditions(args.condition) + if not conditions: + raise ValueError("No valid conditions provided.") + + lambda_range = _parse_lambda_range(args.lambda_range) + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + predictor = LassoBehavioralPredictor( + doorcache_path=args.door_cache, + behavior_csv_path=args.behavior_csv, + ) + + modes_to_run = [] + if args.also_run_raw or not args.subtract_control: + modes_to_run.append(("raw", False)) + if args.subtract_control: + modes_to_run.append(("delta", True)) + + summary_rows = [] + for condition_name in conditions: + ran_raw = False + for mode_label, subtract_control in modes_to_run: + try: + results = _run_condition( + predictor, + condition_name=condition_name, + mode_label=mode_label, + subtract_control=subtract_control, + control_condition=args.control_condition, + missing_control_policy=args.missing_control_policy, + prediction_mode=args.prediction_mode, + lambda_range=lambda_range, + cv_folds=args.cv_folds, + output_dir=output_dir, + ) + except ValueError as exc: + if subtract_control and _is_missing_control_error(exc): + logger.warning( + "No control found for '%s'; falling back to raw mode.", + condition_name, + ) + if ran_raw: + logger.info("Raw mode already completed for '%s'; skipping delta.", condition_name) + continue + results = _run_condition( + predictor, + condition_name=condition_name, + mode_label="raw", + subtract_control=False, + control_condition=None, + missing_control_policy=args.missing_control_policy, + prediction_mode=args.prediction_mode, + lambda_range=lambda_range, + cv_folds=args.cv_folds, + output_dir=output_dir, + ) + mode_label = "raw" + else: + raise + + if mode_label == "raw": + ran_raw = True + + summary_rows.append( + { + "condition": condition_name, + "mode": mode_label, + "cv_mse": results.cv_mse, + "cv_r2": results.cv_r2_score, + "n_receptors_selected": results.n_receptors_selected, + "control_condition": results.control_condition, + } + ) + + if args.also_run_raw: + import pandas as pd + + summary_df = pd.DataFrame(summary_rows) + summary_df.to_csv(output_dir / "lasso_summary_comparison.csv", index=False) + logger.info("Wrote comparison summary to %s", output_dir / "lasso_summary_comparison.csv") + + +if __name__ == "__main__": + main() diff --git a/src/door_toolkit/pathways/behavioral_prediction.py b/src/door_toolkit/pathways/behavioral_prediction.py index 77af4b4..a080a2e 100644 --- a/src/door_toolkit/pathways/behavioral_prediction.py +++ b/src/door_toolkit/pathways/behavioral_prediction.py @@ -16,11 +16,12 @@ import json import logging +import re import warnings from dataclasses import dataclass, field from difflib import get_close_matches from pathlib import Path -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Literal, Optional, Tuple, Union import matplotlib import matplotlib.pyplot as plt @@ -282,6 +283,9 @@ class BehaviorModelResults: actual_per: Actual PER values predicted_per: Predicted PER values receptor_coverage: Number of receptors with data for trained odorant + subtract_control: Whether target PER values were control-subtracted + control_condition: Control condition used for subtraction (if any) + n_pairs_used: Number of matched opto/control odorant pairs used """ condition_name: str @@ -298,6 +302,9 @@ class BehaviorModelResults: receptor_coverage: int feature_matrix: Optional[np.ndarray] = None receptor_names: List[str] = field(default_factory=list) + subtract_control: bool = False + control_condition: Optional[str] = None + n_pairs_used: int = 0 def get_top_receptors(self, n: int = 10) -> List[Tuple[str, float]]: """ @@ -467,6 +474,9 @@ def export_json(self, output_path: str): "cv_mse": float(self.cv_mse), "n_receptors_selected": int(self.n_receptors_selected), "receptor_coverage": int(self.receptor_coverage), + "subtract_control": bool(self.subtract_control), + "control_condition": self.control_condition, + "n_pairs_used": int(self.n_pairs_used), "lasso_weights": {k: float(v) for k, v in self.lasso_weights.items()}, "top_10_receptors": [ {"receptor": r, "weight": float(w)} for r, w in self.get_top_receptors(10) @@ -587,6 +597,14 @@ class LassoBehavioralPredictor: "opto_3oct": "3-Octonol", # Alternative naming } + # Matched opto → control mapping for control-subtracted fits. + # Local mapping because no shared helper exists in the codebase today. + OPTO_CONTROL_MAPPING = { + "opto_hex": "hex_control", + "opto_EB": "EB_control", + "opto_benz_1": "Benz_control", + } + def __init__( self, doorcache_path: str, @@ -755,6 +773,42 @@ def match_odorant_name(self, csv_odorant_name: str) -> Optional[str]: csv_name_clean = csv_odorant_name.lower().replace("_", "").replace(" ", "").replace("-", "") return self.odorant_to_door.get(csv_name_clean) + @staticmethod + def _normalize_dataset_name(dataset_name: str) -> str: + """Normalize dataset labels for case/format-insensitive matching.""" + return re.sub(r"[^a-z0-9]+", "", str(dataset_name).lower()) + + def _resolve_dataset_name(self, dataset_name: str) -> Optional[str]: + """ + Resolve a dataset label to the exact index entry in behavioral_data. + + Returns None if no match is found; raises if the match is ambiguous. + """ + if dataset_name in self.behavioral_data.index: + return dataset_name + + normalized = self._normalize_dataset_name(dataset_name) + matches = [ + idx + for idx in self.behavioral_data.index + if self._normalize_dataset_name(idx) == normalized + ] + if len(matches) == 1: + return matches[0] + if len(matches) > 1: + raise ValueError( + f"Ambiguous condition '{dataset_name}' matches multiple datasets: {matches}. " + "Please specify the exact dataset name." + ) + return None + + def _infer_control_condition(self, condition_name: str) -> Optional[str]: + """Infer matched control condition from known opto/control pairs.""" + mapping = { + self._normalize_dataset_name(k): v for k, v in self.OPTO_CONTROL_MAPPING.items() + } + return mapping.get(self._normalize_dataset_name(condition_name)) + def get_receptor_profile( self, odorant_name: str, fill_missing: float = 0.0 ) -> Tuple[np.ndarray, int]: @@ -808,6 +862,9 @@ def fit_behavior( lambda_range: Optional[List[float]] = None, cv_folds: int = 5, prediction_mode: str = "test_odorant", + subtract_control: bool = False, + control_condition: Optional[str] = None, + missing_control_policy: Literal["skip", "zero", "error"] = "skip", ) -> BehaviorModelResults: """ Fit LASSO model to predict PER from receptor profiles. @@ -821,6 +878,12 @@ def fit_behavior( - "test_odorant": Use test odorant receptor profiles (default) - "trained_odorant": Use trained odorant receptor profile - "interaction": Use element-wise product of trained × test + subtract_control: If True, fit on PER(opto) - PER(control) + control_condition: Optional control dataset override + missing_control_policy: How to handle missing control values: + - "skip": use only odorants present in both opto/control + - "zero": treat missing control values as 0 (opto must be present) + - "error": raise if control is missing where opto is present Returns: BehaviorModelResults object with fitted model and metrics @@ -832,18 +895,90 @@ def fit_behavior( """ logger.info(f"Fitting LASSO model for condition: {condition_name}") - # Get behavioral responses for this condition - if condition_name not in self.behavioral_data.index: + resolved_condition = self._resolve_dataset_name(condition_name) + if resolved_condition is None: raise ValueError(f"Condition '{condition_name}' not found in behavioral data") + condition_name = resolved_condition + + control_condition_resolved: Optional[str] = None + n_pairs_used = 0 + + if subtract_control: + if missing_control_policy not in {"skip", "zero", "error"}: + raise ValueError( + f"Unknown missing_control_policy '{missing_control_policy}'. " + "Expected one of: skip, zero, error." + ) + + if control_condition is not None: + control_condition_resolved = self._resolve_dataset_name(control_condition) + if control_condition_resolved is None: + raise ValueError( + f"Control condition '{control_condition}' not found in behavioral data" + ) + else: + control_candidate = self._infer_control_condition(condition_name) + if control_candidate is None: + raise ValueError( + f"No matched control mapping for '{condition_name}'. " + "Provide control_condition or set subtract_control=False." + ) + control_condition_resolved = self._resolve_dataset_name(control_candidate) + if control_condition_resolved is None: + raise ValueError( + f"Matched control '{control_candidate}' not found in behavioral data" + ) + + if control_condition_resolved == condition_name: + raise ValueError( + f"Control condition '{control_condition_resolved}' matches opto condition " + f"'{condition_name}'." + ) + + per_opto = self.behavioral_data.loc[condition_name] + per_ctrl = self.behavioral_data.loc[control_condition_resolved] - per_responses = self.behavioral_data.loc[condition_name] + if missing_control_policy == "skip": + valid_mask = per_opto.notna() & per_ctrl.notna() + valid_odorants = (per_opto - per_ctrl)[valid_mask] + elif missing_control_policy == "zero": + valid_mask = per_opto.notna() + per_ctrl_filled = per_ctrl.fillna(0.0) + valid_odorants = (per_opto - per_ctrl_filled)[valid_mask] + else: + missing_mask = per_opto.notna() & per_ctrl.isna() + if missing_mask.any(): + missing_odorants = [str(o) for o in per_opto.index[missing_mask]] + preview = ", ".join(missing_odorants[:5]) + if len(missing_odorants) > 5: + preview = f"{preview} (and {len(missing_odorants) - 5} more)" + raise ValueError( + f"Control condition '{control_condition_resolved}' has missing values for " + f"odorants present in '{condition_name}': {preview}" + ) + valid_mask = per_opto.notna() & per_ctrl.notna() + valid_odorants = (per_opto - per_ctrl)[valid_mask] + + if len(valid_odorants) == 0: + raise ValueError( + f"No valid opto/control pairs for condition '{condition_name}' " + f"and control '{control_condition_resolved}'." + ) + + n_pairs_used = int(len(valid_odorants)) + logger.info( + f"Using control condition '{control_condition_resolved}' with policy " + f"'{missing_control_policy}': {n_pairs_used} odorants after alignment" + ) + else: + per_responses = self.behavioral_data.loc[condition_name] - # Filter out NaN and untested odorants - valid_odorants = per_responses.dropna() - if len(valid_odorants) == 0: - raise ValueError(f"No valid PER data for condition '{condition_name}'") + # Filter out NaN and untested odorants + valid_odorants = per_responses.dropna() + if len(valid_odorants) == 0: + raise ValueError(f"No valid PER data for condition '{condition_name}'") - logger.info(f"Found {len(valid_odorants)} valid test odorants") + logger.info(f"Found {len(valid_odorants)} valid test odorants") # Auto-detect trained odorant (best-effort for all prediction modes) # Decision: Attempt auto-detection for all modes to populate results.trained_odorant @@ -996,6 +1131,9 @@ def fit_behavior( receptor_coverage=trained_coverage, feature_matrix=X, receptor_names=list(active_receptor_names), + subtract_control=subtract_control, + control_condition=control_condition_resolved, + n_pairs_used=n_pairs_used, ) return results diff --git a/tests/test_lasso_behavioral_prediction.py b/tests/test_lasso_behavioral_prediction.py index b59b733..c6f818f 100644 --- a/tests/test_lasso_behavioral_prediction.py +++ b/tests/test_lasso_behavioral_prediction.py @@ -27,6 +27,18 @@ def mock_behavioral_csv(tmp_path): return csv_path +@pytest.fixture +def control_behavioral_csv(tmp_path): + """Create mock behavioral CSV with opto/control rows for subtraction tests.""" + csv_content = """dataset,Hexanol,Benzaldehyde,Linalool,Citral +opto_hex,0.8,0.2,0.4,0.6 +hex_control,0.1,,0.05,0.3 +""" + csv_path = tmp_path / "test_behavior_control.csv" + csv_path.write_text(csv_content) + return csv_path + + @pytest.fixture def lasso_predictor(mock_door_cache, mock_behavioral_csv): """Create LassoBehavioralPredictor instance for testing.""" @@ -38,6 +50,17 @@ def lasso_predictor(mock_door_cache, mock_behavioral_csv): ) +@pytest.fixture +def control_lasso_predictor(mock_door_cache, control_behavioral_csv): + """Create LassoBehavioralPredictor for control subtraction tests.""" + return LassoBehavioralPredictor( + doorcache_path=str(mock_door_cache), + behavior_csv_path=str(control_behavioral_csv), + scale_features=True, + scale_targets=False, + ) + + class TestLassoBehavioralPredictor: """Tests for LassoBehavioralPredictor class.""" @@ -418,3 +441,66 @@ def test_insufficient_samples(self, mock_door_cache, tmp_path): with pytest.raises(ValueError, match="Insufficient data"): predictor.fit_behavior("opto_minimal", lambda_range=[0.1], cv_folds=5) + + +class TestLassoBehavioralPredictorControlSubtraction: + """Tests for control-subtracted LASSO fits.""" + + def test_subtract_control_skip_aligns_and_drops_nans(self, control_lasso_predictor): + results = control_lasso_predictor.fit_behavior( + condition_name="opto_hex", + lambda_range=[0.1], + cv_folds=2, + subtract_control=True, + missing_control_policy="skip", + ) + + assert "Benzaldehyde" not in results.test_odorants + assert len(results.test_odorants) == 3 + + actual_map = dict(zip(results.test_odorants, results.actual_per)) + assert actual_map["Hexanol"] == pytest.approx(0.7) + assert actual_map["Linalool"] == pytest.approx(0.35) + assert actual_map["Citral"] == pytest.approx(0.3) + + def test_subtract_control_zero_fills_missing_control(self, control_lasso_predictor): + results = control_lasso_predictor.fit_behavior( + condition_name="opto_hex", + lambda_range=[0.1], + cv_folds=2, + subtract_control=True, + missing_control_policy="zero", + ) + + assert "Benzaldehyde" in results.test_odorants + assert len(results.test_odorants) == 4 + + actual_map = dict(zip(results.test_odorants, results.actual_per)) + assert actual_map["Benzaldehyde"] == pytest.approx(0.2) + + def test_subtract_control_error_raises_on_missing_control_values( + self, control_lasso_predictor + ): + with pytest.raises(ValueError, match="missing values"): + control_lasso_predictor.fit_behavior( + condition_name="opto_hex", + lambda_range=[0.1], + cv_folds=2, + subtract_control=True, + missing_control_policy="error", + ) + + def test_subtract_control_missing_control_row_warns_or_errors( + self, control_lasso_predictor + ): + control_lasso_predictor.behavioral_data = control_lasso_predictor.behavioral_data.drop( + index=["hex_control"] + ) + + with pytest.raises(ValueError, match="control"): + control_lasso_predictor.fit_behavior( + condition_name="opto_hex", + lambda_range=[0.1], + cv_folds=2, + subtract_control=True, + )