Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions scripts/cellmap/configs/mednext_cos7.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
'scale': (8, 8, 8), # 8nm isotropic resolution
}
target_array_info = input_array_info
force_all_classes = 'both' # keep every organelle present in both splits

# Output paths
output_dir = 'outputs/cellmap_cos7'
Expand Down Expand Up @@ -52,6 +53,9 @@
epochs = 500 # Maximum epochs
num_gpus = 1 # Number of GPUs
precision = '16-mixed' # Mixed precision training
iterations_per_epoch = None # Keep dataloader on the cheap shuffle path
train_batches_per_epoch = 2000 # Lightning caps epoch length at 2k steps
val_batches_per_epoch = 200 # Limit validation passes per epoch

# Learning rate scheduler (constant for MedNeXt)
scheduler_config = {
Expand Down
4 changes: 4 additions & 0 deletions scripts/cellmap/configs/mednext_mito.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
'scale': (4, 4, 4), # 4nm isotropic (higher resolution)
}
target_array_info = input_array_info
force_all_classes = 'both' # ensure mito voxels present in both train/val splits

# Output paths
output_dir = 'outputs/cellmap_mito'
Expand Down Expand Up @@ -50,6 +51,9 @@
epochs = 1000 # More epochs for single class
num_gpus = 1 # Number of GPUs
precision = '16-mixed' # Mixed precision training
iterations_per_epoch = None # Leave None so dataloader avoids huge subset shuffles
train_batches_per_epoch = 2000 # Cap Lightning's epoch length instead
val_batches_per_epoch = 200 # Limit validation passes per epoch

