Skip to content

Commit 9f2f0cf

Browse files
authored
tuning optimization (#1032)
1 parent f9a3fe2 commit 9f2f0cf

15 files changed

+392
-73
lines changed

docker/Dockerfile

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ RUN pip install -r /lightllm/requirements.txt --no-cache-dir
3939

4040
RUN pip install --no-cache-dir vllm --pre --extra-index-url https://wheels.vllm.ai/nightly
4141

42-
RUN git clone https://github.com/ModelTC/LightKernel.git && cd LightKernel && pip install --no-deps -v .
42+
# TODO: offline compile
43+
# RUN git clone https://github.com/ModelTC/LightKernel.git && cd LightKernel && pip install --no-deps -v .
4344

4445
RUN apt-get update && apt-get install -y libnuma-dev # for sgl_kernel
4546

docker/Dockerfile.deepep

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ RUN pip install -r /lightllm/requirements.txt --no-cache-dir
3939

4040
RUN pip install --no-cache-dir vllm --pre --extra-index-url https://wheels.vllm.ai/nightly
4141

42-
RUN git clone https://github.com/ModelTC/LightKernel.git && cd LightKernel && pip install --no-deps -v .
42+
# TODO: offline compile
43+
# RUN git clone https://github.com/ModelTC/LightKernel.git && cd LightKernel && pip install --no-deps -v .
4344

4445
RUN apt-get update && apt-get install -y libnuma-dev wget devscripts debhelper dh-make build-essential dkms
4546
RUN apt-get install -y ibverbs-providers infiniband-diags perftest rdma-core libibverbs-dev librdmacm-dev

lightllm/common/fused_moe/grouped_fused_moe.py

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,7 @@ def grouped_matmul_kernel(
332332
GROUP_SIZE_M: tl.constexpr,
333333
MUL_ROUTED_WEIGHT: tl.constexpr = False,
334334
NEED_K_MASK: tl.constexpr = True,
335+
NEED_TRANS: tl.constexpr = False,
335336
):
336337
pid = tl.program_id(0)
337338

@@ -367,13 +368,6 @@ def grouped_matmul_kernel(
367368
mask=token_mask,
368369
other=0,
369370
)
370-
if MUL_ROUTED_WEIGHT:
371-
a_m_scale = tl.load(
372-
expert_to_weights_ptr + expert_id * expert_to_weights_stride0 + offs_am,
373-
mask=token_mask,
374-
other=0.0,
375-
)
376-
377371
offs_bn = (tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % n
378372
offs_k = tl.arange(0, BLOCK_SIZE_K)
379373

@@ -387,7 +381,7 @@ def grouped_matmul_kernel(
387381
b_scale = tl.load(weight_scale_ptr + expert_id, eviction_policy="evict_last")
388382
ab_scale = a_scale * b_scale
389383

390-
if use_fp8_w8a8:
384+
if NEED_TRANS:
391385
a_ptrs = token_ptr + (a_m_index // topk_num)[None, :] * token_stride_0 + offs_k[:, None]
392386
b_ptrs = weights_ptr + weight_stride_0 * expert_id + offs_k[None, :] + offs_bn[:, None] * weight_stride_1
393387
accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_M), dtype=tl.float32)
@@ -401,16 +395,20 @@ def grouped_matmul_kernel(
401395
# tl.multiple_of(a_ptrs, [16, 16])
402396
# tl.multiple_of(b_ptrs, [16, 16])
403397

404-
if use_fp8_w8a8:
398+
if NEED_TRANS:
405399
if NEED_K_MASK:
406-
a = tl.load(a_ptrs, mask=(token_mask[None, :]) & (offs_k[:, None] < k), other=0.0)
400+
a = tl.load(
401+
a_ptrs, mask=(token_mask[None, :]) & (offs_k[:, None] < k - step_k * BLOCK_SIZE_K), other=0.0
402+
)
407403
b = tl.load(b_ptrs, mask=(offs_k[None, :] < k), other=0.0)
408404
else:
409405
a = tl.load(a_ptrs, mask=(token_mask[None, :]), other=0.0)
410406
b = tl.load(b_ptrs)
411407
else:
412408
if NEED_K_MASK:
413-
a = tl.load(a_ptrs, mask=(token_mask[:, None]) & (offs_k[None, :] < k), other=0.0)
409+
a = tl.load(
410+
a_ptrs, mask=(token_mask[:, None]) & (offs_k[None, :] < k - step_k * BLOCK_SIZE_K), other=0.0
411+
)
414412
b = tl.load(b_ptrs, mask=(offs_k[:, None] < k), other=0.0)
415413
else:
416414
a = tl.load(a_ptrs, mask=(token_mask[:, None]), other=0.0)
@@ -421,24 +419,34 @@ def grouped_matmul_kernel(
421419
offs_ks = step_k * BLOCK_SIZE_K // block_size_k
422420
a_scale = tl.load(a_scale_ptrs + offs_ks, mask=token_mask, other=0.0)
423421
b_scale = tl.load(b_scale_ptrs + offs_ks * weight_scale_stride2)
424-
accumulator += tl.dot(b, a) * b_scale[:, None] * a_scale[None, :]
422+
if NEED_TRANS:
423+
accumulator += tl.dot(b, a) * b_scale[:, None] * a_scale[None, :]
424+
else:
425+
accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
425426
else:
426-
accumulator = tl.dot(b, a, acc=accumulator)
427+
if NEED_TRANS:
428+
accumulator = tl.dot(b, a, acc=accumulator)
429+
else:
430+
accumulator = tl.dot(a, b, acc=accumulator)
427431
else:
428432
accumulator += tl.dot(a, b)
429433

430434
a_ptrs += BLOCK_SIZE_K
431435
b_ptrs += BLOCK_SIZE_K
432-
offs_k += BLOCK_SIZE_K
436+
437+
if NEED_TRANS:
438+
accumulator = accumulator.T
433439

434440
if use_fp8_w8a8:
435-
if block_size_k > 0 and block_size_n > 0:
436-
accumulator = accumulator.T
437-
else:
438-
accumulator = accumulator.T
441+
if not (block_size_k > 0 and block_size_n > 0):
439442
accumulator *= ab_scale
440443

441444
if MUL_ROUTED_WEIGHT:
445+
a_m_scale = tl.load(
446+
expert_to_weights_ptr + expert_id * expert_to_weights_stride0 + offs_am,
447+
mask=token_mask,
448+
other=0.0,
449+
)
442450
accumulator *= a_m_scale[:, None]
443451

444452
c = accumulator.to(compute_type)
@@ -478,13 +486,15 @@ def _get_grouped_matmul_configs():
478486
"GROUP_SIZE_M": gm,
479487
"num_warps": nw,
480488
"num_stages": ns,
489+
"NEED_TRANS": need_trans,
481490
}
482-
for ns in [1, 2, 3, 4, 5]
483-
for gm in [1, 2, 4, 8]
484-
for nw in [2, 4, 8]
491+
for ns in [2, 3, 4, 5]
492+
for gm in [1, 16, 32, 64]
493+
for nw in [4, 8]
485494
for bm in [16, 32, 64, 128]
486495
for bn in [16, 32, 64, 128]
487-
for bk in [16, 32, 64, 128]
496+
for bk in [32, 64, 128]
497+
for need_trans in [True, False]
488498
]
489499

490500

@@ -559,6 +569,9 @@ def grouped_matmul(
559569
GROUP_SIZE_M = run_config["GROUP_SIZE_M"]
560570
num_warps = run_config["num_warps"]
561571
num_stages = run_config["num_stages"]
572+
NEED_TRANS = run_config.get("NEED_TRANS", False)
573+
if not use_fp8_w8a8:
574+
assert NEED_TRANS is False, "only use_fp8_w8a8 mode can use NEED_TRANS to accelerate"
562575

563576
if block_size_k != 0:
564577
# 如果使用了 block wise 量化,分块大小不能超过 block size
@@ -638,6 +651,7 @@ def grouped_matmul(
638651
GROUP_SIZE_M=GROUP_SIZE_M,
639652
MUL_ROUTED_WEIGHT=mul_routed_weight,
640653
NEED_K_MASK=NEED_K_MASK,
654+
NEED_TRANS=NEED_TRANS,
641655
num_warps=num_warps,
642656
num_stages=num_stages,
643657
)

lightllm/common/fused_moe/grouped_topk.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def grouped_topk_kernel(
140140
offs_group = tl.arange(0, EXPERT_GROUP_NUM)
141141
offs_group_v = tl.arange(0, EXPERT_GROUP_SIZE)
142142
tl.store(scores_buffer_ptr + scores_stride_m * token_index + offs_n, scores, mask=offs_n < total_expert_num)
143+
tl.debug_barrier()
143144
group_scores = tl.load(
144145
scores_buffer_ptr
145146
+ scores_stride_token_m * token_index
@@ -174,7 +175,7 @@ def grouped_topk_kernel(
174175
mask_group_scores,
175176
mask=((offs_group < group_num)[:, None]) & ((offs_group_v < group_expert_num)[None, :]),
176177
) # [group, group_size]
177-
178+
tl.debug_barrier()
178179
mask_scores = tl.load(
179180
scores_buffer_ptr + scores_stride_m * token_index + offs_n, mask=offs_n < total_expert_num, other=-10000000.0
180181
)
@@ -227,6 +228,11 @@ def triton_grouped_topk(
227228

228229
assert total_expert_num % num_expert_group == 0
229230

231+
if token_num <= 256:
232+
num_warps = 4
233+
else:
234+
num_warps = 1
235+
230236
grouped_topk_kernel[(token_num,)](
231237
gating_output,
232238
*gating_output.stride(),
@@ -250,7 +256,7 @@ def triton_grouped_topk(
250256
EXPERT_GROUP_SIZE=triton.next_power_of_2(total_expert_num // num_expert_group),
251257
RENORMALIZE=renormalize,
252258
GROUP_SCORE_USED_TOPK_NUM=group_score_used_topk_num,
253-
num_warps=1,
259+
num_warps=num_warps,
254260
num_stages=1,
255261
)
256262
return out_topk_weights, out_topk_ids

lightllm/common/fused_moe/moe_kernel_configs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def try_to_get_best_config(
4646
"BLOCK_SIZE_N": 32,
4747
"BLOCK_SIZE_K": 64,
4848
"GROUP_SIZE_M": 1,
49+
"NEED_TRANS": False,
4950
"num_warps": 4,
5051
"num_stages": 1,
5152
}
@@ -55,6 +56,7 @@ def try_to_get_best_config(
5556
"BLOCK_SIZE_N": 64,
5657
"BLOCK_SIZE_K": 32,
5758
"GROUP_SIZE_M": 8,
59+
"NEED_TRANS": False,
5860
"num_warps": 4,
5961
"num_stages": 1,
6062
}

lightllm/common/fused_moe/moe_silu_and_mul.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def _get_silu_and_mul_configs():
6868
{"BLOCK_M": bm, "BLOCK_N": bn, "num_warps": nw, "NUM_STAGES": ns}
6969
for ns in [1, 2, 4]
7070
for nw in [1, 4, 8]
71-
for bm in [32, 64, 128, 256]
71+
for bm in [1, 8, 32, 64, 128, 256]
7272
for bn in [32, 64, 128, 256]
7373
]
7474

lightllm/common/quantization/deepgemm_quant.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,5 +66,14 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_
6666

6767
if out is None:
6868
out = alloc_func((m, n), dtype=input_tensor.dtype, device=input_tensor.device)
69-
deep_gemm.gemm_fp8_fp8_bf16_nt([qinput_tensor, input_scale], [qweight.t(), weight_scale.t()], out)
69+
_deepgemm_fp8_nt((qinput_tensor, input_scale), (qweight.t(), weight_scale.t()), out)
7070
return out
71+
72+
73+
def _deepgemm_fp8_nt(a_tuple, b_tuple, out):
74+
if HAS_DEEPGEMM:
75+
if hasattr(deep_gemm, "gemm_fp8_fp8_bf16_nt"):
76+
return deep_gemm.gemm_fp8_fp8_bf16_nt([a_tuple[0], a_tuple[1]], [b_tuple[0], b_tuple[1]], out)
77+
if hasattr(deep_gemm, "fp8_gemm_nt"):
78+
return deep_gemm.fp8_gemm_nt((a_tuple[0], a_tuple[1]), (b_tuple[0], b_tuple[1]), out)
79+
raise RuntimeError("deep_gemm does not provide fp8 NT GEMM kernel in this version")

lightllm/common/triton_utils/autotune_kernel_configs/triton_3.3.1/NVIDIA_H200/grouped_matmul:v1/{K=256,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,80 +3,108 @@
33
"BLOCK_SIZE_K": 128,
44
"BLOCK_SIZE_M": 16,
55
"BLOCK_SIZE_N": 128,
6-
"GROUP_SIZE_M": 8,
6+
"GROUP_SIZE_M": 32,
7+
"NEED_TRANS": true,
78
"num_stages": 2,
89
"num_warps": 4
910
},
1011
"128": {
1112
"BLOCK_SIZE_K": 128,
1213
"BLOCK_SIZE_M": 16,
1314
"BLOCK_SIZE_N": 128,
14-
"GROUP_SIZE_M": 8,
15+
"GROUP_SIZE_M": 64,
16+
"NEED_TRANS": true,
1517
"num_stages": 2,
1618
"num_warps": 4
1719
},
1820
"131072": {
1921
"BLOCK_SIZE_K": 128,
20-
"BLOCK_SIZE_M": 128,
21-
"BLOCK_SIZE_N": 64,
22-
"GROUP_SIZE_M": 2,
22+
"BLOCK_SIZE_M": 64,
23+
"BLOCK_SIZE_N": 128,
24+
"GROUP_SIZE_M": 16,
25+
"NEED_TRANS": false,
2326
"num_stages": 3,
2427
"num_warps": 4
2528
},
2629
"16384": {
2730
"BLOCK_SIZE_K": 128,
2831
"BLOCK_SIZE_M": 64,
2932
"BLOCK_SIZE_N": 128,
30-
"GROUP_SIZE_M": 8,
33+
"GROUP_SIZE_M": 32,
34+
"NEED_TRANS": false,
3135
"num_stages": 3,
3236
"num_warps": 4
3337
},
3438
"2048": {
3539
"BLOCK_SIZE_K": 128,
3640
"BLOCK_SIZE_M": 16,
3741
"BLOCK_SIZE_N": 128,
38-
"GROUP_SIZE_M": 8,
42+
"GROUP_SIZE_M": 64,
43+
"NEED_TRANS": true,
44+
"num_stages": 2,
45+
"num_warps": 4
46+
},
47+
"256": {
48+
"BLOCK_SIZE_K": 128,
49+
"BLOCK_SIZE_M": 16,
50+
"BLOCK_SIZE_N": 128,
51+
"GROUP_SIZE_M": 64,
52+
"NEED_TRANS": true,
3953
"num_stages": 2,
4054
"num_warps": 4
4155
},
4256
"32768": {
4357
"BLOCK_SIZE_K": 128,
4458
"BLOCK_SIZE_M": 64,
4559
"BLOCK_SIZE_N": 128,
46-
"GROUP_SIZE_M": 8,
60+
"GROUP_SIZE_M": 32,
61+
"NEED_TRANS": false,
4762
"num_stages": 3,
4863
"num_warps": 4
4964
},
5065
"512": {
5166
"BLOCK_SIZE_K": 128,
5267
"BLOCK_SIZE_M": 16,
5368
"BLOCK_SIZE_N": 128,
54-
"GROUP_SIZE_M": 8,
69+
"GROUP_SIZE_M": 64,
70+
"NEED_TRANS": true,
5571
"num_stages": 2,
5672
"num_warps": 4
5773
},
5874
"64": {
5975
"BLOCK_SIZE_K": 128,
6076
"BLOCK_SIZE_M": 16,
6177
"BLOCK_SIZE_N": 128,
62-
"GROUP_SIZE_M": 2,
78+
"GROUP_SIZE_M": 64,
79+
"NEED_TRANS": true,
6380
"num_stages": 2,
6481
"num_warps": 4
6582
},
6683
"8": {
67-
"BLOCK_SIZE_K": 32,
84+
"BLOCK_SIZE_K": 64,
6885
"BLOCK_SIZE_M": 16,
6986
"BLOCK_SIZE_N": 128,
70-
"GROUP_SIZE_M": 1,
87+
"GROUP_SIZE_M": 64,
88+
"NEED_TRANS": true,
7189
"num_stages": 3,
72-
"num_warps": 2
90+
"num_warps": 4
7391
},
74-
"8192": {
92+
"800": {
7593
"BLOCK_SIZE_K": 128,
7694
"BLOCK_SIZE_M": 16,
7795
"BLOCK_SIZE_N": 128,
78-
"GROUP_SIZE_M": 8,
96+
"GROUP_SIZE_M": 32,
97+
"NEED_TRANS": true,
7998
"num_stages": 2,
8099
"num_warps": 4
100+
},
101+
"8192": {
102+
"BLOCK_SIZE_K": 128,
103+
"BLOCK_SIZE_M": 64,
104+
"BLOCK_SIZE_N": 128,
105+
"GROUP_SIZE_M": 64,
106+
"NEED_TRANS": false,
107+
"num_stages": 3,
108+
"num_warps": 4
81109
}
82110
}

0 commit comments

Comments
 (0)