@@ -46,16 +46,16 @@ class CustomTrainingArgs(Seq2SeqTrainingArguments):
4646 )
4747 eval_strategy : str = "epoch"
4848 save_strategy : str = "epoch"
49- logging_steps : int = 5
49+ logging_steps : int = 50
5050 learning_rate : float = 1e-4
5151 lr_scheduler_type : str = "linear"
5252 warmup_ratio : float = 0.05
53- num_train_epochs : int = 3
54- per_device_train_batch_size : int = 8
55- per_device_eval_batch_size : int = 16
53+ num_train_epochs : int = 8
54+ per_device_train_batch_size : int = 16
55+ per_device_eval_batch_size : int = 32
5656 max_grad_norm : float = 0.5
5757 label_smoothing_factor : float = 0.1
58- weight_decay : float = 0.01
58+ # weight_decay: float = 0.01
5959 generation_max_length : int = 384
6060 save_total_limit : int = 2
6161 fp16 : bool = True
@@ -76,8 +76,8 @@ class CustomTrainingArgs(Seq2SeqTrainingArguments):
7676 default = False ,
7777 metadata = {"help" : "If True, run the test-split evaluation after training." },
7878 )
79- use_flash_attention : bool = field (
80- default = True , metadata = {"help" : "Whether to enable Flash Attention v1." }
79+ use_sdpa_attention : bool = field (
80+ default = True , metadata = {"help" : "Enable Sdpa for mem - efficient kernel . }
8181 )
8282
8383
@@ -143,6 +143,8 @@ def main() -> None:
143143
144144 # initialize base model and LoRA
145145 base_model = build_base_model ()
146+ if training_args .use_sdpa_attention :
147+ base_model .config .attn_implementation = "sdpa"
146148 logger .info (
147149 f"Base model trainable params:\n { print_trainable_parameters (base_model )} "
148150 )
@@ -153,42 +155,7 @@ def main() -> None:
153155 f"LoRA model (peft_rank={ training_args .peft_rank } , lora_alpha={ training_args .lora_alpha } ) trainable params:\n { print_trainable_parameters (lora_model )} "
154156 )
155157
156- # from torch.utils.data import DataLoader
157-
158- # data_collator = DataCollatorForSeq2Seq(
159- # tok,
160- # model=lora_model,
161- # padding="longest",
162- # label_pad_token_id=-100,
163- # )
164-
165- # batch = next(iter(DataLoader(ds_tok["train"], batch_size=2, collate_fn=data_collator )))
166- # # 1) decode inputs normally
167- # print("INPUTS:")
168- # print(tok.batch_decode(batch["input_ids"], skip_special_tokens=True))
169-
170- # # 2) map -100 → pad_token_id before decoding labels
171- # labels = batch["labels"].detach().cpu().numpy()
172- # labels = np.where(labels != -100, labels, tok.pad_token_id)
173-
174- # print("LABELS:")
175- # print(tok.batch_decode(labels, skip_special_tokens=True))
176-
177- # import sys
178- # sys.exit()
179-
180158 # ---------- Train ----------
181- # toggle flash attention
182- if training_args .use_flash_attention :
183- logger .info ("Using flash attention" )
184- ctx = torch .backends .cuda .sdp_kernel (
185- enable_flash = True , enable_math = True , enable_mem_efficient = True
186- )
187- else :
188- logger .info ("Skipping flash attention" )
189- ctx = nullcontext ()
190- # ctx = nullcontext()
191-
192159 # data collator: dynamic padding per batch
193160 data_collator = DataCollatorForSeq2Seq (
194161 tok ,
@@ -209,15 +176,13 @@ def main() -> None:
209176 compute_metrics = build_compute_metrics (tok ),
210177 )
211178
212- with ctx :
213- trainer .train ()
179+ trainer .train ()
214180
215181 # ---------- Save, Test or Push ----------
216182 # evaluate test
217183 if training_args .run_test :
218184 logger .info ("Running final test-set evaluation..." )
219- with ctx :
220- metrics = trainer .evaluate (ds_tok ["test" ])
185+ metrics = trainer .evaluate (ds_tok ["test" ])
221186 logger .info (f"Test metrics:\n { metrics } " )
222187 else :
223188 logger .info ("Skipping test evaluation." )
0 commit comments