diff --git a/sagemaker-core/src/sagemaker/core/remote_function/client.py b/sagemaker-core/src/sagemaker/core/remote_function/client.py index b140c03901..a38b57662a 100644 --- a/sagemaker-core/src/sagemaker/core/remote_function/client.py +++ b/sagemaker-core/src/sagemaker/core/remote_function/client.py @@ -366,7 +366,7 @@ def wrapper(*args, **kwargs): s3_uri=s3_path_join( job_settings.s3_root_uri, job.job_name, EXCEPTION_FOLDER ), - hmac_key=job.hmac_key, + ) except ServiceError as serr: chained_e = serr.__cause__ @@ -403,7 +403,7 @@ def wrapper(*args, **kwargs): return serialization.deserialize_obj_from_s3( sagemaker_session=job_settings.sagemaker_session, s3_uri=s3_path_join(job_settings.s3_root_uri, job.job_name, RESULTS_FOLDER), - hmac_key=job.hmac_key, + ) if job.describe()["TrainingJobStatus"] == "Stopped": @@ -983,7 +983,7 @@ def from_describe_response(describe_training_job_response, sagemaker_session): job_return = serialization.deserialize_obj_from_s3( sagemaker_session=sagemaker_session, s3_uri=s3_path_join(job.s3_uri, RESULTS_FOLDER), - hmac_key=job.hmac_key, + ) except DeserializationError as e: client_exception = e @@ -995,7 +995,7 @@ def from_describe_response(describe_training_job_response, sagemaker_session): job_exception = serialization.deserialize_exception_from_s3( sagemaker_session=sagemaker_session, s3_uri=s3_path_join(job.s3_uri, EXCEPTION_FOLDER), - hmac_key=job.hmac_key, + ) except ServiceError as serr: chained_e = serr.__cause__ @@ -1085,7 +1085,7 @@ def result(self, timeout: float = None) -> Any: self._return = serialization.deserialize_obj_from_s3( sagemaker_session=self._job.sagemaker_session, s3_uri=s3_path_join(self._job.s3_uri, RESULTS_FOLDER), - hmac_key=self._job.hmac_key, + ) self._state = _FINISHED return self._return @@ -1094,7 +1094,7 @@ def result(self, timeout: float = None) -> Any: self._exception = serialization.deserialize_exception_from_s3( sagemaker_session=self._job.sagemaker_session, s3_uri=s3_path_join(self._job.s3_uri, EXCEPTION_FOLDER), - hmac_key=self._job.hmac_key, + ) except ServiceError as serr: chained_e = serr.__cause__ diff --git a/sagemaker-core/src/sagemaker/core/remote_function/core/pipeline_variables.py b/sagemaker-core/src/sagemaker/core/remote_function/core/pipeline_variables.py index 5278306063..491267b35f 100644 --- a/sagemaker-core/src/sagemaker/core/remote_function/core/pipeline_variables.py +++ b/sagemaker-core/src/sagemaker/core/remote_function/core/pipeline_variables.py @@ -164,7 +164,6 @@ class _DelayedReturnResolver: def __init__( self, delayed_returns: List[_DelayedReturn], - hmac_key: str, properties_resolver: _PropertiesResolver, parameter_resolver: _ParameterResolver, execution_variable_resolver: _ExecutionVariableResolver, @@ -175,7 +174,6 @@ def __init__( Args: delayed_returns: list of delayed returns to resolve. - hmac_key: key used to encrypt serialized and deserialized function and arguments. properties_resolver: resolver used to resolve step properties. parameter_resolver: resolver used to pipeline parameters. execution_variable_resolver: resolver used to resolve execution variables. @@ -197,7 +195,6 @@ def deserialization_task(uri): return uri, deserialize_obj_from_s3( sagemaker_session=settings["sagemaker_session"], s3_uri=uri, - hmac_key=hmac_key, ) with ThreadPoolExecutor() as executor: @@ -247,7 +244,6 @@ def resolve_pipeline_variables( context: Context, func_args: Tuple, func_kwargs: Dict, - hmac_key: str, s3_base_uri: str, **settings, ): @@ -257,7 +253,6 @@ def resolve_pipeline_variables( context: context for the execution. func_args: function args. func_kwargs: function kwargs. - hmac_key: key used to encrypt serialized and deserialized function and arguments. s3_base_uri: the s3 base uri of the function step that the serialized artifacts will be uploaded to. The s3_base_uri = s3_root_uri + pipeline_name. **settings: settings to pass to the deserialization function. @@ -280,7 +275,6 @@ def resolve_pipeline_variables( properties_resolver = _PropertiesResolver(context) delayed_return_resolver = _DelayedReturnResolver( delayed_returns=delayed_returns, - hmac_key=hmac_key, properties_resolver=properties_resolver, parameter_resolver=parameter_resolver, execution_variable_resolver=execution_variable_resolver, diff --git a/sagemaker-core/src/sagemaker/core/remote_function/core/serialization.py b/sagemaker-core/src/sagemaker/core/remote_function/core/serialization.py index 39517bdc6b..8871f6727f 100644 --- a/sagemaker-core/src/sagemaker/core/remote_function/core/serialization.py +++ b/sagemaker-core/src/sagemaker/core/remote_function/core/serialization.py @@ -19,7 +19,6 @@ import io import sys -import hmac import hashlib import pickle @@ -156,7 +155,7 @@ def deserialize(s3_uri: str, bytes_to_deserialize: bytes) -> Any: # TODO: use dask serializer in case dask distributed is installed in users' environment. def serialize_func_to_s3( - func: Callable, sagemaker_session: Session, s3_uri: str, hmac_key: str, s3_kms_key: str = None + func: Callable, sagemaker_session: Session, s3_uri: str, s3_kms_key: str = None ): """Serializes function and uploads it to S3. @@ -164,7 +163,6 @@ def serialize_func_to_s3( sagemaker_session (sagemaker.core.helper.session.Session): The underlying Boto3 session which AWS service calls are delegated to. s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. - hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized func. s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. func: function to be serialized and persisted Raises: @@ -173,14 +171,13 @@ def serialize_func_to_s3( _upload_payload_and_metadata_to_s3( bytes_to_upload=CloudpickleSerializer.serialize(func), - hmac_key=hmac_key, s3_uri=s3_uri, sagemaker_session=sagemaker_session, s3_kms_key=s3_kms_key, ) -def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: str) -> Callable: +def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str) -> Callable: """Downloads from S3 and then deserializes data objects. This method downloads the serialized training job outputs to a temporary directory and @@ -190,7 +187,6 @@ def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: sagemaker_session (sagemaker.core.helper.session.Session): The underlying sagemaker session which AWS service calls are delegated to. s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. - hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized func. Returns : The deserialized function. Raises: @@ -203,14 +199,14 @@ def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session) _perform_integrity_check( - expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize + expected_hash_value=metadata.sha256_hash, buffer=bytes_to_deserialize ) return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize) def serialize_obj_to_s3( - obj: Any, sagemaker_session: Session, s3_uri: str, hmac_key: str, s3_kms_key: str = None + obj: Any, sagemaker_session: Session, s3_uri: str, s3_kms_key: str = None ): """Serializes data object and uploads it to S3. @@ -219,7 +215,6 @@ def serialize_obj_to_s3( The underlying Boto3 session which AWS service calls are delegated to. s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. - hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj. obj: object to be serialized and persisted Raises: SerializationError: when fail to serialize object to bytes. @@ -227,7 +222,6 @@ def serialize_obj_to_s3( _upload_payload_and_metadata_to_s3( bytes_to_upload=CloudpickleSerializer.serialize(obj), - hmac_key=hmac_key, s3_uri=s3_uri, sagemaker_session=sagemaker_session, s3_kms_key=s3_kms_key, @@ -274,14 +268,13 @@ def json_serialize_obj_to_s3( ) -def deserialize_obj_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: str) -> Any: +def deserialize_obj_from_s3(sagemaker_session: Session, s3_uri: str) -> Any: """Downloads from S3 and then deserializes data objects. Args: sagemaker_session (sagemaker.core.helper.session.Session): The underlying sagemaker session which AWS service calls are delegated to. s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. - hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj. Returns : Deserialized python objects. Raises: @@ -295,14 +288,14 @@ def deserialize_obj_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: s bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session) _perform_integrity_check( - expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize + expected_hash_value=metadata.sha256_hash, buffer=bytes_to_deserialize ) return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize) def serialize_exception_to_s3( - exc: Exception, sagemaker_session: Session, s3_uri: str, hmac_key: str, s3_kms_key: str = None + exc: Exception, sagemaker_session: Session, s3_uri: str, s3_kms_key: str = None ): """Serializes exception with traceback and uploads it to S3. @@ -311,7 +304,6 @@ def serialize_exception_to_s3( The underlying Boto3 session which AWS service calls are delegated to. s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. - hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized exception. exc: Exception to be serialized and persisted Raises: SerializationError: when fail to serialize object to bytes. @@ -320,7 +312,6 @@ def serialize_exception_to_s3( _upload_payload_and_metadata_to_s3( bytes_to_upload=CloudpickleSerializer.serialize(exc), - hmac_key=hmac_key, s3_uri=s3_uri, sagemaker_session=sagemaker_session, s3_kms_key=s3_kms_key, @@ -329,7 +320,6 @@ def serialize_exception_to_s3( def _upload_payload_and_metadata_to_s3( bytes_to_upload: Union[bytes, io.BytesIO], - hmac_key: str, s3_uri: str, sagemaker_session: Session, s3_kms_key, @@ -338,7 +328,6 @@ def _upload_payload_and_metadata_to_s3( Args: bytes_to_upload (bytes): Serialized bytes to upload. - hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj. s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. sagemaker_session (sagemaker.core.helper.session.Session): The underlying Boto3 session which AWS service calls are delegated to. @@ -346,7 +335,7 @@ def _upload_payload_and_metadata_to_s3( """ _upload_bytes_to_s3(bytes_to_upload, f"{s3_uri}/payload.pkl", s3_kms_key, sagemaker_session) - sha256_hash = _compute_hash(bytes_to_upload, secret_key=hmac_key) + sha256_hash = _compute_hash(bytes_to_upload) _upload_bytes_to_s3( _MetaData(sha256_hash).to_json(), @@ -356,14 +345,13 @@ def _upload_payload_and_metadata_to_s3( ) -def deserialize_exception_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: str) -> Any: +def deserialize_exception_from_s3(sagemaker_session: Session, s3_uri: str) -> Any: """Downloads from S3 and then deserializes exception. Args: sagemaker_session (sagemaker.core.helper.session.Session): The underlying sagemaker session which AWS service calls are delegated to. s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. - hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized exception. Returns : Deserialized exception with traceback. Raises: @@ -377,7 +365,7 @@ def deserialize_exception_from_s3(sagemaker_session: Session, s3_uri: str, hmac_ bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session) _perform_integrity_check( - expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize + expected_hash_value=metadata.sha256_hash, buffer=bytes_to_deserialize ) return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize) @@ -403,19 +391,19 @@ def _read_bytes_from_s3(s3_uri, sagemaker_session): ) from e -def _compute_hash(buffer: bytes, secret_key: str) -> str: - """Compute the hmac-sha256 hash""" - return hmac.new(secret_key.encode(), msg=buffer, digestmod=hashlib.sha256).hexdigest() +def _compute_hash(buffer: bytes) -> str: + """Compute the sha256 hash""" + return hashlib.sha256(buffer).hexdigest() -def _perform_integrity_check(expected_hash_value: str, secret_key: str, buffer: bytes): +def _perform_integrity_check(expected_hash_value: str, buffer: bytes): """Performs integrity checks for serialized code/arguments uploaded to s3. Verifies whether the hash read from s3 matches the hash calculated during remote function execution. """ - actual_hash_value = _compute_hash(buffer=buffer, secret_key=secret_key) - if not hmac.compare_digest(expected_hash_value, actual_hash_value): + actual_hash_value = _compute_hash(buffer=buffer) + if expected_hash_value != actual_hash_value: raise DeserializationError( "Integrity check for the serialized function or data failed. " "Please restrict access to your S3 bucket" diff --git a/sagemaker-core/src/sagemaker/core/remote_function/core/stored_function.py b/sagemaker-core/src/sagemaker/core/remote_function/core/stored_function.py index 48724d8e36..c7ee86f8a7 100644 --- a/sagemaker-core/src/sagemaker/core/remote_function/core/stored_function.py +++ b/sagemaker-core/src/sagemaker/core/remote_function/core/stored_function.py @@ -55,7 +55,6 @@ def __init__( self, sagemaker_session: Session, s3_base_uri: str, - hmac_key: str, s3_kms_key: str = None, context: Context = Context(), ): @@ -66,13 +65,11 @@ def __init__( AWS service calls are delegated to. s3_base_uri: the base uri to which serialized artifacts will be uploaded. s3_kms_key: KMS key used to encrypt artifacts uploaded to S3. - hmac_key: Key used to encrypt serialized and deserialized function and arguments. context: Build or run context of a pipeline step. """ self.sagemaker_session = sagemaker_session self.s3_base_uri = s3_base_uri self.s3_kms_key = s3_kms_key - self.hmac_key = hmac_key self.context = context # For pipeline steps, function code is at: base/step_name/build_timestamp/ @@ -114,7 +111,7 @@ def save(self, func, *args, **kwargs): sagemaker_session=self.sagemaker_session, s3_uri=s3_path_join(self.func_upload_path, FUNCTION_FOLDER), s3_kms_key=self.s3_kms_key, - hmac_key=self.hmac_key, + ) logger.info( @@ -126,7 +123,7 @@ def save(self, func, *args, **kwargs): obj=(args, kwargs), sagemaker_session=self.sagemaker_session, s3_uri=s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER), - hmac_key=self.hmac_key, + s3_kms_key=self.s3_kms_key, ) @@ -144,7 +141,7 @@ def save_pipeline_step_function(self, serialized_data): ) serialization._upload_payload_and_metadata_to_s3( bytes_to_upload=serialized_data.func, - hmac_key=self.hmac_key, + s3_uri=s3_path_join(self.func_upload_path, FUNCTION_FOLDER), sagemaker_session=self.sagemaker_session, s3_kms_key=self.s3_kms_key, @@ -156,7 +153,7 @@ def save_pipeline_step_function(self, serialized_data): ) serialization._upload_payload_and_metadata_to_s3( bytes_to_upload=serialized_data.args, - hmac_key=self.hmac_key, + s3_uri=s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER), sagemaker_session=self.sagemaker_session, s3_kms_key=self.s3_kms_key, @@ -172,7 +169,7 @@ def load_and_invoke(self) -> Any: func = serialization.deserialize_func_from_s3( sagemaker_session=self.sagemaker_session, s3_uri=s3_path_join(self.func_upload_path, FUNCTION_FOLDER), - hmac_key=self.hmac_key, + ) logger.info( @@ -182,7 +179,7 @@ def load_and_invoke(self) -> Any: args, kwargs = serialization.deserialize_obj_from_s3( sagemaker_session=self.sagemaker_session, s3_uri=s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER), - hmac_key=self.hmac_key, + ) logger.info("Resolving pipeline variables") @@ -190,7 +187,7 @@ def load_and_invoke(self) -> Any: self.context, args, kwargs, - hmac_key=self.hmac_key, + s3_base_uri=self.s3_base_uri, sagemaker_session=self.sagemaker_session, ) @@ -206,7 +203,7 @@ def load_and_invoke(self) -> Any: obj=result, sagemaker_session=self.sagemaker_session, s3_uri=s3_path_join(self.results_upload_path, RESULTS_FOLDER), - hmac_key=self.hmac_key, + s3_kms_key=self.s3_kms_key, ) diff --git a/sagemaker-core/src/sagemaker/core/remote_function/errors.py b/sagemaker-core/src/sagemaker/core/remote_function/errors.py index d12fde52d6..3f391570cf 100644 --- a/sagemaker-core/src/sagemaker/core/remote_function/errors.py +++ b/sagemaker-core/src/sagemaker/core/remote_function/errors.py @@ -70,7 +70,7 @@ def _write_failure_reason_file(failure_msg): f.write(failure_msg) -def handle_error(error, sagemaker_session, s3_base_uri, s3_kms_key, hmac_key) -> int: +def handle_error(error, sagemaker_session, s3_base_uri, s3_kms_key) -> int: """Handle all exceptions raised during remote function execution. Args: @@ -79,7 +79,6 @@ def handle_error(error, sagemaker_session, s3_base_uri, s3_kms_key, hmac_key) -> AWS service calls are delegated to. s3_base_uri (str): S3 root uri to which resulting serialized exception will be uploaded. s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. - hmac_key (str): Key used to calculate hmac hash of the serialized exception. Returns : exit_code (int): Exit code to terminate current job. """ @@ -97,7 +96,6 @@ def handle_error(error, sagemaker_session, s3_base_uri, s3_kms_key, hmac_key) -> exc=error, sagemaker_session=sagemaker_session, s3_uri=s3_path_join(s3_base_uri, "exception"), - hmac_key=hmac_key, s3_kms_key=s3_kms_key, ) diff --git a/sagemaker-core/src/sagemaker/core/remote_function/invoke_function.py b/sagemaker-core/src/sagemaker/core/remote_function/invoke_function.py index d353232b57..2e69f4f116 100644 --- a/sagemaker-core/src/sagemaker/core/remote_function/invoke_function.py +++ b/sagemaker-core/src/sagemaker/core/remote_function/invoke_function.py @@ -98,7 +98,7 @@ def _load_pipeline_context(args) -> Context: def _execute_remote_function( - sagemaker_session, s3_base_uri, s3_kms_key, run_in_context, hmac_key, context + sagemaker_session, s3_base_uri, s3_kms_key, run_in_context, context ): """Execute stored remote function""" from sagemaker.core.remote_function.core.stored_function import StoredFunction @@ -107,7 +107,6 @@ def _execute_remote_function( sagemaker_session=sagemaker_session, s3_base_uri=s3_base_uri, s3_kms_key=s3_kms_key, - hmac_key=hmac_key, context=context, ) @@ -138,15 +137,12 @@ def main(sys_args=None): run_in_context = args.run_in_context pipeline_context = _load_pipeline_context(args) - hmac_key = os.getenv("REMOTE_FUNCTION_SECRET_KEY") - sagemaker_session = _get_sagemaker_session(region) _execute_remote_function( sagemaker_session=sagemaker_session, s3_base_uri=s3_base_uri, s3_kms_key=s3_kms_key, run_in_context=run_in_context, - hmac_key=hmac_key, context=pipeline_context, ) @@ -162,7 +158,6 @@ def main(sys_args=None): sagemaker_session=sagemaker_session, s3_base_uri=s3_uri, s3_kms_key=s3_kms_key, - hmac_key=hmac_key, ) finally: sys.exit(exit_code) diff --git a/sagemaker-core/src/sagemaker/core/remote_function/job.py b/sagemaker-core/src/sagemaker/core/remote_function/job.py index bed00e148f..435062db57 100644 --- a/sagemaker-core/src/sagemaker/core/remote_function/job.py +++ b/sagemaker-core/src/sagemaker/core/remote_function/job.py @@ -17,7 +17,6 @@ import json import os import re -import secrets import shutil import sys import time @@ -621,11 +620,6 @@ def __init__( {"AWS_DEFAULT_REGION": self.sagemaker_session.boto_region_name} ) - # The following will be overridden by the _Job.compile method. - # However, it needs to be kept here for feature store SDK. - # TODO: update the feature store SDK to set the HMAC key there. - self.environment_variables.update({"REMOTE_FUNCTION_SECRET_KEY": secrets.token_hex(32)}) - if spark_config and image_uri: raise ValueError("spark_config and image_uri cannot be specified at the same time!") @@ -839,19 +833,17 @@ def _get_default_spark_image(session): class _Job: """Helper class that interacts with the SageMaker training service.""" - def __init__(self, job_name: str, s3_uri: str, sagemaker_session: Session, hmac_key: str): + def __init__(self, job_name: str, s3_uri: str, sagemaker_session: Session): """Initialize a _Job object. Args: job_name (str): The training job name. s3_uri (str): The training job output S3 uri. sagemaker_session (Session): SageMaker boto session. - hmac_key (str): Remote function secret key. """ self.job_name = job_name self.s3_uri = s3_uri self.sagemaker_session = sagemaker_session - self.hmac_key = hmac_key self._last_describe_response = None @staticmethod @@ -867,9 +859,8 @@ def from_describe_response(describe_training_job_response, sagemaker_session): """ job_name = describe_training_job_response["TrainingJobName"] s3_uri = describe_training_job_response["OutputDataConfig"]["S3OutputPath"] - hmac_key = describe_training_job_response["Environment"]["REMOTE_FUNCTION_SECRET_KEY"] - job = _Job(job_name, s3_uri, sagemaker_session, hmac_key) + job = _Job(job_name, s3_uri, sagemaker_session) job._last_describe_response = describe_training_job_response return job @@ -907,7 +898,6 @@ def start(job_settings: _JobSettings, func, func_args, func_kwargs, run_info=Non job_name, s3_base_uri, job_settings.sagemaker_session, - training_job_request["Environment"]["REMOTE_FUNCTION_SECRET_KEY"], ) @staticmethod @@ -935,18 +925,11 @@ def compile( jobs_container_entrypoint = JOBS_CONTAINER_ENTRYPOINT[:] - # generate hmac key for integrity check - if step_compilation_context is None: - hmac_key = secrets.token_hex(32) - else: - hmac_key = step_compilation_context.function_step_secret_token - # serialize function and arguments if step_compilation_context is None: stored_function = StoredFunction( sagemaker_session=job_settings.sagemaker_session, s3_base_uri=s3_base_uri, - hmac_key=hmac_key, s3_kms_key=job_settings.s3_kms_key, ) stored_function.save(func, *func_args, **func_kwargs) @@ -954,7 +937,6 @@ def compile( stored_function = StoredFunction( sagemaker_session=job_settings.sagemaker_session, s3_base_uri=s3_base_uri, - hmac_key=hmac_key, s3_kms_key=job_settings.s3_kms_key, context=Context( step_name=step_compilation_context.step_name, @@ -1114,7 +1096,6 @@ def compile( request_dict["EnableManagedSpotTraining"] = job_settings.use_spot_instances request_dict["Environment"] = job_settings.environment_variables - request_dict["Environment"].update({"REMOTE_FUNCTION_SECRET_KEY": hmac_key}) extended_request = _extend_spark_config_to_request(request_dict, job_settings, s3_base_uri) extended_request = _extend_mpirun_to_request(extended_request, job_settings) diff --git a/sagemaker-core/tests/unit/remote_function/runtime_environment/test_bootstrap_runtime_environment.py b/sagemaker-core/tests/unit/remote_function/runtime_environment/test_bootstrap_runtime_environment.py index cc8319f935..461a3ecb73 100644 --- a/sagemaker-core/tests/unit/remote_function/runtime_environment/test_bootstrap_runtime_environment.py +++ b/sagemaker-core/tests/unit/remote_function/runtime_environment/test_bootstrap_runtime_environment.py @@ -10,20 +10,23 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +"""Tests for bootstrap_runtime_environment module.""" +from __future__ import absolute_import -import pytest -from unittest.mock import Mock, patch, mock_open, MagicMock import json -import sys +import os +import pytest +import subprocess +from unittest.mock import patch, MagicMock, mock_open, call from sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment import ( + _parse_args, _bootstrap_runtime_env_for_remote_function, _bootstrap_runtime_env_for_pipeline_step, _handle_pre_exec_scripts, _install_dependencies, _unpack_user_workspace, _write_failure_reason_file, - _parse_args, log_key_value, log_env_variables, mask_sensitive_info, @@ -35,6 +38,11 @@ main, SUCCESS_EXIT_CODE, DEFAULT_FAILURE_CODE, + FAILURE_REASON_PATH, + REMOTE_FUNCTION_WORKSPACE, + BASE_CHANNEL_PATH, + JOB_REMOTE_FUNCTION_WORKSPACE, + SCRIPT_AND_DEPENDENCIES_CHANNEL_NAME, SENSITIVE_KEYWORDS, HIDDEN_VALUE, ) @@ -43,506 +51,629 @@ ) -class TestBootstrapRuntimeEnvironment: - """Test cases for bootstrap runtime environment functions""" - - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._unpack_user_workspace" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._handle_pre_exec_scripts" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._install_dependencies" - ) - def test_bootstrap_runtime_env_for_remote_function( - self, mock_install, mock_handle, mock_unpack - ): - """Test _bootstrap_runtime_env_for_remote_function""" - mock_unpack.return_value = "/workspace" - dependency_settings = _DependencySettings(dependency_file="requirements.txt") - - _bootstrap_runtime_env_for_remote_function( - client_python_version="3.8", conda_env="myenv", dependency_settings=dependency_settings - ) - - mock_unpack.assert_called_once() - mock_handle.assert_called_once_with("/workspace") - mock_install.assert_called_once() - - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._unpack_user_workspace" - ) - def test_bootstrap_runtime_env_for_remote_function_no_workspace(self, mock_unpack): - """Test _bootstrap_runtime_env_for_remote_function with no workspace""" - mock_unpack.return_value = None - - _bootstrap_runtime_env_for_remote_function(client_python_version="3.8") - - mock_unpack.assert_called_once() - - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._unpack_user_workspace" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.os.path.exists" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.os.mkdir" - ) - def test_bootstrap_runtime_env_for_pipeline_step(self, mock_mkdir, mock_exists, mock_unpack): - """Test _bootstrap_runtime_env_for_pipeline_step""" - mock_unpack.return_value = None - mock_exists.return_value = False - - _bootstrap_runtime_env_for_pipeline_step( - client_python_version="3.8", func_step_workspace="workspace" - ) - - mock_mkdir.assert_called_once() - - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.os.path.isfile" - ) - def test_handle_pre_exec_scripts_exists(self, mock_isfile, mock_manager_class): - """Test _handle_pre_exec_scripts when script exists""" - mock_isfile.return_value = True - mock_manager = Mock() - mock_manager_class.return_value = mock_manager - - _handle_pre_exec_scripts("/workspace") - - mock_manager.run_pre_exec_script.assert_called_once() - - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.os.path.isfile" - ) - def test_handle_pre_exec_scripts_not_exists(self, mock_isfile, mock_manager_class): - """Test _handle_pre_exec_scripts when script doesn't exist""" - mock_isfile.return_value = False - mock_manager = Mock() - mock_manager_class.return_value = mock_manager - - _handle_pre_exec_scripts("/workspace") - - mock_manager.run_pre_exec_script.assert_not_called() - - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.os.path.join" - ) - def test_install_dependencies_with_file(self, mock_join, mock_manager_class): - """Test _install_dependencies with dependency file""" - mock_join.return_value = "/workspace/requirements.txt" - mock_manager = Mock() - mock_manager_class.return_value = mock_manager - - dependency_settings = _DependencySettings(dependency_file="requirements.txt") - - _install_dependencies( - dependency_file_dir="/workspace", - conda_env="myenv", - client_python_version="3.8", - channel_name="channel", - dependency_settings=dependency_settings, - ) - - mock_manager.bootstrap.assert_called_once() - - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager" - ) - def test_install_dependencies_no_file(self, mock_manager_class): - """Test _install_dependencies with no dependency file""" - mock_manager = Mock() - mock_manager_class.return_value = mock_manager - - dependency_settings = _DependencySettings(dependency_file=None) - - _install_dependencies( - dependency_file_dir="/workspace", - conda_env=None, - client_python_version="3.8", - channel_name="channel", - dependency_settings=dependency_settings, - ) - - mock_manager.bootstrap.assert_not_called() - - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.os.path.exists" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.os.path.isfile" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.shutil.unpack_archive" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.pathlib.Path" - ) - def test_unpack_user_workspace_success(self, mock_path, mock_unpack, mock_isfile, mock_exists): - """Test _unpack_user_workspace successfully unpacks workspace""" - mock_exists.return_value = True - mock_isfile.return_value = True - mock_path.return_value.absolute.return_value = "/workspace" - - result = _unpack_user_workspace() - - assert result is not None - mock_unpack.assert_called_once() - - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.os.path.exists" - ) - def test_unpack_user_workspace_no_directory(self, mock_exists): - """Test _unpack_user_workspace when directory doesn't exist""" - mock_exists.return_value = False - - result = _unpack_user_workspace() - - assert result is None - - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.os.path.exists" - ) - @patch("builtins.open", new_callable=mock_open) - def test_write_failure_reason_file(self, mock_file, mock_exists): - """Test _write_failure_reason_file""" - mock_exists.return_value = False - - _write_failure_reason_file("Test error message") - - mock_file.assert_called_once() - mock_file().write.assert_called_once_with("RuntimeEnvironmentError: Test error message") - - def test_parse_args(self): - """Test _parse_args""" - args = _parse_args( - [ - "--job_conda_env", - "myenv", - "--client_python_version", - "3.8", - "--dependency_settings", - '{"dependency_file": "requirements.txt"}', - ] - ) - - assert args.job_conda_env == "myenv" - assert args.client_python_version == "3.8" - assert args.dependency_settings == '{"dependency_file": "requirements.txt"}' - - -class TestLoggingFunctions: - """Test cases for logging functions""" - - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.logger" - ) - def test_log_key_value_normal(self, mock_logger): - """Test log_key_value with normal key""" - log_key_value("MY_KEY", "my_value") - +class TestParseArgs: + """Test _parse_args function.""" + + def test_parse_required_args(self): + """Test parsing required arguments.""" + args = [ + "--client_python_version", "3.8", + ] + parsed = _parse_args(args) + assert parsed.client_python_version == "3.8" + + def test_parse_all_args(self): + """Test parsing all arguments.""" + args = [ + "--job_conda_env", "my-env", + "--client_python_version", "3.9", + "--client_sagemaker_pysdk_version", "2.100.0", + "--pipeline_execution_id", "exec-123", + "--dependency_settings", '{"dependency_file": "requirements.txt"}', + "--func_step_s3_dir", "s3://bucket/func", + "--distribution", "torchrun", + "--user_nproc_per_node", "4", + ] + parsed = _parse_args(args) + assert parsed.job_conda_env == "my-env" + assert parsed.client_python_version == "3.9" + assert parsed.client_sagemaker_pysdk_version == "2.100.0" + assert parsed.pipeline_execution_id == "exec-123" + assert parsed.dependency_settings == '{"dependency_file": "requirements.txt"}' + assert parsed.func_step_s3_dir == "s3://bucket/func" + assert parsed.distribution == "torchrun" + assert parsed.user_nproc_per_node == "4" + + def test_parse_default_values(self): + """Test default values for optional arguments.""" + args = [ + "--client_python_version", "3.8", + ] + parsed = _parse_args(args) + assert parsed.job_conda_env is None + assert parsed.client_sagemaker_pysdk_version is None + assert parsed.pipeline_execution_id is None + assert parsed.dependency_settings is None + assert parsed.func_step_s3_dir is None + assert parsed.distribution is None + assert parsed.user_nproc_per_node is None + + +class TestLogKeyValue: + """Test log_key_value function.""" + + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.logger") + def test_logs_regular_value(self, mock_logger): + """Test logs regular key-value pair.""" + log_key_value("my_name", "my_value") + mock_logger.info.assert_called_once_with("%s=%s", "my_name", "my_value") + + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.logger") + def test_masks_sensitive_key(self, mock_logger): + """Test masks sensitive keywords.""" + for keyword in ["PASSWORD", "SECRET", "TOKEN", "KEY", "PRIVATE", "CREDENTIALS"]: + mock_logger.reset_mock() + log_key_value(f"my_{keyword}", "sensitive_value") + mock_logger.info.assert_called_once_with("%s=%s", f"my_{keyword}", HIDDEN_VALUE) + + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.logger") + def test_logs_dict_value(self, mock_logger): + """Test logs dictionary value.""" + value = {"field1": "value1", "field2": "value2"} + log_key_value("my_config", value) + mock_logger.info.assert_called_once_with("%s=%s", "my_config", json.dumps(value)) + + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.logger") + def test_logs_json_string_value(self, mock_logger): + """Test logs JSON string value.""" + value = '{"key1": "value1"}' + log_key_value("my_key", value) mock_logger.info.assert_called_once() - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.logger" - ) - def test_log_key_value_sensitive(self, mock_logger): - """Test log_key_value with sensitive key""" - log_key_value("MY_PASSWORD", "secret123") - mock_logger.info.assert_called_once() - call_args = mock_logger.info.call_args[0] - assert HIDDEN_VALUE in str(call_args) - - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.logger" - ) - def test_log_key_value_dict(self, mock_logger): - """Test log_key_value with dictionary value""" - log_key_value("MY_CONFIG", {"key": "value"}) - - mock_logger.info.assert_called_once() +class TestLogEnvVariables: + """Test log_env_variables function.""" - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.logger" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.os.environ", - {"ENV_VAR": "value"}, - ) - def test_log_env_variables(self, mock_logger): - """Test log_env_variables""" - log_env_variables({"CUSTOM_VAR": "custom_value"}) + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.log_key_value") + @patch.dict("os.environ", {"ENV_VAR1": "value1", "ENV_VAR2": "value2"}) + def test_logs_env_and_dict_variables(self, mock_log_kv): + """Test logs both environment and dictionary variables.""" + env_dict = {"DICT_VAR1": "dict_value1", "DICT_VAR2": "dict_value2"} + log_env_variables(env_dict) + + # Should be called for env vars and dict vars + assert mock_log_kv.call_count >= 4 - assert mock_logger.info.call_count >= 2 - def test_mask_sensitive_info(self): - """Test mask_sensitive_info""" - data = {"username": "user", "password": "secret", "nested": {"api_key": "key123"}} +class TestMaskSensitiveInfo: + """Test mask_sensitive_info function.""" + def test_masks_sensitive_keys_in_dict(self): + """Test masks sensitive keys in dictionary.""" + data = { + "username": "user", + "password": "secret123", + "api_key": "key123", + } result = mask_sensitive_info(data) - - assert result["password"] == HIDDEN_VALUE - assert result["nested"]["api_key"] == HIDDEN_VALUE assert result["username"] == "user" + assert result["password"] == HIDDEN_VALUE + assert result["api_key"] == HIDDEN_VALUE + + def test_masks_nested_dict(self): + """Test masks sensitive keys in nested dictionary.""" + data = { + "config": { + "username": "user", + "secret": "secret123", + } + } + result = mask_sensitive_info(data) + assert result["config"]["username"] == "user" + assert result["config"]["secret"] == HIDDEN_VALUE + def test_returns_non_dict_unchanged(self): + """Test returns non-dictionary unchanged.""" + data = "string_value" + result = mask_sensitive_info(data) + assert result == "string_value" -class TestResourceFunctions: - """Test cases for resource detection functions""" - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.multiprocessing.cpu_count" - ) - def test_num_cpus(self, mock_cpu_count): - """Test num_cpus""" - mock_cpu_count.return_value = 4 +class TestNumCpus: + """Test num_cpus function.""" - result = num_cpus() + @patch("multiprocessing.cpu_count") + def test_returns_cpu_count(self, mock_cpu_count): + """Test returns CPU count.""" + mock_cpu_count.return_value = 8 + assert num_cpus() == 8 - assert result == 4 - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.subprocess.check_output" - ) - def test_num_gpus_with_gpus(self, mock_check_output): - """Test num_gpus when GPUs are present""" - mock_check_output.return_value = b"GPU 0: Tesla V100\nGPU 1: Tesla V100\n" +class TestNumGpus: + """Test num_gpus function.""" - result = num_gpus() + @patch("subprocess.check_output") + def test_returns_gpu_count(self, mock_check_output): + """Test returns GPU count.""" + mock_check_output.return_value = b"GPU 0: Tesla V100\nGPU 1: Tesla V100\n" + assert num_gpus() == 2 - assert result == 2 + @patch("subprocess.check_output") + def test_returns_zero_on_error(self, mock_check_output): + """Test returns zero when nvidia-smi fails.""" + mock_check_output.side_effect = subprocess.CalledProcessError(1, "nvidia-smi") + assert num_gpus() == 0 - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.subprocess.check_output" - ) - def test_num_gpus_no_gpus(self, mock_check_output): - """Test num_gpus when no GPUs are present""" + @patch("subprocess.check_output") + def test_returns_zero_on_os_error(self, mock_check_output): + """Test returns zero when nvidia-smi not found.""" mock_check_output.side_effect = OSError() + assert num_gpus() == 0 - result = num_gpus() - assert result == 0 +class TestNumNeurons: + """Test num_neurons function.""" - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.subprocess.check_output" - ) - def test_num_neurons_with_neurons(self, mock_check_output): - """Test num_neurons when neurons are present""" - mock_check_output.return_value = b'[{"nc_count": 2}, {"nc_count": 2}]' + @patch("subprocess.check_output") + def test_returns_neuron_count(self, mock_check_output): + """Test returns neuron core count.""" + mock_output = json.dumps([{"nc_count": 2}, {"nc_count": 4}]) + mock_check_output.return_value = mock_output.encode("utf-8") + assert num_neurons() == 6 - result = num_neurons() - - assert result == 4 - - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.subprocess.check_output" - ) - def test_num_neurons_no_neurons(self, mock_check_output): - """Test num_neurons when no neurons are present""" + @patch("subprocess.check_output") + def test_returns_zero_on_os_error(self, mock_check_output): + """Test returns zero when neuron-ls not found.""" mock_check_output.side_effect = OSError() + assert num_neurons() == 0 - result = num_neurons() - - assert result == 0 - - -class TestSerializationFunctions: - """Test cases for serialization functions""" - - def test_safe_serialize_string(self): - """Test safe_serialize with string""" - result = safe_serialize("test_string") - - assert result == "test_string" + @patch("subprocess.check_output") + def test_returns_zero_on_called_process_error(self, mock_check_output): + """Test returns zero when neuron-ls fails.""" + error = subprocess.CalledProcessError(1, "neuron-ls") + error.output = b"error=No neuron devices found" + mock_check_output.side_effect = error + assert num_neurons() == 0 - def test_safe_serialize_dict(self): - """Test safe_serialize with dictionary""" - result = safe_serialize({"key": "value"}) - assert result == '{"key": "value"}' +class TestSafeSerialize: + """Test safe_serialize function.""" - def test_safe_serialize_list(self): - """Test safe_serialize with list""" - result = safe_serialize([1, 2, 3]) + def test_returns_string_as_is(self): + """Test returns string without quotes.""" + assert safe_serialize("test_string") == "test_string" - assert result == "[1, 2, 3]" + def test_serializes_dict(self): + """Test serializes dictionary.""" + data = {"key": "value"} + assert safe_serialize(data) == '{"key": "value"}' - def test_safe_serialize_non_serializable(self): - """Test safe_serialize with non-serializable object""" + def test_serializes_list(self): + """Test serializes list.""" + data = [1, 2, 3] + assert safe_serialize(data) == "[1, 2, 3]" - class CustomObject: + def test_returns_str_for_non_serializable(self): + """Test returns str() for non-serializable objects.""" + class CustomObj: def __str__(self): return "custom_object" - - result = safe_serialize(CustomObject()) - - assert "custom_object" in result + + obj = CustomObj() + assert safe_serialize(obj) == "custom_object" class TestSetEnv: - """Test cases for set_env function""" + """Test set_env function.""" @patch("builtins.open", new_callable=mock_open) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.os.environ", - {"TRAINING_JOB_NAME": "test-job"}, - ) - def test_set_env_basic(self, mock_neurons, mock_gpus, mock_cpus, mock_file): - """Test set_env with basic configuration""" - mock_cpus.return_value = 4 - mock_gpus.return_value = 0 + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.log_env_variables") + @patch.dict("os.environ", {"TRAINING_JOB_NAME": "test-job"}) + def test_sets_basic_env_vars(self, mock_log_env, mock_neurons, mock_gpus, mock_cpus, mock_file): + """Test sets basic environment variables.""" + mock_cpus.return_value = 8 + mock_gpus.return_value = 2 mock_neurons.return_value = 0 - + resource_config = { "current_host": "algo-1", - "current_instance_type": "ml.m5.xlarge", - "hosts": ["algo-1"], + "current_instance_type": "ml.p3.2xlarge", + "hosts": ["algo-1", "algo-2"], "network_interface_name": "eth0", } - + set_env(resource_config) - + mock_file.assert_called_once() + mock_log_env.assert_called_once() @patch("builtins.open", new_callable=mock_open) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.os.environ", - {"TRAINING_JOB_NAME": "test-job"}, - ) - def test_set_env_with_torchrun(self, mock_neurons, mock_gpus, mock_cpus, mock_file): - """Test set_env with torchrun distribution""" - mock_cpus.return_value = 4 + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.log_env_variables") + @patch.dict("os.environ", {"TRAINING_JOB_NAME": "test-job"}) + def test_sets_torchrun_distribution_vars(self, mock_log_env, mock_neurons, mock_gpus, mock_cpus, mock_file): + """Test sets torchrun distribution environment variables.""" + mock_cpus.return_value = 8 mock_gpus.return_value = 2 mock_neurons.return_value = 0 - + resource_config = { "current_host": "algo-1", - "current_instance_type": "ml.p3.2xlarge", - "hosts": ["algo-1", "algo-2"], + "current_instance_type": "ml.p4d.24xlarge", + "hosts": ["algo-1"], "network_interface_name": "eth0", } - + set_env(resource_config, distribution="torchrun") - + + # Verify file was written mock_file.assert_called_once() @patch("builtins.open", new_callable=mock_open) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.os.environ", - {"TRAINING_JOB_NAME": "test-job"}, - ) - def test_set_env_with_mpirun(self, mock_neurons, mock_gpus, mock_cpus, mock_file): - """Test set_env with mpirun distribution""" - mock_cpus.return_value = 4 + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.log_env_variables") + @patch.dict("os.environ", {"TRAINING_JOB_NAME": "test-job"}) + def test_sets_mpirun_distribution_vars(self, mock_log_env, mock_neurons, mock_gpus, mock_cpus, mock_file): + """Test sets mpirun distribution environment variables.""" + mock_cpus.return_value = 8 mock_gpus.return_value = 2 mock_neurons.return_value = 0 - + resource_config = { "current_host": "algo-1", "current_instance_type": "ml.p3.2xlarge", "hosts": ["algo-1", "algo-2"], "network_interface_name": "eth0", } - + set_env(resource_config, distribution="mpirun") + + mock_file.assert_called_once() + @patch("builtins.open", new_callable=mock_open) + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.log_env_variables") + @patch.dict("os.environ", {"TRAINING_JOB_NAME": "test-job"}) + def test_uses_user_nproc_per_node(self, mock_log_env, mock_neurons, mock_gpus, mock_cpus, mock_file): + """Test uses user-specified nproc_per_node.""" + mock_cpus.return_value = 8 + mock_gpus.return_value = 2 + mock_neurons.return_value = 0 + + resource_config = { + "current_host": "algo-1", + "current_instance_type": "ml.p3.2xlarge", + "hosts": ["algo-1"], + "network_interface_name": "eth0", + } + + set_env(resource_config, user_nproc_per_node="4") + mock_file.assert_called_once() +class TestWriteFailureReasonFile: + """Test _write_failure_reason_file function.""" + + @patch("builtins.open", new_callable=mock_open) + @patch("os.path.exists") + def test_writes_failure_file(self, mock_exists, mock_file): + """Test writes failure reason file.""" + mock_exists.return_value = False + + _write_failure_reason_file("Test error message") + + mock_file.assert_called_once_with(FAILURE_REASON_PATH, "w") + mock_file().write.assert_called_once_with("RuntimeEnvironmentError: Test error message") + + @patch("builtins.open", new_callable=mock_open) + @patch("os.path.exists") + def test_does_not_write_if_exists(self, mock_exists, mock_file): + """Test does not write if failure file already exists.""" + mock_exists.return_value = True + + _write_failure_reason_file("Test error message") + + mock_file.assert_not_called() + + +class TestUnpackUserWorkspace: + """Test _unpack_user_workspace function.""" + + @patch("os.path.exists") + def test_returns_none_if_dir_not_exists(self, mock_exists): + """Test returns None if workspace directory doesn't exist.""" + mock_exists.return_value = False + + result = _unpack_user_workspace() + + assert result is None + + @patch("os.path.isfile") + @patch("os.path.exists") + def test_returns_none_if_archive_not_exists(self, mock_exists, mock_isfile): + """Test returns None if workspace archive doesn't exist.""" + mock_exists.return_value = True + mock_isfile.return_value = False + + result = _unpack_user_workspace() + + assert result is None + + @patch("shutil.unpack_archive") + @patch("os.path.isfile") + @patch("os.path.exists") + @patch("os.getcwd") + def test_unpacks_workspace_successfully(self, mock_getcwd, mock_exists, mock_isfile, mock_unpack): + """Test unpacks workspace successfully.""" + mock_getcwd.return_value = "/tmp/workspace" + mock_exists.return_value = True + mock_isfile.return_value = True + + result = _unpack_user_workspace() + + mock_unpack.assert_called_once() + assert result is not None + + +class TestHandlePreExecScripts: + """Test _handle_pre_exec_scripts function.""" + + @patch("os.path.isfile") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager") + def test_runs_pre_exec_script(self, mock_manager_class, mock_isfile): + """Test runs pre-execution script.""" + mock_isfile.return_value = True + mock_manager = MagicMock() + mock_manager_class.return_value = mock_manager + + _handle_pre_exec_scripts("/tmp/scripts") + + mock_manager.run_pre_exec_script.assert_called_once() + + +class TestInstallDependencies: + """Test _install_dependencies function.""" + + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager") + def test_installs_with_dependency_settings(self, mock_manager_class): + """Test installs dependencies with dependency settings.""" + mock_manager = MagicMock() + mock_manager_class.return_value = mock_manager + + dep_settings = _DependencySettings(dependency_file="requirements.txt") + + _install_dependencies( + "/tmp/deps", + "my-env", + "3.8", + "channel", + dep_settings + ) + + mock_manager.bootstrap.assert_called_once() + + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager") + def test_skips_if_no_dependency_file(self, mock_manager_class): + """Test skips installation if no dependency file.""" + mock_manager = MagicMock() + mock_manager_class.return_value = mock_manager + + dep_settings = _DependencySettings(dependency_file=None) + + _install_dependencies( + "/tmp/deps", + "my-env", + "3.8", + "channel", + dep_settings + ) + + mock_manager.bootstrap.assert_not_called() + + @patch("os.listdir") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager") + def test_finds_dependency_file_legacy(self, mock_manager_class, mock_listdir): + """Test finds dependency file in legacy mode.""" + mock_manager = MagicMock() + mock_manager_class.return_value = mock_manager + mock_listdir.return_value = ["requirements.txt", "script.py"] + + _install_dependencies( + "/tmp/deps", + "my-env", + "3.8", + "channel", + None + ) + + mock_manager.bootstrap.assert_called_once() + + +class TestBootstrapRuntimeEnvForRemoteFunction: + """Test _bootstrap_runtime_env_for_remote_function function.""" + + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._install_dependencies") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._handle_pre_exec_scripts") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._unpack_user_workspace") + def test_bootstraps_successfully(self, mock_unpack, mock_handle_scripts, mock_install): + """Test bootstraps runtime environment successfully.""" + mock_unpack.return_value = "/tmp/workspace" + + _bootstrap_runtime_env_for_remote_function("3.8", "my-env", None) + + mock_unpack.assert_called_once() + mock_handle_scripts.assert_called_once() + mock_install.assert_called_once() + + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._unpack_user_workspace") + def test_returns_early_if_no_workspace(self, mock_unpack): + """Test returns early if no workspace to unpack.""" + mock_unpack.return_value = None + + _bootstrap_runtime_env_for_remote_function("3.8", "my-env", None) + + mock_unpack.assert_called_once() + + +class TestBootstrapRuntimeEnvForPipelineStep: + """Test _bootstrap_runtime_env_for_pipeline_step function.""" + + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._install_dependencies") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._handle_pre_exec_scripts") + @patch("shutil.copy") + @patch("os.listdir") + @patch("os.path.exists") + @patch("os.mkdir") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._unpack_user_workspace") + def test_bootstraps_with_workspace(self, mock_unpack, mock_mkdir, mock_exists, mock_listdir, mock_copy, mock_handle_scripts, mock_install): + """Test bootstraps pipeline step with workspace.""" + mock_unpack.return_value = "/tmp/workspace" + mock_exists.return_value = True + mock_listdir.return_value = ["requirements.txt"] + + _bootstrap_runtime_env_for_pipeline_step("3.8", "func_step", "my-env", None) + + mock_unpack.assert_called_once() + mock_handle_scripts.assert_called_once() + mock_install.assert_called_once() + + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._install_dependencies") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._handle_pre_exec_scripts") + @patch("os.path.exists") + @patch("os.mkdir") + @patch("os.getcwd") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._unpack_user_workspace") + def test_creates_workspace_if_none(self, mock_unpack, mock_getcwd, mock_mkdir, mock_exists, mock_handle_scripts, mock_install): + """Test creates workspace directory if none exists.""" + mock_unpack.return_value = None + mock_getcwd.return_value = "/tmp" + mock_exists.return_value = False + + _bootstrap_runtime_env_for_pipeline_step("3.8", "func_step", "my-env", None) + + mock_mkdir.assert_called_once() + + class TestMain: - """Test cases for main function""" - - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._parse_args" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._bootstrap_runtime_env_for_remote_function" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.getpass.getuser" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.os.path.exists" - ) - def test_main_success( - self, mock_exists, mock_getuser, mock_manager_class, mock_bootstrap, mock_parse - ): - """Test main function successful execution""" - mock_args = Mock() + """Test main function.""" + + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.set_env") + @patch("builtins.open", new_callable=mock_open, read_data='{"current_host": "algo-1", "current_instance_type": "ml.m5.xlarge", "hosts": ["algo-1"], "network_interface_name": "eth0"}') + @patch("os.path.exists") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._bootstrap_runtime_env_for_remote_function") + @patch("getpass.getuser") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._parse_args") + def test_main_success(self, mock_parse_args, mock_getuser, mock_bootstrap, mock_manager_class, mock_exists, mock_file, mock_set_env): + """Test main function successful execution.""" + mock_getuser.return_value = "root" + mock_exists.return_value = True + mock_manager = MagicMock() + mock_manager_class.return_value = mock_manager + + # Mock parsed args + mock_args = MagicMock() mock_args.client_python_version = "3.8" - mock_args.client_sagemaker_pysdk_version = "2.0.0" + mock_args.client_sagemaker_pysdk_version = None mock_args.job_conda_env = None mock_args.pipeline_execution_id = None mock_args.dependency_settings = None mock_args.func_step_s3_dir = None mock_args.distribution = None mock_args.user_nproc_per_node = None - mock_parse.return_value = mock_args - - mock_getuser.return_value = "root" - mock_exists.return_value = False - - mock_manager = Mock() - mock_manager_class.return_value = mock_manager - + mock_parse_args.return_value = mock_args + + args = [ + "--client_python_version", "3.8", + ] + with pytest.raises(SystemExit) as exc_info: - main([]) - + main(args) + assert exc_info.value.code == SUCCESS_EXIT_CODE + mock_bootstrap.assert_called_once() - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._parse_args" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._write_failure_reason_file" - ) - def test_main_failure(self, mock_write_failure, mock_parse): - """Test main function with failure""" - mock_parse.side_effect = Exception("Test error") - + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._write_failure_reason_file") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager") + @patch("getpass.getuser") + def test_main_handles_exception(self, mock_getuser, mock_manager_class, mock_write_failure): + """Test main function handles exceptions.""" + mock_getuser.return_value = "root" + mock_manager = MagicMock() + mock_manager._validate_python_version.side_effect = Exception("Test error") + mock_manager_class.return_value = mock_manager + + args = [ + "--client_python_version", "3.8", + ] + with pytest.raises(SystemExit) as exc_info: - main([]) - + main(args) + assert exc_info.value.code == DEFAULT_FAILURE_CODE mock_write_failure.assert_called_once() + + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.set_env") + @patch("builtins.open", new_callable=mock_open, read_data='{"current_host": "algo-1", "current_instance_type": "ml.m5.xlarge", "hosts": ["algo-1"], "network_interface_name": "eth0"}') + @patch("os.path.exists") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._bootstrap_runtime_env_for_pipeline_step") + @patch("getpass.getuser") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._parse_args") + def test_main_pipeline_execution(self, mock_parse_args, mock_getuser, mock_bootstrap, mock_manager_class, mock_exists, mock_file, mock_set_env): + """Test main function for pipeline execution.""" + mock_getuser.return_value = "root" + mock_exists.return_value = True + mock_manager = MagicMock() + mock_manager_class.return_value = mock_manager + + # Mock parsed args + mock_args = MagicMock() + mock_args.client_python_version = "3.8" + mock_args.client_sagemaker_pysdk_version = None + mock_args.job_conda_env = None + mock_args.pipeline_execution_id = "exec-123" + mock_args.dependency_settings = None + mock_args.func_step_s3_dir = "s3://bucket/func" + mock_args.distribution = None + mock_args.user_nproc_per_node = None + mock_parse_args.return_value = mock_args + + args = [ + "--client_python_version", "3.8", + "--pipeline_execution_id", "exec-123", + "--func_step_s3_dir", "s3://bucket/func", + ] + + with pytest.raises(SystemExit) as exc_info: + main(args) + + assert exc_info.value.code == SUCCESS_EXIT_CODE + mock_bootstrap.assert_called_once() + + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager") + @patch("getpass.getuser") + def test_main_non_root_user(self, mock_getuser, mock_manager_class): + """Test main function with non-root user.""" + mock_getuser.return_value = "ubuntu" + mock_manager = MagicMock() + mock_manager_class.return_value = mock_manager + + args = [ + "--client_python_version", "3.8", + ] + + with pytest.raises(SystemExit): + main(args) + + mock_manager.change_dir_permission.assert_called_once() diff --git a/sagemaker-core/tests/unit/remote_function/runtime_environment/test_mpi_utils_remote.py b/sagemaker-core/tests/unit/remote_function/runtime_environment/test_mpi_utils_remote.py index e075489b6b..b84dda5c1a 100644 --- a/sagemaker-core/tests/unit/remote_function/runtime_environment/test_mpi_utils_remote.py +++ b/sagemaker-core/tests/unit/remote_function/runtime_environment/test_mpi_utils_remote.py @@ -10,10 +10,14 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +"""Tests for mpi_utils_remote module.""" +from __future__ import absolute_import +import os import pytest -from unittest.mock import Mock, patch, MagicMock, mock_open import subprocess +import time +from unittest.mock import patch, MagicMock, mock_open, call import paramiko from sagemaker.core.remote_function.runtime_environment.mpi_utils_remote import ( @@ -32,6 +36,7 @@ main, SUCCESS_EXIT_CODE, DEFAULT_FAILURE_CODE, + FAILURE_REASON_PATH, FINISHED_STATUS_FILE, READY_FILE, DEFAULT_SSH_PORT, @@ -39,328 +44,381 @@ class TestCustomHostKeyPolicy: - """Test cases for CustomHostKeyPolicy class""" + """Test CustomHostKeyPolicy class.""" - def test_missing_host_key_algo_hostname(self): - """Test missing_host_key accepts algo-* hostnames""" + def test_accepts_algo_hostname(self): + """Test accepts hostnames starting with algo-.""" policy = CustomHostKeyPolicy() - client = Mock() - client.get_host_keys.return_value = Mock() - key = Mock() - key.get_name.return_value = "ssh-rsa" - + mock_client = MagicMock() + mock_hostname = "algo-1234" + mock_key = MagicMock() + mock_key.get_name.return_value = "ssh-rsa" + # Should not raise exception - policy.missing_host_key(client, "algo-1", key) - - client.get_host_keys().add.assert_called_once() + policy.missing_host_key(mock_client, mock_hostname, mock_key) + + mock_client.get_host_keys().add.assert_called_once_with(mock_hostname, "ssh-rsa", mock_key) - def test_missing_host_key_unknown_hostname(self): - """Test missing_host_key rejects unknown hostnames""" + def test_rejects_non_algo_hostname(self): + """Test rejects hostnames not starting with algo-.""" policy = CustomHostKeyPolicy() - client = Mock() - key = Mock() + mock_client = MagicMock() + mock_hostname = "unknown-host" + mock_key = MagicMock() + + with pytest.raises(paramiko.SSHException): + policy.missing_host_key(mock_client, mock_hostname, mock_key) + + +class TestParseArgs: + """Test _parse_args function.""" - with pytest.raises(paramiko.SSHException, match="Unknown host key"): - policy.missing_host_key(client, "unknown-host", key) + def test_parse_default_args(self): + """Test parsing with default arguments.""" + args = [] + parsed = _parse_args(args) + assert parsed.job_ended == "0" + def test_parse_job_ended_true(self): + """Test parsing with job_ended set to true.""" + args = ["--job_ended", "1"] + parsed = _parse_args(args) + assert parsed.job_ended == "1" -class TestConnectionFunctions: - """Test cases for connection functions""" + def test_parse_job_ended_false(self): + """Test parsing with job_ended set to false.""" + args = ["--job_ended", "0"] + parsed = _parse_args(args) + assert parsed.job_ended == "0" - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.paramiko.SSHClient") + +class TestCanConnect: + """Test _can_connect function.""" + + @patch("paramiko.SSHClient") def test_can_connect_success(self, mock_ssh_client_class): - """Test _can_connect when connection succeeds""" - mock_client = Mock() + """Test successful connection.""" + mock_client = MagicMock() mock_ssh_client_class.return_value.__enter__.return_value = mock_client - + result = _can_connect("algo-1", DEFAULT_SSH_PORT) - + assert result is True mock_client.connect.assert_called_once_with("algo-1", port=DEFAULT_SSH_PORT) - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.paramiko.SSHClient") + @patch("paramiko.SSHClient") def test_can_connect_failure(self, mock_ssh_client_class): - """Test _can_connect when connection fails""" - mock_client = Mock() + """Test failed connection.""" + mock_client = MagicMock() mock_client.connect.side_effect = Exception("Connection failed") mock_ssh_client_class.return_value.__enter__.return_value = mock_client - + result = _can_connect("algo-1", DEFAULT_SSH_PORT) - + assert result is False - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.subprocess.run") - def test_write_file_to_host_success(self, mock_run): - """Test _write_file_to_host when write succeeds""" - mock_run.return_value = Mock() + @patch("paramiko.SSHClient") + def test_can_connect_uses_custom_port(self, mock_ssh_client_class): + """Test connection with custom port.""" + mock_client = MagicMock() + mock_ssh_client_class.return_value.__enter__.return_value = mock_client + + _can_connect("algo-1", 2222) + + mock_client.connect.assert_called_once_with("algo-1", port=2222) - result = _write_file_to_host("algo-1", "/tmp/status") +class TestWriteFileToHost: + """Test _write_file_to_host function.""" + + @patch("subprocess.run") + def test_write_file_success(self, mock_run): + """Test successful file write.""" + mock_run.return_value = MagicMock(returncode=0) + + result = _write_file_to_host("algo-1", "/tmp/status") + assert result is True mock_run.assert_called_once() - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.subprocess.run") - def test_write_file_to_host_failure(self, mock_run): - """Test _write_file_to_host when write fails""" + @patch("subprocess.run") + def test_write_file_failure(self, mock_run): + """Test failed file write.""" mock_run.side_effect = subprocess.CalledProcessError(1, "ssh") - + result = _write_file_to_host("algo-1", "/tmp/status") - + assert result is False - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.os.path.exists") + +class TestWriteFailureReasonFile: + """Test _write_failure_reason_file function.""" + @patch("builtins.open", new_callable=mock_open) - def test_write_failure_reason_file(self, mock_file, mock_exists): - """Test _write_failure_reason_file""" + @patch("os.path.exists") + def test_writes_failure_file(self, mock_exists, mock_file): + """Test writes failure reason file.""" mock_exists.return_value = False + + _write_failure_reason_file("Test error message") + + mock_file.assert_called_once_with(FAILURE_REASON_PATH, "w") + mock_file().write.assert_called_once_with("RuntimeEnvironmentError: Test error message") - _write_failure_reason_file("Test error") - - mock_file.assert_called_once() - mock_file().write.assert_called_once_with("RuntimeEnvironmentError: Test error") + @patch("builtins.open", new_callable=mock_open) + @patch("os.path.exists") + def test_does_not_write_if_exists(self, mock_exists, mock_file): + """Test does not write if failure file already exists.""" + mock_exists.return_value = True + + _write_failure_reason_file("Test error message") + + mock_file.assert_not_called() -class TestWaitFunctions: - """Test cases for wait functions""" +class TestWaitForMaster: + """Test _wait_for_master function.""" + @patch("time.sleep") @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._can_connect") - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.time.sleep") - def test_wait_for_master_success(self, mock_sleep, mock_can_connect): - """Test _wait_for_master when master becomes available""" - mock_can_connect.side_effect = [False, False, True] - + def test_wait_for_master_success(self, mock_can_connect, mock_sleep): + """Test successful wait for master.""" + mock_can_connect.return_value = True + _wait_for_master("algo-1", DEFAULT_SSH_PORT, timeout=300) + + mock_can_connect.assert_called_once_with("algo-1", DEFAULT_SSH_PORT) - assert mock_can_connect.call_count == 3 - + @patch("time.time") + @patch("time.sleep") @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._can_connect") - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.time.sleep") - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.time.time") - def test_wait_for_master_timeout(self, mock_time, mock_sleep, mock_can_connect): - """Test _wait_for_master when timeout occurs""" + def test_wait_for_master_timeout(self, mock_can_connect, mock_sleep, mock_time): + """Test timeout waiting for master.""" mock_can_connect.return_value = False - mock_time.side_effect = [0, 100, 200, 301, 301] - - with pytest.raises(TimeoutError, match="Timed out waiting for master"): + # Need enough values for all time.time() calls in the loop + mock_time.side_effect = [0] + [i * 5 for i in range(1, 100)] # Simulate time passing + + with pytest.raises(TimeoutError): _wait_for_master("algo-1", DEFAULT_SSH_PORT, timeout=300) - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.os.path.exists") - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.time.sleep") - def test_wait_for_status_file(self, mock_sleep, mock_exists): - """Test _wait_for_status_file""" - mock_exists.side_effect = [False, False, True] + @patch("time.time") + @patch("time.sleep") + @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._can_connect") + def test_wait_for_master_retries(self, mock_can_connect, mock_sleep, mock_time): + """Test retries before successful connection.""" + mock_can_connect.side_effect = [False, False, True] + # Return value instead of side_effect for time.time() + mock_time.return_value = 0 + + _wait_for_master("algo-1", DEFAULT_SSH_PORT, timeout=300) + + assert mock_can_connect.call_count == 3 + +class TestWaitForStatusFile: + """Test _wait_for_status_file function.""" + + @patch("time.sleep") + @patch("os.path.exists") + def test_wait_for_status_file_exists(self, mock_exists, mock_sleep): + """Test wait for status file that exists.""" + mock_exists.return_value = True + _wait_for_status_file("/tmp/status") + + mock_exists.assert_called_once_with("/tmp/status") + @patch("time.sleep") + @patch("os.path.exists") + def test_wait_for_status_file_waits(self, mock_exists, mock_sleep): + """Test waits until status file exists.""" + mock_exists.side_effect = [False, False, True] + + _wait_for_status_file("/tmp/status") + assert mock_exists.call_count == 3 + assert mock_sleep.call_count == 2 + + +class TestWaitForWorkers: + """Test _wait_for_workers function.""" + + @patch("os.path.exists") + @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._can_connect") + def test_wait_for_workers_empty_list(self, mock_can_connect, mock_exists): + """Test wait for workers with empty list.""" + _wait_for_workers([], DEFAULT_SSH_PORT, timeout=300) + + mock_can_connect.assert_not_called() + @patch("time.sleep") + @patch("os.path.exists") @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._can_connect") - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.os.path.exists") - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.time.sleep") - def test_wait_for_workers_success(self, mock_sleep, mock_exists, mock_can_connect): - """Test _wait_for_workers when all workers become available""" + def test_wait_for_workers_success(self, mock_can_connect, mock_exists, mock_sleep): + """Test successful wait for workers.""" mock_can_connect.return_value = True mock_exists.return_value = True - + _wait_for_workers(["algo-2", "algo-3"], DEFAULT_SSH_PORT, timeout=300) - + assert mock_can_connect.call_count == 2 + @patch("time.time") + @patch("time.sleep") + @patch("os.path.exists") @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._can_connect") - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.time.sleep") - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.time.time") - def test_wait_for_workers_timeout(self, mock_time, mock_sleep, mock_can_connect): - """Test _wait_for_workers when timeout occurs""" + def test_wait_for_workers_timeout(self, mock_can_connect, mock_exists, mock_sleep, mock_time): + """Test timeout waiting for workers.""" mock_can_connect.return_value = False - mock_time.side_effect = [0, 100, 200, 301, 301] - - with pytest.raises(TimeoutError, match="Timed out waiting for workers"): + mock_exists.return_value = False + # Need enough values for all time.time() calls in the loop + mock_time.side_effect = [0] + [i * 5 for i in range(1, 100)] + + with pytest.raises(TimeoutError): _wait_for_workers(["algo-2"], DEFAULT_SSH_PORT, timeout=300) - def test_wait_for_workers_no_workers(self): - """Test _wait_for_workers with no workers""" - # Should not raise exception - _wait_for_workers([], DEFAULT_SSH_PORT, timeout=300) - -class TestBootstrapFunctions: - """Test cases for bootstrap functions""" +class TestBootstrapMasterNode: + """Test bootstrap_master_node function.""" @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._wait_for_workers") def test_bootstrap_master_node(self, mock_wait): - """Test bootstrap_master_node""" - bootstrap_master_node(["algo-2", "algo-3"]) + """Test bootstrap master node.""" + worker_hosts = ["algo-2", "algo-3"] + + bootstrap_master_node(worker_hosts) + + mock_wait.assert_called_once_with(worker_hosts) - mock_wait.assert_called_once_with(["algo-2", "algo-3"]) +class TestBootstrapWorkerNode: + """Test bootstrap_worker_node function.""" + + @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._wait_for_status_file") + @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._write_file_to_host") @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._wait_for_master") - @patch( - "sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._write_file_to_host" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._wait_for_status_file" - ) - def test_bootstrap_worker_node(self, mock_wait_status, mock_write, mock_wait_master): - """Test bootstrap_worker_node""" + def test_bootstrap_worker_node(self, mock_wait_master, mock_write, mock_wait_status): + """Test bootstrap worker node.""" bootstrap_worker_node("algo-1", "algo-2", "/tmp/status") - + mock_wait_master.assert_called_once_with("algo-1") mock_write.assert_called_once() mock_wait_status.assert_called_once_with("/tmp/status") - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.os.path.exists") - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.subprocess.Popen") - def test_start_sshd_daemon_success(self, mock_popen, mock_exists): - """Test start_sshd_daemon when sshd exists""" - mock_exists.return_value = True - start_sshd_daemon() +class TestStartSshdDaemon: + """Test start_sshd_daemon function.""" - mock_popen.assert_called_once() + @patch("subprocess.Popen") + @patch("os.path.exists") + def test_starts_sshd_successfully(self, mock_exists, mock_popen): + """Test starts SSH daemon successfully.""" + mock_exists.return_value = True + + start_sshd_daemon() + + mock_popen.assert_called_once_with(["/usr/sbin/sshd", "-D"]) - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.os.path.exists") - def test_start_sshd_daemon_not_found(self, mock_exists): - """Test start_sshd_daemon when sshd not found""" + @patch("os.path.exists") + def test_raises_error_if_sshd_not_found(self, mock_exists): + """Test raises error if SSH daemon not found.""" mock_exists.return_value = False - - with pytest.raises(RuntimeError, match="SSH daemon not found"): + + with pytest.raises(RuntimeError): start_sshd_daemon() - @patch( - "sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._write_file_to_host" - ) - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.time.sleep") - def test_write_status_file_to_workers_success(self, mock_sleep, mock_write): - """Test write_status_file_to_workers when writes succeed""" - mock_write.return_value = True - write_status_file_to_workers(["algo-2", "algo-3"], "/tmp/status") +class TestWriteStatusFileToWorkers: + """Test write_status_file_to_workers function.""" + @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._write_file_to_host") + def test_writes_to_all_workers(self, mock_write): + """Test writes status file to all workers.""" + mock_write.return_value = True + worker_hosts = ["algo-2", "algo-3"] + + write_status_file_to_workers(worker_hosts, "/tmp/status") + assert mock_write.call_count == 2 - @patch( - "sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._write_file_to_host" - ) - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.time.sleep") - def test_write_status_file_to_workers_timeout(self, mock_sleep, mock_write): - """Test write_status_file_to_workers when timeout occurs""" + @patch("time.sleep") + @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._write_file_to_host") + def test_retries_on_failure(self, mock_write, mock_sleep): + """Test retries writing status file on failure.""" + mock_write.side_effect = [False, False, True] + worker_hosts = ["algo-2"] + + write_status_file_to_workers(worker_hosts, "/tmp/status") + + assert mock_write.call_count == 3 + assert mock_sleep.call_count == 2 + + @patch("time.sleep") + @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._write_file_to_host") + def test_raises_timeout_after_retries(self, mock_write, mock_sleep): + """Test raises timeout after max retries.""" mock_write.return_value = False - - with pytest.raises(TimeoutError, match="Timed out waiting"): - write_status_file_to_workers(["algo-2"], "/tmp/status") - - -class TestParseArgs: - """Test cases for _parse_args function""" - - def test_parse_args_job_ended_false(self): - """Test _parse_args with job_ended=0""" - args = _parse_args(["--job_ended", "0"]) - - assert args.job_ended == "0" - - def test_parse_args_job_ended_true(self): - """Test _parse_args with job_ended=1""" - args = _parse_args(["--job_ended", "1"]) - - assert args.job_ended == "1" - - def test_parse_args_default(self): - """Test _parse_args with default values""" - args = _parse_args([]) - - assert args.job_ended == "0" + worker_hosts = ["algo-2"] + + with pytest.raises(TimeoutError): + write_status_file_to_workers(worker_hosts, "/tmp/status") class TestMain: - """Test cases for main function""" + """Test main function.""" - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._parse_args") + @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.bootstrap_worker_node") @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.start_sshd_daemon") - @patch( - "sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.bootstrap_worker_node" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.os.environ", - {"SM_MASTER_ADDR": "algo-1", "SM_CURRENT_HOST": "algo-2"}, - ) - def test_main_worker_node_job_running(self, mock_bootstrap_worker, mock_start_sshd, mock_parse): - """Test main for worker node when job is running""" - mock_args = Mock() - mock_args.job_ended = "0" - mock_parse.return_value = mock_args - - main([]) - + @patch.dict("os.environ", {"SM_MASTER_ADDR": "algo-1", "SM_CURRENT_HOST": "algo-2"}) + def test_main_worker_node_running(self, mock_start_sshd, mock_bootstrap_worker): + """Test main function for worker node during job run.""" + args = ["--job_ended", "0"] + + main(args) + mock_start_sshd.assert_called_once() mock_bootstrap_worker.assert_called_once() - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._parse_args") + @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.bootstrap_master_node") @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.start_sshd_daemon") - @patch( - "sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.bootstrap_master_node" - ) - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.json.loads") - @patch( - "sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.os.environ", - { - "SM_MASTER_ADDR": "algo-1", - "SM_CURRENT_HOST": "algo-1", - "SM_HOSTS": '["algo-1", "algo-2", "algo-3"]', - }, - ) - def test_main_master_node_job_running( - self, mock_json_loads, mock_bootstrap_master, mock_start_sshd, mock_parse - ): - """Test main for master node when job is running""" - mock_args = Mock() - mock_args.job_ended = "0" - mock_parse.return_value = mock_args - mock_json_loads.return_value = ["algo-1", "algo-2", "algo-3"] - - main([]) - + @patch.dict("os.environ", {"SM_MASTER_ADDR": "algo-1", "SM_CURRENT_HOST": "algo-1", "SM_HOSTS": '["algo-1", "algo-2"]'}) + def test_main_master_node_running(self, mock_start_sshd, mock_bootstrap_master): + """Test main function for master node during job run.""" + args = ["--job_ended", "0"] + + main(args) + mock_start_sshd.assert_called_once() - mock_bootstrap_master.assert_called_once_with(["algo-2", "algo-3"]) - - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._parse_args") - @patch( - "sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.write_status_file_to_workers" - ) - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.json.loads") - @patch( - "sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.os.environ", - { - "SM_MASTER_ADDR": "algo-1", - "SM_CURRENT_HOST": "algo-1", - "SM_HOSTS": '["algo-1", "algo-2"]', - }, - ) - def test_main_master_node_job_ended(self, mock_json_loads, mock_write_status, mock_parse): - """Test main for master node when job has ended""" - mock_args = Mock() - mock_args.job_ended = "1" - mock_parse.return_value = mock_args - mock_json_loads.return_value = ["algo-1", "algo-2"] - - main([]) - - mock_write_status.assert_called_once_with(["algo-2"]) - - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._parse_args") - @patch( - "sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._write_failure_reason_file" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.os.environ", - {"SM_MASTER_ADDR": "algo-1", "SM_CURRENT_HOST": "algo-2"}, - ) - def test_main_with_exception(self, mock_write_failure, mock_parse): - """Test main when exception occurs""" - mock_parse.side_effect = Exception("Test error") - + mock_bootstrap_master.assert_called_once() + + @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.write_status_file_to_workers") + @patch.dict("os.environ", {"SM_MASTER_ADDR": "algo-1", "SM_CURRENT_HOST": "algo-1", "SM_HOSTS": '["algo-1", "algo-2"]'}) + def test_main_master_node_job_ended(self, mock_write_status): + """Test main function for master node after job ends.""" + args = ["--job_ended", "1"] + + main(args) + + mock_write_status.assert_called_once() + + @patch.dict("os.environ", {"SM_MASTER_ADDR": "algo-1", "SM_CURRENT_HOST": "algo-2"}) + def test_main_worker_node_job_ended(self): + """Test main function for worker node after job ends.""" + args = ["--job_ended", "1"] + + # Should not raise any exceptions + main(args) + + @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._write_failure_reason_file") + @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.start_sshd_daemon") + @patch.dict("os.environ", {"SM_MASTER_ADDR": "algo-1", "SM_CURRENT_HOST": "algo-2"}) + def test_main_handles_exception(self, mock_start_sshd, mock_write_failure): + """Test main function handles exceptions.""" + mock_start_sshd.side_effect = Exception("Test error") + args = ["--job_ended", "0"] + with pytest.raises(SystemExit) as exc_info: - main([]) - + main(args) + assert exc_info.value.code == DEFAULT_FAILURE_CODE mock_write_failure.assert_called_once() diff --git a/sagemaker-core/tests/unit/remote_function/runtime_environment/test_runtime_environment_manager.py b/sagemaker-core/tests/unit/remote_function/runtime_environment/test_runtime_environment_manager.py index be2f1430d6..a300daf2b3 100644 --- a/sagemaker-core/tests/unit/remote_function/runtime_environment/test_runtime_environment_manager.py +++ b/sagemaker-core/tests/unit/remote_function/runtime_environment/test_runtime_environment_manager.py @@ -10,16 +10,20 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +"""Tests for runtime_environment_manager module.""" +from __future__ import absolute_import -import pytest -from unittest.mock import Mock, patch, MagicMock, mock_open +import json +import os import subprocess import sys +import pytest +from unittest.mock import patch, MagicMock, mock_open, call from sagemaker.core.remote_function.runtime_environment.runtime_environment_manager import ( + _DependencySettings, RuntimeEnvironmentManager, RuntimeEnvironmentError, - _DependencySettings, get_logger, _run_and_get_output_shell_cmd, _run_pre_execution_command_script, @@ -31,467 +35,465 @@ class TestDependencySettings: - """Test cases for _DependencySettings class""" + """Test _DependencySettings class.""" + + def test_init_with_no_file(self): + """Test initialization without dependency file.""" + settings = _DependencySettings() + assert settings.dependency_file is None def test_init_with_file(self): - """Test initialization with dependency file""" + """Test initialization with dependency file.""" settings = _DependencySettings(dependency_file="requirements.txt") - assert settings.dependency_file == "requirements.txt" - def test_init_without_file(self): - """Test initialization without dependency file""" - settings = _DependencySettings() - - assert settings.dependency_file is None - def test_to_string(self): - """Test to_string method""" + """Test converts to JSON string.""" settings = _DependencySettings(dependency_file="requirements.txt") - result = settings.to_string() + assert result == '{"dependency_file": "requirements.txt"}' - assert "requirements.txt" in result - - def test_from_string(self): - """Test from_string method""" + def test_from_string_with_file(self): + """Test creates from JSON string with file.""" json_str = '{"dependency_file": "requirements.txt"}' - settings = _DependencySettings.from_string(json_str) - assert settings.dependency_file == "requirements.txt" - def test_from_string_none(self): - """Test from_string with None""" + def test_from_string_with_none(self): + """Test creates from None.""" settings = _DependencySettings.from_string(None) - assert settings is None - def test_from_dependency_file_path(self): - """Test from_dependency_file_path method""" - settings = _DependencySettings.from_dependency_file_path("/path/to/requirements.txt") - - assert settings.dependency_file == "requirements.txt" + def test_from_dependency_file_path_with_none(self): + """Test creates from None file path.""" + settings = _DependencySettings.from_dependency_file_path(None) + assert settings.dependency_file is None - def test_from_dependency_file_path_auto_capture(self): - """Test from_dependency_file_path with auto_capture""" + def test_from_dependency_file_path_with_auto_capture(self): + """Test creates from auto_capture.""" settings = _DependencySettings.from_dependency_file_path("auto_capture") - assert settings.dependency_file == "env_snapshot.yml" - def test_from_dependency_file_path_none(self): - """Test from_dependency_file_path with None""" - settings = _DependencySettings.from_dependency_file_path(None) + def test_from_dependency_file_path_with_path(self): + """Test creates from file path.""" + settings = _DependencySettings.from_dependency_file_path("/path/to/requirements.txt") + assert settings.dependency_file == "requirements.txt" - assert settings.dependency_file is None + +class TestGetLogger: + """Test get_logger function.""" + + def test_returns_logger(self): + """Test returns logger instance.""" + logger = get_logger() + assert logger is not None + assert logger.name == "sagemaker.remote_function" class TestRuntimeEnvironmentManager: - """Test cases for RuntimeEnvironmentManager class""" + """Test RuntimeEnvironmentManager class.""" def test_init(self): - """Test initialization""" + """Test initialization.""" manager = RuntimeEnvironmentManager() - assert manager is not None - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.os.path.isfile" - ) - def test_snapshot_with_requirements_txt(self, mock_isfile): - """Test snapshot with requirements.txt""" - mock_isfile.return_value = True + @patch("os.path.isfile") + def test_snapshot_returns_none_for_none(self, mock_isfile): + """Test snapshot returns None when dependencies is None.""" manager = RuntimeEnvironmentManager() + result = manager.snapshot(None) + assert result is None - result = manager.snapshot("requirements.txt") + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._capture_from_local_runtime") + def test_snapshot_auto_capture(self, mock_capture): + """Test snapshot with auto_capture.""" + mock_capture.return_value = "/path/to/env_snapshot.yml" + manager = RuntimeEnvironmentManager() + result = manager.snapshot("auto_capture") + assert result == "/path/to/env_snapshot.yml" + mock_capture.assert_called_once() + @patch("os.path.isfile") + def test_snapshot_with_txt_file(self, mock_isfile): + """Test snapshot with requirements.txt file.""" + mock_isfile.return_value = True + manager = RuntimeEnvironmentManager() + result = manager.snapshot("requirements.txt") assert result == "requirements.txt" - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.os.path.isfile" - ) - def test_snapshot_with_conda_yml(self, mock_isfile): - """Test snapshot with conda environment.yml""" + @patch("os.path.isfile") + def test_snapshot_with_yml_file(self, mock_isfile): + """Test snapshot with conda.yml file.""" mock_isfile.return_value = True manager = RuntimeEnvironmentManager() - result = manager.snapshot("environment.yml") - assert result == "environment.yml" - @patch.object(RuntimeEnvironmentManager, "_capture_from_local_runtime") - def test_snapshot_with_auto_capture(self, mock_capture): - """Test snapshot with auto_capture""" - mock_capture.return_value = "env_snapshot.yml" - manager = RuntimeEnvironmentManager() - - result = manager.snapshot("auto_capture") - - assert result == "env_snapshot.yml" - mock_capture.assert_called_once() - - def test_snapshot_with_none(self): - """Test snapshot with None""" - manager = RuntimeEnvironmentManager() - - result = manager.snapshot(None) - - assert result is None - - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.os.path.isfile" - ) - def test_snapshot_with_invalid_file(self, mock_isfile): - """Test snapshot with invalid file""" + @patch("os.path.isfile") + def test_snapshot_raises_error_for_invalid_file(self, mock_isfile): + """Test snapshot raises error for invalid file.""" mock_isfile.return_value = False manager = RuntimeEnvironmentManager() + with pytest.raises(ValueError): + manager.snapshot("requirements.txt") - with pytest.raises(ValueError, match="No dependencies file named"): - manager.snapshot("invalid.txt") - - @patch.object(RuntimeEnvironmentManager, "_get_active_conda_env_name") - @patch.object(RuntimeEnvironmentManager, "_get_active_conda_env_prefix") - @patch.object(RuntimeEnvironmentManager, "_export_conda_env_from_prefix") - def test_capture_from_local_runtime_with_conda_env(self, mock_export, mock_prefix, mock_name): - """Test _capture_from_local_runtime with conda environment""" - mock_name.return_value = "myenv" - mock_prefix.return_value = "/opt/conda/envs/myenv" + def test_snapshot_raises_error_for_invalid_format(self): + """Test snapshot raises error for invalid format.""" manager = RuntimeEnvironmentManager() + with pytest.raises(ValueError): + manager.snapshot("invalid.json") - result = manager._capture_from_local_runtime() - - assert "env_snapshot.yml" in result - mock_export.assert_called_once() - - @patch.object(RuntimeEnvironmentManager, "_get_active_conda_env_name") - @patch.object(RuntimeEnvironmentManager, "_get_active_conda_env_prefix") - def test_capture_from_local_runtime_no_conda_env(self, mock_prefix, mock_name): - """Test _capture_from_local_runtime without conda environment""" - mock_name.return_value = None - mock_prefix.return_value = None - manager = RuntimeEnvironmentManager() - - with pytest.raises(ValueError, match="No conda environment"): - manager._capture_from_local_runtime() - - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.os.getenv" - ) + @patch("os.getenv") def test_get_active_conda_env_prefix(self, mock_getenv): - """Test _get_active_conda_env_prefix""" + """Test gets active conda environment prefix.""" mock_getenv.return_value = "/opt/conda/envs/myenv" manager = RuntimeEnvironmentManager() - result = manager._get_active_conda_env_prefix() - assert result == "/opt/conda/envs/myenv" - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.os.getenv" - ) + @patch("os.getenv") def test_get_active_conda_env_name(self, mock_getenv): - """Test _get_active_conda_env_name""" + """Test gets active conda environment name.""" mock_getenv.return_value = "myenv" manager = RuntimeEnvironmentManager() - result = manager._get_active_conda_env_name() - assert result == "myenv" - @patch.object(RuntimeEnvironmentManager, "_install_req_txt_in_conda_env") - @patch.object(RuntimeEnvironmentManager, "_write_conda_env_to_file") - def test_bootstrap_with_requirements_txt_and_conda_env(self, mock_write, mock_install): - """Test bootstrap with requirements.txt and conda environment""" + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._export_conda_env_from_prefix") + @patch("os.getcwd") + @patch("os.getenv") + def test_capture_from_local_runtime(self, mock_getenv, mock_getcwd, mock_export): + """Test captures from local runtime.""" + mock_getenv.side_effect = lambda x: "myenv" if x == "CONDA_DEFAULT_ENV" else "/opt/conda/envs/myenv" + mock_getcwd.return_value = "/tmp" manager = RuntimeEnvironmentManager() + result = manager._capture_from_local_runtime() + assert result == "/tmp/env_snapshot.yml" + mock_export.assert_called_once() - manager.bootstrap( - local_dependencies_file="requirements.txt", - client_python_version="3.8", - conda_env="myenv", - ) - - mock_install.assert_called_once_with("myenv", "requirements.txt") - mock_write.assert_called_once_with("myenv") - - @patch.object(RuntimeEnvironmentManager, "_install_requirements_txt") - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._python_executable" - ) - def test_bootstrap_with_requirements_txt_no_conda_env(self, mock_python_exec, mock_install): - """Test bootstrap with requirements.txt without conda environment""" - mock_python_exec.return_value = "/usr/bin/python3" + @patch("os.getenv") + def test_capture_from_local_runtime_raises_error_no_conda(self, mock_getenv): + """Test raises error when no conda environment active.""" + mock_getenv.return_value = None manager = RuntimeEnvironmentManager() + with pytest.raises(ValueError): + manager._capture_from_local_runtime() - manager.bootstrap(local_dependencies_file="requirements.txt", client_python_version="3.8") - + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._install_requirements_txt") + def test_bootstrap_with_txt_file_no_conda(self, mock_install): + """Test bootstrap with requirements.txt without conda.""" + manager = RuntimeEnvironmentManager() + manager.bootstrap("requirements.txt", "3.8", None) mock_install.assert_called_once() - @patch.object(RuntimeEnvironmentManager, "_update_conda_env") - @patch.object(RuntimeEnvironmentManager, "_write_conda_env_to_file") - def test_bootstrap_with_conda_yml_and_conda_env(self, mock_write, mock_update): - """Test bootstrap with conda yml and existing conda environment""" + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._write_conda_env_to_file") + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._install_req_txt_in_conda_env") + def test_bootstrap_with_txt_file_with_conda(self, mock_install, mock_write): + """Test bootstrap with requirements.txt with conda.""" manager = RuntimeEnvironmentManager() + manager.bootstrap("requirements.txt", "3.8", "myenv") + mock_install.assert_called_once() + mock_write.assert_called_once() - manager.bootstrap( - local_dependencies_file="environment.yml", - client_python_version="3.8", - conda_env="myenv", - ) - + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._write_conda_env_to_file") + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._update_conda_env") + def test_bootstrap_with_yml_file_with_conda(self, mock_update, mock_write): + """Test bootstrap with conda.yml with existing conda env.""" + manager = RuntimeEnvironmentManager() + manager.bootstrap("environment.yml", "3.8", "myenv") mock_update.assert_called_once() mock_write.assert_called_once() - @patch.object(RuntimeEnvironmentManager, "_create_conda_env") - @patch.object(RuntimeEnvironmentManager, "_validate_python_version") - @patch.object(RuntimeEnvironmentManager, "_write_conda_env_to_file") - def test_bootstrap_with_conda_yml_no_conda_env(self, mock_write, mock_validate, mock_create): - """Test bootstrap with conda yml without existing conda environment""" + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._write_conda_env_to_file") + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._validate_python_version") + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._create_conda_env") + def test_bootstrap_with_yml_file_without_conda(self, mock_create, mock_validate, mock_write): + """Test bootstrap with conda.yml without existing conda env.""" manager = RuntimeEnvironmentManager() - - manager.bootstrap(local_dependencies_file="environment.yml", client_python_version="3.8") - + manager.bootstrap("environment.yml", "3.8", None) mock_create.assert_called_once() mock_validate.assert_called_once() mock_write.assert_called_once() - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.os.path.isfile" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._run_pre_execution_command_script" - ) - def test_run_pre_exec_script_exists(self, mock_run_script, mock_isfile): - """Test run_pre_exec_script when script exists""" + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._run_pre_execution_command_script") + @patch("os.path.isfile") + def test_run_pre_exec_script_exists(self, mock_isfile, mock_run_script): + """Test runs pre-execution script when it exists.""" mock_isfile.return_value = True mock_run_script.return_value = (0, "") manager = RuntimeEnvironmentManager() - manager.run_pre_exec_script("/path/to/script.sh") - mock_run_script.assert_called_once() - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.os.path.isfile" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._run_pre_execution_command_script" - ) - def test_run_pre_exec_script_fails(self, mock_run_script, mock_isfile): - """Test run_pre_exec_script when script fails""" + @patch("os.path.isfile") + def test_run_pre_exec_script_not_exists(self, mock_isfile): + """Test handles pre-execution script not existing.""" + mock_isfile.return_value = False + manager = RuntimeEnvironmentManager() + # Should not raise exception + manager.run_pre_exec_script("/path/to/script.sh") + + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._run_pre_execution_command_script") + @patch("os.path.isfile") + def test_run_pre_exec_script_raises_error_on_failure(self, mock_isfile, mock_run_script): + """Test raises error when pre-execution script fails.""" mock_isfile.return_value = True mock_run_script.return_value = (1, "Error message") manager = RuntimeEnvironmentManager() - - with pytest.raises(RuntimeEnvironmentError, match="Encountered error"): + with pytest.raises(RuntimeEnvironmentError): manager.run_pre_exec_script("/path/to/script.sh") - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.subprocess.run" - ) + @patch("subprocess.run") def test_change_dir_permission_success(self, mock_run): - """Test change_dir_permission successfully""" + """Test changes directory permissions successfully.""" manager = RuntimeEnvironmentManager() - manager.change_dir_permission(["/tmp/dir1", "/tmp/dir2"], "777") - mock_run.assert_called_once() - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.subprocess.run" - ) - def test_change_dir_permission_failure(self, mock_run): - """Test change_dir_permission with failure""" - mock_run.side_effect = subprocess.CalledProcessError( - 1, "chmod", stderr=b"Permission denied" - ) + @patch("subprocess.run") + def test_change_dir_permission_raises_error_on_failure(self, mock_run): + """Test raises error when permission change fails.""" + mock_run.side_effect = subprocess.CalledProcessError(1, "chmod", stderr=b"Permission denied") manager = RuntimeEnvironmentManager() + with pytest.raises(RuntimeEnvironmentError): + manager.change_dir_permission(["/tmp/dir1"], "777") + @patch("subprocess.run") + def test_change_dir_permission_raises_error_no_sudo(self, mock_run): + """Test raises error when sudo not found.""" + mock_run.side_effect = FileNotFoundError("[Errno 2] No such file or directory: 'sudo'") + manager = RuntimeEnvironmentManager() with pytest.raises(RuntimeEnvironmentError): - manager.change_dir_permission(["/tmp/dir"], "777") + manager.change_dir_permission(["/tmp/dir1"], "777") - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._run_shell_cmd" - ) + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._run_shell_cmd") def test_install_requirements_txt(self, mock_run_cmd): - """Test _install_requirements_txt""" + """Test installs requirements.txt.""" manager = RuntimeEnvironmentManager() - - manager._install_requirements_txt("/path/to/requirements.txt", "/usr/bin/python3") - + manager._install_requirements_txt("/path/to/requirements.txt", "/usr/bin/python") mock_run_cmd.assert_called_once() - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._run_shell_cmd" - ) - @patch.object(RuntimeEnvironmentManager, "_get_conda_exe") + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._run_shell_cmd") + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._get_conda_exe") def test_create_conda_env(self, mock_get_conda, mock_run_cmd): - """Test _create_conda_env""" + """Test creates conda environment.""" mock_get_conda.return_value = "conda" manager = RuntimeEnvironmentManager() - manager._create_conda_env("myenv", "/path/to/environment.yml") + mock_run_cmd.assert_called_once() + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._run_shell_cmd") + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._get_conda_exe") + def test_install_req_txt_in_conda_env(self, mock_get_conda, mock_run_cmd): + """Test installs requirements.txt in conda environment.""" + mock_get_conda.return_value = "conda" + manager = RuntimeEnvironmentManager() + manager._install_req_txt_in_conda_env("myenv", "/path/to/requirements.txt") mock_run_cmd.assert_called_once() - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._run_shell_cmd" - ) - @patch.object(RuntimeEnvironmentManager, "_get_conda_exe") + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._run_shell_cmd") + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._get_conda_exe") def test_update_conda_env(self, mock_get_conda, mock_run_cmd): - """Test _update_conda_env""" + """Test updates conda environment.""" mock_get_conda.return_value = "conda" manager = RuntimeEnvironmentManager() - manager._update_conda_env("myenv", "/path/to/environment.yml") - mock_run_cmd.assert_called_once() - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.subprocess.Popen" - ) - def test_get_conda_exe_mamba(self, mock_popen): - """Test _get_conda_exe returns mamba""" - mock_process = Mock() + @patch("builtins.open", new_callable=mock_open) + @patch("subprocess.Popen") + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._get_conda_exe") + def test_export_conda_env_from_prefix(self, mock_get_conda, mock_popen, mock_file): + """Test exports conda environment.""" + mock_get_conda.return_value = "conda" + mock_process = MagicMock() + mock_process.communicate.return_value = (b"env output", b"") mock_process.wait.return_value = 0 mock_popen.return_value = mock_process + manager = RuntimeEnvironmentManager() + manager._export_conda_env_from_prefix("/opt/conda/envs/myenv", "/tmp/env.yml") + + mock_popen.assert_called_once() + mock_file.assert_called_once_with("/tmp/env.yml", "w") + @patch("builtins.open", new_callable=mock_open) + @patch("os.getcwd") + def test_write_conda_env_to_file(self, mock_getcwd, mock_file): + """Test writes conda environment name to file.""" + mock_getcwd.return_value = "/tmp" + manager = RuntimeEnvironmentManager() + manager._write_conda_env_to_file("myenv") + mock_file.assert_called_once_with("/tmp/remote_function_conda_env.txt", "w") + mock_file().write.assert_called_once_with("myenv") + + @patch("subprocess.Popen") + def test_get_conda_exe_returns_mamba(self, mock_popen): + """Test returns mamba when available.""" + mock_popen.return_value.wait.side_effect = [0, 1] # mamba exists, conda doesn't + manager = RuntimeEnvironmentManager() result = manager._get_conda_exe() - assert result == "mamba" - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.subprocess.Popen" - ) - def test_get_conda_exe_conda(self, mock_popen): - """Test _get_conda_exe returns conda""" - mock_process = Mock() - mock_process.wait.side_effect = [1, 0] # mamba not found, conda found - mock_popen.return_value = mock_process + @patch("subprocess.Popen") + def test_get_conda_exe_returns_conda(self, mock_popen): + """Test returns conda when mamba not available.""" + mock_popen.return_value.wait.side_effect = [1, 0] # mamba doesn't exist, conda does manager = RuntimeEnvironmentManager() - result = manager._get_conda_exe() - assert result == "conda" - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.subprocess.Popen" - ) - def test_get_conda_exe_not_found(self, mock_popen): - """Test _get_conda_exe when neither mamba nor conda found""" - mock_process = Mock() - mock_process.wait.return_value = 1 - mock_popen.return_value = mock_process + @patch("subprocess.Popen") + def test_get_conda_exe_raises_error(self, mock_popen): + """Test raises error when neither conda nor mamba available.""" + mock_popen.return_value.wait.return_value = 1 manager = RuntimeEnvironmentManager() - - with pytest.raises(ValueError, match="Neither conda nor mamba"): + with pytest.raises(ValueError): manager._get_conda_exe() - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.subprocess.check_output" - ) - @patch.object(RuntimeEnvironmentManager, "_get_conda_exe") + @patch("subprocess.check_output") + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._get_conda_exe") def test_python_version_in_conda_env(self, mock_get_conda, mock_check_output): - """Test _python_version_in_conda_env""" + """Test gets Python version in conda environment.""" mock_get_conda.return_value = "conda" mock_check_output.return_value = b"Python 3.8.10" manager = RuntimeEnvironmentManager() - result = manager._python_version_in_conda_env("myenv") - assert result == "3.8" - def test_current_python_version(self): - """Test _current_python_version""" + @patch("subprocess.check_output") + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._get_conda_exe") + def test_python_version_in_conda_env_raises_error(self, mock_get_conda, mock_check_output): + """Test raises error when getting Python version fails.""" + mock_get_conda.return_value = "conda" + mock_check_output.side_effect = subprocess.CalledProcessError(1, "conda", output=b"Error") manager = RuntimeEnvironmentManager() + with pytest.raises(RuntimeEnvironmentError): + manager._python_version_in_conda_env("myenv") + def test_current_python_version(self): + """Test gets current Python version.""" + manager = RuntimeEnvironmentManager() result = manager._current_python_version() + expected = f"{sys.version_info.major}.{sys.version_info.minor}" + assert result == expected - assert result == f"{sys.version_info.major}.{sys.version_info.minor}" - - @patch.object(RuntimeEnvironmentManager, "_python_version_in_conda_env") - def test_validate_python_version_match(self, mock_python_version): - """Test _validate_python_version when versions match""" + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._python_version_in_conda_env") + def test_validate_python_version_with_conda(self, mock_python_version): + """Test validates Python version with conda environment.""" mock_python_version.return_value = "3.8" manager = RuntimeEnvironmentManager() + # Should not raise exception + manager._validate_python_version("3.8", "myenv") - # Should not raise error - manager._validate_python_version("3.8", conda_env="myenv") - - @patch.object(RuntimeEnvironmentManager, "_python_version_in_conda_env") - def test_validate_python_version_mismatch(self, mock_python_version): - """Test _validate_python_version when versions don't match""" + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._python_version_in_conda_env") + def test_validate_python_version_mismatch_with_conda(self, mock_python_version): + """Test raises error on Python version mismatch with conda.""" mock_python_version.return_value = "3.9" manager = RuntimeEnvironmentManager() + with pytest.raises(RuntimeEnvironmentError): + manager._validate_python_version("3.8", "myenv") - with pytest.raises(RuntimeEnvironmentError, match="does not match"): - manager._validate_python_version("3.8", conda_env="myenv") - - @patch.object(RuntimeEnvironmentManager, "_current_sagemaker_pysdk_version") - def test_validate_sagemaker_pysdk_version_match(self, mock_version): - """Test _validate_sagemaker_pysdk_version when versions match""" - mock_version.return_value = "2.0.0" + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._current_python_version") + def test_validate_python_version_without_conda(self, mock_current_version): + """Test validates Python version without conda environment.""" + mock_current_version.return_value = "3.8" manager = RuntimeEnvironmentManager() + # Should not raise exception + manager._validate_python_version("3.8", None) - # Should not raise error, just log warning - manager._validate_sagemaker_pysdk_version("2.0.0") + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._current_python_version") + def test_validate_python_version_mismatch_without_conda(self, mock_current_version): + """Test raises error on Python version mismatch without conda.""" + mock_current_version.return_value = "3.9" + manager = RuntimeEnvironmentManager() + with pytest.raises(RuntimeEnvironmentError): + manager._validate_python_version("3.8", None) - @patch.object(RuntimeEnvironmentManager, "_current_sagemaker_pysdk_version") - def test_validate_sagemaker_pysdk_version_mismatch(self, mock_version): - """Test _validate_sagemaker_pysdk_version when versions don't match""" - mock_version.return_value = "2.1.0" + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._current_sagemaker_pysdk_version") + def test_validate_sagemaker_pysdk_version_match(self, mock_current_version): + """Test validates matching SageMaker SDK version.""" + mock_current_version.return_value = "2.100.0" manager = RuntimeEnvironmentManager() + # Should not raise exception or warning + manager._validate_sagemaker_pysdk_version("2.100.0") - # Should log warning but not raise error - manager._validate_sagemaker_pysdk_version("2.0.0") + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._current_sagemaker_pysdk_version") + def test_validate_sagemaker_pysdk_version_mismatch(self, mock_current_version): + """Test logs warning on SageMaker SDK version mismatch.""" + mock_current_version.return_value = "2.101.0" + manager = RuntimeEnvironmentManager() + # Should log warning but not raise exception + manager._validate_sagemaker_pysdk_version("2.100.0") + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._current_sagemaker_pysdk_version") + def test_validate_sagemaker_pysdk_version_none(self, mock_current_version): + """Test handles None client version.""" + mock_current_version.return_value = "2.100.0" + manager = RuntimeEnvironmentManager() + # Should not raise exception + manager._validate_sagemaker_pysdk_version(None) -class TestHelperFunctions: - """Test cases for helper functions""" - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.subprocess.check_output" - ) - def test_run_and_get_output_shell_cmd(self, mock_check_output): - """Test _run_and_get_output_shell_cmd""" - mock_check_output.return_value = b"output" +class TestRunAndGetOutputShellCmd: + """Test _run_and_get_output_shell_cmd function.""" + @patch("subprocess.check_output") + def test_runs_command_successfully(self, mock_check_output): + """Test runs command and returns output.""" + mock_check_output.return_value = b"command output" result = _run_and_get_output_shell_cmd("echo test") + assert result == "command output" + - assert result == "output" - - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.subprocess.Popen" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_output" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_error" - ) - def test_run_pre_execution_command_script(self, mock_log_error, mock_log_output, mock_popen): - """Test _run_pre_execution_command_script""" - mock_process = Mock() +class TestRunPreExecutionCommandScript: + """Test _run_pre_execution_command_script function.""" + + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_error") + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_output") + @patch("subprocess.Popen") + @patch("os.path.dirname") + def test_runs_script_successfully(self, mock_dirname, mock_popen, mock_log_output, mock_log_error): + """Test runs script successfully.""" + mock_dirname.return_value = "/tmp" + mock_process = MagicMock() mock_process.wait.return_value = 0 mock_popen.return_value = mock_process mock_log_error.return_value = "" + + return_code, error_logs = _run_pre_execution_command_script("/tmp/script.sh") + + assert return_code == 0 + assert error_logs == "" + + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_error") + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_output") + @patch("subprocess.Popen") + @patch("os.path.dirname") + def test_runs_script_with_error(self, mock_dirname, mock_popen, mock_log_output, mock_log_error): + """Test runs script that returns error.""" + mock_dirname.return_value = "/tmp" + mock_process = MagicMock() + mock_process.wait.return_value = 1 + mock_popen.return_value = mock_process + mock_log_error.return_value = "Error message" + + return_code, error_logs = _run_pre_execution_command_script("/tmp/script.sh") + + assert return_code == 1 + assert error_logs == "Error message" - return_code, error_logs = _run_pre_execution_command_script("/path/to/script.sh") - assert return_code == 0 +class TestRunShellCmd: + """Test _run_shell_cmd function.""" - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.subprocess.Popen" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_output" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_error" - ) - def test_run_shell_cmd_success(self, mock_log_error, mock_log_output, mock_popen): - """Test _run_shell_cmd with successful command""" - mock_process = Mock() + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_error") + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_output") + @patch("subprocess.Popen") + def test_runs_command_successfully(self, mock_popen, mock_log_output, mock_log_error): + """Test runs command successfully.""" + mock_process = MagicMock() mock_process.wait.return_value = 0 mock_popen.return_value = mock_process mock_log_error.return_value = "" @@ -500,63 +502,71 @@ def test_run_shell_cmd_success(self, mock_log_error, mock_log_output, mock_popen mock_popen.assert_called_once() - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.subprocess.Popen" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_output" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_error" - ) - def test_run_shell_cmd_failure(self, mock_log_error, mock_log_output, mock_popen): - """Test _run_shell_cmd with failed command""" - mock_process = Mock() + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_error") + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_output") + @patch("subprocess.Popen") + def test_runs_command_raises_error_on_failure(self, mock_popen, mock_log_output, mock_log_error): + """Test raises error when command fails.""" + mock_process = MagicMock() mock_process.wait.return_value = 1 mock_popen.return_value = mock_process mock_log_error.return_value = "Error message" - - with pytest.raises(RuntimeEnvironmentError, match="Encountered error"): + + with pytest.raises(RuntimeEnvironmentError): _run_shell_cmd(["false"]) - def test_python_executable(self): - """Test _python_executable""" - result = _python_executable() - assert result == sys.executable +class TestLogOutput: + """Test _log_output function.""" - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.sys.executable", - None, - ) - def test_python_executable_not_found(self): - """Test _python_executable when not found""" - with pytest.raises(RuntimeEnvironmentError, match="Failed to retrieve"): - _python_executable() + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.logger") + def test_logs_output(self, mock_logger): + """Test logs process output.""" + from io import BytesIO + mock_process = MagicMock() + mock_process.stdout = BytesIO(b"line1\nline2\n") + + _log_output(mock_process) + + assert mock_logger.info.call_count == 2 -class TestRuntimeEnvironmentError: - """Test cases for RuntimeEnvironmentError exception""" +class TestLogError: + """Test _log_error function.""" - def test_init(self): - """Test initialization""" - error = RuntimeEnvironmentError("Test error message") + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.logger") + def test_logs_error(self, mock_logger): + """Test logs process errors.""" + from io import BytesIO + mock_process = MagicMock() + mock_process.stderr = BytesIO(b"ERROR: error message\nwarning message\n") + + error_logs = _log_error(mock_process) + + assert "ERROR: error message" in error_logs + assert "warning message" in error_logs - assert error.message == "Test error message" - assert str(error) == "Test error message" - def test_raise(self): - """Test raising the exception""" - with pytest.raises(RuntimeEnvironmentError, match="Test error"): - raise RuntimeEnvironmentError("Test error") +class TestPythonExecutable: + """Test _python_executable function.""" + def test_returns_python_executable(self): + """Test returns Python executable path.""" + result = _python_executable() + assert result == sys.executable -class TestGetLogger: - """Test cases for get_logger function""" + @patch("sys.executable", None) + def test_raises_error_if_no_executable(self): + """Test raises error if no Python executable.""" + with pytest.raises(RuntimeEnvironmentError): + _python_executable() - def test_get_logger(self): - """Test get_logger returns logger""" - logger = get_logger() - assert logger is not None - assert logger.name == "sagemaker.remote_function" +class TestRuntimeEnvironmentError: + """Test RuntimeEnvironmentError class.""" + + def test_creates_error_with_message(self): + """Test creates error with message.""" + error = RuntimeEnvironmentError("Test error") + assert str(error) == "Test error" + assert error.message == "Test error" diff --git a/sagemaker-train/tests/unit/train/remote_function/test_checkpoint_location.py b/sagemaker-core/tests/unit/remote_function/test_checkpoint_location.py similarity index 97% rename from sagemaker-train/tests/unit/train/remote_function/test_checkpoint_location.py rename to sagemaker-core/tests/unit/remote_function/test_checkpoint_location.py index 5f3ca2c78f..98a5f8bcc8 100644 --- a/sagemaker-train/tests/unit/train/remote_function/test_checkpoint_location.py +++ b/sagemaker-core/tests/unit/remote_function/test_checkpoint_location.py @@ -14,7 +14,7 @@ from __future__ import absolute_import import pytest -from sagemaker.train.remote_function.checkpoint_location import ( +from sagemaker.core.remote_function.checkpoint_location import ( CheckpointLocation, _validate_s3_uri_for_checkpoint, _JOB_CHECKPOINT_LOCATION, diff --git a/sagemaker-train/tests/unit/train/remote_function/test_custom_file_filter.py b/sagemaker-core/tests/unit/remote_function/test_custom_file_filter.py similarity index 95% rename from sagemaker-train/tests/unit/train/remote_function/test_custom_file_filter.py rename to sagemaker-core/tests/unit/remote_function/test_custom_file_filter.py index 881127543f..5145a77adf 100644 --- a/sagemaker-train/tests/unit/train/remote_function/test_custom_file_filter.py +++ b/sagemaker-core/tests/unit/remote_function/test_custom_file_filter.py @@ -19,7 +19,7 @@ from unittest.mock import patch, MagicMock import pytest -from sagemaker.train.remote_function.custom_file_filter import ( +from sagemaker.core.remote_function.custom_file_filter import ( CustomFileFilter, resolve_custom_file_filter_from_config_file, copy_workdir, @@ -69,14 +69,14 @@ def custom_filter(path, names): result = resolve_custom_file_filter_from_config_file(direct_input=custom_filter) assert result is custom_filter - @patch("sagemaker.train.remote_function.custom_file_filter.resolve_value_from_config") + @patch("sagemaker.core.remote_function.custom_file_filter.resolve_value_from_config") def test_returns_none_when_no_config(self, mock_resolve): """Test returns None when no config is found.""" mock_resolve.return_value = None result = resolve_custom_file_filter_from_config_file() assert result is None - @patch("sagemaker.train.remote_function.custom_file_filter.resolve_value_from_config") + @patch("sagemaker.core.remote_function.custom_file_filter.resolve_value_from_config") def test_creates_filter_from_config(self, mock_resolve): """Test creates CustomFileFilter from config.""" patterns = ["*.pyc", "*.log"] @@ -85,7 +85,7 @@ def test_creates_filter_from_config(self, mock_resolve): assert isinstance(result, CustomFileFilter) assert result.ignore_name_patterns == patterns - @patch("sagemaker.train.remote_function.custom_file_filter.resolve_value_from_config") + @patch("sagemaker.core.remote_function.custom_file_filter.resolve_value_from_config") def test_passes_sagemaker_session_to_resolve(self, mock_resolve): """Test passes sagemaker_session to resolve_value_from_config.""" mock_session = MagicMock() diff --git a/sagemaker-train/tests/unit/train/remote_function/test_invoke_function.py b/sagemaker-core/tests/unit/remote_function/test_invoke_function.py similarity index 87% rename from sagemaker-train/tests/unit/train/remote_function/test_invoke_function.py rename to sagemaker-core/tests/unit/remote_function/test_invoke_function.py index 6beafc3d27..4810eba2e0 100644 --- a/sagemaker-train/tests/unit/train/remote_function/test_invoke_function.py +++ b/sagemaker-core/tests/unit/remote_function/test_invoke_function.py @@ -17,7 +17,7 @@ import pytest from unittest.mock import patch, MagicMock, call -from sagemaker.train.remote_function.invoke_function import ( +from sagemaker.core.remote_function.invoke_function import ( _parse_args, _get_sagemaker_session, _load_run_object, @@ -26,7 +26,7 @@ main, SUCCESS_EXIT_CODE, ) -from sagemaker.train.remote_function.job import KEY_EXPERIMENT_NAME, KEY_RUN_NAME +from sagemaker.core.remote_function.job import KEY_EXPERIMENT_NAME, KEY_RUN_NAME class TestParseArgs: @@ -95,8 +95,8 @@ def test_parse_default_values(self): class TestGetSagemakerSession: """Test _get_sagemaker_session function.""" - @patch("sagemaker.train.remote_function.invoke_function.boto3.session.Session") - @patch("sagemaker.train.remote_function.invoke_function.Session") + @patch("sagemaker.core.remote_function.invoke_function.boto3.session.Session") + @patch("sagemaker.core.remote_function.invoke_function.Session") def test_creates_session_with_region(self, mock_session_class, mock_boto_session): """Test creates SageMaker session with correct region.""" mock_boto = MagicMock() @@ -180,7 +180,6 @@ def test_executes_without_run_context(self, mock_stored_function_class): s3_base_uri="s3://bucket/path", s3_kms_key="key-123", run_in_context=None, - hmac_key="hmac-key", context=mock_context, ) @@ -188,12 +187,11 @@ def test_executes_without_run_context(self, mock_stored_function_class): sagemaker_session=mock_session, s3_base_uri="s3://bucket/path", s3_kms_key="key-123", - hmac_key="hmac-key", context=mock_context, ) mock_stored_func.load_and_invoke.assert_called_once() - @patch("sagemaker.train.remote_function.invoke_function._load_run_object") + @patch("sagemaker.core.remote_function.invoke_function._load_run_object") @patch("sagemaker.core.remote_function.core.stored_function.StoredFunction") def test_executes_with_run_context(self, mock_stored_function_class, mock_load_run): """Test executes stored function with run context.""" @@ -210,7 +208,6 @@ def test_executes_with_run_context(self, mock_stored_function_class, mock_load_r s3_base_uri="s3://bucket/path", s3_kms_key=None, run_in_context=run_json, - hmac_key="hmac-key", context=mock_context, ) @@ -223,11 +220,10 @@ def test_executes_with_run_context(self, mock_stored_function_class, mock_load_r class TestMain: """Test main function.""" - @patch("sagemaker.train.remote_function.invoke_function._execute_remote_function") - @patch("sagemaker.train.remote_function.invoke_function._get_sagemaker_session") - @patch("sagemaker.train.remote_function.invoke_function._load_pipeline_context") - @patch("sagemaker.train.remote_function.invoke_function._parse_args") - @patch.dict("os.environ", {"REMOTE_FUNCTION_SECRET_KEY": "test-key"}) + @patch("sagemaker.core.remote_function.invoke_function._execute_remote_function") + @patch("sagemaker.core.remote_function.invoke_function._get_sagemaker_session") + @patch("sagemaker.core.remote_function.invoke_function._load_pipeline_context") + @patch("sagemaker.core.remote_function.invoke_function._parse_args") def test_main_success(self, mock_parse, mock_load_context, mock_get_session, mock_execute): """Test main function successful execution.""" mock_args = MagicMock() @@ -250,12 +246,11 @@ def test_main_success(self, mock_parse, mock_load_context, mock_get_session, moc assert exc_info.value.code == SUCCESS_EXIT_CODE mock_execute.assert_called_once() - @patch("sagemaker.train.remote_function.invoke_function.handle_error") - @patch("sagemaker.train.remote_function.invoke_function._execute_remote_function") - @patch("sagemaker.train.remote_function.invoke_function._get_sagemaker_session") - @patch("sagemaker.train.remote_function.invoke_function._load_pipeline_context") - @patch("sagemaker.train.remote_function.invoke_function._parse_args") - @patch.dict("os.environ", {"REMOTE_FUNCTION_SECRET_KEY": "test-key"}) + @patch("sagemaker.core.remote_function.invoke_function.handle_error") + @patch("sagemaker.core.remote_function.invoke_function._execute_remote_function") + @patch("sagemaker.core.remote_function.invoke_function._get_sagemaker_session") + @patch("sagemaker.core.remote_function.invoke_function._load_pipeline_context") + @patch("sagemaker.core.remote_function.invoke_function._parse_args") def test_main_handles_exception( self, mock_parse, mock_load_context, mock_get_session, mock_execute, mock_handle_error ): diff --git a/sagemaker-core/tests/unit/remote_function/test_job.py b/sagemaker-core/tests/unit/remote_function/test_job.py index abc5be68be..6f10016643 100644 --- a/sagemaker-core/tests/unit/remote_function/test_job.py +++ b/sagemaker-core/tests/unit/remote_function/test_job.py @@ -143,26 +143,23 @@ class TestJob: def test_init(self, mock_session): """Test _Job initialization.""" - job = _Job("test-job", "s3://bucket/output", mock_session, "test-key") + job = _Job("test-job", "s3://bucket/output", mock_session) assert job.job_name == "test-job" assert job.s3_uri == "s3://bucket/output" - assert job.hmac_key == "test-key" def test_from_describe_response(self, mock_session): """Test creating _Job from describe response.""" response = { "TrainingJobName": "test-job", "OutputDataConfig": {"S3OutputPath": "s3://bucket/output"}, - "Environment": {"REMOTE_FUNCTION_SECRET_KEY": "test-key"}, } job = _Job.from_describe_response(response, mock_session) assert job.job_name == "test-job" assert job.s3_uri == "s3://bucket/output" - assert job.hmac_key == "test-key" def test_describe_returns_cached_response(self, mock_session): """Test that describe returns cached response for completed jobs.""" - job = _Job("test-job", "s3://bucket/output", mock_session, "test-key") + job = _Job("test-job", "s3://bucket/output", mock_session) job._last_describe_response = {"TrainingJobStatus": "Completed"} result = job.describe() @@ -171,7 +168,7 @@ def test_describe_returns_cached_response(self, mock_session): def test_describe_calls_api_for_in_progress_jobs(self, mock_session): """Test that describe calls API for in-progress jobs.""" - job = _Job("test-job", "s3://bucket/output", mock_session, "test-key") + job = _Job("test-job", "s3://bucket/output", mock_session) mock_session.sagemaker_client.describe_training_job.return_value = { "TrainingJobStatus": "InProgress" } @@ -182,7 +179,7 @@ def test_describe_calls_api_for_in_progress_jobs(self, mock_session): def test_stop(self, mock_session): """Test stopping a job.""" - job = _Job("test-job", "s3://bucket/output", mock_session, "test-key") + job = _Job("test-job", "s3://bucket/output", mock_session) job.stop() mock_session.sagemaker_client.stop_training_job.assert_called_once_with( TrainingJobName="test-job" @@ -191,7 +188,7 @@ def test_stop(self, mock_session): @patch("sagemaker.core.remote_function.job._logs_for_job") def test_wait(self, mock_logs, mock_session): """Test waiting for job completion.""" - job = _Job("test-job", "s3://bucket/output", mock_session, "test-key") + job = _Job("test-job", "s3://bucket/output", mock_session) mock_logs.return_value = {"TrainingJobStatus": "Completed"} job.wait(timeout=100) @@ -882,7 +879,7 @@ def test_start(self, mock_get_name, mock_compile, mock_session): mock_get_name.return_value = "test-job" mock_compile.return_value = { "TrainingJobName": "test-job", - "Environment": {"REMOTE_FUNCTION_SECRET_KEY": "test-key"}, + "Environment": {}, } job_settings = Mock() diff --git a/sagemaker-core/tests/unit/remote_function/test_job_comprehensive.py b/sagemaker-core/tests/unit/remote_function/test_job_comprehensive.py index 4069029685..bc8d5a8e56 100644 --- a/sagemaker-core/tests/unit/remote_function/test_job_comprehensive.py +++ b/sagemaker-core/tests/unit/remote_function/test_job_comprehensive.py @@ -144,17 +144,15 @@ def test_from_describe_response(self, mock_session): response = { "TrainingJobName": "test-job", "OutputDataConfig": {"S3OutputPath": "s3://bucket/output"}, - "Environment": {"REMOTE_FUNCTION_SECRET_KEY": "test-key"}, } job = _Job.from_describe_response(response, mock_session) assert job.job_name == "test-job" assert job.s3_uri == "s3://bucket/output" - assert job.hmac_key == "test-key" assert job._last_describe_response == response def test_describe_cached_completed(self, mock_session): """Test lines 865-871: describe with cached completed job.""" - job = _Job("test-job", "s3://bucket/output", mock_session, "test-key") + job = _Job("test-job", "s3://bucket/output", mock_session) job._last_describe_response = {"TrainingJobStatus": "Completed"} result = job.describe() @@ -163,7 +161,7 @@ def test_describe_cached_completed(self, mock_session): def test_describe_cached_failed(self, mock_session): """Test lines 865-871: describe with cached failed job.""" - job = _Job("test-job", "s3://bucket/output", mock_session, "test-key") + job = _Job("test-job", "s3://bucket/output", mock_session) job._last_describe_response = {"TrainingJobStatus": "Failed"} result = job.describe() @@ -172,7 +170,7 @@ def test_describe_cached_failed(self, mock_session): def test_describe_cached_stopped(self, mock_session): """Test lines 865-871: describe with cached stopped job.""" - job = _Job("test-job", "s3://bucket/output", mock_session, "test-key") + job = _Job("test-job", "s3://bucket/output", mock_session) job._last_describe_response = {"TrainingJobStatus": "Stopped"} result = job.describe() @@ -181,7 +179,7 @@ def test_describe_cached_stopped(self, mock_session): def test_stop(self, mock_session): """Test lines 886-887: stop method.""" - job = _Job("test-job", "s3://bucket/output", mock_session, "test-key") + job = _Job("test-job", "s3://bucket/output", mock_session) job.stop() mock_session.sagemaker_client.stop_training_job.assert_called_once_with( TrainingJobName="test-job" @@ -190,7 +188,7 @@ def test_stop(self, mock_session): @patch("sagemaker.core.remote_function.job._logs_for_job") def test_wait(self, mock_logs, mock_session): """Test lines 889-903: wait method.""" - job = _Job("test-job", "s3://bucket/output", mock_session, "test-key") + job = _Job("test-job", "s3://bucket/output", mock_session) mock_logs.return_value = {"TrainingJobStatus": "Completed"} job.wait(timeout=100) diff --git a/sagemaker-train/tests/unit/train/remote_function/test_logging_config.py b/sagemaker-core/tests/unit/remote_function/test_logging_config.py similarity index 97% rename from sagemaker-train/tests/unit/train/remote_function/test_logging_config.py rename to sagemaker-core/tests/unit/remote_function/test_logging_config.py index 7812c311eb..6454ea1071 100644 --- a/sagemaker-train/tests/unit/train/remote_function/test_logging_config.py +++ b/sagemaker-core/tests/unit/remote_function/test_logging_config.py @@ -16,7 +16,7 @@ import logging import time from unittest.mock import patch -from sagemaker.train.remote_function.logging_config import _UTCFormatter, get_logger +from sagemaker.core.remote_function.logging_config import _UTCFormatter, get_logger class TestUTCFormatter: diff --git a/sagemaker-mlops/src/sagemaker/mlops/workflow/pipeline.py b/sagemaker-mlops/src/sagemaker/mlops/workflow/pipeline.py index 30fbba3639..9b4b9a191b 100644 --- a/sagemaker-mlops/src/sagemaker/mlops/workflow/pipeline.py +++ b/sagemaker-mlops/src/sagemaker/mlops/workflow/pipeline.py @@ -1036,7 +1036,6 @@ def get_function_step_result( return deserialize_obj_from_s3( sagemaker_session=sagemaker_session, s3_uri=s3_uri, - hmac_key=describe_training_job_response["Environment"]["REMOTE_FUNCTION_SECRET_KEY"], ) raise RemoteFunctionError(_ERROR_MSG_OF_STEP_INCOMPLETE) diff --git a/sagemaker-mlops/tests/unit/workflow/test_pipeline.py b/sagemaker-mlops/tests/unit/workflow/test_pipeline.py index 55922a66ca..9169f1ce7f 100644 --- a/sagemaker-mlops/tests/unit/workflow/test_pipeline.py +++ b/sagemaker-mlops/tests/unit/workflow/test_pipeline.py @@ -360,7 +360,6 @@ def test_get_function_step_result_incomplete_job(mock_session): "AlgorithmSpecification": {"ContainerEntrypoint": JOBS_CONTAINER_ENTRYPOINT}, "OutputDataConfig": {"S3OutputPath": "s3://bucket/path"}, "TrainingJobStatus": "Failed", - "Environment": {"REMOTE_FUNCTION_SECRET_KEY": "key"} } with pytest.raises(RemoteFunctionError, match="not in Completed status"): @@ -376,7 +375,6 @@ def test_get_function_step_result_success(mock_session): "AlgorithmSpecification": {"ContainerEntrypoint": JOBS_CONTAINER_ENTRYPOINT}, "OutputDataConfig": {"S3OutputPath": "s3://bucket/path/exec-id/step1/results"}, "TrainingJobStatus": "Completed", - "Environment": {"REMOTE_FUNCTION_SECRET_KEY": "key"} } with patch("sagemaker.mlops.workflow.pipeline.deserialize_obj_from_s3", return_value="result"): @@ -443,7 +441,6 @@ def test_pipeline_execution_result_terminal_failure(mock_session): "AlgorithmSpecification": {"ContainerEntrypoint": JOBS_CONTAINER_ENTRYPOINT}, "OutputDataConfig": {"S3OutputPath": "s3://bucket/path/exec-id/step1/results"}, "TrainingJobStatus": "Completed", - "Environment": {"REMOTE_FUNCTION_SECRET_KEY": "key"} } with patch.object(execution, "wait", side_effect=WaiterError("name", "Waiter encountered a terminal failure state", {})): @@ -461,7 +458,6 @@ def test_get_function_step_result_obsolete_s3_path(mock_session): "AlgorithmSpecification": {"ContainerEntrypoint": JOBS_CONTAINER_ENTRYPOINT}, "OutputDataConfig": {"S3OutputPath": "s3://bucket/different/path"}, "TrainingJobStatus": "Completed", - "Environment": {"REMOTE_FUNCTION_SECRET_KEY": "key"} } with patch("sagemaker.mlops.workflow.pipeline.deserialize_obj_from_s3", return_value="result"): diff --git a/sagemaker-train/src/sagemaker/train/remote_function/__init__.py b/sagemaker-train/src/sagemaker/train/remote_function/__init__.py deleted file mode 100644 index bf29079921..0000000000 --- a/sagemaker-train/src/sagemaker/train/remote_function/__init__.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -""" -DEPRECATED: This module has been moved to sagemaker.core.remote_function - -This is a backward compatibility shim. Please update your imports to: - from sagemaker.core.remote_function import ... -""" -from __future__ import absolute_import - -import warnings - -# Backward compatibility: re-export from core -from sagemaker.core.remote_function.client import remote, RemoteExecutor # noqa: F401 -from sagemaker.core.remote_function.checkpoint_location import CheckpointLocation # noqa: F401 -from sagemaker.core.remote_function.custom_file_filter import CustomFileFilter # noqa: F401 -from sagemaker.core.remote_function.spark_config import SparkConfig # noqa: F401 - -warnings.warn( - "sagemaker.train.remote_function has been moved to sagemaker.core.remote_function. " - "Please update your imports. This shim will be removed in a future version.", - DeprecationWarning, - stacklevel=2 -) diff --git a/sagemaker-train/src/sagemaker/train/remote_function/checkpoint_location.py b/sagemaker-train/src/sagemaker/train/remote_function/checkpoint_location.py deleted file mode 100644 index 4153fe03d3..0000000000 --- a/sagemaker-train/src/sagemaker/train/remote_function/checkpoint_location.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""This module is used to define the CheckpointLocation to remote function.""" -from __future__ import absolute_import - -from os import PathLike -import re - -# Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CheckpointConfig.html -S3_URI_REGEX_PATTERN = r"^(https|s3)://([^/]+)/?(.*)$" - -_JOB_CHECKPOINT_LOCATION = "/opt/ml/checkpoints/" - - -def _validate_s3_uri_for_checkpoint(s3_uri: str): - """Validate if checkpoint location is specified with a valid s3 URI.""" - return re.match(S3_URI_REGEX_PATTERN, s3_uri) - - -class CheckpointLocation(PathLike): - """Class to represent the location where checkpoints are accessed in a remote function. - - To save or load checkpoints in a remote function, pass an CheckpointLocation object as a - function parameter and use it as a os.PathLike object. This CheckpointLocation object - represents the local directory (/opt/ml/checkpoints/) of checkpoints in side the job. - """ - - _local_path = _JOB_CHECKPOINT_LOCATION - - def __init__(self, s3_uri): - if not _validate_s3_uri_for_checkpoint(s3_uri): - raise ValueError("CheckpointLocation should be specified with valid s3 URI.") - self._s3_uri = s3_uri - - def __fspath__(self): - """Return job local path where checkpoints are stored.""" - return self._local_path diff --git a/sagemaker-train/src/sagemaker/train/remote_function/client.py b/sagemaker-train/src/sagemaker/train/remote_function/client.py deleted file mode 100644 index eb99d14c1e..0000000000 --- a/sagemaker-train/src/sagemaker/train/remote_function/client.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -""" -DEPRECATED: This module has been moved to sagemaker.core.remote_function.client - -This is a backward compatibility shim. -""" -from __future__ import absolute_import - -import warnings - -# Backward compatibility: re-export from core -from sagemaker.core.remote_function.client import * # noqa: F401, F403 - -warnings.warn( - "sagemaker.train.remote_function.client has been moved to sagemaker.core.remote_function.client. " - "Please update your imports. This shim will be removed in a future version.", - DeprecationWarning, - stacklevel=2 -) diff --git a/sagemaker-train/src/sagemaker/train/remote_function/core/__init__.py b/sagemaker-train/src/sagemaker/train/remote_function/core/__init__.py deleted file mode 100644 index 7e9f2d30da..0000000000 --- a/sagemaker-train/src/sagemaker/train/remote_function/core/__init__.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -""" -DEPRECATED: This module has been moved to sagemaker.core.remote_function.core - -This is a backward compatibility shim. -""" -from __future__ import absolute_import - -import warnings - -warnings.warn( - "sagemaker.train.remote_function.core has been moved to sagemaker.core.remote_function.core. " - "Please update your imports. This shim will be removed in a future version.", - DeprecationWarning, - stacklevel=2 -) diff --git a/sagemaker-train/src/sagemaker/train/remote_function/core/_custom_dispatch_table.py b/sagemaker-train/src/sagemaker/train/remote_function/core/_custom_dispatch_table.py deleted file mode 100644 index 20b7a297b5..0000000000 --- a/sagemaker-train/src/sagemaker/train/remote_function/core/_custom_dispatch_table.py +++ /dev/null @@ -1,56 +0,0 @@ - -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""SageMaker remote function data serializer/deserializer.""" -from __future__ import absolute_import - -from sagemaker.train.remote_function.errors import SerializationError - -from sagemaker.core.helper.pipeline_variable import PipelineVariable -from sagemaker.core.workflow.parameters import ( - ParameterInteger, - ParameterFloat, - ParameterString, - ParameterBoolean, -) -from sagemaker.core.workflow.execution_variables import ExecutionVariable -from sagemaker.mlops.workflow.function_step import DelayedReturn -from sagemaker.core.workflow.properties import ( - Properties, - PropertiesMap, - PropertiesList, -) - - -def _pipeline_variable_reducer(pipeline_variable): - """Reducer for pipeline variable.""" - - raise SerializationError( - """Please pass the pipeline variable to the function decorated with @step as an argument. - Referencing to a pipeline variable from within the function - or passing a pipeline variable nested in a data structure are not supported.""" - ) - - -dispatch_table = { - ParameterInteger: _pipeline_variable_reducer, - ParameterFloat: _pipeline_variable_reducer, - ParameterString: _pipeline_variable_reducer, - ParameterBoolean: _pipeline_variable_reducer, - ExecutionVariable: _pipeline_variable_reducer, - PipelineVariable: _pipeline_variable_reducer, - Properties: _pipeline_variable_reducer, - PropertiesMap: _pipeline_variable_reducer, - PropertiesList: _pipeline_variable_reducer, - DelayedReturn: _pipeline_variable_reducer, -} diff --git a/sagemaker-train/src/sagemaker/train/remote_function/core/pipeline_variables.py b/sagemaker-train/src/sagemaker/train/remote_function/core/pipeline_variables.py deleted file mode 100644 index 5767a07596..0000000000 --- a/sagemaker-train/src/sagemaker/train/remote_function/core/pipeline_variables.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -""" -DEPRECATED: This module has been moved to sagemaker.core.remote_function.core.pipeline_variables - -This is a backward compatibility shim. -""" -from __future__ import absolute_import - -import warnings - -# Backward compatibility: re-export from core -from sagemaker.core.remote_function.core.pipeline_variables import * # noqa: F401, F403 - -warnings.warn( - "sagemaker.train.remote_function.core.pipeline_variables has been moved to sagemaker.core.remote_function.core.pipeline_variables. " - "Please update your imports. This shim will be removed in a future version.", - DeprecationWarning, - stacklevel=2 -) diff --git a/sagemaker-train/src/sagemaker/train/remote_function/core/serialization.py b/sagemaker-train/src/sagemaker/train/remote_function/core/serialization.py deleted file mode 100644 index d30d1494d5..0000000000 --- a/sagemaker-train/src/sagemaker/train/remote_function/core/serialization.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -""" -DEPRECATED: This module has been moved to sagemaker.core.remote_function.core.serialization - -This is a backward compatibility shim. -""" -from __future__ import absolute_import - -import warnings - -# Backward compatibility: re-export from core -from sagemaker.core.remote_function.core.serialization import * # noqa: F401, F403 - -warnings.warn( - "sagemaker.train.remote_function.core.serialization has been moved to sagemaker.core.remote_function.core.serialization. " - "Please update your imports. This shim will be removed in a future version.", - DeprecationWarning, - stacklevel=2 -) diff --git a/sagemaker-train/src/sagemaker/train/remote_function/core/stored_function.py b/sagemaker-train/src/sagemaker/train/remote_function/core/stored_function.py deleted file mode 100644 index 34915a4d42..0000000000 --- a/sagemaker-train/src/sagemaker/train/remote_function/core/stored_function.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -""" -DEPRECATED: This module has been moved to sagemaker.core.remote_function.core.stored_function - -This is a backward compatibility shim. -""" -from __future__ import absolute_import - -import warnings - -# Backward compatibility: re-export from core -from sagemaker.core.remote_function.core.stored_function import * # noqa: F401, F403 - -warnings.warn( - "sagemaker.train.remote_function.core.stored_function has been moved to sagemaker.core.remote_function.core.stored_function. " - "Please update your imports. This shim will be removed in a future version.", - DeprecationWarning, - stacklevel=2 -) diff --git a/sagemaker-train/src/sagemaker/train/remote_function/custom_file_filter.py b/sagemaker-train/src/sagemaker/train/remote_function/custom_file_filter.py deleted file mode 100644 index 9c1b1e1baa..0000000000 --- a/sagemaker-train/src/sagemaker/train/remote_function/custom_file_filter.py +++ /dev/null @@ -1,128 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""SageMaker remote function client.""" -from __future__ import absolute_import - -import fnmatch -import os -import shutil -from typing import List, Optional, Callable, Union - -from sagemaker.core.common_utils import resolve_value_from_config -from sagemaker.core.config.config_schema import REMOTE_FUNCTION_PATH, CUSTOM_FILE_FILTER - - -class CustomFileFilter: - """Configuration that specifies how the local working directory should be packaged.""" - - def __init__(self, *, ignore_name_patterns: List[str] = None): - """Initialize a CustomFileFilter. - - Args: - ignore_name_patterns (List[str]): ignore files or directories with names - that match one of the glob-style patterns. Defaults to None. - """ - - if ignore_name_patterns is None: - ignore_name_patterns = [] - - self._workdir = os.getcwd() - self._ignore_name_patterns = ignore_name_patterns - - @property - def ignore_name_patterns(self): - """Get the ignore name patterns.""" - return self._ignore_name_patterns - - @property - def workdir(self): - """Get the working directory.""" - return self._workdir - - -def resolve_custom_file_filter_from_config_file( - direct_input: Union[Callable[[str, List], List], CustomFileFilter] = None, - sagemaker_session=None, -) -> Union[Callable[[str, List], List], CustomFileFilter, None]: - """Resolve the CustomFileFilter configuration from the config file. - - Args: - direct_input (Callable[[str, List], List], CustomFileFilter): direct input from the user. - sagemaker_session (sagemaker.core.helper.session.Session): sagemaker session. - Returns: - CustomFileFilter: configuration that specifies how the local - working directory should be packaged. - """ - if direct_input is not None: - return direct_input - ignore_name_patterns = resolve_value_from_config( - direct_input=None, - config_path=".".join([REMOTE_FUNCTION_PATH, CUSTOM_FILE_FILTER, "IgnoreNamePatterns"]), - default_value=None, - sagemaker_session=sagemaker_session, - ) - if ignore_name_patterns is not None: - return CustomFileFilter(ignore_name_patterns=ignore_name_patterns) - return None - - -def copy_workdir( - dst: str, - custom_file_filter: Optional[Union[Callable[[str, List], List], CustomFileFilter]] = None, -): - """Copy the local working directory to the destination. - - Args: - dst (str): destination path. - custom_file_filter (Union[Callable[[str, List], List], CustomFileFilter): configuration that - specifies how the local working directory should be packaged. - """ - - def _ignore_patterns(path: str, names: List): # pylint: disable=unused-argument - ignored_names = set() - if custom_file_filter.ignore_name_patterns is not None: - for pattern in custom_file_filter.ignore_name_patterns: - ignored_names.update(fnmatch.filter(names, pattern)) - return ignored_names - - def _filter_non_python_files(path: str, names: List) -> List: - """Ignore function for filtering out non python files.""" - to_ignore = [] - for name in names: - full_path = os.path.join(path, name) - if os.path.isfile(full_path): - if not name.endswith(".py"): - to_ignore.append(name) - elif os.path.isdir(full_path): - if name == "__pycache__": - to_ignore.append(name) - else: - to_ignore.append(name) - - return to_ignore - - _ignore = None - _src = os.getcwd() - if not custom_file_filter: - _ignore = _filter_non_python_files - elif callable(custom_file_filter): - _ignore = custom_file_filter - elif isinstance(custom_file_filter, CustomFileFilter): - _ignore = _ignore_patterns - _src = custom_file_filter.workdir - - shutil.copytree( - _src, - dst, - ignore=_ignore, - ) \ No newline at end of file diff --git a/sagemaker-train/src/sagemaker/train/remote_function/errors.py b/sagemaker-train/src/sagemaker/train/remote_function/errors.py deleted file mode 100644 index e67fcf7d9f..0000000000 --- a/sagemaker-train/src/sagemaker/train/remote_function/errors.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -""" -DEPRECATED: This module has been moved to sagemaker.core.remote_function.errors - -This is a backward compatibility shim. -""" -from __future__ import absolute_import - -import warnings - -# Backward compatibility: re-export from core -from sagemaker.core.remote_function.errors import * # noqa: F401, F403 - -warnings.warn( - "sagemaker.train.remote_function.errors has been moved to sagemaker.core.remote_function.errors. " - "Please update your imports. This shim will be removed in a future version.", - DeprecationWarning, - stacklevel=2 -) diff --git a/sagemaker-train/src/sagemaker/train/remote_function/invoke_function.py b/sagemaker-train/src/sagemaker/train/remote_function/invoke_function.py deleted file mode 100644 index 3bafeffd5b..0000000000 --- a/sagemaker-train/src/sagemaker/train/remote_function/invoke_function.py +++ /dev/null @@ -1,172 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""An entry point for invoking remote function inside a job.""" - -from __future__ import absolute_import - -import argparse -import sys -import json -import os -from typing import TYPE_CHECKING - -import boto3 -from sagemaker.train.remote_function.job import ( - KEY_EXPERIMENT_NAME, - KEY_RUN_NAME, -) - -from sagemaker.core.helper.session_helper import Session -from sagemaker.core.s3 import s3_path_join -from sagemaker.train.remote_function.errors import handle_error -from sagemaker.train.remote_function import logging_config -from sagemaker.train.remote_function.core.pipeline_variables import Context - -if TYPE_CHECKING: - from sagemaker.core.experiments.run import Run - - -SUCCESS_EXIT_CODE = 0 - - -def _parse_args(args): - """Parses CLI arguments.""" - parser = argparse.ArgumentParser() - parser.add_argument("--region", type=str, required=True) - parser.add_argument("--s3_base_uri", type=str, required=True) - parser.add_argument("--s3_kms_key", type=str) - parser.add_argument("--run_in_context", type=str) - parser.add_argument("--pipeline_step_name", type=str) - parser.add_argument("--pipeline_execution_id", type=str) - parser.add_argument("--property_references", nargs="+", type=str, default=[]) - parser.add_argument( - "--serialize_output_to_json", default=False, type=lambda x: (str(x).lower() == "true") - ) - parser.add_argument("--func_step_s3_dir", type=str) - - args, _ = parser.parse_known_args(args) - return args - - -def _get_sagemaker_session(region): - """Get sagemaker session for interacting with AWS or Sagemaker services""" - boto_session = boto3.session.Session(region_name=region) - return Session(boto_session=boto_session) - - -def _load_run_object(run_in_context: str, sagemaker_session: Session) -> "Run": - """Load current run in json string into run object""" - from sagemaker.core.experiments.run import Run - - run_dict = json.loads(run_in_context) - return Run( - experiment_name=run_dict.get(KEY_EXPERIMENT_NAME), - run_name=run_dict.get(KEY_RUN_NAME), - sagemaker_session=sagemaker_session, - ) - - -def _load_pipeline_context(args) -> Context: - """Load pipeline build or run context into context object""" - - pipeline_step_name = args.pipeline_step_name - pipeline_execution_id = args.pipeline_execution_id - property_references = args.property_references - serialize_output_to_json = args.serialize_output_to_json - func_step_s3_dir = args.func_step_s3_dir - - property_references_dict = {} - for i in range(0, len(property_references), 2): - property_references_dict[property_references[i]] = property_references[i + 1] - return Context( - step_name=pipeline_step_name, - execution_id=pipeline_execution_id, - property_references=property_references_dict, - serialize_output_to_json=serialize_output_to_json, - func_step_s3_dir=func_step_s3_dir, - ) - - -def _execute_remote_function( - sagemaker_session, s3_base_uri, s3_kms_key, run_in_context, hmac_key, context -): - """Execute stored remote function""" - from sagemaker.train.remote_function.core.stored_function import StoredFunction - - stored_function = StoredFunction( - sagemaker_session=sagemaker_session, - s3_base_uri=s3_base_uri, - s3_kms_key=s3_kms_key, - hmac_key=hmac_key, - context=context, - ) - - if run_in_context: - run_obj = _load_run_object(run_in_context, sagemaker_session) - with run_obj: - stored_function.load_and_invoke() - else: - stored_function.load_and_invoke() - - -def main(sys_args=None): - """Entry point for invoke function script - - Args: - sys_args (list): List of arguments to parse. If not specified, sys.argv is used. - """ - - logger = logging_config.get_logger() - - exit_code = SUCCESS_EXIT_CODE - - try: - args = _parse_args(sys_args) - region = args.region - s3_base_uri = args.s3_base_uri - s3_kms_key = args.s3_kms_key - run_in_context = args.run_in_context - pipeline_context = _load_pipeline_context(args) - - hmac_key = os.getenv("REMOTE_FUNCTION_SECRET_KEY") - - sagemaker_session = _get_sagemaker_session(region) - _execute_remote_function( - sagemaker_session=sagemaker_session, - s3_base_uri=s3_base_uri, - s3_kms_key=s3_kms_key, - run_in_context=run_in_context, - hmac_key=hmac_key, - context=pipeline_context, - ) - - except Exception as e: # pylint: disable=broad-except - logger.exception("Error encountered while invoking the remote function.") - s3_uri = ( - s3_path_join(s3_base_uri, pipeline_context.execution_id, pipeline_context.step_name) - if pipeline_context.step_name - else s3_base_uri - ) - exit_code = handle_error( - error=e, - sagemaker_session=sagemaker_session, - s3_base_uri=s3_uri, - s3_kms_key=s3_kms_key, - hmac_key=hmac_key, - ) - finally: - sys.exit(exit_code) - - -if __name__ == "__main__": - main(sys.argv[1:]) \ No newline at end of file diff --git a/sagemaker-train/src/sagemaker/train/remote_function/job.py b/sagemaker-train/src/sagemaker/train/remote_function/job.py deleted file mode 100644 index 33bf62af86..0000000000 --- a/sagemaker-train/src/sagemaker/train/remote_function/job.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -""" -DEPRECATED: This module has been moved to sagemaker.core.remote_function.job - -This is a backward compatibility shim. -""" -from __future__ import absolute_import - -import warnings - -# Backward compatibility: re-export from core -from sagemaker.core.remote_function.job import * # noqa: F401, F403 - -warnings.warn( - "sagemaker.train.remote_function.job has been moved to sagemaker.core.remote_function.job. " - "Please update your imports. This shim will be removed in a future version.", - DeprecationWarning, - stacklevel=2 -) diff --git a/sagemaker-train/src/sagemaker/train/remote_function/logging_config.py b/sagemaker-train/src/sagemaker/train/remote_function/logging_config.py deleted file mode 100644 index 875fabf6e0..0000000000 --- a/sagemaker-train/src/sagemaker/train/remote_function/logging_config.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Utilities related to logging.""" -from __future__ import absolute_import - -import logging -import time - - -class _UTCFormatter(logging.Formatter): - """Class that overrides the default local time provider in log formatter.""" - - converter = time.gmtime - - -def get_logger(): - """Return a logger with the name 'sagemaker'""" - sagemaker_logger = logging.getLogger("sagemaker.remote_function") - if len(sagemaker_logger.handlers) == 0: - sagemaker_logger.setLevel(logging.INFO) - handler = logging.StreamHandler() - formatter = _UTCFormatter("%(asctime)s %(name)s %(levelname)-8s %(message)s") - handler.setFormatter(formatter) - sagemaker_logger.addHandler(handler) - # don't stream logs with the root logger handler - sagemaker_logger.propagate = 0 - - return sagemaker_logger diff --git a/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/__init__.py b/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/__init__.py deleted file mode 100644 index 18557a2eb5..0000000000 --- a/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Sagemaker modules container_drivers directory.""" -from __future__ import absolute_import diff --git a/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/bootstrap_runtime_environment.py b/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/bootstrap_runtime_environment.py deleted file mode 100644 index afe0f80012..0000000000 --- a/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/bootstrap_runtime_environment.py +++ /dev/null @@ -1,602 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""An entry point for runtime environment. This must be kept independent of SageMaker PySDK""" -from __future__ import absolute_import - -import argparse -import getpass -import json -import multiprocessing -import os -import pathlib -import shutil -import subprocess -import sys -from typing import Any, Dict - -if __package__ is None or __package__ == "": - from runtime_environment_manager import ( - RuntimeEnvironmentManager, - _DependencySettings, - get_logger, - ) -else: - from sagemaker.train.remote_function.runtime_environment.runtime_environment_manager import ( - RuntimeEnvironmentManager, - _DependencySettings, - get_logger, - ) - -SUCCESS_EXIT_CODE = 0 -DEFAULT_FAILURE_CODE = 1 - -REMOTE_FUNCTION_WORKSPACE = "sm_rf_user_ws" -BASE_CHANNEL_PATH = "/opt/ml/input/data" -FAILURE_REASON_PATH = "/opt/ml/output/failure" -JOB_OUTPUT_DIRS = ["/opt/ml/input", "/opt/ml/output", "/opt/ml/model", "/tmp"] -PRE_EXECUTION_SCRIPT_NAME = "pre_exec.sh" -JOB_REMOTE_FUNCTION_WORKSPACE = "sagemaker_remote_function_workspace" -SCRIPT_AND_DEPENDENCIES_CHANNEL_NAME = "pre_exec_script_and_dependencies" - -SM_MODEL_DIR = "/opt/ml/model" - -SM_INPUT_DIR = "/opt/ml/input" -SM_INPUT_DATA_DIR = "/opt/ml/input/data" -SM_INPUT_CONFIG_DIR = "/opt/ml/input/config" - -SM_OUTPUT_DIR = "/opt/ml/output" -SM_OUTPUT_FAILURE = "/opt/ml/output/failure" -SM_OUTPUT_DATA_DIR = "/opt/ml/output/data" - -SM_MASTER_ADDR = "algo-1" -SM_MASTER_PORT = 7777 - -RESOURCE_CONFIG = f"{SM_INPUT_CONFIG_DIR}/resourceconfig.json" -ENV_OUTPUT_FILE = "/opt/ml/input/sm_training.env" - -SENSITIVE_KEYWORDS = ["SECRET", "PASSWORD", "KEY", "TOKEN", "PRIVATE", "CREDS", "CREDENTIALS"] -HIDDEN_VALUE = "******" - -SM_EFA_NCCL_INSTANCES = [ - "ml.g4dn.8xlarge", - "ml.g4dn.12xlarge", - "ml.g5.48xlarge", - "ml.p3dn.24xlarge", - "ml.p4d.24xlarge", - "ml.p4de.24xlarge", - "ml.p5.48xlarge", - "ml.trn1.32xlarge", -] - -SM_EFA_RDMA_INSTANCES = [ - "ml.p4d.24xlarge", - "ml.p4de.24xlarge", - "ml.trn1.32xlarge", -] - -logger = get_logger() - - -def _bootstrap_runtime_env_for_remote_function( - client_python_version: str, - conda_env: str = None, - dependency_settings: _DependencySettings = None, -): - """Bootstrap runtime environment for remote function invocation. - - Args: - client_python_version (str): Python version at the client side. - conda_env (str): conda environment to be activated. Default is None. - dependency_settings (dict): Settings for installing dependencies. - """ - - workspace_unpack_dir = _unpack_user_workspace() - if not workspace_unpack_dir: - logger.info("No workspace to unpack and setup.") - return - - _handle_pre_exec_scripts(workspace_unpack_dir) - - _install_dependencies( - workspace_unpack_dir, - conda_env, - client_python_version, - REMOTE_FUNCTION_WORKSPACE, - dependency_settings, - ) - - -def _bootstrap_runtime_env_for_pipeline_step( - client_python_version: str, - func_step_workspace: str, - conda_env: str = None, - dependency_settings: _DependencySettings = None, -): - """Bootstrap runtime environment for pipeline step invocation. - - Args: - client_python_version (str): Python version at the client side. - func_step_workspace (str): s3 folder where workspace for FunctionStep is stored - conda_env (str): conda environment to be activated. Default is None. - dependency_settings (dict): Name of the dependency file. Default is None. - """ - - workspace_dir = _unpack_user_workspace(func_step_workspace) - if not workspace_dir: - os.mkdir(JOB_REMOTE_FUNCTION_WORKSPACE) - workspace_dir = pathlib.Path(os.getcwd(), JOB_REMOTE_FUNCTION_WORKSPACE).absolute() - - pre_exec_script_and_dependencies_dir = os.path.join( - BASE_CHANNEL_PATH, SCRIPT_AND_DEPENDENCIES_CHANNEL_NAME - ) - - if not os.path.exists(pre_exec_script_and_dependencies_dir): - logger.info("No dependencies to bootstrap") - return - for file in os.listdir(pre_exec_script_and_dependencies_dir): - src_path = os.path.join(pre_exec_script_and_dependencies_dir, file) - dest_path = os.path.join(workspace_dir, file) - shutil.copy(src_path, dest_path) - - _handle_pre_exec_scripts(workspace_dir) - - _install_dependencies( - workspace_dir, - conda_env, - client_python_version, - SCRIPT_AND_DEPENDENCIES_CHANNEL_NAME, - dependency_settings, - ) - - -def _handle_pre_exec_scripts(script_file_dir: str): - """Run the pre execution scripts. - - Args: - script_file_dir (str): Directory in the container where pre-execution scripts exists. - """ - - path_to_pre_exec_script = os.path.join(script_file_dir, PRE_EXECUTION_SCRIPT_NAME) - RuntimeEnvironmentManager().run_pre_exec_script(pre_exec_script_path=path_to_pre_exec_script) - - -def _install_dependencies( - dependency_file_dir: str, - conda_env: str, - client_python_version: str, - channel_name: str, - dependency_settings: _DependencySettings = None, -): - """Install dependencies in the job container - - Args: - dependency_file_dir (str): Directory in the container where dependency file exists. - conda_env (str): conda environment to be activated. - client_python_version (str): Python version at the client side. - channel_name (str): Channel where dependency file was uploaded. - dependency_settings (dict): Settings for installing dependencies. - """ - - if dependency_settings is not None and dependency_settings.dependency_file is None: - # an empty dict is passed when no dependencies are specified - logger.info("No dependencies to install.") - elif dependency_settings is not None: - dependencies_file = os.path.join(dependency_file_dir, dependency_settings.dependency_file) - RuntimeEnvironmentManager().bootstrap( - local_dependencies_file=dependencies_file, - conda_env=conda_env, - client_python_version=client_python_version, - ) - else: - # no dependency file name is passed when an legacy version of the SDK is used - # we look for a file with .txt, .yml or .yaml extension in the workspace directory - dependencies_file = None - for file in os.listdir(dependency_file_dir): - if file.endswith(".txt") or file.endswith(".yml") or file.endswith(".yaml"): - dependencies_file = os.path.join(dependency_file_dir, file) - break - - if dependencies_file: - RuntimeEnvironmentManager().bootstrap( - local_dependencies_file=dependencies_file, - conda_env=conda_env, - client_python_version=client_python_version, - ) - else: - logger.info( - "Did not find any dependency file in the directory at '%s'." - " Assuming no additional dependencies to install.", - os.path.join(BASE_CHANNEL_PATH, channel_name), - ) - - -def _unpack_user_workspace(func_step_workspace: str = None): - """Unzip the user workspace""" - - workspace_archive_dir_path = ( - os.path.join(BASE_CHANNEL_PATH, REMOTE_FUNCTION_WORKSPACE) - if not func_step_workspace - else os.path.join(BASE_CHANNEL_PATH, func_step_workspace) - ) - if not os.path.exists(workspace_archive_dir_path): - logger.info( - "Directory '%s' does not exist.", - workspace_archive_dir_path, - ) - return None - - workspace_archive_path = os.path.join(workspace_archive_dir_path, "workspace.zip") - if not os.path.isfile(workspace_archive_path): - logger.info( - "Workspace archive '%s' does not exist.", - workspace_archive_dir_path, - ) - return None - - workspace_unpack_dir = pathlib.Path(os.getcwd()).absolute() - shutil.unpack_archive(filename=workspace_archive_path, extract_dir=workspace_unpack_dir) - logger.info("Successfully unpacked workspace archive at '%s'.", workspace_unpack_dir) - workspace_unpack_dir = pathlib.Path(workspace_unpack_dir, JOB_REMOTE_FUNCTION_WORKSPACE) - return workspace_unpack_dir - - -def _write_failure_reason_file(failure_msg): - """Create a file 'failure' with failure reason written if bootstrap runtime env failed. - - See: https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo.html - Args: - failure_msg: The content of file to be written. - """ - if not os.path.exists(FAILURE_REASON_PATH): - with open(FAILURE_REASON_PATH, "w") as f: - f.write("RuntimeEnvironmentError: " + failure_msg) - - -def _parse_args(sys_args): - """Parses CLI arguments.""" - parser = argparse.ArgumentParser() - parser.add_argument("--job_conda_env", type=str) - parser.add_argument("--client_python_version", type=str) - parser.add_argument("--client_sagemaker_pysdk_version", type=str, default=None) - parser.add_argument("--pipeline_execution_id", type=str) - parser.add_argument("--dependency_settings", type=str) - parser.add_argument("--func_step_s3_dir", type=str) - parser.add_argument("--distribution", type=str, default=None) - parser.add_argument("--user_nproc_per_node", type=str, default=None) - args, _ = parser.parse_known_args(sys_args) - return args - - -def log_key_value(key: str, value: str): - """Log a key-value pair, masking sensitive values if necessary.""" - if any(keyword.lower() in key.lower() for keyword in SENSITIVE_KEYWORDS): - logger.info("%s=%s", key, HIDDEN_VALUE) - elif isinstance(value, dict): - masked_value = mask_sensitive_info(value) - logger.info("%s=%s", key, json.dumps(masked_value)) - else: - try: - decoded_value = json.loads(value) - if isinstance(decoded_value, dict): - masked_value = mask_sensitive_info(decoded_value) - logger.info("%s=%s", key, json.dumps(masked_value)) - else: - logger.info("%s=%s", key, decoded_value) - except (json.JSONDecodeError, TypeError): - logger.info("%s=%s", key, value) - - -def log_env_variables(env_vars_dict: Dict[str, Any]): - """Log Environment Variables from the environment and an env_vars_dict.""" - for key, value in os.environ.items(): - log_key_value(key, value) - - for key, value in env_vars_dict.items(): - log_key_value(key, value) - - -def mask_sensitive_info(data): - """Recursively mask sensitive information in a dictionary.""" - if isinstance(data, dict): - for k, v in data.items(): - if isinstance(v, dict): - data[k] = mask_sensitive_info(v) - elif isinstance(v, str) and any( - keyword.lower() in k.lower() for keyword in SENSITIVE_KEYWORDS - ): - data[k] = HIDDEN_VALUE - return data - - -def num_cpus() -> int: - """Return the number of CPUs available in the current container. - - Returns: - int: Number of CPUs available in the current container. - """ - return multiprocessing.cpu_count() - - -def num_gpus() -> int: - """Return the number of GPUs available in the current container. - - Returns: - int: Number of GPUs available in the current container. - """ - try: - cmd = ["nvidia-smi", "--list-gpus"] - output = subprocess.check_output(cmd).decode("utf-8") - return sum(1 for line in output.splitlines() if line.startswith("GPU ")) - except (OSError, subprocess.CalledProcessError): - logger.info("No GPUs detected (normal if no gpus installed)") - return 0 - - -def num_neurons() -> int: - """Return the number of neuron cores available in the current container. - - Returns: - int: Number of Neuron Cores available in the current container. - """ - try: - cmd = ["neuron-ls", "-j"] - output = subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode("utf-8") - j = json.loads(output) - neuron_cores = 0 - for item in j: - neuron_cores += item.get("nc_count", 0) - logger.info("Found %s neurons on this instance", neuron_cores) - return neuron_cores - except OSError: - logger.info("No Neurons detected (normal if no neurons installed)") - return 0 - except subprocess.CalledProcessError as e: - if e.output is not None: - try: - msg = e.output.decode("utf-8").partition("error=")[2] - logger.info( - "No Neurons detected (normal if no neurons installed). \ - If neuron installed then %s", - msg, - ) - except AttributeError: - logger.info("No Neurons detected (normal if no neurons installed)") - else: - logger.info("No Neurons detected (normal if no neurons installed)") - - return 0 - - -def safe_serialize(data): - """Serialize the data without wrapping strings in quotes. - - This function handles the following cases: - 1. If `data` is a string, it returns the string as-is without wrapping in quotes. - 2. If `data` is serializable (e.g., a dictionary, list, int, float), it returns - the JSON-encoded string using `json.dumps()`. - 3. If `data` cannot be serialized (e.g., a custom object), it returns the string - representation of the data using `str(data)`. - - Args: - data (Any): The data to serialize. - - Returns: - str: The serialized JSON-compatible string or the string representation of the input. - """ - if isinstance(data, str): - return data - try: - return json.dumps(data) - except TypeError: - return str(data) - - -def set_env( - resource_config: Dict[str, Any], - distribution: str = None, - user_nproc_per_node: bool = None, - output_file: str = ENV_OUTPUT_FILE, -): - """Set environment variables for the training job container. - - Args: - resource_config (Dict[str, Any]): Resource configuration for the training job. - output_file (str): Output file to write the environment variables. - """ - # Constants - env_vars = { - "SM_MODEL_DIR": SM_MODEL_DIR, - "SM_INPUT_DIR": SM_INPUT_DIR, - "SM_INPUT_DATA_DIR": SM_INPUT_DATA_DIR, - "SM_INPUT_CONFIG_DIR": SM_INPUT_CONFIG_DIR, - "SM_OUTPUT_DIR": SM_OUTPUT_DIR, - "SM_OUTPUT_FAILURE": SM_OUTPUT_FAILURE, - "SM_OUTPUT_DATA_DIR": SM_OUTPUT_DATA_DIR, - "SM_MASTER_ADDR": SM_MASTER_ADDR, - "SM_MASTER_PORT": SM_MASTER_PORT, - } - - # Host Variables - current_host = resource_config["current_host"] - current_instance_type = resource_config["current_instance_type"] - hosts = resource_config["hosts"] - sorted_hosts = sorted(hosts) - - env_vars["SM_CURRENT_HOST"] = current_host - env_vars["SM_CURRENT_INSTANCE_TYPE"] = current_instance_type - env_vars["SM_HOSTS"] = sorted_hosts - env_vars["SM_NETWORK_INTERFACE_NAME"] = resource_config["network_interface_name"] - env_vars["SM_HOST_COUNT"] = len(sorted_hosts) - env_vars["SM_CURRENT_HOST_RANK"] = sorted_hosts.index(current_host) - - env_vars["SM_NUM_CPUS"] = num_cpus() - env_vars["SM_NUM_GPUS"] = num_gpus() - env_vars["SM_NUM_NEURONS"] = num_neurons() - - # Misc. - env_vars["SM_RESOURCE_CONFIG"] = resource_config - - if user_nproc_per_node is not None and int(user_nproc_per_node) > 0: - env_vars["SM_NPROC_PER_NODE"] = int(user_nproc_per_node) - else: - if int(env_vars["SM_NUM_GPUS"]) > 0: - env_vars["SM_NPROC_PER_NODE"] = int(env_vars["SM_NUM_GPUS"]) - elif int(env_vars["SM_NUM_NEURONS"]) > 0: - env_vars["SM_NPROC_PER_NODE"] = int(env_vars["SM_NUM_NEURONS"]) - else: - env_vars["SM_NPROC_PER_NODE"] = int(env_vars["SM_NUM_CPUS"]) - - # All Training Environment Variables - env_vars["SM_TRAINING_ENV"] = { - "current_host": env_vars["SM_CURRENT_HOST"], - "current_instance_type": env_vars["SM_CURRENT_INSTANCE_TYPE"], - "hosts": env_vars["SM_HOSTS"], - "host_count": env_vars["SM_HOST_COUNT"], - "nproc_per_node": env_vars["SM_NPROC_PER_NODE"], - "master_addr": env_vars["SM_MASTER_ADDR"], - "master_port": env_vars["SM_MASTER_PORT"], - "input_config_dir": env_vars["SM_INPUT_CONFIG_DIR"], - "input_data_dir": env_vars["SM_INPUT_DATA_DIR"], - "input_dir": env_vars["SM_INPUT_DIR"], - "job_name": os.environ["TRAINING_JOB_NAME"], - "model_dir": env_vars["SM_MODEL_DIR"], - "network_interface_name": env_vars["SM_NETWORK_INTERFACE_NAME"], - "num_cpus": env_vars["SM_NUM_CPUS"], - "num_gpus": env_vars["SM_NUM_GPUS"], - "num_neurons": env_vars["SM_NUM_NEURONS"], - "output_data_dir": env_vars["SM_OUTPUT_DATA_DIR"], - "resource_config": env_vars["SM_RESOURCE_CONFIG"], - } - - if distribution and distribution == "torchrun": - logger.info("Distribution: torchrun") - - instance_type = env_vars["SM_CURRENT_INSTANCE_TYPE"] - network_interface_name = env_vars.get("SM_NETWORK_INTERFACE_NAME", "eth0") - - if instance_type in SM_EFA_NCCL_INSTANCES: - # Enable EFA use - env_vars["FI_PROVIDER"] = "efa" - if instance_type in SM_EFA_RDMA_INSTANCES: - # Use EFA's RDMA functionality for one-sided and two-sided transfer - env_vars["FI_EFA_USE_DEVICE_RDMA"] = "1" - env_vars["RDMAV_FORK_SAFE"] = "1" - env_vars["NCCL_SOCKET_IFNAME"] = str(network_interface_name) - env_vars["NCCL_PROTO"] = "simple" - elif distribution and distribution == "mpirun": - logger.info("Distribution: mpirun") - - env_vars["MASTER_ADDR"] = env_vars["SM_MASTER_ADDR"] - env_vars["MASTER_PORT"] = str(env_vars["SM_MASTER_PORT"]) - - host_list = [ - "{}:{}".format(host, int(env_vars["SM_NPROC_PER_NODE"])) for host in sorted_hosts - ] - env_vars["SM_HOSTS_LIST"] = ",".join(host_list) - - instance_type = env_vars["SM_CURRENT_INSTANCE_TYPE"] - - if instance_type in SM_EFA_NCCL_INSTANCES: - env_vars["SM_FI_PROVIDER"] = "-x FI_PROVIDER=efa" - env_vars["SM_NCCL_PROTO"] = "-x NCCL_PROTO=simple" - else: - env_vars["SM_FI_PROVIDER"] = "" - env_vars["SM_NCCL_PROTO"] = "" - - if instance_type in SM_EFA_RDMA_INSTANCES: - env_vars["SM_FI_EFA_USE_DEVICE_RDMA"] = "-x FI_EFA_USE_DEVICE_RDMA=1" - else: - env_vars["SM_FI_EFA_USE_DEVICE_RDMA"] = "" - - with open(output_file, "w") as f: - for key, value in env_vars.items(): - f.write(f"export {key}='{safe_serialize(value)}'\n") - - logger.info("Environment Variables:") - log_env_variables(env_vars_dict=env_vars) - - -def main(sys_args=None): - """Entry point for bootstrap script""" - - exit_code = DEFAULT_FAILURE_CODE - - try: - args = _parse_args(sys_args) - - logger.info("Arguments:") - for arg in vars(args): - logger.info("%s=%s", arg, getattr(args, arg)) - - client_python_version = args.client_python_version - client_sagemaker_pysdk_version = args.client_sagemaker_pysdk_version - job_conda_env = args.job_conda_env - pipeline_execution_id = args.pipeline_execution_id - dependency_settings = _DependencySettings.from_string(args.dependency_settings) - func_step_workspace = args.func_step_s3_dir - distribution = args.distribution - user_nproc_per_node = args.user_nproc_per_node - - conda_env = job_conda_env or os.getenv("SAGEMAKER_JOB_CONDA_ENV") - - RuntimeEnvironmentManager()._validate_python_version(client_python_version, conda_env) - - user = getpass.getuser() - if user != "root": - log_message = ( - "The job is running on non-root user: %s. Adding write permissions to the " - "following job output directories: %s." - ) - logger.info(log_message, user, JOB_OUTPUT_DIRS) - RuntimeEnvironmentManager().change_dir_permission( - dirs=JOB_OUTPUT_DIRS, new_permission="777" - ) - - if pipeline_execution_id: - _bootstrap_runtime_env_for_pipeline_step( - client_python_version, func_step_workspace, conda_env, dependency_settings - ) - else: - _bootstrap_runtime_env_for_remote_function( - client_python_version, conda_env, dependency_settings - ) - - RuntimeEnvironmentManager()._validate_sagemaker_pysdk_version( - client_sagemaker_pysdk_version - ) - - if os.path.exists(RESOURCE_CONFIG): - try: - logger.info("Found %s", RESOURCE_CONFIG) - with open(RESOURCE_CONFIG, "r") as f: - resource_config = json.load(f) - set_env( - resource_config=resource_config, - distribution=distribution, - user_nproc_per_node=user_nproc_per_node, - ) - except (json.JSONDecodeError, FileNotFoundError) as e: - # Optionally, you might want to log this error - logger.info("ERROR: Error processing %s: %s", RESOURCE_CONFIG, str(e)) - - exit_code = SUCCESS_EXIT_CODE - except Exception as e: # pylint: disable=broad-except - logger.exception("Error encountered while bootstrapping runtime environment: %s", e) - - _write_failure_reason_file(str(e)) - finally: - sys.exit(exit_code) - - -if __name__ == "__main__": - main(sys.argv[1:]) \ No newline at end of file diff --git a/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/mpi_utils_remote.py b/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/mpi_utils_remote.py deleted file mode 100644 index 79ddd4020b..0000000000 --- a/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/mpi_utils_remote.py +++ /dev/null @@ -1,252 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""An utils function for runtime environment. This must be kept independent of SageMaker PySDK""" -from __future__ import absolute_import - -import argparse -import json -import os -import subprocess -import sys -import time -from typing import List - -import paramiko - -if __package__ is None or __package__ == "": - from runtime_environment_manager import ( - get_logger, - ) -else: - from sagemaker.train.remote_function.runtime_environment.runtime_environment_manager import ( - get_logger, - ) - -SUCCESS_EXIT_CODE = 0 -DEFAULT_FAILURE_CODE = 1 - -FINISHED_STATUS_FILE = "/tmp/done.algo-1" -READY_FILE = "/tmp/ready.%s" -DEFAULT_SSH_PORT = 22 - -FAILURE_REASON_PATH = "/opt/ml/output/failure" -FINISHED_STATUS_FILE = "/tmp/done.algo-1" - -logger = get_logger() - - -class CustomHostKeyPolicy(paramiko.client.MissingHostKeyPolicy): - """Class to handle host key policy for SageMaker distributed training SSH connections. - - Example: - >>> client = paramiko.SSHClient() - >>> client.set_missing_host_key_policy(CustomHostKeyPolicy()) - >>> # Will succeed for SageMaker algorithm containers - >>> client.connect('algo-1234.internal') - >>> # Will raise SSHException for other unknown hosts - >>> client.connect('unknown-host') # raises SSHException - """ - - def missing_host_key(self, client, hostname, key): - """Accept host keys for algo-* hostnames, reject others. - - Args: - client: The SSHClient instance - hostname: The hostname attempting to connect - key: The host key - Raises: - paramiko.SSHException: If hostname doesn't match algo-* pattern - """ - if hostname.startswith("algo-"): - client.get_host_keys().add(hostname, key.get_name(), key) - return - raise paramiko.SSHException(f"Unknown host key for {hostname}") - - -def _parse_args(sys_args): - """Parses CLI arguments.""" - parser = argparse.ArgumentParser() - parser.add_argument("--job_ended", type=str, default="0") - args, _ = parser.parse_known_args(sys_args) - return args - - -def _can_connect(host: str, port: int = DEFAULT_SSH_PORT) -> bool: - """Check if the connection to the provided host and port is possible.""" - try: - with paramiko.SSHClient() as client: - client.load_system_host_keys() - client.set_missing_host_key_policy(CustomHostKeyPolicy()) - client.connect(host, port=port) - logger.info("Can connect to host %s", host) - return True - except Exception as e: # pylint: disable=W0703 - logger.info("Cannot connect to host %s", host) - logger.debug("Connection failed with exception: %s", e) - return False - - -def _write_file_to_host(host: str, status_file: str) -> bool: - """Write the a file to the provided host.""" - try: - logger.info("Writing %s to %s", status_file, host) - subprocess.run( - ["ssh", host, "touch", f"{status_file}"], - capture_output=True, - text=True, - check=True, - ) - logger.info("Finished writing status file") - return True - except subprocess.CalledProcessError: - logger.info("Cannot connect to %s", host) - return False - - -def _write_failure_reason_file(failure_msg): - """Create a file 'failure' with failure reason written if bootstrap runtime env failed. - - See: https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo.html - Args: - failure_msg: The content of file to be written. - """ - if not os.path.exists(FAILURE_REASON_PATH): - with open(FAILURE_REASON_PATH, "w") as f: - f.write("RuntimeEnvironmentError: " + failure_msg) - - -def _wait_for_master(master_host: str, port: int = DEFAULT_SSH_PORT, timeout: int = 300): - """Worker nodes wait until they can connect to the master node.""" - start_time = time.time() - while True: - logger.info("Worker is attempting to connect to the master node %s...", master_host) - if _can_connect(master_host, port): - logger.info("Worker can connect to master node %s.", master_host) - break - if time.time() - start_time > timeout: - raise TimeoutError("Timed out waiting for master %s to be reachable." % master_host) - - time.sleep(5) # Wait for 5 seconds before trying again - - -def _wait_for_status_file(status_file: str): - """Wait for the status file to be created.""" - logger.info("Waiting for status file %s", status_file) - while not os.path.exists(status_file): - time.sleep(30) - logger.info("Found status file %s", status_file) - - -def _wait_for_workers(worker_hosts: List[str], port: int = DEFAULT_SSH_PORT, timeout: int = 300): - """Master node waits until it can connect to all worker nodes.""" - start_time = time.time() - if not worker_hosts: - logger.info("No worker nodes to connect to.") - return - - while True: - logger.info("Master is attempting to connect to all workers...") - all_workers_connected = all( - _can_connect(worker, port) and os.path.exists(READY_FILE % worker) - for worker in worker_hosts - ) - - if all_workers_connected: - logger.info("Master can connect to all worker nodes.") - break - if time.time() - start_time > timeout: - raise TimeoutError("Timed out waiting for workers to be reachable.") - - time.sleep(5) # Wait for 5 seconds before trying again - - -def bootstrap_master_node(worker_hosts: List[str]): - """Bootstrap the master node.""" - logger.info("Bootstrapping master node...") - _wait_for_workers(worker_hosts) - - -def bootstrap_worker_node( - master_host: str, current_host: str, status_file: str = FINISHED_STATUS_FILE -): - """Bootstrap the worker nodes.""" - logger.info("Bootstrapping worker node...") - _wait_for_master(master_host) - _write_file_to_host(master_host, READY_FILE % current_host) - _wait_for_status_file(status_file) - - -def start_sshd_daemon(): - """Start the SSH daemon on the current node.""" - sshd_executable = "/usr/sbin/sshd" - - if not os.path.exists(sshd_executable): - raise RuntimeError("SSH daemon not found.") - - # Start the sshd in daemon mode (-D) - subprocess.Popen([sshd_executable, "-D"]) - logger.info("Started SSH daemon.") - - -def write_status_file_to_workers(worker_hosts: List[str], status_file: str = FINISHED_STATUS_FILE): - """Write the status file to all worker nodes.""" - for worker in worker_hosts: - retry = 0 - while not _write_file_to_host(worker, status_file): - time.sleep(5) - retry += 1 - if retry > 5: - raise TimeoutError("Timed out waiting for %s to be reachable." % worker) - logger.info("Retrying to write status file to %s", worker) - - -def main(sys_args=None): - """Entry point for bootstrap script""" - try: - args = _parse_args(sys_args) - - job_ended = args.job_ended - - main_host = os.environ["SM_MASTER_ADDR"] - current_host = os.environ["SM_CURRENT_HOST"] - - if job_ended == "0": - logger.info("Job is running, bootstrapping nodes") - - start_sshd_daemon() - - if current_host != main_host: - bootstrap_worker_node(main_host, current_host) - else: - sorted_hosts = json.loads(os.environ["SM_HOSTS"]) - worker_hosts = [host for host in sorted_hosts if host != main_host] - - bootstrap_master_node(worker_hosts) - else: - logger.info("Job ended, writing status file to workers") - - if current_host == main_host: - sorted_hosts = json.loads(os.environ["SM_HOSTS"]) - worker_hosts = [host for host in sorted_hosts if host != main_host] - - write_status_file_to_workers(worker_hosts) - except Exception as e: # pylint: disable=broad-except - logger.exception("Error encountered while bootstrapping runtime environment: %s", e) - - _write_failure_reason_file(str(e)) - - sys.exit(DEFAULT_FAILURE_CODE) - - -if __name__ == "__main__": - main(sys.argv[1:]) \ No newline at end of file diff --git a/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/runtime_environment_manager.py b/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/runtime_environment_manager.py deleted file mode 100644 index f4d95f5412..0000000000 --- a/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/runtime_environment_manager.py +++ /dev/null @@ -1,467 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""SageMaker runtime environment module. This must be kept independent of SageMaker PySDK""" - -from __future__ import absolute_import - - -import logging -import sys -import shlex -import os -import subprocess -import time -import dataclasses -import json - - -class _UTCFormatter(logging.Formatter): - """Class that overrides the default local time provider in log formatter.""" - - converter = time.gmtime - - -def get_logger(): - """Return a logger with the name 'sagemaker'""" - sagemaker_logger = logging.getLogger("sagemaker.remote_function") - if len(sagemaker_logger.handlers) == 0: - sagemaker_logger.setLevel(logging.INFO) - handler = logging.StreamHandler() - formatter = _UTCFormatter("%(asctime)s %(name)s %(levelname)-8s %(message)s") - handler.setFormatter(formatter) - sagemaker_logger.addHandler(handler) - # don't stream logs with the root logger handler - sagemaker_logger.propagate = 0 - - return sagemaker_logger - - -logger = get_logger() - - -@dataclasses.dataclass -class _DependencySettings: - """Dependency settings for the remote function. - - Instructs the runtime environment script on how to handle dependencies. - If ``dependency_file`` is set, the runtime environment script will attempt - to install the dependencies. If ``dependency_file`` is not set, the runtime - environment script will assume no dependencies are required. - """ - - dependency_file: str = None - - def to_string(self): - """Converts the dependency settings to a string.""" - return json.dumps(dataclasses.asdict(self)) - - @staticmethod - def from_string(dependency_settings_string): - """Converts a json string to dependency settings. - - Args: - dependency_settings_string (str): The json string to convert. - """ - if dependency_settings_string is None: - return None - dependency_settings_dict = json.loads(dependency_settings_string) - return _DependencySettings(dependency_settings_dict.get("dependency_file")) - - @staticmethod - def from_dependency_file_path(dependency_file_path): - """Converts a dependency file path to dependency settings. - - Args: - dependency_file_path (str): The path to the dependency file. - """ - if dependency_file_path is None: - return _DependencySettings() - if dependency_file_path == "auto_capture": - return _DependencySettings("env_snapshot.yml") - return _DependencySettings(os.path.basename(dependency_file_path)) - - -class RuntimeEnvironmentManager: - """Runtime Environment Manager class to manage runtime environment.""" - - def snapshot(self, dependencies: str = None) -> str: - """Creates snapshot of the user's environment - - If a req.txt or conda.yml file is provided, it verifies their existence and - returns the local file path - If ``auto_capture`` is set, this method will take the snapshot of - user's dependencies installed in the local runtime. - Current support for ``auto_capture``: - * conda env, generate a yml file and return it's local path - - Args: - dependencies (str): Local path where dependencies file exists. - - Returns: - file path of the existing or generated dependencies file - """ - - # No additional dependencies specified - if dependencies is None: - return None - - if dependencies == "auto_capture": - return self._capture_from_local_runtime() - - # Dependencies specified as either req.txt or conda_env.yml - if ( - dependencies.endswith(".txt") - or dependencies.endswith(".yml") - or dependencies.endswith(".yaml") - ): - self._is_file_exists(dependencies) - return dependencies - - raise ValueError(f'Invalid dependencies provided: "{dependencies}"') - - def _capture_from_local_runtime(self) -> str: - """Generates dependencies list from the user's local runtime. - - Raises RuntimeEnvironmentError if not able to. - - Currently supports: conda environments - """ - - # Try to capture dependencies from the conda environment, if any. - conda_env_name = self._get_active_conda_env_name() - conda_env_prefix = self._get_active_conda_env_prefix() - if conda_env_name: - logger.info("Found conda_env_name: '%s'", conda_env_name) - elif conda_env_prefix: - logger.info("Found conda_env_prefix: '%s'", conda_env_prefix) - else: - raise ValueError("No conda environment seems to be active.") - - if conda_env_name == "base": - logger.warning( - "We recommend using an environment other than base to " - "isolate your project dependencies from conda dependencies" - ) - - local_dependencies_path = os.path.join(os.getcwd(), "env_snapshot.yml") - self._export_conda_env_from_prefix(conda_env_prefix, local_dependencies_path) - - return local_dependencies_path - - def _get_active_conda_env_prefix(self) -> str: - """Returns the conda prefix from the set environment variable. None otherwise.""" - return os.getenv("CONDA_PREFIX") - - def _get_active_conda_env_name(self) -> str: - """Returns the conda environment name from the set environment variable. None otherwise.""" - return os.getenv("CONDA_DEFAULT_ENV") - - def bootstrap( - self, local_dependencies_file: str, client_python_version: str, conda_env: str = None - ): - """Bootstraps the runtime environment by installing the additional dependencies if any. - - Args: - local_dependencies_file (str): path where dependencies file exists. - conda_env (str): conda environment to be activated. Default is None. - - Returns: None - """ - - if local_dependencies_file.endswith(".txt"): - if conda_env: - self._install_req_txt_in_conda_env(conda_env, local_dependencies_file) - self._write_conda_env_to_file(conda_env) - - else: - self._install_requirements_txt(local_dependencies_file, _python_executable()) - - elif local_dependencies_file.endswith(".yml") or local_dependencies_file.endswith(".yaml"): - if conda_env: - self._update_conda_env(conda_env, local_dependencies_file) - else: - conda_env = "sagemaker-runtime-env" - self._create_conda_env(conda_env, local_dependencies_file) - self._validate_python_version(client_python_version, conda_env) - self._write_conda_env_to_file(conda_env) - - def run_pre_exec_script(self, pre_exec_script_path: str): - """Runs script of pre-execution commands if existing. - - Args: - pre_exec_script_path (str): Path to pre-execution command script file. - """ - if os.path.isfile(pre_exec_script_path): - logger.info("Running pre-execution commands in '%s'", pre_exec_script_path) - return_code, error_logs = _run_pre_execution_command_script(pre_exec_script_path) - - if return_code: - error_message = ( - f"Encountered error while running pre-execution commands. Reason: {error_logs}" - ) - raise RuntimeEnvironmentError(error_message) - else: - logger.info( - "'%s' does not exist. Assuming no pre-execution commands to run", - pre_exec_script_path, - ) - - def change_dir_permission(self, dirs: list, new_permission: str): - """Change the permission of given directories - - Args: - dirs (list[str]): A list of directories for permission update. - new_permission (str): The new permission for the given directories. - """ - - _ERROR_MSG_PREFIX = "Failed to change directory permissions due to: " - command = ["sudo", "chmod", "-R", new_permission] + dirs - logger.info("Executing '%s'.", " ".join(command)) - - try: - subprocess.run(command, check=True, stderr=subprocess.PIPE) - except subprocess.CalledProcessError as called_process_err: - err_msg = called_process_err.stderr.decode("utf-8") - raise RuntimeEnvironmentError(f"{_ERROR_MSG_PREFIX} {err_msg}") - except FileNotFoundError as file_not_found_err: - if "[Errno 2] No such file or directory: 'sudo'" in str(file_not_found_err): - raise RuntimeEnvironmentError( - f"{_ERROR_MSG_PREFIX} {file_not_found_err}. " - "Please contact the image owner to install 'sudo' in the job container " - "and provide sudo privilege to the container user." - ) - raise RuntimeEnvironmentError(file_not_found_err) - - def _is_file_exists(self, dependencies): - """Check whether the dependencies file exists at the given location. - - Raises error if not - """ - if not os.path.isfile(dependencies): - raise ValueError(f'No dependencies file named "{dependencies}" was found.') - - def _install_requirements_txt(self, local_path, python_executable): - """Install requirements.txt file""" - cmd = f"{python_executable} -m pip install -r {local_path} -U" - logger.info("Running command: '%s' in the dir: '%s' ", cmd, os.getcwd()) - _run_shell_cmd(cmd) - logger.info("Command %s ran successfully", cmd) - - def _create_conda_env(self, env_name, local_path): - """Create conda env using conda yml file""" - - cmd = f"{self._get_conda_exe()} env create -n {env_name} --file {local_path}" - logger.info("Creating conda environment %s using: %s.", env_name, cmd) - _run_shell_cmd(cmd) - logger.info("Conda environment %s created successfully.", env_name) - - def _install_req_txt_in_conda_env(self, env_name, local_path): - """Install requirements.txt in the given conda environment""" - - cmd = f"{self._get_conda_exe()} run -n {env_name} pip install -r {local_path} -U" - logger.info("Activating conda env and installing requirements: %s", cmd) - _run_shell_cmd(cmd) - logger.info("Requirements installed successfully in conda env %s", env_name) - - def _update_conda_env(self, env_name, local_path): - """Update conda env using conda yml file""" - - cmd = f"{self._get_conda_exe()} env update -n {env_name} --file {local_path}" - logger.info("Updating conda env: %s", cmd) - _run_shell_cmd(cmd) - logger.info("Conda env %s updated succesfully", env_name) - - def _export_conda_env_from_prefix(self, prefix, local_path): - """Export the conda env to a conda yml file""" - - cmd = f"{self._get_conda_exe()} env export -p {prefix} --no-builds > {local_path}" - logger.info("Exporting conda environment: %s", cmd) - _run_shell_cmd(cmd) - logger.info("Conda environment %s exported successfully", prefix) - - def _write_conda_env_to_file(self, env_name): - """Writes conda env to the text file""" - - file_name = "remote_function_conda_env.txt" - file_path = os.path.join(os.getcwd(), file_name) - with open(file_path, "w") as output_file: - output_file.write(env_name) - - def _get_conda_exe(self): - """Checks whether conda or mamba is available to use""" - - if not subprocess.Popen(["which", "mamba"]).wait(): - return "mamba" - if not subprocess.Popen(["which", "conda"]).wait(): - return "conda" - raise ValueError("Neither conda nor mamba is installed on the image") - - def _python_version_in_conda_env(self, env_name): - """Returns python version inside a conda environment""" - cmd = f"{self._get_conda_exe()} run -n {env_name} python --version" - try: - output = ( - subprocess.check_output(shlex.split(cmd), stderr=subprocess.STDOUT) - .decode("utf-8") - .strip() - ) - # convert 'Python 3.7.16' to [3, 7, 16] - version = output.split("Python ")[1].split(".") - return version[0] + "." + version[1] - except subprocess.CalledProcessError as e: - raise RuntimeEnvironmentError(e.output) - - def _current_python_version(self): - """Returns the current python version where program is running""" - - return f"{sys.version_info.major}.{sys.version_info.minor}".strip() - - def _current_sagemaker_pysdk_version(self): - """Returns the current sagemaker python sdk version where program is running""" - try: - from importlib import metadata - return metadata.version("sagemaker") - except Exception: - return "3.0.0.dev0" # Development version fallback - - def _validate_python_version(self, client_python_version: str, conda_env: str = None): - """Validate the python version - - Validates if the python version where remote function runs - matches the one used on client side. - """ - if conda_env: - job_python_version = self._python_version_in_conda_env(conda_env) - else: - job_python_version = self._current_python_version() - if client_python_version.strip() != job_python_version.strip(): - raise RuntimeEnvironmentError( - f"Python version found in the container is '{job_python_version}' which " - f"does not match python version '{client_python_version}' on the local client. " - f"Please make sure that the python version used in the training container " - f"is same as the local python version." - ) - - def _validate_sagemaker_pysdk_version(self, client_sagemaker_pysdk_version): - """Validate the sagemaker python sdk version - - Validates if the sagemaker python sdk version where remote function runs - matches the one used on client side. - Otherwise, log a warning to call out that unexpected behaviors - may occur in this case. - """ - job_sagemaker_pysdk_version = self._current_sagemaker_pysdk_version() - if ( - client_sagemaker_pysdk_version - and client_sagemaker_pysdk_version != job_sagemaker_pysdk_version - ): - logger.warning( - "Inconsistent sagemaker versions found: " - "sagemaker python sdk version found in the container is " - "'%s' which does not match the '%s' on the local client. " - "Please make sure that the sagemaker version used in the training container " - "is the same as the local sagemaker version in case of unexpected behaviors.", - job_sagemaker_pysdk_version, - client_sagemaker_pysdk_version, - ) - - -def _run_and_get_output_shell_cmd(cmd: str) -> str: - """Run and return the output of the given shell command""" - return subprocess.check_output(shlex.split(cmd), stderr=subprocess.STDOUT).decode("utf-8") - - -def _run_pre_execution_command_script(script_path: str): - """This method runs a given shell script using subprocess - - Raises RuntimeEnvironmentError if the shell script fails - """ - current_dir = os.path.dirname(script_path) - - process = subprocess.Popen( - ["/bin/bash", "-eu", script_path], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - cwd=current_dir, - ) - - _log_output(process) - error_logs = _log_error(process) - return_code = process.wait() - - return return_code, error_logs - - -def _run_shell_cmd(cmd: str): - """This method runs a given shell command using subprocess - - Raises RuntimeEnvironmentError if the command fails - """ - - process = subprocess.Popen((cmd), stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True) - - _log_output(process) - error_logs = _log_error(process) - return_code = process.wait() - if return_code: - error_message = f"Encountered error while running command '{cmd}'. Reason: {error_logs}" - raise RuntimeEnvironmentError(error_message) - - -def _log_output(process): - """This method takes in Popen process and logs the output of that process""" - with process.stdout as pipe: - for line in iter(pipe.readline, b""): - logger.info(str(line, "UTF-8")) - - -def _log_error(process): - """This method takes in Popen process and logs the error of that process. - - Returns those logs as a string - """ - - error_logs = "" - with process.stderr as pipe: - for line in iter(pipe.readline, b""): - error_str = str(line, "UTF-8") - if "ERROR:" in error_str: - logger.error(error_str) - else: - logger.warning(error_str) - error_logs = error_logs + error_str - - return error_logs - - -def _python_executable(): - """Return the real path for the Python executable, if it exists. - - Return RuntimeEnvironmentError otherwise. - - Returns: - (str): The real path of the current Python executable. - """ - if not sys.executable: - raise RuntimeEnvironmentError( - "Failed to retrieve the path for the Python executable binary" - ) - return sys.executable - - -class RuntimeEnvironmentError(Exception): - """The base exception class for bootstrap env excepitons""" - - def __init__(self, message): - self.message = message - super().__init__(self.message) \ No newline at end of file diff --git a/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/spark_app.py b/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/spark_app.py deleted file mode 100644 index 6d4eaeb18e..0000000000 --- a/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/spark_app.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""This is a simple scrip of spark which invokes the pickled remote function""" -from __future__ import absolute_import - -from sagemaker.train.remote_function import invoke_function - -invoke_function.main() diff --git a/sagemaker-train/src/sagemaker/train/remote_function/spark_config.py b/sagemaker-train/src/sagemaker/train/remote_function/spark_config.py deleted file mode 100644 index b5083b0566..0000000000 --- a/sagemaker-train/src/sagemaker/train/remote_function/spark_config.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -""" -DEPRECATED: This module has been moved to sagemaker.core.remote_function.spark_config - -This is a backward compatibility shim. -""" -from __future__ import absolute_import - -import warnings - -# Backward compatibility: re-export from core -from sagemaker.core.remote_function.spark_config import * # noqa: F401, F403 - -warnings.warn( - "sagemaker.train.remote_function.spark_config has been moved to sagemaker.core.remote_function.spark_config. " - "Please update your imports. This shim will be removed in a future version.", - DeprecationWarning, - stacklevel=2 -) diff --git a/sagemaker-train/tests/unit/train/remote_function/__init__.py b/sagemaker-train/tests/unit/train/remote_function/__init__.py deleted file mode 100644 index a7eb53c855..0000000000 --- a/sagemaker-train/tests/unit/train/remote_function/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Remote function unit tests.""" -from __future__ import absolute_import diff --git a/sagemaker-train/tests/unit/train/remote_function/test_bootstrap_runtime_environment.py b/sagemaker-train/tests/unit/train/remote_function/test_bootstrap_runtime_environment.py deleted file mode 100644 index c74d6e4152..0000000000 --- a/sagemaker-train/tests/unit/train/remote_function/test_bootstrap_runtime_environment.py +++ /dev/null @@ -1,677 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Tests for bootstrap_runtime_environment module.""" -from __future__ import absolute_import - -import json -import os -import pytest -import subprocess -from unittest.mock import patch, MagicMock, mock_open, call - -from sagemaker.train.remote_function.runtime_environment.bootstrap_runtime_environment import ( - _parse_args, - _bootstrap_runtime_env_for_remote_function, - _bootstrap_runtime_env_for_pipeline_step, - _handle_pre_exec_scripts, - _install_dependencies, - _unpack_user_workspace, - _write_failure_reason_file, - log_key_value, - log_env_variables, - mask_sensitive_info, - num_cpus, - num_gpus, - num_neurons, - safe_serialize, - set_env, - main, - SUCCESS_EXIT_CODE, - DEFAULT_FAILURE_CODE, - FAILURE_REASON_PATH, - REMOTE_FUNCTION_WORKSPACE, - BASE_CHANNEL_PATH, - JOB_REMOTE_FUNCTION_WORKSPACE, - SCRIPT_AND_DEPENDENCIES_CHANNEL_NAME, - SENSITIVE_KEYWORDS, - HIDDEN_VALUE, -) -from sagemaker.train.remote_function.runtime_environment.runtime_environment_manager import ( - _DependencySettings, -) - - -class TestParseArgs: - """Test _parse_args function.""" - - def test_parse_required_args(self): - """Test parsing required arguments.""" - args = [ - "--client_python_version", "3.8", - ] - parsed = _parse_args(args) - assert parsed.client_python_version == "3.8" - - def test_parse_all_args(self): - """Test parsing all arguments.""" - args = [ - "--job_conda_env", "my-env", - "--client_python_version", "3.9", - "--client_sagemaker_pysdk_version", "2.100.0", - "--pipeline_execution_id", "exec-123", - "--dependency_settings", '{"dependency_file": "requirements.txt"}', - "--func_step_s3_dir", "s3://bucket/func", - "--distribution", "torchrun", - "--user_nproc_per_node", "4", - ] - parsed = _parse_args(args) - assert parsed.job_conda_env == "my-env" - assert parsed.client_python_version == "3.9" - assert parsed.client_sagemaker_pysdk_version == "2.100.0" - assert parsed.pipeline_execution_id == "exec-123" - assert parsed.dependency_settings == '{"dependency_file": "requirements.txt"}' - assert parsed.func_step_s3_dir == "s3://bucket/func" - assert parsed.distribution == "torchrun" - assert parsed.user_nproc_per_node == "4" - - def test_parse_default_values(self): - """Test default values for optional arguments.""" - args = [ - "--client_python_version", "3.8", - ] - parsed = _parse_args(args) - assert parsed.job_conda_env is None - assert parsed.client_sagemaker_pysdk_version is None - assert parsed.pipeline_execution_id is None - assert parsed.dependency_settings is None - assert parsed.func_step_s3_dir is None - assert parsed.distribution is None - assert parsed.user_nproc_per_node is None - - -class TestLogKeyValue: - """Test log_key_value function.""" - - @patch("sagemaker.train.remote_function.runtime_environment.bootstrap_runtime_environment.logger") - def test_logs_regular_value(self, mock_logger): - """Test logs regular key-value pair.""" - log_key_value("my_name", "my_value") - mock_logger.info.assert_called_once_with("%s=%s", "my_name", "my_value") - - @patch("sagemaker.train.remote_function.runtime_environment.bootstrap_runtime_environment.logger") - def test_masks_sensitive_key(self, mock_logger): - """Test masks sensitive keywords.""" - for keyword in ["PASSWORD", "SECRET", "TOKEN", "KEY", "PRIVATE", "CREDENTIALS"]: - mock_logger.reset_mock() - log_key_value(f"my_{keyword}", "sensitive_value") - mock_logger.info.assert_called_once_with("%s=%s", f"my_{keyword}", HIDDEN_VALUE) - - @patch("sagemaker.train.remote_function.runtime_environment.bootstrap_runtime_environment.logger") - def test_logs_dict_value(self, mock_logger): - """Test logs dictionary value.""" - value = {"field1": "value1", "field2": "value2"} - log_key_value("my_config", value) - mock_logger.info.assert_called_once_with("%s=%s", "my_config", json.dumps(value)) - - @patch("sagemaker.train.remote_function.runtime_environment.bootstrap_runtime_environment.logger") - def test_logs_json_string_value(self, mock_logger): - """Test logs JSON string value.""" - value = '{"key1": "value1"}' - log_key_value("my_key", value) - mock_logger.info.assert_called_once() - - -class TestLogEnvVariables: - """Test log_env_variables function.""" - - @patch("sagemaker.train.remote_function.runtime_environment.bootstrap_runtime_environment.log_key_value") - @patch.dict("os.environ", {"ENV_VAR1": "value1", "ENV_VAR2": "value2"}) - def test_logs_env_and_dict_variables(self, mock_log_kv): - """Test logs both environment and dictionary variables.""" - env_dict = {"DICT_VAR1": "dict_value1", "DICT_VAR2": "dict_value2"} - log_env_variables(env_dict) - - # Should be called for env vars and dict vars - assert mock_log_kv.call_count >= 4 - - -class TestMaskSensitiveInfo: - """Test mask_sensitive_info function.""" - - def test_masks_sensitive_keys_in_dict(self): - """Test masks sensitive keys in dictionary.""" - data = { - "username": "user", - "password": "secret123", - "api_key": "key123", - } - result = mask_sensitive_info(data) - assert result["username"] == "user" - assert result["password"] == HIDDEN_VALUE - assert result["api_key"] == HIDDEN_VALUE - - def test_masks_nested_dict(self): - """Test masks sensitive keys in nested dictionary.""" - data = { - "config": { - "username": "user", - "secret": "secret123", - } - } - result = mask_sensitive_info(data) - assert result["config"]["username"] == "user" - assert result["config"]["secret"] == HIDDEN_VALUE - - def test_returns_non_dict_unchanged(self): - """Test returns non-dictionary unchanged.""" - data = "string_value" - result = mask_sensitive_info(data) - assert result == "string_value" - - -class TestNumCpus: - """Test num_cpus function.""" - - @patch("multiprocessing.cpu_count") - def test_returns_cpu_count(self, mock_cpu_count): - """Test returns CPU count.""" - mock_cpu_count.return_value = 8 - assert num_cpus() == 8 - - -class TestNumGpus: - """Test num_gpus function.""" - - @patch("subprocess.check_output") - def test_returns_gpu_count(self, mock_check_output): - """Test returns GPU count.""" - mock_check_output.return_value = b"GPU 0: Tesla V100\nGPU 1: Tesla V100\n" - assert num_gpus() == 2 - - @patch("subprocess.check_output") - def test_returns_zero_on_error(self, mock_check_output): - """Test returns zero when nvidia-smi fails.""" - mock_check_output.side_effect = subprocess.CalledProcessError(1, "nvidia-smi") - assert num_gpus() == 0 - - @patch("subprocess.check_output") - def test_returns_zero_on_os_error(self, mock_check_output): - """Test returns zero when nvidia-smi not found.""" - mock_check_output.side_effect = OSError() - assert num_gpus() == 0 - - -class TestNumNeurons: - """Test num_neurons function.""" - - @patch("subprocess.check_output") - def test_returns_neuron_count(self, mock_check_output): - """Test returns neuron core count.""" - mock_output = json.dumps([{"nc_count": 2}, {"nc_count": 4}]) - mock_check_output.return_value = mock_output.encode("utf-8") - assert num_neurons() == 6 - - @patch("subprocess.check_output") - def test_returns_zero_on_os_error(self, mock_check_output): - """Test returns zero when neuron-ls not found.""" - mock_check_output.side_effect = OSError() - assert num_neurons() == 0 - - @patch("subprocess.check_output") - def test_returns_zero_on_called_process_error(self, mock_check_output): - """Test returns zero when neuron-ls fails.""" - error = subprocess.CalledProcessError(1, "neuron-ls") - error.output = b"error=No neuron devices found" - mock_check_output.side_effect = error - assert num_neurons() == 0 - - -class TestSafeSerialize: - """Test safe_serialize function.""" - - def test_returns_string_as_is(self): - """Test returns string without quotes.""" - assert safe_serialize("test_string") == "test_string" - - def test_serializes_dict(self): - """Test serializes dictionary.""" - data = {"key": "value"} - assert safe_serialize(data) == '{"key": "value"}' - - def test_serializes_list(self): - """Test serializes list.""" - data = [1, 2, 3] - assert safe_serialize(data) == "[1, 2, 3]" - - def test_returns_str_for_non_serializable(self): - """Test returns str() for non-serializable objects.""" - class CustomObj: - def __str__(self): - return "custom_object" - - obj = CustomObj() - assert safe_serialize(obj) == "custom_object" - - -class TestSetEnv: - """Test set_env function.""" - - @patch("builtins.open", new_callable=mock_open) - @patch("sagemaker.train.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus") - @patch("sagemaker.train.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus") - @patch("sagemaker.train.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons") - @patch("sagemaker.train.remote_function.runtime_environment.bootstrap_runtime_environment.log_env_variables") - @patch.dict("os.environ", {"TRAINING_JOB_NAME": "test-job"}) - def test_sets_basic_env_vars(self, mock_log_env, mock_neurons, mock_gpus, mock_cpus, mock_file): - """Test sets basic environment variables.""" - mock_cpus.return_value = 8 - mock_gpus.return_value = 2 - mock_neurons.return_value = 0 - - resource_config = { - "current_host": "algo-1", - "current_instance_type": "ml.p3.2xlarge", - "hosts": ["algo-1", "algo-2"], - "network_interface_name": "eth0", - } - - set_env(resource_config) - - mock_file.assert_called_once() - mock_log_env.assert_called_once() - - @patch("builtins.open", new_callable=mock_open) - @patch("sagemaker.train.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus") - @patch("sagemaker.train.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus") - @patch("sagemaker.train.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons") - @patch("sagemaker.train.remote_function.runtime_environment.bootstrap_runtime_environment.log_env_variables") - @patch.dict("os.environ", {"TRAINING_JOB_NAME": "test-job"}) - def test_sets_torchrun_distribution_vars(self, mock_log_env, mock_neurons, mock_gpus, mock_cpus, mock_file): - """Test sets torchrun distribution environment variables.""" - mock_cpus.return_value = 8 - mock_gpus.return_value = 2 - mock_neurons.return_value = 0 - - resource_config = { - "current_host": "algo-1", - "current_instance_type": "ml.p4d.24xlarge", - "hosts": ["algo-1"], - "network_interface_name": "eth0", - } - - set_env(resource_config, distribution="torchrun") - - # Verify file was written - mock_file.assert_called_once() - - @patch("builtins.open", new_callable=mock_open) - @patch("sagemaker.train.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus") - @patch("sagemaker.train.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus") - @patch("sagemaker.train.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons") - @patch("sagemaker.train.remote_function.runtime_environment.bootstrap_runtime_environment.log_env_variables") - @patch.dict("os.environ", {"TRAINING_JOB_NAME": "test-job"}) - def test_sets_mpirun_distribution_vars(self, mock_log_env, mock_neurons, mock_gpus, mock_cpus, mock_file): - """Test sets mpirun distribution environment variables.""" - mock_cpus.return_value = 8 - mock_gpus.return_value = 2 - mock_neurons.return_value = 0 - - resource_config = { - "current_host": "algo-1", - "current_instance_type": "ml.p3.2xlarge", - "hosts": ["algo-1", "algo-2"], - "network_interface_name": "eth0", - } - - set_env(resource_config, distribution="mpirun") - - mock_file.assert_called_once() - - @patch("builtins.open", new_callable=mock_open) - @patch("sagemaker.train.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus") - @patch("sagemaker.train.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus") - @patch("sagemaker.train.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons") - @patch("sagemaker.train.remote_function.runtime_environment.bootstrap_runtime_environment.log_env_variables") - @patch.dict("os.environ", {"TRAINING_JOB_NAME": "test-job"}) - def test_uses_user_nproc_per_node(self, mock_log_env, mock_neurons, mock_gpus, mock_cpus, mock_file): - """Test uses user-specified nproc_per_node.""" - mock_cpus.return_value = 8 - mock_gpus.return_value = 2 - mock_neurons.return_value = 0 - - resource_config = { - "current_host": "algo-1", - "current_instance_type": "ml.p3.2xlarge", - "hosts": ["algo-1"], - "network_interface_name": "eth0", - } - - set_env(resource_config, user_nproc_per_node="4") - - mock_file.assert_called_once() - - -class TestWriteFailureReasonFile: - """Test _write_failure_reason_file function.""" - - @patch("builtins.open", new_callable=mock_open) - @patch("os.path.exists") - def test_writes_failure_file(self, mock_exists, mock_file): - """Test writes failure reason file.""" - mock_exists.return_value = False - - _write_failure_reason_file("Test error message") - - mock_file.assert_called_once_with(FAILURE_REASON_PATH, "w") - mock_file().write.assert_called_once_with("RuntimeEnvironmentError: Test error message") - - @patch("builtins.open", new_callable=mock_open) - @patch("os.path.exists") - def test_does_not_write_if_exists(self, mock_exists, mock_file): - """Test does not write if failure file already exists.""" - mock_exists.return_value = True - - _write_failure_reason_file("Test error message") - - mock_file.assert_not_called() - - -class TestUnpackUserWorkspace: - """Test _unpack_user_workspace function.""" - - @patch("os.path.exists") - def test_returns_none_if_dir_not_exists(self, mock_exists): - """Test returns None if workspace directory doesn't exist.""" - mock_exists.return_value = False - - result = _unpack_user_workspace() - - assert result is None - - @patch("os.path.isfile") - @patch("os.path.exists") - def test_returns_none_if_archive_not_exists(self, mock_exists, mock_isfile): - """Test returns None if workspace archive doesn't exist.""" - mock_exists.return_value = True - mock_isfile.return_value = False - - result = _unpack_user_workspace() - - assert result is None - - @patch("shutil.unpack_archive") - @patch("os.path.isfile") - @patch("os.path.exists") - @patch("os.getcwd") - def test_unpacks_workspace_successfully(self, mock_getcwd, mock_exists, mock_isfile, mock_unpack): - """Test unpacks workspace successfully.""" - mock_getcwd.return_value = "/tmp/workspace" - mock_exists.return_value = True - mock_isfile.return_value = True - - result = _unpack_user_workspace() - - mock_unpack.assert_called_once() - assert result is not None - - -class TestHandlePreExecScripts: - """Test _handle_pre_exec_scripts function.""" - - @patch("sagemaker.train.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager") - def test_runs_pre_exec_script(self, mock_manager_class): - """Test runs pre-execution script.""" - mock_manager = MagicMock() - mock_manager_class.return_value = mock_manager - - _handle_pre_exec_scripts("/tmp/scripts") - - mock_manager.run_pre_exec_script.assert_called_once() - - -class TestInstallDependencies: - """Test _install_dependencies function.""" - - @patch("sagemaker.train.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager") - def test_installs_with_dependency_settings(self, mock_manager_class): - """Test installs dependencies with dependency settings.""" - mock_manager = MagicMock() - mock_manager_class.return_value = mock_manager - - dep_settings = _DependencySettings(dependency_file="requirements.txt") - - _install_dependencies( - "/tmp/deps", - "my-env", - "3.8", - "channel", - dep_settings - ) - - mock_manager.bootstrap.assert_called_once() - - @patch("sagemaker.train.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager") - def test_skips_if_no_dependency_file(self, mock_manager_class): - """Test skips installation if no dependency file.""" - mock_manager = MagicMock() - mock_manager_class.return_value = mock_manager - - dep_settings = _DependencySettings(dependency_file=None) - - _install_dependencies( - "/tmp/deps", - "my-env", - "3.8", - "channel", - dep_settings - ) - - mock_manager.bootstrap.assert_not_called() - - @patch("os.listdir") - @patch("sagemaker.train.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager") - def test_finds_dependency_file_legacy(self, mock_manager_class, mock_listdir): - """Test finds dependency file in legacy mode.""" - mock_manager = MagicMock() - mock_manager_class.return_value = mock_manager - mock_listdir.return_value = ["requirements.txt", "script.py"] - - _install_dependencies( - "/tmp/deps", - "my-env", - "3.8", - "channel", - None - ) - - mock_manager.bootstrap.assert_called_once() - - -class TestBootstrapRuntimeEnvForRemoteFunction: - """Test _bootstrap_runtime_env_for_remote_function function.""" - - @patch("sagemaker.train.remote_function.runtime_environment.bootstrap_runtime_environment._install_dependencies") - @patch("sagemaker.train.remote_function.runtime_environment.bootstrap_runtime_environment._handle_pre_exec_scripts") - @patch("sagemaker.train.remote_function.runtime_environment.bootstrap_runtime_environment._unpack_user_workspace") - def test_bootstraps_successfully(self, mock_unpack, mock_handle_scripts, mock_install): - """Test bootstraps runtime environment successfully.""" - mock_unpack.return_value = "/tmp/workspace" - - _bootstrap_runtime_env_for_remote_function("3.8", "my-env", None) - - mock_unpack.assert_called_once() - mock_handle_scripts.assert_called_once() - mock_install.assert_called_once() - - @patch("sagemaker.train.remote_function.runtime_environment.bootstrap_runtime_environment._unpack_user_workspace") - def test_returns_early_if_no_workspace(self, mock_unpack): - """Test returns early if no workspace to unpack.""" - mock_unpack.return_value = None - - _bootstrap_runtime_env_for_remote_function("3.8", "my-env", None) - - mock_unpack.assert_called_once() - - -class TestBootstrapRuntimeEnvForPipelineStep: - """Test _bootstrap_runtime_env_for_pipeline_step function.""" - - @patch("sagemaker.train.remote_function.runtime_environment.bootstrap_runtime_environment._install_dependencies") - @patch("sagemaker.train.remote_function.runtime_environment.bootstrap_runtime_environment._handle_pre_exec_scripts") - @patch("shutil.copy") - @patch("os.listdir") - @patch("os.path.exists") - @patch("os.mkdir") - @patch("sagemaker.train.remote_function.runtime_environment.bootstrap_runtime_environment._unpack_user_workspace") - def test_bootstraps_with_workspace(self, mock_unpack, mock_mkdir, mock_exists, mock_listdir, mock_copy, mock_handle_scripts, mock_install): - """Test bootstraps pipeline step with workspace.""" - mock_unpack.return_value = "/tmp/workspace" - mock_exists.return_value = True - mock_listdir.return_value = ["requirements.txt"] - - _bootstrap_runtime_env_for_pipeline_step("3.8", "func_step", "my-env", None) - - mock_unpack.assert_called_once() - mock_handle_scripts.assert_called_once() - mock_install.assert_called_once() - - @patch("sagemaker.train.remote_function.runtime_environment.bootstrap_runtime_environment._install_dependencies") - @patch("sagemaker.train.remote_function.runtime_environment.bootstrap_runtime_environment._handle_pre_exec_scripts") - @patch("os.path.exists") - @patch("os.mkdir") - @patch("os.getcwd") - @patch("sagemaker.train.remote_function.runtime_environment.bootstrap_runtime_environment._unpack_user_workspace") - def test_creates_workspace_if_none(self, mock_unpack, mock_getcwd, mock_mkdir, mock_exists, mock_handle_scripts, mock_install): - """Test creates workspace directory if none exists.""" - mock_unpack.return_value = None - mock_getcwd.return_value = "/tmp" - mock_exists.return_value = False - - _bootstrap_runtime_env_for_pipeline_step("3.8", "func_step", "my-env", None) - - mock_mkdir.assert_called_once() - - -class TestMain: - """Test main function.""" - - @patch("sagemaker.train.remote_function.runtime_environment.bootstrap_runtime_environment.set_env") - @patch("builtins.open", new_callable=mock_open, read_data='{"current_host": "algo-1", "current_instance_type": "ml.m5.xlarge", "hosts": ["algo-1"], "network_interface_name": "eth0"}') - @patch("os.path.exists") - @patch("sagemaker.train.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager") - @patch("sagemaker.train.remote_function.runtime_environment.bootstrap_runtime_environment._bootstrap_runtime_env_for_remote_function") - @patch("getpass.getuser") - @patch("sagemaker.train.remote_function.runtime_environment.bootstrap_runtime_environment._parse_args") - def test_main_success(self, mock_parse_args, mock_getuser, mock_bootstrap, mock_manager_class, mock_exists, mock_file, mock_set_env): - """Test main function successful execution.""" - mock_getuser.return_value = "root" - mock_exists.return_value = True - mock_manager = MagicMock() - mock_manager_class.return_value = mock_manager - - # Mock parsed args - mock_args = MagicMock() - mock_args.client_python_version = "3.8" - mock_args.client_sagemaker_pysdk_version = None - mock_args.job_conda_env = None - mock_args.pipeline_execution_id = None - mock_args.dependency_settings = None - mock_args.func_step_s3_dir = None - mock_args.distribution = None - mock_args.user_nproc_per_node = None - mock_parse_args.return_value = mock_args - - args = [ - "--client_python_version", "3.8", - ] - - with pytest.raises(SystemExit) as exc_info: - main(args) - - assert exc_info.value.code == SUCCESS_EXIT_CODE - mock_bootstrap.assert_called_once() - - @patch("sagemaker.train.remote_function.runtime_environment.bootstrap_runtime_environment._write_failure_reason_file") - @patch("sagemaker.train.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager") - @patch("getpass.getuser") - def test_main_handles_exception(self, mock_getuser, mock_manager_class, mock_write_failure): - """Test main function handles exceptions.""" - mock_getuser.return_value = "root" - mock_manager = MagicMock() - mock_manager._validate_python_version.side_effect = Exception("Test error") - mock_manager_class.return_value = mock_manager - - args = [ - "--client_python_version", "3.8", - ] - - with pytest.raises(SystemExit) as exc_info: - main(args) - - assert exc_info.value.code == DEFAULT_FAILURE_CODE - mock_write_failure.assert_called_once() - - @patch("sagemaker.train.remote_function.runtime_environment.bootstrap_runtime_environment.set_env") - @patch("builtins.open", new_callable=mock_open, read_data='{"current_host": "algo-1", "current_instance_type": "ml.m5.xlarge", "hosts": ["algo-1"], "network_interface_name": "eth0"}') - @patch("os.path.exists") - @patch("sagemaker.train.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager") - @patch("sagemaker.train.remote_function.runtime_environment.bootstrap_runtime_environment._bootstrap_runtime_env_for_pipeline_step") - @patch("getpass.getuser") - @patch("sagemaker.train.remote_function.runtime_environment.bootstrap_runtime_environment._parse_args") - def test_main_pipeline_execution(self, mock_parse_args, mock_getuser, mock_bootstrap, mock_manager_class, mock_exists, mock_file, mock_set_env): - """Test main function for pipeline execution.""" - mock_getuser.return_value = "root" - mock_exists.return_value = True - mock_manager = MagicMock() - mock_manager_class.return_value = mock_manager - - # Mock parsed args - mock_args = MagicMock() - mock_args.client_python_version = "3.8" - mock_args.client_sagemaker_pysdk_version = None - mock_args.job_conda_env = None - mock_args.pipeline_execution_id = "exec-123" - mock_args.dependency_settings = None - mock_args.func_step_s3_dir = "s3://bucket/func" - mock_args.distribution = None - mock_args.user_nproc_per_node = None - mock_parse_args.return_value = mock_args - - args = [ - "--client_python_version", "3.8", - "--pipeline_execution_id", "exec-123", - "--func_step_s3_dir", "s3://bucket/func", - ] - - with pytest.raises(SystemExit) as exc_info: - main(args) - - assert exc_info.value.code == SUCCESS_EXIT_CODE - mock_bootstrap.assert_called_once() - - @patch("sagemaker.train.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager") - @patch("getpass.getuser") - def test_main_non_root_user(self, mock_getuser, mock_manager_class): - """Test main function with non-root user.""" - mock_getuser.return_value = "ubuntu" - mock_manager = MagicMock() - mock_manager_class.return_value = mock_manager - - args = [ - "--client_python_version", "3.8", - ] - - with pytest.raises(SystemExit): - main(args) - - mock_manager.change_dir_permission.assert_called_once() diff --git a/sagemaker-train/tests/unit/train/remote_function/test_mpi_utils_remote.py b/sagemaker-train/tests/unit/train/remote_function/test_mpi_utils_remote.py deleted file mode 100644 index 81736f36af..0000000000 --- a/sagemaker-train/tests/unit/train/remote_function/test_mpi_utils_remote.py +++ /dev/null @@ -1,424 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Tests for mpi_utils_remote module.""" -from __future__ import absolute_import - -import os -import pytest -import subprocess -import time -from unittest.mock import patch, MagicMock, mock_open, call -import paramiko - -from sagemaker.train.remote_function.runtime_environment.mpi_utils_remote import ( - CustomHostKeyPolicy, - _parse_args, - _can_connect, - _write_file_to_host, - _write_failure_reason_file, - _wait_for_master, - _wait_for_status_file, - _wait_for_workers, - bootstrap_master_node, - bootstrap_worker_node, - start_sshd_daemon, - write_status_file_to_workers, - main, - SUCCESS_EXIT_CODE, - DEFAULT_FAILURE_CODE, - FAILURE_REASON_PATH, - FINISHED_STATUS_FILE, - READY_FILE, - DEFAULT_SSH_PORT, -) - - -class TestCustomHostKeyPolicy: - """Test CustomHostKeyPolicy class.""" - - def test_accepts_algo_hostname(self): - """Test accepts hostnames starting with algo-.""" - policy = CustomHostKeyPolicy() - mock_client = MagicMock() - mock_hostname = "algo-1234" - mock_key = MagicMock() - mock_key.get_name.return_value = "ssh-rsa" - - # Should not raise exception - policy.missing_host_key(mock_client, mock_hostname, mock_key) - - mock_client.get_host_keys().add.assert_called_once_with(mock_hostname, "ssh-rsa", mock_key) - - def test_rejects_non_algo_hostname(self): - """Test rejects hostnames not starting with algo-.""" - policy = CustomHostKeyPolicy() - mock_client = MagicMock() - mock_hostname = "unknown-host" - mock_key = MagicMock() - - with pytest.raises(paramiko.SSHException): - policy.missing_host_key(mock_client, mock_hostname, mock_key) - - -class TestParseArgs: - """Test _parse_args function.""" - - def test_parse_default_args(self): - """Test parsing with default arguments.""" - args = [] - parsed = _parse_args(args) - assert parsed.job_ended == "0" - - def test_parse_job_ended_true(self): - """Test parsing with job_ended set to true.""" - args = ["--job_ended", "1"] - parsed = _parse_args(args) - assert parsed.job_ended == "1" - - def test_parse_job_ended_false(self): - """Test parsing with job_ended set to false.""" - args = ["--job_ended", "0"] - parsed = _parse_args(args) - assert parsed.job_ended == "0" - - -class TestCanConnect: - """Test _can_connect function.""" - - @patch("paramiko.SSHClient") - def test_can_connect_success(self, mock_ssh_client_class): - """Test successful connection.""" - mock_client = MagicMock() - mock_ssh_client_class.return_value.__enter__.return_value = mock_client - - result = _can_connect("algo-1", DEFAULT_SSH_PORT) - - assert result is True - mock_client.connect.assert_called_once_with("algo-1", port=DEFAULT_SSH_PORT) - - @patch("paramiko.SSHClient") - def test_can_connect_failure(self, mock_ssh_client_class): - """Test failed connection.""" - mock_client = MagicMock() - mock_client.connect.side_effect = Exception("Connection failed") - mock_ssh_client_class.return_value.__enter__.return_value = mock_client - - result = _can_connect("algo-1", DEFAULT_SSH_PORT) - - assert result is False - - @patch("paramiko.SSHClient") - def test_can_connect_uses_custom_port(self, mock_ssh_client_class): - """Test connection with custom port.""" - mock_client = MagicMock() - mock_ssh_client_class.return_value.__enter__.return_value = mock_client - - _can_connect("algo-1", 2222) - - mock_client.connect.assert_called_once_with("algo-1", port=2222) - - -class TestWriteFileToHost: - """Test _write_file_to_host function.""" - - @patch("subprocess.run") - def test_write_file_success(self, mock_run): - """Test successful file write.""" - mock_run.return_value = MagicMock(returncode=0) - - result = _write_file_to_host("algo-1", "/tmp/status") - - assert result is True - mock_run.assert_called_once() - - @patch("subprocess.run") - def test_write_file_failure(self, mock_run): - """Test failed file write.""" - mock_run.side_effect = subprocess.CalledProcessError(1, "ssh") - - result = _write_file_to_host("algo-1", "/tmp/status") - - assert result is False - - -class TestWriteFailureReasonFile: - """Test _write_failure_reason_file function.""" - - @patch("builtins.open", new_callable=mock_open) - @patch("os.path.exists") - def test_writes_failure_file(self, mock_exists, mock_file): - """Test writes failure reason file.""" - mock_exists.return_value = False - - _write_failure_reason_file("Test error message") - - mock_file.assert_called_once_with(FAILURE_REASON_PATH, "w") - mock_file().write.assert_called_once_with("RuntimeEnvironmentError: Test error message") - - @patch("builtins.open", new_callable=mock_open) - @patch("os.path.exists") - def test_does_not_write_if_exists(self, mock_exists, mock_file): - """Test does not write if failure file already exists.""" - mock_exists.return_value = True - - _write_failure_reason_file("Test error message") - - mock_file.assert_not_called() - - -class TestWaitForMaster: - """Test _wait_for_master function.""" - - @patch("time.sleep") - @patch("sagemaker.train.remote_function.runtime_environment.mpi_utils_remote._can_connect") - def test_wait_for_master_success(self, mock_can_connect, mock_sleep): - """Test successful wait for master.""" - mock_can_connect.return_value = True - - _wait_for_master("algo-1", DEFAULT_SSH_PORT, timeout=300) - - mock_can_connect.assert_called_once_with("algo-1", DEFAULT_SSH_PORT) - - @patch("time.time") - @patch("time.sleep") - @patch("sagemaker.train.remote_function.runtime_environment.mpi_utils_remote._can_connect") - def test_wait_for_master_timeout(self, mock_can_connect, mock_sleep, mock_time): - """Test timeout waiting for master.""" - mock_can_connect.return_value = False - # Need enough values for all time.time() calls in the loop - mock_time.side_effect = [0] + [i * 5 for i in range(1, 100)] # Simulate time passing - - with pytest.raises(TimeoutError): - _wait_for_master("algo-1", DEFAULT_SSH_PORT, timeout=300) - - @patch("time.time") - @patch("time.sleep") - @patch("sagemaker.train.remote_function.runtime_environment.mpi_utils_remote._can_connect") - def test_wait_for_master_retries(self, mock_can_connect, mock_sleep, mock_time): - """Test retries before successful connection.""" - mock_can_connect.side_effect = [False, False, True] - # Return value instead of side_effect for time.time() - mock_time.return_value = 0 - - _wait_for_master("algo-1", DEFAULT_SSH_PORT, timeout=300) - - assert mock_can_connect.call_count == 3 - - -class TestWaitForStatusFile: - """Test _wait_for_status_file function.""" - - @patch("time.sleep") - @patch("os.path.exists") - def test_wait_for_status_file_exists(self, mock_exists, mock_sleep): - """Test wait for status file that exists.""" - mock_exists.return_value = True - - _wait_for_status_file("/tmp/status") - - mock_exists.assert_called_once_with("/tmp/status") - - @patch("time.sleep") - @patch("os.path.exists") - def test_wait_for_status_file_waits(self, mock_exists, mock_sleep): - """Test waits until status file exists.""" - mock_exists.side_effect = [False, False, True] - - _wait_for_status_file("/tmp/status") - - assert mock_exists.call_count == 3 - assert mock_sleep.call_count == 2 - - -class TestWaitForWorkers: - """Test _wait_for_workers function.""" - - @patch("os.path.exists") - @patch("sagemaker.train.remote_function.runtime_environment.mpi_utils_remote._can_connect") - def test_wait_for_workers_empty_list(self, mock_can_connect, mock_exists): - """Test wait for workers with empty list.""" - _wait_for_workers([], DEFAULT_SSH_PORT, timeout=300) - - mock_can_connect.assert_not_called() - - @patch("time.sleep") - @patch("os.path.exists") - @patch("sagemaker.train.remote_function.runtime_environment.mpi_utils_remote._can_connect") - def test_wait_for_workers_success(self, mock_can_connect, mock_exists, mock_sleep): - """Test successful wait for workers.""" - mock_can_connect.return_value = True - mock_exists.return_value = True - - _wait_for_workers(["algo-2", "algo-3"], DEFAULT_SSH_PORT, timeout=300) - - assert mock_can_connect.call_count == 2 - - @patch("time.time") - @patch("time.sleep") - @patch("os.path.exists") - @patch("sagemaker.train.remote_function.runtime_environment.mpi_utils_remote._can_connect") - def test_wait_for_workers_timeout(self, mock_can_connect, mock_exists, mock_sleep, mock_time): - """Test timeout waiting for workers.""" - mock_can_connect.return_value = False - mock_exists.return_value = False - # Need enough values for all time.time() calls in the loop - mock_time.side_effect = [0] + [i * 5 for i in range(1, 100)] - - with pytest.raises(TimeoutError): - _wait_for_workers(["algo-2"], DEFAULT_SSH_PORT, timeout=300) - - -class TestBootstrapMasterNode: - """Test bootstrap_master_node function.""" - - @patch("sagemaker.train.remote_function.runtime_environment.mpi_utils_remote._wait_for_workers") - def test_bootstrap_master_node(self, mock_wait): - """Test bootstrap master node.""" - worker_hosts = ["algo-2", "algo-3"] - - bootstrap_master_node(worker_hosts) - - mock_wait.assert_called_once_with(worker_hosts) - - -class TestBootstrapWorkerNode: - """Test bootstrap_worker_node function.""" - - @patch("sagemaker.train.remote_function.runtime_environment.mpi_utils_remote._wait_for_status_file") - @patch("sagemaker.train.remote_function.runtime_environment.mpi_utils_remote._write_file_to_host") - @patch("sagemaker.train.remote_function.runtime_environment.mpi_utils_remote._wait_for_master") - def test_bootstrap_worker_node(self, mock_wait_master, mock_write, mock_wait_status): - """Test bootstrap worker node.""" - bootstrap_worker_node("algo-1", "algo-2", "/tmp/status") - - mock_wait_master.assert_called_once_with("algo-1") - mock_write.assert_called_once() - mock_wait_status.assert_called_once_with("/tmp/status") - - -class TestStartSshdDaemon: - """Test start_sshd_daemon function.""" - - @patch("subprocess.Popen") - @patch("os.path.exists") - def test_starts_sshd_successfully(self, mock_exists, mock_popen): - """Test starts SSH daemon successfully.""" - mock_exists.return_value = True - - start_sshd_daemon() - - mock_popen.assert_called_once_with(["/usr/sbin/sshd", "-D"]) - - @patch("os.path.exists") - def test_raises_error_if_sshd_not_found(self, mock_exists): - """Test raises error if SSH daemon not found.""" - mock_exists.return_value = False - - with pytest.raises(RuntimeError): - start_sshd_daemon() - - -class TestWriteStatusFileToWorkers: - """Test write_status_file_to_workers function.""" - - @patch("sagemaker.train.remote_function.runtime_environment.mpi_utils_remote._write_file_to_host") - def test_writes_to_all_workers(self, mock_write): - """Test writes status file to all workers.""" - mock_write.return_value = True - worker_hosts = ["algo-2", "algo-3"] - - write_status_file_to_workers(worker_hosts, "/tmp/status") - - assert mock_write.call_count == 2 - - @patch("time.sleep") - @patch("sagemaker.train.remote_function.runtime_environment.mpi_utils_remote._write_file_to_host") - def test_retries_on_failure(self, mock_write, mock_sleep): - """Test retries writing status file on failure.""" - mock_write.side_effect = [False, False, True] - worker_hosts = ["algo-2"] - - write_status_file_to_workers(worker_hosts, "/tmp/status") - - assert mock_write.call_count == 3 - assert mock_sleep.call_count == 2 - - @patch("time.sleep") - @patch("sagemaker.train.remote_function.runtime_environment.mpi_utils_remote._write_file_to_host") - def test_raises_timeout_after_retries(self, mock_write, mock_sleep): - """Test raises timeout after max retries.""" - mock_write.return_value = False - worker_hosts = ["algo-2"] - - with pytest.raises(TimeoutError): - write_status_file_to_workers(worker_hosts, "/tmp/status") - - -class TestMain: - """Test main function.""" - - @patch("sagemaker.train.remote_function.runtime_environment.mpi_utils_remote.bootstrap_worker_node") - @patch("sagemaker.train.remote_function.runtime_environment.mpi_utils_remote.start_sshd_daemon") - @patch.dict("os.environ", {"SM_MASTER_ADDR": "algo-1", "SM_CURRENT_HOST": "algo-2"}) - def test_main_worker_node_running(self, mock_start_sshd, mock_bootstrap_worker): - """Test main function for worker node during job run.""" - args = ["--job_ended", "0"] - - main(args) - - mock_start_sshd.assert_called_once() - mock_bootstrap_worker.assert_called_once() - - @patch("sagemaker.train.remote_function.runtime_environment.mpi_utils_remote.bootstrap_master_node") - @patch("sagemaker.train.remote_function.runtime_environment.mpi_utils_remote.start_sshd_daemon") - @patch.dict("os.environ", {"SM_MASTER_ADDR": "algo-1", "SM_CURRENT_HOST": "algo-1", "SM_HOSTS": '["algo-1", "algo-2"]'}) - def test_main_master_node_running(self, mock_start_sshd, mock_bootstrap_master): - """Test main function for master node during job run.""" - args = ["--job_ended", "0"] - - main(args) - - mock_start_sshd.assert_called_once() - mock_bootstrap_master.assert_called_once() - - @patch("sagemaker.train.remote_function.runtime_environment.mpi_utils_remote.write_status_file_to_workers") - @patch.dict("os.environ", {"SM_MASTER_ADDR": "algo-1", "SM_CURRENT_HOST": "algo-1", "SM_HOSTS": '["algo-1", "algo-2"]'}) - def test_main_master_node_job_ended(self, mock_write_status): - """Test main function for master node after job ends.""" - args = ["--job_ended", "1"] - - main(args) - - mock_write_status.assert_called_once() - - @patch.dict("os.environ", {"SM_MASTER_ADDR": "algo-1", "SM_CURRENT_HOST": "algo-2"}) - def test_main_worker_node_job_ended(self): - """Test main function for worker node after job ends.""" - args = ["--job_ended", "1"] - - # Should not raise any exceptions - main(args) - - @patch("sagemaker.train.remote_function.runtime_environment.mpi_utils_remote._write_failure_reason_file") - @patch("sagemaker.train.remote_function.runtime_environment.mpi_utils_remote.start_sshd_daemon") - @patch.dict("os.environ", {"SM_MASTER_ADDR": "algo-1", "SM_CURRENT_HOST": "algo-2"}) - def test_main_handles_exception(self, mock_start_sshd, mock_write_failure): - """Test main function handles exceptions.""" - mock_start_sshd.side_effect = Exception("Test error") - args = ["--job_ended", "0"] - - with pytest.raises(SystemExit) as exc_info: - main(args) - - assert exc_info.value.code == DEFAULT_FAILURE_CODE - mock_write_failure.assert_called_once() diff --git a/sagemaker-train/tests/unit/train/remote_function/test_runtime_environment_manager.py b/sagemaker-train/tests/unit/train/remote_function/test_runtime_environment_manager.py deleted file mode 100644 index 464e46db6d..0000000000 --- a/sagemaker-train/tests/unit/train/remote_function/test_runtime_environment_manager.py +++ /dev/null @@ -1,564 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Tests for runtime_environment_manager module.""" -from __future__ import absolute_import - -import json -import os -import subprocess -import sys -import pytest -from unittest.mock import patch, MagicMock, mock_open, call - -from sagemaker.train.remote_function.runtime_environment.runtime_environment_manager import ( - _DependencySettings, - RuntimeEnvironmentManager, - RuntimeEnvironmentError, - get_logger, - _run_and_get_output_shell_cmd, - _run_pre_execution_command_script, - _run_shell_cmd, - _log_output, - _log_error, - _python_executable, -) - - -class TestDependencySettings: - """Test _DependencySettings class.""" - - def test_init_with_no_file(self): - """Test initialization without dependency file.""" - settings = _DependencySettings() - assert settings.dependency_file is None - - def test_init_with_file(self): - """Test initialization with dependency file.""" - settings = _DependencySettings(dependency_file="requirements.txt") - assert settings.dependency_file == "requirements.txt" - - def test_to_string(self): - """Test converts to JSON string.""" - settings = _DependencySettings(dependency_file="requirements.txt") - result = settings.to_string() - assert result == '{"dependency_file": "requirements.txt"}' - - def test_from_string_with_file(self): - """Test creates from JSON string with file.""" - json_str = '{"dependency_file": "requirements.txt"}' - settings = _DependencySettings.from_string(json_str) - assert settings.dependency_file == "requirements.txt" - - def test_from_string_with_none(self): - """Test creates from None.""" - settings = _DependencySettings.from_string(None) - assert settings is None - - def test_from_dependency_file_path_with_none(self): - """Test creates from None file path.""" - settings = _DependencySettings.from_dependency_file_path(None) - assert settings.dependency_file is None - - def test_from_dependency_file_path_with_auto_capture(self): - """Test creates from auto_capture.""" - settings = _DependencySettings.from_dependency_file_path("auto_capture") - assert settings.dependency_file == "env_snapshot.yml" - - def test_from_dependency_file_path_with_path(self): - """Test creates from file path.""" - settings = _DependencySettings.from_dependency_file_path("/path/to/requirements.txt") - assert settings.dependency_file == "requirements.txt" - - -class TestGetLogger: - """Test get_logger function.""" - - def test_returns_logger(self): - """Test returns logger instance.""" - logger = get_logger() - assert logger is not None - assert logger.name == "sagemaker.remote_function" - - -class TestRuntimeEnvironmentManager: - """Test RuntimeEnvironmentManager class.""" - - def test_init(self): - """Test initialization.""" - manager = RuntimeEnvironmentManager() - assert manager is not None - - @patch("os.path.isfile") - def test_snapshot_returns_none_for_none(self, mock_isfile): - """Test snapshot returns None when dependencies is None.""" - manager = RuntimeEnvironmentManager() - result = manager.snapshot(None) - assert result is None - - @patch("sagemaker.train.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._capture_from_local_runtime") - def test_snapshot_auto_capture(self, mock_capture): - """Test snapshot with auto_capture.""" - mock_capture.return_value = "/path/to/env_snapshot.yml" - manager = RuntimeEnvironmentManager() - result = manager.snapshot("auto_capture") - assert result == "/path/to/env_snapshot.yml" - mock_capture.assert_called_once() - - @patch("os.path.isfile") - def test_snapshot_with_txt_file(self, mock_isfile): - """Test snapshot with requirements.txt file.""" - mock_isfile.return_value = True - manager = RuntimeEnvironmentManager() - result = manager.snapshot("requirements.txt") - assert result == "requirements.txt" - - @patch("os.path.isfile") - def test_snapshot_with_yml_file(self, mock_isfile): - """Test snapshot with conda.yml file.""" - mock_isfile.return_value = True - manager = RuntimeEnvironmentManager() - result = manager.snapshot("environment.yml") - assert result == "environment.yml" - - @patch("os.path.isfile") - def test_snapshot_raises_error_for_invalid_file(self, mock_isfile): - """Test snapshot raises error for invalid file.""" - mock_isfile.return_value = False - manager = RuntimeEnvironmentManager() - with pytest.raises(ValueError): - manager.snapshot("requirements.txt") - - def test_snapshot_raises_error_for_invalid_format(self): - """Test snapshot raises error for invalid format.""" - manager = RuntimeEnvironmentManager() - with pytest.raises(ValueError): - manager.snapshot("invalid.json") - - @patch("os.getenv") - def test_get_active_conda_env_prefix(self, mock_getenv): - """Test gets active conda environment prefix.""" - mock_getenv.return_value = "/opt/conda/envs/myenv" - manager = RuntimeEnvironmentManager() - result = manager._get_active_conda_env_prefix() - assert result == "/opt/conda/envs/myenv" - - @patch("os.getenv") - def test_get_active_conda_env_name(self, mock_getenv): - """Test gets active conda environment name.""" - mock_getenv.return_value = "myenv" - manager = RuntimeEnvironmentManager() - result = manager._get_active_conda_env_name() - assert result == "myenv" - - @patch("sagemaker.train.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._export_conda_env_from_prefix") - @patch("os.getcwd") - @patch("os.getenv") - def test_capture_from_local_runtime(self, mock_getenv, mock_getcwd, mock_export): - """Test captures from local runtime.""" - mock_getenv.side_effect = lambda x: "myenv" if x == "CONDA_DEFAULT_ENV" else "/opt/conda/envs/myenv" - mock_getcwd.return_value = "/tmp" - manager = RuntimeEnvironmentManager() - result = manager._capture_from_local_runtime() - assert result == "/tmp/env_snapshot.yml" - mock_export.assert_called_once() - - @patch("os.getenv") - def test_capture_from_local_runtime_raises_error_no_conda(self, mock_getenv): - """Test raises error when no conda environment active.""" - mock_getenv.return_value = None - manager = RuntimeEnvironmentManager() - with pytest.raises(ValueError): - manager._capture_from_local_runtime() - - @patch("sagemaker.train.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._install_requirements_txt") - def test_bootstrap_with_txt_file_no_conda(self, mock_install): - """Test bootstrap with requirements.txt without conda.""" - manager = RuntimeEnvironmentManager() - manager.bootstrap("requirements.txt", "3.8", None) - mock_install.assert_called_once() - - @patch("sagemaker.train.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._write_conda_env_to_file") - @patch("sagemaker.train.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._install_req_txt_in_conda_env") - def test_bootstrap_with_txt_file_with_conda(self, mock_install, mock_write): - """Test bootstrap with requirements.txt with conda.""" - manager = RuntimeEnvironmentManager() - manager.bootstrap("requirements.txt", "3.8", "myenv") - mock_install.assert_called_once() - mock_write.assert_called_once() - - @patch("sagemaker.train.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._write_conda_env_to_file") - @patch("sagemaker.train.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._update_conda_env") - def test_bootstrap_with_yml_file_with_conda(self, mock_update, mock_write): - """Test bootstrap with conda.yml with existing conda env.""" - manager = RuntimeEnvironmentManager() - manager.bootstrap("environment.yml", "3.8", "myenv") - mock_update.assert_called_once() - mock_write.assert_called_once() - - @patch("sagemaker.train.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._write_conda_env_to_file") - @patch("sagemaker.train.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._validate_python_version") - @patch("sagemaker.train.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._create_conda_env") - def test_bootstrap_with_yml_file_without_conda(self, mock_create, mock_validate, mock_write): - """Test bootstrap with conda.yml without existing conda env.""" - manager = RuntimeEnvironmentManager() - manager.bootstrap("environment.yml", "3.8", None) - mock_create.assert_called_once() - mock_validate.assert_called_once() - mock_write.assert_called_once() - - @patch("sagemaker.train.remote_function.runtime_environment.runtime_environment_manager._run_pre_execution_command_script") - @patch("os.path.isfile") - def test_run_pre_exec_script_exists(self, mock_isfile, mock_run_script): - """Test runs pre-execution script when it exists.""" - mock_isfile.return_value = True - mock_run_script.return_value = (0, "") - manager = RuntimeEnvironmentManager() - manager.run_pre_exec_script("/path/to/script.sh") - mock_run_script.assert_called_once() - - @patch("os.path.isfile") - def test_run_pre_exec_script_not_exists(self, mock_isfile): - """Test handles pre-execution script not existing.""" - mock_isfile.return_value = False - manager = RuntimeEnvironmentManager() - # Should not raise exception - manager.run_pre_exec_script("/path/to/script.sh") - - @patch("sagemaker.train.remote_function.runtime_environment.runtime_environment_manager._run_pre_execution_command_script") - @patch("os.path.isfile") - def test_run_pre_exec_script_raises_error_on_failure(self, mock_isfile, mock_run_script): - """Test raises error when pre-execution script fails.""" - mock_isfile.return_value = True - mock_run_script.return_value = (1, "Error message") - manager = RuntimeEnvironmentManager() - with pytest.raises(RuntimeEnvironmentError): - manager.run_pre_exec_script("/path/to/script.sh") - - @patch("subprocess.run") - def test_change_dir_permission_success(self, mock_run): - """Test changes directory permissions successfully.""" - manager = RuntimeEnvironmentManager() - manager.change_dir_permission(["/tmp/dir1", "/tmp/dir2"], "777") - mock_run.assert_called_once() - - @patch("subprocess.run") - def test_change_dir_permission_raises_error_on_failure(self, mock_run): - """Test raises error when permission change fails.""" - mock_run.side_effect = subprocess.CalledProcessError(1, "chmod", stderr=b"Permission denied") - manager = RuntimeEnvironmentManager() - with pytest.raises(RuntimeEnvironmentError): - manager.change_dir_permission(["/tmp/dir1"], "777") - - @patch("subprocess.run") - def test_change_dir_permission_raises_error_no_sudo(self, mock_run): - """Test raises error when sudo not found.""" - mock_run.side_effect = FileNotFoundError("[Errno 2] No such file or directory: 'sudo'") - manager = RuntimeEnvironmentManager() - with pytest.raises(RuntimeEnvironmentError): - manager.change_dir_permission(["/tmp/dir1"], "777") - - @patch("sagemaker.train.remote_function.runtime_environment.runtime_environment_manager._run_shell_cmd") - def test_install_requirements_txt(self, mock_run_cmd): - """Test installs requirements.txt.""" - manager = RuntimeEnvironmentManager() - manager._install_requirements_txt("/path/to/requirements.txt", "/usr/bin/python") - mock_run_cmd.assert_called_once() - - @patch("sagemaker.train.remote_function.runtime_environment.runtime_environment_manager._run_shell_cmd") - @patch("sagemaker.train.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._get_conda_exe") - def test_create_conda_env(self, mock_get_conda, mock_run_cmd): - """Test creates conda environment.""" - mock_get_conda.return_value = "conda" - manager = RuntimeEnvironmentManager() - manager._create_conda_env("myenv", "/path/to/environment.yml") - mock_run_cmd.assert_called_once() - - @patch("sagemaker.train.remote_function.runtime_environment.runtime_environment_manager._run_shell_cmd") - @patch("sagemaker.train.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._get_conda_exe") - def test_install_req_txt_in_conda_env(self, mock_get_conda, mock_run_cmd): - """Test installs requirements.txt in conda environment.""" - mock_get_conda.return_value = "conda" - manager = RuntimeEnvironmentManager() - manager._install_req_txt_in_conda_env("myenv", "/path/to/requirements.txt") - mock_run_cmd.assert_called_once() - - @patch("sagemaker.train.remote_function.runtime_environment.runtime_environment_manager._run_shell_cmd") - @patch("sagemaker.train.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._get_conda_exe") - def test_update_conda_env(self, mock_get_conda, mock_run_cmd): - """Test updates conda environment.""" - mock_get_conda.return_value = "conda" - manager = RuntimeEnvironmentManager() - manager._update_conda_env("myenv", "/path/to/environment.yml") - mock_run_cmd.assert_called_once() - - @patch("sagemaker.train.remote_function.runtime_environment.runtime_environment_manager._run_shell_cmd") - @patch("sagemaker.train.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._get_conda_exe") - def test_export_conda_env_from_prefix(self, mock_get_conda, mock_run_cmd): - """Test exports conda environment.""" - mock_get_conda.return_value = "conda" - manager = RuntimeEnvironmentManager() - manager._export_conda_env_from_prefix("/opt/conda/envs/myenv", "/tmp/env.yml") - mock_run_cmd.assert_called_once() - - @patch("builtins.open", new_callable=mock_open) - @patch("os.getcwd") - def test_write_conda_env_to_file(self, mock_getcwd, mock_file): - """Test writes conda environment name to file.""" - mock_getcwd.return_value = "/tmp" - manager = RuntimeEnvironmentManager() - manager._write_conda_env_to_file("myenv") - mock_file.assert_called_once_with("/tmp/remote_function_conda_env.txt", "w") - mock_file().write.assert_called_once_with("myenv") - - @patch("subprocess.Popen") - def test_get_conda_exe_returns_mamba(self, mock_popen): - """Test returns mamba when available.""" - mock_popen.return_value.wait.side_effect = [0, 1] # mamba exists, conda doesn't - manager = RuntimeEnvironmentManager() - result = manager._get_conda_exe() - assert result == "mamba" - - @patch("subprocess.Popen") - def test_get_conda_exe_returns_conda(self, mock_popen): - """Test returns conda when mamba not available.""" - mock_popen.return_value.wait.side_effect = [1, 0] # mamba doesn't exist, conda does - manager = RuntimeEnvironmentManager() - result = manager._get_conda_exe() - assert result == "conda" - - @patch("subprocess.Popen") - def test_get_conda_exe_raises_error(self, mock_popen): - """Test raises error when neither conda nor mamba available.""" - mock_popen.return_value.wait.return_value = 1 - manager = RuntimeEnvironmentManager() - with pytest.raises(ValueError): - manager._get_conda_exe() - - @patch("subprocess.check_output") - @patch("sagemaker.train.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._get_conda_exe") - def test_python_version_in_conda_env(self, mock_get_conda, mock_check_output): - """Test gets Python version in conda environment.""" - mock_get_conda.return_value = "conda" - mock_check_output.return_value = b"Python 3.8.10" - manager = RuntimeEnvironmentManager() - result = manager._python_version_in_conda_env("myenv") - assert result == "3.8" - - @patch("subprocess.check_output") - @patch("sagemaker.train.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._get_conda_exe") - def test_python_version_in_conda_env_raises_error(self, mock_get_conda, mock_check_output): - """Test raises error when getting Python version fails.""" - mock_get_conda.return_value = "conda" - mock_check_output.side_effect = subprocess.CalledProcessError(1, "conda", output=b"Error") - manager = RuntimeEnvironmentManager() - with pytest.raises(RuntimeEnvironmentError): - manager._python_version_in_conda_env("myenv") - - def test_current_python_version(self): - """Test gets current Python version.""" - manager = RuntimeEnvironmentManager() - result = manager._current_python_version() - expected = f"{sys.version_info.major}.{sys.version_info.minor}" - assert result == expected - - @patch("sagemaker.train.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._python_version_in_conda_env") - def test_validate_python_version_with_conda(self, mock_python_version): - """Test validates Python version with conda environment.""" - mock_python_version.return_value = "3.8" - manager = RuntimeEnvironmentManager() - # Should not raise exception - manager._validate_python_version("3.8", "myenv") - - @patch("sagemaker.train.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._python_version_in_conda_env") - def test_validate_python_version_mismatch_with_conda(self, mock_python_version): - """Test raises error on Python version mismatch with conda.""" - mock_python_version.return_value = "3.9" - manager = RuntimeEnvironmentManager() - with pytest.raises(RuntimeEnvironmentError): - manager._validate_python_version("3.8", "myenv") - - @patch("sagemaker.train.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._current_python_version") - def test_validate_python_version_without_conda(self, mock_current_version): - """Test validates Python version without conda environment.""" - mock_current_version.return_value = "3.8" - manager = RuntimeEnvironmentManager() - # Should not raise exception - manager._validate_python_version("3.8", None) - - @patch("sagemaker.train.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._current_python_version") - def test_validate_python_version_mismatch_without_conda(self, mock_current_version): - """Test raises error on Python version mismatch without conda.""" - mock_current_version.return_value = "3.9" - manager = RuntimeEnvironmentManager() - with pytest.raises(RuntimeEnvironmentError): - manager._validate_python_version("3.8", None) - - @patch("sagemaker.train.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._current_sagemaker_pysdk_version") - def test_validate_sagemaker_pysdk_version_match(self, mock_current_version): - """Test validates matching SageMaker SDK version.""" - mock_current_version.return_value = "2.100.0" - manager = RuntimeEnvironmentManager() - # Should not raise exception or warning - manager._validate_sagemaker_pysdk_version("2.100.0") - - @patch("sagemaker.train.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._current_sagemaker_pysdk_version") - def test_validate_sagemaker_pysdk_version_mismatch(self, mock_current_version): - """Test logs warning on SageMaker SDK version mismatch.""" - mock_current_version.return_value = "2.101.0" - manager = RuntimeEnvironmentManager() - # Should log warning but not raise exception - manager._validate_sagemaker_pysdk_version("2.100.0") - - @patch("sagemaker.train.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._current_sagemaker_pysdk_version") - def test_validate_sagemaker_pysdk_version_none(self, mock_current_version): - """Test handles None client version.""" - mock_current_version.return_value = "2.100.0" - manager = RuntimeEnvironmentManager() - # Should not raise exception - manager._validate_sagemaker_pysdk_version(None) - - -class TestRunAndGetOutputShellCmd: - """Test _run_and_get_output_shell_cmd function.""" - - @patch("subprocess.check_output") - def test_runs_command_successfully(self, mock_check_output): - """Test runs command and returns output.""" - mock_check_output.return_value = b"command output" - result = _run_and_get_output_shell_cmd("echo test") - assert result == "command output" - - -class TestRunPreExecutionCommandScript: - """Test _run_pre_execution_command_script function.""" - - @patch("sagemaker.train.remote_function.runtime_environment.runtime_environment_manager._log_error") - @patch("sagemaker.train.remote_function.runtime_environment.runtime_environment_manager._log_output") - @patch("subprocess.Popen") - @patch("os.path.dirname") - def test_runs_script_successfully(self, mock_dirname, mock_popen, mock_log_output, mock_log_error): - """Test runs script successfully.""" - mock_dirname.return_value = "/tmp" - mock_process = MagicMock() - mock_process.wait.return_value = 0 - mock_popen.return_value = mock_process - mock_log_error.return_value = "" - - return_code, error_logs = _run_pre_execution_command_script("/tmp/script.sh") - - assert return_code == 0 - assert error_logs == "" - - @patch("sagemaker.train.remote_function.runtime_environment.runtime_environment_manager._log_error") - @patch("sagemaker.train.remote_function.runtime_environment.runtime_environment_manager._log_output") - @patch("subprocess.Popen") - @patch("os.path.dirname") - def test_runs_script_with_error(self, mock_dirname, mock_popen, mock_log_output, mock_log_error): - """Test runs script that returns error.""" - mock_dirname.return_value = "/tmp" - mock_process = MagicMock() - mock_process.wait.return_value = 1 - mock_popen.return_value = mock_process - mock_log_error.return_value = "Error message" - - return_code, error_logs = _run_pre_execution_command_script("/tmp/script.sh") - - assert return_code == 1 - assert error_logs == "Error message" - - -class TestRunShellCmd: - """Test _run_shell_cmd function.""" - - @patch("sagemaker.train.remote_function.runtime_environment.runtime_environment_manager._log_error") - @patch("sagemaker.train.remote_function.runtime_environment.runtime_environment_manager._log_output") - @patch("subprocess.Popen") - def test_runs_command_successfully(self, mock_popen, mock_log_output, mock_log_error): - """Test runs command successfully.""" - mock_process = MagicMock() - mock_process.wait.return_value = 0 - mock_popen.return_value = mock_process - mock_log_error.return_value = "" - - _run_shell_cmd("echo test") - - mock_popen.assert_called_once() - - @patch("sagemaker.train.remote_function.runtime_environment.runtime_environment_manager._log_error") - @patch("sagemaker.train.remote_function.runtime_environment.runtime_environment_manager._log_output") - @patch("subprocess.Popen") - def test_runs_command_raises_error_on_failure(self, mock_popen, mock_log_output, mock_log_error): - """Test raises error when command fails.""" - mock_process = MagicMock() - mock_process.wait.return_value = 1 - mock_popen.return_value = mock_process - mock_log_error.return_value = "Error message" - - with pytest.raises(RuntimeEnvironmentError): - _run_shell_cmd("false") - - -class TestLogOutput: - """Test _log_output function.""" - - @patch("sagemaker.train.remote_function.runtime_environment.runtime_environment_manager.logger") - def test_logs_output(self, mock_logger): - """Test logs process output.""" - from io import BytesIO - mock_process = MagicMock() - mock_process.stdout = BytesIO(b"line1\nline2\n") - - _log_output(mock_process) - - assert mock_logger.info.call_count == 2 - - -class TestLogError: - """Test _log_error function.""" - - @patch("sagemaker.train.remote_function.runtime_environment.runtime_environment_manager.logger") - def test_logs_error(self, mock_logger): - """Test logs process errors.""" - from io import BytesIO - mock_process = MagicMock() - mock_process.stderr = BytesIO(b"ERROR: error message\nwarning message\n") - - error_logs = _log_error(mock_process) - - assert "ERROR: error message" in error_logs - assert "warning message" in error_logs - - -class TestPythonExecutable: - """Test _python_executable function.""" - - def test_returns_python_executable(self): - """Test returns Python executable path.""" - result = _python_executable() - assert result == sys.executable - - @patch("sys.executable", None) - def test_raises_error_if_no_executable(self): - """Test raises error if no Python executable.""" - with pytest.raises(RuntimeEnvironmentError): - _python_executable() - - -class TestRuntimeEnvironmentError: - """Test RuntimeEnvironmentError class.""" - - def test_creates_error_with_message(self): - """Test creates error with message.""" - error = RuntimeEnvironmentError("Test error") - assert str(error) == "Test error" - assert error.message == "Test error"