Skip to content
This repository was archived by the owner on Nov 19, 2025. It is now read-only.
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 0 additions & 23 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -121,26 +121,3 @@ RUN cd /opt/NeMo-Aligner && \

RUN cd TensorRT-LLM && patch -p1 < ../NeMo-Aligner/setup/trtllm.patch

# TODO(terryk): This layer should be deleted ASAP after NeMo is bumped to include all of these PRs
RUN <<"EOF" bash -exu
cd NeMo
# Ensures we don't cherry-pick "future" origin/main commits
git fetch -a
# 0c92fe17df4642ffc33d5d8c0c83fda729e3910c: [fix] Ensures disabling exp_manager with exp_manager=null does not error NeMo#10651
# 60e677423667c029dd05875da72bf0719774f844: [feat] Update get_model_parallel_src_rank to support tp-pp-dp ordering NeMo#10652
# 0deaf6716cb4f20766c995ce25d129795f1ae200: fix[export]: update API for disabling device reassignment in TRTLLM for Aligner NeMo#10863
# (superceded by 10863) 148543d6e9c66ff1f8562e84484448202249811d: feat: Migrate GPTSession refit path in Nemo export to ModelRunner for Aligner NeMo#10654
for pr_and_commit in \
"10651 0c92fe17df4642ffc33d5d8c0c83fda729e3910c" \
"10652 60e677423667c029dd05875da72bf0719774f844" \
"10863 0deaf6716cb4f20766c995ce25d129795f1ae200" \
; do
pr=$(cut -f1 -d' ' <<<"$pr_and_commit")
head_pr_commit=$(cut -f2 -d' ' <<<"$pr_and_commit")
git fetch origin $head_pr_commit:PR-${pr}
# cherry-picks all commits between main and the top of the PR
git cherry-pick --allow-empty $(git merge-base origin/main PR-${pr})..PR-${pr}
# Tag cherry-picks to help
git tag cherry-pick-PR-${pr}
done
EOF
1 change: 1 addition & 0 deletions examples/nlp/gpt/conf/gpt_dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ model:
micro_batch_size: 1
global_batch_size: 64
megatron_amp_O2: True
mamba_hybrid: False

dpo:
# This default value ensures there are no numeric differences beween trained and reference policies when computing log probs.
Expand Down
2 changes: 1 addition & 1 deletion examples/nlp/gpt/conf/gpt_sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ model:
output_original_text: True # needed for the proper metrics support

