@@ -163,6 +163,9 @@ def __init__(
163163 self .tp_size = 1
164164 self .tp_rank = 0
165165
166+ self .attn_tp_size = fd_config .parallel_config .tensor_parallel_size
167+ self .attn_tp_rank = fd_config .parallel_config .tensor_parallel_rank
168+
166169 assert (self .tp_size >= 1 and self .ep_size == 1 ) or (
167170 self .tp_size == 1 and self .ep_size > 1
168171 ), "MoE only support parallelism on TP or EP dimension."
@@ -598,18 +601,18 @@ def forward_split_allgather(self, x: paddle.Tensor, gate: nn.Layer):
598601 Forward split allgather function.
599602 """
600603 token_num = x .shape [0 ]
601- token_num_per_rank = (token_num + self .tp_size - 1 ) // self .tp_size
604+ token_num_per_rank = (token_num + self .attn_tp_size - 1 ) // self .attn_tp_size
602605 # AllGather will hang when the data shapes on multi-ranks are different!
603606 part_x = paddle .zeros (shape = [token_num_per_rank , x .shape [1 ]], dtype = x .dtype )
604- start_offset = self .tp_rank * token_num_per_rank
605- end_offset = (self .tp_rank + 1 ) * token_num_per_rank
607+ start_offset = self .attn_tp_rank * token_num_per_rank
608+ end_offset = (self .attn_tp_rank + 1 ) * token_num_per_rank
606609 if start_offset >= token_num :
607610 start_offset = token_num
608611 if end_offset > token_num :
609612 end_offset = token_num
610613 part_x [: (end_offset - start_offset ), :] = x [start_offset :end_offset , :]
611614 out = self .quant_method .apply (self , part_x , gate )
612- multi_outs = paddle .zeros ([token_num_per_rank * self .tp_size , x .shape [1 ]], dtype = x .dtype )
615+ multi_outs = paddle .zeros ([token_num_per_rank * self .attn_tp_size , x .shape [1 ]], dtype = x .dtype )
613616 paddle .distributed .all_gather (multi_outs , out , self .tp_group )
614617 out = multi_outs [:token_num , :]
615618
@@ -629,9 +632,9 @@ def forward(self, x: paddle.Tensor, gate: nn.Layer):
629632 token_num = x .shape [0 ]
630633 if (
631634 self .ep_size > 1
632- and self .tp_size > 1
635+ and self .attn_tp_size > 1
633636 and (not self .fd_config .parallel_config .use_sequence_parallel_moe )
634- and token_num >= self .tp_size
637+ and token_num >= self .attn_tp_size
635638 ):
636639 out = self .forward_split_allgather (x , gate )
637640 elif self .fd_config .parallel_config .use_ep and self .fd_config .parallel_config .enable_chunked_moe :
0 commit comments