-
Notifications
You must be signed in to change notification settings - Fork 289
Qwen3 moe mtp #1131
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Qwen3 moe mtp #1131
Conversation
Summary of ChangesHello @shihaobai, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces comprehensive support for the Qwen3 Mixture-of-Experts (MoE) model within the Multi-Turn Prediction (MTP) framework. It involves defining a new Qwen3 MoE MTP model, adapting core components like FlashAttention and memory management to handle MTP-specific requirements, and updating the model loading mechanism to seamlessly integrate this new model type as a draft model. The changes ensure that the LightLLM system can efficiently utilize Qwen3 MoE models for improved multi-turn prediction performance. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for the Qwen3 MOE model in MTP (Multi-Token Prediction) mode. The changes include new model and weight classes for Qwen3MOEMTP and modifications to memory management, attention mechanisms, and model loading to accommodate this new model. While the changes are mostly correct, I've identified a few issues related to code clarity, maintainability, and a design flaw where model-specific logic is misplaced. My review includes suggestions to rename misleading variables, simplify redundant code, refactor hardcoded logic to a more appropriate location, and improve debugging messages.
| if get_env_start_args().mtp_mode == "deepseekv3_eagle": | ||
| added_mtp_layer_num += 1 | ||
| elif get_env_start_args().mtp_mode == "deepseekv3_vanilla": | ||
| added_mtp_layer_num += get_env_start_args().mtp_step |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic for calculating added_mtp_layer_num is specific to deepseekv3 MTP modes, but it's located in the qwen2 model file. This creates a tight and incorrect coupling. When running with a Qwen MTP model, this logic will fail to calculate the correct number of extra layers for the memory manager, potentially leading to insufficient memory allocation and runtime errors. This logic should be generalized or moved to a more appropriate location, such as a base MTP model or handled during model initialization based on the specific draft model's configuration.
| is_deepseekv3_mtp_draft_model = "Deepseek3MTPModel" in str(self.__class__) or "Qwen3MOEMTPModel" in str( | ||
| self.__class__ | ||
| ) | ||
| if is_deepseekv3_mtp_draft_model: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable is_deepseekv3_mtp_draft_model is now misleading as it also checks for Qwen3MOEMTPModel. It should be renamed to something more generic, like is_mtp_draft_model, to accurately reflect its purpose.
| is_deepseekv3_mtp_draft_model = "Deepseek3MTPModel" in str(self.__class__) or "Qwen3MOEMTPModel" in str( | |
| self.__class__ | |
| ) | |
| if is_deepseekv3_mtp_draft_model: | |
| is_mtp_draft_model = "Deepseek3MTPModel" in str(self.__class__) or "Qwen3MOEMTPModel" in str( | |
| self.__class__ | |
| ) | |
| if is_mtp_draft_model: |
| if args_mtp_step > 0: | ||
| self.b_att_seq_len = self.b_seq_len[args_mtp_step :: (args_mtp_step + 1)].contiguous() | ||
| else: | ||
| self.b_att_seq_len = self.b_seq_len |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This if/else block can be simplified. The slicing logic [args_mtp_step :: (args_mtp_step + 1)] works correctly for args_mtp_step = 0 as well, where it becomes [0::1]. The .contiguous() call is necessary for slices with a step greater than 1 and is harmless for already contiguous tensors, so it can be applied unconditionally.
self.b_att_seq_len = self.b_seq_len[args_mtp_step :: (args_mtp_step + 1)].contiguous()| weights = [self.eh_proj_weight_, self.enorm_weight_, self.hnorm_weight_] | ||
| for i in range(len(weights)): | ||
| assert weights[i] is not None, "index:" + str(i) + " " + errors |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The assertion message in verify_load is not very descriptive. For better debuggability, it would be helpful to include the name of the weight that failed to load, rather than just its index.
| weights = [self.eh_proj_weight_, self.enorm_weight_, self.hnorm_weight_] | |
| for i in range(len(weights)): | |
| assert weights[i] is not None, "index:" + str(i) + " " + errors | |
| weight_names = ["eh_proj_weight_", "enorm_weight_", "hnorm_weight_"] | |
| weights = [self.eh_proj_weight_, self.enorm_weight_, self.hnorm_weight_] | |
| for i, name in enumerate(weight_names): | |
| assert weights[i] is not None, f"{name} {errors}" |
| class Qwen3MOEMTPModel(Qwen3MOEModel): | ||
|
|
||
| pre_and_post_weight_class = Qwen3MOEMTPPreAndPostLayerWeight | ||
| pre_layer_infer_class = Deepseek3MTPPreLayerInfer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No description provided.