From c0766cf9d83d532c723fd8a0af04de00f2139f99 Mon Sep 17 00:00:00 2001 From: Mohamed Zeidan Date: Mon, 1 Dec 2025 17:09:02 -0800 Subject: [PATCH] rumtime env bug fix --- .../runtime_environment_manager.py | 122 +++++++++++++++--- .../test_runtime_environment_manager.py | 6 +- 2 files changed, 107 insertions(+), 21 deletions(-) diff --git a/sagemaker-core/src/sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py b/sagemaker-core/src/sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py index f4d95f5412..8379503fe2 100644 --- a/sagemaker-core/src/sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +++ b/sagemaker-core/src/sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py @@ -94,6 +94,50 @@ def from_dependency_file_path(dependency_file_path): class RuntimeEnvironmentManager: """Runtime Environment Manager class to manage runtime environment.""" + def _validate_path(self, path: str) -> str: + """Validate and sanitize file path to prevent path traversal attacks. + + Args: + path (str): The file path to validate + + Returns: + str: The validated absolute path + + Raises: + ValueError: If the path is invalid or contains suspicious patterns + """ + if not path: + raise ValueError("Path cannot be empty") + + # Get absolute path to prevent path traversal + abs_path = os.path.abspath(path) + + # Check for null bytes (common in path traversal attacks) + if '\x00' in path: + raise ValueError(f"Invalid path contains null byte: {path}") + + return abs_path + + def _validate_env_name(self, env_name: str) -> None: + """Validate conda environment name to prevent command injection. + + Args: + env_name (str): The environment name to validate + + Raises: + ValueError: If the environment name contains invalid characters + """ + if not env_name: + raise ValueError("Environment name cannot be empty") + + # Allow only alphanumeric, underscore, and hyphen + import re + if not re.match(r'^[a-zA-Z0-9_-]+$', env_name): + raise ValueError( + f"Invalid environment name '{env_name}'. " + "Only alphanumeric characters, underscores, and hyphens are allowed." + ) + def snapshot(self, dependencies: str = None) -> str: """Creates snapshot of the user's environment @@ -252,42 +296,77 @@ def _is_file_exists(self, dependencies): def _install_requirements_txt(self, local_path, python_executable): """Install requirements.txt file""" - cmd = f"{python_executable} -m pip install -r {local_path} -U" - logger.info("Running command: '%s' in the dir: '%s' ", cmd, os.getcwd()) + # Validate path to prevent command injection + validated_path = self._validate_path(local_path) + cmd = [python_executable, "-m", "pip", "install", "-r", validated_path, "-U"] + logger.info("Running command: '%s' in the dir: '%s' ", " ".join(cmd), os.getcwd()) _run_shell_cmd(cmd) - logger.info("Command %s ran successfully", cmd) + logger.info("Command %s ran successfully", " ".join(cmd)) def _create_conda_env(self, env_name, local_path): """Create conda env using conda yml file""" + # Validate inputs to prevent command injection + self._validate_env_name(env_name) + validated_path = self._validate_path(local_path) - cmd = f"{self._get_conda_exe()} env create -n {env_name} --file {local_path}" - logger.info("Creating conda environment %s using: %s.", env_name, cmd) + cmd = [self._get_conda_exe(), "env", "create", "-n", env_name, "--file", validated_path] + logger.info("Creating conda environment %s using: %s.", env_name, " ".join(cmd)) _run_shell_cmd(cmd) logger.info("Conda environment %s created successfully.", env_name) def _install_req_txt_in_conda_env(self, env_name, local_path): """Install requirements.txt in the given conda environment""" + # Validate inputs to prevent command injection + self._validate_env_name(env_name) + validated_path = self._validate_path(local_path) - cmd = f"{self._get_conda_exe()} run -n {env_name} pip install -r {local_path} -U" - logger.info("Activating conda env and installing requirements: %s", cmd) + cmd = [self._get_conda_exe(), "run", "-n", env_name, "pip", "install", "-r", validated_path, "-U"] + logger.info("Activating conda env and installing requirements: %s", " ".join(cmd)) _run_shell_cmd(cmd) logger.info("Requirements installed successfully in conda env %s", env_name) def _update_conda_env(self, env_name, local_path): """Update conda env using conda yml file""" + # Validate inputs to prevent command injection + self._validate_env_name(env_name) + validated_path = self._validate_path(local_path) - cmd = f"{self._get_conda_exe()} env update -n {env_name} --file {local_path}" - logger.info("Updating conda env: %s", cmd) + cmd = [self._get_conda_exe(), "env", "update", "-n", env_name, "--file", validated_path] + logger.info("Updating conda env: %s", " ".join(cmd)) _run_shell_cmd(cmd) logger.info("Conda env %s updated succesfully", env_name) def _export_conda_env_from_prefix(self, prefix, local_path): """Export the conda env to a conda yml file""" - - cmd = f"{self._get_conda_exe()} env export -p {prefix} --no-builds > {local_path}" - logger.info("Exporting conda environment: %s", cmd) - _run_shell_cmd(cmd) - logger.info("Conda environment %s exported successfully", prefix) + # Validate inputs to prevent command injection + validated_prefix = self._validate_path(prefix) + validated_path = self._validate_path(local_path) + + cmd = [self._get_conda_exe(), "env", "export", "-p", validated_prefix, "--no-builds"] + logger.info("Exporting conda environment: %s", " ".join(cmd)) + + # Capture output and write to file instead of using shell redirection + try: + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=False + ) + output, error_output = process.communicate() + return_code = process.wait() + + if return_code: + error_message = f"Encountered error while running command '{' '.join(cmd)}'. Reason: {error_output.decode('utf-8')}" + raise RuntimeEnvironmentError(error_message) + + # Write the captured output to the file + with open(validated_path, 'w') as f: + f.write(output.decode('utf-8')) + + logger.info("Conda environment %s exported successfully", validated_prefix) + except Exception as e: + raise RuntimeEnvironmentError(f"Failed to export conda environment: {str(e)}") def _write_conda_env_to_file(self, env_name): """Writes conda env to the text file""" @@ -402,19 +481,26 @@ def _run_pre_execution_command_script(script_path: str): return return_code, error_logs -def _run_shell_cmd(cmd: str): +def _run_shell_cmd(cmd: list): """This method runs a given shell command using subprocess - Raises RuntimeEnvironmentError if the command fails + Args: + cmd (list): Command and arguments as a list (e.g., ['pip', 'install', '-r', 'requirements.txt']) + + Raises: + RuntimeEnvironmentError: If the command fails + ValueError: If cmd is not a list """ + if not isinstance(cmd, list): + raise ValueError("Command must be a list of arguments for security reasons") - process = subprocess.Popen((cmd), stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True) + process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False) _log_output(process) error_logs = _log_error(process) return_code = process.wait() if return_code: - error_message = f"Encountered error while running command '{cmd}'. Reason: {error_logs}" + error_message = f"Encountered error while running command '{' '.join(cmd)}'. Reason: {error_logs}" raise RuntimeEnvironmentError(error_message) diff --git a/sagemaker-core/tests/unit/remote_function/runtime_environment/test_runtime_environment_manager.py b/sagemaker-core/tests/unit/remote_function/runtime_environment/test_runtime_environment_manager.py index 5c9689a62b..5eaca91d94 100644 --- a/sagemaker-core/tests/unit/remote_function/runtime_environment/test_runtime_environment_manager.py +++ b/sagemaker-core/tests/unit/remote_function/runtime_environment/test_runtime_environment_manager.py @@ -448,7 +448,7 @@ def test_run_shell_cmd_success(self, mock_log_error, mock_log_output, mock_popen mock_popen.return_value = mock_process mock_log_error.return_value = "" - _run_shell_cmd("echo test") + _run_shell_cmd(["echo", "test"]) mock_popen.assert_called_once() @@ -463,7 +463,7 @@ def test_run_shell_cmd_failure(self, mock_log_error, mock_log_output, mock_popen mock_log_error.return_value = "Error message" with pytest.raises(RuntimeEnvironmentError, match="Encountered error"): - _run_shell_cmd("false") + _run_shell_cmd(["false"]) def test_python_executable(self): """Test _python_executable""" @@ -502,4 +502,4 @@ def test_get_logger(self): logger = get_logger() assert logger is not None - assert logger.name == "sagemaker.remote_function" + assert logger.name == "sagemaker.remote_function" \ No newline at end of file