Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,7 +908,9 @@ def _init_padded_req(self):
def _gen_special_model_input(self, token_num: int):
special_model_input = {}

is_deepseekv3_mtp_draft_model = "Deepseek3MTPModel" in str(self.__class__)
is_deepseekv3_mtp_draft_model = "Deepseek3MTPModel" in str(self.__class__) or "Qwen3MOEMTPModel" in str(
self.__class__
)
if is_deepseekv3_mtp_draft_model:
Comment on lines +911 to 914
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable is_deepseekv3_mtp_draft_model is now misleading as it also checks for Qwen3MOEMTPModel. It should be renamed to something more generic, like is_mtp_draft_model, to accurately reflect its purpose.

Suggested change
is_deepseekv3_mtp_draft_model = "Deepseek3MTPModel" in str(self.__class__) or "Qwen3MOEMTPModel" in str(
self.__class__
)
if is_deepseekv3_mtp_draft_model:
is_mtp_draft_model = "Deepseek3MTPModel" in str(self.__class__) or "Qwen3MOEMTPModel" in str(
self.__class__
)
if is_mtp_draft_model:

special_model_input["deepseekv3_mtp_draft_input_hiddens"] = torch.randn(
token_num, self.config["hidden_size"], dtype=self.data_type, device="cuda"
Expand Down
15 changes: 11 additions & 4 deletions lightllm/models/llama/flashattention_infer_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,22 +38,29 @@ def _init_flash_attention_state(self, model, input_ids: torch.Tensor):
self.cu_seqlens_q = self.b1_cu_q_seq_len.int()
self.cu_seqlens_k = self.b1_cu_kv_seq_len.int()
max_seq_len_k = self.max_kv_seq_len
args_mtp_step = get_env_start_args().mtp_step
att_batch_size = self.batch_size // (args_mtp_step + 1)
if self.batch_size <= model.graph_max_batch_size and self.max_len_in_batch <= model.graph_max_len_in_batch:
page_buffer = FlashAttentionStateInfo.get_page_table_buffer(
model.graph_max_batch_size, model.graph_max_len_in_batch
)
self.page_table = page_buffer[self.microbatch_index][
: self.batch_size * model.graph_max_len_in_batch
].reshape(self.batch_size, model.graph_max_len_in_batch)
: att_batch_size * model.graph_max_len_in_batch
].reshape(att_batch_size, model.graph_max_len_in_batch)
else:
self.page_table = torch.empty(
(self.batch_size, self.max_len_in_batch), dtype=torch.int32, device=input_ids.device
(att_batch_size, self.max_len_in_batch), dtype=torch.int32, device=input_ids.device
)

page_table_copy(
page_table=self.page_table[:, :max_seq_len_k],
req_to_token_indexs=model.req_manager.req_to_token_indexs,
b_req_idx=self.b_req_idx,
b_req_idx=self.b_req_idx[args_mtp_step :: (args_mtp_step + 1)],
)
if args_mtp_step > 0:
self.b_att_seq_len = self.b_seq_len[args_mtp_step :: (args_mtp_step + 1)].contiguous()
else:
self.b_att_seq_len = self.b_seq_len
Comment on lines +60 to +63
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This if/else block can be simplified. The slicing logic [args_mtp_step :: (args_mtp_step + 1)] works correctly for args_mtp_step = 0 as well, where it becomes [0::1]. The .contiguous() call is necessary for slices with a step greater than 1 and is harmless for already contiguous tensors, so it can be applied unconditionally.

            self.b_att_seq_len = self.b_seq_len[args_mtp_step :: (args_mtp_step + 1)].contiguous()


