Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion F2LLM/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ In this repo we provide a streamlined and efficient script for training embeddin

- Setup environment following `requirements.txt`. We note that transformers>=4.51.0 is required for training Qwen3 models.
- Download data and backbone models from Hugging Face (we use Qwen3 models).
- Run `tokenize_data_qwen.py` to tokenize the downloaded data
- Run `tokenize_data.py --tokenizer qwen` to tokenize the downloaded data
- Modify model path, data path, and other arguments in `configs/config.json`.
- Start training with `accelerate launch --config_file configs/accelerate_config.yaml run.py --config configs/config.json`.

Expand Down
2 changes: 1 addition & 1 deletion F2LLM/configs/accelerate_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ distributed_type: DEEPSPEED
downcast_bf16: "no"
machine_rank: 0
main_training_function: main
mixed_precision: "bf16"
mixed_precision: "no" # "no" for encoder-only model with flash attention 2
num_machines: 1
num_processes: 8
rdzv_backend: static
Expand Down
19 changes: 19 additions & 0 deletions F2LLM/configs/bert_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
"model_path": "models/bert_multilingual",
"experiment_id": "bert_multilingual+lr.2e-5+bs.32+context.512+3epochs",
"train_data_path": "training_data/data_tokenized_bert",
"output_dir": "output",
"tb_dir": "output/tb",
"cache_dir": "cache",
"train_batch_size": 32,
"checkpointing_steps": 5000,
"validation_steps": 5000,
"max_seq_length": 512,
"learning_rate": 2e-5,
"min_lr": 1e-7,
"weight_decay": 0.01,
"warmup_steps": 500,
"train_epochs": 3,
"log_interval": 100,
"num_hard_neg": 7
}
46 changes: 36 additions & 10 deletions F2LLM/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import torch
from transformers import AutoModel, AutoTokenizer
from utils import detect_model_type, extract_cls_embeddings, extract_mean_pooling_embeddings, extract_last_token_embeddings
import flash_attn
from packaging import version


class F2LLM:
Expand All @@ -12,11 +15,34 @@ def __init__(self,
self.args = args
self.dtype = torch.bfloat16
self.device = None # set after accelerator.prepare
self.lm = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=self.dtype, attn_implementation='flash_attention_2')
self.model_type = detect_model_type(model_path)

flash_attn_version = getattr(flash_attn, '__version__', '0.0.0')
no_support_deterministic = version.parse(flash_attn_version) < version.parse("2.6.0")

if self.model_type == 'encoder_only' and no_support_deterministic:
attn_implementation = 'eager'
else:
attn_implementation = 'flash_attention_2'

print(f"Model Type: {self.model_type} | Attention Implementation: {attn_implementation}")
self.lm = AutoModel.from_pretrained(
model_path,
trust_remote_code=True,
torch_dtype=self.dtype,
attn_implementation=attn_implementation
)

self.lm.config.use_cache = False
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.max_seq_length = max_seq_length

self.use_cls_pooling = self._should_use_cls_pooling()

def _should_use_cls_pooling(self):
if self.model_type != 'encoder_only':
return False
return not hasattr(self.lm, 'pooler') or self.lm.pooler is None

def set_device(self):
self.device = self.lm.device

Expand All @@ -27,11 +53,11 @@ def forward(self, batch):
outputs = self.lm(batch['input_ids'],
batch['attention_mask'],
)

passage_features_all_tokens = outputs.last_hidden_state
return {
'query_passage_features': torch.stack([passage_features_all_tokens[i, [batch['seq_lens'][i]-1]] for i in range(bs)]),
'passage_passage_features': torch.stack([passage_features_all_tokens[i, [batch['seq_lens'][i]-1]] for i in range(bs, 2*bs)]),
'negative_passage_features': None if num_hard_neg == 0 else torch.stack([passage_features_all_tokens[i, [batch['seq_lens'][i]-1]] for i in range(2*bs, len(batch['seq_lens']))]).view(bs, num_hard_neg, -1)
}

