Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 47 additions & 7 deletions F2LLM/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
from transformers import AutoModel, AutoTokenizer
from utils import detect_encoder_only_model


class F2LLM:
Expand All @@ -12,7 +13,25 @@ def __init__(self,
self.args = args
self.dtype = torch.bfloat16
self.device = None # set after accelerator.prepare
self.lm = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=self.dtype, attn_implementation='flash_attention_2')

try:
self.lm = AutoModel.from_pretrained(
model_path,
trust_remote_code=True,
torch_dtype=self.dtype,
attn_implementation='flash_attention_2'
)
except Exception as e:
print(f"Flash Attention 2不可用,使用默认attention实现: {e}")
self.lm = AutoModel.from_pretrained(
model_path,
trust_remote_code=True,
torch_dtype=self.dtype
)

self.is_encoder_only = detect_encoder_only_model(self.lm.config)
print(f"模型类型: {'Encoder-only' if self.is_encoder_only else 'Decoder-only'}")

self.lm.config.use_cache = False
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.max_seq_length = max_seq_length
Expand All @@ -24,14 +43,35 @@ def forward(self, batch):
bs = batch['bs']
num_hard_neg = int((len(batch['input_ids']) - 2*bs) / bs)

outputs = self.lm(batch['input_ids'],
batch['attention_mask'],
)
outputs = self.lm(
batch['input_ids'],
batch['attention_mask'],
)

passage_features_all_tokens = outputs.last_hidden_state

if self.is_encoder_only:
query_features = passage_features_all_tokens[:bs, 0, :]
passage_features = passage_features_all_tokens[bs:2*bs, 0, :]
negative_features = None if num_hard_neg == 0 else \
passage_features_all_tokens[2*bs:, 0, :].view(bs, num_hard_neg, -1)
else:
query_features = torch.stack([
passage_features_all_tokens[i, batch['seq_lens'][i]-1]
for i in range(bs)
])
passage_features = torch.stack([
passage_features_all_tokens[i, batch['seq_lens'][i]-1]
for i in range(bs, 2*bs)
])
negative_features = None if num_hard_neg == 0 else torch.stack([
passage_features_all_tokens[i, batch['seq_lens'][i]-1]
for i in range(2*bs, len(batch['seq_lens']))
]).view(bs, num_hard_neg, -1)

return {
'query_passage_features': torch.stack([passage_features_all_tokens[i, [batch['seq_lens'][i]-1]] for i in range(bs)]),
'passage_passage_features': torch.stack([passage_features_all_tokens[i, [batch['seq_lens'][i]-1]] for i in range(bs, 2*bs)]),
'negative_passage_features': None if num_hard_neg == 0 else torch.stack([passage_features_all_tokens[i, [batch['seq_lens'][i]-1]] for i in range(2*bs, len(batch['seq_lens']))]).view(bs, num_hard_neg, -1)
'query_passage_features': query_features,
'passage_passage_features': passage_features,
'negative_passage_features': negative_features
}

10 changes: 8 additions & 2 deletions F2LLM/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,13 @@ def __iter__(self):

accelerator.print(f"******************************** Training step before prepare: {args.train_steps} ********************************")
model = F2LLM(args.model_path, args.max_seq_length, args=args)
model.lm.gradient_checkpointing_enable()

try:
model.lm.gradient_checkpointing_enable()
accelerator.print("Gradient checkpointing 已启用")
except Exception as e:
accelerator.print("继续训练但不使用 gradient checkpointing")

# set seed again to make sure that different models share the same seed
set_seed(0)

Expand Down Expand Up @@ -150,4 +156,4 @@ def __iter__(self):


accelerate_train(args, accelerator, model, train_dataloader, valid_loaders,
optimizer, lr_scheduler, len(dataset))
optimizer, lr_scheduler, len(dataset))
36 changes: 28 additions & 8 deletions F2LLM/tokenize_data_qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,37 @@
import numpy as np
import pandas as pd
import os
from transformers import AutoTokenizer
from transformers import AutoTokenizer, AutoConfig
from tqdm.auto import tqdm
from utils import detect_encoder_only_model


tokenizer = AutoTokenizer.from_pretrained('models/qwen3-0.6b')
model_path = 'models/albert-base-v2'
tokenizer = AutoTokenizer.from_pretrained(model_path)
max_seq_length = 1023
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
is_encoder_only = detect_encoder_only_model(config)


def process_sent(sentence):

# We make sure there's always an eos token at the end of each sequence
tokenizer_outputs = tokenizer(sentence, max_length=max_seq_length, truncation=True, add_special_tokens=False)

return np.array(tokenizer_outputs.input_ids + [tokenizer.eos_token_id])
if is_encoder_only:
# Encoder-only模型: 让tokenizer自动添加特殊token ([CLS], [SEP])
tokenizer_outputs = tokenizer(
sentence,
max_length=max_seq_length,
truncation=True,
add_special_tokens=True
)
return np.array(tokenizer_outputs.input_ids)
else:
# Decoder-only模型: 手动添加eos_token
tokenizer_outputs = tokenizer(
sentence,
max_length=max_seq_length,
truncation=True,
add_special_tokens=False
)
return np.array(tokenizer_outputs.input_ids + [tokenizer.eos_token_id])


def process_sent_batch(s):
Expand All @@ -30,8 +47,11 @@ def parallelize(data, func, num_of_processes=8):


root_dir = 'training_data'
model_name = config.model_type
for ds_name in tqdm(sorted(os.listdir(root_dir))):
print(ds_name, flush=True)
if not ds_name.endswith(".parquet"):
continue

df = pd.read_parquet(f"{root_dir}/{ds_name}")
df['query_input_ids'] = parallelize(df['query'], process_sent_batch, 62)
Expand All @@ -51,4 +71,4 @@ def parallelize(data, func, num_of_processes=8):
for i in range(1, num_neg+1):
df[f'negative_{i}_input_ids'] = df[f'negative_{i}'].map(df_tmp.input_ids)

df.to_parquet(f'data_tokenized_qwen/{ds_name}', index=False)
df.to_parquet(f'data_tokenized_{model_name}/{ds_name}', index=False)
23 changes: 23 additions & 0 deletions F2LLM/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,29 @@ def write_tensorboard(summary_writer: SummaryWriter, log_dict: dict, completed_s
summary_writer.add_scalar(key, value, completed_steps)


def detect_encoder_only_model(config):
model_type = config.model_type.lower() if hasattr(config, 'model_type') else ''

encoder_only_types = [
'bert', 'roberta', 'electra', 'deberta', 'deberta-v2',
'albert', 'xlm-roberta', 'camembert', 'distilbert',
'mpnet', 'squeezebert', 'mobilebert'
]

for encoder_type in encoder_only_types:
if encoder_type in model_type:
return True

if hasattr(config, 'is_decoder') and config.is_decoder:
return False

if hasattr(config, 'is_encoder_decoder') and not config.is_encoder_decoder:
if not hasattr(config, 'is_decoder'):
return True

return False


def save_checkpoint(args, accelerator, model, output_dir, lr_scheduler):
accelerator.wait_for_everyone()
accelerator.print(f"Saving checkpoint to {output_dir}")
Expand Down