Skip to content

Commit 9b35b42

Browse files
authored
Add --num-cubes for superslicing cluster create (#914)
* Super-slicing cluster create arguments validation * lint and goldens fixes * fix f-string-without-interpolation * add --num-cubes for superslicing cluster create * update help message
1 parent 48599a1 commit 9b35b42

File tree

6 files changed

+193
-6
lines changed

6 files changed

+193
-6
lines changed

goldens.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ goldens:
1818
"Cluster create sub-slicing":
1919
command: SUB_SLICING_ENABLED=true xpk cluster create --project=golden-project --zone=us-central1-a --cluster=golden-cluster --tpu-type=v6e-4x4 --reservation=golden-reservation --sub-slicing --dry-run
2020
"Cluster create super-slicing":
21-
command: SUPER_SLICING_ENABLED=true xpk cluster create --project=golden-project --zone=us-central1-a --cluster=golden-cluster --tpu-type=tpu7x-4x4x4 --reservation=golden-reservation/reservationBlocks/block/reservationSubBlocks/subblock --super-slicing --num-slices=5 --dry-run
21+
command: SUPER_SLICING_ENABLED=true xpk cluster create --project=golden-project --zone=us-central1-a --cluster=golden-cluster --tpu-type=tpu7x-4x4x4 --reservation=golden-reservation/reservationBlocks/block/reservationSubBlocks/subblock --super-slicing --num-cubes=5 --dry-run
2222
"Cluster create private":
2323
command: xpk cluster create-pathways --project=golden-project --zone=us-central1-a --cluster=golden-cluster-private --private --tpu-type=v5p-8 --num-slices=1 --default-pool-cpu-machine-type=n1-standard-16 --default-pool-cpu-num-nodes=4 --reservation=golden-reservation --dry-run
2424
"Cluster create with Managed Lustre driver":

goldens/Cluster_create_super-slicing.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
$ SUPER_SLICING_ENABLED=true xpk cluster create --project=golden-project --zone=us-central1-a --cluster=golden-cluster --tpu-type=tpu7x-4x4x4 --reservation=golden-reservation/reservationBlocks/block/reservationSubBlocks/subblock --super-slicing --num-slices=5 --dry-run
1+
$ SUPER_SLICING_ENABLED=true xpk cluster create --project=golden-project --zone=us-central1-a --cluster=golden-cluster --tpu-type=tpu7x-4x4x4 --reservation=golden-reservation/reservationBlocks/block/reservationSubBlocks/subblock --super-slicing --num-cubes=5 --dry-run
22
[XPK] Starting xpk v0.0.0
33
[XPK] Starting cluster create for cluster golden-cluster:
44
[XPK] Working on golden-project and us-central1-a

src/xpk/commands/cluster.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,9 +211,11 @@ def _validate_cluster_create_args(args, system: SystemCharacteristics):
211211
if FeatureFlags.SUB_SLICING_ENABLED and args.sub_slicing:
212212
validate_sub_slicing_system(system)
213213
_validate_sub_slicing_reservation(args)
214-
if FeatureFlags.SUPER_SLICING_ENABLED and args.super_slicing:
215-
validate_super_slicing_system(system)
216-
_validate_super_slicing_reservation(args)
214+
if FeatureFlags.SUPER_SLICING_ENABLED:
215+
_validate_num_slices_and_set_default(args)
216+
if args.super_slicing:
217+
validate_super_slicing_system(system)
218+
_validate_super_slicing_reservation(args)
217219
if args.enable_pathways:
218220
_validate_pathways_machine(args)
219221

@@ -281,6 +283,22 @@ def _validate_gsc_reservation(args, creation_description: str):
281283
xpk_exit(1)
282284

283285

286+
def _validate_num_slices_and_set_default(args):
287+
if args.num_cubes is not None and not args.super_slicing:
288+
xpk_print('--num-cubes can only be used with --super-slicing')
289+
xpk_exit(1)
290+
291+
if (
292+
args.num_cubes is not None
293+
and args.num_slices is not None
294+
and args.num_cubes != args.num_slices
295+
):
296+
xpk_print('--num-cubes must not be different from --num-slices')
297+
xpk_exit(1)
298+
299+
args.num_slices = args.num_slices or args.num_cubes or 1
300+
301+
284302
def cluster_create(args) -> None:
285303
"""Function around cluster creation.
286304

src/xpk/commands/cluster_test.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -638,12 +638,16 @@ def test_validate_cluster_create_args_for_correct_super_slicing_args_pass(
638638
args = construct_args(
639639
super_slicing=True,
640640
reservation='test-reservation/reservationBlocks/block',
641+
num_cubes=None,
642+
num_slices=None,
641643
)
642644

643645
_validate_cluster_create_args(args, SUPER_SLICING_SYSTEM)
644646
args = construct_args(
645647
super_slicing=True,
646648
reservation='test-reservation/reservationBlocks/block/reservationSubBlocks/subblock',
649+
num_cubes=None,
650+
num_slices=None,
647651
)
648652
_validate_cluster_create_args(
649653
args, UserFacingNameToSystemCharacteristics['tpu7x-128']
@@ -659,6 +663,8 @@ def test_validate_cluster_create_args_for_super_slicing_system_not_supported_thr
659663
args = construct_args(
660664
super_slicing=True,
661665
reservation='test-reservation/reservationBlocks/block',
666+
num_cubes=None,
667+
num_slices=None,
662668
)
663669

664670
with pytest.raises(SystemExit):
@@ -680,6 +686,8 @@ def test_validate_cluster_create_args_for_super_slicing_missing_reservation(
680686
args = construct_args(
681687
super_slicing=True,
682688
reservation=None,
689+
num_cubes=None,
690+
num_slices=None,
683691
)
684692

685693
with pytest.raises(SystemExit):
@@ -699,6 +707,8 @@ def test_validate_cluster_create_args_for_super_slicing_reservation_no_blocks(
699707
args = construct_args(
700708
super_slicing=True,
701709
reservation='reservation',
710+
num_cubes=None,
711+
num_slices=None,
702712
)
703713

704714
with pytest.raises(SystemExit):
@@ -718,6 +728,8 @@ def test_validate_cluster_create_args_for_super_slicing_sparse_deployment_type_r
718728
args = construct_args(
719729
super_slicing=True,
720730
reservation='test-reservation/reservationBlocks/block',
731+
num_cubes=None,
732+
num_slices=None,
721733
)
722734
mocks.commands_get_reservation_deployment_type.return_value = 'SPARSE'
723735

@@ -729,3 +741,73 @@ def test_validate_cluster_create_args_for_super_slicing_sparse_deployment_type_r
729741
'Refer to the documentation for more information on creating Cluster'
730742
in mocks.commands_print_mock.call_args[0][0]
731743
)
744+
745+
746+
def test_validate_cluster_create_args_forbids_num_cubes_without_superslicing(
747+
mocks: _Mocks,
748+
):
749+
FeatureFlags.SUPER_SLICING_ENABLED = True # enable the feature
750+
args = construct_args(
751+
super_slicing=False, # but disable the flag
752+
reservation='test-reservation/reservationBlocks/block',
753+
num_cubes=1,
754+
num_slices=None,
755+
)
756+
757+
with pytest.raises(SystemExit):
758+
_validate_cluster_create_args(args, SUPER_SLICING_SYSTEM)
759+
760+
assert mocks.commands_print_mock.call_count == 1
761+
assert (
762+
'--num-cubes can only be used with --super-slicing'
763+
in mocks.commands_print_mock.call_args[0][0]
764+
)
765+
766+
767+
def test_validate_cluster_create_args_forbids_num_cubes_different_from_num_slices(
768+
mocks: _Mocks,
769+
):
770+
FeatureFlags.SUPER_SLICING_ENABLED = True
771+
args = construct_args(
772+
super_slicing=True,
773+
reservation='test-reservation/reservationBlocks/block',
774+
num_cubes=1,
775+
num_slices=2,
776+
)
777+
778+
with pytest.raises(SystemExit):
779+
_validate_cluster_create_args(args, SUPER_SLICING_SYSTEM)
780+
781+
assert mocks.commands_print_mock.call_count == 1
782+
assert (
783+
'--num-cubes must not be different from --num-slices'
784+
in mocks.commands_print_mock.call_args[0][0]
785+
)
786+
787+
788+
@pytest.mark.parametrize(
789+
'num_cubes, num_slices, expected',
790+
[
791+
(None, None, 1),
792+
(3, None, 3),
793+
(None, 3, 3),
794+
(3, 3, 3),
795+
],
796+
)
797+
def test_validate_cluster_create_args_sets_correct_num_slices(
798+
mocks: _Mocks,
799+
num_cubes: int | None,
800+
num_slices: int | None,
801+
expected: int,
802+
):
803+
FeatureFlags.SUPER_SLICING_ENABLED = True
804+
args = construct_args(
805+
super_slicing=True,
806+
reservation='test-reservation/reservationBlocks/block',
807+
num_cubes=num_cubes,
808+
num_slices=num_slices,
809+
)
810+
811+
_validate_cluster_create_args(args, SUPER_SLICING_SYSTEM)
812+
813+
assert args.num_slices == expected

src/xpk/parser/cluster.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,10 @@ def add_shared_cluster_create_optional_arguments(
604604
parser_or_group.add_argument(
605605
'--num-slices',
606606
type=int,
607-
default=1,
607+
# removing default in case of super slicing because
608+
# --num-slices must be equal to --num-cubes if both are set
609+
# it will default to 1 during validation
610+
default=1 if not FeatureFlags.SUPER_SLICING_ENABLED else None,
608611
help='The number of slices to run the job on, defaults to 1.',
609612
required=False,
610613
)
@@ -928,3 +931,14 @@ def add_cluster_create_super_slicing_arguments(
928931
action='store_true',
929932
help='Whether to set up cluster to support super-slicing',
930933
)
934+
parser_or_group.add_argument(
935+
'--num-cubes',
936+
type=int,
937+
# default value is set during validation because it needs to be compared
938+
# against --num-slices
939+
help=(
940+
'Total number of cubes to create within a cluster, defaults to 1. Can'
941+
' only be used with --super-slicing.'
942+
),
943+
required=False,
944+
)

src/xpk/parser/cluster_test.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,3 +188,76 @@ def test_cluster_create_super_slicing_can_be_set():
188188
)
189189

190190
assert args.super_slicing is True
191+
192+
193+
def test_cluster_create_num_cubes_is_hidden_with_flag_off():
194+
FeatureFlags.SUPER_SLICING_ENABLED = False
195+
parser = argparse.ArgumentParser()
196+
197+
set_cluster_create_parser(parser)
198+
help_str = parser.format_help()
199+
200+
assert "--num-cubes" not in help_str
201+
202+
203+
def test_cluster_create_num_cubes_is_shown_with_flag_on():
204+
FeatureFlags.SUPER_SLICING_ENABLED = True
205+
parser = argparse.ArgumentParser()
206+
207+
set_cluster_create_parser(parser)
208+
help_str = parser.format_help()
209+
210+
assert "--num-cubes" in help_str
211+
212+
213+
def test_cluster_create_num_cubes_can_be_set():
214+
FeatureFlags.SUPER_SLICING_ENABLED = True
215+
parser = argparse.ArgumentParser()
216+
217+
set_cluster_create_parser(parser)
218+
args = parser.parse_args(
219+
[
220+
"--cluster",
221+
"test-cluster",
222+
"--tpu-type",
223+
"tpu7x-2",
224+
"--num-cubes",
225+
"5",
226+
],
227+
)
228+
229+
assert args.num_cubes == 5
230+
231+
232+
def test_cluster_create_num_slices_defaults_to_1_if_no_superslicing_feature():
233+
FeatureFlags.SUPER_SLICING_ENABLED = False
234+
parser = argparse.ArgumentParser()
235+
236+
set_cluster_create_parser(parser)
237+
args = parser.parse_args(
238+
[
239+
"--cluster",
240+
"test-cluster",
241+
"--tpu-type",
242+
"tpu7x-2",
243+
],
244+
)
245+
246+
assert args.num_slices == 1
247+
248+
249+
def test_cluster_create_num_slices_has_no_default_if_superslicing_feature():
250+
FeatureFlags.SUPER_SLICING_ENABLED = True
251+
parser = argparse.ArgumentParser()
252+
253+
set_cluster_create_parser(parser)
254+
args = parser.parse_args(
255+
[
256+
"--cluster",
257+
"test-cluster",
258+
"--tpu-type",
259+
"tpu7x-2",
260+
],
261+
)
262+
263+
assert args.num_slices is None

0 commit comments

Comments
 (0)