Skip to content

Commit df0812d

Browse files
authored
fix benchmark (#1033)
1 parent 9f2f0cf commit df0812d

File tree

1 file changed

+21
-16
lines changed

1 file changed

+21
-16
lines changed

test/benchmark/kernel/benchmark_fused_moe_triton.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,9 @@
44

55
import torch
66
import triton
7-
import vllm
87
from transformers import AutoConfig
98
from lightllm.common.fused_moe.topk_select import select_experts
109
from lightllm.common.fused_moe.grouped_fused_moe import fused_experts_impl
11-
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe as fused_moe_vllm
12-
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
13-
fused_moe as fused_moe_sglang,
14-
)
1510

1611

1712
def get_model_config(model_name: str, tp_size: int):
@@ -59,12 +54,10 @@ def get_model_config(model_name: str, tp_size: int):
5954
intermediate_size = config.intermediate_size
6055
shard_intermediate_size = 2 * intermediate_size // tp_size
6156

62-
vllm_version_num = vllm.__version_tuple__[0] * 100 + vllm.__version_tuple__[1] * 10 + vllm.__version_tuple__[2]
6357
block_shape = None
6458
if hasattr(config, "quantization_config") and "weight_block_size" in config.quantization_config:
6559
block_shape = config.quantization_config["weight_block_size"]
6660
assert len(block_shape) == 2
67-
assert vllm_version_num >= 66, "Block-wise quantized fp8 fused_moe is only supported for VLLM>=0.6.6.post1"
6861

6962
shape_configs = {
7063
"num_experts": E,
@@ -131,6 +124,8 @@ def fused_moe_vllm_api(
131124
a2_scale=None,
132125
block_shape=None,
133126
):
127+
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe as fused_moe_vllm
128+
134129
if block_shape is not None:
135130
return fused_moe_vllm(
136131
x,
@@ -177,14 +172,21 @@ def fused_moe_sglang_api(
177172
a2_scale=None,
178173
block_shape=None,
179174
):
175+
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
176+
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
177+
fused_moe as fused_moe_sglang,
178+
)
179+
180+
topk_output = select_experts(
181+
hidden_states=x,
182+
router_logits=input_gating,
183+
topk_config=TopKConfig(top_k=topk, renormalize=False),
184+
)
180185
return fused_moe_sglang(
181186
x,
182187
w1,
183188
w2,
184-
input_gating,
185-
topk,
186-
renormalize=True,
187-
inplace=True,
189+
topk_output,
188190
use_fp8_w8a8=use_fp8_w8a8,
189191
w1_scale=w1_scale,
190192
w2_scale=w2_scale,
@@ -197,7 +199,7 @@ def fused_moe_sglang_api(
197199
@triton.testing.perf_report(
198200
triton.testing.Benchmark(
199201
x_names=["batch_size"],
200-
x_vals=[1, 8, 16, 32, 64, 128],
202+
x_vals=[1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192],
201203
line_arg="provider",
202204
line_vals=[
203205
"vllm_fused_moe_triton",
@@ -219,7 +221,7 @@ def fused_moe_sglang_api(
219221
args={},
220222
)
221223
)
222-
def benchmark(batch_size, provider, model_config, use_fp8=False):
224+
def benchmark(batch_size, provider, model_config, use_fp8=False, use_cuda_graph=False):
223225
torch.set_default_device("cuda")
224226
torch.cuda.manual_seed_all(0)
225227

@@ -264,9 +266,9 @@ def benchmark(batch_size, provider, model_config, use_fp8=False):
264266
api_func = (
265267
fused_moe_vllm_api
266268
if provider == "vllm_fused_moe_triton"
267-
else fused_moe_sglang_api
268-
if provider == "lightllm_fused_moe_triton"
269269
else fused_moe_lightllm_api
270+
if provider == "lightllm_fused_moe_triton"
271+
else fused_moe_sglang_api
270272
)
271273
for _ in range(10):
272274
api_func(
@@ -285,7 +287,8 @@ def benchmark(batch_size, provider, model_config, use_fp8=False):
285287
torch.cuda.synchronize()
286288

287289
quantiles = [0.5, 0.2, 0.8]
288-
ms, min_ms, max_ms = triton.testing.do_bench(
290+
do_bench = triton.testing.do_bench if not use_cuda_graph else triton.testing.do_bench_cudagraph
291+
ms, min_ms, max_ms = do_bench(
289292
lambda: api_func(
290293
x,
291294
w1,
@@ -309,6 +312,7 @@ def main():
309312
parser.add_argument("--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1")
310313
parser.add_argument("--tp-size", type=int, default=8)
311314
parser.add_argument("--use-fp8", action="store_true")
315+
parser.add_argument("--use-cuda-graph", action="store_true")
312316
parser.add_argument(
313317
"--save-path",
314318
type=str,
@@ -323,6 +327,7 @@ def main():
323327
save_path=args.save_path,
324328
model_config=model_config,
325329
use_fp8=args.use_fp8,
330+
use_cuda_graph=args.use_cuda_graph,
326331
)
327332

328333

0 commit comments

Comments
 (0)