diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index 24eab78c14fc..d263f22eaeab 100644 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -166,6 +166,9 @@ def _build_checkpoint_conversion_mapping(): mapping["deepseek_v3"] = mapping["qwen2_moe"].copy() mapping["dots1"] = mapping["qwen2_moe"].copy() mapping["ernie4_5_moe"] = mapping["qwen2_moe"].copy() + mapping["ernie4_5_moe"] += [ + WeightRenaming("mlp.moe_statics.e_score_correction_bias", "mlp.gate.moe_statics.e_score_correction_bias") + ] mapping["glm4_moe"] = mapping["qwen2_moe"].copy() mapping["glm4v_moe"] = mapping["qwen2_moe"].copy() mapping["longcat_flash"] = mapping["qwen2_moe"].copy() diff --git a/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py b/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py index 27602d04d7a1..7e0cc0cdce26 100644 --- a/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +++ b/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py @@ -373,14 +373,14 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens with torch.autocast(device_type=device_type, enabled=False): # Force float32 router_logits = F.linear(hidden_states.float(), self.weight) - router_logits = F.softmax(router_logits, dim=1, dtype=torch.float) - router_top_value, router_indices = torch.topk(self.moe_statics(router_logits), self.top_k, dim=-1) - router_top_value = router_top_value / torch.clamp( - router_top_value.sum(dim=-1, keepdim=True), min=self.norm_min + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + _, selected_experts = torch.topk(self.moe_statics(routing_weights), self.top_k, dim=-1) + routing_weights = torch.gather(routing_weights, dim=-1, index=selected_experts) + routing_weights = routing_weights / torch.clamp( + routing_weights.sum(dim=-1, keepdim=True), min=self.norm_min ) - router_scores = router_top_value - router_scores = router_scores.to(hidden_states.dtype) - return router_logits, router_scores, router_indices + routing_weights = routing_weights.to(hidden_states.dtype) + return router_logits, selected_experts, routing_weights class Ernie4_5_MoeSparseMoeBlock(nn.Module): @@ -403,7 +403,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.shared_experts is not None: shared_output = self.shared_experts(hidden_states) - _, top_k_weights, top_k_index = self.gate(hidden_states) + _, top_k_index, top_k_weights = self.gate(hidden_states) final_hidden_states = self.experts(hidden_states, top_k_index, top_k_weights) if self.shared_experts is not None: diff --git a/src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py b/src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py index 74feb925a7ce..fb2acd4c31f7 100644 --- a/src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +++ b/src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py @@ -148,14 +148,14 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens with torch.autocast(device_type=device_type, enabled=False): # Force float32 router_logits = F.linear(hidden_states.float(), self.weight) - router_logits = F.softmax(router_logits, dim=1, dtype=torch.float) - router_top_value, router_indices = torch.topk(self.moe_statics(router_logits), self.top_k, dim=-1) - router_top_value = router_top_value / torch.clamp( - router_top_value.sum(dim=-1, keepdim=True), min=self.norm_min + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + _, selected_experts = torch.topk(self.moe_statics(routing_weights), self.top_k, dim=-1) + routing_weights = torch.gather(routing_weights, dim=-1, index=selected_experts) + routing_weights = routing_weights / torch.clamp( + routing_weights.sum(dim=-1, keepdim=True), min=self.norm_min ) - router_scores = router_top_value - router_scores = router_scores.to(hidden_states.dtype) - return router_logits, router_scores, router_indices + routing_weights = routing_weights.to(hidden_states.dtype) + return router_logits, selected_experts, routing_weights class Ernie4_5_MoeSparseMoeBlock(nn.Module): @@ -178,7 +178,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.shared_experts is not None: shared_output = self.shared_experts(hidden_states) - _, top_k_weights, top_k_index = self.gate(hidden_states) + _, top_k_index, top_k_weights = self.gate(hidden_states) final_hidden_states = self.experts(hidden_states, top_k_index, top_k_weights) if self.shared_experts is not None: diff --git a/tests/models/ernie4_5_moe/test_modeling_ernie4_5_moe.py b/tests/models/ernie4_5_moe/test_modeling_ernie4_5_moe.py index 9e47ec145255..b86d8ef98e0b 100644 --- a/tests/models/ernie4_5_moe/test_modeling_ernie4_5_moe.py +++ b/tests/models/ernie4_5_moe/test_modeling_ernie4_5_moe.py @@ -130,9 +130,7 @@ def test_load_balancing_loss(self): self.assertNotAlmostEqual(include_padding_result.aux_loss.item(), result.aux_loss.item()) -# Run on runners with larger accelerators (for example A10 instead of T4) with a lot of CPU RAM (e.g. g5-12xlarge) -@require_torch_multi_accelerator -@require_torch_large_accelerator +@slow @require_torch class Ernie4_5_MoeIntegrationTest(unittest.TestCase): @classmethod @@ -144,27 +142,59 @@ def tearDownClass(cls): del cls.model cleanup(torch_device, gc_collect=True) + def setup(self): + cleanup(torch_device, gc_collect=True) + def tearDown(self): cleanup(torch_device, gc_collect=True) @classmethod - def get_model(cls): - if cls.model is None: - cls.model = Ernie4_5_MoeForCausalLM.from_pretrained( - "baidu/ERNIE-4.5-21B-A3B-PT", - device_map="auto", - quantization_config=BitsAndBytesConfig(load_in_4bit=True), - ) + def get_large_model(cls): + cls.model = Ernie4_5_MoeForCausalLM.from_pretrained( + "baidu/ERNIE-4.5-21B-A3B-PT", + device_map="auto", + quantization_config=BitsAndBytesConfig(load_in_4bit=True), + ) + + return cls.model + + @classmethod + def get_small_model(cls): + cls.model = Ernie4_5_MoeForCausalLM.from_pretrained( + "hf-internal-testing/ERNIE-4.5-Small-Moe", + device_map="auto", + dtype="auto", + ) return cls.model + @require_torch_multi_accelerator + @require_torch_large_accelerator @require_bitsandbytes - @slow def test_model_21b_a3b_generation(self): - EXPECTED_TEXT_COMPLETION = "User: Hey, are you conscious? Can you talk to me?\nAssistant: I don't have consciousness in the way humans do. I'm a text-based AI created to process and generate responses based on patterns in data." # fmt: skip + EXPECTED_TEXT_COMPLETION = "User: Hey, are you conscious? Can you talk to me?\nAssistant: \nI don't have consciousness in the way humans do. I don't feel emotions, have thoughts, or experience awareness. However, I'm" # fmt: skip + + model = self.get_large_model() + tokenizer = AutoTokenizer.from_pretrained("baidu/ERNIE-4.5-21B-A3B-PT") + prompt = "Hey, are you conscious? Can you talk to me?" + messages = [{"role": "user", "content": prompt}] + text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + model_inputs = tokenizer([text], add_special_tokens=False, return_tensors="pt").to(model.device) + + generated_ids = model.generate( + model_inputs.input_ids, + max_new_tokens=32, + do_sample=False, + ) + text = tokenizer.decode(generated_ids[0], skip_special_tokens=True).strip("\n") + self.assertEqual(EXPECTED_TEXT_COMPLETION, text) + + def test_shortened_model_generation(self): + # This is gibberish which is expected as the model are the first x layers of the original 28B model + EXPECTED_TEXT_COMPLETION = 'User: Hey, are you conscious? Can you talk to me?\nAssistant: 不了的شم尔斯graveyard效应osm osmos乎哉哉哉哉 bargaining程度level打好莱坞制片制片amme瑙瑙��eka地步chansesaurian' # fmt: skip - model = self.get_model() - tokenizer = AutoTokenizer.from_pretrained("baidu/ERNIE-4.5-21B-A3B-PT", revision="refs/pr/11") + model = self.get_small_model() + tokenizer = AutoTokenizer.from_pretrained("baidu/ERNIE-4.5-21B-A3B-PT") prompt = "Hey, are you conscious? Can you talk to me?" messages = [{"role": "user", "content": prompt}] text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)