diff --git a/.gitignore b/.gitignore index d3f14a9ca..bfed5cdc6 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ # related to files .pybiomart.sqlite +.venv/ logs/ params* resources* diff --git a/common b/common index 876036f71..f01ff2170 160000 --- a/common +++ b/common @@ -1 +1 @@ -Subproject commit 876036f71713cbd79285b108ab0a9a8238f2b5e1 +Subproject commit f01ff2170161295e89014ee5453c61b29b4e4e77 diff --git a/dockers/dictys_v4/Dockerfile b/dockers/dictys_v4/Dockerfile new file mode 100644 index 000000000..bf3cfa0e4 --- /dev/null +++ b/dockers/dictys_v4/Dockerfile @@ -0,0 +1,122 @@ +FROM ubuntu:22.04 + +ARG DEBIAN_FRONTEND=noninteractive +ENV TZ="America/New_York" + +# Base OS deps (build tools + common libs) + libpng for matplotlib +RUN apt-get update && apt-get install -y --no-install-recommends \ + curl ca-certificates git unzip zip \ + build-essential pkg-config \ + zlib1g-dev libbz2-dev liblzma-dev \ + libncurses5-dev libgdbm-dev libnss3-dev libssl-dev libreadline-dev libffi-dev libsqlite3-dev \ + libfreetype6-dev libpng-dev \ + python3-venv python3-distutils python3-dev \ + # bio tools via apt instead of building + samtools tabix \ + perl \ + && rm -rf /var/lib/apt/lists/* + +# Install CPython 3.9.17 from source +ARG PYTHON_VERSION=3.9.17 +RUN set -eux; \ + cd /tmp; \ + curl -fsSLO https://www.python.org/ftp/python/${PYTHON_VERSION}/Python-${PYTHON_VERSION}.tgz; \ + tar -xzf Python-${PYTHON_VERSION}.tgz; \ + cd Python-${PYTHON_VERSION}; \ + ./configure --enable-optimizations; \ + make -j"$(nproc)"; \ + make install; \ + cd /; rm -rf /tmp/Python-${PYTHON_VERSION}*; \ + ln -s /usr/local/bin/python3 /usr/local/bin/python; \ + ln -s /usr/local/bin/pip3 /usr/local/bin/pip + +# Make constraints global for all pip installs +COPY constraints.txt /tmp/constraints.txt +ENV PIP_CONSTRAINT=/tmp/constraints.txt \ + PIP_DISABLE_PIP_VERSION_CHECK=1 \ + PIP_DEFAULT_TIMEOUT=180 + +# Clean any existing numpy/matplotlib remnants aggressively +RUN python - <<'PY' +import sys, site, pkgutil, shutil, pathlib +paths = set(site.getsitepackages() + [site.getusersitepackages()]) +for p in list(paths): + if not p: + continue + for name in ("numpy", "matplotlib"): + for m in pathlib.Path(p).glob(name): + shutil.rmtree(m, ignore_errors=True) + for m in pathlib.Path(p).glob(f"{name}-*.dist-info"): + shutil.rmtree(m, ignore_errors=True) + for m in pathlib.Path(p).glob(f"{name}-*.egg-info"): + shutil.rmtree(m, ignore_errors=True) +print("Cleaned numpy/matplotlib from:", *paths, sep="\n - ") +PY + +# Install bedtools +RUN apt-get update && apt-get install -y --no-install-recommends bedtools \ + && rm -rf /var/lib/apt/lists/* + +# Install tools + exact pins +RUN python -m pip install --no-cache-dir -U pip setuptools wheel \ + && pip install --no-cache-dir --upgrade --force-reinstall \ + "numpy==1.26.4" "matplotlib==3.8.4" "cython<3" + +# Install MACS2. Build-from-source packages must reuse our pinned toolchain +RUN pip install --no-cache-dir --no-build-isolation MACS2==2.2.9.1 + +# Install Dictys without dependencies (we'll install them manually right after) +RUN pip install --no-cache-dir --no-build-isolation --no-deps \ + git+https://github.com/pinellolab/dictys.git@a82930fe8030af2785f9069ef5e909e49acc866f + +# Install Dictys dependencies and more +RUN pip install --no-cache-dir --prefer-binary \ + pandas scipy networkx h5py threadpoolctl joblib \ + jupyter jupyterlab adjustText pyro-ppl docutils requests + +# Install pyDNase and anndata without dependencies so it can't pin matplotlib<2 +RUN pip install --no-cache-dir --no-build-isolation --no-deps pyDNase clint pysam packaging array_api_compat legacy-api-wrap zarr natsort anndata + +# Install pybedtools version that works with cython<3 +RUN pip install --no-cache-dir --no-build-isolation "pybedtools==0.9.1" + +# Install pytorch +# RUN pip install --no-cache-dir --prefer-binary --index-url https://download.pytorch.org/whl/cpu torch + +# HOMER prerequisites +RUN apt-get update && apt-get install -y --no-install-recommends \ + wget perl unzip build-essential zlib1g-dev \ + && rm -rf /var/lib/apt/lists/* + +# Install HOMER core + hg38 genome +RUN set -eux; \ + mkdir -p /opt/homer && cd /opt/homer; \ + curl -fsSLO http://homer.ucsd.edu/homer/configureHomer.pl; \ + chmod +x configureHomer.pl; \ + perl configureHomer.pl -install homer; \ + perl configureHomer.pl -install homerTools; \ + perl configureHomer.pl -install hg38 +ENV PATH="/opt/homer/bin:${PATH}" + +# hg38 annotations +RUN set -eux; \ + cd /opt/homer; \ + grep "hg38" update.txt > tmp.txt && mv tmp.txt update.txt; \ + cd update && ./updateUCSCGenomeAnnotations.pl ../update.txt + +# Install CUDA +# RUN curl -fsSLo /etc/apt/preferences.d/cuda-repository-pin-600 \ +# https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-ubuntu2204.pin && \ +# curl -fsSLo /usr/share/keyrings/nvidia-cuda.gpg \ +# https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/3bf863cc.pub && \ +# echo "deb [signed-by=/usr/share/keyrings/nvidia-cuda.gpg] http://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/ /" \ +# > /etc/apt/sources.list.d/cuda.list && \ +# apt-get update && apt-get install -y --no-install-recommends cuda && \ +# rm -rf /var/lib/apt/lists/* + +# Install AWS CLI +RUN curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip" && \ + unzip awscliv2.zip && \ + ./aws/install + +CMD ["/bin/bash"] \ No newline at end of file diff --git a/dockers/dictys_v4/constraints.txt b/dockers/dictys_v4/constraints.txt new file mode 100644 index 000000000..c5b8e303b --- /dev/null +++ b/dockers/dictys_v4/constraints.txt @@ -0,0 +1,3 @@ +numpy==1.26.4 +matplotlib<3.9 +cython<3 \ No newline at end of file diff --git a/src/methods/dictys/helper.py b/src/methods/dictys/helper.py index 90cde9b03..46c28901e 100644 --- a/src/methods/dictys/helper.py +++ b/src/methods/dictys/helper.py @@ -1,6 +1,8 @@ import os os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" os.environ["MKL_THREADING_LAYER"] = "GNU" +import shutil +from typing import Optional, List import numpy as np import pandas as pd @@ -11,6 +13,27 @@ warnings.filterwarnings("ignore") +OVERRIDE_DOWNLOAD = False + + +def run_cmd(cmd: List[str], cwd: Optional[str] = None) -> None: + kwargs = {} + if cwd is not None: + kwargs['cwd'] = cwd + with subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + **kwargs + ) as proc: + for line in proc.stdout: + print(line, end="") + rc = proc.wait() + if rc != 0: + raise RuntimeError(f"Command {cmd} failed with exit code {rc}") + def define_vars(par): os.makedirs(par['temp_dir'], exist_ok=True) @@ -28,6 +51,7 @@ def define_vars(par): par['bams_dir'] = f"{par['data_dir']}/bams/" par['gene_bed'] = f"{par['data_dir']}/gene.bed" + par['make_dir'] = f"{par['temp_dir']}/makefiles" def extract_exp(par): @@ -88,6 +112,7 @@ def extract_atac(par): print(f'Sort and compress tsv file {frags_path}') os.system(f"sort -k1,1 -k2,2n {temp_path} | bgzip -c > {frags_path}") + def create_bam(par): print('Creating BAM file from fragments', flush=True) cmd = f"python {par['frag_to_bam']} --fnames {par['frags_path']} --barcodes {par['barcodes']}" @@ -107,9 +132,24 @@ def bam_to_bams(par): - 'bams_dir': path to output folder for per-cell BAMs - 'exp_path': path to reference expression file """ + + print('Delete temp BAM directories', flush=True) + folders = [ + par['bams_dir'], + os.path.join(par['bams_dir'], '..', 'bams_text'), + os.path.join(par['bams_dir'], '..', 'bams_header') + ] + for folder in folders: + if os.path.exists(folder): + shutil.rmtree(folder) + print('Splitting BAM into per-cell BAMs', flush=True) - cmd = f"bash dictys_helper split_bam.sh {par['bam_name']} {par['bams_dir']} --section CB:Z: --ref_expression {par['exp_path']}" - run_cmd(cmd) + run_cmd([ + "bash", "dictys_helper", "split_bam.sh", par['bam_name'], par['bams_dir'], + "--section", "CB:Z:", "--ref_expression", par['exp_path'] + ]) + + def extrac_clusters(par): print('Extracting clusters', flush=True) subsets = f"{par['data_dir']}/subsets.txt" @@ -127,15 +167,6 @@ def extrac_clusters(par): subprocess.run(cp, shell=True, check=True) print('Extracting clusters successful', flush=True) -def run_cmd(cmd): - try: - result = subprocess.run(cmd, check=True, text=True, capture_output=True, shell=True) - print("STDOUT:", result.stdout) - print("STDERR:", result.stderr) - except subprocess.CalledProcessError as e: - print("Command failed with exit code", e.returncode) - print("STDOUT:", e.stdout) - print("STDERR:", e.stderr) def download_file(url, dest): import requests @@ -145,21 +176,27 @@ def download_file(url, dest): with open(dest, "wb") as f: for chunk in r.iter_content(chunk_size=8192): f.write(chunk) + + def get_priors(par): import gzip import shutil # - get the genome print('Getting genome ...', flush=True) - os.makedirs(f"{par['data_dir']}/genome/", exist_ok=True) - cmd = f"aws s3 cp s3://openproblems-data/resources/grn/supp_data/genome/genome.fa {par['data_dir']}/genome/ --no-sign-request" - try: - run_cmd(cmd) - except: + if OVERRIDE_DOWNLOAD or (not os.path.exists(f"{par['data_dir']}/genome/genome.fa")): + os.makedirs(f"{par['data_dir']}/genome/", exist_ok=True) try: - cmd = f"cp resources/supp_data/genome/genome.fa {par['data_dir']}/genome/" - run_cmd(cmd) + run_cmd([ + "aws", "s3", "cp", "s3://openproblems-data/resources/grn/supp_data/genome/genome.fa", + f"{par['data_dir']}/genome/", "--no-sign-request" + ]) except: - raise ValueError("Could not get the genome") + try: + run_cmd([ + "cp", "resources/supp_data/genome/genome.fa", f"{par['data_dir']}/genome/" + ]) + except: + raise ValueError("Could not get the reference genome") # - get gene annotation print('Getting gene annotation ...', flush=True) @@ -167,29 +204,33 @@ def get_priors(par): gtf_gz = data_dir / "gene.gtf.gz" gtf = data_dir / "gene.gtf" url = "http://ftp.ensembl.org/pub/release-107/gtf/homo_sapiens/Homo_sapiens.GRCh38.107.gtf.gz" - download_file(url, gtf_gz) + if OVERRIDE_DOWNLOAD or (not os.path.exists(gtf_gz)): + download_file(url, gtf_gz) with gzip.open(gtf_gz, "rb") as f_in, open(gtf, "wb") as f_out: shutil.copyfileobj(f_in, f_out) gtf_gz.unlink() - cmd = f"bash dictys_helper gene_gtf.sh {gtf} {par['gene_bed']}" print('Making bed files for gene annotation ...', flush=True) - run_cmd(cmd) + run_cmd([ + "bash", "dictys_helper", "gene_gtf.sh", gtf, par['gene_bed'] + ]) print('Downloading motif file...', flush=True) url='https://hocomoco11.autosome.org/final_bundle/hocomoco11/full/HUMAN/mono/HOCOMOCOv11_full_HUMAN_mono_homer_format_0.0001.motif' motif_file = data_dir / 'motifs.motif' - download_file(url, motif_file) - + if OVERRIDE_DOWNLOAD or (not os.path.exists(motif_file)): + download_file(url, motif_file) + + def configure(par): import json device='cuda:0' #cuda:0 , cpu - par['make_dir'] = f"{par['temp_dir']}/makefiles" os.makedirs(par['make_dir'], exist_ok=True) - cmd = f"cd {par['make_dir']} && bash dictys_helper makefile_template.sh common.mk config.mk env_none.mk static.mk" - run_cmd(cmd) + run_cmd([ + "bash", "dictys_helper", "makefile_template.sh", "common.mk", "config.mk", "env_none.mk", "static.mk" + ], cwd=par['make_dir']) json_arg = json.dumps({ "DEVICE": device, @@ -197,14 +238,23 @@ def configure(par): "JOINT": "1" }) - cmd = f"cd {par['make_dir']} && bash dictys_helper makefile_update.py config.mk '{json_arg}'" - run_cmd(cmd) - cmd = f"cd {par['temp_dir']} && bash dictys_helper makefile_check.py" - run_cmd(cmd) + run_cmd([ + "bash", "dictys_helper", "makefile_update.py", "config.mk", json_arg + ], cwd=par['make_dir']) + + run_cmd([ + "bash", "dictys_helper", "makefile_check.py", "--dir_makefiles", par['make_dir'], + "--dir_data", par['data_dir'] + ]) + + def infer_grn(par): print('Inferring GRNs', flush=True) - cmd = f"cd {par['temp_dir']} && bash dictys_helper network_inference.sh -j {par['num_workers']} -J 1 static" - run_cmd(cmd) + run_cmd([ + "bash", "dictys_helper", "network_inference.sh", "-j", str(par['num_workers']), "-J", "1", "static" + ], cwd=par['temp_dir']) + + def export_net(par): from util import process_links from dictys.net import network @@ -224,8 +274,8 @@ def export_net(par): output.write(par['prediction']) def main(par): - define_vars(par) + define_vars(par) extract_exp(par) extract_atac(par) create_bam(par) diff --git a/src/methods/portia/script.py b/src/methods/portia/script.py index d2f0cfab5..2ab94a09b 100644 --- a/src/methods/portia/script.py +++ b/src/methods/portia/script.py @@ -44,19 +44,17 @@ def main(par): tf_names = [gene_name for gene_name in gene_names if (gene_name in tfs)] tf_idx = np.asarray([i for i, gene_name in enumerate(gene_names) if gene_name in tf_names], dtype=int) - print('Inferring grn') dataset = pt.GeneExpressionDataset() for exp_id, data in enumerate(X): dataset.add(pt.Experiment(exp_id, data)) - + M_bar = pt.run(dataset, tf_idx=tf_idx, method='no-transform') ranked_scores = pt.rank_scores(M_bar, gene_names, limit=par['max_n_links']) sources, targets, weights = zip(*[(gene_a, gene_b, score) for gene_a, gene_b, score in ranked_scores]) grn = pd.DataFrame({'source':sources, 'target':targets, 'weight':weights}) - print(grn.shape) grn = grn[grn.source.isin(tf_names)] grn = process_links(grn, par) diff --git a/src/metrics/experimental/anchor_regression/helper.py b/src/metrics/experimental/anchor_regression/helper.py index 9c29eb042..5bdfa9a65 100644 --- a/src/metrics/experimental/anchor_regression/helper.py +++ b/src/metrics/experimental/anchor_regression/helper.py @@ -10,6 +10,7 @@ from scipy.sparse.linalg import LinearOperator from scipy.stats import pearsonr, wilcoxon from scipy.sparse import csr_matrix +from sklearn.model_selection import KFold from sklearn.preprocessing import StandardScaler, LabelEncoder, OneHotEncoder import anndata as ad import warnings @@ -23,6 +24,7 @@ from util import read_prediction, manage_layer from dataset_config import DATASET_GROUPS +from baseline import create_grn_baseline def encode_obs_cols(adata, cols): @@ -75,8 +77,8 @@ def anchor_regression( X: np.ndarray, Z: np.ndarray, Y: np.ndarray, - l2_reg: float = 1e-2, - anchor_strength: float = 1.0 + l2_reg: float = 1e-6, + gamma: float = 1.0 ) -> np.ndarray: """Anchor regression for causal inference under confounding. @@ -86,15 +88,15 @@ def anchor_regression( these environmental variables. Shape (n, z). Y: predicted variables, of shape (n, u). l2_reg: L2 regularization strength. - anchor_strength: Strength of anchor regularization. - 0 = standard regression, higher = more anchor regularization. + gamma: Strength of anchor regularization. + 1 = standard regression, higher = more anchor regularization. Returns: Inferred parameters, of shape (d, u). """ # Whitening transformation - W = get_whitening_transform(Z, gamma=anchor_strength) + W = get_whitening_transform(Z, gamma=gamma) X_t = W @ X Y_t = W @ Y @@ -110,25 +112,23 @@ def compute_stabilities( X: np.ndarray, y: np.ndarray, Z: np.ndarray, - A: np.ndarray, is_selected: np.ndarray, eps: float = 1e-50 ) -> float: - theta0_signed = anchor_regression(X, Z, y, anchor_strength=1) - theta0_signed = theta0_signed[is_selected] - theta0 = np.abs(theta0_signed) - theta0 /= np.sum(theta0) - theta_signed = anchor_regression(X, Z, y, anchor_strength=20) - theta_signed = theta_signed[is_selected] - theta = np.abs(theta_signed) + theta0 = np.abs(anchor_regression(X, Z, y, gamma=1)) + theta0 /= np.sum(theta0) + theta = np.abs(anchor_regression(X, Z, y, gamma=1.2)) theta /= np.sum(theta) + s1 = theta[is_selected] * theta0[is_selected] + s2 = theta[~is_selected] * theta0[~is_selected] - stabilities = np.clip((theta0 - theta) / (theta0 + eps), 0, 1) - stabilities[np.sign(theta0_signed) != np.sign(theta_signed)] = 0 + s1 = np.mean(s1) + s2 = np.mean(s2) - return stabilities + stability = (s1 - s2) / (s1 + s2 + eps) + return stability def evaluate_gene_stability( @@ -148,39 +148,27 @@ def evaluate_gene_stability( eps: Small epsilon for numerical stability. Returns: - Stability score for gene j + Stability score for gene j. """ is_selected = np.array(A[:, j] != 0) if (not np.any(is_selected)) or np.all(is_selected): - return 0.0 + return np.nan assert not is_selected[j] # Exclude target gene from features mask = np.ones(X.shape[1], dtype=bool) mask[j] = False - stabilities_selected = np.mean(compute_stabilities(X[:, mask], X[:, j], Z, A, is_selected[mask], eps=eps)) - #stabilities_non_selected = np.mean(compute_stabilities(X[:, mask], X[:, j], Z, A, ~is_selected[mask], eps=eps)) - - #score = (stabilities_selected - stabilities_non_selected) / (stabilities_selected + stabilities_non_selected + eps) - score = np.mean(stabilities_selected) - - return score + return compute_stabilities(X[:, mask], X[:, j], Z, is_selected[mask]) def main(par): """Main anchor regression evaluation function.""" + # Load evaluation data adata = ad.read_h5ad(par['evaluation_data']) dataset_id = adata.uns['dataset_id'] method_id = ad.read_h5ad(par['prediction'], backed='r').uns['method_id'] - - # Get dataset-specific anchor variables - if dataset_id not in DATASET_GROUPS: - raise ValueError(f"Dataset {dataset_id} not found in DATASET_GROUPS") - - anchor_cols = DATASET_GROUPS[dataset_id].get('anchors', ['donor_id', 'plate_name']) - print(f"Using anchor variables: {anchor_cols}") # Manage layer layer = manage_layer(adata, par) @@ -192,17 +180,33 @@ def main(par): gene_names = adata.var_names gene_dict = {gene_name: i for i, gene_name in enumerate(gene_names)} + # Get dataset-specific anchor variables + if dataset_id not in DATASET_GROUPS: + raise ValueError(f"Dataset {dataset_id} not found in DATASET_GROUPS") + anchor_cols = DATASET_GROUPS[dataset_id].get('anchors', ['donor_id', 'plate_name']) + print(f"Using anchor variables: {anchor_cols}") + # Encode anchor variables anchor_variables = encode_obs_cols(adata, anchor_cols) anchor_encoded = combine_multi_index(*anchor_variables) + + # Get CV groups + if "cell_type" in adata.obs: + cv_groups = LabelEncoder().fit_transform(adata.obs["cell_type"].values) + else: + np.random.randint(0, 5) + cv_groups = np.random.shuffle() if len(anchor_variables) == 0: raise ValueError(f"No anchor variables found in dataset for columns: {anchor_cols}") # One-hot encode anchor variables - Z = OneHotEncoder(sparse_output=False, dtype=np.float32).fit_transform(anchor_encoded.reshape(-1, 1)) + Z = OneHotEncoder(drop="first", sparse_output=False, dtype=np.float32).fit_transform(anchor_encoded.reshape(-1, 1)) print(f"Anchor matrix Z shape: {Z.shape}") + # Add intercept + Z = np.concatenate((Z, np.ones((len(Z), 1), dtype=np.float32)), axis=1) + # Load inferred GRN df = read_prediction(par) sources = df["source"].to_numpy() @@ -232,24 +236,33 @@ def main(par): np.fill_diagonal(A, 0) print(f"Evaluating {X.shape[1]} genes with {np.sum(A != 0)} regulatory links") + # Whether or not to take into account the regulatory modes (enhancer/inhibitor) + signed = np.any(A < 0) + # Center and scale dataset scaler = StandardScaler() X = scaler.fit_transform(X) - # Create baseline GRN - A_baseline = np.copy(A) - for j in range(A_baseline.shape[1]): - np.random.shuffle(A_baseline[:j, j]) - np.random.shuffle(A_baseline[j+1:, j]) - assert np.any(A != A_baseline) + # Create baseline model + A_baseline = create_grn_baseline(A) + # Compute gene stabilities scores, scores_baseline = [], [] for j in tqdm.tqdm(range(X.shape[1]), desc="Evaluating gene stability"): scores.append(evaluate_gene_stability(X, Z, A, j)) + scores_baseline.append(evaluate_gene_stability(X, Z, A_baseline, j)) scores = np.array(scores) + scores_baseline = np.array(scores_baseline) + + # Skip NaNs + mask = ~np.logical_or(np.isnan(scores), np.isnan(scores_baseline)) + scores = scores[mask] + scores_baseline = scores_baseline[mask] # Calculate final score - final_score = np.mean(scores) + p_value = wilcoxon(scores, scores_baseline).pvalue + p_value = np.clip(p_value, 1e-300, 1) + final_score = -np.log10(p_value) print(f"Method: {method_id}") print(f"Anchor Regression Score: {final_score:.6f}") diff --git a/src/metrics/experimental/anchor_regression/run_local.sh b/src/metrics/experimental/anchor_regression/run_local.sh index fc872cc59..78e66b9ac 100644 --- a/src/metrics/experimental/anchor_regression/run_local.sh +++ b/src/metrics/experimental/anchor_regression/run_local.sh @@ -1,5 +1,5 @@ #!/bin/bash -#SBATCH --job-name=anchor_regression +#SBATCH --job-name=regression_3 #SBATCH --output=logs/%j.out #SBATCH --error=logs/%j.err #SBATCH --ntasks=1 @@ -12,7 +12,7 @@ set -euo pipefail -save_dir="output/anchor_regression" +save_dir="output/regression_3" mkdir -p "$save_dir" # datasets to process @@ -21,7 +21,7 @@ datasets=( 'op' "300BCG" 'parsebioscience') #"300BCG" "ibd" 'parsebioscience' methods=("pearson_corr" "positive_control" "negative_control" "ppcor" "portia" "scenic" "grnboost" "scprint" "scenicplus" "celloracle" "scglue" "figr" "granie") # temporary file to collect CSV rows -combined_csv="${save_dir}/anchor_regression_scores.csv" +combined_csv="${save_dir}/regression_3_scores.csv" echo "dataset,method,metric,value" > "$combined_csv" for dataset in "${datasets[@]}"; do @@ -31,7 +31,7 @@ for dataset in "${datasets[@]}"; do for method in "${methods[@]}"; do prediction="resources/results/${dataset}/${dataset}.${method}.${method}.prediction.h5ad" - score="${save_dir}/anchor_regression_${dataset}_${method}.h5ad" + score="${save_dir}/regression_3_${dataset}_${method}.h5ad" if [[ ! -f "$prediction" ]]; then echo "File not found: $prediction, skipping..." @@ -39,7 +39,7 @@ for dataset in "${datasets[@]}"; do fi echo -e "\nProcessing method: $method\n" - python src/metrics/anchor_regression/script.py \ + python src/metrics/regression_3/script.py \ --prediction "$prediction" \ --evaluation_data "$evaluation_data" \ --score "$score" diff --git a/src/metrics/experimental/anchor_regression/script.py b/src/metrics/experimental/anchor_regression/script.py index a4687b358..b140d64c5 100644 --- a/src/metrics/experimental/anchor_regression/script.py +++ b/src/metrics/experimental/anchor_regression/script.py @@ -30,14 +30,14 @@ sys.path.append(meta["resources_dir"]) except: meta = { - "resources_dir":'src/metrics/anchor_regression/', + "resources_dir":'src/metrics/regression_3/', "util_dir": 'src/utils', - 'helper_dir': 'src/metrics/anchor_regression/' + 'helper_dir': 'src/metrics/regression_3/' } sys.path.append(meta["resources_dir"]) sys.path.append(meta["util_dir"]) sys.path.append(meta["helper_dir"]) -from helper import main as main_anchor +from helper import main as main_reg3 from util import format_save_score @@ -48,7 +48,7 @@ par[key] = value if __name__ == "__main__": - output = main_anchor(par) + output = main_reg3(par) dataset_id = ad.read_h5ad(par['evaluation_data'], backed='r').uns['dataset_id'] method_id = ad.read_h5ad(par['prediction'], backed='r').uns['method_id'] diff --git a/src/metrics/experimental/regression_3/helper.py b/src/metrics/experimental/regression_3/helper.py index 25c0acea2..97e47d5bd 100644 --- a/src/metrics/experimental/regression_3/helper.py +++ b/src/metrics/experimental/regression_3/helper.py @@ -26,6 +26,7 @@ from util import read_prediction, manage_layer from dataset_config import DATASET_GROUPS +from baseline import create_grn_baseline def encode_obs_cols(adata, cols): @@ -53,7 +54,8 @@ def compute_residual_correlations( Z_test: np.ndarray ) -> np.ndarray: model = xgboost.XGBRegressor(n_estimators=10) - #model = Ridge(alpha=10) + #model = xgboost.XGBRegressor(n_estimators=10) + model = Ridge(alpha=1) model.fit(X_train, y_train) y_hat = model.predict(X_test) residuals = y_test - y_hat @@ -115,7 +117,7 @@ def main(par): gene_mask = np.logical_or(np.any(A, axis=1), np.any(A, axis=0)) in_degrees = np.sum(A != 0, axis=0) out_degrees = np.sum(A != 0, axis=1) - idx = np.argsort(np.maximum(out_degrees, in_degrees))[:-2000] + idx = np.argsort(np.maximum(out_degrees, in_degrees))[:-1000] gene_mask[idx] = False X = X[:, gene_mask] X = X.toarray() if isinstance(X, csr_matrix) else X @@ -127,11 +129,7 @@ def main(par): print(f"Evaluating {X.shape[1]} genes with {np.sum(A != 0)} regulatory links") # Create baseline model - A_baseline = np.copy(A) - for j in range(A.shape[1]): - np.random.shuffle(A_baseline[:j, j]) - np.random.shuffle(A_baseline[j+1:, j]) - assert np.any(A_baseline != A) + A_baseline = create_grn_baseline(A) scores, baseline_scores = [], [] for group in np.unique(anchor_encoded): diff --git a/src/metrics/experimental/regression_3/script.py b/src/metrics/experimental/regression_3/script.py index 09ce46e56..b9022698b 100644 --- a/src/metrics/experimental/regression_3/script.py +++ b/src/metrics/experimental/regression_3/script.py @@ -25,7 +25,7 @@ } ## VIASH END try: - sys.path.append(meta["resources_dir"]) + sys.path.append(meta["resources_dir"]) except: meta = { "resources_dir":'src/metrics/regression/', @@ -41,9 +41,9 @@ if __name__ == '__main__': - print(par) - output = main(par) - print(output) - method_id = ad.read_h5ad(par['prediction'], backed='r').uns['method_id'] - dataset_id = ad.read_h5ad(par['evaluation_data'], backed='r').uns['dataset_id'] - format_save_score(output, method_id, dataset_id, par['score']) + print(par) + output = main(par) + print(output) + method_id = ad.read_h5ad(par['prediction'], backed='r').uns['method_id'] + dataset_id = ad.read_h5ad(par['evaluation_data'], backed='r').uns['dataset_id'] + format_save_score(output, method_id, dataset_id, par['score']) diff --git a/src/metrics/experimental/vc/helper.py b/src/metrics/experimental/vc/helper.py index 484866d6f..8e8c678a2 100644 --- a/src/metrics/experimental/vc/helper.py +++ b/src/metrics/experimental/vc/helper.py @@ -1,4 +1,5 @@ import os +import traceback from typing import Tuple, Dict import sys import tqdm @@ -8,7 +9,9 @@ import pandas as pd import torch from scipy.sparse import csr_matrix -from sklearn.model_selection import GroupKFold +from scipy.spatial.distance import cityblock +from scipy.stats import wilcoxon +from sklearn.model_selection import StratifiedGroupKFold, StratifiedKFold from sklearn.preprocessing import StandardScaler, LabelEncoder from torch.utils.data import Dataset import anndata as ad @@ -22,18 +25,19 @@ NUMPY_DTYPE = np.float32 # For reproducibility purposes -seed = 0xCAFE -os.environ['PYTHONHASHSEED'] = str(seed) -random.seed(seed) -np.random.seed(seed) -torch.manual_seed(seed) -torch.cuda.manual_seed(seed) -torch.cuda.manual_seed_all(seed) -torch.use_deterministic_algorithms(True) +def set_seed(): + seed = 0xCAFE + os.environ['PYTHONHASHSEED'] = str(seed) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.use_deterministic_algorithms(True) from util import read_prediction, manage_layer from dataset_config import DATASET_GROUPS -from scipy.spatial.distance import cityblock +from baseline import create_grn_baseline def combine_multi_index(*arrays) -> np.array: @@ -58,168 +62,15 @@ def create_control_matching(are_controls: np.ndarray, match_groups: np.ndarray) if len(control_indices) == 0: raise ValueError("No control samples found in dataset!") - # Create control mapping + # First, try to create exact matching (original approach) control_map = {} for i, (is_control, group_id) in enumerate(zip(are_controls, match_groups)): if is_control: control_map[int(group_id)] = i - - # If no controls were mapped (shouldn't happen but safety check), - # map group 0 to the first control - if not control_map and len(control_indices) > 0: - control_map[0] = control_indices[0] return control_map, match_groups -def compute_pds_cell_level(X_true, X_pred, perturbations, gene_names, max_cells_per_pert=10): - """ - Compute PDS at individual cell level (more challenging than mean-based PDS). - - For each individual predicted cell, find its rank when compared to - true mean profiles of all perturbations. - """ - unique_perts = np.unique(perturbations) - n_perts = len(unique_perts) - - print(f"Computing cell-level PDS for {n_perts} unique perturbations") - - if n_perts < 2: - return 0.0 - - # Compute mean true expression profiles per perturbation - true_means = {} - for pert in unique_perts: - mask = (perturbations == pert) - if np.sum(mask) == 0: - continue - true_means[pert] = np.mean(X_true[mask, :], axis=0) - - all_scores = [] - - # For each perturbation, sample some cells and compute their PDS - for pert in unique_perts: - if pert not in true_means: - continue - - # Get cells for this perturbation - mask = (perturbations == pert) - pert_indices = np.where(mask)[0] - - if len(pert_indices) == 0: - continue - - # Sample max_cells_per_pert cells to avoid bias from perturbations with many cells - # Use seeded random generator for reproducibility - cell_rng = np.random.RandomState(seed + int(pert)) # Different seed per perturbation - sampled_indices = cell_rng.choice( - pert_indices, - size=min(max_cells_per_pert, len(pert_indices)), - replace=False - ) - - for cell_idx in sampled_indices: - # Get predicted profile for this individual cell - pred_vec = X_pred[cell_idx, :].copy() - - # Calculate distances to all true mean profiles - dists = [] - for t in unique_perts: - if t not in true_means: - continue - true_vec = true_means[t].copy() - - # Remove target gene if it exists - if str(pert) in gene_names: - gene_idx = np.where(gene_names == str(pert))[0] - if len(gene_idx) > 0: - true_vec = np.delete(true_vec, gene_idx) - pred_vec_temp = np.delete(pred_vec, gene_idx) - else: - pred_vec_temp = pred_vec - else: - pred_vec_temp = pred_vec - - dist = cityblock(pred_vec_temp, true_vec) - dists.append((t, dist)) - - # Sort by distance and find rank of correct perturbation - dists_sorted = sorted(dists, key=lambda x: x[1]) - true_rank = next((i for i, (t, _) in enumerate(dists_sorted) if t == pert), n_perts-1) - - # Cell-level PDS - pds = 1 - (true_rank / (n_perts - 1)) if n_perts > 1 else 1.0 - all_scores.append(pds) - - mean_pds = np.mean(all_scores) if all_scores else 0.0 - print(f"Cell-level PDS scores: min={min(all_scores):.3f}, max={max(all_scores):.3f}, mean={mean_pds:.3f} (n_cells={len(all_scores)})") - return mean_pds - - -def compute_pds(X_true, X_pred, perturbations, gene_names): - """ - Compute both mean-level and cell-level PDS for comparison. - """ - # Mean-level PDS (original approach) - unique_perts = np.unique(perturbations) - n_perts = len(unique_perts) - - print(f"Computing mean-level PDS for {n_perts} unique perturbations") - - if n_perts < 2: - return 0.0, 0.0 - - # Compute mean expression profiles per perturbation - true_means = {} - pred_means = {} - - for pert in unique_perts: - mask = (perturbations == pert) - if np.sum(mask) == 0: - continue - true_means[pert] = np.mean(X_true[mask, :], axis=0) - pred_means[pert] = np.mean(X_pred[mask, :], axis=0) - - scores = {} - for pert in unique_perts: - if pert not in pred_means or pert not in true_means: - continue - - pred_vec = pred_means[pert].copy() - dists = [] - - for t in unique_perts: - if t not in true_means: - continue - true_vec = true_means[t].copy() - - if str(pert) in gene_names: - gene_idx = np.where(gene_names == str(pert))[0] - if len(gene_idx) > 0: - true_vec = np.delete(true_vec, gene_idx) - pred_vec_temp = np.delete(pred_vec, gene_idx) - else: - pred_vec_temp = pred_vec - else: - pred_vec_temp = pred_vec - - dist = cityblock(pred_vec_temp, true_vec) - dists.append((t, dist)) - - dists_sorted = sorted(dists, key=lambda x: x[1]) - true_rank = next((i for i, (t, _) in enumerate(dists_sorted) if t == pert), n_perts-1) - pds = 1 - (true_rank / (n_perts - 1)) if n_perts > 1 else 1.0 - scores[pert] = pds - - mean_level_pds = np.mean(list(scores.values())) if scores else 0.0 - print(f"Mean-level PDS: min={min(scores.values()):.3f}, max={max(scores.values()):.3f}, mean={mean_level_pds:.3f}") - - # Cell-level PDS (more challenging) - cell_level_pds = compute_pds_cell_level(X_true, X_pred, perturbations, gene_names) - - return mean_level_pds, cell_level_pds - - class GRNLayer(torch.nn.Module): def __init__( @@ -228,93 +79,164 @@ def __init__( A_signs: torch.Tensor, signed: bool = True, inverse: bool = True, - alpha: float = 1.0 + alpha: float = 0.2, + stable: bool = True, + bias: bool = True ): torch.nn.Module.__init__(self) self.n_genes: int = A_weights.size(1) self.A_weights: torch.nn.Parameter = A_weights + dtype = A_weights.dtype + device = A_weights.device + if bias: + self.b: torch.nn.Parameter = torch.nn.Parameter(torch.zeros((1, self.n_genes), dtype=dtype, device=device)) + else: + self.b = None self.register_buffer('A_signs', A_signs.to(A_weights.device)) - self.register_buffer('A_mask', (A_signs > 0).to(self.A_weights.dtype).to(A_weights.device)) - self.register_buffer('I', torch.eye(self.n_genes, dtype=A_weights.dtype, device=A_weights.device)) + self.register_buffer('A_mask', (A_signs != 0).to(self.A_weights.dtype).to(A_weights.device)) + self.register_buffer('I', torch.eye(self.n_genes, dtype=dtype, device=device)) self.signed: bool = signed self.inverse: bool = inverse self.alpha: float = alpha + self.stable: bool = stable def forward(self, x: torch.Tensor) -> torch.Tensor: if self.signed: A = torch.abs(self.A_weights) * self.A_signs + assert torch.any(A < 0) else: A = self.A_weights * self.A_mask if self.inverse: - # For inverse transformation, use iterative solve to avoid memory issues - # Solve (I - alpha * A.t()) * y = x for y - ia = self.I - self.alpha * A.t() - - # Add small regularization to diagonal for numerical stability - ia = ia + 1e-6 * self.I - - # Use solve instead of inversion to save memory - try: - # Solve ia * y.t() = x.t() for y.t(), then transpose - result = torch.linalg.solve(ia, x.t()).t() - return result - except torch.linalg.LinAlgError: - # Fallback: simple linear transformation without inversion - print("Warning: Matrix solve failed, using simplified GRN transformation") - return torch.mm(x, A) + if self.stable: + # Approximation using Neumann series + B = GRNLayer.neumann_series(A.t(), self.alpha) + y = torch.mm(x, B) + else: + # For inverse transformation, use iterative solve to avoid memory issues + # Solve (I - alpha * A.t()) * y = x for y + ia = self.I - self.alpha * A.t() + + try: + # Use solve instead of inversion to save memory + y = torch.linalg.solve(ia, x.t()).t() + except torch.linalg.LinAlgError: + # Fallback: approximation using Neumann series + B = GRNLayer.neumann_series(A.t(), self.alpha) + y = torch.mm(x, B) else: # Forward transformation: apply GRN directly - return torch.mm(x, A.t()) + y = torch.mm(x, self.I - self.alpha * A.t()) + + # Add bias term + if self.b is not None: + y = y + self.b + + return y + + @staticmethod + def neumann_series(A: torch.Tensor, alpha: float, k: int = 2) -> torch.Tensor: + """Approximate the inverse of I - A using Neumann series. + + Args: + A: the matrix for which to invert I - A. + k: the number of terms in the series. The higher, the more accurate. + + Returns: + Approximated inverse of I - A. + """ + I = torch.eye(A.shape[0], device=A.device, dtype=A.dtype) + M = alpha * A + term = I.clone() + B = I.clone() + for _ in range(k): + term = term @ M + B = B + term + return B class Model(torch.nn.Module): - def __init__(self, A: np.ndarray, n_perturbations: int, n_hidden: int = 64, signed: bool = True): + + def __init__(self, A: np.ndarray, n_perturbations: int, n_hidden: int = 16, signed: bool = True): + + # n_hidden needs to be small enough to prevent the NN from arbitrarily shifting the learning task + # from the GRN to the MLPs. + torch.nn.Module.__init__(self) self.n_genes: int = A.shape[1] self.n_perturbations: int = n_perturbations self.n_hidden: int = n_hidden - self.perturbation_embedding = torch.nn.Embedding(n_perturbations, n_hidden) + self.signed: bool = signed + # Perturbation transformations defined in the latent space + #self.perturbation_embedding = torch.nn.Embedding(n_perturbations, n_hidden) + self.perturbation_embedding = torch.nn.Embedding(n_perturbations, n_hidden * n_hidden) + + # First layer: GRN-informed transformation of control expression A_signs = torch.from_numpy(np.sign(A).astype(NUMPY_DTYPE)) A_weights = np.copy(A).astype(NUMPY_DTYPE) A_weights /= (np.sqrt(self.n_genes) * float(np.std(A_weights))) A_weights = torch.nn.Parameter(torch.from_numpy(A_weights)) # Ensure A_signs is on the same device as A_weights A_signs = A_signs.to(A_weights.device) - - # First layer: GRN-informed transformation of control expression - self.grn_input_layer = GRNLayer(A_weights, A_signs, inverse=False, signed=signed, alpha=0.1) + self.grn_input_layer = GRNLayer(A_weights, A_signs, inverse=False, signed=signed) - # Middle layers: perturbation processing + # Middle layers: encode/decode between expression profiles and latent space self.encoder = torch.nn.Sequential( + torch.nn.LayerNorm(self.n_genes), + torch.nn.Dropout(p=0.4, inplace=True), torch.nn.PReLU(1), torch.nn.Linear(self.n_genes, self.n_hidden), - torch.nn.PReLU(1), + torch.nn.PReLU(self.n_hidden), torch.nn.Linear(self.n_hidden, self.n_hidden), - torch.nn.PReLU(1), + torch.nn.PReLU(self.n_hidden), + torch.nn.Linear(self.n_hidden, self.n_hidden), + torch.nn.PReLU(self.n_hidden), ) self.decoder = torch.nn.Sequential( + torch.nn.LayerNorm(self.n_hidden), torch.nn.Linear(self.n_hidden, self.n_hidden), - torch.nn.PReLU(1), + torch.nn.PReLU(self.n_hidden), + torch.nn.Linear(self.n_hidden, self.n_hidden), + torch.nn.PReLU(self.n_hidden), torch.nn.Linear(self.n_hidden, self.n_genes), torch.nn.PReLU(1), + torch.nn.Dropout(p=0.4, inplace=True), ) # Last layer: GRN-informed transformation to final expression - self.grn_output_layer = GRNLayer(A_weights, A_signs, inverse=True, signed=signed, alpha=0.1) + self.grn_output_layer = GRNLayer(A_weights, A_signs, inverse=True, signed=signed) def forward(self, x: torch.Tensor, pert: torch.LongTensor) -> torch.Tensor: - # Apply GRN transformation to control expression + + # Encode each expression profile x = self.grn_input_layer(x) y = self.encoder(x) - z = self.perturbation_embedding(pert) - y = y + z + + # Each perturbation is a linear transformation in the latent space. + # Apply perturbation transform to the encoded profile. + z = self.perturbation_embedding(pert).view(len(x), self.n_hidden, self.n_hidden) + y = torch.einsum('ij,ijk->ik', y, z) + #z = self.perturbation_embedding(pert) + #y = y + z + + # Decode each expression profile y = self.decoder(y) y = self.grn_output_layer(y) return y + def set_grn(self, A: np.ndarray) -> None: + signed = np.any(A < 0) + A_signs = torch.from_numpy(np.sign(A).astype(NUMPY_DTYPE)) + A_weights = np.copy(A).astype(NUMPY_DTYPE) + A_weights /= (np.sqrt(self.n_genes) * float(np.std(A_weights))) + A_weights = torch.nn.Parameter(torch.from_numpy(A_weights)) + # Ensure A_signs is on the same device as A_weights + A_signs = A_signs.to(A_weights.device) + self.grn_input_layer = GRNLayer(A_weights, A_signs, inverse=False, signed=self.signed) + self.grn_output_layer = GRNLayer(A_weights, A_signs, inverse=True, signed=self.signed) + class PerturbationDataset(Dataset): @@ -342,97 +264,97 @@ def __len__(self) -> int: def __getitem__(self, i: int) -> Tuple[torch.Tensor, int, torch.Tensor]: i = self.idx[i] y = torch.from_numpy(self.X[i, :]) - group = int(self.match_groups[i]) + # Find matched control + group = int(self.match_groups[i]) if group in self.control_map: j = int(self.control_map[group]) - elif int(self.loose_match_groups[i]) in self.loose_control_map: + else: group = int(self.loose_match_groups[i]) j = int(self.loose_control_map[group]) - else: - # Fallback: use any available control sample - # This handles cases where no matching control exists (e.g., single control scenarios) - available_controls = list(self.control_map.values()) + list(self.loose_control_map.values()) - if available_controls: - j = available_controls[0] # Use first available control - else: - raise ValueError("No control samples available for matching!") x = torch.from_numpy(self.X[j, :]) p = int(self.perturbations[i]) - return x, p, y + d_x = y - x + return x, p, d_x +def coefficients_of_determination(y_target: np.ndarray, y_pred: np.ndarray, eps: float = 1e-20) -> np.ndarray: + residuals = np.square(y_target - y_pred) + ss_res = np.sum(residuals, axis=0) + eps + mean = np.mean(y_target, axis=0)[np.newaxis, :] + residuals = np.square(y_target - mean) + ss_tot = np.sum(residuals, axis=0) + eps + return 1 - ss_res / ss_tot -def evaluate(A, train_data_loader, test_data_loader, n_perturbations: int) -> Tuple[float, float]: - # Training - signed = np.any(A < 0) - model = Model(A, n_perturbations, n_hidden=16, signed=signed) +def evaluate(A, train_data_loader, test_data_loader, state_dict, n_perturbations: int, signed: bool = True) -> np.ndarray: + set_seed() + A = np.copy(A) + model = Model(A, n_perturbations, signed=signed) model = model.to(DEVICE) - optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5) + model.load_state_dict(state_dict, strict=False) + model.set_grn(A) + optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-6) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='min', patience=5, - min_lr=1e-5, cooldown=3, factor=0.8 + min_lr=1e-6, cooldown=3, factor=0.8 ) - pbar = tqdm.tqdm(range(100)) # Reduced epochs for faster testing - best_val_loss = float('inf') - best_ss_res = None - best_epoch = 0 + pbar = tqdm.tqdm(range(1000)) + best_avg_r2, best_r2 = -np.inf, None model.train() for epoch in pbar: total_loss = 0 - for x, pert, y in train_data_loader: - x, pert, y = x.to(DEVICE), pert.to(DEVICE), y.to(DEVICE) + y_target, y_pred = [], [] + for x, pert, d_x in train_data_loader: + x, pert, d_x = x.to(DEVICE), pert.to(DEVICE), d_x.to(DEVICE) + + # Reset gradients optimizer.zero_grad() + # Model now predicts full perturbed expression directly - y_hat = model(x, pert) - loss = torch.mean(torch.square(y - y_hat)) + d_x_hat = model(x, pert) + y_target.append(d_x.cpu().data.numpy()) + y_pred.append(d_x_hat.cpu().data.numpy()) + + # Compute mean squared error + loss = torch.mean(torch.square(d_x - d_x_hat)) + total_loss += loss.item() * len(x) + + # Compute gradients (clip them to prevent divergence) loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1) + + # Update parameters optimizer.step() - total_loss += loss.item() * len(x) - pbar.set_description(str(total_loss)) + scheduler.step(total_loss) + r2_train = coefficients_of_determination(np.concatenate(y_target, axis=0), np.concatenate(y_pred, axis=0)) model.eval() - ss_res = 0 + y_target, y_pred = [], [] with torch.no_grad(): - for x, pert, y in test_data_loader: - x, pert, y = x.to(DEVICE), pert.to(DEVICE), y.to(DEVICE) - # Model predicts full perturbed expression - y_hat = model(x, pert) - residuals = torch.square(y - y_hat).cpu().data.numpy() - ss_res += np.sum(residuals, axis=0) - if np.sum(ss_res) < best_val_loss: - best_val_loss = np.sum(ss_res) - best_epoch = epoch - best_ss_res = ss_res + for x, pert, d_x in test_data_loader: + x, pert, d_x = x.to(DEVICE), pert.to(DEVICE), d_x.to(DEVICE) + d_x_hat = model(x, pert) + y_target.append(d_x.cpu().data.numpy()) + y_pred.append(d_x_hat.cpu().data.numpy()) + r2_test = coefficients_of_determination(np.concatenate(y_target, axis=0), np.concatenate(y_pred, axis=0)) + avg_r2 = np.mean(r2_test) + if avg_r2 > best_avg_r2: + best_avg_r2 = avg_r2 + best_r2 = r2_test + pbar.set_description(str(np.mean(r2_train)) + " " + str(np.mean(r2_test))) model.train() - ss_res = best_ss_res - - # Final evaluation with PDS model.eval() - ss_tot = 0 - - with torch.no_grad(): - for x, pert, y in test_data_loader: - x, pert, y = x.to(DEVICE), pert.to(DEVICE), y.to(DEVICE) - y_hat = model(x, pert) - - residuals = torch.square(y - torch.mean(y, dim=0).unsqueeze(0)).cpu().data.numpy() - ss_tot += np.sum(residuals, axis=0) - - print(f"Best epoch: {best_epoch} ({best_val_loss})") - return best_ss_res, ss_tot - + return best_r2 def main(par): + # Load evaluation data adata = ad.read_h5ad(par['evaluation_data']) - assert 'is_control' in adata.obs.columns, "'is_control' column is required in the dataset for perturbation evaluation" - assert adata.obs['is_control'].sum() > 0, "'is_control' column must contain at least one True value for control samples" dataset_id = adata.uns['dataset_id'] method_id = ad.read_h5ad(par['prediction'], backed='r').uns['method_id'] @@ -458,12 +380,20 @@ def main(par): cv_groups = encode_obs_cols(adata, par['cv_groups']) match_groups = encode_obs_cols(adata, par['match']) loose_match_groups = encode_obs_cols(adata, par['loose_match']) + + # Get cell types + N_FOLDS = 5 + try: + cell_types = np.squeeze(encode_obs_cols(adata, ["cell_type"])) + except Exception: + print(traceback.format_exc()) + cell_types = np.random.randint(0, 5, size=len(X)) # Set perturbations to first column (perturbation) perturbations = cv_groups[0] # perturbation codes - # Groups used for cross-validation - cv_groups = combine_multi_index(*cv_groups) + # Validation strategy: evaluate on unseen (perturbation, cell type) pairs. + cv_groups = combine_multi_index(cell_types, perturbations) # Groups used for matching with negative controls match_groups = combine_multi_index(*match_groups) @@ -506,10 +436,12 @@ def main(par): A = A[gene_mask, :][:, gene_mask] gene_names = gene_names[gene_mask] + # Filter genes based on GRN instead of HVGs + # Keep all genes that are present in the GRN (already filtered above) print(f"Using {len(gene_names)} genes present in the GRN") # Additional memory-aware gene filtering for very large GRNs - MAX_GENES_FOR_MEMORY = 3000 # Reduced further to avoid memory issues + MAX_GENES_FOR_MEMORY = 500 # Reduced further to avoid memory issues if len(gene_names) > MAX_GENES_FOR_MEMORY: print(f"Too many genes ({len(gene_names)}) for memory. Selecting top {MAX_GENES_FOR_MEMORY} by GRN connectivity.") @@ -523,29 +455,28 @@ def main(par): print(f"Final: Using {len(gene_names)} most connected genes for evaluation") - # Remove self-regulations - np.fill_diagonal(A, 0) + # Add self-regulations + np.fill_diagonal(A, 1) - # Create baseline model - A_baseline = np.copy(A) - for j in range(A.shape[1]): - np.random.shuffle(A[:j, j]) - np.random.shuffle(A[j+1:, j]) - assert np.any(A_baseline != A) + # Check whether the inferred GRN contains signed predictions + signed = np.any(A < 0) # Mapping between gene expression profiles and their matched negative controls - control_map, _ = create_control_matching(are_controls, match_groups) loose_control_map, _ = create_control_matching(are_controls, loose_match_groups) - - ss_res = 0 - ss_tot = 0 - cv = GroupKFold(n_splits=5) + r2 = [] + r2_baseline = [] + cv = StratifiedGroupKFold(n_splits=N_FOLDS, shuffle=True, random_state=0xCAFE) - results = [] - - for i, (train_index, test_index) in enumerate(cv.split(X, X, cv_groups)): + for i, (train_index, test_index) in enumerate(cv.split(X, perturbations, cv_groups)): + + if (len(train_index) == 0) or (len(test_index) == 0): + continue + + # Create baseline model + A_baseline = create_grn_baseline(A) + # Center and scale dataset scaler = StandardScaler() scaler.fit(X[train_index, :]) @@ -553,45 +484,60 @@ def main(par): # Create data loaders n_perturbations = int(np.max(perturbations) + 1) - train_dataset = PerturbationDataset( - X_standardized, - train_index, - match_groups, - control_map, - loose_match_groups, - loose_control_map, - perturbations - ) - train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=512) - test_dataset = PerturbationDataset( - X_standardized, - test_index, - match_groups, - control_map, - loose_match_groups, - loose_control_map, - perturbations - ) - test_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=512) + def create_data_loaders(): + train_dataset = PerturbationDataset( + X_standardized, + train_index, + match_groups, + control_map, + loose_match_groups, + loose_control_map, + perturbations + ) + train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=512, shuffle=True) + test_dataset = PerturbationDataset( + X_standardized, + test_index, + match_groups, + control_map, + loose_match_groups, + loose_control_map, + perturbations + ) + test_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=512) + return train_data_loader, test_data_loader + + # For fair comparison, we first randomly initialize NN parameters, and use these same + # parameters to build both models (only the GRN weights will differ). + model_template = Model(A, n_perturbations, signed=signed).to(DEVICE) + state_dict = model_template.state_dict() # Evaluate inferred GRN - res = evaluate(A, train_data_loader, test_data_loader, n_perturbations) - ss_res = ss_res + res[0] - ss_tot = ss_tot + res[1] + train_data_loader, test_data_loader = create_data_loaders() + r2.append(evaluate(A, train_data_loader, test_data_loader, state_dict, n_perturbations, signed=signed)) # Evaluate baseline GRN (shuffled target genes) - #ss_tot = ss_tot + evaluate(A_baseline, train_data_loader, test_data_loader, n_perturbations) - - r2 = 1 - ss_res / ss_tot + train_data_loader, test_data_loader = create_data_loaders() + r2_baseline.append(evaluate(A_baseline, train_data_loader, test_data_loader, state_dict, n_perturbations, signed=signed)) + + r2 = np.asarray(r2).flatten() + r2_baseline = np.asarray(r2_baseline).flatten() + print("Mean R2", np.mean(r2), np.mean(r2_baseline)) + if np.all(r2 == r2_baseline): + final_score = 0 + else: + p_value = wilcoxon(r2, r2_baseline, alternative="greater").pvalue + final_score = -np.log10(p_value) - final_score = np.mean(np.clip(r2, 0, 1)) print(f"Method: {method_id}") - print(f"R2: {final_score}") + print(f"Final score: {final_score}") results = { + 'r2': [float(np.mean(r2))], + 'r2_baseline': [float(np.mean(r2_baseline))], + 'r2_diff': [float(np.mean(r2)) - float(np.mean(r2_baseline))], 'vc': [float(final_score)] - } df_results = pd.DataFrame(results) - return df_results \ No newline at end of file + return df_results diff --git a/src/metrics/gt/acquire/run_local.sh b/src/metrics/gt/acquire/run_local.sh new file mode 100644 index 000000000..c735e7857 --- /dev/null +++ b/src/metrics/gt/acquire/run_local.sh @@ -0,0 +1,13 @@ +#!/bin/bash +#SBATCH --job-name=beeline_data +#SBATCH --output=logs/%j.out +#SBATCH --error=logs/%j.err +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=2 +#SBATCH --time=10:00:00 +#SBATCH --mem=250GB +#SBATCH --partition=cpu +#SBATCH --mail-type=END,FAIL +#SBATCH --mail-user=jalil.nourisa@gmail.com + +python src/metrics/tf_binding/acquire/script.py \ No newline at end of file diff --git a/src/metrics/gt/acquire/script.py b/src/metrics/gt/acquire/script.py new file mode 100644 index 000000000..5c4eaad6f --- /dev/null +++ b/src/metrics/gt/acquire/script.py @@ -0,0 +1,19 @@ +import os +import sys +import io +import zipfile +import tempfile +import requests +from pathlib import Path + + +sys.path.append("src/utils") +from util import download_and_uncompress_zip + + +if __name__ == "__main__": + + download_and_uncompress_zip( + "https://zenodo.org/records/3701939/files/BEELINE-Networks.zip?download=1", + "resources/grn_benchmark/ground_truth/beeline" + ) diff --git a/src/metrics/gt/helper.py b/src/metrics/gt/helper.py new file mode 100644 index 000000000..2e2adc853 --- /dev/null +++ b/src/metrics/gt/helper.py @@ -0,0 +1,71 @@ +from typing import Optional, Dict + +import numpy as np +import pandas as pd +from sklearn.metrics import recall_score, precision_score, matthews_corrcoef, roc_auc_score, average_precision_score +from util import read_prediction + + +# For reproducibility +seed = 42 +np.random.seed(seed) + + +def load_grn_from_dataframe( + df: pd.DataFrame, + tf_dict: Dict[str, int], + tg_dict: Dict[str, int] +) -> np.ndarray: + G = np.zeros((len(tf_dict), len(tg_dict)), dtype=float) + for tf_name, tg_name, weight in zip(df['source'], df['target'], df['weight']): + if tf_name not in tf_dict: + continue + if tg_name not in tg_dict: + continue + i = tf_dict[tf_name] + j = tg_dict[tg_name] + G[i, j] = weight + return G + + +def main(par): + + # Load ground-truth edges (consider only TFs listed in the file loaded hereabove) + true_graph = pd.read_csv(par['ground_truth']) + if not {'Gene1', 'Gene2'}.issubset(set(true_graph.columns)): + raise ValueError("ground_truth must have columns: 'Gene1', 'Gene2'") + if 'weight' not in true_graph: + true_graph['weight'] = np.ones(len(true_graph)) + true_graph.rename(columns={"Gene1": "source", "Gene2": "target"}, copy=False, inplace=True) + + # Load inferred GRN + prediction = read_prediction(par) + assert prediction.shape[0] > 0, 'No links found in the network' + if not {'source', 'target', 'weight'}.issubset(set(prediction.columns)): + raise ValueError("prediction must have columns: 'source', 'target', 'weight'") + + # Intersect TF lists, intersect TG lists + tf_names = set(true_graph['source'].unique()) #.intersection(set(tf_all)) + tg_names = set(true_graph['target'].unique()) + #tf_names = set(true_graph['source'].unique()).intersection(set(prediction['source'].unique())) #.intersection(set(tf_all)) + #tg_names = set(true_graph['target'].unique()).intersection(set(prediction['target'].unique())) + tf_dict = {tf_name: i for i, tf_name in enumerate(tf_names)} + tg_dict = {tg_name: i for i, tg_name in enumerate(tg_names)} + + # Reformat GRNs as NumPy arrays + A = load_grn_from_dataframe(true_graph, tf_dict, tg_dict) + G = load_grn_from_dataframe(prediction, tf_dict, tg_dict) + G = np.abs(G) + + # Evaluate inferred GRN + tf_binding_precision = precision_score((A != 0).flatten(), (G != 0).flatten()) + tf_binding_recall = recall_score((A != 0).flatten(), (G != 0).flatten()) + final_score = roc_auc_score(A.flatten(), G.flatten()) + + summary_df = pd.DataFrame([{ + 'tf_binding_precision': tf_binding_precision, + 'tf_binding_recall': tf_binding_recall, + 'final_score': final_score + }]) + + return summary_df diff --git a/src/metrics/gt/run_local.py b/src/metrics/gt/run_local.py new file mode 100644 index 000000000..3d8e1898e --- /dev/null +++ b/src/metrics/gt/run_local.py @@ -0,0 +1,37 @@ +import subprocess + +import anndata as ad +import pandas as pd + + +results = {"dataset": [], "method": [], "score": []} +for dataset in ["beeline_hESC", "beeline_hHep", "beeline_mDC", "beeline_mESC", "beeline_mESC-E", "beeline_mHSC-E", "beeline_mHSC-L"]: + #for method in ["negative_control", "granie", "ppcor", "portia", "pearson_corr", "positive_control", "scenic", "grnboost", "scprint", "scenicplus", "celloracle", "scglue", "figr"]: + #for method in ['granie', 'ppcor', 'portia', 'negative_control']: + for method in ['portia']: + + print() + print(method) + + score_filepath = f"output/gt/gt.h5ad" + if "_h" in dataset: + gt = f"human/{dataset.split('_')[1]}-ChIP-seq-network" + else: + gt = f"mouse/{dataset.split('_')[1]}-ChIP-seq-network" + subprocess.call([ + "python", + "src/metrics/gt/script.py", + "--prediction", f"resources/results/{dataset}/{method}.h5ad", + "--evaluation_data", f'resources/grn_benchmark/inference_data/{dataset}.h5ad', + "--ground_truth", f"resources/grn_benchmark/ground_truth/beeline/Networks/{gt}.csv", + "--score", score_filepath + ]) + + adata = ad.read_h5ad(score_filepath) + if "metric_values" in adata.uns: + metric_names = adata.uns["metric_ids"] + metric_values = adata.uns["metric_values"] + print(metric_values) + +df = pd.DataFrame(results) +df.to_csv(f"output/gt/gt_scores_beeline.csv", header=True, index=False) diff --git a/src/metrics/gt/run_local.sh b/src/metrics/gt/run_local.sh new file mode 100644 index 000000000..0b2ef233a --- /dev/null +++ b/src/metrics/gt/run_local.sh @@ -0,0 +1,78 @@ +#!/bin/bash +#SBATCH --job-name=tf_binding +#SBATCH --output=logs/%j.out +#SBATCH --error=logs/%j.err +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=20 +#SBATCH --time=30:00:00 +#SBATCH --mem=250GB +#SBATCH --partition=cpu +#SBATCH --mail-type=END,FAIL +#SBATCH --mail-user=jalil.nourisa@gmail.com + +set -euo pipefail + +save_dir="output/tf_binding" +mkdir -p "$save_dir" + +# datasets to process +datasets=( 'replogle' 'norman' 'adamson' ) # 'xaira_HCT116' 'replogle' 'norman' 'adamson') #"300BCG" "ibd" 'parsebioscience''op' "300BCG" 'parsebioscience' 'replogle' 'norman' 'adamson' +# methods to process +methods=( "pearson_corr" "negative_control" "positive_control" "ppcor" "portia" "scenic" "grnboost" "scprint" "scenicplus" "celloracle" "scglue" "figr" "granie") + +for dataset in "${datasets[@]}"; do + echo -e "\n\nProcessing dataset: $dataset\n" + + # Create separate CSV file for each dataset + dataset_csv="${save_dir}/tf_binding_scores_${dataset}.csv" + echo "dataset,method,metric,value" > "$dataset_csv" + + evaluation_data="resources/grn_benchmark/evaluation_data/${dataset}_bulk.h5ad" + + for method in "${methods[@]}"; do + prediction="resources/results/${dataset}/${dataset}.${method}.${method}.prediction.h5ad" + score="${save_dir}/tf_binding_${dataset}_${method}.h5ad" + + if [[ ! -f "$prediction" ]]; then + echo "File not found: $prediction, skipping..." + continue + fi + if [[ "$dataset" == "replogle" || "$dataset" == "norman" || "$dataset" == "adamson" ]]; then + ground_truth="resources/grn_benchmark/ground_truth/K562_remap.csv" + elif [[ "$dataset" == "xaira_HEK293T" ]]; then + ground_truth="resources/grn_benchmark/ground_truth/HEK293T_remap.csv" + elif [[ "$dataset" == "xaira_HCT116" ]]; then + ground_truth="resources/grn_benchmark/ground_truth/HCT116_chipatlas.csv" + else + echo "No ground truth available for dataset: $dataset, skipping..." + continue + fi + + echo -e "\nProcessing method: $method\n" + python src/metrics/tf_binding/script.py \ + --prediction "$prediction" \ + --evaluation_data "$evaluation_data" \ + --ground_truth "$ground_truth" \ + --score "$score" + + # Extract metrics from the .h5ad and append to CSV + python -u - < np.array: + """Combine parallel label arrays into a single integer label per position.""" + A = np.stack(arrays) + n_classes = tuple(int(A[i].max()) + 1 for i in range(A.shape[0])) + return np.ravel_multi_index(A, dims=n_classes, order='C') + + +def compute_residual_correlations( + X_train: np.ndarray, + y_train: np.ndarray, + X_test: np.ndarray, + y_test: np.ndarray, + Z_test: np.ndarray +) -> np.ndarray: + model = xgboost.XGBRegressor(n_estimators=10) + #model = Ridge(alpha=10) + model.fit(X_train, y_train) + y_hat = model.predict(X_test) + residuals = y_test - y_hat + coefs = pearsonr(residuals[:, np.newaxis], Z_test, axis=0)[0] + coefs = np.nan_to_num(coefs, nan=0) + assert coefs.shape[0] == Z_test.shape[1] + return np.abs(coefs) + + +def main(par): + # Load evaluation data + adata = ad.read_h5ad(par['evaluation_data']) + dataset_id = adata.uns['dataset_id'] + method_id = ad.read_h5ad(par['prediction'], backed='r').uns['method_id'] + + # Get dataset-specific anchor variables + if dataset_id not in DATASET_GROUPS: + raise ValueError(f"Dataset {dataset_id} not found in DATASET_GROUPS") + + anchor_cols = DATASET_GROUPS[dataset_id].get('anchors', ['donor_id', 'plate_name']) + print(f"Using anchor variables: {anchor_cols}") + + # Manage layer + layer = manage_layer(adata, par) + X = adata.layers[layer] + if isinstance(X, csr_matrix): + X = X.toarray() + X = X.astype(np.float32) + + gene_names = adata.var_names + gene_dict = {gene_name: i for i, gene_name in enumerate(gene_names)} + + # Encode anchor variables + anchor_variables = encode_obs_cols(adata, anchor_cols) + anchor_encoded = combine_multi_index(*anchor_variables) + + if len(anchor_variables) == 0: + raise ValueError(f"No anchor variables found in dataset for columns: {anchor_cols}") + + # One-hot encode anchor variables + Z = OneHotEncoder(sparse_output=False, dtype=np.float32).fit_transform(anchor_encoded.reshape(-1, 1)) + print(f"Anchor matrix Z shape: {Z.shape}") + + # Load inferred GRN + df = read_prediction(par) + sources = df["source"].to_numpy() + targets = df["target"].to_numpy() + weights = df["weight"].to_numpy() + + A = np.zeros((len(gene_names), len(gene_names)), dtype=X.dtype) + for source, target, weight in zip(sources, targets, weights): + if (source in gene_dict) and (target in gene_dict): + i = gene_dict[source] + j = gene_dict[target] + A[i, j] = float(weight) + + # Only consider the genes that are actually present in the inferred GRN, + # and keep only the most-connected genes (for speed). + gene_mask = np.logical_or(np.any(A, axis=1), np.any(A, axis=0)) + in_degrees = np.sum(A != 0, axis=0) + out_degrees = np.sum(A != 0, axis=1) + idx = np.argsort(np.maximum(out_degrees, in_degrees))[:-1000] + gene_mask[idx] = False + X = X[:, gene_mask] + X = X.toarray() if isinstance(X, csr_matrix) else X + A = A[gene_mask, :][:, gene_mask] + gene_names = gene_names[gene_mask] + + # Remove self-regulations + np.fill_diagonal(A, 0) + print(f"Evaluating {X.shape[1]} genes with {np.sum(A != 0)} regulatory links") + + # Create baseline model + A_baseline = np.copy(A) + for j in range(A.shape[1]): + np.random.shuffle(A_baseline[:j, j]) + np.random.shuffle(A_baseline[j+1:, j]) + assert np.any(A_baseline != A) + + scores, baseline_scores = [], [] + for group in np.unique(anchor_encoded): + + # Train/test split + mask = (anchor_encoded != group) + X_train = X[mask, :] + X_test = X[~mask, :] + + # Standardize features + scaler = StandardScaler() + X_train = scaler.fit_transform(X_train) + X_test = scaler.transform(X_test) + + for j in tqdm.tqdm(range(X_train.shape[1])): + + # Evaluate inferred GRN + selected = (A[:, j] != 0) + unselected = ~np.copy(selected) + unselected[j] = False + if (not np.any(selected)) or (not np.any(unselected)): + continue + else: + coefs = compute_residual_correlations( + X_train[:, selected], + X_train[:, j], + X_test[:, selected], + X_test[:, j], + X_test[:, ~selected] + ) + scores.append(np.mean(coefs)) + + # Evaluate baseline GRN + selected = (A_baseline[:, j] != 0) + unselected = ~np.copy(selected) + unselected[j] = False + coefs = compute_residual_correlations( + X_train[:, selected], + X_train[:, j], + X_test[:, selected], + X_test[:, j], + X_test[:, ~selected] + ) + baseline_scores.append(np.mean(coefs)) + scores = np.array(scores) + baseline_scores = np.array(baseline_scores) + + p_value = wilcoxon(baseline_scores, scores, alternative="greater").pvalue + p_value = max(p_value, 1e-300) + + # Calculate final score + final_score = -np.log10(p_value) + print(f"Anchor Regression Score: {final_score:.6f}") + print(f"Method: {method_id}") + + # Return results as DataFrame + results = { + 'regression_3': [final_score] + } + + df_results = pd.DataFrame(results) + return df_results diff --git a/src/metrics/tf_binding/run_local.py b/src/metrics/tf_binding/run_local.py new file mode 100644 index 000000000..0599d0d8f --- /dev/null +++ b/src/metrics/tf_binding/run_local.py @@ -0,0 +1,32 @@ +import subprocess + +import anndata as ad +import pandas as pd + + +for dataset in ["replogle"]: + #for method in ['granie']: + for method in ["negative_control", "granie", "ppcor", "portia", "pearson_corr", "positive_control", "scenic", "grnboost", "scprint", "scenicplus", "celloracle", "scglue", "figr"]: + + print() + print(method) + + score_filepath = f"output/tf_binding/tf_binding.h5ad" + subprocess.call([ + "python", + "src/metrics/tf_binding/script.py", + "--prediction", f"resources/results/{dataset}/{dataset}.{method}.{method}.prediction.h5ad", + "--evaluation_data", f"resources/grn_benchmark/evaluation_data/{dataset}_bulk.h5ad", + "--ground_truth", "resources/grn_benchmark/ground_truth/K562.csv", + "--score", score_filepath + ]) + + adata = ad.read_h5ad(score_filepath) + if "metric_values" in adata.uns: + metric_names = adata.uns["metric_ids"] + metric_values = adata.uns["metric_values"] + df = pd.DataFrame({"metric": metric_names, "value": metric_values}) + df["dataset"] = dataset + df["method"] = method + df = df[["dataset", "method", "metric", "value"]] # Reorder columns to match header + df.to_csv(f"output/tf_binding/tf_binding_scores_{dataset}.csv", mode="a", header=False, index=False) diff --git a/src/process_data/beeline/config.novsh.yaml b/src/process_data/beeline/config.novsh.yaml new file mode 100644 index 000000000..922e7a7de --- /dev/null +++ b/src/process_data/beeline/config.novsh.yaml @@ -0,0 +1,53 @@ + +name: beeline +namespace: "process_data" +info: + label: Process BEELINE dataset + summary: "Process sourced BEELINE data to generate inference and evaluation datasets" + +arguments: + - name: --replogle_gwps + type: file + required: true + direction: input + - name: --tf_all + type: file + required: true + direction: input + - name: --replogle_test_perturbs + type: file + required: true + direction: input + - name: --replogle_gwps_test_sc + type: file + required: true + direction: output + - name: --replogle_gwps_train_sc + type: file + required: true + direction: output + - name: --replogle_gwps_train_sc_subset + type: file + required: true + direction: output + + +resources: + - type: python_script + path: script.py + - path: /src/utils/util.py + dest: util.py + +engines: + - type: docker + image: ghcr.io/openproblems-bio/base_python:1.0.4 + setup: + - type: python + packages: [ numpy==1.26.4 ] + - type: native + +runners: + - type: executable + - type: nextflow + directives: + label: [midtime, midmem, midcpu] diff --git a/src/process_data/beeline/script.py b/src/process_data/beeline/script.py new file mode 100644 index 000000000..72c03dfd7 --- /dev/null +++ b/src/process_data/beeline/script.py @@ -0,0 +1,73 @@ + +import os +import anndata as ad +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import sys +import scanpy as sc +from sklearn.model_selection import train_test_split + +from scipy.sparse import csr_matrix + +## VIASH START +par = { +} +## VIASH END + +try: + sys.path.append(meta["resources_dir"]) +except: + meta = { + 'resources_dir': 'src/process_data/', + 'util_dir': 'src/utils', + } + sys.path.append(meta["resources_dir"]) + sys.path.append(meta["util_dir"]) + +from helper_data import sum_by +from util import download_and_uncompress_zip + + +def add_metadata(adata, name: str): + adata.uns['dataset_summary'] = 'Processed experimental single-cell gene expression datasets used in BEELINE.' + adata.uns['dataset_description'] = 'Processed experimental single-cell gene expression datasets used in BEELINE.' + adata.uns['data_reference'] = "@dataset{aditya_pratapa_2020_3701939,\nauthor = {Aditya Pratapa and Amogh Jalihal and Jeffrey Law and Aditya Bharadwaj and T M Murali},\n title = {Benchmarking algorithms for gene regulatory network inference from single-cell transcriptomic data },\nmonth = mar,\nyear = 2020,\npublisher = {Zenodo},\ndoi = {10.5281/zenodo.3701939},\nurl = {https://doi.org/10.5281/zenodo.3701939},\n}" + adata.uns['data_url'] = 'https://zenodo.org/records/3701939' + adata.uns['dataset_id'] = f'beeline_{name}' + adata.uns['dataset_name'] = f'BEELINE_{name}' + adata.uns['dataset_organism'] = 'human' if name.startswith('h') else 'mouse' + adata.uns['normalization_id'] = 'X_norm' + return adata + + +if __name__ == '__main__': + + # - get the data + download_and_uncompress_zip( + "https://zenodo.org/records/3701939/files/BEELINE-data.zip?download=1", + "resources/datasets_raw/beeline" + ) + + # Convert CSV files to H5AD + for name in ["hESC", "hHep", "mDC", "mESC", "mHSC-E", "mHSC-GM", "mHSC-L"]: + filepath = f"resources/datasets_raw/beeline/BEELINE-data/inputs/scRNA-Seq/{name}/ExpressionData.csv" + + df = pd.read_csv(filepath, index_col=0).transpose() + adata = ad.AnnData(X=df.values, obs=df.index.to_frame(index=False), var=pd.DataFrame(index=df.columns)) + adata.var.index.name = "gene_name" + adata.layers['X_norm'] = adata.X.copy() + + # Make all obs/var column names strings and clean up a bit + adata.obs.columns = pd.Index(adata.obs.columns).map(str).str.replace("/", "_") + adata.var.columns = pd.Index(adata.var.columns).map(str).str.replace("/", "_") + + # - filter genes and cells + sc.pp.filter_cells(adata, min_genes=100) + sc.pp.filter_genes(adata, min_cells=10) + + # - add metadata + adata = add_metadata(adata, name) + + # - save + adata.write(f"resources/grn_benchmark/inference_data/beeline_{name}.h5ad") diff --git a/src/utils/baseline.py b/src/utils/baseline.py new file mode 100644 index 000000000..5d6232a0e --- /dev/null +++ b/src/utils/baseline.py @@ -0,0 +1,171 @@ +from typing import List + +import numpy as np + + +class WeightedDegree(object): + + def __init__(self, node_idx: int, degree: int, weight: float): + """Class used to sort nodes by in-degree or out-degree. + + Ties are broken by using the weights from the GRN. + + Args: + node_idx + degree: in-degree or out-degree of a node. + weight: sum of weights of incoming or outcoming edges. + """ + self.node_idx: int = int(node_idx) + self.degree: int = int(degree) + self.weight: float = float(weight) + + def __lt__(self, other: "WeightedDegree") -> bool: + if self.degree < other.degree: + return True + elif self.degree == other.degree: + return self.weight < other.weight + else: + return False + + def __eq__(self, other: "WeightedDegree") -> bool: + return (self.degree == other.degree) and (self.weight == other.weight) + + @staticmethod + def from_grn(A: np.ndarray, incoming: bool = True) -> List["WeightedDegree"]: + if not incoming: + A = A.T + D = (A != 0) + degrees = np.sum(D, axis=0) + weights = np.sum(A, axis=0) + return [WeightedDegree(i, degrees[i], weights[i]) for i in range(len(degrees))] + + +def create_grn_baseline(A): + """ + Deterministic baseline for a directed simple graph. + Preserves in/out degree sequences of A (A may be weighted/signed; nonzeros indicate edges). + Returns a weighted B: same topology as the deterministic baseline, with the exact weights from A + reassigned deterministically to different edges (no randomness). + """ + + n = A.shape[0] + + # Order nodes by degrees, with explicit tie-breaking by weight + in_degrees = WeightedDegree.from_grn(A, incoming=True) + out_degrees = WeightedDegree.from_grn(A, incoming=False) + in_order = np.argsort(in_degrees)[::-1] + out_order = np.argsort(out_degrees)[::-1] + + # Baseline GRN + B = np.zeros_like(A) + + for u in out_order: + k = out_degrees[u].degree + if k == 0: + continue + # Greedily fill from current in_order + picks = [] + for v in in_order: + if v == u or B[u, v] == 1 or in_degrees[v].degree == 0: + continue + picks.append(v) + if len(picks) == k: + break + + # Deterministic repair if not enough picks + if len(picks) < k: + # Try swaps: reassign some of u's future edges by freeing capacity deterministically + # Here, we do a deterministic second pass allowing one-time edge relocation + for v in in_order: + if v == u or in_degrees[v].degree == 0 or v in picks: + continue + + # Find w already chosen where we can swap capacity + swapped = False + for w in picks: + # Check if there exists x != u with B[x, w]==1 and B[x, v]==0 to swap (x,w)->(x,v) + # Deterministic scan by x id + for x in range(n): + if x == u: + continue + if B[x, w] == 1 and B[x, v] == 0 and x != v and x != w: + B[x, w] = 0 + B[x, v] = 1 + in_degrees[w].degree += 1 + in_degrees[v].degree -= 1 + picks.append(v) + swapped = True + break + if swapped or len(picks) == k: + break + if len(picks) == k: + break + + if len(picks) < k: + raise ValueError("Directed degree sequence not realizable with simple digraph under constraints.") + + # Place edges for u + for v in picks: + B[u, v] = 1 + in_degrees[v].degree -= 1 + #out_degrees[v].degree = 0 + out_degrees[u].degree = 0 + + # Stable re-sort by residual in-degree, with explicit tie-breaking by weight + in_order = np.argsort(in_degrees)[::-1] + + # Zero diagonal guarantee + np.fill_diagonal(B, 0) + + # Recompute initial incoming stats from A + init_in_degrees = WeightedDegree.from_grn(A, incoming=True) + + # Convenience arrays for deterministic target ranking + # (higher degree first, then higher total incoming weight, then smaller id) + init_in_deg_arr = np.array([wd.degree for wd in init_in_degrees]) + init_in_wgt_arr = np.array([wd.weight for wd in init_in_degrees]) + + for u in range(n): + + # Outgoing weights in A + mask_A_u = (A[u, :] != 0) + if not np.any(mask_A_u): + continue + orig_targets = np.where(mask_A_u)[0] + W = A[u, orig_targets].astype(float) + + # Sort weights deterministically: by |w| desc, then w asc, then orig target id asc + # (lexsort uses last key as primary) + w_keys_3 = np.abs(W) * -1.0 # primary: |w| descending + w_keys_2 = W # secondary: actual value ascending (keeps sign order stable) + w_keys_1 = orig_targets # tertiary: original target id ascending + order_w = np.lexsort((w_keys_1, w_keys_2, w_keys_3)) + W_sorted = W[order_w] + + # Targets in B for this source, ranked by original A's incoming difficulty/salience + # Rank by: in-degree desc, then incoming-weight desc, then id asc + targets_B = np.where(B[u, :] == 1)[0] + if targets_B.size == 0: + continue + t_deg = init_in_deg_arr[targets_B] + t_wgt = init_in_wgt_arr[targets_B] + t_id = targets_B + order_t = np.lexsort((t_id, -t_wgt, -t_deg)) # primary: -t_deg, then -t_wgt, then t_id + targets_ranked = targets_B[order_t] + + # Sanity: degrees should match. If not, deterministically trim/pad. + kA = W_sorted.size + kB = targets_ranked.size + if kA > kB: + # Drop evenly from both ends to avoid bias + excess = kA - kB + left = excess // 2 + right = excess - left + W_sorted = W_sorted[left: kA - right] + elif kA < kB: + W_sorted = np.concatenate([W_sorted, np.zeros(kB - kA, dtype=float)], axis=0) + + # Assign exact weights to new edges deterministically + B[u, targets_ranked] = W_sorted + + return B diff --git a/src/utils/util.py b/src/utils/util.py index 30d06f46e..94c907c8f 100644 --- a/src/utils/util.py +++ b/src/utils/util.py @@ -1,9 +1,16 @@ +import io +import os +import zipfile +import tempfile +import requests +from pathlib import Path import pandas as pd import anndata as ad import numpy as np from tqdm import tqdm import scipy.sparse as sp + def naming_convention(dataset, method): if (dataset in ['replogle', 'parsescience', 'xaira_HEK293T']) & (method in ['scprint']): dataset = f'{dataset}_sc' @@ -411,6 +418,47 @@ def download_annotation(par): print(f"Failed to download the gencode.v45.annotation.gtf.gz. Status code: {response.status_code}") print("Downloading prior ended") +def download_and_uncompress_zip(url: str, folder: str) -> None: + + chunk_size = 1 << 20 + dest = Path(folder) + dest.mkdir(parents=True, exist_ok=True) + + # Stream download -> seekable temp file (ZipFile needs seekable) + with requests.get(url, stream=True, timeout=60) as r: + r.raise_for_status() + with tempfile.TemporaryFile() as tf: + for chunk in r.iter_content(chunk_size=chunk_size): + if chunk: # skip keep-alives + tf.write(chunk) + tf.seek(0) + + with zipfile.ZipFile(tf) as zf: + for member in zf.infolist(): + print(member) + # Normalize path and prevent traversal outside `dest` + extracted_path = dest / Path(member.filename).as_posix() + normalized = extracted_path.resolve() + + if not str(normalized).startswith(str(dest.resolve()) + os.sep) and normalized != dest.resolve(): + raise RuntimeError(f"Unsafe path in ZIP: {member.filename}") + + if member.is_dir(): + normalized.mkdir(parents=True, exist_ok=True) + continue + + # Ensure parent dirs exist + normalized.parent.mkdir(parents=True, exist_ok=True) + + # Extract the file + with zf.open(member, "r") as source, open(normalized, "wb") as target: + while True: + chunk = source.read(1 << 20) + if not chunk: + break + target.write(chunk) + + def read_gtf_as_df(gtf_path: str) -> pd.DataFrame: """ Read a GTF/GFF3 file (plain or gzipped) and return gene-level annotation @@ -520,6 +568,10 @@ def create_grn_baseline(A): reassigned deterministically to different edges (no randomness). """ + # Ensure no self-regulation + A = np.copy(A) + np.fill_diagonal(A, 0) + n = A.shape[0] # Order nodes by degrees, with explicit tie-breaking by weight @@ -564,7 +616,7 @@ def create_grn_baseline(A): B[x, w] = 0 B[x, v] = 1 in_degrees[w].degree += 1 - in_degrees[v].degree -= 1 + #in_degrees[v].degree -= 1 picks.append(v) swapped = True break @@ -573,14 +625,13 @@ def create_grn_baseline(A): if len(picks) == k: break - if len(picks) < k: - raise ValueError("Directed degree sequence not realizable with simple digraph under constraints.") + #if len(picks) < k: + # raise ValueError("Directed degree sequence not realizable with simple digraph under constraints.") # Place edges for u for v in picks: B[u, v] = 1 in_degrees[v].degree -= 1 - #out_degrees[v].degree = 0 out_degrees[u].degree = 0 # Stable re-sort by residual in-degree, with explicit tie-breaking by weight @@ -640,4 +691,12 @@ def create_grn_baseline(A): # Assign exact weights to new edges deterministically B[u, targets_ranked] = W_sorted - return B \ No newline at end of file + # Quality check + #target_in_degrees = np.sum((A != 0), axis=0) + #target_out_degrees = np.sum((A != 0), axis=1) + #in_degrees = np.sum((B != 0), axis=0) + #out_degrees = np.sum((B != 0), axis=1) + #print(target_in_degrees - in_degrees) + #print(target_out_degrees - out_degrees) + + return B