From 393fcf145272a437c28be06aa819a826ccf369e2 Mon Sep 17 00:00:00 2001 From: jkobject Date: Tue, 18 Feb 2025 14:15:21 +0100 Subject: [PATCH 01/22] dbug scprint --- src/methods/scprint/config.vsh.yaml | 12 ++++-------- src/methods/scprint/script.py | 9 ++++++--- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/src/methods/scprint/config.vsh.yaml b/src/methods/scprint/config.vsh.yaml index 0efe1d84..5fc8c98d 100644 --- a/src/methods/scprint/config.vsh.yaml +++ b/src/methods/scprint/config.vsh.yaml @@ -57,7 +57,7 @@ arguments: - name: --batch_size type: integer description: The size of the batches to be used in the DataLoader. - default: 64 + default: 32 - name: --max_len type: integer description: The maximum length of the gene sequence. @@ -75,19 +75,15 @@ engines: setup: - type: python pip: - - huggingface_hub - # Can be unpinned after https://github.com/cantinilab/scPRINT/issues/14 is resolved - - scprint==1.6.2 - - scdataloader==1.6.4 + - scprint - type: docker run: lamin init --storage ./main --name main --schema bionty - - type: python - script: import bionty as bt; bt.core.sync_all_sources_to_latest() - type: docker run: lamin load anonymous/main - type: python script: from scdataloader.utils import populate_my_ontology; populate_my_ontology() - + - type: python + script: import bionty as bt; bt.core.sync_all_sources_to_latest() runners: - type: executable - type: nextflow diff --git a/src/methods/scprint/script.py b/src/methods/scprint/script.py index 5f0c95e8..adb040e9 100644 --- a/src/methods/scprint/script.py +++ b/src/methods/scprint/script.py @@ -74,16 +74,19 @@ print("CUDA is not available, using CPU", flush=True) precision = "32" dtype = torch.float32 -n_cores_available = len(os.sched_getaffinity(0)) -print(f"Using {n_cores_available} worker cores") +n_cores = min(len(os.sched_getaffinity(0)), 24) +print(f"Using {n_cores} worker cores") embedder = Embedder( how="random expr", batch_size=par["batch_size"], max_len=par["max_len"], add_zero_genes=0, - num_workers=n_cores_available, + num_workers=n_cores, doclass=False, doplot=False, + pred_embedding=["cell_type_ontology_term_id"], + keep_all_cls_pred=False, + output_expression="none", precision=precision, dtype=dtype, ) From 0bc9c5ab855e7fbe9c84ba82b6d6db660fd6d647 Mon Sep 17 00:00:00 2001 From: jkobject Date: Tue, 18 Feb 2025 14:31:17 +0100 Subject: [PATCH 02/22] allowing flash attn --- src/methods/scprint/script.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/methods/scprint/script.py b/src/methods/scprint/script.py index adb040e9..7ec38d3e 100644 --- a/src/methods/scprint/script.py +++ b/src/methods/scprint/script.py @@ -58,22 +58,26 @@ model_checkpoint_file = hf_hub_download( repo_id="jkobject/scPRINT", filename=f"{par['model_name']}.ckpt" ) -print(f"Model checkpoint file: '{model_checkpoint_file}'", flush=True) -model = scPrint.load_from_checkpoint( - model_checkpoint_file, - transformer="normal", # Don't use this for GPUs with flashattention - precpt_gene_emb=None, -) print("\n>>> Embedding data...", flush=True) if torch.cuda.is_available(): print("CUDA is available, using GPU", flush=True) precision = "16" dtype = torch.float16 + transformer="flash" else: print("CUDA is not available, using CPU", flush=True) precision = "32" dtype = torch.float32 + transformer="normal" + +print(f"Model checkpoint file: '{model_checkpoint_file}'", flush=True) +model = scPrint.load_from_checkpoint( + model_checkpoint_file, + transformer=transformer, # Don't use this for GPUs with flashattention + precpt_gene_emb=None, +) + n_cores = min(len(os.sched_getaffinity(0)), 24) print(f"Using {n_cores} worker cores") embedder = Embedder( From 3cb94a302e5ec114174df5a3e6666733c81e84bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20Kalfon?= Date: Thu, 20 Feb 2025 12:44:50 +0100 Subject: [PATCH 03/22] Update _viash.yaml --- _viash.yaml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/_viash.yaml b/_viash.yaml index 36dd957a..5b612d43 100644 --- a/_viash.yaml +++ b/_viash.yaml @@ -91,7 +91,11 @@ authors: info: github: sainirmayi orcid: 0009-0003-6319-9803 - + - name: Jeremie Kalfon + roles: [contributor] + info: + github: jkobject + orcid: 0000-0002-2818-9728 config_mods: | .runners[.type == "nextflow"].config.labels := { lowmem : "memory = 20.Gb", midmem : "memory = 50.Gb", highmem : "memory = 100.Gb", lowcpu : "cpus = 5", midcpu : "cpus = 15", highcpu : "cpus = 30", lowtime : "time = 1.h", midtime : "time = 4.h", hightime : "time = 8.h", veryhightime : "time = 24.h" } From 086946bcfd57b7540ff76ae6b3f09ad8ead5c453 Mon Sep 17 00:00:00 2001 From: Luke Zappia Date: Fri, 21 Feb 2025 08:47:45 +0100 Subject: [PATCH 04/22] Update CHANGELOG --- CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 48f7bceb..53a22533 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,9 @@ +# task_batch_integration devel + +## Minor changes + +* Un-pin the scPRINT version and update parameters (PR #51) + # task_batch_integration 2.0.0 A major update to the OpenProblems framework, switching from a Python-based framework to a Viash + Nextflow-based framework. This update features the same concepts as the previous version, but with a new implementation that is more flexible, scalable, and maintainable. From f56905ed5b7584382c6cff6ab3e899ade73b232b Mon Sep 17 00:00:00 2001 From: jkobject Date: Tue, 4 Mar 2025 17:18:58 +0100 Subject: [PATCH 05/22] adding some debug --- src/methods/scprint/config.vsh.yaml | 6 +++--- src/methods/scprint/script.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/methods/scprint/config.vsh.yaml b/src/methods/scprint/config.vsh.yaml index 5fc8c98d..d2f6b07f 100644 --- a/src/methods/scprint/config.vsh.yaml +++ b/src/methods/scprint/config.vsh.yaml @@ -48,8 +48,8 @@ arguments: - name: "--model_name" type: "string" description: Which model to use. Not used if --model is provided. - choices: ["large", "medium", "small"] - default: "large" + choices: ["large", "v2-medium", "small"] + default: "v2-medium" - name: --model type: file description: Path to the scPRINT model. @@ -75,7 +75,7 @@ engines: setup: - type: python pip: - - scprint + - git+https://github.com/cantinilab/scPRINT.git - type: docker run: lamin init --storage ./main --name main --schema bionty - type: docker diff --git a/src/methods/scprint/script.py b/src/methods/scprint/script.py index 7ec38d3e..20590e80 100644 --- a/src/methods/scprint/script.py +++ b/src/methods/scprint/script.py @@ -13,7 +13,7 @@ par = { "input": "resources_test/task_batch_integration/cxg_immune_cell_atlas/dataset.h5ad", "output": "output.h5ad", - "model_name": "large", + "model_name": "v2-medium", "model": None, } meta = {"name": "scprint"} @@ -64,12 +64,12 @@ print("CUDA is available, using GPU", flush=True) precision = "16" dtype = torch.float16 - transformer="flash" + transformer = "flash" else: print("CUDA is not available, using CPU", flush=True) precision = "32" dtype = torch.float32 - transformer="normal" + transformer = "normal" print(f"Model checkpoint file: '{model_checkpoint_file}'", flush=True) model = scPrint.load_from_checkpoint( From 92942b82acff5f38b6bc986b125b87f91f6a1183 Mon Sep 17 00:00:00 2001 From: jkobject Date: Wed, 5 Mar 2025 17:44:32 +0100 Subject: [PATCH 06/22] better model loading and new model --- src/methods/scprint/config.vsh.yaml | 8 ++++---- src/methods/scprint/script.py | 21 ++++++++++++++++----- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/src/methods/scprint/config.vsh.yaml b/src/methods/scprint/config.vsh.yaml index d2f6b07f..15a30a79 100644 --- a/src/methods/scprint/config.vsh.yaml +++ b/src/methods/scprint/config.vsh.yaml @@ -35,7 +35,7 @@ info: scprint_large: model_name: "large" scprint_medium: - model_name: "medium" + model_name: "v2-medium" scprint_small: model_name: "small" test_setup: @@ -75,15 +75,15 @@ engines: setup: - type: python pip: - - git+https://github.com/cantinilab/scPRINT.git + - scprint - type: docker run: lamin init --storage ./main --name main --schema bionty - type: docker run: lamin load anonymous/main - - type: python - script: from scdataloader.utils import populate_my_ontology; populate_my_ontology() - type: python script: import bionty as bt; bt.core.sync_all_sources_to_latest() + - type: python + script: from scdataloader.utils import populate_my_ontology; populate_my_ontology() runners: - type: executable - type: nextflow diff --git a/src/methods/scprint/script.py b/src/methods/scprint/script.py index 20590e80..d4016420 100644 --- a/src/methods/scprint/script.py +++ b/src/methods/scprint/script.py @@ -72,11 +72,22 @@ transformer = "normal" print(f"Model checkpoint file: '{model_checkpoint_file}'", flush=True) -model = scPrint.load_from_checkpoint( - model_checkpoint_file, - transformer=transformer, # Don't use this for GPUs with flashattention - precpt_gene_emb=None, -) + +m = torch.load(model_checkpoint_file, map_location=torch.device("cpu")) +if "label_counts" in m["hyper_parameters"]: + model = scPrint.load_from_checkpoint( + model_checkpoint_file, + transformer=transformer, # Don't use this for GPUs with flashattention + precpt_gene_emb=None, + classes=m["hyper_parameters"]["label_counts"], + ) +else: + model = scPrint.load_from_checkpoint( + model_checkpoint_file, + transformer=transformer, # Don't use this for GPUs with flashattention + precpt_gene_emb=None, + ) +del m n_cores = min(len(os.sched_getaffinity(0)), 24) print(f"Using {n_cores} worker cores") From 9042fba5a3128d4f1c4526453755bfc2b3aa57d5 Mon Sep 17 00:00:00 2001 From: jkobject Date: Thu, 6 Mar 2025 09:42:45 +0100 Subject: [PATCH 07/22] final debug --- src/methods/scprint/script.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/methods/scprint/script.py b/src/methods/scprint/script.py index d4016420..88195eec 100644 --- a/src/methods/scprint/script.py +++ b/src/methods/scprint/script.py @@ -59,7 +59,6 @@ repo_id="jkobject/scPRINT", filename=f"{par['model_name']}.ckpt" ) -print("\n>>> Embedding data...", flush=True) if torch.cuda.is_available(): print("CUDA is available, using GPU", flush=True) precision = "16" @@ -89,6 +88,7 @@ ) del m +print("\n>>> Embedding data...", flush=True) n_cores = min(len(os.sched_getaffinity(0)), 24) print(f"Using {n_cores} worker cores") embedder = Embedder( From 6491b5b09cdb9fd6679d88e7eae01db41244686e Mon Sep 17 00:00:00 2001 From: jkobject Date: Fri, 7 Mar 2025 11:04:50 +0100 Subject: [PATCH 08/22] better now --- src/methods/scprint/config.vsh.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/methods/scprint/config.vsh.yaml b/src/methods/scprint/config.vsh.yaml index 15a30a79..f59ea0e2 100644 --- a/src/methods/scprint/config.vsh.yaml +++ b/src/methods/scprint/config.vsh.yaml @@ -75,7 +75,8 @@ engines: setup: - type: python pip: - - scprint + - scprint==2.2.1 + - gseapy==1.1.2 - type: docker run: lamin init --storage ./main --name main --schema bionty - type: docker From 30facd805ebc8a4ea1884da8d46b936593256168 Mon Sep 17 00:00:00 2001 From: jkobject Date: Mon, 10 Mar 2025 12:10:20 +0100 Subject: [PATCH 09/22] finish debug --- src/methods/scprint/script.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/src/methods/scprint/script.py b/src/methods/scprint/script.py index 88195eec..46dd9fdd 100644 --- a/src/methods/scprint/script.py +++ b/src/methods/scprint/script.py @@ -30,14 +30,18 @@ print("\n>>> Reading input data...", flush=True) input = read_anndata(par["input"], X="layers/counts", obs="obs", var="var", uns="uns") -if input.uns["dataset_organism"] == "homo_sapiens": - input.obs["organism_ontology_term_id"] = "NCBITaxon:9606" -elif input.uns["dataset_organism"] == "mus_musculus": - input.obs["organism_ontology_term_id"] = "NCBITaxon:10090" -else: - exit_non_applicable( - f"scPRINT requires human or mouse data, not '{input.uns['dataset_organism']}'" - ) +if ( + "organism_ontology_term_id" not in input.obs.columns + and "dataset_organism" in input.uns +): + if input.uns["dataset_organism"] == "homo_sapiens": + input.obs["organism_ontology_term_id"] = "NCBITaxon:9606" + elif input.uns["dataset_organism"] == "mus_musculus": + input.obs["organism_ontology_term_id"] = "NCBITaxon:10090" + else: + exit_non_applicable( + f"scPRINT requires human or mouse data, not '{input.uns['dataset_organism']}'" + ) adata = input.copy() print("\n>>> Preprocessing data...", flush=True) @@ -112,7 +116,7 @@ obs=input.obs[[]], var=input.var[[]], obsm={ - "X_emb": embedded.obsm["scprint"], + "X_emb": embedded.obsm["scprint_emb"], }, uns={ "dataset_id": input.uns["dataset_id"], From 0f9bf7b047bb502dfba7c69145b9875be84ab331 Mon Sep 17 00:00:00 2001 From: jkobject Date: Thu, 13 Mar 2025 21:07:27 +0100 Subject: [PATCH 10/22] ending tests successfully --- src/methods/scprint/config.vsh.yaml | 4 +++- src/methods/scprint/script.py | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/methods/scprint/config.vsh.yaml b/src/methods/scprint/config.vsh.yaml index f59ea0e2..22ac28f1 100644 --- a/src/methods/scprint/config.vsh.yaml +++ b/src/methods/scprint/config.vsh.yaml @@ -75,8 +75,9 @@ engines: setup: - type: python pip: - - scprint==2.2.1 + - git+https://github.com/cantinilab/scPRINT.git@d8cc270b099c8d5dacf6913acc26f2b696685b2b - gseapy==1.1.2 + - git+https://github.com/jkobject/scDataLoader.git@0f9e1858c8a4c6b0239ceb00e762d52032d745e7 - type: docker run: lamin init --storage ./main --name main --schema bionty - type: docker @@ -87,6 +88,7 @@ engines: script: from scdataloader.utils import populate_my_ontology; populate_my_ontology() runners: - type: executable + docker_run_args: --gpus all - type: nextflow directives: label: [hightime, highmem, midcpu, gpu, highsharedmem] diff --git a/src/methods/scprint/script.py b/src/methods/scprint/script.py index 0b87dff2..2342875a 100644 --- a/src/methods/scprint/script.py +++ b/src/methods/scprint/script.py @@ -67,7 +67,7 @@ print("CUDA is available, using GPU", flush=True) precision = "16" dtype = torch.float16 - transformer="flash" + transformer = "flash" else: print("CUDA is not available, using CPU", flush=True) precision = "32" @@ -106,6 +106,7 @@ pred_embedding=["cell_type_ontology_term_id"], keep_all_cls_pred=False, output_expression="none", + save_every=30_000, precision=precision, dtype=dtype, ) From 446f23e7c6fd5c11a6478ae2afb9e7bbd3dca01c Mon Sep 17 00:00:00 2001 From: jkobject Date: Fri, 14 Mar 2025 09:59:11 +0100 Subject: [PATCH 11/22] removing flag --- src/methods/scprint/config.vsh.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/src/methods/scprint/config.vsh.yaml b/src/methods/scprint/config.vsh.yaml index 22ac28f1..c0fb6ac6 100644 --- a/src/methods/scprint/config.vsh.yaml +++ b/src/methods/scprint/config.vsh.yaml @@ -88,7 +88,6 @@ engines: script: from scdataloader.utils import populate_my_ontology; populate_my_ontology() runners: - type: executable - docker_run_args: --gpus all - type: nextflow directives: label: [hightime, highmem, midcpu, gpu, highsharedmem] From 0455229369387c8ba773558dbac0399f21db7520 Mon Sep 17 00:00:00 2001 From: jkobject Date: Fri, 14 Mar 2025 17:08:40 +0100 Subject: [PATCH 12/22] new dataloader version --- src/methods/scprint/config.vsh.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/methods/scprint/config.vsh.yaml b/src/methods/scprint/config.vsh.yaml index c0fb6ac6..a240e24f 100644 --- a/src/methods/scprint/config.vsh.yaml +++ b/src/methods/scprint/config.vsh.yaml @@ -77,7 +77,7 @@ engines: pip: - git+https://github.com/cantinilab/scPRINT.git@d8cc270b099c8d5dacf6913acc26f2b696685b2b - gseapy==1.1.2 - - git+https://github.com/jkobject/scDataLoader.git@0f9e1858c8a4c6b0239ceb00e762d52032d745e7 + - git+https://github.com/jkobject/scDataLoader.git@c67c24a2e5c62399912be39169aae76e29e108aa - type: docker run: lamin init --storage ./main --name main --schema bionty - type: docker From ab0136f22bd3a439befa92e2ed42497d4ecad324 Mon Sep 17 00:00:00 2001 From: Luke Zappia Date: Tue, 18 Mar 2025 10:42:03 +0100 Subject: [PATCH 13/22] Update CHANGELOG --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 53a22533..edf66154 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ ## Minor changes * Un-pin the scPRINT version and update parameters (PR #51) +* Update scPRINT to better handle large datasets, including a new default model (PR #54) # task_batch_integration 2.0.0 From f4075f10383f18b1ff71d1214d8a04ade9d5c0f9 Mon Sep 17 00:00:00 2001 From: jkobject Date: Wed, 23 Apr 2025 16:23:38 +0200 Subject: [PATCH 14/22] solving some issues --- src/methods/scprint/config.vsh.yaml | 1 - src/methods/scprint/script.py | 7 +++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/methods/scprint/config.vsh.yaml b/src/methods/scprint/config.vsh.yaml index a240e24f..caf6ca51 100644 --- a/src/methods/scprint/config.vsh.yaml +++ b/src/methods/scprint/config.vsh.yaml @@ -48,7 +48,6 @@ arguments: - name: "--model_name" type: "string" description: Which model to use. Not used if --model is provided. - choices: ["large", "v2-medium", "small"] default: "v2-medium" - name: --model type: file diff --git a/src/methods/scprint/script.py b/src/methods/scprint/script.py index 2342875a..d57e839b 100644 --- a/src/methods/scprint/script.py +++ b/src/methods/scprint/script.py @@ -66,12 +66,10 @@ if torch.cuda.is_available(): print("CUDA is available, using GPU", flush=True) precision = "16" - dtype = torch.float16 transformer = "flash" else: print("CUDA is not available, using CPU", flush=True) precision = "32" - dtype = torch.float32 transformer = "normal" print(f"Model checkpoint file: '{model_checkpoint_file}'", flush=True) @@ -92,6 +90,9 @@ ) del m +if model.device == "cpu" and torch.cuda.is_available(): + model = model.to("cuda") + print("\n>>> Embedding data...", flush=True) n_cores = min(len(os.sched_getaffinity(0)), 24) print(f"Using {n_cores} worker cores") @@ -107,8 +108,6 @@ keep_all_cls_pred=False, output_expression="none", save_every=30_000, - precision=precision, - dtype=dtype, ) embedded, _ = embedder(model, adata, cache=False) From ec2df331ae752e0dc1de4fa6d34802786c948ace Mon Sep 17 00:00:00 2001 From: jkobject-tower Date: Tue, 12 Aug 2025 15:52:45 +0200 Subject: [PATCH 15/22] update scprint --- src/methods/scprint/config.vsh.yaml | 5 ++-- src/methods/scprint/script.py | 45 ++++++++++++++++++++++------- 2 files changed, 36 insertions(+), 14 deletions(-) diff --git a/src/methods/scprint/config.vsh.yaml b/src/methods/scprint/config.vsh.yaml index caf6ca51..6960b78a 100644 --- a/src/methods/scprint/config.vsh.yaml +++ b/src/methods/scprint/config.vsh.yaml @@ -74,9 +74,7 @@ engines: setup: - type: python pip: - - git+https://github.com/cantinilab/scPRINT.git@d8cc270b099c8d5dacf6913acc26f2b696685b2b - - gseapy==1.1.2 - - git+https://github.com/jkobject/scDataLoader.git@c67c24a2e5c62399912be39169aae76e29e108aa + - scprint==2.3.5 - type: docker run: lamin init --storage ./main --name main --schema bionty - type: docker @@ -87,6 +85,7 @@ engines: script: from scdataloader.utils import populate_my_ontology; populate_my_ontology() runners: - type: executable + # docker_run_args: --gpus all - type: nextflow directives: label: [hightime, highmem, midcpu, gpu, highsharedmem] diff --git a/src/methods/scprint/script.py b/src/methods/scprint/script.py index d57e839b..b090b72a 100644 --- a/src/methods/scprint/script.py +++ b/src/methods/scprint/script.py @@ -6,6 +6,7 @@ import torch from huggingface_hub import hf_hub_download from scdataloader import Preprocessor +from scdataloader.utils import load_genes from scprint import scPrint from scprint.tasks import Embedder @@ -63,35 +64,56 @@ repo_id="jkobject/scPRINT", filename=f"{par['model_name']}.ckpt" ) +print(f"Model checkpoint file: '{model_checkpoint_file}'", flush=True) + if torch.cuda.is_available(): print("CUDA is available, using GPU", flush=True) - precision = "16" transformer = "flash" else: print("CUDA is not available, using CPU", flush=True) - precision = "32" transformer = "normal" -print(f"Model checkpoint file: '{model_checkpoint_file}'", flush=True) - -m = torch.load(model_checkpoint_file, map_location=torch.device("cpu")) +# make sure that you check if you have a GPU with flashattention or not (see README) +try: + m = torch.load(model_checkpoint_file) +# if not use this instead since the model weights are by default mapped to GPU types +except RuntimeError: + m = torch.load(model_checkpoint_file, map_location=torch.device("cpu")) + +# both are for compatibility issues with different versions of the pretrained model, so we need to load it with the correct transformer +if "prenorm" in m["hyper_parameters"]: + m["hyper_parameters"].pop("prenorm") + torch.save(m, model_checkpoint_file) if "label_counts" in m["hyper_parameters"]: + # you need to set precpt_gene_emb=None otherwise the model will look for its precomputed gene embeddings files although they were already converted into model weights, so you don't need this file for a pretrained model model = scPrint.load_from_checkpoint( model_checkpoint_file, - transformer=transformer, # Don't use this for GPUs with flashattention precpt_gene_emb=None, classes=m["hyper_parameters"]["label_counts"], + transformer=transformer, ) else: model = scPrint.load_from_checkpoint( - model_checkpoint_file, - transformer=transformer, # Don't use this for GPUs with flashattention - precpt_gene_emb=None, + model_checkpoint_file, precpt_gene_emb=None, transformer=transformer ) del m +# this might happen if you have a model that was trained with a different set of genes than the one you are using in the ontology (e.g. newer ontologies), While having genes in the onlogy not in the model is fine. the opposite is not, so we need to remove the genes that are in the model but not in the ontology +missing = set(model.genes) - set(load_genes(model.organisms).index) +if len(missing) > 0: + print( + "Warning: some genes missmatch exist between model and ontology: solving...", + ) + model._rm_genes(missing) + +# again if not on GPU you need to convert the model to float64 +if not torch.cuda.is_available(): + model = model.to(torch.float32) + +# you can perform your inference on float16 if you have a GPU, otherwise use float64 +dtype = torch.float16 if torch.cuda.is_available() else torch.float32 -if model.device == "cpu" and torch.cuda.is_available(): - model = model.to("cuda") +# the models are often loaded with some parts still displayed as "cuda" and some as "cpu", so we need to make sure that the model is fully on the right device +model = model.to("cuda" if torch.cuda.is_available() else "cpu") print("\n>>> Embedding data...", flush=True) n_cores = min(len(os.sched_getaffinity(0)), 24) @@ -108,6 +130,7 @@ keep_all_cls_pred=False, output_expression="none", save_every=30_000, + dtype=dtype, ) embedded, _ = embedder(model, adata, cache=False) From e93bd4644f53efb0c91b2c10a05794363c3c3a30 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20Kalfon?= Date: Fri, 29 Aug 2025 15:59:13 +0200 Subject: [PATCH 16/22] Update src/methods/scprint/script.py Co-authored-by: Luke Zappia --- src/methods/scprint/script.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/methods/scprint/script.py b/src/methods/scprint/script.py index b090b72a..7f788767 100644 --- a/src/methods/scprint/script.py +++ b/src/methods/scprint/script.py @@ -105,7 +105,7 @@ ) model._rm_genes(missing) -# again if not on GPU you need to convert the model to float64 +# again if not on GPU you need to convert the model to float32 if not torch.cuda.is_available(): model = model.to(torch.float32) From 0a80d00a782fef38606b875c6f02255d8d2269c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20Kalfon?= Date: Thu, 4 Sep 2025 18:00:07 +0200 Subject: [PATCH 17/22] Update src/methods/scprint/script.py --- src/methods/scprint/script.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/methods/scprint/script.py b/src/methods/scprint/script.py index 7f788767..46297afa 100644 --- a/src/methods/scprint/script.py +++ b/src/methods/scprint/script.py @@ -73,7 +73,6 @@ print("CUDA is not available, using CPU", flush=True) transformer = "normal" -# make sure that you check if you have a GPU with flashattention or not (see README) try: m = torch.load(model_checkpoint_file) # if not use this instead since the model weights are by default mapped to GPU types From cbac6ce4cc823aeb392061907f0886acc1383b98 Mon Sep 17 00:00:00 2001 From: jkobject-tower Date: Mon, 29 Sep 2025 10:57:04 +0200 Subject: [PATCH 18/22] improve the scgpt installation (now uses flash attention) --- src/methods/scgpt_finetuned/config.vsh.yaml | 19 +++++++++++++------ src/methods/scgpt_zeroshot/config.vsh.yaml | 19 +++++++++++++------ 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/src/methods/scgpt_finetuned/config.vsh.yaml b/src/methods/scgpt_finetuned/config.vsh.yaml index 20760aa3..2b949cbb 100644 --- a/src/methods/scgpt_finetuned/config.vsh.yaml +++ b/src/methods/scgpt_finetuned/config.vsh.yaml @@ -51,13 +51,20 @@ engines: image: openproblems/base_pytorch_nvidia:1 # TODO: Try to find working installation of flash attention (flash-attn<1.0.5) setup: - - type: python - pypi: - - gdown - - scgpt # Install from PyPI to get dependencies + #- type: python + # pypi: + # - gdown + # - scgpt # Install from PyPI to get dependencies + #- type: docker + # # Force re-installing from GitHub to get bug fixes + # run: pip install --upgrade --no-deps --force-reinstall git+https://github.com/bowang-lab/scGPT.git - type: docker - # Force re-installing from GitHub to get bug fixes - run: pip install --upgrade --no-deps --force-reinstall git+https://github.com/bowang-lab/scGPT.git + run: | + git clone https://github.com/bowang-lab/scGPT && \ + pip install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu121 && \ + pip install "flash-attn<1.0.5" --no-build-isolation && \ + pip install ipykernel pandas scanpy numba "numpy<1.24" torchtext==0.17.0 scib "scvi-tools<1.0" datasets==2.14.5 transformers==4.33.2 wandb "cell-gears<0.0.3" torch_geometric pyarrow==15.0.0 gdown && \ + cd scGPT && pip install -e . --no-deps runners: - type: executable diff --git a/src/methods/scgpt_zeroshot/config.vsh.yaml b/src/methods/scgpt_zeroshot/config.vsh.yaml index ba2455c6..3ff6425c 100644 --- a/src/methods/scgpt_zeroshot/config.vsh.yaml +++ b/src/methods/scgpt_zeroshot/config.vsh.yaml @@ -53,13 +53,20 @@ engines: image: openproblems/base_pytorch_nvidia:1 # TODO: Try to find working installation of flash attention (flash-attn<1.0.5) setup: - - type: python - pypi: - - gdown - - scgpt # Install from PyPI to get dependencies + #- type: python + # pypi: + # - gdown + # - scgpt # Install from PyPI to get dependencies + #- type: docker + # # Force re-installing from GitHub to get bug fixes + # run: pip install --upgrade --no-deps --force-reinstall git+https://github.com/bowang-lab/scGPT.git - type: docker - # Force re-installing from GitHub to get bug fixes - run: pip install --upgrade --no-deps --force-reinstall git+https://github.com/bowang-lab/scGPT.git + run: | + git clone https://github.com/bowang-lab/scGPT && \ + pip install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu121 && \ + pip install "flash-attn<1.0.5" --no-build-isolation && \ + pip install ipykernel pandas scanpy numba "numpy<1.24" torchtext==0.17.0 scib "scvi-tools<1.0" datasets==2.14.5 transformers==4.33.2 wandb "cell-gears<0.0.3" torch_geometric pyarrow==15.0.0 gdown && \ + cd scGPT && pip install -e . --no-deps runners: - type: executable From a9b80b90bb055a9893b35cc3f5b529dced39299b Mon Sep 17 00:00:00 2001 From: jkobject-tower Date: Mon, 13 Oct 2025 13:21:25 +0200 Subject: [PATCH 19/22] changing default parameters --- src/methods/scprint/config.vsh.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/methods/scprint/config.vsh.yaml b/src/methods/scprint/config.vsh.yaml index 0e3020e5..0a9a6c7d 100644 --- a/src/methods/scprint/config.vsh.yaml +++ b/src/methods/scprint/config.vsh.yaml @@ -61,7 +61,7 @@ arguments: - name: --max_len type: integer description: The maximum length of the gene sequence. - default: 4000 + default: 2300 resources: - type: python_script @@ -86,7 +86,7 @@ engines: script: from scdataloader.utils import populate_my_ontology; populate_my_ontology() runners: - type: executable - # docker_run_args: --gpus all + # docker_run_args: --gpus all - type: nextflow directives: label: [hightime, highmem, midcpu, gpu, highsharedmem] From 6215ffabac0486b816f56dc0df78e318dd2f4631 Mon Sep 17 00:00:00 2001 From: jkobject-tower Date: Mon, 13 Oct 2025 13:37:53 +0200 Subject: [PATCH 20/22] rm scgpt --- src/methods/scgpt_finetuned/config.vsh.yaml | 19 ++++++------------- src/methods/scgpt_zeroshot/config.vsh.yaml | 19 ++++++------------- 2 files changed, 12 insertions(+), 26 deletions(-) diff --git a/src/methods/scgpt_finetuned/config.vsh.yaml b/src/methods/scgpt_finetuned/config.vsh.yaml index 2b949cbb..20760aa3 100644 --- a/src/methods/scgpt_finetuned/config.vsh.yaml +++ b/src/methods/scgpt_finetuned/config.vsh.yaml @@ -51,20 +51,13 @@ engines: image: openproblems/base_pytorch_nvidia:1 # TODO: Try to find working installation of flash attention (flash-attn<1.0.5) setup: - #- type: python - # pypi: - # - gdown - # - scgpt # Install from PyPI to get dependencies - #- type: docker - # # Force re-installing from GitHub to get bug fixes - # run: pip install --upgrade --no-deps --force-reinstall git+https://github.com/bowang-lab/scGPT.git + - type: python + pypi: + - gdown + - scgpt # Install from PyPI to get dependencies - type: docker - run: | - git clone https://github.com/bowang-lab/scGPT && \ - pip install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu121 && \ - pip install "flash-attn<1.0.5" --no-build-isolation && \ - pip install ipykernel pandas scanpy numba "numpy<1.24" torchtext==0.17.0 scib "scvi-tools<1.0" datasets==2.14.5 transformers==4.33.2 wandb "cell-gears<0.0.3" torch_geometric pyarrow==15.0.0 gdown && \ - cd scGPT && pip install -e . --no-deps + # Force re-installing from GitHub to get bug fixes + run: pip install --upgrade --no-deps --force-reinstall git+https://github.com/bowang-lab/scGPT.git runners: - type: executable diff --git a/src/methods/scgpt_zeroshot/config.vsh.yaml b/src/methods/scgpt_zeroshot/config.vsh.yaml index 3ff6425c..ba2455c6 100644 --- a/src/methods/scgpt_zeroshot/config.vsh.yaml +++ b/src/methods/scgpt_zeroshot/config.vsh.yaml @@ -53,20 +53,13 @@ engines: image: openproblems/base_pytorch_nvidia:1 # TODO: Try to find working installation of flash attention (flash-attn<1.0.5) setup: - #- type: python - # pypi: - # - gdown - # - scgpt # Install from PyPI to get dependencies - #- type: docker - # # Force re-installing from GitHub to get bug fixes - # run: pip install --upgrade --no-deps --force-reinstall git+https://github.com/bowang-lab/scGPT.git + - type: python + pypi: + - gdown + - scgpt # Install from PyPI to get dependencies - type: docker - run: | - git clone https://github.com/bowang-lab/scGPT && \ - pip install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu121 && \ - pip install "flash-attn<1.0.5" --no-build-isolation && \ - pip install ipykernel pandas scanpy numba "numpy<1.24" torchtext==0.17.0 scib "scvi-tools<1.0" datasets==2.14.5 transformers==4.33.2 wandb "cell-gears<0.0.3" torch_geometric pyarrow==15.0.0 gdown && \ - cd scGPT && pip install -e . --no-deps + # Force re-installing from GitHub to get bug fixes + run: pip install --upgrade --no-deps --force-reinstall git+https://github.com/bowang-lab/scGPT.git runners: - type: executable From 773e5536aeb53d72cb6ac24b6222983936155dfd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20Kalfon?= Date: Thu, 11 Dec 2025 17:12:15 +0100 Subject: [PATCH 21/22] Update script.py --- src/methods/scprint/script.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/methods/scprint/script.py b/src/methods/scprint/script.py index 46297afa..b26aa79f 100644 --- a/src/methods/scprint/script.py +++ b/src/methods/scprint/script.py @@ -14,7 +14,7 @@ par = { "input": "resources_test/task_batch_integration/cxg_immune_cell_atlas/dataset.h5ad", "output": "output.h5ad", - "model_name": "v2-medium", + "model_name": "medium-v1.5", "model": None, } meta = {"name": "scprint"} From 6994abab140890c2d3841264b731d5eee2f72519 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20Kalfon?= Date: Thu, 11 Dec 2025 17:14:53 +0100 Subject: [PATCH 22/22] Update model names in config.vsh.yaml --- src/methods/scprint/config.vsh.yaml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/methods/scprint/config.vsh.yaml b/src/methods/scprint/config.vsh.yaml index cbb968df..cd10ef6c 100644 --- a/src/methods/scprint/config.vsh.yaml +++ b/src/methods/scprint/config.vsh.yaml @@ -33,14 +33,14 @@ info: method_types: [embedding] variants: scprint_large: - model_name: "large" + model_name: "large-v1" scprint_medium: - model_name: "v2-medium" + model_name: "medium-v1.5" scprint_small: - model_name: "small" + model_name: "small-v1" test_setup: run: - model_name: small + model_name: small-v1 batch_size: 16 max_len: 100 @@ -48,8 +48,8 @@ arguments: - name: "--model_name" type: "string" description: Which model to use. Not used if --model is provided. - choices: ["large", "v2-medium", "small"] - default: "v2-medium" + choices: ["large-v1", "medium-v1.5", "small-v1"] + default: "medium-v1.5" - name: --model type: file description: Path to the scPRINT model.