diff --git a/sagemaker-train/src/sagemaker/train/constants.py b/sagemaker-train/src/sagemaker/train/constants.py index 2ad66e868c..9a07888064 100644 --- a/sagemaker-train/src/sagemaker/train/constants.py +++ b/sagemaker-train/src/sagemaker/train/constants.py @@ -41,3 +41,11 @@ ] HUB_NAME = "SageMakerPublicHub" + +# Allowed reward model IDs for RLAIF trainer +_ALLOWED_REWARD_MODEL_IDS = [ + "openai.gpt-oss-120b-1:0", + "openai.gpt-oss-20b-1:0", + "qwen.qwen3-32b-v1:0", + "qwen.qwen3-coder-30b-a3b-v1:0" +] diff --git a/sagemaker-train/src/sagemaker/train/rlaif_trainer.py b/sagemaker-train/src/sagemaker/train/rlaif_trainer.py index bc6ab234e9..68d50a2989 100644 --- a/sagemaker-train/src/sagemaker/train/rlaif_trainer.py +++ b/sagemaker-train/src/sagemaker/train/rlaif_trainer.py @@ -26,7 +26,7 @@ ) from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter from sagemaker.core.telemetry.constants import Feature -from sagemaker.train.constants import HUB_NAME +from sagemaker.train.constants import HUB_NAME, _ALLOWED_REWARD_MODEL_IDS logger = logging.getLogger(__name__) @@ -87,7 +87,6 @@ class RLAIFTrainer(BaseTrainer): ARN, or ModelPackageGroup object. Required when model is not a ModelPackage. reward_model_id (str): Bedrock model identifier for generating LLM feedback. - Evaluator models available: https://docs.aws.amazon.com/bedrock/latest/userguide/evaluation-judge.html Required for RLAIF training to provide reward signals. reward_prompt (Union[str, Evaluator]): The reward prompt or evaluator for AI feedback generation. @@ -141,7 +140,7 @@ def __init__( self.training_type = training_type self.model_package_group_name = _validate_and_resolve_model_package_group(model, model_package_group_name) - self.reward_model_id = reward_model_id + self.reward_model_id = self._validate_reward_model_id(reward_model_id) self.reward_prompt = reward_prompt self.mlflow_resource_arn = mlflow_resource_arn self.mlflow_experiment_name = mlflow_experiment_name @@ -165,6 +164,18 @@ def __init__( # Process reward_prompt parameter self._process_hyperparameters() + + def _validate_reward_model_id(self, reward_model_id): + """Validate reward_model_id is one of the allowed values.""" + if not reward_model_id: + return None + + if reward_model_id not in _ALLOWED_REWARD_MODEL_IDS: + raise ValueError( + f"Invalid reward_model_id '{reward_model_id}'. " + f"Available models are: {_ALLOWED_REWARD_MODEL_IDS}" + ) + return reward_model_id @_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="RLAIFTrainer.train") diff --git a/sagemaker-train/tests/unit/train/test_rlaif_trainer.py b/sagemaker-train/tests/unit/train/test_rlaif_trainer.py index 0008b88912..eca69eed6d 100644 --- a/sagemaker-train/tests/unit/train/test_rlaif_trainer.py +++ b/sagemaker-train/tests/unit/train/test_rlaif_trainer.py @@ -503,3 +503,32 @@ def test_process_non_builtin_reward_prompt_hub_content_error(self): with pytest.raises(ValueError, match="Custom prompt 'invalid-prompt' not found in HubContent"): trainer._process_non_builtin_reward_prompt() + + def test_validate_reward_model_id_valid_models(self): + """Test _validate_reward_model_id with valid model IDs.""" + trainer = RLAIFTrainer.__new__(RLAIFTrainer) + + valid_models = [ + "openai.gpt-oss-120b-1:0", + "openai.gpt-oss-20b-1:0", + "qwen.qwen3-32b-v1:0", + "qwen.qwen3-coder-30b-a3b-v1:0" + ] + + for model_id in valid_models: + result = trainer._validate_reward_model_id(model_id) + assert result == model_id + + def test_validate_reward_model_id_invalid_model(self): + """Test _validate_reward_model_id raises error for invalid model ID.""" + trainer = RLAIFTrainer.__new__(RLAIFTrainer) + + with pytest.raises(ValueError, match="Invalid reward_model_id 'invalid-model-id'"): + trainer._validate_reward_model_id("invalid-model-id") + + def test_validate_reward_model_id_none_model(self): + """Test _validate_reward_model_id handles None model ID.""" + trainer = RLAIFTrainer.__new__(RLAIFTrainer) + + result = trainer._validate_reward_model_id(None) + assert result is None