Skip to content

Commit 1efada3

Browse files
authored
feat: kv fp8 quant calibration for fa3 and flashinfer (#935)
1 parent fc72ffa commit 1efada3

32 files changed

+7777
-35
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from lightllm.distributed.communication_op import CustomProcessGroup, dist_group_manager
2424
from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput
2525
from lightllm.utils.custom_kernel_utis import pad2dim_tensor_to_new_batch
26+
from lightllm.utils.envs_utils import set_model_init_status
2627

2728

2829
logger = init_logger(__name__)
@@ -103,6 +104,7 @@ def __init__(self, kvargs):
103104
self._init_cudagraph()
104105
self._check_max_len_infer()
105106
torch.cuda.empty_cache()
107+
set_model_init_status(True)
106108
return
107109

108110
def _init_config(self):
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import torch
2+
3+
import triton
4+
import triton.language as tl
5+
6+
7+
@triton.jit
8+
def _fwd_kernel_destindex_copy_kv_per_head_fp8(
9+
K,
10+
Dest_loc,
11+
Out,
12+
scale,
13+
stride_k_bs,
14+
stride_k_h,
15+
stride_k_d,
16+
stride_o_bs,
17+
stride_o_h,
18+
stride_o_d,
19+
head_num,
20+
BLOCK_DMODEL: tl.constexpr,
21+
BLOCK_HEAD: tl.constexpr,
22+
FP8_MIN: tl.constexpr,
23+
FP8_MAX: tl.constexpr,
24+
):
25+
cur_index = tl.program_id(0)
26+
offs_h = tl.arange(0, BLOCK_HEAD)
27+
offs_d = tl.arange(0, BLOCK_DMODEL)
28+
29+
dest_index = tl.load(Dest_loc + cur_index).to(tl.int64)
30+
31+
k_ptrs = K + cur_index * stride_k_bs + stride_k_h * offs_h[:, None] + stride_k_d * offs_d[None, :]
32+
o_ptrs = Out + dest_index * stride_o_bs + stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :]
33+
34+
# to fp8
35+
scale_ptrs = scale + offs_h
36+
scales = tl.load(scale_ptrs, mask=offs_h < head_num, other=1.0)
37+
k = tl.load(k_ptrs, mask=offs_h[:, None] < head_num, other=0.0)
38+
k_scale = k / scales[:, None]
39+
k_fp8 = tl.clamp(k_scale, min=FP8_MIN, max=FP8_MAX).to(tl.float8e4nv)
40+
41+
tl.store(o_ptrs, k_fp8, mask=offs_h[:, None] < head_num)
42+
return
43+
44+
45+
@torch.no_grad()
46+
def destindex_copy_kv_fp8(K, DestLoc, scales, Out):
47+
if scales is None:
48+
Out[DestLoc] = K.to(torch.float8_e4m3fn)
49+
return
50+
51+
seq_len = DestLoc.shape[0]
52+
head_num = K.shape[1]
53+
head_dim = K.shape[2]
54+
assert K.shape[1] == Out.shape[1] and K.shape[2] == Out.shape[2]
55+
BLOCK_HEAD = triton.next_power_of_2(head_num)
56+
grid = (seq_len,)
57+
num_warps = 1
58+
59+
_fwd_kernel_destindex_copy_kv_per_head_fp8[grid](
60+
K,
61+
DestLoc,
62+
Out,
63+
scales,
64+
K.stride(0),
65+
K.stride(1),
66+
K.stride(2),
67+
Out.stride(0),
68+
Out.stride(1),
69+
Out.stride(2),
70+
head_num,
71+
BLOCK_DMODEL=head_dim,
72+
BLOCK_HEAD=BLOCK_HEAD,
73+
FP8_MIN=torch.finfo(torch.float8_e4m3fn).min,
74+
FP8_MAX=torch.finfo(torch.float8_e4m3fn).max,
75+
num_warps=num_warps,
76+
num_stages=1,
77+
)
78+
79+
80+
if __name__ == "__main__":
81+
import torch.nn.functional as F
82+
from lightllm.utils.vllm_utils import vllm_ops
83+
84+
B, N_CTX, H, HEAD_DIM = 32, 1024, 16, 128
85+
dtype = torch.bfloat16
86+
NUM = B
87+
dest_loc = torch.arange(NUM).cuda() * 2
88+
kv = torch.randn((len(dest_loc), H, HEAD_DIM), dtype=dtype).cuda()
89+
out = torch.zeros((B * N_CTX, H, HEAD_DIM), dtype=torch.uint8).cuda()
90+
scale = kv.abs().amax(dim=(0, 2)).to(torch.float32) / 448
91+
destindex_copy_kv_fp8(kv, dest_loc, scale, out.view(torch.float8_e4m3fn))
92+
93+
assert torch.allclose(
94+
out[:, :, :HEAD_DIM][dest_loc].view(torch.float8_e4m3fn).float() * scale.view(H, 1).expand(NUM, H, 1),
95+
kv.float(),
96+
atol=1e-5,
97+
rtol=1e-1,
98+
)
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
import torch
2+
3+
import triton
4+
import triton.language as tl
5+
6+
7+
@triton.jit
8+
def _per_head_max_reduce_kernel(
9+
Q,
10+
Scales,
11+
StartLoc,
12+
stride_q_t,
13+
stride_q_h,
14+
stride_scales_b,
15+
FP8_MAX: tl.constexpr,
16+
BLOCK_T: tl.constexpr,
17+
BLOCK_D: tl.constexpr,
18+
):
19+
b_id = tl.program_id(0)
20+
h_id = tl.program_id(1)
21+
22+
max_val = 0.0
23+
24+
start_loc = tl.load(StartLoc + b_id)
25+
end_loc = tl.load(StartLoc + b_id + 1)
26+
for t_offset in range(start_loc, end_loc, BLOCK_T):
27+
t_idx = t_offset + tl.arange(0, BLOCK_T)
28+
q_range = tl.arange(0, BLOCK_D)
29+
q_ptrs = Q + t_idx[:, None] * stride_q_t + h_id * stride_q_h + q_range[None, :]
30+
mask = (t_idx[:, None] < end_loc) & (q_range[None, :] < stride_q_h)
31+
q_vals = tl.load(q_ptrs, mask=mask, other=0.0)
32+
max_val = tl.maximum(tl.max(q_vals.abs()), max_val)
33+
34+
scale = tl.where(max_val > 0, max_val / FP8_MAX, 1.0)
35+
scale_ptr = Scales + b_id * stride_scales_b + h_id
36+
tl.store(scale_ptr, scale)
37+
38+
39+
@triton.jit
40+
def _apply_quantization_kernel(
41+
Q,
42+
Q_out,
43+
BatchIds,
44+
Scales,
45+
stride_q_t,
46+
stride_q_h,
47+
stride_qout_t,
48+
stride_qout_h,
49+
stride_scales_b,
50+
FP8_MIN: tl.constexpr,
51+
FP8_MAX: tl.constexpr,
52+
BLOCK_D: tl.constexpr,
53+
):
54+
t_id = tl.program_id(0)
55+
h_id = tl.program_id(1)
56+
57+
batch_id = tl.load(BatchIds + t_id)
58+
scale_ptr = Scales + batch_id * stride_scales_b + h_id
59+
scale = tl.load(scale_ptr)
60+
61+
q_range = tl.arange(0, BLOCK_D)
62+
q_ptrs = Q + t_id * stride_q_t + h_id * stride_q_h + q_range
63+
qout_ptrs = Q_out + t_id * stride_qout_t + h_id * stride_qout_h + q_range
64+
mask = q_range < stride_q_h
65+
q_vals = tl.load(q_ptrs, mask=mask, other=0.0)
66+
q_scaled = q_vals / scale
67+
q_clamped = tl.clamp(q_scaled, min=FP8_MIN, max=FP8_MAX).to(tl.float8e4nv)
68+
tl.store(qout_ptrs, q_clamped, mask=q_range < stride_qout_h)
69+
70+
71+
@torch.no_grad()
72+
def q_per_head_fp8_quant(q, seq_lens, b1_start_loc, scale_out=None, token_batch_ids=None):
73+
T, H, D = q.shape
74+
B = seq_lens.shape[0]
75+
76+
BLOCK_D = triton.next_power_of_2(D)
77+
BLOCK_T = 256
78+
num_warps = 4
79+
num_stages = 2
80+
81+
q_out = torch.empty_like(q, dtype=torch.float8_e4m3fn)
82+
if scale_out is None:
83+
scale_out = torch.empty((B, H), dtype=torch.float32, device=q.device)
84+
if token_batch_ids is None:
85+
token_batch_ids = torch.repeat_interleave(torch.arange(B, device=q.device), seq_lens)
86+
87+
_per_head_max_reduce_kernel[(B, H)](
88+
q,
89+
scale_out,
90+
b1_start_loc,
91+
q.stride(0),
92+
q.stride(1),
93+
scale_out.stride(0),
94+
FP8_MAX=torch.finfo(torch.float8_e4m3fn).max,
95+
BLOCK_T=BLOCK_T,
96+
BLOCK_D=BLOCK_D,
97+
num_warps=num_warps,
98+
num_stages=num_stages,
99+
)
100+
101+
_apply_quantization_kernel[(T, H)](
102+
q,
103+
q_out,
104+
token_batch_ids,
105+
scale_out,
106+
q.stride(0),
107+
q.stride(1),
108+
q_out.stride(0),
109+
q_out.stride(1),
110+
scale_out.stride(0),
111+
FP8_MIN=torch.finfo(torch.float8_e4m3fn).min,
112+
FP8_MAX=torch.finfo(torch.float8_e4m3fn).max,
113+
BLOCK_D=BLOCK_D,
114+
num_warps=num_warps,
115+
num_stages=num_stages,
116+
)
117+
return q_out, scale_out
118+
119+
120+
def ref_q_per_head_fp8_quant(q, seq_lens):
121+
min_fp8 = torch.finfo(torch.float8_e4m3fn).min
122+
max_fp8 = torch.finfo(torch.float8_e4m3fn).max
123+
B = seq_lens.size(0)
124+
device = q.device
125+
token_batch_ids = torch.repeat_interleave(torch.arange(B, device=device), seq_lens)
126+
max_per_time_head = q.abs().amax(dim=2)
127+
max_per_bh = torch.zeros((B, max_per_time_head.size(1)), device=device, dtype=max_per_time_head.dtype)
128+
max_per_bh.scatter_reduce_(
129+
0,
130+
token_batch_ids.unsqueeze(-1).expand(-1, max_per_time_head.size(1)),
131+
max_per_time_head,
132+
reduce="amax",
133+
include_self=False,
134+
)
135+
scales = torch.where(max_per_bh > 0, max_per_bh / max_fp8, torch.ones_like(max_per_bh)).to(torch.float32)
136+
scale_expanded = scales[token_batch_ids].view(-1, scales.size(1), 1)
137+
q_q = (q / scale_expanded).clamp(min_fp8, max_fp8).to(torch.float8_e4m3fn)
138+
return q_q, scales
139+
140+
141+
if __name__ == "__main__":
142+
B, T, H, D = 200, 1000, 4, 7 * 128
143+
seq_lens = torch.ones((B,), dtype=torch.int32).cuda() * T // B
144+
start_locs = torch.zeros(B + 1, dtype=torch.int32).cuda()
145+
start_locs[1:] = seq_lens.cumsum(dim=0)
146+
q = torch.randn((T, H, D), dtype=torch.float32).cuda()
147+
148+
q_out, scales = q_per_head_fp8_quant(q, seq_lens, start_locs)
149+
q_out1, scales1 = ref_q_per_head_fp8_quant(q, seq_lens)
150+
assert torch.allclose(scales, scales1, atol=1e-10, rtol=0)
151+
assert torch.allclose(q_out.int(), q_out1.int(), atol=1e-10, rtol=0)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from .offline_fp8_quant_mem_manager import OfflineFP8QuantMemManager
2+
3+
4+
class CalibrationFP8KVMemoryManager(OfflineFP8QuantMemManager):
5+
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9):
6+
super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction, is_export_mode=False)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from .offline_fp8_quant_mem_manager import OfflineFP8QuantMemManager
2+
3+
4+
class ExportCalibrationMemoryManager(OfflineFP8QuantMemManager):
5+
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9):
6+
super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction, is_export_mode=True)

