From 026dad9a285bb2776e567cc68581e9862c5f58ab Mon Sep 17 00:00:00 2001 From: ramanlab Date: Thu, 8 Jan 2026 12:14:24 -0600 Subject: [PATCH 1/4] feat: Add LASSO ablation analysis for receptor circuit robustness Implements ablation analysis to measure the robustness of LASSO-identified receptor circuits by zeroing out selected receptor channels and refitting. New features: - scripts/lasso_with_ablations.py: CLI script for ablation analysis - Supports --ablate (comma-separated) and --ablate_file (one per line) - ablation_set_mode: single (each receptor individually) or all_in_one - missing_receptor_policy: error or skip - Outputs: baseline_model.json, ablation_summary.csv, per-ablation artifacts - Helper functions in behavioral_prediction.py: - resolve_receptor_names(): Case-insensitive receptor name matching - apply_receptor_ablation(): Zero out specified receptor columns - fit_lasso_with_fixed_scaler(): Refit LASSO with preserved baseline scaler Key design decisions: - Fit scaler on baseline X, reuse for ablation variants (apples-to-apples) - Case-insensitive exact matching only (no fuzzy matching) - Deterministic ordering and filenames Tests: 25 new tests in test_lasso_ablation.py covering: - Receptor name resolution - Ablation matrix operations - LASSO fitting with fixed scaler - Output file format validation Co-Authored-By: Claude Opus 4.5 --- scripts/lasso_with_ablations.py | 737 ++++++++++++++++++ src/door_toolkit/pathways/__init__.py | 6 + .../pathways/behavioral_prediction.py | 220 ++++++ tests/test_lasso_ablation.py | 543 +++++++++++++ 4 files changed, 1506 insertions(+) create mode 100644 scripts/lasso_with_ablations.py create mode 100644 tests/test_lasso_ablation.py diff --git a/scripts/lasso_with_ablations.py b/scripts/lasso_with_ablations.py new file mode 100644 index 0000000..ad373dc --- /dev/null +++ b/scripts/lasso_with_ablations.py @@ -0,0 +1,737 @@ +#!/usr/bin/env python3 +""" +LASSO Ablation Analysis Script +============================== + +Refit LASSO behavioral prediction models after ablating (zeroing) selected +receptor feature channels, measuring robustness of the identified receptor +circuits. + +This script: +1. Fits a baseline LASSO model (no ablation) +2. For each ablation scenario, zeros out specified receptor columns +3. Refits LASSO using the SAME scaler as baseline (apples-to-apples comparison) +4. Saves detailed artifacts and summary statistics + +Example usage: + # Ablate multiple receptors together + python scripts/lasso_with_ablations.py \\ + --door_cache door_cache \\ + --behavior_csv reaction_rates.csv \\ + --condition opto_hex \\ + --ablate Or42b,Or47b \\ + --ablation_set_mode all_in_one \\ + --output_dir outputs/ablation/ + + # Ablate receptors one at a time + python scripts/lasso_with_ablations.py \\ + --door_cache door_cache \\ + --behavior_csv reaction_rates.csv \\ + --condition opto_hex \\ + --ablate Or42b,Or47b,Or22a \\ + --ablation_set_mode single \\ + --output_dir outputs/ablation/ + + # Ablate receptors from a file + python scripts/lasso_with_ablations.py \\ + --door_cache door_cache \\ + --behavior_csv reaction_rates.csv \\ + --condition opto_hex \\ + --ablate_file receptors_to_ablate.txt \\ + --output_dir outputs/ablation/ +""" + +from __future__ import annotations + +import argparse +import json +import logging +import sys +from dataclasses import dataclass, asdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from sklearn.preprocessing import StandardScaler + +matplotlib.use("Agg") + +from door_toolkit.pathways.behavioral_prediction import ( + LassoBehavioralPredictor, + apply_receptor_ablation, + fit_lasso_with_fixed_scaler, + resolve_receptor_names, +) + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) +logger = logging.getLogger(__name__) + + +@dataclass +class AblationResult: + """Results from a single ablation run.""" + + ablation_name: str + receptors_ablated: List[str] + ablated_indices: List[int] + cv_r2: float + cv_mse: float + n_receptors_selected: int + lambda_value: float + lasso_weights: Dict[str, float] + delta_r2: float = 0.0 + delta_mse: float = 0.0 + + +def parse_args() -> argparse.Namespace: + """Parse command-line arguments.""" + parser = argparse.ArgumentParser( + description="LASSO ablation analysis for receptor circuit robustness.", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + + # Required arguments + parser.add_argument( + "--door_cache", + type=str, + required=True, + help="Path to DoOR cache directory", + ) + parser.add_argument( + "--behavior_csv", + type=str, + required=True, + help="Path to behavioral CSV (reaction_rates_summary_unordered.csv)", + ) + parser.add_argument( + "--condition", + type=str, + required=True, + help="Optogenetic condition name (e.g., opto_hex)", + ) + parser.add_argument( + "--output_dir", + type=str, + required=True, + help="Directory for output files", + ) + + # Ablation specification (mutually exclusive) + ablation_group = parser.add_mutually_exclusive_group(required=True) + ablation_group.add_argument( + "--ablate", + type=str, + help="Comma-separated list of receptor names to ablate", + ) + ablation_group.add_argument( + "--ablate_file", + type=str, + help="File with receptor names (one per line) to ablate", + ) + + # Ablation mode + parser.add_argument( + "--ablation_set_mode", + type=str, + choices=["single", "all_in_one"], + default="all_in_one", + help=( + "How to process the ablation set: " + "'single' = ablate each receptor individually; " + "'all_in_one' = ablate all receptors together (default)" + ), + ) + + # Prediction mode + parser.add_argument( + "--prediction_mode", + type=str, + choices=["test_odorant", "trained_odorant", "interaction"], + default="test_odorant", + help="Feature extraction mode (default: test_odorant)", + ) + + # LASSO parameters + parser.add_argument( + "--cv_folds", + type=int, + default=5, + help="Number of cross-validation folds (default: 5)", + ) + parser.add_argument( + "--lambda_range", + type=str, + default="0.0001,0.001,0.01,0.1,1.0", + help="Comma-separated lambda values for LASSO CV (default: 0.0001,0.001,0.01,0.1,1.0)", + ) + parser.add_argument( + "--lambda_value", + type=float, + default=None, + help="Fixed lambda value (overrides --lambda_range)", + ) + + # Receptor resolution + parser.add_argument( + "--missing_receptor_policy", + type=str, + choices=["error", "skip"], + default="error", + help="Policy for unresolved receptor names: 'error' or 'skip' (default: error)", + ) + + # Scaling + parser.add_argument( + "--scale_features", + action="store_true", + default=True, + help="Standardize receptor features (default: True)", + ) + parser.add_argument( + "--no_scale_features", + action="store_false", + dest="scale_features", + help="Do not standardize receptor features", + ) + + return parser.parse_args() + + +def load_receptors_from_file(filepath: str) -> List[str]: + """Load receptor names from a file (one per line).""" + path = Path(filepath) + if not path.exists(): + raise FileNotFoundError(f"Receptor file not found: {filepath}") + + receptors = [] + with open(path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if line and not line.startswith("#"): + receptors.append(line) + + logger.info(f"Loaded {len(receptors)} receptors from {filepath}") + return receptors + + +def save_weights_csv( + weights: Dict[str, float], filepath: Path, condition: str, ablation_name: str +) -> None: + """Save LASSO weights to CSV.""" + rows = [] + for receptor, weight in sorted(weights.items(), key=lambda x: abs(x[1]), reverse=True): + rows.append({ + "condition": condition, + "ablation": ablation_name, + "receptor": receptor, + "weight": weight, + "abs_weight": abs(weight), + }) + + df = pd.DataFrame(rows) + filepath.parent.mkdir(parents=True, exist_ok=True) + df.to_csv(filepath, index=False) + logger.info(f"Saved weights to {filepath}") + + +def save_model_json( + result: AblationResult, + condition: str, + prediction_mode: str, + n_samples: int, + n_receptors_total: int, + filepath: Path, +) -> None: + """Save model metadata to JSON.""" + data = { + "condition_name": condition, + "ablation_name": result.ablation_name, + "receptors_ablated": result.receptors_ablated, + "ablated_indices": result.ablated_indices, + "prediction_mode": prediction_mode, + "n_samples": n_samples, + "n_receptors_total": n_receptors_total, + "cv_r2": result.cv_r2, + "cv_mse": result.cv_mse, + "delta_r2": result.delta_r2, + "delta_mse": result.delta_mse, + "lambda_value": result.lambda_value, + "n_receptors_selected": result.n_receptors_selected, + "lasso_weights": result.lasso_weights, + "top_10_receptors": [ + {"receptor": r, "weight": w} + for r, w in sorted( + result.lasso_weights.items(), key=lambda x: abs(x[1]), reverse=True + )[:10] + ], + } + + filepath.parent.mkdir(parents=True, exist_ok=True) + with open(filepath, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2) + logger.info(f"Saved model to {filepath}") + + +def plot_ablation_comparison( + baseline_result: AblationResult, + ablation_results: List[AblationResult], + output_dir: Path, + condition: str, +) -> None: + """Generate bar charts comparing baseline vs ablations.""" + if not ablation_results: + return + + # Prepare data + names = ["Baseline"] + [r.ablation_name for r in ablation_results] + r2_values = [baseline_result.cv_r2] + [r.cv_r2 for r in ablation_results] + mse_values = [baseline_result.cv_mse] + [r.cv_mse for r in ablation_results] + + # Create figure with two subplots + fig, axes = plt.subplots(1, 2, figsize=(14, 5)) + + # R² comparison + ax1 = axes[0] + colors = ["green"] + ["coral"] * len(ablation_results) + bars1 = ax1.bar(range(len(names)), r2_values, color=colors, alpha=0.7, edgecolor="k") + ax1.set_xticks(range(len(names))) + ax1.set_xticklabels(names, rotation=45, ha="right") + ax1.set_ylabel("Cross-validated R²", fontsize=12) + ax1.set_title(f"{condition} - R² Comparison", fontsize=14) + ax1.axhline(baseline_result.cv_r2, color="green", linestyle="--", alpha=0.5, label="Baseline") + ax1.legend() + ax1.grid(alpha=0.3, axis="y") + + # Add value labels + for bar, val in zip(bars1, r2_values): + ax1.text( + bar.get_x() + bar.get_width() / 2, + bar.get_height() + 0.01, + f"{val:.3f}", + ha="center", + va="bottom", + fontsize=9, + ) + + # MSE comparison + ax2 = axes[1] + bars2 = ax2.bar(range(len(names)), mse_values, color=colors, alpha=0.7, edgecolor="k") + ax2.set_xticks(range(len(names))) + ax2.set_xticklabels(names, rotation=45, ha="right") + ax2.set_ylabel("Cross-validated MSE", fontsize=12) + ax2.set_title(f"{condition} - MSE Comparison", fontsize=14) + ax2.axhline(baseline_result.cv_mse, color="green", linestyle="--", alpha=0.5, label="Baseline") + ax2.legend() + ax2.grid(alpha=0.3, axis="y") + + # Add value labels + for bar, val in zip(bars2, mse_values): + ax2.text( + bar.get_x() + bar.get_width() / 2, + bar.get_height() + 0.001, + f"{val:.4f}", + ha="center", + va="bottom", + fontsize=9, + ) + + plt.tight_layout() + + # Save + plot_path = output_dir / "ablation_comparison.png" + plt.savefig(plot_path, dpi=300, bbox_inches="tight") + plt.close() + logger.info(f"Saved ablation comparison plot to {plot_path}") + + +def run_baseline( + predictor: LassoBehavioralPredictor, + condition: str, + prediction_mode: str, + lambda_range: np.ndarray, + cv_folds: int, + scale_features: bool, +) -> Tuple[AblationResult, np.ndarray, np.ndarray, List[str], Optional[StandardScaler]]: + """Run baseline LASSO fit (no ablation). + + Returns: + Tuple of (baseline_result, X, y, receptor_names, scaler) + """ + logger.info(f"Fitting baseline model for condition: {condition}") + + # Get behavioral responses + per_responses = predictor.behavioral_data.loc[condition] + valid_odorants = per_responses.dropna() + + if len(valid_odorants) == 0: + raise ValueError(f"No valid PER data for condition '{condition}'") + + # Extract features based on prediction mode + if prediction_mode == "test_odorant": + X, test_odorants, y = predictor._extract_test_odorant_features(valid_odorants) + elif prediction_mode == "trained_odorant": + trained_odorant = predictor.CONDITION_ODORANT_MAPPING.get(condition) + if not trained_odorant: + raise ValueError(f"Could not determine trained odorant for {condition}") + X, test_odorants, y = predictor._extract_trained_odorant_features( + trained_odorant, valid_odorants + ) + elif prediction_mode == "interaction": + trained_odorant = predictor.CONDITION_ODORANT_MAPPING.get(condition) + if not trained_odorant: + raise ValueError(f"Could not determine trained odorant for {condition}") + X, test_odorants, y = predictor._extract_interaction_features( + trained_odorant, valid_odorants + ) + else: + raise ValueError(f"Unknown prediction_mode: {prediction_mode}") + + if X.shape[0] < 3: + raise ValueError(f"Insufficient data: only {X.shape[0]} samples") + + # Get receptor names + receptor_names = list(predictor.masked_receptor_names) + + # Fit scaler on baseline X (if scaling enabled) + if scale_features: + scaler = StandardScaler() + scaler.fit(X) + else: + scaler = None + + # Fit baseline model + weights, cv_r2, cv_mse, best_lambda, y_pred = fit_lasso_with_fixed_scaler( + X=X, + y=y, + receptor_names=receptor_names, + scaler=scaler, + lambda_range=lambda_range, + cv_folds=cv_folds, + ) + + baseline_result = AblationResult( + ablation_name="baseline", + receptors_ablated=[], + ablated_indices=[], + cv_r2=cv_r2, + cv_mse=cv_mse, + n_receptors_selected=len(weights), + lambda_value=best_lambda, + lasso_weights=weights, + delta_r2=0.0, + delta_mse=0.0, + ) + + logger.info( + f"Baseline: R² = {cv_r2:.4f}, MSE = {cv_mse:.4f}, " + f"λ = {best_lambda:.6f}, {len(weights)} receptors selected" + ) + + return baseline_result, X, y, receptor_names, scaler + + +def run_ablation( + X: np.ndarray, + y: np.ndarray, + receptor_names: List[str], + receptors_to_ablate: List[str], + scaler: Optional[StandardScaler], + lambda_range: np.ndarray, + cv_folds: int, + baseline_r2: float, + baseline_mse: float, + ablation_name: str, +) -> AblationResult: + """Run LASSO with specified receptors ablated.""" + logger.info(f"Running ablation: {ablation_name} ({receptors_to_ablate})") + + # Apply ablation + X_ablated, ablated_indices = apply_receptor_ablation( + X=X, + receptor_names=receptor_names, + receptors_to_ablate=receptors_to_ablate, + ) + + # Fit model with same scaler + weights, cv_r2, cv_mse, best_lambda, y_pred = fit_lasso_with_fixed_scaler( + X=X_ablated, + y=y, + receptor_names=receptor_names, + scaler=scaler, + lambda_range=lambda_range, + cv_folds=cv_folds, + ) + + result = AblationResult( + ablation_name=ablation_name, + receptors_ablated=receptors_to_ablate, + ablated_indices=ablated_indices, + cv_r2=cv_r2, + cv_mse=cv_mse, + n_receptors_selected=len(weights), + lambda_value=best_lambda, + lasso_weights=weights, + delta_r2=cv_r2 - baseline_r2, + delta_mse=cv_mse - baseline_mse, + ) + + logger.info( + f"Ablation '{ablation_name}': R² = {cv_r2:.4f} (Δ = {result.delta_r2:+.4f}), " + f"MSE = {cv_mse:.4f} (Δ = {result.delta_mse:+.4f})" + ) + + return result + + +def main() -> int: + """Main entry point.""" + args = parse_args() + + # Parse lambda range + if args.lambda_value is not None: + lambda_range = np.array([args.lambda_value]) + else: + lambda_range = np.array([float(x.strip()) for x in args.lambda_range.split(",")]) + + logger.info(f"Lambda range: {lambda_range}") + + # Load receptors to ablate + if args.ablate: + receptors_to_ablate = [r.strip() for r in args.ablate.split(",") if r.strip()] + else: + receptors_to_ablate = load_receptors_from_file(args.ablate_file) + + if not receptors_to_ablate: + logger.error("No receptors specified for ablation") + return 1 + + logger.info(f"Receptors to ablate: {receptors_to_ablate}") + + # Initialize predictor + predictor = LassoBehavioralPredictor( + doorcache_path=args.door_cache, + behavior_csv_path=args.behavior_csv, + scale_features=False, # We handle scaling manually for ablation + scale_targets=False, + ) + + # Validate condition + if args.condition not in predictor.behavioral_data.index: + logger.error( + f"Condition '{args.condition}' not found. " + f"Available: {list(predictor.behavioral_data.index)}" + ) + return 1 + + # Resolve receptor names + available_receptors = list(predictor.masked_receptor_names) + strict_mode = args.missing_receptor_policy == "error" + + matched_receptors, unmatched = resolve_receptor_names( + receptors_to_ablate, available_receptors, strict=strict_mode + ) + + if unmatched: + if strict_mode: + # Error already raised by resolve_receptor_names + pass + else: + logger.warning(f"Skipping unmatched receptors: {unmatched}") + + if not matched_receptors: + logger.error("No valid receptors to ablate after resolution") + return 1 + + logger.info(f"Resolved receptors: {matched_receptors}") + + # Create output directory + output_dir = Path(args.output_dir) / args.condition + output_dir.mkdir(parents=True, exist_ok=True) + + # Run baseline + try: + baseline_result, X, y, receptor_names, scaler = run_baseline( + predictor=predictor, + condition=args.condition, + prediction_mode=args.prediction_mode, + lambda_range=lambda_range, + cv_folds=args.cv_folds, + scale_features=args.scale_features, + ) + except Exception as e: + logger.error(f"Baseline fit failed: {e}") + return 1 + + # Save baseline artifacts + save_model_json( + result=baseline_result, + condition=args.condition, + prediction_mode=args.prediction_mode, + n_samples=X.shape[0], + n_receptors_total=X.shape[1], + filepath=output_dir / "baseline_model.json", + ) + save_weights_csv( + weights=baseline_result.lasso_weights, + filepath=output_dir / "baseline_weights.csv", + condition=args.condition, + ablation_name="baseline", + ) + + # Run ablations + ablation_results: List[AblationResult] = [] + + if args.ablation_set_mode == "single": + # Ablate each receptor individually + for receptor in matched_receptors: + try: + ablation_name = f"ablate_{receptor}" + result = run_ablation( + X=X, + y=y, + receptor_names=receptor_names, + receptors_to_ablate=[receptor], + scaler=scaler, + lambda_range=lambda_range, + cv_folds=args.cv_folds, + baseline_r2=baseline_result.cv_r2, + baseline_mse=baseline_result.cv_mse, + ablation_name=ablation_name, + ) + ablation_results.append(result) + + # Save individual ablation artifacts + ablation_dir = output_dir / ablation_name + save_model_json( + result=result, + condition=args.condition, + prediction_mode=args.prediction_mode, + n_samples=X.shape[0], + n_receptors_total=X.shape[1], + filepath=ablation_dir / "model.json", + ) + save_weights_csv( + weights=result.lasso_weights, + filepath=ablation_dir / "weights.csv", + condition=args.condition, + ablation_name=ablation_name, + ) + + except Exception as e: + logger.error(f"Ablation '{receptor}' failed: {e}") + continue + + else: # all_in_one + # Ablate all receptors together + ablation_name = "ablate_" + "_".join(matched_receptors) + # Truncate if too long + if len(ablation_name) > 100: + ablation_name = f"ablate_{len(matched_receptors)}_receptors" + + try: + result = run_ablation( + X=X, + y=y, + receptor_names=receptor_names, + receptors_to_ablate=matched_receptors, + scaler=scaler, + lambda_range=lambda_range, + cv_folds=args.cv_folds, + baseline_r2=baseline_result.cv_r2, + baseline_mse=baseline_result.cv_mse, + ablation_name=ablation_name, + ) + ablation_results.append(result) + + # Save ablation artifacts + ablation_dir = output_dir / ablation_name + save_model_json( + result=result, + condition=args.condition, + prediction_mode=args.prediction_mode, + n_samples=X.shape[0], + n_receptors_total=X.shape[1], + filepath=ablation_dir / "model.json", + ) + save_weights_csv( + weights=result.lasso_weights, + filepath=ablation_dir / "weights.csv", + condition=args.condition, + ablation_name=ablation_name, + ) + + except Exception as e: + logger.error(f"Ablation failed: {e}") + + # Generate summary CSV + summary_rows = [] + + # Baseline row + summary_rows.append({ + "ablation_name": "baseline", + "receptors_ablated": "", + "n_ablated": 0, + "cv_r2": baseline_result.cv_r2, + "cv_mse": baseline_result.cv_mse, + "n_receptors_selected": baseline_result.n_receptors_selected, + "lambda_value": baseline_result.lambda_value, + "delta_r2": 0.0, + "delta_mse": 0.0, + }) + + # Ablation rows + for result in ablation_results: + summary_rows.append({ + "ablation_name": result.ablation_name, + "receptors_ablated": ";".join(result.receptors_ablated), + "n_ablated": len(result.receptors_ablated), + "cv_r2": result.cv_r2, + "cv_mse": result.cv_mse, + "n_receptors_selected": result.n_receptors_selected, + "lambda_value": result.lambda_value, + "delta_r2": result.delta_r2, + "delta_mse": result.delta_mse, + }) + + summary_df = pd.DataFrame(summary_rows) + summary_path = output_dir / "ablation_summary.csv" + summary_df.to_csv(summary_path, index=False) + logger.info(f"Saved summary to {summary_path}") + + # Generate comparison plot + plot_ablation_comparison( + baseline_result=baseline_result, + ablation_results=ablation_results, + output_dir=output_dir, + condition=args.condition, + ) + + # Print summary + print("\n" + "=" * 80) + print(f"LASSO Ablation Analysis Complete: {args.condition}") + print("=" * 80) + print(f"\nBaseline: R² = {baseline_result.cv_r2:.4f}, MSE = {baseline_result.cv_mse:.4f}") + print(f" {baseline_result.n_receptors_selected} receptors selected") + print(f"\nAblation Results:") + for result in ablation_results: + print( + f" {result.ablation_name:40s} " + f"R² = {result.cv_r2:.4f} (Δ = {result.delta_r2:+.4f}) " + f"MSE = {result.cv_mse:.4f} (Δ = {result.delta_mse:+.4f})" + ) + print(f"\nOutputs saved to: {output_dir}") + print("=" * 80) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/door_toolkit/pathways/__init__.py b/src/door_toolkit/pathways/__init__.py index e19be76..b200b5d 100644 --- a/src/door_toolkit/pathways/__init__.py +++ b/src/door_toolkit/pathways/__init__.py @@ -33,6 +33,9 @@ LassoBehavioralPredictor, BehaviorModelResults, BehaviorPrediction, + resolve_receptor_names, + apply_receptor_ablation, + fit_lasso_with_fixed_scaler, ) from door_toolkit.pathways.behavior_rate_model import ( SparseRateGLM, @@ -50,6 +53,9 @@ "LassoBehavioralPredictor", "BehaviorModelResults", "BehaviorPrediction", + "resolve_receptor_names", + "apply_receptor_ablation", + "fit_lasso_with_fixed_scaler", "SparseRateGLM", "TrainConfig", "build_training_table", diff --git a/src/door_toolkit/pathways/behavioral_prediction.py b/src/door_toolkit/pathways/behavioral_prediction.py index 8e57921..c3442a0 100644 --- a/src/door_toolkit/pathways/behavioral_prediction.py +++ b/src/door_toolkit/pathways/behavioral_prediction.py @@ -1197,3 +1197,223 @@ def _plot_receptor_overlap( plt.close() else: plt.show() + + +# ============================================================================ +# LASSO Ablation/Robustness Helpers +# ============================================================================ + + +def resolve_receptor_names( + requested_receptors: List[str], + available_receptors: List[str], + strict: bool = True, +) -> Tuple[List[str], List[str]]: + """ + Resolve receptor names using case-insensitive exact matching. + + Args: + requested_receptors: User-provided receptor names + available_receptors: Valid receptor names from encoder + strict: If True, raise error on any unmatched receptors + + Returns: + Tuple of (matched_names, unmatched_names) + - matched_names: Successfully matched receptor names (in available order) + - unmatched_names: Could not be matched + + Raises: + ValueError: If strict=True and any receptors are unmatched + + Example: + >>> resolve_receptor_names( + ... ["Or42b", "OR47B", "Or99x"], + ... ["Or42b", "Or47b", "Or22a"], + ... strict=False + ... ) + (["Or42b", "Or47b"], ["Or99x"]) + """ + # Build case-insensitive lookup: lowercase -> original name + lower_to_original: Dict[str, str] = {} + for receptor in available_receptors: + lower_to_original[receptor.lower()] = receptor + + matched: List[str] = [] + unmatched: List[str] = [] + + for requested in requested_receptors: + requested_clean = requested.strip() + if not requested_clean: + continue + + # Try case-insensitive exact match + key = requested_clean.lower() + if key in lower_to_original: + canonical = lower_to_original[key] + if canonical not in matched: + matched.append(canonical) + if requested_clean != canonical: + logger.debug( + f"Receptor name normalized: '{requested_clean}' -> '{canonical}'" + ) + else: + unmatched.append(requested_clean) + + if strict and unmatched: + raise ValueError( + f"Could not resolve the following receptors (case-insensitive): " + f"{unmatched}. Available receptors include: " + f"{available_receptors[:10]}{'...' if len(available_receptors) > 10 else ''}" + ) + + return matched, unmatched + + +def apply_receptor_ablation( + X: np.ndarray, + receptor_names: List[str], + receptors_to_ablate: List[str], + ablation_value: float = 0.0, +) -> Tuple[np.ndarray, List[int]]: + """ + Set specified receptor channels to ablation_value. + + Args: + X: Feature matrix (n_samples, n_receptors) + receptor_names: List of receptor names corresponding to X columns + receptors_to_ablate: Receptor names to ablate (case-insensitive) + ablation_value: Value to set ablated channels (default: 0.0) + + Returns: + Tuple of (X_ablated, ablated_indices) + - X_ablated: Copy of X with ablated channels set to ablation_value + - ablated_indices: Column indices that were ablated + + Raises: + ValueError: If any receptor in receptors_to_ablate not found + + Example: + >>> X = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + >>> X_abl, idx = apply_receptor_ablation( + ... X, ["Or42b", "Or47b", "Or22a"], ["Or42b", "Or22a"] + ... ) + >>> X_abl + array([[0., 2., 0.], + [0., 5., 0.]]) + >>> idx + [0, 2] + """ + # Resolve receptor names (strict mode) + matched, unmatched = resolve_receptor_names( + receptors_to_ablate, receptor_names, strict=True + ) + + # Find indices of matched receptors + name_to_idx = {name: i for i, name in enumerate(receptor_names)} + ablated_indices = [name_to_idx[r] for r in matched] + + # Create ablated copy + X_ablated = X.copy() + for idx in ablated_indices: + X_ablated[:, idx] = ablation_value + + logger.info( + f"Ablated {len(ablated_indices)} receptors: {matched} (indices: {ablated_indices})" + ) + + return X_ablated, ablated_indices + + +def fit_lasso_with_fixed_scaler( + X: np.ndarray, + y: np.ndarray, + receptor_names: List[str], + scaler: Optional[StandardScaler], + lambda_range: np.ndarray, + cv_folds: int = 5, + random_state: int = 42, +) -> Tuple[Dict[str, float], float, float, float, np.ndarray]: + """ + Fit LASSO using a pre-fitted scaler for fair comparison. + + This ensures that ablation/focus variants use the same feature + scaling as the baseline model, preventing confounding effects + from different normalization. + + Args: + X: Feature matrix (n_samples, n_receptors) + y: Target values + receptor_names: List of receptor names corresponding to X columns + scaler: Pre-fitted StandardScaler (if None, fits new scaler) + lambda_range: Array of lambda values for LASSO CV + cv_folds: Number of cross-validation folds + random_state: Random seed for reproducibility + + Returns: + Tuple of (lasso_weights, cv_r2, cv_mse, best_lambda, y_pred) + - lasso_weights: Dict of {receptor: weight} for non-zero coefficients + - cv_r2: Cross-validated R² score + - cv_mse: Cross-validated MSE + - best_lambda: Selected regularization parameter + - y_pred: Predictions on training data + + Example: + >>> # Baseline model - fits scaler + >>> scaler = StandardScaler().fit(X_baseline) + >>> weights_base, r2, mse, lam, pred = fit_lasso_with_fixed_scaler( + ... X_baseline, y, names, scaler, lambdas + ... ) + >>> # Ablated model - reuses same scaler + >>> weights_abl, r2_abl, mse_abl, lam_abl, pred_abl = fit_lasso_with_fixed_scaler( + ... X_ablated, y, names, scaler, lambdas + ... ) + """ + # Apply scaling + if scaler is not None: + X_scaled = scaler.transform(X) + else: + X_scaled = X.copy() + + # Adjust CV folds for small samples + n_samples = X_scaled.shape[0] + if n_samples < 10: + cv_folds_adjusted = n_samples # LOOCV + else: + cv_folds_adjusted = min(cv_folds, n_samples) + + # Fit LASSO with cross-validation + lasso_cv = LassoCV( + alphas=lambda_range, + cv=cv_folds_adjusted, + max_iter=10000, + random_state=random_state, + ) + lasso_cv.fit(X_scaled, y) + + best_lambda = lasso_cv.alpha_ + + # Refit with best lambda + lasso = Lasso(alpha=best_lambda, max_iter=10000, random_state=random_state) + lasso.fit(X_scaled, y) + + # Predict + y_pred = lasso.predict(X_scaled) + + # Extract non-zero coefficients + lasso_weights: Dict[str, float] = {} + for i, coef in enumerate(lasso.coef_): + if abs(coef) > 1e-6: + lasso_weights[receptor_names[i]] = float(coef) + + # Compute cross-validated metrics + cv_r2_scores = cross_val_score( + lasso, X_scaled, y, cv=cv_folds_adjusted, scoring="r2" + ) + cv_r2 = float(np.mean(cv_r2_scores)) + + cv_mse_scores = cross_val_score( + lasso, X_scaled, y, cv=cv_folds_adjusted, scoring="neg_mean_squared_error" + ) + cv_mse = float(-np.mean(cv_mse_scores)) + + return lasso_weights, cv_r2, cv_mse, best_lambda, y_pred diff --git a/tests/test_lasso_ablation.py b/tests/test_lasso_ablation.py new file mode 100644 index 0000000..d4db38e --- /dev/null +++ b/tests/test_lasso_ablation.py @@ -0,0 +1,543 @@ +"""Tests for LASSO ablation analysis functionality.""" + +import json +import subprocess +import sys +from pathlib import Path + +import numpy as np +import pandas as pd +import pytest +from sklearn.preprocessing import StandardScaler + +from door_toolkit.pathways.behavioral_prediction import ( + LassoBehavioralPredictor, + apply_receptor_ablation, + fit_lasso_with_fixed_scaler, + resolve_receptor_names, +) + + +# ============================================================================ +# Tests for resolve_receptor_names +# ============================================================================ + + +class TestResolveReceptorNames: + """Tests for receptor name resolution.""" + + def test_exact_match(self): + """Test exact case-sensitive matching.""" + available = ["Or42b", "Or47b", "Or22a", "Or59b"] + matched, unmatched = resolve_receptor_names( + ["Or42b", "Or47b"], available, strict=False + ) + + assert matched == ["Or42b", "Or47b"] + assert unmatched == [] + + def test_case_insensitive_match(self): + """Test case-insensitive matching.""" + available = ["Or42b", "Or47b", "Or22a"] + matched, unmatched = resolve_receptor_names( + ["or42B", "OR47B", "or22A"], available, strict=False + ) + + assert set(matched) == {"Or42b", "Or47b", "Or22a"} + assert unmatched == [] + + def test_unmatched_receptors_non_strict(self): + """Test unmatched receptors in non-strict mode.""" + available = ["Or42b", "Or47b"] + matched, unmatched = resolve_receptor_names( + ["Or42b", "Or99x", "InvalidReceptor"], available, strict=False + ) + + assert matched == ["Or42b"] + assert set(unmatched) == {"Or99x", "InvalidReceptor"} + + def test_unmatched_receptors_strict_raises(self): + """Test that strict mode raises on unmatched receptors.""" + available = ["Or42b", "Or47b"] + + with pytest.raises(ValueError, match="Could not resolve"): + resolve_receptor_names( + ["Or42b", "Or99x"], available, strict=True + ) + + def test_empty_and_whitespace_handling(self): + """Test handling of empty strings and whitespace.""" + available = ["Or42b", "Or47b"] + matched, unmatched = resolve_receptor_names( + [" Or42b ", "", " ", "Or47b"], available, strict=False + ) + + assert set(matched) == {"Or42b", "Or47b"} + assert unmatched == [] + + def test_duplicate_removal(self): + """Test that duplicates are removed from matched list.""" + available = ["Or42b", "Or47b"] + matched, unmatched = resolve_receptor_names( + ["Or42b", "or42b", "OR42B", "Or42b"], available, strict=False + ) + + assert matched == ["Or42b"] # Only one entry + assert unmatched == [] + + def test_mixed_case_variants(self): + """Test various case variants of the same receptor.""" + available = ["Ir64a.DC4", "Gr21a.Gr63a", "Or85f"] + matched, unmatched = resolve_receptor_names( + ["ir64a.dc4", "GR21A.GR63A", "or85F"], available, strict=False + ) + + assert set(matched) == {"Ir64a.DC4", "Gr21a.Gr63a", "Or85f"} + assert unmatched == [] + + +# ============================================================================ +# Tests for apply_receptor_ablation +# ============================================================================ + + +class TestApplyReceptorAblation: + """Tests for receptor ablation application.""" + + def test_basic_ablation(self): + """Test basic column zeroing.""" + X = np.array([ + [1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0], + [9.0, 10.0, 11.0, 12.0], + ]) + receptor_names = ["Or42b", "Or47b", "Or22a", "Or59b"] + + X_ablated, indices = apply_receptor_ablation( + X, receptor_names, ["Or42b", "Or22a"] + ) + + # Original should be unchanged + assert X[0, 0] == 1.0 + assert X[0, 2] == 3.0 + + # Ablated should have zeros in specified columns + assert X_ablated[0, 0] == 0.0 + assert X_ablated[0, 2] == 0.0 + assert X_ablated[1, 0] == 0.0 + assert X_ablated[1, 2] == 0.0 + + # Non-ablated columns should be unchanged + assert X_ablated[0, 1] == 2.0 + assert X_ablated[0, 3] == 4.0 + + # Check indices + assert set(indices) == {0, 2} + + def test_ablation_preserves_shape(self): + """Test that ablation preserves matrix shape.""" + X = np.random.rand(10, 20) + receptor_names = [f"Or{i}" for i in range(20)] + + X_ablated, indices = apply_receptor_ablation( + X, receptor_names, ["Or5", "Or10", "Or15"] + ) + + assert X_ablated.shape == X.shape + assert len(indices) == 3 + + def test_ablation_with_case_insensitive_names(self): + """Test that ablation handles case-insensitive receptor names.""" + X = np.array([[1.0, 2.0, 3.0]]) + receptor_names = ["Or42b", "Or47b", "Or22a"] + + X_ablated, indices = apply_receptor_ablation( + X, receptor_names, ["or42B", "OR22A"] # Different case + ) + + assert X_ablated[0, 0] == 0.0 # Or42b + assert X_ablated[0, 2] == 0.0 # Or22a + assert X_ablated[0, 1] == 2.0 # Or47b unchanged + + def test_ablation_invalid_receptor_raises(self): + """Test that invalid receptor names raise ValueError.""" + X = np.array([[1.0, 2.0, 3.0]]) + receptor_names = ["Or42b", "Or47b", "Or22a"] + + with pytest.raises(ValueError, match="Could not resolve"): + apply_receptor_ablation(X, receptor_names, ["Or99x"]) + + def test_custom_ablation_value(self): + """Test ablation with custom value (not zero).""" + X = np.array([[1.0, 2.0, 3.0]]) + receptor_names = ["Or42b", "Or47b", "Or22a"] + + X_ablated, indices = apply_receptor_ablation( + X, receptor_names, ["Or42b"], ablation_value=-999.0 + ) + + assert X_ablated[0, 0] == -999.0 + assert X_ablated[0, 1] == 2.0 + + def test_ablation_all_columns(self): + """Test ablating all columns.""" + X = np.array([[1.0, 2.0, 3.0]]) + receptor_names = ["Or42b", "Or47b", "Or22a"] + + X_ablated, indices = apply_receptor_ablation( + X, receptor_names, ["Or42b", "Or47b", "Or22a"] + ) + + assert np.all(X_ablated == 0.0) + assert len(indices) == 3 + + +# ============================================================================ +# Tests for fit_lasso_with_fixed_scaler +# ============================================================================ + + +class TestFitLassoWithFixedScaler: + """Tests for LASSO fitting with preserved scaler.""" + + def test_basic_fit(self): + """Test basic LASSO fitting.""" + np.random.seed(42) + n_samples, n_features = 20, 10 + X = np.random.rand(n_samples, n_features) + # Create y with some correlation to X + true_weights = np.zeros(n_features) + true_weights[0] = 1.0 + true_weights[1] = 0.5 + y = X @ true_weights + np.random.randn(n_samples) * 0.1 + + receptor_names = [f"Or{i}" for i in range(n_features)] + scaler = StandardScaler().fit(X) + lambda_range = np.array([0.01, 0.1, 1.0]) + + weights, cv_r2, cv_mse, best_lambda, y_pred = fit_lasso_with_fixed_scaler( + X, y, receptor_names, scaler, lambda_range, cv_folds=5 + ) + + assert isinstance(weights, dict) + assert isinstance(cv_r2, float) + assert isinstance(cv_mse, float) + assert best_lambda in lambda_range + assert len(y_pred) == n_samples + + def test_scaler_consistency(self): + """Test that using same scaler gives consistent scaling.""" + np.random.seed(42) + X = np.random.rand(15, 8) + y = np.random.rand(15) + receptor_names = [f"Or{i}" for i in range(8)] + lambda_range = np.array([0.1]) + + # Fit scaler on X + scaler = StandardScaler().fit(X) + + # Ablate some columns + X_ablated = X.copy() + X_ablated[:, [2, 5]] = 0.0 + + # Fit both models with same scaler + weights1, r2_1, mse_1, _, _ = fit_lasso_with_fixed_scaler( + X, y, receptor_names, scaler, lambda_range + ) + weights2, r2_2, mse_2, _, _ = fit_lasso_with_fixed_scaler( + X_ablated, y, receptor_names, scaler, lambda_range + ) + + # Both should complete without error + assert isinstance(weights1, dict) + assert isinstance(weights2, dict) + + def test_no_scaler(self): + """Test fitting without scaler (None).""" + np.random.seed(42) + X = np.random.rand(15, 5) + y = np.random.rand(15) + receptor_names = [f"Or{i}" for i in range(5)] + lambda_range = np.array([0.1]) + + weights, cv_r2, cv_mse, best_lambda, y_pred = fit_lasso_with_fixed_scaler( + X, y, receptor_names, None, lambda_range + ) + + assert isinstance(weights, dict) + assert len(y_pred) == 15 + + +# ============================================================================ +# Tests for ablation_set_mode behavior +# ============================================================================ + + +class TestAblationSetMode: + """Tests for ablation_set_mode behavior (single vs all_in_one).""" + + @pytest.fixture + def sample_data(self): + """Create sample data for ablation tests.""" + np.random.seed(42) + X = np.random.rand(10, 5) + y = np.random.rand(10) + receptor_names = ["Or42b", "Or47b", "Or22a", "Or59b", "Or92a"] + return X, y, receptor_names + + def test_single_mode_creates_individual_ablations(self, sample_data): + """Test that single mode ablates each receptor individually.""" + X, y, receptor_names = sample_data + receptors_to_ablate = ["Or42b", "Or47b", "Or22a"] + + results = [] + for receptor in receptors_to_ablate: + X_ablated, indices = apply_receptor_ablation( + X, receptor_names, [receptor] + ) + results.append((receptor, X_ablated, indices)) + + # Should have 3 separate ablations + assert len(results) == 3 + + # Each ablation should only zero one column + for receptor, X_abl, indices in results: + assert len(indices) == 1 + # Count zero columns + zero_cols = np.sum(np.all(X_abl == 0, axis=0)) + assert zero_cols == 1 + + def test_all_in_one_mode_ablates_together(self, sample_data): + """Test that all_in_one mode ablates all receptors at once.""" + X, y, receptor_names = sample_data + receptors_to_ablate = ["Or42b", "Or47b", "Or22a"] + + X_ablated, indices = apply_receptor_ablation( + X, receptor_names, receptors_to_ablate + ) + + # Should have 3 ablated columns + assert len(indices) == 3 + + # All three columns should be zero + for idx in indices: + assert np.all(X_ablated[:, idx] == 0.0) + + +# ============================================================================ +# Integration tests for script outputs +# ============================================================================ + + +class TestAblationScriptOutputs: + """Integration tests for ablation script output files.""" + + def test_save_model_json_format(self, tmp_path): + """Test that save_model_json creates valid JSON with expected fields.""" + from scripts.lasso_with_ablations import save_model_json, AblationResult + + output_dir = tmp_path / "output" + output_dir.mkdir() + + # Create a mock result + result = AblationResult( + ablation_name="baseline", + receptors_ablated=[], + ablated_indices=[], + cv_r2=0.75, + cv_mse=0.025, + n_receptors_selected=5, + lambda_value=0.01, + lasso_weights={"Or42b": 0.5, "Or47b": -0.3}, + delta_r2=0.0, + delta_mse=0.0, + ) + + filepath = output_dir / "baseline_model.json" + save_model_json( + result=result, + condition="opto_hex", + prediction_mode="test_odorant", + n_samples=10, + n_receptors_total=20, + filepath=filepath, + ) + + # Verify file exists and is valid JSON + assert filepath.exists() + + with open(filepath) as f: + data = json.load(f) + + assert data["condition_name"] == "opto_hex" + assert data["ablation_name"] == "baseline" + assert data["cv_r2"] == 0.75 + assert data["cv_mse"] == 0.025 + assert data["n_receptors_selected"] == 5 + assert data["lambda_value"] == 0.01 + assert "lasso_weights" in data + assert "top_10_receptors" in data + + def test_save_weights_csv_format(self, tmp_path): + """Test that save_weights_csv creates valid CSV.""" + from scripts.lasso_with_ablations import save_weights_csv + + output_dir = tmp_path / "output" + output_dir.mkdir() + + weights = {"Or42b": 0.5, "Or47b": -0.3, "Or22a": 0.1} + filepath = output_dir / "weights.csv" + + save_weights_csv( + weights=weights, + filepath=filepath, + condition="opto_hex", + ablation_name="baseline", + ) + + # Verify file exists + assert filepath.exists() + + # Verify CSV structure + df = pd.read_csv(filepath) + assert "condition" in df.columns + assert "ablation" in df.columns + assert "receptor" in df.columns + assert "weight" in df.columns + assert "abs_weight" in df.columns + assert len(df) == 3 + + # Check ordering (by absolute weight descending) + assert df.iloc[0]["receptor"] == "Or42b" # Highest abs weight + + def test_ablation_summary_csv_format(self, tmp_path): + """Test that ablation_summary.csv has correct format.""" + from scripts.lasso_with_ablations import AblationResult + + # Create summary rows without needing predictor + summary_rows = [ + { + "ablation_name": "baseline", + "receptors_ablated": "", + "n_ablated": 0, + "cv_r2": 0.75, + "cv_mse": 0.025, + "n_receptors_selected": 5, + "lambda_value": 0.01, + "delta_r2": 0.0, + "delta_mse": 0.0, + }, + { + "ablation_name": "ablate_Or42b", + "receptors_ablated": "Or42b", + "n_ablated": 1, + "cv_r2": 0.65, + "cv_mse": 0.035, + "n_receptors_selected": 4, + "lambda_value": 0.01, + "delta_r2": -0.10, + "delta_mse": 0.010, + }, + ] + + summary_df = pd.DataFrame(summary_rows) + summary_path = tmp_path / "ablation_summary.csv" + summary_df.to_csv(summary_path, index=False) + + # Verify CSV can be read back + loaded_df = pd.read_csv(summary_path) + assert "ablation_name" in loaded_df.columns + assert "cv_r2" in loaded_df.columns + assert "cv_mse" in loaded_df.columns + assert "delta_r2" in loaded_df.columns + assert "delta_mse" in loaded_df.columns + assert len(loaded_df) == 2 + + def test_run_ablation_with_synthetic_data(self, tmp_path): + """Test run_ablation with synthetic data (no predictor needed).""" + from scripts.lasso_with_ablations import run_ablation + + # Create synthetic data + np.random.seed(42) + n_samples, n_features = 15, 10 + X = np.random.rand(n_samples, n_features) + + # Create y with correlation to first receptor + true_weights = np.zeros(n_features) + true_weights[0] = 1.0 + y = X @ true_weights + np.random.randn(n_samples) * 0.1 + + receptor_names = [f"Or{i}" for i in range(n_features)] + scaler = StandardScaler().fit(X) + lambda_range = np.array([0.01, 0.1]) + + # Run ablation + result = run_ablation( + X=X, + y=y, + receptor_names=receptor_names, + receptors_to_ablate=["Or0"], # Ablate the important receptor + scaler=scaler, + lambda_range=lambda_range, + cv_folds=3, + baseline_r2=0.9, + baseline_mse=0.01, + ablation_name="ablate_Or0", + ) + + # Verify result structure + assert result.ablation_name == "ablate_Or0" + assert result.receptors_ablated == ["Or0"] + assert result.ablated_indices == [0] + assert isinstance(result.cv_r2, float) + assert isinstance(result.cv_mse, float) + assert result.delta_r2 == result.cv_r2 - 0.9 + assert result.delta_mse == result.cv_mse - 0.01 + + +# ============================================================================ +# Edge case tests +# ============================================================================ + + +class TestAblationEdgeCases: + """Tests for edge cases in ablation.""" + + def test_ablate_zero_column(self): + """Test ablating an already-zero column.""" + X = np.array([[1.0, 0.0, 3.0], [4.0, 0.0, 6.0]]) + receptor_names = ["Or42b", "Or47b", "Or22a"] + + X_ablated, indices = apply_receptor_ablation( + X, receptor_names, ["Or47b"] + ) + + # Should still work, column stays zero + assert np.all(X_ablated[:, 1] == 0.0) + assert indices == [1] + + def test_ablate_single_sample(self): + """Test ablation with single sample (edge case for CV).""" + X = np.array([[1.0, 2.0, 3.0]]) + receptor_names = ["Or42b", "Or47b", "Or22a"] + + X_ablated, indices = apply_receptor_ablation( + X, receptor_names, ["Or42b"] + ) + + assert X_ablated.shape == (1, 3) + assert X_ablated[0, 0] == 0.0 + + def test_ablate_high_dimensional(self): + """Test ablation with many receptors.""" + n_receptors = 100 + X = np.random.rand(20, n_receptors) + receptor_names = [f"Or{i}" for i in range(n_receptors)] + + # Ablate 50 receptors + to_ablate = [f"Or{i}" for i in range(0, 50)] + X_ablated, indices = apply_receptor_ablation(X, receptor_names, to_ablate) + + assert len(indices) == 50 + for idx in indices: + assert np.all(X_ablated[:, idx] == 0.0) From 4278a80f021d538ab0b20a602cb594540d135f4d Mon Sep 17 00:00:00 2001 From: ramanlab Date: Thu, 8 Jan 2026 12:20:47 -0600 Subject: [PATCH 2/4] feat: Add LASSO focus mode analysis for receptor circuit sufficiency MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements focus mode analysis to measure how many receptors are needed for accurate behavioral prediction by restricting to top-N receptors from baseline and refitting LASSO. New features: - scripts/lasso_with_focus_mode.py: CLI script for focus mode analysis - --topn_list: comma-separated N values (e.g., "1,2,3,5,10,15,20") - --focus_receptors: explicit receptor list (overrides topn_list) - --baseline_select_by: ranking method (abs_weight default) - Outputs: baseline_model.json, focus_curve.csv, focus_curve.png - Helper functions in behavioral_prediction.py: - restrict_to_receptors(): Subset feature matrix to specified receptors - get_top_receptors_by_weight(): Rank receptors by |LASSO weight| Key design decisions: - For each N, fit scaler on restricted X (fair comparison within N) - Deterministic receptor ranking by absolute weight - Case-insensitive receptor name matching - Generates MSE vs N and R² vs N curves Tests: 24 new tests in test_lasso_focus_mode.py covering: - Feature matrix restriction - Top-N receptor selection (determinism verified) - Focus mode integration workflow - Output file format validation Co-Authored-By: Claude Opus 4.5 --- scripts/lasso_with_focus_mode.py | 692 ++++++++++++++++++ src/door_toolkit/pathways/__init__.py | 4 + .../pathways/behavioral_prediction.py | 83 +++ tests/test_lasso_focus_mode.py | 485 ++++++++++++ 4 files changed, 1264 insertions(+) create mode 100644 scripts/lasso_with_focus_mode.py create mode 100644 tests/test_lasso_focus_mode.py diff --git a/scripts/lasso_with_focus_mode.py b/scripts/lasso_with_focus_mode.py new file mode 100644 index 0000000..60661d1 --- /dev/null +++ b/scripts/lasso_with_focus_mode.py @@ -0,0 +1,692 @@ +#!/usr/bin/env python3 +""" +LASSO Focus Mode Analysis Script +================================= + +Refit LASSO behavioral prediction models using only a restricted set of +receptors (top-N from baseline by absolute weight), generating MSE vs N +curves to assess how many receptors are needed for accurate prediction. + +This script: +1. Fits a baseline LASSO model (full receptor set) +2. Ranks receptors by absolute weight +3. For each N in topn_list, restricts to top-N receptors and refits +4. Generates focus_curve.csv and focus_curve.png + +Example usage: + # Sweep through different receptor counts + python scripts/lasso_with_focus_mode.py \\ + --door_cache door_cache \\ + --behavior_csv reaction_rates.csv \\ + --condition opto_hex \\ + --topn_list 1,2,3,5,10,15,20 \\ + --output_dir outputs/focus/ + + # Focus on a specific receptor set + python scripts/lasso_with_focus_mode.py \\ + --door_cache door_cache \\ + --behavior_csv reaction_rates.csv \\ + --condition opto_hex \\ + --focus_receptors Or42b,Or47b,Or22a \\ + --output_dir outputs/focus/ +""" + +from __future__ import annotations + +import argparse +import json +import logging +import sys +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from sklearn.preprocessing import StandardScaler + +matplotlib.use("Agg") + +from door_toolkit.pathways.behavioral_prediction import ( + LassoBehavioralPredictor, + fit_lasso_with_fixed_scaler, + restrict_to_receptors, + get_top_receptors_by_weight, +) + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) +logger = logging.getLogger(__name__) + + +@dataclass +class FocusResult: + """Results from a single focus mode run.""" + + n_receptors: int + receptors_used: List[str] + cv_r2: float + cv_mse: float + lambda_value: float + n_receptors_selected: int + lasso_weights: Dict[str, float] + delta_r2: float = 0.0 + delta_mse: float = 0.0 + + +def parse_args() -> argparse.Namespace: + """Parse command-line arguments.""" + parser = argparse.ArgumentParser( + description="LASSO focus mode analysis for receptor circuit sufficiency.", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + + # Required arguments + parser.add_argument( + "--door_cache", + type=str, + required=True, + help="Path to DoOR cache directory", + ) + parser.add_argument( + "--behavior_csv", + type=str, + required=True, + help="Path to behavioral CSV (reaction_rates_summary_unordered.csv)", + ) + parser.add_argument( + "--condition", + type=str, + required=True, + help="Optogenetic condition name (e.g., opto_hex)", + ) + parser.add_argument( + "--output_dir", + type=str, + required=True, + help="Directory for output files", + ) + + # Focus mode specification + parser.add_argument( + "--topn_list", + type=str, + default="1,2,3,5,10,15,20,30", + help="Comma-separated list of N values for top-N focus (default: 1,2,3,5,10,15,20,30)", + ) + parser.add_argument( + "--focus_receptors", + type=str, + default=None, + help="Comma-separated list of receptors to focus on (overrides --topn_list)", + ) + parser.add_argument( + "--baseline_select_by", + type=str, + choices=["abs_weight", "stability"], + default="abs_weight", + help="How to rank receptors for top-N selection (default: abs_weight)", + ) + + # Prediction mode + parser.add_argument( + "--prediction_mode", + type=str, + choices=["test_odorant", "trained_odorant", "interaction"], + default="test_odorant", + help="Feature extraction mode (default: test_odorant)", + ) + + # LASSO parameters + parser.add_argument( + "--cv_folds", + type=int, + default=5, + help="Number of cross-validation folds (default: 5)", + ) + parser.add_argument( + "--lambda_range", + type=str, + default="0.0001,0.001,0.01,0.1,1.0", + help="Comma-separated lambda values for LASSO CV (default: 0.0001,0.001,0.01,0.1,1.0)", + ) + parser.add_argument( + "--lambda_value", + type=float, + default=None, + help="Fixed lambda value (overrides --lambda_range)", + ) + + # Scaling + parser.add_argument( + "--scale_features", + action="store_true", + default=True, + help="Standardize receptor features (default: True)", + ) + parser.add_argument( + "--no_scale_features", + action="store_false", + dest="scale_features", + help="Do not standardize receptor features", + ) + + return parser.parse_args() + + +def save_baseline_json( + baseline_weights: Dict[str, float], + baseline_r2: float, + baseline_mse: float, + baseline_lambda: float, + condition: str, + prediction_mode: str, + n_samples: int, + n_receptors_total: int, + filepath: Path, +) -> None: + """Save baseline model metadata to JSON.""" + data = { + "condition_name": condition, + "prediction_mode": prediction_mode, + "n_samples": n_samples, + "n_receptors_total": n_receptors_total, + "cv_r2": baseline_r2, + "cv_mse": baseline_mse, + "lambda_value": baseline_lambda, + "n_receptors_selected": len(baseline_weights), + "lasso_weights": baseline_weights, + "receptor_ranking": [ + {"receptor": r, "weight": w, "abs_weight": abs(w)} + for r, w in sorted( + baseline_weights.items(), key=lambda x: abs(x[1]), reverse=True + ) + ], + } + + filepath.parent.mkdir(parents=True, exist_ok=True) + with open(filepath, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2) + logger.info(f"Saved baseline model to {filepath}") + + +def save_focus_json( + result: FocusResult, + condition: str, + prediction_mode: str, + n_samples: int, + filepath: Path, +) -> None: + """Save focus mode result to JSON.""" + data = { + "condition_name": condition, + "prediction_mode": prediction_mode, + "n_samples": n_samples, + "n_receptors_focused": result.n_receptors, + "receptors_used": result.receptors_used, + "cv_r2": result.cv_r2, + "cv_mse": result.cv_mse, + "delta_r2": result.delta_r2, + "delta_mse": result.delta_mse, + "lambda_value": result.lambda_value, + "n_receptors_selected": result.n_receptors_selected, + "lasso_weights": result.lasso_weights, + } + + filepath.parent.mkdir(parents=True, exist_ok=True) + with open(filepath, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2) + logger.info(f"Saved focus result to {filepath}") + + +def plot_focus_curve( + focus_results: List[FocusResult], + baseline_r2: float, + baseline_mse: float, + output_dir: Path, + condition: str, +) -> None: + """Generate focus curve plots (MSE vs N and R² vs N).""" + if not focus_results: + return + + # Sort by n_receptors + results_sorted = sorted(focus_results, key=lambda r: r.n_receptors) + + n_values = [r.n_receptors for r in results_sorted] + mse_values = [r.cv_mse for r in results_sorted] + r2_values = [r.cv_r2 for r in results_sorted] + + # Create figure with two subplots + fig, axes = plt.subplots(1, 2, figsize=(14, 5)) + + # MSE vs N + ax1 = axes[0] + ax1.plot(n_values, mse_values, "o-", color="coral", linewidth=2, markersize=8) + ax1.axhline( + baseline_mse, + color="green", + linestyle="--", + linewidth=2, + label=f"Baseline MSE ({baseline_mse:.4f})", + ) + ax1.set_xlabel("Number of Receptors (N)", fontsize=12) + ax1.set_ylabel("Cross-validated MSE", fontsize=12) + ax1.set_title(f"{condition} - MSE vs Number of Receptors", fontsize=14) + ax1.legend() + ax1.grid(alpha=0.3) + + # Add value labels + for n, mse in zip(n_values, mse_values): + ax1.annotate( + f"{mse:.4f}", + (n, mse), + textcoords="offset points", + xytext=(0, 10), + ha="center", + fontsize=8, + ) + + # R² vs N + ax2 = axes[1] + ax2.plot(n_values, r2_values, "o-", color="steelblue", linewidth=2, markersize=8) + ax2.axhline( + baseline_r2, + color="green", + linestyle="--", + linewidth=2, + label=f"Baseline R² ({baseline_r2:.4f})", + ) + ax2.set_xlabel("Number of Receptors (N)", fontsize=12) + ax2.set_ylabel("Cross-validated R²", fontsize=12) + ax2.set_title(f"{condition} - R² vs Number of Receptors", fontsize=14) + ax2.legend() + ax2.grid(alpha=0.3) + + # Add value labels + for n, r2 in zip(n_values, r2_values): + ax2.annotate( + f"{r2:.4f}", + (n, r2), + textcoords="offset points", + xytext=(0, 10), + ha="center", + fontsize=8, + ) + + plt.tight_layout() + + # Save + plot_path = output_dir / "focus_curve.png" + plt.savefig(plot_path, dpi=300, bbox_inches="tight") + plt.close() + logger.info(f"Saved focus curve plot to {plot_path}") + + +def run_baseline( + predictor: LassoBehavioralPredictor, + condition: str, + prediction_mode: str, + lambda_range: np.ndarray, + cv_folds: int, + scale_features: bool, +) -> Tuple[Dict[str, float], float, float, float, np.ndarray, np.ndarray, List[str], Optional[StandardScaler]]: + """Run baseline LASSO fit (full receptor set). + + Returns: + Tuple of (baseline_weights, cv_r2, cv_mse, best_lambda, X, y, receptor_names, scaler) + """ + logger.info(f"Fitting baseline model for condition: {condition}") + + # Get behavioral responses + per_responses = predictor.behavioral_data.loc[condition] + valid_odorants = per_responses.dropna() + + if len(valid_odorants) == 0: + raise ValueError(f"No valid PER data for condition '{condition}'") + + # Extract features based on prediction mode + if prediction_mode == "test_odorant": + X, test_odorants, y = predictor._extract_test_odorant_features(valid_odorants) + elif prediction_mode == "trained_odorant": + trained_odorant = predictor.CONDITION_ODORANT_MAPPING.get(condition) + if not trained_odorant: + raise ValueError(f"Could not determine trained odorant for {condition}") + X, test_odorants, y = predictor._extract_trained_odorant_features( + trained_odorant, valid_odorants + ) + elif prediction_mode == "interaction": + trained_odorant = predictor.CONDITION_ODORANT_MAPPING.get(condition) + if not trained_odorant: + raise ValueError(f"Could not determine trained odorant for {condition}") + X, test_odorants, y = predictor._extract_interaction_features( + trained_odorant, valid_odorants + ) + else: + raise ValueError(f"Unknown prediction_mode: {prediction_mode}") + + if X.shape[0] < 3: + raise ValueError(f"Insufficient data: only {X.shape[0]} samples") + + # Get receptor names + receptor_names = list(predictor.masked_receptor_names) + + # Fit scaler on baseline X (if scaling enabled) + if scale_features: + scaler = StandardScaler() + scaler.fit(X) + else: + scaler = None + + # Fit baseline model + weights, cv_r2, cv_mse, best_lambda, y_pred = fit_lasso_with_fixed_scaler( + X=X, + y=y, + receptor_names=receptor_names, + scaler=scaler, + lambda_range=lambda_range, + cv_folds=cv_folds, + ) + + logger.info( + f"Baseline: R² = {cv_r2:.4f}, MSE = {cv_mse:.4f}, " + f"λ = {best_lambda:.6f}, {len(weights)} receptors selected" + ) + + return weights, cv_r2, cv_mse, best_lambda, X, y, receptor_names, scaler + + +def run_focus( + X: np.ndarray, + y: np.ndarray, + receptor_names: List[str], + receptors_to_keep: List[str], + lambda_range: np.ndarray, + cv_folds: int, + scale_features: bool, + baseline_r2: float, + baseline_mse: float, +) -> FocusResult: + """Run LASSO with restricted receptor set.""" + n_receptors = len(receptors_to_keep) + logger.info(f"Running focus mode with N={n_receptors} receptors: {receptors_to_keep}") + + # Restrict X to specified receptors + X_restricted, kept_names, kept_indices = restrict_to_receptors( + X=X, + receptor_names=receptor_names, + receptors_to_keep=receptors_to_keep, + ) + + # Fit scaler on restricted X (for fair comparison within this N) + if scale_features: + scaler = StandardScaler() + scaler.fit(X_restricted) + else: + scaler = None + + # Fit model + weights, cv_r2, cv_mse, best_lambda, y_pred = fit_lasso_with_fixed_scaler( + X=X_restricted, + y=y, + receptor_names=kept_names, + scaler=scaler, + lambda_range=lambda_range, + cv_folds=cv_folds, + ) + + result = FocusResult( + n_receptors=n_receptors, + receptors_used=kept_names, + cv_r2=cv_r2, + cv_mse=cv_mse, + lambda_value=best_lambda, + n_receptors_selected=len(weights), + lasso_weights=weights, + delta_r2=cv_r2 - baseline_r2, + delta_mse=cv_mse - baseline_mse, + ) + + logger.info( + f"Focus N={n_receptors}: R² = {cv_r2:.4f} (Δ = {result.delta_r2:+.4f}), " + f"MSE = {cv_mse:.4f} (Δ = {result.delta_mse:+.4f})" + ) + + return result + + +def main() -> int: + """Main entry point.""" + args = parse_args() + + # Parse lambda range + if args.lambda_value is not None: + lambda_range = np.array([args.lambda_value]) + else: + lambda_range = np.array([float(x.strip()) for x in args.lambda_range.split(",")]) + + logger.info(f"Lambda range: {lambda_range}") + + # Initialize predictor + predictor = LassoBehavioralPredictor( + doorcache_path=args.door_cache, + behavior_csv_path=args.behavior_csv, + scale_features=False, # We handle scaling manually + scale_targets=False, + ) + + # Validate condition + if args.condition not in predictor.behavioral_data.index: + logger.error( + f"Condition '{args.condition}' not found. " + f"Available: {list(predictor.behavioral_data.index)}" + ) + return 1 + + # Create output directory + output_dir = Path(args.output_dir) / args.condition + output_dir.mkdir(parents=True, exist_ok=True) + + # Run baseline + try: + ( + baseline_weights, + baseline_r2, + baseline_mse, + baseline_lambda, + X, + y, + receptor_names, + scaler, + ) = run_baseline( + predictor=predictor, + condition=args.condition, + prediction_mode=args.prediction_mode, + lambda_range=lambda_range, + cv_folds=args.cv_folds, + scale_features=args.scale_features, + ) + except Exception as e: + logger.error(f"Baseline fit failed: {e}") + return 1 + + # Save baseline + save_baseline_json( + baseline_weights=baseline_weights, + baseline_r2=baseline_r2, + baseline_mse=baseline_mse, + baseline_lambda=baseline_lambda, + condition=args.condition, + prediction_mode=args.prediction_mode, + n_samples=X.shape[0], + n_receptors_total=X.shape[1], + filepath=output_dir / "baseline_model.json", + ) + + # Determine focus runs + focus_results: List[FocusResult] = [] + + if args.focus_receptors: + # Use explicit receptor list + focus_receptors = [r.strip() for r in args.focus_receptors.split(",") if r.strip()] + logger.info(f"Using explicit focus receptors: {focus_receptors}") + + try: + result = run_focus( + X=X, + y=y, + receptor_names=receptor_names, + receptors_to_keep=focus_receptors, + lambda_range=lambda_range, + cv_folds=args.cv_folds, + scale_features=args.scale_features, + baseline_r2=baseline_r2, + baseline_mse=baseline_mse, + ) + focus_results.append(result) + + # Save individual focus result + focus_dir = output_dir / f"focus_n{result.n_receptors}" + save_focus_json( + result=result, + condition=args.condition, + prediction_mode=args.prediction_mode, + n_samples=X.shape[0], + filepath=focus_dir / "model.json", + ) + + except Exception as e: + logger.error(f"Focus run failed: {e}") + + else: + # Use topn_list + topn_list = [int(x.strip()) for x in args.topn_list.split(",")] + logger.info(f"Top-N values to test: {topn_list}") + + # Get ranked receptors from baseline + if not baseline_weights: + logger.error("Baseline model has no non-zero weights, cannot determine top-N") + return 1 + + if args.baseline_select_by == "abs_weight": + ranked_receptors = get_top_receptors_by_weight(baseline_weights, len(baseline_weights)) + else: + # stability mode would require multiple runs - fallback to abs_weight + logger.warning("stability mode not yet implemented, using abs_weight") + ranked_receptors = get_top_receptors_by_weight(baseline_weights, len(baseline_weights)) + + logger.info(f"Receptor ranking (top 10): {ranked_receptors[:10]}") + + for n in topn_list: + if n > len(ranked_receptors): + logger.warning( + f"N={n} exceeds available ranked receptors ({len(ranked_receptors)}), skipping" + ) + continue + + if n <= 0: + logger.warning(f"N={n} is invalid, skipping") + continue + + top_n_receptors = ranked_receptors[:n] + + try: + result = run_focus( + X=X, + y=y, + receptor_names=receptor_names, + receptors_to_keep=top_n_receptors, + lambda_range=lambda_range, + cv_folds=args.cv_folds, + scale_features=args.scale_features, + baseline_r2=baseline_r2, + baseline_mse=baseline_mse, + ) + focus_results.append(result) + + # Save individual focus result + focus_dir = output_dir / f"focus_n{n}" + save_focus_json( + result=result, + condition=args.condition, + prediction_mode=args.prediction_mode, + n_samples=X.shape[0], + filepath=focus_dir / "model.json", + ) + + except Exception as e: + logger.error(f"Focus N={n} failed: {e}") + continue + + # Generate focus_curve.csv + curve_rows = [] + + # Add baseline row + curve_rows.append({ + "n_receptors": X.shape[1], + "receptors_used": ";".join(sorted(baseline_weights.keys())), + "cv_r2": baseline_r2, + "cv_mse": baseline_mse, + "lambda_value": baseline_lambda, + "n_receptors_selected": len(baseline_weights), + "delta_r2": 0.0, + "delta_mse": 0.0, + "is_baseline": True, + }) + + # Add focus rows + for result in focus_results: + curve_rows.append({ + "n_receptors": result.n_receptors, + "receptors_used": ";".join(result.receptors_used), + "cv_r2": result.cv_r2, + "cv_mse": result.cv_mse, + "lambda_value": result.lambda_value, + "n_receptors_selected": result.n_receptors_selected, + "delta_r2": result.delta_r2, + "delta_mse": result.delta_mse, + "is_baseline": False, + }) + + curve_df = pd.DataFrame(curve_rows) + curve_df = curve_df.sort_values("n_receptors") + curve_path = output_dir / "focus_curve.csv" + curve_df.to_csv(curve_path, index=False) + logger.info(f"Saved focus curve to {curve_path}") + + # Generate plots + plot_focus_curve( + focus_results=focus_results, + baseline_r2=baseline_r2, + baseline_mse=baseline_mse, + output_dir=output_dir, + condition=args.condition, + ) + + # Print summary + print("\n" + "=" * 80) + print(f"LASSO Focus Mode Analysis Complete: {args.condition}") + print("=" * 80) + print(f"\nBaseline (full): R² = {baseline_r2:.4f}, MSE = {baseline_mse:.4f}") + print(f" {len(baseline_weights)} receptors selected out of {X.shape[1]}") + print(f"\nFocus Mode Results:") + for result in sorted(focus_results, key=lambda r: r.n_receptors): + print( + f" N={result.n_receptors:3d}: " + f"R² = {result.cv_r2:.4f} (Δ = {result.delta_r2:+.4f}) " + f"MSE = {result.cv_mse:.4f} (Δ = {result.delta_mse:+.4f})" + ) + print(f"\nOutputs saved to: {output_dir}") + print("=" * 80) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/door_toolkit/pathways/__init__.py b/src/door_toolkit/pathways/__init__.py index b200b5d..4673b8b 100644 --- a/src/door_toolkit/pathways/__init__.py +++ b/src/door_toolkit/pathways/__init__.py @@ -36,6 +36,8 @@ resolve_receptor_names, apply_receptor_ablation, fit_lasso_with_fixed_scaler, + restrict_to_receptors, + get_top_receptors_by_weight, ) from door_toolkit.pathways.behavior_rate_model import ( SparseRateGLM, @@ -56,6 +58,8 @@ "resolve_receptor_names", "apply_receptor_ablation", "fit_lasso_with_fixed_scaler", + "restrict_to_receptors", + "get_top_receptors_by_weight", "SparseRateGLM", "TrainConfig", "build_training_table", diff --git a/src/door_toolkit/pathways/behavioral_prediction.py b/src/door_toolkit/pathways/behavioral_prediction.py index c3442a0..77af4b4 100644 --- a/src/door_toolkit/pathways/behavioral_prediction.py +++ b/src/door_toolkit/pathways/behavioral_prediction.py @@ -1417,3 +1417,86 @@ def fit_lasso_with_fixed_scaler( cv_mse = float(-np.mean(cv_mse_scores)) return lasso_weights, cv_r2, cv_mse, best_lambda, y_pred + + +def restrict_to_receptors( + X: np.ndarray, + receptor_names: List[str], + receptors_to_keep: List[str], +) -> Tuple[np.ndarray, List[str], List[int]]: + """ + Restrict feature matrix to specified receptor subset. + + Args: + X: Feature matrix (n_samples, n_receptors) + receptor_names: List of receptor names corresponding to X columns + receptors_to_keep: Receptor names to retain (case-insensitive) + + Returns: + Tuple of (X_restricted, kept_receptor_names, kept_indices) + - X_restricted: Subset of X with only specified columns + - kept_receptor_names: Receptor names in kept order + - kept_indices: Original column indices that were kept + + Raises: + ValueError: If any receptor in receptors_to_keep not found + + Example: + >>> X = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + >>> X_res, names, idx = restrict_to_receptors( + ... X, ["Or42b", "Or47b", "Or22a"], ["Or47b", "Or22a"] + ... ) + >>> X_res + array([[2., 3.], + [5., 6.]]) + >>> names + ['Or47b', 'Or22a'] + """ + # Resolve receptor names (strict mode) + matched, unmatched = resolve_receptor_names( + receptors_to_keep, receptor_names, strict=True + ) + + # Find indices of matched receptors (in original order) + name_to_idx = {name: i for i, name in enumerate(receptor_names)} + kept_indices = [name_to_idx[r] for r in matched] + + # Sort indices to preserve original column order + kept_indices_sorted = sorted(kept_indices) + kept_receptor_names = [receptor_names[i] for i in kept_indices_sorted] + + # Restrict X to kept columns + X_restricted = X[:, kept_indices_sorted] + + logger.info( + f"Restricted to {len(kept_indices_sorted)} receptors: {kept_receptor_names}" + ) + + return X_restricted, kept_receptor_names, kept_indices_sorted + + +def get_top_receptors_by_weight( + lasso_weights: Dict[str, float], + n: int, +) -> List[str]: + """ + Get top N receptors ranked by absolute LASSO weight. + + Args: + lasso_weights: Dictionary of {receptor: weight} + n: Number of top receptors to return + + Returns: + List of receptor names sorted by absolute weight (descending) + + Example: + >>> weights = {"Or42b": 0.5, "Or47b": -0.8, "Or22a": 0.3} + >>> get_top_receptors_by_weight(weights, 2) + ['Or47b', 'Or42b'] + """ + sorted_receptors = sorted( + lasso_weights.items(), + key=lambda x: abs(x[1]), + reverse=True, + ) + return [r for r, _ in sorted_receptors[:n]] diff --git a/tests/test_lasso_focus_mode.py b/tests/test_lasso_focus_mode.py new file mode 100644 index 0000000..505360b --- /dev/null +++ b/tests/test_lasso_focus_mode.py @@ -0,0 +1,485 @@ +"""Tests for LASSO focus mode analysis functionality.""" + +import json +from pathlib import Path + +import numpy as np +import pandas as pd +import pytest +from sklearn.preprocessing import StandardScaler + +from door_toolkit.pathways.behavioral_prediction import ( + get_top_receptors_by_weight, + restrict_to_receptors, + fit_lasso_with_fixed_scaler, +) + + +# ============================================================================ +# Tests for restrict_to_receptors +# ============================================================================ + + +class TestRestrictToReceptors: + """Tests for feature matrix restriction.""" + + def test_basic_restriction(self): + """Test basic column restriction.""" + X = np.array([ + [1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0], + [9.0, 10.0, 11.0, 12.0], + ]) + receptor_names = ["Or42b", "Or47b", "Or22a", "Or59b"] + + X_res, kept_names, kept_idx = restrict_to_receptors( + X, receptor_names, ["Or47b", "Or22a"] + ) + + # Check shapes + assert X_res.shape == (3, 2) + + # Check values (columns 1 and 2) + np.testing.assert_array_equal(X_res[:, 0], [2.0, 6.0, 10.0]) + np.testing.assert_array_equal(X_res[:, 1], [3.0, 7.0, 11.0]) + + # Check names and indices + assert kept_names == ["Or47b", "Or22a"] + assert kept_idx == [1, 2] + + def test_restriction_preserves_row_count(self): + """Test that restriction preserves number of rows.""" + X = np.random.rand(10, 20) + receptor_names = [f"Or{i}" for i in range(20)] + + X_res, kept_names, kept_idx = restrict_to_receptors( + X, receptor_names, ["Or5", "Or10", "Or15"] + ) + + assert X_res.shape[0] == 10 + assert X_res.shape[1] == 3 + + def test_restriction_with_case_insensitive_names(self): + """Test that restriction handles case-insensitive names.""" + X = np.array([[1.0, 2.0, 3.0]]) + receptor_names = ["Or42b", "Or47b", "Or22a"] + + X_res, kept_names, kept_idx = restrict_to_receptors( + X, receptor_names, ["or42B", "OR22A"] + ) + + assert X_res.shape == (1, 2) + assert set(kept_names) == {"Or42b", "Or22a"} + + def test_restriction_invalid_receptor_raises(self): + """Test that invalid receptor names raise ValueError.""" + X = np.array([[1.0, 2.0, 3.0]]) + receptor_names = ["Or42b", "Or47b", "Or22a"] + + with pytest.raises(ValueError, match="Could not resolve"): + restrict_to_receptors(X, receptor_names, ["Or99x"]) + + def test_restriction_single_column(self): + """Test restriction to single column.""" + X = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + receptor_names = ["Or42b", "Or47b", "Or22a"] + + X_res, kept_names, kept_idx = restrict_to_receptors( + X, receptor_names, ["Or47b"] + ) + + assert X_res.shape == (2, 1) + np.testing.assert_array_equal(X_res[:, 0], [2.0, 5.0]) + + def test_restriction_all_columns(self): + """Test restriction to all columns (no change).""" + X = np.array([[1.0, 2.0, 3.0]]) + receptor_names = ["Or42b", "Or47b", "Or22a"] + + X_res, kept_names, kept_idx = restrict_to_receptors( + X, receptor_names, ["Or42b", "Or47b", "Or22a"] + ) + + np.testing.assert_array_equal(X_res, X) + assert kept_names == receptor_names + + def test_restriction_preserves_order(self): + """Test that restriction preserves original column order.""" + X = np.array([[1.0, 2.0, 3.0, 4.0, 5.0]]) + receptor_names = ["Or1", "Or2", "Or3", "Or4", "Or5"] + + # Request in different order + X_res, kept_names, kept_idx = restrict_to_receptors( + X, receptor_names, ["Or5", "Or2", "Or4"] + ) + + # Should be in original order (2, 4, 5) + assert kept_names == ["Or2", "Or4", "Or5"] + assert kept_idx == [1, 3, 4] + np.testing.assert_array_equal(X_res[0], [2.0, 4.0, 5.0]) + + +# ============================================================================ +# Tests for get_top_receptors_by_weight +# ============================================================================ + + +class TestGetTopReceptorsByWeight: + """Tests for top-N receptor selection by weight.""" + + def test_basic_selection(self): + """Test basic top-N selection.""" + weights = {"Or42b": 0.5, "Or47b": -0.8, "Or22a": 0.3, "Or59b": -0.1} + + top_2 = get_top_receptors_by_weight(weights, 2) + + assert len(top_2) == 2 + assert top_2[0] == "Or47b" # Highest abs weight (-0.8) + assert top_2[1] == "Or42b" # Second highest (0.5) + + def test_selection_deterministic(self): + """Test that selection is deterministic.""" + weights = {"Or42b": 0.5, "Or47b": -0.8, "Or22a": 0.3} + + # Run multiple times + results = [get_top_receptors_by_weight(weights, 2) for _ in range(10)] + + # All results should be identical + for result in results: + assert result == results[0] + + def test_selection_all_receptors(self): + """Test selecting all receptors.""" + weights = {"Or42b": 0.5, "Or47b": -0.8, "Or22a": 0.3} + + top_all = get_top_receptors_by_weight(weights, 3) + + assert len(top_all) == 3 + assert set(top_all) == {"Or42b", "Or47b", "Or22a"} + + def test_selection_more_than_available(self): + """Test requesting more receptors than available.""" + weights = {"Or42b": 0.5, "Or47b": -0.8} + + top_5 = get_top_receptors_by_weight(weights, 5) + + assert len(top_5) == 2 # Only 2 available + + def test_selection_single(self): + """Test selecting single top receptor.""" + weights = {"Or42b": 0.5, "Or47b": -0.8, "Or22a": 0.3} + + top_1 = get_top_receptors_by_weight(weights, 1) + + assert len(top_1) == 1 + assert top_1[0] == "Or47b" + + def test_selection_zero(self): + """Test selecting zero receptors.""" + weights = {"Or42b": 0.5, "Or47b": -0.8} + + top_0 = get_top_receptors_by_weight(weights, 0) + + assert len(top_0) == 0 + + def test_selection_considers_absolute_value(self): + """Test that selection uses absolute value.""" + weights = {"Or42b": 0.3, "Or47b": -0.5, "Or22a": 0.4} + + top_2 = get_top_receptors_by_weight(weights, 2) + + # Should be -0.5 (abs=0.5) and 0.4 (abs=0.4) + assert set(top_2) == {"Or47b", "Or22a"} + + +# ============================================================================ +# Tests for focus mode integration +# ============================================================================ + + +class TestFocusModeIntegration: + """Integration tests for focus mode workflow.""" + + @pytest.fixture + def synthetic_data(self): + """Create synthetic data for focus mode tests.""" + np.random.seed(42) + n_samples, n_features = 20, 10 + + X = np.random.rand(n_samples, n_features) + + # Create y with correlation to first two receptors + true_weights = np.zeros(n_features) + true_weights[0] = 1.0 + true_weights[1] = 0.5 + y = X @ true_weights + np.random.randn(n_samples) * 0.1 + + receptor_names = [f"Or{i}" for i in range(n_features)] + return X, y, receptor_names + + def test_focus_workflow_reduces_features(self, synthetic_data): + """Test that focus mode correctly reduces feature count.""" + X, y, receptor_names = synthetic_data + lambda_range = np.array([0.01, 0.1]) + + # Fit baseline + scaler = StandardScaler().fit(X) + baseline_weights, r2, mse, lam, pred = fit_lasso_with_fixed_scaler( + X, y, receptor_names, scaler, lambda_range + ) + + # Get top-3 receptors + if len(baseline_weights) >= 3: + top_3 = get_top_receptors_by_weight(baseline_weights, 3) + else: + top_3 = get_top_receptors_by_weight(baseline_weights, len(baseline_weights)) + + # Restrict to top-3 + X_focused, kept_names, kept_idx = restrict_to_receptors( + X, receptor_names, top_3 + ) + + assert X_focused.shape[1] <= 3 + assert len(kept_names) <= 3 + + def test_focus_produces_valid_model(self, synthetic_data): + """Test that focus mode produces valid LASSO model.""" + X, y, receptor_names = synthetic_data + lambda_range = np.array([0.01, 0.1]) + + # Fit baseline + scaler = StandardScaler().fit(X) + baseline_weights, r2, mse, lam, pred = fit_lasso_with_fixed_scaler( + X, y, receptor_names, scaler, lambda_range + ) + + # Focus on top-5 + if len(baseline_weights) >= 5: + top_n = get_top_receptors_by_weight(baseline_weights, 5) + else: + top_n = get_top_receptors_by_weight(baseline_weights, len(baseline_weights)) + + X_focused, kept_names, kept_idx = restrict_to_receptors( + X, receptor_names, top_n + ) + + # Fit on focused data + scaler_focused = StandardScaler().fit(X_focused) + focus_weights, focus_r2, focus_mse, focus_lam, focus_pred = fit_lasso_with_fixed_scaler( + X_focused, y, kept_names, scaler_focused, lambda_range + ) + + # Should produce valid outputs + assert isinstance(focus_weights, dict) + assert isinstance(focus_r2, float) + assert isinstance(focus_mse, float) + assert len(focus_pred) == len(y) + + +# ============================================================================ +# Tests for script output formats +# ============================================================================ + + +class TestFocusModeScriptOutputs: + """Tests for focus mode script output files.""" + + def test_save_baseline_json_format(self, tmp_path): + """Test that save_baseline_json creates valid JSON.""" + from scripts.lasso_with_focus_mode import save_baseline_json + + output_dir = tmp_path / "output" + output_dir.mkdir() + + baseline_weights = {"Or42b": 0.5, "Or47b": -0.3, "Or22a": 0.1} + filepath = output_dir / "baseline_model.json" + + save_baseline_json( + baseline_weights=baseline_weights, + baseline_r2=0.75, + baseline_mse=0.025, + baseline_lambda=0.01, + condition="opto_hex", + prediction_mode="test_odorant", + n_samples=10, + n_receptors_total=20, + filepath=filepath, + ) + + assert filepath.exists() + + with open(filepath) as f: + data = json.load(f) + + assert data["condition_name"] == "opto_hex" + assert data["cv_r2"] == 0.75 + assert data["cv_mse"] == 0.025 + assert "receptor_ranking" in data + assert len(data["receptor_ranking"]) == 3 + + def test_save_focus_json_format(self, tmp_path): + """Test that save_focus_json creates valid JSON.""" + from scripts.lasso_with_focus_mode import save_focus_json, FocusResult + + output_dir = tmp_path / "output" + output_dir.mkdir() + + result = FocusResult( + n_receptors=3, + receptors_used=["Or42b", "Or47b", "Or22a"], + cv_r2=0.65, + cv_mse=0.035, + lambda_value=0.01, + n_receptors_selected=2, + lasso_weights={"Or42b": 0.6, "Or47b": -0.4}, + delta_r2=-0.10, + delta_mse=0.010, + ) + + filepath = output_dir / "model.json" + save_focus_json( + result=result, + condition="opto_hex", + prediction_mode="test_odorant", + n_samples=10, + filepath=filepath, + ) + + assert filepath.exists() + + with open(filepath) as f: + data = json.load(f) + + assert data["n_receptors_focused"] == 3 + assert data["receptors_used"] == ["Or42b", "Or47b", "Or22a"] + assert data["delta_r2"] == -0.10 + + def test_focus_curve_csv_format(self, tmp_path): + """Test that focus_curve.csv has correct format.""" + # Create sample curve rows + curve_rows = [ + { + "n_receptors": 20, + "receptors_used": "Or42b;Or47b;Or22a", + "cv_r2": 0.75, + "cv_mse": 0.025, + "lambda_value": 0.01, + "n_receptors_selected": 5, + "delta_r2": 0.0, + "delta_mse": 0.0, + "is_baseline": True, + }, + { + "n_receptors": 3, + "receptors_used": "Or42b;Or47b;Or22a", + "cv_r2": 0.65, + "cv_mse": 0.035, + "lambda_value": 0.01, + "n_receptors_selected": 2, + "delta_r2": -0.10, + "delta_mse": 0.010, + "is_baseline": False, + }, + ] + + curve_df = pd.DataFrame(curve_rows) + curve_path = tmp_path / "focus_curve.csv" + curve_df.to_csv(curve_path, index=False) + + # Verify CSV can be read back + loaded_df = pd.read_csv(curve_path) + assert "n_receptors" in loaded_df.columns + assert "cv_r2" in loaded_df.columns + assert "cv_mse" in loaded_df.columns + assert "delta_r2" in loaded_df.columns + assert "is_baseline" in loaded_df.columns + assert len(loaded_df) == 2 + + def test_run_focus_with_synthetic_data(self): + """Test run_focus with synthetic data.""" + from scripts.lasso_with_focus_mode import run_focus + + np.random.seed(42) + n_samples, n_features = 15, 10 + X = np.random.rand(n_samples, n_features) + + true_weights = np.zeros(n_features) + true_weights[0] = 1.0 + true_weights[1] = 0.5 + y = X @ true_weights + np.random.randn(n_samples) * 0.1 + + receptor_names = [f"Or{i}" for i in range(n_features)] + lambda_range = np.array([0.01, 0.1]) + + result = run_focus( + X=X, + y=y, + receptor_names=receptor_names, + receptors_to_keep=["Or0", "Or1", "Or2"], + lambda_range=lambda_range, + cv_folds=3, + scale_features=True, + baseline_r2=0.9, + baseline_mse=0.01, + ) + + assert result.n_receptors == 3 + assert result.receptors_used == ["Or0", "Or1", "Or2"] + assert isinstance(result.cv_r2, float) + assert isinstance(result.cv_mse, float) + assert result.delta_r2 == result.cv_r2 - 0.9 + assert result.delta_mse == result.cv_mse - 0.01 + + +# ============================================================================ +# Tests for edge cases +# ============================================================================ + + +class TestFocusModeEdgeCases: + """Tests for edge cases in focus mode.""" + + def test_single_receptor_focus(self): + """Test focusing on a single receptor.""" + X = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]) + receptor_names = ["Or42b", "Or47b", "Or22a"] + + X_res, kept_names, kept_idx = restrict_to_receptors( + X, receptor_names, ["Or47b"] + ) + + assert X_res.shape == (3, 1) + assert kept_names == ["Or47b"] + + def test_topn_with_tied_weights(self): + """Test top-N selection with tied absolute weights.""" + weights = {"Or42b": 0.5, "Or47b": -0.5, "Or22a": 0.3} + + # With tied weights, order should be deterministic (alphabetical within tie) + top_2 = get_top_receptors_by_weight(weights, 2) + + assert len(top_2) == 2 + # Both have abs weight 0.5, order determined by sort stability + assert set(top_2) == {"Or42b", "Or47b"} + + def test_focus_empty_weight_dict(self): + """Test top-N with empty weight dict.""" + weights = {} + + top_3 = get_top_receptors_by_weight(weights, 3) + + assert len(top_3) == 0 + + def test_restrict_to_empty_list(self): + """Test restricting to empty receptor list returns empty.""" + X = np.array([[1.0, 2.0, 3.0]]) + receptor_names = ["Or42b", "Or47b", "Or22a"] + + # Empty input returns empty output (no receptors to resolve) + X_res, kept_names, kept_idx = restrict_to_receptors( + X, receptor_names, [] + ) + + assert X_res.shape == (1, 0) + assert kept_names == [] + assert kept_idx == [] From 8183dbcb610079c8b6389a10f03e7eeea3b968b1 Mon Sep 17 00:00:00 2001 From: ramanlab Date: Thu, 8 Jan 2026 12:24:22 -0600 Subject: [PATCH 3/4] docs: Add ablation and focus mode CLI usage to documentation Co-Authored-By: Claude Opus 4.5 --- README.md | 14 ++++++++ docs/BEHAVIORAL_PREDICTION_ANALYSIS.md | 49 ++++++++++++++++++++++++++ 2 files changed, 63 insertions(+) diff --git a/README.md b/README.md index 79a2b7f..73106f7 100644 --- a/README.md +++ b/README.md @@ -645,6 +645,20 @@ opto_benz_1,0.25,0.02,0.44,0.59,0.12 - Zero weights → receptors excluded by LASSO (not predictive) - Sparse circuits (3-7 receptors) suggest minimal testable hypotheses +**Robustness Analysis:** Two CLI scripts assess circuit robustness. *Ablation* (`lasso_with_ablations.py`) tests necessity by zeroing out receptors and measuring MSE increase. *Focus mode* (`lasso_with_focus_mode.py`) tests sufficiency by refitting LASSO on only the top-N receptors to generate MSE vs N curves. + +```bash +# Ablation: test if removing Or22b/Or49a degrades the model +python scripts/lasso_with_ablations.py --door_cache door_cache \ + --behavior_csv reaction_rates.csv --condition opto_hex \ + --ablate Or22b Or49a --ablation_set_mode single --output_dir ablation_out + +# Focus: test if top 1-5 receptors are sufficient +python scripts/lasso_with_focus_mode.py --door_cache door_cache \ + --behavior_csv reaction_rates.csv --condition opto_hex \ + --topn_list 1 2 3 5 --output_dir focus_out +``` + ### CLI Usage ```bash diff --git a/docs/BEHAVIORAL_PREDICTION_ANALYSIS.md b/docs/BEHAVIORAL_PREDICTION_ANALYSIS.md index dd47ea1..6143b80 100644 --- a/docs/BEHAVIORAL_PREDICTION_ANALYSIS.md +++ b/docs/BEHAVIORAL_PREDICTION_ANALYSIS.md @@ -210,6 +210,55 @@ From the comparison output: --- +## 5.5 Robustness Analysis Scripts + +Two scripts assess how stable the LASSO-identified receptor circuits are: + +### Ablation Analysis + +Test whether the model degrades when specific receptors are ablated (zeroed out): + +```bash +conda activate DoOR +python scripts/lasso_with_ablations.py \ + --door_cache door_cache \ + --behavior_csv /path/to/reaction_rates_summary_unordered.csv \ + --condition opto_hex \ + --output_dir ablation_results \ + --ablate Or22b Or49a \ + --ablation_set_mode single +``` + +**Key arguments:** +- `--ablate`: Receptor(s) to ablate (case-insensitive) +- `--ablation_set_mode`: `single` (ablate each individually) or `all_in_one` (ablate together) +- `--missing_receptor_policy`: `error`, `warn`, or `skip` for unmatched receptors + +**Outputs:** `baseline_model.json`, `ablation_summary.csv`, per-ablation folders, `ablation_comparison.png` + +### Focus Mode Analysis + +Test whether top-N receptors are *sufficient* to maintain model performance: + +```bash +conda activate DoOR +python scripts/lasso_with_focus_mode.py \ + --door_cache door_cache \ + --behavior_csv /path/to/reaction_rates_summary_unordered.csv \ + --condition opto_hex \ + --output_dir focus_results \ + --topn_list 1 2 3 5 10 +``` + +**Key arguments:** +- `--topn_list`: Test subsets of top 1, 2, 3, ... receptors +- `--focus_receptors`: Alternatively, specify exact receptors to include +- `--baseline_select_by`: `abs_weight` (default) or `weight` for ranking + +**Outputs:** `baseline_model.json`, `focus_curve.csv`, `focus_curve.png`, per-N folders + +--- + ## 6. Biological Insights ### Top Receptor Candidates for Experimental Validation From 1ed3db4adbb54c27de61d20add194625aa4ad0ff Mon Sep 17 00:00:00 2001 From: ramanlab Date: Thu, 8 Jan 2026 12:32:11 -0600 Subject: [PATCH 4/4] refactor: Save ablation_summary and comparison plot to ablations/ subfolder - Create dedicated ablations/ subfolder within condition output directory - Move ablation_summary.csv to ablations/ablation_summary.csv - Move ablation_comparison.png to ablations/ablation_comparison.png - Individual ablation results remain in separate ablate_* folders Co-Authored-By: Claude Opus 4.5 --- scripts/lasso_with_ablations.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/scripts/lasso_with_ablations.py b/scripts/lasso_with_ablations.py index ad373dc..3d14cf2 100644 --- a/scripts/lasso_with_ablations.py +++ b/scripts/lasso_with_ablations.py @@ -702,7 +702,12 @@ def main() -> int: }) summary_df = pd.DataFrame(summary_rows) - summary_path = output_dir / "ablation_summary.csv" + + # Create ablations subfolder for summary files + ablations_dir = output_dir / "ablations" + ablations_dir.mkdir(parents=True, exist_ok=True) + + summary_path = ablations_dir / "ablation_summary.csv" summary_df.to_csv(summary_path, index=False) logger.info(f"Saved summary to {summary_path}") @@ -710,7 +715,7 @@ def main() -> int: plot_ablation_comparison( baseline_result=baseline_result, ablation_results=ablation_results, - output_dir=output_dir, + output_dir=ablations_dir, condition=args.condition, )