Skip to content

Commit d99cb1e

Browse files
committed
feat: list all registered schedulers (#1009)
1 parent 8dcad29 commit d99cb1e

File tree

2 files changed

+91
-11
lines changed

2 files changed

+91
-11
lines changed

torchx/schedulers/__init__.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from torchx.schedulers.api import Scheduler
1414
from torchx.util.entrypoints import load_group
1515

16-
DEFAULT_SCHEDULER_MODULES: Mapping[str, str] = {
16+
BUILTIN_SCHEDULER_MODULES: Mapping[str, str] = {
1717
"local_docker": "torchx.schedulers.docker_scheduler",
1818
"local_cwd": "torchx.schedulers.local_scheduler",
1919
"slurm": "torchx.schedulers.slurm_scheduler",
@@ -39,24 +39,25 @@ def run(*args: object, **kwargs: object) -> Scheduler:
3939
return run
4040

4141

42+
def default_schedulers() -> dict[str, SchedulerFactory]:
43+
"""Build default schedulers (built-in + extras from torchx.schedulers.extra)."""
44+
return {
45+
**{s: _defer_load_scheduler(p) for s, p in BUILTIN_SCHEDULER_MODULES.items()},
46+
**load_group("torchx.schedulers.extra", default={}),
47+
}
48+
49+
4250
def get_scheduler_factories(
43-
group: str = "torchx.schedulers", skip_defaults: bool = False
51+
group: str = "torchx.schedulers",
52+
skip_defaults: bool = False,
4453
) -> dict[str, SchedulerFactory]:
4554
"""
4655
get_scheduler_factories returns all the available schedulers names under `group` and the
4756
method to instantiate them.
4857
4958
The first scheduler in the dictionary is used as the default scheduler.
5059
"""
51-
52-
if skip_defaults:
53-
default_schedulers = {}
54-
else:
55-
default_schedulers: dict[str, SchedulerFactory] = {}
56-
for scheduler, path in DEFAULT_SCHEDULER_MODULES.items():
57-
default_schedulers[scheduler] = _defer_load_scheduler(path)
58-
59-
return load_group(group, default=default_schedulers)
60+
return load_group(group, default={} if skip_defaults else default_schedulers())
6061

6162

6263
def get_default_scheduler_name() -> str:

torchx/schedulers/test/registry_test.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,82 @@ def test_get_local_schedulers(self, mock_load_group: MagicMock) -> None:
4343

4444
for scheduler in schedulers.values():
4545
self.assertEqual("test_session", scheduler.session_name)
46+
47+
@patch("torchx.schedulers.load_group")
48+
def test_torchx_schedulers_overrides_all(self, mock_load_group: MagicMock) -> None:
49+
"""torchx.schedulers completely overrides defaults and ignores extras"""
50+
mock_custom = MagicMock()
51+
mock_extra = MagicMock()
52+
53+
def load_group_side_effect(group, default):
54+
if group == "torchx.schedulers":
55+
return {"custom": mock_custom}
56+
elif group == "torchx.schedulers.extra":
57+
return {"extra": mock_extra}
58+
return {}
59+
60+
mock_load_group.side_effect = load_group_side_effect
61+
62+
factories = get_scheduler_factories()
63+
64+
self.assertEqual(factories, {"custom": mock_custom})
65+
self.assertNotIn("local_docker", factories)
66+
self.assertNotIn("extra", factories)
67+
68+
@patch("torchx.schedulers.load_group")
69+
def test_no_custom_returns_defaults_and_extras(
70+
self, mock_load_group: MagicMock
71+
) -> None:
72+
"""no custom schedulers returns built-in + extras"""
73+
mock_extra = MagicMock()
74+
75+
def load_group_side_effect(group, default):
76+
if group == "torchx.schedulers.extra":
77+
return {"extra": mock_extra}
78+
return default
79+
80+
mock_load_group.side_effect = load_group_side_effect
81+
82+
factories = get_scheduler_factories()
83+
84+
self.assertIn("local_docker", factories)
85+
self.assertIn("slurm", factories)
86+
self.assertIn("extra", factories)
87+
88+
@patch("torchx.schedulers.load_group")
89+
def test_no_custom_no_extras_returns_builtins(
90+
self, mock_load_group: MagicMock
91+
) -> None:
92+
"""no custom, no extras returns only built-in schedulers"""
93+
mock_load_group.side_effect = lambda group, default: default
94+
95+
factories = get_scheduler_factories()
96+
97+
self.assertIn("local_docker", factories)
98+
self.assertIn("slurm", factories)
99+
100+
@patch("torchx.schedulers.load_group")
101+
def test_skip_defaults_returns_empty(self, mock_load_group: MagicMock) -> None:
102+
"""skip_defaults=True with no custom schedulers returns empty"""
103+
mock_load_group.side_effect = lambda group, default: default
104+
105+
factories = get_scheduler_factories(skip_defaults=True)
106+
107+
self.assertEqual(factories, {})
108+
109+
@patch("torchx.schedulers.load_group")
110+
def test_custom_scheduler_is_default(self, mock_load_group: MagicMock) -> None:
111+
"""first custom scheduler becomes the default"""
112+
mock_aws = MagicMock()
113+
mock_custom = MagicMock()
114+
115+
def load_group_side_effect(group, default):
116+
if group == "torchx.schedulers":
117+
return {"aws_batch": mock_aws, "custom_1": mock_custom}
118+
return {}
119+
120+
mock_load_group.side_effect = load_group_side_effect
121+
122+
default_name = get_default_scheduler_name()
123+
124+
self.assertIn(default_name, ["aws_batch", "custom_1"])

0 commit comments

Comments
 (0)