Skip to content

Commit f21d198

Browse files
committed
Add multi-container support for McJax workload create
1 parent 9b35b42 commit f21d198

File tree

9 files changed

+88
-1
lines changed

9 files changed

+88
-1
lines changed

goldens/Workload_create.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ kubectl get configmap golden-cluster-resources-configmap -o=custom-columns="Conf
1616
[XPK] No gcp parallelstore instances to add detected.
1717
[XPK] No gce persistent disk instances to add detected.
1818
[XPK] No managed lustre instances to add detected.
19+
Namespace(xpk_subcommands='workload', func=<function workload_create at 0x7fbd15b03100>, xpk_workload_subcommands='create', command='bash hello', tpu_type='v5p-8', device_type=None, storage=[], num_nodes=1, multi_container=False, scheduler='default-scheduler', ramdisk_directory='', mtc_enabled=False, debug_dump_gcs=None, deploy_stacktrace_sidecar=False, use_pathways=False, restart_on_exit_codes=None, workload='golden-workload', cluster='golden-cluster', project='golden-project', zone='us-central1-a', dry_run=True, skip_validation=False, quiet=False, docker_name='jax-tpu', output_manifest_file=None, num_slices=1, priority='medium', max_restarts='0', ttl_seconds_after_finished=43200, termination_grace_period_seconds='30', colocated_python_sidecar_image='', enable_debug_logs=False, env_file=None, env={}, base_docker_image='python:3.10', script_dir='/tmp', docker_image=None, docker_image_pull_secret=None, use_vertex_tensorboard=False, experiment_name=None, on_demand=False, reservation=None, spot=False, flex=False, enable_ray_cluster=False)
1920
[XPK] Temp file (4b6736a12db8ea0f78ce793fd0d4ee0c94c652303f1dc0fecad085ea0993f688) content:
2021
FROM python:3.10
2122

goldens/Workload_create_pathways.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ kubectl get configmap golden-cluster-metadata-configmap -o=custom-columns="Confi
1111
[XPK] Task: `GKE Cluster Get ConfigMap` is implemented by the following command not running since it is a dry run.
1212
kubectl get configmap golden-cluster-resources-configmap -o=custom-columns="ConfigData:data" --no-headers=true
1313
[XPK] gke_accelerator type not found in config map. Autoprovisioning is not enabled.
14+
Namespace(xpk_subcommands='workload', func=<function workload_create_pathways at 0x7fa65b206fc0>, xpk_workload_subcommands='create-pathways', tpu_type='v5p-8', headless=False, proxy_server_image='', server_image='', pathways_gcs_location='gs://cloud-pathways-staging/tmp', command='bash hello', storage=[], custom_pathways_server_args='', custom_pathways_proxy_server_args='', custom_pathways_worker_args='', elastic_slices=0, max_slice_restarts=1, workload='golden-workload', cluster='golden-cluster', project='golden-project', zone='us-central1-a', dry_run=True, skip_validation=False, quiet=False, docker_name='jax-tpu', output_manifest_file=None, num_slices=1, priority='medium', max_restarts='0', ttl_seconds_after_finished=43200, termination_grace_period_seconds='30', colocated_python_sidecar_image='', enable_debug_logs=False, env_file=None, env={}, base_docker_image='python:3.10', script_dir='/tmp', docker_image=None, docker_image_pull_secret=None, use_vertex_tensorboard=False, experiment_name=None, on_demand=False, reservation=None, spot=False, flex=False, enable_ray_cluster=False, use_pathways=True)
1415
[XPK] Task: `Check if PathwaysJob is installed on golden-cluster` is implemented by the following command not running since it is a dry run.
1516
kubectl get pods -n pathways-job-system --no-headers -o custom-columns=NAME:.metadata.name
1617
[XPK] check_if_pathways_job_is_installed 0 0

