|
23 | 23 | from ..core.system_characteristics import DockerPlatform, SystemCharacteristics, AcceleratorType, UserFacingNameToSystemCharacteristics, GpuConfig |
24 | 24 | from .workload import workload_create |
25 | 25 | from .cluster_test import construct_args |
| 26 | +from ..core.docker_container import get_user_workload_container as real_get_user_workload_container |
26 | 27 |
|
27 | 28 |
|
28 | 29 | SYSTEM_CHARACTERISTICS = SystemCharacteristics( |
@@ -205,3 +206,76 @@ def test_workload_create_dry_run_with_output_file(mocker): |
205 | 206 | written_content = mock_open.return_value.write.call_args[0][0] |
206 | 207 | assert 'test-workload' in written_content |
207 | 208 | assert 'cloud.google.com/gke-tpu-topology: 8x8' in written_content |
| 209 | + |
| 210 | + |
| 211 | +def test_workload_create_multi_container( |
| 212 | + workload_create_mocks: _WorkloadCreateMocks, |
| 213 | + mocker, |
| 214 | +): |
| 215 | + """Tests that the generated YAML for a multi-container workload has correct pod failure policy and container structure.""" |
| 216 | + |
| 217 | + # Enable dry_run to prevent external calls like get_storages_to_mount -> gcloud |
| 218 | + mocker.patch('xpk.utils.execution_context.dry_run', True) |
| 219 | + |
| 220 | + # Mock dependencies required by get_user_workload_container -> get_main_container |
| 221 | + mocker.patch( |
| 222 | + 'xpk.core.docker_container.setup_docker_image', |
| 223 | + return_value=(0, 'dummy-image'), |
| 224 | + ) |
| 225 | + mocker.patch( |
| 226 | + 'xpk.core.docker_container.get_gke_debugging_dashboard', return_value=None |
| 227 | + ) |
| 228 | + |
| 229 | + # Use the real get_user_workload_container to test integration |
| 230 | + workload_create_mocks.get_user_workload_container.side_effect = ( |
| 231 | + real_get_user_workload_container |
| 232 | + ) |
| 233 | + |
| 234 | + # Use a system with chips_per_vm=4 to test resource splitting logic |
| 235 | + system_chars = dataclasses.replace(SYSTEM_CHARACTERISTICS, chips_per_vm=4) |
| 236 | + |
| 237 | + with patch( |
| 238 | + 'xpk.commands.workload.get_system_characteristics', |
| 239 | + return_value=(system_chars, 0), |
| 240 | + ): |
| 241 | + args = construct_args( |
| 242 | + workload='test-workload', |
| 243 | + command='echo hello', |
| 244 | + num_nodes=1, |
| 245 | + restart_on_exit_codes=None, |
| 246 | + multi_container=True, |
| 247 | + docker_name='test-docker', |
| 248 | + deploy_stacktrace_sidecar=False, |
| 249 | + enable_debug_logs=False, |
| 250 | + scheduler='default-scheduler', |
| 251 | + ) |
| 252 | + workload_create(args) |
| 253 | + |
| 254 | + assert workload_create_mocks.write_tmp_file.called |
| 255 | + yaml_content = workload_create_mocks.write_tmp_file.call_args[0][0] |
| 256 | + jobset = yaml.safe_load(yaml_content) |
| 257 | + |
| 258 | + # Verify Pod Failure Policy |
| 259 | + pod_failure_rules = jobset['spec']['replicatedJobs'][0]['template']['spec'][ |
| 260 | + 'podFailurePolicy' |
| 261 | + ]['rules'] |
| 262 | + # Should have 2 rules for multi_container |
| 263 | + assert len(pod_failure_rules) == 2 |
| 264 | + assert pod_failure_rules[0]['onExitCodes']['containerName'].endswith('-1') |
| 265 | + assert pod_failure_rules[1]['onExitCodes']['containerName'].endswith('-2') |
| 266 | + |
| 267 | + # Verify Containers |
| 268 | + # Navigate to the containers list in the YAML |
| 269 | + containers = jobset['spec']['replicatedJobs'][0]['template']['spec'][ |
| 270 | + 'template' |
| 271 | + ]['spec']['containers'] |
| 272 | + |
| 273 | + assert len(containers) == 2 |
| 274 | + assert containers[0]['name'] == 'jax-tpu-1' |
| 275 | + assert containers[0]['image'] == 'dummy-image' |
| 276 | + # Check if resources are split correctly (4 chips / 2 containers = 2 chips) |
| 277 | + assert containers[0]['resources']['limits']['google.com/tpu'] == 2 |
| 278 | + |
| 279 | + assert containers[1]['name'] == 'jax-tpu-2' |
| 280 | + assert containers[1]['image'] == 'dummy-image' |
| 281 | + assert containers[1]['resources']['limits']['google.com/tpu'] == 2 |
0 commit comments