diff --git a/scripts/audit/importance_connectome_synthetic.py b/scripts/audit/importance_connectome_synthetic.py new file mode 100644 index 0000000..3eea848 --- /dev/null +++ b/scripts/audit/importance_connectome_synthetic.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 +""" +Audit: synthetic connectome amplification invariants. +""" + +from __future__ import annotations + +import numpy as np +import torch + +from door_toolkit.pathways.connectome_analysis import compute_connectome_influence + + +def main() -> None: + receptor_names = ["R1", "R2", "R3"] + s_orn = np.array([1.0, 2.0, 3.0], dtype=np.float64) + + # Case 1: Uniform fanout yields amplification factor mean ~= 1 and amplified == base. + A = torch.ones((2, 3), dtype=torch.float32) # PN x ORN + B = torch.ones((2, 2), dtype=torch.float32) # KC x PN + result = compute_connectome_influence( + s_orn, + A, + B, + receptor_names=receptor_names, + pn_ids=["PN1", "PN2"], + kc_ids=["KC1", "KC2"], + top_pn=2, + top_kc=2, + ) + amp_mean = float(result.orn_table["amplification_factor_kc_mean1"].mean()) + assert abs(amp_mean - 1.0) < 1e-9 + assert np.allclose( + result.orn_table["connectome_amplified_importance"].to_numpy(), + result.orn_table["base_importance"].to_numpy(), + atol=1e-9, + ) + + # Case 2: Permute receptor order and verify named outputs are stable. + perm = np.array([2, 0, 1], dtype=int) + A_perm = A[:, perm] + s_perm = s_orn[perm] + names_perm = [receptor_names[i] for i in perm] + result_perm = compute_connectome_influence( + s_perm, + A_perm, + B, + receptor_names=names_perm, + pn_ids=["PN1", "PN2"], + kc_ids=["KC1", "KC2"], + top_pn=2, + top_kc=2, + ) + base_map = dict( + zip(result.orn_table["receptor"], result.orn_table["connectome_amplified_importance"]) + ) + perm_map = dict( + zip(result_perm.orn_table["receptor"], result_perm.orn_table["connectome_amplified_importance"]) + ) + for receptor in receptor_names: + assert abs(base_map[receptor] - perm_map[receptor]) < 1e-9 + + print("OK: Connectome importance invariants passed.") + + +if __name__ == "__main__": + main() diff --git a/scripts/audit/importance_glm_synthetic.py b/scripts/audit/importance_glm_synthetic.py new file mode 100644 index 0000000..0ad4aa3 --- /dev/null +++ b/scripts/audit/importance_glm_synthetic.py @@ -0,0 +1,162 @@ +#!/usr/bin/env python3 +""" +Audit: synthetic GLM importance invariants (weight + ablation). +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import List, Tuple + +import numpy as np +import torch + +from door_toolkit.pathways.behavior_rate_model import ( + SparseRateGLM, + TrainingTable, + ablation_importance, + weight_importance_dataframe, +) + + +@dataclass(frozen=True) +class SyntheticSpec: + receptor_names: List[str] + n_datasets: int + n_odors: int + seed: int = 0 + + +def build_synthetic_table(spec: SyntheticSpec) -> TrainingTable: + rng = np.random.default_rng(spec.seed) + n_cells = spec.n_datasets * spec.n_odors + n_channels = len(spec.receptor_names) + + # Deterministic features (x_train unused but required by table schema). + x_test = rng.normal(0.0, 1.0, size=(n_cells, n_channels)).astype(np.float32) + x_train = np.zeros_like(x_test) + + datasets = [] + test_odors = [] + reward_flags = [] + for d in range(spec.n_datasets): + ds_name = "opto_A" if d == 0 else "control_B" + for o in range(spec.n_odors): + datasets.append(ds_name) + test_odors.append(f"odor_{o+1}") + reward_flags.append(1.0 if ds_name.startswith("opto_") else 0.0) + + # Build labels from a known linear rule (receptor 0 only). + model = SparseRateGLM( + n_channels=n_channels, + include_test=True, + include_train=False, + include_interaction=False, + include_diff=False, + include_reward=False, + ) + with torch.no_grad(): + model.w_test[:] = 0.0 + model.w_test[0] = 3.0 + model.intercept.fill_(0.0) + + x_test_t = torch.tensor(x_test, dtype=torch.float32) + x_train_t = torch.tensor(x_train, dtype=torch.float32) + reward_t = torch.tensor(reward_flags, dtype=torch.float32) + with torch.no_grad(): + logits = model(x_test_t, x_train_t, reward_t) + y = torch.sigmoid(logits).detach().cpu().numpy().astype(np.float32) + + return TrainingTable( + dataset=datasets, + test_odor=test_odors, + y=torch.tensor(y, dtype=torch.float32), + x_test=x_test_t, + x_train=x_train_t, + reward_flag=reward_t, + receptor_names=list(spec.receptor_names), + ) + + +def compute_importance( + table: TrainingTable, weight_vector: np.ndarray +) -> Tuple[np.ndarray, np.ndarray]: + model = SparseRateGLM( + n_channels=len(table.receptor_names), + include_test=True, + include_train=False, + include_interaction=False, + include_diff=False, + include_reward=False, + ) + with torch.no_grad(): + model.w_test[:] = torch.tensor(weight_vector, dtype=torch.float32) + model.intercept.fill_(0.0) + + weight_df = weight_importance_dataframe(model, table.receptor_names) + ab_global, _, _ = ablation_importance(model, table) + weight_scores = weight_df.set_index("receptor")["importance_weight"].to_dict() + ab_scores = ab_global.set_index("receptor")["delta_bce"].to_dict() + return ( + np.array([weight_scores[r] for r in table.receptor_names], dtype=np.float64), + np.array([ab_scores[r] for r in table.receptor_names], dtype=np.float64), + ) + + +def assert_invariants(receptor_names: List[str], weight_scores: np.ndarray, ab_scores: np.ndarray) -> None: + # Known driver receptor should dominate. + assert receptor_names[0] == receptor_names[weight_scores.argmax()] + assert receptor_names[0] == receptor_names[ab_scores.argmax()] + + # Non-driver receptors should have near-zero ablation impact. + if len(ab_scores) > 1: + max_other = np.max(ab_scores[1:]) + assert max_other < 1e-6, f"Expected near-zero ablation for non-drivers, got {max_other}" + + +def main() -> None: + spec = SyntheticSpec( + receptor_names=["R1", "R2", "R3", "R4"], + n_datasets=2, + n_odors=3, + seed=7, + ) + + table = build_synthetic_table(spec) + base_weights = np.zeros(len(spec.receptor_names), dtype=np.float32) + base_weights[0] = 3.0 + + weight_scores, ab_scores = compute_importance(table, base_weights) + assert_invariants(spec.receptor_names, weight_scores, ab_scores) + + # Permutation invariance (named results should be stable). + perm = np.array([2, 0, 3, 1], dtype=int) + perm_names = [spec.receptor_names[i] for i in perm] + perm_table = TrainingTable( + dataset=table.dataset, + test_odor=table.test_odor, + y=table.y, + x_test=table.x_test[:, perm], + x_train=table.x_train[:, perm], + reward_flag=table.reward_flag, + receptor_names=perm_names, + ) + perm_weights = base_weights[perm] + weight_scores_perm, ab_scores_perm = compute_importance(perm_table, perm_weights) + + # Compare by receptor name. + base_by_name = dict(zip(spec.receptor_names, weight_scores)) + perm_by_name = dict(zip(perm_names, weight_scores_perm)) + for receptor in spec.receptor_names: + assert abs(base_by_name[receptor] - perm_by_name[receptor]) < 1e-9 + + base_ab_by_name = dict(zip(spec.receptor_names, ab_scores)) + perm_ab_by_name = dict(zip(perm_names, ab_scores_perm)) + for receptor in spec.receptor_names: + assert abs(base_ab_by_name[receptor] - perm_ab_by_name[receptor]) < 1e-9 + + print("OK: GLM importance invariants passed.") + + +if __name__ == "__main__": + main() diff --git a/scripts/audit/importance_lasso_synthetic.py b/scripts/audit/importance_lasso_synthetic.py new file mode 100644 index 0000000..3dbdb96 --- /dev/null +++ b/scripts/audit/importance_lasso_synthetic.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python3 +""" +Audit: synthetic LASSO importance invariants. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Dict, List, Tuple + +import numpy as np +import pandas as pd +from sklearn.linear_model import Lasso, LassoCV +from sklearn.model_selection import cross_val_score +from sklearn.preprocessing import StandardScaler + +from door_toolkit.encoder import DoOREncoder +from door_toolkit.pathways.behavioral_prediction import LassoBehavioralPredictor + + +def _clean_token(text: str) -> str: + return str(text).lower().replace("_", "").replace(" ", "").replace("-", "") + + +def resolve_door_name(csv_name: str, encoder: DoOREncoder) -> str | None: + mapping = LassoBehavioralPredictor.ODORANT_NAME_MAPPING + key = _clean_token(csv_name) + candidates = mapping.get(key) + if candidates is None: + return None + for candidate in candidates: + try: + encoder.encode(candidate) + return candidate + except KeyError: + continue + raise ValueError(f"No DoOR match for {csv_name}") + + +def fit_lasso_weights( + X: np.ndarray, + y: np.ndarray, + receptor_names: List[str], + lambda_range: np.ndarray, + cv_folds: int = 5, +) -> Tuple[Dict[str, float], float]: + X_scaled = StandardScaler().fit_transform(X) + if X_scaled.shape[0] < 10: + cv_folds_adjusted = X_scaled.shape[0] + else: + cv_folds_adjusted = min(cv_folds, X_scaled.shape[0]) + + lasso_cv = LassoCV( + alphas=lambda_range, cv=cv_folds_adjusted, max_iter=10000, random_state=42 + ) + lasso_cv.fit(X_scaled, y) + + lasso = Lasso(alpha=lasso_cv.alpha_, max_iter=10000, random_state=42) + lasso.fit(X_scaled, y) + + weights = { + receptor_names[i]: float(coef) + for i, coef in enumerate(lasso.coef_) + if abs(coef) > 1e-6 + } + cv_scores = cross_val_score( + lasso, X_scaled, y, cv=cv_folds_adjusted, scoring="neg_mean_squared_error" + ) + cv_mse = float(-np.mean(cv_scores)) + return weights, cv_mse + + +def main() -> None: + out_dir = Path("outputs/audit/importance_lasso_synthetic") + out_dir.mkdir(parents=True, exist_ok=True) + + odors = [ + "Hexanol", + "Ethyl_Butyrate", + "Ethyl_Butyrate_(6-Training)", + "Benzaldehyde", + "Apple_Cider_Vinegar", + "3-Octonol", + "Citral", + "Linalool", + ] + + encoder = DoOREncoder("door_cache", use_torch=False) + door_names = [resolve_door_name(o, encoder) for o in odors] + if any(name is None for name in door_names): + missing = [o for o, name in zip(odors, door_names) if name is None] + raise ValueError(f"Unresolved odors in synthetic set: {missing}") + profiles = np.stack( + [ + np.asarray(encoder.encode(name, fill_missing=0.0), dtype=np.float64) + for name in door_names + if name is not None + ], + axis=0, + ) + + # Choose a receptor with high variance across odors. + variances = profiles.var(axis=0) + target_idx = int(np.argmax(variances)) + target_receptor = encoder.receptor_names[target_idx] + responses = profiles[:, target_idx] + resp_min = float(np.min(responses)) + resp_max = float(np.max(responses)) + if resp_max <= resp_min: + raise ValueError("No dynamic range in synthetic responses.") + y = (responses - resp_min) / (resp_max - resp_min) + + behavior_df = pd.DataFrame([y], index=["opto_hex"], columns=odors) + behavior_csv = out_dir / "synthetic_behavior.csv" + behavior_df.to_csv(behavior_csv) + + predictor = LassoBehavioralPredictor( + doorcache_path="door_cache", + behavior_csv_path=str(behavior_csv), + scale_features=True, + scale_targets=False, + ) + lambda_range = np.logspace(-4, 0, 30) + results = predictor.fit_behavior( + condition_name="opto_hex", + prediction_mode="test_odorant", + lambda_range=lambda_range.tolist(), + cv_folds=5, + ) + + # Invariant 1: target receptor should be top by absolute weight. + if not results.lasso_weights: + raise ValueError("No non-zero LASSO weights found.") + top_receptor = max(results.lasso_weights.items(), key=lambda x: abs(x[1]))[0] + assert top_receptor == target_receptor, ( + f"Expected top receptor {target_receptor}, got {top_receptor}" + ) + + # Invariant 2: permuting receptor order preserves top receptors by name. + X = results.feature_matrix + names = results.receptor_names + weights_orig, cv_mse_orig = fit_lasso_weights(X, results.actual_per, names, lambda_range) + + perm = np.array([2, 0, 4, 1, 3, 5, 6, 7] + list(range(8, X.shape[1])), dtype=int) + X_perm = X[:, perm] + names_perm = [names[i] for i in perm] + weights_perm, _ = fit_lasso_weights(X_perm, results.actual_per, names_perm, lambda_range) + + def top_receptors(weights: Dict[str, float], k: int = 3) -> List[str]: + ordered = sorted(weights.items(), key=lambda x: abs(x[1]), reverse=True) + return [name for name, _ in ordered[:k]] + + top_orig = top_receptors(weights_orig, k=3) + top_perm = top_receptors(weights_perm, k=3) + assert target_receptor in top_orig + assert target_receptor in top_perm + + # Invariant 3: random labels worsen CV MSE. + rng = np.random.default_rng(0) + shuffled = rng.permutation(results.actual_per) + _, cv_mse_shuf = fit_lasso_weights(X, shuffled, names, lambda_range) + assert cv_mse_shuf >= cv_mse_orig * 1.1, ( + f"Label shuffle did not degrade MSE (orig={cv_mse_orig:.4f}, shuf={cv_mse_shuf:.4f})" + ) + + print("OK: LASSO importance invariants passed.") + + +if __name__ == "__main__": + main() diff --git a/scripts/audit/importance_shapley_proxy_synthetic.py b/scripts/audit/importance_shapley_proxy_synthetic.py new file mode 100644 index 0000000..b20829e --- /dev/null +++ b/scripts/audit/importance_shapley_proxy_synthetic.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +""" +Audit: proxy Shapley importance invariants (variance-based). +""" + +from __future__ import annotations + +from pathlib import Path + +import numpy as np +import pandas as pd + +from door_toolkit.pathways.analyzer import PathwayAnalyzer + + +def main() -> None: + # Tiny synthetic behavioral dataset (loaded for audit parity; not used by the proxy method). + out_dir = Path("outputs/audit/importance_shapley_proxy") + out_dir.mkdir(parents=True, exist_ok=True) + behavior_df = pd.DataFrame( + [[0.1, 0.5, 0.9]], index=["opto_hex"], columns=["Hexanol", "Benzaldehyde", "Linalool"] + ) + behavior_csv = out_dir / "synthetic_behavior.csv" + behavior_df.to_csv(behavior_csv) + + analyzer = PathwayAnalyzer("door_cache") + odorants = analyzer.encoder.odorant_names[:5] + importance = analyzer.compute_shapley_importance("feeding", odorants=list(odorants)) + + if importance: + total = float(sum(importance.values())) + assert abs(total - 1.0) < 1e-6 + assert all(v >= 0.0 for v in importance.values()) + + # Permutation invariance for odorant list order. + reversed_importance = analyzer.compute_shapley_importance( + "feeding", odorants=list(reversed(odorants)) + ) + keys = set(importance.keys()) | set(reversed_importance.keys()) + for key in keys: + v1 = importance.get(key, 0.0) + v2 = reversed_importance.get(key, 0.0) + assert abs(v1 - v2) < 1e-9 + + print("OK: Shapley-proxy importance invariants passed.") + + +if __name__ == "__main__": + main()