From f2d41d2f959a52d53d341d2a28b8a44930bcb78a Mon Sep 17 00:00:00 2001 From: "qingnang.lh" Date: Sat, 6 Dec 2025 13:59:35 +0800 Subject: [PATCH 1/2] Add support for Encoder-only Model --- F2LLM/README.md | 2 +- F2LLM/configs/bert_config.json | 19 +++++ F2LLM/model.py | 46 +++++++++--- F2LLM/tokenize_data.py | 123 +++++++++++++++++++++++++++++++++ F2LLM/utils.py | 71 ++++++++++++++++++- 5 files changed, 249 insertions(+), 12 deletions(-) create mode 100644 F2LLM/configs/bert_config.json create mode 100644 F2LLM/tokenize_data.py diff --git a/F2LLM/README.md b/F2LLM/README.md index 6b79819..f9b5f74 100644 --- a/F2LLM/README.md +++ b/F2LLM/README.md @@ -26,7 +26,7 @@ In this repo we provide a streamlined and efficient script for training embeddin - Setup environment following `requirements.txt`. We note that transformers>=4.51.0 is required for training Qwen3 models. - Download data and backbone models from Hugging Face (we use Qwen3 models). -- Run `tokenize_data_qwen.py` to tokenize the downloaded data +- Run `tokenize_data.py --tokenizer qwen` to tokenize the downloaded data - Modify model path, data path, and other arguments in `configs/config.json`. - Start training with `accelerate launch --config_file configs/accelerate_config.yaml run.py --config configs/config.json`. diff --git a/F2LLM/configs/bert_config.json b/F2LLM/configs/bert_config.json new file mode 100644 index 0000000..0934211 --- /dev/null +++ b/F2LLM/configs/bert_config.json @@ -0,0 +1,19 @@ +{ + "model_path": "models/bert_multilingual", + "experiment_id": "bert_multilingual+lr.2e-5+bs.32+context.512+3epochs", + "train_data_path": "training_data/data_tokenized_bert", + "output_dir": "output", + "tb_dir": "output/tb", + "cache_dir": "cache", + "train_batch_size": 32, + "checkpointing_steps": 5000, + "validation_steps": 5000, + "max_seq_length": 512, + "learning_rate": 2e-5, + "min_lr": 1e-7, + "weight_decay": 0.01, + "warmup_steps": 500, + "train_epochs": 3, + "log_interval": 100, + "num_hard_neg": 7 +} \ No newline at end of file diff --git a/F2LLM/model.py b/F2LLM/model.py index d33ade7..f1992c8 100644 --- a/F2LLM/model.py +++ b/F2LLM/model.py @@ -1,5 +1,8 @@ import torch from transformers import AutoModel, AutoTokenizer +from utils import detect_model_type, extract_cls_embeddings, extract_mean_pooling_embeddings, extract_last_token_embeddings +import flash_attn +from packaging import version class F2LLM: @@ -12,11 +15,34 @@ def __init__(self, self.args = args self.dtype = torch.bfloat16 self.device = None # set after accelerator.prepare - self.lm = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=self.dtype, attn_implementation='flash_attention_2') + self.model_type = detect_model_type(model_path) + + flash_attn_version = getattr(flash_attn, '__version__', '0.0.0') + no_support_deterministic = version.parse(flash_attn_version) < version.parse("2.4.1") + + if self.model_type == 'encoder_only' and no_support_deterministic: + attn_implementation = 'eager' + else: + attn_implementation = 'flash_attention_2' + + print(f"{self.model_type}") + self.lm = AutoModel.from_pretrained( + model_path, + trust_remote_code=True, + torch_dtype=self.dtype, + attn_implementation=attn_implementation + ) + self.lm.config.use_cache = False self.tokenizer = AutoTokenizer.from_pretrained(model_path) self.max_seq_length = max_seq_length - + self.use_cls_pooling = self._should_use_cls_pooling() + + def _should_use_cls_pooling(self): + if self.model_type != 'encoder_only': + return False + return not hasattr(self.lm, 'pooler') or self.lm.pooler is None + def set_device(self): self.device = self.lm.device @@ -27,11 +53,11 @@ def forward(self, batch): outputs = self.lm(batch['input_ids'], batch['attention_mask'], ) - - passage_features_all_tokens = outputs.last_hidden_state - return { - 'query_passage_features': torch.stack([passage_features_all_tokens[i, [batch['seq_lens'][i]-1]] for i in range(bs)]), - 'passage_passage_features': torch.stack([passage_features_all_tokens[i, [batch['seq_lens'][i]-1]] for i in range(bs, 2*bs)]), - 'negative_passage_features': None if num_hard_neg == 0 else torch.stack([passage_features_all_tokens[i, [batch['seq_lens'][i]-1]] for i in range(2*bs, len(batch['seq_lens']))]).view(bs, num_hard_neg, -1) - } - + + if self.model_type == 'decoder_only': + return extract_last_token_embeddings(bs, num_hard_neg, outputs.last_hidden_state, batch) + else: + if self.use_cls_pooling: + return extract_cls_embeddings(bs, num_hard_neg, outputs.last_hidden_state, batch) + else: + return extract_mean_pooling_embeddings(bs, num_hard_neg, outputs.last_hidden_state, batch) diff --git a/F2LLM/tokenize_data.py b/F2LLM/tokenize_data.py new file mode 100644 index 0000000..aeacb7a --- /dev/null +++ b/F2LLM/tokenize_data.py @@ -0,0 +1,123 @@ +from multiprocessing import Pool +from functools import partial +import numpy as np +import pandas as pd +import os +import argparse +from transformers import AutoTokenizer +from tqdm.auto import tqdm + + +def get_tokenizer_config(tokenizer_type): + if tokenizer_type == 'bert': + return { + 'model_path': 'models/bert_multilingual', + 'max_seq_length': 512, + 'add_special_tokens': True, + 'add_eos_token': False, + 'output_dir': 'data_tokenized_bert' + } + elif tokenizer_type == 'qwen': + return { + 'model_path': 'models/qwen3-0.6b', + 'max_seq_length': 1023, + 'add_special_tokens': False, + 'add_eos_token': True, + 'output_dir': 'data_tokenized_qwen' + } + else: + raise ValueError(f"Unsupported tokenizer type: {tokenizer_type}") + + +def initialize_tokenizer(config): + tokenizer = AutoTokenizer.from_pretrained(config['model_path']) + max_seq_length = config['max_seq_length'] + return tokenizer, max_seq_length + + +def process_sent(sentence, tokenizer, max_seq_length, config): + # Process a single sentence with the specified tokenizer and configuration + tokenizer_outputs = tokenizer( + sentence, + max_length=max_seq_length, + truncation=True, + add_special_tokens=config['add_special_tokens'] + ) + + input_ids = tokenizer_outputs.input_ids + + # Add EOS token for Qwen if required + if config['add_eos_token']: + input_ids = input_ids + [tokenizer.eos_token_id] + + return np.array(input_ids) + + +def process_sent_batch(s, tokenizer, max_seq_length, config): + return s.apply(lambda x: process_sent(x, tokenizer, max_seq_length, config)) + +def parallelize(data, tokenizer, max_seq_length, config, num_of_processes=8): + indices = np.array_split(data.index, num_of_processes) + data_split = [data.iloc[idx] for idx in indices] + with Pool(num_of_processes) as pool: + data = pd.concat(pool.map(partial(process_sent_batch, tokenizer=tokenizer, max_seq_length=max_seq_length, config=config), data_split)) + return data + + +def tokenize_dataset(tokenizer_type, root_dir='training_data'): + config = get_tokenizer_config(tokenizer_type) + tokenizer, max_seq_length = initialize_tokenizer(config) + + output_dir = config['output_dir'] + os.makedirs(output_dir, exist_ok=True) + + for ds_name in tqdm(sorted(os.listdir(root_dir))): + print(ds_name, flush=True) + + df = pd.read_parquet(f"{root_dir}/{ds_name}") + df['query_input_ids'] = parallelize( + df['query'], + tokenizer, + max_seq_length, + config, + 62 + ) + + num_neg = 24 if 'negative_2' in df.keys() else 1 + + ls = df.passage.to_list() + for i in range(1, num_neg+1): + ls += df[f'negative_{i}'].to_list() + ls = list(set(ls)) + df_tmp = pd.DataFrame({'text': ls}) + df_tmp['input_ids'] = parallelize( + df_tmp['text'], + tokenizer, + max_seq_length, + config, + 62 + ) + df_tmp = df_tmp.set_index('text') + + df['passage_input_ids'] = df.passage.map(df_tmp.input_ids) + + for i in range(1, num_neg+1): + df[f'negative_{i}_input_ids'] = df[f'negative_{i}'].map(df_tmp.input_ids) + + df.to_parquet(f'{output_dir}/{ds_name}', index=False) + + +def main(): + parser = argparse.ArgumentParser(description='Tokenize datasets using BERT or Qwen tokenizer') + parser.add_argument('--tokenizer', type=str, choices=['bert', 'qwen'], default='bert', + help='Tokenizer type to use (default: bert)') + parser.add_argument('--data_dir', type=str, default='training_data', + help='Directory containing the training data') + + args = parser.parse_args() + + tokenize_dataset(args.tokenizer, args.data_dir) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/F2LLM/utils.py b/F2LLM/utils.py index b167d3c..31d6b84 100644 --- a/F2LLM/utils.py +++ b/F2LLM/utils.py @@ -5,9 +5,12 @@ from torch.nn import CrossEntropyLoss import os +from transformers import AutoConfig + CLASSIFICATION_DATASETS = ['amazon_counterfactual', 'amazon_polarity', 'imdb', 'toxic_conversations', 'cola'] CLUSTERING_DATASETS = ['amazon_reviews', 'banking77', 'emotion', 'mtop_intent', 'mtop_domain', 'massive_scenario', 'massive_intent', 'tweet_sentiment_extraction', 'arxiv_clustering_p2p', 'arxiv_clustering_s2s', 'biorxiv_clustering_p2p', 'biorxiv_clustering_s2s', 'medrxiv_clustering_p2p', 'medrxiv_clustering_s2s', 'reddit_clustering_p2p', 'reddit_clustering_s2s', 'stackexchange_clustering_p2p', 'stackexchange_clustering_s2s', 'twentynewsgroups'] RETRIEVAL_DATASETS = ['arguana', 'snli', 'mnli', 'anli', 'paq', 'squad', 'stackexchange', 'msmarco', 'natural_questions', 'hotpotqa', 'fever', 'eli5', 'fiqa', 'bioasq', 'nfcorpus', 'miracl', 'mrtidy', 'scifact', 'qqp', 'stackoverflowdupquestions', 'sts12', 'sts22', 'stsbenchmark', 'amazon_qa', 'cnn_dm', 'coliee', 'paq_part2', 'pubmedqa', 's2orc_abstract_citation', 's2orc_title_abstract', 's2orc_title_citation', 'sentence_compression', 'specter', 'triviaqa', 'xsum', 'stackexchange_part2', 'stackexchangedupquestions_s2s', 'stackexchangedupquestions_p2p'] +ENCODER_ONLY_INDICATORS = ['bert', 'electra', 'mpnet'] def write_tensorboard(summary_writer: SummaryWriter, log_dict: dict, completed_steps): @@ -224,4 +227,70 @@ def accelerate_train(args, model.lm.train() if summary_writer: - summary_writer.close() \ No newline at end of file + summary_writer.close() + + +def detect_model_type(model_path): + try: + config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + except Exception as e: + # If we can't load the config, default to decoder-only for backward compatibility + return 'decoder_only' + + model_name = model_path.split('/')[-1].lower() + + if any(indicator in model_name for indicator in ENCODER_ONLY_INDICATORS): + return 'encoder_only' + + return 'decoder_only' + + +def extract_cls_embeddings(batch_size, num_hard_neg, last_hidden_state, batch): + features = {} + features['query_passage_features'] = last_hidden_state[0:batch_size, 0, :].unsqueeze(1) + features['passage_passage_features'] = last_hidden_state[batch_size:2*batch_size, 0, :].unsqueeze(1) + features['negative_passage_features'] = ( + last_hidden_state[2*batch_size:, 0, :].view(batch_size, num_hard_neg, -1) + if num_hard_neg > 0 else None + ) + return features + + +def extract_mean_pooling_embeddings(batch_size, num_hard_neg, last_hidden_state, batch): + # Apply mean pooling + attention_mask = batch['attention_mask'] + input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float() + sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1) + sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) + mean_pooled = sum_embeddings / sum_mask + + # Extract features + features = {} + features['query_passage_features'] = mean_pooled[0:batch_size, :].unsqueeze(1) + features['passage_passage_features'] = mean_pooled[batch_size:2*batch_size, :].unsqueeze(1) + features['negative_passage_features'] = ( + mean_pooled[2*batch_size:, :].view(batch_size, num_hard_neg, -1) + if num_hard_neg > 0 else None + ) + + return features + + +def extract_last_token_embeddings(batch_size, num_hard_neg, last_hidden_state, batch): + # Extract features using the last token for each sequence + features = {} + features['query_passage_features'] = extract_last_token_features(last_hidden_state, batch, 0, batch_size) + features['passage_passage_features'] = extract_last_token_features(last_hidden_state, batch, batch_size, 2 * batch_size) + features['negative_passage_features'] = ( + extract_last_token_features(last_hidden_state, batch, 2 * batch_size, len(batch['seq_lens'])).view(batch_size, num_hard_neg, -1) + if num_hard_neg != 0 else None + ) + + return features + + +def extract_last_token_features(hidden_states, batch, start_idx, end_idx): + return torch.stack([ + hidden_states[i, [batch['seq_lens'][i] - 1]] + for i in range(start_idx, end_idx) + ]) From 3613fb59df3559ca1c7836982d9708f6987795af Mon Sep 17 00:00:00 2001 From: "qingnang.lh" Date: Sat, 6 Dec 2025 18:34:24 +0800 Subject: [PATCH 2/2] Test model training with flash_attn==2.6.0 and falsh_attn==2.3.6 --- F2LLM/configs/accelerate_config.yaml | 2 +- F2LLM/model.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/F2LLM/configs/accelerate_config.yaml b/F2LLM/configs/accelerate_config.yaml index 5133305..aa65b01 100644 --- a/F2LLM/configs/accelerate_config.yaml +++ b/F2LLM/configs/accelerate_config.yaml @@ -11,7 +11,7 @@ distributed_type: DEEPSPEED downcast_bf16: "no" machine_rank: 0 main_training_function: main -mixed_precision: "bf16" +mixed_precision: "no" # "no" for encoder-only model with flash attention 2 num_machines: 1 num_processes: 8 rdzv_backend: static diff --git a/F2LLM/model.py b/F2LLM/model.py index f1992c8..e63b384 100644 --- a/F2LLM/model.py +++ b/F2LLM/model.py @@ -18,14 +18,14 @@ def __init__(self, self.model_type = detect_model_type(model_path) flash_attn_version = getattr(flash_attn, '__version__', '0.0.0') - no_support_deterministic = version.parse(flash_attn_version) < version.parse("2.4.1") + no_support_deterministic = version.parse(flash_attn_version) < version.parse("2.6.0") if self.model_type == 'encoder_only' and no_support_deterministic: attn_implementation = 'eager' else: attn_implementation = 'flash_attention_2' - print(f"{self.model_type}") + print(f"Model Type: {self.model_type} | Attention Implementation: {attn_implementation}") self.lm = AutoModel.from_pretrained( model_path, trust_remote_code=True,