22import random
33import numpy as np
44import torch
5- from contextlib import nullcontext
65from datasets import load_dataset
76from dataclasses import dataclass , field
87from transformers import (
98 HfArgumentParser ,
109 Seq2SeqTrainingArguments ,
1110 DataCollatorForSeq2Seq ,
1211 Seq2SeqTrainer ,
12+ EarlyStoppingCallback ,
1313)
1414from typing import List
1515from pathlib import Path
1616from korea_travel_guide .utils import load_environ_vars , print_trainable_parameters
1717from korea_travel_guide .model import build_base_model , build_peft_model
1818from korea_travel_guide .data import tokenize_and_format
19- from korea_travel_guide .evaluation import build_compute_metrics
2019from uuid import uuid4
2120
2221logger = logging .getLogger (__name__ )
@@ -38,6 +37,7 @@ class DataArgs:
3837 default = False , metadata = {"help" : "If True, ignore CSVs and load SQuAD instead." }
3938 )
4039
40+
4141# training & LoRA extras — extend HF’s own Seq2SeqTrainingArguments
4242@dataclass
4343class CustomTrainingArgs (Seq2SeqTrainingArguments ):
@@ -46,22 +46,25 @@ class CustomTrainingArgs(Seq2SeqTrainingArguments):
4646 default = "outputs/bart-base-korea-travel-guide-lora" ,
4747 metadata = {"help" : "Prefix folder for all checkpoints/run logs." },
4848 )
49- eval_strategy : str = "epoch"
50- save_strategy : str = "epoch"
51- logging_steps : int = 50
52- learning_rate : float = 1e-4
53- lr_scheduler_type : str = "linear"
54- warmup_ratio : float = 0.05
5549 num_train_epochs : int = 6
5650 per_device_train_batch_size : int = 8
5751 per_device_eval_batch_size : int = 16
52+ learning_rate : float = 7e-5
53+ lr_scheduler_type : str = "cosine"
54+ warmup_ratio : float = 0.1
5855 max_grad_norm : float = 0.5
5956 label_smoothing_factor : float = 0.1
60- # weight_decay: float = 0.01
61- generation_max_length : int = 384
57+ weight_decay : float = 0.01
58+
59+ eval_strategy : str = "epoch"
60+ save_strategy : str = "epoch"
61+ logging_steps : int = 50
6262 save_total_limit : int = 2
63+ load_best_model_at_end : bool = True
64+ metric_for_best_model : str = "eval/loss"
65+ greater_is_better : bool = False
66+
6367 fp16 : bool = True
64- predict_with_generate : bool = True
6568 push_to_hub : bool = False
6669 report_to : str = "wandb"
6770 run_name : str = field (
@@ -92,10 +95,6 @@ def parse_args() -> tuple[DataArgs, CustomTrainingArgs]:
9295 if training_args .push_to_hub and not training_args .hf_hub_repo_id :
9396 parser .error ("--hf_hub_repo_id is required when --push_to_hub is set" )
9497
95- # # isolate each run’s artefacts (good for sweeps)
96- # run_id = os.environ.get("WANDB_RUN_ID", uuid4().hex[:8])
97- # training_args.output_dir = f"{training_args.output_dir}/{run_id}"
98-
9998 # set wandb for logging
10099 training_args .report_to = "wandb"
101100
@@ -130,14 +129,13 @@ def main() -> None:
130129 # ---------- Data Preprocessing ----------
131130 # load either CSVs or SQuAD for a quick pipeline sanity check
132131 if data_args .use_squad :
133- # 1) pull down SQuAD
132+ # 1) pull down SQuAD
134133 raw = load_dataset ("squad" )
134+
135135 # 2) map to simple Q/A pairs (first answer only)
136136 def to_qa (ex ):
137- return {
138- "question" : ex ["question" ],
139- "answer" : ex ["answers" ]["text" ][0 ]
140- }
137+ return {"question" : ex ["question" ], "answer" : ex ["answers" ]["text" ][0 ]}
138+
141139 ds = raw .map (to_qa , remove_columns = raw ["train" ].column_names )
142140 else :
143141 # load from your processed CSVs
@@ -198,7 +196,7 @@ def to_qa(ex):
198196 eval_dataset = ds_tok ["validation" ],
199197 tokenizer = tok ,
200198 data_collator = data_collator ,
201- # compute_metrics=build_compute_metrics(tok) ,
199+ callbacks = [ EarlyStoppingCallback ( early_stopping_patience = 2 )] ,
202200 )
203201
204202 trainer .train ()
0 commit comments