Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions src/methods_transcript_assignment/rna2seg/config.vsh.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
__merge__: /src/api/comp_method_transcript_assignment.yaml

name: rna2seg
label: "RNA2Seg Transcript Assignment"
summary: "Assign transcripts to cells using the RNA2Seg method"
description: "RNA2seg is a deep learning-based segmentation model designed to improve cell segmentation in Imaging-based Spatial Transcriptomics (IST)."
links:
documentation: "https://rna2seg.readthedocs.io/en/latest/index.html"
repository: "https://github.com/fish-quant/rna2seg"
references:
doi: "10.1101/2025.03.03.641259"

arguments:
- name: --transcripts_key
type: string
description: The key of the transcripts within the points of the spatial data
default: transcripts

- name: --coordinate_system
type: string
description: The key of the pixel space coordinate system within the spatial data
default: global

- name: --flow_threshold
type: double
description: Flow threshold for detecting cells
default: 0.9

- name: --cellbound_flow_threshold
type: double
description: Cell boundary flow threshold for detecting cells
default: 0.4

- name: --create_cytoplasm_image
type: boolean
description: Whether to create an artificial cytoplasm image based on thresholding the nuclear stain image
default: False

- name: --cytoplasm_min_threshold
type: double
description: Minimum percentile for cytoplasm image thresholding
default: 0.25

- name: --cytoplasm_max_threshold
type: double
description: Maximum percentile for cytoplasm image thresholding
default: 0.75

- name: patch_width
type: integer
description: Size of sopa patches to tile over image for parallelization. May affect segmentation output.
default: 1000

- name: patch_overlap
type: integer
description: Overlap of sopa patches to tile over image for parallelization.
default: 50


resources:
- type: python_script
path: script.py

engines:
- type: docker
image: openproblems/base_python:1
__merge__:
- /src/base/setup_spatialdata_partial.yaml
setup:
- type: docker
# env:
# - PATH="/root/.cargo/bin:${PATH}"
run:
- apt-get update && apt-get install libgl1 -y
- type: python
pypi: [rna2seg, "numcodecs==0.15.1"]
- type: native

runners:
- type: executable
- type: nextflow
directives:
label: [ hightime, highcpu, midmem ]
230 changes: 230 additions & 0 deletions src/methods_transcript_assignment/rna2seg/script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
import os
import shutil
from pathlib import Path
import xarray as xr
import dask
import numpy as np
import pandas as pd
import anndata as ad
import spatialdata as sd
import sopa
import cv2
import torch

from rna2seg.dataset_zarr.patches import create_patch_rna2seg
import albumentations as A
from rna2seg.dataset_zarr import RNA2segDataset
from rna2seg.models import RNA2seg
from tqdm import tqdm
from rna2seg.utils import save_shapes2zarr



## VIASH START
# Note: this section is auto-generated by viash at runtime. To edit it, make changes
# in config.vsh.yaml and then run `viash config inject config.vsh.yaml`.
par = {
'input_ist': 'resources_test/task_ist_preprocessing/mouse_brain_combined/raw_ist.zarr',
'input_segmentation': 'resources_test/task_ist_preprocessing/mouse_brain_combined/segmentation.zarr',
'transcripts_key': 'transcripts',
'coordinate_system': 'global',
'output': './temp/sopa_testing/rna2seg_transcripts.zarr',
'flow_threshold': 0.9,
'cellbound_flow_threshold': 0.4,
'create_cytoplasm_image': False,
'cytoplasm_min_threshold': 0.25,
'cytoplasm_max_threshold': 0.75,
'patch_width': 1000,
'patch_overlap': 50,
}
meta = {
'name': 'rna2seg',
'temp_dir': "/Users/habib/Projects/txsim_project/task_ist_preprocessing/temp/rna2seg",
'cpus': 10
}
## VIASH END

TMP_DIR = Path(meta["temp_dir"] or "/tmp")
TMP_ZARR = TMP_DIR / 'temp_rna2seg_sdata.zarr'


# Read input
print('Reading input files', flush=True)
sdata = sd.read_zarr(par['input_ist'])
# sdata_segm = sd.read_zarr(par['input_segmentation'])

# Check if coordinate system is available in input data
transcripts_coord_systems = sd.transformations.get_transformation(sdata[par["transcripts_key"]], get_all=True).keys()
assert par['coordinate_system'] in transcripts_coord_systems, f"Coordinate system '{par['coordinate_system']}' not found in input data."
# segmentation_coord_systems = sd.transformations.get_transformation(sdata_segm["segmentation"], get_all=True).keys()
# assert par['coordinate_system'] in segmentation_coord_systems, f"Coordinate system '{par['coordinate_system']}' not found in input data."


### Run RNA2seg with sopa


### CREATE CYTOPLASM IMAGE FUNCTION
# TODO define this function somewhere else and import
def get_nuclear_outline(nuclear_image, threshold_min = 0.25, threshold_max = 0.75):
threshold_image = np.clip(nuclear_image, np.quantile(nuclear_image, threshold_min), np.quantile(nuclear_image, threshold_max))
# get the nuclear values (over a nuclear mask)
nuclear_mask = (nuclear_image > np.quantile(nuclear_image, threshold_max)) * nuclear_image
# scale nuclear values to whole cell values
scaling_factor = (np.max(threshold_image) - np.min(threshold_image)) / np.max(nuclear_mask)
# subtract nucleus from whole cell to get cytoplasm
cyto_image = threshold_image - (nuclear_mask * scaling_factor)
cyto_image = np.clip(cyto_image, 0 ,np.inf).astype(nuclear_image.dtype)
return cyto_image


