Skip to content

Commit 7fa925f

Browse files
author
Ubuntu
committed
increased lora capacity, more epoches, and increased max length for generated output
1 parent 300f419 commit 7fa925f

File tree

3 files changed

+10
-8
lines changed

3 files changed

+10
-8
lines changed

src/korea_travel_guide/inference.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,11 @@ def main():
117117
# fast batched generate
118118
out = model.generate(
119119
**enc,
120-
max_length=128,
121-
num_beams=4,
120+
max_length=1024,
121+
num_beams=5,
122122
early_stopping=True,
123+
length_penalty=1.0,
124+
repetition_penalty=1.1,
123125
)
124126

125127
decoded = tok.batch_decode(out, skip_special_tokens=True)

src/korea_travel_guide/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def build_peft_model(
1616
lora_alpha: int = 16,
1717
lora_dropout: float = 0.1,
1818
bias: str = "none",
19-
target_modules: list[str] = ("q_proj", "v_proj"),
19+
target_modules: list[str] = ("q_proj", "k_proj", "v_proj", "o_proj"),
2020
modules_to_save: list[str] = ("lm_head",),
2121
) -> PeftModel:
2222
config = LoraConfig(

src/korea_travel_guide/train.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,12 @@ class CustomTrainingArgs(Seq2SeqTrainingArguments):
4747
eval_strategy: str = "epoch"
4848
save_strategy: str = "epoch"
4949
logging_steps: int = 5
50-
learning_rate: float = 3e-5
50+
learning_rate: float = 1e-4
5151
lr_scheduler_type: str = "linear"
5252
warmup_ratio: float = 0.05
53-
num_train_epochs: int = 5
54-
per_device_train_batch_size: int = 16
55-
per_device_eval_batch_size: int = 32
53+
num_train_epochs: int = 6
54+
per_device_train_batch_size: int = 8
55+
per_device_eval_batch_size: int = 16
5656
max_grad_norm: float = 0.5
5757
# label_smoothing_factor: float = 0.1
5858
weight_decay: float = 0.01
@@ -68,7 +68,7 @@ class CustomTrainingArgs(Seq2SeqTrainingArguments):
6868
label_names: List[str] = field(default_factory=lambda: ["labels"])
6969

7070
# additional custom args
71-
peft_rank: int = field(default=8, metadata={"help": "LoRA adapter rank (r)."})
71+
peft_rank: int = field(default=32, metadata={"help": "LoRA adapter rank (r)."})
7272
hf_hub_repo_id: str | None = None
7373
run_test: bool = field(
7474
default=False,

0 commit comments

Comments
 (0)