lightllm/common/mem_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from lightllm.common.mem_manager import MemoryManager
22
from lightllm.common.int8kv_mem_manager import INT8KVMemoryManager
3+
from lightllm.common.calibration_fp8kv_mem_manager import CalibrationFP8KVMemoryManager
4+
from lightllm.common.export_calibration_mem_manager import ExportCalibrationMemoryManager
35
from lightllm.common.ppl_int8kv_mem_manager import PPLINT8KVMemoryManager
46
from lightllm.common.ppl_int4kv_mem_manager import PPLINT4KVMemoryManager
57
from lightllm.utils.log_utils import init_logger
@@ -20,6 +22,12 @@ def select_mem_manager_class(mode):
2022
logger.info("Model kv cache using mode triton int8kv")
2123
elif "triton_fp8kv" in mode:
2224
raise Exception("currently only for deepseek")
25+
elif "offline_calibration_fp8kv" in mode:
26+
memory_manager_class = CalibrationFP8KVMemoryManager
27+
logger.info("Model kv cache using mode offline calibration fp8kv")
28+
elif "export_fp8kv_calibration" in mode:
29+
memory_manager_class = ExportCalibrationMemoryManager
30+
logger.info("Using mode export fp8kv calibration")
2331
else:
2432
memory_manager_class = MemoryManager
2533
logger.info("Model kv cache using mode normal")

0 commit comments

Comments
 (0)