Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/transformers/conversion_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,9 @@ def _build_checkpoint_conversion_mapping():
mapping["deepseek_v3"] = mapping["qwen2_moe"].copy()
mapping["dots1"] = mapping["qwen2_moe"].copy()
mapping["ernie4_5_moe"] = mapping["qwen2_moe"].copy()
mapping["ernie4_5_moe"] += [
WeightRenaming("mlp.moe_statics.e_score_correction_bias", "mlp.gate.moe_statics.e_score_correction_bias")
]
Comment on lines +169 to +171
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This weight was missing

mapping["glm4_moe"] = mapping["qwen2_moe"].copy()
mapping["glm4v_moe"] = mapping["qwen2_moe"].copy()
mapping["longcat_flash"] = mapping["qwen2_moe"].copy()
Expand Down
16 changes: 8 additions & 8 deletions src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,14 +373,14 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens

with torch.autocast(device_type=device_type, enabled=False): # Force float32
router_logits = F.linear(hidden_states.float(), self.weight)
router_logits = F.softmax(router_logits, dim=1, dtype=torch.float)
router_top_value, router_indices = torch.topk(self.moe_statics(router_logits), self.top_k, dim=-1)
router_top_value = router_top_value / torch.clamp(
router_top_value.sum(dim=-1, keepdim=True), min=self.norm_min
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
_, selected_experts = torch.topk(self.moe_statics(routing_weights), self.top_k, dim=-1)
routing_weights = torch.gather(routing_weights, dim=-1, index=selected_experts)
routing_weights = routing_weights / torch.clamp(
routing_weights.sum(dim=-1, keepdim=True), min=self.norm_min
)
router_scores = router_top_value
router_scores = router_scores.to(hidden_states.dtype)
return router_logits, router_scores, router_indices
routing_weights = routing_weights.to(hidden_states.dtype)
return router_logits, selected_experts, routing_weights


class Ernie4_5_MoeSparseMoeBlock(nn.Module):
Expand All @@ -403,7 +403,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states)

_, top_k_weights, top_k_index = self.gate(hidden_states)
_, top_k_index, top_k_weights = self.gate(hidden_states)
final_hidden_states = self.experts(hidden_states, top_k_index, top_k_weights)

if self.shared_experts is not None:
Expand Down
16 changes: 8 additions & 8 deletions src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,14 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens

with torch.autocast(device_type=device_type, enabled=False): # Force float32
router_logits = F.linear(hidden_states.float(), self.weight)
router_logits = F.softmax(router_logits, dim=1, dtype=torch.float)
router_top_value, router_indices = torch.topk(self.moe_statics(router_logits), self.top_k, dim=-1)
router_top_value = router_top_value / torch.clamp(
router_top_value.sum(dim=-1, keepdim=True), min=self.norm_min
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
_, selected_experts = torch.topk(self.moe_statics(routing_weights), self.top_k, dim=-1)
routing_weights = torch.gather(routing_weights, dim=-1, index=selected_experts)
routing_weights = routing_weights / torch.clamp(
routing_weights.sum(dim=-1, keepdim=True), min=self.norm_min
)
router_scores = router_top_value
router_scores = router_scores.to(hidden_states.dtype)
return router_logits, router_scores, router_indices
routing_weights = routing_weights.to(hidden_states.dtype)
return router_logits, selected_experts, routing_weights


class Ernie4_5_MoeSparseMoeBlock(nn.Module):
Expand All @@ -178,7 +178,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states)

_, top_k_weights, top_k_index = self.gate(hidden_states)
_, top_k_index, top_k_weights = self.gate(hidden_states)
final_hidden_states = self.experts(hidden_states, top_k_index, top_k_weights)

if self.shared_experts is not None:
Expand Down
2 changes: 1 addition & 1 deletion tests/models/ernie4_5_moe/test_modeling_ernie4_5_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def get_model(cls):
@require_bitsandbytes
@slow
def test_model_21b_a3b_generation(self):
EXPECTED_TEXT_COMPLETION = "User: Hey, are you conscious? Can you talk to me?\nAssistant: I don't have consciousness in the way humans do. I'm a text-based AI created to process and generate responses based on patterns in data." # fmt: skip
EXPECTED_TEXT_COMPLETION = "User: Hey, are you conscious? Can you talk to me?\nAssistant: \nI don't have consciousness in the way humans do. I don't feel emotions, have thoughts, or experience awareness. However, I'm" # fmt: skip
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a fast test for this one as well as we seem to break it often


model = self.get_model()
tokenizer = AutoTokenizer.from_pretrained("baidu/ERNIE-4.5-21B-A3B-PT", revision="refs/pr/11")
Expand Down