@@ -332,6 +332,7 @@ def grouped_matmul_kernel(
332332 GROUP_SIZE_M : tl .constexpr ,
333333 MUL_ROUTED_WEIGHT : tl .constexpr = False ,
334334 NEED_K_MASK : tl .constexpr = True ,
335+ NEED_TRANS : tl .constexpr = False ,
335336):
336337 pid = tl .program_id (0 )
337338
@@ -367,13 +368,6 @@ def grouped_matmul_kernel(
367368 mask = token_mask ,
368369 other = 0 ,
369370 )
370- if MUL_ROUTED_WEIGHT :
371- a_m_scale = tl .load (
372- expert_to_weights_ptr + expert_id * expert_to_weights_stride0 + offs_am ,
373- mask = token_mask ,
374- other = 0.0 ,
375- )
376-
377371 offs_bn = (tile_n_idx * BLOCK_SIZE_N + tl .arange (0 , BLOCK_SIZE_N )) % n
378372 offs_k = tl .arange (0 , BLOCK_SIZE_K )
379373
@@ -387,7 +381,7 @@ def grouped_matmul_kernel(
387381 b_scale = tl .load (weight_scale_ptr + expert_id , eviction_policy = "evict_last" )
388382 ab_scale = a_scale * b_scale
389383
390- if use_fp8_w8a8 :
384+ if NEED_TRANS :
391385 a_ptrs = token_ptr + (a_m_index // topk_num )[None , :] * token_stride_0 + offs_k [:, None ]
392386 b_ptrs = weights_ptr + weight_stride_0 * expert_id + offs_k [None , :] + offs_bn [:, None ] * weight_stride_1
393387 accumulator = tl .zeros ((BLOCK_SIZE_N , BLOCK_SIZE_M ), dtype = tl .float32 )
@@ -401,16 +395,20 @@ def grouped_matmul_kernel(
401395 # tl.multiple_of(a_ptrs, [16, 16])
402396 # tl.multiple_of(b_ptrs, [16, 16])
403397
404- if use_fp8_w8a8 :
398+ if NEED_TRANS :
405399 if NEED_K_MASK :
406- a = tl .load (a_ptrs , mask = (token_mask [None , :]) & (offs_k [:, None ] < k ), other = 0.0 )
400+ a = tl .load (
401+ a_ptrs , mask = (token_mask [None , :]) & (offs_k [:, None ] < k - step_k * BLOCK_SIZE_K ), other = 0.0
402+ )
407403 b = tl .load (b_ptrs , mask = (offs_k [None , :] < k ), other = 0.0 )
408404 else :
409405 a = tl .load (a_ptrs , mask = (token_mask [None , :]), other = 0.0 )
410406 b = tl .load (b_ptrs )
411407 else :
412408 if NEED_K_MASK :
413- a = tl .load (a_ptrs , mask = (token_mask [:, None ]) & (offs_k [None , :] < k ), other = 0.0 )
409+ a = tl .load (
410+ a_ptrs , mask = (token_mask [:, None ]) & (offs_k [None , :] < k - step_k * BLOCK_SIZE_K ), other = 0.0
411+ )
414412 b = tl .load (b_ptrs , mask = (offs_k [:, None ] < k ), other = 0.0 )
415413 else :
416414 a = tl .load (a_ptrs , mask = (token_mask [:, None ]), other = 0.0 )
@@ -421,24 +419,34 @@ def grouped_matmul_kernel(
421419 offs_ks = step_k * BLOCK_SIZE_K // block_size_k
422420 a_scale = tl .load (a_scale_ptrs + offs_ks , mask = token_mask , other = 0.0 )
423421 b_scale = tl .load (b_scale_ptrs + offs_ks * weight_scale_stride2 )
424- accumulator += tl .dot (b , a ) * b_scale [:, None ] * a_scale [None , :]
422+ if NEED_TRANS :
423+ accumulator += tl .dot (b , a ) * b_scale [:, None ] * a_scale [None , :]
424+ else :
425+ accumulator += tl .dot (a , b ) * a_scale [:, None ] * b_scale [None , :]
425426 else :
426- accumulator = tl .dot (b , a , acc = accumulator )
427+ if NEED_TRANS :
428+ accumulator = tl .dot (b , a , acc = accumulator )
429+ else :
430+ accumulator = tl .dot (a , b , acc = accumulator )
427431 else :
428432 accumulator += tl .dot (a , b )
429433
430434 a_ptrs += BLOCK_SIZE_K
431435 b_ptrs += BLOCK_SIZE_K
432- offs_k += BLOCK_SIZE_K
436+
437+ if NEED_TRANS :
438+ accumulator = accumulator .T
433439
434440 if use_fp8_w8a8 :
435- if block_size_k > 0 and block_size_n > 0 :
436- accumulator = accumulator .T
437- else :
438- accumulator = accumulator .T
441+ if not (block_size_k > 0 and block_size_n > 0 ):
439442 accumulator *= ab_scale
440443
441444 if MUL_ROUTED_WEIGHT :
445+ a_m_scale = tl .load (
446+ expert_to_weights_ptr + expert_id * expert_to_weights_stride0 + offs_am ,
447+ mask = token_mask ,
448+ other = 0.0 ,
449+ )
442450 accumulator *= a_m_scale [:, None ]
443451
444452 c = accumulator .to (compute_type )
@@ -478,13 +486,15 @@ def _get_grouped_matmul_configs():
478486 "GROUP_SIZE_M" : gm ,
479487 "num_warps" : nw ,
480488 "num_stages" : ns ,
489+ "NEED_TRANS" : need_trans ,
481490 }
482- for ns in [1 , 2 , 3 , 4 , 5 ]
483- for gm in [1 , 2 , 4 , 8 ]
484- for nw in [2 , 4 , 8 ]
491+ for ns in [2 , 3 , 4 , 5 ]
492+ for gm in [1 , 16 , 32 , 64 ]
493+ for nw in [4 , 8 ]
485494 for bm in [16 , 32 , 64 , 128 ]
486495 for bn in [16 , 32 , 64 , 128 ]
487- for bk in [16 , 32 , 64 , 128 ]
496+ for bk in [32 , 64 , 128 ]
497+ for need_trans in [True , False ]
488498 ]
489499
490500
@@ -559,6 +569,9 @@ def grouped_matmul(
559569 GROUP_SIZE_M = run_config ["GROUP_SIZE_M" ]
560570 num_warps = run_config ["num_warps" ]
561571 num_stages = run_config ["num_stages" ]
572+ NEED_TRANS = run_config .get ("NEED_TRANS" , False )
573+ if not use_fp8_w8a8 :
574+ assert NEED_TRANS is False , "only use_fp8_w8a8 mode can use NEED_TRANS to accelerate"
562575
563576 if block_size_k != 0 :
564577 # 如果使用了 block wise 量化,分块大小不能超过 block size
@@ -638,6 +651,7 @@ def grouped_matmul(
638651 GROUP_SIZE_M = GROUP_SIZE_M ,
639652 MUL_ROUTED_WEIGHT = mul_routed_weight ,
640653 NEED_K_MASK = NEED_K_MASK ,
654+ NEED_TRANS = NEED_TRANS ,
641655 num_warps = num_warps ,
642656 num_stages = num_stages ,
643657 )
0 commit comments