Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -352,13 +352,18 @@ def _get_fine_tuning_options_and_model_arn(model_name: str, customization_techni
elif (isinstance(training_type, TrainingType) and training_type == TrainingType.FULL) or training_type == "FULL":
recipe = next((r for r in recipes_with_template if not r.get("Peft")), None)

if recipe and recipe.get("SmtjOverrideParamsS3Uri"):
if not recipe:
raise ValueError(f"No recipes found with Smtj for technique: {customization_technique},training_type:{training_type}")

elif recipe and recipe.get("SmtjOverrideParamsS3Uri"):
s3_uri = recipe["SmtjOverrideParamsS3Uri"]
s3 = boto3.client("s3")
bucket, key = s3_uri.replace("s3://", "").split("/", 1)
obj = s3.get_object(Bucket=bucket, Key=key)
options_dict = json.loads(obj["Body"].read())
return FineTuningOptions(options_dict), model_arn, is_gated_model
else:
return FineTuningOptions({}), model_arn, is_gated_model

except Exception as e:
logger.error("Exception getting fine-tuning options: %s", e)
Expand Down Expand Up @@ -598,6 +603,9 @@ def _create_output_config(sagemaker_session,s3_output_path=None, kms_key_id=None
# Use default S3 output path if none provided
if s3_output_path is None:
s3_output_path = _get_default_s3_output_path(sagemaker_session)

# Validate S3 path exists
_validate_s3_path_exists(s3_output_path, sagemaker_session)

return OutputDataConfig(
s3_output_path=s3_output_path,
Expand Down Expand Up @@ -682,3 +690,43 @@ def _validate_eula_for_gated_model(model, accept_eula, is_gated_model):
)

return accept_eula


def _validate_s3_path_exists(s3_path: str, sagemaker_session):
"""Validate if S3 path exists and is accessible."""
if not s3_path.startswith("s3://"):
raise ValueError(f"Invalid S3 path format: {s3_path}")

# Parse S3 URI
s3_parts = s3_path.replace("s3://", "").split("/", 1)
bucket_name = s3_parts[0]
prefix = s3_parts[1] if len(s3_parts) > 1 else ""

s3_client = sagemaker_session.boto_session.client('s3')

try:
# Check if bucket exists and is accessible
s3_client.head_bucket(Bucket=bucket_name)

# If prefix is provided, check if it exists
if prefix:
response = s3_client.list_objects_v2(Bucket=bucket_name, Prefix=prefix, MaxKeys=1)
if 'Contents' not in response:
raise ValueError(f"S3 prefix '{prefix}' does not exist in bucket '{bucket_name}'")

except Exception as e:
if "NoSuchBucket" in str(e):
raise ValueError(f"S3 bucket '{bucket_name}' does not exist or is not accessible")
raise ValueError(f"Failed to validate S3 path '{s3_path}': {str(e)}")


def _validate_hyperparameter_values(hyperparameters: dict):
"""Validate hyperparameter values for allowed characters."""
import re
allowed_chars = r"^[a-zA-Z0-9/_.:,\-\s'\"\[\]]*$"
for key, value in hyperparameters.items():
if isinstance(value, str) and not re.match(allowed_chars, value):
raise ValueError(
f"Hyperparameter '{key}' value '{value}' contains invalid characters. "
f"Only a-z, A-Z, 0-9, /, _, ., :, \\, -, space, ', \", [, ] and , are allowed."
)
34 changes: 33 additions & 1 deletion sagemaker-train/src/sagemaker/train/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
_create_serverless_config,
_create_mlflow_config,
_create_model_package_config,
_validate_eula_for_gated_model
_validate_eula_for_gated_model,
_validate_hyperparameter_values
)
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
from sagemaker.core.telemetry.constants import Feature
Expand Down Expand Up @@ -137,8 +138,38 @@ def __init__(

))

# Process hyperparameters
self._process_hyperparameters()

# Validate and set EULA acceptance
self.accept_eula = _validate_eula_for_gated_model(model, accept_eula, is_gated_model)

def _process_hyperparameters(self):
"""Remove hyperparameter keys that are handled by constructor inputs."""
if self.hyperparameters:
# Remove keys that are handled by constructor inputs
if hasattr(self.hyperparameters, 'data_path'):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: there is code redundancy improvements possible here. Have an ignore_list somewhere in constants and look over it to delattr.

