|
21 | 21 | _create_serverless_config, |
22 | 22 | _create_mlflow_config, |
23 | 23 | _create_model_package_config, |
24 | | - _validate_eula_for_gated_model |
| 24 | + _validate_eula_for_gated_model, |
| 25 | + _validate_hyperparameter_values |
25 | 26 | ) |
26 | 27 | from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter |
27 | 28 | from sagemaker.core.telemetry.constants import Feature |
@@ -163,7 +164,8 @@ def __init__( |
163 | 164 | self.accept_eula = _validate_eula_for_gated_model(model, accept_eula, is_gated_model) |
164 | 165 |
|
165 | 166 | # Process reward_prompt parameter |
166 | | - self._process_reward_prompt() |
| 167 | + self._process_hyperparameters() |
| 168 | + |
167 | 169 |
|
168 | 170 | @_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="RLAIFTrainer.train") |
169 | 171 | def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validation_dataset: Optional[Union[str, DataSet]] = None, wait: bool = True): |
@@ -223,6 +225,8 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati |
223 | 225 |
|
224 | 226 | final_hyperparameters = self.hyperparameters.to_dict() |
225 | 227 |
|
| 228 | + _validate_hyperparameter_values(final_hyperparameters) |
| 229 | + |
226 | 230 | model_package_config = _create_model_package_config( |
227 | 231 | model_package_group_name=self.model_package_group_name, |
228 | 232 | model=self.model, |
@@ -258,40 +262,107 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati |
258 | 262 | self.latest_training_job = training_job |
259 | 263 | return training_job |
260 | 264 |
|
261 | | - def _process_reward_prompt(self): |
262 | | - """Process reward_prompt parameter for builtin vs custom prompts.""" |
263 | | - if not self.reward_prompt: |
264 | | - return |
265 | | - |
266 | | - # Handle Evaluator object |
267 | | - if not isinstance(self.reward_prompt, str): |
268 | | - evaluator_arn = _extract_evaluator_arn(self.reward_prompt, "reward_prompt") |
269 | | - self._evaluator_arn = evaluator_arn |
270 | | - self._reward_prompt_processed = {"custom_prompt_arn": evaluator_arn} |
| 265 | + def _process_hyperparameters(self): |
| 266 | + """Update hyperparameters based on constructor inputs and process reward_prompt.""" |
| 267 | + if not self.hyperparameters or not hasattr(self.hyperparameters, '_specs') or not self.hyperparameters._specs: |
271 | 268 | return |
272 | | - |
273 | | - # Handle string inputs |
274 | | - if self.reward_prompt.startswith("Builtin"): |
275 | | - # Map to preset_prompt in hyperparameters |
276 | | - self._reward_prompt_processed = {"preset_prompt": self.reward_prompt} |
277 | | - elif self.reward_prompt.startswith("arn:aws:sagemaker:"): |
| 269 | + |
| 270 | + # Remove keys that are handled by constructor inputs |
| 271 | + if hasattr(self.hyperparameters, 'output_path'): |
| 272 | + delattr(self.hyperparameters, 'output_path') |
| 273 | + self.hyperparameters._specs.pop('output_path', None) |
| 274 | + if hasattr(self.hyperparameters, 'data_path'): |
| 275 | + delattr(self.hyperparameters, 'data_path') |
| 276 | + self.hyperparameters._specs.pop('data_path', None) |
| 277 | + if hasattr(self.hyperparameters, 'validation_data_path'): |
| 278 | + delattr(self.hyperparameters, 'validation_data_path') |
| 279 | + self.hyperparameters._specs.pop('validation_data_path', None) |
| 280 | + |
| 281 | + # Update judge_model_id if reward_model_id is provided |
| 282 | + if hasattr(self, 'reward_model_id') and self.reward_model_id: |
| 283 | + judge_model_value = f"bedrock/{self.reward_model_id}" |
| 284 | + self.hyperparameters.judge_model_id = judge_model_value |
| 285 | + |
| 286 | + # Process reward_prompt parameter |
| 287 | + if hasattr(self, 'reward_prompt') and self.reward_prompt: |
| 288 | + if isinstance(self.reward_prompt, str): |
| 289 | + if self.reward_prompt.startswith("Builtin"): |
| 290 | + # Handle builtin reward prompts |
| 291 | + self._update_judge_prompt_template_direct(self.reward_prompt) |
| 292 | + else: |
| 293 | + # Handle evaluator ARN or hub content name |
| 294 | + self._process_non_builtin_reward_prompt() |
| 295 | + else: |
| 296 | + # Handle evaluator object |
| 297 | + if hasattr(self.hyperparameters, 'judge_prompt_template'): |
| 298 | + delattr(self.hyperparameters, 'judge_prompt_template') |
| 299 | + self.hyperparameters._specs.pop('judge_prompt_template', None) |
| 300 | + |
| 301 | + evaluator_arn = _extract_evaluator_arn(self.reward_prompt, "reward_prompt") |
| 302 | + self._evaluator_arn = evaluator_arn |
| 303 | + |
| 304 | + def _process_non_builtin_reward_prompt(self): |
| 305 | + """Process non-builtin reward prompt (ARN or hub content name).""" |
| 306 | + # Remove judge_prompt_template for non-builtin prompts |
| 307 | + if hasattr(self.hyperparameters, 'judge_prompt_template'): |
| 308 | + delattr(self.hyperparameters, 'judge_prompt_template') |
| 309 | + self.hyperparameters._specs.pop('judge_prompt_template', None) |
| 310 | + |
| 311 | + if self.reward_prompt.startswith("arn:aws:sagemaker:"): |
278 | 312 | # Validate and assign ARN |
279 | 313 | evaluator_arn = _extract_evaluator_arn(self.reward_prompt, "reward_prompt") |
280 | 314 | self._evaluator_arn = evaluator_arn |
281 | | - self._reward_prompt_processed = {"custom_prompt_arn": evaluator_arn} |
282 | 315 | else: |
283 | 316 | try: |
284 | | - session = self.sagemaker_session or _get_beta_session() |
| 317 | + session = TrainDefaults.get_sagemaker_session( |
| 318 | + sagemaker_session=self.sagemaker_session |
| 319 | + ) |
285 | 320 | hub_content = _get_hub_content_metadata( |
286 | | - hub_name=HUB_NAME, # or appropriate hub name |
| 321 | + hub_name=HUB_NAME, |
287 | 322 | hub_content_type="JsonDoc", |
288 | 323 | hub_content_name=self.reward_prompt, |
289 | 324 | session=session.boto_session, |
290 | | - region=session.boto_session.region_name or "us-west-2" |
| 325 | + region=session.boto_session.region_name |
291 | 326 | ) |
292 | | - # Store ARN for evaluator_arn in ServerlessJobConfig |
| 327 | + # Store ARN for evaluator_arn |
293 | 328 | self._evaluator_arn = hub_content.hub_content_arn |
294 | | - self._reward_prompt_processed = {"custom_prompt_arn": hub_content.hub_content_arn} |
295 | 329 | except Exception as e: |
296 | 330 | raise ValueError(f"Custom prompt '{self.reward_prompt}' not found in HubContent: {e}") |
| 331 | + |
| 332 | + |
| 333 | + |
| 334 | + def _update_judge_prompt_template_direct(self, reward_prompt): |
| 335 | + """Update judge_prompt_template based on Builtin reward function.""" |
| 336 | + # Get available templates from hyperparameters specs |
| 337 | + judge_prompt_spec = self.hyperparameters._specs.get('judge_prompt_template', {}) |
| 338 | + available_templates = judge_prompt_spec.get('enum', []) |
| 339 | + |
| 340 | + if not available_templates: |
| 341 | + # If no enum found, use the current value as the only available option |
| 342 | + current_value = getattr(self.hyperparameters, 'judge_prompt_template', None) |
| 343 | + if current_value: |
| 344 | + available_templates = [current_value] |
| 345 | + else: |
| 346 | + return |
| 347 | + |
| 348 | + # Extract template name after "Builtin." and convert to lowercase |
| 349 | + template_name = reward_prompt.split(".", 1)[1].lower() |
| 350 | + |
| 351 | + # Find matching template by extracting filename without extension |
| 352 | + matching_template = None |
| 353 | + for template in available_templates: |
| 354 | + template_filename = template.split("/")[-1].replace(".jinja", "").lower() |
| 355 | + if template_filename == template_name: |
| 356 | + matching_template = template |
| 357 | + break |
| 358 | + |
| 359 | + if matching_template: |
| 360 | + self.hyperparameters.judge_prompt_template = matching_template |
| 361 | + else: |
| 362 | + available_options = [f"Builtin.{t.split('/')[-1].replace('.jinja', '')}" for t in available_templates] |
| 363 | + raise ValueError( |
| 364 | + f"Selected reward function option '{reward_prompt}' is not available. " |
| 365 | + f"Choose one from the available options: {available_options}. " |
| 366 | + f"Example: reward_prompt='Builtin.summarize'" |
| 367 | + ) |
297 | 368 |
|
0 commit comments