44
55import torch
66import triton
7- import vllm
87from transformers import AutoConfig
98from lightllm .common .fused_moe .topk_select import select_experts
109from lightllm .common .fused_moe .grouped_fused_moe import fused_experts_impl
11- from vllm .model_executor .layers .fused_moe .fused_moe import fused_moe as fused_moe_vllm
12- from sglang .srt .layers .moe .fused_moe_triton .fused_moe import (
13- fused_moe as fused_moe_sglang ,
14- )
1510
1611
1712def get_model_config (model_name : str , tp_size : int ):
@@ -59,12 +54,10 @@ def get_model_config(model_name: str, tp_size: int):
5954 intermediate_size = config .intermediate_size
6055 shard_intermediate_size = 2 * intermediate_size // tp_size
6156
62- vllm_version_num = vllm .__version_tuple__ [0 ] * 100 + vllm .__version_tuple__ [1 ] * 10 + vllm .__version_tuple__ [2 ]
6357 block_shape = None
6458 if hasattr (config , "quantization_config" ) and "weight_block_size" in config .quantization_config :
6559 block_shape = config .quantization_config ["weight_block_size" ]
6660 assert len (block_shape ) == 2
67- assert vllm_version_num >= 66 , "Block-wise quantized fp8 fused_moe is only supported for VLLM>=0.6.6.post1"
6861
6962 shape_configs = {
7063 "num_experts" : E ,
@@ -131,6 +124,8 @@ def fused_moe_vllm_api(
131124 a2_scale = None ,
132125 block_shape = None ,
133126):
127+ from vllm .model_executor .layers .fused_moe .fused_moe import fused_moe as fused_moe_vllm
128+
134129 if block_shape is not None :
135130 return fused_moe_vllm (
136131 x ,
@@ -177,14 +172,21 @@ def fused_moe_sglang_api(
177172 a2_scale = None ,
178173 block_shape = None ,
179174):
175+ from sglang .srt .layers .moe .topk import TopKConfig , select_experts
176+ from sglang .srt .layers .moe .fused_moe_triton .fused_moe import (
177+ fused_moe as fused_moe_sglang ,
178+ )
179+
180+ topk_output = select_experts (
181+ hidden_states = x ,
182+ router_logits = input_gating ,
183+ topk_config = TopKConfig (top_k = topk , renormalize = False ),
184+ )
180185 return fused_moe_sglang (
181186 x ,
182187 w1 ,
183188 w2 ,
184- input_gating ,
185- topk ,
186- renormalize = True ,
187- inplace = True ,
189+ topk_output ,
188190 use_fp8_w8a8 = use_fp8_w8a8 ,
189191 w1_scale = w1_scale ,
190192 w2_scale = w2_scale ,
@@ -197,7 +199,7 @@ def fused_moe_sglang_api(
197199@triton .testing .perf_report (
198200 triton .testing .Benchmark (
199201 x_names = ["batch_size" ],
200- x_vals = [1 , 8 , 16 , 32 , 64 , 128 ],
202+ x_vals = [1 , 8 , 16 , 32 , 64 , 128 , 256 , 512 , 1024 , 2048 , 4096 , 8192 ],
201203 line_arg = "provider" ,
202204 line_vals = [
203205 "vllm_fused_moe_triton" ,
@@ -219,7 +221,7 @@ def fused_moe_sglang_api(
219221 args = {},
220222 )
221223)
222- def benchmark (batch_size , provider , model_config , use_fp8 = False ):
224+ def benchmark (batch_size , provider , model_config , use_fp8 = False , use_cuda_graph = False ):
223225 torch .set_default_device ("cuda" )
224226 torch .cuda .manual_seed_all (0 )
225227
@@ -264,9 +266,9 @@ def benchmark(batch_size, provider, model_config, use_fp8=False):
264266 api_func = (
265267 fused_moe_vllm_api
266268 if provider == "vllm_fused_moe_triton"
267- else fused_moe_sglang_api
268- if provider == "lightllm_fused_moe_triton"
269269 else fused_moe_lightllm_api
270+ if provider == "lightllm_fused_moe_triton"
271+ else fused_moe_sglang_api
270272 )
271273 for _ in range (10 ):
272274 api_func (
@@ -285,7 +287,8 @@ def benchmark(batch_size, provider, model_config, use_fp8=False):
285287 torch .cuda .synchronize ()
286288
287289 quantiles = [0.5 , 0.2 , 0.8 ]
288- ms , min_ms , max_ms = triton .testing .do_bench (
290+ do_bench = triton .testing .do_bench if not use_cuda_graph else triton .testing .do_bench_cudagraph
291+ ms , min_ms , max_ms = do_bench (
289292 lambda : api_func (
290293 x ,
291294 w1 ,
@@ -309,6 +312,7 @@ def main():
309312 parser .add_argument ("--model" , type = str , default = "mistralai/Mixtral-8x7B-Instruct-v0.1" )
310313 parser .add_argument ("--tp-size" , type = int , default = 8 )
311314 parser .add_argument ("--use-fp8" , action = "store_true" )
315+ parser .add_argument ("--use-cuda-graph" , action = "store_true" )
312316 parser .add_argument (
313317 "--save-path" ,
314318 type = str ,
@@ -323,6 +327,7 @@ def main():
323327 save_path = args .save_path ,
324328 model_config = model_config ,
325329 use_fp8 = args .use_fp8 ,
330+ use_cuda_graph = args .use_cuda_graph ,
326331 )
327332
328333
0 commit comments