@@ -43,3 +43,82 @@ def test_get_local_schedulers(self, mock_load_group: MagicMock) -> None:
4343
4444 for scheduler in schedulers .values ():
4545 self .assertEqual ("test_session" , scheduler .session_name )
46+
47+ @patch ("torchx.schedulers.load_group" )
48+ def test_torchx_schedulers_overrides_all (self , mock_load_group : MagicMock ) -> None :
49+ """torchx.schedulers completely overrides defaults and ignores extras"""
50+ mock_custom = MagicMock ()
51+ mock_extra = MagicMock ()
52+
53+ def load_group_side_effect (group , default ):
54+ if group == "torchx.schedulers" :
55+ return {"custom" : mock_custom }
56+ elif group == "torchx.schedulers.extra" :
57+ return {"extra" : mock_extra }
58+ return {}
59+
60+ mock_load_group .side_effect = load_group_side_effect
61+
62+ factories = get_scheduler_factories ()
63+
64+ self .assertEqual (factories , {"custom" : mock_custom })
65+ self .assertNotIn ("local_docker" , factories )
66+ self .assertNotIn ("extra" , factories )
67+
68+ @patch ("torchx.schedulers.load_group" )
69+ def test_no_custom_returns_defaults_and_extras (
70+ self , mock_load_group : MagicMock
71+ ) -> None :
72+ """no custom schedulers returns built-in + extras"""
73+ mock_extra = MagicMock ()
74+
75+ def load_group_side_effect (group , default ):
76+ if group == "torchx.schedulers.extra" :
77+ return {"extra" : mock_extra }
78+ return default
79+
80+ mock_load_group .side_effect = load_group_side_effect
81+
82+ factories = get_scheduler_factories ()
83+
84+ self .assertIn ("local_docker" , factories )
85+ self .assertIn ("slurm" , factories )
86+ self .assertIn ("extra" , factories )
87+
88+ @patch ("torchx.schedulers.load_group" )
89+ def test_no_custom_no_extras_returns_builtins (
90+ self , mock_load_group : MagicMock
91+ ) -> None :
92+ """no custom, no extras returns only built-in schedulers"""
93+ mock_load_group .side_effect = lambda group , default : default
94+
95+ factories = get_scheduler_factories ()
96+
97+ self .assertIn ("local_docker" , factories )
98+ self .assertIn ("slurm" , factories )
99+
100+ @patch ("torchx.schedulers.load_group" )
101+ def test_skip_defaults_returns_empty (self , mock_load_group : MagicMock ) -> None :
102+ """skip_defaults=True with no custom schedulers returns empty"""
103+ mock_load_group .side_effect = lambda group , default : default
104+
105+ factories = get_scheduler_factories (skip_defaults = True )
106+
107+ self .assertEqual (factories , {})
108+
109+ @patch ("torchx.schedulers.load_group" )
110+ def test_custom_scheduler_is_default (self , mock_load_group : MagicMock ) -> None :
111+ """first custom scheduler becomes the default"""
112+ mock_aws = MagicMock ()
113+ mock_custom = MagicMock ()
114+
115+ def load_group_side_effect (group , default ):
116+ if group == "torchx.schedulers" :
117+ return {"aws_batch" : mock_aws , "custom_1" : mock_custom }
118+ return {}
119+
120+ mock_load_group .side_effect = load_group_side_effect
121+
122+ default_name = get_default_scheduler_name ()
123+
124+ self .assertIn (default_name , ["aws_batch" , "custom_1" ])
0 commit comments