Skip to content

Commit f0d433a

Browse files
committed
increased lora rank more
1 parent 7fa925f commit f0d433a

File tree

3 files changed

+10
-6
lines changed

3 files changed

+10
-6
lines changed

src/korea_travel_guide/inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def main():
117117
# fast batched generate
118118
out = model.generate(
119119
**enc,
120-
max_length=1024,
120+
max_length=512,
121121
num_beams=5,
122122
early_stopping=True,
123123
length_penalty=1.0,

src/korea_travel_guide/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def build_peft_model(
1616
lora_alpha: int = 16,
1717
lora_dropout: float = 0.1,
1818
bias: str = "none",
19-
target_modules: list[str] = ("q_proj", "k_proj", "v_proj", "o_proj"),
19+
target_modules: list[str] = ("q_proj", "k_proj", "v_proj", "o_proj", "fc1", "fc2"),
2020
modules_to_save: list[str] = ("lm_head",),
2121
) -> PeftModel:
2222
config = LoraConfig(

src/korea_travel_guide/train.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,12 @@ class CustomTrainingArgs(Seq2SeqTrainingArguments):
5050
learning_rate: float = 1e-4
5151
lr_scheduler_type: str = "linear"
5252
warmup_ratio: float = 0.05
53-
num_train_epochs: int = 6
53+
num_train_epochs: int = 3
54+
5455
per_device_train_batch_size: int = 8
5556
per_device_eval_batch_size: int = 16
5657
max_grad_norm: float = 0.5
57-
# label_smoothing_factor: float = 0.1
58+
label_smoothing_factor: float = 0.1
5859
weight_decay: float = 0.01
5960
save_total_limit: int = 2
6061
fp16: bool = True
@@ -69,6 +70,7 @@ class CustomTrainingArgs(Seq2SeqTrainingArguments):
6970

7071
# additional custom args
7172
peft_rank: int = field(default=32, metadata={"help": "LoRA adapter rank (r)."})
73+
lora_alpha: int = 64
7274
hf_hub_repo_id: str | None = None
7375
run_test: bool = field(
7476
default=False,
@@ -144,9 +146,11 @@ def main() -> None:
144146
logger.info(
145147
f"Base model trainable params:\n{print_trainable_parameters(base_model)}"
146148
)
147-
lora_model = build_peft_model(base_model, training_args.peft_rank)
149+
lora_model = build_peft_model(
150+
base_model, training_args.peft_rank, training_args.lora_alpha
151+
)
148152
logger.info(
149-
f"LoRA model (peft_rank={training_args.peft_rank}) trainable params:\n{print_trainable_parameters(lora_model)}"
153+
f"LoRA model (peft_rank={training_args.peft_rank}, lora_alpha={training_args.lora_alpha}) trainable params:\n{print_trainable_parameters(lora_model)}"
150154
)
151155

152156
# from torch.utils.data import DataLoader

0 commit comments

Comments
 (0)