goldens/Workload_create_sub-slicing.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ kubectl get configmap golden-cluster-resources-configmap -o=custom-columns="Conf
1919
[XPK] No gcp parallelstore instances to add detected.
2020
[XPK] No gce persistent disk instances to add detected.
2121
[XPK] No managed lustre instances to add detected.
22+
Namespace(xpk_subcommands='workload', func=<function workload_create at 0x7f750e2cf100>, xpk_workload_subcommands='create', command='bash hello', tpu_type='v6e-2x4', device_type=None, storage=[], num_nodes=1, multi_container=False, scheduler='default-scheduler', ramdisk_directory='', mtc_enabled=False, debug_dump_gcs=None, deploy_stacktrace_sidecar=False, use_pathways=False, restart_on_exit_codes=None, workload='golden-workload', cluster='golden-cluster', project='golden-project', zone='us-central1-a', dry_run=True, skip_validation=False, quiet=False, docker_name='jax-tpu', output_manifest_file=None, num_slices=1, priority='medium', max_restarts='0', ttl_seconds_after_finished=43200, termination_grace_period_seconds='30', colocated_python_sidecar_image='', enable_debug_logs=False, env_file=None, env={}, base_docker_image='python:3.10', script_dir='/tmp', docker_image=None, docker_image_pull_secret=None, use_vertex_tensorboard=False, experiment_name=None, on_demand=False, reservation=None, spot=False, flex=False, enable_ray_cluster=False)
2223
[XPK] Workload will be scheduled using the Sub-slicing feature.
2324
[XPK] Temp file (4b6736a12db8ea0f78ce793fd0d4ee0c94c652303f1dc0fecad085ea0993f688) content:
2425
FROM python:3.10

goldens/Workload_create_super-slicing.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ kubectl get configmap golden-cluster-resources-configmap -o=custom-columns="Conf
1919
[XPK] No gcp parallelstore instances to add detected.
2020
[XPK] No gce persistent disk instances to add detected.
2121
[XPK] No managed lustre instances to add detected.
22+
Namespace(xpk_subcommands='workload', func=<function workload_create at 0x7f1062083100>, xpk_workload_subcommands='create', command='bash hello', tpu_type='tpu7x-4x4x20', device_type=None, storage=[], num_nodes=1, multi_container=False, scheduler='default-scheduler', ramdisk_directory='', mtc_enabled=False, debug_dump_gcs=None, deploy_stacktrace_sidecar=False, use_pathways=False, restart_on_exit_codes=None, workload='golden-workload', cluster='golden-cluster', project='golden-project', zone='us-central1-a', dry_run=True, skip_validation=False, quiet=False, docker_name='jax-tpu', output_manifest_file=None, num_slices=1, priority='medium', max_restarts='0', ttl_seconds_after_finished=43200, termination_grace_period_seconds='30', colocated_python_sidecar_image='', enable_debug_logs=False, env_file=None, env={}, base_docker_image='python:3.10', script_dir='/tmp', docker_image=None, docker_image_pull_secret=None, use_vertex_tensorboard=False, experiment_name=None, on_demand=False, reservation=None, spot=False, flex=False, enable_ray_cluster=False)
2223
[XPK] Workload will be scheduled using the Super-slicing feature.
2324
[XPK] Temp file (4b6736a12db8ea0f78ce793fd0d4ee0c94c652303f1dc0fecad085ea0993f688) content:
2425
FROM python:3.10

goldens/Workload_create_with_output-manifest-file.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ kubectl get configmap golden-cluster-resources-configmap -o=custom-columns="Conf
1616
[XPK] No gcp parallelstore instances to add detected.
1717
[XPK] No gce persistent disk instances to add detected.
1818
[XPK] No managed lustre instances to add detected.
19+
Namespace(xpk_subcommands='workload', func=<function workload_create at 0x7f0fae8cf100>, xpk_workload_subcommands='create', command='bash hello', tpu_type='v5p-8', device_type=None, storage=[], num_nodes=1, multi_container=False, scheduler='default-scheduler', ramdisk_directory='', mtc_enabled=False, debug_dump_gcs=None, deploy_stacktrace_sidecar=False, use_pathways=False, restart_on_exit_codes=None, workload='golden-workload', cluster='golden-cluster', project='golden-project', zone='us-central1-a', dry_run=True, skip_validation=False, quiet=False, docker_name='jax-tpu', output_manifest_file='/var/tmp/manifest.yaml', num_slices=1, priority='medium', max_restarts='0', ttl_seconds_after_finished=43200, termination_grace_period_seconds='30', colocated_python_sidecar_image='', enable_debug_logs=False, env_file=None, env={}, base_docker_image='python:3.10', script_dir='/tmp', docker_image=None, docker_image_pull_secret=None, use_vertex_tensorboard=False, experiment_name=None, on_demand=False, reservation=None, spot=False, flex=False, enable_ray_cluster=False)
1920
[XPK] Temp file (4b6736a12db8ea0f78ce793fd0d4ee0c94c652303f1dc0fecad085ea0993f688) content:
2021
FROM python:3.10
2122

