diff --git a/nemo_run/core/execution/dgxcloud.py b/nemo_run/core/execution/dgxcloud.py index 13596ebf..1d17afdb 100644 --- a/nemo_run/core/execution/dgxcloud.py +++ b/nemo_run/core/execution/dgxcloud.py @@ -29,8 +29,11 @@ import requests from invoke.context import Context -from nemo_run.config import get_nemorun_home +from nemo_run.config import RUNDIR_NAME, get_nemorun_home from nemo_run.core.execution.base import Executor, ExecutorMacros +from nemo_run.core.execution.launcher import FaultTolerance, Launcher, Torchrun +from nemo_run.core.execution.utils import fill_template +from nemo_run.core.frontend.console.api import CONSOLE from nemo_run.core.packaging.base import Packager from nemo_run.core.packaging.git import GitArchivePackager @@ -461,6 +464,24 @@ def cancel(self, job_id: str): response.text, ) + def _setup_launcher(self): + super()._setup_launcher() + launcher = self.launcher + if launcher and isinstance(launcher, (FaultTolerance, Torchrun)): + self.torchrun_nproc_per_node = self.nprocs_per_node + self.ntasks_per_node = 1 + CONSOLE.log( + f"Detected {launcher.__class__.__name__} launcher, setting ntasks_per_node=1 and torchrun_nproc_per_node={self.torchrun_nproc_per_node}" + ) + + if launcher and isinstance(launcher, FaultTolerance): + base_dir = os.path.join(self.job_dir, Path(self.job_dir).name) + launcher.cfg_path = os.path.join(base_dir, f"{self.job_name}_ft_cfg.yml") + launcher.finished_flag_file = os.path.join( + "/", RUNDIR_NAME, f"{self.job_name}_finished_flag" + ) + launcher.job_results_file = os.path.join(base_dir, f"{self.job_name}_job_results") + def cleanup(self, handle: str): ... def assign( @@ -556,3 +577,55 @@ def _default_headers(self, token: Optional[str] = None) -> dict: if token: headers["Authorization"] = f"Bearer {token}" return headers + + +@dataclass(kw_only=True) +class DGXCloudRequest: + launch_cmd: list[str] + jobs: list[str] + executor: DGXCloudExecutor + max_retries: int + extra_env: dict[str, str] + launcher: Optional[Launcher] = None + + def materialize(self) -> str: + """Creates the content of a DGXC entrypoint script.""" + + # 1. Environment Variables + # Combine executor defaults with extra envs + env_vars = [] + full_env_vars = self.executor.env_vars | self.extra_env + for key, value in full_env_vars.items(): + env_vars.append(f"export {key.upper()}={value}") + + # 3. Prepare Template Variables + vars_to_fill = { + "max_retries": self.max_retries, + "env_vars": env_vars, + "training_command": " ".join(self.launch_cmd), + "ft_enabled": self.launcher and isinstance(self.launcher, FaultTolerance), + } + + # 4. Fault Tolerance Injection + if self.launcher and isinstance(self.launcher, FaultTolerance): + assert ( + self.launcher.cfg_path + and self.launcher.finished_flag_file + and self.launcher.job_results_file + ), "Fault Tolerance requires cfg_path, finished_flag_file, and job_results_file" + + vars_to_fill["fault_tol_cfg_path"] = self.launcher.cfg_path + vars_to_fill["fault_tol_finished_flag_file"] = self.launcher.finished_flag_file + vars_to_fill["fault_tol_job_results_file"] = self.launcher.job_results_file + + # Render the template + entrypoint_script = fill_template("dgxc.sh.j2", vars_to_fill) + return entrypoint_script + + def __repr__(self) -> str: + return f"""# DGXC Entrypoint Script Request +# Executor: {self.executor.__class__.__name__} +# Jobs: {self.jobs} +# --------------------------------------------------- +{self.materialize()} +""" diff --git a/nemo_run/core/execution/templates/dgxc.sh.j2 b/nemo_run/core/execution/templates/dgxc.sh.j2 new file mode 100644 index 00000000..75bdede2 --- /dev/null +++ b/nemo_run/core/execution/templates/dgxc.sh.j2 @@ -0,0 +1,31 @@ +{%- import "ft_launcher_dgxc.j2" as fault_tolerance -%} +#!/bin/bash + +set -evx # Print commands, but DO NOT exit immediately on error (we handle that below) +export PYTHONUNBUFFERED=1 +export TORCHX_MAX_RETRIES={{max_retries}} + +{%- for env_var in env_vars %} +{{env_var}} +{%- endfor %} + +{%- if ft_enabled %} +{{ fault_tolerance.ft_launcher_setup(fault_tol_cfg_path, fault_tol_finished_flag_file, fault_tol_job_results_file) }} +{%- endif %} + +echo "Starting training command..." +set +e # Turn off auto-exit so we can capture the code + +{{ training_command }} + +exitcode=$? +set -e + +echo "Main command exited with code $exitcode" + +{%- if ft_enabled %} +{{ fault_tolerance.ft_launcher_teardown() }} +{%- else %} + +exit $exitcode +{%- endif %} diff --git a/nemo_run/core/execution/templates/ft_launcher_dgxc.j2 b/nemo_run/core/execution/templates/ft_launcher_dgxc.j2 new file mode 100644 index 00000000..150d8b0c --- /dev/null +++ b/nemo_run/core/execution/templates/ft_launcher_dgxc.j2 @@ -0,0 +1,24 @@ +{% macro ft_launcher_setup(fault_tol_cfg_path, fault_tol_finished_flag_file, fault_tol_job_results_file) -%} +# This script uses experimental fault tolerance launcher +# Fault tolerance related items +export FAULT_TOL_CFG_PATH="{{fault_tol_cfg_path}}" +export FAULT_TOL_FINISHED_FLAG_FILE="{{fault_tol_finished_flag_file}}" + +JOB_RESULTS_FILE="{{fault_tol_job_results_file}}" + +is_training_finished() { + test -f "$(dirname $JOB_RESULTS_FILE)/$(basename $FAULT_TOL_FINISHED_FLAG_FILE)" +} + +if is_training_finished ; then + echo "Training is finished"; + exit 0; +else + rm -f "$FAULT_TOL_FINISHED_FLAG_FILE" "$JOB_RESULTS_FILE" +fi + +{%- endmacro %} + +{% macro ft_launcher_teardown() -%} +exit $exitcode +{%- endmacro %} diff --git a/nemo_run/core/execution/templates/ft_launcher.j2 b/nemo_run/core/execution/templates/ft_launcher_slurm.j2 similarity index 100% rename from nemo_run/core/execution/templates/ft_launcher.j2 rename to nemo_run/core/execution/templates/ft_launcher_slurm.j2 diff --git a/nemo_run/core/execution/templates/slurm.sh.j2 b/nemo_run/core/execution/templates/slurm.sh.j2 index 26f756fa..dc2c93fa 100644 --- a/nemo_run/core/execution/templates/slurm.sh.j2 +++ b/nemo_run/core/execution/templates/slurm.sh.j2 @@ -1,4 +1,4 @@ -{%- import "ft_launcher.j2" as fault_tolerance -%} +{%- import "ft_launcher_slurm.j2" as fault_tolerance -%} #!/bin/bash # # Generated by NeMo Run diff --git a/nemo_run/run/torchx_backend/components/ft_launcher.py b/nemo_run/run/torchx_backend/components/ft_launcher.py index 3920041f..2395465f 100644 --- a/nemo_run/run/torchx_backend/components/ft_launcher.py +++ b/nemo_run/run/torchx_backend/components/ft_launcher.py @@ -92,31 +92,34 @@ def ft_launcher( ): if workload_check_interval: ft_args += [ - "--ft-param-workload_check_interval", + "--ft-workload_check_interval", str(workload_check_interval), ] if initial_rank_heartbeat_timeout: ft_args += [ - "--ft-param-initial_rank_heartbeat_timeout", + "--ft-initial_rank_heartbeat_timeout", str(initial_rank_heartbeat_timeout), ] if rank_heartbeat_timeout: ft_args += [ - "--ft-param-rank_heartbeat_timeout", + "--ft-rank_heartbeat_timeout", str(rank_heartbeat_timeout), ] if rank_termination_signal: - ft_args += ["--ft-param-rank_termination_signal", rank_termination_signal] + ft_args += ["--ft-rank_termination_signal", rank_termination_signal] if log_level: - ft_args += ["--ft-param-log_level", log_level] + ft_args += ["--ft-log_level", log_level] if max_restarts: ft_args += ["--max-restarts", str(max_restarts)] + if dgxc is True: + ft_args += ["--ft-use-infra-group-rank", "False"] + else: ft_args = ["--ignore-missing-fault-tol-cfg"] diff --git a/nemo_run/run/torchx_backend/packaging.py b/nemo_run/run/torchx_backend/packaging.py index 84b9dc4c..8a850de4 100644 --- a/nemo_run/run/torchx_backend/packaging.py +++ b/nemo_run/run/torchx_backend/packaging.py @@ -203,6 +203,7 @@ def _get_details_from_script(fn_or_script: Script, serialize_configs: bool): log_level=launcher.log_level, max_retries=executor.retries, max_restarts=launcher.max_restarts, + dgxc=isinstance(executor, DGXCloudExecutor), use_env=use_env, ) else: diff --git a/nemo_run/run/torchx_backend/schedulers/dgxcloud.py b/nemo_run/run/torchx_backend/schedulers/dgxcloud.py index 4377ec71..b786d3c0 100644 --- a/nemo_run/run/torchx_backend/schedulers/dgxcloud.py +++ b/nemo_run/run/torchx_backend/schedulers/dgxcloud.py @@ -37,7 +37,7 @@ from nemo_run.config import get_nemorun_home from nemo_run.core.execution.base import Executor -from nemo_run.core.execution.dgxcloud import DGXCloudExecutor, DGXCloudState +from nemo_run.core.execution.dgxcloud import DGXCloudExecutor, DGXCloudRequest, DGXCloudState from nemo_run.core.serialization.zlib_json import ZlibJSONSerializer from nemo_run.run.torchx_backend.schedulers.api import SchedulerMixin @@ -109,6 +109,23 @@ def _submit_dryrun( # type: ignore role = values.apply(role) cmd = [role.entrypoint] + role.args + + req = DGXCloudRequest( + launch_cmd=cmd, + jobs=[role.name], + executor=executor, + max_retries=role.max_retries, + extra_env=role.env, + launcher=executor.get_launcher(), + ) + + # Write and copy sbatch script + path = os.path.join(executor.experiment_dir, "torchrun_job.sh") + script = req.materialize() + + with open(path, "w") as f: + f.write(script) + return AppDryRunInfo( DGXRequest(app=app, executor=executor, cmd=cmd, name=role.name), # Minimal function to show the config, if any @@ -128,7 +145,9 @@ def schedule(self, dryrun_info: AppDryRunInfo[DGXRequest]) -> str: # The DGXExecutor's launch call typically returns (job_id, handle). # We'll call it without additional parameters here. - job_id, status = executor.launch(name=req.name, cmd=req.cmd) + cmd = os.path.join(executor.experiment_dir, "torchrun_job.sh") + req.launch_cmd = ["bash", cmd] + job_id, status = executor.launch(name=req.name, cmd=req.launch_cmd) if not job_id: raise RuntimeError("Failed scheduling run on DGX: no job_id returned")