Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions F2LLM/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,28 @@ where N_NODE is the number of machines; N_PROCESSES is N_NODE\*8; MASTER_IP is t

On worker nodes, also run the above commmand but modify `machine_rank` accordingly.

### LoRA Training

This repository now supports Parameter-Efficient Fine-Tuning (PEFT) using LoRA (Low-Rank Adaptation) to significantly reduce computational costs and memory usage during training.

To use LoRA training:

1. Add LoRA parameters to your config file (see `configs/config_lora.json` for an example):
```json
{
"use_lora": true,
"lora_r": 8,
"lora_alpha": 32,
"lora_dropout": 0.1,
"lora_target_modules": "q_proj,v_proj"
}
```

2. Run training with the LoRA config:
```
accelerate launch --config_file configs/accelerate_config.yaml run.py --config configs/config_lora.json
```

### Citation

If you use the F2LLM models, data, or code, please cite the following technical report.
Expand Down
6 changes: 6 additions & 0 deletions F2LLM/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ class Args:
log_interval: int = 20
checkpointing_steps: int = 100
validation_steps: int = 100
# LoRA settings
use_lora: bool = False
lora_r: int = 8
lora_alpha: int = 32
lora_dropout: float = 0.1
lora_target_modules: str = "q_proj,v_proj"
# just placeholder, for logging purpose
num_processes: int=0

Expand Down
24 changes: 24 additions & 0 deletions F2LLM/configs/config_lora.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
{
"model_path": "models/qwen3-4b",
"experiment_id": "4b+lr.8e-6+bs.16x32+context.1024+2epochs+lora",
"train_data_path": "training_data/data_tokenized_qwen",
"output_dir": "output",
"tb_dir": "output/tb",
"cache_dir": "cache",
"train_batch_size": 16,
"checkpointing_steps": 5000,
"validation_steps": 5000,
"max_seq_length": 1024,
"learning_rate": 8e-6,
"min_lr": 1e-7,
"weight_decay": 0.01,
"warmup_steps": 500,
"train_epochs": 2,
"log_interval": 100,
"num_hard_neg": 7,
"use_lora": true,
"lora_r": 8,
"lora_alpha": 32,
"lora_dropout": 0.1,
"lora_target_modules": "q_proj,v_proj"
}
29 changes: 28 additions & 1 deletion F2LLM/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
from transformers import AutoModel, AutoTokenizer
from peft import get_peft_model, LoraConfig, TaskType


class F2LLM:
Expand All @@ -16,6 +17,33 @@ def __init__(self,
self.lm.config.use_cache = False
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.max_seq_length = max_seq_length

# Apply LoRA if enabled
if args and args.use_lora:
self._apply_lora()

# Enable gradient requirements for LoRA with flash attention
if hasattr(self.lm, 'enable_input_require_grads'):
self.lm.enable_input_require_grads()

def _apply_lora(self):
"""Apply LoRA adaptation to the model"""
# Print LoRA training message
print("Using LoRA training, optimizing only LoRA parameters")

target_modules = self.args.lora_target_modules.split(",") if self.args.lora_target_modules else None

peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM, # For decoder-only models
inference_mode=False,
r=self.args.lora_r,
lora_alpha=self.args.lora_alpha,
lora_dropout=self.args.lora_dropout,
target_modules=target_modules
)

self.lm = get_peft_model(self.lm, peft_config)
self.lm.print_trainable_parameters()

def set_device(self):
self.device = self.lm.device
Expand All @@ -34,4 +62,3 @@ def forward(self, batch):
'passage_passage_features': torch.stack([passage_features_all_tokens[i, [batch['seq_lens'][i]-1]] for i in range(bs, 2*bs)]),
'negative_passage_features': None if num_hard_neg == 0 else torch.stack([passage_features_all_tokens[i, [batch['seq_lens'][i]-1]] for i in range(2*bs, len(batch['seq_lens']))]).view(bs, num_hard_neg, -1)
}

15 changes: 8 additions & 7 deletions F2LLM/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
accelerate
datasets
deepspeed
flash-attn
torch
transformers
tensorboard
accelerate==1.3.0
datasets==2.21.0
deepspeed==0.16.2
flash-attn==2.3.6+pack.glm.mask
torch==2.4.0+cu124
transformers==4.51.0
tensorboard==2.20.0
peft==0.3.0
2 changes: 2 additions & 0 deletions F2LLM/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ def __iter__(self):
# set seed again to make sure that different models share the same seed
set_seed(0)

if args.use_lora:
accelerator.print("Using LoRA training, optimizing only LoRA parameters")
optimizer = AdamW(model.lm.parameters(),
weight_decay=args.weight_decay,
lr=args.learning_rate,
Expand Down