From a68d600f64bf48889aa46deb8aa2b55188eda41a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BE=97=E7=A6=8F?= Date: Sun, 30 Nov 2025 22:23:21 +0800 Subject: [PATCH] suport encoder only model --- F2LLM/model.py | 54 ++++++++++++++++++++++++++++++++----- F2LLM/run.py | 10 +++++-- F2LLM/tokenize_data_qwen.py | 36 +++++++++++++++++++------ F2LLM/utils.py | 23 ++++++++++++++++ 4 files changed, 106 insertions(+), 17 deletions(-) diff --git a/F2LLM/model.py b/F2LLM/model.py index d33ade7..8c9c202 100644 --- a/F2LLM/model.py +++ b/F2LLM/model.py @@ -1,5 +1,6 @@ import torch from transformers import AutoModel, AutoTokenizer +from utils import detect_encoder_only_model class F2LLM: @@ -12,7 +13,25 @@ def __init__(self, self.args = args self.dtype = torch.bfloat16 self.device = None # set after accelerator.prepare - self.lm = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=self.dtype, attn_implementation='flash_attention_2') + + try: + self.lm = AutoModel.from_pretrained( + model_path, + trust_remote_code=True, + torch_dtype=self.dtype, + attn_implementation='flash_attention_2' + ) + except Exception as e: + print(f"Flash Attention 2不可用,使用默认attention实现: {e}") + self.lm = AutoModel.from_pretrained( + model_path, + trust_remote_code=True, + torch_dtype=self.dtype + ) + + self.is_encoder_only = detect_encoder_only_model(self.lm.config) + print(f"模型类型: {'Encoder-only' if self.is_encoder_only else 'Decoder-only'}") + self.lm.config.use_cache = False self.tokenizer = AutoTokenizer.from_pretrained(model_path) self.max_seq_length = max_seq_length @@ -24,14 +43,35 @@ def forward(self, batch): bs = batch['bs'] num_hard_neg = int((len(batch['input_ids']) - 2*bs) / bs) - outputs = self.lm(batch['input_ids'], - batch['attention_mask'], - ) + outputs = self.lm( + batch['input_ids'], + batch['attention_mask'], + ) passage_features_all_tokens = outputs.last_hidden_state + + if self.is_encoder_only: + query_features = passage_features_all_tokens[:bs, 0, :] + passage_features = passage_features_all_tokens[bs:2*bs, 0, :] + negative_features = None if num_hard_neg == 0 else \ + passage_features_all_tokens[2*bs:, 0, :].view(bs, num_hard_neg, -1) + else: + query_features = torch.stack([ + passage_features_all_tokens[i, batch['seq_lens'][i]-1] + for i in range(bs) + ]) + passage_features = torch.stack([ + passage_features_all_tokens[i, batch['seq_lens'][i]-1] + for i in range(bs, 2*bs) + ]) + negative_features = None if num_hard_neg == 0 else torch.stack([ + passage_features_all_tokens[i, batch['seq_lens'][i]-1] + for i in range(2*bs, len(batch['seq_lens'])) + ]).view(bs, num_hard_neg, -1) + return { - 'query_passage_features': torch.stack([passage_features_all_tokens[i, [batch['seq_lens'][i]-1]] for i in range(bs)]), - 'passage_passage_features': torch.stack([passage_features_all_tokens[i, [batch['seq_lens'][i]-1]] for i in range(bs, 2*bs)]), - 'negative_passage_features': None if num_hard_neg == 0 else torch.stack([passage_features_all_tokens[i, [batch['seq_lens'][i]-1]] for i in range(2*bs, len(batch['seq_lens']))]).view(bs, num_hard_neg, -1) + 'query_passage_features': query_features, + 'passage_passage_features': passage_features, + 'negative_passage_features': negative_features } diff --git a/F2LLM/run.py b/F2LLM/run.py index e40b707..fae529b 100644 --- a/F2LLM/run.py +++ b/F2LLM/run.py @@ -120,7 +120,13 @@ def __iter__(self): accelerator.print(f"******************************** Training step before prepare: {args.train_steps} ********************************") model = F2LLM(args.model_path, args.max_seq_length, args=args) -model.lm.gradient_checkpointing_enable() + +try: + model.lm.gradient_checkpointing_enable() + accelerator.print("Gradient checkpointing 已启用") +except Exception as e: + accelerator.print("继续训练但不使用 gradient checkpointing") + # set seed again to make sure that different models share the same seed set_seed(0) @@ -150,4 +156,4 @@ def __iter__(self): accelerate_train(args, accelerator, model, train_dataloader, valid_loaders, - optimizer, lr_scheduler, len(dataset)) \ No newline at end of file + optimizer, lr_scheduler, len(dataset)) diff --git a/F2LLM/tokenize_data_qwen.py b/F2LLM/tokenize_data_qwen.py index 2d9c47e..7a8de23 100644 --- a/F2LLM/tokenize_data_qwen.py +++ b/F2LLM/tokenize_data_qwen.py @@ -2,20 +2,37 @@ import numpy as np import pandas as pd import os -from transformers import AutoTokenizer +from transformers import AutoTokenizer, AutoConfig from tqdm.auto import tqdm +from utils import detect_encoder_only_model -tokenizer = AutoTokenizer.from_pretrained('models/qwen3-0.6b') +model_path = 'models/albert-base-v2' +tokenizer = AutoTokenizer.from_pretrained(model_path) max_seq_length = 1023 +config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) +is_encoder_only = detect_encoder_only_model(config) def process_sent(sentence): - - # We make sure there's always an eos token at the end of each sequence - tokenizer_outputs = tokenizer(sentence, max_length=max_seq_length, truncation=True, add_special_tokens=False) - - return np.array(tokenizer_outputs.input_ids + [tokenizer.eos_token_id]) + if is_encoder_only: + # Encoder-only模型: 让tokenizer自动添加特殊token ([CLS], [SEP]) + tokenizer_outputs = tokenizer( + sentence, + max_length=max_seq_length, + truncation=True, + add_special_tokens=True + ) + return np.array(tokenizer_outputs.input_ids) + else: + # Decoder-only模型: 手动添加eos_token + tokenizer_outputs = tokenizer( + sentence, + max_length=max_seq_length, + truncation=True, + add_special_tokens=False + ) + return np.array(tokenizer_outputs.input_ids + [tokenizer.eos_token_id]) def process_sent_batch(s): @@ -30,8 +47,11 @@ def parallelize(data, func, num_of_processes=8): root_dir = 'training_data' +model_name = config.model_type for ds_name in tqdm(sorted(os.listdir(root_dir))): print(ds_name, flush=True) + if not ds_name.endswith(".parquet"): + continue df = pd.read_parquet(f"{root_dir}/{ds_name}") df['query_input_ids'] = parallelize(df['query'], process_sent_batch, 62) @@ -51,4 +71,4 @@ def parallelize(data, func, num_of_processes=8): for i in range(1, num_neg+1): df[f'negative_{i}_input_ids'] = df[f'negative_{i}'].map(df_tmp.input_ids) - df.to_parquet(f'data_tokenized_qwen/{ds_name}', index=False) + df.to_parquet(f'data_tokenized_{model_name}/{ds_name}', index=False) diff --git a/F2LLM/utils.py b/F2LLM/utils.py index b167d3c..7fbb443 100644 --- a/F2LLM/utils.py +++ b/F2LLM/utils.py @@ -15,6 +15,29 @@ def write_tensorboard(summary_writer: SummaryWriter, log_dict: dict, completed_s summary_writer.add_scalar(key, value, completed_steps) +def detect_encoder_only_model(config): + model_type = config.model_type.lower() if hasattr(config, 'model_type') else '' + + encoder_only_types = [ + 'bert', 'roberta', 'electra', 'deberta', 'deberta-v2', + 'albert', 'xlm-roberta', 'camembert', 'distilbert', + 'mpnet', 'squeezebert', 'mobilebert' + ] + + for encoder_type in encoder_only_types: + if encoder_type in model_type: + return True + + if hasattr(config, 'is_decoder') and config.is_decoder: + return False + + if hasattr(config, 'is_encoder_decoder') and not config.is_encoder_decoder: + if not hasattr(config, 'is_decoder'): + return True + + return False + + def save_checkpoint(args, accelerator, model, output_dir, lr_scheduler): accelerator.wait_for_everyone() accelerator.print(f"Saving checkpoint to {output_dir}")