@@ -50,11 +50,12 @@ class CustomTrainingArgs(Seq2SeqTrainingArguments):
5050 learning_rate : float = 1e-4
5151 lr_scheduler_type : str = "linear"
5252 warmup_ratio : float = 0.05
53- num_train_epochs : int = 6
53+ num_train_epochs : int = 3
54+
5455 per_device_train_batch_size : int = 8
5556 per_device_eval_batch_size : int = 16
5657 max_grad_norm : float = 0.5
57- # label_smoothing_factor: float = 0.1
58+ label_smoothing_factor : float = 0.1
5859 weight_decay : float = 0.01
5960 save_total_limit : int = 2
6061 fp16 : bool = True
@@ -69,6 +70,7 @@ class CustomTrainingArgs(Seq2SeqTrainingArguments):
6970
7071 # additional custom args
7172 peft_rank : int = field (default = 32 , metadata = {"help" : "LoRA adapter rank (r)." })
73+ lora_alpha : int = 64
7274 hf_hub_repo_id : str | None = None
7375 run_test : bool = field (
7476 default = False ,
@@ -144,9 +146,11 @@ def main() -> None:
144146 logger .info (
145147 f"Base model trainable params:\n { print_trainable_parameters (base_model )} "
146148 )
147- lora_model = build_peft_model (base_model , training_args .peft_rank )
149+ lora_model = build_peft_model (
150+ base_model , training_args .peft_rank , training_args .lora_alpha
151+ )
148152 logger .info (
149- f"LoRA model (peft_rank={ training_args .peft_rank } ) trainable params:\n { print_trainable_parameters (lora_model )} "
153+ f"LoRA model (peft_rank={ training_args .peft_rank } , lora_alpha= { training_args . lora_alpha } ) trainable params:\n { print_trainable_parameters (lora_model )} "
150154 )
151155
152156 # from torch.utils.data import DataLoader
0 commit comments