delattr(self.hyperparameters, 'data_path')
self.hyperparameters._specs.pop('data_path', None)
if hasattr(self.hyperparameters, 'output_path'):
delattr(self.hyperparameters, 'output_path')
self.hyperparameters._specs.pop('output_path', None)
if hasattr(self.hyperparameters, 'data_s3_path'):
delattr(self.hyperparameters, 'data_s3_path')
self.hyperparameters._specs.pop('data_s3_path', None)
if hasattr(self.hyperparameters, 'output_s3_path'):
delattr(self.hyperparameters, 'output_s3_path')
self.hyperparameters._specs.pop('output_s3_path', None)
if hasattr(self.hyperparameters, 'training_data_name'):
delattr(self.hyperparameters, 'training_data_name')
self.hyperparameters._specs.pop('training_data_name', None)
if hasattr(self.hyperparameters, 'validation_data_name'):
delattr(self.hyperparameters, 'validation_data_name')
self.hyperparameters._specs.pop('validation_data_name', None)
if hasattr(self.hyperparameters, 'validation_data_path'):
delattr(self.hyperparameters, 'validation_data_path')
self.hyperparameters._specs.pop('validation_data_path', None)

@_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="DPOTrainer.train")
def train(self,
training_dataset: Optional[Union[str, DataSet]] = None,
Expand Down Expand Up @@ -198,6 +229,7 @@ def train(self,
)

final_hyperparameters = self.hyperparameters.to_dict()
_validate_hyperparameter_values(final_hyperparameters)

model_package_config = _create_model_package_config(
model_package_group_name=self.model_package_group_name,
Expand Down
119 changes: 95 additions & 24 deletions sagemaker-train/src/sagemaker/train/rlaif_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
_create_serverless_config,
_create_mlflow_config,
_create_model_package_config,
_validate_eula_for_gated_model
_validate_eula_for_gated_model,
_validate_hyperparameter_values
)
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
from sagemaker.core.telemetry.constants import Feature
Expand Down Expand Up @@ -163,7 +164,8 @@ def __init__(
self.accept_eula = _validate_eula_for_gated_model(model, accept_eula, is_gated_model)

# Process reward_prompt parameter
self._process_reward_prompt()
self._process_hyperparameters()


@_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="RLAIFTrainer.train")
def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validation_dataset: Optional[Union[str, DataSet]] = None, wait: bool = True):
Expand Down Expand Up @@ -223,6 +225,8 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati

final_hyperparameters = self.hyperparameters.to_dict()

_validate_hyperparameter_values(final_hyperparameters)

