Skip to content

Commit cfb8a07

Browse files
committed
Add multi-container support for McJax workload create
1 parent 9b35b42 commit cfb8a07

File tree

6 files changed

+159
-1
lines changed

6 files changed

+159
-1
lines changed

pathways-job

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Subproject commit 2880c34e7d71596664bafa1c3cecb5754a9991e7

src/xpk/commands/cluster_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def construct_args(**kwargs: Any) -> Namespace:
147147
docker_image_pull_secret='',
148148
managed_mldiagnostics=False,
149149
output_manifest_file='',
150+
multi_container=False,
150151
)
151152
args_dict.update(kwargs)
152153
return Namespace(**args_dict)

src/xpk/commands/workload.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,23 @@ def workload_create(args) -> None:
489489
- PodFailurePolicy"""
490490
restart_on_exit_codes_list = get_restart_exit_codes(args)
491491
restart_on_exit_codes = ','.join(map(str, restart_on_exit_codes_list))
492-
pod_failure_policy = f"""
492+
if args.multi_container:
493+
pod_failure_policy = f"""
494+
podFailurePolicy:
495+
rules:
496+
- action: FailJob
497+
onExitCodes:
498+
containerName: {get_main_container_docker_image(args, workload_system)}-1
499+
operator: NotIn
500+
values: [{restart_on_exit_codes}]
501+
- action: FailJob
502+
onExitCodes:
503+
containerName: {get_main_container_docker_image(args, workload_system)}-2
504+
operator: NotIn
505+
values: [{restart_on_exit_codes}]"""
506+
507+
else:
508+
pod_failure_policy = f"""
493509
podFailurePolicy:
494510
rules:
495511
- action: FailJob

src/xpk/commands/workload_test.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from ..core.system_characteristics import DockerPlatform, SystemCharacteristics, AcceleratorType, UserFacingNameToSystemCharacteristics, GpuConfig
2424
from .workload import workload_create
2525
from .cluster_test import construct_args
26+
from ..core.docker_container import get_user_workload_container as real_get_user_workload_container
2627

2728

2829
SYSTEM_CHARACTERISTICS = SystemCharacteristics(
@@ -205,3 +206,76 @@ def test_workload_create_dry_run_with_output_file(mocker):
205206
written_content = mock_open.return_value.write.call_args[0][0]
206207
assert 'test-workload' in written_content
207208
assert 'cloud.google.com/gke-tpu-topology: 8x8' in written_content
209+
210+
211+
def test_workload_create_multi_container(
212+
workload_create_mocks: _WorkloadCreateMocks,
213+
mocker,
214+
):
215+
"""Tests that the generated YAML for a multi-container workload has correct pod failure policy and container structure."""
216+
217+
# Enable dry_run to prevent external calls like get_storages_to_mount -> gcloud
218+
mocker.patch('xpk.utils.execution_context.dry_run', True)
219+
220+
# Mock dependencies required by get_user_workload_container -> get_main_container
221+
mocker.patch(
222+
'xpk.core.docker_container.setup_docker_image',
223+
return_value=(0, 'dummy-image'),
224+
)
225+
mocker.patch(
226+
'xpk.core.docker_container.get_gke_debugging_dashboard', return_value=None
227+
)
228+
229+
# Use the real get_user_workload_container to test integration
230+
workload_create_mocks.get_user_workload_container.side_effect = (
231+
real_get_user_workload_container
232+
)
233+
234+
# Use a system with chips_per_vm=4 to test resource splitting logic
235+
system_chars = dataclasses.replace(SYSTEM_CHARACTERISTICS, chips_per_vm=4)
236+
237+
with patch(
238+
'xpk.commands.workload.get_system_characteristics',
239+
return_value=(system_chars, 0),
240+
):
241+
args = construct_args(
242+
workload='test-workload',
243+
command='echo hello',
244+
num_nodes=1,
245+
restart_on_exit_codes=None,
246+
multi_container=True,
247+
docker_name='test-docker',
248+
deploy_stacktrace_sidecar=False,
249+
enable_debug_logs=False,
250+
scheduler='default-scheduler',
251+
)
252+
workload_create(args)
253+
254+
assert workload_create_mocks.write_tmp_file.called
255+
yaml_content = workload_create_mocks.write_tmp_file.call_args[0][0]
256+
jobset = yaml.safe_load(yaml_content)
257+
258+
# Verify Pod Failure Policy
259+
pod_failure_rules = jobset['spec']['replicatedJobs'][0]['template']['spec'][
260+
'podFailurePolicy'
261+
]['rules']
262+
# Should have 2 rules for multi_container
263+
assert len(pod_failure_rules) == 2
264+
assert pod_failure_rules[0]['onExitCodes']['containerName'].endswith('-1')
265+
assert pod_failure_rules[1]['onExitCodes']['containerName'].endswith('-2')
266+
267+
# Verify Containers
268+
# Navigate to the containers list in the YAML
269+
containers = jobset['spec']['replicatedJobs'][0]['template']['spec'][
270+
'template'
271+
]['spec']['containers']
272+
273+
assert len(containers) == 2
274+
assert containers[0]['name'] == 'jax-tpu-1'
275+
assert containers[0]['image'] == 'dummy-image'
276+
# Check if resources are split correctly (4 chips / 2 containers = 2 chips)
277+
assert containers[0]['resources']['limits']['google.com/tpu'] == 2
278+
279+
assert containers[1]['name'] == 'jax-tpu-2'
280+
assert containers[1]['image'] == 'dummy-image'
281+
assert containers[1]['resources']['limits']['google.com/tpu'] == 2

src/xpk/core/docker_container.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,67 @@ def get_main_container(args, system, docker_image, resource_type) -> str:
112112
'touch /shared-volume/stacktrace_signal; '
113113
)
114114

115+
if not args.use_pathways and args.multi_container:
116+
containers = []
117+
for i in range(2):
118+
container_yaml = """
119+
- name: {docker_name}
120+
image: {docker_image}
121+
{image_pull_policy}
122+
env: {env}
123+
securityContext:
124+
privileged: true
125+
command:
126+
- bash
127+
- -c
128+
- |
129+
echo XPK Start: $(date);
130+
_sigterm() (kill -SIGTERM $! 2>/dev/null;);
131+
trap _sigterm SIGTERM;
132+
{gsutil_test_command}
133+
({command}) & PID=$!;
134+
while kill -0 $PID 2>/dev/null;
135+
do sleep 5;
136+
done;
137+
wait $PID;
138+
EXIT_CODE=$?;
139+
{xpk_internal_commands}
140+
echo XPK End: $(date);
141+
echo EXIT_CODE=$EXIT_CODE;
142+
{tpu_stacktrace_terminate_command}
143+
{gpu_workload_terminate_command}
144+
exit $EXIT_CODE
145+
resources:
146+
limits:
147+
{resources}
148+
"""
149+
volume_mounts = get_volume_mounts(args, system)
150+
if volume_mounts != '':
151+
container_yaml += """
152+
volumeMounts:
153+
{volume_mounts}
154+
"""
155+
containers.append(
156+
container_yaml.format(
157+
args=args,
158+
system=system,
159+
image_pull_policy=add_image_pull_policy_for_pw_or_gpu(
160+
args, system
161+
),
162+
env=get_env_container(args, system),
163+
docker_name=f'jax-tpu-{i+1}',
164+
docker_image=docker_image,
165+
gsutil_test_command=gsutil_test_command,
166+
command=command,
167+
tpu_stacktrace_terminate_command=tpu_stacktrace_terminate_command,
168+
gpu_workload_terminate_command=gpu_workload_terminate_command,
169+
xpk_internal_commands=xpk_internal_commands,
170+
resources=f'{resource_type}: {int(system.chips_per_vm / 2)}',
171+
volume_mounts=volume_mounts,
172+
)
173+
)
174+
return ''.join(containers)
175+
115176
yaml = """- name: {docker_name}
116177
image: {docker_image}
117178
{image_pull_policy}

src/xpk/parser/workload.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,11 @@ def set_workload_create_parser(workload_create_parser: ArgumentParser):
131131
default=1,
132132
help='The number of nodes to use, default=1.',
133133
)
134+
workload_create_parser_optional_arguments.add_argument(
135+
'--multi-container',
136+
action='store_true',
137+
help='Enable multi-container workload.',
138+
)
134139
workload_create_parser_optional_arguments.add_argument(
135140
'--scheduler',
136141
type=str,

0 commit comments

Comments
 (0)