Skip to content

Commit a8cc765

Browse files
committed
removed compute_metrics from training to improve efficiency; added 1000 more dataset; improved inference output quality
1 parent 887697b commit a8cc765

File tree

4 files changed

+55
-39
lines changed

4 files changed

+55
-39
lines changed

data/subreddit_size_map.json

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,7 @@
44
"AskSocialScience": 500,
55
"AskDocs": 500,
66
"askscience": 1500,
7-
"AskHistorians": 1500
7+
"AskHistorians": 1500,
8+
"AskBiology": 414,
9+
"AskEconomics": 532
810
}

src/korea_travel_guide/evaluation.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
from transformers import EvalPrediction
44

55

6-
def build_compute_metrics(tok):
6+
def build_compute_metrics(tok, num_process_workers: int = 2):
77
"""Return a closure that Hugging Face's Trainer can call."""
88
rouge = evaluate.load("rouge") # longest-substring overlap
99
bleu = evaluate.load("bleu") # n-gram precision
10-
# bertscore = evaluate.load("bertscore") # semantic similarity
10+
bertscore = evaluate.load("bertscore") # semantic similarity
1111

1212
def _compute_metrics(eval_pred: EvalPrediction):
1313
preds, labels = eval_pred.predictions, eval_pred.label_ids
@@ -23,24 +23,31 @@ def _compute_metrics(eval_pred: EvalPrediction):
2323

2424
# metrics
2525
rouge_l = rouge.compute(
26-
predictions=decoded_preds, references=decoded_labels, use_stemmer=True
26+
predictions=decoded_preds,
27+
references=decoded_labels,
28+
use_stemmer=True,
29+
num_process_workers=num_process_workers,
2730
)["rougeL"]
2831
bleu_score = bleu.compute(
2932
predictions=decoded_preds,
3033
references=[[ref] for ref in decoded_labels], # BLEU expects list-of-lists
3134
smooth=True,
35+
num_process_workers=num_process_workers,
3236
)["bleu"]
33-
# bert_f1 = np.mean(
34-
# bertscore.compute(
35-
# predictions=decoded_preds, references=decoded_labels, lang="en"
36-
# )["f1"]
37-
# )
37+
bert_f1 = np.mean(
38+
bertscore.compute(
39+
predictions=decoded_preds,
40+
references=decoded_labels,
41+
lang="en",
42+
num_process_workers=num_process_workers,
43+
)["f1"]
44+
)
3845

3946
# round for nice logging
4047
return {
4148
"rougeL": round(rouge_l * 100, 4),
4249
"bleu": round(bleu_score * 100, 4),
43-
# "bertscore_f1": round(bert_f1, 4),
50+
"bertscore_f1": round(bert_f1, 4),
4451
}
4552

4653
return _compute_metrics

src/korea_travel_guide/inference.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import logging
22
import torch
3-
from contextlib import nullcontext
43
from dataclasses import dataclass, field
54
from datasets import load_dataset
65
from transformers import (
@@ -34,6 +33,10 @@ class InferenceArgs:
3433
default_factory=list,
3534
metadata={"help": "One or more input texts for `predict` mode."},
3635
)
36+
num_process_workers: int = field(
37+
default=2,
38+
metadata={"help": "Number of workers to parallelize n-gram counting."},
39+
)
3740
use_sdpa_attention: bool = field(
3841
default=True, metadata={"help": "Enable Sdpa for mem-efficient kernel."}
3942
)
@@ -82,12 +85,13 @@ def main():
8285
output_dir="outputs/inference",
8386
per_device_eval_batch_size=inf_args.batch_size,
8487
predict_with_generate=True,
88+
generation_max_length=384,
8589
report_to=[],
8690
),
8791
eval_dataset=ds_tok["test"],
8892
data_collator=data_collator,
8993
tokenizer=tok,
90-
compute_metrics=build_compute_metrics(tok),
94+
compute_metrics=build_compute_metrics(tok, inf_args.num_process_workers),
9195
)
9296

9397
pred_output = trainer.predict(ds_tok["test"])
@@ -106,14 +110,19 @@ def main():
106110
)
107111
enc = {k: v.to(device) for k, v in enc.items()}
108112

109-
# fast batched generate
113+
# fast batched generate (with arguments for higher quality generations)
110114
out = model.generate(
111115
**enc,
112-
max_length=512,
113-
num_beams=5,
114-
early_stopping=True,
115-
length_penalty=1.0,
116-
repetition_penalty=1.1,
116+
max_length=200,
117+
num_beams=5, # improves quality
118+
do_sample=True, # add stochasticity
119+
length_penalty=1.2, # >1 favors longer answers
120+
repetition_penalty=1.3, # >1 penalizes reuse of the same token
121+
no_repeat_ngram_size=3, # block exact n-gram repeats
122+
top_p=0.9, # nucleus sampling for diversity
123+
temperature=0.8, # nucleus sampling for diversity
124+
early_stopping=True, # stop on EOS to avoid garbage at the end
125+
eos_token_id=tok.eos_token_id,
117126
)
118127