if "offline_calibration_fp8kv" in model.mode:
if self.is_prefill:
Expand Down
4 changes: 2 additions & 2 deletions lightllm/models/llama/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,10 +883,10 @@ def _token_decode_attention_flashattention(self, q, infer_state: FlashAttentionS
k_cache=cache_k,
v_cache=cache_v,
page_table=infer_state.page_table,
cache_seqlens=infer_state.b_seq_len,
cache_seqlens=infer_state.b_att_seq_len,
cu_seqlens_q=infer_state.cu_seqlens_q,
cu_seqlens_k_new=infer_state.cu_seqlens_k,
max_seqlen_q=1,
max_seqlen_q=infer_state.max_q_seq_len,
softmax_scale=sm_scale,
causal=True,
window_size=(-1, -1),
Expand Down
11 changes: 10 additions & 1 deletion lightllm/models/qwen2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from lightllm.models.qwen2.layer_weights.transformer_layer_weight import Qwen2TransformerLayerWeight
from lightllm.models.llama.model import LlamaTpPartModel
from lightllm.common.mem_utils import select_mem_manager_class
from lightllm.utils.envs_utils import get_env_start_args


@ModelRegistry("qwen2")
Expand Down Expand Up @@ -41,12 +42,20 @@ def _init_mem_manager(self):
head_dim_ = self.config["hidden_size"] // self.config["num_attention_heads"]
head_dim_ = self.config.get("head_dim", head_dim_)
tp_k_head_num_ = max(self.config["num_key_value_heads"] // self.tp_world_size_, 1)

# mtp 模式下需要在mem manger上扩展draft model使用的layer
added_mtp_layer_num = 0
if get_env_start_args().mtp_mode == "deepseekv3_eagle":
added_mtp_layer_num += 1
elif get_env_start_args().mtp_mode == "deepseekv3_vanilla":
added_mtp_layer_num += get_env_start_args().mtp_step
Comment on lines +48 to +51
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This logic for calculating added_mtp_layer_num is specific to deepseekv3 MTP modes, but it's located in the qwen2 model file. This creates a tight and incorrect coupling. When running with a Qwen MTP model, this logic will fail to calculate the correct number of extra layers for the memory manager, potentially leading to insufficient memory allocation and runtime errors. This logic should be generalized or moved to a more appropriate location, such as a base MTP model or handled during model initialization based on the specific draft model's configuration.


self.mem_manager = select_mem_manager_class(self.mode)(
self.max_total_token_num,
dtype=self.data_type,
head_num=tp_k_head_num_,
head_dim=head_dim_,
layer_num=self.config["num_hidden_layers"],
layer_num=self.config["num_hidden_layers"] + added_mtp_layer_num,
mem_fraction=self.mem_fraction,
)
return
Empty file.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import numpy as np
from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight


class Qwen3MOEMTPPreAndPostLayerWeight(LlamaPreAndPostLayerWeight):
def __init__(self, data_type, network_config, mode):
super().__init__(data_type, network_config, mode)
# 与Qwen3MOE模型共享
self.wte_weight_ = None
self.lm_head_weight_ = None
return

def load_hf_weights(self, weights):
if "model.layers.0.proj.weight" in weights:
self.eh_proj_weight_ = self._cuda(weights["model.layers.0.proj.weight"]).t()
if "model.layers.0.norm_after_embedding.weight" in weights:
self.enorm_weight_ = self._cuda(weights["model.layers.0.norm_after_embedding.weight"])
if "model.layers.0.norm_before_output.weight" in weights:
self.hnorm_weight_ = self._cuda(weights["model.layers.0.norm_before_output.weight"])
return

def verify_load(self):
errors = "weights load not ok"
weights = [self.eh_proj_weight_, self.enorm_weight_, self.hnorm_weight_]
for i in range(len(weights)):
assert weights[i] is not None, "index:" + str(i) + " " + errors
Comment on lines +24 to +26
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The assertion message in verify_load is not very descriptive. For better debuggability, it would be helpful to include the name of the weight that failed to load, rather than just its index.

Suggested change
weights = [self.eh_proj_weight_, self.enorm_weight_, self.hnorm_weight_]
for i in range(len(weights)):
assert weights[i] is not None, "index:" + str(i) + " " + errors
weight_names = ["eh_proj_weight_", "enorm_weight_", "hnorm_weight_"]
weights = [self.eh_proj_weight_, self.enorm_weight_, self.hnorm_weight_]
for i, name in enumerate(weight_names):
assert weights[i] is not None, f"{name} {errors}"

return
47 changes: 47 additions & 0 deletions lightllm/models/qwen3_moe_mtp/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from lightllm.models.qwen3_moe.model import Qwen3MOEModel
from lightllm.models.qwen3_moe_mtp.layer_weights.pre_and_post_layer_weight import Qwen3MOEMTPPreAndPostLayerWeight
from lightllm.models.deepseek_mtp.layer_infer.pre_layer_infer import Deepseek3MTPPreLayerInfer
from lightllm.common.basemodel import TpPartBaseModel


class Qwen3MOEMTPModel(Qwen3MOEModel):

pre_and_post_weight_class = Qwen3MOEMTPPreAndPostLayerWeight
pre_layer_infer_class = Deepseek3MTPPreLayerInfer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Reusing Deepseek3MTPPreLayerInfer for a Qwen model might be confusing due to its specific name. If this class contains generic MTP pre-layer inference logic, consider renaming it to something more abstract (e.g., MTPPreLayerInfer) to improve clarity and maintainability.


def __init__(self, kvargs: dict):
self._pre_init(kvargs)
super().__init__(kvargs)
return

def _pre_init(self, kvargs: dict):
self.main_model: TpPartBaseModel = kvargs.pop("main_model")
self.mem_layer_start = kvargs.pop("mem_layer_start", 0)
return

def _init_custom(self):
self._cos_cached = self.main_model._cos_cached
self._sin_cached = self.main_model._sin_cached
return

def _init_req_manager(self):
self.req_manager = self.main_model.req_manager
return

def _init_mem_manager(self):
self.mem_manager = self.main_model.mem_manager
return

def _init_weights(self):
super()._init_weights()
self.pre_post_weight.wte_weight_ = self.main_model.pre_post_weight.wte_weight_
self.pre_post_weight.lm_head_weight_ = self.main_model.pre_post_weight.lm_head_weight_
self.pre_post_weight.final_norm_weight_ = self.main_model.pre_post_weight.final_norm_weight_
return

def _init_infer_layer(self):
super()._init_infer_layer()
# reset the layer_num_ of the self.layers_infer
for layer in self.layers_infer:
layer.layer_num_ = layer.layer_num_ + self.mem_layer_start
return
10 changes: 7 additions & 3 deletions lightllm/server/router/model_infer/mode_backend/base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from lightllm.server.core.objs.shm_objs_io_buffer import ShmObjsIOBuffer
from lightllm.server.router.model_infer.mode_backend.overlap_events import OverlapEventManager, OverlapEventPack
from lightllm.models.deepseek_mtp.model import Deepseek3MTPModel
from lightllm.models.qwen3_moe_mtp.model import Qwen3MOEMTPModel
from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample
from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token
from lightllm.server.pd_io_struct import NIXLChunckedTransTaskRet
Expand Down Expand Up @@ -281,9 +282,12 @@ def init_mtp_draft_model(self, main_kvargs: dict):
}

mtp_model_cfg, _ = PretrainedConfig.get_config_dict(self.args.mtp_draft_model_dir)
assert mtp_model_cfg["model_type"] == "deepseek_v3"
assert mtp_model_cfg["architectures"][0] == "DeepseekV3ForCausalLMNextN"
self.draft_models.append(Deepseek3MTPModel(mtp_model_kvargs))
if mtp_model_cfg["model_type"] == "deepseekv3":
self.draft_models.append(Deepseek3MTPModel(mtp_model_kvargs))
elif mtp_model_cfg["model_type"] == "qwen3_moe":
self.draft_models.append(Qwen3MOEMTPModel(mtp_model_kvargs))
else:
assert False, f"error mtp mode {mtp_model_cfg['model_type']}"

self.logger.info(f"loaded mtp model class {self.draft_models[i].__class__}")
return
Expand Down