@@ -343,29 +343,15 @@ def _get_fine_tuning_options_and_model_arn(model_name: str, customization_techni
343343 recipes_with_template = [r for r in matching_recipes if r .get ("SmtjRecipeTemplateS3Uri" )]
344344
345345 if not recipes_with_template :
346- raise ValueError (f"No recipes found with SmtjRecipeTemplateS3Uri for technique: { customization_technique } " )
347-
348- # If multiple recipes, filter by training_type (peft key)
349- if len (recipes_with_template ) > 1 :
350-
351- if isinstance (training_type , TrainingType ) and training_type == TrainingType .LORA :
352- # Filter recipes that have peft key for LORA
353- lora_recipes = [r for r in recipes_with_template if r .get ("Peft" )]
354- if lora_recipes :
355- recipes_with_template = lora_recipes
356- elif len (recipes_with_template ) > 1 :
357- raise ValueError (f"Multiple recipes found for LORA training but none have peft key" )
358- elif isinstance (training_type , TrainingType ) and training_type == TrainingType .FULL :
359- # For FULL training, if multiple recipes exist, throw error
360- if len (recipes_with_template ) > 1 :
361- raise ValueError (f"Multiple recipes found for FULL training - cannot determine which to use" )
362-
363- # If still multiple recipes after filtering, throw error
364- if len (recipes_with_template ) > 1 :
365- raise ValueError (f"Multiple recipes found after filtering - cannot determine which to use" )
366-
367- recipe = recipes_with_template [0 ]
368-
346+ raise ValueError (f"No recipes found with Smtj for technique: { customization_technique } " )
347+
348+ # Select recipe based on training type
349+ recipe = None
350+ if (isinstance (training_type , TrainingType ) and training_type == TrainingType .LORA ) or training_type == "LORA" :
351+ recipe = next ((r for r in recipes_with_template if r .get ("Peft" )), None )
352+ elif (isinstance (training_type , TrainingType ) and training_type == TrainingType .FULL ) or training_type == "FULL" :
353+ recipe = next ((r for r in recipes_with_template if not r .get ("Peft" )), None )
354+
369355 if recipe and recipe .get ("SmtjOverrideParamsS3Uri" ):
370356 s3_uri = recipe ["SmtjOverrideParamsS3Uri" ]
371357 s3 = boto3 .client ("s3" )
0 commit comments