119128
decoded = tok.batch_decode(out, skip_special_tokens=True)

src/korea_travel_guide/train.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,20 @@
22
import random
33
import numpy as np
44
import torch
5-
from contextlib import nullcontext
65
from datasets import load_dataset
76
from dataclasses import dataclass, field
87
from transformers import (
98
HfArgumentParser,
109
Seq2SeqTrainingArguments,
1110
DataCollatorForSeq2Seq,
1211
Seq2SeqTrainer,
12+
EarlyStoppingCallback,
1313
)
1414
from typing import List
1515
from pathlib import Path
1616
from korea_travel_guide.utils import load_environ_vars, print_trainable_parameters
1717
from korea_travel_guide.model import build_base_model, build_peft_model
1818
from korea_travel_guide.data import tokenize_and_format
19-
from korea_travel_guide.evaluation import build_compute_metrics
2019
from uuid import uuid4
2120

2221
logger = logging.getLogger(__name__)
@@ -38,6 +37,7 @@ class DataArgs:
3837
default=False, metadata={"help": "If True, ignore CSVs and load SQuAD instead."}
3938
)
4039

40+
4141
# training & LoRA extras — extend HF’s own Seq2SeqTrainingArguments
4242
@dataclass
4343
class CustomTrainingArgs(Seq2SeqTrainingArguments):
@@ -46,22 +46,25 @@ class CustomTrainingArgs(Seq2SeqTrainingArguments):
4646
default="outputs/bart-base-korea-travel-guide-lora",
4747
metadata={"help": "Prefix folder for all checkpoints/run logs."},
4848
)
49-
eval_strategy: str = "epoch"
50-
save_strategy: str = "epoch"
51-
logging_steps: int = 50
52-
learning_rate: float = 1e-4
53-
lr_scheduler_type: str = "linear"
54-
warmup_ratio: float = 0.05
5549
num_train_epochs: int = 6
5650
per_device_train_batch_size: int = 8
5751
per_device_eval_batch_size: int = 16
52+
learning_rate: float = 7e-5
53+
lr_scheduler_type: str = "cosine"
54+
warmup_ratio: float = 0.1
5855
max_grad_norm: float = 0.5
5956
label_smoothing_factor: float = 0.1
60-
# weight_decay: float = 0.01
61-
generation_max_length: int = 384
57+
weight_decay: float = 0.01
58+
59+
eval_strategy: str = "epoch"
60+
save_strategy: str = "epoch"
61+
logging_steps: int = 50
6262
save_total_limit: int = 2
63+
load_best_model_at_end: bool = True
64+
metric_for_best_model: str = "eval/loss"
65+
greater_is_better: bool = False
66+
6367
fp16: bool = True
64-
predict_with_generate: bool = True
6568
push_to_hub: bool = False
6669
report_to: str = "wandb"
6770
run_name: str = field(
@@ -92,10 +95,6 @@ def parse_args() -> tuple[DataArgs, CustomTrainingArgs]:
9295
if training_args.push_to_hub and not training_args.hf_hub_repo_id:
9396
parser.error("--hf_hub_repo_id is required when --push_to_hub is set")
9497

95-
# # isolate each run’s artefacts (good for sweeps)
96-
# run_id = os.environ.get("WANDB_RUN_ID", uuid4().hex[:8])
97-
# training_args.output_dir = f"{training_args.output_dir}/{run_id}"
98-
9998
# set wandb for logging
10099
training_args.report_to = "wandb"
101100

@@ -130,14 +129,13 @@ def main() -> None:
130129
# ---------- Data Preprocessing ----------
131130
# load either CSVs or SQuAD for a quick pipeline sanity check
132131
if data_args.use_squad:
133-
# 1) pull down SQuAD
132+
# 1) pull down SQuAD
134133
raw = load_dataset("squad")
134+
135135
# 2) map to simple Q/A pairs (first answer only)
136136
def to_qa(ex):
137-
return {
138-
"question": ex["question"],
139-
"answer": ex["answers"]["text"][0]
140-
}
137+
return {"question": ex["question"], "answer": ex["answers"]["text"][0]}
138+
141139
ds = raw.map(to_qa, remove_columns=raw["train"].column_names)
142140
else:
143141
# load from your processed CSVs
@@ -198,7 +196,7 @@ def to_qa(ex):
198196
eval_dataset=ds_tok["validation"],
199197
tokenizer=tok,
200198
data_collator=data_collator,
201-
# compute_metrics=build_compute_metrics(tok),
199+
callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
202200
)
203201

204202
trainer.train()

0 commit comments

Comments
 (0)