-
Notifications
You must be signed in to change notification settings - Fork 289
Qwen3 moe mtp #1131
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Qwen3 moe mtp #1131
Changes from all commits
ed69449
a4fb3ed
ee464d9
2db3418
e3fae1c
1ff65f7
e0b6ed6
88301ed
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -38,22 +38,29 @@ def _init_flash_attention_state(self, model, input_ids: torch.Tensor): | |
| self.cu_seqlens_q = self.b1_cu_q_seq_len.int() | ||
| self.cu_seqlens_k = self.b1_cu_kv_seq_len.int() | ||
| max_seq_len_k = self.max_kv_seq_len | ||
| args_mtp_step = get_env_start_args().mtp_step | ||
| att_batch_size = self.batch_size // (args_mtp_step + 1) | ||
| if self.batch_size <= model.graph_max_batch_size and self.max_len_in_batch <= model.graph_max_len_in_batch: | ||
| page_buffer = FlashAttentionStateInfo.get_page_table_buffer( | ||
| model.graph_max_batch_size, model.graph_max_len_in_batch | ||
| ) | ||
| self.page_table = page_buffer[self.microbatch_index][ | ||
| : self.batch_size * model.graph_max_len_in_batch | ||
| ].reshape(self.batch_size, model.graph_max_len_in_batch) | ||
| : att_batch_size * model.graph_max_len_in_batch | ||
| ].reshape(att_batch_size, model.graph_max_len_in_batch) | ||
| else: | ||
| self.page_table = torch.empty( | ||
| (self.batch_size, self.max_len_in_batch), dtype=torch.int32, device=input_ids.device | ||
| (att_batch_size, self.max_len_in_batch), dtype=torch.int32, device=input_ids.device | ||
| ) | ||
|
|
||
| page_table_copy( | ||
| page_table=self.page_table[:, :max_seq_len_k], | ||
| req_to_token_indexs=model.req_manager.req_to_token_indexs, | ||
| b_req_idx=self.b_req_idx, | ||
| b_req_idx=self.b_req_idx[args_mtp_step :: (args_mtp_step + 1)], | ||
| ) | ||
| if args_mtp_step > 0: | ||
| self.b_att_seq_len = self.b_seq_len[args_mtp_step :: (args_mtp_step + 1)].contiguous() | ||
| else: | ||
| self.b_att_seq_len = self.b_seq_len | ||
|
Comment on lines
+60
to
+63
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This self.b_att_seq_len = self.b_seq_len[args_mtp_step :: (args_mtp_step + 1)].contiguous() |
||
|
|
||
| if "offline_calibration_fp8kv" in model.mode: | ||
| if self.is_prefill: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,6 +3,7 @@ | |
| from lightllm.models.qwen2.layer_weights.transformer_layer_weight import Qwen2TransformerLayerWeight | ||
| from lightllm.models.llama.model import LlamaTpPartModel | ||
| from lightllm.common.mem_utils import select_mem_manager_class | ||
| from lightllm.utils.envs_utils import get_env_start_args | ||
|
|
||
|
|
||
| @ModelRegistry("qwen2") | ||
|
|
@@ -41,12 +42,20 @@ def _init_mem_manager(self): | |
| head_dim_ = self.config["hidden_size"] // self.config["num_attention_heads"] | ||
| head_dim_ = self.config.get("head_dim", head_dim_) | ||
| tp_k_head_num_ = max(self.config["num_key_value_heads"] // self.tp_world_size_, 1) | ||
|
|
||
| # mtp 模式下需要在mem manger上扩展draft model使用的layer | ||
| added_mtp_layer_num = 0 | ||
| if get_env_start_args().mtp_mode == "deepseekv3_eagle": | ||
| added_mtp_layer_num += 1 | ||
| elif get_env_start_args().mtp_mode == "deepseekv3_vanilla": | ||
| added_mtp_layer_num += get_env_start_args().mtp_step | ||
|
Comment on lines
+48
to
+51
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This logic for calculating |
||
|
|
||
| self.mem_manager = select_mem_manager_class(self.mode)( | ||
| self.max_total_token_num, | ||
| dtype=self.data_type, | ||
| head_num=tp_k_head_num_, | ||
| head_dim=head_dim_, | ||
| layer_num=self.config["num_hidden_layers"], | ||
| layer_num=self.config["num_hidden_layers"] + added_mtp_layer_num, | ||
| mem_fraction=self.mem_fraction, | ||
| ) | ||
| return | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,27 @@ | ||||||||||||||||
| import numpy as np | ||||||||||||||||
| from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight | ||||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
| class Qwen3MOEMTPPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): | ||||||||||||||||
| def __init__(self, data_type, network_config, mode): | ||||||||||||||||
| super().__init__(data_type, network_config, mode) | ||||||||||||||||
| # 与Qwen3MOE模型共享 | ||||||||||||||||
| self.wte_weight_ = None | ||||||||||||||||
| self.lm_head_weight_ = None | ||||||||||||||||
| return | ||||||||||||||||
|
|
||||||||||||||||
| def load_hf_weights(self, weights): | ||||||||||||||||
| if "model.layers.0.proj.weight" in weights: | ||||||||||||||||
| self.eh_proj_weight_ = self._cuda(weights["model.layers.0.proj.weight"]).t() | ||||||||||||||||
| if "model.layers.0.norm_after_embedding.weight" in weights: | ||||||||||||||||
| self.enorm_weight_ = self._cuda(weights["model.layers.0.norm_after_embedding.weight"]) | ||||||||||||||||
| if "model.layers.0.norm_before_output.weight" in weights: | ||||||||||||||||
| self.hnorm_weight_ = self._cuda(weights["model.layers.0.norm_before_output.weight"]) | ||||||||||||||||
| return | ||||||||||||||||
|
|
||||||||||||||||
| def verify_load(self): | ||||||||||||||||
| errors = "weights load not ok" | ||||||||||||||||
| weights = [self.eh_proj_weight_, self.enorm_weight_, self.hnorm_weight_] | ||||||||||||||||
| for i in range(len(weights)): | ||||||||||||||||
| assert weights[i] is not None, "index:" + str(i) + " " + errors | ||||||||||||||||
|
Comment on lines
+24
to
+26
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The assertion message in
Suggested change
|
||||||||||||||||
| return | ||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,47 @@ | ||
| from lightllm.models.qwen3_moe.model import Qwen3MOEModel | ||
| from lightllm.models.qwen3_moe_mtp.layer_weights.pre_and_post_layer_weight import Qwen3MOEMTPPreAndPostLayerWeight | ||
| from lightllm.models.deepseek_mtp.layer_infer.pre_layer_infer import Deepseek3MTPPreLayerInfer | ||
| from lightllm.common.basemodel import TpPartBaseModel | ||
|
|
||
|
|
||
| class Qwen3MOEMTPModel(Qwen3MOEModel): | ||
|
|
||
| pre_and_post_weight_class = Qwen3MOEMTPPreAndPostLayerWeight | ||
| pre_layer_infer_class = Deepseek3MTPPreLayerInfer | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| def __init__(self, kvargs: dict): | ||
| self._pre_init(kvargs) | ||
| super().__init__(kvargs) | ||
| return | ||
|
|
||
| def _pre_init(self, kvargs: dict): | ||
| self.main_model: TpPartBaseModel = kvargs.pop("main_model") | ||
| self.mem_layer_start = kvargs.pop("mem_layer_start", 0) | ||
| return | ||
|
|
||
| def _init_custom(self): | ||
| self._cos_cached = self.main_model._cos_cached | ||
| self._sin_cached = self.main_model._sin_cached | ||
| return | ||
|
|
||
| def _init_req_manager(self): | ||
| self.req_manager = self.main_model.req_manager | ||
| return | ||
|
|
||
| def _init_mem_manager(self): | ||
| self.mem_manager = self.main_model.mem_manager | ||
| return | ||
|
|
||
| def _init_weights(self): | ||
| super()._init_weights() | ||
| self.pre_post_weight.wte_weight_ = self.main_model.pre_post_weight.wte_weight_ | ||
| self.pre_post_weight.lm_head_weight_ = self.main_model.pre_post_weight.lm_head_weight_ | ||
| self.pre_post_weight.final_norm_weight_ = self.main_model.pre_post_weight.final_norm_weight_ | ||
| return | ||
|
|
||
| def _init_infer_layer(self): | ||
| super()._init_infer_layer() | ||
| # reset the layer_num_ of the self.layers_infer | ||
| for layer in self.layers_infer: | ||
| layer.layer_num_ = layer.layer_num_ + self.mem_layer_start | ||
| return | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable
is_deepseekv3_mtp_draft_modelis now misleading as it also checks forQwen3MOEMTPModel. It should be renamed to something more generic, likeis_mtp_draft_model, to accurately reflect its purpose.