From fac4de29805ee509c212689577fc8ede811e0344 Mon Sep 17 00:00:00 2001 From: Habib Rehman Date: Tue, 14 Oct 2025 16:09:21 -0400 Subject: [PATCH 1/3] Initial rna2seg script --- common | 2 +- .../rna2seg/config.vsh.yaml | 37 ++++ .../rna2seg/script.py | 180 ++++++++++++++++++ 3 files changed, 218 insertions(+), 1 deletion(-) create mode 100644 src/methods_transcript_assignment/rna2seg/config.vsh.yaml create mode 100644 src/methods_transcript_assignment/rna2seg/script.py diff --git a/common b/common index 79b884b4..65e05af6 160000 --- a/common +++ b/common @@ -1 +1 @@ -Subproject commit 79b884b4c7fed300972d83a6ca025abb6116cbdc +Subproject commit 65e05af68a11ee87853fcf7a3c6b579001f21abe diff --git a/src/methods_transcript_assignment/rna2seg/config.vsh.yaml b/src/methods_transcript_assignment/rna2seg/config.vsh.yaml new file mode 100644 index 00000000..f520b4b1 --- /dev/null +++ b/src/methods_transcript_assignment/rna2seg/config.vsh.yaml @@ -0,0 +1,37 @@ +__merge__: /src/api/comp_method_transcript_assignment.yaml + +name: rna2seg +label: "RNA2Seg Transcript Assignment" +summary: "Assign transcripts to cells using the RNA2Seg method" +description: "RNA2seg is a deep learning-based segmentation model designed to improve cell segmentation in Imaging-based Spatial Transcriptomics (IST)." +links: + documentation: "https://rna2seg.readthedocs.io/en/latest/index.html" + repository: "https://github.com/fish-quant/rna2seg" +references: + doi: "10.1101/2025.03.03.641259" + +arguments: + - name: --transcripts_key + type: string + description: The key of the transcripts within the points of the spatial data + default: transcripts + +resources: + - type: python_script + path: script.py + +engines: + - type: docker + image: openproblems/base_python:1 + __merge__: + - /src/base/setup_spatialdata_partial.yaml + setup: + - type: python + pypi: [rna2seg] + - type: native + +runners: + - type: executable + - type: nextflow + directives: + label: [ hightime, highcpu, midmem ] diff --git a/src/methods_transcript_assignment/rna2seg/script.py b/src/methods_transcript_assignment/rna2seg/script.py new file mode 100644 index 00000000..9a81e098 --- /dev/null +++ b/src/methods_transcript_assignment/rna2seg/script.py @@ -0,0 +1,180 @@ +import os +import shutil +from pathlib import Path +import xarray as xr +import dask +import numpy as np +import pandas as pd +import anndata as ad +import spatialdata as sd +import sopa +import cv2 + +from rna2seg.dataset_zarr.patches import create_patch_rna2seg +import albumentations as A +from rna2seg.dataset_zarr import RNA2segDataset +from rna2seg.models import RNA2seg +from tqdm import tqdm +from rna2seg.utils import save_shapes2zarr + + + +## VIASH START +# Note: this section is auto-generated by viash at runtime. To edit it, make changes +# in config.vsh.yaml and then run `viash config inject config.vsh.yaml`. +par = { + 'input_ist': 'resources_test/task_ist_preprocessing/mouse_brain_combined/raw_ist.zarr', + 'input_segmentation': 'resources_test/task_ist_preprocessing/mouse_brain_combined/segmentation.zarr', + 'transcripts_key': 'transcripts', + 'coordinate_system': 'global', + 'output': './temp/sopa_testing/rna2seg_transcripts.zarr' +} +meta = { + 'name': 'rna2seg', + 'temp_dir': "/Users/habib/Projects/txsim_project/task_ist_preprocessing/temp/sopa", + 'cpus': 10 +} +## VIASH END + +TMP_DIR = Path(meta["temp_dir"] or "/tmp") +TMP_ZARR = TMP_DIR / 'rna2seg_sdata.zarr' + + +# Read input +print('Reading input files', flush=True) +sdata = sd.read_zarr(par['input_ist']) +sdata_segm = sd.read_zarr(par['input_segmentation']) + +# Check if coordinate system is available in input data +transcripts_coord_systems = sd.transformations.get_transformation(sdata[par["transcripts_key"]], get_all=True).keys() +assert par['coordinate_system'] in transcripts_coord_systems, f"Coordinate system '{par['coordinate_system']}' not found in input data." +segmentation_coord_systems = sd.transformations.get_transformation(sdata_segm["segmentation"], get_all=True).keys() +assert par['coordinate_system'] in segmentation_coord_systems, f"Coordinate system '{par['coordinate_system']}' not found in input data." + +# Transform transcript coordinates to the coordinate system +print('Transforming transcripts coordinates', flush=True) +transcripts = sd.transform(sdata[par['transcripts_key']], to_coordinate_system=par['coordinate_system']) + +# In case of a translation transformation of the segmentation (e.g. crop of the data), we need to adjust the transcript coordinates +trans = sd.transformations.get_transformation(sdata_segm["segmentation"], get_all=True)[par['coordinate_system']].inverse() +transcripts = sd.transform(transcripts, trans, par['coordinate_system']) + +### Run RNA2seg with sopa + + +# Create reduced sdata +print("Creating sopa SpatialData object") +sdata_sopa = sd.SpatialData( + points={ + "transcripts": sdata[par['transcripts_key']] + }, + images={ + "morphology_mip": sdata['morphology_mip'] + } +) +sdata_sopa.write(TMP_ZARR, overwrite=True) + +# create patch in the sdata and precompute transcipt.csv for each patch with sopa +image_key = 'morphology_mip' +points_key = par["transcripts_key"] +gene_column_name="feature_name" # typically "feature_name" for Xenium +patch_width = 2000 +patch_overlap = 150 +min_points_per_patch = 1 +folder_patch_rna2seg = Path(TMP_ZARR / f".rna2seg_{patch_width}_{patch_overlap}") +create_patch_rna2seg(sdata=sdata_sopa, + image_key=image_key, + points_key=points_key, + patch_width=patch_width, + patch_overlap=patch_overlap, + min_points_per_patch=min_points_per_patch, + folder_patch_rna2seg = folder_patch_rna2seg, + gene_column_name=gene_column_name, + overwrite = True) + +# Resize and create RNA2Seg dataset object +transform_resize = A.Compose([ + A.Resize(width=512, height=512, interpolation=cv2.INTER_NEAREST), +]) +dataset = RNA2segDataset( + sdata=sdata_sopa, + channels_dapi=[0], + channels_cellbound=[0], #TODO idk + patch_width = patch_width, + patch_overlap = patch_overlap, + gene_column=gene_column_name, + transform_resize = transform_resize, + patch_dir=folder_patch_rna2seg +) + +#TODO how to fix +device = "cpu" + +# Set up RNA2Seg model +rna2seg = RNA2seg( + device, + net='unet', + flow_threshold = 0.9, + cellbound_flow_threshold = 0.4, + pretrained_model = "default_pretrained" +) + +#Run on patches +for i in tqdm(range(len(dataset))): + input_dict = dataset[i] + rna2seg.run( + path_temp_save=folder_patch_rna2seg, + input_dict=input_dict + ) + +# save shapes to zarr +segmentation_shape_name = "rna2seg_boundaries" +save_shapes2zarr(dataset=dataset, + path_parquet_files=folder_patch_rna2seg, + segmentation_key=segmentation_shape_name, + overwrite= True + ) + +# Assign transcripts based on shapes +sopa.spatial.assign_transcript_to_cell( + sdata_sopa, + points_key="transcripts", + shapes_key="rna2seg_boundaries", #TODO what is the key + key_added="cell_id", + unassigned_value=0 +) + +# Create objects for cells table +print('Creating objects for cells table', flush=True) +#create new .obs for cells based on the segmentation output (corresponding with the transcripts 'cell_id') +unique_cells = np.unique(sdata_sopa["transcripts"]["cell_id"]) + +# check if a '0' (noise/background) cell is in cell_id and remove +zero_idx = np.where(unique_cells == 0) +if len(zero_idx[0]): unique_cells=np.delete(unique_cells, zero_idx[0][0]) + +#transform into pandas series and check +cell_id_col = pd.Series(unique_cells, name='cell_id', index=unique_cells) +assert 0 not in cell_id_col, "Found '0' in cell_id column of assingment output cell matrix" + + +# Create transcripts only sdata +print('Subsetting to transcripts cell id data', flush=True) +sdata_transcripts_only = sd.SpatialData( + points={ + "transcripts": sdata_sopa['transcripts'] + }, + tables={ + "table": ad.AnnData( + obs=pd.DataFrame(cell_id_col), + var=sdata.tables["table"].var[[]] + ) + } +) + +# Write output +print('Write transcripts with cell ids', flush=True) +if os.path.exists(par["output"]): + shutil.rmtree(par["output"]) + +sdata_transcripts_only.write(par['output']) From 7e566877a558e9bc6283acbffb50776030caf4d4 Mon Sep 17 00:00:00 2001 From: Habib Rehman Date: Tue, 4 Nov 2025 10:42:36 -0500 Subject: [PATCH 2/3] attempted to fix docker issues --- src/base/setup_spatialdata_partial.yaml | 3 +- .../rna2seg/config.vsh.yaml | 46 ++++++++++ .../rna2seg/script.py | 92 ++++++++++++++----- 3 files changed, 119 insertions(+), 22 deletions(-) diff --git a/src/base/setup_spatialdata_partial.yaml b/src/base/setup_spatialdata_partial.yaml index 517e3ba8..9d542cb2 100644 --- a/src/base/setup_spatialdata_partial.yaml +++ b/src/base/setup_spatialdata_partial.yaml @@ -1,3 +1,4 @@ setup: - type: python - pypi: [spatialdata, "anndata>=0.12.0"] + pypi: [spatialdata, "anndata>=0.12.0", "pyarrow<22.0.0"] + # remove pyarrow when https://github.com/scverse/spatialdata/issues/1007 is fixed \ No newline at end of file diff --git a/src/methods_transcript_assignment/rna2seg/config.vsh.yaml b/src/methods_transcript_assignment/rna2seg/config.vsh.yaml index f520b4b1..34a1e19c 100644 --- a/src/methods_transcript_assignment/rna2seg/config.vsh.yaml +++ b/src/methods_transcript_assignment/rna2seg/config.vsh.yaml @@ -16,6 +16,47 @@ arguments: description: The key of the transcripts within the points of the spatial data default: transcripts + - name: --coordinate_system + type: string + description: The key of the pixel space coordinate system within the spatial data + default: global + + - name: --flow_threshold + type: double + description: Flow threshold for detecting cells + default: 0.9 + + - name: --cellbound_flow_threshold + type: double + description: Cell boundary flow threshold for detecting cells + default: 0.4 + + - name: --create_cytoplasm_image + type: boolean + description: Whether to create an artificial cytoplasm image based on thresholding the nuclear stain image + default: False + + - name: --cytoplasm_min_threshold + type: double + description: Minimum percentile for cytoplasm image thresholding + default: 0.25 + + - name: --cytoplasm_max_threshold + type: double + description: Maximum percentile for cytoplasm image thresholding + default: 0.75 + + - name: patch_width + type: integer + description: Size of sopa patches to tile over image for parallelization. May affect segmentation output. + default: 1000 + + - name: patch_overlap + type: integer + description: Overlap of sopa patches to tile over image for parallelization. + default: 50 + + resources: - type: python_script path: script.py @@ -26,6 +67,11 @@ engines: __merge__: - /src/base/setup_spatialdata_partial.yaml setup: + - type: docker + # env: + # - PATH="/root/.cargo/bin:${PATH}" + run: + - apt-get update && apt-get install libgl1 -y - type: python pypi: [rna2seg] - type: native diff --git a/src/methods_transcript_assignment/rna2seg/script.py b/src/methods_transcript_assignment/rna2seg/script.py index 9a81e098..01def974 100644 --- a/src/methods_transcript_assignment/rna2seg/script.py +++ b/src/methods_transcript_assignment/rna2seg/script.py @@ -9,6 +9,7 @@ import spatialdata as sd import sopa import cv2 +import torch from rna2seg.dataset_zarr.patches import create_patch_rna2seg import albumentations as A @@ -27,11 +28,18 @@ 'input_segmentation': 'resources_test/task_ist_preprocessing/mouse_brain_combined/segmentation.zarr', 'transcripts_key': 'transcripts', 'coordinate_system': 'global', - 'output': './temp/sopa_testing/rna2seg_transcripts.zarr' + 'output': './temp/sopa_testing/rna2seg_transcripts.zarr', + 'flow_threshold': 0.9, + 'cellbound_flow_threshold': 0.4, + 'create_cytoplasm_image': False, + 'cytoplasm_min_threshold': 0.25, + 'cytoplasm_max_threshold': 0.75, + 'patch_width': 1000, + 'patch_overlap': 50, } meta = { 'name': 'rna2seg', - 'temp_dir': "/Users/habib/Projects/txsim_project/task_ist_preprocessing/temp/sopa", + 'temp_dir': "/Users/habib/Projects/txsim_project/task_ist_preprocessing/temp/rna2seg", 'cpus': 10 } ## VIASH END @@ -43,25 +51,53 @@ # Read input print('Reading input files', flush=True) sdata = sd.read_zarr(par['input_ist']) -sdata_segm = sd.read_zarr(par['input_segmentation']) +# sdata_segm = sd.read_zarr(par['input_segmentation']) # Check if coordinate system is available in input data transcripts_coord_systems = sd.transformations.get_transformation(sdata[par["transcripts_key"]], get_all=True).keys() assert par['coordinate_system'] in transcripts_coord_systems, f"Coordinate system '{par['coordinate_system']}' not found in input data." -segmentation_coord_systems = sd.transformations.get_transformation(sdata_segm["segmentation"], get_all=True).keys() -assert par['coordinate_system'] in segmentation_coord_systems, f"Coordinate system '{par['coordinate_system']}' not found in input data." +# segmentation_coord_systems = sd.transformations.get_transformation(sdata_segm["segmentation"], get_all=True).keys() +# assert par['coordinate_system'] in segmentation_coord_systems, f"Coordinate system '{par['coordinate_system']}' not found in input data." -# Transform transcript coordinates to the coordinate system -print('Transforming transcripts coordinates', flush=True) -transcripts = sd.transform(sdata[par['transcripts_key']], to_coordinate_system=par['coordinate_system']) - -# In case of a translation transformation of the segmentation (e.g. crop of the data), we need to adjust the transcript coordinates -trans = sd.transformations.get_transformation(sdata_segm["segmentation"], get_all=True)[par['coordinate_system']].inverse() -transcripts = sd.transform(transcripts, trans, par['coordinate_system']) ### Run RNA2seg with sopa +### CREATE CYTOPLASM IMAGE FUNCTION +# TODO define this function somewhere else and import +def get_nuclear_outline(nuclear_image, threshold_min = 0.25, threshold_max = 0.75): + threshold_image = np.clip(nuclear_image, np.quantile(nuclear_image, threshold_min), np.quantile(nuclear_image, threshold_max)) + # get the nuclear values (over a nuclear mask) + nuclear_mask = (nuclear_image > np.quantile(nuclear_image, threshold_max)) * nuclear_image + # scale nuclear values to whole cell values + scaling_factor = (np.max(threshold_image) - np.min(threshold_image)) / np.max(nuclear_mask) + # subtract nucleus from whole cell to get cytoplasm + cyto_image = threshold_image - (nuclear_mask * scaling_factor) + cyto_image = np.clip(cyto_image, 0 ,np.inf).astype(nuclear_image.dtype) + return cyto_image + + +#create composite image with 2nd channel as either 0s or generated cytoplasm image +nuclear_image = sdata['morphology_mip']['scale0'].image.compute().to_numpy() +composite = np.zeros([2, nuclear_image.shape[1], nuclear_image.shape[2]], dtype=nuclear_image.dtype) +composite[0,:,:] = nuclear_image + +if par['create_cytoplasm_image']: + cyto_image = get_nuclear_outline(nuclear_image=nuclear_image, + threshold_min=par['cytoplasm_min_threshold'], + threshold_max=par['cytoplasm_max_threshold']) + composite[1,:,:] = cyto_image +# else: # redundant since the matrix is initialized to zeros +# composite[1,:,:] = 0 +morphology_mip = sd.models.Image2DModel.parse(data=composite, + scale_factors=[2]*(len(sdata['morphology_mip'].groups)-2), + dims=['c','y','x'], + chunks=composite.shape) + +#make sure image is transformed correctly +img_trans=sd.transformations.get_transformation(sdata['morphology_mip'], to_coordinate_system=par['coordinate_system']) +sd.transformations.set_transformation(morphology_mip, img_trans, to_coordinate_system=par['coordinate_system']) + # Create reduced sdata print("Creating sopa SpatialData object") sdata_sopa = sd.SpatialData( @@ -69,17 +105,19 @@ "transcripts": sdata[par['transcripts_key']] }, images={ - "morphology_mip": sdata['morphology_mip'] + "morphology_mip": morphology_mip } ) + sdata_sopa.write(TMP_ZARR, overwrite=True) +print("Running RNA2Seg") # create patch in the sdata and precompute transcipt.csv for each patch with sopa image_key = 'morphology_mip' points_key = par["transcripts_key"] gene_column_name="feature_name" # typically "feature_name" for Xenium -patch_width = 2000 -patch_overlap = 150 +patch_width = par['patch_width'] +patch_overlap = par['patch_overlap'] min_points_per_patch = 1 folder_patch_rna2seg = Path(TMP_ZARR / f".rna2seg_{patch_width}_{patch_overlap}") create_patch_rna2seg(sdata=sdata_sopa, @@ -99,7 +137,7 @@ dataset = RNA2segDataset( sdata=sdata_sopa, channels_dapi=[0], - channels_cellbound=[0], #TODO idk + channels_cellbound=[1], patch_width = patch_width, patch_overlap = patch_overlap, gene_column=gene_column_name, @@ -107,15 +145,14 @@ patch_dir=folder_patch_rna2seg ) -#TODO how to fix -device = "cpu" +device = 'cuda' if torch.cuda.is_available() else 'cpu' # Set up RNA2Seg model rna2seg = RNA2seg( device, net='unet', - flow_threshold = 0.9, - cellbound_flow_threshold = 0.4, + flow_threshold = par['flow_threshold'], + cellbound_flow_threshold = par['cellbound_flow_threshold'], pretrained_model = "default_pretrained" ) @@ -135,11 +172,23 @@ overwrite= True ) +# ONLY IF TESTING/USING CROP +# for whatever reason the cropping breaks the rna2seg transformation +# this fixes it, somehow +transcript_min = sd.transform(sdata_sopa['transcripts'],to_coordinate_system=par['coordinate_system']).compute()['x'].min() +shapes_max = sd.transform(sdata_sopa['rna2seg_boundaries'],to_coordinate_system=par['coordinate_system']).bounds['maxx'].max() +if transcript_min > shapes_max: + print(f"crop detected ({transcript_min} > {shapes_max}), reformatting") + trans = sd.transformations.get_transformation(sdata_sopa['morphology_mip'], to_coordinate_system=par['coordinate_system']) + sd.transformations.set_transformation(sdata_sopa['rna2seg_boundaries'], trans, to_coordinate_system=par['coordinate_system']) + # print(sd.transform(sdata_sopa['rna2seg_boundaries'],to_coordinate_system=par['coordinate_system']).bounds['maxx'].max()) + + # Assign transcripts based on shapes sopa.spatial.assign_transcript_to_cell( sdata_sopa, points_key="transcripts", - shapes_key="rna2seg_boundaries", #TODO what is the key + shapes_key="rna2seg_boundaries", key_added="cell_id", unassigned_value=0 ) @@ -148,6 +197,7 @@ print('Creating objects for cells table', flush=True) #create new .obs for cells based on the segmentation output (corresponding with the transcripts 'cell_id') unique_cells = np.unique(sdata_sopa["transcripts"]["cell_id"]) +# print(unique_cells) # check if a '0' (noise/background) cell is in cell_id and remove zero_idx = np.where(unique_cells == 0) From d532aa91d6a3532685810ef8d898546eff9a7d40 Mon Sep 17 00:00:00 2001 From: Habib Rehman Date: Sun, 16 Nov 2025 22:33:28 -0500 Subject: [PATCH 3/3] Fixed docker env --- src/methods_transcript_assignment/rna2seg/config.vsh.yaml | 2 +- src/methods_transcript_assignment/rna2seg/script.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/methods_transcript_assignment/rna2seg/config.vsh.yaml b/src/methods_transcript_assignment/rna2seg/config.vsh.yaml index 34a1e19c..2874db6c 100644 --- a/src/methods_transcript_assignment/rna2seg/config.vsh.yaml +++ b/src/methods_transcript_assignment/rna2seg/config.vsh.yaml @@ -73,7 +73,7 @@ engines: run: - apt-get update && apt-get install libgl1 -y - type: python - pypi: [rna2seg] + pypi: [rna2seg, "numcodecs==0.15.1"] - type: native runners: diff --git a/src/methods_transcript_assignment/rna2seg/script.py b/src/methods_transcript_assignment/rna2seg/script.py index 01def974..bbdbf085 100644 --- a/src/methods_transcript_assignment/rna2seg/script.py +++ b/src/methods_transcript_assignment/rna2seg/script.py @@ -45,7 +45,7 @@ ## VIASH END TMP_DIR = Path(meta["temp_dir"] or "/tmp") -TMP_ZARR = TMP_DIR / 'rna2seg_sdata.zarr' +TMP_ZARR = TMP_DIR / 'temp_rna2seg_sdata.zarr' # Read input