Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
# Changelog

## v2.255.0 (2025-12-03)

### Features

* Extracts reward Lambda ARN from Nova recipes
* Passes it as training job hyperparameter
* Added LLMFT recipe support with standardized recipe handling
* Enhanced recipe validation and multi-model type compatibility


## v2.254.1 (2025-10-31)

### Bug Fixes and Other Changes
Expand Down
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.254.2.dev0
2.254.2.dev0
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ classifiers = [
]
dependencies = [
"attrs>=24,<26",
"boto3>=1.39.5,<2.0",
"boto3>=1.42.2,<2.0",
"cloudpickle>=2.2.1",
"docker",
"fastapi",
Expand All @@ -49,7 +49,7 @@ dependencies = [
"psutil",
"PyYAML>=6.0.1",
"requests",
"sagemaker-core>=1.0.17,<2.0.0",
"sagemaker-core>=1.0.71,<2.0.0",
"schema",
"smdebug_rulesconfig==1.0.1",
"tblib>=1.7.0,<4",
Expand Down
2 changes: 1 addition & 1 deletion requirements/extras/test_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,4 @@ mypy-boto3-s3==1.35.76
mypy-extensions==1.0.0
mypy==1.9.0
# apache-airflow transitive dependancy
google-re2<1.1.20250805; python_version < "3.10"
google-re2<1.1.20250805; python_version < "3.10"
1 change: 1 addition & 0 deletions src/sagemaker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from sagemaker.analytics import TrainingJobAnalytics, HyperparameterTuningJobAnalytics # noqa: F401
from sagemaker.local.local_session import LocalSession # noqa: F401

from sagemaker.container_base_model import ContainerBaseModel # noqa: F401
from sagemaker.model import Model, ModelPackage # noqa: F401
from sagemaker.model_metrics import ModelMetrics, MetricsSource, FileSource # noqa: F401
from sagemaker.pipeline import PipelineModel # noqa: F401
Expand Down
9 changes: 8 additions & 1 deletion src/sagemaker/chainer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import Callable, Optional, Union, List, Dict

import sagemaker
from sagemaker import image_uris, ModelMetrics
from sagemaker import image_uris, ModelMetrics, ContainerBaseModel
from sagemaker.drift_check_baselines import DriftCheckBaselines
from sagemaker.fw_utils import (
model_code_key_prefix,
Expand Down Expand Up @@ -182,6 +182,8 @@ def register(
source_uri: Optional[Union[str, PipelineVariable]] = None,
model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None,
model_life_cycle: Optional[ModelLifeCycle] = None,
model_package_registration_type: Optional[Union[str, PipelineVariable]] = None,
base_model: Optional[ContainerBaseModel] = None,
):
"""Creates a model package for creating SageMaker models or listing on Marketplace.

Expand Down Expand Up @@ -236,6 +238,9 @@ def register(
model_card (ModeCard or ModelPackageModelCard): document contains qualitative and
quantitative information about a model (default: None).
model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None).
model_package_registration_type (str or PipelineVariable): Model Package Registration
Type (default: None).
base_model (ContainerBaseModel): ContainerBaseModel object (default: None).

Returns:
str: A string of SageMaker Model Package ARN.
Expand Down Expand Up @@ -278,6 +283,8 @@ def register(
source_uri=source_uri,
model_card=model_card,
model_life_cycle=model_life_cycle,
model_package_registration_type=model_package_registration_type,
base_model=base_model,
)

def prepare_container_def(
Expand Down
52 changes: 52 additions & 0 deletions src/sagemaker/container_base_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express oXr implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This file contains code related to base model for containers."""
from __future__ import absolute_import

from typing import Optional, Union

from sagemaker.workflow.entities import PipelineVariable


class ContainerBaseModel(object):
"""Accepts Base Model parameters for conversion to request dict."""

def __init__(
self,
hub_content_name: Union[str, PipelineVariable] = None,
hub_content_version: Optional[Union[str, PipelineVariable]] = None,
recipe_name: Optional[Union[str, PipelineVariable]] = None,
):
"""Initialize a ``ContainerBaseModel`` instance and turn parameters into dict.

Args:
hub_content_name (str or PipelineVariable): The hub content name
hub_content_version (str or PipelineVariable): The hub content version
(default: None)
recipe_name (str or PipelineVariable): The Recipe name
(default: None)
"""
self.hub_content_name = hub_content_name
self.hub_content_version = hub_content_version
self.recipe_name = recipe_name

def _to_request_dict(self):
"""Generates a request dictionary using the parameters provided to the class."""
base_model_request = {}
if self.hub_content_name is not None:
base_model_request["HubContentName"] = self.hub_content_name
if self.hub_content_version is not None:
base_model_request["HubContentVersion"] = self.hub_content_version
if self.recipe_name is not None:
base_model_request["RecipeName"] = self.recipe_name
return base_model_request
7 changes: 7 additions & 0 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1817,6 +1817,8 @@ def register(
source_uri=None,
model_life_cycle=None,
model_card=None,
model_package_registration_type=None,
base_model=None,
**kwargs,
):
"""Creates a model package for creating SageMaker models or listing on Marketplace.
Expand Down Expand Up @@ -1868,6 +1870,9 @@ def register(
model_card (ModeCard or ModelPackageModelCard): document contains qualitative and
quantitative information about a model (default: None).
model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None).
model_package_registration_type (str): Model Package Registration
Type (default: None).
base_model (ContainerBaseModel): ContainerBaseModel object (default: None).
**kwargs: Passed to invocation of ``create_model()``. Implementations may customize
``create_model()`` to accept ``**kwargs`` to customize model creation during
deploy. For more, see the implementation docs.
Expand Down Expand Up @@ -1924,6 +1929,8 @@ def register(
source_uri=source_uri,
model_card=model_card,
model_life_cycle=model_life_cycle,
model_package_registration_type=model_package_registration_type,
base_model=base_model,
)

@property
Expand Down
9 changes: 8 additions & 1 deletion src/sagemaker/huggingface/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import Callable, Optional, Union, List, Dict

import sagemaker
from sagemaker import image_uris, ModelMetrics
from sagemaker import image_uris, ModelMetrics, ContainerBaseModel
from sagemaker.deserializers import JSONDeserializer
from sagemaker.drift_check_baselines import DriftCheckBaselines
from sagemaker.fw_utils import (
Expand Down Expand Up @@ -372,6 +372,8 @@ def register(
source_uri: Optional[Union[str, PipelineVariable]] = None,
model_life_cycle: Optional[ModelLifeCycle] = None,
model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None,
model_package_registration_type: Optional[Union[str, PipelineVariable]] = None,
base_model: Optional[ContainerBaseModel] = None,
):
"""Creates a model package for creating SageMaker models or listing on Marketplace.

Expand Down Expand Up @@ -427,6 +429,9 @@ def register(
model_card (ModeCard or ModelPackageModelCard): document contains qualitative and
quantitative information about a model (default: None).
model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None).
model_package_registration_type (str or PipelineVariable): Model Package Registration
Type (default: None).
base_model (ContainerBaseModel): ContainerBaseModel object (default: None).

Returns:
A `sagemaker.model.ModelPackage` instance.
Expand Down Expand Up @@ -477,6 +482,8 @@ def register(
source_uri=source_uri,
model_life_cycle=model_life_cycle,
model_card=model_card,
model_package_registration_type=model_package_registration_type,
base_model=base_model,
)

def prepare_container_def(
Expand Down
12 changes: 12 additions & 0 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from sagemaker.model_card.helpers import _hash_content_str
from sagemaker.model_card.schema_constraints import ModelApprovalStatusEnum
from sagemaker.session import Session
from sagemaker.container_base_model import ContainerBaseModel
from sagemaker.model_metrics import ModelMetrics
from sagemaker.drift_check_baselines import DriftCheckBaselines
from sagemaker.explainer import ExplainerConfig
Expand Down Expand Up @@ -477,6 +478,8 @@ def register(
model_life_cycle: Optional[ModelLifeCycle] = None,
accept_eula: Optional[bool] = None,
model_type: Optional[JumpStartModelType] = None,
model_package_registration_type: Optional[Union[str, PipelineVariable]] = None,
base_model: Optional[ContainerBaseModel] = None,
):
"""Creates a model package for creating SageMaker models or listing on Marketplace.

Expand Down Expand Up @@ -531,6 +534,9 @@ def register(
model_card (ModeCard or ModelPackageModelCard): document contains qualitative and
quantitative information about a model (default: None).
model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None).
model_package_registration_type (str or PipelineVariable): Model Package Registration
Type (default: None).
base_model (ContainerBaseModel): ContainerBaseModel object (default: None).

Returns:
A `sagemaker.model.ModelPackage` instance or pipeline step arguments
Expand Down Expand Up @@ -578,6 +584,9 @@ def register(
if self.model_data is not None:
container_def["ModelDataUrl"] = self.model_data

if base_model is not None and hasattr(base_model, "_to_request_dict"):
container_def["BaseModel"] = base_model._to_request_dict()

model_pkg_args = sagemaker.get_model_package_args(
self.content_types,
self.response_types,
Expand All @@ -601,6 +610,8 @@ def register(
source_uri=source_uri,
model_card=model_card,
model_life_cycle=model_life_cycle,
model_package_registration_type=model_package_registration_type,
base_model=base_model,
)
model_package = self.sagemaker_session.create_model_package_from_containers(
**model_pkg_args
Expand Down Expand Up @@ -2150,6 +2161,7 @@ def __init__(
You can find additional parameters for initializing this class at
:class:`~sagemaker.model.Model`.
"""

super(FrameworkModel, self).__init__(
image_uri,
model_data,
Expand Down
22 changes: 15 additions & 7 deletions src/sagemaker/modules/train/model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@
_get_args_from_recipe,
_determine_device_type,
_is_nova_recipe,
_is_llmft_recipe,
_load_base_recipe,
)

Expand Down Expand Up @@ -252,6 +253,7 @@ class ModelTrainer(BaseModel):
_metric_definitions: Optional[List[MetricDefinition]] = PrivateAttr(default=None)

_is_nova_recipe: Optional[bool] = PrivateAttr(default=None)
_is_llmft_recipe: Optional[bool] = PrivateAttr(default=None)
_temp_recipe_train_dir: Optional[TemporaryDirectory] = PrivateAttr(default=None)
_temp_code_dir: Optional[TemporaryDirectory] = PrivateAttr(default=None)

Expand Down Expand Up @@ -632,12 +634,13 @@ def _create_training_job_args(

final_input_data_config = list(existing_channels.values()) + new_channels

if self._is_nova_recipe:
if self._is_nova_recipe or self._is_llmft_recipe:

for input_data in final_input_data_config:
if input_data.channel_name == SM_RECIPE:
raise ValueError(
"Cannot use reserved channel name 'recipe' as an input channel name "
" for Nova Recipe"
" for Nova or LLMFT Recipe"
)
recipe_file_path = os.path.join(self._temp_recipe_train_dir.name, SM_RECIPE_YAML)
recipe_channel = self.create_input_data_channel(
Expand All @@ -646,7 +649,10 @@ def _create_training_job_args(
key_prefix=input_data_key_prefix,
)
final_input_data_config.append(recipe_channel)
self.hyperparameters.update({"sagemaker_recipe_local_path": SM_RECIPE_CONTAINER_PATH})
if self._is_nova_recipe:
self.hyperparameters.update(
{"sagemaker_recipe_local_path": SM_RECIPE_CONTAINER_PATH}
)

if final_input_data_config:
final_input_data_config = self._get_input_data_config(
Expand Down Expand Up @@ -1201,14 +1207,15 @@ def from_recipe(
training_recipe=training_recipe, recipe_overrides=recipe_overrides
)
is_nova = _is_nova_recipe(recipe=recipe)
is_llmft = _is_llmft_recipe(recipe=recipe)

if device_type == "cpu" and not is_nova:
if device_type == "cpu" and not (is_nova or is_llmft):
raise ValueError(
"Training recipe is not supported for CPU instances. "
+ "Please provide a GPU or Tranium instance type."
)
if training_image is None and is_nova:
raise ValueError("training_image must be provided when using recipe for Nova.")
if training_image is None and (is_nova or is_llmft):
raise ValueError("training_image must be provided when using recipe for Nova or LLMFT")

if training_image_config and training_image is None:
raise ValueError("training_image must be provided when using training_image_config.")
Expand Down Expand Up @@ -1238,7 +1245,7 @@ def from_recipe(
model_trainer_args["training_image"] = training_image
if hyperparameters and not is_nova:
logger.warning(
"Hyperparameters are not supported for general training recipes. "
"Hyperparameters are not supported for general and LLMFT training recipes. "
+ "Ignoring hyperparameters input."
)
if is_nova:
Expand All @@ -1264,6 +1271,7 @@ def from_recipe(
**model_trainer_args,
)
model_trainer._is_nova_recipe = is_nova
model_trainer._is_llmft_recipe = is_llmft
model_trainer._temp_recipe_train_dir = tmp_dir
return model_trainer

Expand Down
Loading
Loading