# Learning rate scheduler
scheduler_config = {
Expand Down
3 changes: 3 additions & 0 deletions scripts/cellmap/configs/monai_unet_quick.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
# Classes to segment (just 2 classes for quick test)
classes = ['nuc', 'mito']

# Data root path
data_root = '/projects/weilab/dataset/cellmap'

# Data configuration
input_array_info = {
'shape': (64, 64, 64), # Small patches for speed
Expand Down
124 changes: 88 additions & 36 deletions scripts/cellmap/predict_cellmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import numpy as np
from tqdm import tqdm
from monai.inferers import SlidingWindowInferer
import torch.nn.functional as F

# CellMap utilities
from cellmap_segmentation_challenge.utils import TEST_CROPS, load_safe_config
Expand All @@ -45,34 +46,51 @@
from omegaconf import OmegaConf


def find_scale_level(zarr_path, target_resolution):
"""Find the scale level that best matches target resolution."""
def select_scale_level(zarr_path, target_resolution):
"""Return the scale path plus voxel size/translation metadata closest to target resolution."""
store = zarr.open(zarr_path, mode='r')

# Read OME-NGFF multiscale metadata
multiscale_meta = store.attrs.get('multiscales', [{}])[0]
datasets_meta = multiscale_meta.get('datasets', [])

# Default fallback if metadata is missing
if not datasets_meta:
# Fallback: use s2 (typically 8nm)
return 2
return {
"path": "s2",
"voxel_size": np.array(target_resolution, dtype=float),
"translation": np.zeros(3, dtype=float),
}

# Find closest scale to target resolution
best_scale = 0
best = datasets_meta[0]
min_diff = float('inf')

for i, ds_meta in enumerate(datasets_meta):
transforms = ds_meta.get('coordinateTransformations', [{}])
scale = transforms[0].get('scale', [1, 1, 1]) if transforms else [1, 1, 1]
# scale is [z, y, x] in nm
for ds_meta in datasets_meta:
transforms = ds_meta.get('coordinateTransformations', [])
scale = next(
(np.array(t.get('scale', [1, 1, 1]), dtype=float) for t in transforms if t.get('type') == 'scale'),
np.ones(3, dtype=float),
)
avg_resolution = np.mean(scale)
diff = abs(avg_resolution - np.mean(target_resolution))

if diff < min_diff:
min_diff = diff
best_scale = i
best = ds_meta

transforms = best.get('coordinateTransformations', [])
voxel_size = next(
(np.array(t.get('scale', [1, 1, 1]), dtype=float) for t in transforms if t.get('type') == 'scale'),
np.array(target_resolution, dtype=float),
)
translation = next(
(np.array(t.get('translation', [0, 0, 0]), dtype=float) for t in transforms if t.get('type') == 'translation'),
np.zeros(3, dtype=float),
)

return best_scale
return {
"path": best.get('path', 's0'),
"voxel_size": voxel_size,
"translation": translation,
}


def predict_cellmap(checkpoint_path, config_path, output_dir, crop_filter=None):
Expand Down Expand Up @@ -132,14 +150,19 @@ def predict_cellmap(checkpoint_path, config_path, output_dir, crop_filter=None):
model = model.to(device)
print(f"Using device: {device}")

# Setup sliding window inferer (MONAI)
inferer = SlidingWindowInferer(
roi_size=(128, 128, 128),
sw_batch_size=4,
overlap=0.5,
mode='gaussian',
device=torch.device(device),
)
base_roi = (128, 128, 128)
inferer_cache: dict[tuple[int, int, int], SlidingWindowInferer] = {}

def get_inferer(roi_size: tuple[int, int, int]) -> SlidingWindowInferer:
if roi_size not in inferer_cache:
inferer_cache[roi_size] = SlidingWindowInferer(
roi_size=roi_size,
sw_batch_size=4,
overlap=0.5,
mode='gaussian',
device=torch.device(device),
)
return inferer_cache[roi_size]

# Filter test crops if specified
if crop_filter:
Expand Down Expand Up @@ -167,16 +190,20 @@ def predict_cellmap(checkpoint_path, config_path, output_dir, crop_filter=None):

# Find appropriate scale level for target resolution
em_path = f"{zarr_path}/recon-1/em/fibsem-uint8"
scale_level = find_scale_level(em_path, target_resolution)
print(f" Using scale level: s{scale_level} (target resolution: {target_resolution} nm)")
scale_info = select_scale_level(em_path, target_resolution)
scale_level = scale_info['path']
scale_voxel_size = scale_info['voxel_size']
scale_translation = scale_info['translation']
print(f" Using scale level: {scale_level} (voxel size: {scale_voxel_size} nm)")

# Load EM data once for all crops in this dataset
try:
raw_array = zarr.open(f"{em_path}/s{scale_level}", mode='r')
raw_array = zarr.open(f"{em_path}/{scale_level}", mode='r')
except Exception as e:
print(f" Error loading EM data: {e}")
print(f" Skipping dataset {dataset}")
continue
raw_shape = np.array(raw_array.shape, dtype=int)

for crop in tqdm(dataset_crops, desc=f" Crops in {dataset}"):
crop_id = crop.id
Expand All @@ -186,32 +213,57 @@ def predict_cellmap(checkpoint_path, config_path, output_dir, crop_filter=None):
if class_label not in classes:
continue

# Extract crop region from full volume
# Note: This is simplified - in production, use crop.translation and crop.shape
# to extract exact region
# Extract crop region using precise metadata
crop_output_dir = f"{output_dir}/{dataset}/crop{crop_id}"
os.makedirs(crop_output_dir, exist_ok=True)

# Load a reasonable-sized region (simplified)
# In production, use crop metadata to extract exact region
try:
# Get raw data shape
raw_shape = raw_array.shape
target_shape = np.array(crop.shape, dtype=int)
target_voxel = np.array(crop.voxel_size, dtype=float)
translation_nm = np.array(crop.translation, dtype=float)

physical_extent = target_shape * target_voxel
start_idx = np.floor((translation_nm - scale_translation) / scale_voxel_size).astype(int)
end_idx = np.ceil((translation_nm + physical_extent - scale_translation) / scale_voxel_size).astype(int)

end_idx = np.maximum(end_idx, start_idx + 1)
start_idx = np.clip(start_idx, 0, np.maximum(raw_shape - 1, 0))
end_idx = np.clip(end_idx, start_idx + 1, raw_shape)

# Simple extraction (center crop for demo)
# TODO: Use actual crop.translation and crop.shape for exact extraction
d, h, w = min(256, raw_shape[0]), min(256, raw_shape[1]), min(256, raw_shape[2])
raw_volume = raw_array[:d, :h, :w]
slices = tuple(slice(int(s), int(e)) for s, e in zip(start_idx, end_idx))
raw_volume = raw_array[slices]

# Normalize and convert to tensor
raw_volume = np.array(raw_volume).astype(np.float32) / 255.0
raw_tensor = torch.from_numpy(raw_volume[None, None, ...]).to(device) # (1, 1, D, H, W)

roi_size = tuple(
int(max(1, min(base_dim, vol_dim)))
for base_dim, vol_dim in zip(base_roi, raw_volume.shape)
)
inferer = get_inferer(roi_size)

# Run inference
with torch.no_grad():
predictions = inferer(raw_tensor, model)
predictions = torch.sigmoid(predictions).cpu().numpy()[0] # (C, D, H, W)

# Resize predictions back to the official crop shape if needed
target_shape_tuple = tuple(int(x) for x in target_shape)
if predictions.shape[1:] != target_shape_tuple:
pred_tensor = torch.from_numpy(predictions).unsqueeze(0)
predictions = (
F.interpolate(
pred_tensor,
size=target_shape_tuple,
mode="trilinear",
align_corners=False,
)
.squeeze(0)
.cpu()
.numpy()
)

# Save predictions for each class
for i, cls in enumerate(classes):
pred_array = (predictions[i] > 0.5).astype(np.uint8)
Expand Down
Loading
Loading