if self.model_type == 'decoder_only':
return extract_last_token_embeddings(bs, num_hard_neg, outputs.last_hidden_state, batch)
else:
if self.use_cls_pooling:
return extract_cls_embeddings(bs, num_hard_neg, outputs.last_hidden_state, batch)
else:
return extract_mean_pooling_embeddings(bs, num_hard_neg, outputs.last_hidden_state, batch)
123 changes: 123 additions & 0 deletions F2LLM/tokenize_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from multiprocessing import Pool
from functools import partial
import numpy as np
import pandas as pd
import os
import argparse
from transformers import AutoTokenizer
from tqdm.auto import tqdm


def get_tokenizer_config(tokenizer_type):
if tokenizer_type == 'bert':
return {
'model_path': 'models/bert_multilingual',
'max_seq_length': 512,
'add_special_tokens': True,
'add_eos_token': False,
'output_dir': 'data_tokenized_bert'
}
elif tokenizer_type == 'qwen':
return {
'model_path': 'models/qwen3-0.6b',
'max_seq_length': 1023,
'add_special_tokens': False,
'add_eos_token': True,
'output_dir': 'data_tokenized_qwen'
}
else:
raise ValueError(f"Unsupported tokenizer type: {tokenizer_type}")


def initialize_tokenizer(config):
tokenizer = AutoTokenizer.from_pretrained(config['model_path'])
max_seq_length = config['max_seq_length']
return tokenizer, max_seq_length


def process_sent(sentence, tokenizer, max_seq_length, config):
# Process a single sentence with the specified tokenizer and configuration
tokenizer_outputs = tokenizer(
sentence,
max_length=max_seq_length,
truncation=True,
add_special_tokens=config['add_special_tokens']
)

input_ids = tokenizer_outputs.input_ids

# Add EOS token for Qwen if required
if config['add_eos_token']:
input_ids = input_ids + [tokenizer.eos_token_id]

return np.array(input_ids)


def process_sent_batch(s, tokenizer, max_seq_length, config):
return s.apply(lambda x: process_sent(x, tokenizer, max_seq_length, config))

def parallelize(data, tokenizer, max_seq_length, config, num_of_processes=8):
indices = np.array_split(data.index, num_of_processes)
data_split = [data.iloc[idx] for idx in indices]
with Pool(num_of_processes) as pool:
data = pd.concat(pool.map(partial(process_sent_batch, tokenizer=tokenizer, max_seq_length=max_seq_length, config=config), data_split))
return data


def tokenize_dataset(tokenizer_type, root_dir='training_data'):
config = get_tokenizer_config(tokenizer_type)
tokenizer, max_seq_length = initialize_tokenizer(config)

output_dir = config['output_dir']
os.makedirs(output_dir, exist_ok=True)

for ds_name in tqdm(sorted(os.listdir(root_dir))):
print(ds_name, flush=True)

df = pd.read_parquet(f"{root_dir}/{ds_name}")
df['query_input_ids'] = parallelize(
df['query'],
tokenizer,
max_seq_length,
config,
62
)

num_neg = 24 if 'negative_2' in df.keys() else 1

ls = df.passage.to_list()
for i in range(1, num_neg+1):
ls += df[f'negative_{i}'].to_list()
ls = list(set(ls))
df_tmp = pd.DataFrame({'text': ls})
df_tmp['input_ids'] = parallelize(
df_tmp['text'],
tokenizer,
max_seq_length,
config,
62
)
df_tmp = df_tmp.set_index('text')

df['passage_input_ids'] = df.passage.map(df_tmp.input_ids)

for i in range(1, num_neg+1):
df[f'negative_{i}_input_ids'] = df[f'negative_{i}'].map(df_tmp.input_ids)

df.to_parquet(f'{output_dir}/{ds_name}', index=False)


def main():
parser = argparse.ArgumentParser(description='Tokenize datasets using BERT or Qwen tokenizer')
parser.add_argument('--tokenizer', type=str, choices=['bert', 'qwen'], default='bert',
help='Tokenizer type to use (default: bert)')
parser.add_argument('--data_dir', type=str, default='training_data',
help='Directory containing the training data')

args = parser.parse_args()

tokenize_dataset(args.tokenizer, args.data_dir)


if __name__ == '__main__':
main()
71 changes: 70 additions & 1 deletion F2LLM/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
from torch.nn import CrossEntropyLoss
import os

from transformers import AutoConfig

