Skip to content

Commit 5f82b9b

Browse files
committed
feat: list all registered schedulers (#1009)
1 parent 8dcad29 commit 5f82b9b

File tree

2 files changed

+84
-11
lines changed

2 files changed

+84
-11
lines changed

torchx/schedulers/__init__.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from torchx.schedulers.api import Scheduler
1414
from torchx.util.entrypoints import load_group
1515

16-
DEFAULT_SCHEDULER_MODULES: Mapping[str, str] = {
16+
BUILTIN_SCHEDULER_MODULES: Mapping[str, str] = {
1717
"local_docker": "torchx.schedulers.docker_scheduler",
1818
"local_cwd": "torchx.schedulers.local_scheduler",
1919
"slurm": "torchx.schedulers.slurm_scheduler",
@@ -39,24 +39,25 @@ def run(*args: object, **kwargs: object) -> Scheduler:
3939
return run
4040

4141

42+
def default_schedulers() -> dict[str, SchedulerFactory]:
43+
"""Build default schedulers (built-in + extras from torchx.schedulers.extra)."""
44+
return {
45+
**{s: _defer_load_scheduler(p) for s, p in BUILTIN_SCHEDULER_MODULES.items()},
46+
**load_group("torchx.schedulers.extra", default={}),
47+
}
48+
49+
4250
def get_scheduler_factories(
43-
group: str = "torchx.schedulers", skip_defaults: bool = False
51+
group: str = "torchx.schedulers",
52+
skip_defaults: bool = False,
4453
) -> dict[str, SchedulerFactory]:
4554
"""
4655
get_scheduler_factories returns all the available schedulers names under `group` and the
4756
method to instantiate them.
4857
4958
The first scheduler in the dictionary is used as the default scheduler.
5059
"""
51-
52-
if skip_defaults:
53-
default_schedulers = {}
54-
else:
55-
default_schedulers: dict[str, SchedulerFactory] = {}
56-
for scheduler, path in DEFAULT_SCHEDULER_MODULES.items():
57-
default_schedulers[scheduler] = _defer_load_scheduler(path)
58-
59-
return load_group(group, default=default_schedulers)
60+
return load_group(group, default={} if skip_defaults else default_schedulers())
6061

6162

6263
def get_default_scheduler_name() -> str:

torchx/schedulers/test/registry_test.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,75 @@ def test_get_local_schedulers(self, mock_load_group: MagicMock) -> None:
4343

4444
for scheduler in schedulers.values():
4545
self.assertEqual("test_session", scheduler.session_name)
46+
47+
@patch("torchx.schedulers.load_group")
48+
def test_torchx_schedulers_overrides_all(self, mock_load_group: MagicMock) -> None:
49+
"""torchx.schedulers completely overrides defaults and ignores extras"""
50+
mock_custom: MagicMock = MagicMock()
51+
mock_extra: MagicMock = MagicMock()
52+
53+
mock_load_group.side_effect = lambda group, default: (
54+
{"custom": mock_custom}
55+
if group == "torchx.schedulers"
56+
else {"extra": mock_extra} if group == "torchx.schedulers.extra" else {}
57+
)
58+
59+
factories = get_scheduler_factories()
60+
61+
self.assertEqual(factories, {"custom": mock_custom})
62+
self.assertNotIn("local_docker", factories)
63+
self.assertNotIn("extra", factories)
64+
65+
@patch("torchx.schedulers.load_group")
66+
def test_no_custom_returns_defaults_and_extras(
67+
self, mock_load_group: MagicMock
68+
) -> None:
69+
"""no custom schedulers returns built-in + extras"""
70+
mock_extra: MagicMock = MagicMock()
71+
72+
mock_load_group.side_effect = lambda group, default: (
73+
{"extra": mock_extra} if group == "torchx.schedulers.extra" else default
74+
)
75+
76+
factories = get_scheduler_factories()
77+
78+
self.assertIn("local_docker", factories)
79+
self.assertIn("slurm", factories)
80+
self.assertIn("extra", factories)
81+
82+
@patch("torchx.schedulers.load_group")
83+
def test_no_custom_no_extras_returns_builtins(
84+
self, mock_load_group: MagicMock
85+
) -> None:
86+
"""no custom, no extras returns only built-in schedulers"""
87+
mock_load_group.side_effect = lambda group, default: default
88+
89+
factories = get_scheduler_factories()
90+
91+
self.assertIn("local_docker", factories)
92+
self.assertIn("slurm", factories)
93+
94+
@patch("torchx.schedulers.load_group")
95+
def test_skip_defaults_returns_empty(self, mock_load_group: MagicMock) -> None:
96+
"""skip_defaults=True with no custom schedulers returns empty"""
97+
mock_load_group.side_effect = lambda group, default: default
98+
99+
factories = get_scheduler_factories(skip_defaults=True)
100+
101+
self.assertEqual(factories, {})
102+
103+
@patch("torchx.schedulers.load_group")
104+
def test_custom_scheduler_is_default(self, mock_load_group: MagicMock) -> None:
105+
"""first custom scheduler becomes the default"""
106+
mock_aws: MagicMock = MagicMock()
107+
mock_custom: MagicMock = MagicMock()
108+
109+
mock_load_group.side_effect = lambda group, default: (
110+
{"aws_batch": mock_aws, "custom_1": mock_custom}
111+
if group == "torchx.schedulers"
112+
else {}
113+
)
114+
115+
default_name = get_default_scheduler_name()
116+
117+
self.assertIn(default_name, ["aws_batch", "custom_1"])

0 commit comments

Comments
 (0)