@@ -34,7 +34,9 @@ class DataArgs:
3434 train_sample : bool = field (
3535 default = False , metadata = {"help" : "Use the mini CSV for smoke tests if True." }
3636 )
37-
37+ use_squad : bool = field (
38+ default = False , metadata = {"help" : "If True, ignore CSVs and load SQuAD instead." }
39+ )
3840
3941# training & LoRA extras — extend HF’s own Seq2SeqTrainingArguments
4042@dataclass
@@ -126,19 +128,42 @@ def main() -> None:
126128 logger .info (f"Set seed: { training_args .seed } " )
127129
128130 # ---------- Data Preprocessing ----------
129- # load and tokenize dataset
130- # load CSVs
131- data_files = {
132- "train" : str (
133- data_args .train_sample_file
134- if data_args .train_sample
135- else data_args .train_file
136- ),
137- "validation" : str (data_args .validation_file ),
138- "test" : str (data_args .test_file ),
139- }
140-
141- ds = load_dataset ("csv" , data_files = data_files )
131+ # load either CSVs or SQuAD for a quick pipeline sanity check
132+ if data_args .use_squad :
133+ # 1) pull down SQuAD
134+ raw = load_dataset ("squad" )
135+ # 2) map to simple Q/A pairs (first answer only)
136+ def to_qa (ex ):
137+ return {
138+ "question" : ex ["question" ],
139+ "answer" : ex ["answers" ]["text" ][0 ]
140+ }
141+ ds = raw .map (to_qa , remove_columns = raw ["train" ].column_names )
142+ else :
143+ # load from your processed CSVs
144+ data_files = {
145+ "train" : str (
146+ data_args .train_sample_file
147+ if data_args .train_sample
148+ else data_args .train_file
149+ ),
150+ "validation" : str (data_args .validation_file ),
151+ "test" : str (data_args .test_file ),
152+ }
153+ ds = load_dataset ("csv" , data_files = data_files )
154+ # # load and tokenize dataset
155+ # # load CSVs
156+ # data_files = {
157+ # "train": str(
158+ # data_args.train_sample_file
159+ # if data_args.train_sample
160+ # else data_args.train_file
161+ # ),
162+ # "validation": str(data_args.validation_file),
163+ # "test": str(data_args.test_file),
164+ # }
165+
166+ # ds = load_dataset("csv", data_files=data_files)
142167 ds_tok , tok = tokenize_and_format (ds )
143168
144169 # initialize base model and LoRA
0 commit comments