Skip to content

Commit 0e8f567

Browse files
committed
Add multi-container support for McJax workload create
1 parent 9b35b42 commit 0e8f567

File tree

6 files changed

+148
-1
lines changed

6 files changed

+148
-1
lines changed

pathways-job

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Subproject commit 2880c34e7d71596664bafa1c3cecb5754a9991e7

src/xpk/commands/cluster_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def construct_args(**kwargs: Any) -> Namespace:
147147
docker_image_pull_secret='',
148148
managed_mldiagnostics=False,
149149
output_manifest_file='',
150+
multi_container=False,
150151
)
151152
args_dict.update(kwargs)
152153
return Namespace(**args_dict)

src/xpk/commands/workload.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,23 @@ def workload_create(args) -> None:
489489
- PodFailurePolicy"""
490490
restart_on_exit_codes_list = get_restart_exit_codes(args)
491491
restart_on_exit_codes = ','.join(map(str, restart_on_exit_codes_list))
492-
pod_failure_policy = f"""
492+
if args.multi_container:
493+
pod_failure_policy = f"""
494+
podFailurePolicy:
495+
rules:
496+
- action: FailJob
497+
onExitCodes:
498+
containerName: {get_main_container_docker_image(args, workload_system)}-1
499+
operator: NotIn
500+
values: [{restart_on_exit_codes}]
501+
- action: FailJob
502+
onExitCodes:
503+
containerName: {get_main_container_docker_image(args, workload_system)}-2
504+
operator: NotIn
505+
values: [{restart_on_exit_codes}]"""
506+
507+
else:
508+
pod_failure_policy = f"""
493509
podFailurePolicy:
494510
rules:
495511
- action: FailJob

src/xpk/commands/workload_test.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,3 +205,66 @@ def test_workload_create_dry_run_with_output_file(mocker):
205205
written_content = mock_open.return_value.write.call_args[0][0]
206206
assert 'test-workload' in written_content
207207
assert 'cloud.google.com/gke-tpu-topology: 8x8' in written_content
208+
209+
210+
def test_workload_create_multi_container(
211+
workload_create_mocks: _WorkloadCreateMocks,
212+
mocker,
213+
):
214+
"""Tests that the generated YAML for a multi-container workload has correct pod failure policy and container structure."""
215+
216+
# Enable dry_run to prevent external calls like get_storages_to_mount -> gcloud
217+
mocker.patch('xpk.utils.execution_context.dry_run', True)
218+
219+
# Mock dependencies required by get_user_workload_container -> get_main_container
220+
mocker.patch('xpk.core.docker_container.setup_docker_image', return_value=(0, 'dummy-image'))
221+
mocker.patch('xpk.core.docker_container.get_gke_debugging_dashboard', return_value=None)
222+
223+
# Use the real get_user_workload_container to test integration
224+
from xpk.core.docker_container import get_user_workload_container as real_get_user_workload_container
225+
workload_create_mocks.get_user_workload_container.side_effect = real_get_user_workload_container
226+
227+
# Use a system with chips_per_vm=4 to test resource splitting logic
228+
system_chars = dataclasses.replace(SYSTEM_CHARACTERISTICS, chips_per_vm=4)
229+
230+
with patch(
231+
'xpk.commands.workload.get_system_characteristics',
232+
return_value=(system_chars, 0),
233+
):
234+
args = construct_args(
235+
workload='test-workload',
236+
command='echo hello',
237+
num_nodes=1,
238+
restart_on_exit_codes=None,
239+
multi_container=True,
240+
docker_name='test-docker',
241+
deploy_stacktrace_sidecar=False,
242+
enable_debug_logs=False,
243+
scheduler='default-scheduler',
244+
)
245+
workload_create(args)
246+
247+
assert workload_create_mocks.write_tmp_file.called
248+
yaml_content = workload_create_mocks.write_tmp_file.call_args[0][0]
249+
jobset = yaml.safe_load(yaml_content)
250+
251+
# Verify Pod Failure Policy
252+
pod_failure_rules = jobset['spec']['replicatedJobs'][0]['template']['spec']['podFailurePolicy']['rules']
253+
# Should have 2 rules for multi_container
254+
assert len(pod_failure_rules) == 2
255+
assert pod_failure_rules[0]['onExitCodes']['containerName'].endswith('-1')
256+
assert pod_failure_rules[1]['onExitCodes']['containerName'].endswith('-2')
257+
258+
# Verify Containers
259+
# Navigate to the containers list in the YAML
260+
containers = jobset['spec']['replicatedJobs'][0]['template']['spec']['template']['spec']['containers']
261+
262+
assert len(containers) == 2
263+
assert containers[0]['name'] == 'jax-tpu-1'
264+
assert containers[0]['image'] == 'dummy-image'
265+
# Check if resources are split correctly (4 chips / 2 containers = 2 chips)
266+
assert containers[0]['resources']['limits']['google.com/tpu'] == 2
267+
268+
assert containers[1]['name'] == 'jax-tpu-2'
269+
assert containers[1]['image'] == 'dummy-image'
270+
assert containers[1]['resources']['limits']['google.com/tpu'] == 2

src/xpk/core/docker_container.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,67 @@ def get_main_container(args, system, docker_image, resource_type) -> str:
112112
'touch /shared-volume/stacktrace_signal; '
113113
)
114114

115+
if not args.use_pathways and args.multi_container:
116+
containers = []
117+
for i in range(2):
118+
container_yaml = """
119+
- name: {docker_name}
120+
image: {docker_image}
121+
{image_pull_policy}
122+
env: {env}
123+
securityContext:
124+
privileged: true
125+
command:
126+
- bash
127+
- -c
128+
- |
129+
echo XPK Start: $(date);
130+
_sigterm() (kill -SIGTERM $! 2>/dev/null;);
131+
trap _sigterm SIGTERM;
132+
{gsutil_test_command}
133+
({command}) & PID=$!;
134+
while kill -0 $PID 2>/dev/null;
135+
do sleep 5;
136+
done;
137+
wait $PID;
138+
EXIT_CODE=$?;
139+
{xpk_internal_commands}
140+
echo XPK End: $(date);
141+
echo EXIT_CODE=$EXIT_CODE;
142+
{tpu_stacktrace_terminate_command}
143+
{gpu_workload_terminate_command}
144+
exit $EXIT_CODE
145+
resources:
146+
limits:
147+
{resources}
148+
"""
149+
volume_mounts = get_volume_mounts(args, system)
150+
if volume_mounts != '':
151+
container_yaml += """
152+
volumeMounts:
153+
{volume_mounts}
154+
"""
155+
containers.append(
156+
container_yaml.format(
157+
args=args,
158+
system=system,
159+
image_pull_policy=add_image_pull_policy_for_pw_or_gpu(
160+
args, system
161+
),
162+
env=get_env_container(args, system),
163+
docker_name=f'jax-tpu-{i+1}',
164+
docker_image=docker_image,
165+
gsutil_test_command=gsutil_test_command,
166+
command=command,
167+
tpu_stacktrace_terminate_command=tpu_stacktrace_terminate_command,
168+
gpu_workload_terminate_command=gpu_workload_terminate_command,
169+
xpk_internal_commands=xpk_internal_commands,
170+
resources=f'{resource_type}: {int(system.chips_per_vm / 2)}',
171+
volume_mounts=volume_mounts,
172+
)
173+
)
174+
return ''.join(containers)
175+
115176
yaml = """- name: {docker_name}
116177
image: {docker_image}
117178
{image_pull_policy}

src/xpk/parser/workload.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,11 @@ def set_workload_create_parser(workload_create_parser: ArgumentParser):
131131
default=1,
132132
help='The number of nodes to use, default=1.',
133133
)
134+
workload_create_parser_optional_arguments.add_argument(
135+
'--multi-container',
136+
action='store_true',
137+
help='Enable multi-container workload.',
138+
)
134139
workload_create_parser_optional_arguments.add_argument(
135140
'--scheduler',
136141
type=str,

0 commit comments

Comments
 (0)