Skip to content

Commit 60574e5

Browse files
rsareddy0329Roja Reddy Sareddy
andauthored
fix: Address Hyperparameter issue , validate s3 output path, additional unit tests (#5376)
* fix: Fix the recipe selection for multiple recipe scenario * fix: Fix the recipe selection for multiple recipe scenario * fix: Hyperparameter issue fixes, validate s3 output path,additional unit tests --------- Co-authored-by: Roja Reddy Sareddy <rsareddy@amazon.com>
1 parent f32f615 commit 60574e5

14 files changed

+846
-92
lines changed

sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,13 +352,18 @@ def _get_fine_tuning_options_and_model_arn(model_name: str, customization_techni
352352
elif (isinstance(training_type, TrainingType) and training_type == TrainingType.FULL) or training_type == "FULL":
353353
recipe = next((r for r in recipes_with_template if not r.get("Peft")), None)
354354

355-
if recipe and recipe.get("SmtjOverrideParamsS3Uri"):
355+
if not recipe:
356+
raise ValueError(f"No recipes found with Smtj for technique: {customization_technique},training_type:{training_type}")
357+
358+
elif recipe and recipe.get("SmtjOverrideParamsS3Uri"):
356359
s3_uri = recipe["SmtjOverrideParamsS3Uri"]
357360
s3 = boto3.client("s3")
358361
bucket, key = s3_uri.replace("s3://", "").split("/", 1)
359362
obj = s3.get_object(Bucket=bucket, Key=key)
360363
options_dict = json.loads(obj["Body"].read())
361364
return FineTuningOptions(options_dict), model_arn, is_gated_model
365+
else:
366+
return FineTuningOptions({}), model_arn, is_gated_model
362367

363368
except Exception as e:
364369
logger.error("Exception getting fine-tuning options: %s", e)
@@ -598,6 +603,9 @@ def _create_output_config(sagemaker_session,s3_output_path=None, kms_key_id=None
598603
# Use default S3 output path if none provided
599604
if s3_output_path is None:
600605
s3_output_path = _get_default_s3_output_path(sagemaker_session)
606+
607+
# Validate S3 path exists
608+
_validate_s3_path_exists(s3_output_path, sagemaker_session)
601609

602610
return OutputDataConfig(
603611
s3_output_path=s3_output_path,
@@ -682,3 +690,43 @@ def _validate_eula_for_gated_model(model, accept_eula, is_gated_model):
682690
)
683691

684692
return accept_eula
693+
694+
695+
def _validate_s3_path_exists(s3_path: str, sagemaker_session):
696+
"""Validate if S3 path exists and is accessible."""
697+
if not s3_path.startswith("s3://"):
698+
raise ValueError(f"Invalid S3 path format: {s3_path}")
699+
700+
# Parse S3 URI
701+
s3_parts = s3_path.replace("s3://", "").split("/", 1)
702+
bucket_name = s3_parts[0]
703+
prefix = s3_parts[1] if len(s3_parts) > 1 else ""
704+
705+
s3_client = sagemaker_session.boto_session.client('s3')
706+
707+
try:
708+
# Check if bucket exists and is accessible
709+
s3_client.head_bucket(Bucket=bucket_name)
710+
711+
# If prefix is provided, check if it exists
712+
if prefix:
713+
response = s3_client.list_objects_v2(Bucket=bucket_name, Prefix=prefix, MaxKeys=1)
714+
if 'Contents' not in response:
715+
raise ValueError(f"S3 prefix '{prefix}' does not exist in bucket '{bucket_name}'")
716+
717+
except Exception as e:
718+
if "NoSuchBucket" in str(e):
719+
raise ValueError(f"S3 bucket '{bucket_name}' does not exist or is not accessible")
720+
raise ValueError(f"Failed to validate S3 path '{s3_path}': {str(e)}")
721+
722+
723+
def _validate_hyperparameter_values(hyperparameters: dict):
724+
"""Validate hyperparameter values for allowed characters."""
725+
import re
726+
allowed_chars = r"^[a-zA-Z0-9/_.:,\-\s'\"\[\]]*$"
727+
for key, value in hyperparameters.items():
728+
if isinstance(value, str) and not re.match(allowed_chars, value):
729+
raise ValueError(
730+
f"Hyperparameter '{key}' value '{value}' contains invalid characters. "
731+
f"Only a-z, A-Z, 0-9, /, _, ., :, \\, -, space, ', \", [, ] and , are allowed."
732+
)

sagemaker-train/src/sagemaker/train/dpo_trainer.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
_create_serverless_config,
1818
_create_mlflow_config,
1919
_create_model_package_config,
20-
_validate_eula_for_gated_model
20+
_validate_eula_for_gated_model,
21+
_validate_hyperparameter_values
2122
)
2223
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
2324
from sagemaker.core.telemetry.constants import Feature
@@ -137,8 +138,38 @@ def __init__(
137138

138139
))
139140

141+
# Process hyperparameters
142+
self._process_hyperparameters()
143+
140144
# Validate and set EULA acceptance
141145
self.accept_eula = _validate_eula_for_gated_model(model, accept_eula, is_gated_model)
146+
147+
def _process_hyperparameters(self):
148+
"""Remove hyperparameter keys that are handled by constructor inputs."""
149+
if self.hyperparameters:
150+
# Remove keys that are handled by constructor inputs
151+
if hasattr(self.hyperparameters, 'data_path'):
152+
delattr(self.hyperparameters, 'data_path')
153+
self.hyperparameters._specs.pop('data_path', None)
154+
if hasattr(self.hyperparameters, 'output_path'):
155+
delattr(self.hyperparameters, 'output_path')
156+
self.hyperparameters._specs.pop('output_path', None)
157+
if hasattr(self.hyperparameters, 'data_s3_path'):
158+
delattr(self.hyperparameters, 'data_s3_path')
159+
self.hyperparameters._specs.pop('data_s3_path', None)
160+
if hasattr(self.hyperparameters, 'output_s3_path'):
161+
delattr(self.hyperparameters, 'output_s3_path')
162+
self.hyperparameters._specs.pop('output_s3_path', None)
163+
if hasattr(self.hyperparameters, 'training_data_name'):
164+
delattr(self.hyperparameters, 'training_data_name')
165+
self.hyperparameters._specs.pop('training_data_name', None)
166+
if hasattr(self.hyperparameters, 'validation_data_name'):
167+
delattr(self.hyperparameters, 'validation_data_name')
168+
self.hyperparameters._specs.pop('validation_data_name', None)
169+
if hasattr(self.hyperparameters, 'validation_data_path'):
170+
delattr(self.hyperparameters, 'validation_data_path')
171+
self.hyperparameters._specs.pop('validation_data_path', None)
172+
142173
@_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="DPOTrainer.train")
143174
def train(self,
144175
training_dataset: Optional[Union[str, DataSet]] = None,
@@ -198,6 +229,7 @@ def train(self,
198229
)
199230

200231
final_hyperparameters = self.hyperparameters.to_dict()
232+
_validate_hyperparameter_values(final_hyperparameters)
201233

202234
model_package_config = _create_model_package_config(
203235
model_package_group_name=self.model_package_group_name,

sagemaker-train/src/sagemaker/train/rlaif_trainer.py

Lines changed: 95 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
_create_serverless_config,
2222
_create_mlflow_config,
2323
_create_model_package_config,
24-
_validate_eula_for_gated_model
24+
_validate_eula_for_gated_model,
25+
_validate_hyperparameter_values
2526
)
2627
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
2728
from sagemaker.core.telemetry.constants import Feature
@@ -163,7 +164,8 @@ def __init__(
163164
self.accept_eula = _validate_eula_for_gated_model(model, accept_eula, is_gated_model)
164165

165166
# Process reward_prompt parameter
166-
self._process_reward_prompt()
167+
self._process_hyperparameters()
168+
167169

168170
@_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="RLAIFTrainer.train")
169171
def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validation_dataset: Optional[Union[str, DataSet]] = None, wait: bool = True):
@@ -223,6 +225,8 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
223225

