Skip to content

Commit 7db1d4f

Browse files
author
Donglai Wei
committed
fix doc build and test failure
1 parent eaa8472 commit 7db1d4f

File tree

6 files changed

+51
-11
lines changed

6 files changed

+51
-11
lines changed

connectomics/config/auto_config.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,9 @@ def auto_plan_config(
390390
if getattr(config.optimization, "precision", None) is not None:
391391
manual_overrides["precision"] = config.optimization.precision
392392
if getattr(config.optimization, "accumulate_grad_batches", None) is not None:
393-
manual_overrides["accumulate_grad_batches"] = config.optimization.accumulate_grad_batches
393+
manual_overrides["accumulate_grad_batches"] = (
394+
config.optimization.accumulate_grad_batches
395+
)
394396

395397
opt_cfg = getattr(config.optimization, "optimizer", None)
396398
if opt_cfg and getattr(opt_cfg, "lr", None) is not None:
@@ -406,8 +408,7 @@ def auto_plan_config(
406408

407409
# Plan
408410
use_mixed_precision = not (
409-
hasattr(config, "optimization")
410-
and getattr(config.optimization, "precision", None) == "32"
411+
hasattr(config, "optimization") and getattr(config.optimization, "precision", None) == "32"
411412
)
412413

413414
result = planner.plan(

connectomics/config/hydra_config.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1220,6 +1220,18 @@ class TuneConfig:
12201220
DecodingParameterSpace,
12211221
PostprocessingParameterSpace,
12221222
ParameterSpaceConfig,
1223+
# Core config dataclasses (for Lightning checkpoints)
1224+
Config,
1225+
SystemConfig,
1226+
SystemTrainingConfig,
1227+
SystemInferenceConfig,
1228+
ModelConfig,
1229+
DataConfig,
1230+
OptimizationConfig,
1231+
MonitorConfig,
1232+
InferenceConfig,
1233+
TestConfig,
1234+
TuneConfig,
12231235
]
12241236
)
12251237
except Exception:

connectomics/data/augment/monai_transforms.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,9 @@ def _apply_missing_section(
237237
if is_tensor:
238238
keep_mask = torch.ones(depth, dtype=torch.bool, device=img.device)
239239
keep_mask[indices_to_remove] = False
240-
return torch.index_select(img, dim=depth_axis, index=keep_mask.nonzero(as_tuple=False).squeeze(-1))
240+
return torch.index_select(
241+
img, dim=depth_axis, index=keep_mask.nonzero(as_tuple=False).squeeze(-1)
242+
)
241243
else:
242244
return np.delete(img, indices_to_remove, axis=depth_axis)
243245

connectomics/training/lit/model.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,10 @@ def __init__(
8282
if hasattr(cfg.model, "loss_weights")
8383
else [1.0] * len(self.loss_functions)
8484
)
85-
self.multi_task_config = (
86-
getattr(cfg.model, "multi_task_config", None) or []
87-
)
85+
self.multi_task_config = getattr(cfg.model, "multi_task_config", None) or []
8886
self.multi_task_enabled = len(self.multi_task_config) > 0
89-
num_tasks = len(self.multi_task_config) if self.multi_task_config else len(
90-
self.loss_functions
87+
num_tasks = (
88+
len(self.multi_task_config) if self.multi_task_config else len(self.loss_functions)
9189
)
9290
self.loss_weighter = build_loss_weighter(cfg, num_tasks=num_tasks, model=self.model)
9391

connectomics/training/lit/trainer.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,38 @@
2424
from pytorch_lightning.strategies import DDPStrategy
2525

2626
from ...config import Config
27+
from ...config.hydra_config import (
28+
SystemConfig,
29+
SystemTrainingConfig,
30+
SystemInferenceConfig,
31+
ModelConfig,
32+
DataConfig,
33+
OptimizationConfig,
34+
MonitorConfig,
35+
InferenceConfig,
36+
TestConfig,
37+
TuneConfig,
38+
)
2739
from .callbacks import VisualizationCallback, EMAWeightsCallback
2840

2941
# Register safe globals for PyTorch 2.6+ checkpoint loading
3042
# This allows our Config class to be unpickled from Lightning checkpoints
3143
try:
32-
torch.serialization.add_safe_globals([Config])
44+
torch.serialization.add_safe_globals(
45+
[
46+
Config,
47+
SystemConfig,
48+
SystemTrainingConfig,
49+
SystemInferenceConfig,
50+
ModelConfig,
51+
DataConfig,
52+
OptimizationConfig,
53+
MonitorConfig,
54+
InferenceConfig,
55+
TestConfig,
56+
TuneConfig,
57+
]
58+
)
3359
except AttributeError:
3460
# PyTorch < 2.6 doesn't have add_safe_globals
3561
pass

docs/source/_templates/layout.html

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@
4545

4646
{% endif %}
4747

48-
<link rel="stylesheet" href="{{ pathto('_static/' + style, 1) }}" type="text/css" />
48+
{# `style` is absent on newer Sphinx, so default to the theme stylesheet when undefined #}
49+
<link rel="stylesheet" href="{{ pathto('_static/' + (style | default('css/theme.css')), 1) }}" type="text/css" />
4950
<!-- <link rel="stylesheet" href="{{ pathto('_static/pygments.css', 1) }}" type="text/css" /> -->
5051
{%- for css in css_files %}
5152
{%- if css|attr("rel") %}

0 commit comments

Comments
 (0)