From ed69449ffb4215affa87597fe282dbc54e8ba512 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 1 Dec 2025 12:40:30 +0000 Subject: [PATCH 1/8] qwen3_moe mtp --- lightllm/models/qwen3_moe_mtp/__init__.py | 0 .../qwen3_moe_mtp/layer_infer/__init__.py | 0 .../layer_infer/pre_layer_infer.py | 60 +++++++++++++++++++ .../qwen3_moe_mtp/layer_weights/__init__.py | 0 .../pre_and_post_layer_weight.py | 29 +++++++++ lightllm/models/qwen3_moe_mtp/model.py | 47 +++++++++++++++ .../model_infer/mode_backend/base_backend.py | 10 +++- 7 files changed, 143 insertions(+), 3 deletions(-) create mode 100644 lightllm/models/qwen3_moe_mtp/__init__.py create mode 100644 lightllm/models/qwen3_moe_mtp/layer_infer/__init__.py create mode 100644 lightllm/models/qwen3_moe_mtp/layer_infer/pre_layer_infer.py create mode 100644 lightllm/models/qwen3_moe_mtp/layer_weights/__init__.py create mode 100644 lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py create mode 100644 lightllm/models/qwen3_moe_mtp/model.py diff --git a/lightllm/models/qwen3_moe_mtp/__init__.py b/lightllm/models/qwen3_moe_mtp/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/qwen3_moe_mtp/layer_infer/__init__.py b/lightllm/models/qwen3_moe_mtp/layer_infer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/qwen3_moe_mtp/layer_infer/pre_layer_infer.py b/lightllm/models/qwen3_moe_mtp/layer_infer/pre_layer_infer.py new file mode 100644 index 000000000..66a41da73 --- /dev/null +++ b/lightllm/models/qwen3_moe_mtp/layer_infer/pre_layer_infer.py @@ -0,0 +1,60 @@ +import torch + +from lightllm.models.deepseek_mtp.layer_weights.pre_and_post_layer_weight import Deepseek3MTPPreAndPostLayerWeight +from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo +from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer +from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward + + +class Deepseek3MTPPreLayerInfer(LlamaPreLayerInfer): + """ """ + + def __init__(self, network_config, mode): + super().__init__(network_config, mode) + self.eps_ = network_config["rms_norm_eps"] + self.hidden_size = network_config["hidden_size"] + return + + def _mtp_context_forward( + self, input_embdings, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek3MTPPreAndPostLayerWeight + ): + tgt_embdings = infer_state.deepseekv3_mtp_draft_input_hiddens + assert input_embdings.shape[0] == tgt_embdings.shape[0] + rmsnorm_forward(input_embdings, weight=layer_weight.enorm_weight_, eps=self.eps_, out=input_embdings) + rmsnorm_forward(tgt_embdings, weight=layer_weight.hnorm_weight_, eps=self.eps_, out=tgt_embdings) + + cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1) + + ans_logics = self.alloc_tensor( + (cat_embdings.shape[0], layer_weight.eh_proj_weight_.shape[1]), dtype=input_embdings.dtype + ) + torch.mm(cat_embdings, layer_weight.eh_proj_weight_, out=ans_logics) + return ans_logics + + def _mtp_token_forward( + self, input_embdings, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek3MTPPreAndPostLayerWeight + ): + tgt_embdings = infer_state.deepseekv3_mtp_draft_input_hiddens + assert input_embdings.shape[0] == tgt_embdings.shape[0] + rmsnorm_forward(input_embdings, weight=layer_weight.enorm_weight_, eps=self.eps_, out=input_embdings) + rmsnorm_forward(tgt_embdings, weight=layer_weight.hnorm_weight_, eps=self.eps_, out=tgt_embdings) + + cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1) + + ans_logics = self.alloc_tensor( + (cat_embdings.shape[0], layer_weight.eh_proj_weight_.shape[1]), dtype=input_embdings.dtype + ) + torch.mm(cat_embdings, layer_weight.eh_proj_weight_, out=ans_logics) + return ans_logics + + def context_forward( + self, input_ids, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek3MTPPreAndPostLayerWeight + ): + input_embdings = super().context_forward(input_ids, infer_state, layer_weight) + return self._mtp_context_forward(input_embdings, infer_state, layer_weight) + + def token_forward( + self, input_ids, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek3MTPPreAndPostLayerWeight + ): + input_embdings = super().token_forward(input_ids, infer_state, layer_weight) + return self._mtp_token_forward(input_embdings, infer_state, layer_weight) diff --git a/lightllm/models/qwen3_moe_mtp/layer_weights/__init__.py b/lightllm/models/qwen3_moe_mtp/layer_weights/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py new file mode 100644 index 000000000..f5b805647 --- /dev/null +++ b/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py @@ -0,0 +1,29 @@ +import numpy as np +from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight + + +class Deepseek3MTPPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): + def __init__(self, data_type, network_config, mode): + super().__init__(data_type, network_config, mode) + # 与DeepseekV3模型共享 + self.wte_weight_ = None + self.lm_head_weight_ = None + return + + def load_hf_weights(self, weights): + if "model.layers.0.eh_proj.weight" in weights: + self.eh_proj_weight_ = self._cuda(weights["model.layers.0.eh_proj.weight"]).t() + if "model.layers.0.enorm.weight" in weights: + self.enorm_weight_ = self._cuda(weights["model.layers.0.enorm.weight"]) + if "model.layers.0.hnorm.weight" in weights: + self.hnorm_weight_ = self._cuda(weights["model.layers.0.hnorm.weight"]) + if "model.layers.0.shared_head.norm.weight" in weights: + self.final_norm_weight_ = self._cuda(weights["model.layers.0.shared_head.norm.weight"]) + return + + def verify_load(self): + errors = "weights load not ok" + weights = [self.eh_proj_weight_, self.enorm_weight_, self.hnorm_weight_, self.final_norm_weight_] + for i in range(len(weights)): + assert weights[i] is not None, "index:" + str(i) + " " + errors + return diff --git a/lightllm/models/qwen3_moe_mtp/model.py b/lightllm/models/qwen3_moe_mtp/model.py new file mode 100644 index 000000000..4586db4be --- /dev/null +++ b/lightllm/models/qwen3_moe_mtp/model.py @@ -0,0 +1,47 @@ +from lightllm.models.qwen3_moe.model import Qwen3MOEModel +from lightllm.models.deepseek_mtp.layer_infer.pre_layer_infer import Deepseek3MTPPreLayerInfer +from lightllm.models.deepseek_mtp.layer_weights.pre_and_post_layer_weight import Deepseek3MTPPreAndPostLayerWeight +from lightllm.common.basemodel import TpPartBaseModel + + +class Qwen3MOEMTPModel(Qwen3MOEModel): + + pre_and_post_weight_class = Deepseek3MTPPreAndPostLayerWeight + pre_layer_infer_class = Deepseek3MTPPreLayerInfer + + def __init__(self, kvargs: dict): + self._pre_init(kvargs) + super().__init__(kvargs) + return + + def _pre_init(self, kvargs: dict): + self.main_model: TpPartBaseModel = kvargs.pop("main_model") + self.mem_layer_start = kvargs.pop("mem_layer_start", 0) + return + + def _init_custom(self): + self._cos_cached = self.main_model._cos_cached + self._sin_cached = self.main_model._sin_cached + return + + def _init_req_manager(self): + self.req_manager = self.main_model.req_manager + return + + def _init_mem_manager(self): + self.mem_manager = self.main_model.mem_manager + return + + def _init_weights(self): + super()._init_weights() + self.pre_post_weight.wte_weight_ = self.main_model.pre_post_weight.wte_weight_ + self.pre_post_weight.lm_head_weight_ = self.main_model.pre_post_weight.lm_head_weight_ + self.pre_post_weight.final_norm_weight_ = self.main_model.pre_post_weight.final_norm_weight_ + return + + def _init_infer_layer(self): + super()._init_infer_layer() + # reset the layer_num_ of the self.layers_infer + for layer in self.layers_infer: + layer.layer_num_ = layer.layer_num_ + self.mem_layer_start + return diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 95f0c9951..da8e095ef 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -35,6 +35,7 @@ from lightllm.server.core.objs.shm_objs_io_buffer import ShmObjsIOBuffer from lightllm.server.router.model_infer.mode_backend.overlap_events import OverlapEventManager, OverlapEventPack from lightllm.models.deepseek_mtp.model import Deepseek3MTPModel +from lightllm.models.qwen3_moe_mtp.model import Qwen3MOEMTPModel from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token from lightllm.server.pd_io_struct import NIXLChunckedTransTaskRet @@ -281,9 +282,12 @@ def init_mtp_draft_model(self, main_kvargs: dict): } mtp_model_cfg, _ = PretrainedConfig.get_config_dict(self.args.mtp_draft_model_dir) - assert mtp_model_cfg["model_type"] == "deepseek_v3" - assert mtp_model_cfg["architectures"][0] == "DeepseekV3ForCausalLMNextN" - self.draft_models.append(Deepseek3MTPModel(mtp_model_kvargs)) + if mtp_model_cfg["model_type"] == "deepseekv3": + self.draft_models.append(Deepseek3MTPModel(mtp_model_kvargs)) + elif mtp_model_cfg["model_type"] == "qwen3_moe": + self.draft_models.append(Qwen3MOEMTPModel(mtp_model_kvargs)) + else: + assert False, f"error mtp mode {mtp_model_cfg['model_type']}" self.logger.info(f"loaded mtp model class {self.draft_models[i].__class__}") return From a4fb3ed9a31cc8f77b8c501bceb24999d2e8e6a1 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 1 Dec 2025 12:52:47 +0000 Subject: [PATCH 2/8] fix weight name --- lightllm/common/basemodel/basemodel.py | 4 +- .../qwen3_moe_mtp/layer_infer/__init__.py | 0 .../layer_infer/pre_layer_infer.py | 60 ------------------- .../pre_and_post_layer_weight.py | 18 +++--- lightllm/models/qwen3_moe_mtp/model.py | 4 +- 5 files changed, 13 insertions(+), 73 deletions(-) delete mode 100644 lightllm/models/qwen3_moe_mtp/layer_infer/__init__.py delete mode 100644 lightllm/models/qwen3_moe_mtp/layer_infer/pre_layer_infer.py diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 4421ea0da..ffb1cb75a 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -908,7 +908,9 @@ def _init_padded_req(self): def _gen_special_model_input(self, token_num: int): special_model_input = {} - is_deepseekv3_mtp_draft_model = "Deepseek3MTPModel" in str(self.__class__) + is_deepseekv3_mtp_draft_model = "Deepseek3MTPModel" in str(self.__class__) or "Qwen3MOEMTPModel" in str( + self.__class__ + ) if is_deepseekv3_mtp_draft_model: special_model_input["deepseekv3_mtp_draft_input_hiddens"] = torch.randn( token_num, self.config["hidden_size"], dtype=self.data_type, device="cuda" diff --git a/lightllm/models/qwen3_moe_mtp/layer_infer/__init__.py b/lightllm/models/qwen3_moe_mtp/layer_infer/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/lightllm/models/qwen3_moe_mtp/layer_infer/pre_layer_infer.py b/lightllm/models/qwen3_moe_mtp/layer_infer/pre_layer_infer.py deleted file mode 100644 index 66a41da73..000000000 --- a/lightllm/models/qwen3_moe_mtp/layer_infer/pre_layer_infer.py +++ /dev/null @@ -1,60 +0,0 @@ -import torch - -from lightllm.models.deepseek_mtp.layer_weights.pre_and_post_layer_weight import Deepseek3MTPPreAndPostLayerWeight -from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo -from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer -from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward - - -class Deepseek3MTPPreLayerInfer(LlamaPreLayerInfer): - """ """ - - def __init__(self, network_config, mode): - super().__init__(network_config, mode) - self.eps_ = network_config["rms_norm_eps"] - self.hidden_size = network_config["hidden_size"] - return - - def _mtp_context_forward( - self, input_embdings, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek3MTPPreAndPostLayerWeight - ): - tgt_embdings = infer_state.deepseekv3_mtp_draft_input_hiddens - assert input_embdings.shape[0] == tgt_embdings.shape[0] - rmsnorm_forward(input_embdings, weight=layer_weight.enorm_weight_, eps=self.eps_, out=input_embdings) - rmsnorm_forward(tgt_embdings, weight=layer_weight.hnorm_weight_, eps=self.eps_, out=tgt_embdings) - - cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1) - - ans_logics = self.alloc_tensor( - (cat_embdings.shape[0], layer_weight.eh_proj_weight_.shape[1]), dtype=input_embdings.dtype - ) - torch.mm(cat_embdings, layer_weight.eh_proj_weight_, out=ans_logics) - return ans_logics - - def _mtp_token_forward( - self, input_embdings, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek3MTPPreAndPostLayerWeight - ): - tgt_embdings = infer_state.deepseekv3_mtp_draft_input_hiddens - assert input_embdings.shape[0] == tgt_embdings.shape[0] - rmsnorm_forward(input_embdings, weight=layer_weight.enorm_weight_, eps=self.eps_, out=input_embdings) - rmsnorm_forward(tgt_embdings, weight=layer_weight.hnorm_weight_, eps=self.eps_, out=tgt_embdings) - - cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1) - - ans_logics = self.alloc_tensor( - (cat_embdings.shape[0], layer_weight.eh_proj_weight_.shape[1]), dtype=input_embdings.dtype - ) - torch.mm(cat_embdings, layer_weight.eh_proj_weight_, out=ans_logics) - return ans_logics - - def context_forward( - self, input_ids, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek3MTPPreAndPostLayerWeight - ): - input_embdings = super().context_forward(input_ids, infer_state, layer_weight) - return self._mtp_context_forward(input_embdings, infer_state, layer_weight) - - def token_forward( - self, input_ids, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek3MTPPreAndPostLayerWeight - ): - input_embdings = super().token_forward(input_ids, infer_state, layer_weight) - return self._mtp_token_forward(input_embdings, infer_state, layer_weight) diff --git a/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py index f5b805647..408992178 100644 --- a/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py @@ -2,23 +2,21 @@ from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight -class Deepseek3MTPPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): +class Qwen3MOEMTPPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) - # 与DeepseekV3模型共享 + # 与Qwen3MOE模型共享 self.wte_weight_ = None self.lm_head_weight_ = None return def load_hf_weights(self, weights): - if "model.layers.0.eh_proj.weight" in weights: - self.eh_proj_weight_ = self._cuda(weights["model.layers.0.eh_proj.weight"]).t() - if "model.layers.0.enorm.weight" in weights: - self.enorm_weight_ = self._cuda(weights["model.layers.0.enorm.weight"]) - if "model.layers.0.hnorm.weight" in weights: - self.hnorm_weight_ = self._cuda(weights["model.layers.0.hnorm.weight"]) - if "model.layers.0.shared_head.norm.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["model.layers.0.shared_head.norm.weight"]) + if "model.0.proj.weight" in weights: + self.eh_proj_weight_ = self._cuda(weights["model.0.proj.weight"]).t() + if "model.0.norm_after_embedding.weight" in weights: + self.enorm_weight_ = self._cuda(weights["model.0.norm_after_embedding.weight"]) + if "model.0.norm_before_output.weight" in weights: + self.hnorm_weight_ = self._cuda(weights["model.0.norm_before_output.weight"]) return def verify_load(self): diff --git a/lightllm/models/qwen3_moe_mtp/model.py b/lightllm/models/qwen3_moe_mtp/model.py index 4586db4be..ba6c82804 100644 --- a/lightllm/models/qwen3_moe_mtp/model.py +++ b/lightllm/models/qwen3_moe_mtp/model.py @@ -1,12 +1,12 @@ from lightllm.models.qwen3_moe.model import Qwen3MOEModel +from lightllm.models.qwen3_moe_mtp.layer_weights.pre_and_post_layer_weight import Qwen3MOEMTPPreAndPostLayerWeight from lightllm.models.deepseek_mtp.layer_infer.pre_layer_infer import Deepseek3MTPPreLayerInfer -from lightllm.models.deepseek_mtp.layer_weights.pre_and_post_layer_weight import Deepseek3MTPPreAndPostLayerWeight from lightllm.common.basemodel import TpPartBaseModel class Qwen3MOEMTPModel(Qwen3MOEModel): - pre_and_post_weight_class = Deepseek3MTPPreAndPostLayerWeight + pre_and_post_weight_class = Qwen3MOEMTPPreAndPostLayerWeight pre_layer_infer_class = Deepseek3MTPPreLayerInfer def __init__(self, kvargs: dict): From ee464d9b5db21d85865e5e32f943aa9cd4c77e61 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 1 Dec 2025 13:13:11 +0000 Subject: [PATCH 3/8] fix qwen3 fa3 mtp --- .../models/llama/flashattention_infer_struct.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/lightllm/models/llama/flashattention_infer_struct.py b/lightllm/models/llama/flashattention_infer_struct.py index c6e7aa560..4cfd72e81 100644 --- a/lightllm/models/llama/flashattention_infer_struct.py +++ b/lightllm/models/llama/flashattention_infer_struct.py @@ -38,22 +38,29 @@ def _init_flash_attention_state(self, model, input_ids: torch.Tensor): self.cu_seqlens_q = self.b1_cu_q_seq_len.int() self.cu_seqlens_k = self.b1_cu_kv_seq_len.int() max_seq_len_k = self.max_kv_seq_len + args_mtp_step = get_env_start_args().mtp_step + att_batch_size = self.batch_size // (args_mtp_step + 1) if self.batch_size <= model.graph_max_batch_size and self.max_len_in_batch <= model.graph_max_len_in_batch: page_buffer = FlashAttentionStateInfo.get_page_table_buffer( model.graph_max_batch_size, model.graph_max_len_in_batch ) self.page_table = page_buffer[self.microbatch_index][ - : self.batch_size * model.graph_max_len_in_batch - ].reshape(self.batch_size, model.graph_max_len_in_batch) + : att_batch_size * model.graph_max_len_in_batch + ].reshape(att_batch_size, model.graph_max_len_in_batch) else: self.page_table = torch.empty( - (self.batch_size, self.max_len_in_batch), dtype=torch.int32, device=input_ids.device + (att_batch_size, self.max_len_in_batch), dtype=torch.int32, device=input_ids.device ) + page_table_copy( page_table=self.page_table[:, :max_seq_len_k], req_to_token_indexs=model.req_manager.req_to_token_indexs, - b_req_idx=self.b_req_idx, + b_req_idx=self.b_req_idx[args_mtp_step :: (args_mtp_step + 1)], ) + if args_mtp_step > 0: + self.b_att_seq_len = self.b_seq_len[args_mtp_step :: (args_mtp_step + 1)].contiguous() + else: + self.b_att_seq_len = self.b_seq_len if "offline_calibration_fp8kv" in model.mode: if self.is_prefill: From 2db3418e6166179cf91a2665e9d2a71029168ee4 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 1 Dec 2025 13:17:57 +0000 Subject: [PATCH 4/8] fix --- lightllm/models/llama/layer_infer/transformer_layer_infer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index 8c6015677..ea44fe2e5 100755 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -883,10 +883,10 @@ def _token_decode_attention_flashattention(self, q, infer_state: FlashAttentionS k_cache=cache_k, v_cache=cache_v, page_table=infer_state.page_table, - cache_seqlens=infer_state.b_seq_len, + cache_seqlens=infer_state.b_att_seq_len, cu_seqlens_q=infer_state.cu_seqlens_q, cu_seqlens_k_new=infer_state.cu_seqlens_k, - max_seqlen_q=1, + max_seqlen_q=infer_state.max_q_seq_len, softmax_scale=sm_scale, causal=True, window_size=(-1, -1), From e3fae1cb6c1bceff6c2dd7329a5eed7208127c6b Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 1 Dec 2025 13:27:58 +0000 Subject: [PATCH 5/8] fix --- .../qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py index 408992178..bb18a4dda 100644 --- a/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py @@ -21,7 +21,7 @@ def load_hf_weights(self, weights): def verify_load(self): errors = "weights load not ok" - weights = [self.eh_proj_weight_, self.enorm_weight_, self.hnorm_weight_, self.final_norm_weight_] + weights = [self.eh_proj_weight_, self.enorm_weight_, self.hnorm_weight_] for i in range(len(weights)): assert weights[i] is not None, "index:" + str(i) + " " + errors return From 1ff65f7580885b2e58001c6e6839a8280e98bc32 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 1 Dec 2025 13:51:25 +0000 Subject: [PATCH 6/8] fix --- .../layer_weights/pre_and_post_layer_weight.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py index bb18a4dda..57d98eec9 100644 --- a/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py @@ -11,12 +11,12 @@ def __init__(self, data_type, network_config, mode): return def load_hf_weights(self, weights): - if "model.0.proj.weight" in weights: - self.eh_proj_weight_ = self._cuda(weights["model.0.proj.weight"]).t() - if "model.0.norm_after_embedding.weight" in weights: - self.enorm_weight_ = self._cuda(weights["model.0.norm_after_embedding.weight"]) - if "model.0.norm_before_output.weight" in weights: - self.hnorm_weight_ = self._cuda(weights["model.0.norm_before_output.weight"]) + if "model.layers.0.proj.weight" in weights: + self.eh_proj_weight_ = self._cuda(weights["model.layers.0.proj.weight"]).t() + if "model.layers.0.norm_after_embedding.weight" in weights: + self.enorm_weight_ = self._cuda(weights["model.layers.0.norm_after_embedding.weight"]) + if "model.layers.0.norm_before_output.weight" in weights: + self.hnorm_weight_ = self._cuda(weights["model.layers.0.norm_before_output.weight"]) return def verify_load(self): From e0b6ed6b1c4d6092904cb493e958a4fd111af4f2 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 1 Dec 2025 14:00:23 +0000 Subject: [PATCH 7/8] fix --- lightllm/models/qwen2/model.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/lightllm/models/qwen2/model.py b/lightllm/models/qwen2/model.py index e3d8de461..457ccf592 100644 --- a/lightllm/models/qwen2/model.py +++ b/lightllm/models/qwen2/model.py @@ -3,6 +3,7 @@ from lightllm.models.qwen2.layer_weights.transformer_layer_weight import Qwen2TransformerLayerWeight from lightllm.models.llama.model import LlamaTpPartModel from lightllm.common.mem_utils import select_mem_manager_class +from lightllm.common.env_utils import get_env_start_args @ModelRegistry("qwen2") @@ -41,12 +42,20 @@ def _init_mem_manager(self): head_dim_ = self.config["hidden_size"] // self.config["num_attention_heads"] head_dim_ = self.config.get("head_dim", head_dim_) tp_k_head_num_ = max(self.config["num_key_value_heads"] // self.tp_world_size_, 1) + + # mtp 模式下需要在mem manger上扩展draft model使用的layer + added_mtp_layer_num = 0 + if get_env_start_args().mtp_mode == "deepseekv3_eagle": + added_mtp_layer_num += 1 + elif get_env_start_args().mtp_mode == "deepseekv3_vanilla": + added_mtp_layer_num += get_env_start_args().mtp_step + self.mem_manager = select_mem_manager_class(self.mode)( self.max_total_token_num, dtype=self.data_type, head_num=tp_k_head_num_, head_dim=head_dim_, - layer_num=self.config["num_hidden_layers"], + layer_num=self.config["num_hidden_layers"] + added_mtp_layer_num, mem_fraction=self.mem_fraction, ) return From 88301ed50ea025799edb481bbd999810fbe7dbef Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 1 Dec 2025 14:02:45 +0000 Subject: [PATCH 8/8] fix --- lightllm/models/qwen2/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/models/qwen2/model.py b/lightllm/models/qwen2/model.py index 457ccf592..7e2f1a302 100644 --- a/lightllm/models/qwen2/model.py +++ b/lightllm/models/qwen2/model.py @@ -3,7 +3,7 @@ from lightllm.models.qwen2.layer_weights.transformer_layer_weight import Qwen2TransformerLayerWeight from lightllm.models.llama.model import LlamaTpPartModel from lightllm.common.mem_utils import select_mem_manager_class -from lightllm.common.env_utils import get_env_start_args +from lightllm.utils.envs_utils import get_env_start_args @ModelRegistry("qwen2")