diff --git a/scripts/run_benchmark/run_full_local.sh b/scripts/run_benchmark/run_full_local.sh index 20e434b3..b60940c9 100755 --- a/scripts/run_benchmark/run_full_local.sh +++ b/scripts/run_benchmark/run_full_local.sh @@ -26,7 +26,7 @@ input_states: resources/datasets/**/state.yaml rename_keys: 'input_dataset:output_dataset;input_solution:output_solution' output_state: "state.yaml" publish_dir: "$publish_dir" -settings: '{"methods_exclude": ["uce", "scgpt_finetuned"]}' +settings: '{"methods_exclude": ["uce", "scgpt_finetuned", "transcriptformer_mlflow"]}' HERE # run the benchmark diff --git a/scripts/run_benchmark/run_test_local.sh b/scripts/run_benchmark/run_test_local.sh index 85e39583..4b7bf15e 100755 --- a/scripts/run_benchmark/run_test_local.sh +++ b/scripts/run_benchmark/run_test_local.sh @@ -21,7 +21,7 @@ input_states: resources_test/task_batch_integration/**/state.yaml rename_keys: 'input_dataset:output_dataset;input_solution:output_solution' output_state: "state.yaml" publish_dir: "$publish_dir" -settings: '{"methods_exclude": ["uce", "scgpt_finetuned"]}' +settings: '{"methods_exclude": ["uce", "scgpt_finetuned", "transcriptformer_mlflow"]}' HERE nextflow run . \ diff --git a/src/methods/geneformer_mlflow/config.vsh.yaml b/src/methods/geneformer_mlflow/config.vsh.yaml new file mode 100644 index 00000000..e1d187cf --- /dev/null +++ b/src/methods/geneformer_mlflow/config.vsh.yaml @@ -0,0 +1,51 @@ +__merge__: ../../api/base_method.yaml + +name: geneformer_mlflow +label: Geneformer (MLflow model) +summary: Geneformer is a foundation transformer model pretrained on a large-scale corpus of single cell transcriptomes +description: | + Geneformer is a foundation transformer model pretrained on a large-scale + corpus of single cell transcriptomes to enable context-aware predictions in + network biology. For this task, Geneformer is used to create a batch-corrected + cell embedding. + + Here, we use a version packaged as an MLflow model. +references: + doi: + - 10.1038/s41586-023-06139-9 + - 10.1101/2024.08.16.608180 +links: + documentation: https://geneformer.readthedocs.io/en/latest/index.html + repository: https://huggingface.co/ctheodoris/Geneformer + +info: + method_types: [embedding] + preferred_normalization: counts + +arguments: + - name: --model + type: file + description: | + An MLflow model URI for the Geneformer model. If it is a .zip or + .tar.gz file it will be extracted to a temporary directory. + required: true + +resources: + - type: python_script + path: script.py + - path: /src/utils/read_anndata_partial.py + - path: /src/utils/exit_codes.py + - path: /src/utils/unpack.py + - path: /src/utils/mlflow.py + - path: requirements.txt + +engines: + - type: docker + image: openproblems/base_pytorch_nvidia:1 + __merge__: /src/utils/mlflow_docker_setup.yaml + +runners: + - type: executable + - type: nextflow + directives: + label: [hightime, highmem, midcpu, gpu] diff --git a/src/methods/geneformer_mlflow/requirements.txt b/src/methods/geneformer_mlflow/requirements.txt new file mode 100644 index 00000000..21bec26b --- /dev/null +++ b/src/methods/geneformer_mlflow/requirements.txt @@ -0,0 +1,540 @@ +# +# This file is autogenerated by pip-compile with Python 3.11 +# by the following command: +# +# pip-compile --output-file=/tmp/tmpmz65ifid/requirements_pip_final.txt requirements.in +# +absl-py==2.3.1 + # via tensorboard +accelerate==1.10.0 + # via peft +accumulation-tree==0.6.4 + # via tdigest +aiohappyeyeballs==2.6.1 + # via aiohttp +aiohttp==3.12.15 + # via fsspec +aiosignal==1.4.0 + # via aiohttp +alembic==1.16.4 + # via + # mlflow + # optuna +anndata==0.10.9 + # via + # -r requirements.in + # geneformer + # scanpy +annotated-types==0.7.0 + # via pydantic +antlr4-python3-runtime==4.9.3 + # via omegaconf +anyio==4.10.0 + # via starlette +array-api-compat==1.12.0 + # via anndata +attrs==25.3.0 + # via + # aiohttp + # jsonschema + # referencing +blinker==1.9.0 + # via flask +cachetools==5.5.2 + # via + # google-auth + # mlflow-skinny +certifi==2025.8.3 + # via requests +charset-normalizer==3.4.3 + # via requests +click==8.2.1 + # via + # flask + # loompy + # mlflow-skinny + # ray + # uvicorn +cloudpickle==3.1.1 + # via mlflow-skinny +colorlog==6.9.0 + # via optuna +contourpy==1.3.3 + # via matplotlib +cycler==0.12.1 + # via matplotlib +databricks-sdk==0.62.0 + # via mlflow-skinny +datasets==4.0.0 + # via geneformer +dill==0.3.8 + # via + # datasets + # multiprocess +docker==7.1.0 + # via mlflow +fastapi==0.116.1 + # via mlflow-skinny +filelock==3.18.0 + # via + # datasets + # huggingface-hub + # ray + # torch + # transformers +flask==3.1.1 + # via mlflow +fonttools==4.59.0 + # via matplotlib +frozenlist==1.7.0 + # via + # aiohttp + # aiosignal +fsspec[http]==2025.3.0 + # via + # datasets + # huggingface-hub + # torch +geneformer @ git+https://huggingface.co/ctheodoris/Geneformer@69e6887 + # via -r requirements.in +gitdb==4.0.12 + # via gitpython +gitpython==3.1.45 + # via mlflow-skinny +google-auth==2.40.3 + # via databricks-sdk +graphene==3.4.3 + # via mlflow +graphql-core==3.2.6 + # via + # graphene + # graphql-relay +graphql-relay==3.2.0 + # via graphene +greenlet==3.2.4 + # via sqlalchemy +grpcio==1.74.0 + # via tensorboard +gunicorn==23.0.0 + # via mlflow +h11==0.16.0 + # via uvicorn +h5py==3.14.0 + # via + # anndata + # loompy + # scanpy +hf-xet==1.1.7 + # via huggingface-hub +huggingface-hub==0.34.4 + # via + # accelerate + # datasets + # peft + # tokenizers + # transformers +idna==3.10 + # via + # anyio + # requests + # yarl +importlib-metadata==8.7.0 + # via + # mlflow-skinny + # opentelemetry-api +itsdangerous==2.2.0 + # via flask +jinja2==3.1.6 + # via + # flask + # torch +joblib==1.5.1 + # via + # pynndescent + # scanpy + # scikit-learn +jsonschema==4.25.0 + # via ray +jsonschema-specifications==2025.4.1 + # via jsonschema +kiwisolver==1.4.9 + # via matplotlib +legacy-api-wrap==1.4.1 + # via scanpy +llvmlite==0.44.0 + # via + # numba + # pynndescent +loompy==3.0.8 + # via geneformer +mako==1.3.10 + # via alembic +markdown==3.8.2 + # via tensorboard +markupsafe==3.0.2 + # via + # flask + # jinja2 + # mako + # werkzeug +matplotlib==3.10.5 + # via + # geneformer + # mlflow + # scanpy + # seaborn +mlflow==3.1.0 + # via -r requirements.in +mlflow-skinny==3.1.0 + # via mlflow +mpmath==1.3.0 + # via sympy +msgpack==1.1.1 + # via ray +multidict==6.6.4 + # via + # aiohttp + # yarl +multiprocess==0.70.16 + # via datasets +natsort==8.4.0 + # via + # anndata + # scanpy +networkx==3.5 + # via + # scanpy + # torch +numba==0.61.2 + # via + # loompy + # pynndescent + # scanpy + # umap-learn +numpy==2.2.6 + # via + # accelerate + # anndata + # contourpy + # datasets + # geneformer + # h5py + # loompy + # matplotlib + # mlflow + # numba + # numpy-groupies + # optuna + # pandas + # patsy + # peft + # scanpy + # scikit-learn + # scipy + # seaborn + # statsmodels + # tensorboard + # transformers + # umap-learn +numpy-groupies==0.11.3 + # via loompy +nvidia-cublas-cu12==12.8.4.1 + # via + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 + # torch +nvidia-cuda-cupti-cu12==12.8.90 + # via torch +nvidia-cuda-nvrtc-cu12==12.8.93 + # via torch +nvidia-cuda-runtime-cu12==12.8.90 + # via torch +nvidia-cudnn-cu12==9.10.2.21 + # via torch +nvidia-cufft-cu12==11.3.3.83 + # via torch +nvidia-cufile-cu12==1.13.1.3 + # via torch +nvidia-curand-cu12==10.3.9.90 + # via torch +nvidia-cusolver-cu12==11.7.3.90 + # via torch +nvidia-cusparse-cu12==12.5.8.93 + # via + # nvidia-cusolver-cu12 + # torch +nvidia-cusparselt-cu12==0.7.1 + # via torch +nvidia-nccl-cu12==2.27.3 + # via torch +nvidia-nvjitlink-cu12==12.8.93 + # via + # nvidia-cufft-cu12 + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 + # torch +nvidia-nvtx-cu12==12.8.90 + # via torch +omegaconf==2.3.0 + # via -r requirements.in +opentelemetry-api==1.36.0 + # via + # mlflow-skinny + # opentelemetry-sdk + # opentelemetry-semantic-conventions +opentelemetry-sdk==1.36.0 + # via mlflow-skinny +opentelemetry-semantic-conventions==0.57b0 + # via opentelemetry-sdk +optuna==4.4.0 + # via + # geneformer + # optuna-integration +optuna-integration==4.4.0 + # via geneformer +packaging==25.0 + # via + # accelerate + # anndata + # datasets + # geneformer + # gunicorn + # huggingface-hub + # matplotlib + # mlflow-skinny + # optuna + # peft + # ray + # scanpy + # statsmodels + # tensorboard + # transformers +pandas==2.3.1 + # via + # anndata + # datasets + # geneformer + # mlflow + # scanpy + # seaborn + # statsmodels +patsy==1.0.1 + # via + # scanpy + # statsmodels +peft==0.17.0 + # via geneformer +pillow==11.3.0 + # via + # matplotlib + # tensorboard +propcache==0.3.2 + # via + # aiohttp + # yarl +protobuf==6.31.1 + # via + # mlflow-skinny + # ray + # tensorboard +psutil==7.0.0 + # via + # accelerate + # peft +pyarrow==20.0.0 + # via + # datasets + # geneformer + # mlflow +pyasn1==0.6.1 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.4.2 + # via google-auth +pydantic==2.11.7 + # via + # fastapi + # mlflow-skinny +pydantic-core==2.33.2 + # via pydantic +pynndescent==0.5.13 + # via + # scanpy + # umap-learn +pyparsing==3.2.3 + # via matplotlib +python-dateutil==2.9.0.post0 + # via + # graphene + # matplotlib + # pandas +pytz==2025.2 + # via + # geneformer + # pandas +pyudorandom==1.0.0 + # via tdigest +pyyaml==6.0.2 + # via + # accelerate + # datasets + # huggingface-hub + # mlflow-skinny + # omegaconf + # optuna + # peft + # ray + # transformers +ray==2.48.0 + # via geneformer +referencing==0.36.2 + # via + # jsonschema + # jsonschema-specifications +regex==2025.7.34 + # via transformers +requests==2.32.4 + # via + # databricks-sdk + # datasets + # docker + # huggingface-hub + # mlflow-skinny + # ray + # transformers +rpds-py==0.27.0 + # via + # jsonschema + # referencing +rsa==4.9.1 + # via google-auth +safetensors==0.6.2 + # via + # accelerate + # peft + # transformers +scanpy==1.11.4 + # via geneformer +scikit-learn==1.7.1 + # via + # geneformer + # mlflow + # pynndescent + # scanpy + # umap-learn +scipy==1.16.1 + # via + # anndata + # geneformer + # loompy + # mlflow + # pynndescent + # scanpy + # scikit-learn + # statsmodels + # umap-learn +seaborn==0.13.2 + # via + # geneformer + # scanpy +session-info2==0.2 + # via scanpy +six==1.17.0 + # via python-dateutil +smmap==5.0.2 + # via gitdb +sniffio==1.3.1 + # via anyio +sqlalchemy==2.0.43 + # via + # alembic + # mlflow + # optuna +sqlparse==0.5.3 + # via mlflow-skinny +starlette==0.47.2 + # via fastapi +statsmodels==0.14.5 + # via + # geneformer + # scanpy +sympy==1.14.0 + # via torch +tdigest==0.5.2.2 + # via geneformer +tensorboard==2.20.0 + # via geneformer +tensorboard-data-server==0.7.2 + # via tensorboard +threadpoolctl==3.6.0 + # via scikit-learn +tokenizers==0.21.4 + # via transformers +torch==2.8.0 + # via + # accelerate + # geneformer + # peft +tqdm==4.67.1 + # via + # datasets + # geneformer + # huggingface-hub + # optuna + # peft + # scanpy + # transformers + # umap-learn +transformers==4.49.0 + # via + # -r requirements.in + # geneformer + # peft +triton==3.4.0 + # via torch +typing-extensions==4.14.1 + # via + # aiosignal + # alembic + # anyio + # fastapi + # graphene + # huggingface-hub + # mlflow-skinny + # opentelemetry-api + # opentelemetry-sdk + # opentelemetry-semantic-conventions + # pydantic + # pydantic-core + # referencing + # scanpy + # sqlalchemy + # starlette + # torch + # typing-inspection +typing-inspection==0.4.1 + # via pydantic +tzdata==2025.2 + # via pandas +umap-learn==0.5.9.post2 + # via scanpy +urllib3==2.5.0 + # via + # docker + # requests +uvicorn==0.35.0 + # via mlflow-skinny +werkzeug==3.1.3 + # via + # flask + # tensorboard +xxhash==3.5.0 + # via datasets +yarl==1.20.1 + # via aiohttp +zipp==3.23.0 + # via importlib-metadata + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/src/methods/geneformer_mlflow/script.py b/src/methods/geneformer_mlflow/script.py new file mode 100644 index 00000000..dc710040 --- /dev/null +++ b/src/methods/geneformer_mlflow/script.py @@ -0,0 +1,89 @@ +import os +import sys + +import anndata as ad +import mlflow.pyfunc +import numpy as np + +## VIASH START +# Note: this section is auto-generated by viash at runtime. To edit it, make changes +# in config.vsh.yaml and then run `viash config inject config.vsh.yaml`. +par = { + "input": "resources_test/.../input.h5ad", + "output": "output.h5ad", + "model": "resources_test/.../model", +} +meta = {"name": "geneformer_mlflow"} +## VIASH END + +sys.path.append(meta["resources_dir"]) +from exit_codes import exit_non_applicable # noqa: E402 +from mlflow import embed # noqa: E402 +from read_anndata_partial import read_anndata # noqa: E402 +from unpack import unpack_directory # noqa: E402 + +print("====== Geneformer (MLflow model) ======", flush=True) + +print("\n>>> Reading input files...", flush=True) +print(f"Input H5AD file: '{par['input']}'", flush=True) +adata = read_anndata(par["input"], X="layers/counts", obs="obs", var="var", uns="uns") + +if adata.uns["dataset_organism"] != "homo_sapiens": + exit_non_applicable( + f"Geneformer (MLflow) can only be used with human data " + f'(dataset_organism == "{adata.uns["dataset_organism"]}")' + ) + +print(adata, flush=True) + +print("\n>>> Unpacking model...", flush=True) +model_dir, model_temp = unpack_directory(par["model"]) + +print("\n>>> Loading model...", flush=True) +model = mlflow.pyfunc.load_model(model_dir) +print(model, flush=True) + +n_processors = meta.get("cpus") or os.cpu_count() +print(f"Available processors: {n_processors}", flush=True) + + +def process_geneformer_input(input_adata): + """Add Geneformer-specific fields to input AnnData.""" + input_adata.obs["cell_idx"] = np.arange(input_adata.n_obs) + input_adata.obs["n_counts"] = input_adata.X.sum(axis=1) + + +print("\n>>> Embedding data...", flush=True) +embedding = embed( + adata, + model, + layers=["counts"], + var={"feature_id": "ensembl_id"}, + model_params={"nproc": n_processors}, + process_adata=process_geneformer_input, +) + +print("\n>>> Storing output...", flush=True) +output = ad.AnnData( + obs=adata.obs[[]], + var=adata.var[[]], + obsm={ + "X_emb": embedding, + }, + uns={ + "dataset_id": adata.uns["dataset_id"], + "normalization_id": adata.uns["normalization_id"], + "method_id": meta["name"], + }, +) +print(output) + +print("\n>>> Writing output to file...", flush=True) +print(f"Output H5AD file: '{par['output']}'", flush=True) +output.write_h5ad(par["output"], compression="gzip") + +print("\n>>> Cleaning up temporary files...", flush=True) +if model_temp is not None: + model_temp.cleanup() + +print("\n>>> Done!", flush=True) diff --git a/src/methods/scgpt_mlflow/config.vsh.yaml b/src/methods/scgpt_mlflow/config.vsh.yaml new file mode 100644 index 00000000..d684085e --- /dev/null +++ b/src/methods/scgpt_mlflow/config.vsh.yaml @@ -0,0 +1,48 @@ +__merge__: ../../api/base_method.yaml + +name: scgpt_mlflow +label: scGPT (MLflow model) +summary: A foundation model for single-cell biology +description: | + scGPT is a foundation model for single-cell biology based on a generative + pre-trained transformer and trained on a repository of over 33 million cells. + + Here, we use a version packaged as an MLflow model. +references: + doi: + - 10.1038/s41592-024-02201-0 +links: + documentation: https://scgpt.readthedocs.io/en/latest/ + repository: https://github.com/bowang-lab/scGPT + +info: + method_types: [embedding] + preferred_normalization: counts + +arguments: + - name: --model + type: file + description: | + An MLflow model URI for the transcriptformer model. If it is a .zip or + .tar.gz file it will be extracted to a temporary directory. + required: true + +resources: + - type: python_script + path: script.py + - path: /src/utils/read_anndata_partial.py + - path: /src/utils/exit_codes.py + - path: /src/utils/unpack.py + - path: /src/utils/mlflow.py + - path: requirements.txt + +engines: + - type: docker + image: openproblems/base_pytorch_nvidia:1 + __merge__: /src/utils/mlflow_docker_setup.yaml + +runners: + - type: executable + - type: nextflow + directives: + label: [hightime, highmem, midcpu, gpu] diff --git a/src/methods/scgpt_mlflow/requirements.txt b/src/methods/scgpt_mlflow/requirements.txt new file mode 100644 index 00000000..2ad53dc3 --- /dev/null +++ b/src/methods/scgpt_mlflow/requirements.txt @@ -0,0 +1,684 @@ +# This file was autogenerated by uv via the following command: +# uv pip compile requirements.in -o /tmp/tmp7yfkiop2/requirements_initial.txt +absl-py==2.3.1 + # via + # chex + # ml-collections + # optax + # orbax + # orbax-checkpoint +aiofiles==24.1.0 + # via orbax-checkpoint +aiohappyeyeballs==2.6.1 + # via aiohttp +aiohttp==3.12.15 + # via + # datasets + # fsspec +aiosignal==1.4.0 + # via aiohttp +alembic==1.16.4 + # via mlflow +anndata==0.10.9 + # via + # -r requirements.in + # mudata + # scanpy + # scib + # scvi-tools +annotated-types==0.7.0 + # via pydantic +anyio==4.10.0 + # via starlette +array-api-compat==1.12.0 + # via anndata +asttokens==3.0.0 + # via stack-data +async-timeout==5.0.1 + # via aiohttp +attrs==25.3.0 + # via aiohttp +blinker==1.9.0 + # via flask +cached-property==2.0.1 + # via orbax +cachetools==5.5.2 + # via + # google-auth + # mlflow-skinny +cell-gears==0.0.2 + # via scgpt +certifi==2025.8.3 + # via requests +charset-normalizer==3.4.3 + # via requests +chex==0.1.90 + # via + # optax + # scvi-tools +click==8.2.1 + # via + # flask + # mlflow-skinny + # uvicorn +cloudpickle==3.1.1 + # via mlflow-skinny +contourpy==1.3.2 + # via matplotlib +cycler==0.12.1 + # via matplotlib +databricks-sdk==0.62.0 + # via mlflow-skinny +datasets==2.14.4 + # via scgpt +dcor==0.6 + # via cell-gears +decorator==5.2.1 + # via ipython +deprecated==1.2.18 + # via scib +dill==0.3.7 + # via + # datasets + # multiprocess +docker==7.1.0 + # via mlflow +docrep==0.3.2 + # via scvi-tools +et-xmlfile==2.0.0 + # via openpyxl +etils==1.13.0 + # via + # orbax + # orbax-checkpoint +exceptiongroup==1.3.0 + # via + # anndata + # anyio + # ipython +executing==2.2.0 + # via stack-data +fastapi==0.116.1 + # via mlflow-skinny +filelock==3.18.0 + # via + # huggingface-hub + # torch + # triton +flask==3.1.1 + # via mlflow +flax==0.10.7 + # via scvi-tools +fonttools==4.59.0 + # via matplotlib +frozenlist==1.7.0 + # via + # aiohttp + # aiosignal +fsspec==2025.7.0 + # via + # datasets + # etils + # huggingface-hub + # pytorch-lightning + # torch +gitdb==4.0.12 + # via gitpython +gitpython==3.1.45 + # via mlflow-skinny +google-auth==2.40.3 + # via databricks-sdk +graphene==3.4.3 + # via mlflow +graphql-core==3.2.6 + # via + # graphene + # graphql-relay +graphql-relay==3.2.0 + # via graphene +greenlet==3.2.4 + # via sqlalchemy +gunicorn==23.0.0 + # via mlflow +h11==0.16.0 + # via uvicorn +h5py==3.14.0 + # via + # anndata + # scanpy + # scib + # scvi-tools +hf-xet==1.1.7 + # via huggingface-hub +huggingface-hub==0.34.4 + # via datasets +humanize==4.12.3 + # via orbax-checkpoint +idna==3.10 + # via + # anyio + # requests + # yarl +igraph==0.11.9 + # via + # leidenalg + # scib +importlib-metadata==8.7.0 + # via + # mlflow-skinny + # opentelemetry-api +importlib-resources==6.5.2 + # via + # etils + # orbax +ipython==8.27.0 + # via -r requirements.in +itsdangerous==2.2.0 + # via flask +jax==0.6.2 + # via + # chex + # flax + # numpyro + # optax + # orbax + # orbax-checkpoint + # scvi-tools +jaxlib==0.6.2 + # via + # chex + # jax + # numpyro + # optax + # orbax + # scvi-tools +jedi==0.19.2 + # via ipython +jinja2==3.1.6 + # via + # flask + # torch +joblib==1.5.1 + # via + # dcor + # pynndescent + # scanpy + # scikit-learn +kiwisolver==1.4.9 + # via matplotlib +legacy-api-wrap==1.4.1 + # via scanpy +leidenalg==0.10.2 + # via + # scgpt + # scib +lightning-utilities==0.15.2 + # via + # pytorch-lightning + # torchmetrics +llvmlite==0.44.0 + # via + # numba + # pynndescent + # scib +mako==1.3.10 + # via alembic +markdown-it-py==4.0.0 + # via rich +markupsafe==3.0.2 + # via + # flask + # jinja2 + # mako + # werkzeug +matplotlib==3.10.5 + # via + # mlflow + # scanpy + # scib + # seaborn +matplotlib-inline==0.1.7 + # via ipython +mdurl==0.1.2 + # via markdown-it-py +ml-collections==1.1.0 + # via scvi-tools +ml-dtypes==0.5.3 + # via + # jax + # jaxlib + # tensorstore +mlflow==3.1.0 + # via -r requirements.in +mlflow-skinny==3.1.0 + # via mlflow +mpmath==1.3.0 + # via sympy +msgpack==1.1.1 + # via + # flax + # orbax + # orbax-checkpoint +mudata==0.3.2 + # via scvi-tools +multidict==6.6.4 + # via + # aiohttp + # yarl +multipledispatch==1.0.0 + # via numpyro +multiprocess==0.70.15 + # via datasets +natsort==8.4.0 + # via + # anndata + # scanpy +nest-asyncio==1.6.0 + # via + # orbax + # orbax-checkpoint +networkx==3.4.2 + # via + # cell-gears + # scanpy + # torch +numba==0.61.2 + # via + # dcor + # pynndescent + # scanpy + # scgpt + # scib + # umap-learn +numpy==1.26.4 + # via + # anndata + # cell-gears + # chex + # contourpy + # datasets + # dcor + # h5py + # jax + # jaxlib + # matplotlib + # ml-dtypes + # mlflow + # numba + # numpyro + # optax + # orbax + # orbax-checkpoint + # pandas + # patsy + # pyro-ppl + # pytorch-lightning + # scanpy + # scib + # scikit-learn + # scikit-misc + # scipy + # scvi-tools + # seaborn + # statsmodels + # tensorstore + # torchmetrics + # torchtext + # treescope + # umap-learn +numpyro==0.19.0 + # via scvi-tools +nvidia-cublas-cu12==12.1.3.1 + # via + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 + # torch +nvidia-cuda-cupti-cu12==12.1.105 + # via torch +nvidia-cuda-nvrtc-cu12==12.1.105 + # via torch +nvidia-cuda-runtime-cu12==12.1.105 + # via torch +nvidia-cudnn-cu12==8.9.2.26 + # via torch +nvidia-cufft-cu12==11.0.2.54 + # via torch +nvidia-curand-cu12==10.3.2.106 + # via torch +nvidia-cusolver-cu12==11.4.5.107 + # via torch +nvidia-cusparse-cu12==12.1.0.106 + # via + # nvidia-cusolver-cu12 + # torch +nvidia-nccl-cu12==2.18.1 + # via torch +nvidia-nvjitlink-cu12==12.9.86 + # via + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 +nvidia-nvtx-cu12==12.1.105 + # via torch +openpyxl==3.1.5 + # via scvi-tools +opentelemetry-api==1.36.0 + # via + # mlflow-skinny + # opentelemetry-sdk + # opentelemetry-semantic-conventions +opentelemetry-sdk==1.36.0 + # via mlflow-skinny +opentelemetry-semantic-conventions==0.57b0 + # via opentelemetry-sdk +opt-einsum==3.4.0 + # via + # jax + # pyro-ppl +optax==0.2.5 + # via + # flax + # scvi-tools +orbax==0.1.7 + # via scgpt +orbax-checkpoint==0.11.21 + # via flax +packaging==25.0 + # via + # anndata + # datasets + # gunicorn + # huggingface-hub + # lightning-utilities + # matplotlib + # mlflow-skinny + # pytorch-lightning + # scanpy + # statsmodels + # torchmetrics +pandas==2.3.1 + # via + # anndata + # cell-gears + # datasets + # mlflow + # scanpy + # scgpt + # scib + # scvi-tools + # seaborn + # statsmodels +parso==0.8.4 + # via jedi +patsy==1.0.1 + # via + # scanpy + # statsmodels +pexpect==4.9.0 + # via ipython +pillow==11.3.0 + # via matplotlib +prompt-toolkit==3.0.51 + # via ipython +propcache==0.3.2 + # via + # aiohttp + # yarl +protobuf==6.31.1 + # via + # mlflow-skinny + # orbax-checkpoint +ptyprocess==0.7.0 + # via pexpect +pure-eval==0.2.3 + # via stack-data +pyarrow==20.0.0 + # via + # datasets + # mlflow +pyasn1==0.6.1 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.4.2 + # via google-auth +pydantic==2.11.7 + # via + # fastapi + # mlflow-skinny +pydantic-core==2.33.2 + # via pydantic +pydot==4.0.1 + # via scib +pygments==2.19.2 + # via + # ipython + # rich +pynndescent==0.5.13 + # via + # scanpy + # umap-learn +pyparsing==3.2.3 + # via + # matplotlib + # pydot +pyro-api==0.1.2 + # via pyro-ppl +pyro-ppl==1.9.1 + # via scvi-tools +python-dateutil==2.9.0.post0 + # via + # graphene + # matplotlib + # pandas +pytorch-lightning==1.9.5 + # via scvi-tools +pytz==2025.2 + # via pandas +pyyaml==6.0.2 + # via + # datasets + # flax + # huggingface-hub + # ml-collections + # mlflow-skinny + # orbax + # orbax-checkpoint + # pytorch-lightning +requests==2.32.4 + # via + # databricks-sdk + # datasets + # docker + # huggingface-hub + # mlflow-skinny + # torchdata + # torchtext +rich==14.1.0 + # via + # flax + # scvi-tools +rsa==4.9.1 + # via google-auth +scanpy==1.11.4 + # via + # cell-gears + # scgpt + # scib +scgpt==0.2.1 + # via -r requirements.in +scib==1.1.7 + # via scgpt +scikit-learn==1.7.1 + # via + # cell-gears + # mlflow + # pynndescent + # scanpy + # scib + # scvi-tools + # umap-learn +scikit-misc==0.5.1 + # via + # scgpt + # scib +scipy==1.12.0 + # via + # -r requirements.in + # anndata + # dcor + # jax + # jaxlib + # mlflow + # pynndescent + # scanpy + # scib + # scikit-learn + # scvi-tools + # statsmodels + # umap-learn +scvi-tools==0.20.3 + # via scgpt +seaborn==0.13.2 + # via + # scanpy + # scib +session-info2==0.2 + # via scanpy +setuptools==80.9.0 + # via lightning-utilities +simplejson==3.20.1 + # via orbax-checkpoint +six==1.17.0 + # via + # docrep + # python-dateutil +smmap==5.0.2 + # via gitdb +sniffio==1.3.1 + # via anyio +sqlalchemy==2.0.43 + # via + # alembic + # mlflow +sqlparse==0.5.3 + # via mlflow-skinny +stack-data==0.6.3 + # via ipython +starlette==0.47.2 + # via fastapi +statsmodels==0.14.5 + # via scanpy +sympy==1.14.0 + # via torch +tensorstore==0.1.76 + # via + # flax + # orbax + # orbax-checkpoint +texttable==1.7.0 + # via igraph +threadpoolctl==3.6.0 + # via scikit-learn +tomli==2.2.1 + # via alembic +toolz==1.0.0 + # via chex +torch==2.1.2 + # via + # cell-gears + # pyro-ppl + # pytorch-lightning + # scgpt + # scvi-tools + # torchdata + # torchmetrics + # torchtext +torchdata==0.7.1 + # via torchtext +torchmetrics==1.8.1 + # via + # pytorch-lightning + # scvi-tools +torchtext==0.16.2 + # via scgpt +tqdm==4.67.1 + # via + # cell-gears + # datasets + # huggingface-hub + # numpyro + # pyro-ppl + # pytorch-lightning + # scanpy + # scvi-tools + # torchtext + # umap-learn +traitlets==5.14.3 + # via + # ipython + # matplotlib-inline +treescope==0.1.10 + # via flax +triton==2.1.0 + # via torch +typing-extensions==4.14.1 + # via + # aiosignal + # alembic + # anyio + # chex + # etils + # exceptiongroup + # fastapi + # flax + # graphene + # huggingface-hub + # ipython + # lightning-utilities + # mlflow-skinny + # multidict + # opentelemetry-api + # opentelemetry-sdk + # opentelemetry-semantic-conventions + # orbax + # orbax-checkpoint + # pydantic + # pydantic-core + # pytorch-lightning + # scanpy + # scgpt + # sqlalchemy + # starlette + # torch + # typing-inspection + # uvicorn +typing-inspection==0.4.1 + # via pydantic +tzdata==2025.2 + # via pandas +umap-learn==0.5.9.post2 + # via + # scanpy + # scgpt + # scib +urllib3==2.5.0 + # via + # docker + # requests + # torchdata +uvicorn==0.35.0 + # via mlflow-skinny +wcwidth==0.2.13 + # via prompt-toolkit +werkzeug==3.1.3 + # via flask +wrapt==1.17.3 + # via deprecated +xxhash==3.5.0 + # via datasets +yarl==1.20.1 + # via aiohttp +zipp==3.23.0 + # via + # etils + # importlib-metadata diff --git a/src/methods/scgpt_mlflow/script.py b/src/methods/scgpt_mlflow/script.py new file mode 100644 index 00000000..db54fd23 --- /dev/null +++ b/src/methods/scgpt_mlflow/script.py @@ -0,0 +1,76 @@ +import sys + +import anndata as ad +import mlflow.pyfunc + +## VIASH START +# Note: this section is auto-generated by viash at runtime. To edit it, make changes +# in config.vsh.yaml and then run `viash config inject config.vsh.yaml`. +par = { + "input": "resources_test/.../input.h5ad", + "output": "output.h5ad", + "model": "resources_test/.../model", +} +meta = {"name": "scGPT_mlflow"} +## VIASH END + +sys.path.append(meta["resources_dir"]) +from exit_codes import exit_non_applicable # noqa: E402 +from mlflow import embed # noqa: E402 +from read_anndata_partial import read_anndata # noqa: E402 +from unpack import unpack_directory # noqa: E402 + +print("====== scGPT (MLflow model) ======", flush=True) + +print("\n>>> Reading input files...", flush=True) +print(f"Input H5AD file: '{par['input']}'", flush=True) +adata = read_anndata(par["input"], X="layers/counts", obs="obs", var="var", uns="uns") + +if adata.uns["dataset_organism"] != "homo_sapiens": + exit_non_applicable( + f"scGPT (MLflow) can only be used with human data " + f'(dataset_organism == "{adata.uns["dataset_organism"]}")' + ) + +print(adata, flush=True) + +print("\n>>> Unpacking model...", flush=True) +model_dir, model_temp = unpack_directory(par["model"]) + +print("\n>>> Loading model...", flush=True) +model = mlflow.pyfunc.load_model(model_dir) +print(model, flush=True) + +print("\n>>> Embedding data...", flush=True) +embedding = embed( + adata, + model, + layers=["counts"], + var={"feature_name": "feature_name"}, + model_params={"gene_col": "feature_name"}, +) + +print("\n>>> Storing output...", flush=True) +output = ad.AnnData( + obs=adata.obs[[]], + var=adata.var[[]], + obsm={ + "X_emb": embedding, + }, + uns={ + "dataset_id": adata.uns["dataset_id"], + "normalization_id": adata.uns["normalization_id"], + "method_id": meta["name"], + }, +) +print(output) + +print("\n>>> Writing output to file...", flush=True) +print(f"Output H5AD file: '{par['output']}'", flush=True) +output.write_h5ad(par["output"], compression="gzip") + +print("\n>>> Cleaning up temporary files...", flush=True) +if model_temp is not None: + model_temp.cleanup() + +print("\n>>> Done!", flush=True) diff --git a/src/methods/scvi_mlflow/config.vsh.yaml b/src/methods/scvi_mlflow/config.vsh.yaml new file mode 100644 index 00000000..85b6520f --- /dev/null +++ b/src/methods/scvi_mlflow/config.vsh.yaml @@ -0,0 +1,49 @@ +__merge__: ../../api/base_method.yaml + +name: scvi_mlflow +label: scVI (MLflow model) +summary: scVI combines a variational autoencoder with a hierarchical Bayesian model (MLflow model) +description: | + scVI combines a variational autoencoder with a hierarchical Bayesian model. + It uses the negative binomial distribution to describe gene expression of + each cell, conditioned on unobserved factors and the batch variable. + + This version uses a pre-trained MLflow model. +references: + doi: + - 10.1038/s41592-018-0229-2 +links: + repository: https://github.com/scverse/scvi-tools + documentation: https://docs.scvi-tools.org/en/stable/user_guide/models/scvi.html + +info: + method_types: [embedding] + preferred_normalization: counts + +arguments: + - name: --model + type: file + description: | + An MLflow model URI for the scVI model. If it is a .zip or + .tar.gz file it will be extracted to a temporary directory. + required: true + +resources: + - type: python_script + path: script.py + - path: /src/utils/read_anndata_partial.py + - path: /src/utils/exit_codes.py + - path: /src/utils/unpack.py + - path: /src/utils/mlflow.py + - path: requirements.txt + +engines: + - type: docker + image: openproblems/base_pytorch_nvidia:1 + __merge__: /src/utils/mlflow_docker_setup.yaml + +runners: + - type: executable + - type: nextflow + directives: + label: [hightime, highmem, midcpu, gpu] diff --git a/src/methods/scvi_mlflow/requirements.txt b/src/methods/scvi_mlflow/requirements.txt new file mode 100644 index 00000000..c3c79df5 --- /dev/null +++ b/src/methods/scvi_mlflow/requirements.txt @@ -0,0 +1,459 @@ +# This file was autogenerated by uv via the following command: +# uv pip compile requirements.in -o /tmp/tmp6b02zuzi/requirements_initial.txt +absl-py==2.3.1 + # via + # chex + # ml-collections + # optax + # orbax-checkpoint +aiohappyeyeballs==2.6.1 + # via aiohttp +aiohttp==3.12.15 + # via fsspec +aiosignal==1.4.0 + # via aiohttp +alembic==1.16.4 + # via mlflow +anndata==0.10.8 + # via + # -r requirements.in + # mudata + # scvi-tools +annotated-types==0.7.0 + # via pydantic +anyio==4.10.0 + # via starlette +array-api-compat==1.12.0 + # via anndata +attrs==25.3.0 + # via aiohttp +blinker==1.9.0 + # via flask +cachetools==5.5.2 + # via + # google-auth + # mlflow-skinny +certifi==2025.8.3 + # via requests +charset-normalizer==3.4.3 + # via requests +chex==0.1.90 + # via optax +click==8.2.1 + # via + # flask + # mlflow-skinny + # uvicorn +cloudpickle==3.1.1 + # via mlflow-skinny +contourpy==1.3.3 + # via matplotlib +cycler==0.12.1 + # via matplotlib +databricks-sdk==0.62.0 + # via mlflow-skinny +docker==7.1.0 + # via mlflow +docrep==0.3.2 + # via scvi-tools +etils==1.13.0 + # via orbax-checkpoint +fastapi==0.116.1 + # via mlflow-skinny +filelock==3.18.0 + # via + # torch + # triton +flask==3.1.1 + # via mlflow +flax==0.10.4 + # via scvi-tools +fonttools==4.59.0 + # via matplotlib +frozenlist==1.7.0 + # via + # aiohttp + # aiosignal +fsspec==2025.7.0 + # via + # etils + # lightning + # pytorch-lightning + # torch +gitdb==4.0.12 + # via gitpython +gitpython==3.1.45 + # via mlflow-skinny +google-auth==2.40.3 + # via databricks-sdk +graphene==3.4.3 + # via mlflow +graphql-core==3.2.6 + # via + # graphene + # graphql-relay +graphql-relay==3.2.0 + # via graphene +greenlet==3.2.4 + # via sqlalchemy +gunicorn==23.0.0 + # via mlflow +h11==0.16.0 + # via uvicorn +h5py==3.14.0 + # via + # anndata + # scvi-tools +humanize==4.12.3 + # via orbax-checkpoint +idna==3.10 + # via + # anyio + # requests + # yarl +importlib-metadata==8.7.0 + # via + # mlflow-skinny + # opentelemetry-api +importlib-resources==6.5.2 + # via etils +itsdangerous==2.2.0 + # via flask +jax==0.4.33 + # via + # -r requirements.in + # chex + # flax + # numpyro + # optax + # orbax-checkpoint + # scvi-tools +jaxlib==0.4.33 + # via + # -r requirements.in + # chex + # jax + # numpyro + # optax + # orbax-checkpoint + # scvi-tools +jinja2==3.1.6 + # via + # flask + # torch +joblib==1.5.1 + # via scikit-learn +kiwisolver==1.4.9 + # via matplotlib +lightning==2.5.2 + # via scvi-tools +lightning-utilities==0.15.2 + # via + # lightning + # pytorch-lightning + # torchmetrics +mako==1.3.10 + # via alembic +markdown-it-py==4.0.0 + # via rich +markupsafe==3.0.2 + # via + # flask + # jinja2 + # mako + # werkzeug +matplotlib==3.10.5 + # via mlflow +mdurl==0.1.2 + # via markdown-it-py +ml-collections==1.1.0 + # via scvi-tools +ml-dtypes==0.5.3 + # via + # jax + # jaxlib + # tensorstore +mlflow==3.1.0 + # via -r requirements.in +mlflow-skinny==3.1.0 + # via mlflow +mpmath==1.3.0 + # via sympy +msgpack==1.1.1 + # via + # flax + # orbax-checkpoint +mudata==0.3.2 + # via scvi-tools +multidict==6.6.4 + # via + # aiohttp + # yarl +multipledispatch==1.0.0 + # via numpyro +natsort==8.4.0 + # via anndata +nest-asyncio==1.6.0 + # via orbax-checkpoint +networkx==3.5 + # via torch +numpy==1.26.4 + # via + # anndata + # chex + # contourpy + # flax + # h5py + # jax + # jaxlib + # matplotlib + # ml-dtypes + # mlflow + # numpyro + # optax + # orbax-checkpoint + # pandas + # pyro-ppl + # scikit-learn + # scipy + # scvi-tools + # tensorstore + # torchmetrics + # treescope +numpyro==0.19.0 + # via scvi-tools +nvidia-cublas-cu12==12.4.5.8 + # via + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 + # torch +nvidia-cuda-cupti-cu12==12.4.127 + # via torch +nvidia-cuda-nvrtc-cu12==12.4.127 + # via torch +nvidia-cuda-runtime-cu12==12.4.127 + # via torch +nvidia-cudnn-cu12==9.1.0.70 + # via torch +nvidia-cufft-cu12==11.2.1.3 + # via torch +nvidia-curand-cu12==10.3.5.147 + # via torch +nvidia-cusolver-cu12==11.6.1.9 + # via torch +nvidia-cusparse-cu12==12.3.1.170 + # via + # nvidia-cusolver-cu12 + # torch +nvidia-nccl-cu12==2.21.5 + # via torch +nvidia-nvjitlink-cu12==12.4.127 + # via + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 + # torch +nvidia-nvtx-cu12==12.4.127 + # via torch +opentelemetry-api==1.36.0 + # via + # mlflow-skinny + # opentelemetry-sdk + # opentelemetry-semantic-conventions +opentelemetry-sdk==1.36.0 + # via mlflow-skinny +opentelemetry-semantic-conventions==0.57b0 + # via opentelemetry-sdk +opt-einsum==3.4.0 + # via + # jax + # pyro-ppl +optax==0.2.5 + # via + # flax + # scvi-tools +orbax-checkpoint==0.6.4 + # via flax +packaging==25.0 + # via + # anndata + # gunicorn + # lightning + # lightning-utilities + # matplotlib + # mlflow-skinny + # pytorch-lightning + # torchmetrics +pandas==2.2.3 + # via + # -r requirements.in + # anndata + # mlflow + # scvi-tools +pillow==11.3.0 + # via matplotlib +propcache==0.3.2 + # via + # aiohttp + # yarl +protobuf==6.31.1 + # via + # mlflow-skinny + # orbax-checkpoint +pyarrow==20.0.0 + # via mlflow +pyasn1==0.6.1 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.4.2 + # via google-auth +pydantic==2.11.7 + # via + # fastapi + # mlflow-skinny +pydantic-core==2.33.2 + # via pydantic +pygments==2.19.2 + # via rich +pyparsing==3.2.3 + # via matplotlib +pyro-api==0.1.2 + # via pyro-ppl +pyro-ppl==1.9.1 + # via scvi-tools +python-dateutil==2.9.0.post0 + # via + # graphene + # matplotlib + # pandas +pytorch-lightning==2.5.2 + # via lightning +pytz==2025.2 + # via pandas +pyyaml==6.0.2 + # via + # flax + # lightning + # ml-collections + # mlflow-skinny + # orbax-checkpoint + # pytorch-lightning +requests==2.32.4 + # via + # databricks-sdk + # docker + # mlflow-skinny +rich==14.1.0 + # via + # flax + # scvi-tools +rsa==4.9.1 + # via google-auth +scikit-learn==1.7.1 + # via + # mlflow + # scvi-tools +scipy==1.16.1 + # via + # anndata + # jax + # jaxlib + # mlflow + # scikit-learn + # scvi-tools +scvi-tools==1.1.6.post2 + # via -r requirements.in +setuptools==80.9.0 + # via lightning-utilities +six==1.17.0 + # via + # docrep + # python-dateutil +smmap==5.0.2 + # via gitdb +sniffio==1.3.1 + # via anyio +sqlalchemy==2.0.43 + # via + # alembic + # mlflow +sqlparse==0.5.3 + # via mlflow-skinny +starlette==0.47.2 + # via fastapi +sympy==1.13.1 + # via torch +tensorstore==0.1.76 + # via + # flax + # orbax-checkpoint +threadpoolctl==3.6.0 + # via scikit-learn +toolz==1.0.0 + # via chex +torch==2.5.1 + # via + # -r requirements.in + # lightning + # pyro-ppl + # pytorch-lightning + # scvi-tools + # torchmetrics +torchmetrics==1.8.1 + # via + # lightning + # pytorch-lightning + # scvi-tools +tqdm==4.67.1 + # via + # lightning + # numpyro + # pyro-ppl + # pytorch-lightning + # scvi-tools +treescope==0.1.10 + # via flax +triton==3.1.0 + # via torch +typing-extensions==4.14.1 + # via + # aiosignal + # alembic + # anyio + # chex + # etils + # fastapi + # flax + # graphene + # lightning + # lightning-utilities + # mlflow-skinny + # opentelemetry-api + # opentelemetry-sdk + # opentelemetry-semantic-conventions + # orbax-checkpoint + # pydantic + # pydantic-core + # pytorch-lightning + # sqlalchemy + # starlette + # torch + # typing-inspection +typing-inspection==0.4.1 + # via pydantic +tzdata==2025.2 + # via pandas +urllib3==2.5.0 + # via + # docker + # requests +uvicorn==0.35.0 + # via mlflow-skinny +werkzeug==3.1.3 + # via flask +yarl==1.20.1 + # via aiohttp +zipp==3.23.0 + # via + # etils + # importlib-metadata diff --git a/src/methods/scvi_mlflow/script.py b/src/methods/scvi_mlflow/script.py new file mode 100644 index 00000000..0c27a11a --- /dev/null +++ b/src/methods/scvi_mlflow/script.py @@ -0,0 +1,80 @@ +import sys + +import anndata as ad +import mlflow.pyfunc + +## VIASH START +# Note: this section is auto-generated by viash at runtime. To edit it, make changes +# in config.vsh.yaml and then run `viash config inject config.vsh.yaml`. +par = { + "input": "resources_test/.../input.h5ad", + "output": "output.h5ad", + "model": "resources_test/.../model", +} +meta = {"name": "scvi_mlflow"} +## VIASH END + +sys.path.append(meta["resources_dir"]) +from exit_codes import exit_non_applicable # noqa: E402 +from mlflow import embed # noqa: E402 +from read_anndata_partial import read_anndata # noqa: E402 +from unpack import unpack_directory # noqa: E402 + +print("====== scVI (MLflow model) ======", flush=True) + +print("\n>>> Reading input files...", flush=True) +print(f"Input H5AD file: '{par['input']}'", flush=True) +adata = read_anndata(par["input"], X="layers/counts", obs="obs", var="var", uns="uns") + +if adata.uns["dataset_organism"] == "homo_sapiens": + organism = "human" +elif adata.uns["dataset_organism"] == "mus_musculus": + organism = "mouse" +else: + exit_non_applicable( + f"scVI (MLflow) can only be used with human or mouse data " + f'(dataset_organism == "{adata.uns["dataset_organism"]}")' + ) + +print(adata, flush=True) + +print("\n>>> Unpacking model...", flush=True) +model_dir, model_temp = unpack_directory(par["model"]) + +print(f"\n>>> Loading {organism} model...", flush=True) +model = mlflow.pyfunc.load_model(model_dir, model_config={"organism": organism}) +print(model, flush=True) + +print("\n>>> Embedding data...", flush=True) +embedding = embed( + adata, + model, + layers=["counts"], + obs=["batch"], + var={"feature_id": "feature_id"} +) + +print("\n>>> Storing output...", flush=True) +output = ad.AnnData( + obs=adata.obs[[]], + var=adata.var[[]], + obsm={ + "X_emb": embedding, + }, + uns={ + "dataset_id": adata.uns["dataset_id"], + "normalization_id": adata.uns["normalization_id"], + "method_id": meta["name"], + }, +) +print(output) + +print("\n>>> Writing output to file...", flush=True) +print(f"Output H5AD file: '{par['output']}'", flush=True) +output.write_h5ad(par["output"], compression="gzip") + +print("\n>>> Cleaning up temporary files...", flush=True) +if model_temp is not None: + model_temp.cleanup() + +print("\n>>> Done!", flush=True) diff --git a/src/methods/transcriptformer_mlflow/config.vsh.yaml b/src/methods/transcriptformer_mlflow/config.vsh.yaml new file mode 100644 index 00000000..453ba275 --- /dev/null +++ b/src/methods/transcriptformer_mlflow/config.vsh.yaml @@ -0,0 +1,53 @@ +__merge__: ../../api/base_method.yaml + +name: transcriptformer_mlflow +label: TranscriptFormer (MLflow model) +summary: "Context-aware representations of single-cell transcriptomes by jointly modeling genes and transcripts" +description: | + TranscriptFormer is designed to learn rich, context-aware representations of + single-cell transcriptomes while jointly modeling genes and transcripts using + a novel generative architecture. + + It is a family of generative foundation models representing a cross-species + generative cell atlas trained on up to 112 million cells spanning 1.53 billion + years of evolution across 12 species. + + Here, we use a version packaged as an MLflow model. +references: + doi: + - 10.1101/2025.04.25.650731 +links: + documentation: https://github.com/czi-ai/transcriptformer#readme + repository: https://github.com/czi-ai/transcriptformer + +info: + method_types: [embedding] + preferred_normalization: counts + +arguments: + - name: --model + type: file + description: | + An MLflow model URI for the transcriptformer model. If it is a .zip or + .tar.gz file it will be extracted to a temporary directory. + required: true + +resources: + - type: python_script + path: script.py + - path: /src/utils/read_anndata_partial.py + - path: /src/utils/exit_codes.py + - path: /src/utils/unpack.py + - path: /src/utils/mlflow.py + - path: requirements.txt + +engines: + - type: docker + image: openproblems/base_pytorch_nvidia:1 + __merge__: /src/utils/mlflow_docker_setup.yaml + +runners: + - type: executable + - type: nextflow + directives: + label: [hightime, highmem, midcpu, gpu] diff --git a/src/methods/transcriptformer_mlflow/requirements.txt b/src/methods/transcriptformer_mlflow/requirements.txt new file mode 100644 index 00000000..70d923d1 --- /dev/null +++ b/src/methods/transcriptformer_mlflow/requirements.txt @@ -0,0 +1,338 @@ +# This file was autogenerated by uv via the following command: +# uv pip compile requirements.in -o requirements.txt +aiobotocore==2.23.0 + # via s3fs +aiohappyeyeballs==2.6.1 + # via aiohttp +aiohttp==3.12.13 + # via + # aiobotocore + # fsspec + # s3fs +aioitertools==0.12.0 + # via aiobotocore +aiosignal==1.3.2 + # via aiohttp +anndata==0.11.4 + # via + # cellxgene-census + # scanpy + # somacore + # tiledbsoma + # transcriptformer +antlr4-python3-runtime==4.9.3 + # via + # hydra-core + # omegaconf +array-api-compat==1.12.0 + # via anndata +attrs==25.3.0 + # via + # aiohttp + # somacore + # tiledbsoma +boto3==1.38.27 + # via transcriptformer +botocore==1.38.27 + # via + # aiobotocore + # boto3 + # s3transfer +cellxgene-census==1.17.0 + # via transcriptformer +certifi==2025.6.15 + # via requests +charset-normalizer==3.4.2 + # via requests +contourpy==1.3.2 + # via matplotlib +cycler==0.12.1 + # via matplotlib +filelock==3.18.0 + # via + # torch + # triton +fonttools==4.58.4 + # via matplotlib +frozenlist==1.7.0 + # via + # aiohttp + # aiosignal +fsspec==2025.5.1 + # via + # pytorch-lightning + # s3fs + # torch +h5py==3.14.0 + # via + # anndata + # scanpy + # transcriptformer +hydra-core==1.3.2 + # via transcriptformer +idna==3.10 + # via + # requests + # yarl +iniconfig==2.1.0 + # via pytest +jinja2==3.1.6 + # via torch +jmespath==1.0.1 + # via + # aiobotocore + # boto3 + # botocore +joblib==1.5.1 + # via + # pynndescent + # scanpy + # scikit-learn +kiwisolver==1.4.8 + # via matplotlib +legacy-api-wrap==1.4.1 + # via scanpy +lightning-utilities==0.14.3 + # via + # pytorch-lightning + # torchmetrics +llvmlite==0.44.0 + # via + # numba + # pynndescent +markupsafe==3.0.2 + # via jinja2 +matplotlib==3.10.3 + # via + # scanpy + # seaborn +more-itertools==10.7.0 + # via tiledbsoma +mpmath==1.3.0 + # via sympy +multidict==6.6.0 + # via + # aiobotocore + # aiohttp + # yarl +natsort==8.4.0 + # via + # anndata + # scanpy +networkx==3.5 + # via + # scanpy + # torch +numba==0.61.2 + # via + # pynndescent + # scanpy + # umap-learn +numpy==2.2.6 + # via + # anndata + # cellxgene-census + # contourpy + # h5py + # matplotlib + # numba + # pandas + # patsy + # scanpy + # scikit-learn + # scipy + # seaborn + # shapely + # somacore + # statsmodels + # tiledbsoma + # torchmetrics + # transcriptformer + # umap-learn +nvidia-cublas-cu12==12.4.5.8 + # via + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 + # torch +nvidia-cuda-cupti-cu12==12.4.127 + # via torch +nvidia-cuda-nvrtc-cu12==12.4.127 + # via torch +nvidia-cuda-runtime-cu12==12.4.127 + # via torch +nvidia-cudnn-cu12==9.1.0.70 + # via torch +nvidia-cufft-cu12==11.2.1.3 + # via torch +nvidia-curand-cu12==10.3.5.147 + # via torch +nvidia-cusolver-cu12==11.6.1.9 + # via torch +nvidia-cusparse-cu12==12.3.1.170 + # via + # nvidia-cusolver-cu12 + # torch +nvidia-ml-py==12.575.51 + # via pynvml +nvidia-nccl-cu12==2.21.5 + # via torch +nvidia-nvjitlink-cu12==12.4.127 + # via + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 + # torch +nvidia-nvtx-cu12==12.4.127 + # via torch +omegaconf==2.3.0 + # via hydra-core +packaging==25.0 + # via + # anndata + # hydra-core + # lightning-utilities + # matplotlib + # pytest + # pytorch-lightning + # scanpy + # statsmodels + # torchmetrics +pandas==2.3.0 + # via + # anndata + # scanpy + # seaborn + # somacore + # statsmodels + # tiledbsoma + # transcriptformer +patsy==1.0.1 + # via + # scanpy + # statsmodels +pillow==11.2.1 + # via matplotlib +pluggy==1.6.0 + # via pytest +propcache==0.3.2 + # via + # aiohttp + # yarl +psutil==7.0.0 + # via transcriptformer +pyarrow==20.0.0 + # via + # somacore + # tiledbsoma +pyarrow-hotfix==0.7 + # via somacore +pygments==2.19.2 + # via pytest +pynndescent==0.5.13 + # via + # scanpy + # umap-learn +pynvml==12.0.0 + # via transcriptformer +pyparsing==3.2.3 + # via matplotlib +pytest==8.4.1 + # via transcriptformer +python-dateutil==2.9.0.post0 + # via + # aiobotocore + # botocore + # matplotlib + # pandas +pytorch-lightning==2.5.2 + # via transcriptformer +pytz==2025.2 + # via pandas +pyyaml==6.0.2 + # via + # omegaconf + # pytorch-lightning +requests==2.32.4 + # via cellxgene-census +s3fs==2025.5.1 + # via cellxgene-census +s3transfer==0.13.0 + # via boto3 +scanpy==1.11.2 + # via + # tiledbsoma + # transcriptformer +scikit-learn==1.7.0 + # via + # pynndescent + # scanpy + # umap-learn +scipy==1.16.0 + # via + # anndata + # pynndescent + # scanpy + # scikit-learn + # somacore + # statsmodels + # tiledbsoma + # transcriptformer + # umap-learn +seaborn==0.13.2 + # via scanpy +session-info2==0.1.2 + # via scanpy +setuptools==80.9.0 + # via lightning-utilities +shapely==2.1.1 + # via somacore +six==1.17.0 + # via python-dateutil +somacore==1.0.28 + # via tiledbsoma +statsmodels==0.14.4 + # via scanpy +sympy==1.13.1 + # via torch +threadpoolctl==3.6.0 + # via scikit-learn +tiledbsoma==1.17.0 + # via cellxgene-census +timeout-decorator==0.5.0 + # via transcriptformer +torch==2.5.1 + # via + # pytorch-lightning + # torchmetrics + # transcriptformer +torchmetrics==1.7.3 + # via pytorch-lightning +tqdm==4.67.1 + # via + # pytorch-lightning + # scanpy + # umap-learn +transcriptformer==0.3.0 + # via -r requirements.in +triton==3.1.0 + # via torch +typing-extensions==4.14.0 + # via + # cellxgene-census + # lightning-utilities + # pytorch-lightning + # scanpy + # somacore + # tiledbsoma + # torch +tzdata==2025.2 + # via pandas +umap-learn==0.5.7 + # via scanpy +urllib3==2.5.0 + # via + # botocore + # requests +wrapt==1.17.2 + # via aiobotocore +yarl==1.20.1 + # via aiohttp diff --git a/src/methods/transcriptformer_mlflow/script.py b/src/methods/transcriptformer_mlflow/script.py new file mode 100644 index 00000000..0ddacee8 --- /dev/null +++ b/src/methods/transcriptformer_mlflow/script.py @@ -0,0 +1,82 @@ +import sys + +import anndata as ad +import mlflow.pyfunc + +## VIASH START +# Note: this section is auto-generated by viash at runtime. To edit it, make changes +# in config.vsh.yaml and then run `viash config inject config.vsh.yaml`. +par = { + "input": "resources_test/.../input.h5ad", + "output": "output.h5ad", + "model": "resources_test/.../model", +} +meta = {"name": "transcriptformer_mlflow"} +## VIASH END + +sys.path.append(meta["resources_dir"]) +from exit_codes import exit_non_applicable # noqa: E402 +from mlflow import embed # noqa: E402 +from read_anndata_partial import read_anndata # noqa: E402 +from unpack import unpack_directory # noqa: E402 + +print("====== TranscriptFormer (MLflow model) ======", flush=True) + +print("\n>>> Reading input files...", flush=True) +print(f"Input H5AD file: '{par['input']}'", flush=True) +adata = read_anndata(par["input"], X="layers/counts", obs="obs", var="var", uns="uns") + +if adata.uns["dataset_organism"] != "homo_sapiens": + exit_non_applicable( + f"Transcriptformer can only be used with human data " + f'(dataset_organism == "{adata.uns["dataset_organism"]}")' + ) + +print(adata, flush=True) + +print("\n>>> Unpacking model...", flush=True) +model_dir, model_temp = unpack_directory(par["model"]) + +print("\n>>> Loading model...", flush=True) +model = mlflow.pyfunc.load_model(model_dir) +print(model, flush=True) + + +def process_transcriptformer_input(input_adata): + """Add TranscriptFormer-specific fields to input AnnData.""" + input_adata.obs["assay"] = "unknown" # Avoid error if assay is missing + + +print("\n>>> Embedding data...", flush=True) +embedding = embed( + adata, + model, + layers=["counts"], + var={"feature_id": "ensembl_id"}, + process_adata=process_transcriptformer_input, +) + +print("\n>>> Storing output...", flush=True) +output = ad.AnnData( + obs=adata.obs[[]], + var=adata.var[[]], + obsm={ + "X_emb": embedding, + }, + uns={ + "dataset_id": adata.uns["dataset_id"], + "normalization_id": adata.uns["normalization_id"], + "method_id": meta["name"], + }, +) +print(output) + +print("\n>>> Writing output to file...", flush=True) +print(f"Output H5AD file: '{par['output']}'", flush=True) +output.write_h5ad(par["output"], compression="gzip") + +print("\n>>> Cleaning up temporary files...", flush=True) +if model_temp is not None: + model_temp.cleanup() + +print("\n>>> Done!", flush=True) diff --git a/src/methods/uce_mlflow/config.vsh.yaml b/src/methods/uce_mlflow/config.vsh.yaml new file mode 100644 index 00000000..96ccbd8b --- /dev/null +++ b/src/methods/uce_mlflow/config.vsh.yaml @@ -0,0 +1,49 @@ +__merge__: ../../api/base_method.yaml + +name: uce_mlflow +label: UCE (MLflow model) +summary: UCE offers a unified biological latent space that can represent any cell +description: | + Universal Cell Embedding (UCE) is a single-cell foundation model that offers a + unified biological latent space that can represent any cell, regardless of + tissue or species + + Here, we use a version packaged as an MLflow model. +references: + doi: + - 10.1101/2023.11.28.568918 +links: + documentation: https://github.com/snap-stanford/UCE/blob/main/README.md + repository: https://github.com/snap-stanford/UCE + +info: + method_types: [embedding] + preferred_normalization: counts + +arguments: + - name: --model + type: file + description: | + An MLflow model URI for the UCE model. If it is a .zip or + .tar.gz file it will be extracted to a temporary directory. + required: true + +resources: + - type: python_script + path: script.py + - path: /src/utils/read_anndata_partial.py + - path: /src/utils/exit_codes.py + - path: /src/utils/unpack.py + - path: /src/utils/mlflow.py + - path: requirements.txt + +engines: + - type: docker + image: openproblems/base_pytorch_nvidia:1 + __merge__: /src/utils/mlflow_docker_setup.yaml + +runners: + - type: executable + - type: nextflow + directives: + label: [hightime, highmem, midcpu, biggpu] diff --git a/src/methods/uce_mlflow/requirements.txt b/src/methods/uce_mlflow/requirements.txt new file mode 100644 index 00000000..b2f4227b --- /dev/null +++ b/src/methods/uce_mlflow/requirements.txt @@ -0,0 +1,366 @@ +# This file was autogenerated by uv via the following command: +# uv pip compile requirements.in -o /tmp/tmpg2ov1w_7/requirements_initial.txt +accelerate==0.34.2 + # via -r requirements.in +alembic==1.16.4 + # via mlflow +anndata==0.10.9 + # via + # -r requirements.in + # scanpy +annotated-types==0.7.0 + # via pydantic +antlr4-python3-runtime==4.9.3 + # via omegaconf +anyio==4.10.0 + # via starlette +array-api-compat==1.12.0 + # via anndata +blinker==1.9.0 + # via flask +cachetools==5.5.2 + # via + # google-auth + # mlflow-skinny +certifi==2025.8.3 + # via requests +charset-normalizer==3.4.3 + # via requests +click==8.2.1 + # via + # flask + # mlflow-skinny + # uvicorn +cloudpickle==3.1.1 + # via mlflow-skinny +contourpy==1.3.3 + # via matplotlib +cycler==0.12.1 + # via matplotlib +databricks-sdk==0.62.0 + # via mlflow-skinny +docker==7.1.0 + # via mlflow +fastapi==0.116.1 + # via mlflow-skinny +filelock==3.18.0 + # via + # huggingface-hub + # torch + # triton +flask==3.1.1 + # via mlflow +fonttools==4.59.0 + # via matplotlib +fsspec==2025.7.0 + # via + # huggingface-hub + # torch +gitdb==4.0.12 + # via gitpython +gitpython==3.1.45 + # via mlflow-skinny +google-auth==2.40.3 + # via databricks-sdk +graphene==3.4.3 + # via mlflow +graphql-core==3.2.6 + # via + # graphene + # graphql-relay +graphql-relay==3.2.0 + # via graphene +greenlet==3.2.4 + # via sqlalchemy +gunicorn==23.0.0 + # via mlflow +h11==0.16.0 + # via uvicorn +h5py==3.14.0 + # via + # anndata + # scanpy +hf-xet==1.1.7 + # via huggingface-hub +huggingface-hub==0.34.4 + # via accelerate +idna==3.10 + # via + # anyio + # requests +importlib-metadata==8.7.0 + # via + # mlflow-skinny + # opentelemetry-api +itsdangerous==2.2.0 + # via flask +jinja2==3.1.6 + # via + # flask + # torch +joblib==1.5.1 + # via + # pynndescent + # scanpy + # scikit-learn +kiwisolver==1.4.9 + # via matplotlib +legacy-api-wrap==1.4.1 + # via scanpy +llvmlite==0.44.0 + # via + # numba + # pynndescent +mako==1.3.10 + # via alembic +markupsafe==3.0.2 + # via + # flask + # jinja2 + # mako + # werkzeug +matplotlib==3.10.5 + # via + # mlflow + # scanpy + # seaborn +mlflow==3.1.0 + # via -r requirements.in +mlflow-skinny==3.1.0 + # via mlflow +mpmath==1.3.0 + # via sympy +natsort==8.4.0 + # via + # anndata + # scanpy +networkx==3.5 + # via + # scanpy + # torch +numba==0.61.2 + # via + # pynndescent + # scanpy + # umap-learn +numpy==1.26.4 + # via + # -r requirements.in + # accelerate + # anndata + # contourpy + # h5py + # matplotlib + # mlflow + # numba + # pandas + # patsy + # scanpy + # scikit-learn + # scipy + # seaborn + # statsmodels + # umap-learn +nvidia-cublas-cu12==12.1.3.1 + # via + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 + # torch +nvidia-cuda-cupti-cu12==12.1.105 + # via torch +nvidia-cuda-nvrtc-cu12==12.1.105 + # via torch +nvidia-cuda-runtime-cu12==12.1.105 + # via torch +nvidia-cudnn-cu12==9.1.0.70 + # via torch +nvidia-cufft-cu12==11.0.2.54 + # via torch +nvidia-curand-cu12==10.3.2.106 + # via torch +nvidia-cusolver-cu12==11.4.5.107 + # via torch +nvidia-cusparse-cu12==12.1.0.106 + # via + # nvidia-cusolver-cu12 + # torch +nvidia-nccl-cu12==2.20.5 + # via torch +nvidia-nvjitlink-cu12==12.9.86 + # via + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 +nvidia-nvtx-cu12==12.1.105 + # via torch +omegaconf==2.3.0 + # via -r requirements.in +opentelemetry-api==1.36.0 + # via + # mlflow-skinny + # opentelemetry-sdk + # opentelemetry-semantic-conventions +opentelemetry-sdk==1.36.0 + # via mlflow-skinny +opentelemetry-semantic-conventions==0.57b0 + # via opentelemetry-sdk +packaging==25.0 + # via + # accelerate + # anndata + # gunicorn + # huggingface-hub + # matplotlib + # mlflow-skinny + # scanpy + # statsmodels +pandas==2.2.3 + # via + # -r requirements.in + # anndata + # mlflow + # scanpy + # seaborn + # statsmodels +patsy==1.0.1 + # via + # scanpy + # statsmodels +pillow==11.3.0 + # via matplotlib +protobuf==6.31.1 + # via mlflow-skinny +psutil==7.0.0 + # via accelerate +pyarrow==20.0.0 + # via mlflow +pyasn1==0.6.1 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.4.2 + # via google-auth +pydantic==2.11.7 + # via + # fastapi + # mlflow-skinny +pydantic-core==2.33.2 + # via pydantic +pynndescent==0.5.13 + # via + # scanpy + # umap-learn +pyparsing==3.2.3 + # via matplotlib +python-dateutil==2.9.0.post0 + # via + # graphene + # matplotlib + # pandas +pytz==2025.2 + # via pandas +pyyaml==6.0.2 + # via + # accelerate + # huggingface-hub + # mlflow-skinny + # omegaconf +requests==2.32.4 + # via + # databricks-sdk + # docker + # huggingface-hub + # mlflow-skinny +rsa==4.9.1 + # via google-auth +safetensors==0.6.2 + # via accelerate +scanpy==1.10.2 + # via -r requirements.in +scikit-learn==1.7.1 + # via + # mlflow + # pynndescent + # scanpy + # umap-learn +scipy==1.14.1 + # via + # -r requirements.in + # anndata + # mlflow + # pynndescent + # scanpy + # scikit-learn + # statsmodels + # umap-learn +seaborn==0.13.2 + # via scanpy +session-info==1.0.1 + # via scanpy +six==1.17.0 + # via python-dateutil +smmap==5.0.2 + # via gitdb +sniffio==1.3.1 + # via anyio +sqlalchemy==2.0.43 + # via + # alembic + # mlflow +sqlparse==0.5.3 + # via mlflow-skinny +starlette==0.47.2 + # via fastapi +statsmodels==0.14.5 + # via scanpy +stdlib-list==0.11.1 + # via session-info +sympy==1.14.0 + # via torch +threadpoolctl==3.6.0 + # via scikit-learn +torch==2.4.1 + # via + # -r requirements.in + # accelerate +tqdm==4.66.5 + # via + # -r requirements.in + # huggingface-hub + # scanpy + # umap-learn +triton==3.0.0 + # via torch +typing-extensions==4.14.1 + # via + # alembic + # anyio + # fastapi + # graphene + # huggingface-hub + # mlflow-skinny + # opentelemetry-api + # opentelemetry-sdk + # opentelemetry-semantic-conventions + # pydantic + # pydantic-core + # sqlalchemy + # starlette + # torch + # typing-inspection +typing-inspection==0.4.1 + # via pydantic +tzdata==2025.2 + # via pandas +umap-learn==0.5.9.post2 + # via scanpy +urllib3==1.26.6 + # via + # -r requirements.in + # docker + # requests +uvicorn==0.35.0 + # via mlflow-skinny +werkzeug==3.1.3 + # via flask +zipp==3.23.0 + # via importlib-metadata diff --git a/src/methods/uce_mlflow/script.py b/src/methods/uce_mlflow/script.py new file mode 100644 index 00000000..6e6fffb6 --- /dev/null +++ b/src/methods/uce_mlflow/script.py @@ -0,0 +1,75 @@ +import sys + +import anndata as ad +import mlflow.pyfunc + +## VIASH START +# Note: this section is auto-generated by viash at runtime. To edit it, make changes +# in config.vsh.yaml and then run `viash config inject config.vsh.yaml`. +par = { + "input": "resources_test/.../input.h5ad", + "output": "output.h5ad", + "model": "resources_test/.../model", +} +meta = {"name": "uce_mlflow"} +## VIASH END + +sys.path.append(meta["resources_dir"]) +from exit_codes import exit_non_applicable # noqa: E402 +from mlflow import embed # noqa: E402 +from read_anndata_partial import read_anndata # noqa: E402 +from unpack import unpack_directory # noqa: E402 + +print("====== UCE (MLflow model) ======", flush=True) + +print("\n>>> Reading input files...", flush=True) +print(f"Input H5AD file: '{par['input']}'", flush=True) +adata = read_anndata(par["input"], X="layers/counts", obs="obs", var="var", uns="uns") + +if adata.uns["dataset_organism"] != "homo_sapiens": + exit_non_applicable( + f"UCE (MLflow) can only be used with human data " + f'(dataset_organism == "{adata.uns["dataset_organism"]}")' + ) + +print(adata, flush=True) + +print("\n>>> Unpacking model...", flush=True) +model_dir, model_temp = unpack_directory(par["model"]) + +print("\n>>> Loading model...", flush=True) +model = mlflow.pyfunc.load_model(model_dir) +print(model, flush=True) + +print("\n>>> Embedding data...", flush=True) +embedding = embed( + adata, + model, + layers=["counts"], + var={"feature_name": "feature_name"}, +) + +print("\n>>> Storing output...", flush=True) +output = ad.AnnData( + obs=adata.obs[[]], + var=adata.var[[]], + obsm={ + "X_emb": embedding, + }, + uns={ + "dataset_id": adata.uns["dataset_id"], + "normalization_id": adata.uns["normalization_id"], + "method_id": meta["name"], + }, +) +print(output) + +print("\n>>> Writing output to file...", flush=True) +print(f"Output H5AD file: '{par['output']}'", flush=True) +output.write_h5ad(par["output"], compression="gzip") + +print("\n>>> Cleaning up temporary files...", flush=True) +if model_temp is not None: + model_temp.cleanup() + +print("\n>>> Done!", flush=True) diff --git a/src/utils/mlflow.py b/src/utils/mlflow.py new file mode 100644 index 00000000..447614e6 --- /dev/null +++ b/src/utils/mlflow.py @@ -0,0 +1,174 @@ +""" +Common utilities for MLflow-based methods. +""" +import os +import tempfile + +import anndata as ad +import pandas as pd +import sklearn.neighbors + + +def create_temp_h5ad( + adata, layers=None, obs=None, var=None, obsm=None, varm=None, uns=None +): + """ + Create a temporary H5AD file with specified data from an AnnData object. + + Args: + adata: Input AnnData object + layers: List of layer names to include (e.g., ["counts"]) + obs: List of obs column names to include (e.g., ["batch"]) + var: Dict mapping var column names to new names (e.g., {"feature_id": "ensembl_id"}) + obsm: List of obsm keys to include + varm: List of varm keys to include + uns: List of uns keys to include + + Returns: + tuple: (h5ad_file, input_adata) where h5ad_file is the NamedTemporaryFile and + input_adata is the created AnnData object + """ + # Extract X from layers or use X directly + if layers and len(layers) > 0: + X = adata.layers[layers[0]].copy() + else: + X = adata.X.copy() + + # Create new AnnData + input_adata = ad.AnnData(X=X) + + # Set var_names + input_adata.var_names = adata.var_names + + # Add obs columns + if obs: + for obs_key in obs: + if obs_key in adata.obs: + input_adata.obs[obs_key] = adata.obs[obs_key].values + + # Add var columns (with optional renaming) + if var: + for old_name, new_name in var.items(): + if old_name in adata.var: + input_adata.var[new_name] = adata.var[old_name].values + + # Add obsm + if obsm: + for obsm_key in obsm: + if obsm_key in adata.obsm: + input_adata.obsm[obsm_key] = adata.obsm[obsm_key].copy() + + # Add varm + if varm: + for varm_key in varm: + if varm_key in adata.varm: + input_adata.varm[varm_key] = adata.varm[varm_key].copy() + + # Add uns + if uns: + for uns_key in uns: + if uns_key in adata.uns: + input_adata.uns[uns_key] = adata.uns[uns_key] + + # Write to temp file + h5ad_file = tempfile.NamedTemporaryFile(suffix=".h5ad", delete=False) + input_adata.write(h5ad_file.name) + + return h5ad_file, input_adata + + +def embed(adata, model, layers=None, obs=None, var=None, model_params=None, process_adata=None): + """ + Embed data using an MLflow model. + + Args: + adata: Input AnnData object to embed + model: Loaded MLflow model + layers: List of layer names to include (e.g., ["counts"]) + obs: List of obs column names to include (e.g., ["batch"]) + var: Dict mapping var column names to new names (e.g., {"feature_id": "ensembl_id"}) + model_params: Optional dict of parameters to pass to model.predict() + process_adata: Optional function to process input_adata before writing (e.g., to add defaults) + + Returns: + np.ndarray: Embeddings for the input data + """ + print("Writing temporary input H5AD file...", flush=True) + h5ad_file, input_adata = create_temp_h5ad(adata, layers=layers, obs=obs, var=var) + + # Apply any post-processing to input_adata + if process_adata: + process_adata(input_adata) + + print(f"Temporary H5AD file: '{h5ad_file.name}'", flush=True) + print(input_adata, flush=True) + + # Re-write the file after processing + input_adata.write(h5ad_file.name) + + print("Running model...", flush=True) + input_df = pd.DataFrame({"input_uri": [h5ad_file.name]}) + if model_params: + embedding = model.predict(input_df, params=model_params) + else: + embedding = model.predict(input_df) + + # Clean up + h5ad_file.close() + os.unlink(h5ad_file.name) + + return embedding + + +def embed_and_classify( + train_adata, + test_adata, + model, + layers=None, + obs=None, + var=None, + model_params=None, + process_adata=None, + n_neighbors=5, +): + """ + Generic pipeline for embedding data and training a kNN classifier. + + Args: + train_adata: Training AnnData object with labels + test_adata: Test AnnData object to predict + model: Loaded MLflow model + layers: List of layer names to include (e.g., ["counts"]) + obs: List of obs column names to include (e.g., ["batch"]) + var: Dict mapping var column names to new names (e.g., {"feature_id": "ensembl_id"}) + model_params: Optional dict of parameters to pass to model.predict() + process_adata: Optional function to process input_adata before writing (e.g., to add defaults) + n_neighbors: Number of neighbors for kNN classifier + + Returns: + np.ndarray: Predicted labels for test data + """ + # Embed training data + print("\n>>> Embedding training data...", flush=True) + embedding_train = embed( + train_adata, model, layers=layers, obs=obs, var=var, + model_params=model_params, process_adata=process_adata + ) + + # Train kNN classifier + print("\n>>> Training kNN classifier...", flush=True) + classifier = sklearn.neighbors.KNeighborsClassifier(n_neighbors=n_neighbors) + classifier.fit(embedding_train, train_adata.obs["label"].astype(str)) + + # Embed test data + print("\n>>> Embedding test data...", flush=True) + embedding_test = embed( + test_adata, model, layers=layers, obs=obs, var=var, + model_params=model_params, process_adata=process_adata + ) + + # Classify + print("\n>>> Classifying test data...", flush=True) + predictions = classifier.predict(embedding_test) + + return predictions diff --git a/src/utils/mlflow_docker_setup.yaml b/src/utils/mlflow_docker_setup.yaml new file mode 100644 index 00000000..aa03e9a7 --- /dev/null +++ b/src/utils/mlflow_docker_setup.yaml @@ -0,0 +1,14 @@ +- type: docker + add: https://astral.sh/uv/0.7.19/install.sh /uv-installer.sh + run: sh /uv-installer.sh && rm /uv-installer.sh + env: PATH="/root/.local/bin/:$PATH" +- type: docker + run: uv venv --python 3.11 /opt/venv +- type: docker + env: + - VIRTUAL_ENV=/opt/venv + - PATH="/opt/venv/bin:$PATH" + add: requirements.txt /requirements.txt + run: uv pip install -r /requirements.txt && uv pip install mlflow==3.1.0 +- type: docker + run: uv pip install git+https://github.com/openproblems-bio/core#subdirectory=packages/python/openproblems diff --git a/src/utils/unpack.py b/src/utils/unpack.py new file mode 100644 index 00000000..443aa39f --- /dev/null +++ b/src/utils/unpack.py @@ -0,0 +1,43 @@ +import os +import tarfile +import tempfile +import zipfile + +def unpack_directory(directory): + """ + Unpack a directory to a temporary location (if needed) + + Args: + directory (str): Path to a directory, .zip, or .tar.gz file. + + Returns: + tuple: (unpacked_directory (str), temp_directory (TemporaryDirectory or None)) + unpacked_directory: Path to the unpacked directory. + temp_directory: TemporaryDirectory object if a temp dir was created, else None. + """ + print(f"Unpacking directory: '{directory}'", flush=True) + + if os.path.isdir(directory): + print(f"Returning provided directory: '{directory}'", flush=True) + temp_directory = None + unpacked_directory = directory + else: + temp_directory = tempfile.TemporaryDirectory() + unpacked_directory = temp_directory.name + + if zipfile.is_zipfile(directory): + print("Extracting .zip...", flush=True) + with zipfile.ZipFile(directory, "r") as zip_file: + zip_file.extractall(unpacked_directory) + elif tarfile.is_tarfile(directory) and directory.endswith(".tar.gz"): + print("Extracting .tar.gz...", flush=True) + with tarfile.open(directory, "r:gz") as tar_file: + tar_file.extractall(unpacked_directory) + unpacked_directory = os.path.join(unpacked_directory, os.listdir(unpacked_directory)[0]) + else: + raise ValueError( + "The 'directory' argument should be a directory, a .zip file or a .tar.gz file" + ) + print(f"Extracted to '{unpacked_directory}'", flush=True) + + return (unpacked_directory, temp_directory) diff --git a/src/workflows/run_benchmark/config.vsh.yaml b/src/workflows/run_benchmark/config.vsh.yaml index 09905ad0..d9ae23f4 100644 --- a/src/workflows/run_benchmark/config.vsh.yaml +++ b/src/workflows/run_benchmark/config.vsh.yaml @@ -92,7 +92,7 @@ dependencies: - name: methods/batchelor_mnn_correct - name: methods/bbknn - name: methods/combat - - name: methods/geneformer + - name: methods/geneformer_mlflow - name: methods/harmony - name: methods/harmonypy - name: methods/liger @@ -101,12 +101,18 @@ dependencies: - name: methods/scalex - name: methods/scanorama - name: methods/scanvi + - name: methods/scgpt_mlflow - name: methods/scgpt_finetuned - name: methods/scgpt_zeroshot - name: methods/scimilarity - name: methods/scprint - name: methods/scvi + - name: methods/scvi_mlflow + - name: methods/transcriptformer_mlflow - name: methods/uce + - name: methods/uce_mlflow + # outdated methods + # - name: methods/geneformer # metrics - name: metrics/asw_batch - name: metrics/asw_label diff --git a/src/workflows/run_benchmark/main.nf b/src/workflows/run_benchmark/main.nf index 6196f749..88f83327 100644 --- a/src/workflows/run_benchmark/main.nf +++ b/src/workflows/run_benchmark/main.nf @@ -20,7 +20,10 @@ methods = [ batchelor_mnn_correct, bbknn, combat, - geneformer, + // geneformer, + geneformer_mlflow.run( + args: [model: file("s3://openproblems-work/cache/geneformer-mlflow-model.zip")] + ), harmony, harmonypy, liger, @@ -32,6 +35,9 @@ methods = [ scgpt_finetuned.run( args: [model: file("s3://openproblems-work/cache/scGPT_human.zip")] ), + scgpt_mlflow.run( + args: [model: file("s3://openproblems-work/cache/scgpt-mlflow-model.zip")] + ), scgpt_zeroshot.run( args: [model: file("s3://openproblems-work/cache/scGPT_human.zip")] ), @@ -40,8 +46,17 @@ methods = [ ), scprint, scvi, + scvi_mlflow.run( + args: [model: file("s3://openproblems-work/cache/scvi-mlflow-model.zip")] + ), + transcriptformer_mlflow.run( + args: [model: file("s3://openproblems-work/cache/transcriptformer-mlflow-model.zip")] + ), uce.run( args: [model: file("s3://openproblems-work/cache/uce-model-v5.zip")] + ), + uce_mlflow.run( + args: [model: file("s3://openproblems-work/cache/uce-mlflow-model.zip")] ) ] @@ -55,7 +70,7 @@ metrics = [ hvg_overlap, isolated_label_asw, isolated_label_f1, - kbet, + // kbet, kbet_pg, kbet_pg_label, lisi,