From d9f751ff5e60414cbb9cd6acde85a38452234fcd Mon Sep 17 00:00:00 2001 From: BoyuShen2004 Date: Mon, 8 Dec 2025 01:08:32 -0500 Subject: [PATCH] Fix CellMap training issues: resolve NaN losses and infinite epochs - Add data type conversion (_prepare_images/_prepare_labels) to prevent uint8/float16 mismatch - Implement CellMapBalancedLoss with comprehensive NaN/inf handling for sparse segmentation - Add batch skipping for empty batches to prevent NaN contamination in progress bars - Replace problematic pos_weight computation with uniform weights to avoid numerical instability - Add limit_train_batches/limit_val_batches to prevent infinite epochs from large datasets - Enhance loss function with gradient-connected safety values and extensive numerical stability checks --- scripts/cellmap/configs/mednext_cos7.py | 4 + scripts/cellmap/configs/mednext_mito.py | 4 + scripts/cellmap/configs/monai_unet_quick.py | 3 + scripts/cellmap/predict_cellmap.py | 124 +++-- scripts/cellmap/train_cellmap.py | 554 ++++++++++++++++++-- 5 files changed, 596 insertions(+), 93 deletions(-) diff --git a/scripts/cellmap/configs/mednext_cos7.py b/scripts/cellmap/configs/mednext_cos7.py index 79500217..d59903f2 100644 --- a/scripts/cellmap/configs/mednext_cos7.py +++ b/scripts/cellmap/configs/mednext_cos7.py @@ -24,6 +24,7 @@ 'scale': (8, 8, 8), # 8nm isotropic resolution } target_array_info = input_array_info +force_all_classes = 'both' # keep every organelle present in both splits # Output paths output_dir = 'outputs/cellmap_cos7' @@ -52,6 +53,9 @@ epochs = 500 # Maximum epochs num_gpus = 1 # Number of GPUs precision = '16-mixed' # Mixed precision training +iterations_per_epoch = None # Keep dataloader on the cheap shuffle path +train_batches_per_epoch = 2000 # Lightning caps epoch length at 2k steps +val_batches_per_epoch = 200 # Limit validation passes per epoch # Learning rate scheduler (constant for MedNeXt) scheduler_config = { diff --git a/scripts/cellmap/configs/mednext_mito.py b/scripts/cellmap/configs/mednext_mito.py index 40056c65..23c8a0a1 100644 --- a/scripts/cellmap/configs/mednext_mito.py +++ b/scripts/cellmap/configs/mednext_mito.py @@ -22,6 +22,7 @@ 'scale': (4, 4, 4), # 4nm isotropic (higher resolution) } target_array_info = input_array_info +force_all_classes = 'both' # ensure mito voxels present in both train/val splits # Output paths output_dir = 'outputs/cellmap_mito' @@ -50,6 +51,9 @@ epochs = 1000 # More epochs for single class num_gpus = 1 # Number of GPUs precision = '16-mixed' # Mixed precision training +iterations_per_epoch = None # Leave None so dataloader avoids huge subset shuffles +train_batches_per_epoch = 2000 # Cap Lightning's epoch length instead +val_batches_per_epoch = 200 # Limit validation passes per epoch # Learning rate scheduler scheduler_config = { diff --git a/scripts/cellmap/configs/monai_unet_quick.py b/scripts/cellmap/configs/monai_unet_quick.py index 9a8fca9e..42d710b4 100644 --- a/scripts/cellmap/configs/monai_unet_quick.py +++ b/scripts/cellmap/configs/monai_unet_quick.py @@ -14,6 +14,9 @@ # Classes to segment (just 2 classes for quick test) classes = ['nuc', 'mito'] +# Data root path +data_root = '/projects/weilab/dataset/cellmap' + # Data configuration input_array_info = { 'shape': (64, 64, 64), # Small patches for speed diff --git a/scripts/cellmap/predict_cellmap.py b/scripts/cellmap/predict_cellmap.py index 13cf25ea..8b8b44e3 100755 --- a/scripts/cellmap/predict_cellmap.py +++ b/scripts/cellmap/predict_cellmap.py @@ -35,6 +35,7 @@ import numpy as np from tqdm import tqdm from monai.inferers import SlidingWindowInferer +import torch.nn.functional as F # CellMap utilities from cellmap_segmentation_challenge.utils import TEST_CROPS, load_safe_config @@ -45,34 +46,51 @@ from omegaconf import OmegaConf -def find_scale_level(zarr_path, target_resolution): - """Find the scale level that best matches target resolution.""" +def select_scale_level(zarr_path, target_resolution): + """Return the scale path plus voxel size/translation metadata closest to target resolution.""" store = zarr.open(zarr_path, mode='r') - # Read OME-NGFF multiscale metadata multiscale_meta = store.attrs.get('multiscales', [{}])[0] datasets_meta = multiscale_meta.get('datasets', []) + # Default fallback if metadata is missing if not datasets_meta: - # Fallback: use s2 (typically 8nm) - return 2 + return { + "path": "s2", + "voxel_size": np.array(target_resolution, dtype=float), + "translation": np.zeros(3, dtype=float), + } - # Find closest scale to target resolution - best_scale = 0 + best = datasets_meta[0] min_diff = float('inf') - for i, ds_meta in enumerate(datasets_meta): - transforms = ds_meta.get('coordinateTransformations', [{}]) - scale = transforms[0].get('scale', [1, 1, 1]) if transforms else [1, 1, 1] - # scale is [z, y, x] in nm + for ds_meta in datasets_meta: + transforms = ds_meta.get('coordinateTransformations', []) + scale = next( + (np.array(t.get('scale', [1, 1, 1]), dtype=float) for t in transforms if t.get('type') == 'scale'), + np.ones(3, dtype=float), + ) avg_resolution = np.mean(scale) diff = abs(avg_resolution - np.mean(target_resolution)) - if diff < min_diff: min_diff = diff - best_scale = i + best = ds_meta + + transforms = best.get('coordinateTransformations', []) + voxel_size = next( + (np.array(t.get('scale', [1, 1, 1]), dtype=float) for t in transforms if t.get('type') == 'scale'), + np.array(target_resolution, dtype=float), + ) + translation = next( + (np.array(t.get('translation', [0, 0, 0]), dtype=float) for t in transforms if t.get('type') == 'translation'), + np.zeros(3, dtype=float), + ) - return best_scale + return { + "path": best.get('path', 's0'), + "voxel_size": voxel_size, + "translation": translation, + } def predict_cellmap(checkpoint_path, config_path, output_dir, crop_filter=None): @@ -132,14 +150,19 @@ def predict_cellmap(checkpoint_path, config_path, output_dir, crop_filter=None): model = model.to(device) print(f"Using device: {device}") - # Setup sliding window inferer (MONAI) - inferer = SlidingWindowInferer( - roi_size=(128, 128, 128), - sw_batch_size=4, - overlap=0.5, - mode='gaussian', - device=torch.device(device), - ) + base_roi = (128, 128, 128) + inferer_cache: dict[tuple[int, int, int], SlidingWindowInferer] = {} + + def get_inferer(roi_size: tuple[int, int, int]) -> SlidingWindowInferer: + if roi_size not in inferer_cache: + inferer_cache[roi_size] = SlidingWindowInferer( + roi_size=roi_size, + sw_batch_size=4, + overlap=0.5, + mode='gaussian', + device=torch.device(device), + ) + return inferer_cache[roi_size] # Filter test crops if specified if crop_filter: @@ -167,16 +190,20 @@ def predict_cellmap(checkpoint_path, config_path, output_dir, crop_filter=None): # Find appropriate scale level for target resolution em_path = f"{zarr_path}/recon-1/em/fibsem-uint8" - scale_level = find_scale_level(em_path, target_resolution) - print(f" Using scale level: s{scale_level} (target resolution: {target_resolution} nm)") + scale_info = select_scale_level(em_path, target_resolution) + scale_level = scale_info['path'] + scale_voxel_size = scale_info['voxel_size'] + scale_translation = scale_info['translation'] + print(f" Using scale level: {scale_level} (voxel size: {scale_voxel_size} nm)") # Load EM data once for all crops in this dataset try: - raw_array = zarr.open(f"{em_path}/s{scale_level}", mode='r') + raw_array = zarr.open(f"{em_path}/{scale_level}", mode='r') except Exception as e: print(f" Error loading EM data: {e}") print(f" Skipping dataset {dataset}") continue + raw_shape = np.array(raw_array.shape, dtype=int) for crop in tqdm(dataset_crops, desc=f" Crops in {dataset}"): crop_id = crop.id @@ -186,32 +213,57 @@ def predict_cellmap(checkpoint_path, config_path, output_dir, crop_filter=None): if class_label not in classes: continue - # Extract crop region from full volume - # Note: This is simplified - in production, use crop.translation and crop.shape - # to extract exact region + # Extract crop region using precise metadata crop_output_dir = f"{output_dir}/{dataset}/crop{crop_id}" os.makedirs(crop_output_dir, exist_ok=True) - # Load a reasonable-sized region (simplified) - # In production, use crop metadata to extract exact region try: - # Get raw data shape - raw_shape = raw_array.shape + target_shape = np.array(crop.shape, dtype=int) + target_voxel = np.array(crop.voxel_size, dtype=float) + translation_nm = np.array(crop.translation, dtype=float) + + physical_extent = target_shape * target_voxel + start_idx = np.floor((translation_nm - scale_translation) / scale_voxel_size).astype(int) + end_idx = np.ceil((translation_nm + physical_extent - scale_translation) / scale_voxel_size).astype(int) + + end_idx = np.maximum(end_idx, start_idx + 1) + start_idx = np.clip(start_idx, 0, np.maximum(raw_shape - 1, 0)) + end_idx = np.clip(end_idx, start_idx + 1, raw_shape) - # Simple extraction (center crop for demo) - # TODO: Use actual crop.translation and crop.shape for exact extraction - d, h, w = min(256, raw_shape[0]), min(256, raw_shape[1]), min(256, raw_shape[2]) - raw_volume = raw_array[:d, :h, :w] + slices = tuple(slice(int(s), int(e)) for s, e in zip(start_idx, end_idx)) + raw_volume = raw_array[slices] # Normalize and convert to tensor raw_volume = np.array(raw_volume).astype(np.float32) / 255.0 raw_tensor = torch.from_numpy(raw_volume[None, None, ...]).to(device) # (1, 1, D, H, W) + roi_size = tuple( + int(max(1, min(base_dim, vol_dim))) + for base_dim, vol_dim in zip(base_roi, raw_volume.shape) + ) + inferer = get_inferer(roi_size) + # Run inference with torch.no_grad(): predictions = inferer(raw_tensor, model) predictions = torch.sigmoid(predictions).cpu().numpy()[0] # (C, D, H, W) + # Resize predictions back to the official crop shape if needed + target_shape_tuple = tuple(int(x) for x in target_shape) + if predictions.shape[1:] != target_shape_tuple: + pred_tensor = torch.from_numpy(predictions).unsqueeze(0) + predictions = ( + F.interpolate( + pred_tensor, + size=target_shape_tuple, + mode="trilinear", + align_corners=False, + ) + .squeeze(0) + .cpu() + .numpy() + ) + # Save predictions for each class for i, cls in enumerate(classes): pred_array = (predictions[i] > 0.5).astype(np.uint8) diff --git a/scripts/cellmap/train_cellmap.py b/scripts/cellmap/train_cellmap.py index 52750e44..97beb091 100755 --- a/scripts/cellmap/train_cellmap.py +++ b/scripts/cellmap/train_cellmap.py @@ -20,31 +20,197 @@ import os import sys from pathlib import Path +from typing import Mapping, Sequence # Add PyTC to path PYTC_ROOT = Path(__file__).parent.parent.parent sys.path.insert(0, str(PYTC_ROOT)) import torch +import torch.nn as nn +import torch.nn.functional as F import pytorch_lightning as pl from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor from pytorch_lightning.loggers import TensorBoardLogger # CellMap data loading (official) from cellmap_segmentation_challenge.utils import ( - get_dataloader, # Official dataloader factory - make_datasplit_csv, # Auto-generate train/val split - get_tested_classes, # Official class list - CellMapLossWrapper, # NaN-aware loss + get_dataloader, # Official dataloader factory + make_datasplit_csv, # Auto-generate train/val split + make_s3_datasplit_csv, + get_tested_classes, # Official class list + load_safe_config, ) from cellmap_segmentation_challenge import config as cellmap_cfg +from upath import UPath + +# --------------------------------------------------------------------------- +# Compatibility patch: +# xarray-tensorstore>=0.3.0 expects callers to pass the zarr_format argument to +# _zarr_spec_from_path. cellmap-data currently calls it with a single argument. +# Monkey patch the helper so we fall back to zarr v2 automatically. +# --------------------------------------------------------------------------- +try: + import inspect + import xarray_tensorstore as xt + + _orig_zarr_spec = xt._zarr_spec_from_path + _spec_sig = inspect.signature(_orig_zarr_spec) + needs_patch = ( + len(_spec_sig.parameters) >= 2 + and list(_spec_sig.parameters.values())[1].default is inspect._empty + ) + + if needs_patch: + def _compat_zarr_spec(path: str, zarr_format: int | None = None): + if zarr_format is None: + zarr_format = 2 + return _orig_zarr_spec(path, zarr_format) + + xt._zarr_spec_from_path = _compat_zarr_spec +except Exception as patch_err: # pragma: no cover - best-effort guard + print(f"[WARN] Failed to patch xarray_tensorstore: {patch_err}") # PyTC model building (import only, no modification) from connectomics.models import build_model from connectomics.models.loss import create_loss # Import config utilities -from cellmap_segmentation_challenge.utils import load_safe_config + +class CellMapBalancedLoss(nn.Module): + """Combines class-balanced BCE with soft Dice while masking NaNs.""" + + def __init__(self, bce_weight: float = 0.5, dice_weight: float = 0.5, eps: float = 1e-6): + super().__init__() + self.bce_weight = bce_weight + self.dice_weight = dice_weight + self.eps = eps + + def forward(self, logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + mask = torch.isfinite(target).float() + valid_voxels = mask.sum() + + if valid_voxels == 0: + # No supervision in this patch; return zero loss but keep gradients defined. + return logits.sum() * 0.0 + + target = target.nan_to_num(0.0) + + # Clamp logits to prevent numerical instability (overflow in sigmoid/BCE) + logits = torch.clamp(logits, -10.0, 10.0) + # Compute per-channel pos_weight to counter extreme imbalance + dims = (0, 2, 3, 4) if logits.dim() == 5 else tuple(range(logits.dim() - 1)) + pos = (target * mask).sum(dim=dims) + neg = ((1.0 - target) * mask).sum(dim=dims) + + # For sparse segmentation (like CellMap), compute loss on all classes + # even when some classes have no positive examples in a batch + + # Check if any classes have positive examples + valid_classes = (pos > 0) + + # Debug: check for potential issues + if torch.isnan(pos).any() or torch.isinf(pos).any() or torch.isnan(neg).any() or torch.isinf(neg).any(): + # If pos/neg have NaN/inf, return safe loss + return logits.sum() * 0.0 + 1e-6 + + if not valid_classes.any(): + # No classes have positive examples; return small finite loss to keep gradients flowing + return logits.sum() * 0.0 + 1e-6 + + # For multi-class sparse segmentation, extreme pos_weight can cause numerical issues + # Use a simpler approach: uniform weight for all classes to avoid extreme imbalance + pos_weight = torch.ones_like(pos) + + # Compute per-class losses and average across classes + num_classes = logits.shape[1] + bce_losses = [] + dice_losses = [] + + for c in range(num_classes): + logits_c = logits[:, c] + target_c = target[:, c] + mask_c = mask[:, c] if mask is not None else None + + # BCE for this class with pos_weight to handle class imbalance + if mask_c is not None and mask_c.sum() > 0: + # Get pos_weight as a tensor (0D tensor) for this class + pos_weight_c = pos_weight[c] # Keep as tensor, not scalar + bce_c = F.binary_cross_entropy_with_logits( + logits_c, + target_c, + weight=mask_c, + pos_weight=pos_weight_c, # Add pos_weight to handle class imbalance + reduction="mean", # Mean for this class + ) + # Safety check: if BCE is not finite, use 0 but maintain gradient connection + if not torch.isfinite(bce_c): + bce_c = logits_c.sum() * 0.0 + bce_losses.append(bce_c) + else: + bce_losses.append(logits_c.sum() * 0.0) + + # Dice for this class + probs_c = torch.sigmoid(logits_c) * (mask_c if mask_c is not None else torch.ones_like(logits_c)) + # Clamp probabilities to prevent numerical issues + probs_c = torch.clamp(probs_c, 1e-7, 1.0 - 1e-7) + target_masked_c = target_c * (mask_c if mask_c is not None else torch.ones_like(target_c)) + + # Use correct spatial dimensions for the per-class tensor (4D: B, D, H, W) + # After removing channel dim, spatial dims are (1, 2, 3) for 4D tensor + spatial_dims = tuple(range(1, logits_c.dim())) + intersection_c = (probs_c * target_masked_c).sum(dim=spatial_dims) # Sum spatial dims, shape: (B,) + denom_c = probs_c.sum(dim=spatial_dims) + target_masked_c.sum(dim=spatial_dims) # Shape: (B,) + + # Replace any NaN/inf values with 0 to prevent propagation + intersection_c = intersection_c.nan_to_num(0.0) + denom_c = denom_c.nan_to_num(0.0) + + # Compute Dice per batch element, handling edge cases + # Use torch.where to handle denom_c == 0 case, and ensure no NaN in computation + numerator = 2.0 * intersection_c + self.eps + denominator = denom_c + self.eps + + # Replace any NaN/inf in numerator/denominator + numerator = numerator.nan_to_num(0.0) + denominator = denominator.nan_to_num(self.eps) # Ensure denominator is never 0 + + dice_per_batch = torch.where( + denom_c > 0, + 1.0 - numerator / denominator, + torch.zeros_like(intersection_c) + ) + + # Replace any NaN/inf in dice_per_batch before averaging + dice_per_batch = dice_per_batch.nan_to_num(0.0) + + # Average over batch dimension to get scalar for this class + dice_c = dice_per_batch.mean() + + # Final safety check: if Dice is not finite, use 0 but maintain gradient connection + if not torch.isfinite(dice_c): + dice_c = logits_c.sum() * 0.0 + dice_losses.append(dice_c) + + # Ensure all individual losses are finite before stacking (use gradient-connected replacements) + bce_losses = [l if torch.isfinite(l) else (logits.sum() * 0.0) for l in bce_losses] + dice_losses = [l if torch.isfinite(l) else (logits.sum() * 0.0) for l in dice_losses] + + # Average across classes + bce = torch.stack(bce_losses).mean() + dice = torch.stack(dice_losses).mean() + + # Safety check for averaged losses + if not torch.isfinite(bce) or not torch.isfinite(dice): + return logits.sum() * 0.0 + 1e-6 + + loss = self.bce_weight * bce + self.dice_weight * dice + + # Final safety check: if loss is NaN or inf, return a small finite loss that maintains gradients + if not torch.isfinite(loss): + return logits.sum() * 0.0 + 1e-6 + + return loss class CellMapLightningModule(pl.LightningModule): @@ -74,6 +240,20 @@ def __init__( # Save hyperparameters self.save_hyperparameters(ignore=['model', 'criterion']) + def _prepare_images(self, images: torch.Tensor) -> torch.Tensor: + """Cast raw EM voxels to float and normalize if they arrive as uint8.""" + if images.dtype == torch.uint8: + images = images.float().div_(255.0) + else: + images = images.float() + return images + + def _prepare_labels(self, labels: torch.Tensor) -> torch.Tensor: + """Ensure supervision tensors are float32 before loss computations.""" + if labels.dtype != torch.float32: + labels = labels.float() + return labels + def _maybe_resample(self, images: torch.Tensor, labels: torch.Tensor): """Optionally resample images/labels to a fixed shape to avoid scale filtering.""" if self.target_shape is None: @@ -92,47 +272,109 @@ def forward(self, x): return self.model(x) def training_step(self, batch, batch_idx): - images = batch['input'] - labels = batch['output'] + images = self._prepare_images(batch['input']) + labels = self._prepare_labels(batch['output']) + batch_size = images.shape[0] images, labels = self._maybe_resample(images, labels) + + # Let the loss function handle all edge cases (empty batches, NaN, etc.) + # Don't skip batches - the loss function returns safe finite values for all cases predictions = self(images) + predictions = self._normalize_predictions(predictions, labels.shape[-3:]) loss = self.criterion(predictions, labels) + # The loss function should always return a finite value, but double-check + if not torch.isfinite(loss): + # This should never happen if loss function is correct, but handle gracefully + loss = predictions.sum() * 0.0 + 1e-6 + self.log('train/loss', loss, prog_bar=True, sync_dist=True) return loss def validation_step(self, batch, batch_idx): - images = batch['input'] - labels = batch['output'] + images = self._prepare_images(batch['input']) + labels = self._prepare_labels(batch['output']) + batch_size = images.shape[0] images, labels = self._maybe_resample(images, labels) + + # Let the loss function handle all edge cases (empty batches, NaN, etc.) + # Don't skip batches - the loss function returns safe finite values for all cases + valid_mask = torch.isfinite(labels).float() + labels = labels.nan_to_num(0.0) + predictions = self(images) + predictions = self._normalize_predictions(predictions, labels.shape[-3:]) loss = self.criterion(predictions, labels) + # The loss function should always return a finite value, but double-check + if not torch.isfinite(loss): + # This should never happen if loss function is correct, but handle gracefully + loss = predictions.sum() * 0.0 + 1e-6 + # Compute Dice score per class with torch.no_grad(): pred_binary = (torch.sigmoid(predictions) > 0.5).float() # Average Dice across classes dice_scores = [] + eps = 1e-7 for c in range(predictions.shape[1]): - pred_c = pred_binary[:, c] - label_c = labels[:, c] - intersection = (pred_c * label_c).sum() - dice = (2. * intersection) / (pred_c.sum() + label_c.sum() + 1e-7) + mask_c = valid_mask[:, c] + valid_vox = mask_c.sum() + if valid_vox == 0: + dice = torch.tensor(1.0, device=labels.device) + else: + pred_c = pred_binary[:, c] * mask_c + label_c = labels[:, c] * mask_c + intersection = (pred_c * label_c).sum() + denom = pred_c.sum() + label_c.sum() + dice = (2. * intersection + eps) / (denom + eps) dice_scores.append(dice) # Log per-class Dice if we have class names if c < len(self.classes): - self.log(f'val/dice_{self.classes[c]}', dice, sync_dist=True) + self.log( + f'val/dice_{self.classes[c]}', + dice, + sync_dist=True, + batch_size=batch_size, + ) mean_dice = torch.stack(dice_scores).mean() - self.log('val/loss', loss, prog_bar=True, sync_dist=True) - self.log('val/dice', mean_dice, prog_bar=True, sync_dist=True) + self.log('val/loss', loss, prog_bar=True, sync_dist=True, batch_size=batch_size) + self.log('val/dice', mean_dice, prog_bar=True, sync_dist=True, batch_size=batch_size) return loss + def _normalize_predictions(self, predictions, target_shape): + """ + Convert model outputs to a single tensor aligned with target_shape. + + MedNeXt (and other deep-supervision models) can return a dict of multi-scale logits. + We upsample each prediction to the target shape and average them so downstream loss + functions see consistent tensors. + """ + if not isinstance(predictions, dict): + return predictions + + merged = [] + for tensor in predictions.values(): + if tensor.shape[-3:] != tuple(target_shape): + tensor = F.interpolate( + tensor, + size=target_shape, + mode="trilinear", + align_corners=False, + ) + merged.append(tensor) + + if not merged: + raise ValueError("Prediction dictionary is empty; cannot compute loss.") + + return torch.mean(torch.stack(merged, dim=0), dim=0) + def configure_optimizers(self): optimizer = torch.optim.AdamW( self.parameters(), @@ -174,6 +416,50 @@ def train_cellmap(config_path: str, data_root: str | None = None, target_shape=N # Load config (CellMap's safe config loader) print(f"Loading config from: {config_path}") config = load_safe_config(config_path) + base_experiment_path = getattr(config, "base_experiment_path", None) + if base_experiment_path is not None: + base_experiment_path = UPath(base_experiment_path) + + def _resolve_path(value, default): + """Resolve relative paths against the configured experiment root.""" + if value is None: + value = default + if value is None: + return None + path = UPath(value) + if not path.is_absolute() and base_experiment_path is not None: + path = base_experiment_path / path + return path.path + + def _infer_scale(filter_value, array_info): + """Infer a scale tuple based on filter settings and array metadata.""" + if filter_value in (False, None): + return None + + def _extract(info): + if isinstance(info, Mapping) and "scale" in info: + return info["scale"] + if isinstance(info, Mapping): + for value in info.values(): + result = _extract(value) + if result is not None: + return result + return None + + if filter_value is True: + scale = _extract(array_info) + return tuple(scale) if scale is not None else None + if isinstance(filter_value, (int, float)): + return (float(filter_value),) * 3 + if isinstance(filter_value, Sequence) and not isinstance( + filter_value, (str, bytes) + ): + seq = list(filter_value) + if len(seq) == 1: + return tuple(seq * 3) + return tuple(seq) + return tuple(filter_value) + # Allow CLI overrides if data_root: @@ -182,29 +468,83 @@ def train_cellmap(config_path: str, data_root: str | None = None, target_shape=N setattr(config, "target_shape", target_shape) # Extract config values - model_name = getattr(config, 'model_name', 'mednext') - classes = getattr(config, 'classes', get_tested_classes()) - learning_rate = getattr(config, 'learning_rate', 1e-3) - batch_size = getattr(config, 'batch_size', 2) - max_epochs = getattr(config, 'epochs', 1000) - num_gpus = getattr(config, 'num_gpus', 1) - precision = getattr(config, 'precision', '16-mixed') - - # Output paths - output_dir = getattr(config, 'output_dir', 'outputs/cellmap') + model_name = getattr(config, "model_name", "mednext") + classes = getattr(config, "classes", get_tested_classes()) + learning_rate = getattr(config, "learning_rate", 1e-3) + batch_size = getattr(config, "batch_size", 2) + batch_size = getattr(config, "train_micro_batch_size_per_gpu", batch_size) + max_epochs = getattr(config, "epochs", 1000) + num_gpus = getattr(config, "num_gpus", 1) + precision = getattr(config, "precision", "16-mixed") + validation_prob = getattr(config, "validation_prob", 0.15) + filter_by_scale = getattr(config, "filter_by_scale", False) + force_classes_mode = getattr(config, "force_all_classes", "both") + use_s3 = getattr(config, "use_s3", False) + weighted_sampler = getattr(config, "weighted_sampler", False) + use_mutual_exclusion = getattr(config, "use_mutual_exclusion", False) + train_raw_value_transforms = getattr( + config, "train_raw_value_transforms", None + ) + val_raw_value_transforms = getattr(config, "val_raw_value_transforms", None) + target_value_transforms = getattr(config, "target_value_transforms", None) + dataloader_kwargs = dict(getattr(config, "dataloader_kwargs", {})) + datasplit_kwargs = dict(getattr(config, "datasplit_kwargs", {})) + train_batches_per_epoch = getattr(config, "train_batches_per_epoch", None) + val_batches_per_epoch = getattr(config, "val_batches_per_epoch", None) + + output_dir = _resolve_path( + getattr(config, "output_dir", "outputs/cellmap"), "outputs/cellmap" + ) os.makedirs(output_dir, exist_ok=True) - datasplit_path = getattr(config, 'datasplit_path', f'{output_dir}/datasplit.csv') - input_array_info = getattr(config, 'input_array_info', { - 'shape': (128, 128, 128), - 'scale': (8, 8, 8), - }) - target_array_info = getattr(config, 'target_array_info', input_array_info) - spatial_transforms = getattr(config, 'spatial_transforms', { - 'mirror': {'axes': {'x': 0.5, 'y': 0.5, 'z': 0.5}}, - 'transpose': {'axes': ['x', 'y', 'z']}, - 'rotate': {'axes': {'x': [-180, 180], 'y': [-180, 180], 'z': [-180, 180]}}, - }) + datasplit_path = _resolve_path( + getattr(config, "datasplit_path", None), + os.path.join(output_dir, "datasplit.csv"), + ) + datasplit_dir = os.path.dirname(datasplit_path) + if datasplit_dir: + os.makedirs(datasplit_dir, exist_ok=True) + + tensorboard_dir = _resolve_path( + getattr(config, "logs_save_path", None), + os.path.join(output_dir, "tensorboard"), + ) + checkpoint_dir = _resolve_path( + getattr(config, "checkpoint_dir", None), + os.path.join(output_dir, "checkpoints"), + ) + if tensorboard_dir: + os.makedirs(tensorboard_dir, exist_ok=True) + if checkpoint_dir: + os.makedirs(checkpoint_dir, exist_ok=True) + + input_array_info = getattr( + config, + "input_array_info", + { + "shape": (128, 128, 128), + "scale": (8, 8, 8), + }, + ) + target_array_info = getattr(config, "target_array_info", input_array_info) + spatial_transforms = getattr( + config, + "spatial_transforms", + { + "mirror": {"axes": {"x": 0.5, "y": 0.5, "z": 0.5}}, + "transpose": {"axes": ["x", "y", "z"]}, + "rotate": { + "axes": { + "x": [-180, 180], + "y": [-180, 180], + "z": [-180, 180], + } + }, + }, + ) + iterations_per_epoch = getattr(config, "iterations_per_epoch", None) + validation_time_limit = getattr(config, "validation_time_limit", None) + validation_batch_limit = getattr(config, "validation_batch_limit", None) print(f"Training configuration:") print(f" Model: {model_name}") @@ -213,6 +553,17 @@ def train_cellmap(config_path: str, data_root: str | None = None, target_shape=N print(f" Max epochs: {max_epochs}") print(f" GPUs: {num_gpus}") print(f" Precision: {precision}") + iter_msg = ( + iterations_per_epoch + if iterations_per_epoch is not None + else "auto (full dataset shuffle)" + ) + print(f" Iterations per epoch: {iter_msg}") + print(f" Weighted sampler: {weighted_sampler}") + if train_batches_per_epoch is not None: + print(f" Trainer train batches/epoch: {train_batches_per_epoch}") + if val_batches_per_epoch is not None: + print(f" Trainer val batches/epoch: {val_batches_per_epoch}") if target_shape: print(f" Target resample shape: {target_shape}") @@ -228,45 +579,123 @@ def train_cellmap(config_path: str, data_root: str | None = None, target_shape=N else: print(f"Using default CellMap search path: {search_path}") - # Generate datasplit CSV if doesn't exist (CellMap's official utility) + # Generate datasplit CSV if it doesn't exist if not os.path.exists(datasplit_path): print(f"Generating datasplit CSV: {datasplit_path}") - # If resampling is enabled, allow all scales (skip filtering) - scale_filter = None if target_shape else input_array_info.get('scale') - make_datasplit_csv( - classes=classes, - csv_path=datasplit_path, - validation_prob=0.15, - scale=scale_filter, - force_all_classes='validate', - search_path=search_path, - ) + scale_filter = None + if not target_shape: + scale_filter = _infer_scale(filter_by_scale, input_array_info) + if force_classes_mode not in {"train", "validate", "both", None}: + raise ValueError( + "force_all_classes must be one of {'train', 'validate', 'both', None}" + ) + effective_force_mode = force_classes_mode or "both" + print(f"Forcing class coverage in datasplit: {effective_force_mode}") + datasplit_args = dict(datasplit_kwargs) + datasplit_args.setdefault("search_path", search_path) + datasplit_args.setdefault("csv_path", datasplit_path) + if use_s3: + make_s3_datasplit_csv( + classes=classes, + scale=scale_filter, + force_all_classes=effective_force_mode, + validation_prob=validation_prob, + **datasplit_args, + ) + else: + make_datasplit_csv( + classes=classes, + scale=scale_filter, + force_all_classes=effective_force_mode, + validation_prob=validation_prob, + **datasplit_args, + ) else: print(f"Using existing datasplit: {datasplit_path}") # Get dataloaders (CellMap's official dataloader) print("Creating dataloaders...") - train_loader, val_loader = get_dataloader( + dataloader_args = dict( datasplit_path=datasplit_path, classes=classes, batch_size=batch_size, input_array_info=input_array_info, target_array_info=target_array_info, spatial_transforms=spatial_transforms, - iterations_per_epoch=1000, - weighted_sampler=True, + target_value_transforms=target_value_transforms, + train_raw_value_transforms=train_raw_value_transforms, + val_raw_value_transforms=val_raw_value_transforms, + random_validation=bool(validation_time_limit or validation_batch_limit), + use_mutual_exclusion=use_mutual_exclusion, + weighted_sampler=weighted_sampler, + **dataloader_kwargs, ) + # Always pass iterations_per_epoch (even None) so we override the default 1000 + dataloader_args["iterations_per_epoch"] = iterations_per_epoch + + train_loader, val_loader = get_dataloader(**dataloader_args) + + # Wrap loaders to satisfy PyTorch Lightning's expectation of a batch_sampler attribute + class _LightningDataLoaderWrapper: + """Adds batch_sampler attribute so custom loaders work with Lightning.""" + + def __init__(self, loader): + self._loader = loader + self.batch_sampler = None # Lightning inspects this attribute + + def __iter__(self): + return iter(self._loader) + + def __len__(self): + inner_loader = getattr(self._loader, "loader", None) + if inner_loader is not None: + try: + return len(inner_loader) + except TypeError: + pass + + if hasattr(self._loader, "__len__"): + try: + return len(self._loader) + except TypeError: + pass + + iterations = getattr(self._loader, "iterations_per_epoch", None) + if isinstance(iterations, int) and iterations > 0: + return iterations + + dataset = getattr(self._loader, "dataset", None) + if dataset is not None: + try: + return len(dataset) + except TypeError: + pass + + raise TypeError( + "Wrapped loader does not define a finite length required for progress bars." + ) + + def __getattr__(self, name): + return getattr(self._loader, name) + + train_loader = _LightningDataLoaderWrapper(train_loader) + if val_loader is not None: + val_loader = _LightningDataLoaderWrapper(val_loader) + # Build model using PyTC's model factory (MONAI models) print(f"Building model: {model_name}") # Create minimal config for PyTC's build_model from omegaconf import OmegaConf + # Get input shape from config (D, H, W) for 3D + input_shape = input_array_info.get('shape', (64, 64, 64)) model_config = OmegaConf.create({ 'model': { 'architecture': model_name, 'in_channels': 1, 'out_channels': len(classes), + 'input_size': list(input_shape), # [D, H, W] for 3D 'mednext_size': getattr(config, 'mednext_size', 'B'), 'mednext_kernel_size': getattr(config, 'mednext_kernel_size', 5), 'deep_supervision': getattr(config, 'deep_supervision', True), @@ -276,10 +705,9 @@ def train_cellmap(config_path: str, data_root: str | None = None, target_shape=N model = build_model(model_config) print(f"Model built successfully") - # Create loss (CellMap's NaN-aware wrapper + PyTC loss) + # Create loss (balanced BCE + Dice, NaN-aware) print("Creating loss function...") - base_loss = torch.nn.BCEWithLogitsLoss - criterion = CellMapLossWrapper(base_loss, reduction='mean') + criterion = CellMapBalancedLoss(bce_weight=0.7, dice_weight=0.3) # Create Lightning module lit_model = CellMapLightningModule( @@ -293,7 +721,7 @@ def train_cellmap(config_path: str, data_root: str | None = None, target_shape=N # Setup callbacks checkpoint_callback = ModelCheckpoint( - dirpath=f'{output_dir}/checkpoints', + dirpath=checkpoint_dir, filename=f'{model_name}-{{epoch:02d}}-{{val/dice:.3f}}', monitor='val/dice', mode='max', @@ -313,11 +741,21 @@ def train_cellmap(config_path: str, data_root: str | None = None, target_shape=N # Setup loggers tb_logger = TensorBoardLogger( - f'{output_dir}/tensorboard', + tensorboard_dir, name=model_name, ) # Create trainer + limit_train_batches = ( + train_batches_per_epoch if train_batches_per_epoch is not None else 1.0 + ) + if validation_batch_limit is not None: + limit_val_batches = validation_batch_limit + elif val_batches_per_epoch is not None: + limit_val_batches = val_batches_per_epoch + else: + limit_val_batches = 1.0 + trainer = pl.Trainer( max_epochs=max_epochs, accelerator='gpu' if torch.cuda.is_available() else 'cpu', @@ -330,11 +768,13 @@ def train_cellmap(config_path: str, data_root: str | None = None, target_shape=N log_every_n_steps=50, enable_progress_bar=True, enable_model_summary=True, + limit_train_batches=limit_train_batches, + limit_val_batches=limit_val_batches, ) # Train! print("Starting training...") - print(f"Monitor progress: tensorboard --logdir {output_dir}/tensorboard") + print(f"Monitor progress: tensorboard --logdir {tensorboard_dir}") trainer.fit(lit_model, train_loader, val_loader) print(f"\nTraining complete!")