model_package_config = _create_model_package_config(
model_package_group_name=self.model_package_group_name,
model=self.model,
Expand Down Expand Up @@ -258,40 +262,107 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
self.latest_training_job = training_job
return training_job

def _process_reward_prompt(self):
"""Process reward_prompt parameter for builtin vs custom prompts."""
if not self.reward_prompt:
return

# Handle Evaluator object
if not isinstance(self.reward_prompt, str):
evaluator_arn = _extract_evaluator_arn(self.reward_prompt, "reward_prompt")
self._evaluator_arn = evaluator_arn
self._reward_prompt_processed = {"custom_prompt_arn": evaluator_arn}
def _process_hyperparameters(self):
"""Update hyperparameters based on constructor inputs and process reward_prompt."""
if not self.hyperparameters or not hasattr(self.hyperparameters, '_specs') or not self.hyperparameters._specs:
return

# Handle string inputs
if self.reward_prompt.startswith("Builtin"):
# Map to preset_prompt in hyperparameters
self._reward_prompt_processed = {"preset_prompt": self.reward_prompt}
elif self.reward_prompt.startswith("arn:aws:sagemaker:"):

# Remove keys that are handled by constructor inputs
if hasattr(self.hyperparameters, 'output_path'):
delattr(self.hyperparameters, 'output_path')
self.hyperparameters._specs.pop('output_path', None)
if hasattr(self.hyperparameters, 'data_path'):
delattr(self.hyperparameters, 'data_path')
self.hyperparameters._specs.pop('data_path', None)
if hasattr(self.hyperparameters, 'validation_data_path'):
delattr(self.hyperparameters, 'validation_data_path')
self.hyperparameters._specs.pop('validation_data_path', None)

# Update judge_model_id if reward_model_id is provided
if hasattr(self, 'reward_model_id') and self.reward_model_id:
judge_model_value = f"bedrock/{self.reward_model_id}"
self.hyperparameters.judge_model_id = judge_model_value

# Process reward_prompt parameter
if hasattr(self, 'reward_prompt') and self.reward_prompt:
if isinstance(self.reward_prompt, str):
if self.reward_prompt.startswith("Builtin"):
# Handle builtin reward prompts
self._update_judge_prompt_template_direct(self.reward_prompt)
else:
# Handle evaluator ARN or hub content name
self._process_non_builtin_reward_prompt()
else:
# Handle evaluator object
if hasattr(self.hyperparameters, 'judge_prompt_template'):
delattr(self.hyperparameters, 'judge_prompt_template')
self.hyperparameters._specs.pop('judge_prompt_template', None)

evaluator_arn = _extract_evaluator_arn(self.reward_prompt, "reward_prompt")
self._evaluator_arn = evaluator_arn

def _process_non_builtin_reward_prompt(self):
"""Process non-builtin reward prompt (ARN or hub content name)."""
# Remove judge_prompt_template for non-builtin prompts
if hasattr(self.hyperparameters, 'judge_prompt_template'):
delattr(self.hyperparameters, 'judge_prompt_template')
self.hyperparameters._specs.pop('judge_prompt_template', None)

if self.reward_prompt.startswith("arn:aws:sagemaker:"):
# Validate and assign ARN
evaluator_arn = _extract_evaluator_arn(self.reward_prompt, "reward_prompt")
self._evaluator_arn = evaluator_arn
self._reward_prompt_processed = {"custom_prompt_arn": evaluator_arn}
else:
try:
session = self.sagemaker_session or _get_beta_session()
session = TrainDefaults.get_sagemaker_session(
sagemaker_session=self.sagemaker_session
)
hub_content = _get_hub_content_metadata(
hub_name=HUB_NAME, # or appropriate hub name
hub_name=HUB_NAME,
hub_content_type="JsonDoc",
hub_content_name=self.reward_prompt,
session=session.boto_session,
region=session.boto_session.region_name or "us-west-2"
region=session.boto_session.region_name
)
# Store ARN for evaluator_arn in ServerlessJobConfig
# Store ARN for evaluator_arn
self._evaluator_arn = hub_content.hub_content_arn
self._reward_prompt_processed = {"custom_prompt_arn": hub_content.hub_content_arn}
except Exception as e:
raise ValueError(f"Custom prompt '{self.reward_prompt}' not found in HubContent: {e}")



def _update_judge_prompt_template_direct(self, reward_prompt):
"""Update judge_prompt_template based on Builtin reward function."""
# Get available templates from hyperparameters specs
judge_prompt_spec = self.hyperparameters._specs.get('judge_prompt_template', {})
available_templates = judge_prompt_spec.get('enum', [])

if not available_templates:
# If no enum found, use the current value as the only available option
current_value = getattr(self.hyperparameters, 'judge_prompt_template', None)
if current_value:
available_templates = [current_value]
else:
return

# Extract template name after "Builtin." and convert to lowercase
template_name = reward_prompt.split(".", 1)[1].lower()

# Find matching template by extracting filename without extension
matching_template = None
for template in available_templates:
template_filename = template.split("/")[-1].replace(".jinja", "").lower()
if template_filename == template_name:
matching_template = template
break

if matching_template:
self.hyperparameters.judge_prompt_template = matching_template
else:
available_options = [f"Builtin.{t.split('/')[-1].replace('.jinja', '')}" for t in available_templates]
raise ValueError(
f"Selected reward function option '{reward_prompt}' is not available. "
f"Choose one from the available options: {available_options}. "
f"Example: reward_prompt='Builtin.summarize'"
)

29 changes: 28 additions & 1 deletion sagemaker-train/src/sagemaker/train/rlvr_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
_create_serverless_config,
_create_mlflow_config,
_create_model_package_config,
_validate_eula_for_gated_model
_validate_eula_for_gated_model,
_validate_hyperparameter_values
)
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
from sagemaker.core.telemetry.constants import Feature
Expand Down Expand Up @@ -148,9 +149,32 @@ def __init__(
sagemaker_session=self.sagemaker_session
))

# Remove constructor-handled hyperparameters
self._process_hyperparameters()

# Validate and set EULA acceptance
self.accept_eula = _validate_eula_for_gated_model(model, accept_eula, is_gated_model)

def _process_hyperparameters(self):
"""Remove hyperparameter keys that are handled by constructor inputs."""
if self.hyperparameters:
# Remove keys that are handled by constructor inputs
if hasattr(self.hyperparameters, 'data_s3_path'):
delattr(self.hyperparameters, 'data_s3_path')
self.hyperparameters._specs.pop('data_s3_path', None)
if hasattr(self.hyperparameters, 'reward_lambda_arn'):
delattr(self.hyperparameters, 'reward_lambda_arn')
self.hyperparameters._specs.pop('reward_lambda_arn', None)
if hasattr(self.hyperparameters, 'data_path'):
delattr(self.hyperparameters, 'data_path')
self.hyperparameters._specs.pop('data_path', None)
if hasattr(self.hyperparameters, 'validation_data_path'):
delattr(self.hyperparameters, 'validation_data_path')
self.hyperparameters._specs.pop('validation_data_path', None)
if hasattr(self.hyperparameters, 'output_path'):
delattr(self.hyperparameters, 'output_path')
self.hyperparameters._specs.pop('output_path', None)

@_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="RLVRTrainer.train")
def train(self, training_dataset: Optional[Union[str, DataSet]] = None,
validation_dataset: Optional[Union[str, DataSet]] = None, wait: bool = True):
Expand Down Expand Up @@ -210,6 +234,9 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None,
)

final_hyperparameters = self.hyperparameters.to_dict()

# Validate hyperparameter values
_validate_hyperparameter_values(final_hyperparameters)

model_package_config = _create_model_package_config(
model_package_group_name=self.model_package_group_name,
Expand Down
Loading
Loading