@@ -150,15 +150,15 @@ class ModelConfig:
150150 activation : str = "relu"
151151
152152 # UNet-specific parameters (MONAI UNet)
153- spatial_dims : int = (
154- 3 # Spatial dimensions: 2 for 2D, 3 for 3D (auto-inferred from input_size length, not used directly)
155- )
153+ spatial_dims : int = 3 # Spatial dimensions: 2 for 2D, 3 for 3D
154+ # (auto-inferred from input_size length, not used directly)
156155 num_res_units : int = 2 # Number of residual units per block
157156 kernel_size : int = 3 # Convolution kernel size
158157 strides : Optional [List [int ]] = None # Downsampling strides (e.g., [2, 2, 2, 2] for 4 levels)
159158 act : str = "relu" # Activation function: 'relu', 'prelu', 'elu', etc.
160159 upsample : str = (
161- "deconv" # Upsampling mode for MONAI BasicUNet: 'deconv' (transposed conv), 'nontrainable' (interpolation + conv), or 'pixelshuffle'
160+ "deconv" # Upsampling mode: 'deconv' (transposed conv),
161+ # 'nontrainable' (interpolation + conv), or 'pixelshuffle'
162162 )
163163
164164 # Transformer-specific (UNETR, etc.)
@@ -911,19 +911,16 @@ class TestTimeAugmentationConfig:
911911 - rotation90_axes: Uses spatial-only indices (e.g., [1, 2] for H-W plane where 0=D, 1=H, 2=W)
912912 """
913913
914- enabled : bool = False
915- flip_axes : Any = (
916- None # TTA flip strategy: "all" (8 flips), null (no aug), or list like [[2], [3]] (full tensor indices)
917- )
918- rotation90_axes : Any = (
919- None # TTA rotation90 strategy: "all" (3 planes × 4 rotations), null, or list like [[1, 2]] (spatial indices: 0=D, 1=H, 2=W)
920- )
914+ flip_axes : Any = None # TTA flip strategy: "all" (8 flips), null (no aug),
915+ # or list like [[2], [3]] (full tensor indices)
916+ rotation90_axes : Any = None # TTA rotation90 strategy: "all" (3 planes × 4 rotations),
917+ # null, or list like [[1, 2]] (spatial indices: 0=D, 1=H, 2=W)
921918 channel_activations : Optional [List [Any ]] = (
922- None # Per-channel activations: [[start_ch, end_ch, 'activation'], ...] e.g., [[0, 2, 'softmax'], [2, 3, 'sigmoid'], [3, 4, 'tanh']]
923- )
924- select_channel : Any = (
925- None # Channel selection: null (all), [1] (foreground), -1 (all) (applied even with null flip_axes)
919+ None # Per-channel activations: [[start_ch, end_ch, 'activation'], ...]
920+ # e.g., [[0, 2, 'softmax'], [2, 3, 'sigmoid'], [3, 4, 'tanh']]
926921 )
922+ select_channel : Any = None # Channel selection: null (all), [1] (foreground), -1 (all)
923+ # (applied even with null flip_axes)
927924 ensemble_mode : str = "mean" # Ensemble mode for TTA: 'mean', 'min', 'max'
928925 apply_mask : bool = False # Multiply each channel by corresponding test_mask after ensemble
929926
@@ -943,8 +940,9 @@ class SavePredictionConfig:
943940
944941 enabled : bool = True # Enable saving intermediate predictions
945942 intensity_scale : float = (
946- - 1.0
947- ) # If < 0, keep raw predictions (no normalization/scaling). If > 0, normalize to [0,1] then scale.
943+ - 1.0 # If < 0, keep raw predictions (no normalization/scaling).
944+ # If > 0, normalize to [0,1] then scale.
945+ )
948946 intensity_dtype : str = (
949947 "uint8" # Save as uint8 for visualization (ignored if intensity_scale < 0)
950948 )
@@ -983,9 +981,8 @@ class DecodeBinaryContourDistanceWatershedConfig:
983981class DecodeModeConfig :
984982 """Configuration for a single decode mode/function."""
985983
986- name : str = (
987- "decode_binary_watershed" # Function name: decode_binary_cc, decode_binary_watershed, decode_binary_contour_distance_watershed, etc.
988- )
984+ name : str = "decode_binary_watershed" # Function name: decode_binary_cc,
985+ # decode_binary_watershed, decode_binary_contour_distance_watershed, etc.
989986 kwargs : Dict [str , Any ] = field (
990987 default_factory = dict
991988 ) # Keyword arguments for the decode function
@@ -1105,6 +1102,8 @@ class TestDataConfig:
11051102 cache_suffix : str = "_prediction.h5"
11061103 # Image transformation (applied to test images during inference)
11071104 image_transform : ImageTransformConfig = field (default_factory = ImageTransformConfig )
1105+ # Label transformation (optional, typically not used for test mode to preserve raw labels for evaluation)
1106+ label_transform : Optional [LabelTransformConfig ] = None
11081107
11091108
11101109@dataclass
@@ -1128,6 +1127,8 @@ class TuneDataConfig:
11281127 tune_resolution : Optional [List [int ]] = None
11291128 # Image transformation (applied to tune images during inference)
11301129 image_transform : ImageTransformConfig = field (default_factory = ImageTransformConfig )
1130+ # Label transformation (optional, typically not used for tune mode to preserve raw labels for evaluation)
1131+ label_transform : Optional [LabelTransformConfig ] = None
11311132
11321133
11331134@dataclass
@@ -1210,6 +1211,24 @@ class TuneConfig:
12101211 parameter_space : ParameterSpaceConfig = field (default_factory = ParameterSpaceConfig )
12111212
12121213
1214+ # Allow safe loading of checkpoints with PyTorch 2.6+ weights_only defaults
1215+ try :
1216+ import torch
1217+
1218+ if hasattr (torch , "serialization" ) and hasattr (torch .serialization , "add_safe_globals" ):
1219+ torch .serialization .add_safe_globals (
1220+ [
1221+ ParameterConfig ,
1222+ DecodingParameterSpace ,
1223+ PostprocessingParameterSpace ,
1224+ ParameterSpaceConfig ,
1225+ ]
1226+ )
1227+ except Exception :
1228+ # Best-effort registration; ignore if torch not available at import time
1229+ pass
1230+
1231+
12131232@dataclass
12141233class Config :
12151234 """Main configuration for PyTorch Connectomics.
0 commit comments