Skip to content

Commit 690bcb8

Browse files
authored
[Optimization] 1.fix tp+ep moe_forward; 2.set max_prefill_batch=env.MAX_PREFILL_NUM (#5315)
1 parent f6544c0 commit 690bcb8

File tree

2 files changed

+14
-7
lines changed

2 files changed

+14
-7
lines changed

fastdeploy/config.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1586,7 +1586,11 @@ def __init__(
15861586
self.max_prefill_batch = int(os.getenv("MAX_PREFILL_NUM", "3"))
15871587
if current_platform.is_xpu():
15881588
self.max_prefill_batch = 1
1589-
if self.model_config is not None and self.model_config.enable_mm:
1589+
if (
1590+
int(envs.ENABLE_V1_KVCACHE_SCHEDULER) == 0
1591+
and self.model_config is not None
1592+
and self.model_config.enable_mm
1593+
):
15901594
self.max_prefill_batch = 1 # TODO:当前多模prefill阶段只支持并行度为1,待优化
15911595
else:
15921596
self.max_prefill_batch = self.scheduler_config.max_num_seqs

fastdeploy/model_executor/layers/moe/moe.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,9 @@ def __init__(
163163
self.tp_size = 1
164164
self.tp_rank = 0
165165

166+
self.attn_tp_size = fd_config.parallel_config.tensor_parallel_size
167+
self.attn_tp_rank = fd_config.parallel_config.tensor_parallel_rank
168+
166169
assert (self.tp_size >= 1 and self.ep_size == 1) or (
167170
self.tp_size == 1 and self.ep_size > 1
168171
), "MoE only support parallelism on TP or EP dimension."
@@ -598,18 +601,18 @@ def forward_split_allgather(self, x: paddle.Tensor, gate: nn.Layer):
598601
Forward split allgather function.
599602
"""
600603
token_num = x.shape[0]
601-
token_num_per_rank = (token_num + self.tp_size - 1) // self.tp_size
604+
token_num_per_rank = (token_num + self.attn_tp_size - 1) // self.attn_tp_size
602605
# AllGather will hang when the data shapes on multi-ranks are different!
603606
part_x = paddle.zeros(shape=[token_num_per_rank, x.shape[1]], dtype=x.dtype)
604-
start_offset = self.tp_rank * token_num_per_rank
605-
end_offset = (self.tp_rank + 1) * token_num_per_rank
607+
start_offset = self.attn_tp_rank * token_num_per_rank
608+
end_offset = (self.attn_tp_rank + 1) * token_num_per_rank
606609
if start_offset >= token_num:
607610
start_offset = token_num
608611
if end_offset > token_num:
609612
end_offset = token_num
610613
part_x[: (end_offset - start_offset), :] = x[start_offset:end_offset, :]
611614
out = self.quant_method.apply(self, part_x, gate)
612-
multi_outs = paddle.zeros([token_num_per_rank * self.tp_size, x.shape[1]], dtype=x.dtype)
615+
multi_outs = paddle.zeros([token_num_per_rank * self.attn_tp_size, x.shape[1]], dtype=x.dtype)
613616
paddle.distributed.all_gather(multi_outs, out, self.tp_group)
614617
out = multi_outs[:token_num, :]
615618

@@ -629,9 +632,9 @@ def forward(self, x: paddle.Tensor, gate: nn.Layer):
629632
token_num = x.shape[0]
630633
if (
631634
self.ep_size > 1
632-
and self.tp_size > 1
635+
and self.attn_tp_size > 1
633636
and (not self.fd_config.parallel_config.use_sequence_parallel_moe)
634-
and token_num >= self.tp_size
637+
and token_num >= self.attn_tp_size
635638
):
636639
out = self.forward_split_allgather(x, gate)
637640
elif self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.enable_chunked_moe:

0 commit comments

Comments
 (0)