Skip to content

Commit 60bbd86

Browse files
committed
format
1 parent 9028d74 commit 60bbd86

File tree

1 file changed

+61
-38
lines changed

1 file changed

+61
-38
lines changed

test/benchmark/kernel/benchmark_fa3_decode_mtp.py

Lines changed: 61 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from lightllm.common.flash_attn import flash_attn_with_kvcache_mtp
2626
from lightllm.utils.bench_utils import do_bench
2727

28+
2829
def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
2930
query = query.float()
3031
key = key.float()
@@ -36,8 +37,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
3637
s_q = query.shape[-2]
3738
s_k = key.shape[-2]
3839
attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype, device=query.device)
39-
temp_mask = torch.ones(
40-
s_q, s_k, dtype=torch.bool, device=query.device).tril(diagonal=s_k - s_q)
40+
temp_mask = torch.ones(s_q, s_k, dtype=torch.bool, device=query.device).tril(diagonal=s_k - s_q)
4141
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
4242
attn_bias.to(query.dtype)
4343
attn_weight += attn_bias
@@ -47,8 +47,9 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
4747

4848

4949
@torch.inference_mode()
50-
def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q,
51-
h_kv, d, dv, causal, dtype):
50+
def run_torch_mla(
51+
q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype
52+
):
5253
# q: [b, s_q, h_q, d]
5354
# block_table: [b, max_seqlen_pad // block_size]
5455
# blocked_k: [b * max_seqlen_pad // block_size, block_size, h_kv, d]
@@ -77,27 +78,35 @@ def ref_mla():
7778
return out_torch
7879

7980

80-
def run_fa3_mla_mtp(mtp_size, q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens,
81-
h_q, h_kv, d, dv, causal, dtype):
81+
def run_fa3_mla_mtp(
82+
mtp_size,
83+
q,
84+
block_table,
85+
blocked_k,
86+
max_seqlen_pad,
87+
block_size,
88+
b,
89+
s_q,
90+
cache_seqlens,
91+
h_q,
92+
h_kv,
93+
d,
94+
dv,
95+
causal,
96+
dtype,
97+
):
8298

8399
assert d > dv, "mla with rope dim should be larger than no rope dim"
84100
q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous()
85-
blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[...,
86-
dv:].contiguous()
101+
blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous()
87102

88103
dpe = d - dv
89-
num_kv_splits = 1
90-
91-
out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device)
92-
glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device)
93104

94105
batch_mtp = b // mtp_size
95-
cu_seqlens_q = torch.arange(
96-
0, batch_mtp + 1, step=s_q, dtype=torch.int32, device=q.device
97-
)
106+
cu_seqlens_q = torch.arange(0, batch_mtp + 1, step=s_q, dtype=torch.int32, device=q.device)
98107
cu_seqlens_k = torch.cumsum(cache_seqlens, dim=0)
99108
cu_seqlens_k = torch.cat([torch.tensor([0]).to(cu_seqlens_k), cu_seqlens_k])
100-
scale = (1.0 / (dv + dpe))**0.5 # log2(e)
109+
scale = (1.0 / (dv + dpe)) ** 0.5 # log2(e)
101110
k_descale, v_descale = None, None
102111
BLOCK_H = h_q * mtp_size
103112

@@ -119,23 +128,24 @@ def flash_mla_fa3():
119128
k_descale=k_descale,
120129
v_descale=v_descale,
121130
return_softmax_lse=False,
122-
mtp_step=1
131+
mtp_step=1,
123132
)
124133
return out.view([b, s_q, h_q, dv])
125134

126135
out_flash = flash_mla_fa3()
127136
t = do_bench(flash_mla_fa3)
128137

129-
out_ref = run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q,
130-
cache_seqlens, h_q, h_kv, d, dv, causal, dtype)
131-
132-
# 计算相对绝对误差
138+
out_ref = run_torch_mla(
139+
q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype
140+
)
141+
142+
# 计算相对绝对误差
133143
def print_error(a, b, name=""):
134144
max_absolute_error = torch.abs(a - b).max()
135145
relative_abs_error = torch.abs(a - b) / (torch.abs(a) + 1e-4)
136146
max_relative_abs_error = relative_abs_error.max()
137147
mean_relative_abs_error = relative_abs_error.mean()
138-
148+
139149
print(f"{name}: Maximum absolute difference: {max_absolute_error:.6e}")
140150
print(f"Maximum relative absolute error: {max_relative_abs_error:.6e}")
141151
print(f"Mean relative absolute error: {mean_relative_abs_error:.6e}")
@@ -148,13 +158,13 @@ def print_error(a, b, name=""):
148158

