@@ -205,3 +205,66 @@ def test_workload_create_dry_run_with_output_file(mocker):
205205 written_content = mock_open .return_value .write .call_args [0 ][0 ]
206206 assert 'test-workload' in written_content
207207 assert 'cloud.google.com/gke-tpu-topology: 8x8' in written_content
208+
209+
210+ def test_workload_create_multi_container (
211+ workload_create_mocks : _WorkloadCreateMocks ,
212+ mocker ,
213+ ):
214+ """Tests that the generated YAML for a multi-container workload has correct pod failure policy and container structure."""
215+
216+ # Enable dry_run to prevent external calls like get_storages_to_mount -> gcloud
217+ mocker .patch ('xpk.utils.execution_context.dry_run' , True )
218+
219+ # Mock dependencies required by get_user_workload_container -> get_main_container
220+ mocker .patch ('xpk.core.docker_container.setup_docker_image' , return_value = (0 , 'dummy-image' ))
221+ mocker .patch ('xpk.core.docker_container.get_gke_debugging_dashboard' , return_value = None )
222+
223+ # Use the real get_user_workload_container to test integration
224+ from xpk .core .docker_container import get_user_workload_container as real_get_user_workload_container
225+ workload_create_mocks .get_user_workload_container .side_effect = real_get_user_workload_container
226+
227+ # Use a system with chips_per_vm=4 to test resource splitting logic
228+ system_chars = dataclasses .replace (SYSTEM_CHARACTERISTICS , chips_per_vm = 4 )
229+
230+ with patch (
231+ 'xpk.commands.workload.get_system_characteristics' ,
232+ return_value = (system_chars , 0 ),
233+ ):
234+ args = construct_args (
235+ workload = 'test-workload' ,
236+ command = 'echo hello' ,
237+ num_nodes = 1 ,
238+ restart_on_exit_codes = None ,
239+ multi_container = True ,
240+ docker_name = 'test-docker' ,
241+ deploy_stacktrace_sidecar = False ,
242+ enable_debug_logs = False ,
243+ scheduler = 'default-scheduler' ,
244+ )
245+ workload_create (args )
246+
247+ assert workload_create_mocks .write_tmp_file .called
248+ yaml_content = workload_create_mocks .write_tmp_file .call_args [0 ][0 ]
249+ jobset = yaml .safe_load (yaml_content )
250+
251+ # Verify Pod Failure Policy
252+ pod_failure_rules = jobset ['spec' ]['replicatedJobs' ][0 ]['template' ]['spec' ]['podFailurePolicy' ]['rules' ]
253+ # Should have 2 rules for multi_container
254+ assert len (pod_failure_rules ) == 2
255+ assert pod_failure_rules [0 ]['onExitCodes' ]['containerName' ].endswith ('-1' )
256+ assert pod_failure_rules [1 ]['onExitCodes' ]['containerName' ].endswith ('-2' )
257+
258+ # Verify Containers
259+ # Navigate to the containers list in the YAML
260+ containers = jobset ['spec' ]['replicatedJobs' ][0 ]['template' ]['spec' ]['template' ]['spec' ]['containers' ]
261+
262+ assert len (containers ) == 2
263+ assert containers [0 ]['name' ] == 'jax-tpu-1'
264+ assert containers [0 ]['image' ] == 'dummy-image'
265+ # Check if resources are split correctly (4 chips / 2 containers = 2 chips)
266+ assert containers [0 ]['resources' ]['limits' ]['google.com/tpu' ] == 2
267+
268+ assert containers [1 ]['name' ] == 'jax-tpu-2'
269+ assert containers [1 ]['image' ] == 'dummy-image'
270+ assert containers [1 ]['resources' ]['limits' ]['google.com/tpu' ] == 2
0 commit comments