Skip to content

Commit 9679c5d

Browse files
author
Roja Reddy Sareddy
committed
Fix: Add validation to bedrock reward models
1 parent 112e12b commit 9679c5d

File tree

3 files changed

+51
-3
lines changed

3 files changed

+51
-3
lines changed

sagemaker-train/src/sagemaker/train/constants.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,11 @@
4141
]
4242

4343
HUB_NAME = "SageMakerPublicHub"
44+
45+
# Allowed reward model IDs for RLAIF trainer
46+
ALLOWED_REWARD_MODEL_IDS = [
47+
"openai.gpt-oss-120b-1:0",
48+
"openai.gpt-oss-20b-1:0",
49+
"qwen.qwen3-32b-v1:0",
50+
"qwen.qwen3-coder-30b-a3b-v1:0"
51+
]

sagemaker-train/src/sagemaker/train/rlaif_trainer.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
)
2727
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
2828
from sagemaker.core.telemetry.constants import Feature
29-
from sagemaker.train.constants import HUB_NAME
29+
from sagemaker.train.constants import HUB_NAME, ALLOWED_REWARD_MODEL_IDS
3030

3131
logger = logging.getLogger(__name__)
3232

@@ -87,7 +87,6 @@ class RLAIFTrainer(BaseTrainer):
8787
ARN, or ModelPackageGroup object. Required when model is not a ModelPackage.
8888
reward_model_id (str):
8989
Bedrock model identifier for generating LLM feedback.
90-
Evaluator models available: https://docs.aws.amazon.com/bedrock/latest/userguide/evaluation-judge.html
9190
Required for RLAIF training to provide reward signals.
9291
reward_prompt (Union[str, Evaluator]):
9392
The reward prompt or evaluator for AI feedback generation.
@@ -141,7 +140,7 @@ def __init__(
141140
self.training_type = training_type
142141
self.model_package_group_name = _validate_and_resolve_model_package_group(model,
143142
model_package_group_name)
144-
self.reward_model_id = reward_model_id
143+
self.reward_model_id = self._validate_reward_model_id(reward_model_id)
145144
self.reward_prompt = reward_prompt
146145
self.mlflow_resource_arn = mlflow_resource_arn
147146
self.mlflow_experiment_name = mlflow_experiment_name
@@ -165,6 +164,18 @@ def __init__(
165164

166165
# Process reward_prompt parameter
167166
self._process_hyperparameters()
167+
168+
def _validate_reward_model_id(self, reward_model_id):
169+
"""Validate reward_model_id is one of the allowed values."""
170+
if not reward_model_id:
171+
return None
172+
173+
if reward_model_id not in ALLOWED_REWARD_MODEL_IDS:
174+
raise ValueError(
175+
f"Invalid reward_model_id '{reward_model_id}'. "
176+
f"Available models are: {ALLOWED_REWARD_MODEL_IDS}"
177+
)
178+
return reward_model_id
168179

169180

170181
@_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="RLAIFTrainer.train")

sagemaker-train/tests/unit/train/test_rlaif_trainer.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,3 +503,32 @@ def test_process_non_builtin_reward_prompt_hub_content_error(self):
503503

504504
with pytest.raises(ValueError, match="Custom prompt 'invalid-prompt' not found in HubContent"):
505505
trainer._process_non_builtin_reward_prompt()
506+
507+
def test_validate_reward_model_id_valid_models(self):
508+
"""Test _validate_reward_model_id with valid model IDs."""
509+
trainer = RLAIFTrainer.__new__(RLAIFTrainer)
510+
511+
valid_models = [
512+
"openai.gpt-oss-120b-1:0",
513+
"openai.gpt-oss-20b-1:0",
514+
"qwen.qwen3-32b-v1:0",
515+
"qwen.qwen3-coder-30b-a3b-v1:0"
516+
]
517+
518+
for model_id in valid_models:
519+
result = trainer._validate_reward_model_id(model_id)
520+
assert result == model_id
521+
522+
def test_validate_reward_model_id_invalid_model(self):
523+
"""Test _validate_reward_model_id raises error for invalid model ID."""
524+
trainer = RLAIFTrainer.__new__(RLAIFTrainer)
525+
526+
with pytest.raises(ValueError, match="Invalid reward_model_id 'invalid-model-id'"):
527+
trainer._validate_reward_model_id("invalid-model-id")
528+
529+
def test_validate_reward_model_id_none_model(self):
530+
"""Test _validate_reward_model_id handles None model ID."""
531+
trainer = RLAIFTrainer.__new__(RLAIFTrainer)
532+
533+
result = trainer._validate_reward_model_id(None)
534+
assert result is None

0 commit comments

Comments
 (0)