Skip to content

Commit 855d517

Browse files
author
niushengxiao
committed
feat: move init_req_to_token_indexes and copy_kv_index_to_req to alloc fun
1 parent a4bc0d6 commit 855d517

File tree

10 files changed

+58
-79
lines changed

10 files changed

+58
-79
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 3 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,7 @@
1111
from lightllm.common.basemodel.infer_struct import InferStateInfo
1212
from lightllm.common.mem_manager import MemoryManager
1313
from lightllm.common.req_manager import ReqManager
14-
from lightllm.common.infer_utils import init_req_to_token_indexes
1514
from lightllm.common.build_utils import repair_config
16-
from lightllm.common.basemodel.triton_kernel.copy_kv_index_to_req import copy_kv_index_to_req
1715
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
1816
from lightllm.common.basemodel.cuda_graph import CudaGraph
1917
from lightllm.common.quantization import Quantcfg
@@ -332,14 +330,6 @@ def _prefill(
332330
model_input: ModelInput,
333331
):
334332
infer_state = self._create_inferstate(model_input)
335-
init_req_to_token_indexes(
336-
self.req_manager.req_to_token_indexs,
337-
model_input.b_req_idx,
338-
model_input.b_seq_len,
339-
infer_state.b_ready_cache_len,
340-
model_input.max_len_in_batch,
341-
infer_state.mem_index,
342-
)
343333

