diff --git a/CHANGELOG.md b/CHANGELOG.md index a08bb7ee75..8a14df9026 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,15 @@ # Changelog +## v2.255.0 (2025-12-03) + +### Features + +* Extracts reward Lambda ARN from Nova recipes +* Passes it as training job hyperparameter +* Added LLMFT recipe support with standardized recipe handling +* Enhanced recipe validation and multi-model type compatibility + + ## v2.254.1 (2025-10-31) ### Bug Fixes and Other Changes diff --git a/VERSION b/VERSION index 4459d36c7a..a50d32a8ef 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.254.2.dev0 +2.254.2.dev0 \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 02b89e975a..d4f6e1e652 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ classifiers = [ ] dependencies = [ "attrs>=24,<26", - "boto3>=1.39.5,<2.0", + "boto3>=1.42.2,<2.0", "cloudpickle>=2.2.1", "docker", "fastapi", @@ -49,7 +49,7 @@ dependencies = [ "psutil", "PyYAML>=6.0.1", "requests", - "sagemaker-core>=1.0.17,<2.0.0", + "sagemaker-core>=1.0.71,<2.0.0", "schema", "smdebug_rulesconfig==1.0.1", "tblib>=1.7.0,<4", diff --git a/requirements/extras/test_requirements.txt b/requirements/extras/test_requirements.txt index 477103ab1d..ecc805fc47 100644 --- a/requirements/extras/test_requirements.txt +++ b/requirements/extras/test_requirements.txt @@ -62,4 +62,4 @@ mypy-boto3-s3==1.35.76 mypy-extensions==1.0.0 mypy==1.9.0 # apache-airflow transitive dependancy -google-re2<1.1.20250805; python_version < "3.10" +google-re2<1.1.20250805; python_version < "3.10" \ No newline at end of file diff --git a/src/sagemaker/__init__.py b/src/sagemaker/__init__.py index 71ea51c60f..6a3f43f66c 100644 --- a/src/sagemaker/__init__.py +++ b/src/sagemaker/__init__.py @@ -48,6 +48,7 @@ from sagemaker.analytics import TrainingJobAnalytics, HyperparameterTuningJobAnalytics # noqa: F401 from sagemaker.local.local_session import LocalSession # noqa: F401 +from sagemaker.container_base_model import ContainerBaseModel # noqa: F401 from sagemaker.model import Model, ModelPackage # noqa: F401 from sagemaker.model_metrics import ModelMetrics, MetricsSource, FileSource # noqa: F401 from sagemaker.pipeline import PipelineModel # noqa: F401 diff --git a/src/sagemaker/chainer/model.py b/src/sagemaker/chainer/model.py index c2d2187b69..2872d7de67 100644 --- a/src/sagemaker/chainer/model.py +++ b/src/sagemaker/chainer/model.py @@ -17,7 +17,7 @@ from typing import Callable, Optional, Union, List, Dict import sagemaker -from sagemaker import image_uris, ModelMetrics +from sagemaker import image_uris, ModelMetrics, ContainerBaseModel from sagemaker.drift_check_baselines import DriftCheckBaselines from sagemaker.fw_utils import ( model_code_key_prefix, @@ -182,6 +182,8 @@ def register( source_uri: Optional[Union[str, PipelineVariable]] = None, model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None, model_life_cycle: Optional[ModelLifeCycle] = None, + model_package_registration_type: Optional[Union[str, PipelineVariable]] = None, + base_model: Optional[ContainerBaseModel] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -236,6 +238,9 @@ def register( model_card (ModeCard or ModelPackageModelCard): document contains qualitative and quantitative information about a model (default: None). model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None). + model_package_registration_type (str or PipelineVariable): Model Package Registration + Type (default: None). + base_model (ContainerBaseModel): ContainerBaseModel object (default: None). Returns: str: A string of SageMaker Model Package ARN. @@ -278,6 +283,8 @@ def register( source_uri=source_uri, model_card=model_card, model_life_cycle=model_life_cycle, + model_package_registration_type=model_package_registration_type, + base_model=base_model, ) def prepare_container_def( diff --git a/src/sagemaker/container_base_model.py b/src/sagemaker/container_base_model.py new file mode 100644 index 0000000000..5845a28d65 --- /dev/null +++ b/src/sagemaker/container_base_model.py @@ -0,0 +1,52 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express oXr implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This file contains code related to base model for containers.""" +from __future__ import absolute_import + +from typing import Optional, Union + +from sagemaker.workflow.entities import PipelineVariable + + +class ContainerBaseModel(object): + """Accepts Base Model parameters for conversion to request dict.""" + + def __init__( + self, + hub_content_name: Union[str, PipelineVariable] = None, + hub_content_version: Optional[Union[str, PipelineVariable]] = None, + recipe_name: Optional[Union[str, PipelineVariable]] = None, + ): + """Initialize a ``ContainerBaseModel`` instance and turn parameters into dict. + + Args: + hub_content_name (str or PipelineVariable): The hub content name + hub_content_version (str or PipelineVariable): The hub content version + (default: None) + recipe_name (str or PipelineVariable): The Recipe name + (default: None) + """ + self.hub_content_name = hub_content_name + self.hub_content_version = hub_content_version + self.recipe_name = recipe_name + + def _to_request_dict(self): + """Generates a request dictionary using the parameters provided to the class.""" + base_model_request = {} + if self.hub_content_name is not None: + base_model_request["HubContentName"] = self.hub_content_name + if self.hub_content_version is not None: + base_model_request["HubContentVersion"] = self.hub_content_version + if self.recipe_name is not None: + base_model_request["RecipeName"] = self.recipe_name + return base_model_request diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 2d8318fd39..413f948af6 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -1817,6 +1817,8 @@ def register( source_uri=None, model_life_cycle=None, model_card=None, + model_package_registration_type=None, + base_model=None, **kwargs, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -1868,6 +1870,9 @@ def register( model_card (ModeCard or ModelPackageModelCard): document contains qualitative and quantitative information about a model (default: None). model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None). + model_package_registration_type (str): Model Package Registration + Type (default: None). + base_model (ContainerBaseModel): ContainerBaseModel object (default: None). **kwargs: Passed to invocation of ``create_model()``. Implementations may customize ``create_model()`` to accept ``**kwargs`` to customize model creation during deploy. For more, see the implementation docs. @@ -1924,6 +1929,8 @@ def register( source_uri=source_uri, model_card=model_card, model_life_cycle=model_life_cycle, + model_package_registration_type=model_package_registration_type, + base_model=base_model, ) @property diff --git a/src/sagemaker/huggingface/model.py b/src/sagemaker/huggingface/model.py index 3ca25fb3ce..2d85cec5ac 100644 --- a/src/sagemaker/huggingface/model.py +++ b/src/sagemaker/huggingface/model.py @@ -17,7 +17,7 @@ from typing import Callable, Optional, Union, List, Dict import sagemaker -from sagemaker import image_uris, ModelMetrics +from sagemaker import image_uris, ModelMetrics, ContainerBaseModel from sagemaker.deserializers import JSONDeserializer from sagemaker.drift_check_baselines import DriftCheckBaselines from sagemaker.fw_utils import ( @@ -372,6 +372,8 @@ def register( source_uri: Optional[Union[str, PipelineVariable]] = None, model_life_cycle: Optional[ModelLifeCycle] = None, model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None, + model_package_registration_type: Optional[Union[str, PipelineVariable]] = None, + base_model: Optional[ContainerBaseModel] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -427,6 +429,9 @@ def register( model_card (ModeCard or ModelPackageModelCard): document contains qualitative and quantitative information about a model (default: None). model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None). + model_package_registration_type (str or PipelineVariable): Model Package Registration + Type (default: None). + base_model (ContainerBaseModel): ContainerBaseModel object (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -477,6 +482,8 @@ def register( source_uri=source_uri, model_life_cycle=model_life_cycle, model_card=model_card, + model_package_registration_type=model_package_registration_type, + base_model=base_model, ) def prepare_container_def( diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 3bfac0c8da..b7e6017d45 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -52,6 +52,7 @@ from sagemaker.model_card.helpers import _hash_content_str from sagemaker.model_card.schema_constraints import ModelApprovalStatusEnum from sagemaker.session import Session +from sagemaker.container_base_model import ContainerBaseModel from sagemaker.model_metrics import ModelMetrics from sagemaker.drift_check_baselines import DriftCheckBaselines from sagemaker.explainer import ExplainerConfig @@ -477,6 +478,8 @@ def register( model_life_cycle: Optional[ModelLifeCycle] = None, accept_eula: Optional[bool] = None, model_type: Optional[JumpStartModelType] = None, + model_package_registration_type: Optional[Union[str, PipelineVariable]] = None, + base_model: Optional[ContainerBaseModel] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -531,6 +534,9 @@ def register( model_card (ModeCard or ModelPackageModelCard): document contains qualitative and quantitative information about a model (default: None). model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None). + model_package_registration_type (str or PipelineVariable): Model Package Registration + Type (default: None). + base_model (ContainerBaseModel): ContainerBaseModel object (default: None). Returns: A `sagemaker.model.ModelPackage` instance or pipeline step arguments @@ -578,6 +584,9 @@ def register( if self.model_data is not None: container_def["ModelDataUrl"] = self.model_data + if base_model is not None and hasattr(base_model, "_to_request_dict"): + container_def["BaseModel"] = base_model._to_request_dict() + model_pkg_args = sagemaker.get_model_package_args( self.content_types, self.response_types, @@ -601,6 +610,8 @@ def register( source_uri=source_uri, model_card=model_card, model_life_cycle=model_life_cycle, + model_package_registration_type=model_package_registration_type, + base_model=base_model, ) model_package = self.sagemaker_session.create_model_package_from_containers( **model_pkg_args @@ -2150,6 +2161,7 @@ def __init__( You can find additional parameters for initializing this class at :class:`~sagemaker.model.Model`. """ + super(FrameworkModel, self).__init__( image_uri, model_data, diff --git a/src/sagemaker/modules/train/model_trainer.py b/src/sagemaker/modules/train/model_trainer.py index 828c5da198..179d03519d 100644 --- a/src/sagemaker/modules/train/model_trainer.py +++ b/src/sagemaker/modules/train/model_trainer.py @@ -108,6 +108,7 @@ _get_args_from_recipe, _determine_device_type, _is_nova_recipe, + _is_llmft_recipe, _load_base_recipe, ) @@ -252,6 +253,7 @@ class ModelTrainer(BaseModel): _metric_definitions: Optional[List[MetricDefinition]] = PrivateAttr(default=None) _is_nova_recipe: Optional[bool] = PrivateAttr(default=None) + _is_llmft_recipe: Optional[bool] = PrivateAttr(default=None) _temp_recipe_train_dir: Optional[TemporaryDirectory] = PrivateAttr(default=None) _temp_code_dir: Optional[TemporaryDirectory] = PrivateAttr(default=None) @@ -632,12 +634,13 @@ def _create_training_job_args( final_input_data_config = list(existing_channels.values()) + new_channels - if self._is_nova_recipe: + if self._is_nova_recipe or self._is_llmft_recipe: + for input_data in final_input_data_config: if input_data.channel_name == SM_RECIPE: raise ValueError( "Cannot use reserved channel name 'recipe' as an input channel name " - " for Nova Recipe" + " for Nova or LLMFT Recipe" ) recipe_file_path = os.path.join(self._temp_recipe_train_dir.name, SM_RECIPE_YAML) recipe_channel = self.create_input_data_channel( @@ -646,7 +649,10 @@ def _create_training_job_args( key_prefix=input_data_key_prefix, ) final_input_data_config.append(recipe_channel) - self.hyperparameters.update({"sagemaker_recipe_local_path": SM_RECIPE_CONTAINER_PATH}) + if self._is_nova_recipe: + self.hyperparameters.update( + {"sagemaker_recipe_local_path": SM_RECIPE_CONTAINER_PATH} + ) if final_input_data_config: final_input_data_config = self._get_input_data_config( @@ -1201,14 +1207,15 @@ def from_recipe( training_recipe=training_recipe, recipe_overrides=recipe_overrides ) is_nova = _is_nova_recipe(recipe=recipe) + is_llmft = _is_llmft_recipe(recipe=recipe) - if device_type == "cpu" and not is_nova: + if device_type == "cpu" and not (is_nova or is_llmft): raise ValueError( "Training recipe is not supported for CPU instances. " + "Please provide a GPU or Tranium instance type." ) - if training_image is None and is_nova: - raise ValueError("training_image must be provided when using recipe for Nova.") + if training_image is None and (is_nova or is_llmft): + raise ValueError("training_image must be provided when using recipe for Nova or LLMFT") if training_image_config and training_image is None: raise ValueError("training_image must be provided when using training_image_config.") @@ -1238,7 +1245,7 @@ def from_recipe( model_trainer_args["training_image"] = training_image if hyperparameters and not is_nova: logger.warning( - "Hyperparameters are not supported for general training recipes. " + "Hyperparameters are not supported for general and LLMFT training recipes. " + "Ignoring hyperparameters input." ) if is_nova: @@ -1264,6 +1271,7 @@ def from_recipe( **model_trainer_args, ) model_trainer._is_nova_recipe = is_nova + model_trainer._is_llmft_recipe = is_llmft model_trainer._temp_recipe_train_dir = tmp_dir return model_trainer diff --git a/src/sagemaker/modules/train/sm_recipes/utils.py b/src/sagemaker/modules/train/sm_recipes/utils.py index c7457f6fad..7b4928eb30 100644 --- a/src/sagemaker/modules/train/sm_recipes/utils.py +++ b/src/sagemaker/modules/train/sm_recipes/utils.py @@ -46,6 +46,19 @@ def _try_resolve_recipe(recipe: DictConfig, key=None) -> DictConfig: return recipe[key] +def _resolve_final_recipe(recipe: DictConfig): + """Resolve final recipe.""" + final_recipe = _try_resolve_recipe(recipe) + if final_recipe is None: + final_recipe = _try_resolve_recipe(recipe, "recipes") + if final_recipe is None: + final_recipe = _try_resolve_recipe(recipe, "training") + if final_recipe is None: + raise RuntimeError("Could not resolve provided recipe.") + + return final_recipe + + def _determine_device_type(instance_type: str) -> str: """Determine device type (gpu, cpu, trainium) based on instance type.""" instance_family = instance_type.split(".")[1] @@ -269,6 +282,27 @@ def _is_nova_recipe( return bool(has_nova_model) or bool(has_distillation) +def _is_llmft_recipe( + recipe: DictConfig, +) -> bool: + """Check if the recipe is a LLMFT recipe. + + A recipe is considered a LLMFT recipe if it meets the following conditions: + 1. Having a run section + 2. The model_type in run is llmft + 3. Having a training_config section + + Args: + recipe (DictConfig): The loaded recipe configuration + + Returns: + bool: True if the recipe is a LLMFT recipe, False otherwise + """ + run_config = recipe.get("run", {}) + has_llmft_model = run_config.get("model_type", "").lower() == "llm_finetuning_aws" + return bool(has_llmft_model) and bool(recipe.get("training_config")) + + def _get_args_from_nova_recipe( recipe: DictConfig, compute: Compute, @@ -312,6 +346,12 @@ def _get_args_from_nova_recipe( if lambda_arn: args["hyperparameters"]["eval_lambda_arn"] = lambda_arn + # Handle reward lambda configuration + run_config = recipe.get("run", {}) + reward_lambda_arn = run_config.get("reward_lambda_arn", "") + if reward_lambda_arn: + args["hyperparameters"]["reward_lambda_arn"] = reward_lambda_arn + _register_custom_resolvers() # Resolve Final Recipe @@ -339,6 +379,43 @@ def _get_args_from_nova_recipe( return args, recipe_local_dir +def _get_args_from_llmft_recipe( + recipe: DictConfig, + compute: Compute, +) -> Tuple[Dict[str, Any], tempfile.TemporaryDirectory]: + + if not compute.instance_count and not recipe.get("trainer", {}).get("num_nodes", None): + raise ValueError( + "Must set ``instance_count`` in compute or ``num_nodes`` in trainer in recipe." + ) + if compute.instance_count and recipe.get("trainer", {}).get("num_nodes", None) is not None: + logger.warning( + f"Using Compute to set instance_count:\n{compute}." + "\nIgnoring trainer -> num_nodes in recipe." + ) + compute.instance_count = compute.instance_count or recipe.get("trainer", {}).get("num_nodes") + + args = dict() + + _register_custom_resolvers() + final_recipe = _resolve_final_recipe(recipe) + + # Save Final Recipe to tmp dir + recipe_local_dir = tempfile.TemporaryDirectory(prefix="recipe_") + final_recipe_path = os.path.join(recipe_local_dir.name, SM_RECIPE_YAML) + OmegaConf.save(config=final_recipe, f=final_recipe_path) + + args.update( + { + "compute": compute, + "training_image": None, + "source_code": None, + "distributed": None, + } + ) + return args, recipe_local_dir + + def _get_args_from_recipe( training_recipe: Union[str, DictConfig], compute: Compute, @@ -383,6 +460,9 @@ def _get_args_from_recipe( if _is_nova_recipe(recipe): args, recipe_local_dir = _get_args_from_nova_recipe(recipe, compute, role=role) return args, recipe_local_dir + if _is_llmft_recipe(recipe): + args, recipe_local_dir = _get_args_from_llmft_recipe(recipe, compute) + return args, recipe_local_dir if "trainer" not in recipe: raise ValueError("Supplied recipe does not contain required field trainer.") diff --git a/src/sagemaker/mxnet/model.py b/src/sagemaker/mxnet/model.py index fa0c691d2d..53a6a95e37 100644 --- a/src/sagemaker/mxnet/model.py +++ b/src/sagemaker/mxnet/model.py @@ -19,7 +19,7 @@ import packaging.version import sagemaker -from sagemaker import image_uris, ModelMetrics +from sagemaker import image_uris, ModelMetrics, ContainerBaseModel from sagemaker.deserializers import JSONDeserializer from sagemaker.drift_check_baselines import DriftCheckBaselines from sagemaker.fw_utils import ( @@ -184,6 +184,8 @@ def register( source_uri: Optional[Union[str, PipelineVariable]] = None, model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None, model_life_cycle: Optional[ModelLifeCycle] = None, + model_package_registration_type: Optional[Union[str, PipelineVariable]] = None, + base_model: Optional[ContainerBaseModel] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -238,6 +240,9 @@ def register( model_card (ModeCard or ModelPackageModelCard): document contains qualitative and quantitative information about a model (default: None). model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None). + model_package_registration_type (str or PipelineVariable): Model Package Registration + Type (default: None). + base_model (ContainerBaseModel): ContainerBaseModel object (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -280,6 +285,8 @@ def register( source_uri=source_uri, model_card=model_card, model_life_cycle=model_life_cycle, + model_package_registration_type=model_package_registration_type, + base_model=base_model, ) def prepare_container_def( diff --git a/src/sagemaker/pipeline.py b/src/sagemaker/pipeline.py index b36cd4e917..1e0695bf9e 100644 --- a/src/sagemaker/pipeline.py +++ b/src/sagemaker/pipeline.py @@ -16,7 +16,7 @@ from typing import Callable, Optional, Dict, List, Union import sagemaker -from sagemaker import ModelMetrics, Model +from sagemaker import ModelMetrics, Model, ContainerBaseModel from sagemaker import local from sagemaker import session from sagemaker.config import ( @@ -369,6 +369,8 @@ def register( skip_model_validation: Optional[Union[str, PipelineVariable]] = None, source_uri: Optional[Union[str, PipelineVariable]] = None, model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None, + model_package_registration_type: Optional[Union[str, PipelineVariable]] = None, + base_model: Optional[ContainerBaseModel] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -422,6 +424,9 @@ def register( (default: None). model_card (ModeCard or ModelPackageModelCard): document contains qualitative and quantitative information about a model (default: None). + model_package_registration_type (str or PipelineVariable): Model Package Registration + Type (default: None). + base_model (ContainerBaseModel): ContainerBaseModel object (default: None). Returns: If ``sagemaker_session`` is a ``PipelineSession`` instance, returns pipeline step @@ -449,6 +454,8 @@ def register( } for model in self.models ] + if base_model is not None and hasattr(base_model, "_to_request_dict"): + container_def["BaseModel"] = base_model._to_request_dict() model_pkg_args = sagemaker.get_model_package_args( content_types, @@ -471,6 +478,7 @@ def register( skip_model_validation=skip_model_validation, source_uri=source_uri, model_card=model_card, + model_package_registration_type=model_package_registration_type, ) model_package = self.sagemaker_session.create_model_package_from_containers( diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index ce8daae9d1..3b83370c15 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -180,6 +180,28 @@ def _is_eval_recipe(recipe): return bool(eval_config) +def _is_llmft_recipe(recipe): + """Check if the recipe is a llmft recipe. + + A llmft recipe is identified by: + 1. Having a run section + 2. The model_type in run is llm_finetuning_aws or verl + 3. Having a training_config section OR being a verl recipe + + Args: + recipe (OmegaConf): The loaded recipe configuration + + Returns: + bool: True if the recipe is a llmft recipe, False otherwise + """ + # Check for llmft or verl model + run_config = recipe.get("run", {}) + model_type = run_config.get("model_type", "").lower() + has_llmft_model = model_type == "llm_finetuning_aws" + has_verl_model = model_type == "verl" + return (bool(has_llmft_model) or bool(has_verl_model)) and bool(recipe.get("training_config")) + + def _recipe_initialize_args(source_dir): """Initialize the arguments dictionary for recipe setup. @@ -544,6 +566,7 @@ def __init__( :class:`~sagemaker.estimator.EstimatorBase`. """ self.is_nova_or_eval_recipe = False + self.is_llmft_recipe = False if training_recipe is not None: if entry_point is not None: logger.warning("Argument entry_point will be ignored with training_recipe.") @@ -555,8 +578,8 @@ def __init__( training_recipe, recipe_overrides, source_dir, kwargs ) - if self.is_nova_or_eval_recipe and image_uri is None: - raise ValueError("Must supply image_uri for nova jobs.") + if (self.is_nova_or_eval_recipe or self.is_llmft_recipe) and image_uri is None: + raise ValueError("Must supply image_uri when running llmft or nova jobs.") entry_point = args["entry_point"] source_dir = args["source_dir"] @@ -694,6 +717,41 @@ def hyperparameters(self): return hyperparameters + def _create_recipe_copy(self, original_s3_uri): + """Create a copy of the recipe with the name recipe.yaml in the same S3 bucket. + + This helps us standardize the arguments for file name in the container + when launching the llmft recipes + + Args: + original_s3_uri (str): The S3 URI of the original uploaded file + + Returns: + str: The S3 URI of the copied recipe file + """ + try: + # Parse the original S3 URI + parsed_uri = original_s3_uri.replace("s3://", "").split("/") + bucket = parsed_uri[0] + original_key = "/".join(parsed_uri[1:]) + + # Create new key in the same directory + directory = "/".join(original_key.split("/")[:-1]) + new_key = f"{directory}/recipe.yaml" + + s3_client = self.sagemaker_session.boto_session.client("s3") + + # Copy the object with the new name + copy_source = {"Bucket": bucket, "Key": original_key} + + s3_client.copy_object(CopySource=copy_source, Bucket=bucket, Key=new_key) + + return f"s3://{bucket}/{new_key}" + + except Exception as e: + logger.error(f"Failed to create recipe copy: {str(e)}") + raise + def fit( self, inputs: Optional[Union[str, Dict, TrainingInput, FileSystemInput]] = None, @@ -719,16 +777,21 @@ def fit( """ # Handle recipe upload and input channel creation if we have a recipe if ( - self.is_nova_or_eval_recipe is not None - and self.is_nova_or_eval_recipe + ( + (self.is_nova_or_eval_recipe is not None and self.is_nova_or_eval_recipe) + or (self.is_llmft_recipe is not None and self.is_llmft_recipe) + ) and hasattr(self, "training_recipe_file") and self.training_recipe_file ): # Upload the recipe to S3 if it hasn't been uploaded yet if not hasattr(self, "recipe_s3_uri") or not self.recipe_s3_uri: + self.recipe_s3_uri = self._upload_recipe_to_s3( self.sagemaker_session, self.training_recipe_file.name ) + if self.is_llmft_recipe: + self.recipe_duplicated_s3_uri = self._create_recipe_copy(self.recipe_s3_uri) # Prepare inputs dictionary from sagemaker.inputs import TrainingInput @@ -740,12 +803,20 @@ def fit( # Add the recipe channel recipe_channel_name = "recipe" - inputs[recipe_channel_name] = TrainingInput( - s3_data=os.path.dirname(self.recipe_s3_uri), input_mode="File" - ) + if self.is_nova_or_eval_recipe: + inputs[recipe_channel_name] = TrainingInput( + s3_data=os.path.dirname(self.recipe_s3_uri), input_mode="File" + ) + else: + inputs[recipe_channel_name] = self.recipe_duplicated_s3_uri # Update hyperparameters to reference the recipe location in the container - recipe_filename = os.path.basename(self.training_recipe_file.name) + # For LLMFT recipes, use the standardized filename "recipe.yaml" since _create_recipe_copy() + # creates a copy with that name. For other recipes, use the original filename. + if self.is_llmft_recipe: + recipe_filename = "recipe.yaml" + else: + recipe_filename = os.path.basename(self.training_recipe_file.name) self._hyperparameters.update( { @@ -884,7 +955,6 @@ def _recipe_load(training_recipe, recipe_launcher_dir, training_recipes_cfg): """ recipe_name = os.path.splitext(os.path.basename(training_recipe))[0] temp_local_recipe = tempfile.NamedTemporaryFile(prefix=recipe_name, suffix=".yaml").name - try: if training_recipe.endswith(".yaml"): _recipe_load_from_yaml(training_recipe, temp_local_recipe) @@ -1083,7 +1153,6 @@ def _recipe_resolve_and_save(self, recipe, recipe_name, source_dir): suffix=".yaml", ) OmegaConf.save(config=final_recipe, f=self.training_recipe_file.name) - return final_recipe def _upload_recipe_to_s3(self, session, recipe_file_path): @@ -1155,6 +1224,7 @@ def _setup_for_training_recipe(self, training_recipe, recipe_overrides, source_d recipe = OmegaConf.merge(recipe, recipe_overrides) self.is_nova_or_eval_recipe = _is_nova_recipe(recipe) or _is_eval_recipe(recipe) + self.is_llmft_recipe = _is_llmft_recipe(recipe) if self.is_nova_or_eval_recipe: return self._setup_for_nova_recipe( recipe, @@ -1162,6 +1232,13 @@ def _setup_for_training_recipe(self, training_recipe, recipe_overrides, source_d source_dir, kwargs, ) + elif self.is_llmft_recipe: + return self._setup_for_llmft_recipe( + recipe, + recipe_name, + source_dir, + kwargs, + ) else: return self._setup_for_standard_recipe( recipe, @@ -1251,6 +1328,57 @@ def _setup_for_nova_recipe( if lambda_arn: args["hyperparameters"]["eval_lambda_arn"] = lambda_arn + # Handle reward lambda configuration + run_config = recipe.get("run", {}) + reward_lambda_arn = run_config.get("reward_lambda_arn", "") + if reward_lambda_arn: + args["hyperparameters"]["reward_lambda_arn"] = reward_lambda_arn + + # Resolve and save the final recipe + self._recipe_resolve_and_save(recipe, recipe_name, args["source_dir"]) + + return args + + def _setup_for_llmft_recipe( + self, + recipe, + recipe_name, + source_dir, + kwargs, + ): + """Set up configuration specifically for llmft recipes. + + Args: + recipe (OmegaConf): Recipe configuration. + recipe_name (str): Recipe name. + source_dir (str): Path to the source directory. + kwargs (dict): Dictionary of keyword arguments. + + Returns: + dict: Arguments dictionary for estimator initialization. + """ + # Initialize args + args = _recipe_initialize_args(source_dir) + + args["entry_point"] = None + args["source_dir"] = None + args["distribution"] = {} + + # Handle instance count for standard recipes + if "instance_count" in kwargs: + if "num_nodes" in recipe.get("trainer", {}): + logger.warning( + "Using instance_count argument to estimator to set number " + "of nodes. Ignoring trainer -> num_nodes in recipe." + ) + elif "trainer" in recipe and "num_nodes" in recipe["trainer"]: + kwargs["instance_count"] = recipe["trainer"]["num_nodes"] + else: + raise ValueError( + "Must set either instance_count argument for estimator or " + "set trainer -> num_nodes in recipe." + ) + # Resolve and save the final recipe self._recipe_resolve_and_save(recipe, recipe_name, args["source_dir"]) diff --git a/src/sagemaker/pytorch/model.py b/src/sagemaker/pytorch/model.py index 958327ba08..9be5ed2e4f 100644 --- a/src/sagemaker/pytorch/model.py +++ b/src/sagemaker/pytorch/model.py @@ -19,7 +19,7 @@ import packaging.version import sagemaker -from sagemaker import image_uris, ModelMetrics +from sagemaker import image_uris, ModelMetrics, ContainerBaseModel from sagemaker.deserializers import NumpyDeserializer from sagemaker.drift_check_baselines import DriftCheckBaselines from sagemaker.fw_utils import ( @@ -186,6 +186,8 @@ def register( source_uri: Optional[Union[str, PipelineVariable]] = None, model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None, model_life_cycle: Optional[ModelLifeCycle] = None, + model_package_registration_type: Optional[Union[str, PipelineVariable]] = None, + base_model: Optional[ContainerBaseModel] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -240,6 +242,9 @@ def register( model_card (ModeCard or ModelPackageModelCard): document contains qualitative and quantitative information about a model (default: None). model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None). + model_package_registration_type (str or PipelineVariable): Model Package Registration + Type (default: None). + base_model (ContainerBaseModel): ContainerBaseModel object (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -282,6 +287,8 @@ def register( source_uri=source_uri, model_card=model_card, model_life_cycle=model_life_cycle, + model_package_registration_type=model_package_registration_type, + base_model=base_model, ) def prepare_container_def( diff --git a/src/sagemaker/serve/model_format/mlflow/utils.py b/src/sagemaker/serve/model_format/mlflow/utils.py index 69082fe575..99786397c4 100644 --- a/src/sagemaker/serve/model_format/mlflow/utils.py +++ b/src/sagemaker/serve/model_format/mlflow/utils.py @@ -208,7 +208,7 @@ def _get_deployment_flavor(flavor_metadata: Optional[Dict[str, Any]]) -> str: def _get_python_version_from_parsed_mlflow_model_file( - parsed_metadata: Dict[str, Any] + parsed_metadata: Dict[str, Any], ) -> Optional[str]: """Checks the python version of a given parsed MLflow model file. diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 13fd3155aa..fd7d8554ee 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -4470,6 +4470,7 @@ def create_model_package_from_containers( source_uri=None, model_card=None, model_life_cycle=None, + model_package_registration_type=None, ): """Get request dictionary for CreateModelPackage API. @@ -4510,6 +4511,8 @@ def create_model_package_from_containers( model_card (ModeCard or ModelPackageModelCard): document contains qualitative and quantitative information about a model (default: None). model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None). + model_package_registration_type (str or PipelineVariable): Model Package Registration + Type (default: None). """ if containers: # Containers are provided. Now we can merge missing entries from config. @@ -4569,6 +4572,7 @@ def create_model_package_from_containers( source_uri=source_uri, model_card=model_card, model_life_cycle=model_life_cycle, + model_package_registration_type=model_package_registration_type, ) def submit(request): @@ -7627,6 +7631,8 @@ def get_model_package_args( source_uri=None, model_card=None, model_life_cycle=None, + model_package_registration_type=None, + base_model=None, ): """Get arguments for create_model_package method. @@ -7669,6 +7675,9 @@ def get_model_package_args( model_card (ModeCard or ModelPackageModelCard): document contains qualitative and quantitative information about a model (default: None). model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None). + model_package_registration_type (str): Model Package Registration + Type (default: None). + base_model (ContainerBaseModel): ContainerBaseModel object (default: None). Returns: dict: A dictionary of method argument names and values. @@ -7681,9 +7690,15 @@ def get_model_package_args( } if model_data is not None: container["ModelDataUrl"] = model_data - + if base_model is not None: + container["BaseModel"] = base_model._to_request_dict() containers = [container] + # Convert base_model in containers to request dict if they have _to_request_dict method + for container in containers: + if "BaseModel" in container and hasattr(container["BaseModel"], "_to_request_dict"): + container["BaseModel"] = container["BaseModel"]._to_request_dict() + model_package_args = { "containers": containers, "inference_instances": inference_instances, @@ -7735,6 +7750,8 @@ def get_model_package_args( original_req["ModelCardContent"] = original_req["Content"] del original_req["Content"] model_package_args["model_card"] = original_req + if model_package_registration_type is not None: + model_package_args["model_package_registration_type"] = model_package_registration_type return model_package_args @@ -7762,6 +7779,7 @@ def get_create_model_package_request( source_uri=None, model_card=None, model_life_cycle=None, + model_package_registration_type=None, ): """Get request dictionary for CreateModelPackage API. @@ -7802,6 +7820,8 @@ def get_create_model_package_request( model_card (ModeCard or ModelPackageModelCard): document contains qualitative and quantitative information about a model (default: None). model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None). + model_package_registration_type (str): Model Package Registration + Type (default: None). """ if all([model_package_name, model_package_group_name]): @@ -7903,6 +7923,8 @@ def get_create_model_package_request( request_dict["ModelCard"] = model_card if model_life_cycle is not None: request_dict["ModelLifeCycle"] = model_life_cycle + if model_package_registration_type is not None: + request_dict["ModelPackageRegistrationType"] = model_package_registration_type return request_dict diff --git a/src/sagemaker/sklearn/model.py b/src/sagemaker/sklearn/model.py index a9b0e2e8f0..2b7afceb23 100644 --- a/src/sagemaker/sklearn/model.py +++ b/src/sagemaker/sklearn/model.py @@ -17,7 +17,7 @@ from typing import Callable, Union, Optional, List, Dict import sagemaker -from sagemaker import image_uris, ModelMetrics +from sagemaker import image_uris, ModelMetrics, ContainerBaseModel from sagemaker.deserializers import NumpyDeserializer from sagemaker.drift_check_baselines import DriftCheckBaselines from sagemaker.fw_utils import model_code_key_prefix, validate_version_or_image_args @@ -179,6 +179,8 @@ def register( source_uri: Optional[Union[str, PipelineVariable]] = None, model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None, model_life_cycle: Optional[ModelLifeCycle] = None, + model_package_registration_type: Optional[Union[str, PipelineVariable]] = None, + base_model: Optional[ContainerBaseModel] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -233,6 +235,9 @@ def register( model_card (ModeCard or ModelPackageModelCard): document contains qualitative and quantitative information about a model (default: None). model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None). + model_package_registration_type (str or PipelineVariable): Model Package Registration + Type (default: None). + base_model (ContainerBaseModel): ContainerBaseModel object (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -275,6 +280,8 @@ def register( source_uri=source_uri, model_card=model_card, model_life_cycle=model_life_cycle, + model_package_registration_type=model_package_registration_type, + base_model=base_model, ) def prepare_container_def( diff --git a/src/sagemaker/tensorflow/model.py b/src/sagemaker/tensorflow/model.py index b384cbbbb5..2e07798926 100644 --- a/src/sagemaker/tensorflow/model.py +++ b/src/sagemaker/tensorflow/model.py @@ -17,7 +17,7 @@ from typing import Callable, Union, Optional, List, Dict import sagemaker -from sagemaker import image_uris, s3, ModelMetrics +from sagemaker import image_uris, s3, ModelMetrics, ContainerBaseModel from sagemaker.deserializers import JSONDeserializer from sagemaker.deprecations import removed_kwargs from sagemaker.drift_check_baselines import DriftCheckBaselines @@ -241,6 +241,8 @@ def register( source_uri: Optional[Union[str, PipelineVariable]] = None, model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None, model_life_cycle: Optional[ModelLifeCycle] = None, + model_package_registration_type: Optional[Union[str, PipelineVariable]] = None, + base_model: Optional[ContainerBaseModel] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -295,6 +297,9 @@ def register( model_card (ModeCard or ModelPackageModelCard): document contains qualitative and quantitative information about a model (default: None). model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None). + model_package_registration_type (str or PipelineVariable): Model Package Registration + Type (default: None). + base_model (ContainerBaseModel): ContainerBaseModel object (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -337,6 +342,8 @@ def register( source_uri=source_uri, model_card=model_card, model_life_cycle=model_life_cycle, + model_package_registration_type=model_package_registration_type, + base_model=base_model, ) def deploy( diff --git a/src/sagemaker/workflow/_utils.py b/src/sagemaker/workflow/_utils.py index 36c393969a..b3796cac13 100644 --- a/src/sagemaker/workflow/_utils.py +++ b/src/sagemaker/workflow/_utils.py @@ -331,6 +331,8 @@ def __init__( source_uri=None, model_card=None, model_life_cycle=None, + model_package_registration_type=None, + base_model=None, **kwargs, ): """Constructor of a register model step. @@ -387,6 +389,9 @@ def __init__( quantitative information about a model (default: None). model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None). **kwargs: additional arguments to `create_model`. + model_package_registration_type (str): Model Package Registration + Type (default: None). + base_model (ContainerBaseModel): ContainerBaseModel object (default: None). """ super(_RegisterModelStep, self).__init__( name, StepTypeEnum.REGISTER_MODEL, display_name, description, depends_on, retry_policies @@ -425,7 +430,8 @@ def __init__( self.source_uri = source_uri self.model_card = model_card self.model_life_cycle = model_life_cycle - + self.model_package_registration_type = model_package_registration_type + self.base_model = base_model self._properties = Properties( step_name=name, step=self, shape_name="DescribeModelPackageOutput" ) @@ -502,6 +508,8 @@ def arguments(self) -> RequestType: source_uri=self.source_uri, model_card=self.model_card, model_life_cycle=self.model_life_cycle, + model_package_registration_type=self.model_package_registration_type, + base_model=self.base_model, ) request_dict = get_create_model_package_request(**model_package_args) diff --git a/src/sagemaker/workflow/conditions.py b/src/sagemaker/workflow/conditions.py index 4b4996a7fa..6425944a7d 100644 --- a/src/sagemaker/workflow/conditions.py +++ b/src/sagemaker/workflow/conditions.py @@ -296,7 +296,7 @@ def _referenced_steps(self) -> List[str]: def primitive_or_expr( - value: Union[ExecutionVariable, Expression, PrimitiveType, Parameter, Properties, StepOutput] + value: Union[ExecutionVariable, Expression, PrimitiveType, Parameter, Properties, StepOutput], ) -> Union[Dict[str, str], PrimitiveType]: """Provide the expression of the value or return value if it is a primitive. diff --git a/src/sagemaker/workflow/step_collections.py b/src/sagemaker/workflow/step_collections.py index a1d939254c..9439ddca84 100644 --- a/src/sagemaker/workflow/step_collections.py +++ b/src/sagemaker/workflow/step_collections.py @@ -99,6 +99,8 @@ def __init__( source_uri=None, model_card=None, model_life_cycle=None, + model_package_registration_type=None, + base_model=None, **kwargs, ): """Construct steps `_RepackModelStep` and `_RegisterModelStep` based on the estimator. @@ -160,6 +162,9 @@ def __init__( model_card (ModeCard or ModelPackageModelCard): document contains qualitative and quantitative information about a model (default: None). model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None). + model_package_registration_type (str): Model Package Registration + Type (default: None). + base_model (ContainerBaseModel): ContainerBaseModel object (default: None). **kwargs: additional arguments to `create_model`. """ super().__init__(name=name, depends_on=depends_on) @@ -300,6 +305,8 @@ def __init__( source_uri=source_uri, model_card=model_card, model_life_cycle=model_life_cycle, + model_package_registration_type=model_package_registration_type, + base_model=base_model, **kwargs, ) if not repack_model: diff --git a/src/sagemaker/xgboost/model.py b/src/sagemaker/xgboost/model.py index f4797c79e7..a28b7a2721 100644 --- a/src/sagemaker/xgboost/model.py +++ b/src/sagemaker/xgboost/model.py @@ -17,7 +17,7 @@ from typing import Callable, Optional, Union, List, Dict import sagemaker -from sagemaker import image_uris, ModelMetrics +from sagemaker import image_uris, ModelMetrics, ContainerBaseModel from sagemaker.deserializers import CSVDeserializer from sagemaker.drift_check_baselines import DriftCheckBaselines from sagemaker.fw_utils import model_code_key_prefix @@ -167,6 +167,8 @@ def register( source_uri: Optional[Union[str, PipelineVariable]] = None, model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None, model_life_cycle: Optional[ModelLifeCycle] = None, + model_package_registration_type: Optional[Union[str, PipelineVariable]] = None, + base_model: Optional[ContainerBaseModel] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -221,6 +223,9 @@ def register( model_card (ModeCard or ModelPackageModelCard): document contains qualitative and quantitative information about a model (default: None). model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None). + model_package_registration_type (str or PipelineVariable): Model Package Registration + Type (default: None). + base_model (ContainerBaseModel): ContainerBaseModel object (default: None). Returns: str: A string of SageMaker Model Package ARN. @@ -263,6 +268,8 @@ def register( source_uri=source_uri, model_card=model_card, model_life_cycle=model_life_cycle, + model_package_registration_type=model_package_registration_type, + base_model=base_model, ) def prepare_container_def( diff --git a/tests/data/huggingface/run_glue.py b/tests/data/huggingface/run_glue.py index 1060398fa4..f55f484163 100644 --- a/tests/data/huggingface/run_glue.py +++ b/tests/data/huggingface/run_glue.py @@ -13,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" Finetuning the library models for sequence classification on GLUE.""" +"""Finetuning the library models for sequence classification on GLUE.""" # You can also adapt this script on your own text classification task. Pointers for this are left as comments. import logging diff --git a/tests/data/huggingface_byoc/run_glue.py b/tests/data/huggingface_byoc/run_glue.py index 1060398fa4..f55f484163 100644 --- a/tests/data/huggingface_byoc/run_glue.py +++ b/tests/data/huggingface_byoc/run_glue.py @@ -13,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" Finetuning the library models for sequence classification on GLUE.""" +"""Finetuning the library models for sequence classification on GLUE.""" # You can also adapt this script on your own text classification task. Pointers for this are left as comments. import logging diff --git a/tests/integ/sagemaker/workflow/test_model_create_and_registration.py b/tests/integ/sagemaker/workflow/test_model_create_and_registration.py index e84c1920f4..8ece810b58 100644 --- a/tests/integ/sagemaker/workflow/test_model_create_and_registration.py +++ b/tests/integ/sagemaker/workflow/test_model_create_and_registration.py @@ -40,6 +40,7 @@ Model, ModelMetrics, MetricsSource, + ContainerBaseModel, ) from sagemaker import FileSource, utils from sagemaker.inputs import CreateModelInput @@ -605,6 +606,10 @@ def test_model_registration_with_drift_check_baselines( nearest_model_name = "resnet50" data_input_configuration = '{"input_1":[1,224,224,3]}' skip_model_validation = "All" + model_package_registration_type = "Registered" + base_model = ContainerBaseModel( + hub_content_name="test", hub_content_version="1234.1234", recipe_name="testRecipeName" + ) # If image_uri is not provided, the instance_type should not be a pipeline variable # since instance_type is used to retrieve image_uri in compile time (PySDK) @@ -639,6 +644,8 @@ def test_model_registration_with_drift_check_baselines( nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, + model_package_registration_type=model_package_registration_type, + base_model=base_model, ) pipeline = Pipeline( @@ -684,7 +691,6 @@ def test_model_registration_with_drift_check_baselines( response = sagemaker_session_for_pipeline.sagemaker_client.describe_model_package( ModelPackageName=execution_steps[0]["Metadata"]["RegisterModel"]["Arn"] ) - assert ( response["ModelMetrics"]["Explainability"]["Report"]["ContentType"] == "application/json" diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index bd870dc461..71e7427fb5 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -422,7 +422,7 @@ def get_base_deployment_configs( def append_instance_stat_metrics( - metrics: Dict[str, List[JumpStartBenchmarkStat]] + metrics: Dict[str, List[JumpStartBenchmarkStat]], ) -> Dict[str, List[JumpStartBenchmarkStat]]: if metrics is not None: for key in metrics: diff --git a/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py b/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py index 6087050171..b50ab79719 100644 --- a/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py +++ b/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py @@ -29,7 +29,9 @@ _configure_trainium_args, _get_trainining_recipe_gpu_model_name_and_script, _is_nova_recipe, + _is_llmft_recipe, _get_args_from_nova_recipe, + _get_args_from_llmft_recipe, ) from sagemaker.modules.utils import _run_clone_command_silent from sagemaker.modules.configs import Compute @@ -478,3 +480,191 @@ def test_get_args_from_nova_recipe_with_evaluation(test_case): recipe=recipe, compute=test_case["compute"], role=test_case["role"] ) assert args == test_case["expected_args"] + + +@pytest.mark.parametrize( + "test_case", + [ + { + "recipe": { + "run": { + "name": "dummy-model", + "model_type": "llm_finetuning_aws", + }, + "trainer": {"num_nodes": "12"}, + "training_config": {"model_save_name": "xyz"}, + }, + "is_llmft": True, + }, + { + "recipe": { + "run": { + "name": "dummy-model", + "model_type": "llm_finetuning_aws", + }, + "training_config": {"model_save_name": "xyz"}, + }, + "is_llmft": True, + }, + { + "recipe": { + "run": { + "name": "dummy-model", + "model_type": "llm_finetuning_aws", + }, + }, + "is_llmft": False, + }, + { + "recipe": { + "run": { + "name": "dummy-model", + "model_type": "xyz", + }, + "training_config": {"model_save_name": "xyz"}, + }, + "is_llmft": False, + }, + ], + ids=[ + "llmft_model", + "llmft_model_subtype", + "llmft_missing_training_config", + "non_llmft_model", + ], +) +def test_is_llmft_recipe(test_case): + recipe = OmegaConf.create(test_case["recipe"]) + is_llmft = _is_llmft_recipe(recipe) + assert is_llmft == test_case["is_llmft"] + + +@patch("sagemaker.modules.train.sm_recipes.utils._get_args_from_llmft_recipe") +def test_get_args_from_recipe_with_llmft_and_role(mock_get_args_from_llmft_recipe): + # Set up mock return value + mock_args = {} + mock_dir = MagicMock() + mock_get_args_from_llmft_recipe.return_value = (mock_args, mock_dir) + + recipe = { + "run": { + "name": "dummy-model", + "model_type": "llm_finetuning_aws", + }, + "trainer": {"num_nodes": "12"}, + "training_config": {"model_save_name": "xyz"}, + } + compute = Compute(instance_type="ml.g5.xlarge") + role = "arn:aws:iam::123456789012:role/SageMakerRole" + + # Mock the LLMFT recipe detection to return True + with patch("sagemaker.modules.train.sm_recipes.utils._is_llmft_recipe", return_value=True): + _get_args_from_recipe( + training_recipe=recipe, + compute=compute, + region_name="us-west-2", + recipe_overrides=None, + requirements=None, + role=role, + ) + + # Verify _get_args_from_llmft_recipe was called + mock_get_args_from_llmft_recipe.assert_called_once_with(recipe, compute) + + +@pytest.mark.parametrize( + "test_case", + [ + { + "recipe": { + "run": { + "name": "dummy-model", + "model_type": "llm_finetuning_aws", + }, + "trainer": {"num_nodes": "12"}, + "training_config": {"model_save_name": "xyz"}, + }, + "compute": Compute(instance_type="ml.m5.xlarge", instance_count=2), + "expected_args": { + "compute": Compute(instance_type="ml.m5.xlarge", instance_count=2), + "training_image": None, + "source_code": None, + "distributed": None, + }, + }, + { + "recipe": { + "run": { + "name": "dummy-model", + "model_type": "llm_finetuning_aws", + }, + "training_config": {"model_save_name": "xyz"}, + }, + "compute": Compute(instance_type="ml.m5.xlarge", instance_count=2), + "expected_args": { + "compute": Compute(instance_type="ml.m5.xlarge", instance_count=2), + "training_image": None, + "source_code": None, + "distributed": None, + }, + }, + ], +) +def test_get_args_from_llmft_recipe(test_case): + recipe = OmegaConf.create(test_case["recipe"]) + args, _ = _get_args_from_llmft_recipe(recipe=recipe, compute=test_case["compute"]) + assert args == test_case["expected_args"] + + +@pytest.mark.parametrize( + "test_case", + [ + { + "recipe": { + "run": { + "model_type": "amazon.nova", + "model_name_or_path": "dummy-test", + "reward_lambda_arn": "arn:aws:lambda:us-east-1:123456789012:function:MyRewardLambdaFunction", + }, + }, + "compute": Compute(instance_type="ml.m5.xlarge", instance_count=2), + "role": "arn:aws:iam::123456789012:role/SageMakerRole", + "expected_args": { + "compute": Compute(instance_type="ml.m5.xlarge", instance_count=2), + "hyperparameters": { + "base_model": "dummy-test", + "reward_lambda_arn": "arn:aws:lambda:us-east-1:123456789012:function:MyRewardLambdaFunction", + }, + "training_image": None, + "source_code": None, + "distributed": None, + }, + }, + { + "recipe": { + "run": { + "model_type": "amazon.nova", + "model_name_or_path": "dummy-test", + # No reward_lambda_arn - should not be in hyperparameters + }, + }, + "compute": Compute(instance_type="ml.m5.xlarge", instance_count=2), + "role": "arn:aws:iam::123456789012:role/SageMakerRole", + "expected_args": { + "compute": Compute(instance_type="ml.m5.xlarge", instance_count=2), + "hyperparameters": { + "base_model": "dummy-test", + }, + "training_image": None, + "source_code": None, + "distributed": None, + }, + }, + ], +) +def test_get_args_from_nova_recipe_with_reward_lambda(test_case): + recipe = OmegaConf.create(test_case["recipe"]) + args, _ = _get_args_from_nova_recipe( + recipe=recipe, compute=test_case["compute"], role=test_case["role"] + ) + assert args == test_case["expected_args"] diff --git a/tests/unit/sagemaker/modules/train/test_model_trainer.py b/tests/unit/sagemaker/modules/train/test_model_trainer.py index 73893ea7f4..cbb9d3aa92 100644 --- a/tests/unit/sagemaker/modules/train/test_model_trainer.py +++ b/tests/unit/sagemaker/modules/train/test_model_trainer.py @@ -1476,3 +1476,93 @@ def test_nova_recipe_with_distillation(modules_session): # Clean up the temporary file os.unlink(recipe.name) + + +@patch("sagemaker.modules.train.model_trainer._get_unique_name") +@patch("sagemaker.modules.train.model_trainer.TrainingJob") +def test_llmft_recipe(mock_training_job, mock_unique_name, modules_session): + def mock_upload_data(path, bucket, key_prefix): + if os.path.isfile(path): + file_name = os.path.basename(path) + return f"s3://{bucket}/{key_prefix}/{file_name}" + else: + return f"s3://{bucket}/{key_prefix}" + + unique_name = "base-job-0123456789" + base_name = "base-job" + + modules_session.upload_data.side_effect = mock_upload_data + mock_unique_name.return_value = unique_name + + recipe_data = { + "run": { + "name": "dummy-model", + "model_type": "llm_finetuning_aws", + }, + "trainer": {"num_nodes": "12"}, + "training_config": {"model_save_name": "xyz"}, + } + with NamedTemporaryFile(suffix=".yaml", delete=False) as recipe: + with open(recipe.name, "w") as file: + yaml.dump(recipe_data, file) + + trainer = ModelTrainer.from_recipe( + training_recipe=recipe.name, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute=DEFAULT_COMPUTE_CONFIG, + training_image=DEFAULT_IMAGE, + base_job_name=base_name, + ) + + assert trainer._is_llmft_recipe + + trainer.train() + mock_training_job.create.assert_called_once() + + default_base_path = f"s3://{DEFAULT_BUCKET}/{DEFAULT_BUCKET_PREFIX}/{base_name}" + assert mock_training_job.create.call_args.kwargs["input_data_config"] == [ + Channel( + channel_name="recipe", + data_source=DataSource( + s3_data_source=S3DataSource( + s3_data_type="S3Prefix", + s3_uri=f"{default_base_path}/{unique_name}/input/recipe/recipe.yaml", + s3_data_distribution_type="FullyReplicated", + ) + ), + input_mode="File", + ) + ] + + +def test_llmft_recipe_missing_training_image_error(modules_session): + """Test that LLMFT recipe throws an error when training_image is not provided.""" + recipe_data = { + "run": { + "name": "dummy-model", + "model_type": "llm_finetuning_aws", + }, + "trainer": {"num_nodes": "12"}, + "training_config": {"model_save_name": "xyz"}, + } + + with NamedTemporaryFile(suffix=".yaml", delete=False) as recipe: + with open(recipe.name, "w") as file: + yaml.dump(recipe_data, file) + + # Test that ValueError is raised when training_image is not provided for LLMFT recipe + with pytest.raises( + ValueError, match="training_image must be provided when using recipe for Nova or LLMFT" + ): + ModelTrainer.from_recipe( + training_recipe=recipe.name, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute=DEFAULT_COMPUTE_CONFIG, + # Note: training_image is intentionally not provided + base_job_name="base-job", + ) + + # Clean up the temporary file + os.unlink(recipe.name) diff --git a/tests/unit/sagemaker/workflow/test_pipeline_session.py b/tests/unit/sagemaker/workflow/test_pipeline_session.py index d05bb5c20f..2b7eee5a05 100644 --- a/tests/unit/sagemaker/workflow/test_pipeline_session.py +++ b/tests/unit/sagemaker/workflow/test_pipeline_session.py @@ -190,6 +190,7 @@ def test_pipeline_session_context_for_model_step_without_instance_types( framework="TENSORFLOW", framework_version="2.9", nearest_model_name="resnet50", + model_package_registration_type="Registered", ) expected_output = { @@ -221,6 +222,7 @@ def test_pipeline_session_context_for_model_step_without_instance_types( "SkipModelValidation": "None", "SamplePayloadUrl": "s3://test-bucket/model", "Task": "IMAGE_CLASSIFICATION", + "ModelPackageRegistrationType": "Registered", } assert register_step_args.create_model_package_request == expected_output @@ -249,6 +251,7 @@ def test_pipeline_session_context_for_model_step_with_one_instance_types( framework_version="2.9", nearest_model_name="resnet50", data_input_configuration='{"input_1":[1,224,224,3]}', + model_package_registration_type="Registered", ) expected_output = { @@ -284,6 +287,7 @@ def test_pipeline_session_context_for_model_step_with_one_instance_types( "SkipModelValidation": "None", "SamplePayloadUrl": "s3://test-bucket/model", "Task": "IMAGE_CLASSIFICATION", + "ModelPackageRegistrationType": "Registered", } assert register_step_args.create_model_package_request == expected_output diff --git a/tests/unit/sagemaker/workflow/test_step_collections.py b/tests/unit/sagemaker/workflow/test_step_collections.py index ddc76a05f7..4cca2af50c 100644 --- a/tests/unit/sagemaker/workflow/test_step_collections.py +++ b/tests/unit/sagemaker/workflow/test_step_collections.py @@ -653,6 +653,7 @@ def test_register_model_sip(estimator, model_metrics, drift_check_baselines): framework="TENSORFLOW", framework_version="2.9", nearest_model_name="resnet50", + model_package_registration_type="Registered", ) assert ordered(register_model.request_dicts()) == ordered( [ @@ -688,6 +689,7 @@ def test_register_model_sip(estimator, model_metrics, drift_check_baselines): }, "ModelApprovalStatus": "Approved", "SkipModelValidation": "None", + "ModelPackageRegistrationType": "Registered", "ModelMetrics": { "Bias": {}, "Explainability": {}, diff --git a/tests/unit/test_pytorch.py b/tests/unit/test_pytorch.py index 8352f3090b..980bb4435b 100644 --- a/tests/unit/test_pytorch.py +++ b/tests/unit/test_pytorch.py @@ -20,7 +20,7 @@ from packaging.version import Version import tempfile -from sagemaker import image_uris +from sagemaker import image_uris, ContainerBaseModel from sagemaker.pytorch import defaults from sagemaker.pytorch import PyTorch, PyTorchPredictor, PyTorchModel from sagemaker.pytorch.estimator import ( @@ -772,7 +772,9 @@ def test_register_pytorch_model_auto_infer_framework( py_version=pytorch_inference_py_version, sagemaker_session=sagemaker_session, ) - + base_model = ContainerBaseModel( + hub_content_name="test", hub_content_version="1234.1234", recipe_name="testRecipeName" + ) pytorch_model.register( content_types, response_types, @@ -781,6 +783,8 @@ def test_register_pytorch_model_auto_infer_framework( model_package_group_name=model_package_group_name, marketplace_cert=True, image_uri=image_uri, + model_package_registration_type="Registered", + base_model=base_model, ) expected_create_model_package_request = { @@ -791,6 +795,11 @@ def test_register_pytorch_model_auto_infer_framework( "ModelDataUrl": ANY, "Framework": "PYTORCH", "FrameworkVersion": pytorch_inference_version, + "BaseModel": { + "HubContentName": "test", + "HubContentVersion": "1234.1234", + "RecipeName": "testRecipeName", + }, }, ], "content_types": content_types, @@ -798,6 +807,7 @@ def test_register_pytorch_model_auto_infer_framework( "inference_instances": inference_instances, "transform_instances": transform_instances, "model_package_group_name": model_package_group_name, + "model_package_registration_type": "Registered", "marketplace_cert": True, } sagemaker_session.create_model_package_from_containers.assert_called_with( @@ -1159,3 +1169,36 @@ def test_training_recipe_images_uri(): } neuron_image_uri = _get_training_recipe_image_uri(neuron_image_cfg, "us-west-2") assert neuron_image_uri == RECIPE_NEURON_IMAGE + + +@pytest.mark.parametrize( + "recipe_config,expected,test_description", + [ + ( + {"run": {"model_type": "llm_finetuning_aws"}, "training_config": {"some": "config"}}, + True, + "standard LLMFT recipe", + ), + ( + { + "run": {"model_type": "verl"}, + "training_config": {"some": "config"}, + "actor_rollout_ref": {"some": "config"}, + }, + True, + "VERL recipe", + ), + ( + {"run": {"model_type": "regular_model"}, "some_config": {"value": "test"}}, + False, + "non-LLMFT recipe", + ), + ], +) +def test_is_llmft_recipe(recipe_config, expected, test_description): + """Test LLMFT recipe detection for various configurations.""" + from sagemaker.pytorch.estimator import _is_llmft_recipe + from omegaconf import OmegaConf + + recipe = OmegaConf.create(recipe_config) + assert _is_llmft_recipe(recipe) is expected, f"Failed for {test_description}" diff --git a/tests/unit/test_pytorch_nova.py b/tests/unit/test_pytorch_nova.py index 662d27e85f..ddc4b62d1e 100644 --- a/tests/unit/test_pytorch_nova.py +++ b/tests/unit/test_pytorch_nova.py @@ -832,3 +832,81 @@ def test_setup_for_nova_recipe_sets_model_type(mock_resolve_save, sagemaker_sess # Verify that model_type hyperparameter was set correctly assert pytorch._hyperparameters.get("model_type") == "amazon.nova.llama-2-7b" + + +@patch("sagemaker.pytorch.estimator.PyTorch._recipe_resolve_and_save") +def test_setup_for_nova_recipe_with_reward_lambda(mock_resolve_save, sagemaker_session): + """Test that _setup_for_nova_recipe correctly handles reward lambda configuration.""" + # Create a mock recipe with reward lambda config + recipe = OmegaConf.create( + { + "run": { + "model_type": "amazon.nova.foobar3", + "model_name_or_path": "foobar/foobar-3-8b", + "reward_lambda_arn": "arn:aws:lambda:us-west-2:123456789012:function:reward-function", + "replicas": 1, + }, + } + ) + + with patch( + "sagemaker.pytorch.estimator.PyTorch._recipe_load", return_value=("nova_recipe", recipe) + ): + mock_resolve_save.return_value = recipe + + pytorch = PyTorch( + training_recipe="nova_recipe", + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE_GPU, + image_uri=IMAGE_URI, + framework_version="1.13.1", + py_version="py3", + ) + + # Check that the Nova recipe was correctly identified + assert pytorch.is_nova_or_eval_recipe is True + + # Verify that reward_lambda_arn hyperparameter was set correctly + assert ( + pytorch._hyperparameters.get("reward_lambda_arn") + == "arn:aws:lambda:us-west-2:123456789012:function:reward-function" + ) + + +@patch("sagemaker.pytorch.estimator.PyTorch._recipe_resolve_and_save") +def test_setup_for_nova_recipe_without_reward_lambda(mock_resolve_save, sagemaker_session): + """Test that _setup_for_nova_recipe does not set reward_lambda_arn when not present.""" + # Create a mock recipe without reward lambda config + recipe = OmegaConf.create( + { + "run": { + "model_type": "amazon.nova.foobar3", + "model_name_or_path": "foobar/foobar-3-8b", + "replicas": 1, + }, + } + ) + + with patch( + "sagemaker.pytorch.estimator.PyTorch._recipe_load", return_value=("nova_recipe", recipe) + ): + mock_resolve_save.return_value = recipe + + pytorch = PyTorch( + training_recipe="nova_recipe", + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE_GPU, + image_uri=IMAGE_URI, + framework_version="1.13.1", + py_version="py3", + ) + + # Check that the Nova recipe was correctly identified + assert pytorch.is_nova_or_eval_recipe is True + + # Verify that reward_lambda_arn hyperparameter was not set + assert "reward_lambda_arn" not in pytorch._hyperparameters diff --git a/tests/unit/test_pytorch_sft.py b/tests/unit/test_pytorch_sft.py new file mode 100644 index 0000000000..3220e9d4e2 --- /dev/null +++ b/tests/unit/test_pytorch_sft.py @@ -0,0 +1,703 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +import pytest +import tempfile +from mock import Mock, patch +from omegaconf import OmegaConf + +from sagemaker.pytorch import PyTorch +from sagemaker.pytorch.estimator import _is_llmft_recipe +from sagemaker.inputs import TrainingInput +from sagemaker.session_settings import SessionSettings + +# Constants for testing +ROLE = "Dummy" +REGION = "us-west-2" +BUCKET_NAME = "mybucket" +INSTANCE_COUNT = 1 +INSTANCE_TYPE = "ml.c4.4xlarge" +INSTANCE_TYPE_GPU = "ml.p4d.24xlarge" +INSTANCE_TYPE_TRN = "ml.trn1.32xlarge" +IMAGE_URI = "sagemaker-pytorch" + + +@pytest.fixture(name="sagemaker_session") +def fixture_sagemaker_session(): + boto_mock = Mock(name="boto_session", region_name=REGION) + session = Mock( + name="sagemaker_session", + boto_session=boto_mock, + boto_region_name=REGION, + config=None, + local_mode=False, + s3_resource=None, + s3_client=None, + settings=SessionSettings(), + default_bucket_prefix=None, + ) + session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) + session.expand_role = Mock(name="expand_role", return_value=ROLE) + session.upload_data = Mock(return_value="s3://mybucket/recipes/llmft-recipe.yaml") + session.sagemaker_config = {} + return session + + +def test_is_llmft_recipe(): + """Test that _is_llmft_recipe correctly identifies LLMFT recipes.""" + # Valid LLMFT recipe + recipe = OmegaConf.create( + { + "run": { + "model_type": "llm_finetuning_aws", + "name": "foo-bar123", + }, + "trainer": { + "devices": 8, + "num_nodes": 1, + }, + "training_config": { + "model_config": { + "model_name_or_path": "foo-bar/foo-bar123", + } + }, + } + ) + assert _is_llmft_recipe(recipe) is True + + # Not an LLMFT recipe - missing run section + recipe = OmegaConf.create( + { + "trainer": { + "devices": 8, + "num_nodes": 1, + }, + "training_config": { + "model_config": { + "model_name_or_path": "foo-bar/foo-bar123", + } + }, + } + ) + assert _is_llmft_recipe(recipe) is False + + # Not an LLMFT recipe - wrong model_type + recipe = OmegaConf.create( + { + "run": { + "model_type": "dpo", + "name": "foo-bar123", + }, + "trainer": { + "devices": 8, + "num_nodes": 1, + }, + "training_config": { + "model_config": { + "model_name_or_path": "foo-bar/foo-bar123", + } + }, + } + ) + assert _is_llmft_recipe(recipe) is False + + # Not an LLMFT recipe - missing training_config section + recipe = OmegaConf.create( + { + "run": { + "model_type": "llm_finetuning_aws", + "name": "foo-bar123", + }, + "trainer": { + "devices": 8, + "num_nodes": 1, + }, + } + ) + assert _is_llmft_recipe(recipe) is False + + +@patch("sagemaker.pytorch.estimator.PyTorch._recipe_resolve_and_save") +def test_setup_for_llmft_recipe_basic(mock_resolve_save, sagemaker_session): + """Test that _setup_for_llmft_recipe correctly sets up hyperparameters for LLMFT recipes.""" + # Create a mock LLMFT recipe + recipe = OmegaConf.create( + { + "run": { + "model_type": "llm_finetuning_aws", + "name": "foo-bar123", + }, + "trainer": { + "devices": 8, + "num_nodes": 2, + }, + "training_config": { + "model_config": { + "model_name_or_path": "foo-bar/foo-bar123", + "peft_config": { + "peft_type": "lora", + "target_modules": "all-linear", + "r": 16, + "lora_alpha": 32, + }, + }, + "training_args": { + "trainer_type": "sft", + "learning_rate": 0.0001, + "max_epochs": 1, + }, + }, + } + ) + + # Setup the expected return value + expected_args = { + "hyperparameters": {}, + "entry_point": None, + "source_dir": None, + "distribution": {}, + "default_image_uri": IMAGE_URI, + } + + # Mock the _setup_for_llmft_recipe method + with patch( + "sagemaker.pytorch.estimator.PyTorch._setup_for_llmft_recipe", return_value=expected_args + ) as mock_llmft_setup: + # Create the PyTorch estimator with mocked _recipe_load + with patch( + "sagemaker.pytorch.estimator.PyTorch._recipe_load", + return_value=("llmft_recipe", recipe), + ): + # Mock _recipe_resolve_and_save to return our recipe + mock_resolve_save.return_value = recipe + + pytorch = PyTorch( + training_recipe="llmft_recipe", + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE_GPU, + image_uri=IMAGE_URI, + framework_version="1.13.1", + py_version="py3", + ) + + # Check that the LLMFT recipe was correctly identified + assert pytorch.is_llmft_recipe is True + + # Verify _setup_for_llmft_recipe was called + mock_llmft_setup.assert_called_once() + call_args = mock_llmft_setup.call_args + assert len(call_args[0]) >= 2 # Check that at least recipe and recipe_name were passed + assert call_args[0][0] == recipe # first arg should be recipe + assert call_args[0][1] == "llmft_recipe" # second arg should be recipe_name + + +def test_device_handle_instance_count_with_llmft_num_nodes(): + """Test that _device_handle_instance_count correctly gets instance_count from LLMFT recipe num_nodes.""" + # Create mock LLMFT recipe with num_nodes + recipe = OmegaConf.create( + { + "run": { + "model_type": "llm_finetuning_aws", + "name": "foo-bar123", + }, + "trainer": { + "devices": 8, + "num_nodes": 4, + }, + "training_config": { + "model_config": { + "model_name_or_path": "foo-bar/foo-bar123", + } + }, + } + ) + + # Test with no instance_count in kwargs + kwargs = {} + PyTorch._device_handle_instance_count(kwargs, recipe) + assert kwargs["instance_count"] == 4 + + +def test_device_handle_instance_count_with_llmft_no_num_nodes(): + """Test that _device_handle_instance_count raises an error when no instance_count or num_nodes are provided.""" + # Create mock LLMFT recipe without num_nodes + recipe = OmegaConf.create( + { + "run": { + "model_type": "llm_finetuning_aws", + "name": "foo-bar123", + }, + "trainer": { + "devices": 8, + }, + "training_config": { + "model_config": { + "model_name_or_path": "foo-bar/foo-bar123", + } + }, + } + ) + + # Test with no instance_count in kwargs + kwargs = {} + with pytest.raises(ValueError) as error: + PyTorch._device_handle_instance_count(kwargs, recipe) + + assert "Must set either instance_count argument for estimator or" in str(error) + + +@patch("sagemaker.pytorch.estimator.logger.warning") +def test_device_handle_instance_count_with_llmft_both_provided(mock_warning): + """Test that _device_handle_instance_count warns when both instance_count and num_nodes are provided.""" + # Create mock LLMFT recipe with num_nodes + recipe = OmegaConf.create( + { + "run": { + "model_type": "llm_finetuning_aws", + "name": "foo-bar123", + }, + "trainer": { + "devices": 8, + "num_nodes": 4, + }, + "training_config": { + "model_config": { + "model_name_or_path": "foo-bar/foo-bar123", + } + }, + } + ) + + # Test with instance_count in kwargs + kwargs = {"instance_count": 2} + PyTorch._device_handle_instance_count(kwargs, recipe) + + # Verify warning was logged + mock_warning.assert_called_with( + "Using instance_count argument to estimator to set number " + "of nodes. Ignoring trainer -> num_nodes in recipe." + ) + + # Verify instance_count wasn't changed + assert kwargs["instance_count"] == 2 + + +def test_device_validate_and_get_type_with_llmft(): + """Test that _device_validate_and_get_type works correctly with LLMFT recipes.""" + # Create mock LLMFT recipe + recipe = OmegaConf.create( + { + "run": { + "model_type": "llm_finetuning_aws", + "name": "foo-bar123", + }, + "trainer": { + "devices": 8, + "num_nodes": 1, + }, + "training_config": { + "model_config": { + "model_name_or_path": "foo-bar/foo-bar123", + } + }, + } + ) + + # Test with GPU instance type + kwargs = {"instance_type": INSTANCE_TYPE_GPU} + device_type = PyTorch._device_validate_and_get_type(kwargs, recipe) + assert device_type == "gpu" + + # Test with CPU instance type + kwargs = {"instance_type": INSTANCE_TYPE} + device_type = PyTorch._device_validate_and_get_type(kwargs, recipe) + assert device_type == "cpu" + + # Test with TRN instance type + kwargs = {"instance_type": INSTANCE_TYPE_TRN} + device_type = PyTorch._device_validate_and_get_type(kwargs, recipe) + assert device_type == "trainium" + + +def test_device_validate_and_get_type_no_instance_type_llmft(): + """Test that _device_validate_and_get_type raises an error when no instance_type is provided for LLMFT.""" + # Create mock LLMFT recipe + recipe = OmegaConf.create( + { + "run": { + "model_type": "llm_finetuning_aws", + "name": "foo-bar123", + }, + "trainer": { + "devices": 8, + "num_nodes": 1, + }, + "training_config": { + "model_config": { + "model_name_or_path": "foo-bar/foo-bar123", + } + }, + } + ) + + # Test with no instance_type + kwargs = {} + with pytest.raises(ValueError) as error: + PyTorch._device_validate_and_get_type(kwargs, recipe) + + assert "Must pass instance type to estimator" in str(error) + + +@patch("sagemaker.pytorch.estimator.PyTorch._recipe_load") +@patch("time.time", return_value=1714500000) # May 1, 2024 +def test_upload_recipe_to_s3_llmft(mock_time, mock_recipe_load, sagemaker_session): + """Test that _upload_recipe_to_s3 correctly uploads the LLMFT recipe file to S3.""" + # Create a mock LLMFT recipe + mock_recipe = OmegaConf.create( + { + "run": { + "model_type": "llm_finetuning_aws", + "name": "foo-bar123", + }, + "trainer": { + "devices": 8, + "num_nodes": 1, + }, + "training_config": { + "model_config": { + "model_name_or_path": "foo-bar/foo-bar123", + } + }, + } + ) + + # Set up the mock to return a recipe name and the mock recipe + mock_recipe_load.return_value = ("llmft_recipe", mock_recipe) + + # Setup + pytorch = PyTorch( + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE_GPU, + image_uri=IMAGE_URI, + framework_version="1.13.1", + py_version="py3", + training_recipe="llmft_recipe", + ) + + # Set llmft recipe attributes + pytorch.is_llmft_recipe = True + + # Create a temporary file to use as the recipe file + with tempfile.NamedTemporaryFile(suffix=".yaml") as temp_file: + # Test uploading the recipe file to S3 + s3_uri = pytorch._upload_recipe_to_s3(sagemaker_session, temp_file.name) + + # Verify the upload_data method was called with the correct parameters + sagemaker_session.upload_data.assert_called_once() + + # Check that the S3 URI is returned correctly + assert s3_uri == sagemaker_session.upload_data.return_value + + +@patch("sagemaker.pytorch.estimator.PyTorch._recipe_load") +@patch("tempfile.NamedTemporaryFile") +@patch("omegaconf.OmegaConf.save") +@patch("sagemaker.pytorch.estimator._try_resolve_recipe") +def test_recipe_resolve_and_save_llmft( + mock_try_resolve, mock_save, mock_temp_file, mock_recipe_load, sagemaker_session +): + """Test that _recipe_resolve_and_save correctly resolves and saves the llmft recipe.""" + # Create a mock llmft recipe + mock_recipe = OmegaConf.create( + { + "run": { + "model_type": "llm_finetuning_aws", + "name": "foo-bar123", + }, + "trainer": { + "devices": 8, + "num_nodes": 1, + }, + "training_config": { + "model_config": { + "model_name_or_path": "foo-bar/foo-bar123", + } + }, + } + ) + + # Set up the mock to return a recipe name and the mock recipe + mock_recipe_load.return_value = ("llmft_recipe", mock_recipe) + + # Setup + pytorch = PyTorch( + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE_GPU, + image_uri=IMAGE_URI, + framework_version="1.13.1", + py_version="py3", + training_recipe="llmft_recipe", + ) + + # Set llmft recipe attributes + pytorch.is_llmft_recipe = True + + # Mock the temporary file + mock_temp_file_instance = Mock() + mock_temp_file_instance.name = "/tmp/llmft-recipe_12345.yaml" + mock_temp_file.return_value = mock_temp_file_instance + + # Create mock recipe + recipe = OmegaConf.create( + { + "run": { + "model_type": "llmft", + "name": "foo-bar123", + }, + "trainer": { + "devices": 8, + "num_nodes": 1, + }, + "training_config": { + "model_config": { + "model_name_or_path": "foo-bar/foo-bar123", + } + }, + } + ) + + # Mock the recipe resolution + mock_try_resolve.side_effect = [recipe, None, None] + + # Call the _recipe_resolve_and_save method + result = pytorch._recipe_resolve_and_save(recipe, "llmft-recipe", ".") + + # Verify the recipe was resolved and saved + mock_try_resolve.assert_called_with(recipe) + mock_save.assert_called_with(config=recipe, f=mock_temp_file_instance.name) + + # Verify the result is the resolved recipe + assert result == recipe + + +@patch("sagemaker.pytorch.estimator.PyTorch._recipe_load") +@patch("sagemaker.pytorch.estimator.Framework.fit") +def test_fit_with_llmft_recipe_s3_upload(mock_framework_fit, mock_recipe_load, sagemaker_session): + """Test that fit correctly uploads the llmft recipe to S3 and adds it to the inputs.""" + # Create a mock llmft recipe + mock_recipe = OmegaConf.create( + { + "run": { + "model_type": "llm_finetuning_aws", + "name": "foo-bar123", + }, + "trainer": { + "devices": 8, + "num_nodes": 1, + }, + "training_config": { + "model_config": { + "model_name_or_path": "foo-bar/foo-bar123", + } + }, + } + ) + + # Set up the mock to return a recipe name and the mock recipe + mock_recipe_load.return_value = ("llmft_recipe", mock_recipe) + + # Create a PyTorch estimator with an llmft recipe + with tempfile.NamedTemporaryFile(suffix=".yaml") as temp_file: + pytorch = PyTorch( + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE_GPU, + image_uri=IMAGE_URI, + framework_version="1.13.1", + py_version="py3", + training_recipe="llmft_recipe", + ) + + # Set llmft recipe attributes + pytorch.is_llmft_recipe = True + pytorch.training_recipe_file = temp_file + + # Mock the _upload_recipe_to_s3 and _create_recipe_copy methods + with ( + patch.object(pytorch, "_upload_recipe_to_s3") as mock_upload_recipe, + patch.object(pytorch, "_create_recipe_copy") as mock_create_copy, + ): + mock_upload_recipe.return_value = "s3://mybucket/recipes/llmft-recipe.yaml" + mock_create_copy.return_value = "s3://mybucket/recipes/recipe.yaml" + + # Call the fit method + pytorch.fit() + + # Verify the upload_recipe_to_s3 method was called + mock_upload_recipe.assert_called_once_with(sagemaker_session, temp_file.name) + + # Verify the create_recipe_copy method was called + mock_create_copy.assert_called_once_with("s3://mybucket/recipes/llmft-recipe.yaml") + + # Verify the fit method was called with the recipe channel + call_args = mock_framework_fit.call_args[1] + assert "inputs" in call_args + assert "recipe" in call_args["inputs"] + + # Verify the hyperparameters were updated with the recipe path + assert "sagemaker_recipe_local_path" in pytorch._hyperparameters + + +@patch("sagemaker.pytorch.estimator.PyTorch._recipe_load") +@patch("sagemaker.pytorch.estimator.PyTorch._upload_recipe_to_s3") +@patch("sagemaker.pytorch.estimator.PyTorch._create_recipe_copy") +@patch("sagemaker.pytorch.estimator.Framework.fit") +def test_fit_with_llmft_recipe_and_inputs( + mock_framework_fit, mock_create_copy, mock_upload_recipe, mock_recipe_load, sagemaker_session +): + """Test that fit correctly handles llmft recipes with additional inputs.""" + # Create a mock llmft recipe + mock_recipe = OmegaConf.create( + { + "run": { + "model_type": "llm_finetuning_aws", + "name": "foo-bar123", + }, + "trainer": { + "devices": 8, + "num_nodes": 1, + }, + "training_config": { + "model_config": { + "model_name_or_path": "foo-bar/foo-bar123", + } + }, + } + ) + + # Set up the mock to return a recipe name and the mock recipe + mock_recipe_load.return_value = ("llmft_recipe", mock_recipe) + mock_upload_recipe.return_value = "s3://mybucket/recipes/llmft-recipe.yaml" + mock_create_copy.return_value = "s3://mybucket/recipes/recipe.yaml" + + # Create a PyTorch estimator with an llmft recipe + with tempfile.NamedTemporaryFile(suffix=".yaml") as temp_file: + pytorch = PyTorch( + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE_GPU, + image_uri=IMAGE_URI, + framework_version="1.13.1", + py_version="py3", + training_recipe="llmft_recipe", + ) + + # Set llmft recipe attributes + pytorch.is_llmft_recipe = True + pytorch.training_recipe_file = temp_file + + # Create training inputs + train_input = TrainingInput(s3_data="s3://mybucket/train") + val_input = TrainingInput(s3_data="s3://mybucket/validation") + inputs = {"train": train_input, "validation": val_input} + + # Call the fit method with inputs + pytorch.fit(inputs=inputs) + + # Verify the fit method was called with both the recipe channel and the provided inputs + call_args = mock_framework_fit.call_args[1] + assert "inputs" in call_args + assert "recipe" in call_args["inputs"] + assert "train" in call_args["inputs"] + assert "validation" in call_args["inputs"] + + # Verify the hyperparameters were updated with the recipe path + assert "sagemaker_recipe_local_path" in pytorch._hyperparameters + + +@patch("sagemaker.pytorch.estimator.PyTorch._recipe_load") +@patch("sagemaker.pytorch.estimator.PyTorch._upload_recipe_to_s3") +@patch("sagemaker.pytorch.estimator.PyTorch._create_recipe_copy") +@patch("sagemaker.pytorch.estimator.Framework.fit") +def test_fit_with_llmft_recipe( + mock_framework_fit, mock_create_copy, mock_upload_recipe, mock_recipe_load, sagemaker_session +): + """Test that fit correctly handles llmft recipes.""" + + # Create a mock llmft recipe + mock_recipe = OmegaConf.create( + { + "run": { + "model_type": "llm_finetuning_aws", + "name": "foo-bar123", + }, + "trainer": { + "devices": 8, + "num_nodes": 1, + }, + "training_config": { + "model_config": { + "model_name_or_path": "foo-bar/foo-bar123", + } + }, + } + ) + + # Set up the mock to return a recipe name and the mock recipe + mock_recipe_load.return_value = ("llmft_recipe", mock_recipe) + + # Create a PyTorch estimator with an llmft recipe + with tempfile.NamedTemporaryFile(suffix=".yaml") as temp_file: + pytorch = PyTorch( + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE_GPU, + image_uri=IMAGE_URI, + framework_version="1.13.1", + py_version="py3", + training_recipe="llmft_recipe", + ) + + # Set llmft recipe attributes + pytorch.is_llmft_recipe = True + pytorch.training_recipe_file = temp_file + + # Mock the upload_recipe_to_s3 and create_recipe_copy methods + mock_upload_recipe.return_value = "s3://mybucket/recipes/llmft-recipe.yaml" + mock_create_copy.return_value = "s3://mybucket/recipes/recipe.yaml" + + # Call the fit method + pytorch.fit() + + # Verify the upload_recipe_to_s3 method was called + mock_upload_recipe.assert_called_once_with(sagemaker_session, temp_file.name) + + # Verify the create_recipe_copy method was called + mock_create_copy.assert_called_once_with("s3://mybucket/recipes/llmft-recipe.yaml") + + # Verify the fit method was called with the recipe channel + call_args = mock_framework_fit.call_args[1] + assert "inputs" in call_args + assert "recipe" in call_args["inputs"] + + # Verify the hyperparameters were updated with the recipe path + assert "sagemaker_recipe_local_path" in pytorch._hyperparameters diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 721243096d..079bdea8eb 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -4990,6 +4990,7 @@ def test_create_model_package_with_sagemaker_config_injection(sagemaker_session) response_types = ["application/json"] inference_instances = ["ml.m4.xlarge"] transform_instances = ["ml.m4.xlarget"] + model_package_registration_type = "Registered" model_metrics = { "Bias": { "ContentType": "content-type", @@ -5045,6 +5046,7 @@ def test_create_model_package_with_sagemaker_config_injection(sagemaker_session) task=task, validation_specification=validation_specification, skip_model_validation=skip_model_validation, + model_package_registration_type=model_package_registration_type, ) expected_kms_key_id = SAGEMAKER_CONFIG_MODEL_PACKAGE["SageMaker"]["ModelPackage"][ "ValidationSpecification" @@ -5083,6 +5085,7 @@ def test_create_model_package_with_sagemaker_config_injection(sagemaker_session) "Task": task, "ValidationSpecification": validation_specification, "SkipModelValidation": skip_model_validation, + "ModelPackageRegistrationType": "Registered", } ) expected_args["ValidationSpecification"]["ValidationRole"] = expected_role_arn @@ -5113,6 +5116,8 @@ def test_create_model_package_from_containers_with_source_uri_and_inference_spec approval_status = ("Approved",) skip_model_validation = "All" source_uri = "dummy-source-uri" + model_package_registration_type = "Registered" + sagemaker_session.sagemaker_client.search.return_value = {"Results": []} sagemaker_session.sagemaker_client.search.return_value = {"Results": []} @@ -5134,6 +5139,7 @@ def test_create_model_package_from_containers_with_source_uri_and_inference_spec approval_status=approval_status, skip_model_validation=skip_model_validation, source_uri=source_uri, + model_package_registration_type=model_package_registration_type, ) expected_create_mp_args = { "ModelPackageGroupName": model_package_group_name, @@ -5147,6 +5153,7 @@ def test_create_model_package_from_containers_with_source_uri_and_inference_spec "CertifyForMarketplace": marketplace_cert, "ModelApprovalStatus": approval_status, "SkipModelValidation": skip_model_validation, + "ModelPackageRegistrationType": "Registered", } sagemaker_session.sagemaker_client.create_model_package.assert_called_once_with( diff --git a/tox.ini b/tox.ini index 9c624b2052..76837104aa 100644 --- a/tox.ini +++ b/tox.ini @@ -90,7 +90,7 @@ commands = pip install 'torchvision==0.18.1+cpu' -f 'https://download.pytorch.org/whl/torch_stable.html' pip install 'dill>=0.3.9' pip install 'altair>=5.3' # needed for amtviz - pip install -U "sagemaker-core" # needed to keep sagemaker-core up to date + pip install -U "sagemaker-core<2.0.0" pytest {posargs} deps =