@@ -97,24 +97,27 @@ def test_load_balancing_loss(self):
9797 config , input_dict = self .model_tester .prepare_config_and_inputs_for_common ()
9898 config .num_labels = 3
9999 config .num_experts = 3
100- config .expert_interval = 2
101100 config .output_router_logits = True
102101 input_ids = input_dict ["input_ids" ]
103- attention_mask = input_ids .ne (1 ).to (torch_device )
102+ attention_mask = input_ids .ne (config . pad_token_id ).to (torch_device )
104103 model = Ernie4_5_MoeForCausalLM (config )
105104 model .to (torch_device )
106105 model .eval ()
107106 result = model (input_ids , attention_mask = attention_mask )
108- self .assertEqual (result .router_logits [0 ].shape , (91 , config .num_experts ))
107+ bs , seqlen = input_ids .shape
108+ self .assertEqual (result .router_logits [0 ].shape , (bs * seqlen , config .num_experts ))
109109 torch .testing .assert_close (result .aux_loss .cpu (), torch .tensor (2 , dtype = torch .float32 ), rtol = 1e-2 , atol = 1e-2 )
110110
111111 # First, we make sure that adding padding tokens doesn't change the loss
112112 # loss(input_ids, attention_mask=None) == loss(input_ids + padding, attention_mask=attention_mask_with_padding)
113+ # (This length is selected from experiments)
113114 pad_length = input_ids .shape [1 ] * 4
114- # Add padding tokens (assume that pad_token_id=1) to input_ids
115- padding_block = torch .ones (input_ids .shape [0 ], pad_length , dtype = torch .int32 ).to (torch_device )
115+ # Add padding tokens to input_ids
116+ padding_block = config .pad_token_id * torch .ones (input_ids .shape [0 ], pad_length , dtype = torch .int32 ).to (
117+ torch_device
118+ )
116119 padded_input_ids = torch .cat ((padding_block , input_ids ), dim = 1 ) # this is to simulate padding to the left
117- padded_attention_mask = padded_input_ids .ne (1 ).to (torch_device )
120+ padded_attention_mask = padded_input_ids .ne (config . pad_token_id ).to (torch_device )
118121
119122 padded_result = model (padded_input_ids , attention_mask = padded_attention_mask )
120123 torch .testing .assert_close (result .aux_loss .cpu (), padded_result .aux_loss .cpu (), rtol = 1e-4 , atol = 1e-4 )
0 commit comments