344334
infer_state.init_some_extra_state(self, model_input.input_ids)
345335
return self._context_forward(model_input.input_ids, infer_state)
@@ -360,12 +350,6 @@ def _decode(
360350
find_graph_batch_size = self.graph.find_closest_graph_batch_size(model_input.batch_size)
361351
padded_model_input = self._create_padded_decode_model_input(model_input, find_graph_batch_size)
362352
infer_state = self._create_inferstate(padded_model_input)
363-
copy_kv_index_to_req(
364-
self.req_manager.req_to_token_indexs,
365-
infer_state.b_req_idx,
366-
infer_state.b_seq_len,
367-
infer_state.mem_index,
368-
)
369353
infer_state.init_some_extra_state(self, padded_model_input.input_ids)
370354

371355
if self.graph.need_capture(find_graph_batch_size):
@@ -381,12 +365,6 @@ def _decode(
381365
)
382366
else:
383367
infer_state = self._create_inferstate(model_input)
384-
copy_kv_index_to_req(
385-
self.req_manager.req_to_token_indexs,
386-
infer_state.b_req_idx,
387-
infer_state.b_seq_len,
388-
infer_state.mem_index,
389-
)
390368
infer_state.init_some_extra_state(self, model_input.input_ids)
391369
model_output = self._token_forward(model_input.input_ids, infer_state)
392370

@@ -471,25 +449,9 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod
471449
input_ids0, input_ids1 = model_input0.input_ids, model_input1.input_ids
472450

473451
infer_state0 = self._create_inferstate(model_input0, 0)
474-
init_req_to_token_indexes(
475-
self.req_manager.req_to_token_indexs,
476-
model_input0.b_req_idx,
477-
model_input0.b_seq_len,
478-
infer_state0.b_ready_cache_len,
479-
model_input0.max_len_in_batch,
480-
infer_state0.mem_index,
481-
)
482452
infer_state0.init_some_extra_state(self, input_ids0)
483453

484454
infer_state1 = self._create_inferstate(model_input1, 1)
485-
init_req_to_token_indexes(
486-
self.req_manager.req_to_token_indexs,
487-
model_input1.b_req_idx,
488-
model_input1.b_seq_len,
489-
infer_state1.b_ready_cache_len,
490-
model_input1.max_len_in_batch,
491-
infer_state1.mem_index,
492-
)
493455
infer_state1.init_some_extra_state(self, input_ids1)
494456

495457
model_output0, model_output1 = self._overlap_tpsp_context_forward(
@@ -531,20 +493,8 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode
531493
padded_model_input0 = self._create_padded_decode_model_input(model_input0, find_graph_batch_size)
532494
padded_model_input1 = self._create_padded_decode_model_input(model_input1, find_graph_batch_size)
533495
infer_state0 = self._create_inferstate(padded_model_input0, 0)
534-
copy_kv_index_to_req(
535-
self.req_manager.req_to_token_indexs,
536-
infer_state0.b_req_idx,
537-
infer_state0.b_seq_len,
538-
infer_state0.mem_index,
539-
)
540496
infer_state0.init_some_extra_state(self, padded_model_input0.input_ids)
541497
infer_state1 = self._create_inferstate(padded_model_input1, 1)
542-
copy_kv_index_to_req(
543-
self.req_manager.req_to_token_indexs,
544-
infer_state1.b_req_idx,
545-
infer_state1.b_seq_len,
546-
infer_state1.mem_index,
547-
)
548498
infer_state1.init_some_extra_state(self, padded_model_input1.input_ids)
549499

550500
if self.graph.need_capture(find_graph_batch_size):
@@ -569,20 +519,8 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode
569519
model_output1 = self._create_unpad_decode_model_output(model_output1, origin_batch_size=origin_batch_size)
570520
else:
571521
infer_state0 = self._create_inferstate(model_input0, 0)
572-
copy_kv_index_to_req(
573-
self.req_manager.req_to_token_indexs,
574-
infer_state0.b_req_idx,
575-
infer_state0.b_seq_len,
576-
infer_state0.mem_index,
577-
)
578522
infer_state0.init_some_extra_state(self, model_input0.input_ids)
579523
infer_state1 = self._create_inferstate(model_input1, 1)
580-
copy_kv_index_to_req(
581-
self.req_manager.req_to_token_indexs,
582-
infer_state1.b_req_idx,
583-
infer_state1.b_seq_len,
584-
infer_state1.mem_index,
585-
)
586524
infer_state1.init_some_extra_state(self, model_input1.input_ids)
587525

588526
model_output0, model_output1 = self._overlap_tpsp_token_forward(
@@ -683,10 +621,12 @@ def _check_max_len_infer(self):
683621
logger.info("begin check max_len infer")
684622
dummy_input_ids = torch.ones(self.batch_max_tokens, dtype=torch.int32, device="cuda")
685623
b_req_idx = torch.tensor([self.req_manager.alloc()], dtype=torch.int32, device="cuda")
686-
mem_indexes = self.mem_manager.alloc(len(dummy_input_ids)).cuda()
687624
b_seq_len = torch.ones(1, dtype=torch.int32, device="cuda")
688625
b_seq_len[:] = self.batch_max_tokens
689626
b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda")
627+
mem_indexes = self.mem_manager.alloc(
628+
len(dummy_input_ids), b_req_idx, b_seq_len, b_ready_cache_len, True
629+
).cuda()
690630
total_token_num = self.batch_max_tokens
691631
b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cuda")
692632
model_input = ModelInput(

lightllm/common/basemodel/cuda_graph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,13 +196,13 @@ def warmup(self, model):
196196
total_token_num = batch_size * seq_len
197197
max_len_in_batch = self.graph_max_len_in_batch
198198
input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device="cuda")
199-
mem_indexes = model.mem_manager.alloc(len(input_ids)).cuda()
200199
b_req_idx = torch.tensor(
201200
[model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda"
202201
)
203202
b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda")
204203
b_seq_len.fill_(seq_len)
205204
b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
205+
mem_indexes = model.mem_manager.alloc(len(input_ids), b_req_idx, b_seq_len).cuda()
206206

207207
model_input = ModelInput(
208208
batch_size=batch_size,
@@ -252,13 +252,13 @@ def warmup_overlap(self, model):
252252
total_token_num = batch_size * seq_len
253253
max_len_in_batch = self.graph_max_len_in_batch
254254
input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device="cuda")
255-
mem_indexes = model.mem_manager.alloc(len(input_ids)).cuda()
256255
b_req_idx = torch.tensor(
257256
[model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda"
258257
)
259258
b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda")
260259
b_seq_len.fill_(seq_len)
261260
b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
261+
mem_indexes = model.mem_manager.alloc(len(input_ids), b_req_idx, b_seq_len).cuda()
262262

263263
micro_batch = ModelInput(
264264
is_prefill=False,

lightllm/common/infer_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
def init_req_to_token_indexes(
2-
req_to_token_indexs, b_req_idx, b_seq_len, b_ready_cache_len, max_len_in_batch, alloc_mem_index
3-
):
1+
def init_req_to_token_indexes(req_to_token_indexs, b_req_idx, b_seq_len, b_ready_cache_len, alloc_mem_index):
42
start_index = 0
53
b_seq_len_numpy = b_seq_len.cpu().numpy()
64
b_ready_cache_len_numpy = b_ready_cache_len.cpu().numpy()

lightllm/common/mem_manager.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args
1313
from lightllm.distributed.pynccl import PyNcclCommunicator
1414
from lightllm.utils.dist_utils import get_current_device_id
15+
from lightllm.common.infer_utils import init_req_to_token_indexes
16+
from lightllm.common.basemodel.triton_kernel.copy_kv_index_to_req import copy_kv_index_to_req
1517

1618
logger = init_logger(__name__)
1719

@@ -52,6 +54,7 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False
5254
layer_num,
5355
)
5456
self.HOLD_TOKEN_MEMINDEX = self.size
57+
self.req_to_token_indexs = None
5558

5659
def get_cell_size(self):
5760
return 2 * self.head_num * self.head_dim * self.layer_num * torch._utils._element_size(self.dtype)
@@ -243,7 +246,9 @@ def _write_kv_move_data_p2p(self, token_indexes: torch.Tensor, buffer_tensor: to
243246
def _free_buffers(self):
244247
self.kv_buffer = None
245248

246-
def alloc(self, need_size) -> torch.Tensor:
249+
def alloc(
250+
self, need_size, b_req_idx=None, b_seq_len=None, b_ready_cache_len=None, is_prefill=False
251+
) -> torch.Tensor:
247252
if need_size > self.mark_end - self.mark_start:
248253
logger.error(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}")
249254
assert False, "error alloc state"
@@ -255,8 +260,29 @@ def alloc(self, need_size) -> torch.Tensor:
255260

256261
self.can_use_mem_size -= need_size
257262
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
263+
264+
if self.req_to_token_indexs is not None:
265+
assert b_req_idx is not None and b_seq_len is not None, "b_req_idx and b_seq_len must be provided"
266+
if is_prefill:
267+
init_req_to_token_indexes(
268+
self.req_to_token_indexs,
269+
b_req_idx,
270+
b_seq_len,
271+
b_ready_cache_len,
272+
ans,
273+
)
274+
else:
275+
copy_kv_index_to_req(
276+
self.req_to_token_indexs,
277+
b_req_idx.cuda(),
278+
b_seq_len.cuda(),
279+
ans.cuda(),
280+
)
258281
return ans
259282

283+
def set_prefix_cache_to_req(self, req_idx: int, start: int, end: int, values: torch.Tensor):
284+
self.req_to_token_indexs[req_idx, start:end] = values
285+
260286
def free(self, free_index: Union[torch.Tensor, List[int]]):
261287
"""_summary_
262288

lightllm/common/req_manager.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager: MemoryMana
6262
self.req_to_token_indexs = torch.zeros(
6363
(max_request_num + 1, max_sequence_length), dtype=torch.int32, device="cuda"
6464
)
65+
mem_manager.req_to_token_indexs = self.req_to_token_indexs
6566
self.mem_manager = mem_manager
6667
self.req_sampling_params_manager = ReqSamplingParamsManager(max_request_num)
6768
self.max_request_num = max_request_num

lightllm/server/router/model_infer/infer_batch.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,9 @@ def _match_radix_cache(self):
340340
self.shared_kv_node = share_node
341341
ready_cache_len = share_node.node_prefix_total_len
342342
# 从 cpu 到 gpu 是流内阻塞操作
343-
g_infer_context.req_manager.req_to_token_indexs[self.req_idx, 0:ready_cache_len] = value_tensor
343+
g_infer_context.req_manager.mem_manager.set_prefix_cache_to_req(
344+
self.req_idx, 0, ready_cache_len, value_tensor
345+
)
344346
self.cur_kv_len = int(ready_cache_len) # 序列化问题, 该对象可能为numpy.int64,用 int(*)转换
345347
self.shm_req.prompt_cache_len = self.cur_kv_len # 记录 prompt cache 的命中长度
346348

@@ -461,7 +463,7 @@ def diverse_copy(self, req_manager, is_prefill):
461463
req = g_infer_context.requests_mapping[req_id]
462464
req.finish_status.set_status(FinishStatus.NO_FINISH)
463465
input_len = req.get_chuncked_input_token_len()
464-
req_manager.req_to_token_indexs[req.req_idx][prefix_len:input_len] = cache_token_id
466+
req_manager.mem_manager.set_prefix_cache_to_req(req.req_idx, prefix_len, input_len, cache_token_id)
465467
assert input_len == pre_input_len
466468

467469

lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,9 @@ def padded_prepare_prefill_inputs(
7878
g_infer_state_lock.acquire()
7979
if g_infer_context.radix_cache is not None:
8080
g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0] - padded_req_num)
81-
mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0] - padded_req_num)
81+
mem_indexes = g_infer_context.req_manager.mem_manager.alloc(
82+
input_ids.shape[0] - padded_req_num, b_req_idx, b_seq_len, b_ready_cache_len, True
83+
)
8284
g_infer_state_lock.release()
8385

8486
if padded_req_num > 0:
@@ -163,7 +165,9 @@ def padded_prepare_decode_inputs(
163165
g_infer_state_lock.acquire()
164166
if g_infer_context.radix_cache is not None:
165167
g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(b_seq_len.shape[0] - padded_req_num)
166-
mem_indexes = g_infer_context.req_manager.mem_manager.alloc(b_seq_len.shape[0] - padded_req_num)
168+
mem_indexes = g_infer_context.req_manager.mem_manager.alloc(
169+
b_seq_len.shape[0] - padded_req_num, b_req_idx, b_seq_len
170+
)
167171
g_infer_state_lock.release()
168172

169173
if padded_req_num > 0:

lightllm/server/router/model_infer/mode_backend/generic_pre_process.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,9 @@ def prepare_prefill_inputs(
5656
g_infer_state_lock.acquire()
5757
if g_infer_context.radix_cache is not None:
5858
g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0])
59-
mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0])
59+
mem_indexes = g_infer_context.req_manager.mem_manager.alloc(
60+
input_ids.shape[0], b_req_idx, b_seq_len, b_ready_cache_len, True
61+
)
6062
g_infer_state_lock.release()
6163

6264
model_input = ModelInput(
@@ -112,7 +114,7 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In
112114
g_infer_state_lock.acquire()
113115
if g_infer_context.radix_cache is not None:
114116
g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(b_seq_len.shape[0])
115-
mem_indexes = g_infer_context.req_manager.mem_manager.alloc(b_seq_len.shape[0])
117+
mem_indexes = g_infer_context.req_manager.mem_manager.alloc(b_seq_len.shape[0], b_req_idx, b_seq_len)
116118
g_infer_state_lock.release()
117119

118120
model_input = ModelInput(

test/benchmark/static_inference/model_infer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,9 @@ def run_forward_once(
242242
b_seq_len[i] = input_len
243243

244244
total_token_num = batch_size * input_len
245-
mem_indexes = model_part.req_manager.mem_manager.alloc(test_data.shape[0]).cuda()
245+
mem_indexes = model_part.req_manager.mem_manager.alloc(
246+
test_data.shape[0], b_req_idx, b_seq_len, b_ready_cache_len, True
247+
).cuda()
246248

247249
rank_id = model_kvargs["rank_id"]
248250

@@ -303,7 +305,7 @@ def run_forward_once(
303305
step_start = time.time()
304306
total_token_num += batch_size
305307
b_seq_len += 1
306-
mem_indexes = model_part.req_manager.mem_manager.alloc(predict_ids.shape[0]).cuda()
308+
mem_indexes = model_part.req_manager.mem_manager.alloc(predict_ids.shape[0], b_req_idx, b_seq_len).cuda()
307309
max_len_in_batch = input_len + i + 1
308310
logits = decode_fn(
309311
model_part,

test/benchmark/static_inference/model_infer_mtp.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,9 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_
126126
b_seq_len[i] = input_len
127127

128128
total_token_num = input_len * batch_size
129-
mem_indexes = main_model.req_manager.mem_manager.alloc(test_data.shape[0]).cuda()
129+
mem_indexes = main_model.req_manager.mem_manager.alloc(
130+
test_data.shape[0], b_req_idx, b_seq_len, b_ready_cache_len, True
131+
).cuda()
130132
# Main model Prefill
131133
model_input = ModelInput(
132134
batch_size=batch_size,
@@ -193,7 +195,9 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_
193195

194196
nopad_b_seq_idx = torch.tensor(nopad_b_seq_idx, dtype=torch.int32, device="cuda")
195197
nopad_b_seq_len = torch.tensor(nopad_b_seq_len, dtype=torch.int32, device="cuda")
196-
mem_indexes = main_model.req_manager.mem_manager.alloc(batch_size * (len(draft_models) + 1)).cuda()
198+
mem_indexes = main_model.req_manager.mem_manager.alloc(
199+
batch_size * (len(draft_models) + 1), nopad_b_seq_idx, nopad_b_seq_len
200+
).cuda()
197201

198202
model_input = ModelInput(
199203
batch_size=batch_size * (len(draft_models) + 1),

0 commit comments

Comments
 (0)