Skip to content

Commit 9369d2d

Browse files
author
niushengxiao
committed
feat: add page_size_variable mode for fa3 backend
1 parent 3a14754 commit 9369d2d

File tree

14 files changed

+835
-47
lines changed

14 files changed

+835
-47
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
from lightllm.common.basemodel.infer_struct import InferStateInfo
1313
from lightllm.common.mem_manager import MemoryManager
1414
from lightllm.common.req_manager import ReqManager
15+
from lightllm.common.infer_utils import init_req_to_token_indexes
1516
from lightllm.common.build_utils import repair_config
17+
from lightllm.common.basemodel.triton_kernel.copy_kv_index_to_req import copy_kv_index_to_req
1618
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
1719
from lightllm.common.basemodel.cuda_graph import CudaGraph
1820
from lightllm.common.quantization import Quantcfg
@@ -331,6 +333,14 @@ def _prefill(
331333
model_input: ModelInput,
332334
):
333335
infer_state = self._create_inferstate(model_input)
336+
init_req_to_token_indexes(
337+
self.req_manager.req_to_token_indexs,
338+
model_input.b_req_idx,
339+
model_input.b_seq_len,
340+
infer_state.b_ready_cache_len,
341+
model_input.max_len_in_batch,
342+
infer_state.mem_index,
343+
)
334344