224226
final_hyperparameters = self.hyperparameters.to_dict()
225227

228+
_validate_hyperparameter_values(final_hyperparameters)
229+
226230
model_package_config = _create_model_package_config(
227231
model_package_group_name=self.model_package_group_name,
228232
model=self.model,
@@ -258,40 +262,107 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
258262
self.latest_training_job = training_job
259263
return training_job
260264

261-
def _process_reward_prompt(self):
262-
"""Process reward_prompt parameter for builtin vs custom prompts."""
263-
if not self.reward_prompt:
264-
return
265-
266-
# Handle Evaluator object
267-
if not isinstance(self.reward_prompt, str):
268-
evaluator_arn = _extract_evaluator_arn(self.reward_prompt, "reward_prompt")
269-
self._evaluator_arn = evaluator_arn
270-
self._reward_prompt_processed = {"custom_prompt_arn": evaluator_arn}
265+
def _process_hyperparameters(self):
266+
"""Update hyperparameters based on constructor inputs and process reward_prompt."""
267+
if not self.hyperparameters or not hasattr(self.hyperparameters, '_specs') or not self.hyperparameters._specs:
271268
return
272-
273-
# Handle string inputs
274-
if self.reward_prompt.startswith("Builtin"):
275-
# Map to preset_prompt in hyperparameters
276-
self._reward_prompt_processed = {"preset_prompt": self.reward_prompt}
277-
elif self.reward_prompt.startswith("arn:aws:sagemaker:"):
269+
270+
# Remove keys that are handled by constructor inputs
271+
if hasattr(self.hyperparameters, 'output_path'):
272+
delattr(self.hyperparameters, 'output_path')
273+
self.hyperparameters._specs.pop('output_path', None)
274+
if hasattr(self.hyperparameters, 'data_path'):
275+
delattr(self.hyperparameters, 'data_path')
276+
self.hyperparameters._specs.pop('data_path', None)
277+
if hasattr(self.hyperparameters, 'validation_data_path'):
278+
delattr(self.hyperparameters, 'validation_data_path')
279+
self.hyperparameters._specs.pop('validation_data_path', None)
280+
281+
# Update judge_model_id if reward_model_id is provided
282+
if hasattr(self, 'reward_model_id') and self.reward_model_id:
283+
judge_model_value = f"bedrock/{self.reward_model_id}"
284+
self.hyperparameters.judge_model_id = judge_model_value
285+
286+
# Process reward_prompt parameter
287+
if hasattr(self, 'reward_prompt') and self.reward_prompt:
288+
if isinstance(self.reward_prompt, str):
289+
if self.reward_prompt.startswith("Builtin"):
290+
# Handle builtin reward prompts
291+
self._update_judge_prompt_template_direct(self.reward_prompt)
292+
else:
293+
# Handle evaluator ARN or hub content name
294+
self._process_non_builtin_reward_prompt()
295+
else:
296+
# Handle evaluator object
297+
if hasattr(self.hyperparameters, 'judge_prompt_template'):
298+
delattr(self.hyperparameters, 'judge_prompt_template')
299+
self.hyperparameters._specs.pop('judge_prompt_template', None)
300+
301+
evaluator_arn = _extract_evaluator_arn(self.reward_prompt, "reward_prompt")
302+
self._evaluator_arn = evaluator_arn
303+
304+
def _process_non_builtin_reward_prompt(self):
305+
"""Process non-builtin reward prompt (ARN or hub content name)."""
306+
# Remove judge_prompt_template for non-builtin prompts
307+
if hasattr(self.hyperparameters, 'judge_prompt_template'):
308+
delattr(self.hyperparameters, 'judge_prompt_template')
309+
self.hyperparameters._specs.pop('judge_prompt_template', None)
310+
311+
if self.reward_prompt.startswith("arn:aws:sagemaker:"):
278312
# Validate and assign ARN
279313
evaluator_arn = _extract_evaluator_arn(self.reward_prompt, "reward_prompt")
280314
self._evaluator_arn = evaluator_arn
281-
self._reward_prompt_processed = {"custom_prompt_arn": evaluator_arn}
282315
else:
283316
try:
284-
session = self.sagemaker_session or _get_beta_session()
317+
session = TrainDefaults.get_sagemaker_session(
318+
sagemaker_session=self.sagemaker_session
319+
)
285320
hub_content = _get_hub_content_metadata(
286-
hub_name=HUB_NAME, # or appropriate hub name
321+
hub_name=HUB_NAME,
287322
hub_content_type="JsonDoc",
288323
hub_content_name=self.reward_prompt,
289324
session=session.boto_session,
290-
region=session.boto_session.region_name or "us-west-2"
325+
region=session.boto_session.region_name
291326
)
292-
# Store ARN for evaluator_arn in ServerlessJobConfig
327+
# Store ARN for evaluator_arn
293328
self._evaluator_arn = hub_content.hub_content_arn
294-
self._reward_prompt_processed = {"custom_prompt_arn": hub_content.hub_content_arn}
295329
except Exception as e:
296330
raise ValueError(f"Custom prompt '{self.reward_prompt}' not found in HubContent: {e}")
331+
332+
333+
334+
def _update_judge_prompt_template_direct(self, reward_prompt):
335+
"""Update judge_prompt_template based on Builtin reward function."""
336+
# Get available templates from hyperparameters specs
337+
judge_prompt_spec = self.hyperparameters._specs.get('judge_prompt_template', {})
338+
available_templates = judge_prompt_spec.get('enum', [])
339+
340+
if not available_templates:
341+
# If no enum found, use the current value as the only available option
342+
current_value = getattr(self.hyperparameters, 'judge_prompt_template', None)
343+
if current_value:
344+
available_templates = [current_value]
345+
else:
346+
return
347+
348+
# Extract template name after "Builtin." and convert to lowercase
349+
template_name = reward_prompt.split(".", 1)[1].lower()
350+
351+
# Find matching template by extracting filename without extension
352+
matching_template = None
353+
for template in available_templates:
354+
template_filename = template.split("/")[-1].replace(".jinja", "").lower()
355+
if template_filename == template_name:
356+
matching_template = template
357+
break
358+
359+
if matching_template:
360+
self.hyperparameters.judge_prompt_template = matching_template
361+
else:
362+
available_options = [f"Builtin.{t.split('/')[-1].replace('.jinja', '')}" for t in available_templates]
363+
raise ValueError(
364+
f"Selected reward function option '{reward_prompt}' is not available. "
365+
f"Choose one from the available options: {available_options}. "
366+
f"Example: reward_prompt='Builtin.summarize'"
367+
)
297368

sagemaker-train/src/sagemaker/train/rlvr_trainer.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
_create_serverless_config,
2020
_create_mlflow_config,
2121
_create_model_package_config,
22-
_validate_eula_for_gated_model
22+
_validate_eula_for_gated_model,
23+
_validate_hyperparameter_values
2324
)
2425
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
2526
from sagemaker.core.telemetry.constants import Feature
@@ -148,9 +149,32 @@ def __init__(
148149
sagemaker_session=self.sagemaker_session
149150
))
150151