#create composite image with 2nd channel as either 0s or generated cytoplasm image
nuclear_image = sdata['morphology_mip']['scale0'].image.compute().to_numpy()
composite = np.zeros([2, nuclear_image.shape[1], nuclear_image.shape[2]], dtype=nuclear_image.dtype)
composite[0,:,:] = nuclear_image

if par['create_cytoplasm_image']:
cyto_image = get_nuclear_outline(nuclear_image=nuclear_image,
threshold_min=par['cytoplasm_min_threshold'],
threshold_max=par['cytoplasm_max_threshold'])
composite[1,:,:] = cyto_image
# else: # redundant since the matrix is initialized to zeros
# composite[1,:,:] = 0
morphology_mip = sd.models.Image2DModel.parse(data=composite,
scale_factors=[2]*(len(sdata['morphology_mip'].groups)-2),
dims=['c','y','x'],
chunks=composite.shape)

#make sure image is transformed correctly
img_trans=sd.transformations.get_transformation(sdata['morphology_mip'], to_coordinate_system=par['coordinate_system'])
sd.transformations.set_transformation(morphology_mip, img_trans, to_coordinate_system=par['coordinate_system'])

# Create reduced sdata
print("Creating sopa SpatialData object")
sdata_sopa = sd.SpatialData(
points={
"transcripts": sdata[par['transcripts_key']]
},
images={
"morphology_mip": morphology_mip
}
)

sdata_sopa.write(TMP_ZARR, overwrite=True)

print("Running RNA2Seg")
# create patch in the sdata and precompute transcipt.csv for each patch with sopa
image_key = 'morphology_mip'
points_key = par["transcripts_key"]
gene_column_name="feature_name" # typically "feature_name" for Xenium
patch_width = par['patch_width']
patch_overlap = par['patch_overlap']
min_points_per_patch = 1
folder_patch_rna2seg = Path(TMP_ZARR / f".rna2seg_{patch_width}_{patch_overlap}")
create_patch_rna2seg(sdata=sdata_sopa,
image_key=image_key,
points_key=points_key,
patch_width=patch_width,
patch_overlap=patch_overlap,
min_points_per_patch=min_points_per_patch,
folder_patch_rna2seg = folder_patch_rna2seg,
gene_column_name=gene_column_name,
overwrite = True)

# Resize and create RNA2Seg dataset object
transform_resize = A.Compose([
A.Resize(width=512, height=512, interpolation=cv2.INTER_NEAREST),
])
dataset = RNA2segDataset(
sdata=sdata_sopa,
channels_dapi=[0],
channels_cellbound=[1],
patch_width = patch_width,
patch_overlap = patch_overlap,
gene_column=gene_column_name,
transform_resize = transform_resize,
patch_dir=folder_patch_rna2seg
)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Set up RNA2Seg model
rna2seg = RNA2seg(
device,
net='unet',
flow_threshold = par['flow_threshold'],
cellbound_flow_threshold = par['cellbound_flow_threshold'],
pretrained_model = "default_pretrained"
)

#Run on patches
for i in tqdm(range(len(dataset))):
input_dict = dataset[i]
rna2seg.run(
path_temp_save=folder_patch_rna2seg,
input_dict=input_dict
)

# save shapes to zarr
segmentation_shape_name = "rna2seg_boundaries"
save_shapes2zarr(dataset=dataset,
path_parquet_files=folder_patch_rna2seg,
segmentation_key=segmentation_shape_name,
overwrite= True
)

# ONLY IF TESTING/USING CROP
# for whatever reason the cropping breaks the rna2seg transformation
# this fixes it, somehow
transcript_min = sd.transform(sdata_sopa['transcripts'],to_coordinate_system=par['coordinate_system']).compute()['x'].min()
shapes_max = sd.transform(sdata_sopa['rna2seg_boundaries'],to_coordinate_system=par['coordinate_system']).bounds['maxx'].max()
if transcript_min > shapes_max:
print(f"crop detected ({transcript_min} > {shapes_max}), reformatting")
trans = sd.transformations.get_transformation(sdata_sopa['morphology_mip'], to_coordinate_system=par['coordinate_system'])
sd.transformations.set_transformation(sdata_sopa['rna2seg_boundaries'], trans, to_coordinate_system=par['coordinate_system'])
# print(sd.transform(sdata_sopa['rna2seg_boundaries'],to_coordinate_system=par['coordinate_system']).bounds['maxx'].max())


# Assign transcripts based on shapes
sopa.spatial.assign_transcript_to_cell(
sdata_sopa,
points_key="transcripts",
shapes_key="rna2seg_boundaries",
key_added="cell_id",
unassigned_value=0
)

# Create objects for cells table
print('Creating objects for cells table', flush=True)
#create new .obs for cells based on the segmentation output (corresponding with the transcripts 'cell_id')
unique_cells = np.unique(sdata_sopa["transcripts"]["cell_id"])
# print(unique_cells)

# check if a '0' (noise/background) cell is in cell_id and remove
zero_idx = np.where(unique_cells == 0)
if len(zero_idx[0]): unique_cells=np.delete(unique_cells, zero_idx[0][0])

#transform into pandas series and check
cell_id_col = pd.Series(unique_cells, name='cell_id', index=unique_cells)
assert 0 not in cell_id_col, "Found '0' in cell_id column of assingment output cell matrix"


# Create transcripts only sdata
print('Subsetting to transcripts cell id data', flush=True)
sdata_transcripts_only = sd.SpatialData(
points={
"transcripts": sdata_sopa['transcripts']
},
tables={
"table": ad.AnnData(
obs=pd.DataFrame(cell_id_col),
var=sdata.tables["table"].var[[]]
)
}
)

# Write output
print('Write transcripts with cell ids', flush=True)
if os.path.exists(par["output"]):
shutil.rmtree(par["output"])

sdata_transcripts_only.write(par['output'])
Loading