@@ -43,3 +43,75 @@ def test_get_local_schedulers(self, mock_load_group: MagicMock) -> None:
4343
4444 for scheduler in schedulers .values ():
4545 self .assertEqual ("test_session" , scheduler .session_name )
46+
47+ @patch ("torchx.schedulers.load_group" )
48+ def test_torchx_schedulers_overrides_all (self , mock_load_group : MagicMock ) -> None :
49+ """torchx.schedulers completely overrides defaults and ignores extras"""
50+ mock_custom : MagicMock = MagicMock ()
51+ mock_extra : MagicMock = MagicMock ()
52+
53+ mock_load_group .side_effect = lambda group , default : (
54+ {"custom" : mock_custom }
55+ if group == "torchx.schedulers"
56+ else {"extra" : mock_extra } if group == "torchx.schedulers.extra" else {}
57+ )
58+
59+ factories = get_scheduler_factories ()
60+
61+ self .assertEqual (factories , {"custom" : mock_custom })
62+ self .assertNotIn ("local_docker" , factories )
63+ self .assertNotIn ("extra" , factories )
64+
65+ @patch ("torchx.schedulers.load_group" )
66+ def test_no_custom_returns_defaults_and_extras (
67+ self , mock_load_group : MagicMock
68+ ) -> None :
69+ """no custom schedulers returns built-in + extras"""
70+ mock_extra : MagicMock = MagicMock ()
71+
72+ mock_load_group .side_effect = lambda group , default : (
73+ {"extra" : mock_extra } if group == "torchx.schedulers.extra" else default
74+ )
75+
76+ factories = get_scheduler_factories ()
77+
78+ self .assertIn ("local_docker" , factories )
79+ self .assertIn ("slurm" , factories )
80+ self .assertIn ("extra" , factories )
81+
82+ @patch ("torchx.schedulers.load_group" )
83+ def test_no_custom_no_extras_returns_builtins (
84+ self , mock_load_group : MagicMock
85+ ) -> None :
86+ """no custom, no extras returns only built-in schedulers"""
87+ mock_load_group .side_effect = lambda group , default : default
88+
89+ factories = get_scheduler_factories ()
90+
91+ self .assertIn ("local_docker" , factories )
92+ self .assertIn ("slurm" , factories )
93+
94+ @patch ("torchx.schedulers.load_group" )
95+ def test_skip_defaults_returns_empty (self , mock_load_group : MagicMock ) -> None :
96+ """skip_defaults=True with no custom schedulers returns empty"""
97+ mock_load_group .side_effect = lambda group , default : default
98+
99+ factories = get_scheduler_factories (skip_defaults = True )
100+
101+ self .assertEqual (factories , {})
102+
103+ @patch ("torchx.schedulers.load_group" )
104+ def test_custom_scheduler_is_default (self , mock_load_group : MagicMock ) -> None :
105+ """first custom scheduler becomes the default"""
106+ mock_aws : MagicMock = MagicMock ()
107+ mock_custom : MagicMock = MagicMock ()
108+
109+ mock_load_group .side_effect = lambda group , default : (
110+ {"aws_batch" : mock_aws , "custom_1" : mock_custom }
111+ if group == "torchx.schedulers"
112+ else {}
113+ )
114+
115+ default_name = get_default_scheduler_name ()
116+
117+ self .assertIn (default_name , ["aws_batch" , "custom_1" ])
0 commit comments