optim:
name: distributed_fused_adam # Supports distributed optimizer for memory savings. To enable, set to 'distributed_fused_adam'. Needs Apex to be built with specific args to work.
name: fused_adam # Supports distributed optimizer for memory savings. To enable, set to 'distributed_fused_adam'. Needs Apex to be built with specific args to work.
lr: 3e-5
weight_decay: 0.01
betas:
Expand Down
4 changes: 2 additions & 2 deletions examples/nlp/gpt/train_gpt_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from nemo.utils.exp_manager import exp_manager
from nemo_aligner.algorithms.dpo import DPOTrainer, dpo_custom_collate
from nemo_aligner.data.nlp.builders import build_dataloader, build_train_valid_test_dpo_datasets, identity_collate
from nemo_aligner.models.nlp.gpt.megatron_gpt_dpo_model import MegatronGPTDPOModel
from nemo_aligner.models.nlp.gpt.megatron_gpt_dpo_model import MegatronGPTDPOModel, MegatronMambaDPOModel
from nemo_aligner.utils.distributed import Timer
from nemo_aligner.utils.train_script_utils import (
CustomLoggerWrapper,
Expand Down Expand Up @@ -53,7 +53,7 @@ def main(cfg) -> None:
logger = CustomLoggerWrapper(trainer.loggers)

ptl_model = load_from_nemo(
MegatronGPTDPOModel,
MegatronMambaDPOModel if cfg.model.mamba_hybrid else MegatronGPTDPOModel,
cfg.model,
trainer,
strict=True,
Expand Down
83 changes: 9 additions & 74 deletions examples/nlp/gpt/train_gpt_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from nemo.utils.exp_manager import exp_manager
from nemo_aligner.algorithms.supervised import SupervisedTrainer
from nemo_aligner.data.nlp.builders import build_dataloader, build_sft_dataset
from nemo_aligner.models.nlp.gpt.gpt_sft_model import GPTSFTModel
from nemo_aligner.models.nlp.gpt.gpt_sft_model import GPTSFTModel, MambaSFTModel
from nemo_aligner.utils.distributed import Timer
from nemo_aligner.utils.train_script_utils import (
CustomLoggerWrapper,
Expand All @@ -39,7 +39,7 @@
resolve_and_create_trainer,
retrieve_custom_trainer_state_dict,
)
from nemo_aligner.utils.utils import load_from_nemo
from nemo_aligner.utils.utils import load_and_override_model_config, load_from_nemo

"""Script to start SFT training"""

Expand All @@ -49,75 +49,10 @@
mp.set_start_method("spawn", force=True)


def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False):
"""
This function modifies the original gpt pre-training config (gpt_cfg) with attributes from the finetuning config (cfg).
The `add_cfg_to_tree` arg adds `cfg` to the top of the yaml tree which is needed for all `hparams.yaml` files when passed as an arg to `load_from_checkpoint()`.
"""
OmegaConf.set_struct(gpt_cfg, True)
OmegaConf.resolve(cfg)
with open_dict(gpt_cfg):
gpt_cfg.megatron_amp_O2 = cfg.model.get("megatron_amp_O2", False)
gpt_cfg.micro_batch_size = cfg.model.data.train_ds.micro_batch_size
gpt_cfg.global_batch_size = cfg.model.data.train_ds.global_batch_size
gpt_cfg.sequence_parallel = cfg.model.get("sequence_parallel", False)
gpt_cfg.activations_checkpoint_granularity = cfg.model.get("activations_checkpoint_granularity", None)
gpt_cfg.activations_checkpoint_num_layers = cfg.model.get("activations_checkpoint_num_layers", None)
gpt_cfg.activations_checkpoint_method = cfg.model.get("activations_checkpoint_method", None)
gpt_cfg.activations_checkpoint_layers_per_pipeline = cfg.model.get(
"activations_checkpoint_layers_per_pipeline", None
)
gpt_cfg.peft = cfg.model.peft
gpt_cfg.data = cfg.model.data
gpt_cfg.optim = cfg.model.optim
gpt_cfg.precision = cfg.trainer.precision
gpt_cfg.answer_only_loss = cfg.model.answer_only_loss
gpt_cfg.restore_from_path = cfg.model.restore_from_path
gpt_cfg.resume_from_checkpoint = cfg.model.resume_from_checkpoint
gpt_cfg.save_nemo_on_validation_end = cfg.model.save_nemo_on_validation_end
gpt_cfg.gradient_as_bucket_view = cfg.model.gradient_as_bucket_view
gpt_cfg.hidden_dropout = cfg.model.get("hidden_dropout", 0.0)
gpt_cfg.attention_dropout = cfg.model.get("attention_dropout", 0.0)
gpt_cfg.ffn_dropout = cfg.model.ffn_dropout
gpt_cfg.use_flash_attention = cfg.model.get("use_flash_attention", False)
# if TP/PP size is -1, use default TP/PP size as original model
if cfg.model.get("tensor_model_parallel_size", 1) > 0:
gpt_cfg.tensor_model_parallel_size = cfg.model.get("tensor_model_parallel_size", 1)
if cfg.model.get("pipeline_model_parallel_size", 1) > 0:
gpt_cfg.pipeline_model_parallel_size = cfg.model.get("pipeline_model_parallel_size", 1)
gpt_cfg.pipeline_model_parallel_split_rank = cfg.model.get("pipeline_model_parallel_split_rank", 0)

if cfg.model.data.get("chat", False):
# chat model, overwrite the prompt template
prompt_template = get_prompt_template_example(cfg.model.data.chat_prompt_tokens)
gpt_cfg.data.train_ds.prompt_template = prompt_template
gpt_cfg.data.validation_ds.prompt_template = prompt_template

sft_cls = GPTSFTModel
gpt_cfg.target = f"{sft_cls.__module__}.{sft_cls.__name__}"

if cfg.model.get("use_flash_attention", None) is not None:
gpt_cfg.use_flash_attention = cfg.model.use_flash_attention

if cfg.model.get("seq_len_interpolation_factor", None) is not None:
gpt_cfg.seq_len_interpolation_factor = cfg.model.seq_len_interpolation_factor

if cfg.model.get("dist_ckpt_load_strictness", None) is not None:
gpt_cfg.dist_ckpt_load_strictness = cfg.model.dist_ckpt_load_strictness

gpt_cfg.inference = cfg.model.get("inference", {})

# This is needed when modifying a hparam file directly to load `.ckpt` files.
# This is not needed to modify the cfg in `.nemo` files.
if add_cfg_to_tree:
OmegaConf.resolve(gpt_cfg)
gpt_cfg.cfg = gpt_cfg

return gpt_cfg


@hydra_runner(config_path="conf", config_name="gpt_sft")
def main(cfg) -> None:
cfg.model = load_and_override_model_config(cfg.model.restore_from_path, cfg.model)

logging.info("\n\n************** Experiment configuration ***********")
logging.info(f"\n{OmegaConf.to_yaml(cfg)}")

Expand All @@ -129,17 +64,15 @@ def main(cfg) -> None:
with open_dict(cfg):
cfg.model.precision = cfg.trainer.precision

ptl_model, updated_cfg = load_from_nemo(
GPTSFTModel,
ptl_model = load_from_nemo(
MambaSFTModel if cfg.model.get("mamba_hybrid", False) else GPTSFTModel,
cfg,
trainer,
strict=True,
modify_config_fn=_modify_config,
restore_path=cfg.model.restore_from_path,
return_updated_cfg=True,
)

init_peft(ptl_model, updated_cfg)
init_peft(ptl_model, cfg.model)

with open_dict(cfg):
# overwrite the model config with the config from the checkpoint
Expand Down Expand Up @@ -173,6 +106,7 @@ def main(cfg) -> None:
train_data_cfg,
ptl_model.tokenizer,
num_samples,
is_mamba=cfg.model.get("mamba_hybrid", False),
answer_only_loss=True,
is_chat=cfg.model.data.chat,
special_tokens=cfg.model.data.chat_prompt_tokens,
Expand All @@ -185,6 +119,7 @@ def main(cfg) -> None:
val_data_cfg,
ptl_model.tokenizer,
num_samples,
is_mamba=cfg.model.get("mamba_hybrid", False),
answer_only_loss=True,
is_chat=cfg.model.data.chat,
special_tokens=cfg.model.data.chat_prompt_tokens,
Expand Down
6 changes: 5 additions & 1 deletion nemo_aligner/data/nlp/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,9 @@ def build_dataset(index, name):
)


def build_sft_dataset(data_cfg, tokenizer, num_samples, answer_only_loss=True, is_chat=True, special_tokens=None):
def build_sft_dataset(
data_cfg, tokenizer, num_samples, is_mamba=False, answer_only_loss=True, is_chat=True, special_tokens=None
):
packed_sequence = data_cfg.get("packed_sequence", False)
dataset_kwargs = {}

Expand Down Expand Up @@ -411,9 +413,11 @@ def build_sft_dataset(data_cfg, tokenizer, num_samples, answer_only_loss=True, i
answer_only_loss=answer_only_loss,
truncation_field=data_cfg.get("truncation_field", "text"),
pad_to_max_length=data_cfg.get("pad_to_max_length", False),
pad_seq_length_to_mult=256 if is_mamba else 16,
index_mapping_dir=data_cfg.get("index_mapping_dir", None),
prompt_template=data_cfg.get("prompt_template", None),
virtual_tokens=0,
meta_tokens=data_cfg.get("meta_tokens", 0),
memmap_workers=data_cfg.get(
"memmap_workers", None
), # used to set num. of workers to create the memmap index files
Expand Down
103 changes: 95 additions & 8 deletions nemo_aligner/data/nlp/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,19 @@
"""Custom datasets for RLHF training"""

import os
from typing import Dict, List

import numpy as np
import scipy
import torch
from omegaconf import OmegaConf

from nemo.collections.nlp.data.language_modeling.megatron.gpt_dataset import _create_ltor_masks_and_position_ids
from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_chat_dataset import GPTSFTChatDataset
from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_chat_dataset import (
GPTSFTChatDataset,
_get_header_conversation_type_mask_role,
get_prompt_template_example,
)
from nemo.core import Dataset
from nemo.utils import logging

Expand Down Expand Up @@ -344,16 +350,97 @@ def encode(self, text, append_eod=False):

return text_ids, len(text_ids)

@staticmethod
def _convert_messages(
input_list: List[Dict[str, str]]
) -> Dict: # TODO: (@adithyare) this method should live elsewhare..
"""
args:
input_list: is a list of dicts in the openai format
for example:
[{"role": "system", "content": "you are helpful},
{"role": "user", "content": "Why is the sky blue?"},
{"role": "assistant", "content": "Because blablabla"},
...]
returns:
output_dict: a dict in nemo's format {"system": "sytem prompt",
"conversation": [],
...
}
"""
output_dict = {
"system": "",
"conversations": [],
"mask": "User",
"type": "VALUE_TO_TEXT",
}

# Extract the system message
num_system_msg = 0
for msg in input_list:
if msg["role"] == "system":
output_dict["system"] = msg["content"]
num_system_msg += 1
if num_system_msg > 1:
raise RuntimeError("Multiple system messages seen, please consolidate into a single system message.")

# Build the conversations list
for msg in input_list:
if msg["role"] != "system":
conversation_entry = {
"from": msg["role"].capitalize(), # Capitalize 'user' and 'assistant'
"value": msg["content"],
"label": None,
}
output_dict["conversations"].append(conversation_entry)

return output_dict

def convert(self, messages):
"""
args:
messages: is a list of dicts in the openai format
for example:
[{"role": "system", "content": "you are helpful},
{"role": "user", "content": "Why is the sky blue?"},
{"role": "assistant", "content": "Because blablabla"},
...]
returns:
conversation: is a string formatted with the chat template
"""
if OmegaConf.select(self.cfg, "data.chat_prompt_tokens") is None:
raise RuntimeError(
"You don't have a model (model_config.yaml) which has chat_prompt_tokens, are you sure this is a Chat/Instruction model?"
)
special_tokens = self.cfg.data.chat_prompt_tokens
nemo_source = self._convert_messages(messages)
header, conversation, data_type, mask_role = _get_header_conversation_type_mask_role(
nemo_source, special_tokens
)
return conversation

def __getitem__(self, idx):
"""Returns a pair of chosen/rejected pairs, their respective lengths, and labels."""
payload = self.data[idx]
prompt, prompt_len = self.encode(payload["prompt"], append_eod=False)
chosen, chosen_len = self.encode(
payload["prompt"] + payload["chosen_response"], append_eod=self.cfg.data.get("append_eod", False)
)
reject, reject_len = self.encode(
payload["prompt"] + payload["rejected_response"], append_eod=self.cfg.data.get("append_eod", False)
)

if isinstance(payload["prompt"], str):
# (@adithyare) format with hardcoded chat tokens
# will allow this for the time being.
prompt_fmtd = payload["prompt"]
chosen_fmtd = payload["prompt"] + payload["chosen_response"]
rejected_fmtd = payload["prompt"] + payload["rejected_response"]
logging.warning(
"Pre-formatting chat conversation as string with hardcoded chat tokens will be deprecated."
) # (@adithyare) this will spam the console for now.
else:
prompt_fmtd = self.convert(payload["prompt"]) # (@adithyare) read var as "prompt formatted"
chosen_fmtd = self.convert(payload["prompt"] + [payload["chosen_response"]])
rejected_fmtd = self.convert(payload["prompt"] + [payload["rejected_response"]])

prompt, prompt_len = self.encode(prompt_fmtd, append_eod=False)
chosen, chosen_len = self.encode(chosen_fmtd, append_eod=self.cfg.data.get("append_eod", False))
reject, reject_len = self.encode(rejected_fmtd, append_eod=self.cfg.data.get("append_eod", False))

# chosen_response_only, chosen_response_len = self.encode(payload['chosen_response'])
# reject_response_only, reject_response_len = self.encode(payload['rejected_response'])
chosen_labels = ([-100] * prompt_len) + chosen[prompt_len:]
Expand Down
Loading
Loading