From 85e5eabc66fca1fd0f8b140a3c94fc73e5825c0e Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Wed, 3 Dec 2025 09:55:40 -0800 Subject: [PATCH 1/2] renamed sagemaker training drivers folder --- .../{aii3xf60 => sagemaker_training_drivers}/__init__.py | 0 .../{aii3xf60 => sagemaker_training_drivers}/common/__init__.py | 0 .../{aii3xf60 => sagemaker_training_drivers}/common/utils.py | 0 .../{aii3xf60 => sagemaker_training_drivers}/distributed.json | 0 .../distributed_drivers/__init__.py | 0 .../distributed_drivers/basic_script_driver.py | 0 .../distributed_drivers/mpi_driver.py | 0 .../distributed_drivers/mpi_utils.py | 0 .../distributed_drivers/torchrun_driver.py | 0 .../{aii3xf60 => sagemaker_training_drivers}/scripts/__init__.py | 0 .../scripts/environment.py | 0 .../{aii3xf60 => sagemaker_training_drivers}/sm_train.sh | 0 .../{aii3xf60 => sagemaker_training_drivers}/sourcecode.json | 0 13 files changed, 0 insertions(+), 0 deletions(-) rename v3-examples/training-examples/{aii3xf60 => sagemaker_training_drivers}/__init__.py (100%) rename v3-examples/training-examples/{aii3xf60 => sagemaker_training_drivers}/common/__init__.py (100%) rename v3-examples/training-examples/{aii3xf60 => sagemaker_training_drivers}/common/utils.py (100%) rename v3-examples/training-examples/{aii3xf60 => sagemaker_training_drivers}/distributed.json (100%) rename v3-examples/training-examples/{aii3xf60 => sagemaker_training_drivers}/distributed_drivers/__init__.py (100%) rename v3-examples/training-examples/{aii3xf60 => sagemaker_training_drivers}/distributed_drivers/basic_script_driver.py (100%) rename v3-examples/training-examples/{aii3xf60 => sagemaker_training_drivers}/distributed_drivers/mpi_driver.py (100%) rename v3-examples/training-examples/{aii3xf60 => sagemaker_training_drivers}/distributed_drivers/mpi_utils.py (100%) rename v3-examples/training-examples/{aii3xf60 => sagemaker_training_drivers}/distributed_drivers/torchrun_driver.py (100%) rename v3-examples/training-examples/{aii3xf60 => sagemaker_training_drivers}/scripts/__init__.py (100%) rename v3-examples/training-examples/{aii3xf60 => sagemaker_training_drivers}/scripts/environment.py (100%) rename v3-examples/training-examples/{aii3xf60 => sagemaker_training_drivers}/sm_train.sh (100%) rename v3-examples/training-examples/{aii3xf60 => sagemaker_training_drivers}/sourcecode.json (100%) diff --git a/v3-examples/training-examples/aii3xf60/__init__.py b/v3-examples/training-examples/sagemaker_training_drivers/__init__.py similarity index 100% rename from v3-examples/training-examples/aii3xf60/__init__.py rename to v3-examples/training-examples/sagemaker_training_drivers/__init__.py diff --git a/v3-examples/training-examples/aii3xf60/common/__init__.py b/v3-examples/training-examples/sagemaker_training_drivers/common/__init__.py similarity index 100% rename from v3-examples/training-examples/aii3xf60/common/__init__.py rename to v3-examples/training-examples/sagemaker_training_drivers/common/__init__.py diff --git a/v3-examples/training-examples/aii3xf60/common/utils.py b/v3-examples/training-examples/sagemaker_training_drivers/common/utils.py similarity index 100% rename from v3-examples/training-examples/aii3xf60/common/utils.py rename to v3-examples/training-examples/sagemaker_training_drivers/common/utils.py diff --git a/v3-examples/training-examples/aii3xf60/distributed.json b/v3-examples/training-examples/sagemaker_training_drivers/distributed.json similarity index 100% rename from v3-examples/training-examples/aii3xf60/distributed.json rename to v3-examples/training-examples/sagemaker_training_drivers/distributed.json diff --git a/v3-examples/training-examples/aii3xf60/distributed_drivers/__init__.py b/v3-examples/training-examples/sagemaker_training_drivers/distributed_drivers/__init__.py similarity index 100% rename from v3-examples/training-examples/aii3xf60/distributed_drivers/__init__.py rename to v3-examples/training-examples/sagemaker_training_drivers/distributed_drivers/__init__.py diff --git a/v3-examples/training-examples/aii3xf60/distributed_drivers/basic_script_driver.py b/v3-examples/training-examples/sagemaker_training_drivers/distributed_drivers/basic_script_driver.py similarity index 100% rename from v3-examples/training-examples/aii3xf60/distributed_drivers/basic_script_driver.py rename to v3-examples/training-examples/sagemaker_training_drivers/distributed_drivers/basic_script_driver.py diff --git a/v3-examples/training-examples/aii3xf60/distributed_drivers/mpi_driver.py b/v3-examples/training-examples/sagemaker_training_drivers/distributed_drivers/mpi_driver.py similarity index 100% rename from v3-examples/training-examples/aii3xf60/distributed_drivers/mpi_driver.py rename to v3-examples/training-examples/sagemaker_training_drivers/distributed_drivers/mpi_driver.py diff --git a/v3-examples/training-examples/aii3xf60/distributed_drivers/mpi_utils.py b/v3-examples/training-examples/sagemaker_training_drivers/distributed_drivers/mpi_utils.py similarity index 100% rename from v3-examples/training-examples/aii3xf60/distributed_drivers/mpi_utils.py rename to v3-examples/training-examples/sagemaker_training_drivers/distributed_drivers/mpi_utils.py diff --git a/v3-examples/training-examples/aii3xf60/distributed_drivers/torchrun_driver.py b/v3-examples/training-examples/sagemaker_training_drivers/distributed_drivers/torchrun_driver.py similarity index 100% rename from v3-examples/training-examples/aii3xf60/distributed_drivers/torchrun_driver.py rename to v3-examples/training-examples/sagemaker_training_drivers/distributed_drivers/torchrun_driver.py diff --git a/v3-examples/training-examples/aii3xf60/scripts/__init__.py b/v3-examples/training-examples/sagemaker_training_drivers/scripts/__init__.py similarity index 100% rename from v3-examples/training-examples/aii3xf60/scripts/__init__.py rename to v3-examples/training-examples/sagemaker_training_drivers/scripts/__init__.py diff --git a/v3-examples/training-examples/aii3xf60/scripts/environment.py b/v3-examples/training-examples/sagemaker_training_drivers/scripts/environment.py similarity index 100% rename from v3-examples/training-examples/aii3xf60/scripts/environment.py rename to v3-examples/training-examples/sagemaker_training_drivers/scripts/environment.py diff --git a/v3-examples/training-examples/aii3xf60/sm_train.sh b/v3-examples/training-examples/sagemaker_training_drivers/sm_train.sh similarity index 100% rename from v3-examples/training-examples/aii3xf60/sm_train.sh rename to v3-examples/training-examples/sagemaker_training_drivers/sm_train.sh diff --git a/v3-examples/training-examples/aii3xf60/sourcecode.json b/v3-examples/training-examples/sagemaker_training_drivers/sourcecode.json similarity index 100% rename from v3-examples/training-examples/aii3xf60/sourcecode.json rename to v3-examples/training-examples/sagemaker_training_drivers/sourcecode.json From d7b9dfbd96bffa09994f12d40040944e9ace1345 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Wed, 3 Dec 2025 10:04:33 -0800 Subject: [PATCH 2/2] Removing folder --- .../sagemaker_training_drivers/__init__.py | 14 - .../common/__init__.py | 14 - .../common/utils.py | 205 ------------ .../distributed.json | 1 - .../distributed_drivers/__init__.py | 14 - .../basic_script_driver.py | 81 ----- .../distributed_drivers/mpi_driver.py | 105 ------ .../distributed_drivers/mpi_utils.py | 302 ----------------- .../distributed_drivers/torchrun_driver.py | 129 -------- .../scripts/__init__.py | 14 - .../scripts/environment.py | 305 ------------------ .../sagemaker_training_drivers/sm_train.sh | 59 ---- .../sourcecode.json | 1 - 13 files changed, 1244 deletions(-) delete mode 100644 v3-examples/training-examples/sagemaker_training_drivers/__init__.py delete mode 100644 v3-examples/training-examples/sagemaker_training_drivers/common/__init__.py delete mode 100644 v3-examples/training-examples/sagemaker_training_drivers/common/utils.py delete mode 100644 v3-examples/training-examples/sagemaker_training_drivers/distributed.json delete mode 100644 v3-examples/training-examples/sagemaker_training_drivers/distributed_drivers/__init__.py delete mode 100644 v3-examples/training-examples/sagemaker_training_drivers/distributed_drivers/basic_script_driver.py delete mode 100644 v3-examples/training-examples/sagemaker_training_drivers/distributed_drivers/mpi_driver.py delete mode 100644 v3-examples/training-examples/sagemaker_training_drivers/distributed_drivers/mpi_utils.py delete mode 100644 v3-examples/training-examples/sagemaker_training_drivers/distributed_drivers/torchrun_driver.py delete mode 100644 v3-examples/training-examples/sagemaker_training_drivers/scripts/__init__.py delete mode 100644 v3-examples/training-examples/sagemaker_training_drivers/scripts/environment.py delete mode 100644 v3-examples/training-examples/sagemaker_training_drivers/sm_train.sh delete mode 100644 v3-examples/training-examples/sagemaker_training_drivers/sourcecode.json diff --git a/v3-examples/training-examples/sagemaker_training_drivers/__init__.py b/v3-examples/training-examples/sagemaker_training_drivers/__init__.py deleted file mode 100644 index 864f3663b8..0000000000 --- a/v3-examples/training-examples/sagemaker_training_drivers/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Sagemaker modules container drivers directory.""" -from __future__ import absolute_import diff --git a/v3-examples/training-examples/sagemaker_training_drivers/common/__init__.py b/v3-examples/training-examples/sagemaker_training_drivers/common/__init__.py deleted file mode 100644 index aab88c6b97..0000000000 --- a/v3-examples/training-examples/sagemaker_training_drivers/common/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Sagemaker modules container drivers - common directory.""" -from __future__ import absolute_import diff --git a/v3-examples/training-examples/sagemaker_training_drivers/common/utils.py b/v3-examples/training-examples/sagemaker_training_drivers/common/utils.py deleted file mode 100644 index 03146a3bbe..0000000000 --- a/v3-examples/training-examples/sagemaker_training_drivers/common/utils.py +++ /dev/null @@ -1,205 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""This module provides utility functions for the container drivers.""" -from __future__ import absolute_import - -import os -import logging -import sys -import subprocess -import traceback -import json - -from typing import List, Dict, Any, Tuple, IO, Optional - -# Initialize logger -SM_LOG_LEVEL = os.environ.get("SM_LOG_LEVEL", 20) -logger = logging.getLogger(__name__) -console_handler = logging.StreamHandler(sys.stdout) -logger.addHandler(console_handler) -logger.setLevel(int(SM_LOG_LEVEL)) - -FAILURE_FILE = "/opt/ml/output/failure" -DEFAULT_FAILURE_MESSAGE = """ -Training Execution failed. -For more details, see CloudWatch logs at 'aws/sagemaker/TrainingJobs'. -TrainingJob - {training_job_name} -""" - -USER_CODE_PATH = "/opt/ml/input/data/code" -SOURCE_CODE_JSON = "/opt/ml/input/data/sm_drivers/sourcecode.json" -DISTRIBUTED_JSON = "/opt/ml/input/data/sm_drivers/distributed.json" - -HYPERPARAMETERS_JSON = "/opt/ml/input/config/hyperparameters.json" - -SM_EFA_NCCL_INSTANCES = [ - "ml.g4dn.8xlarge", - "ml.g4dn.12xlarge", - "ml.g5.48xlarge", - "ml.p3dn.24xlarge", - "ml.p4d.24xlarge", - "ml.p4de.24xlarge", - "ml.p5.48xlarge", - "ml.trn1.32xlarge", -] - -SM_EFA_RDMA_INSTANCES = [ - "ml.p4d.24xlarge", - "ml.p4de.24xlarge", - "ml.trn1.32xlarge", -] - - -def write_failure_file(message: Optional[str] = None): - """Write a failure file with the message.""" - if message is None: - message = DEFAULT_FAILURE_MESSAGE.format(training_job_name=os.environ["TRAINING_JOB_NAME"]) - if not os.path.exists(FAILURE_FILE): - with open(FAILURE_FILE, "w") as f: - f.write(message) - - -def read_source_code_json(source_code_json: Dict[str, Any] = SOURCE_CODE_JSON): - """Read the source code config json file.""" - try: - with open(source_code_json, "r") as f: - source_code_dict = json.load(f) or {} - except FileNotFoundError: - source_code_dict = {} - return source_code_dict - - -def read_distributed_json(distributed_json: Dict[str, Any] = DISTRIBUTED_JSON): - """Read the distribution config json file.""" - try: - with open(distributed_json, "r") as f: - distributed_dict = json.load(f) or {} - except FileNotFoundError: - distributed_dict = {} - return distributed_dict - - -def read_hyperparameters_json(hyperparameters_json: Dict[str, Any] = HYPERPARAMETERS_JSON): - """Read the hyperparameters config json file.""" - try: - with open(hyperparameters_json, "r") as f: - hyperparameters_dict = json.load(f) or {} - except FileNotFoundError: - hyperparameters_dict = {} - return hyperparameters_dict - - -def get_process_count(process_count: Optional[int] = None) -> int: - """Get the number of processes to run on each node in the training job.""" - return ( - process_count - or int(os.environ.get("SM_NUM_GPUS", 0)) - or int(os.environ.get("SM_NUM_NEURONS", 0)) - or 1 - ) - - -def hyperparameters_to_cli_args(hyperparameters: Dict[str, Any]) -> List[str]: - """Convert the hyperparameters to CLI arguments.""" - cli_args = [] - for key, value in hyperparameters.items(): - value = safe_deserialize(value) - cli_args.extend([f"--{key}", safe_serialize(value)]) - - return cli_args - - -def safe_deserialize(data: Any) -> Any: - """Safely deserialize data from a JSON string. - - This function handles the following cases: - 1. If `data` is not a string, it returns the input as-is. - 2. If `data` is a JSON-encoded string, it attempts to deserialize it using `json.loads()`. - 3. If `data` is a string but cannot be decoded as JSON, it returns the original string. - - Returns: - Any: The deserialized data, or the original input if it cannot be JSON-decoded. - """ - if not isinstance(data, str): - return data - - try: - return json.loads(data) - except json.JSONDecodeError: - return data - - -def safe_serialize(data): - """Serialize the data without wrapping strings in quotes. - - This function handles the following cases: - 1. If `data` is a string, it returns the string as-is without wrapping in quotes. - 2. If `data` is serializable (e.g., a dictionary, list, int, float), it returns - the JSON-encoded string using `json.dumps()`. - 3. If `data` cannot be serialized (e.g., a custom object), it returns the string - representation of the data using `str(data)`. - - Args: - data (Any): The data to serialize. - - Returns: - str: The serialized JSON-compatible string or the string representation of the input. - """ - if isinstance(data, str): - return data - try: - return json.dumps(data) - except TypeError: - return str(data) - - -def get_python_executable() -> str: - """Get the python executable path.""" - return sys.executable - - -def log_subprocess_output(pipe: IO[bytes]): - """Log the output from the subprocess.""" - for line in iter(pipe.readline, b""): - logger.info(line.decode("utf-8").strip()) - - -def execute_commands(commands: List[str]) -> Tuple[int, str]: - """Execute the provided commands and return exit code with failure traceback if any.""" - try: - process = subprocess.Popen( - commands, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - ) - with process.stdout: - log_subprocess_output(process.stdout) - exitcode = process.wait() - if exitcode != 0: - raise subprocess.CalledProcessError(exitcode, commands) - return exitcode, "" - except subprocess.CalledProcessError as e: - # Capture the traceback in case of failure - error_traceback = traceback.format_exc() - print(f"Command failed with exit code {e.returncode}. Traceback: {error_traceback}") - return e.returncode, error_traceback - - -def is_worker_node() -> bool: - """Check if the current node is a worker node.""" - return os.environ.get("SM_CURRENT_HOST") != os.environ.get("SM_MASTER_ADDR") - - -def is_master_node() -> bool: - """Check if the current node is the master node.""" - return os.environ.get("SM_CURRENT_HOST") == os.environ.get("SM_MASTER_ADDR") diff --git a/v3-examples/training-examples/sagemaker_training_drivers/distributed.json b/v3-examples/training-examples/sagemaker_training_drivers/distributed.json deleted file mode 100644 index 9e26dfeeb6..0000000000 --- a/v3-examples/training-examples/sagemaker_training_drivers/distributed.json +++ /dev/null @@ -1 +0,0 @@ -{} \ No newline at end of file diff --git a/v3-examples/training-examples/sagemaker_training_drivers/distributed_drivers/__init__.py b/v3-examples/training-examples/sagemaker_training_drivers/distributed_drivers/__init__.py deleted file mode 100644 index a44e7e81a9..0000000000 --- a/v3-examples/training-examples/sagemaker_training_drivers/distributed_drivers/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Sagemaker modules container drivers - drivers directory.""" -from __future__ import absolute_import diff --git a/v3-examples/training-examples/sagemaker_training_drivers/distributed_drivers/basic_script_driver.py b/v3-examples/training-examples/sagemaker_training_drivers/distributed_drivers/basic_script_driver.py deleted file mode 100644 index a298da80a2..0000000000 --- a/v3-examples/training-examples/sagemaker_training_drivers/distributed_drivers/basic_script_driver.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""This module is the entry point for the Basic Script Driver.""" -from __future__ import absolute_import - -import os -import sys -import json -import shlex - -from pathlib import Path -from typing import List - -sys.path.insert(0, str(Path(__file__).parent.parent)) - -from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611 - logger, - get_python_executable, - write_failure_file, - hyperparameters_to_cli_args, - execute_commands, -) - - -def create_commands() -> List[str]: - """Create the commands to execute.""" - entry_script = os.environ["SM_ENTRY_SCRIPT"] - hyperparameters = json.loads(os.environ["SM_HPS"]) - python_executable = get_python_executable() - - args = hyperparameters_to_cli_args(hyperparameters) - if entry_script.endswith(".py"): - commands = [python_executable, entry_script] - commands += args - elif entry_script.endswith(".sh"): - args_str = " ".join(shlex.quote(arg) for arg in args) - commands = [ - "/bin/sh", - "-c", - f"chmod +x {entry_script} && ./{entry_script} {args_str}", - ] - else: - raise ValueError( - f"Unsupported entry script type: {entry_script}. Only .py and .sh are supported." - ) - return commands - - -def main(): - """Main function for the Basic Script Driver. - - This function is the entry point for the Basic Script Driver. - - Execution Lifecycle: - 1. Read the source code and hyperparameters JSON files. - 2. Set hyperparameters as command line arguments. - 3. Create the commands to execute. - 4. Execute the commands. - """ - - cmd = create_commands() - - logger.info(f"Executing command: {' '.join(cmd)}") - exit_code, traceback = execute_commands(cmd) - if exit_code != 0: - write_failure_file(traceback) - sys.exit(exit_code) - - -if __name__ == "__main__": - main() diff --git a/v3-examples/training-examples/sagemaker_training_drivers/distributed_drivers/mpi_driver.py b/v3-examples/training-examples/sagemaker_training_drivers/distributed_drivers/mpi_driver.py deleted file mode 100644 index 8ffe1f4318..0000000000 --- a/v3-examples/training-examples/sagemaker_training_drivers/distributed_drivers/mpi_driver.py +++ /dev/null @@ -1,105 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""This module is the entry point for the MPI driver script.""" -from __future__ import absolute_import - -import os -import sys -import json - -from sagemaker.train.container_drivers.distributed_drivers.mpi_utils import ( - start_sshd_daemon, - bootstrap_master_node, - bootstrap_worker_node, - get_mpirun_command, - write_status_file_to_workers, - write_env_vars_to_file, -) - - -from sagemaker.train.container_drivers.common.utils import ( - logger, - hyperparameters_to_cli_args, - get_process_count, - execute_commands, - write_failure_file, -) - - -def main(): - """Main function for the MPI driver script. - - The MPI Dirver is responsible for setting up the MPI environment, - generating the correct mpi commands, and launching the MPI job. - - Execution Lifecycle: - 1. Setup General Environment Variables at /etc/environment - 2. Start SSHD Daemon - 3. Bootstrap Worker Nodes - a. Wait to establish connection with Master Node - b. Wait for Master Node to write status file - 4. Bootstrap Master Node - a. Wait to establish connection with Worker Nodes - b. Generate MPI Command - c. Execute MPI Command with user script provided in `entry_script` - d. Write status file to Worker Nodes - 5. Exit - - """ - entry_script = os.environ["SM_ENTRY_SCRIPT"] - distributed_config = json.loads(os.environ["SM_DISTRIBUTED_CONFIG"]) - hyperparameters = json.loads(os.environ["SM_HPS"]) - - sm_current_host = os.environ["SM_CURRENT_HOST"] - sm_hosts = json.loads(os.environ["SM_HOSTS"]) - sm_master_addr = os.environ["SM_MASTER_ADDR"] - - write_env_vars_to_file() - start_sshd_daemon() - - if sm_current_host != sm_master_addr: - bootstrap_worker_node(sm_master_addr) - else: - worker_hosts = [host for host in sm_hosts if host != sm_master_addr] - bootstrap_master_node(worker_hosts) - - host_list = json.loads(os.environ["SM_HOSTS"]) - host_count = int(os.environ["SM_HOST_COUNT"]) - process_count = int(distributed_config["process_count_per_node"] or 0) - process_count = get_process_count(process_count) - - if process_count > 1: - host_list = ["{}:{}".format(host, process_count) for host in host_list] - - mpi_command = get_mpirun_command( - host_count=host_count, - host_list=host_list, - num_processes=process_count, - additional_options=distributed_config["mpi_additional_options"] or [], - entry_script_path=entry_script, - ) - - args = hyperparameters_to_cli_args(hyperparameters) - mpi_command += args - - logger.info(f"Executing command: {' '.join(mpi_command)}") - exit_code, error_traceback = execute_commands(mpi_command) - write_status_file_to_workers(worker_hosts) - - if exit_code != 0: - write_failure_file(error_traceback) - sys.exit(exit_code) - - -if __name__ == "__main__": - main() diff --git a/v3-examples/training-examples/sagemaker_training_drivers/distributed_drivers/mpi_utils.py b/v3-examples/training-examples/sagemaker_training_drivers/distributed_drivers/mpi_utils.py deleted file mode 100644 index ec9e1fcef9..0000000000 --- a/v3-examples/training-examples/sagemaker_training_drivers/distributed_drivers/mpi_utils.py +++ /dev/null @@ -1,302 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""This module provides mpi related utility functions for the container drivers.""" -from __future__ import absolute_import - -import os -import sys -import subprocess -import time - -from pathlib import Path -from typing import List - -import paramiko - -sys.path.insert(0, str(Path(__file__).parent.parent)) - -from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611 - SM_EFA_NCCL_INSTANCES, - SM_EFA_RDMA_INSTANCES, - get_python_executable, - logger, -) - -FINISHED_STATUS_FILE = "/tmp/done.algo-1" -READY_FILE = "/tmp/ready.%s" -DEFAULT_SSH_PORT = 22 - - -def _write_file_to_host(host: str, status_file: str) -> bool: - """Write the a file to the provided host.""" - try: - logger.info(f"Writing {status_file} to {host}") - subprocess.run( - ["ssh", host, "touch", f"{status_file}"], - capture_output=True, - text=True, - check=True, - ) - logger.info("Finished writing status file") - return True - except subprocess.CalledProcessError: - logger.info(f"Cannot connect to {host}") - return False - - -def write_status_file_to_workers(worker_hosts: List[str], status_file: str = FINISHED_STATUS_FILE): - """Write the status file to all worker nodes.""" - for worker in worker_hosts: - retry = 0 - while not _write_file_to_host(worker, status_file): - time.sleep(5) - retry += 1 - if retry > 5: - raise TimeoutError(f"Timed out waiting for {worker} to be reachable.") - logger.info(f"Retrying to write status file to {worker}") - - -def _wait_for_status_file(status_file: str): - """Wait for the status file to be created.""" - logger.info(f"Waiting for status file {status_file}") - while not os.path.exists(status_file): - time.sleep(30) - logger.info(f"Found status file {status_file}") - - -def start_sshd_daemon(): - """Start the SSH daemon on the current node.""" - sshd_executable = "/usr/sbin/sshd" - - if not os.path.exists(sshd_executable): - raise RuntimeError("SSH daemon not found.") - - # Start the sshd in daemon mode (-D) - subprocess.Popen([sshd_executable, "-D"]) - logger.info("Started SSH daemon.") - - -class CustomHostKeyPolicy(paramiko.client.MissingHostKeyPolicy): - """Class to handle host key policy for SageMaker distributed training SSH connections. - - Example: - >>> client = paramiko.SSHClient() - >>> client.set_missing_host_key_policy(CustomHostKeyPolicy()) - >>> # Will succeed for SageMaker algorithm containers - >>> client.connect('algo-1234.internal') - >>> # Will raise SSHException for other unknown hosts - >>> client.connect('unknown-host') # raises SSHException - """ - - def missing_host_key(self, client, hostname, key): - """Accept host keys for algo-* hostnames, reject others. - - Args: - client: The SSHClient instance - hostname: The hostname attempting to connect - key: The host key - - Raises: - paramiko.SSHException: If hostname doesn't match algo-* pattern - """ - if hostname.startswith("algo-"): - client.get_host_keys().add(hostname, key.get_name(), key) - return - raise paramiko.SSHException(f"Unknown host key for {hostname}") - - -def _can_connect(host: str, port: int = DEFAULT_SSH_PORT) -> bool: - """Check if the connection to the provided host and port is possible.""" - try: - logger.debug("Testing connection to host %s", host) - with paramiko.SSHClient() as client: - client.load_system_host_keys() - client.set_missing_host_key_policy(CustomHostKeyPolicy()) - client.connect(host, port=port) - logger.info("Can connect to host %s", host) - return True - except Exception as e: # pylint: disable=W0703 - logger.info("Cannot connect to host %s", host) - logger.debug(f"Connection failed with exception: {e}") - return False - - -def _wait_for_workers(worker_hosts: List[str], port: int = DEFAULT_SSH_PORT, timeout: int = 300): - """Master node waits until it can connect to all worker nodes.""" - start_time = time.time() - if not worker_hosts: - logger.info("No worker nodes to connect to.") - return - - while True: - logger.info("Master is attempting to connect to all workers...") - all_workers_connected = all( - _can_connect(worker, port) and os.path.exists(READY_FILE % worker) - for worker in worker_hosts - ) - - if all_workers_connected: - logger.info("Master can connect to all worker nodes.") - break - if time.time() - start_time > timeout: - raise TimeoutError("Timed out waiting for workers to be reachable.") - - time.sleep(5) # Wait for 5 seconds before trying again - - -def _wait_for_master(master_host: str, port: int = DEFAULT_SSH_PORT, timeout: int = 300): - """Worker nodes wait until they can connect to the master node.""" - start_time = time.time() - while True: - logger.info(f"Worker is attempting to connect to the master node {master_host}...") - if _can_connect(master_host, port): - logger.info(f"Worker can connect to master node {master_host}.") - break - if time.time() - start_time > timeout: - raise TimeoutError(f"Timed out waiting for master {master_host} to be reachable.") - - time.sleep(5) # Wait for 5 seconds before trying again - - -def bootstrap_worker_node(master_host: str, status_file: str = FINISHED_STATUS_FILE): - """Bootstrap the worker nodes.""" - logger.info("Bootstrapping worker node...") - _wait_for_master(master_host) - _write_file_to_host(master_host, READY_FILE % os.environ["SM_CURRENT_HOST"]) - _wait_for_status_file(status_file) - - -def bootstrap_master_node(worker_hosts: List[str]): - """Bootstrap the master node.""" - logger.info("Bootstrapping master node...") - _wait_for_workers(worker_hosts) - - -def validate_smddprun() -> bool: - """Whether smddprun is installed. - - Returns: - bool: True if installed - """ - try: - output = subprocess.run( - ["which", "smddprun"], - capture_output=True, - text=True, - check=True, - ) - return output.stdout != "" - except subprocess.CalledProcessError: - return False - - -def validate_smddpmprun() -> bool: - """Whether smddpmprun is installed. - - Returns: - bool: True if both are installed - """ - try: - output = subprocess.run( - ["which", "smddpmprun"], - capture_output=True, - text=True, - check=True, - ) - return output.stdout != "" - except subprocess.CalledProcessError: - return False - - -def write_env_vars_to_file(): - """Write environment variables to /etc/environment file.""" - with open("/etc/environment", "a", encoding="utf-8") as f: - for name in os.environ: - f.write(f"{name}={os.environ.get(name)}\n") - - -def get_mpirun_command( - host_count: int, - host_list: List[str], - num_processes: int, - additional_options: List[str], - entry_script_path: str, -): - """Fetch mpi command""" - network_interface_name = os.environ.get("SM_NETWORK_INTERFACE_NAME", "eth0") - - mpirun_command = [ - "mpirun", - "--host", - ",".join(host_list), - "-np", - str(num_processes), - "--allow-run-as-root", - "--tag-output", - "-mca", - "btl_tcp_if_include", - network_interface_name, - "-mca", - "oob_tcp_if_include", - network_interface_name, - "-mca", - "plm_rsh_no_tree_spawn", - "1", - "-mca", - "pml", - "ob1", - "-mca", - "btl", - "^openib", - "-mca", - "orte_abort_on_non_zero_status", - "1", - "-mca", - "btl_vader_single_copy_mechanism", - "none", - "-mca", - "plm_rsh_num_concurrent", - str(host_count), - "-x", - "NCCL_SOCKET_IFNAME=%s" % network_interface_name, - "-x", - "LD_LIBRARY_PATH", - "-x", - "PATH", - ] - - if additional_options: - mpirun_command.extend(additional_options) - - instance_type = os.environ["SM_CURRENT_INSTANCE_TYPE"] - # EFA settings - if instance_type in SM_EFA_NCCL_INSTANCES: - mpirun_command.extend(["-x", "FI_PROVIDER=efa"]) - # Use simple protocol to handle the out-of-order data delivery from EFA - mpirun_command.extend(["-x", "NCCL_PROTO=simple"]) - - if instance_type in SM_EFA_RDMA_INSTANCES: - # Use EFA's RDMA functionality for one-sided and two-sided transfer - mpirun_command.extend(["-x", "FI_EFA_USE_DEVICE_RDMA=1"]) - - for credential in [ - "AWS_ACCESS_KEY_ID", - "AWS_SECRET_ACCESS_KEY", - "AWS_SESSION_TOKEN", - ]: - if credential in os.environ: - mpirun_command.extend(["-x", credential]) - - mpirun_command.extend([get_python_executable()]) - mpirun_command.extend(["-m", "mpi4py", entry_script_path]) - return mpirun_command diff --git a/v3-examples/training-examples/sagemaker_training_drivers/distributed_drivers/torchrun_driver.py b/v3-examples/training-examples/sagemaker_training_drivers/distributed_drivers/torchrun_driver.py deleted file mode 100644 index 7fcfabe05d..0000000000 --- a/v3-examples/training-examples/sagemaker_training_drivers/distributed_drivers/torchrun_driver.py +++ /dev/null @@ -1,129 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""This module is the entry point for the Torchrun driver script.""" -from __future__ import absolute_import - -import os -import sys -import json - -from pathlib import Path -from typing import List, Tuple - -sys.path.insert(0, str(Path(__file__).parent.parent)) - -from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611 - logger, - hyperparameters_to_cli_args, - get_process_count, - get_python_executable, - execute_commands, - write_failure_file, - SM_EFA_NCCL_INSTANCES, - SM_EFA_RDMA_INSTANCES, -) - - -def pytorch_version() -> Tuple[int, int]: - """Get the PyTorch version as a tuple of integers.""" - import torch - - return tuple(map(int, torch.__version__.split(".")[:2])) - - -def get_base_pytorch_command() -> List[str]: - """Get the base Torch Distributed launcher to execute""" - if pytorch_version() >= (1, 9): - return ["torchrun"] - return [f"{get_python_executable()}", "-m", "torch.distributed.launch"] - - -def setup_env(): - """Setup the environment variables for PyTorch distributed training""" - instance_type = os.environ["SM_CURRENT_INSTANCE_TYPE"] - network_interface_name = os.environ.get("SM_NETWORK_INTERFACE_NAME", "eth0") - if instance_type in SM_EFA_NCCL_INSTANCES: - # Enable EFA use - os.environ["FI_PROVIDER"] = "efa" - if instance_type in SM_EFA_RDMA_INSTANCES: - # Use EFA's RDMA functionality for one-sided and two-sided transfer - os.environ["FI_EFA_USE_DEVICE_RDMA"] = "1" - os.environ["RDMAV_FORK_SAFE"] = "1" - os.environ["NCCL_SOCKET_IFNAME"] = str(network_interface_name) - os.environ["NCCL_PROTO"] = "simple" - - -def create_commands(): - """Create the Torch Distributed command to execute""" - entry_script = os.environ["SM_ENTRY_SCRIPT"] - distributed_config = json.loads(os.environ["SM_DISTRIBUTED_CONFIG"]) - hyperparameters = json.loads(os.environ["SM_HPS"]) - - process_count = int(distributed_config["process_count_per_node"] or 0) - process_count = get_process_count(process_count) - host_count = int(os.environ["SM_HOST_COUNT"]) - - torch_cmd = [] - if os.environ.get("RUN_NEURON_PARALLEL_COMPILE") == "1": - torch_cmd.append("neuron_parallel_compile") - - torch_cmd.extend(get_base_pytorch_command()) - torch_cmd.extend( - [ - f"--nnodes={host_count}", - f"--nproc_per_node={process_count}", - ] - ) - - # If more than one node is used, add node rank information - if int(host_count) > 1: - torch_cmd.extend( - [ - f"--master_addr={os.environ['SM_MASTER_ADDR']}", - f"--master_port={os.environ['SM_MASTER_PORT']}", - f"--node_rank={os.environ['SM_CURRENT_HOST_RANK']}", - ] - ) - - torch_cmd.extend([entry_script]) - - args = hyperparameters_to_cli_args(hyperparameters) - torch_cmd += args - - return torch_cmd - - -def main(): - """Main function to execute the PyTorch distributed training script. - - This function sets some environment variables and executes the PyTorch - distributed training script. - - Execution Lifecycle: - 1. Setup Environment Variables for PyTorch Distributed Training - 2. Create Torch Distributed Command - 3. Execute Torch Distributed Command with user script provided in `entry_script` - 4. Exit - - """ - setup_env() - torch_cmd = create_commands() - logger.info(f"Executing command: {' '.join(torch_cmd)}") - exit_code, traceback = execute_commands(torch_cmd) - if exit_code != 0: - write_failure_file(traceback) - sys.exit(exit_code) - - -if __name__ == "__main__": - main() diff --git a/v3-examples/training-examples/sagemaker_training_drivers/scripts/__init__.py b/v3-examples/training-examples/sagemaker_training_drivers/scripts/__init__.py deleted file mode 100644 index f04c5b17a0..0000000000 --- a/v3-examples/training-examples/sagemaker_training_drivers/scripts/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Sagemaker modules container drivers - scripts directory.""" -from __future__ import absolute_import diff --git a/v3-examples/training-examples/sagemaker_training_drivers/scripts/environment.py b/v3-examples/training-examples/sagemaker_training_drivers/scripts/environment.py deleted file mode 100644 index 897b1f8af4..0000000000 --- a/v3-examples/training-examples/sagemaker_training_drivers/scripts/environment.py +++ /dev/null @@ -1,305 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""This module is used to define the environment variables for the training job container.""" -from __future__ import absolute_import - -from typing import Dict, Any -import multiprocessing -import subprocess -import json -import os -import sys -from pathlib import Path -import logging - -sys.path.insert(0, str(Path(__file__).parent.parent)) - -from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611 - safe_serialize, - safe_deserialize, - read_distributed_json, - read_source_code_json, -) - -# Initialize logger -SM_LOG_LEVEL = os.environ.get("SM_LOG_LEVEL", 20) -logger = logging.getLogger(__name__) -console_handler = logging.StreamHandler(sys.stdout) -logger.addHandler(console_handler) -logger.setLevel(int(SM_LOG_LEVEL)) - -SM_MODEL_DIR = "/opt/ml/model" - -SM_INPUT_DIR = "/opt/ml/input" -SM_INPUT_DATA_DIR = "/opt/ml/input/data" -SM_INPUT_CONFIG_DIR = "/opt/ml/input/config" - -SM_OUTPUT_DIR = "/opt/ml/output" -SM_OUTPUT_FAILURE = "/opt/ml/output/failure" -SM_OUTPUT_DATA_DIR = "/opt/ml/output/data" -SM_SOURCE_DIR_PATH = "/opt/ml/input/data/code" -SM_DISTRIBUTED_DRIVER_DIR_PATH = "/opt/ml/input/data/sm_drivers/distributed_drivers" - -SM_MASTER_ADDR = "algo-1" -SM_MASTER_PORT = 7777 - -RESOURCE_CONFIG = f"{SM_INPUT_CONFIG_DIR}/resourceconfig.json" -INPUT_DATA_CONFIG = f"{SM_INPUT_CONFIG_DIR}/inputdataconfig.json" -HYPERPARAMETERS_CONFIG = f"{SM_INPUT_CONFIG_DIR}/hyperparameters.json" - -ENV_OUTPUT_FILE = "/opt/ml/input/sm_training.env" - -SENSITIVE_KEYWORDS = ["SECRET", "PASSWORD", "KEY", "TOKEN", "PRIVATE", "CREDS", "CREDENTIALS"] -HIDDEN_VALUE = "******" - - -def num_cpus() -> int: - """Return the number of CPUs available in the current container. - - Returns: - int: Number of CPUs available in the current container. - """ - return multiprocessing.cpu_count() - - -def num_gpus() -> int: - """Return the number of GPUs available in the current container. - - Returns: - int: Number of GPUs available in the current container. - """ - try: - cmd = ["nvidia-smi", "--list-gpus"] - output = subprocess.check_output(cmd).decode("utf-8") - return sum(1 for line in output.splitlines() if line.startswith("GPU ")) - except (OSError, subprocess.CalledProcessError): - logger.info("No GPUs detected (normal if no gpus installed)") - return 0 - - -def num_neurons() -> int: - """Return the number of neuron cores available in the current container. - - Returns: - int: Number of Neuron Cores available in the current container. - """ - try: - cmd = ["neuron-ls", "-j"] - output = subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode("utf-8") - j = json.loads(output) - neuron_cores = 0 - for item in j: - neuron_cores += item.get("nc_count", 0) - logger.info("Found %s neurons on this instance", neuron_cores) - return neuron_cores - except OSError: - logger.info("No Neurons detected (normal if no neurons installed)") - return 0 - except subprocess.CalledProcessError as e: - if e.output is not None: - try: - msg = e.output.decode("utf-8").partition("error=")[2] - logger.info( - "No Neurons detected (normal if no neurons installed). \ - If neuron installed then %s", - msg, - ) - except AttributeError: - logger.info("No Neurons detected (normal if no neurons installed)") - else: - logger.info("No Neurons detected (normal if no neurons installed)") - - return 0 - - -def deserialize_hyperparameters(hyperparameters: Dict[str, str]) -> Dict[str, Any]: - """Deserialize hyperparameters from string to their original types. - - Args: - hyperparameters (Dict[str, str]): Hyperparameters as strings. - - Returns: - Dict[str, Any]: Hyperparameters as their original types. - """ - deserialized_hyperparameters = {} - for key, value in hyperparameters.items(): - deserialized_hyperparameters[key] = safe_deserialize(value) - return deserialized_hyperparameters - - -def set_env( - resource_config: Dict[str, Any], - input_data_config: Dict[str, Any], - hyperparameters_config: Dict[str, Any], - output_file: str = ENV_OUTPUT_FILE, -): - """Set environment variables for the training job container. - - Args: - resource_config (Dict[str, Any]): Resource configuration for the training job. - input_data_config (Dict[str, Any]): Input data configuration for the training job. - hyperparameters_config (Dict[str, Any]): Hyperparameters configuration for the training job. - output_file (str): Output file to write the environment variables. - """ - # Constants - env_vars = { - "SM_MODEL_DIR": SM_MODEL_DIR, - "SM_INPUT_DIR": SM_INPUT_DIR, - "SM_INPUT_DATA_DIR": SM_INPUT_DATA_DIR, - "SM_INPUT_CONFIG_DIR": SM_INPUT_CONFIG_DIR, - "SM_OUTPUT_DIR": SM_OUTPUT_DIR, - "SM_OUTPUT_FAILURE": SM_OUTPUT_FAILURE, - "SM_OUTPUT_DATA_DIR": SM_OUTPUT_DATA_DIR, - "SM_LOG_LEVEL": SM_LOG_LEVEL, - "SM_MASTER_ADDR": SM_MASTER_ADDR, - "SM_MASTER_PORT": SM_MASTER_PORT, - } - - # SourceCode and DistributedConfig Environment Variables - source_code = read_source_code_json() - if source_code: - env_vars["SM_SOURCE_DIR"] = SM_SOURCE_DIR_PATH - env_vars["SM_ENTRY_SCRIPT"] = source_code.get("entry_script", "") - - distributed = read_distributed_json() - if distributed: - env_vars["SM_DISTRIBUTED_DRIVER_DIR"] = SM_DISTRIBUTED_DRIVER_DIR_PATH - env_vars["SM_DISTRIBUTED_CONFIG"] = distributed - - # Data Channels - channels = list(input_data_config.keys()) - for channel in channels: - env_vars[f"SM_CHANNEL_{channel.upper()}"] = f"{SM_INPUT_DATA_DIR}/{channel}" - env_vars["SM_CHANNELS"] = channels - - # Hyperparameters - hps = deserialize_hyperparameters(hyperparameters_config) - for key, value in hps.items(): - key_upper = key.replace("-", "_").upper() - env_vars[f"SM_HP_{key_upper}"] = value - env_vars["SM_HPS"] = hps - - # Host Variables - current_host = resource_config["current_host"] - current_instance_type = resource_config["current_instance_type"] - hosts = resource_config["hosts"] - sorted_hosts = sorted(hosts) - - env_vars["SM_CURRENT_HOST"] = current_host - env_vars["SM_CURRENT_INSTANCE_TYPE"] = current_instance_type - env_vars["SM_HOSTS"] = sorted_hosts - env_vars["SM_NETWORK_INTERFACE_NAME"] = resource_config["network_interface_name"] - env_vars["SM_HOST_COUNT"] = len(sorted_hosts) - env_vars["SM_CURRENT_HOST_RANK"] = sorted_hosts.index(current_host) - - env_vars["SM_NUM_CPUS"] = num_cpus() - env_vars["SM_NUM_GPUS"] = num_gpus() - env_vars["SM_NUM_NEURONS"] = num_neurons() - - # Misc. - env_vars["SM_RESOURCE_CONFIG"] = resource_config - env_vars["SM_INPUT_DATA_CONFIG"] = input_data_config - - # All Training Environment Variables - env_vars["SM_TRAINING_ENV"] = { - "channel_input_dirs": { - channel: env_vars[f"SM_CHANNEL_{channel.upper()}"] for channel in channels - }, - "current_host": env_vars["SM_CURRENT_HOST"], - "current_instance_type": env_vars["SM_CURRENT_INSTANCE_TYPE"], - "hosts": env_vars["SM_HOSTS"], - "master_addr": env_vars["SM_MASTER_ADDR"], - "master_port": env_vars["SM_MASTER_PORT"], - "hyperparameters": env_vars["SM_HPS"], - "input_data_config": input_data_config, - "input_config_dir": env_vars["SM_INPUT_CONFIG_DIR"], - "input_data_dir": env_vars["SM_INPUT_DATA_DIR"], - "input_dir": env_vars["SM_INPUT_DIR"], - "job_name": os.environ["TRAINING_JOB_NAME"], - "log_level": env_vars["SM_LOG_LEVEL"], - "model_dir": env_vars["SM_MODEL_DIR"], - "network_interface_name": env_vars["SM_NETWORK_INTERFACE_NAME"], - "num_cpus": env_vars["SM_NUM_CPUS"], - "num_gpus": env_vars["SM_NUM_GPUS"], - "num_neurons": env_vars["SM_NUM_NEURONS"], - "output_data_dir": env_vars["SM_OUTPUT_DATA_DIR"], - "resource_config": env_vars["SM_RESOURCE_CONFIG"], - } - with open(output_file, "w") as f: - for key, value in env_vars.items(): - f.write(f"export {key}='{safe_serialize(value)}'\n") - - logger.info("Environment Variables:") - log_env_variables(env_vars_dict=env_vars) - - -def mask_sensitive_info(data): - """Recursively mask sensitive information in a dictionary.""" - if isinstance(data, dict): - for k, v in data.items(): - if isinstance(v, dict): - data[k] = mask_sensitive_info(v) - elif isinstance(v, str) and any( - keyword.lower() in k.lower() for keyword in SENSITIVE_KEYWORDS - ): - data[k] = HIDDEN_VALUE - return data - - -def log_key_value(key: str, value: str): - """Log a key-value pair, masking sensitive values if necessary.""" - if any(keyword.lower() in key.lower() for keyword in SENSITIVE_KEYWORDS): - logger.info("%s=%s", key, HIDDEN_VALUE) - elif isinstance(value, dict): - masked_value = mask_sensitive_info(value) - logger.info("%s=%s", key, json.dumps(masked_value)) - else: - try: - decoded_value = json.loads(value) - if isinstance(decoded_value, dict): - masked_value = mask_sensitive_info(decoded_value) - logger.info("%s=%s", key, json.dumps(masked_value)) - else: - logger.info("%s=%s", key, decoded_value) - except (json.JSONDecodeError, TypeError): - logger.info("%s=%s", key, value) - - -def log_env_variables(env_vars_dict: Dict[str, Any]): - """Log Environment Variables from the environment and an env_vars_dict.""" - for key, value in os.environ.items(): - log_key_value(key, value) - - for key, value in env_vars_dict.items(): - log_key_value(key, value) - - -def main(): - """Main function to set the environment variables for the training job container.""" - with open(RESOURCE_CONFIG, "r") as f: - resource_config = json.load(f) - with open(INPUT_DATA_CONFIG, "r") as f: - input_data_config = json.load(f) - with open(HYPERPARAMETERS_CONFIG, "r") as f: - hyperparameters_config = json.load(f) - - set_env( - resource_config=resource_config, - input_data_config=input_data_config, - hyperparameters_config=hyperparameters_config, - output_file=ENV_OUTPUT_FILE, - ) - - -if __name__ == "__main__": - main() diff --git a/v3-examples/training-examples/sagemaker_training_drivers/sm_train.sh b/v3-examples/training-examples/sagemaker_training_drivers/sm_train.sh deleted file mode 100644 index 33ddf13bae..0000000000 --- a/v3-examples/training-examples/sagemaker_training_drivers/sm_train.sh +++ /dev/null @@ -1,59 +0,0 @@ - -#!/bin/bash -set -e -echo "Starting training script" - -handle_error() { - EXIT_STATUS=$? - echo "An error occurred with exit code $EXIT_STATUS" - if [ ! -s /opt/ml/output/failure ]; then - echo "Training Execution failed. For more details, see CloudWatch logs at 'aws/sagemaker/TrainingJobs'. -TrainingJob - $TRAINING_JOB_NAME" >> /opt/ml/output/failure - fi - exit $EXIT_STATUS -} - -check_python() { - SM_PYTHON_CMD=$(command -v python3 || command -v python) - SM_PIP_CMD=$(command -v pip3 || command -v pip) - - # Check if Python is found - if [[ -z "$SM_PYTHON_CMD" || -z "$SM_PIP_CMD" ]]; then - echo "Error: The Python executable was not found in the system path." - return 1 - fi - - return 0 -} - -trap 'handle_error' ERR - -check_python - -set -x -$SM_PYTHON_CMD --version - -echo "/opt/ml/input/config/resourceconfig.json:" -cat /opt/ml/input/config/resourceconfig.json -echo - -echo "/opt/ml/input/config/inputdataconfig.json:" -cat /opt/ml/input/config/inputdataconfig.json -echo - -echo "Setting up environment variables" -$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/scripts/environment.py - -set +x -source /opt/ml/input/sm_training.env -set -x - -cd /opt/ml/input/data/code - - - -echo "Running Basic Script driver" -$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/distributed_drivers/basic_script_driver.py - - -echo "Training Container Execution Completed" diff --git a/v3-examples/training-examples/sagemaker_training_drivers/sourcecode.json b/v3-examples/training-examples/sagemaker_training_drivers/sourcecode.json deleted file mode 100644 index fccf2e3f1b..0000000000 --- a/v3-examples/training-examples/sagemaker_training_drivers/sourcecode.json +++ /dev/null @@ -1 +0,0 @@ -{"source_dir": "/var/folders/12/bjmscmk114v7hxrzj6v4wd840000gq/T/tmpy87jl49z/source", "requirements": null, "entry_script": "local_training_script.py", "command": null} \ No newline at end of file