2525from lightllm .common .flash_attn import flash_attn_with_kvcache_mtp
2626from lightllm .utils .bench_utils import do_bench
2727
28+
2829def scaled_dot_product_attention (query , key , value , h_q , h_kv , is_causal = False ):
2930 query = query .float ()
3031 key = key .float ()
@@ -36,8 +37,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
3637 s_q = query .shape [- 2 ]
3738 s_k = key .shape [- 2 ]
3839 attn_bias = torch .zeros (s_q , s_k , dtype = query .dtype , device = query .device )
39- temp_mask = torch .ones (
40- s_q , s_k , dtype = torch .bool , device = query .device ).tril (diagonal = s_k - s_q )
40+ temp_mask = torch .ones (s_q , s_k , dtype = torch .bool , device = query .device ).tril (diagonal = s_k - s_q )
4141 attn_bias .masked_fill_ (temp_mask .logical_not (), float ("-inf" ))
4242 attn_bias .to (query .dtype )
4343 attn_weight += attn_bias
@@ -47,8 +47,9 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
4747
4848
4949@torch .inference_mode ()
50- def run_torch_mla (q , block_table , blocked_k , max_seqlen_pad , block_size , b , s_q , cache_seqlens , h_q ,
51- h_kv , d , dv , causal , dtype ):
50+ def run_torch_mla (
51+ q , block_table , blocked_k , max_seqlen_pad , block_size , b , s_q , cache_seqlens , h_q , h_kv , d , dv , causal , dtype
52+ ):
5253 # q: [b, s_q, h_q, d]
5354 # block_table: [b, max_seqlen_pad // block_size]
5455 # blocked_k: [b * max_seqlen_pad // block_size, block_size, h_kv, d]
@@ -77,27 +78,35 @@ def ref_mla():
7778 return out_torch
7879
7980
80- def run_fa3_mla_mtp (mtp_size , q , block_table , blocked_k , max_seqlen_pad , block_size , b , s_q , cache_seqlens ,
81- h_q , h_kv , d , dv , causal , dtype ):
81+ def run_fa3_mla_mtp (
82+ mtp_size ,
83+ q ,
84+ block_table ,
85+ blocked_k ,
86+ max_seqlen_pad ,
87+ block_size ,
88+ b ,
89+ s_q ,
90+ cache_seqlens ,
91+ h_q ,
92+ h_kv ,
93+ d ,
94+ dv ,
95+ causal ,
96+ dtype ,
97+ ):
8298
8399 assert d > dv , "mla with rope dim should be larger than no rope dim"
84100 q_nope , q_pe = q [..., :dv ].contiguous (), q [..., dv :].contiguous ()
85- blocked_k_nope , blocked_k_pe = blocked_k [..., :dv ].contiguous (), blocked_k [...,
86- dv :].contiguous ()
101+ blocked_k_nope , blocked_k_pe = blocked_k [..., :dv ].contiguous (), blocked_k [..., dv :].contiguous ()
87102
88103 dpe = d - dv
89- num_kv_splits = 1
90-
91- out_partial = torch .empty (b , h_q , num_kv_splits , dv , dtype = dtype , device = q .device )
92- glse = torch .empty (b , h_q , num_kv_splits , dtype = dtype , device = q .device )
93104
94105 batch_mtp = b // mtp_size
95- cu_seqlens_q = torch .arange (
96- 0 , batch_mtp + 1 , step = s_q , dtype = torch .int32 , device = q .device
97- )
106+ cu_seqlens_q = torch .arange (0 , batch_mtp + 1 , step = s_q , dtype = torch .int32 , device = q .device )
98107 cu_seqlens_k = torch .cumsum (cache_seqlens , dim = 0 )
99108 cu_seqlens_k = torch .cat ([torch .tensor ([0 ]).to (cu_seqlens_k ), cu_seqlens_k ])
100- scale = (1.0 / (dv + dpe ))** 0.5 # log2(e)
109+ scale = (1.0 / (dv + dpe )) ** 0.5 # log2(e)
101110 k_descale , v_descale = None , None
102111 BLOCK_H = h_q * mtp_size
103112
@@ -119,23 +128,24 @@ def flash_mla_fa3():
119128 k_descale = k_descale ,
120129 v_descale = v_descale ,
121130 return_softmax_lse = False ,
122- mtp_step = 1
131+ mtp_step = 1 ,
123132 )
124133 return out .view ([b , s_q , h_q , dv ])
125134
126135 out_flash = flash_mla_fa3 ()
127136 t = do_bench (flash_mla_fa3 )
128137
129- out_ref = run_torch_mla (q , block_table , blocked_k , max_seqlen_pad , block_size , b , s_q ,
130- cache_seqlens , h_q , h_kv , d , dv , causal , dtype )
131-
132- # 计算相对绝对误差
138+ out_ref = run_torch_mla (
139+ q , block_table , blocked_k , max_seqlen_pad , block_size , b , s_q , cache_seqlens , h_q , h_kv , d , dv , causal , dtype
140+ )
141+
142+ # 计算相对绝对误差
133143 def print_error (a , b , name = "" ):
134144 max_absolute_error = torch .abs (a - b ).max ()
135145 relative_abs_error = torch .abs (a - b ) / (torch .abs (a ) + 1e-4 )
136146 max_relative_abs_error = relative_abs_error .max ()
137147 mean_relative_abs_error = relative_abs_error .mean ()
138-
148+
139149 print (f"{ name } : Maximum absolute difference: { max_absolute_error :.6e} " )
140150 print (f"Maximum relative absolute error: { max_relative_abs_error :.6e} " )
141151 print (f"Mean relative absolute error: { mean_relative_abs_error :.6e} " )
@@ -148,13 +158,13 @@ def print_error(a, b, name=""):
148158
149159if __name__ == "__main__" :
150160 parser = argparse .ArgumentParser ()
151- parser .add_argument (' --batch' , type = int , default = 128 , help = ' batch size' )
152- parser .add_argument (' --h_q' , type = int , default = 16 , help = ' q heads number' )
153- parser .add_argument (' --h_kv' , type = int , default = 1 , help = ' kv heads number' )
154- parser .add_argument (' --cache_seqlen' , type = int , default = 8192 , help = ' kv cache context length' )
155- parser .add_argument (' --d' , type = int , default = 576 , help = ' query/key head dim, d = dv + dpe' )
156- parser .add_argument (' --dv' , type = int , default = 512 , help = ' value head dim' )
157- parser .add_argument (' --mtp_size' , type = int , default = 2 , help = ' Specifies the number of tokens per prediction.' )
161+ parser .add_argument (" --batch" , type = int , default = 128 , help = " batch size" )
162+ parser .add_argument (" --h_q" , type = int , default = 16 , help = " q heads number" )
163+ parser .add_argument (" --h_kv" , type = int , default = 1 , help = " kv heads number" )
164+ parser .add_argument (" --cache_seqlen" , type = int , default = 8192 , help = " kv cache context length" )
165+ parser .add_argument (" --d" , type = int , default = 576 , help = " query/key head dim, d = dv + dpe" )
166+ parser .add_argument (" --dv" , type = int , default = 512 , help = " value head dim" )
167+ parser .add_argument (" --mtp_size" , type = int , default = 2 , help = " Specifies the number of tokens per prediction." )
158168 args = parser .parse_args ()
159169 b , h_q , h_kv , cache_seqlen , d , dv = args .batch , args .h_q , args .h_kv , args .cache_seqlen , args .d , args .dv
160170 mtp_size = args .mtp_size
@@ -165,28 +175,41 @@ def print_error(a, b, name=""):
165175 s_q = 1 # for decode, s_q = 1
166176 block_size = 1
167177 batch_mtp = b // mtp_size
168- cache_seqlens = torch .tensor ([cache_seqlen + i for i in range (batch_mtp )],
169- dtype = torch .int32 ,
170- device = device )
178+ cache_seqlens = torch .tensor ([cache_seqlen + i for i in range (batch_mtp )], dtype = torch .int32 , device = device )
171179 # print(cache_seqlens[-1])
172180 dpe = d - dv
173181 causal = True
174182
175183 total_seqlens = cache_seqlens .sum ().item ()
176184 mean_seqlens = cache_seqlens .float ().mean ().int ().item ()
177185 max_seqlen = cache_seqlens .max ().item ()
178- max_seqlen_pad = math .ceil (max_seqlen / 256 ) * 256 # ?为什么对齐256
186+ max_seqlen_pad = math .ceil (max_seqlen / 256 ) * 256 # ?为什么对齐256
179187
180188 total_flops = s_q * (total_seqlens * 2 - batch_mtp ) * h_q * (d + dv ) * 2
181189
182190 q = torch .randn (b , s_q , h_q , d , dtype = dtype , device = device )
183- block_table = torch .arange (
184- batch_mtp * max_seqlen_pad , dtype = torch . int32 ,
185- device = device ). view ( batch_mtp , max_seqlen_pad )
191+ block_table = torch .arange (batch_mtp * max_seqlen_pad , dtype = torch . int32 , device = device ). view (
192+ batch_mtp , max_seqlen_pad
193+ )
186194
187195 blocked_k = torch .randn (block_table .numel (), block_size , h_kv , d , dtype = dtype , device = device )
188- out_flash , latency = run_fa3_mla_mtp (mtp_size , q , block_table , blocked_k , max_seqlen_pad , block_size , b ,
189- s_q , cache_seqlens , h_q , h_kv , d , dv , causal , dtype )
196+ out_flash , latency = run_fa3_mla_mtp (
197+ mtp_size ,
198+ q ,
199+ block_table ,
200+ blocked_k ,
201+ max_seqlen_pad ,
202+ block_size ,
203+ b ,
204+ s_q ,
205+ cache_seqlens ,
206+ h_q ,
207+ h_kv ,
208+ d ,
209+ dv ,
210+ causal ,
211+ dtype ,
212+ )
190213
191214 print ("Tile-lang: {:.3f} ms" .format (latency ))
192215 print ("Tile-lang: {:.3f} TFlops" .format (total_flops / latency * 1e-9 ))
0 commit comments