335345
infer_state.init_some_extra_state(self, model_input.input_ids)
336346
return self._context_forward(model_input.input_ids, infer_state)
@@ -351,6 +361,12 @@ def _decode(
351361
find_graph_batch_size = self.graph.find_closest_graph_batch_size(model_input.batch_size)
352362
padded_model_input = self._create_padded_decode_model_input(model_input, find_graph_batch_size)
353363
infer_state = self._create_inferstate(padded_model_input)
364+
copy_kv_index_to_req(
365+
self.req_manager.req_to_token_indexs,
366+
infer_state.b_req_idx,
367+
infer_state.b_seq_len,
368+
infer_state.mem_index,
369+
)
354370
infer_state.init_some_extra_state(self, padded_model_input.input_ids)
355371

356372
if self.graph.need_capture(find_graph_batch_size):
@@ -366,6 +382,12 @@ def _decode(
366382
)
367383
else:
368384
infer_state = self._create_inferstate(model_input)
385+
copy_kv_index_to_req(
386+
self.req_manager.req_to_token_indexs,
387+
infer_state.b_req_idx,
388+
infer_state.b_seq_len,
389+
infer_state.mem_index,
390+
)
369391
infer_state.init_some_extra_state(self, model_input.input_ids)
370392
model_output = self._token_forward(model_input.input_ids, infer_state)
371393

@@ -450,9 +472,25 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod
450472
input_ids0, input_ids1 = model_input0.input_ids, model_input1.input_ids
451473

452474
infer_state0 = self._create_inferstate(model_input0, 0)
475+
init_req_to_token_indexes(
476+
self.req_manager.req_to_token_indexs,
477+
model_input0.b_req_idx,
478+
model_input0.b_seq_len,
479+
infer_state0.b_ready_cache_len,
480+
model_input0.max_len_in_batch,
481+
infer_state0.mem_index,
482+
)
453483
infer_state0.init_some_extra_state(self, input_ids0)
454484

455485
infer_state1 = self._create_inferstate(model_input1, 1)
486+
init_req_to_token_indexes(
487+
self.req_manager.req_to_token_indexs,
488+
model_input1.b_req_idx,
489+
model_input1.b_seq_len,
490+
infer_state1.b_ready_cache_len,
491+
model_input1.max_len_in_batch,
492+
infer_state1.mem_index,
493+
)
456494
infer_state1.init_some_extra_state(self, input_ids1)
457495

458496
model_output0, model_output1 = self._overlap_tpsp_context_forward(
@@ -494,8 +532,20 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode
494532
padded_model_input0 = self._create_padded_decode_model_input(model_input0, find_graph_batch_size)
495533
padded_model_input1 = self._create_padded_decode_model_input(model_input1, find_graph_batch_size)
496534
infer_state0 = self._create_inferstate(padded_model_input0, 0)
535+
copy_kv_index_to_req(
536+
self.req_manager.req_to_token_indexs,
537+
infer_state0.b_req_idx,
538+
infer_state0.b_seq_len,
539+
infer_state0.mem_index,
540+
)
497541
infer_state0.init_some_extra_state(self, padded_model_input0.input_ids)
498542
infer_state1 = self._create_inferstate(padded_model_input1, 1)
543+
copy_kv_index_to_req(
544+
self.req_manager.req_to_token_indexs,
545+
infer_state1.b_req_idx,
546+
infer_state1.b_seq_len,
547+
infer_state1.mem_index,
548+
)
499549
infer_state1.init_some_extra_state(self, padded_model_input1.input_ids)
500550

501551
if self.graph.need_capture(find_graph_batch_size):
@@ -520,8 +570,20 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode
520570
model_output1 = self._create_unpad_decode_model_output(model_output1, origin_batch_size=origin_batch_size)
521571
else:
522572
infer_state0 = self._create_inferstate(model_input0, 0)
573+
copy_kv_index_to_req(
574+
self.req_manager.req_to_token_indexs,
575+
infer_state0.b_req_idx,
576+
infer_state0.b_seq_len,
577+
infer_state0.mem_index,
578+
)
523579
infer_state0.init_some_extra_state(self, model_input0.input_ids)
524580
infer_state1 = self._create_inferstate(model_input1, 1)
581+
copy_kv_index_to_req(
582+
self.req_manager.req_to_token_indexs,
583+
infer_state1.b_req_idx,
584+
infer_state1.b_seq_len,
585+
infer_state1.mem_index,
586+
)
525587
infer_state1.init_some_extra_state(self, model_input1.input_ids)
526588

527589
model_output0, model_output1 = self._overlap_tpsp_token_forward(

lightllm/common/infer_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
def init_req_to_token_indexes(req_to_token_indexs, b_req_idx, b_seq_len, b_ready_cache_len, alloc_mem_index):
1+
def init_req_to_token_indexes(
2+
req_to_token_indexs, b_req_idx, b_seq_len, b_ready_cache_len, max_len_in_batch, alloc_mem_index
3+
):
24
start_index = 0
35
b_seq_len_numpy = b_seq_len.cpu().numpy()
46
b_ready_cache_len_numpy = b_ready_cache_len.cpu().numpy()

lightllm/common/mem_manager.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args
1313
from lightllm.distributed.pynccl import PyNcclCommunicator
1414
from lightllm.utils.dist_utils import get_current_device_id
15-
from lightllm.common.infer_utils import init_req_to_token_indexes
16-
from lightllm.common.basemodel.triton_kernel.copy_kv_index_to_req import copy_kv_index_to_req
1715

1816
logger = init_logger(__name__)
1917

@@ -260,24 +258,6 @@ def alloc(
260258

261259
self.can_use_mem_size -= need_size
262260
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
263-
264-
if self.req_to_token_indexs is not None:
265-
assert b_req_idx is not None and b_seq_len is not None, "b_req_idx and b_seq_len must be provided"
266-
if is_prefill:
267-
init_req_to_token_indexes(
268-
self.req_to_token_indexs,
269-
b_req_idx,
270-
b_seq_len,
271-
b_ready_cache_len,
272-
ans,
273-
)
274-
else:
275-
copy_kv_index_to_req(
276-
self.req_to_token_indexs,
277-
b_req_idx.cuda(),
278-
b_seq_len.cuda(),
279-
ans.cuda(),
280-
)
281261
return ans
282262

283263
def set_prefix_cache_to_req(self, req_idx: int, start: int, end: int, values: torch.Tensor):

lightllm/common/mem_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from lightllm.common.export_calibration_mem_manager import ExportCalibrationMemoryManager
55
from lightllm.common.ppl_int8kv_mem_manager import PPLINT8KVMemoryManager
66
from lightllm.common.ppl_int4kv_mem_manager import PPLINT4KVMemoryManager
7+
from lightllm.common.page_size_variable_mem_manager import PageSizeVariableMemoryManager
78
from lightllm.utils.log_utils import init_logger
89

910
logger = init_logger(__name__)
@@ -28,6 +29,9 @@ def select_mem_manager_class(mode):
2829
elif "export_fp8kv_calibration" in mode:
2930
memory_manager_class = ExportCalibrationMemoryManager
3031
logger.info("Using mode export fp8kv calibration")
32+
elif "page_size_variable" in mode:
33+
memory_manager_class = PageSizeVariableMemoryManager
34+
logger.info("Page size will be variable")
3135
else:
3236
memory_manager_class = MemoryManager
3337
logger.info("Model kv cache using mode normal")
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
import torch
2+
import numpy as np
3+
from .mem_manager import MemoryManager
4+
from typing import List, Union
5+
from lightllm.utils.log_utils import init_logger
6+
from lightllm.utils.envs_utils import get_page_size
7+
8+
9+
def cdiv(a, b):
10+
return (a + b - 1) // b
11+
12+
13+
logger = init_logger(__name__)
14+
15+
16+
class PageSizeVariableMemoryManager(MemoryManager):
17+
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9):
18+
super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction)
19+
self.req_to_page_indexs = None
20+
page_size = get_page_size()
21+
self.page_idx_pool = torch.arange(
22+
0, cdiv(self.size, page_size), dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True
23+
)
24+
self.mark_page_start = 0
25+
self.can_use_page_size = cdiv(self.size, page_size)
26+
27+
def _init_buffers(self, size, dtype, head_num, head_dim, layer_num):
28+
self.kv_buffer = torch.empty(
29+
(layer_num, cdiv(size, get_page_size()) * get_page_size(), 2 * head_num, head_dim),
30+
dtype=dtype,
31+
device="cuda",
32+
)
33+
34+
# 要求长度必须是page_size的整数倍,page内token索引必须连续
35+
def check_cache_page_valid(self, values: torch.Tensor):
36+
end = len(values)
37+
assert end % self.page_size == 0, "Values length must be a multiple of page size"
38+
total_pages = end // self.page_size
39+
for page_idx in range(total_pages):
40+
values_start = page_idx * self.page_size
41+
values_end = min((page_idx + 1) * self.page_size, end)
42+
page_token_idxs = values[values_start:values_end]
43+
if len(page_token_idxs) > 1:
44+
expected_idxs = torch.arange(
45+
page_token_idxs[0],
46+
page_token_idxs[0] + len(page_token_idxs),
47+
dtype=page_token_idxs.dtype,
48+
device=page_token_idxs.device,
49+
)
50+
if not torch.equal(page_token_idxs, expected_idxs):
51+
return False
52+
return True
53+
54+
def set_prefix_cache_to_req(self, req_idx: int, start: int, end: int, values: torch.Tensor):
55+
# assert self.check_cache_page_valid(values), "Values must be valid for page size"
56+
page_size = get_page_size()
57+
self.req_to_page_indexs[req_idx, start // page_size : end // page_size] = values[::page_size] // page_size
58+
self.req_to_token_indexs[req_idx, start:end] = values
59+
60+
def expand_by_page_size(self, b_token_len, page_size):
61+
# 将seq_len按page整数倍展开,例如seq_len = [9,9,9] -> page_len = [4,4,1,4,4,1,4,4,1], page_size = 4
62+
b_page_len = cdiv(b_token_len, page_size)
63+
need_pages_num = b_page_len.sum()
64+
p_token_len = torch.full((need_pages_num,), page_size, dtype=b_token_len.dtype, device=b_token_len.device)
65+
cumsum_pages = torch.cumsum(b_page_len, dim=0)
66+
last_page_positions = cumsum_pages - 1
67+
remainders = b_token_len - (b_page_len - 1) * page_size
68+
p_token_len[last_page_positions] = remainders
69+
return need_pages_num, b_page_len, p_token_len
70+
71+
def get_paged_token_indexs(self, b_req_idx, page_size, b_seq_len, b_ready_cache_len, is_prefill):
72+
if is_prefill:
73+
b_req_idx = b_req_idx.cuda()
74+
b_seq_len = b_seq_len.cuda()
75+
b_ready_cache_len = b_ready_cache_len.cuda()
76+
77+
b_token_len = b_seq_len - b_ready_cache_len
78+
total_pages_needed, b_page_len, p_token_len = self.expand_by_page_size(b_token_len, page_size)
79+
if self.can_use_page_size < total_pages_needed:
80+
raise RuntimeError(
81+
f"No available pages for alloc. remaining: {self.can_use_page_size}, needed: {total_pages_needed}"
82+
)
83+
84+
allocated_pages = self.page_idx_pool[
85+
self.mark_page_start : self.mark_page_start + total_pages_needed
86+
].cuda()
87+
88+
def get_offsets_by_length(b_len, max_len):
89+
# 例:b_len = [3,4,5] -> [0,1,2,0,1,2,3,0,1,2,3,4]
90+
offsets = torch.arange(max_len, dtype=b_len.dtype, device=b_len.device)
91+
offset_mask = offsets.unsqueeze(0) < b_len.unsqueeze(1)
92+
return torch.masked_select(offsets, offset_mask)
93+
94+
page_offsets = get_offsets_by_length(b_page_len, b_page_len.max())
95+
token_offsets = get_offsets_by_length(p_token_len, page_size)
96+
97+
# 更新req_to_page_indexs, b_ready_cache_len必整除page_size
98+
page_starts = b_ready_cache_len // page_size
99+
req_id = torch.repeat_interleave(
100+
torch.arange(len(b_req_idx), dtype=b_token_len.dtype, device=b_token_len.device), b_page_len
101+
)
102+
self.req_to_page_indexs[b_req_idx[req_id], page_starts[req_id] + page_offsets] = allocated_pages
103+
104+
self.mark_page_start += total_pages_needed
105+
self.can_use_page_size -= total_pages_needed
106+
page_bases = allocated_pages * page_size
107+
return torch.repeat_interleave(page_bases, p_token_len) + token_offsets
108+
else:
109+
b_seq_len = b_seq_len.cuda()
110+
b_req_idx = b_req_idx.cuda()
111+
need_new_page_mask = (b_seq_len - 1) % page_size == 0
112+
new_pages_num = need_new_page_mask.sum()
113+
if self.can_use_page_size < new_pages_num:
114+
raise RuntimeError(
115+
f"No available pages for alloc. remaining: {self.can_use_page_size}, needed: {new_pages_num}"
116+
)
117+
118+
token_idxs = torch.zeros_like(b_seq_len, device=b_seq_len.device)
119+
if new_pages_num > 0:
120+
new_pages = self.page_idx_pool[self.mark_page_start : self.mark_page_start + new_pages_num].cuda()
121+
self.mark_page_start += new_pages_num
122+
self.can_use_page_size -= new_pages_num
123+
token_idxs[need_new_page_mask] = new_pages * page_size
124+
125+
# 需要更新req_to_page_indexs
126+
new_page_req_indices = b_req_idx[need_new_page_mask]
127+
page_positions = (b_seq_len[need_new_page_mask] - 1) // page_size
128+
self.req_to_page_indexs[new_page_req_indices, page_positions] = new_pages
129+
130+
mask = ~need_new_page_mask
131+
if mask.any():
132+
seq_lens = b_seq_len[mask]
133+
token_idxs[mask] = (
134+
self.req_to_token_indexs[b_req_idx[mask], seq_lens - 2] // page_size * page_size
135+
+ (seq_lens - 1) % page_size
136+
)
137+
return token_idxs
138+
139+
def alloc(self, need_size, b_req_idx, b_seq_len, b_ready_cache_len=None, is_prefill=False) -> torch.Tensor:
140+
page_size = get_page_size()
141+
token_idxs = self.get_paged_token_indexs(b_req_idx, page_size, b_seq_len, b_ready_cache_len, is_prefill)
142+
self.can_use_mem_size -= need_size
143+
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
144+
return token_idxs
145+
146+
def free(self, free_index: Union[torch.Tensor, List[int]]):
147+
self.can_use_mem_size += len(free_index)
148+
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
149+
150+
page_size = get_page_size()
151+
if isinstance(free_index, list):
152+
free_index = torch.tensor(free_index, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True)
153+
154+
if len(free_index) == 0:
155+
return
156+
157+
page_indices = free_index // page_size
158+
unique_pages = torch.unique(page_indices)
159+
for page_idx in sorted(unique_pages, reverse=True): # 逆序放回,保持池的相对顺序
160+
self.mark_page_start -= 1
161+
self.page_idx_pool[self.mark_page_start] = page_idx
162+
self.can_use_page_size += 1
163+
164+
return
165+
166+
def free_all(self):
167+
super().free_all()
168+
page_size = get_page_size()
169+
self.mark_page_start = 0
170+
self.can_use_page_size = cdiv(self.size, page_size)
171+
self.page_idx_pool = torch.arange(
172+
0, cdiv(self.size, page_size), dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True
173+
)

lightllm/common/req_manager.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import List, Optional
66
from lightllm.common.basemodel.triton_kernel.gen_sampling_params import token_id_counter
77
from lightllm.common.basemodel.triton_kernel.gen_sampling_params import update_req_to_token_id_counter
8-
from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args
8+
from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args, get_page_size
99
from lightllm.utils.config_utils import get_vocab_size
1010

1111
logger = init_logger(__name__)
@@ -63,6 +63,14 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager: MemoryMana
6363
(max_request_num + 1, max_sequence_length), dtype=torch.int32, device="cuda"
6464
)
6565
mem_manager.req_to_token_indexs = self.req_to_token_indexs
66+
if hasattr(mem_manager, "req_to_page_indexs"):
67+
page_size = get_page_size()
68+
self.req_to_page_indexs = torch.zeros(
69+
(max_request_num + 1, (max_sequence_length + page_size - 1) // page_size),
70+
dtype=torch.int32,
71+
device="cuda",
72+
)
73+
mem_manager.req_to_page_indexs = self.req_to_page_indexs
6674
self.mem_manager = mem_manager
6775
self.req_sampling_params_manager = ReqSamplingParamsManager(max_request_num)
6876
self.max_request_num = max_request_num

0 commit comments

Comments
 (0)