Skip to content

Commit 7781fb7

Browse files
authored
group deepgemm update api (#1035)
1 parent df0812d commit 7781fb7

16 files changed

+1546
-20
lines changed

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44
from typing import Optional, Tuple, List, Dict, Any
55
from lightllm.utils.dist_utils import get_global_world_size, get_global_rank, get_current_device_id
66
from .base_weight import BaseWeight
7-
from lightllm.common.fused_moe.grouped_fused_moe_ep import fused_experts_impl, masked_group_gemm
7+
from lightllm.common.fused_moe.grouped_fused_moe_ep import (
8+
fused_experts_impl,
9+
masked_group_gemm,
10+
_deepgemm_grouped_fp8_nt_contiguous,
11+
)
812
from lightllm.common.fused_moe.moe_silu_and_mul import silu_and_mul_fwd
913
from lightllm.distributed import dist_group_manager
1014
from lightllm.common.fused_moe.topk_select import select_experts
@@ -23,11 +27,6 @@
2327

2428
logger = init_logger(__name__)
2529

26-
try:
27-
import deep_gemm
28-
except:
29-
logger.warning("no deepep or deep_gemm")
30-
3130

3231
class FusedMoeWeightEP(BaseWeight):
3332
def __init__(
@@ -336,7 +335,7 @@ def prefilled_group_gemm(
336335
# groupgemm (contiguous layout)
337336
gemm_out_a = torch.empty((all_tokens, N), device=device, dtype=hidden_dtype)
338337

339-
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(input_tensor, (w1, w1_scale), gemm_out_a, m_indices)
338+
_deepgemm_grouped_fp8_nt_contiguous(input_tensor, (w1, w1_scale), gemm_out_a, m_indices)
340339

341340
# silu_and_mul_fwd + qaunt
342341
# TODO fused kernel
@@ -350,9 +349,7 @@ def prefilled_group_gemm(
350349
# groupgemm (contiguous layout)
351350
gemm_out_b = torch.empty((all_tokens, K), device=device, dtype=hidden_dtype)
352351

353-
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
354-
(qsilu_out, qsilu_out_scale), (w2, w2_scale), gemm_out_b, m_indices
355-
)
352+
_deepgemm_grouped_fp8_nt_contiguous((qsilu_out, qsilu_out_scale), (w2, w2_scale), gemm_out_b, m_indices)
356353
# gather and local reduce
357354
ep_gather(gemm_out_b, recv_topk_idx, recv_topk_weights, output_index, gather_out)
358355
else:

lightllm/common/fused_moe/grouped_fused_moe_ep.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,14 @@
2424
from deep_ep import Buffer, EventOverlap
2525
import deep_gemm
2626

27+
HAS_DEEPGEMM = True
2728
except:
2829
logger.warning("no deepep or deep_gemm")
30+
HAS_DEEPGEMM = False
2931

3032

3133
def masked_group_gemm(
32-
recv_x: Tuple[torch.Tensor],
34+
recv_x: Tuple[torch.Tensor, torch.Tensor],
3335
masked_m: torch.Tensor,
3436
dtype: torch.dtype,
3537
w1: torch.Tensor,
@@ -49,12 +51,10 @@ def masked_group_gemm(
4951
# groupgemm (masked layout)
5052
gemm_out_b = torch.empty_like(recv_x[0], device=recv_x[0].device, dtype=dtype)
5153

52-
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(recv_x, (w1, w1_scale), gemm_out_a, masked_m, expected_m)
54+
_deepgemm_grouped_fp8_nt_masked(recv_x, (w1, w1_scale), gemm_out_a, masked_m, expected_m)
5355

5456
silu_and_mul_masked_post_quant_fwd(gemm_out_a, qsilu_out, qsilu_out_scale, block_size, masked_m)
55-
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
56-
(qsilu_out, qsilu_out_scale), (w2, w2_scale), gemm_out_b, masked_m, expected_m
57-
)
57+
_deepgemm_grouped_fp8_nt_masked((qsilu_out, qsilu_out_scale), (w2, w2_scale), gemm_out_b, masked_m, expected_m)
5858
return gemm_out_b
5959

6060

@@ -168,7 +168,7 @@ def fused_experts_impl(
168168
# groupgemm (contiguous layout)
169169
gemm_out_a = torch.empty((all_tokens, N), device=hidden_states.device, dtype=hidden_states.dtype)
170170
input_tensor[1] = tma_align_input_scale(input_tensor[1])
171-
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(input_tensor, (w1, w1_scale), gemm_out_a, m_indices)
171+
_deepgemm_grouped_fp8_nt_contiguous(input_tensor, (w1, w1_scale), gemm_out_a, m_indices)
172172

173173
# silu_and_mul_fwd + qaunt
174174
# TODO fused kernel
@@ -182,9 +182,7 @@ def fused_experts_impl(
182182
# groupgemm (contiguous layout)
183183
gemm_out_b = torch.empty((all_tokens, K), device=hidden_states.device, dtype=hidden_states.dtype)
184184

185-
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
186-
(qsilu_out, qsilu_out_scale), (w2, w2_scale), gemm_out_b, m_indices
187-
)
185+
_deepgemm_grouped_fp8_nt_contiguous((qsilu_out, qsilu_out_scale), (w2, w2_scale), gemm_out_b, m_indices)
188186

189187
# gather and local reduce
190188
ep_gather(gemm_out_b, recv_topk_idx, recv_topk_weights, output_index, gather_out)
@@ -227,3 +225,32 @@ def fused_experts_impl(
227225
gemm_out_b, topk_idx, topk_weights, handle, async_finish=False, return_recv_hook=False
228226
)
229227
return combined_x
228+
229+
230+
def _deepgemm_grouped_fp8_nt_contiguous(
231+
input_tuple: Tuple[torch.Tensor, torch.Tensor],
232+
w_tuple: Tuple[torch.Tensor, torch.Tensor],
233+
out: torch.Tensor,
234+
m_indices: torch.Tensor,
235+
):
236+
if HAS_DEEPGEMM:
237+
if hasattr(deep_gemm, "m_grouped_gemm_fp8_fp8_bf16_nt_contiguous"):
238+
return deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(input_tuple, w_tuple, out, m_indices)
239+
if hasattr(deep_gemm, "m_grouped_fp8_gemm_nt_contiguous"):
240+
return deep_gemm.m_grouped_fp8_gemm_nt_contiguous(input_tuple, w_tuple, out, m_indices)
241+
raise RuntimeError("deep_gemm does not provide grouped_gemm_fp8 NT contiguous GEMM kernel in this version")
242+
243+
244+
def _deepgemm_grouped_fp8_nt_masked(
245+
input_tuple: Tuple[torch.Tensor, torch.Tensor],
246+
w_tuple: Tuple[torch.Tensor, torch.Tensor],
247+
out: torch.Tensor,
248+
masked_m: torch.Tensor,
249+
expected_m: int,
250+
):
251+
if HAS_DEEPGEMM:
252+
if hasattr(deep_gemm, "m_grouped_fp8_gemm_nt_masked"):
253+
return deep_gemm.m_grouped_fp8_gemm_nt_masked(input_tuple, w_tuple, out, masked_m, expected_m)
254+
if hasattr(deep_gemm, "m_grouped_gemm_fp8_fp8_bf16_nt_masked"):
255+
return deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(input_tuple, w_tuple, out, masked_m, expected_m)
256+
raise RuntimeError("deep_gemm does not provide grouped_gemm_fp8 NT contiguous GEMM kernel in this version")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
{
2+
"1": {
3+
"BLOCK_DIM": 128,
4+
"BLOCK_M": 4,
5+
"NUM_STAGE": 1,
6+
"num_warps": 4
7+
},
8+
"100": {
9+
"BLOCK_DIM": 512,
10+
"BLOCK_M": 1,
11+
"NUM_STAGE": 4,
12+
"num_warps": 1
13+
},
14+
"1024": {
15+
"BLOCK_DIM": 1024,
16+
"BLOCK_M": 4,
17+
"NUM_STAGE": 4,
18+
"num_warps": 4
19+
},
20+
"128": {
21+
"BLOCK_DIM": 1024,
22+
"BLOCK_M": 1,
23+
"NUM_STAGE": 1,
24+
"num_warps": 4
25+
},
26+
"16": {
27+
"BLOCK_DIM": 512,
28+
"BLOCK_M": 1,
29+
"NUM_STAGE": 1,
30+
"num_warps": 8
31+
},
32+
"16384": {
33+
"BLOCK_DIM": 1024,
34+
"BLOCK_M": 1,
35+
"NUM_STAGE": 4,
36+
"num_warps": 4
37+
},
38+
"2048": {
39+
"BLOCK_DIM": 512,
40+
"BLOCK_M": 1,
41+
"NUM_STAGE": 4,
42+
"num_warps": 1
43+
},
44+
"256": {
45+
"BLOCK_DIM": 1024,
46+
"BLOCK_M": 1,
47+
"NUM_STAGE": 2,
48+
"num_warps": 4
49+
},
50+
"32": {
51+
"BLOCK_DIM": 1024,
52+
"BLOCK_M": 1,
53+
"NUM_STAGE": 4,
54+
"num_warps": 4
55+
},
56+
"4096": {
57+
"BLOCK_DIM": 512,
58+
"BLOCK_M": 1,
59+
"NUM_STAGE": 4,
60+
"num_warps": 2
61+
},
62+
"64": {
63+
"BLOCK_DIM": 512,
64+
"BLOCK_M": 1,
65+
"NUM_STAGE": 4,
66+
"num_warps": 1
67+
},
68+
"8": {
69+
"BLOCK_DIM": 256,
70+
"BLOCK_M": 1,
71+
"NUM_STAGE": 1,
72+
"num_warps": 4
73+
}
74+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
{
2+
"1024": {
3+
"BLOCK_SIZE_K": 128,
4+
"BLOCK_SIZE_M": 16,
5+
"BLOCK_SIZE_N": 128,
6+
"GROUP_SIZE_M": 64,
7+
"NEED_TRANS": true,
8+
"num_stages": 2,
9+
"num_warps": 4
10+
},
11+
"128": {
12+
"BLOCK_SIZE_K": 128,
13+
"BLOCK_SIZE_M": 16,
14+
"BLOCK_SIZE_N": 128,
15+
"GROUP_SIZE_M": 64,
16+
"NEED_TRANS": true,
17+
"num_stages": 2,
18+
"num_warps": 4
19+
},
20+
"131072": {
21+
"BLOCK_SIZE_K": 128,
22+
"BLOCK_SIZE_M": 64,
23+
"BLOCK_SIZE_N": 128,
24+
"GROUP_SIZE_M": 16,
25+
"NEED_TRANS": false,
26+
"num_stages": 3,
27+
"num_warps": 4
28+
},
29+
"16384": {
30+
"BLOCK_SIZE_K": 128,
31+
"BLOCK_SIZE_M": 64,
32+
"BLOCK_SIZE_N": 128,
33+
"GROUP_SIZE_M": 64,
34+
"NEED_TRANS": false,
35+
"num_stages": 3,
36+
"num_warps": 4
37+
},
38+
"2048": {
39+
"BLOCK_SIZE_K": 128,
40+
"BLOCK_SIZE_M": 16,
41+
"BLOCK_SIZE_N": 128,
42+
"GROUP_SIZE_M": 64,
43+
"NEED_TRANS": true,
44+
"num_stages": 2,
45+
"num_warps": 4
46+
},
47+
"256": {
48+
"BLOCK_SIZE_K": 128,
49+
"BLOCK_SIZE_M": 16,
50+
"BLOCK_SIZE_N": 128,
51+
"GROUP_SIZE_M": 64,
52+
"NEED_TRANS": true,
53+
"num_stages": 2,
54+
"num_warps": 4
55+
},
56+
"32": {
57+
"BLOCK_SIZE_K": 128,
58+
"BLOCK_SIZE_M": 16,
59+
"BLOCK_SIZE_N": 128,
60+
"GROUP_SIZE_M": 64,
61+
"NEED_TRANS": true,
62+
"num_stages": 2,
63+
"num_warps": 4
64+
},
65+
"32768": {
66+
"BLOCK_SIZE_K": 128,
67+
"BLOCK_SIZE_M": 64,
68+
"BLOCK_SIZE_N": 128,
69+
"GROUP_SIZE_M": 32,
70+
"NEED_TRANS": false,
71+
"num_stages": 3,
72+
"num_warps": 4
73+
},
74+
"512": {
75+
"BLOCK_SIZE_K": 128,
76+
"BLOCK_SIZE_M": 16,
77+
"BLOCK_SIZE_N": 128,
78+
"GROUP_SIZE_M": 64,
79+
"NEED_TRANS": true,
80+
"num_stages": 2,
81+
"num_warps": 4
82+
},
83+
"64": {
84+
"BLOCK_SIZE_K": 128,
85+
"BLOCK_SIZE_M": 16,
86+
"BLOCK_SIZE_N": 128,
87+
"GROUP_SIZE_M": 64,
88+
"NEED_TRANS": true,
89+
"num_stages": 2,
90+
"num_warps": 4
91+
},
92+
"8": {
93+
"BLOCK_SIZE_K": 64,
94+
"BLOCK_SIZE_M": 16,
95+
"BLOCK_SIZE_N": 128,
96+
"GROUP_SIZE_M": 64,
97+
"NEED_TRANS": true,
98+
"num_stages": 3,
99+
"num_warps": 4
100+
},
101+
"800": {
102+
"BLOCK_SIZE_K": 128,
103+
"BLOCK_SIZE_M": 16,
104+
"BLOCK_SIZE_N": 128,
105+
"GROUP_SIZE_M": 32,
106+
"NEED_TRANS": true,
107+
"num_stages": 2,
108+
"num_warps": 4
109+
},
110+
"8192": {
111+
"BLOCK_SIZE_K": 128,
112+
"BLOCK_SIZE_M": 64,
113+
"BLOCK_SIZE_N": 128,
114+
"GROUP_SIZE_M": 64,
115+
"NEED_TRANS": false,
116+
"num_stages": 3,
117+
"num_warps": 4
118+
}
119+
}

0 commit comments

Comments
 (0)