pathways-job

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Subproject commit 2880c34e7d71596664bafa1c3cecb5754a9991e7

src/xpk/commands/workload.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -482,14 +482,31 @@ def workload_create(args) -> None:
482482
# Currently failure policy rules are supported for Pathways workloads. b/408465881
483483
failure_policy_rules = ''
484484
pod_failure_policy = ''
485+
print(args)
485486
if not args.use_pathways:
486487
failure_policy_rules = """rules:
487488
- action: FailJobSet
488489
onJobFailureReasons:
489490
- PodFailurePolicy"""
490491
restart_on_exit_codes_list = get_restart_exit_codes(args)
491492
restart_on_exit_codes = ','.join(map(str, restart_on_exit_codes_list))
492-
pod_failure_policy = f"""
493+
if args.multi_container:
494+
pod_failure_policy = f"""
495+
podFailurePolicy:
496+
rules:
497+
- action: FailJob
498+
onExitCodes:
499+
containerName: {get_main_container_docker_image(args, workload_system)}-1
500+
operator: NotIn
501+
values: [{restart_on_exit_codes}]
502+
- action: FailJob
503+
onExitCodes:
504+
containerName: {get_main_container_docker_image(args, workload_system)}-2
505+
operator: NotIn
506+
values: [{restart_on_exit_codes}]"""
507+
508+
else:
509+
pod_failure_policy = f"""
493510
podFailurePolicy:
494511
rules:
495512
- action: FailJob

src/xpk/core/docker_container.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,65 @@ def get_main_container(args, system, docker_image, resource_type) -> str:
112112
'touch /shared-volume/stacktrace_signal; '
113113
)
114114

115+
if not args.use_pathways and args.multi_container:
116+
containers = []
117+
for i in range(2):
118+
container_yaml = """
119+
- name: {docker_name}
120+
image: {docker_image}
121+
{image_pull_policy}
122+
env: {env}
123+
securityContext:
124+
privileged: true
125+
command:
126+
- bash
127+
- -c
128+
- |
129+
echo XPK Start: $(date);
130+
_sigterm() (kill -SIGTERM $! 2>/dev/null;);
131+
trap _sigterm SIGTERM;
132+
{gsutil_test_command}
133+
({command}) & PID=$!;
134+
while kill -0 $PID 2>/dev/null;
135+
do sleep 5;
136+
done;
137+
wait $PID;
138+
EXIT_CODE=$?;
139+
{xpk_internal_commands}
140+
echo XPK End: $(date);
141+
echo EXIT_CODE=$EXIT_CODE;
142+
{tpu_stacktrace_terminate_command}
143+
{gpu_workload_terminate_command}
144+
exit $EXIT_CODE
145+
resources:
146+
limits:
147+
{resources}
148+
"""
149+
volume_mounts = get_volume_mounts(args, system)
150+
if volume_mounts != '':
151+
container_yaml += """
152+
volumeMounts:
153+
{volume_mounts}
154+
"""
155+
containers.append(
156+
container_yaml.format(
157+
args=args,
158+
system=system,
159+
image_pull_policy=add_image_pull_policy_for_pw_or_gpu(args, system),
160+
env=get_env_container(args, system),
161+
docker_name=f'jax-tpu-{i+1}',
162+
docker_image=docker_image,
163+
gsutil_test_command=gsutil_test_command,
164+
command=command,
165+
tpu_stacktrace_terminate_command=tpu_stacktrace_terminate_command,
166+
gpu_workload_terminate_command=gpu_workload_terminate_command,
167+
xpk_internal_commands=xpk_internal_commands,
168+
resources=f'{resource_type}: {int(system.chips_per_vm / 2)}',
169+
volume_mounts=volume_mounts,
170+
)
171+
)
172+
return ''.join(containers)
173+
115174
yaml = """- name: {docker_name}
116175
image: {docker_image}
117176
{image_pull_policy}

src/xpk/parser/workload.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,11 @@ def set_workload_create_parser(workload_create_parser: ArgumentParser):
131131
default=1,
132132
help='The number of nodes to use, default=1.',
133133
)
134+
workload_create_parser_optional_arguments.add_argument(
135+
'--multi-container',
136+
action='store_true',
137+
help='Enable multi-container workload.',
138+
)
134139
workload_create_parser_optional_arguments.add_argument(
135140
'--scheduler',
136141
type=str,

0 commit comments

Comments
 (0)