149159
if __name__ == "__main__":
150160
parser = argparse.ArgumentParser()
151-
parser.add_argument('--batch', type=int, default=128, help='batch size')
152-
parser.add_argument('--h_q', type=int, default=16, help='q heads number')
153-
parser.add_argument('--h_kv', type=int, default=1, help='kv heads number')
154-
parser.add_argument('--cache_seqlen', type=int, default=8192, help='kv cache context length')
155-
parser.add_argument('--d', type=int, default=576, help='query/key head dim, d = dv + dpe')
156-
parser.add_argument('--dv', type=int, default=512, help='value head dim')
157-
parser.add_argument('--mtp_size', type=int, default=2, help='Specifies the number of tokens per prediction.')
161+
parser.add_argument("--batch", type=int, default=128, help="batch size")
162+
parser.add_argument("--h_q", type=int, default=16, help="q heads number")
163+
parser.add_argument("--h_kv", type=int, default=1, help="kv heads number")
164+
parser.add_argument("--cache_seqlen", type=int, default=8192, help="kv cache context length")
165+
parser.add_argument("--d", type=int, default=576, help="query/key head dim, d = dv + dpe")
166+
parser.add_argument("--dv", type=int, default=512, help="value head dim")
167+
parser.add_argument("--mtp_size", type=int, default=2, help="Specifies the number of tokens per prediction.")
158168
args = parser.parse_args()
159169
b, h_q, h_kv, cache_seqlen, d, dv = args.batch, args.h_q, args.h_kv, args.cache_seqlen, args.d, args.dv
160170
mtp_size = args.mtp_size
@@ -165,28 +175,41 @@ def print_error(a, b, name=""):
165175
s_q = 1 # for decode, s_q = 1
166176
block_size = 1
167177
batch_mtp = b // mtp_size
168-
cache_seqlens = torch.tensor([cache_seqlen + i for i in range(batch_mtp)],
169-
dtype=torch.int32,
170-
device=device)
178+
cache_seqlens = torch.tensor([cache_seqlen + i for i in range(batch_mtp)], dtype=torch.int32, device=device)
171179
# print(cache_seqlens[-1])
172180
dpe = d - dv
173181
causal = True
174182

175183
total_seqlens = cache_seqlens.sum().item()
176184
mean_seqlens = cache_seqlens.float().mean().int().item()
177185
max_seqlen = cache_seqlens.max().item()
178-
max_seqlen_pad = math.ceil(max_seqlen / 256) * 256 # ?为什么对齐256
186+
max_seqlen_pad = math.ceil(max_seqlen / 256) * 256 # ?为什么对齐256
179187

180188
total_flops = s_q * (total_seqlens * 2 - batch_mtp) * h_q * (d + dv) * 2
181189

182190
q = torch.randn(b, s_q, h_q, d, dtype=dtype, device=device)
183-
block_table = torch.arange(
184-
batch_mtp * max_seqlen_pad, dtype=torch.int32,
185-
device=device).view(batch_mtp, max_seqlen_pad)
191+
block_table = torch.arange(batch_mtp * max_seqlen_pad, dtype=torch.int32, device=device).view(
192+
batch_mtp, max_seqlen_pad
193+
)
186194

187195
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d, dtype=dtype, device=device)
188-
out_flash, latency = run_fa3_mla_mtp(mtp_size, q, block_table, blocked_k, max_seqlen_pad, block_size, b,
189-
s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype)
196+
out_flash, latency = run_fa3_mla_mtp(
197+
mtp_size,
198+
q,
199+
block_table,
200+
blocked_k,
201+
max_seqlen_pad,
202+
block_size,
203+
b,
204+
s_q,
205+
cache_seqlens,
206+
h_q,
207+
h_kv,
208+
d,
209+
dv,
210+
causal,
211+
dtype,
212+
)
190213

191214
print("Tile-lang: {:.3f} ms".format(latency))
192215
print("Tile-lang: {:.3f} TFlops".format(total_flops / latency * 1e-9))

0 commit comments

Comments
 (0)