Skip to content

Commit 2aa512f

Browse files
author
Donglai Wei
committed
fix optuna
1 parent 4888f39 commit 2aa512f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

68 files changed

+1050
-623
lines changed

connectomics/config/auto_config.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -199,15 +199,15 @@ def plan(
199199
)
200200
result.planning_notes.append(
201201
f"Estimated memory: {result.estimated_gpu_memory_gb:.2f} GB "
202-
f"({result.estimated_gpu_memory_gb/gpu_memory_gb*100:.1f}% of GPU)"
202+
f"({result.estimated_gpu_memory_gb / gpu_memory_gb * 100:.1f}% of GPU)"
203203
)
204204
result.planning_notes.append(f"Batch size: {batch_size}")
205205

206206
# Gradient accumulation if batch size is very small
207207
if batch_size == 1:
208208
result.accumulate_grad_batches = 4
209209
result.planning_notes.append(
210-
f"Using gradient accumulation (4 batches) for effective batch_size=4"
210+
"Using gradient accumulation (4 batches) for effective batch_size=4"
211211
)
212212

213213
# Step 5: Determine num_workers
@@ -311,7 +311,7 @@ def print_plan(self, result: AutoPlanResult):
311311
print(f" Available: {result.available_gpu_memory_gb:.2f} GB")
312312
print(
313313
f" Estimated Usage: {result.estimated_gpu_memory_gb:.2f} GB "
314-
f"({result.estimated_gpu_memory_gb/result.available_gpu_memory_gb*100:.1f}%)"
314+
f"({result.estimated_gpu_memory_gb / result.available_gpu_memory_gb * 100:.1f}%)"
315315
)
316316
print(f" Per Sample: {result.gpu_memory_per_sample_gb:.2f} GB")
317317
print()
@@ -450,7 +450,6 @@ def auto_plan_config(
450450
if __name__ == "__main__":
451451
# Test auto planning
452452
from connectomics.config import Config
453-
from omegaconf import OmegaConf
454453

455454
# Create test config
456455
cfg = OmegaConf.structured(Config())

connectomics/config/hydra_config.py

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -150,15 +150,15 @@ class ModelConfig:
150150
activation: str = "relu"
151151

152152
# UNet-specific parameters (MONAI UNet)
153-
spatial_dims: int = (
154-
3 # Spatial dimensions: 2 for 2D, 3 for 3D (auto-inferred from input_size length, not used directly)
155-
)
153+
spatial_dims: int = 3 # Spatial dimensions: 2 for 2D, 3 for 3D
154+
# (auto-inferred from input_size length, not used directly)
156155
num_res_units: int = 2 # Number of residual units per block
157156
kernel_size: int = 3 # Convolution kernel size
158157
strides: Optional[List[int]] = None # Downsampling strides (e.g., [2, 2, 2, 2] for 4 levels)
159158
act: str = "relu" # Activation function: 'relu', 'prelu', 'elu', etc.
160159
upsample: str = (
161-
"deconv" # Upsampling mode for MONAI BasicUNet: 'deconv' (transposed conv), 'nontrainable' (interpolation + conv), or 'pixelshuffle'
160+
"deconv" # Upsampling mode: 'deconv' (transposed conv),
161+
# 'nontrainable' (interpolation + conv), or 'pixelshuffle'
162162
)
163163

164164
# Transformer-specific (UNETR, etc.)
@@ -911,19 +911,16 @@ class TestTimeAugmentationConfig:
911911
- rotation90_axes: Uses spatial-only indices (e.g., [1, 2] for H-W plane where 0=D, 1=H, 2=W)
912912
"""
913913

914-
enabled: bool = False
915-
flip_axes: Any = (
916-
None # TTA flip strategy: "all" (8 flips), null (no aug), or list like [[2], [3]] (full tensor indices)
917-
)
918-
rotation90_axes: Any = (
919-
None # TTA rotation90 strategy: "all" (3 planes × 4 rotations), null, or list like [[1, 2]] (spatial indices: 0=D, 1=H, 2=W)
920-
)
914+
flip_axes: Any = None # TTA flip strategy: "all" (8 flips), null (no aug),
915+
# or list like [[2], [3]] (full tensor indices)
916+
rotation90_axes: Any = None # TTA rotation90 strategy: "all" (3 planes × 4 rotations),
917+
# null, or list like [[1, 2]] (spatial indices: 0=D, 1=H, 2=W)
921918
channel_activations: Optional[List[Any]] = (
922-
None # Per-channel activations: [[start_ch, end_ch, 'activation'], ...] e.g., [[0, 2, 'softmax'], [2, 3, 'sigmoid'], [3, 4, 'tanh']]
923-
)
924-
select_channel: Any = (
925-
None # Channel selection: null (all), [1] (foreground), -1 (all) (applied even with null flip_axes)
919+
None # Per-channel activations: [[start_ch, end_ch, 'activation'], ...]
920+
# e.g., [[0, 2, 'softmax'], [2, 3, 'sigmoid'], [3, 4, 'tanh']]
926921
)
922+
select_channel: Any = None # Channel selection: null (all), [1] (foreground), -1 (all)
923+
# (applied even with null flip_axes)
927924
ensemble_mode: str = "mean" # Ensemble mode for TTA: 'mean', 'min', 'max'
928925
apply_mask: bool = False # Multiply each channel by corresponding test_mask after ensemble
929926

@@ -943,8 +940,9 @@ class SavePredictionConfig:
943940

944941
enabled: bool = True # Enable saving intermediate predictions
945942
intensity_scale: float = (
946-
-1.0
947-
) # If < 0, keep raw predictions (no normalization/scaling). If > 0, normalize to [0,1] then scale.
943+
-1.0 # If < 0, keep raw predictions (no normalization/scaling).
944+
# If > 0, normalize to [0,1] then scale.
945+
)
948946
intensity_dtype: str = (
949947
"uint8" # Save as uint8 for visualization (ignored if intensity_scale < 0)
950948
)
@@ -983,9 +981,8 @@ class DecodeBinaryContourDistanceWatershedConfig:
983981
class DecodeModeConfig:
984982
"""Configuration for a single decode mode/function."""
985983

986-
name: str = (
987-
"decode_binary_watershed" # Function name: decode_binary_cc, decode_binary_watershed, decode_binary_contour_distance_watershed, etc.
988-
)
984+
name: str = "decode_binary_watershed" # Function name: decode_binary_cc,
985+
# decode_binary_watershed, decode_binary_contour_distance_watershed, etc.
989986
kwargs: Dict[str, Any] = field(
990987
default_factory=dict
991988
) # Keyword arguments for the decode function
@@ -1105,6 +1102,8 @@ class TestDataConfig:
11051102
cache_suffix: str = "_prediction.h5"
11061103
# Image transformation (applied to test images during inference)
11071104
image_transform: ImageTransformConfig = field(default_factory=ImageTransformConfig)
1105+
# Label transformation (optional, typically not used for test mode to preserve raw labels for evaluation)
1106+
label_transform: Optional[LabelTransformConfig] = None
11081107

11091108

11101109
@dataclass
@@ -1128,6 +1127,8 @@ class TuneDataConfig:
11281127
tune_resolution: Optional[List[int]] = None
11291128
# Image transformation (applied to tune images during inference)
11301129
image_transform: ImageTransformConfig = field(default_factory=ImageTransformConfig)
1130+
# Label transformation (optional, typically not used for tune mode to preserve raw labels for evaluation)
1131+
label_transform: Optional[LabelTransformConfig] = None
11311132

11321133

11331134
@dataclass
@@ -1210,6 +1211,24 @@ class TuneConfig:
12101211
parameter_space: ParameterSpaceConfig = field(default_factory=ParameterSpaceConfig)
12111212

12121213

1214+
# Allow safe loading of checkpoints with PyTorch 2.6+ weights_only defaults
1215+
try:
1216+
import torch
1217+
1218+
if hasattr(torch, "serialization") and hasattr(torch.serialization, "add_safe_globals"):
1219+
torch.serialization.add_safe_globals(
1220+
[
1221+
ParameterConfig,
1222+
DecodingParameterSpace,
1223+
PostprocessingParameterSpace,
1224+
ParameterSpaceConfig,
1225+
]
1226+
)
1227+
except Exception:
1228+
# Best-effort registration; ignore if torch not available at import time
1229+
pass
1230+
1231+
12131232
@dataclass
12141233
class Config:
12151234
"""Main configuration for PyTorch Connectomics.

connectomics/config/hydra_utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,11 @@ def resolve_data_paths(cfg: Config) -> Config:
269269
>>> cfg.data.train_image = ["PT37/*_raw.tif", "file.tif"]
270270
>>> resolve_data_paths(cfg)
271271
>>> print(cfg.data.train_image)
272-
['/data/barcode/PT37/img1_raw.tif', '/data/barcode/PT37/img2_raw.tif', '/data/barcode/file.tif']
272+
[
273+
'/data/barcode/PT37/img1_raw.tif',
274+
'/data/barcode/PT37/img2_raw.tif',
275+
'/data/barcode/file.tif'
276+
]
273277
274278
>>> cfg.test.data.test_path = "/data/test/"
275279
>>> cfg.test.data.test_image = ["volume_*.tif"]
@@ -325,7 +329,8 @@ def _combine_path(
325329
index = int(selector)
326330
if index < -len(expanded) or index >= len(expanded):
327331
print(
328-
f"Warning: Index {index} out of range for {len(expanded)} files, using first"
332+
f"Warning: Index {index} out of range for {len(expanded)} files, "
333+
f"using first"
329334
)
330335
return expanded[0]
331336
return expanded[index]
@@ -343,7 +348,8 @@ def _combine_path(
343348
return matching[0]
344349
else:
345350
print(
346-
f"Warning: No file matches selector '{selector}', using first of {len(expanded)} files"
351+
f"Warning: No file matches selector '{selector}', "
352+
f"using first of {len(expanded)} files"
347353
)
348354
return expanded[0]
349355

connectomics/config/slurm_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,6 @@ def filter_partitions(
275275
and resources.gpus >= min_gpus
276276
and resources.memory_gb >= min_memory_gb
277277
):
278-
279278
# Check GPU type if specified
280279
if gpu_type and resources.gpu_type != gpu_type:
281280
continue

connectomics/data/augment/build.py

Lines changed: 22 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,10 @@
1212
RandRotate90d,
1313
RandFlipd,
1414
RandAffined,
15-
RandZoomd,
1615
RandGaussianNoised,
1716
RandShiftIntensityd,
18-
RandGaussianSmoothd,
1917
RandAdjustContrastd,
2018
RandSpatialCropd,
21-
ScaleIntensityRanged,
2219
ToTensord,
2320
CenterSpatialCropd,
2421
SpatialPadd,
@@ -55,7 +52,8 @@ def build_train_transforms(
5552
5653
Args:
5754
cfg: Hydra Config object
58-
keys: Keys to transform (default: ['image', 'label'] or ['image', 'label', 'mask'] if masks are used)
55+
keys: Keys to transform (default: ['image', 'label'] or
56+
['image', 'label', 'mask'] if masks are used)
5957
skip_loading: Skip LoadVolumed (for pre-cached datasets)
6058
6159
Returns:
@@ -295,7 +293,8 @@ def _build_eval_transforms_impl(cfg: Config, mode: str = "val", keys: list[str]
295293
transforms.append(ApplyVolumetricSplitd(keys=keys))
296294

297295
# Apply resize if configured (before cropping)
298-
# For test/tune mode, only use test.data.image_transform or tune.data.image_transform (no fallback)
296+
# For test/tune mode, only use test.data.image_transform or
297+
# tune.data.image_transform (no fallback)
299298
# For val mode, use data.image_transform
300299
resize_factors = None
301300
if mode == "test":
@@ -370,7 +369,8 @@ def _build_eval_transforms_impl(cfg: Config, mode: str = "val", keys: list[str]
370369
# else: mode == "test" -> no cropping for sliding window inference
371370

372371
# Normalization - use smart normalization
373-
# For test/tune mode, only use test.data.image_transform or tune.data.image_transform (no fallback)
372+
# For test/tune mode, only use test.data.image_transform or
373+
# tune.data.image_transform (no fallback)
374374
# For val mode, use data.image_transform
375375
image_transform = None
376376
if (
@@ -407,42 +407,20 @@ def _build_eval_transforms_impl(cfg: Config, mode: str = "val", keys: list[str]
407407
if getattr(cfg.data, "normalize_labels", False):
408408
transforms.append(NormalizeLabelsd(keys=["label"]))
409409

410-
# Check if we should skip label transforms (test/tune mode)
411-
# Skip label transforms if test.data or tune.data has evaluation.enabled=True
412-
# This preserves original instance labels for metric computation
413-
skip_label_transform = False
414-
if mode == "test":
415-
# Check if test.evaluation or tune.evaluation is enabled (for adapted_rand, etc.)
416-
evaluation_config = None
417-
if hasattr(cfg, "test") and hasattr(cfg.test, "evaluation"):
418-
evaluation_config = cfg.test.evaluation
419-
elif hasattr(cfg, "tune") and cfg.tune and hasattr(cfg.tune, "optimization"):
420-
# For tune mode, check if we're optimizing metrics that need instance labels
421-
if hasattr(cfg.tune.optimization, "single_objective"):
422-
metric = getattr(cfg.tune.optimization.single_objective, "metric", None)
423-
if metric == "adapted_rand":
424-
skip_label_transform = True
425-
print(
426-
f" ⚠️ Skipping label transforms for Optuna tuning (keeping original labels for {metric})"
427-
)
428-
429-
if evaluation_config:
430-
evaluation_enabled = getattr(evaluation_config, "enabled", False)
431-
metrics = getattr(evaluation_config, "metrics", [])
432-
if evaluation_enabled and metrics:
433-
skip_label_transform = True
434-
print(
435-
f" ⚠️ Skipping label transforms for metric evaluation (keeping original labels for {metrics})"
436-
)
437-
438410
# Label transformations (affinity, distance transform, etc.)
439-
# Only apply if not skipped AND label_transform is configured
440-
if hasattr(cfg.data, "label_transform") and not skip_label_transform:
411+
# For test/tune modes: NEVER apply label transforms (keep raw instance labels for evaluation)
412+
# For val mode: use training label_transform config
413+
label_cfg = None
414+
if mode == "val":
415+
# Validation always uses training label_transform
416+
if hasattr(cfg.data, "label_transform"):
417+
label_cfg = cfg.data.label_transform
418+
419+
# Apply label transforms if configured
420+
if label_cfg is not None:
441421
from ..process.build import create_label_transform_pipeline
442422
from ..process.monai_transforms import SegErosionInstanced
443423

444-
label_cfg = cfg.data.label_transform
445-
446424
# Apply instance erosion first if specified
447425
if hasattr(label_cfg, "erosion") and label_cfg.erosion > 0:
448426
transforms.append(SegErosionInstanced(keys=["label"], tsz_h=label_cfg.erosion))
@@ -642,27 +620,32 @@ def should_augment(aug_name: str, aug_enabled: Optional[bool]) -> bool:
642620
scale_range = aug_cfg.affine.scale_range
643621
shear_range = aug_cfg.affine.shear_range
644622

623+
# Interpolation per key: bilinear for images, nearest for labels/masks
624+
affine_modes = ["bilinear" if k == "image" else "nearest" for k in keys]
625+
645626
transforms.append(
646627
RandAffined(
647628
keys=keys,
648629
prob=aug_cfg.affine.prob,
649630
rotate_range=rotate_range,
650631
scale_range=scale_range,
651632
shear_range=shear_range,
652-
mode="bilinear",
633+
mode=affine_modes,
653634
padding_mode="reflection",
654635
)
655636
)
656637

657638
if should_augment("elastic", aug_cfg.elastic.enabled):
658639
# Unified elastic deformation that supports both 2D and 3D
640+
elastic_modes = ["bilinear" if k == "image" else "nearest" for k in keys]
659641
transforms.append(
660642
RandElasticd(
661643
keys=keys,
662644
do_2d=do_2d,
663645
prob=aug_cfg.elastic.prob,
664646
sigma_range=aug_cfg.elastic.sigma_range,
665647
magnitude_range=aug_cfg.elastic.magnitude_range,
648+
mode=elastic_modes,
666649
)
667650
)
668651

connectomics/data/augment/monai_transforms.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66
"""
77

88
from __future__ import annotations
9-
from typing import Dict, Any, List, Optional, Union, Tuple
9+
1010
import math
11+
from typing import Any, Dict, List, Optional, Tuple, Union
12+
13+
import cv2
1114
import numpy as np
1215
import torch
13-
import cv2
1416
from monai.config import KeysCollection
1517
from monai.transforms import MapTransform, RandomizableTransform
1618

@@ -1068,7 +1070,8 @@ def __init__(
10681070
self.mode = "divide"
10691071
except ValueError:
10701072
raise ValueError(
1071-
f"Invalid divide mode '{mode}'. Format should be 'divide-K' where K is a number (e.g., 'divide-255')"
1073+
f"Invalid divide mode '{mode}'. Format should be 'divide-K' "
1074+
f"where K is a number (e.g., 'divide-255')"
10721075
)
10731076
elif mode not in ["none", "normal", "0-1"]:
10741077
raise ValueError(
@@ -1212,7 +1215,8 @@ def __init__(
12121215

12131216
if orientation not in ["horizontal", "vertical", "random"]:
12141217
raise ValueError(
1215-
f"Invalid orientation '{orientation}'. Must be 'horizontal', 'vertical', or 'random'"
1218+
f"Invalid orientation '{orientation}'. Must be 'horizontal', "
1219+
f"'vertical', or 'random'"
12161220
)
12171221
self.orientation = orientation
12181222

@@ -1356,7 +1360,8 @@ class ResizeByFactord(MapTransform):
13561360
13571361
Args:
13581362
keys: Keys to transform
1359-
scale_factors: Scale factors for each spatial dimension (e.g., [0.25, 0.25] for 2D, [0.5, 0.5, 0.5] for 3D)
1363+
scale_factors: Scale factors for each spatial dimension
1364+
(e.g., [0.25, 0.25] for 2D, [0.5, 0.5, 0.5] for 3D)
13601365
mode: Interpolation mode ('bilinear', 'nearest', 'area', etc.)
13611366
align_corners: Whether to align corners (True for bilinear, None for nearest)
13621367
allow_missing_keys: Whether to allow missing keys

0 commit comments

Comments
 (0)