CLASSIFICATION_DATASETS = ['amazon_counterfactual', 'amazon_polarity', 'imdb', 'toxic_conversations', 'cola']
CLUSTERING_DATASETS = ['amazon_reviews', 'banking77', 'emotion', 'mtop_intent', 'mtop_domain', 'massive_scenario', 'massive_intent', 'tweet_sentiment_extraction', 'arxiv_clustering_p2p', 'arxiv_clustering_s2s', 'biorxiv_clustering_p2p', 'biorxiv_clustering_s2s', 'medrxiv_clustering_p2p', 'medrxiv_clustering_s2s', 'reddit_clustering_p2p', 'reddit_clustering_s2s', 'stackexchange_clustering_p2p', 'stackexchange_clustering_s2s', 'twentynewsgroups']
RETRIEVAL_DATASETS = ['arguana', 'snli', 'mnli', 'anli', 'paq', 'squad', 'stackexchange', 'msmarco', 'natural_questions', 'hotpotqa', 'fever', 'eli5', 'fiqa', 'bioasq', 'nfcorpus', 'miracl', 'mrtidy', 'scifact', 'qqp', 'stackoverflowdupquestions', 'sts12', 'sts22', 'stsbenchmark', 'amazon_qa', 'cnn_dm', 'coliee', 'paq_part2', 'pubmedqa', 's2orc_abstract_citation', 's2orc_title_abstract', 's2orc_title_citation', 'sentence_compression', 'specter', 'triviaqa', 'xsum', 'stackexchange_part2', 'stackexchangedupquestions_s2s', 'stackexchangedupquestions_p2p']
ENCODER_ONLY_INDICATORS = ['bert', 'electra', 'mpnet']


def write_tensorboard(summary_writer: SummaryWriter, log_dict: dict, completed_steps):
Expand Down Expand Up @@ -224,4 +227,70 @@ def accelerate_train(args,
model.lm.train()

if summary_writer:
summary_writer.close()
summary_writer.close()


def detect_model_type(model_path):
try:
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
except Exception as e:
# If we can't load the config, default to decoder-only for backward compatibility
return 'decoder_only'

model_name = model_path.split('/')[-1].lower()

if any(indicator in model_name for indicator in ENCODER_ONLY_INDICATORS):
return 'encoder_only'

return 'decoder_only'


def extract_cls_embeddings(batch_size, num_hard_neg, last_hidden_state, batch):
features = {}
features['query_passage_features'] = last_hidden_state[0:batch_size, 0, :].unsqueeze(1)
features['passage_passage_features'] = last_hidden_state[batch_size:2*batch_size, 0, :].unsqueeze(1)
features['negative_passage_features'] = (
last_hidden_state[2*batch_size:, 0, :].view(batch_size, num_hard_neg, -1)
if num_hard_neg > 0 else None
)
return features


def extract_mean_pooling_embeddings(batch_size, num_hard_neg, last_hidden_state, batch):
# Apply mean pooling
attention_mask = batch['attention_mask']
input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
mean_pooled = sum_embeddings / sum_mask

# Extract features
features = {}
features['query_passage_features'] = mean_pooled[0:batch_size, :].unsqueeze(1)
features['passage_passage_features'] = mean_pooled[batch_size:2*batch_size, :].unsqueeze(1)
features['negative_passage_features'] = (
mean_pooled[2*batch_size:, :].view(batch_size, num_hard_neg, -1)
if num_hard_neg > 0 else None
)

return features


def extract_last_token_embeddings(batch_size, num_hard_neg, last_hidden_state, batch):
# Extract features using the last token for each sequence
features = {}
features['query_passage_features'] = extract_last_token_features(last_hidden_state, batch, 0, batch_size)
features['passage_passage_features'] = extract_last_token_features(last_hidden_state, batch, batch_size, 2 * batch_size)
features['negative_passage_features'] = (
extract_last_token_features(last_hidden_state, batch, 2 * batch_size, len(batch['seq_lens'])).view(batch_size, num_hard_neg, -1)
if num_hard_neg != 0 else None
)

return features


def extract_last_token_features(hidden_states, batch, start_idx, end_idx):
return torch.stack([
hidden_states[i, [batch['seq_lens'][i] - 1]]
for i in range(start_idx, end_idx)
])