Skip to content

Commit 9e82c77

Browse files
authored
Fix Ernie Moe Test (#42595)
* fix * fix * rm unnecessary config * remove references
1 parent d5d8793 commit 9e82c77

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

tests/models/ernie4_5_moe/test_modeling_ernie4_5_moe.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,24 +97,27 @@ def test_load_balancing_loss(self):
9797
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
9898
config.num_labels = 3
9999
config.num_experts = 3
100-
config.expert_interval = 2
101100
config.output_router_logits = True
102101
input_ids = input_dict["input_ids"]
103-
attention_mask = input_ids.ne(1).to(torch_device)
102+
attention_mask = input_ids.ne(config.pad_token_id).to(torch_device)
104103
model = Ernie4_5_MoeForCausalLM(config)
105104
model.to(torch_device)
106105
model.eval()
107106
result = model(input_ids, attention_mask=attention_mask)
108-
self.assertEqual(result.router_logits[0].shape, (91, config.num_experts))
107+
bs, seqlen = input_ids.shape
108+
self.assertEqual(result.router_logits[0].shape, (bs * seqlen, config.num_experts))
109109
torch.testing.assert_close(result.aux_loss.cpu(), torch.tensor(2, dtype=torch.float32), rtol=1e-2, atol=1e-2)
110110

111111
# First, we make sure that adding padding tokens doesn't change the loss
112112
# loss(input_ids, attention_mask=None) == loss(input_ids + padding, attention_mask=attention_mask_with_padding)
113+
# (This length is selected from experiments)
113114
pad_length = input_ids.shape[1] * 4
114-
# Add padding tokens (assume that pad_token_id=1) to input_ids
115-
padding_block = torch.ones(input_ids.shape[0], pad_length, dtype=torch.int32).to(torch_device)
115+
# Add padding tokens to input_ids
116+
padding_block = config.pad_token_id * torch.ones(input_ids.shape[0], pad_length, dtype=torch.int32).to(
117+
torch_device
118+
)
116119
padded_input_ids = torch.cat((padding_block, input_ids), dim=1) # this is to simulate padding to the left
117-
padded_attention_mask = padded_input_ids.ne(1).to(torch_device)
120+
padded_attention_mask = padded_input_ids.ne(config.pad_token_id).to(torch_device)
118121

119122
padded_result = model(padded_input_ids, attention_mask=padded_attention_mask)
120123
torch.testing.assert_close(result.aux_loss.cpu(), padded_result.aux_loss.cpu(), rtol=1e-4, atol=1e-4)

0 commit comments

Comments
 (0)