152+
# Remove constructor-handled hyperparameters
153+
self._process_hyperparameters()
154+
151155
# Validate and set EULA acceptance
152156
self.accept_eula = _validate_eula_for_gated_model(model, accept_eula, is_gated_model)
153157

158+
def _process_hyperparameters(self):
159+
"""Remove hyperparameter keys that are handled by constructor inputs."""
160+
if self.hyperparameters:
161+
# Remove keys that are handled by constructor inputs
162+
if hasattr(self.hyperparameters, 'data_s3_path'):
163+
delattr(self.hyperparameters, 'data_s3_path')
164+
self.hyperparameters._specs.pop('data_s3_path', None)
165+
if hasattr(self.hyperparameters, 'reward_lambda_arn'):
166+
delattr(self.hyperparameters, 'reward_lambda_arn')
167+
self.hyperparameters._specs.pop('reward_lambda_arn', None)
168+
if hasattr(self.hyperparameters, 'data_path'):
169+
delattr(self.hyperparameters, 'data_path')
170+
self.hyperparameters._specs.pop('data_path', None)
171+
if hasattr(self.hyperparameters, 'validation_data_path'):
172+
delattr(self.hyperparameters, 'validation_data_path')
173+
self.hyperparameters._specs.pop('validation_data_path', None)
174+
if hasattr(self.hyperparameters, 'output_path'):
175+
delattr(self.hyperparameters, 'output_path')
176+
self.hyperparameters._specs.pop('output_path', None)
177+
154178
@_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="RLVRTrainer.train")
155179
def train(self, training_dataset: Optional[Union[str, DataSet]] = None,
156180
validation_dataset: Optional[Union[str, DataSet]] = None, wait: bool = True):
@@ -210,6 +234,9 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None,
210234
)
211235

212236
final_hyperparameters = self.hyperparameters.to_dict()
237+
238+
# Validate hyperparameter values
239+
_validate_hyperparameter_values(final_hyperparameters)
213240

214241
model_package_config = _create_model_package_config(
215242
model_package_group_name=self.model_package_group_name,

0 commit comments

Comments
 (0)