Skip to content

Commit c0766cf

Browse files
author
Mohamed Zeidan
committed
rumtime env bug fix
1 parent 1286f17 commit c0766cf

File tree

2 files changed

+107
-21
lines changed

2 files changed

+107
-21
lines changed

sagemaker-core/src/sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py

Lines changed: 104 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,50 @@ def from_dependency_file_path(dependency_file_path):
9494
class RuntimeEnvironmentManager:
9595
"""Runtime Environment Manager class to manage runtime environment."""
9696

97+
def _validate_path(self, path: str) -> str:
98+
"""Validate and sanitize file path to prevent path traversal attacks.
99+
100+
Args:
101+
path (str): The file path to validate
102+
103+
Returns:
104+
str: The validated absolute path
105+
106+
Raises:
107+
ValueError: If the path is invalid or contains suspicious patterns
108+
"""
109+
if not path:
110+
raise ValueError("Path cannot be empty")
111+
112+
# Get absolute path to prevent path traversal
113+
abs_path = os.path.abspath(path)
114+
115+
# Check for null bytes (common in path traversal attacks)
116+
if '\x00' in path:
117+
raise ValueError(f"Invalid path contains null byte: {path}")
118+
119+
return abs_path
120+
121+
def _validate_env_name(self, env_name: str) -> None:
122+
"""Validate conda environment name to prevent command injection.
123+
124+
Args:
125+
env_name (str): The environment name to validate
126+
127+
Raises:
128+
ValueError: If the environment name contains invalid characters
129+
"""
130+
if not env_name:
131+
raise ValueError("Environment name cannot be empty")
132+
133+
# Allow only alphanumeric, underscore, and hyphen
134+
import re
135+
if not re.match(r'^[a-zA-Z0-9_-]+$', env_name):
136+
raise ValueError(
137+
f"Invalid environment name '{env_name}'. "
138+
"Only alphanumeric characters, underscores, and hyphens are allowed."
139+
)
140+
97141
def snapshot(self, dependencies: str = None) -> str:
98142
"""Creates snapshot of the user's environment
99143
@@ -252,42 +296,77 @@ def _is_file_exists(self, dependencies):
252296

253297
def _install_requirements_txt(self, local_path, python_executable):
254298
"""Install requirements.txt file"""
255-
cmd = f"{python_executable} -m pip install -r {local_path} -U"
256-
logger.info("Running command: '%s' in the dir: '%s' ", cmd, os.getcwd())
299+
# Validate path to prevent command injection
300+
validated_path = self._validate_path(local_path)
301+
cmd = [python_executable, "-m", "pip", "install", "-r", validated_path, "-U"]
302+
logger.info("Running command: '%s' in the dir: '%s' ", " ".join(cmd), os.getcwd())
257303
_run_shell_cmd(cmd)
258-
logger.info("Command %s ran successfully", cmd)
304+
logger.info("Command %s ran successfully", " ".join(cmd))
259305

260306
def _create_conda_env(self, env_name, local_path):
261307
"""Create conda env using conda yml file"""
308+
# Validate inputs to prevent command injection
309+
self._validate_env_name(env_name)
310+
validated_path = self._validate_path(local_path)
262311

263-
cmd = f"{self._get_conda_exe()} env create -n {env_name} --file {local_path}"
264-
logger.info("Creating conda environment %s using: %s.", env_name, cmd)
312+
cmd = [self._get_conda_exe(), "env", "create", "-n", env_name, "--file", validated_path]
313+
logger.info("Creating conda environment %s using: %s.", env_name, " ".join(cmd))
265314
_run_shell_cmd(cmd)
266315
logger.info("Conda environment %s created successfully.", env_name)
267316

268317
def _install_req_txt_in_conda_env(self, env_name, local_path):
269318
"""Install requirements.txt in the given conda environment"""
319+
# Validate inputs to prevent command injection
320+
self._validate_env_name(env_name)
321+
validated_path = self._validate_path(local_path)
270322

271-
cmd = f"{self._get_conda_exe()} run -n {env_name} pip install -r {local_path} -U"
272-
logger.info("Activating conda env and installing requirements: %s", cmd)
323+
cmd = [self._get_conda_exe(), "run", "-n", env_name, "pip", "install", "-r", validated_path, "-U"]
324+
logger.info("Activating conda env and installing requirements: %s", " ".join(cmd))
273325
_run_shell_cmd(cmd)
274326
logger.info("Requirements installed successfully in conda env %s", env_name)
275327

276328
def _update_conda_env(self, env_name, local_path):
277329
"""Update conda env using conda yml file"""
330+
# Validate inputs to prevent command injection
331+
self._validate_env_name(env_name)
332+
validated_path = self._validate_path(local_path)
278333

279-
cmd = f"{self._get_conda_exe()} env update -n {env_name} --file {local_path}"
280-
logger.info("Updating conda env: %s", cmd)
334+
cmd = [self._get_conda_exe(), "env", "update", "-n", env_name, "--file", validated_path]
335+
logger.info("Updating conda env: %s", " ".join(cmd))
281336
_run_shell_cmd(cmd)
282337
logger.info("Conda env %s updated succesfully", env_name)
283338

284339
def _export_conda_env_from_prefix(self, prefix, local_path):
285340
"""Export the conda env to a conda yml file"""
286-
287-
cmd = f"{self._get_conda_exe()} env export -p {prefix} --no-builds > {local_path}"
288-
logger.info("Exporting conda environment: %s", cmd)
289-
_run_shell_cmd(cmd)
290-
logger.info("Conda environment %s exported successfully", prefix)
341+
# Validate inputs to prevent command injection
342+
validated_prefix = self._validate_path(prefix)
343+
validated_path = self._validate_path(local_path)
344+
345+
cmd = [self._get_conda_exe(), "env", "export", "-p", validated_prefix, "--no-builds"]
346+
logger.info("Exporting conda environment: %s", " ".join(cmd))
347+
348+
# Capture output and write to file instead of using shell redirection
349+
try:
350+
process = subprocess.Popen(
351+
cmd,
352+
stdout=subprocess.PIPE,
353+
stderr=subprocess.PIPE,
354+
shell=False
355+
)
356+
output, error_output = process.communicate()
357+
return_code = process.wait()
358+
359+
if return_code:
360+
error_message = f"Encountered error while running command '{' '.join(cmd)}'. Reason: {error_output.decode('utf-8')}"
361+
raise RuntimeEnvironmentError(error_message)
362+
363+
# Write the captured output to the file
364+
with open(validated_path, 'w') as f:
365+
f.write(output.decode('utf-8'))
366+
367+
logger.info("Conda environment %s exported successfully", validated_prefix)
368+
except Exception as e:
369+
raise RuntimeEnvironmentError(f"Failed to export conda environment: {str(e)}")
291370

292371
def _write_conda_env_to_file(self, env_name):
293372
"""Writes conda env to the text file"""
@@ -402,19 +481,26 @@ def _run_pre_execution_command_script(script_path: str):
402481
return return_code, error_logs
403482

404483

405-
def _run_shell_cmd(cmd: str):
484+
def _run_shell_cmd(cmd: list):
406485
"""This method runs a given shell command using subprocess
407486
408-
Raises RuntimeEnvironmentError if the command fails
487+
Args:
488+
cmd (list): Command and arguments as a list (e.g., ['pip', 'install', '-r', 'requirements.txt'])
489+
490+
Raises:
491+
RuntimeEnvironmentError: If the command fails
492+
ValueError: If cmd is not a list
409493
"""
494+
if not isinstance(cmd, list):
495+
raise ValueError("Command must be a list of arguments for security reasons")
410496

411-
process = subprocess.Popen((cmd), stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
497+
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False)
412498

413499
_log_output(process)
414500
error_logs = _log_error(process)
415501
return_code = process.wait()
416502
if return_code:
417-
error_message = f"Encountered error while running command '{cmd}'. Reason: {error_logs}"
503+
error_message = f"Encountered error while running command '{' '.join(cmd)}'. Reason: {error_logs}"
418504
raise RuntimeEnvironmentError(error_message)
419505

420506

sagemaker-core/tests/unit/remote_function/runtime_environment/test_runtime_environment_manager.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,7 @@ def test_run_shell_cmd_success(self, mock_log_error, mock_log_output, mock_popen
448448
mock_popen.return_value = mock_process
449449
mock_log_error.return_value = ""
450450

451-
_run_shell_cmd("echo test")
451+
_run_shell_cmd(["echo", "test"])
452452

453453
mock_popen.assert_called_once()
454454

@@ -463,7 +463,7 @@ def test_run_shell_cmd_failure(self, mock_log_error, mock_log_output, mock_popen
463463
mock_log_error.return_value = "Error message"
464464

465465
with pytest.raises(RuntimeEnvironmentError, match="Encountered error"):
466-
_run_shell_cmd("false")
466+
_run_shell_cmd(["false"])
467467

468468
def test_python_executable(self):
469469
"""Test _python_executable"""
@@ -502,4 +502,4 @@ def test_get_logger(self):
502502
logger = get_logger()
503503

504504
assert logger is not None
505-
assert logger.name == "sagemaker.remote_function"
505+
assert logger.name == "sagemaker.remote_function"

0 commit comments

Comments
 (0)