Skip to content

Commit 68b6131

Browse files
authored
Merge branch 'main' into feat/extend_schedulers_list
2 parents d99cb1e + 9016924 commit 68b6131

File tree

5 files changed

+149
-7
lines changed

5 files changed

+149
-7
lines changed

torchx/schedulers/slurm_scheduler.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def _should_use_gpus_per_node_from_version() -> bool:
135135
"comment",
136136
"mail-user",
137137
"mail-type",
138+
"account",
138139
}
139140
SBATCH_GROUP_OPTIONS = {
140141
"partition",
@@ -159,6 +160,7 @@ def _apply_app_id_env(s: str) -> str:
159160
SlurmOpts = TypedDict(
160161
"SlurmOpts",
161162
{
163+
"account": Optional[str],
162164
"partition": str,
163165
"time": str,
164166
"comment": Optional[str],
@@ -404,6 +406,12 @@ def __init__(self, session_name: str) -> None:
404406

405407
def _run_opts(self) -> runopts:
406408
opts = runopts()
409+
opts.add(
410+
"account",
411+
type_=str,
412+
help="The account to use for the slurm job.",
413+
default=None,
414+
)
407415
opts.add(
408416
"partition",
409417
type_=str,

torchx/schedulers/test/aws_batch_scheduler_test.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,6 @@ def test_submit_dryrun_tags(self, _) -> None:
159159
def test_submit_dryrun_job_role_arn(self) -> None:
160160
cfg = AWSBatchOpts({"queue": "ignored_in_test", "job_role_arn": "fizzbuzz"})
161161
info = create_scheduler("test").submit_dryrun(_test_app(), cfg)
162-
# pyre-ignore[16]
163162
node_groups = info.request.job_def["nodeProperties"]["nodeRangeProperties"]
164163
self.assertEqual(1, len(node_groups))
165164
self.assertEqual(cfg["job_role_arn"], node_groups[0]["container"]["jobRoleArn"])
@@ -169,7 +168,6 @@ def test_submit_dryrun_execution_role_arn(self) -> None:
169168
{"queue": "ignored_in_test", "execution_role_arn": "veryexecutive"}
170169
)
171170
info = create_scheduler("test").submit_dryrun(_test_app(), cfg)
172-
# pyre-ignore[16]
173171
node_groups = info.request.job_def["nodeProperties"]["nodeRangeProperties"]
174172
self.assertEqual(1, len(node_groups))
175173
self.assertEqual(
@@ -179,7 +177,6 @@ def test_submit_dryrun_execution_role_arn(self) -> None:
179177
def test_submit_dryrun_privileged(self) -> None:
180178
cfg = AWSBatchOpts({"queue": "ignored_in_test", "privileged": True})
181179
info = create_scheduler("test").submit_dryrun(_test_app(), cfg)
182-
# pyre-ignore[16]
183180
node_groups = info.request.job_def["nodeProperties"]["nodeRangeProperties"]
184181
self.assertEqual(1, len(node_groups))
185182
self.assertTrue(node_groups[0]["container"]["privileged"])
@@ -189,7 +186,6 @@ def test_submit_dryrun_instance_type_multinode(self) -> None:
189186
resource = specs.named_resources_aws.aws_p3dn_24xlarge()
190187
app = _test_app(num_replicas=2, resource=resource)
191188
info = create_scheduler("test").submit_dryrun(app, cfg)
192-
# pyre-ignore[16]
193189
node_groups = info.request.job_def["nodeProperties"]["nodeRangeProperties"]
194190
self.assertEqual(1, len(node_groups))
195191
self.assertEqual(
@@ -202,7 +198,6 @@ def test_submit_dryrun_instance_type_singlenode(self) -> None:
202198
resource = specs.named_resources_aws.aws_p3dn_24xlarge()
203199
app = _test_app(num_replicas=1, resource=resource)
204200
info = create_scheduler("test").submit_dryrun(app, cfg)
205-
# pyre-ignore[16]
206201
node_groups = info.request.job_def["nodeProperties"]["nodeRangeProperties"]
207202
self.assertEqual(1, len(node_groups))
208203
self.assertTrue("instanceType" in node_groups[0]["container"])
@@ -212,7 +207,6 @@ def test_submit_dryrun_no_instance_type_non_aws(self) -> None:
212207
resource = specs.named_resources_aws.aws_p3dn_24xlarge()
213208
app = _test_app(num_replicas=2)
214209
info = create_scheduler("test").submit_dryrun(app, cfg)
215-
# pyre-ignore[16]
216210
node_groups = info.request.job_def["nodeProperties"]["nodeRangeProperties"]
217211
self.assertEqual(1, len(node_groups))
218212
self.assertTrue("instanceType" not in node_groups[0]["container"])

torchx/schedulers/test/slurm_scheduler_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,24 @@ def test_dryrun_comment(self, mock_version: MagicMock) -> None:
696696
info.request.cmd,
697697
)
698698

699+
@patch(
700+
"torchx.schedulers.slurm_scheduler.version",
701+
return_value=SLURM_VERSION_24_5,
702+
)
703+
def test_account(self, mock_version: MagicMock) -> None:
704+
scheduler = create_scheduler("foo")
705+
app = simple_app()
706+
info = scheduler.submit_dryrun(
707+
app,
708+
cfg={
709+
"account": "foobar",
710+
},
711+
)
712+
self.assertIn(
713+
"--account=foobar",
714+
info.request.cmd,
715+
)
716+
699717
@patch(
700718
"torchx.schedulers.slurm_scheduler.version",
701719
return_value=SLURM_VERSION_24_5,

torchx/specs/api.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,9 @@ def _apply_nested(self, d: typing.Dict[str, Any]) -> typing.Dict[str, Any]:
253253
current_dict[k] = self.substitute(v)
254254
elif isinstance(v, list):
255255
for i in range(len(v)):
256-
if isinstance(v[i], str):
256+
if isinstance(v[i], dict):
257+
stack.append(v[i])
258+
elif isinstance(v[i], str):
257259
v[i] = self.substitute(v[i])
258260
return d
259261

torchx/specs/test/api_test.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -945,3 +945,123 @@ def test_apply(self) -> None:
945945
self.assertNotEqual(newrole, role)
946946
self.assertEqual(newrole.args, ["img_root"])
947947
self.assertEqual(newrole.env, {"FOO": "app_id"})
948+
949+
def test_apply_nested_with_list_of_dicts(self) -> None:
950+
"""Test that _apply_nested correctly handles dictionaries nested inside lists."""
951+
role = Role(
952+
name="test",
953+
image="test_image",
954+
entrypoint="foo.py",
955+
metadata={
956+
"nested_list": [
957+
{"key1": macros.app_id, "key2": "static"},
958+
{"key3": macros.img_root},
959+
]
960+
},
961+
)
962+
v = macros.Values(
963+
img_root="img_root_value",
964+
app_id="app_id_value",
965+
replica_id="replica_id_value",
966+
base_img_root="base_img_root_value",
967+
rank0_env="rank0_env_value",
968+
)
969+
newrole = v.apply(role)
970+
self.assertEqual(newrole.metadata["nested_list"][0]["key1"], "app_id_value")
971+
self.assertEqual(newrole.metadata["nested_list"][0]["key2"], "static")
972+
self.assertEqual(newrole.metadata["nested_list"][1]["key3"], "img_root_value")
973+
974+
def test_apply_nested_with_deeply_nested_structures(self) -> None:
975+
"""Test that _apply_nested handles deeply nested structures with mixed types."""
976+
role = Role(
977+
name="test",
978+
image="test_image",
979+
entrypoint="foo.py",
980+
metadata={
981+
"level1": {
982+
"level2": {
983+
"list_with_dicts": [
984+
{
985+
"nested_key": macros.replica_id,
986+
"nested_list": [macros.app_id, "static_value"],
987+
},
988+
{"another_key": macros.img_root},
989+
],
990+
"simple_string": macros.rank0_env,
991+
}
992+
}
993+
},
994+
)
995+
v = macros.Values(
996+
img_root="img_root_value",
997+
app_id="app_id_value",
998+
replica_id="replica_id_value",
999+
base_img_root="base_img_root_value",
1000+
rank0_env="rank0_env_value",
1001+
)
1002+
newrole = v.apply(role)
1003+
1004+
# Check deeply nested dict in list
1005+
nested_dict = newrole.metadata["level1"]["level2"]["list_with_dicts"][0]
1006+
self.assertEqual(nested_dict["nested_key"], "replica_id_value")
1007+
self.assertEqual(nested_dict["nested_list"][0], "app_id_value")
1008+
self.assertEqual(nested_dict["nested_list"][1], "static_value")
1009+
1010+
# Check second dict in list
1011+
second_dict = newrole.metadata["level1"]["level2"]["list_with_dicts"][1]
1012+
self.assertEqual(second_dict["another_key"], "img_root_value")
1013+
1014+
# Check simple string at nested level
1015+
self.assertEqual(
1016+
newrole.metadata["level1"]["level2"]["simple_string"], "rank0_env_value"
1017+
)
1018+
1019+
def test_apply_nested_with_list_of_strings(self) -> None:
1020+
"""Test that _apply_nested still works correctly with lists of strings."""
1021+
role = Role(
1022+
name="test",
1023+
image="test_image",
1024+
entrypoint="foo.py",
1025+
metadata={
1026+
"string_list": [macros.app_id, macros.img_root, "static"],
1027+
},
1028+
)
1029+
v = macros.Values(
1030+
img_root="img_root_value",
1031+
app_id="app_id_value",
1032+
replica_id="replica_id_value",
1033+
base_img_root="base_img_root_value",
1034+
rank0_env="rank0_env_value",
1035+
)
1036+
newrole = v.apply(role)
1037+
self.assertEqual(newrole.metadata["string_list"][0], "app_id_value")
1038+
self.assertEqual(newrole.metadata["string_list"][1], "img_root_value")
1039+
self.assertEqual(newrole.metadata["string_list"][2], "static")
1040+
1041+
def test_apply_nested_with_mixed_list_types(self) -> None:
1042+
"""Test that _apply_nested handles lists with mixed types (strings, dicts, other)."""
1043+
role = Role(
1044+
name="test",
1045+
image="test_image",
1046+
entrypoint="foo.py",
1047+
metadata={
1048+
"mixed_list": [
1049+
macros.app_id,
1050+
{"nested": macros.img_root},
1051+
42, # non-string, non-dict value
1052+
"static_string",
1053+
],
1054+
},
1055+
)
1056+
v = macros.Values(
1057+
img_root="img_root_value",
1058+
app_id="app_id_value",
1059+
replica_id="replica_id_value",
1060+
base_img_root="base_img_root_value",
1061+
rank0_env="rank0_env_value",
1062+
)
1063+
newrole = v.apply(role)
1064+
self.assertEqual(newrole.metadata["mixed_list"][0], "app_id_value")
1065+
self.assertEqual(newrole.metadata["mixed_list"][1]["nested"], "img_root_value")
1066+
self.assertEqual(newrole.metadata["mixed_list"][2], 42)
1067+
self.assertEqual(newrole.metadata["mixed_list"][3], "static_string")

0 commit comments

Comments
 (0)