diff --git a/.gitignore b/.gitignore index 4eb961f..17ce8d8 100644 --- a/.gitignore +++ b/.gitignore @@ -169,4 +169,7 @@ helper-code/ flywire_orn_database/ -diagnostics/ +diagnostics/* +!diagnostics/ +!diagnostics/*.py +!diagnostics/*.md diff --git a/diagnostics/DELTA_PER_COLLAPSE_EXPLANATION.md b/diagnostics/DELTA_PER_COLLAPSE_EXPLANATION.md new file mode 100644 index 0000000..fb7b77b --- /dev/null +++ b/diagnostics/DELTA_PER_COLLAPSE_EXPLANATION.md @@ -0,0 +1,45 @@ +# ΔPER LASSO Collapse Explanation + +Latest audit run: `diagnostics/postA_postB_audit_20260108_154458` + +## Evidence of intercept-only collapse (LASSO) + +- opto_hex / delta_base: INTERCEPT-ONLY; n_selected=0, pred_std=0, cv_mse=0.100622, intercept_only_mse=0.100622, y_std=0.271894 +- opto_hex / delta_extended: INTERCEPT-ONLY; n_selected=0, pred_std=0, cv_mse=0.100622, intercept_only_mse=0.100622, y_std=0.271894 +- opto_EB / delta_base: INTERCEPT-ONLY; n_selected=0, pred_std=0, cv_mse=0.0165606, intercept_only_mse=0.0165606, y_std=0.110304 +- opto_EB / delta_extended: INTERCEPT-ONLY; n_selected=0, pred_std=0, cv_mse=0.0165606, intercept_only_mse=0.0165606, y_std=0.110304 +- opto_benz_1 / delta_base: non-intercept; n_selected=1, pred_std=0.0544301, cv_mse=0.0382769, intercept_only_mse=0.0388923, y_std=0.169038 +- opto_benz_1 / delta_extended: non-intercept; n_selected=1, pred_std=0.0544301, cv_mse=0.0382769, intercept_only_mse=0.0388923, y_std=0.169038 + +## If LASSO collapsed, how much better are Ridge/ElasticNet? + +- opto_hex / delta_base: best alt = elasticnet_0.5 | Δcv_mse=0, Δnmse=0 +- opto_hex / delta_extended: best alt = elasticnet_0.5 | Δcv_mse=0, Δnmse=0 +- opto_EB / delta_base: best alt = ridge | Δcv_mse=0.00694191, Δnmse=0.570554 +- opto_EB / delta_extended: best alt = ridge | Δcv_mse=0.00694192, Δnmse=0.570555 + +## Why LASSO collapses in ΔPER for opto_hex/opto_EB + +The audit shows ΔPER LASSO selecting zero features with pred_std=0 and cv_mse equal to intercept-only MSE. This indicates the LASSO penalty dominates the signal at small n, so the best cross-validated model is the intercept-only baseline. Expanding the ΔPER lambda grid (delta_extended) does not change this for opto_hex/opto_EB, so it is not a grid-resolution artifact. + +## Why low-range datasets look “perfect” in raw MSE + +Raw MSE is scale-dependent: smaller y_std yields smaller MSE even when relative error is similar. Normalized metrics (nmse and rmse_over_y_std) in `diagnostics/delta_model_comparison.csv` should be used for cross-condition comparisons. This avoids misreading low-variance conditions as “perfect fits.” + +## Recommended default model for ΔPER reporting + +When LASSO is intercept-only (n_selected=0, pred_std=0, cv_mse==intercept_only_mse), report the best ElasticNet/Ridge by CV MSE. This is already surfaced in `audit_primary_models.csv` in the latest audit run. + +## Reproducible commands + +```bash +conda run -n DoOR python diagnostics/run_postA_postB_audit.py \ + --door_cache door_cache \ + --behavior_csv "/home/ramanlab/Documents/cole/Results/Opto/Reaction_Predictions(Strictest)/reaction_rates_summary_unordered.csv" \ + --conditions opto_hex,opto_EB,opto_benz_1,opto_ACV,opto_3-oct \ + --prediction_mode test_odorant \ + --cv_folds 5 \ + --lambda_range 0.0001,0.001,0.01,0.1,1.0 \ + --lambda_range_delta 1e-8,1e-7,1e-6,1e-5,1e-4,1e-3,1e-2,1e-1,1.0 \ + --missing_control_policy skip +``` \ No newline at end of file diff --git a/diagnostics/PLAN_STABILITY.md b/diagnostics/PLAN_STABILITY.md new file mode 100644 index 0000000..f7db062 --- /dev/null +++ b/diagnostics/PLAN_STABILITY.md @@ -0,0 +1,32 @@ +# Stability + Metrics Layer Plan + +## Discovery summary +- Feature matrices X are built in `src/door_toolkit/pathways/behavioral_prediction.py` via: + - `_extract_test_odorant_features`, `_extract_trained_odorant_features`, `_extract_interaction_features`. +- Receptor ordering comes from `DoOREncoder.response_matrix` column order in `src/door_toolkit/encoder.py` and is exposed as `predictor.masked_receptor_names` (or `encoder.receptor_names`). +- LASSO selection is in `LassoBehavioralPredictor.fit_behavior()` and `fit_lasso_with_fixed_scaler()` (same file). +- Ridge/ElasticNet CV logic exists in `diagnostics/run_postA_postB_audit.py` (LOOCV grid search). +- Audit outputs are under `diagnostics/postA_postB_audit_*/` with `audit_metrics.csv` + `audit_artifacts.json`. + +## Files to add/change +- Add: `diagnostics/run_stability_and_metrics.py` (new stability + metrics runner). +- Add: `tests/test_stability_metrics.py` (determinism + schema + intercept-only flag tests). +- Update: `.gitignore` to allow tracked `diagnostics/*.py` and `diagnostics/*.md`. +- Update: `docs/BEHAVIORAL_PREDICTION_ANALYSIS.md` with 5-line “how to run stability layer”. + +## Algorithms to implement +- Standardized metrics for each (condition, mode, modelclass): + - y_std, y_var, y_min, y_max; pred_std, pred_min, pred_max; cv_mse; nmse; rmse_over_y_std; + intercept_only_flag; intercept_only_mse (LOOCV mean predictor). +- ORN stability (LOOO): + - For each fold: fit model on n-1 odorants (same scaling rules as baseline). + - Record selected ORNs + coefficients; compute selection_frequency, sign_consistency, + mean/std abs(weight), mean rank by abs(weight). + - LASSO only if not intercept-only; ElasticNet for ΔPER when LASSO is intercept-only; Ridge uses rank stability. +- Experiment shortlist: top 5 ORNs by stability_score = selection_frequency * sign_consistency, + plus confidence flags (nmse>=1, intercept-only, missing controls). + +## Verification steps +- `pytest -q` (determinism + schema tests for stability outputs). +- Run stability script on real CSV + conditions with seed=1337; check outputs: + - `stability_per_condition.csv`, `model_metrics.csv`, `EXPERIMENT_SHORTLIST.md`, `SUMMARY.md`, `RUN_COMMANDS.txt`. diff --git a/diagnostics/baseline_drift_hypotheses.md b/diagnostics/baseline_drift_hypotheses.md new file mode 100644 index 0000000..872cd36 --- /dev/null +++ b/diagnostics/baseline_drift_hypotheses.md @@ -0,0 +1,34 @@ +# Baseline Drift Hypotheses + +## Summary +No obvious in-place mutation or stochastic sources were found in the LASSO predictor or the ablation/focus scripts. The most plausible explanations are (1) accidental in-process mutation of a view derived from `X`, or (2) changes in target alignment when using ΔPER (control subtraction). Each hypothesis below includes file and function references. + +## Hypotheses (with code locations) + +### 1) View-based mutation risk (low likelihood) +- `src/door_toolkit/pathways/behavioral_prediction.py:1560` `restrict_to_receptors()` returns `X[:, kept_indices_sorted]` without `.copy()`. + - This returns a view; if any downstream code mutates `X_restricted` in-place it could mutate the original `X` (and appear as baseline drift). + - In `scripts/lasso_with_focus_mode.py:421` the view is only passed to `StandardScaler.fit` and LASSO fitting, which do not mutate input arrays, so this risk is theoretical but low. + +### 2) In-place ablation (low likelihood) +- `src/door_toolkit/pathways/behavioral_prediction.py:1410` `apply_receptor_ablation()` explicitly copies `X` before ablation. + - This is safe; baseline drift would require a different ablation path that modifies `X` in-place. + - `scripts/lasso_with_ablations.py:456` uses `apply_receptor_ablation()` (safe). + +### 3) Non-determinism in CV or lambda selection (unlikely) +- `src/door_toolkit/pathways/behavioral_prediction.py:915` `LassoCV(... random_state=42)`. +- `src/door_toolkit/pathways/behavioral_prediction.py:961` `cross_val_score` uses deterministic KFold (no shuffle). + - Without shuffle, folds are deterministic and reproducible; no randomness expected. + +### 4) Data alignment differences (likely for ΔPER vs raw) +- `src/door_toolkit/pathways/behavioral_prediction.py:873-926` control subtraction uses different masks depending on `missing_control_policy`. + - ΔPER runs drop rows with NaNs in either opto or control (`skip`), or fill missing controls (`zero`). + - This can change sample counts and target variance vs raw fits, potentially leading to different selected features. + +### 5) Dataset label normalization changes (low impact) +- `src/door_toolkit/pathways/behavioral_prediction.py:726` `_resolve_dataset_name()` normalizes dataset labels. + - If the CSV index has multiple labels that normalize to the same token, this can cause ambiguity errors; otherwise should not affect results. + +## Notes +- No global caches or shared mutable matrices were found in the predictor; `get_receptor_profile()` returns fresh arrays. +- The diagnostic script added in this task will validate reproducibility and detect constant-prediction collapses. diff --git a/diagnostics/repo_state.md b/diagnostics/repo_state.md new file mode 100644 index 0000000..e9b9e60 --- /dev/null +++ b/diagnostics/repo_state.md @@ -0,0 +1,65 @@ +# Repo State Snapshot + +## Commands + +### git status -sb +``` +## feature/lasso-subtract-control...origin/feature/lasso-subtract-control +``` + +### git log -n 20 --oneline +``` +64c79ce feat: Implement control subtraction for LASSO behavioral prediction and add corresponding tests +3225982 Merge pull request #2 from colehanan1/feature/lasso-ablation-analysis +1ed3db4 refactor: Save ablation_summary and comparison plot to ablations/ subfolder +8183dbc docs: Add ablation and focus mode CLI usage to documentation +4278a80 feat: Add LASSO focus mode analysis for receptor circuit sufficiency +026dad9 feat: Add LASSO ablation analysis for receptor circuit robustness +bc9234f feat: Add support for strict mode in connectome analysis and deprecate Shapley importance method in favor of Shapley-proxy +e5574e8 Merge pull request #1 from colehanan1/audit/codex_repo_analysis +1b83066 feat: Add synthetic importance audit scripts for connectome, GLM, LASSO, and Shapley methods +8093664 Add threshold calibration utilities and corresponding tests +bd12a4b feat: Update .gitignore to include 'outputs/', 'helper-code/', and 'flywire_orn_database/' directories +6f84d36 feat: Update .gitignore to include 'outputs/' and '.claude/' directories +3ac9598 Release v1.0.1 +8ca9bf7 Add comprehensive test suites for mapping accounting and identifier resolution +862c0f2 Add receptor sensitivity diagnostics script +4219faa Release v1.0.0: Production-ready toolkit with mushroom body circuit validation +49ba71a feat: Add Mushroom Body Circuit Validation module and update README with new features +19b029b feat: Add FlyWire Mushroom Body Pathway Analysis script and Mushroom Body Tracer module +458cf39 Add comprehensive documentation for behavioral prediction analysis, connectomics module, custom pathway guide, and FlyWire integration notes +3b5f197 Add LASSO regression-based behavioral prediction and enhance existing predictor +``` + +### git diff +``` + +``` + +### git diff --stat +``` + +``` + +## Changed Files Relevant to Drift Investigation + +### Behavioral prediction core +- `src/door_toolkit/pathways/behavioral_prediction.py` + - Commit `64c79ce` adds control-subtraction in `fit_behavior`, plus dataset name normalization helpers and new metadata fields. + - Helper utilities for ablation/focus mode are in this file (see `apply_receptor_ablation`, `fit_lasso_with_fixed_scaler`, `restrict_to_receptors`). + +### Ablation/focus scripts +- `scripts/lasso_with_ablations.py` and `scripts/lasso_with_focus_mode.py` show no diffs in this branch vs main (branch equals main at `64c79ce`). + +### Helpers for X construction / scaling / lambda selection +- `LassoBehavioralPredictor.fit_behavior()` in `behavioral_prediction.py` constructs X via `_extract_*` helpers. +- Scaling uses `StandardScaler.fit_transform` (new arrays, no in-place mutation). +- Lambda selection uses `LassoCV(random_state=42)` and `cross_val_score` with deterministic folds. + +## Commit Stats (64c79ce) +``` +docs/BEHAVIORAL_PREDICTION_ANALYSIS.md | 20 ++ +scripts/run_lasso_behavioral_prediction.py | 239 +++++++++++++++++++++ +src/door_toolkit/pathways/behavioral_prediction.py | 156 +++++++++++++- +tests/test_lasso_behavioral_prediction.py | 86 ++++++++ +``` diff --git a/diagnostics/run_postA_postB_audit.py b/diagnostics/run_postA_postB_audit.py new file mode 100644 index 0000000..a140e2c --- /dev/null +++ b/diagnostics/run_postA_postB_audit.py @@ -0,0 +1,1034 @@ +#!/usr/bin/env python3 +""" +Audit LASSO behavioral prediction pipeline after Part A/B changes. + +Writes a timestamped diagnostics run folder with metrics, artifacts, and plots. +""" + +from __future__ import annotations + +import argparse +import json +import logging +import random +import sys +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from sklearn.linear_model import ElasticNet, Lasso, Ridge +from sklearn.model_selection import LeaveOneOut, cross_val_score +from sklearn.preprocessing import StandardScaler + +matplotlib.use("Agg") + +# Add src to path for repo-local runs. +sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src")) + +from door_toolkit.pathways.behavioral_prediction import ( + LassoBehavioralPredictor, + apply_receptor_ablation, + fit_lasso_with_fixed_scaler, + get_top_receptors_by_weight, + restrict_to_receptors, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class RunResult: + condition: str + mode: str + modelclass: str + repeat: int + cv_mse: float + lambda_value: float + n_selected: int + pred_std: float + y_std: float + intercept_only_mse: float + nmse: float + mae: float + mae_over_y_std: float + intercept_only: bool + chosen_params: Dict + weights: Dict[str, float] + y_stats: Dict + pred_stats: Dict + reproducible: bool + mutation_ok: bool + + +def _parse_conditions(values: List[str]) -> List[str]: + conditions: List[str] = [] + seen = set() + for value in values: + for token in value.split(","): + token = token.strip() + if not token or token in seen: + continue + conditions.append(token) + seen.add(token) + return conditions + + +def _parse_lambda_range(value: str) -> np.ndarray: + tokens = [token.strip() for token in value.split(",") if token.strip()] + if not tokens: + raise ValueError("lambda_range cannot be empty") + return np.array([float(token) for token in tokens], dtype=np.float64) + + +def _ensure_dir(path: Path) -> None: + path.mkdir(parents=True, exist_ok=True) + + +def _loocv_intercept_mse(y: np.ndarray) -> float: + if len(y) < 2: + return float("nan") + loo = LeaveOneOut() + errors = [] + for train_idx, test_idx in loo.split(y): + train_mean = float(np.mean(y[train_idx])) + err = float((y[test_idx][0] - train_mean) ** 2) + errors.append(err) + return float(np.mean(errors)) + + +def _cv_mse_grid(model_factory, X: np.ndarray, y: np.ndarray, lambdas: np.ndarray) -> List[float]: + loo = LeaveOneOut() + mse_values: List[float] = [] + for lam in lambdas: + model = model_factory(lam) + scores = cross_val_score(model, X, y, cv=loo, scoring="neg_mean_squared_error") + mse_values.append(float(-np.mean(scores))) + return mse_values + + +def _fit_ridge( + X: np.ndarray, + y: np.ndarray, + lambdas: np.ndarray, +) -> Tuple[float, float, np.ndarray, float]: + def factory(alpha: float) -> Ridge: + return Ridge(alpha=alpha, random_state=42) + + mse_grid = _cv_mse_grid(factory, X, y, lambdas) + best_idx = int(np.argmin(mse_grid)) + best_lambda = float(lambdas[best_idx]) + + model = factory(best_lambda) + model.fit(X, y) + y_pred = model.predict(X) + return best_lambda, float(mse_grid[best_idx]), model.coef_, y_pred + + +def _fit_elasticnet( + X: np.ndarray, + y: np.ndarray, + lambdas: np.ndarray, + l1_ratios: List[float], +) -> Tuple[float, float, float, np.ndarray, np.ndarray]: + best_lambda = None + best_l1 = None + best_mse = None + best_coef = None + best_pred = None + + for l1_ratio in l1_ratios: + def factory(alpha: float) -> ElasticNet: + return ElasticNet( + alpha=alpha, + l1_ratio=l1_ratio, + random_state=42, + max_iter=10000, + ) + + mse_grid = _cv_mse_grid(factory, X, y, lambdas) + idx = int(np.argmin(mse_grid)) + mse = float(mse_grid[idx]) + if best_mse is None or mse < best_mse: + best_mse = mse + best_lambda = float(lambdas[idx]) + best_l1 = float(l1_ratio) + model = factory(best_lambda) + model.fit(X, y) + best_coef = model.coef_.copy() + best_pred = model.predict(X) + + if best_lambda is None or best_coef is None or best_pred is None or best_l1 is None: + raise RuntimeError("ElasticNet fitting failed") + + return best_lambda, float(best_mse), best_l1, best_coef, best_pred + + +def _fit_lasso( + X: np.ndarray, + y: np.ndarray, + receptor_names: List[str], + lambdas: np.ndarray, + cv_folds: int, + scaler: Optional[StandardScaler], +) -> Tuple[Dict[str, float], float, float, np.ndarray]: + weights, _cv_r2, cv_mse, best_lambda, y_pred = fit_lasso_with_fixed_scaler( + X=X, + y=y, + receptor_names=receptor_names, + scaler=scaler, + lambda_range=lambdas, + cv_folds=cv_folds, + ) + return weights, float(cv_mse), float(best_lambda), y_pred + + +def _get_stats(values: np.ndarray) -> Dict[str, float]: + return { + "mean": float(np.mean(values)), + "std": float(np.std(values)), + "min": float(np.min(values)), + "max": float(np.max(values)), + "n": int(len(values)), + } + + +def _plot_mse_grid( + lambdas: np.ndarray, + mse_values: List[float], + title: str, + path: Path, +) -> None: + _ensure_dir(path.parent) + plt.figure(figsize=(6, 4)) + plt.plot(lambdas, mse_values, marker="o") + plt.xscale("log") + plt.xlabel("lambda") + plt.ylabel("LOOCV MSE") + plt.title(title) + plt.grid(alpha=0.3) + plt.tight_layout() + plt.savefig(path, dpi=200) + plt.close() + + +def _plot_scatter(y: np.ndarray, y_pred: np.ndarray, title: str, path: Path) -> None: + _ensure_dir(path.parent) + plt.figure(figsize=(5, 5)) + plt.scatter(y, y_pred, alpha=0.7, edgecolors="k") + min_val = float(min(np.min(y), np.min(y_pred))) + max_val = float(max(np.max(y), np.max(y_pred))) + plt.plot([min_val, max_val], [min_val, max_val], "r--", alpha=0.5) + plt.xlabel("Actual") + plt.ylabel("Predicted") + plt.title(title) + plt.grid(alpha=0.3) + plt.tight_layout() + plt.savefig(path, dpi=200) + plt.close() + + +def _mutation_check( + X: np.ndarray, + receptor_names: List[str], + receptors_to_use: List[str], +) -> bool: + X_before = X.copy() + _ = apply_receptor_ablation( + X=X, + receptor_names=receptor_names, + receptors_to_ablate=receptors_to_use, + ) + _ = restrict_to_receptors( + X=X, + receptor_names=receptor_names, + receptors_to_keep=receptors_to_use, + ) + return np.array_equal(X, X_before) + + +def _build_valid_odorants( + predictor: LassoBehavioralPredictor, + condition: str, + subtract_control: bool, + control_condition: Optional[str], + missing_control_policy: str, +) -> Tuple[pd.Series, str, Optional[str], int]: + resolved_condition = predictor._resolve_dataset_name(condition) + if resolved_condition is None: + raise ValueError(f"Condition '{condition}' not found in behavioral data") + + if not subtract_control: + per_responses = predictor.behavioral_data.loc[resolved_condition] + valid_odorants = per_responses.dropna() + if len(valid_odorants) == 0: + raise ValueError(f"No valid PER data for condition '{resolved_condition}'") + return valid_odorants, resolved_condition, None, 0 + + if missing_control_policy not in {"skip", "zero", "error"}: + raise ValueError( + f"Unknown missing_control_policy '{missing_control_policy}'. " + "Expected one of: skip, zero, error." + ) + + if control_condition is not None: + control_resolved = predictor._resolve_dataset_name(control_condition) + if control_resolved is None: + raise ValueError(f"Control condition '{control_condition}' not found in behavioral data") + else: + control_candidate = predictor._infer_control_condition(resolved_condition) + if control_candidate is None: + raise ValueError( + f"No matched control mapping for '{resolved_condition}'. " + "Provide control_condition or disable subtract_control." + ) + control_resolved = predictor._resolve_dataset_name(control_candidate) + if control_resolved is None: + raise ValueError(f"Matched control '{control_candidate}' not found in behavioral data") + + if control_resolved == resolved_condition: + raise ValueError( + f"Control condition '{control_resolved}' matches opto condition '{resolved_condition}'." + ) + + per_opto = predictor.behavioral_data.loc[resolved_condition] + per_ctrl = predictor.behavioral_data.loc[control_resolved] + + if missing_control_policy == "skip": + valid_mask = per_opto.notna() & per_ctrl.notna() + valid_odorants = (per_opto - per_ctrl)[valid_mask] + elif missing_control_policy == "zero": + valid_mask = per_opto.notna() + per_ctrl_filled = per_ctrl.fillna(0.0) + valid_odorants = (per_opto - per_ctrl_filled)[valid_mask] + else: + missing_mask = per_opto.notna() & per_ctrl.isna() + if missing_mask.any(): + missing_odorants = [str(o) for o in per_opto.index[missing_mask]] + preview = ", ".join(missing_odorants[:5]) + if len(missing_odorants) > 5: + preview = f"{preview} (and {len(missing_odorants) - 5} more)" + raise ValueError( + f"Control condition '{control_resolved}' has missing values for odorants " + f"present in '{resolved_condition}': {preview}" + ) + valid_mask = per_opto.notna() & per_ctrl.notna() + valid_odorants = (per_opto - per_ctrl)[valid_mask] + + if len(valid_odorants) == 0: + raise ValueError( + f"No valid opto/control pairs for '{resolved_condition}' and '{control_resolved}'." + ) + + return valid_odorants, resolved_condition, control_resolved, int(len(valid_odorants)) + + +def _extract_features( + predictor: LassoBehavioralPredictor, + valid_odorants: pd.Series, + condition: str, + prediction_mode: str, +) -> Tuple[np.ndarray, List[str], np.ndarray]: + if prediction_mode == "test_odorant": + return predictor._extract_test_odorant_features(valid_odorants) + if prediction_mode == "trained_odorant": + trained_odorant = predictor.CONDITION_ODORANT_MAPPING.get(condition) + if not trained_odorant: + raise ValueError(f"Could not determine trained odorant for {condition}") + return predictor._extract_trained_odorant_features(trained_odorant, valid_odorants) + if prediction_mode == "interaction": + trained_odorant = predictor.CONDITION_ODORANT_MAPPING.get(condition) + if not trained_odorant: + raise ValueError(f"Could not determine trained odorant for {condition}") + return predictor._extract_interaction_features(trained_odorant, valid_odorants) + raise ValueError(f"Unknown prediction_mode: {prediction_mode}") + + +def _run_one( + *, + condition: str, + mode: str, + predictor: LassoBehavioralPredictor, + lambdas: np.ndarray, + prediction_mode: str, + subtract_control: bool, + control_condition: Optional[str], + missing_control_policy: str, + seed: int, + output_dir: Path, +) -> Tuple[List[RunResult], Dict]: + valid_odorants, resolved_condition, control_resolved, n_pairs_used = _build_valid_odorants( + predictor, + condition=condition, + subtract_control=subtract_control, + control_condition=control_condition, + missing_control_policy=missing_control_policy, + ) + + X, test_odorants, y = _extract_features( + predictor, + valid_odorants, + resolved_condition, + prediction_mode, + ) + + if X.shape[0] < 3: + raise ValueError(f"Insufficient samples for {condition}: {X.shape[0]}") + + receptor_names = list(predictor.masked_receptor_names) + scaler = StandardScaler().fit(X) + X_scaled = scaler.transform(X) + + top_receptors = get_top_receptors_by_weight( + fit_lasso_with_fixed_scaler( + X=X, + y=y, + receptor_names=receptor_names, + scaler=scaler, + lambda_range=lambdas, + cv_folds=X.shape[0], + )[0], + 3, + ) + if not top_receptors: + top_receptors = receptor_names[:1] + + mutation_ok = _mutation_check(X, receptor_names, top_receptors) + + focus_receptors = top_receptors[:2] if len(top_receptors) >= 2 else top_receptors + X_focus, focus_names, _ = restrict_to_receptors( + X=X, + receptor_names=receptor_names, + receptors_to_keep=focus_receptors, + ) + focus_scaler = StandardScaler().fit(X_focus) + X_focus_scaled = focus_scaler.transform(X_focus) + + intercept_mse = _loocv_intercept_mse(y) + y_stats = _get_stats(y) + + results: List[RunResult] = [] + artifacts: Dict = { + "condition": resolved_condition, + "mode": mode, + "control_condition": control_resolved, + "n_pairs_used": int(n_pairs_used), + "test_odorants": test_odorants, + "y_stats": y_stats, + "intercept_only_mse": intercept_mse, + "runs": [], + "mutation_ok": mutation_ok, + } + + model_configs = [ + ("lasso", None), + ("ridge", None), + ("elasticnet", 0.2), + ("elasticnet", 0.5), + ("elasticnet", 0.8), + ] + + for modelclass, l1_ratio in model_configs: + for repeat_idx in (1, 2): + random.seed(seed) + np.random.seed(seed) + + if modelclass == "lasso": + weights, cv_mse, best_lambda, y_pred = _fit_lasso( + X=X, + y=y, + receptor_names=receptor_names, + lambdas=lambdas, + cv_folds=X.shape[0], + scaler=scaler, + ) + coef = np.zeros(len(receptor_names)) + for idx, name in enumerate(receptor_names): + if name in weights: + coef[idx] = weights[name] + chosen_params = {"lambda": best_lambda} + n_selected = int(np.sum(np.abs(coef) > 1e-6)) + elif modelclass == "ridge": + best_lambda, cv_mse, coef, y_pred = _fit_ridge( + X=X_scaled, + y=y, + lambdas=lambdas, + ) + weights = { + receptor_names[i]: float(coef[i]) + for i in range(len(receptor_names)) + if abs(coef[i]) > 1e-6 + } + chosen_params = {"lambda": best_lambda} + n_selected = int(np.sum(np.abs(coef) > 1e-6)) + else: + best_lambda, cv_mse, best_l1, coef, y_pred = _fit_elasticnet( + X=X_scaled, + y=y, + lambdas=lambdas, + l1_ratios=[l1_ratio], + ) + weights = { + receptor_names[i]: float(coef[i]) + for i in range(len(receptor_names)) + if abs(coef[i]) > 1e-6 + } + chosen_params = {"lambda": best_lambda, "l1_ratio": best_l1} + n_selected = int(np.sum(np.abs(coef) > 1e-6)) + + pred_stats = _get_stats(y_pred) + mae = float(np.mean(np.abs(y - y_pred))) + y_std = float(y_stats["std"]) + nmse = float(cv_mse) / (y_std**2 + 1e-12) + mae_over_y_std = mae / (y_std + 1e-12) + intercept_only = ( + n_selected == 0 and abs(float(cv_mse) - float(intercept_mse)) < 1e-12 + ) + reproducible = True + + run = RunResult( + condition=resolved_condition, + mode=mode, + modelclass=modelclass if modelclass != "elasticnet" else f"elasticnet_{l1_ratio}", + repeat=repeat_idx, + cv_mse=float(cv_mse), + lambda_value=float(chosen_params["lambda"]), + n_selected=n_selected, + pred_std=float(pred_stats["std"]), + y_std=y_std, + intercept_only_mse=float(intercept_mse), + nmse=nmse, + mae=mae, + mae_over_y_std=mae_over_y_std, + intercept_only=intercept_only, + chosen_params=chosen_params, + weights=weights, + y_stats=y_stats, + pred_stats=pred_stats, + reproducible=reproducible, + mutation_ok=mutation_ok, + ) + + artifacts["runs"].append( + { + "repeat": repeat_idx, + "modelclass": run.modelclass, + "cv_mse": run.cv_mse, + "lambda": run.lambda_value, + "n_selected": run.n_selected, + "pred_stats": pred_stats, + "mae": mae, + "nmse": nmse, + "mae_over_y_std": mae_over_y_std, + "intercept_only": intercept_only, + "weights": weights, + } + ) + + # Plot for first repeat only + if repeat_idx == 1: + model_dir = output_dir / "plots" + if modelclass == "lasso": + mse_grid = _cv_mse_grid( + lambda alpha: Lasso(alpha=alpha, max_iter=10000, random_state=42), + X_scaled, + y, + lambdas, + ) + title = f"{resolved_condition} {mode} lasso" + plot_path = model_dir / f"cv_mse_vs_lambda_{resolved_condition}_{mode}_lasso.png" + _plot_mse_grid(lambdas, mse_grid, title, plot_path) + scatter_path = model_dir / f"y_vs_pred_scatter_{resolved_condition}_{mode}_lasso.png" + _plot_scatter(y, y_pred, title, scatter_path) + elif modelclass == "ridge": + mse_grid = _cv_mse_grid( + lambda alpha: Ridge(alpha=alpha, random_state=42), + X_scaled, + y, + lambdas, + ) + title = f"{resolved_condition} {mode} ridge" + plot_path = model_dir / f"cv_mse_vs_lambda_{resolved_condition}_{mode}_ridge.png" + _plot_mse_grid(lambdas, mse_grid, title, plot_path) + scatter_path = model_dir / f"y_vs_pred_scatter_{resolved_condition}_{mode}_ridge.png" + _plot_scatter(y, y_pred, title, scatter_path) + else: + title = f"{resolved_condition} {mode} elasticnet l1={l1_ratio}" + mse_grid = _cv_mse_grid( + lambda alpha: ElasticNet( + alpha=alpha, + l1_ratio=l1_ratio, + random_state=42, + max_iter=10000, + ), + X_scaled, + y, + lambdas, + ) + suffix = f"elasticnet_{l1_ratio}" + plot_path = model_dir / f"cv_mse_vs_lambda_{resolved_condition}_{mode}_{suffix}.png" + _plot_mse_grid(lambdas, mse_grid, title, plot_path) + scatter_path = model_dir / f"y_vs_pred_scatter_{resolved_condition}_{mode}_{suffix}.png" + _plot_scatter(y, y_pred, title, scatter_path) + + results.append(run) + + # Focus-mode LASSO sanity run (top-2 by baseline weights) + for repeat_idx in (1, 2): + random.seed(seed) + np.random.seed(seed) + + weights, cv_mse, best_lambda, y_pred = _fit_lasso( + X=X_focus, + y=y, + receptor_names=focus_names, + lambdas=lambdas, + cv_folds=X_focus.shape[0], + scaler=focus_scaler, + ) + coef = np.zeros(len(focus_names)) + for idx, name in enumerate(focus_names): + if name in weights: + coef[idx] = weights[name] + pred_stats = _get_stats(y_pred) + n_selected = int(np.sum(np.abs(coef) > 1e-6)) + mae = float(np.mean(np.abs(y - y_pred))) + y_std = float(y_stats["std"]) + nmse = float(cv_mse) / (y_std**2 + 1e-12) + mae_over_y_std = mae / (y_std + 1e-12) + intercept_only = ( + n_selected == 0 and abs(float(cv_mse) - float(intercept_mse)) < 1e-12 + ) + run = RunResult( + condition=resolved_condition, + mode=mode, + modelclass="lasso_focus_top2", + repeat=repeat_idx, + cv_mse=float(cv_mse), + lambda_value=float(best_lambda), + n_selected=n_selected, + pred_std=float(pred_stats["std"]), + y_std=y_std, + intercept_only_mse=float(intercept_mse), + nmse=nmse, + mae=mae, + mae_over_y_std=mae_over_y_std, + intercept_only=intercept_only, + chosen_params={"lambda": best_lambda, "focus_receptors": focus_names}, + weights=weights, + y_stats=y_stats, + pred_stats=pred_stats, + reproducible=True, + mutation_ok=mutation_ok, + ) + + artifacts["runs"].append( + { + "repeat": repeat_idx, + "modelclass": "lasso_focus_top2", + "cv_mse": run.cv_mse, + "lambda": run.lambda_value, + "n_selected": run.n_selected, + "pred_stats": pred_stats, + "mae": mae, + "nmse": nmse, + "mae_over_y_std": mae_over_y_std, + "intercept_only": intercept_only, + "weights": weights, + } + ) + + if repeat_idx == 1: + mse_grid = _cv_mse_grid( + lambda alpha: Lasso(alpha=alpha, max_iter=10000, random_state=42), + X_focus_scaled, + y, + lambdas, + ) + title = f"{resolved_condition} {mode} lasso focus top2" + plot_path = output_dir / "plots" / ( + f"cv_mse_vs_lambda_{resolved_condition}_{mode}_lasso_focus_top2.png" + ) + _plot_mse_grid(lambdas, mse_grid, title, plot_path) + scatter_path = output_dir / "plots" / ( + f"y_vs_pred_scatter_{resolved_condition}_{mode}_lasso_focus_top2.png" + ) + _plot_scatter(y, y_pred, title, scatter_path) + + results.append(run) + + # Reproducibility check (compare repeat 1 vs 2 per modelclass) + for modelclass in {r.modelclass for r in results}: + runs = [r for r in results if r.modelclass == modelclass] + runs_sorted = sorted(runs, key=lambda r: r.repeat) + if len(runs_sorted) < 2: + continue + r1, r2 = runs_sorted[0], runs_sorted[1] + same = ( + abs(r1.cv_mse - r2.cv_mse) < 1e-12 + and abs(r1.lambda_value - r2.lambda_value) < 1e-12 + and abs(r1.pred_std - r2.pred_std) < 1e-12 + and r1.weights == r2.weights + ) + r1.reproducible = same + r2.reproducible = same + + return results, artifacts + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Audit post-A/post-B LASSO pipeline with diagnostics outputs.", + ) + parser.add_argument("--door_cache", required=True, help="Path to DoOR cache directory.") + parser.add_argument("--behavior_csv", required=True, help="Path to behavioral matrix CSV.") + parser.add_argument( + "--conditions", + required=True, + help="Comma-separated condition list (e.g., opto_hex,opto_EB)", + ) + parser.add_argument( + "--lambda_range", + default="1e-4,1e-3,1e-2,1e-1,1.0", + help="Comma-separated lambda values.", + ) + parser.add_argument( + "--lambda_range_delta", + default="1e-8,1e-7,1e-6,1e-5,1e-4,1e-3,1e-2,1e-1,1.0", + help="Comma-separated lambda values for ΔPER runs.", + ) + parser.add_argument( + "--prediction_mode", + choices=["test_odorant", "trained_odorant", "interaction"], + default="test_odorant", + ) + parser.add_argument("--seed", type=int, default=1337) + parser.add_argument("--subtract_control", action="store_true") + parser.add_argument("--control_condition", default=None) + parser.add_argument( + "--missing_control_policy", + choices=["skip", "zero", "error"], + default="skip", + ) + parser.add_argument( + "--output_dir", + default=None, + help="Optional output dir; default creates diagnostics/postA_postB_audit_", + ) + + return parser.parse_args() + + +def main() -> int: + logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") + args = parse_args() + + random.seed(args.seed) + np.random.seed(args.seed) + + conditions = _parse_conditions([args.conditions]) + lambdas_raw = _parse_lambda_range(args.lambda_range) + lambdas_delta = _parse_lambda_range(args.lambda_range_delta) + + if args.output_dir: + output_dir = Path(args.output_dir) + else: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_dir = Path("diagnostics") / f"postA_postB_audit_{timestamp}" + + _ensure_dir(output_dir) + + predictor = LassoBehavioralPredictor( + doorcache_path=args.door_cache, + behavior_csv_path=args.behavior_csv, + scale_features=False, + scale_targets=False, + ) + + all_metrics: List[Dict] = [] + all_artifacts: Dict[str, Dict] = {} + + for condition in conditions: + mode_specs: List[Tuple[str, bool, np.ndarray]] = [("raw", False, lambdas_raw)] + if args.subtract_control: + if np.array_equal(lambdas_raw, lambdas_delta): + mode_specs.append(("delta", True, lambdas_delta)) + else: + mode_specs.append(("delta_base", True, lambdas_raw)) + mode_specs.append(("delta_extended", True, lambdas_delta)) + + for mode, subtract_control, lambdas in mode_specs: + if ( + subtract_control + and args.missing_control_policy == "skip" + and args.control_condition is None + ): + control_candidate = predictor._infer_control_condition(condition) + if control_candidate is None: + logger.warning( + "%s %s skipped: no matched control mapping (missing_control_policy=skip).", + condition, + mode, + ) + all_artifacts[f"{condition}_{mode}"] = { + "skipped": "no matched control mapping", + } + continue + control_resolved = predictor._resolve_dataset_name(control_candidate) + if control_resolved is None: + logger.warning( + "%s %s skipped: matched control '%s' not found (missing_control_policy=skip).", + condition, + mode, + control_candidate, + ) + all_artifacts[f"{condition}_{mode}"] = { + "skipped": f"matched control '{control_candidate}' not found", + } + continue + + try: + results, artifacts = _run_one( + condition=condition, + mode=mode, + predictor=predictor, + lambdas=lambdas, + prediction_mode=args.prediction_mode, + subtract_control=subtract_control, + control_condition=args.control_condition, + missing_control_policy=args.missing_control_policy, + seed=args.seed, + output_dir=output_dir, + ) + except Exception as exc: + logger.error("%s %s failed: %s", condition, mode, exc) + all_artifacts[f"{condition}_{mode}"] = {"error": str(exc)} + continue + + all_artifacts[f"{condition}_{mode}"] = artifacts + + for run in results: + if run.repeat != 1: + continue + all_metrics.append( + { + "condition": run.condition, + "mode": run.mode, + "modelclass": run.modelclass, + "cv_mse": run.cv_mse, + "lambda": run.lambda_value, + "n_selected": run.n_selected, + "pred_std": run.pred_std, + "y_std": run.y_std, + "intercept_only_mse": run.intercept_only_mse, + "nmse": run.nmse, + "mae": run.mae, + "mae_over_y_std": run.mae_over_y_std, + "intercept_only": run.intercept_only, + "reproducible": run.reproducible, + "mutation_ok": run.mutation_ok, + } + ) + + metrics_df = pd.DataFrame(all_metrics) + metrics_path = output_dir / "audit_metrics.csv" + metrics_df.to_csv(metrics_path, index=False) + + artifacts_path = output_dir / "audit_artifacts.json" + with open(artifacts_path, "w", encoding="utf-8") as f: + json.dump(all_artifacts, f, indent=2) + + summary_lines = [ + "# Audit Summary", + f"Run folder: {output_dir}", + f"Conditions: {', '.join(conditions)}", + f"Prediction mode: {args.prediction_mode}", + f"Subtract control: {args.subtract_control}", + f"Lambda grid: {args.lambda_range}", + f"Lambda grid (ΔPER): {args.lambda_range_delta}", + "", + ] + + if not metrics_df.empty: + reproducible_counts = metrics_df.groupby(["modelclass", "mode"])["reproducible"].mean() + summary_lines.append("## Reproducibility (repeat 1 vs 2)") + for (modelclass, mode), ratio in reproducible_counts.items(): + summary_lines.append(f"- {modelclass} {mode}: {ratio:.2f} reproducible fraction") + + errors = {k: v for k, v in all_artifacts.items() if isinstance(v, dict) and "error" in v} + skipped = {k: v for k, v in all_artifacts.items() if isinstance(v, dict) and "skipped" in v} + if errors or skipped: + summary_lines.append("") + summary_lines.append("## Missing Controls / Skipped Runs") + for key, value in errors.items(): + summary_lines.append(f"- {key}: {value['error']}") + for key, value in skipped.items(): + summary_lines.append(f"- {key}: {value['skipped']}") + + mutation_issues = metrics_df.loc[~metrics_df["mutation_ok"]] + summary_lines.append("") + summary_lines.append("## Mutation Check (focus/ablation on copies)") + if mutation_issues.empty: + summary_lines.append("- All conditions reported mutation_ok=True") + else: + for _, row in mutation_issues.iterrows(): + summary_lines.append( + f"- Mutation detected: {row['condition']} {row['mode']} {row['modelclass']}" + ) + + delta_modes = sorted({m for m in metrics_df["mode"].unique() if m.startswith("delta")}) + summary_lines.append("") + summary_lines.append("## ΔPER Collapse (constant predictions)") + if not delta_modes: + summary_lines.append("- No ΔPER runs were executed") + else: + preferred_delta = "delta_extended" if "delta_extended" in delta_modes else delta_modes[0] + lasso_delta = metrics_df[ + (metrics_df["modelclass"] == "lasso") & (metrics_df["mode"] == preferred_delta) + ] + collapsed = lasso_delta[lasso_delta["pred_std"] < 1e-6] + if collapsed.empty: + summary_lines.append(f"- No LASSO {preferred_delta} runs had pred_std < 1e-6") + else: + for _, row in collapsed.iterrows(): + summary_lines.append( + f"- {row['condition']}: pred_std={row['pred_std']:.6g}, " + f"n_selected={int(row['n_selected'])}, " + f"lambda={row['lambda']:.6g}, " + f"cv_mse={row['cv_mse']:.6g}, " + f"intercept_only_mse={row['intercept_only_mse']:.6g}" + ) + + if "delta_base" in delta_modes and "delta_extended" in delta_modes: + base_collapsed = metrics_df[ + (metrics_df["modelclass"] == "lasso") & (metrics_df["mode"] == "delta_base") + ] + ext_collapsed = metrics_df[ + (metrics_df["modelclass"] == "lasso") + & (metrics_df["mode"] == "delta_extended") + ] + summary_lines.append("") + summary_lines.append("## ΔPER Grid Comparison (LASSO)") + summary_lines.append( + f"- delta_base collapsed count: {int((base_collapsed['pred_std'] < 1e-6).sum())}" + ) + summary_lines.append( + f"- delta_extended collapsed count: {int((ext_collapsed['pred_std'] < 1e-6).sum())}" + ) + + # Primary model selection for ΔPER (fallback to ElasticNet if LASSO intercept-only) + summary_lines.append("") + summary_lines.append("## ΔPER Primary Model (fallback if LASSO intercept-only)") + primary_rows = [] + for condition in conditions: + mode = preferred_delta + lasso_row = metrics_df[ + (metrics_df["condition"] == condition) + & (metrics_df["mode"] == mode) + & (metrics_df["modelclass"] == "lasso") + ] + if lasso_row.empty: + continue + lasso_row = lasso_row.iloc[0] + if bool(lasso_row["intercept_only"]): + candidates = metrics_df[ + (metrics_df["condition"] == condition) + & (metrics_df["mode"] == mode) + & (metrics_df["modelclass"].str.startswith("elasticnet")) + ] + if candidates.empty: + primary_rows.append( + { + "condition": condition, + "mode": mode, + "primary_modelclass": "intercept_only", + "cv_mse": float(lasso_row["cv_mse"]), + "lambda": float(lasso_row["lambda"]), + } + ) + else: + best = candidates.loc[candidates["cv_mse"].idxmin()] + primary_rows.append( + { + "condition": condition, + "mode": mode, + "primary_modelclass": best["modelclass"], + "cv_mse": float(best["cv_mse"]), + "lambda": float(best["lambda"]), + } + ) + else: + primary_rows.append( + { + "condition": condition, + "mode": mode, + "primary_modelclass": "lasso", + "cv_mse": float(lasso_row["cv_mse"]), + "lambda": float(lasso_row["lambda"]), + } + ) + + if primary_rows: + primary_df = pd.DataFrame(primary_rows) + primary_df.to_csv(output_dir / "audit_primary_models.csv", index=False) + for row in primary_rows: + summary_lines.append( + f"- {row['condition']} ({row['mode']}): {row['primary_modelclass']} " + f"(cv_mse={row['cv_mse']:.6g}, lambda={row['lambda']:.6g})" + ) + + summary_lines.append("") + summary_lines.append("## Top ORNs (raw LASSO, repeat 1)") + for key, artifact in all_artifacts.items(): + if not key.endswith("_raw"): + continue + if "runs" not in artifact: + continue + lasso_runs = [ + r + for r in artifact["runs"] + if r.get("modelclass") == "lasso" and r.get("repeat") == 1 + ] + if not lasso_runs: + continue + weights = lasso_runs[0].get("weights", {}) + top_receptors = sorted(weights.items(), key=lambda x: abs(x[1]), reverse=True)[:5] + receptor_str = ", ".join([f"{r}({w:.3g})" for r, w in top_receptors]) + summary_lines.append( + f"- {artifact['condition']}: {receptor_str if receptor_str else 'no nonzero weights'}" + ) + + raw_lasso = metrics_df[(metrics_df["modelclass"] == "lasso") & (metrics_df["mode"] == "raw")] + summary_lines.append("") + summary_lines.append("## Scale vs Error (raw LASSO)") + if len(raw_lasso) >= 3: + corr = np.corrcoef(raw_lasso["y_std"], raw_lasso["cv_mse"])[0, 1] + summary_lines.append( + f"- Pearson corr(y_std, cv_mse) = {corr:.3f} (n={len(raw_lasso)})" + ) + else: + summary_lines.append("- Not enough conditions to compute correlation") + summary_lines.append( + "- Note: raw MSE is scale-dependent; compare nmse or mae/y_std for cross-condition checks." + ) + + summary_lines.append("") + summary_lines.append("## Files") + summary_lines.append("- audit_metrics.csv") + summary_lines.append("- audit_primary_models.csv") + summary_lines.append("- audit_artifacts.json") + summary_lines.append("- plots/*.png") + + summary_path = output_dir / "AUDIT_SUMMARY.md" + summary_path.write_text("\n".join(summary_lines), encoding="utf-8") + + logger.info("Wrote audit outputs to %s", output_dir) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/diagnostics/run_stability_and_metrics.py b/diagnostics/run_stability_and_metrics.py new file mode 100644 index 0000000..40e13d9 --- /dev/null +++ b/diagnostics/run_stability_and_metrics.py @@ -0,0 +1,780 @@ +#!/usr/bin/env python3 +""" +Run stability scoring + standardized metrics for LASSO/Ridge/ElasticNet. + +Outputs are written under diagnostics/stability_/. +""" + +from __future__ import annotations + +import argparse +import logging +import math +import random +import shlex +import sys +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +from typing import Dict, Iterable, List, Optional, Tuple + +import numpy as np +import pandas as pd +from sklearn.linear_model import ElasticNet, Ridge +from sklearn.model_selection import LeaveOneOut, cross_val_score +from sklearn.preprocessing import StandardScaler + +# Add src to path for repo-local runs (matches other scripts in this repo). +sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src")) + +from door_toolkit.pathways.behavioral_prediction import ( + LassoBehavioralPredictor, + fit_lasso_with_fixed_scaler, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class ModelFit: + modelclass: str + lambda_value: float + l1_ratio: Optional[float] + cv_mse: float + n_selected: int + coef: np.ndarray + y_pred: np.ndarray + + +def _parse_conditions(values: List[str]) -> List[str]: + conditions: List[str] = [] + seen = set() + for value in values: + for token in value.split(","): + token = token.strip() + if not token or token in seen: + continue + conditions.append(token) + seen.add(token) + return conditions + + +def _parse_lambda_range(value: Optional[str]) -> Optional[np.ndarray]: + if value is None: + return None + tokens = [token.strip() for token in value.split(",") if token.strip()] + if not tokens: + raise ValueError("lambda_range cannot be empty.") + return np.array([float(token) for token in tokens], dtype=np.float64) + + +def _ensure_dir(path: Path) -> None: + path.mkdir(parents=True, exist_ok=True) + + +def _loocv_intercept_mse(y: np.ndarray) -> float: + if len(y) < 2: + return float("nan") + loo = LeaveOneOut() + errors = [] + for train_idx, test_idx in loo.split(y): + train_mean = float(np.mean(y[train_idx])) + err = float((y[test_idx][0] - train_mean) ** 2) + errors.append(err) + return float(np.mean(errors)) + + +def _is_intercept_only(n_selected: int, pred_std: float, cv_mse: float, intercept_only_mse: float) -> bool: + return ( + n_selected == 0 + and pred_std < 1e-6 + and abs(cv_mse - intercept_only_mse) < 1e-12 + ) + + +def _cv_mse_grid(model_factory, X: np.ndarray, y: np.ndarray, lambdas: np.ndarray) -> List[float]: + loo = LeaveOneOut() + mse_values: List[float] = [] + for lam in lambdas: + model = model_factory(lam) + scores = cross_val_score(model, X, y, cv=loo, scoring="neg_mean_squared_error") + mse_values.append(float(-np.mean(scores))) + return mse_values + + +def _fit_ridge(X: np.ndarray, y: np.ndarray, lambdas: np.ndarray, seed: int) -> ModelFit: + def factory(alpha: float) -> Ridge: + return Ridge(alpha=alpha, random_state=seed) + + mse_grid = _cv_mse_grid(factory, X, y, lambdas) + best_idx = int(np.argmin(mse_grid)) + best_lambda = float(lambdas[best_idx]) + + model = factory(best_lambda) + model.fit(X, y) + y_pred = model.predict(X) + coef = model.coef_.copy() + n_selected = int(np.sum(np.abs(coef) > 1e-6)) + + return ModelFit( + modelclass="ridge", + lambda_value=best_lambda, + l1_ratio=None, + cv_mse=float(mse_grid[best_idx]), + n_selected=n_selected, + coef=coef, + y_pred=y_pred, + ) + + +def _fit_elasticnet( + X: np.ndarray, + y: np.ndarray, + lambdas: np.ndarray, + l1_ratios: List[float], + seed: int, +) -> ModelFit: + best_lambda = None + best_l1 = None + best_mse = None + best_coef = None + best_pred = None + + for l1_ratio in l1_ratios: + def factory(alpha: float) -> ElasticNet: + return ElasticNet( + alpha=alpha, + l1_ratio=l1_ratio, + random_state=seed, + max_iter=10000, + ) + + mse_grid = _cv_mse_grid(factory, X, y, lambdas) + idx = int(np.argmin(mse_grid)) + mse = float(mse_grid[idx]) + if best_mse is None or mse < best_mse: + best_mse = mse + best_lambda = float(lambdas[idx]) + best_l1 = float(l1_ratio) + model = factory(best_lambda) + model.fit(X, y) + best_coef = model.coef_.copy() + best_pred = model.predict(X) + + if best_lambda is None or best_coef is None or best_pred is None or best_l1 is None: + raise RuntimeError("ElasticNet fitting failed") + + n_selected = int(np.sum(np.abs(best_coef) > 1e-6)) + return ModelFit( + modelclass="elasticnet", + lambda_value=float(best_lambda), + l1_ratio=float(best_l1), + cv_mse=float(best_mse), + n_selected=n_selected, + coef=best_coef, + y_pred=best_pred, + ) + + +def _fit_lasso( + X: np.ndarray, + y: np.ndarray, + receptor_names: List[str], + lambdas: np.ndarray, + cv_folds: int, + scaler: Optional[StandardScaler], + seed: int, +) -> Tuple[Dict[str, float], float, float, np.ndarray]: + weights, _cv_r2, cv_mse, best_lambda, y_pred = fit_lasso_with_fixed_scaler( + X=X, + y=y, + receptor_names=receptor_names, + scaler=scaler, + lambda_range=lambdas, + cv_folds=cv_folds, + random_state=seed, + ) + return weights, float(cv_mse), float(best_lambda), y_pred + + +def _collect_metrics( + *, + condition: str, + mode: str, + model: ModelFit, + y: np.ndarray, + intercept_only_mse: float, +) -> Dict[str, float]: + y_std = float(np.std(y)) + y_var = float(np.var(y)) + y_min = float(np.min(y)) + y_max = float(np.max(y)) + pred_std = float(np.std(model.y_pred)) + pred_min = float(np.min(model.y_pred)) + pred_max = float(np.max(model.y_pred)) + nmse = float(model.cv_mse / (y_var + 1e-12)) + rmse_over_y_std = float(math.sqrt(model.cv_mse) / (y_std + 1e-12)) + intercept_only_flag = _is_intercept_only( + model.n_selected, pred_std, model.cv_mse, intercept_only_mse + ) + + return { + "condition": condition, + "mode": mode, + "modelclass": model.modelclass, + "lambda": model.lambda_value, + "l1_ratio": model.l1_ratio, + "n_selected": model.n_selected, + "cv_mse": model.cv_mse, + "y_std": y_std, + "y_var": y_var, + "y_min": y_min, + "y_max": y_max, + "pred_std": pred_std, + "pred_min": pred_min, + "pred_max": pred_max, + "nmse": nmse, + "rmse_over_y_std": rmse_over_y_std, + "intercept_only_mse": intercept_only_mse, + "intercept_only_flag": intercept_only_flag, + } + + +def _rank_indices(abs_weights: np.ndarray, receptor_names: List[str]) -> List[int]: + return sorted( + range(len(abs_weights)), + key=lambda idx: (-abs_weights[idx], receptor_names[idx]), + ) + + +def _compute_stability( + *, + X: np.ndarray, + y: np.ndarray, + receptor_names: List[str], + modelclass: str, + lambdas: np.ndarray, + l1_ratios: List[float], + cv_folds: int, + scale_features: bool, + seed: int, +) -> List[Dict[str, float]]: + n_samples = X.shape[0] + if n_samples < 2: + return [] + + stats = { + name: { + "selected": 0, + "signs": [], + "abs_weights": [], + "ranks": [], + } + for name in receptor_names + } + + for holdout_idx in range(n_samples): + train_mask = np.ones(n_samples, dtype=bool) + train_mask[holdout_idx] = False + X_train = X[train_mask] + y_train = y[train_mask] + + if modelclass == "lasso": + weights, _cv_mse, _best_lambda, _pred = _fit_lasso( + X_train, + y_train, + receptor_names, + lambdas, + cv_folds, + StandardScaler().fit(X_train) if scale_features else None, + seed, + ) + coef = np.zeros(len(receptor_names), dtype=np.float64) + for idx, name in enumerate(receptor_names): + if name in weights: + coef[idx] = weights[name] + elif modelclass.startswith("elasticnet"): + if scale_features: + X_train_scaled = StandardScaler().fit_transform(X_train) + else: + X_train_scaled = X_train + fit = _fit_elasticnet(X_train_scaled, y_train, lambdas, l1_ratios, seed) + coef = fit.coef + elif modelclass == "ridge": + if scale_features: + X_train_scaled = StandardScaler().fit_transform(X_train) + else: + X_train_scaled = X_train + fit = _fit_ridge(X_train_scaled, y_train, lambdas, seed) + coef = fit.coef + else: + raise ValueError(f"Unknown modelclass: {modelclass}") + + abs_weights = np.abs(coef) + if modelclass == "ridge": + selected_indices = list(range(len(coef))) + ranked_indices = _rank_indices(abs_weights, receptor_names) + else: + selected_indices = [i for i, w in enumerate(abs_weights) if w > 1e-6] + ranked_indices = _rank_indices(abs_weights, receptor_names) + + rank_lookup = {idx: rank + 1 for rank, idx in enumerate(ranked_indices)} + + for idx in selected_indices: + name = receptor_names[idx] + weight = float(coef[idx]) + stats[name]["selected"] += 1 + stats[name]["signs"].append(1 if weight >= 0 else -1) + stats[name]["abs_weights"].append(abs(weight)) + stats[name]["ranks"].append(rank_lookup[idx]) + + rows = [] + for name, entry in stats.items(): + selected = entry["selected"] + selection_frequency = selected / n_samples + signs = entry["signs"] + if signs: + median_sign = 1 if float(np.median(signs)) >= 0 else -1 + sign_consistency = sum(1 for s in signs if s == median_sign) / len(signs) + else: + sign_consistency = 0.0 + + abs_weights = entry["abs_weights"] + ranks = entry["ranks"] + + rows.append( + { + "condition": None, + "mode": None, + "modelclass": modelclass, + "orn_name": name, + "selection_frequency": float(selection_frequency), + "sign_consistency": float(sign_consistency), + "mean_abs_weight": float(np.mean(abs_weights)) if abs_weights else 0.0, + "std_abs_weight": float(np.std(abs_weights)) if abs_weights else 0.0, + "mean_rank": float(np.mean(ranks)) if ranks else float("nan"), + "n_folds": n_samples, + } + ) + + return rows + + +def _format_shortlist_table(rows: List[Dict[str, float]]) -> str: + if not rows: + return "(no stable ORNs detected)\n" + + header = "| ORN | stability_score | selection_frequency | sign_consistency | mean_abs_weight | mean_rank |\n" + header += "| --- | --- | --- | --- | --- | --- |\n" + lines = [header] + for row in rows: + lines.append( + f"| {row['orn_name']} | {row['stability_score']:.3f} | " + f"{row['selection_frequency']:.3f} | {row['sign_consistency']:.3f} | " + f"{row['mean_abs_weight']:.4f} | {row['mean_rank']:.2f} |" + ) + return "\n".join(lines) + "\n" + + +def _is_missing_control_error(exc: Exception) -> bool: + message = str(exc).lower() + if "no matched control mapping" in message: + return True + if "control" in message and "not found in behavioral data" in message: + return True + return False + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Run stability scoring + standardized metrics for LASSO/Ridge/ElasticNet.", + ) + + parser.add_argument("--door_cache", required=True, help="Path to DoOR cache directory.") + parser.add_argument("--behavior_csv", required=True, help="Path to behavioral matrix CSV.") + parser.add_argument( + "--conditions", + required=True, + action="append", + help="Condition name(s). Repeat flag or pass comma-separated list.", + ) + parser.add_argument( + "--prediction_mode", + choices=["test_odorant", "trained_odorant", "interaction"], + default="test_odorant", + help="Feature extraction mode.", + ) + parser.add_argument("--cv_folds", type=int, default=5, help="Number of CV folds.") + parser.add_argument( + "--lambda_range", + default="0.0001,0.001,0.01,0.1,1.0", + help="Comma-separated lambda values for raw runs.", + ) + parser.add_argument( + "--lambda_range_delta", + default="1e-8,1e-7,1e-6,1e-5,1e-4,1e-3,1e-2,1e-1,1.0", + help="Comma-separated lambda values for ΔPER runs.", + ) + parser.add_argument( + "--subtract_control", + action="store_true", + help="If set, run ΔPER (opto - control) in addition to raw.", + ) + parser.add_argument( + "--control_condition", + default=None, + help="Optional control dataset override (applies to all conditions).", + ) + parser.add_argument( + "--missing_control_policy", + choices=["skip", "zero", "error"], + default="error", + help="How to handle missing control values.", + ) + parser.add_argument("--seed", type=int, default=1337, help="Random seed.") + parser.add_argument( + "--output_dir", + default=None, + help="Optional output directory. Defaults to diagnostics/stability_.", + ) + parser.add_argument( + "--include_ridge_stability", + action="store_true", + help="If set, compute ridge stability ranks (optional).", + ) + parser.add_argument( + "--adult_only_masking", + action="store_true", + help="Restrict to adult-only receptors via training_receptor_set.json.", + ) + parser.add_argument( + "--training_receptor_set_path", + default=None, + help="Optional path to training_receptor_set.json.", + ) + + return parser.parse_args() + + +def main() -> None: + logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") + args = _parse_args() + + random.seed(args.seed) + np.random.seed(args.seed) + + conditions = _parse_conditions(args.conditions) + if not conditions: + raise ValueError("No valid conditions provided.") + + lambda_range = _parse_lambda_range(args.lambda_range) + lambda_range_delta = _parse_lambda_range(args.lambda_range_delta) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_dir = ( + Path(args.output_dir) + if args.output_dir is not None + else Path("diagnostics") / f"stability_{timestamp}" + ) + _ensure_dir(output_dir) + + command_line = " ".join(shlex.quote(arg) for arg in sys.argv) + (output_dir / "RUN_COMMANDS.txt").write_text(command_line + "\n") + + predictor = LassoBehavioralPredictor( + doorcache_path=args.door_cache, + behavior_csv_path=args.behavior_csv, + adult_only_masking=args.adult_only_masking, + training_receptor_set_path=args.training_receptor_set_path, + ) + + model_metrics_rows: List[Dict[str, float]] = [] + stability_rows: List[Dict[str, float]] = [] + skipped_delta_conditions: List[str] = [] + + l1_ratios = [0.2, 0.5, 0.8] + + modes_to_run = [("raw", False)] + if args.subtract_control: + modes_to_run.append(("delta", True)) + + for condition in conditions: + for mode, subtract_control in modes_to_run: + try: + lambda_values = lambda_range_delta if subtract_control else lambda_range + results = predictor.fit_behavior( + condition_name=condition, + prediction_mode=args.prediction_mode, + lambda_range=lambda_values.tolist() if lambda_values is not None else None, + cv_folds=args.cv_folds, + subtract_control=subtract_control, + control_condition=args.control_condition, + missing_control_policy=args.missing_control_policy, + ) + except ValueError as exc: + if subtract_control and _is_missing_control_error(exc): + if args.missing_control_policy == "skip": + logger.warning( + "No control found for '%s'; skipping ΔPER run (missing_control_policy=skip).", + condition, + ) + skipped_delta_conditions.append(condition) + continue + raise + + X = results.feature_matrix + y = results.actual_per + receptor_names = list(results.receptor_names) + + intercept_only_mse = _loocv_intercept_mse(y) + + scaler = StandardScaler().fit(X) if predictor.scale_features else None + + lasso_weights, lasso_cv_mse, lasso_lambda, lasso_pred = _fit_lasso( + X, + y, + receptor_names, + lambda_values, + args.cv_folds, + scaler, + args.seed, + ) + lasso_coef = np.array([lasso_weights.get(name, 0.0) for name in receptor_names]) + lasso_fit = ModelFit( + modelclass="lasso", + lambda_value=lasso_lambda, + l1_ratio=None, + cv_mse=lasso_cv_mse, + n_selected=len(lasso_weights), + coef=lasso_coef, + y_pred=lasso_pred, + ) + model_metrics_rows.append( + _collect_metrics( + condition=condition, + mode=mode, + model=lasso_fit, + y=y, + intercept_only_mse=intercept_only_mse, + ) + ) + + X_scaled = scaler.transform(X) if scaler is not None else X + ridge_fit = _fit_ridge(X_scaled, y, lambda_values, args.seed) + model_metrics_rows.append( + _collect_metrics( + condition=condition, + mode=mode, + model=ridge_fit, + y=y, + intercept_only_mse=intercept_only_mse, + ) + ) + + elasticnet_fit = _fit_elasticnet(X_scaled, y, lambda_values, l1_ratios, args.seed) + model_metrics_rows.append( + _collect_metrics( + condition=condition, + mode=mode, + model=elasticnet_fit, + y=y, + intercept_only_mse=intercept_only_mse, + ) + ) + + lasso_intercept_only = _is_intercept_only( + lasso_fit.n_selected, + float(np.std(lasso_fit.y_pred)), + lasso_fit.cv_mse, + intercept_only_mse, + ) + + stability_models: List[str] = [] + if not lasso_intercept_only: + stability_models.append("lasso") + elif mode == "delta": + stability_models.append(elasticnet_fit.modelclass) + if args.include_ridge_stability: + stability_models.append("ridge") + + for modelclass in stability_models: + stability_model = modelclass + if modelclass.startswith("elasticnet"): + stability_model = "elasticnet" + + rows = _compute_stability( + X=X, + y=y, + receptor_names=receptor_names, + modelclass=stability_model, + lambdas=lambda_values, + l1_ratios=l1_ratios, + cv_folds=args.cv_folds, + scale_features=predictor.scale_features, + seed=args.seed, + ) + for row in rows: + row["condition"] = condition + row["mode"] = mode + row["modelclass"] = modelclass + stability_rows.extend(rows) + + metrics_df = pd.DataFrame(model_metrics_rows) + metrics_df = metrics_df.sort_values(["condition", "mode", "modelclass"]) + metrics_path = output_dir / "model_metrics.csv" + metrics_df.to_csv(metrics_path, index=False) + + stability_columns = [ + "condition", + "mode", + "modelclass", + "orn_name", + "selection_frequency", + "sign_consistency", + "mean_abs_weight", + "std_abs_weight", + "mean_rank", + "n_folds", + ] + stability_df = pd.DataFrame(stability_rows, columns=stability_columns) + if not stability_df.empty: + stability_df = stability_df.sort_values( + ["condition", "mode", "modelclass", "orn_name"] + ) + stability_path = output_dir / "stability_per_condition.csv" + stability_df.to_csv(stability_path, index=False) + + _write_shortlist_and_summary( + output_dir=output_dir, + metrics_df=metrics_df, + stability_df=stability_df, + skipped_delta_conditions=skipped_delta_conditions, + ) + + logger.info("Wrote outputs to %s", output_dir) + + +def _write_shortlist_and_summary( + *, + output_dir: Path, + metrics_df: pd.DataFrame, + stability_df: pd.DataFrame, + skipped_delta_conditions: List[str], +) -> None: + short_path = output_dir / "EXPERIMENT_SHORTLIST.md" + summary_path = output_dir / "SUMMARY.md" + + modes = sorted(metrics_df["mode"].unique()) if not metrics_df.empty else [] + conditions = sorted(metrics_df["condition"].unique()) if not metrics_df.empty else [] + + def _primary_model(condition: str, mode: str) -> str: + lasso_row = metrics_df[ + (metrics_df["condition"] == condition) + & (metrics_df["mode"] == mode) + & (metrics_df["modelclass"] == "lasso") + ] + if lasso_row.empty: + return "lasso" + if bool(lasso_row.iloc[0]["intercept_only_flag"]): + elasticnet_rows = metrics_df[ + (metrics_df["condition"] == condition) + & (metrics_df["mode"] == mode) + & (metrics_df["modelclass"] == "elasticnet") + ] + if not elasticnet_rows.empty: + return "elasticnet" + return "lasso" + + def _confidence_flags(condition: str, mode: str, modelclass: str) -> List[str]: + row = metrics_df[ + (metrics_df["condition"] == condition) + & (metrics_df["mode"] == mode) + & (metrics_df["modelclass"] == modelclass) + ] + flags: List[str] = [] + if row.empty: + if mode == "delta" and condition in skipped_delta_conditions: + flags.append("ΔPER unavailable (missing control)") + return flags + row = row.iloc[0] + if row["nmse"] >= 1.0: + flags.append("no better than baseline scale-adjusted") + if bool(row["intercept_only_flag"]): + flags.append("no sparse signal; use ridge/elasticnet") + if mode == "delta" and condition in skipped_delta_conditions: + flags.append("ΔPER unavailable (missing control)") + return flags + + short_lines = ["# Experiment Target Shortlist", ""] + + summary_lines = ["# Stability + Metrics Summary", ""] + summary_lines.append("## Key findings") + summary_lines.append("") + + bullet_points = [] + bullet_points.append( + f"Ran stability for {len(conditions)} conditions across modes: {', '.join(modes) if modes else 'none'}" + ) + if skipped_delta_conditions: + bullet_points.append( + f"ΔPER skipped (missing controls): {', '.join(sorted(set(skipped_delta_conditions)))}" + ) + intercept_only_rows = metrics_df[metrics_df["intercept_only_flag"]] + bullet_points.append( + f"Intercept-only LASSO runs: {len(intercept_only_rows)} (see model_metrics.csv)" + ) + if not metrics_df.empty: + bullet_points.append( + f"NMSE range: {metrics_df['nmse'].min():.3f}–{metrics_df['nmse'].max():.3f}" + ) + bullet_points.append("Use nmse/rmse_over_y_std for cross-condition comparison (scale-aware).") + bullet_points.append("Stability scores use selection_frequency × sign_consistency.") + bullet_points.append("Shortlist excludes ORNs with stability_score=0.") + bullet_points.append("Raw mode always run; ΔPER only if requested and controls available.") + + for bullet in bullet_points[:8]: + summary_lines.append(f"- {bullet}") + + summary_lines.append("") + + for condition in conditions: + for mode in modes: + short_lines.append(f"## {condition} ({mode})") + primary_model = _primary_model(condition, mode) + + subset = stability_df[ + (stability_df["condition"] == condition) + & (stability_df["mode"] == mode) + & (stability_df["modelclass"] == primary_model) + ].copy() + + if subset.empty: + short_lines.append("(no stability data)") + short_lines.append("") + continue + + subset["stability_score"] = ( + subset["selection_frequency"] * subset["sign_consistency"] + ) + subset = subset[subset["stability_score"] > 0].sort_values( + ["stability_score", "mean_abs_weight"], ascending=False + ) + top_rows = subset.head(5).to_dict(orient="records") + + flags = _confidence_flags(condition, mode, primary_model) + if flags: + short_lines.append(f"Confidence flags: {', '.join(flags)}") + short_lines.append("\n" + _format_shortlist_table(top_rows)) + + summary_lines.append(f"## {condition} ({mode})") + summary_lines.append(f"Primary model: {primary_model}") + if flags: + summary_lines.append(f"Confidence flags: {', '.join(flags)}") + summary_lines.append("\n" + _format_shortlist_table(top_rows)) + + short_path.write_text("\n".join(short_lines)) + summary_path.write_text("\n".join(summary_lines)) + + +if __name__ == "__main__": + main() diff --git a/docs/BEHAVIORAL_PREDICTION_ANALYSIS.md b/docs/BEHAVIORAL_PREDICTION_ANALYSIS.md index 99a48f2..77b7b53 100644 --- a/docs/BEHAVIORAL_PREDICTION_ANALYSIS.md +++ b/docs/BEHAVIORAL_PREDICTION_ANALYSIS.md @@ -258,7 +258,10 @@ python scripts/run_lasso_behavioral_prediction.py \ Use `--control_condition` to override the default opto→control mapping, and `--also_run_raw` to generate a side-by-side comparison summary CSV. -If a condition lacks a matched control, the CLI logs a warning and falls back to raw mode. +If a condition lacks a matched control and `--missing_control_policy` is `skip`, +the CLI logs a warning and skips the ΔPER run. Otherwise it raises an error. +Currently unmapped for ΔPER: `opto_ACV`, `opto_3-oct` (provide `--control_condition` to run). +Default for `--missing_control_policy` is `error` in CLI scripts. ```bash conda activate DoOR @@ -526,6 +529,14 @@ behavioral_prediction_results/ --- +## Stability Layer (LOOO) +Run: `conda run -n DoOR python diagnostics/run_stability_and_metrics.py --door_cache door_cache --behavior_csv /path/to/reaction_rates_summary_unordered.csv --conditions opto_hex,opto_EB,opto_benz_1,opto_ACV,opto_3-oct --prediction_mode test_odorant --seed 1337 --subtract_control --missing_control_policy skip` +Outputs: `diagnostics/stability_/SUMMARY.md` and `EXPERIMENT_SHORTLIST.md`. +Use normalized metrics (`nmse`, `rmse_over_y_std`) for cross-condition comparisons. +ΔPER runs are skipped for conditions without mapped controls when `missing_control_policy=skip`. + +--- + ## Contact & Support **Questions about this analysis?** diff --git a/scripts/lasso_with_ablations.py b/scripts/lasso_with_ablations.py index 3d14cf2..5f31509 100644 --- a/scripts/lasso_with_ablations.py +++ b/scripts/lasso_with_ablations.py @@ -113,9 +113,9 @@ def parse_args() -> argparse.Namespace: ) parser.add_argument( "--condition", - type=str, required=True, - help="Optogenetic condition name (e.g., opto_hex)", + action="append", + help="Condition name(s). Repeat flag or pass comma-separated list.", ) parser.add_argument( "--output_dir", @@ -158,6 +158,29 @@ def parse_args() -> argparse.Namespace: default="test_odorant", help="Feature extraction mode (default: test_odorant)", ) + parser.add_argument( + "--subtract_control", + action="store_true", + help="Fit on ΔPER (opto - control) instead of raw PER.", + ) + parser.add_argument( + "--control_condition", + type=str, + default=None, + help="Optional control dataset override (default: infer from opto condition).", + ) + parser.add_argument( + "--missing_control_policy", + type=str, + choices=["skip", "zero", "error"], + default="error", + help="How to handle missing control values (default: error).", + ) + parser.add_argument( + "--debug_stats", + action="store_true", + help="Log y stats, chosen lambda, and nonzero coefficient count.", + ) # LASSO parameters parser.add_argument( @@ -205,6 +228,19 @@ def parse_args() -> argparse.Namespace: return parser.parse_args() +def _parse_conditions(values: List[str]) -> List[str]: + conditions: List[str] = [] + seen = set() + for value in values: + for token in value.split(","): + token = token.strip() + if not token or token in seen: + continue + conditions.append(token) + seen.add(token) + return conditions + + def load_receptors_from_file(filepath: str) -> List[str]: """Load receptor names from a file (one per line).""" path = Path(filepath) @@ -242,6 +278,111 @@ def save_weights_csv( logger.info(f"Saved weights to {filepath}") +def _build_valid_odorants( + predictor: LassoBehavioralPredictor, + condition: str, + subtract_control: bool, + control_condition: Optional[str], + missing_control_policy: str, +) -> Tuple[pd.Series, str, Optional[str], int]: + resolved_condition = predictor._resolve_dataset_name(condition) + if resolved_condition is None: + raise ValueError(f"Condition '{condition}' not found in behavioral data") + + if not subtract_control: + per_responses = predictor.behavioral_data.loc[resolved_condition] + valid_odorants = per_responses.dropna() + if len(valid_odorants) == 0: + raise ValueError(f"No valid PER data for condition '{resolved_condition}'") + return valid_odorants, resolved_condition, None, 0 + + if missing_control_policy not in {"skip", "zero", "error"}: + raise ValueError( + f"Unknown missing_control_policy '{missing_control_policy}'. " + "Expected one of: skip, zero, error." + ) + + if control_condition is not None: + control_resolved = predictor._resolve_dataset_name(control_condition) + if control_resolved is None: + raise ValueError(f"Control condition '{control_condition}' not found in behavioral data") + else: + control_candidate = predictor._infer_control_condition(resolved_condition) + if control_candidate is None: + raise ValueError( + f"No matched control mapping for '{resolved_condition}'. " + "Provide --control_condition or disable --subtract_control." + ) + control_resolved = predictor._resolve_dataset_name(control_candidate) + if control_resolved is None: + raise ValueError(f"Matched control '{control_candidate}' not found in behavioral data") + + if control_resolved == resolved_condition: + raise ValueError( + f"Control condition '{control_resolved}' matches opto condition '{resolved_condition}'." + ) + + per_opto = predictor.behavioral_data.loc[resolved_condition] + per_ctrl = predictor.behavioral_data.loc[control_resolved] + + if missing_control_policy == "skip": + valid_mask = per_opto.notna() & per_ctrl.notna() + valid_odorants = (per_opto - per_ctrl)[valid_mask] + elif missing_control_policy == "zero": + valid_mask = per_opto.notna() + per_ctrl_filled = per_ctrl.fillna(0.0) + valid_odorants = (per_opto - per_ctrl_filled)[valid_mask] + else: + missing_mask = per_opto.notna() & per_ctrl.isna() + if missing_mask.any(): + missing_odorants = [str(o) for o in per_opto.index[missing_mask]] + preview = ", ".join(missing_odorants[:5]) + if len(missing_odorants) > 5: + preview = f"{preview} (and {len(missing_odorants) - 5} more)" + raise ValueError( + f"Control condition '{control_resolved}' has missing values for odorants " + f"present in '{resolved_condition}': {preview}" + ) + valid_mask = per_opto.notna() & per_ctrl.notna() + valid_odorants = (per_opto - per_ctrl)[valid_mask] + + if len(valid_odorants) == 0: + raise ValueError( + f"No valid opto/control pairs for '{resolved_condition}' and '{control_resolved}'." + ) + + return valid_odorants, resolved_condition, control_resolved, int(len(valid_odorants)) + + +def _log_debug_stats( + *, + condition: str, + mode: str, + y: np.ndarray, + n_pairs_used: int, + lambda_value: float, + n_nonzero: int, +) -> None: + logger.info( + "[debug] %s %s y_stats: n=%d mean=%.4f std=%.4f min=%.4f max=%.4f n_pairs=%d", + condition, + mode, + len(y), + float(np.mean(y)), + float(np.std(y)), + float(np.min(y)), + float(np.max(y)), + n_pairs_used, + ) + logger.info( + "[debug] %s %s lambda=%.6f n_nonzero=%d", + condition, + mode, + float(lambda_value), + int(n_nonzero), + ) + + def save_model_json( result: AblationResult, condition: str, @@ -359,6 +500,10 @@ def run_baseline( lambda_range: np.ndarray, cv_folds: int, scale_features: bool, + subtract_control: bool = False, + control_condition: Optional[str] = None, + missing_control_policy: str = "skip", + debug_stats: bool = False, ) -> Tuple[AblationResult, np.ndarray, np.ndarray, List[str], Optional[StandardScaler]]: """Run baseline LASSO fit (no ablation). @@ -367,27 +512,35 @@ def run_baseline( """ logger.info(f"Fitting baseline model for condition: {condition}") - # Get behavioral responses - per_responses = predictor.behavioral_data.loc[condition] - valid_odorants = per_responses.dropna() - - if len(valid_odorants) == 0: - raise ValueError(f"No valid PER data for condition '{condition}'") + valid_odorants, condition_resolved, control_resolved, n_pairs_used = _build_valid_odorants( + predictor, + condition=condition, + subtract_control=subtract_control, + control_condition=control_condition, + missing_control_policy=missing_control_policy, + ) + if control_resolved: + logger.info( + "Using control condition '%s' with policy '%s': %d odorants after alignment", + control_resolved, + missing_control_policy, + n_pairs_used, + ) # Extract features based on prediction mode if prediction_mode == "test_odorant": X, test_odorants, y = predictor._extract_test_odorant_features(valid_odorants) elif prediction_mode == "trained_odorant": - trained_odorant = predictor.CONDITION_ODORANT_MAPPING.get(condition) + trained_odorant = predictor.CONDITION_ODORANT_MAPPING.get(condition_resolved) if not trained_odorant: - raise ValueError(f"Could not determine trained odorant for {condition}") + raise ValueError(f"Could not determine trained odorant for {condition_resolved}") X, test_odorants, y = predictor._extract_trained_odorant_features( trained_odorant, valid_odorants ) elif prediction_mode == "interaction": - trained_odorant = predictor.CONDITION_ODORANT_MAPPING.get(condition) + trained_odorant = predictor.CONDITION_ODORANT_MAPPING.get(condition_resolved) if not trained_odorant: - raise ValueError(f"Could not determine trained odorant for {condition}") + raise ValueError(f"Could not determine trained odorant for {condition_resolved}") X, test_odorants, y = predictor._extract_interaction_features( trained_odorant, valid_odorants ) @@ -434,6 +587,15 @@ def run_baseline( f"Baseline: R² = {cv_r2:.4f}, MSE = {cv_mse:.4f}, " f"λ = {best_lambda:.6f}, {len(weights)} receptors selected" ) + if debug_stats: + _log_debug_stats( + condition=condition_resolved, + mode="baseline", + y=y, + n_pairs_used=n_pairs_used, + lambda_value=best_lambda, + n_nonzero=len(weights), + ) return baseline_result, X, y, receptor_names, scaler @@ -449,6 +611,7 @@ def run_ablation( baseline_r2: float, baseline_mse: float, ablation_name: str, + debug_stats: bool = False, ) -> AblationResult: """Run LASSO with specified receptors ablated.""" logger.info(f"Running ablation: {ablation_name} ({receptors_to_ablate})") @@ -487,6 +650,15 @@ def run_ablation( f"Ablation '{ablation_name}': R² = {cv_r2:.4f} (Δ = {result.delta_r2:+.4f}), " f"MSE = {cv_mse:.4f} (Δ = {result.delta_mse:+.4f})" ) + if debug_stats: + _log_debug_stats( + condition=ablation_name, + mode="ablation", + y=y, + n_pairs_used=0, + lambda_value=best_lambda, + n_nonzero=len(weights), + ) return result @@ -523,14 +695,6 @@ def main() -> int: scale_targets=False, ) - # Validate condition - if args.condition not in predictor.behavioral_data.index: - logger.error( - f"Condition '{args.condition}' not found. " - f"Available: {list(predictor.behavioral_data.index)}" - ) - return 1 - # Resolve receptor names available_receptors = list(predictor.masked_receptor_names) strict_mode = args.missing_receptor_policy == "error" @@ -552,67 +716,166 @@ def main() -> int: logger.info(f"Resolved receptors: {matched_receptors}") - # Create output directory - output_dir = Path(args.output_dir) / args.condition - output_dir.mkdir(parents=True, exist_ok=True) + conditions = _parse_conditions(args.condition) + if not conditions: + logger.error("No valid conditions provided.") + return 1 - # Run baseline - try: - baseline_result, X, y, receptor_names, scaler = run_baseline( - predictor=predictor, - condition=args.condition, + exit_code = 0 + for condition in conditions: + try: + resolved_condition = predictor._resolve_dataset_name(condition) + except ValueError as exc: + logger.error(str(exc)) + exit_code = 1 + continue + + if resolved_condition is None: + logger.error( + "Condition '%s' not found. Available: %s", + condition, + list(predictor.behavioral_data.index), + ) + exit_code = 1 + continue + + if ( + args.subtract_control + and args.missing_control_policy == "skip" + and args.control_condition is None + ): + control_candidate = predictor._infer_control_condition(resolved_condition) + if control_candidate is None: + logger.warning( + "No matched control mapping for '%s'; skipping ΔPER run " + "(missing_control_policy=skip).", + resolved_condition, + ) + continue + control_resolved = predictor._resolve_dataset_name(control_candidate) + if control_resolved is None: + logger.warning( + "Matched control '%s' not found for '%s'; skipping ΔPER run " + "(missing_control_policy=skip).", + control_candidate, + resolved_condition, + ) + continue + + # Create output directory + output_dir = Path(args.output_dir) / resolved_condition + output_dir.mkdir(parents=True, exist_ok=True) + + # Run baseline + try: + baseline_result, X, y, receptor_names, scaler = run_baseline( + predictor=predictor, + condition=resolved_condition, + prediction_mode=args.prediction_mode, + lambda_range=lambda_range, + cv_folds=args.cv_folds, + scale_features=args.scale_features, + subtract_control=args.subtract_control, + control_condition=args.control_condition, + missing_control_policy=args.missing_control_policy, + debug_stats=args.debug_stats, + ) + except Exception as e: + logger.error("Baseline fit failed for %s: %s", resolved_condition, e) + exit_code = 1 + continue + + # Save baseline artifacts + save_model_json( + result=baseline_result, + condition=resolved_condition, prediction_mode=args.prediction_mode, - lambda_range=lambda_range, - cv_folds=args.cv_folds, - scale_features=args.scale_features, + n_samples=X.shape[0], + n_receptors_total=X.shape[1], + filepath=output_dir / "baseline_model.json", + ) + save_weights_csv( + weights=baseline_result.lasso_weights, + filepath=output_dir / "baseline_weights.csv", + condition=resolved_condition, + ablation_name="baseline", ) - except Exception as e: - logger.error(f"Baseline fit failed: {e}") - return 1 - - # Save baseline artifacts - save_model_json( - result=baseline_result, - condition=args.condition, - prediction_mode=args.prediction_mode, - n_samples=X.shape[0], - n_receptors_total=X.shape[1], - filepath=output_dir / "baseline_model.json", - ) - save_weights_csv( - weights=baseline_result.lasso_weights, - filepath=output_dir / "baseline_weights.csv", - condition=args.condition, - ablation_name="baseline", - ) - # Run ablations - ablation_results: List[AblationResult] = [] + # Run ablations + ablation_results: List[AblationResult] = [] + + if args.ablation_set_mode == "single": + # Ablate each receptor individually + for receptor in matched_receptors: + try: + ablation_name = f"ablate_{receptor}" + result = run_ablation( + X=X, + y=y, + receptor_names=receptor_names, + receptors_to_ablate=[receptor], + scaler=scaler, + lambda_range=lambda_range, + cv_folds=args.cv_folds, + baseline_r2=baseline_result.cv_r2, + baseline_mse=baseline_result.cv_mse, + ablation_name=ablation_name, + debug_stats=args.debug_stats, + ) + ablation_results.append(result) + + # Save individual ablation artifacts + ablation_dir = output_dir / ablation_name + save_model_json( + result=result, + condition=resolved_condition, + prediction_mode=args.prediction_mode, + n_samples=X.shape[0], + n_receptors_total=X.shape[1], + filepath=ablation_dir / "model.json", + ) + save_weights_csv( + weights=result.lasso_weights, + filepath=ablation_dir / "weights.csv", + condition=resolved_condition, + ablation_name=ablation_name, + ) + + except Exception as e: + logger.error( + "Ablation '%s' failed for %s: %s", receptor, resolved_condition, e + ) + exit_code = 1 + continue + + else: # all_in_one + # Ablate all receptors together + ablation_name = "ablate_" + "_".join(matched_receptors) + # Truncate if too long + if len(ablation_name) > 100: + ablation_name = f"ablate_{len(matched_receptors)}_receptors" - if args.ablation_set_mode == "single": - # Ablate each receptor individually - for receptor in matched_receptors: try: - ablation_name = f"ablate_{receptor}" result = run_ablation( X=X, y=y, receptor_names=receptor_names, - receptors_to_ablate=[receptor], + receptors_to_ablate=matched_receptors, scaler=scaler, lambda_range=lambda_range, cv_folds=args.cv_folds, baseline_r2=baseline_result.cv_r2, baseline_mse=baseline_result.cv_mse, ablation_name=ablation_name, + debug_stats=args.debug_stats, ) ablation_results.append(result) - # Save individual ablation artifacts + # Save ablation artifacts ablation_dir = output_dir / ablation_name save_model_json( result=result, - condition=args.condition, + condition=resolved_condition, prediction_mode=args.prediction_mode, n_samples=X.shape[0], n_receptors_total=X.shape[1], @@ -621,121 +884,82 @@ def main() -> int: save_weights_csv( weights=result.lasso_weights, filepath=ablation_dir / "weights.csv", - condition=args.condition, + condition=resolved_condition, ablation_name=ablation_name, ) except Exception as e: - logger.error(f"Ablation '{receptor}' failed: {e}") + logger.error("Ablation failed for %s: %s", resolved_condition, e) + exit_code = 1 continue - else: # all_in_one - # Ablate all receptors together - ablation_name = "ablate_" + "_".join(matched_receptors) - # Truncate if too long - if len(ablation_name) > 100: - ablation_name = f"ablate_{len(matched_receptors)}_receptors" - - try: - result = run_ablation( - X=X, - y=y, - receptor_names=receptor_names, - receptors_to_ablate=matched_receptors, - scaler=scaler, - lambda_range=lambda_range, - cv_folds=args.cv_folds, - baseline_r2=baseline_result.cv_r2, - baseline_mse=baseline_result.cv_mse, - ablation_name=ablation_name, - ) - ablation_results.append(result) - - # Save ablation artifacts - ablation_dir = output_dir / ablation_name - save_model_json( - result=result, - condition=args.condition, - prediction_mode=args.prediction_mode, - n_samples=X.shape[0], - n_receptors_total=X.shape[1], - filepath=ablation_dir / "model.json", - ) - save_weights_csv( - weights=result.lasso_weights, - filepath=ablation_dir / "weights.csv", - condition=args.condition, - ablation_name=ablation_name, - ) + # Generate summary CSV + summary_rows = [] - except Exception as e: - logger.error(f"Ablation failed: {e}") - - # Generate summary CSV - summary_rows = [] - - # Baseline row - summary_rows.append({ - "ablation_name": "baseline", - "receptors_ablated": "", - "n_ablated": 0, - "cv_r2": baseline_result.cv_r2, - "cv_mse": baseline_result.cv_mse, - "n_receptors_selected": baseline_result.n_receptors_selected, - "lambda_value": baseline_result.lambda_value, - "delta_r2": 0.0, - "delta_mse": 0.0, - }) - - # Ablation rows - for result in ablation_results: + # Baseline row summary_rows.append({ - "ablation_name": result.ablation_name, - "receptors_ablated": ";".join(result.receptors_ablated), - "n_ablated": len(result.receptors_ablated), - "cv_r2": result.cv_r2, - "cv_mse": result.cv_mse, - "n_receptors_selected": result.n_receptors_selected, - "lambda_value": result.lambda_value, - "delta_r2": result.delta_r2, - "delta_mse": result.delta_mse, + "ablation_name": "baseline", + "receptors_ablated": "", + "n_ablated": 0, + "cv_r2": baseline_result.cv_r2, + "cv_mse": baseline_result.cv_mse, + "n_receptors_selected": baseline_result.n_receptors_selected, + "lambda_value": baseline_result.lambda_value, + "delta_r2": 0.0, + "delta_mse": 0.0, }) - summary_df = pd.DataFrame(summary_rows) - - # Create ablations subfolder for summary files - ablations_dir = output_dir / "ablations" - ablations_dir.mkdir(parents=True, exist_ok=True) - - summary_path = ablations_dir / "ablation_summary.csv" - summary_df.to_csv(summary_path, index=False) - logger.info(f"Saved summary to {summary_path}") - - # Generate comparison plot - plot_ablation_comparison( - baseline_result=baseline_result, - ablation_results=ablation_results, - output_dir=ablations_dir, - condition=args.condition, - ) + # Ablation rows + for result in ablation_results: + summary_rows.append({ + "ablation_name": result.ablation_name, + "receptors_ablated": ";".join(result.receptors_ablated), + "n_ablated": len(result.receptors_ablated), + "cv_r2": result.cv_r2, + "cv_mse": result.cv_mse, + "n_receptors_selected": result.n_receptors_selected, + "lambda_value": result.lambda_value, + "delta_r2": result.delta_r2, + "delta_mse": result.delta_mse, + }) + + summary_df = pd.DataFrame(summary_rows) + + # Create ablations subfolder for summary files + ablations_dir = output_dir / "ablations" + ablations_dir.mkdir(parents=True, exist_ok=True) + + summary_path = ablations_dir / "ablation_summary.csv" + summary_df.to_csv(summary_path, index=False) + logger.info("Saved summary to %s", summary_path) + + # Generate comparison plot + plot_ablation_comparison( + baseline_result=baseline_result, + ablation_results=ablation_results, + output_dir=ablations_dir, + condition=resolved_condition, + ) - # Print summary - print("\n" + "=" * 80) - print(f"LASSO Ablation Analysis Complete: {args.condition}") - print("=" * 80) - print(f"\nBaseline: R² = {baseline_result.cv_r2:.4f}, MSE = {baseline_result.cv_mse:.4f}") - print(f" {baseline_result.n_receptors_selected} receptors selected") - print(f"\nAblation Results:") - for result in ablation_results: + # Print summary + print("\n" + "=" * 80) + print(f"LASSO Ablation Analysis Complete: {resolved_condition}") + print("=" * 80) print( - f" {result.ablation_name:40s} " - f"R² = {result.cv_r2:.4f} (Δ = {result.delta_r2:+.4f}) " - f"MSE = {result.cv_mse:.4f} (Δ = {result.delta_mse:+.4f})" + f"\nBaseline: R² = {baseline_result.cv_r2:.4f}, MSE = {baseline_result.cv_mse:.4f}" ) - print(f"\nOutputs saved to: {output_dir}") - print("=" * 80) + print(f" {baseline_result.n_receptors_selected} receptors selected") + print(f"\nAblation Results:") + for result in ablation_results: + print( + f" {result.ablation_name:40s} " + f"R² = {result.cv_r2:.4f} (Δ = {result.delta_r2:+.4f}) " + f"MSE = {result.cv_mse:.4f} (Δ = {result.delta_mse:+.4f})" + ) + print(f"\nOutputs saved to: {output_dir}") + print("=" * 80) - return 0 + return exit_code if __name__ == "__main__": diff --git a/scripts/lasso_with_focus_mode.py b/scripts/lasso_with_focus_mode.py index 60661d1..6eccf3c 100644 --- a/scripts/lasso_with_focus_mode.py +++ b/scripts/lasso_with_focus_mode.py @@ -102,9 +102,9 @@ def parse_args() -> argparse.Namespace: ) parser.add_argument( "--condition", - type=str, required=True, - help="Optogenetic condition name (e.g., opto_hex)", + action="append", + help="Condition name(s). Repeat flag or pass comma-separated list.", ) parser.add_argument( "--output_dir", @@ -142,6 +142,29 @@ def parse_args() -> argparse.Namespace: default="test_odorant", help="Feature extraction mode (default: test_odorant)", ) + parser.add_argument( + "--subtract_control", + action="store_true", + help="Fit on ΔPER (opto - control) instead of raw PER.", + ) + parser.add_argument( + "--control_condition", + type=str, + default=None, + help="Optional control dataset override (default: infer from opto condition).", + ) + parser.add_argument( + "--missing_control_policy", + type=str, + choices=["skip", "zero", "error"], + default="error", + help="How to handle missing control values (default: error).", + ) + parser.add_argument( + "--debug_stats", + action="store_true", + help="Log y stats, chosen lambda, and nonzero coefficient count.", + ) # LASSO parameters parser.add_argument( @@ -180,6 +203,19 @@ def parse_args() -> argparse.Namespace: return parser.parse_args() +def _parse_conditions(values: List[str]) -> List[str]: + conditions: List[str] = [] + seen = set() + for value in values: + for token in value.split(","): + token = token.strip() + if not token or token in seen: + continue + conditions.append(token) + seen.add(token) + return conditions + + def save_baseline_json( baseline_weights: Dict[str, float], baseline_r2: float, @@ -245,6 +281,111 @@ def save_focus_json( logger.info(f"Saved focus result to {filepath}") +def _build_valid_odorants( + predictor: LassoBehavioralPredictor, + condition: str, + subtract_control: bool, + control_condition: Optional[str], + missing_control_policy: str, +) -> Tuple[pd.Series, str, Optional[str], int]: + resolved_condition = predictor._resolve_dataset_name(condition) + if resolved_condition is None: + raise ValueError(f"Condition '{condition}' not found in behavioral data") + + if not subtract_control: + per_responses = predictor.behavioral_data.loc[resolved_condition] + valid_odorants = per_responses.dropna() + if len(valid_odorants) == 0: + raise ValueError(f"No valid PER data for condition '{resolved_condition}'") + return valid_odorants, resolved_condition, None, 0 + + if missing_control_policy not in {"skip", "zero", "error"}: + raise ValueError( + f"Unknown missing_control_policy '{missing_control_policy}'. " + "Expected one of: skip, zero, error." + ) + + if control_condition is not None: + control_resolved = predictor._resolve_dataset_name(control_condition) + if control_resolved is None: + raise ValueError(f"Control condition '{control_condition}' not found in behavioral data") + else: + control_candidate = predictor._infer_control_condition(resolved_condition) + if control_candidate is None: + raise ValueError( + f"No matched control mapping for '{resolved_condition}'. " + "Provide --control_condition or disable --subtract_control." + ) + control_resolved = predictor._resolve_dataset_name(control_candidate) + if control_resolved is None: + raise ValueError(f"Matched control '{control_candidate}' not found in behavioral data") + + if control_resolved == resolved_condition: + raise ValueError( + f"Control condition '{control_resolved}' matches opto condition '{resolved_condition}'." + ) + + per_opto = predictor.behavioral_data.loc[resolved_condition] + per_ctrl = predictor.behavioral_data.loc[control_resolved] + + if missing_control_policy == "skip": + valid_mask = per_opto.notna() & per_ctrl.notna() + valid_odorants = (per_opto - per_ctrl)[valid_mask] + elif missing_control_policy == "zero": + valid_mask = per_opto.notna() + per_ctrl_filled = per_ctrl.fillna(0.0) + valid_odorants = (per_opto - per_ctrl_filled)[valid_mask] + else: + missing_mask = per_opto.notna() & per_ctrl.isna() + if missing_mask.any(): + missing_odorants = [str(o) for o in per_opto.index[missing_mask]] + preview = ", ".join(missing_odorants[:5]) + if len(missing_odorants) > 5: + preview = f"{preview} (and {len(missing_odorants) - 5} more)" + raise ValueError( + f"Control condition '{control_resolved}' has missing values for odorants " + f"present in '{resolved_condition}': {preview}" + ) + valid_mask = per_opto.notna() & per_ctrl.notna() + valid_odorants = (per_opto - per_ctrl)[valid_mask] + + if len(valid_odorants) == 0: + raise ValueError( + f"No valid opto/control pairs for '{resolved_condition}' and '{control_resolved}'." + ) + + return valid_odorants, resolved_condition, control_resolved, int(len(valid_odorants)) + + +def _log_debug_stats( + *, + condition: str, + mode: str, + y: np.ndarray, + n_pairs_used: int, + lambda_value: float, + n_nonzero: int, +) -> None: + logger.info( + "[debug] %s %s y_stats: n=%d mean=%.4f std=%.4f min=%.4f max=%.4f n_pairs=%d", + condition, + mode, + len(y), + float(np.mean(y)), + float(np.std(y)), + float(np.min(y)), + float(np.max(y)), + n_pairs_used, + ) + logger.info( + "[debug] %s %s lambda=%.6f n_nonzero=%d", + condition, + mode, + float(lambda_value), + int(n_nonzero), + ) + + def plot_focus_curve( focus_results: List[FocusResult], baseline_r2: float, @@ -336,6 +477,10 @@ def run_baseline( lambda_range: np.ndarray, cv_folds: int, scale_features: bool, + subtract_control: bool = False, + control_condition: Optional[str] = None, + missing_control_policy: str = "skip", + debug_stats: bool = False, ) -> Tuple[Dict[str, float], float, float, float, np.ndarray, np.ndarray, List[str], Optional[StandardScaler]]: """Run baseline LASSO fit (full receptor set). @@ -344,27 +489,35 @@ def run_baseline( """ logger.info(f"Fitting baseline model for condition: {condition}") - # Get behavioral responses - per_responses = predictor.behavioral_data.loc[condition] - valid_odorants = per_responses.dropna() - - if len(valid_odorants) == 0: - raise ValueError(f"No valid PER data for condition '{condition}'") + valid_odorants, condition_resolved, control_resolved, n_pairs_used = _build_valid_odorants( + predictor, + condition=condition, + subtract_control=subtract_control, + control_condition=control_condition, + missing_control_policy=missing_control_policy, + ) + if control_resolved: + logger.info( + "Using control condition '%s' with policy '%s': %d odorants after alignment", + control_resolved, + missing_control_policy, + n_pairs_used, + ) # Extract features based on prediction mode if prediction_mode == "test_odorant": X, test_odorants, y = predictor._extract_test_odorant_features(valid_odorants) elif prediction_mode == "trained_odorant": - trained_odorant = predictor.CONDITION_ODORANT_MAPPING.get(condition) + trained_odorant = predictor.CONDITION_ODORANT_MAPPING.get(condition_resolved) if not trained_odorant: - raise ValueError(f"Could not determine trained odorant for {condition}") + raise ValueError(f"Could not determine trained odorant for {condition_resolved}") X, test_odorants, y = predictor._extract_trained_odorant_features( trained_odorant, valid_odorants ) elif prediction_mode == "interaction": - trained_odorant = predictor.CONDITION_ODORANT_MAPPING.get(condition) + trained_odorant = predictor.CONDITION_ODORANT_MAPPING.get(condition_resolved) if not trained_odorant: - raise ValueError(f"Could not determine trained odorant for {condition}") + raise ValueError(f"Could not determine trained odorant for {condition_resolved}") X, test_odorants, y = predictor._extract_interaction_features( trained_odorant, valid_odorants ) @@ -398,6 +551,15 @@ def run_baseline( f"Baseline: R² = {cv_r2:.4f}, MSE = {cv_mse:.4f}, " f"λ = {best_lambda:.6f}, {len(weights)} receptors selected" ) + if debug_stats: + _log_debug_stats( + condition=condition_resolved, + mode="baseline", + y=y, + n_pairs_used=n_pairs_used, + lambda_value=best_lambda, + n_nonzero=len(weights), + ) return weights, cv_r2, cv_mse, best_lambda, X, y, receptor_names, scaler @@ -412,6 +574,7 @@ def run_focus( scale_features: bool, baseline_r2: float, baseline_mse: float, + debug_stats: bool = False, ) -> FocusResult: """Run LASSO with restricted receptor set.""" n_receptors = len(receptors_to_keep) @@ -457,6 +620,15 @@ def run_focus( f"Focus N={n_receptors}: R² = {cv_r2:.4f} (Δ = {result.delta_r2:+.4f}), " f"MSE = {cv_mse:.4f} (Δ = {result.delta_mse:+.4f})" ) + if debug_stats: + _log_debug_stats( + condition=f"focus_n{n_receptors}", + mode="focus", + y=y, + n_pairs_used=0, + lambda_value=best_lambda, + n_nonzero=len(weights), + ) return result @@ -481,211 +653,271 @@ def main() -> int: scale_targets=False, ) - # Validate condition - if args.condition not in predictor.behavioral_data.index: - logger.error( - f"Condition '{args.condition}' not found. " - f"Available: {list(predictor.behavioral_data.index)}" - ) - return 1 - - # Create output directory - output_dir = Path(args.output_dir) / args.condition - output_dir.mkdir(parents=True, exist_ok=True) - - # Run baseline - try: - ( - baseline_weights, - baseline_r2, - baseline_mse, - baseline_lambda, - X, - y, - receptor_names, - scaler, - ) = run_baseline( - predictor=predictor, - condition=args.condition, - prediction_mode=args.prediction_mode, - lambda_range=lambda_range, - cv_folds=args.cv_folds, - scale_features=args.scale_features, - ) - except Exception as e: - logger.error(f"Baseline fit failed: {e}") + conditions = _parse_conditions(args.condition) + if not conditions: + logger.error("No valid conditions provided.") return 1 - # Save baseline - save_baseline_json( - baseline_weights=baseline_weights, - baseline_r2=baseline_r2, - baseline_mse=baseline_mse, - baseline_lambda=baseline_lambda, - condition=args.condition, - prediction_mode=args.prediction_mode, - n_samples=X.shape[0], - n_receptors_total=X.shape[1], - filepath=output_dir / "baseline_model.json", - ) - - # Determine focus runs - focus_results: List[FocusResult] = [] + exit_code = 0 + for condition in conditions: + try: + resolved_condition = predictor._resolve_dataset_name(condition) + except ValueError as exc: + logger.error(str(exc)) + exit_code = 1 + continue + + if resolved_condition is None: + logger.error( + "Condition '%s' not found. Available: %s", + condition, + list(predictor.behavioral_data.index), + ) + exit_code = 1 + continue + + if ( + args.subtract_control + and args.missing_control_policy == "skip" + and args.control_condition is None + ): + control_candidate = predictor._infer_control_condition(resolved_condition) + if control_candidate is None: + logger.warning( + "No matched control mapping for '%s'; skipping ΔPER run " + "(missing_control_policy=skip).", + resolved_condition, + ) + continue + control_resolved = predictor._resolve_dataset_name(control_candidate) + if control_resolved is None: + logger.warning( + "Matched control '%s' not found for '%s'; skipping ΔPER run " + "(missing_control_policy=skip).", + control_candidate, + resolved_condition, + ) + continue - if args.focus_receptors: - # Use explicit receptor list - focus_receptors = [r.strip() for r in args.focus_receptors.split(",") if r.strip()] - logger.info(f"Using explicit focus receptors: {focus_receptors}") + # Create output directory + output_dir = Path(args.output_dir) / resolved_condition + output_dir.mkdir(parents=True, exist_ok=True) + # Run baseline try: - result = run_focus( - X=X, - y=y, - receptor_names=receptor_names, - receptors_to_keep=focus_receptors, + ( + baseline_weights, + baseline_r2, + baseline_mse, + baseline_lambda, + X, + y, + receptor_names, + scaler, + ) = run_baseline( + predictor=predictor, + condition=resolved_condition, + prediction_mode=args.prediction_mode, lambda_range=lambda_range, cv_folds=args.cv_folds, scale_features=args.scale_features, - baseline_r2=baseline_r2, - baseline_mse=baseline_mse, + subtract_control=args.subtract_control, + control_condition=args.control_condition, + missing_control_policy=args.missing_control_policy, + debug_stats=args.debug_stats, ) - focus_results.append(result) - - # Save individual focus result - focus_dir = output_dir / f"focus_n{result.n_receptors}" - save_focus_json( - result=result, - condition=args.condition, - prediction_mode=args.prediction_mode, - n_samples=X.shape[0], - filepath=focus_dir / "model.json", - ) - except Exception as e: - logger.error(f"Focus run failed: {e}") - - else: - # Use topn_list - topn_list = [int(x.strip()) for x in args.topn_list.split(",")] - logger.info(f"Top-N values to test: {topn_list}") - - # Get ranked receptors from baseline - if not baseline_weights: - logger.error("Baseline model has no non-zero weights, cannot determine top-N") - return 1 - - if args.baseline_select_by == "abs_weight": - ranked_receptors = get_top_receptors_by_weight(baseline_weights, len(baseline_weights)) - else: - # stability mode would require multiple runs - fallback to abs_weight - logger.warning("stability mode not yet implemented, using abs_weight") - ranked_receptors = get_top_receptors_by_weight(baseline_weights, len(baseline_weights)) - - logger.info(f"Receptor ranking (top 10): {ranked_receptors[:10]}") - - for n in topn_list: - if n > len(ranked_receptors): - logger.warning( - f"N={n} exceeds available ranked receptors ({len(ranked_receptors)}), skipping" - ) - continue + logger.error("Baseline fit failed for %s: %s", resolved_condition, e) + exit_code = 1 + continue + + # Save baseline + save_baseline_json( + baseline_weights=baseline_weights, + baseline_r2=baseline_r2, + baseline_mse=baseline_mse, + baseline_lambda=baseline_lambda, + condition=resolved_condition, + prediction_mode=args.prediction_mode, + n_samples=X.shape[0], + n_receptors_total=X.shape[1], + filepath=output_dir / "baseline_model.json", + ) - if n <= 0: - logger.warning(f"N={n} is invalid, skipping") - continue + # Determine focus runs + focus_results: List[FocusResult] = [] - top_n_receptors = ranked_receptors[:n] + if args.focus_receptors: + # Use explicit receptor list + focus_receptors = [r.strip() for r in args.focus_receptors.split(",") if r.strip()] + logger.info("Using explicit focus receptors: %s", focus_receptors) try: result = run_focus( X=X, y=y, receptor_names=receptor_names, - receptors_to_keep=top_n_receptors, + receptors_to_keep=focus_receptors, lambda_range=lambda_range, cv_folds=args.cv_folds, scale_features=args.scale_features, baseline_r2=baseline_r2, baseline_mse=baseline_mse, + debug_stats=args.debug_stats, ) focus_results.append(result) # Save individual focus result - focus_dir = output_dir / f"focus_n{n}" + focus_dir = output_dir / f"focus_n{result.n_receptors}" save_focus_json( result=result, - condition=args.condition, + condition=resolved_condition, prediction_mode=args.prediction_mode, n_samples=X.shape[0], filepath=focus_dir / "model.json", ) except Exception as e: - logger.error(f"Focus N={n} failed: {e}") + logger.error("Focus run failed for %s: %s", resolved_condition, e) + exit_code = 1 continue - # Generate focus_curve.csv - curve_rows = [] + else: + # Use topn_list + topn_list = [int(x.strip()) for x in args.topn_list.split(",")] + logger.info("Top-N values to test: %s", topn_list) + + # Get ranked receptors from baseline + if not baseline_weights: + logger.error( + "Baseline model has no non-zero weights for %s; cannot determine top-N", + resolved_condition, + ) + exit_code = 1 + continue - # Add baseline row - curve_rows.append({ - "n_receptors": X.shape[1], - "receptors_used": ";".join(sorted(baseline_weights.keys())), - "cv_r2": baseline_r2, - "cv_mse": baseline_mse, - "lambda_value": baseline_lambda, - "n_receptors_selected": len(baseline_weights), - "delta_r2": 0.0, - "delta_mse": 0.0, - "is_baseline": True, - }) + if args.baseline_select_by == "abs_weight": + ranked_receptors = get_top_receptors_by_weight( + baseline_weights, len(baseline_weights) + ) + else: + # stability mode would require multiple runs - fallback to abs_weight + logger.warning("stability mode not yet implemented, using abs_weight") + ranked_receptors = get_top_receptors_by_weight( + baseline_weights, len(baseline_weights) + ) - # Add focus rows - for result in focus_results: + logger.info("Receptor ranking (top 10): %s", ranked_receptors[:10]) + + for n in topn_list: + if n > len(ranked_receptors): + logger.warning( + "N=%d exceeds available ranked receptors (%d), skipping", + n, + len(ranked_receptors), + ) + continue + + if n <= 0: + logger.warning("N=%d is invalid, skipping", n) + continue + + top_n_receptors = ranked_receptors[:n] + + try: + result = run_focus( + X=X, + y=y, + receptor_names=receptor_names, + receptors_to_keep=top_n_receptors, + lambda_range=lambda_range, + cv_folds=args.cv_folds, + scale_features=args.scale_features, + baseline_r2=baseline_r2, + baseline_mse=baseline_mse, + debug_stats=args.debug_stats, + ) + focus_results.append(result) + + # Save individual focus result + focus_dir = output_dir / f"focus_n{n}" + save_focus_json( + result=result, + condition=resolved_condition, + prediction_mode=args.prediction_mode, + n_samples=X.shape[0], + filepath=focus_dir / "model.json", + ) + + except Exception as e: + logger.error("Focus N=%d failed for %s: %s", n, resolved_condition, e) + exit_code = 1 + continue + + # Generate focus_curve.csv + curve_rows = [] + + # Add baseline row curve_rows.append({ - "n_receptors": result.n_receptors, - "receptors_used": ";".join(result.receptors_used), - "cv_r2": result.cv_r2, - "cv_mse": result.cv_mse, - "lambda_value": result.lambda_value, - "n_receptors_selected": result.n_receptors_selected, - "delta_r2": result.delta_r2, - "delta_mse": result.delta_mse, - "is_baseline": False, + "n_receptors": X.shape[1], + "receptors_used": ";".join(sorted(baseline_weights.keys())), + "cv_r2": baseline_r2, + "cv_mse": baseline_mse, + "lambda_value": baseline_lambda, + "n_receptors_selected": len(baseline_weights), + "delta_r2": 0.0, + "delta_mse": 0.0, + "is_baseline": True, }) - curve_df = pd.DataFrame(curve_rows) - curve_df = curve_df.sort_values("n_receptors") - curve_path = output_dir / "focus_curve.csv" - curve_df.to_csv(curve_path, index=False) - logger.info(f"Saved focus curve to {curve_path}") - - # Generate plots - plot_focus_curve( - focus_results=focus_results, - baseline_r2=baseline_r2, - baseline_mse=baseline_mse, - output_dir=output_dir, - condition=args.condition, - ) + # Add focus rows + for result in focus_results: + curve_rows.append({ + "n_receptors": result.n_receptors, + "receptors_used": ";".join(result.receptors_used), + "cv_r2": result.cv_r2, + "cv_mse": result.cv_mse, + "lambda_value": result.lambda_value, + "n_receptors_selected": result.n_receptors_selected, + "delta_r2": result.delta_r2, + "delta_mse": result.delta_mse, + "is_baseline": False, + }) + + curve_df = pd.DataFrame(curve_rows) + curve_df = curve_df.sort_values("n_receptors") + curve_path = output_dir / "focus_curve.csv" + curve_df.to_csv(curve_path, index=False) + logger.info("Saved focus curve to %s", curve_path) + + # Generate plots + plot_focus_curve( + focus_results=focus_results, + baseline_r2=baseline_r2, + baseline_mse=baseline_mse, + output_dir=output_dir, + condition=resolved_condition, + ) - # Print summary - print("\n" + "=" * 80) - print(f"LASSO Focus Mode Analysis Complete: {args.condition}") - print("=" * 80) - print(f"\nBaseline (full): R² = {baseline_r2:.4f}, MSE = {baseline_mse:.4f}") - print(f" {len(baseline_weights)} receptors selected out of {X.shape[1]}") - print(f"\nFocus Mode Results:") - for result in sorted(focus_results, key=lambda r: r.n_receptors): + # Print summary + print("\n" + "=" * 80) + print(f"LASSO Focus Mode Analysis Complete: {resolved_condition}") + print("=" * 80) + print(f"\nBaseline (full): R² = {baseline_r2:.4f}, MSE = {baseline_mse:.4f}") print( - f" N={result.n_receptors:3d}: " - f"R² = {result.cv_r2:.4f} (Δ = {result.delta_r2:+.4f}) " - f"MSE = {result.cv_mse:.4f} (Δ = {result.delta_mse:+.4f})" + f" {len(baseline_weights)} receptors selected out of {X.shape[1]}" ) - print(f"\nOutputs saved to: {output_dir}") - print("=" * 80) + print(f"\nFocus Mode Results:") + for result in sorted(focus_results, key=lambda r: r.n_receptors): + print( + f" N={result.n_receptors:3d}: " + f"R² = {result.cv_r2:.4f} (Δ = {result.delta_r2:+.4f}) " + f"MSE = {result.cv_mse:.4f} (Δ = {result.delta_mse:+.4f})" + ) + print(f"\nOutputs saved to: {output_dir}") + print("=" * 80) - return 0 + return exit_code if __name__ == "__main__": diff --git a/scripts/run_lasso_behavioral_prediction.py b/scripts/run_lasso_behavioral_prediction.py index a965e23..763e97f 100644 --- a/scripts/run_lasso_behavioral_prediction.py +++ b/scripts/run_lasso_behavioral_prediction.py @@ -127,7 +127,7 @@ def _parse_args() -> argparse.Namespace: parser.add_argument( "--missing_control_policy", choices=["skip", "zero", "error"], - default="skip", + default="error", help="How to handle missing control values.", ) @@ -143,6 +143,11 @@ def _parse_args() -> argparse.Namespace: default=None, help="Comma-separated lambda values (e.g., 0.0001,0.001,0.01).", ) + parser.add_argument( + "--lambda_range_delta", + default="1e-8,1e-7,1e-6,1e-5,1e-4,1e-3,1e-2,1e-1,1.0", + help="Comma-separated lambda values for ΔPER runs.", + ) return parser.parse_args() @@ -156,6 +161,7 @@ def main() -> None: raise ValueError("No valid conditions provided.") lambda_range = _parse_lambda_range(args.lambda_range) + lambda_range_delta = _parse_lambda_range(args.lambda_range_delta) output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) @@ -176,6 +182,7 @@ def main() -> None: ran_raw = False for mode_label, subtract_control in modes_to_run: try: + mode_lambda_range = lambda_range_delta if subtract_control else lambda_range results = _run_condition( predictor, condition_name=condition_name, @@ -184,34 +191,20 @@ def main() -> None: control_condition=args.control_condition, missing_control_policy=args.missing_control_policy, prediction_mode=args.prediction_mode, - lambda_range=lambda_range, + lambda_range=mode_lambda_range, cv_folds=args.cv_folds, output_dir=output_dir, ) except ValueError as exc: if subtract_control and _is_missing_control_error(exc): - logger.warning( - "No control found for '%s'; falling back to raw mode.", - condition_name, - ) - if ran_raw: - logger.info("Raw mode already completed for '%s'; skipping delta.", condition_name) + if args.missing_control_policy == "skip": + logger.warning( + "No control found for '%s'; skipping ΔPER run (missing_control_policy=skip).", + condition_name, + ) continue - results = _run_condition( - predictor, - condition_name=condition_name, - mode_label="raw", - subtract_control=False, - control_condition=None, - missing_control_policy=args.missing_control_policy, - prediction_mode=args.prediction_mode, - lambda_range=lambda_range, - cv_folds=args.cv_folds, - output_dir=output_dir, - ) - mode_label = "raw" - else: raise + raise if mode_label == "raw": ran_raw = True diff --git a/scripts/run_lasso_full_sweep.py b/scripts/run_lasso_full_sweep.py new file mode 100644 index 0000000..58e1dfd --- /dev/null +++ b/scripts/run_lasso_full_sweep.py @@ -0,0 +1,860 @@ +#!/usr/bin/env python3 +""" +Run a full LASSO sweep: baseline, ablations, and focus mode per condition. + +This script writes outputs into per-condition folders and produces +data-only summaries comparing each run to the baseline. +""" + +from __future__ import annotations + +import argparse +import json +import logging +import sys +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import numpy as np +import pandas as pd +from sklearn.preprocessing import StandardScaler + +# Add src to path for repo-local runs (matches other scripts in this repo). +sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src")) + +from door_toolkit.pathways.behavioral_prediction import ( + LassoBehavioralPredictor, + apply_receptor_ablation, + fit_lasso_with_fixed_scaler, + get_top_receptors_by_weight, + resolve_receptor_names, + restrict_to_receptors, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class RunMetrics: + run_type: str + label: str + cv_r2: float + cv_mse: float + lambda_value: float + n_receptors_selected: int + n_samples: int + + +def _parse_conditions(values: List[str]) -> List[str]: + conditions: List[str] = [] + seen = set() + for value in values: + for token in value.split(","): + token = token.strip() + if not token or token in seen: + continue + conditions.append(token) + seen.add(token) + return conditions + + +def _parse_lambda_range(value: Optional[str]) -> np.ndarray: + if value is None: + return np.logspace(-4, 0, 50) + tokens = [token.strip() for token in value.split(",") if token.strip()] + if not tokens: + raise ValueError("lambda_range cannot be empty.") + return np.array([float(token) for token in tokens], dtype=np.float64) + + +def _parse_topn_list(value: str) -> List[int]: + tokens = [token.strip() for token in value.split(",") if token.strip()] + return [int(token) for token in tokens] + + +def _parse_topn_list_with_all(value: str) -> List[Optional[int]]: + tokens = [token.strip() for token in value.split(",") if token.strip()] + parsed: List[Optional[int]] = [] + for token in tokens: + if token.lower() == "all": + parsed.append(None) + else: + parsed.append(int(token)) + return parsed + + +def _load_receptors_from_file(filepath: str) -> List[str]: + path = Path(filepath) + if not path.exists(): + raise FileNotFoundError(f"Receptor file not found: {filepath}") + receptors: List[str] = [] + with open(path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if line and not line.startswith("#"): + receptors.append(line) + return receptors + + +def _build_valid_odorants( + predictor: LassoBehavioralPredictor, + condition: str, + subtract_control: bool, + control_condition: Optional[str], + missing_control_policy: str, +) -> Tuple[pd.Series, str, Optional[str], int]: + resolved_condition = predictor._resolve_dataset_name(condition) + if resolved_condition is None: + raise ValueError(f"Condition '{condition}' not found in behavioral data") + + if not subtract_control: + per_responses = predictor.behavioral_data.loc[resolved_condition] + valid_odorants = per_responses.dropna() + if len(valid_odorants) == 0: + raise ValueError(f"No valid PER data for condition '{resolved_condition}'") + return valid_odorants, resolved_condition, None, 0 + + if missing_control_policy not in {"skip", "zero", "error"}: + raise ValueError( + f"Unknown missing_control_policy '{missing_control_policy}'. " + "Expected one of: skip, zero, error." + ) + + if control_condition is not None: + control_resolved = predictor._resolve_dataset_name(control_condition) + if control_resolved is None: + raise ValueError(f"Control condition '{control_condition}' not found in behavioral data") + else: + control_candidate = predictor._infer_control_condition(resolved_condition) + if control_candidate is None: + raise ValueError( + f"No matched control mapping for '{resolved_condition}'. " + "Provide --control_condition or disable --subtract_control." + ) + control_resolved = predictor._resolve_dataset_name(control_candidate) + if control_resolved is None: + raise ValueError(f"Matched control '{control_candidate}' not found in behavioral data") + + if control_resolved == resolved_condition: + raise ValueError( + f"Control condition '{control_resolved}' matches opto condition '{resolved_condition}'." + ) + + per_opto = predictor.behavioral_data.loc[resolved_condition] + per_ctrl = predictor.behavioral_data.loc[control_resolved] + + if missing_control_policy == "skip": + valid_mask = per_opto.notna() & per_ctrl.notna() + valid_odorants = (per_opto - per_ctrl)[valid_mask] + elif missing_control_policy == "zero": + valid_mask = per_opto.notna() + per_ctrl_filled = per_ctrl.fillna(0.0) + valid_odorants = (per_opto - per_ctrl_filled)[valid_mask] + else: + missing_mask = per_opto.notna() & per_ctrl.isna() + if missing_mask.any(): + missing_odorants = [str(o) for o in per_opto.index[missing_mask]] + preview = ", ".join(missing_odorants[:5]) + if len(missing_odorants) > 5: + preview = f"{preview} (and {len(missing_odorants) - 5} more)" + raise ValueError( + f"Control condition '{control_resolved}' has missing values for odorants " + f"present in '{resolved_condition}': {preview}" + ) + valid_mask = per_opto.notna() & per_ctrl.notna() + valid_odorants = (per_opto - per_ctrl)[valid_mask] + + if len(valid_odorants) == 0: + raise ValueError( + f"No valid opto/control pairs for '{resolved_condition}' and '{control_resolved}'." + ) + + return valid_odorants, resolved_condition, control_resolved, int(len(valid_odorants)) + + +def _extract_features( + predictor: LassoBehavioralPredictor, + valid_odorants: pd.Series, + condition: str, + prediction_mode: str, +) -> Tuple[np.ndarray, List[str], np.ndarray]: + if prediction_mode == "test_odorant": + return predictor._extract_test_odorant_features(valid_odorants) + if prediction_mode == "trained_odorant": + trained_odorant = predictor.CONDITION_ODORANT_MAPPING.get(condition) + if not trained_odorant: + raise ValueError(f"Could not determine trained odorant for {condition}") + return predictor._extract_trained_odorant_features(trained_odorant, valid_odorants) + if prediction_mode == "interaction": + trained_odorant = predictor.CONDITION_ODORANT_MAPPING.get(condition) + if not trained_odorant: + raise ValueError(f"Could not determine trained odorant for {condition}") + return predictor._extract_interaction_features(trained_odorant, valid_odorants) + raise ValueError(f"Unknown prediction_mode: {prediction_mode}") + + +def _save_weights_csv( + weights: Dict[str, float], filepath: Path, condition: str, label: str +) -> None: + rows = [] + for receptor, weight in sorted(weights.items(), key=lambda x: abs(x[1]), reverse=True): + rows.append( + { + "condition": condition, + "run_label": label, + "receptor": receptor, + "weight": weight, + "abs_weight": abs(weight), + } + ) + df = pd.DataFrame(rows) + filepath.parent.mkdir(parents=True, exist_ok=True) + df.to_csv(filepath, index=False) + + +def _save_metrics_json(metrics: RunMetrics, filepath: Path, extra: Optional[Dict] = None) -> None: + data = { + "run_type": metrics.run_type, + "label": metrics.label, + "cv_r2": metrics.cv_r2, + "cv_mse": metrics.cv_mse, + "lambda_value": metrics.lambda_value, + "n_receptors_selected": metrics.n_receptors_selected, + "n_samples": metrics.n_samples, + } + if extra: + data.update(extra) + filepath.parent.mkdir(parents=True, exist_ok=True) + with open(filepath, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2) + + +def _delta_label(delta: float, tol: float = 1e-12) -> str: + if delta > tol: + return "higher" + if delta < -tol: + return "lower" + return "same" + + +def _summary_markdown(df: pd.DataFrame) -> str: + return df.to_markdown(index=False) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Run baseline + ablation + focus LASSO sweeps with summaries.", + ) + parser.add_argument("--door_cache", required=True, help="Path to DoOR cache directory.") + parser.add_argument("--behavior_csv", required=True, help="Path to behavioral matrix CSV.") + parser.add_argument( + "--condition", + required=True, + action="append", + help="Condition name(s). Repeat flag or pass comma-separated list.", + ) + parser.add_argument("--output_dir", required=True, help="Directory to write outputs.") + + parser.add_argument( + "--prediction_mode", + choices=["test_odorant", "trained_odorant", "interaction"], + default="test_odorant", + help="Feature extraction mode.", + ) + parser.add_argument("--cv_folds", type=int, default=5, help="Number of CV folds.") + parser.add_argument( + "--lambda_range", + default="0.0001,0.001,0.01,0.1,1.0", + help="Comma-separated lambda values (e.g., 0.0001,0.001,0.01).", + ) + parser.add_argument( + "--scale_features", + action="store_true", + default=True, + help="Standardize receptor features (default: True).", + ) + parser.add_argument( + "--no_scale_features", + action="store_false", + dest="scale_features", + help="Do not standardize receptor features.", + ) + + parser.add_argument("--subtract_control", action="store_true", help="Fit ΔPER = opto - control.") + parser.add_argument( + "--control_condition", + default=None, + help="Optional control dataset override (applies to all conditions).", + ) + parser.add_argument( + "--missing_control_policy", + choices=["skip", "zero", "error"], + default="error", + help="How to handle missing control values.", + ) + + parser.add_argument( + "--ablate", + type=str, + default=None, + help="Comma-separated list of receptors for specific ablation.", + ) + parser.add_argument( + "--ablate_file", + type=str, + default=None, + help="File with receptor names (one per line) for specific ablation.", + ) + parser.add_argument( + "--specific_ablation_mode", + choices=["single", "all_in_one"], + default="all_in_one", + help="Ablation mode for --ablate/--ablate_file (default: all_in_one).", + ) + parser.add_argument( + "--top_k_max", + type=int, + default=0, + help="Max number of top receptors to consider (0 = all).", + ) + parser.add_argument( + "--top_k_group_list", + type=str, + default="2,3", + help="Comma-separated list for cumulative top-K ablations (default: 2,3).", + ) + parser.add_argument( + "--all_but_top_list", + type=str, + default="2,3,all", + help="Comma-separated list for ablate-all-but-top-K (default: 2,3,all).", + ) + parser.add_argument( + "--no_top_single", + action="store_true", + help="Disable single ablations for each top receptor.", + ) + parser.add_argument( + "--no_ablate_all_top", + action="store_true", + help="Disable ablation of all top receptors together.", + ) + parser.add_argument( + "--no_ablate_pos_neg", + action="store_true", + help="Disable ablation of positive/negative top sets.", + ) + parser.add_argument( + "--no_all_but_top", + action="store_true", + help="Disable ablate-all-but-top runs.", + ) + parser.add_argument( + "--focus_topn_list", + type=str, + default="1,2,3", + help="Comma-separated list of top-N values for focus runs (default: 1,2,3).", + ) + parser.add_argument( + "--focus_receptors", + type=str, + default=None, + help="Comma-separated list of receptors to focus on (overrides --focus_topn_list).", + ) + parser.add_argument( + "--missing_receptor_policy", + choices=["error", "skip"], + default="error", + help="Policy for unresolved receptor names in ablation/focus.", + ) + + return parser.parse_args() + + +def main() -> int: + logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") + args = parse_args() + + conditions = _parse_conditions(args.condition) + if not conditions: + logger.error("No valid conditions provided.") + return 1 + + lambda_range = _parse_lambda_range(args.lambda_range) + focus_topn_list = _parse_topn_list(args.focus_topn_list) + top_k_group_list = _parse_topn_list(args.top_k_group_list) + all_but_top_list = _parse_topn_list_with_all(args.all_but_top_list) + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + predictor = LassoBehavioralPredictor( + doorcache_path=args.door_cache, + behavior_csv_path=args.behavior_csv, + scale_features=False, + scale_targets=False, + ) + + available_receptors = list(predictor.masked_receptor_names) + strict_receptors = args.missing_receptor_policy == "error" + + all_summary_rows: List[Dict] = [] + exit_code = 0 + + for condition in conditions: + try: + resolved_condition = predictor._resolve_dataset_name(condition) + except ValueError as exc: + logger.error(str(exc)) + exit_code = 1 + continue + + if resolved_condition is None: + logger.error( + "Condition '%s' not found. Available: %s", + condition, + list(predictor.behavioral_data.index), + ) + exit_code = 1 + continue + + if ( + args.subtract_control + and args.missing_control_policy == "skip" + and args.control_condition is None + ): + control_candidate = predictor._infer_control_condition(resolved_condition) + if control_candidate is None: + logger.warning( + "No matched control mapping for '%s'; skipping ΔPER run " + "(missing_control_policy=skip).", + resolved_condition, + ) + continue + control_resolved = predictor._resolve_dataset_name(control_candidate) + if control_resolved is None: + logger.warning( + "Matched control '%s' not found for '%s'; skipping ΔPER run " + "(missing_control_policy=skip).", + control_candidate, + resolved_condition, + ) + continue + + try: + valid_odorants, _, control_resolved, n_pairs_used = _build_valid_odorants( + predictor, + condition=resolved_condition, + subtract_control=args.subtract_control, + control_condition=args.control_condition, + missing_control_policy=args.missing_control_policy, + ) + except Exception as exc: + logger.error("Condition '%s' failed: %s", resolved_condition, exc) + exit_code = 1 + continue + + try: + X, test_odorants, y = _extract_features( + predictor, + valid_odorants, + resolved_condition, + args.prediction_mode, + ) + except Exception as exc: + logger.error("Feature extraction failed for %s: %s", resolved_condition, exc) + exit_code = 1 + continue + + if X.shape[0] < 3: + logger.error("Insufficient samples for %s: %d", resolved_condition, X.shape[0]) + exit_code = 1 + continue + + condition_dir = output_dir / resolved_condition + condition_dir.mkdir(parents=True, exist_ok=True) + + baseline_scaler = StandardScaler().fit(X) if args.scale_features else None + baseline_weights, baseline_r2, baseline_mse, baseline_lambda, baseline_pred = ( + fit_lasso_with_fixed_scaler( + X=X, + y=y, + receptor_names=available_receptors, + scaler=baseline_scaler, + lambda_range=lambda_range, + cv_folds=args.cv_folds, + ) + ) + + baseline_metrics = RunMetrics( + run_type="baseline", + label="baseline", + cv_r2=baseline_r2, + cv_mse=baseline_mse, + lambda_value=baseline_lambda, + n_receptors_selected=len(baseline_weights), + n_samples=X.shape[0], + ) + + _save_weights_csv( + baseline_weights, + condition_dir / "baseline_weights.csv", + resolved_condition, + "baseline", + ) + _save_metrics_json( + baseline_metrics, + condition_dir / "baseline_metrics.json", + extra={ + "prediction_mode": args.prediction_mode, + "subtract_control": args.subtract_control, + "control_condition": control_resolved, + "missing_control_policy": args.missing_control_policy, + "n_pairs_used": int(n_pairs_used), + "n_receptors_total": X.shape[1], + }, + ) + + top_receptors_all = get_top_receptors_by_weight( + baseline_weights, len(baseline_weights) + ) + if args.top_k_max and args.top_k_max > 0: + top_receptors = top_receptors_all[: args.top_k_max] + else: + top_receptors = top_receptors_all + + summary_rows: List[Dict] = [] + + def add_summary(metrics: RunMetrics, extra: Optional[Dict] = None) -> None: + delta_r2 = metrics.cv_r2 - baseline_metrics.cv_r2 + delta_mse = metrics.cv_mse - baseline_metrics.cv_mse + row = { + "condition": resolved_condition, + "run_type": metrics.run_type, + "label": metrics.label, + "cv_r2": metrics.cv_r2, + "cv_mse": metrics.cv_mse, + "lambda_value": metrics.lambda_value, + "n_receptors_selected": metrics.n_receptors_selected, + "n_samples": metrics.n_samples, + "n_total_receptors": X.shape[1], + "delta_r2": delta_r2, + "delta_mse": delta_mse, + "r2_vs_baseline": _delta_label(delta_r2), + "mse_vs_baseline": _delta_label(delta_mse), + } + if extra: + row.update(extra) + summary_rows.append(row) + + add_summary( + baseline_metrics, + extra={ + "n_ablated": 0, + "n_kept": X.shape[1], + }, + ) + + ablation_labels_seen: set[str] = set() + + def run_ablation_set( + *, + run_type: str, + label: str, + receptors_to_ablate: List[str], + ) -> None: + if not receptors_to_ablate: + return + if len(label) > 120: + label = f"{label[:80]}_{len(receptors_to_ablate)}" + if label in ablation_labels_seen: + return + ablation_labels_seen.add(label) + + X_ablated, _ = apply_receptor_ablation( + X=X, + receptor_names=available_receptors, + receptors_to_ablate=receptors_to_ablate, + ) + weights, cv_r2, cv_mse, best_lambda, _pred = fit_lasso_with_fixed_scaler( + X=X_ablated, + y=y, + receptor_names=available_receptors, + scaler=baseline_scaler, + lambda_range=lambda_range, + cv_folds=args.cv_folds, + ) + metrics = RunMetrics( + run_type=run_type, + label=label, + cv_r2=cv_r2, + cv_mse=cv_mse, + lambda_value=best_lambda, + n_receptors_selected=len(weights), + n_samples=X_ablated.shape[0], + ) + ablation_dir = condition_dir / "ablations" / label + _save_weights_csv( + weights, + ablation_dir / "weights.csv", + resolved_condition, + label, + ) + _save_metrics_json( + metrics, + ablation_dir / "metrics.json", + extra={ + "receptors_ablated": receptors_to_ablate, + "n_ablated": len(receptors_to_ablate), + }, + ) + add_summary( + metrics, + extra={ + "n_ablated": len(receptors_to_ablate), + "n_kept": X.shape[1] - len(receptors_to_ablate), + }, + ) + + # Specific ablations (optional) + receptors_to_ablate: List[str] = [] + if args.ablate: + receptors_to_ablate = [r.strip() for r in args.ablate.split(",") if r.strip()] + elif args.ablate_file: + receptors_to_ablate = _load_receptors_from_file(args.ablate_file) + + if receptors_to_ablate: + matched, unmatched = resolve_receptor_names( + receptors_to_ablate, available_receptors, strict=strict_receptors + ) + if unmatched and not strict_receptors: + logger.warning("Skipping unmatched receptors: %s", unmatched) + + if matched: + if args.specific_ablation_mode == "single": + for receptor in matched: + run_ablation_set( + run_type="ablation_specific", + label=f"ablate_specific_{receptor}", + receptors_to_ablate=[receptor], + ) + else: + run_ablation_set( + run_type="ablation_specific", + label="ablate_specific_all", + receptors_to_ablate=matched, + ) + + # Auto ablations from baseline top receptors + if not top_receptors: + logger.warning("No baseline weights for %s; skipping auto ablations", resolved_condition) + else: + if not args.no_top_single: + for receptor in top_receptors: + run_ablation_set( + run_type="ablation_top_single", + label=f"ablate_top_single_{receptor}", + receptors_to_ablate=[receptor], + ) + + if top_k_group_list: + for k in top_k_group_list: + if k <= 0: + logger.warning("Invalid top-K value %d; skipping", k) + continue + if k > len(top_receptors): + logger.warning( + "Top-K value %d exceeds available top receptors (%d); skipping", + k, + len(top_receptors), + ) + continue + run_ablation_set( + run_type="ablation_top_group", + label=f"ablate_top_{k}", + receptors_to_ablate=top_receptors[:k], + ) + + if not args.no_ablate_all_top: + run_ablation_set( + run_type="ablation_top_all", + label="ablate_top_all", + receptors_to_ablate=top_receptors, + ) + + if not args.no_ablate_pos_neg: + pos_receptors = [r for r in top_receptors if baseline_weights.get(r, 0.0) > 0] + neg_receptors = [r for r in top_receptors if baseline_weights.get(r, 0.0) < 0] + run_ablation_set( + run_type="ablation_top_positive", + label="ablate_top_positive", + receptors_to_ablate=pos_receptors, + ) + run_ablation_set( + run_type="ablation_top_negative", + label="ablate_top_negative", + receptors_to_ablate=neg_receptors, + ) + + if not args.no_all_but_top: + for keep_k in all_but_top_list: + if keep_k is None: + keep_list = top_receptors + label = "ablate_all_but_top_all" + else: + if keep_k <= 0: + logger.warning("Invalid keep-K value %d; skipping", keep_k) + continue + if keep_k > len(top_receptors): + logger.warning( + "Keep-K value %d exceeds available top receptors (%d); skipping", + keep_k, + len(top_receptors), + ) + continue + keep_list = top_receptors[:keep_k] + label = f"ablate_all_but_top_{keep_k}" + + keep_set = set(keep_list) + ablate_list = [r for r in available_receptors if r not in keep_set] + run_ablation_set( + run_type="ablation_all_but_top", + label=label, + receptors_to_ablate=ablate_list, + ) + + # Focus runs + focus_receptors: Optional[List[str]] = None + if args.focus_receptors: + focus_receptors = [r.strip() for r in args.focus_receptors.split(",") if r.strip()] + + if focus_receptors: + matched, unmatched = resolve_receptor_names( + focus_receptors, available_receptors, strict=strict_receptors + ) + if unmatched and not strict_receptors: + logger.warning("Skipping unmatched receptors: %s", unmatched) + if matched: + X_restricted, kept_names, _ = restrict_to_receptors( + X=X, + receptor_names=available_receptors, + receptors_to_keep=matched, + ) + focus_scaler = StandardScaler().fit(X_restricted) if args.scale_features else None + weights, cv_r2, cv_mse, best_lambda, _pred = fit_lasso_with_fixed_scaler( + X=X_restricted, + y=y, + receptor_names=kept_names, + scaler=focus_scaler, + lambda_range=lambda_range, + cv_folds=args.cv_folds, + ) + label = f"focus_{len(kept_names)}" + metrics = RunMetrics( + run_type="focus_specific", + label=label, + cv_r2=cv_r2, + cv_mse=cv_mse, + lambda_value=best_lambda, + n_receptors_selected=len(weights), + n_samples=X_restricted.shape[0], + ) + focus_dir = condition_dir / "focus" / label + _save_weights_csv(weights, focus_dir / "weights.csv", resolved_condition, label) + _save_metrics_json( + metrics, + focus_dir / "metrics.json", + extra={ + "receptors_kept": kept_names, + "n_kept": len(kept_names), + }, + ) + add_summary( + metrics, + extra={ + "n_ablated": X.shape[1] - len(kept_names), + "n_kept": len(kept_names), + }, + ) + else: + if not top_receptors_all: + logger.warning("No baseline weights for %s; skipping focus runs", resolved_condition) + else: + for n in focus_topn_list: + if n <= 0: + logger.warning("Invalid top-N value %d; skipping", n) + continue + if n > len(top_receptors_all): + logger.warning( + "Top-N value %d exceeds available receptors (%d); skipping", + n, + len(top_receptors_all), + ) + continue + top_n = top_receptors_all[:n] + X_restricted, kept_names, _ = restrict_to_receptors( + X=X, + receptor_names=available_receptors, + receptors_to_keep=top_n, + ) + focus_scaler = ( + StandardScaler().fit(X_restricted) if args.scale_features else None + ) + weights, cv_r2, cv_mse, best_lambda, _pred = fit_lasso_with_fixed_scaler( + X=X_restricted, + y=y, + receptor_names=kept_names, + scaler=focus_scaler, + lambda_range=lambda_range, + cv_folds=args.cv_folds, + ) + label = f"focus_top_{n}" + metrics = RunMetrics( + run_type="focus_topn", + label=label, + cv_r2=cv_r2, + cv_mse=cv_mse, + lambda_value=best_lambda, + n_receptors_selected=len(weights), + n_samples=X_restricted.shape[0], + ) + focus_dir = condition_dir / "focus" / label + _save_weights_csv(weights, focus_dir / "weights.csv", resolved_condition, label) + _save_metrics_json( + metrics, + focus_dir / "metrics.json", + extra={ + "receptors_kept": kept_names, + "n_kept": len(kept_names), + }, + ) + add_summary( + metrics, + extra={ + "n_ablated": X.shape[1] - len(kept_names), + "n_kept": len(kept_names), + }, + ) + + summary_df = pd.DataFrame(summary_rows) + summary_csv = condition_dir / "summary.csv" + summary_df.to_csv(summary_csv, index=False) + + summary_md = condition_dir / "summary.md" + summary_md.write_text(_summary_markdown(summary_df), encoding="utf-8") + + all_summary_rows.extend(summary_rows) + + if all_summary_rows: + all_summary_df = pd.DataFrame(all_summary_rows) + all_summary_df.to_csv(output_dir / "summary_all_conditions.csv", index=False) + (output_dir / "summary_all_conditions.md").write_text( + _summary_markdown(all_summary_df), encoding="utf-8" + ) + + return exit_code + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/test_lasso_behavioral_prediction.py b/tests/test_lasso_behavioral_prediction.py index 1194c6d..b55dfd6 100644 --- a/tests/test_lasso_behavioral_prediction.py +++ b/tests/test_lasso_behavioral_prediction.py @@ -12,6 +12,7 @@ LassoBehavioralPredictor, apply_receptor_ablation, fit_lasso_with_fixed_scaler, + restrict_to_receptors, ) from sklearn.preprocessing import StandardScaler @@ -42,6 +43,18 @@ def control_behavioral_csv(tmp_path): return csv_path +@pytest.fixture +def control_behavioral_csv_extended(tmp_path): + """Create mock behavioral CSV with >=6 odorants and control NaNs.""" + csv_content = """dataset,Hexanol,Benzaldehyde,Linalool,Citral,Apple_Cider_Vinegar,Ethyl_Butyrate +opto_hex,0.8,0.2,0.4,0.6,0.3,0.5 +hex_control,0.1,,0.05,0.3,0.1,0.4 +""" + csv_path = tmp_path / "test_behavior_control_extended.csv" + csv_path.write_text(csv_content) + return csv_path + + @pytest.fixture def lasso_predictor(mock_door_cache, mock_behavioral_csv): """Create LassoBehavioralPredictor instance for testing.""" @@ -64,6 +77,17 @@ def control_lasso_predictor(mock_door_cache, control_behavioral_csv): ) +@pytest.fixture +def control_lasso_predictor_extended(mock_door_cache, control_behavioral_csv_extended): + """Create LassoBehavioralPredictor for extended control subtraction tests.""" + return LassoBehavioralPredictor( + doorcache_path=str(mock_door_cache), + behavior_csv_path=str(control_behavioral_csv_extended), + scale_features=True, + scale_targets=False, + ) + + class TestLassoBehavioralPredictor: """Tests for LassoBehavioralPredictor class.""" @@ -508,6 +532,50 @@ def test_subtract_control_missing_control_row_warns_or_errors( subtract_control=True, ) + def test_subtract_control_raw_matches_behavioral_row( + self, control_lasso_predictor_extended + ): + results = control_lasso_predictor_extended.fit_behavior( + condition_name="opto_hex", + lambda_range=[0.1], + cv_folds=2, + subtract_control=False, + ) + + expected = control_lasso_predictor_extended.behavioral_data.loc["opto_hex"] + actual_map = dict(zip(results.test_odorants, results.actual_per)) + for odor, value in actual_map.items(): + assert value == pytest.approx(float(expected[odor])) + + def test_subtract_control_extended_policies(self, control_lasso_predictor_extended): + results_skip = control_lasso_predictor_extended.fit_behavior( + condition_name="opto_hex", + lambda_range=[0.1], + cv_folds=2, + subtract_control=True, + missing_control_policy="skip", + ) + assert "Benzaldehyde" not in results_skip.test_odorants + + results_zero = control_lasso_predictor_extended.fit_behavior( + condition_name="opto_hex", + lambda_range=[0.1], + cv_folds=2, + subtract_control=True, + missing_control_policy="zero", + ) + actual_map_zero = dict(zip(results_zero.test_odorants, results_zero.actual_per)) + assert actual_map_zero["Benzaldehyde"] == pytest.approx(0.2) + + with pytest.raises(ValueError, match="missing values"): + control_lasso_predictor_extended.fit_behavior( + condition_name="opto_hex", + lambda_range=[0.1], + cv_folds=2, + subtract_control=True, + missing_control_policy="error", + ) + class TestLassoBehavioralPredictorRegressionChecks: """Regression checks for mutation and ΔPER prediction collapse.""" @@ -571,9 +639,9 @@ def test_ablation_does_not_mutate_baseline_matrix( assert np.array_equal(X_before, baseline_repeat.feature_matrix) def test_delta_prediction_not_constant(self, mock_door_cache, tmp_path): - csv_content = """dataset,Hexanol,Benzaldehyde,Linalool,Citral -opto_hex,0.8,0.2,0.4,0.6 -hex_control,0.1,0.0,0.05,0.3 + csv_content = """dataset,Hexanol,Benzaldehyde,Linalool,Citral,Apple_Cider_Vinegar,Ethyl_Butyrate +opto_hex,0.8,0.2,0.4,0.6,0.3,0.5 +hex_control,0.1,0.0,0.05,0.3,0.1,0.4 """ csv_path = tmp_path / "delta_variation.csv" csv_path.write_text(csv_content) @@ -587,7 +655,7 @@ def test_delta_prediction_not_constant(self, mock_door_cache, tmp_path): results = predictor.fit_behavior( condition_name="opto_hex", - lambda_range=[1e-4, 1e-3], + lambda_range=[1e-6, 1e-5, 1e-4], cv_folds=2, subtract_control=True, missing_control_policy="skip", @@ -595,3 +663,39 @@ def test_delta_prediction_not_constant(self, mock_door_cache, tmp_path): assert np.std(results.actual_per) > 1e-6 assert np.std(results.predicted_per) > 1e-6 + + +class TestLassoFeatureShapeIntegrity: + """Tests for ablation/focus feature shape changes.""" + + def test_ablation_and_restriction_shapes(self, control_lasso_predictor_extended): + results = control_lasso_predictor_extended.fit_behavior( + condition_name="opto_hex", + lambda_range=[0.1], + cv_folds=2, + subtract_control=False, + ) + + if results.feature_matrix is None: + pytest.skip("Feature matrix not available for shape checks.") + + X = results.feature_matrix + receptor_names = results.receptor_names + if len(receptor_names) < 2: + pytest.skip("Not enough receptors for shape integrity test.") + + X_ablated, indices = apply_receptor_ablation( + X=X, + receptor_names=receptor_names, + receptors_to_ablate=[receptor_names[0]], + ) + assert X_ablated.shape == X.shape + assert np.allclose(X_ablated[:, indices[0]], 0.0) + + X_restricted, kept_names, kept_indices = restrict_to_receptors( + X=X, + receptor_names=receptor_names, + receptors_to_keep=receptor_names[:2], + ) + assert X_restricted.shape[1] == 2 + assert kept_names == receptor_names[:2] diff --git a/tests/test_lasso_focus_mode.py b/tests/test_lasso_focus_mode.py index 505360b..644cb0e 100644 --- a/tests/test_lasso_focus_mode.py +++ b/tests/test_lasso_focus_mode.py @@ -430,6 +430,52 @@ def test_run_focus_with_synthetic_data(self): assert result.delta_r2 == result.cv_r2 - 0.9 assert result.delta_mse == result.cv_mse - 0.01 + def test_run_focus_reproducible_and_no_mutation(self): + """Test focus-mode reproducibility and no mutation of X.""" + from scripts.lasso_with_focus_mode import run_focus + + np.random.seed(123) + n_samples, n_features = 12, 8 + X = np.random.rand(n_samples, n_features) + X_before = X.copy() + + true_weights = np.zeros(n_features) + true_weights[0] = 1.0 + true_weights[1] = 0.4 + y = X @ true_weights + np.random.randn(n_samples) * 0.05 + + receptor_names = [f"Or{i}" for i in range(n_features)] + lambda_range = np.array([0.001, 0.01]) + + result1 = run_focus( + X=X, + y=y, + receptor_names=receptor_names, + receptors_to_keep=["Or0", "Or1", "Or2"], + lambda_range=lambda_range, + cv_folds=3, + scale_features=True, + baseline_r2=0.0, + baseline_mse=0.0, + ) + result2 = run_focus( + X=X, + y=y, + receptor_names=receptor_names, + receptors_to_keep=["Or0", "Or1", "Or2"], + lambda_range=lambda_range, + cv_folds=3, + scale_features=True, + baseline_r2=0.0, + baseline_mse=0.0, + ) + + assert np.array_equal(X, X_before) + assert result1.cv_mse == pytest.approx(result2.cv_mse) + assert result1.lambda_value == pytest.approx(result2.lambda_value) + assert result1.lasso_weights == result2.lasso_weights + assert not np.isnan(result1.cv_mse) + # ============================================================================ # Tests for edge cases diff --git a/tests/test_stability_metrics.py b/tests/test_stability_metrics.py new file mode 100644 index 0000000..732bd0d --- /dev/null +++ b/tests/test_stability_metrics.py @@ -0,0 +1,119 @@ +"""Tests for stability + standardized metrics diagnostics.""" + +from __future__ import annotations + +import importlib.util +import sys +from pathlib import Path + +import pandas as pd +import pytest + + +def _load_stability_module(): + module_path = Path(__file__).resolve().parents[1] / "diagnostics" / "run_stability_and_metrics.py" + spec = importlib.util.spec_from_file_location("run_stability_and_metrics", module_path) + if spec is None or spec.loader is None: + raise RuntimeError("Failed to load diagnostics/run_stability_and_metrics.py") + module = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = module + spec.loader.exec_module(module) + return module + + +def _write_behavior_csv(tmp_path: Path) -> Path: + csv_content = """dataset,Hexanol,Benzaldehyde,Linalool,Citral +opto_hex,0.2,0.1,0.3,0.4 +""" + csv_path = tmp_path / "behavior.csv" + csv_path.write_text(csv_content) + return csv_path + + +def _run_stability(module, tmp_path: Path, mock_door_cache: Path, behavior_csv: Path, seed: int) -> Path: + output_dir = tmp_path + argv = [ + "run_stability_and_metrics.py", + "--door_cache", + str(mock_door_cache), + "--behavior_csv", + str(behavior_csv), + "--conditions", + "opto_hex", + "--prediction_mode", + "test_odorant", + "--seed", + str(seed), + "--output_dir", + str(output_dir), + "--lambda_range", + "1e-4,1e-3,1e-2", + "--lambda_range_delta", + "1e-4,1e-3,1e-2", + "--include_ridge_stability", + ] + return output_dir, argv + + +def test_stability_determinism_and_schema(tmp_path, mock_door_cache, monkeypatch): + module = _load_stability_module() + behavior_csv = _write_behavior_csv(tmp_path) + + run1_dir = tmp_path / "run1" + run2_dir = tmp_path / "run2" + + run1_dir.mkdir() + run2_dir.mkdir() + + _, argv1 = _run_stability(module, run1_dir, mock_door_cache, behavior_csv, seed=123) + monkeypatch.setattr("sys.argv", argv1) + module.main() + + _, argv2 = _run_stability(module, run2_dir, mock_door_cache, behavior_csv, seed=123) + monkeypatch.setattr("sys.argv", argv2) + module.main() + + df1 = pd.read_csv(run1_dir / "stability_per_condition.csv") + df2 = pd.read_csv(run2_dir / "stability_per_condition.csv") + + pd.testing.assert_frame_equal(df1, df2, check_exact=False, atol=1e-12) + + required_columns = { + "condition", + "mode", + "modelclass", + "orn_name", + "selection_frequency", + "sign_consistency", + "mean_abs_weight", + "std_abs_weight", + "mean_rank", + "n_folds", + } + assert required_columns.issubset(df1.columns) + + metrics_df = pd.read_csv(run1_dir / "model_metrics.csv") + metrics_required = { + "condition", + "mode", + "modelclass", + "cv_mse", + "nmse", + "rmse_over_y_std", + "intercept_only_flag", + "intercept_only_mse", + "y_var", + "y_min", + "y_max", + "pred_min", + "pred_max", + } + assert metrics_required.issubset(metrics_df.columns) + + +def test_intercept_only_flag_logic(): + module = _load_stability_module() + assert module._is_intercept_only(0, 0.0, 1.0, 1.0) + assert not module._is_intercept_only(1, 0.0, 1.0, 1.0) + assert not module._is_intercept_only(0, 1e-3, 1.0, 1.0) + assert not module._is_intercept_only(0, 0.0, 1.0, 1.1)