diff --git a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py index 5a6fd8644d..ee0256b79c 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py @@ -352,13 +352,18 @@ def _get_fine_tuning_options_and_model_arn(model_name: str, customization_techni elif (isinstance(training_type, TrainingType) and training_type == TrainingType.FULL) or training_type == "FULL": recipe = next((r for r in recipes_with_template if not r.get("Peft")), None) - if recipe and recipe.get("SmtjOverrideParamsS3Uri"): + if not recipe: + raise ValueError(f"No recipes found with Smtj for technique: {customization_technique},training_type:{training_type}") + + elif recipe and recipe.get("SmtjOverrideParamsS3Uri"): s3_uri = recipe["SmtjOverrideParamsS3Uri"] s3 = boto3.client("s3") bucket, key = s3_uri.replace("s3://", "").split("/", 1) obj = s3.get_object(Bucket=bucket, Key=key) options_dict = json.loads(obj["Body"].read()) return FineTuningOptions(options_dict), model_arn, is_gated_model + else: + return FineTuningOptions({}), model_arn, is_gated_model except Exception as e: logger.error("Exception getting fine-tuning options: %s", e) @@ -598,6 +603,9 @@ def _create_output_config(sagemaker_session,s3_output_path=None, kms_key_id=None # Use default S3 output path if none provided if s3_output_path is None: s3_output_path = _get_default_s3_output_path(sagemaker_session) + + # Validate S3 path exists + _validate_s3_path_exists(s3_output_path, sagemaker_session) return OutputDataConfig( s3_output_path=s3_output_path, @@ -682,3 +690,43 @@ def _validate_eula_for_gated_model(model, accept_eula, is_gated_model): ) return accept_eula + + +def _validate_s3_path_exists(s3_path: str, sagemaker_session): + """Validate if S3 path exists and is accessible.""" + if not s3_path.startswith("s3://"): + raise ValueError(f"Invalid S3 path format: {s3_path}") + + # Parse S3 URI + s3_parts = s3_path.replace("s3://", "").split("/", 1) + bucket_name = s3_parts[0] + prefix = s3_parts[1] if len(s3_parts) > 1 else "" + + s3_client = sagemaker_session.boto_session.client('s3') + + try: + # Check if bucket exists and is accessible + s3_client.head_bucket(Bucket=bucket_name) + + # If prefix is provided, check if it exists + if prefix: + response = s3_client.list_objects_v2(Bucket=bucket_name, Prefix=prefix, MaxKeys=1) + if 'Contents' not in response: + raise ValueError(f"S3 prefix '{prefix}' does not exist in bucket '{bucket_name}'") + + except Exception as e: + if "NoSuchBucket" in str(e): + raise ValueError(f"S3 bucket '{bucket_name}' does not exist or is not accessible") + raise ValueError(f"Failed to validate S3 path '{s3_path}': {str(e)}") + + +def _validate_hyperparameter_values(hyperparameters: dict): + """Validate hyperparameter values for allowed characters.""" + import re + allowed_chars = r"^[a-zA-Z0-9/_.:,\-\s'\"\[\]]*$" + for key, value in hyperparameters.items(): + if isinstance(value, str) and not re.match(allowed_chars, value): + raise ValueError( + f"Hyperparameter '{key}' value '{value}' contains invalid characters. " + f"Only a-z, A-Z, 0-9, /, _, ., :, \\, -, space, ', \", [, ] and , are allowed." + ) diff --git a/sagemaker-train/src/sagemaker/train/dpo_trainer.py b/sagemaker-train/src/sagemaker/train/dpo_trainer.py index 766d693b6a..66ca88130b 100644 --- a/sagemaker-train/src/sagemaker/train/dpo_trainer.py +++ b/sagemaker-train/src/sagemaker/train/dpo_trainer.py @@ -17,7 +17,8 @@ _create_serverless_config, _create_mlflow_config, _create_model_package_config, - _validate_eula_for_gated_model + _validate_eula_for_gated_model, + _validate_hyperparameter_values ) from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter from sagemaker.core.telemetry.constants import Feature @@ -137,8 +138,38 @@ def __init__( )) + # Process hyperparameters + self._process_hyperparameters() + # Validate and set EULA acceptance self.accept_eula = _validate_eula_for_gated_model(model, accept_eula, is_gated_model) + + def _process_hyperparameters(self): + """Remove hyperparameter keys that are handled by constructor inputs.""" + if self.hyperparameters: + # Remove keys that are handled by constructor inputs + if hasattr(self.hyperparameters, 'data_path'): + delattr(self.hyperparameters, 'data_path') + self.hyperparameters._specs.pop('data_path', None) + if hasattr(self.hyperparameters, 'output_path'): + delattr(self.hyperparameters, 'output_path') + self.hyperparameters._specs.pop('output_path', None) + if hasattr(self.hyperparameters, 'data_s3_path'): + delattr(self.hyperparameters, 'data_s3_path') + self.hyperparameters._specs.pop('data_s3_path', None) + if hasattr(self.hyperparameters, 'output_s3_path'): + delattr(self.hyperparameters, 'output_s3_path') + self.hyperparameters._specs.pop('output_s3_path', None) + if hasattr(self.hyperparameters, 'training_data_name'): + delattr(self.hyperparameters, 'training_data_name') + self.hyperparameters._specs.pop('training_data_name', None) + if hasattr(self.hyperparameters, 'validation_data_name'): + delattr(self.hyperparameters, 'validation_data_name') + self.hyperparameters._specs.pop('validation_data_name', None) + if hasattr(self.hyperparameters, 'validation_data_path'): + delattr(self.hyperparameters, 'validation_data_path') + self.hyperparameters._specs.pop('validation_data_path', None) + @_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="DPOTrainer.train") def train(self, training_dataset: Optional[Union[str, DataSet]] = None, @@ -198,6 +229,7 @@ def train(self, ) final_hyperparameters = self.hyperparameters.to_dict() + _validate_hyperparameter_values(final_hyperparameters) model_package_config = _create_model_package_config( model_package_group_name=self.model_package_group_name, diff --git a/sagemaker-train/src/sagemaker/train/rlaif_trainer.py b/sagemaker-train/src/sagemaker/train/rlaif_trainer.py index 06090d0eb4..bc6ab234e9 100644 --- a/sagemaker-train/src/sagemaker/train/rlaif_trainer.py +++ b/sagemaker-train/src/sagemaker/train/rlaif_trainer.py @@ -21,7 +21,8 @@ _create_serverless_config, _create_mlflow_config, _create_model_package_config, - _validate_eula_for_gated_model + _validate_eula_for_gated_model, + _validate_hyperparameter_values ) from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter from sagemaker.core.telemetry.constants import Feature @@ -163,7 +164,8 @@ def __init__( self.accept_eula = _validate_eula_for_gated_model(model, accept_eula, is_gated_model) # Process reward_prompt parameter - self._process_reward_prompt() + self._process_hyperparameters() + @_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="RLAIFTrainer.train") def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validation_dataset: Optional[Union[str, DataSet]] = None, wait: bool = True): @@ -223,6 +225,8 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati final_hyperparameters = self.hyperparameters.to_dict() + _validate_hyperparameter_values(final_hyperparameters) + model_package_config = _create_model_package_config( model_package_group_name=self.model_package_group_name, model=self.model, @@ -258,40 +262,107 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati self.latest_training_job = training_job return training_job - def _process_reward_prompt(self): - """Process reward_prompt parameter for builtin vs custom prompts.""" - if not self.reward_prompt: - return - - # Handle Evaluator object - if not isinstance(self.reward_prompt, str): - evaluator_arn = _extract_evaluator_arn(self.reward_prompt, "reward_prompt") - self._evaluator_arn = evaluator_arn - self._reward_prompt_processed = {"custom_prompt_arn": evaluator_arn} + def _process_hyperparameters(self): + """Update hyperparameters based on constructor inputs and process reward_prompt.""" + if not self.hyperparameters or not hasattr(self.hyperparameters, '_specs') or not self.hyperparameters._specs: return - - # Handle string inputs - if self.reward_prompt.startswith("Builtin"): - # Map to preset_prompt in hyperparameters - self._reward_prompt_processed = {"preset_prompt": self.reward_prompt} - elif self.reward_prompt.startswith("arn:aws:sagemaker:"): + + # Remove keys that are handled by constructor inputs + if hasattr(self.hyperparameters, 'output_path'): + delattr(self.hyperparameters, 'output_path') + self.hyperparameters._specs.pop('output_path', None) + if hasattr(self.hyperparameters, 'data_path'): + delattr(self.hyperparameters, 'data_path') + self.hyperparameters._specs.pop('data_path', None) + if hasattr(self.hyperparameters, 'validation_data_path'): + delattr(self.hyperparameters, 'validation_data_path') + self.hyperparameters._specs.pop('validation_data_path', None) + + # Update judge_model_id if reward_model_id is provided + if hasattr(self, 'reward_model_id') and self.reward_model_id: + judge_model_value = f"bedrock/{self.reward_model_id}" + self.hyperparameters.judge_model_id = judge_model_value + + # Process reward_prompt parameter + if hasattr(self, 'reward_prompt') and self.reward_prompt: + if isinstance(self.reward_prompt, str): + if self.reward_prompt.startswith("Builtin"): + # Handle builtin reward prompts + self._update_judge_prompt_template_direct(self.reward_prompt) + else: + # Handle evaluator ARN or hub content name + self._process_non_builtin_reward_prompt() + else: + # Handle evaluator object + if hasattr(self.hyperparameters, 'judge_prompt_template'): + delattr(self.hyperparameters, 'judge_prompt_template') + self.hyperparameters._specs.pop('judge_prompt_template', None) + + evaluator_arn = _extract_evaluator_arn(self.reward_prompt, "reward_prompt") + self._evaluator_arn = evaluator_arn + + def _process_non_builtin_reward_prompt(self): + """Process non-builtin reward prompt (ARN or hub content name).""" + # Remove judge_prompt_template for non-builtin prompts + if hasattr(self.hyperparameters, 'judge_prompt_template'): + delattr(self.hyperparameters, 'judge_prompt_template') + self.hyperparameters._specs.pop('judge_prompt_template', None) + + if self.reward_prompt.startswith("arn:aws:sagemaker:"): # Validate and assign ARN evaluator_arn = _extract_evaluator_arn(self.reward_prompt, "reward_prompt") self._evaluator_arn = evaluator_arn - self._reward_prompt_processed = {"custom_prompt_arn": evaluator_arn} else: try: - session = self.sagemaker_session or _get_beta_session() + session = TrainDefaults.get_sagemaker_session( + sagemaker_session=self.sagemaker_session + ) hub_content = _get_hub_content_metadata( - hub_name=HUB_NAME, # or appropriate hub name + hub_name=HUB_NAME, hub_content_type="JsonDoc", hub_content_name=self.reward_prompt, session=session.boto_session, - region=session.boto_session.region_name or "us-west-2" + region=session.boto_session.region_name ) - # Store ARN for evaluator_arn in ServerlessJobConfig + # Store ARN for evaluator_arn self._evaluator_arn = hub_content.hub_content_arn - self._reward_prompt_processed = {"custom_prompt_arn": hub_content.hub_content_arn} except Exception as e: raise ValueError(f"Custom prompt '{self.reward_prompt}' not found in HubContent: {e}") + + + + def _update_judge_prompt_template_direct(self, reward_prompt): + """Update judge_prompt_template based on Builtin reward function.""" + # Get available templates from hyperparameters specs + judge_prompt_spec = self.hyperparameters._specs.get('judge_prompt_template', {}) + available_templates = judge_prompt_spec.get('enum', []) + + if not available_templates: + # If no enum found, use the current value as the only available option + current_value = getattr(self.hyperparameters, 'judge_prompt_template', None) + if current_value: + available_templates = [current_value] + else: + return + + # Extract template name after "Builtin." and convert to lowercase + template_name = reward_prompt.split(".", 1)[1].lower() + + # Find matching template by extracting filename without extension + matching_template = None + for template in available_templates: + template_filename = template.split("/")[-1].replace(".jinja", "").lower() + if template_filename == template_name: + matching_template = template + break + + if matching_template: + self.hyperparameters.judge_prompt_template = matching_template + else: + available_options = [f"Builtin.{t.split('/')[-1].replace('.jinja', '')}" for t in available_templates] + raise ValueError( + f"Selected reward function option '{reward_prompt}' is not available. " + f"Choose one from the available options: {available_options}. " + f"Example: reward_prompt='Builtin.summarize'" + ) diff --git a/sagemaker-train/src/sagemaker/train/rlvr_trainer.py b/sagemaker-train/src/sagemaker/train/rlvr_trainer.py index 60dd4f8593..e14734b692 100644 --- a/sagemaker-train/src/sagemaker/train/rlvr_trainer.py +++ b/sagemaker-train/src/sagemaker/train/rlvr_trainer.py @@ -19,7 +19,8 @@ _create_serverless_config, _create_mlflow_config, _create_model_package_config, - _validate_eula_for_gated_model + _validate_eula_for_gated_model, + _validate_hyperparameter_values ) from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter from sagemaker.core.telemetry.constants import Feature @@ -148,9 +149,32 @@ def __init__( sagemaker_session=self.sagemaker_session )) + # Remove constructor-handled hyperparameters + self._process_hyperparameters() + # Validate and set EULA acceptance self.accept_eula = _validate_eula_for_gated_model(model, accept_eula, is_gated_model) + def _process_hyperparameters(self): + """Remove hyperparameter keys that are handled by constructor inputs.""" + if self.hyperparameters: + # Remove keys that are handled by constructor inputs + if hasattr(self.hyperparameters, 'data_s3_path'): + delattr(self.hyperparameters, 'data_s3_path') + self.hyperparameters._specs.pop('data_s3_path', None) + if hasattr(self.hyperparameters, 'reward_lambda_arn'): + delattr(self.hyperparameters, 'reward_lambda_arn') + self.hyperparameters._specs.pop('reward_lambda_arn', None) + if hasattr(self.hyperparameters, 'data_path'): + delattr(self.hyperparameters, 'data_path') + self.hyperparameters._specs.pop('data_path', None) + if hasattr(self.hyperparameters, 'validation_data_path'): + delattr(self.hyperparameters, 'validation_data_path') + self.hyperparameters._specs.pop('validation_data_path', None) + if hasattr(self.hyperparameters, 'output_path'): + delattr(self.hyperparameters, 'output_path') + self.hyperparameters._specs.pop('output_path', None) + @_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="RLVRTrainer.train") def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validation_dataset: Optional[Union[str, DataSet]] = None, wait: bool = True): @@ -210,6 +234,9 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, ) final_hyperparameters = self.hyperparameters.to_dict() + + # Validate hyperparameter values + _validate_hyperparameter_values(final_hyperparameters) model_package_config = _create_model_package_config( model_package_group_name=self.model_package_group_name, diff --git a/sagemaker-train/src/sagemaker/train/sft_trainer.py b/sagemaker-train/src/sagemaker/train/sft_trainer.py index 6a0009b28b..4e109a85b9 100644 --- a/sagemaker-train/src/sagemaker/train/sft_trainer.py +++ b/sagemaker-train/src/sagemaker/train/sft_trainer.py @@ -17,7 +17,8 @@ _create_serverless_config, _create_mlflow_config, _create_model_package_config, - _validate_eula_for_gated_model + _validate_eula_for_gated_model, + _validate_hyperparameter_values ) from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter from sagemaker.core.telemetry.constants import Feature @@ -139,9 +140,38 @@ def __init__( sagemaker_session=self.sagemaker_session )) + # Process hyperparameters + self._process_hyperparameters() + # Validate and set EULA acceptance self.accept_eula = _validate_eula_for_gated_model(model, accept_eula, is_gated_model) + def _process_hyperparameters(self): + """Remove hyperparameter keys that are handled by constructor inputs.""" + if self.hyperparameters: + # Remove keys that are handled by constructor inputs + if hasattr(self.hyperparameters, 'data_path'): + delattr(self.hyperparameters, 'data_path') + self.hyperparameters._specs.pop('data_path', None) + if hasattr(self.hyperparameters, 'output_path'): + delattr(self.hyperparameters, 'output_path') + self.hyperparameters._specs.pop('output_path', None) + if hasattr(self.hyperparameters, 'data_s3_path'): + delattr(self.hyperparameters, 'data_s3_path') + self.hyperparameters._specs.pop('data_s3_path', None) + if hasattr(self.hyperparameters, 'output_s3_path'): + delattr(self.hyperparameters, 'output_s3_path') + self.hyperparameters._specs.pop('output_s3_path', None) + if hasattr(self.hyperparameters, 'training_data_name'): + delattr(self.hyperparameters, 'training_data_name') + self.hyperparameters._specs.pop('training_data_name', None) + if hasattr(self.hyperparameters, 'validation_data_name'): + delattr(self.hyperparameters, 'validation_data_name') + self.hyperparameters._specs.pop('validation_data_name', None) + if hasattr(self.hyperparameters, 'validation_data_path'): + delattr(self.hyperparameters, 'validation_data_path') + self.hyperparameters._specs.pop('validation_data_path', None) + @_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="SFTTrainer.train") def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validation_dataset: Optional[Union[str, DataSet]] = None, wait: bool = True): """Execute the SFT training job. @@ -197,6 +227,9 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati ) final_hyperparameters = self.hyperparameters.to_dict() + + # Validate hyperparameter values + _validate_hyperparameter_values(final_hyperparameters) model_package_config = _create_model_package_config( model_package_group_name=self.model_package_group_name, diff --git a/sagemaker-train/tests/integ/train/test_dpo_trainer_integration.py b/sagemaker-train/tests/integ/train/test_dpo_trainer_integration.py index d220e77aa9..8c2c49dbc4 100644 --- a/sagemaker-train/tests/integ/train/test_dpo_trainer_integration.py +++ b/sagemaker-train/tests/integ/train/test_dpo_trainer_integration.py @@ -30,10 +30,8 @@ def test_dpo_trainer_lora_complete_workflow(sagemaker_session): model="meta-textgeneration-llama-3-2-1b-instruct", training_type=TrainingType.LORA, model_package_group_name="sdk-test-finetuned-models", - training_dataset="s3://mc-flows-sdk-testing/input_data/dpo/preference_dataset_train_256.jsonl", + training_dataset="arn:aws:sagemaker:us-west-2:729646638167:hub-content/sdktest/DataSet/dpo-oss-test-data/0.0.1", s3_output_path="s3://mc-flows-sdk-testing/output/", - # Unique job name - base_job_name=f"dpo-llama-{random.randint(1, 1000)}", accept_eula=True ) @@ -71,11 +69,9 @@ def test_dpo_trainer_with_validation_dataset(sagemaker_session): model="meta-textgeneration-llama-3-2-1b-instruct", training_type=TrainingType.LORA, model_package_group_name="sdk-test-finetuned-models", - training_dataset="s3://mc-flows-sdk-testing/input_data/dpo/preference_dataset_train_256.jsonl", - validation_dataset="s3://mc-flows-sdk-testing/input_data/dpo/preference_dataset_train_256.jsonl", + training_dataset="arn:aws:sagemaker:us-west-2:729646638167:hub-content/sdktest/DataSet/dpo-oss-test-data/0.0.1", + validation_dataset="arn:aws:sagemaker:us-west-2:729646638167:hub-content/sdktest/DataSet/dpo-oss-test-data/0.0.1", s3_output_path="s3://mc-flows-sdk-testing/output/", - # Unique job name - base_job_name=f"dpo-llama-{random.randint(1, 1000)}", accept_eula=True ) diff --git a/sagemaker-train/tests/integ/train/test_rlaif_trainer_integration.py b/sagemaker-train/tests/integ/train/test_rlaif_trainer_integration.py index 9f3594ad01..7e7de19dee 100644 --- a/sagemaker-train/tests/integ/train/test_rlaif_trainer_integration.py +++ b/sagemaker-train/tests/integ/train/test_rlaif_trainer_integration.py @@ -29,15 +29,15 @@ def test_rlaif_trainer_lora_complete_workflow(sagemaker_session): model="meta-textgeneration-llama-3-2-1b-instruct", training_type=TrainingType.LORA, model_package_group_name="sdk-test-finetuned-models", - reward_model_id='anthropic.claude-3-5-sonnet-20240620-v1:0', - reward_prompt='Builtin.Correctness', + reward_model_id='openai.gpt-oss-120b-1:0', + reward_prompt='Builtin.Summarize', mlflow_experiment_name="test-rlaif-finetuned-models-exp", mlflow_run_name="test-rlaif-finetuned-models-run", - training_dataset="s3://mc-flows-sdk-testing/input_data/rlvr-rlaif-test-data/train_285.jsonl", + training_dataset="arn:aws:sagemaker:us-west-2:729646638167:hub-content/sdktest/DataSet/rlvr-rlaif-oss-test-data/0.0.1", s3_output_path="s3://mc-flows-sdk-testing/output/", accept_eula=True ) - + # Create training job training_job = rlaif_trainer.train(wait=False) @@ -64,16 +64,16 @@ def test_rlaif_trainer_lora_complete_workflow(sagemaker_session): @pytest.mark.skip(reason="Skipping GPU resource intensive test") def test_rlaif_trainer_with_custom_reward_settings(sagemaker_session): """Test RLAIF trainer with different reward model and prompt.""" - + rlaif_trainer = RLAIFTrainer( model="meta-textgeneration-llama-3-2-1b-instruct", training_type=TrainingType.LORA, model_package_group_name="sdk-test-finetuned-models", - reward_model_id='anthropic.claude-3-5-sonnet-20240620-v1:0', + reward_model_id='openai.gpt-oss-120b-1:0', reward_prompt="arn:aws:sagemaker:us-west-2:729646638167:hub-content/sdktest/JsonDoc/rlaif-test-prompt/0.0.1", mlflow_experiment_name="test-rlaif-finetuned-models-exp", mlflow_run_name="test-rlaif-finetuned-models-run", - training_dataset="s3://mc-flows-sdk-testing/input_data/rlvr-rlaif-test-data/train_285.jsonl", + training_dataset="arn:aws:sagemaker:us-west-2:729646638167:hub-content/sdktest/DataSet/rlvr-rlaif-oss-test-data/0.0.1", s3_output_path="s3://mc-flows-sdk-testing/output/", accept_eula=True ) @@ -108,11 +108,11 @@ def test_rlaif_trainer_continued_finetuning(sagemaker_session): model="arn:aws:sagemaker:us-west-2:729646638167:model-package/sdk-test-finetuned-models/1", training_type=TrainingType.LORA, model_package_group_name="sdk-test-finetuned-models", - reward_model_id='anthropic.claude-3-5-sonnet-20240620-v1:0', - reward_prompt='Builtin.Correctness', + reward_model_id='openai.gpt-oss-120b-1:0', + reward_prompt='Builtin.Summarize', mlflow_experiment_name="test-rlaif-finetuned-models-exp", mlflow_run_name="test-rlaif-finetuned-models-run", - training_dataset="s3://mc-flows-sdk-testing/input_data/rlvr-rlaif-test-data/train_285.jsonl", + training_dataset="arn:aws:sagemaker:us-west-2:729646638167:hub-content/sdktest/DataSet/rlvr-rlaif-oss-test-data/0.0.1", s3_output_path="s3://mc-flows-sdk-testing/output/", accept_eula=True ) diff --git a/sagemaker-train/tests/integ/train/test_rlvr_trainer_integration.py b/sagemaker-train/tests/integ/train/test_rlvr_trainer_integration.py index d723b3338c..6637a1fdb4 100644 --- a/sagemaker-train/tests/integ/train/test_rlvr_trainer_integration.py +++ b/sagemaker-train/tests/integ/train/test_rlvr_trainer_integration.py @@ -32,7 +32,7 @@ def test_rlvr_trainer_lora_complete_workflow(sagemaker_session): model_package_group_name="sdk-test-finetuned-models", mlflow_experiment_name="test-rlvr-finetuned-models-exp", mlflow_run_name="test-rlvr-finetuned-models-run", - training_dataset="s3://mc-flows-sdk-testing/input_data/rlvr-rlaif-test-data/train_285.jsonl", + training_dataset="arn:aws:sagemaker:us-west-2:729646638167:hub-content/sdktest/DataSet/rlvr-rlaif-oss-test-data/0.0.1", s3_output_path="s3://mc-flows-sdk-testing/output/", accept_eula=True ) @@ -70,7 +70,7 @@ def test_rlvr_trainer_with_custom_reward_function(sagemaker_session): model_package_group_name="sdk-test-finetuned-models", mlflow_experiment_name="test-rlvr-finetuned-models-exp", mlflow_run_name="test-rlvr-finetuned-models-run", - training_dataset="s3://mc-flows-sdk-testing/input_data/rlvr-rlaif-test-data/train_285.jsonl", + training_dataset="arn:aws:sagemaker:us-west-2:729646638167:hub-content/sdktest/DataSet/rlvr-rlaif-oss-test-data/0.0.1", s3_output_path="s3://mc-flows-sdk-testing/output/", custom_reward_function="arn:aws:sagemaker:us-west-2:729646638167:hub-content/sdktest/JsonDoc/rlvr-test-rf/0.0.1", accept_eula=True diff --git a/sagemaker-train/tests/integ/train/test_sft_trainer_integration.py b/sagemaker-train/tests/integ/train/test_sft_trainer_integration.py index e473761bed..aced084c6b 100644 --- a/sagemaker-train/tests/integ/train/test_sft_trainer_integration.py +++ b/sagemaker-train/tests/integ/train/test_sft_trainer_integration.py @@ -30,7 +30,7 @@ def test_sft_trainer_lora_complete_workflow(sagemaker_session): model="meta-textgeneration-llama-3-2-1b-instruct", training_type=TrainingType.LORA, model_package_group_name="arn:aws:sagemaker:us-west-2:729646638167:model-package-group/sdk-test-finetuned-models", - training_dataset="s3://mc-flows-sdk-testing/input_data/sft/", + training_dataset="arn:aws:sagemaker:us-west-2:729646638167:hub-content/sdktest/DataSet/sft-oss-test-data/0.0.1", s3_output_path="s3://mc-flows-sdk-testing/output/", accept_eula=True ) @@ -66,8 +66,8 @@ def test_sft_trainer_with_validation_dataset(sagemaker_session): model="meta-textgeneration-llama-3-2-1b-instruct", training_type=TrainingType.LORA, model_package_group_name="arn:aws:sagemaker:us-west-2:729646638167:model-package-group/sdk-test-finetuned-models", - training_dataset="s3://mc-flows-sdk-testing/input_data/sft/", - validation_dataset="s3://mc-flows-sdk-testing/input_data/sft/", + training_dataset="arn:aws:sagemaker:us-west-2:729646638167:hub-content/sdktest/DataSet/sft-oss-test-data/0.0.1", + validation_dataset="arn:aws:sagemaker:us-west-2:729646638167:hub-content/sdktest/DataSet/sft-oss-test-data/0.0.1", accept_eula=True ) diff --git a/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py b/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py index 960624ccb5..e77b019e68 100644 --- a/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py +++ b/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py @@ -25,7 +25,8 @@ _convert_input_data_to_channels, _create_mlflow_config, _validate_eula_for_gated_model, - _validate_model_region_availability + _validate_model_region_availability, + _validate_s3_path_exists ) from sagemaker.core.resources import ModelPackage, ModelPackageGroup from sagemaker.ai_registry.dataset import DataSet @@ -435,13 +436,15 @@ def test__create_mlflow_config(self): assert config.mlflow_resource_arn == "mlflow-arn" assert config.mlflow_experiment_name == "test-exp" - def test__create_output_config(self): + @patch('sagemaker.train.common_utils.finetune_utils._validate_s3_path_exists') + def test__create_output_config(self, mock_validate_s3): mock_session = Mock() config = _create_output_config(mock_session, "s3://bucket/output", "kms-key") assert config.s3_output_path == "s3://bucket/output" assert config.kms_key_id == "kms-key" + mock_validate_s3.assert_called_once_with("s3://bucket/output", mock_session) def test__convert_input_data_to_channels(self): @@ -500,3 +503,50 @@ def test__validate_model_region_availability_open_weights_invalid_region(self): """Test open weights model validation fails for invalid region""" with pytest.raises(ValueError, match="Region 'us-west-1' does not support model customization"): _validate_model_region_availability("meta-textgeneration-llama-3-2-1b", "us-west-1") + + def test__validate_s3_path_exists_invalid_format(self): + """Test S3 path validation fails for invalid format""" + mock_session = Mock() + + with pytest.raises(ValueError, match="Invalid S3 path format"): + _validate_s3_path_exists("invalid-path", mock_session) + + @patch('boto3.client') + def test__validate_s3_path_exists_bucket_only_success(self, mock_boto_client): + """Test S3 path validation succeeds for bucket-only path""" + mock_session = Mock() + mock_s3_client = Mock() + mock_session.boto_session.client.return_value = mock_s3_client + + _validate_s3_path_exists("s3://test-bucket", mock_session) + + mock_s3_client.head_bucket.assert_called_once_with(Bucket="test-bucket") + + @patch('boto3.client') + def test__validate_s3_path_exists_with_prefix_exists(self, mock_boto_client): + """Test S3 path validation succeeds when prefix exists""" + mock_session = Mock() + mock_s3_client = Mock() + mock_session.boto_session.client.return_value = mock_s3_client + mock_s3_client.list_objects_v2.return_value = {"Contents": [{"Key": "prefix/file.txt"}]} + + _validate_s3_path_exists("s3://test-bucket/prefix/", mock_session) + + mock_s3_client.head_bucket.assert_called_once_with(Bucket="test-bucket") + mock_s3_client.list_objects_v2.assert_called_once_with(Bucket="test-bucket", Prefix="prefix/", MaxKeys=1) + + @patch('boto3.client') + def test__validate_s3_path_exists_with_prefix_not_exists(self, mock_boto_client): + """Test S3 path validation raises error when prefix doesn't exist""" + mock_session = Mock() + mock_s3_client = Mock() + mock_session.boto_session.client.return_value = mock_s3_client + mock_s3_client.list_objects_v2.return_value = {} # No contents + + with pytest.raises(ValueError, match="Failed to validate S3 path 's3://test-bucket/prefix': S3 prefix 'prefix' does not exist in bucket 'test-bucket'"): + _validate_s3_path_exists("s3://test-bucket/prefix", mock_session) + + mock_s3_client.head_bucket.assert_called_once_with(Bucket="test-bucket") + mock_s3_client.list_objects_v2.assert_called_once_with(Bucket="test-bucket", Prefix="prefix", MaxKeys=1) + + diff --git a/sagemaker-train/tests/unit/train/test_dpo_trainer.py b/sagemaker-train/tests/unit/train/test_dpo_trainer.py index 79671f91be..85dce8d56b 100644 --- a/sagemaker-train/tests/unit/train/test_dpo_trainer.py +++ b/sagemaker-train/tests/unit/train/test_dpo_trainer.py @@ -17,7 +17,9 @@ def mock_session(self): @patch('sagemaker.train.dpo_trainer._get_fine_tuning_options_and_model_arn') def test_init_with_defaults(self, mock_finetuning_options, mock_validate_group, mock_session): mock_validate_group.return_value = "test-group" - mock_finetuning_options.return_value = (Mock(), "model-arn", False) + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) trainer = DPOTrainer(model="test-model", model_package_group_name="test-group") assert trainer.training_type == TrainingType.LORA assert trainer.model == "test-model" @@ -26,7 +28,9 @@ def test_init_with_defaults(self, mock_finetuning_options, mock_validate_group, @patch('sagemaker.train.dpo_trainer._get_fine_tuning_options_and_model_arn') def test_init_with_full_training_type(self, mock_finetuning_options, mock_validate_group, mock_session): mock_validate_group.return_value = "test-group" - mock_finetuning_options.return_value = (Mock(), "model-arn", False) + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) trainer = DPOTrainer(model="test-model", training_type=TrainingType.FULL, model_package_group_name="test-group") assert trainer.training_type == TrainingType.FULL @@ -79,7 +83,9 @@ def test_train_with_lora(self, mock_training_job_create, mock_model_package_conf @patch('sagemaker.train.dpo_trainer._get_fine_tuning_options_and_model_arn') def test_training_type_string_value(self, mock_finetuning_options, mock_validate_group, mock_session): mock_validate_group.return_value = "test-group" - mock_finetuning_options.return_value = (Mock(), "model-arn", False) + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) trainer = DPOTrainer(model="test-model", training_type="CUSTOM", model_package_group_name="test-group") assert trainer.training_type == "CUSTOM" @@ -88,7 +94,9 @@ def test_training_type_string_value(self, mock_finetuning_options, mock_validate @patch('sagemaker.train.dpo_trainer._get_fine_tuning_options_and_model_arn') def test_model_package_input(self, mock_finetuning_options, mock_validate_group, mock_resolve_model): mock_validate_group.return_value = "test-group" - mock_finetuning_options.return_value = (Mock(), "model-arn", False) + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) model_package = Mock(spec=ModelPackage) model_package.inference_specification = Mock() @@ -102,7 +110,9 @@ def test_model_package_input(self, mock_finetuning_options, mock_validate_group, @patch('sagemaker.train.dpo_trainer._get_fine_tuning_options_and_model_arn') def test_init_with_datasets(self, mock_finetuning_options, mock_validate_group, mock_session): mock_validate_group.return_value = "test-group" - mock_finetuning_options.return_value = (Mock(), "model-arn", False) + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) trainer = DPOTrainer( model="test-model", model_package_group_name="test-group", @@ -116,7 +126,9 @@ def test_init_with_datasets(self, mock_finetuning_options, mock_validate_group, @patch('sagemaker.train.dpo_trainer._get_fine_tuning_options_and_model_arn') def test_init_with_mlflow_config(self, mock_finetuning_options, mock_validate_group, mock_session): mock_validate_group.return_value = "test-group" - mock_finetuning_options.return_value = (Mock(), "model-arn", False) + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) trainer = DPOTrainer( model="test-model", model_package_group_name="test-group", @@ -177,7 +189,9 @@ def test_train_with_full_training(self, mock_training_job_create, mock_model_pac @patch('sagemaker.train.dpo_trainer._get_fine_tuning_options_and_model_arn') def test_fit_without_datasets_raises_error(self, mock_finetuning_options, mock_validate_group): mock_validate_group.return_value = "test-group" - mock_finetuning_options.return_value = (Mock(), "model-arn", False) + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) trainer = DPOTrainer(model="test-model", model_package_group_name="test-group") with pytest.raises(Exception): @@ -189,7 +203,9 @@ def test_fit_without_datasets_raises_error(self, mock_finetuning_options, mock_v def test_model_package_group_handling(self, mock_validate_group, mock_get_options, mock_resolve_model): mock_validate_group.return_value = "test-group" mock_resolve_model.return_value = "resolved-model" - mock_get_options.return_value = (Mock(), "model-arn", False) + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_get_options.return_value = (mock_hyperparams, "model-arn", False) trainer = DPOTrainer( model="test-model", @@ -201,7 +217,9 @@ def test_model_package_group_handling(self, mock_validate_group, mock_get_option @patch('sagemaker.train.dpo_trainer._get_fine_tuning_options_and_model_arn') def test_s3_output_path_configuration(self, mock_finetuning_options, mock_validate_group, mock_session): mock_validate_group.return_value = "test-group" - mock_finetuning_options.return_value = (Mock(), "model-arn", False) + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) trainer = DPOTrainer( model="test-model", model_package_group_name="test-group", @@ -260,7 +278,9 @@ def test_train_with_tags(self, mock_training_job_create, mock_model_package_conf def test_gated_model_eula_validation(self, mock_finetuning_options, mock_validate_group, mock_session): """Test EULA validation for gated models""" mock_validate_group.return_value = "test-group" - mock_finetuning_options.return_value = (Mock(), "model-arn", True) # is_gated_model=True + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", True) # is_gated_model=True # Should raise error when accept_eula=False for gated model with pytest.raises(ValueError, match="gated model and requires EULA acceptance"): @@ -269,3 +289,71 @@ def test_gated_model_eula_validation(self, mock_finetuning_options, mock_validat # Should work when accept_eula=True for gated model trainer = DPOTrainer(model="gated-model", model_package_group_name="test-group", accept_eula=True) assert trainer.accept_eula == True + + def test_process_hyperparameters_removes_constructor_handled_keys(self): + """Test that _process_hyperparameters removes keys handled by constructor inputs.""" + # Create mock hyperparameters with all possible keys + mock_hyperparams = Mock() + mock_hyperparams._specs = { + 'data_path': 'test_data_path', + 'output_path': 'test_output_path', + 'training_data_name': 'test_training_data_name', + 'validation_data_name': 'test_validation_data_name', + 'other_param': 'should_remain' + } + + # Add attributes to mock + mock_hyperparams.data_path = 'test_data_path' + mock_hyperparams.output_path = 'test_output_path' + mock_hyperparams.training_data_name = 'test_training_data_name' + mock_hyperparams.validation_data_name = 'test_validation_data_name' + + # Create trainer instance with mock hyperparameters + trainer = DPOTrainer.__new__(DPOTrainer) + trainer.hyperparameters = mock_hyperparams + + # Call the method + trainer._process_hyperparameters() + + # Verify attributes were removed + assert not hasattr(mock_hyperparams, 'data_path') + assert not hasattr(mock_hyperparams, 'output_path') + assert not hasattr(mock_hyperparams, 'training_data_name') + assert not hasattr(mock_hyperparams, 'validation_data_name') + + # Verify _specs were updated + assert 'data_path' not in mock_hyperparams._specs + assert 'output_path' not in mock_hyperparams._specs + assert 'training_data_name' not in mock_hyperparams._specs + assert 'validation_data_name' not in mock_hyperparams._specs + assert 'other_param' in mock_hyperparams._specs + + def test_process_hyperparameters_handles_missing_attributes(self): + """Test that _process_hyperparameters handles missing attributes gracefully.""" + # Create mock hyperparameters with only some keys + mock_hyperparams = Mock() + mock_hyperparams._specs = { + 'data_path': 'test_data_path', + 'other_param': 'should_remain' + } + mock_hyperparams.data_path = 'test_data_path' + + # Create trainer instance + trainer = DPOTrainer.__new__(DPOTrainer) + trainer.hyperparameters = mock_hyperparams + + # Call the method + trainer._process_hyperparameters() + + # Verify only existing attributes were processed + assert not hasattr(mock_hyperparams, 'data_path') + assert 'data_path' not in mock_hyperparams._specs + assert 'other_param' in mock_hyperparams._specs + + def test_process_hyperparameters_with_none_hyperparameters(self): + """Test that _process_hyperparameters handles None hyperparameters.""" + trainer = DPOTrainer.__new__(DPOTrainer) + trainer.hyperparameters = None + + # Should not raise an exception + trainer._process_hyperparameters() diff --git a/sagemaker-train/tests/unit/train/test_rlaif_trainer.py b/sagemaker-train/tests/unit/train/test_rlaif_trainer.py index 32df0300c0..0008b88912 100644 --- a/sagemaker-train/tests/unit/train/test_rlaif_trainer.py +++ b/sagemaker-train/tests/unit/train/test_rlaif_trainer.py @@ -17,7 +17,9 @@ def mock_session(self): @patch('sagemaker.train.rlaif_trainer._get_fine_tuning_options_and_model_arn') def test_init_with_defaults(self, mock_finetuning_options, mock_validate_group, mock_session): mock_validate_group.return_value = "test-group" - mock_finetuning_options.return_value = (Mock(), "model-arn", False) + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) trainer = RLAIFTrainer(model="test-model", model_package_group_name="test-group") assert trainer.training_type == TrainingType.LORA assert trainer.model == "test-model" @@ -26,7 +28,9 @@ def test_init_with_defaults(self, mock_finetuning_options, mock_validate_group, @patch('sagemaker.train.rlaif_trainer._get_fine_tuning_options_and_model_arn') def test_init_with_full_training_type(self, mock_finetuning_options, mock_validate_group, mock_session): mock_validate_group.return_value = "test-group" - mock_finetuning_options.return_value = (Mock(), "model-arn", False) + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) trainer = RLAIFTrainer(model="test-model", training_type=TrainingType.FULL, model_package_group_name="test-group") assert trainer.training_type == TrainingType.FULL @@ -122,7 +126,9 @@ def test_peft_value_for_full_training(self, mock_training_job_create, mock_model @patch('sagemaker.train.rlaif_trainer._get_fine_tuning_options_and_model_arn') def test_training_type_string_value(self, mock_finetuning_options, mock_validate_group, mock_session): mock_validate_group.return_value = "test-group" - mock_finetuning_options.return_value = (Mock(), "model-arn", False) + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) trainer = RLAIFTrainer(model="test-model", training_type="CUSTOM", model_package_group_name="test-group") assert trainer.training_type == "CUSTOM" @@ -133,7 +139,9 @@ def test_training_type_string_value(self, mock_finetuning_options, mock_validate def test_model_package_input(self, mock_finetuning_options, mock_validate_group, mock_resolve_model, mock_get_session): mock_validate_group.return_value = "test-group" mock_get_session.return_value = Mock() - mock_finetuning_options.return_value = (Mock(), "model-arn", False) + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) model_package = Mock(spec=ModelPackage) model_package.inference_specification = Mock() @@ -148,7 +156,9 @@ def test_model_package_input(self, mock_finetuning_options, mock_validate_group, @patch('sagemaker.train.rlaif_trainer._get_fine_tuning_options_and_model_arn') def test_init_with_datasets(self, mock_finetuning_options, mock_validate_group, mock_session): mock_validate_group.return_value = "test-group" - mock_finetuning_options.return_value = (Mock(), "model-arn", False) + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) trainer = RLAIFTrainer( model="test-model", model_package_group_name="test-group", @@ -162,7 +172,9 @@ def test_init_with_datasets(self, mock_finetuning_options, mock_validate_group, @patch('sagemaker.train.rlaif_trainer._get_fine_tuning_options_and_model_arn') def test_init_with_mlflow_config(self, mock_finetuning_options, mock_validate_group, mock_session): mock_validate_group.return_value = "test-group" - mock_finetuning_options.return_value = (Mock(), "model-arn", False) + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) trainer = RLAIFTrainer( model="test-model", model_package_group_name="test-group", @@ -179,7 +191,9 @@ def test_init_with_mlflow_config(self, mock_finetuning_options, mock_validate_gr @patch('sagemaker.train.rlaif_trainer._get_fine_tuning_options_and_model_arn') def test_train_without_datasets_raises_error(self, mock_finetuning_options, mock_validate_group, mock_get_session): mock_validate_group.return_value = "test-group" - mock_finetuning_options.return_value = (Mock(), "model-arn", False) + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) mock_get_session.return_value = Mock() trainer = RLAIFTrainer(model="test-model", model_package_group_name="test-group") @@ -194,7 +208,9 @@ def test_model_package_group_handling(self, mock_validate_group, mock_get_option mock_validate_group.return_value = "test-group" mock_get_session.return_value = Mock() mock_resolve_model.return_value = "resolved-model" - mock_get_options.return_value = (Mock(), "model-arn", False) + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_get_options.return_value = (mock_hyperparams, "model-arn", False) trainer = RLAIFTrainer( model="test-model", @@ -206,7 +222,9 @@ def test_model_package_group_handling(self, mock_validate_group, mock_get_option @patch('sagemaker.train.rlaif_trainer._get_fine_tuning_options_and_model_arn') def test_s3_output_path_configuration(self, mock_finetuning_options, mock_validate_group, mock_session): mock_validate_group.return_value = "test-group" - mock_finetuning_options.return_value = (Mock(), "model-arn", False) + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) trainer = RLAIFTrainer( model="test-model", model_package_group_name="test-group", @@ -265,7 +283,9 @@ def test_train_with_tags(self, mock_training_job_create, mock_model_package_conf def test_gated_model_eula_validation(self, mock_finetuning_options, mock_validate_group, mock_session): """Test EULA validation for gated models""" mock_validate_group.return_value = "test-group" - mock_finetuning_options.return_value = (Mock(), "model-arn", True) # is_gated_model=True + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", True) # is_gated_model=True # Should raise error when accept_eula=False for gated model with pytest.raises(ValueError, match="gated model and requires EULA acceptance"): @@ -274,3 +294,212 @@ def test_gated_model_eula_validation(self, mock_finetuning_options, mock_validat # Should work when accept_eula=True for gated model trainer = RLAIFTrainer(model="gated-model", model_package_group_name="test-group", accept_eula=True) assert trainer.accept_eula == True + + def test_process_hyperparameters_removes_constructor_handled_keys(self): + """Test that _process_hyperparameters removes keys handled by constructor inputs.""" + # Create mock hyperparameters with all possible keys + mock_hyperparams = Mock() + mock_hyperparams._specs = { + 'output_path': 'test_output_path', + 'data_path': 'test_data_path', + 'validation_data_path': 'test_validation_data_path', + 'other_param': 'should_remain' + } + + # Add attributes to mock + mock_hyperparams.output_path = 'test_output_path' + mock_hyperparams.data_path = 'test_data_path' + mock_hyperparams.validation_data_path = 'test_validation_data_path' + + # Create trainer instance with mock hyperparameters + trainer = RLAIFTrainer.__new__(RLAIFTrainer) + trainer.hyperparameters = mock_hyperparams + trainer.reward_model_id = "test-reward-model" + + # Call the method + trainer._process_hyperparameters() + + # Verify attributes were removed + assert not hasattr(mock_hyperparams, 'output_path') + assert not hasattr(mock_hyperparams, 'data_path') + assert not hasattr(mock_hyperparams, 'validation_data_path') + + # Verify _specs were updated + assert 'output_path' not in mock_hyperparams._specs + assert 'data_path' not in mock_hyperparams._specs + assert 'validation_data_path' not in mock_hyperparams._specs + assert 'other_param' in mock_hyperparams._specs + + # Verify judge_model_id was set + assert mock_hyperparams.judge_model_id == "bedrock/test-reward-model" + + def test_process_hyperparameters_updates_judge_model_id(self): + """Test that _process_hyperparameters updates judge_model_id when reward_model_id is provided.""" + # Use a simple object instead of Mock to allow proper attribute assignment + class MockHyperparams: + def __init__(self): + self._specs = {'some_param': 'value'} # Non-empty specs + + mock_hyperparams = MockHyperparams() + + trainer = RLAIFTrainer.__new__(RLAIFTrainer) + trainer.hyperparameters = mock_hyperparams + trainer.reward_model_id = "my-reward-model" + + trainer._process_hyperparameters() + + assert hasattr(mock_hyperparams, 'judge_model_id') + assert mock_hyperparams.judge_model_id == "bedrock/my-reward-model" + + def test_process_hyperparameters_handles_missing_attributes(self): + """Test that _process_hyperparameters handles missing attributes gracefully.""" + # Create mock hyperparameters with only some keys + mock_hyperparams = Mock() + mock_hyperparams._specs = { + 'data_path': 'test_data_path', + 'other_param': 'should_remain' + } + mock_hyperparams.data_path = 'test_data_path' + + # Create trainer instance + trainer = RLAIFTrainer.__new__(RLAIFTrainer) + trainer.hyperparameters = mock_hyperparams + trainer.reward_model_id = None + + # Call the method + trainer._process_hyperparameters() + + # Verify only existing attributes were processed + assert not hasattr(mock_hyperparams, 'data_path') + assert 'data_path' not in mock_hyperparams._specs + assert 'other_param' in mock_hyperparams._specs + + def test_process_hyperparameters_with_none_hyperparameters(self): + """Test that _process_hyperparameters handles None hyperparameters.""" + trainer = RLAIFTrainer.__new__(RLAIFTrainer) + trainer.hyperparameters = None + + # Should not raise an exception + trainer._process_hyperparameters() + + def test_process_hyperparameters_early_return_on_none(self): + """Test that _process_hyperparameters returns early when hyperparameters is None.""" + trainer = RLAIFTrainer.__new__(RLAIFTrainer) + trainer.hyperparameters = None + trainer.reward_model_id = "test-model" + + # Should return early and not attempt to set judge_model_id + trainer._process_hyperparameters() + + # No exception should be raised + + def test_update_judge_prompt_template_direct_with_matching_template(self): + """Test _update_judge_prompt_template_direct with matching template.""" + mock_hyperparams = Mock() + mock_hyperparams._specs = { + 'judge_prompt_template': { + 'enum': ['templates/summarize.jinja', 'templates/helpfulness.jinja'] + } + } + + trainer = RLAIFTrainer.__new__(RLAIFTrainer) + trainer.hyperparameters = mock_hyperparams + + trainer._update_judge_prompt_template_direct("Builtin.summarize") + + assert mock_hyperparams.judge_prompt_template == 'templates/summarize.jinja' + + def test_update_judge_prompt_template_direct_with_no_enum(self): + """Test _update_judge_prompt_template_direct when no enum is available.""" + mock_hyperparams = Mock() + mock_hyperparams._specs = {'judge_prompt_template': {}} + mock_hyperparams.judge_prompt_template = 'current_template.jinja' + + trainer = RLAIFTrainer.__new__(RLAIFTrainer) + trainer.hyperparameters = mock_hyperparams + + trainer._update_judge_prompt_template_direct("Builtin.current_template") + + assert mock_hyperparams.judge_prompt_template == 'current_template.jinja' + + def test_update_judge_prompt_template_direct_no_matching_template(self): + """Test _update_judge_prompt_template_direct raises error for non-matching template.""" + mock_hyperparams = Mock() + mock_hyperparams._specs = { + 'judge_prompt_template': { + 'enum': ['templates/summarize.jinja', 'templates/helpfulness.jinja'] + } + } + + trainer = RLAIFTrainer.__new__(RLAIFTrainer) + trainer.hyperparameters = mock_hyperparams + + with pytest.raises(ValueError, match="Selected reward function option 'Builtin.nonexistent' is not available"): + trainer._update_judge_prompt_template_direct("Builtin.nonexistent") + + def test_update_judge_prompt_template_direct_early_return(self): + """Test _update_judge_prompt_template_direct returns early when no templates available.""" + mock_hyperparams = Mock() + mock_hyperparams._specs = {'judge_prompt_template': {}} + mock_hyperparams.judge_prompt_template = None + + trainer = RLAIFTrainer.__new__(RLAIFTrainer) + trainer.hyperparameters = mock_hyperparams + + # Should return early without error + trainer._update_judge_prompt_template_direct("Builtin.anything") + + def test_process_non_builtin_reward_prompt_removes_judge_template(self): + """Test _process_non_builtin_reward_prompt removes judge_prompt_template.""" + mock_hyperparams = Mock() + mock_hyperparams._specs = {'judge_prompt_template': 'template.jinja'} + mock_hyperparams.judge_prompt_template = 'template.jinja' + + trainer = RLAIFTrainer.__new__(RLAIFTrainer) + trainer.hyperparameters = mock_hyperparams + trainer.reward_prompt = "arn:aws:sagemaker:us-east-1:123456789012:evaluator/test" + + with patch('sagemaker.train.rlaif_trainer._extract_evaluator_arn') as mock_extract: + mock_extract.return_value = "test-arn" + trainer._process_non_builtin_reward_prompt() + + assert not hasattr(mock_hyperparams, 'judge_prompt_template') + assert 'judge_prompt_template' not in mock_hyperparams._specs + assert trainer._evaluator_arn == "test-arn" + + def test_process_non_builtin_reward_prompt_with_hub_content(self): + """Test _process_non_builtin_reward_prompt with hub content name.""" + mock_hyperparams = Mock() + mock_hyperparams._specs = {} + + trainer = RLAIFTrainer.__new__(RLAIFTrainer) + trainer.hyperparameters = mock_hyperparams + trainer.reward_prompt = "custom-prompt-name" + trainer.sagemaker_session = None + + with patch('sagemaker.train.rlaif_trainer.TrainDefaults.get_sagemaker_session') as mock_session, \ + patch('sagemaker.train.rlaif_trainer._get_hub_content_metadata') as mock_hub: + mock_session.return_value = Mock(boto_session=Mock(region_name="us-west-2")) + mock_hub.return_value = Mock(hub_content_arn="hub-content-arn") + + trainer._process_non_builtin_reward_prompt() + + assert trainer._evaluator_arn == "hub-content-arn" + + def test_process_non_builtin_reward_prompt_hub_content_error(self): + """Test _process_non_builtin_reward_prompt raises error for invalid hub content.""" + mock_hyperparams = Mock() + mock_hyperparams._specs = {} + + trainer = RLAIFTrainer.__new__(RLAIFTrainer) + trainer.hyperparameters = mock_hyperparams + trainer.reward_prompt = "invalid-prompt" + trainer.sagemaker_session = None + + with patch('sagemaker.train.rlaif_trainer.TrainDefaults.get_sagemaker_session') as mock_session, \ + patch('sagemaker.train.rlaif_trainer._get_hub_content_metadata') as mock_hub: + mock_session.return_value = Mock(boto_session=Mock(region_name="us-west-2")) + mock_hub.side_effect = Exception("Not found") + + with pytest.raises(ValueError, match="Custom prompt 'invalid-prompt' not found in HubContent"): + trainer._process_non_builtin_reward_prompt() diff --git a/sagemaker-train/tests/unit/train/test_rlvr_trainer.py b/sagemaker-train/tests/unit/train/test_rlvr_trainer.py index 4ff9c7552c..7128a3545c 100644 --- a/sagemaker-train/tests/unit/train/test_rlvr_trainer.py +++ b/sagemaker-train/tests/unit/train/test_rlvr_trainer.py @@ -17,7 +17,9 @@ def mock_session(self): @patch('sagemaker.train.rlvr_trainer._get_fine_tuning_options_and_model_arn') def test_init_with_defaults(self, mock_finetuning_options, mock_validate_group, mock_session): mock_validate_group.return_value = "test-group" - mock_finetuning_options.return_value = (Mock(), "model-arn", False) + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) trainer = RLVRTrainer(model="test-model", model_package_group_name="test-group") assert trainer.training_type == TrainingType.LORA assert trainer.model == "test-model" @@ -26,7 +28,9 @@ def test_init_with_defaults(self, mock_finetuning_options, mock_validate_group, @patch('sagemaker.train.rlvr_trainer._get_fine_tuning_options_and_model_arn') def test_init_with_full_training_type(self, mock_finetuning_options, mock_validate_group, mock_session): mock_validate_group.return_value = "test-group" - mock_finetuning_options.return_value = (Mock(), "model-arn", False) + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) trainer = RLVRTrainer(model="test-model", training_type=TrainingType.FULL, model_package_group_name="test-group") assert trainer.training_type == TrainingType.FULL @@ -122,7 +126,9 @@ def test_peft_value_for_full_training(self, mock_training_job_create, mock_model @patch('sagemaker.train.rlvr_trainer._get_fine_tuning_options_and_model_arn') def test_training_type_string_value(self, mock_finetuning_options, mock_validate_group, mock_session): mock_validate_group.return_value = "test-group" - mock_finetuning_options.return_value = (Mock(), "model-arn", False) + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) trainer = RLVRTrainer(model="test-model", training_type="CUSTOM", model_package_group_name="test-group") assert trainer.training_type == "CUSTOM" @@ -133,7 +139,9 @@ def test_training_type_string_value(self, mock_finetuning_options, mock_validate def test_model_package_input(self, mock_finetuning_options, mock_validate_group, mock_resolve_model, mock_get_session): mock_validate_group.return_value = "test-group" mock_get_session.return_value = Mock() - mock_finetuning_options.return_value = (Mock(), "model-arn", False) + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) model_package = Mock(spec=ModelPackage) model_package.inference_specification = Mock() @@ -148,7 +156,9 @@ def test_model_package_input(self, mock_finetuning_options, mock_validate_group, @patch('sagemaker.train.rlvr_trainer._get_fine_tuning_options_and_model_arn') def test_init_with_datasets(self, mock_finetuning_options, mock_validate_group, mock_session): mock_validate_group.return_value = "test-group" - mock_finetuning_options.return_value = (Mock(), "model-arn", False) + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) trainer = RLVRTrainer( model="test-model", model_package_group_name="test-group", @@ -162,7 +172,9 @@ def test_init_with_datasets(self, mock_finetuning_options, mock_validate_group, @patch('sagemaker.train.rlvr_trainer._get_fine_tuning_options_and_model_arn') def test_init_with_mlflow_config(self, mock_finetuning_options, mock_validate_group, mock_session): mock_validate_group.return_value = "test-group" - mock_finetuning_options.return_value = (Mock(), "model-arn", False) + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) trainer = RLVRTrainer( model="test-model", model_package_group_name="test-group", @@ -179,7 +191,9 @@ def test_init_with_mlflow_config(self, mock_finetuning_options, mock_validate_gr @patch('sagemaker.train.rlvr_trainer._get_fine_tuning_options_and_model_arn') def test_train_without_datasets_raises_error(self, mock_finetuning_options, mock_validate_group, mock_get_session): mock_validate_group.return_value = "test-group" - mock_finetuning_options.return_value = (Mock(), "model-arn", False) + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) mock_get_session.return_value = Mock() trainer = RLVRTrainer(model="test-model", model_package_group_name="test-group") @@ -194,7 +208,9 @@ def test_model_package_group_handling(self, mock_validate_group, mock_get_option mock_validate_group.return_value = "test-group" mock_get_session.return_value = Mock() mock_resolve_model.return_value = "resolved-model" - mock_get_options.return_value = (Mock(), "model-arn", False) + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_get_options.return_value = (mock_hyperparams, "model-arn", False) trainer = RLVRTrainer( model="test-model", @@ -206,7 +222,9 @@ def test_model_package_group_handling(self, mock_validate_group, mock_get_option @patch('sagemaker.train.rlvr_trainer._get_fine_tuning_options_and_model_arn') def test_s3_output_path_configuration(self, mock_finetuning_options, mock_validate_group, mock_session): mock_validate_group.return_value = "test-group" - mock_finetuning_options.return_value = (Mock(), "model-arn", False) + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) trainer = RLVRTrainer( model="test-model", model_package_group_name="test-group", @@ -263,7 +281,9 @@ def test_train_with_tags(self, mock_training_job_create, mock_model_package_conf def test_gated_model_eula_validation(self, mock_finetuning_options, mock_validate_group, mock_session): """Test EULA validation for gated models""" mock_validate_group.return_value = "test-group" - mock_finetuning_options.return_value = (Mock(), "model-arn", True) # is_gated_model=True + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", True) # is_gated_model=True # Should raise error when accept_eula=False for gated model with pytest.raises(ValueError, match="gated model and requires EULA acceptance"): @@ -272,3 +292,71 @@ def test_gated_model_eula_validation(self, mock_finetuning_options, mock_validat # Should work when accept_eula=True for gated model trainer = RLVRTrainer(model="gated-model", model_package_group_name="test-group", accept_eula=True) assert trainer.accept_eula == True + + def test_process_hyperparameters_removes_constructor_handled_keys(self): + """Test that _process_hyperparameters removes keys handled by constructor inputs.""" + # Create mock hyperparameters with all possible keys + mock_hyperparams = Mock() + mock_hyperparams._specs = { + 'data_s3_path': 'test_data_s3_path', + 'reward_lambda_arn': 'test_reward_lambda_arn', + 'data_path': 'test_data_path', + 'validation_data_path': 'test_validation_data_path', + 'other_param': 'should_remain' + } + + # Add attributes to mock + mock_hyperparams.data_s3_path = 'test_data_s3_path' + mock_hyperparams.reward_lambda_arn = 'test_reward_lambda_arn' + mock_hyperparams.data_path = 'test_data_path' + mock_hyperparams.validation_data_path = 'test_validation_data_path' + + # Create trainer instance with mock hyperparameters + trainer = RLVRTrainer.__new__(RLVRTrainer) + trainer.hyperparameters = mock_hyperparams + + # Call the method + trainer._process_hyperparameters() + + # Verify attributes were removed + assert not hasattr(mock_hyperparams, 'data_s3_path') + assert not hasattr(mock_hyperparams, 'reward_lambda_arn') + assert not hasattr(mock_hyperparams, 'data_path') + assert not hasattr(mock_hyperparams, 'validation_data_path') + + # Verify _specs were updated + assert 'data_s3_path' not in mock_hyperparams._specs + assert 'reward_lambda_arn' not in mock_hyperparams._specs + assert 'data_path' not in mock_hyperparams._specs + assert 'validation_data_path' not in mock_hyperparams._specs + assert 'other_param' in mock_hyperparams._specs + + def test_process_hyperparameters_handles_missing_attributes(self): + """Test that _process_hyperparameters handles missing attributes gracefully.""" + # Create mock hyperparameters with only some keys + mock_hyperparams = Mock() + mock_hyperparams._specs = { + 'data_s3_path': 'test_data_s3_path', + 'other_param': 'should_remain' + } + mock_hyperparams.data_s3_path = 'test_data_s3_path' + + # Create trainer instance + trainer = RLVRTrainer.__new__(RLVRTrainer) + trainer.hyperparameters = mock_hyperparams + + # Call the method + trainer._process_hyperparameters() + + # Verify only existing attributes were processed + assert not hasattr(mock_hyperparams, 'data_s3_path') + assert 'data_s3_path' not in mock_hyperparams._specs + assert 'other_param' in mock_hyperparams._specs + + def test_process_hyperparameters_with_none_hyperparameters(self): + """Test that _process_hyperparameters handles None hyperparameters.""" + trainer = RLVRTrainer.__new__(RLVRTrainer) + trainer.hyperparameters = None + + # Should not raise an exception + trainer._process_hyperparameters() diff --git a/sagemaker-train/tests/unit/train/test_sft_trainer.py b/sagemaker-train/tests/unit/train/test_sft_trainer.py index d68636da7a..77b120bd6f 100644 --- a/sagemaker-train/tests/unit/train/test_sft_trainer.py +++ b/sagemaker-train/tests/unit/train/test_sft_trainer.py @@ -17,7 +17,9 @@ def mock_session(self): @patch('sagemaker.train.sft_trainer._get_fine_tuning_options_and_model_arn') def test_init_with_defaults(self, mock_finetuning_options, mock_validate_group, mock_session): mock_validate_group.return_value = "test-group" - mock_finetuning_options.return_value = (Mock(), "model-arn", False) + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) trainer = SFTTrainer(model="test-model", model_package_group_name="test-group") assert trainer.training_type == TrainingType.LORA assert trainer.model == "test-model" @@ -26,7 +28,9 @@ def test_init_with_defaults(self, mock_finetuning_options, mock_validate_group, @patch('sagemaker.train.sft_trainer._get_fine_tuning_options_and_model_arn') def test_init_with_full_training_type(self, mock_finetuning_options, mock_validate_group, mock_session): mock_validate_group.return_value = "test-group" - mock_finetuning_options.return_value = (Mock(), "model-arn", False) + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) trainer = SFTTrainer(model="test-model", training_type=TrainingType.FULL, model_package_group_name="test-group") assert trainer.training_type == TrainingType.FULL @@ -122,7 +126,9 @@ def test_peft_value_for_full_training(self, mock_training_job_create, mock_model @patch('sagemaker.train.sft_trainer._get_fine_tuning_options_and_model_arn') def test_training_type_string_value(self, mock_finetuning_options, mock_validate_group, mock_session): mock_validate_group.return_value = "test-group" - mock_finetuning_options.return_value = (Mock(), "model-arn", False) + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) trainer = SFTTrainer(model="test-model", training_type="CUSTOM", model_package_group_name="test-group") assert trainer.training_type == "CUSTOM" @@ -131,7 +137,9 @@ def test_training_type_string_value(self, mock_finetuning_options, mock_validate @patch('sagemaker.train.sft_trainer._get_fine_tuning_options_and_model_arn') def test_model_package_input(self, mock_finetuning_options, mock_validate_group, mock_resolve_model): mock_validate_group.return_value = "test-group" - mock_finetuning_options.return_value = (Mock(), "model-arn", False) + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) model_package = Mock(spec=ModelPackage) model_package.inference_specification = Mock() @@ -146,7 +154,9 @@ def test_model_package_input(self, mock_finetuning_options, mock_validate_group, @patch('sagemaker.train.sft_trainer._get_fine_tuning_options_and_model_arn') def test_init_with_datasets(self, mock_finetuning_options, mock_validate_group, mock_session): mock_validate_group.return_value = "test-group" - mock_finetuning_options.return_value = (Mock(), "model-arn", False) + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) trainer = SFTTrainer( model="test-model", model_package_group_name="test-group", @@ -160,7 +170,9 @@ def test_init_with_datasets(self, mock_finetuning_options, mock_validate_group, @patch('sagemaker.train.sft_trainer._get_fine_tuning_options_and_model_arn') def test_init_with_mlflow_config(self, mock_finetuning_options, mock_validate_group, mock_session): mock_validate_group.return_value = "test-group" - mock_finetuning_options.return_value = (Mock(), "model-arn", False) + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) trainer = SFTTrainer( model="test-model", model_package_group_name="test-group", @@ -177,7 +189,9 @@ def test_init_with_mlflow_config(self, mock_finetuning_options, mock_validate_gr @patch('sagemaker.train.sft_trainer._get_fine_tuning_options_and_model_arn') def test_fit_without_datasets_raises_error(self, mock_finetuning_options, mock_validate_group, mock_get_session): mock_validate_group.return_value = "test-group" - mock_finetuning_options.return_value = (Mock(), "model-arn", False) + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) mock_get_session.return_value = Mock() trainer = SFTTrainer(model="test-model", model_package_group_name="test-group") @@ -188,7 +202,9 @@ def test_fit_without_datasets_raises_error(self, mock_finetuning_options, mock_v @patch('sagemaker.train.sft_trainer._validate_and_resolve_model_package_group') def test_model_package_group_handling(self, mock_validate_group, mock_get_options): mock_validate_group.return_value = "test-group" - mock_get_options.return_value = (Mock(), "model-arn", False) + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_get_options.return_value = (mock_hyperparams, "model-arn", False) trainer = SFTTrainer( model="test-model", @@ -200,7 +216,9 @@ def test_model_package_group_handling(self, mock_validate_group, mock_get_option @patch('sagemaker.train.sft_trainer._get_fine_tuning_options_and_model_arn') def test_s3_output_path_configuration(self, mock_finetuning_options, mock_validate_group, mock_session): mock_validate_group.return_value = "test-group" - mock_finetuning_options.return_value = (Mock(), "model-arn", False) + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) trainer = SFTTrainer( model="test-model", model_package_group_name="test-group", @@ -213,7 +231,9 @@ def test_s3_output_path_configuration(self, mock_finetuning_options, mock_valida def test_gated_model_eula_validation(self, mock_finetuning_options, mock_validate_group, mock_session): """Test EULA validation for gated models""" mock_validate_group.return_value = "test-group" - mock_finetuning_options.return_value = (Mock(), "model-arn", True) # is_gated_model=True + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", True) # is_gated_model=True # Should raise error when accept_eula=False for gated model with pytest.raises(ValueError, match="gated model and requires EULA acceptance"): @@ -267,3 +287,75 @@ def test_train_with_tags(self, mock_training_job_create, mock_model_package_conf {"key": "sagemaker-studio:jumpstart-model-id", "value": "test-model"}, {"key": "sagemaker-studio:jumpstart-hub-name", "value": "SageMakerPublicHub"} ] + + def test_process_hyperparameters_removes_constructor_handled_keys(self): + """Test that _process_hyperparameters removes keys handled by constructor inputs.""" + # Create mock hyperparameters with all possible keys + mock_hyperparams = Mock() + mock_hyperparams._specs = { + 'data_path': 'test_data_path', + 'output_path': 'test_output_path', + 'training_data_name': 'test_training_data_name', + 'validation_data_name': 'test_validation_data_name', + 'validation_data_path': 'test_validation_data_path', + 'other_param': 'should_remain' + } + + # Add attributes to mock + mock_hyperparams.data_path = 'test_data_path' + mock_hyperparams.output_path = 'test_output_path' + mock_hyperparams.training_data_name = 'test_training_data_name' + mock_hyperparams.validation_data_name = 'test_validation_data_name' + mock_hyperparams.validation_data_path = 'test_validation_data_path' + + # Create trainer instance with mock hyperparameters + trainer = SFTTrainer.__new__(SFTTrainer) + trainer.hyperparameters = mock_hyperparams + + # Call the method + trainer._process_hyperparameters() + + # Verify attributes were removed + assert not hasattr(mock_hyperparams, 'data_path') + assert not hasattr(mock_hyperparams, 'output_path') + assert not hasattr(mock_hyperparams, 'training_data_name') + assert not hasattr(mock_hyperparams, 'validation_data_name') + assert not hasattr(mock_hyperparams, 'validation_data_path') + + # Verify _specs were updated + assert 'data_path' not in mock_hyperparams._specs + assert 'output_path' not in mock_hyperparams._specs + assert 'training_data_name' not in mock_hyperparams._specs + assert 'validation_data_name' not in mock_hyperparams._specs + assert 'validation_data_path' not in mock_hyperparams._specs + assert 'other_param' in mock_hyperparams._specs + + def test_process_hyperparameters_handles_missing_attributes(self): + """Test that _process_hyperparameters handles missing attributes gracefully.""" + # Create mock hyperparameters with only some keys + mock_hyperparams = Mock() + mock_hyperparams._specs = { + 'data_path': 'test_data_path', + 'other_param': 'should_remain' + } + mock_hyperparams.data_path = 'test_data_path' + + # Create trainer instance + trainer = SFTTrainer.__new__(SFTTrainer) + trainer.hyperparameters = mock_hyperparams + + # Call the method + trainer._process_hyperparameters() + + # Verify only existing attributes were processed + assert not hasattr(mock_hyperparams, 'data_path') + assert 'data_path' not in mock_hyperparams._specs + assert 'other_param' in mock_hyperparams._specs + + def test_process_hyperparameters_with_none_hyperparameters(self): + """Test that _process_hyperparameters handles None hyperparameters.""" + trainer = SFTTrainer.__new__(SFTTrainer) + trainer.hyperparameters = None + + # Should not raise an exception + trainer._process_hyperparameters()