Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions sagemaker-train/src/sagemaker/train/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,11 @@
]

HUB_NAME = "SageMakerPublicHub"

# Allowed reward model IDs for RLAIF trainer
_ALLOWED_REWARD_MODEL_IDS = [
"openai.gpt-oss-120b-1:0",
"openai.gpt-oss-20b-1:0",
"qwen.qwen3-32b-v1:0",
"qwen.qwen3-coder-30b-a3b-v1:0"
]
17 changes: 14 additions & 3 deletions sagemaker-train/src/sagemaker/train/rlaif_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
)
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
from sagemaker.core.telemetry.constants import Feature
from sagemaker.train.constants import HUB_NAME
from sagemaker.train.constants import HUB_NAME, _ALLOWED_REWARD_MODEL_IDS

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -87,7 +87,6 @@ class RLAIFTrainer(BaseTrainer):
ARN, or ModelPackageGroup object. Required when model is not a ModelPackage.
reward_model_id (str):
Bedrock model identifier for generating LLM feedback.
Evaluator models available: https://docs.aws.amazon.com/bedrock/latest/userguide/evaluation-judge.html
Required for RLAIF training to provide reward signals.
reward_prompt (Union[str, Evaluator]):
The reward prompt or evaluator for AI feedback generation.
Expand Down Expand Up @@ -141,7 +140,7 @@ def __init__(
self.training_type = training_type
self.model_package_group_name = _validate_and_resolve_model_package_group(model,
model_package_group_name)
self.reward_model_id = reward_model_id
self.reward_model_id = self._validate_reward_model_id(reward_model_id)
self.reward_prompt = reward_prompt
self.mlflow_resource_arn = mlflow_resource_arn
self.mlflow_experiment_name = mlflow_experiment_name
Expand All @@ -165,6 +164,18 @@ def __init__(

# Process reward_prompt parameter
self._process_hyperparameters()

def _validate_reward_model_id(self, reward_model_id):
"""Validate reward_model_id is one of the allowed values."""
if not reward_model_id:
return None

if reward_model_id not in _ALLOWED_REWARD_MODEL_IDS:
raise ValueError(
f"Invalid reward_model_id '{reward_model_id}'. "
f"Available models are: {_ALLOWED_REWARD_MODEL_IDS}"
)
return reward_model_id


@_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="RLAIFTrainer.train")
Expand Down
29 changes: 29 additions & 0 deletions sagemaker-train/tests/unit/train/test_rlaif_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,3 +503,32 @@ def test_process_non_builtin_reward_prompt_hub_content_error(self):

with pytest.raises(ValueError, match="Custom prompt 'invalid-prompt' not found in HubContent"):
trainer._process_non_builtin_reward_prompt()

def test_validate_reward_model_id_valid_models(self):
"""Test _validate_reward_model_id with valid model IDs."""
trainer = RLAIFTrainer.__new__(RLAIFTrainer)

valid_models = [
"openai.gpt-oss-120b-1:0",
"openai.gpt-oss-20b-1:0",
"qwen.qwen3-32b-v1:0",
"qwen.qwen3-coder-30b-a3b-v1:0"
]

for model_id in valid_models:
result = trainer._validate_reward_model_id(model_id)
assert result == model_id

def test_validate_reward_model_id_invalid_model(self):
"""Test _validate_reward_model_id raises error for invalid model ID."""
trainer = RLAIFTrainer.__new__(RLAIFTrainer)

with pytest.raises(ValueError, match="Invalid reward_model_id 'invalid-model-id'"):
trainer._validate_reward_model_id("invalid-model-id")

def test_validate_reward_model_id_none_model(self):
"""Test _validate_reward_model_id handles None model ID."""
trainer = RLAIFTrainer.__new__(RLAIFTrainer)

result = trainer._validate_reward_model_id(None)
assert result is None
Loading