Skip to content

Commit 28e0048

Browse files
WANDY666wangzaijunhiworldwzj
authored
add fp8_scaled_mm_per_token (#1112)
Co-authored-by: wangzaijun <wangzaijun@sensetime.com> Co-authored-by: hiworldwzj <30762946+hiworldwzj@users.noreply.github.com>
1 parent 7e7bc60 commit 28e0048

17 files changed

+1205
-1
lines changed

lightllm/common/quantization/triton_quant/fp8/fp8w8a8_scaled_mm_per_token_kernel.py

Lines changed: 471 additions & 0 deletions
Large diffs are not rendered by default.

lightllm/common/quantization/w8a8_quant.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import os
22
import torch
3+
4+
from lightllm.common.quantization.triton_quant.fp8.fp8w8a8_scaled_mm_per_token_kernel import fp8_scaled_mm_per_token
35
from .quantize_method import QuantizationMethod
46
from .registry import QUANTMETHODS
57
import torch.nn.functional as F
@@ -21,6 +23,12 @@ def scaled_fp8_quant(tensor, *args, **kwargs):
2123
if HAS_VLLM:
2224
scaled_fp8_quant = vllm_ops.scaled_fp8_quant
2325

26+
LIGHTLLM_USE_TRITON_FP8_SCALED_MM = os.getenv("LIGHTLLM_USE_TRITON_FP8_SCALED_MM", "False").upper() in [
27+
"ON",
28+
"TRUE",
29+
"1",
30+
]
31+
2432

2533
class BaseQuantizationMethod(QuantizationMethod):
2634
def __init__(self):
@@ -146,7 +154,10 @@ def apply(
146154
)
147155
else:
148156
out = torch.empty((m, n), dtype=input_tensor.dtype, device=input_tensor.device)
149-
cutlass_scaled_mm(out, x_q, qweight, x_scale, weight_scale, bias)
157+
if LIGHTLLM_USE_TRITON_FP8_SCALED_MM:
158+
out = fp8_scaled_mm_per_token(x_q, qweight, x_scale, weight_scale, input_tensor.dtype, out)
159+
else:
160+
cutlass_scaled_mm(out, x_q, qweight, x_scale, weight_scale, bias)
150161
return out
151162

152163
@property
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
{
2+
"32130": {
3+
"BLOCK_K": 128,
4+
"BLOCK_M": 128,
5+
"BLOCK_N": 128,
6+
"GROUP_M": 1,
7+
"num_stages": 2,
8+
"num_warps": 4
9+
},
10+
"75348": {
11+
"BLOCK_K": 128,
12+
"BLOCK_M": 128,
13+
"BLOCK_N": 128,
14+
"GROUP_M": 1,
15+
"num_stages": 2,
16+
"num_warps": 4
17+
}
18+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
{
2+
"1": {
3+
"BLOCK_K": 256,
4+
"BLOCK_M": 8,
5+
"BLOCK_N": 64,
6+
"GROUP_M": 8,
7+
"num_stages": 6,
8+
"num_warps": 8
9+
},
10+
"100": {
11+
"BLOCK_K": 64,
12+
"BLOCK_M": 32,
13+
"BLOCK_N": 128,
14+
"GROUP_M": 8,
15+
"num_stages": 5,
16+
"num_warps": 8
17+
},
18+
"1024": {
19+
"BLOCK_K": 64,
20+
"BLOCK_M": 64,
21+
"BLOCK_N": 128,
22+
"GROUP_M": 8,
23+
"num_stages": 3,
24+
"num_warps": 2
25+
},
26+
"128": {
27+
"BLOCK_K": 256,
28+
"BLOCK_M": 64,
29+
"BLOCK_N": 64,
30+
"GROUP_M": 8,
31+
"num_stages": 4,
32+
"num_warps": 4
33+
},
34+
"16": {
35+
"BLOCK_K": 256,
36+
"BLOCK_M": 16,
37+
"BLOCK_N": 64,
38+
"GROUP_M": 8,
39+
"num_stages": 4,
40+
"num_warps": 8
41+
},
42+
"2048": {
43+
"BLOCK_K": 64,
44+
"BLOCK_M": 64,
45+
"BLOCK_N": 128,
46+
"GROUP_M": 8,
47+
"num_stages": 3,
48+
"num_warps": 2
49+
},
50+
"256": {
51+
"BLOCK_K": 128,
52+
"BLOCK_M": 64,
53+
"BLOCK_N": 128,
54+
"GROUP_M": 8,
55+
"num_stages": 5,
56+
"num_warps": 8
57+
},
58+
"32": {
59+
"BLOCK_K": 256,
60+
"BLOCK_M": 32,
61+
"BLOCK_N": 64,
62+
"GROUP_M": 8,
63+
"num_stages": 5,
64+
"num_warps": 4
65+
},
66+
"64": {
67+
"BLOCK_K": 256,
68+
"BLOCK_M": 32,
69+
"BLOCK_N": 64,
70+
"GROUP_M": 8,
71+
"num_stages": 5,
72+
"num_warps": 4
73+
},
74+
"8": {
75+
"BLOCK_K": 256,
76+
"BLOCK_M": 8,
77+
"BLOCK_N": 64,
78+
"GROUP_M": 8,
79+
"num_stages": 4,
80+
"num_warps": 4
81+
}
82+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
{
2+
"32760": {
3+
"BLOCK_K": 128,
4+
"BLOCK_M": 128,
5+
"BLOCK_N": 128,
6+
"GROUP_M": 8,
7+
"num_stages": 2,
8+
"num_warps": 4
9+
},
10+
"512": {
11+
"BLOCK_K": 64,
12+
"BLOCK_M": 64,
13+
"BLOCK_N": 64,
14+
"GROUP_M": 32,
15+
"num_stages": 2,
16+
"num_warps": 4
17+
}
18+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
{
2+
"32760": {
3+
"BLOCK_K": 128,
4+
"BLOCK_M": 128,
5+
"BLOCK_N": 128,
6+
"GROUP_M": 8,
7+
"num_stages": 2,
8+
"num_warps": 4
9+
}
10+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
{
2+
"1": {
3+
"BLOCK_K": 256,
4+
"BLOCK_M": 8,
5+
"BLOCK_N": 64,
6+
"GROUP_M": 8,
7+
"num_stages": 5,
8+
"num_warps": 8
9+
},
10+
"100": {
11+
"BLOCK_K": 64,
12+
"BLOCK_M": 32,
13+
"BLOCK_N": 128,
14+
"GROUP_M": 8,
15+
"num_stages": 3,
16+
"num_warps": 8
17+
},
18+
"1024": {
19+
"BLOCK_K": 64,
20+
"BLOCK_M": 64,
21+
"BLOCK_N": 256,
22+
"GROUP_M": 8,
23+
"num_stages": 3,
24+
"num_warps": 4
25+
},
26+
"128": {
27+
"BLOCK_K": 128,
28+
"BLOCK_M": 64,
29+
"BLOCK_N": 64,
30+
"GROUP_M": 8,
31+
"num_stages": 3,
32+
"num_warps": 4
33+
},
34+
"16": {
35+
"BLOCK_K": 256,
36+
"BLOCK_M": 16,
37+
"BLOCK_N": 64,
38+
"GROUP_M": 8,
39+
"num_stages": 4,
40+
"num_warps": 8
41+
},
42+
"2048": {
43+
"BLOCK_K": 64,
44+
"BLOCK_M": 64,
45+
"BLOCK_N": 128,
46+
"GROUP_M": 8,
47+
"num_stages": 3,
48+
"num_warps": 2
49+
},
50+
"256": {
51+
"BLOCK_K": 128,
52+
"BLOCK_M": 64,
53+
"BLOCK_N": 128,
54+
"GROUP_M": 8,
55+
"num_stages": 4,
56+
"num_warps": 4
57+
},
58+
"32": {
59+
"BLOCK_K": 256,
60+
"BLOCK_M": 32,
61+
"BLOCK_N": 64,
62+
"GROUP_M": 8,
63+
"num_stages": 4,
64+
"num_warps": 4
65+
},
66+
"64": {
67+
"BLOCK_K": 256,
68+
"BLOCK_M": 32,
69+
"BLOCK_N": 64,
70+
"GROUP_M": 8,
71+
"num_stages": 4,
72+
"num_warps": 4
73+
},
74+
"8": {
75+
"BLOCK_K": 256,
76+
"BLOCK_M": 8,
77+
"BLOCK_N": 64,
78+
"GROUP_M": 8,
79+
"num_stages": 5,
80+
"num_warps": 8
81+
}
82+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
{
2+
"32130": {
3+
"BLOCK_K": 128,
4+
"BLOCK_M": 128,
5+
"BLOCK_N": 128,
6+
"GROUP_M": 1,
7+
"num_stages": 2,
8+
"num_warps": 4
9+
}
10+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
{
2+
"1": {
3+
"BLOCK_K": 256,
4+
"BLOCK_M": 8,
5+
"BLOCK_N": 64,
6+
"GROUP_M": 8,
7+
"num_stages": 5,
8+
"num_warps": 8
9+
},
10+
"100": {
11+
"BLOCK_K": 128,
12+
"BLOCK_M": 16,
13+
"BLOCK_N": 128,
14+
"GROUP_M": 8,
15+
"num_stages": 5,
16+
"num_warps": 8
17+
},
18+
"1024": {
19+
"BLOCK_K": 64,
20+
"BLOCK_M": 64,
21+
"BLOCK_N": 128,
22+
"GROUP_M": 8,
23+
"num_stages": 3,
24+
"num_warps": 2
25+
},
26+
"128": {
27+
"BLOCK_K": 256,
28+
"BLOCK_M": 32,
29+
"BLOCK_N": 64,
30+
"GROUP_M": 8,
31+
"num_stages": 5,
32+
"num_warps": 4
33+
},
34+
"16": {
35+
"BLOCK_K": 256,
36+
"BLOCK_M": 8,
37+
"BLOCK_N": 64,
38+
"GROUP_M": 8,
39+
"num_stages": 3,
40+
"num_warps": 8
41+
},
42+
"2048": {
43+
"BLOCK_K": 128,
44+
"BLOCK_M": 64,
45+
"BLOCK_N": 64,
46+
"GROUP_M": 8,
47+
"num_stages": 3,
48+
"num_warps": 4
49+
},
50+
"256": {
51+
"BLOCK_K": 256,
52+
"BLOCK_M": 64,
53+
"BLOCK_N": 64,
54+
"GROUP_M": 8,
55+
"num_stages": 4,
56+
"num_warps": 4
57+
},
58+
"32": {
59+
"BLOCK_K": 256,
60+
"BLOCK_M": 16,
61+
"BLOCK_N": 64,
62+
"GROUP_M": 8,
63+
"num_stages": 5,
64+
"num_warps": 4
65+
},
66+
"64": {
67+
"BLOCK_K": 256,
68+
"BLOCK_M": 16,
69+
"BLOCK_N": 64,
70+
"GROUP_M": 8,
71+
"num_stages": 4,
72+
"num_warps": 4
73+
},
74+
"8": {
75+
"BLOCK_K": 256,
76+
"BLOCK_M": 8,
77+
"BLOCK_N": 64,
78+
"GROUP_M": 8,
79+
"num_stages": 4,
80+
"num_warps": 4
81+
}
82+
}

0 commit comments

Comments
 (0)