Skip to content

Commit 3a14754

Browse files
author
niushengxiao
committed
feat: move init_req_to_token_indexes and copy_kv_index_to_req to alloc fun
1 parent f9a3fe2 commit 3a14754

File tree

10 files changed

+58
-79
lines changed

10 files changed

+58
-79
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 3 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@
1212
from lightllm.common.basemodel.infer_struct import InferStateInfo
1313
from lightllm.common.mem_manager import MemoryManager
1414
from lightllm.common.req_manager import ReqManager
15-
from lightllm.common.infer_utils import init_req_to_token_indexes
1615
from lightllm.common.build_utils import repair_config
17-
from lightllm.common.basemodel.triton_kernel.copy_kv_index_to_req import copy_kv_index_to_req
1816
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
1917
from lightllm.common.basemodel.cuda_graph import CudaGraph
2018
from lightllm.common.quantization import Quantcfg
@@ -333,14 +331,6 @@ def _prefill(
333331
model_input: ModelInput,
334332
):
335333
infer_state = self._create_inferstate(model_input)
336-
init_req_to_token_indexes(
337-
self.req_manager.req_to_token_indexs,
338-
model_input.b_req_idx,
339-
model_input.b_seq_len,
340-
infer_state.b_ready_cache_len,
341-
model_input.max_len_in_batch,
342-
infer_state.mem_index,
343-
)
344334

345335
infer_state.init_some_extra_state(self, model_input.input_ids)
346336
return self._context_forward(model_input.input_ids, infer_state)
@@ -361,12 +351,6 @@ def _decode(
361351
find_graph_batch_size = self.graph.find_closest_graph_batch_size(model_input.batch_size)
362352
padded_model_input = self._create_padded_decode_model_input(model_input, find_graph_batch_size)
363353
infer_state = self._create_inferstate(padded_model_input)
364-
copy_kv_index_to_req(
365-
self.req_manager.req_to_token_indexs,
366-
infer_state.b_req_idx,
367-
infer_state.b_seq_len,
368-
infer_state.mem_index,
369-
)
370354
infer_state.init_some_extra_state(self, padded_model_input.input_ids)
371355

372356
if self.graph.need_capture(find_graph_batch_size):
@@ -382,12 +366,6 @@ def _decode(
382366
)
383367
else:
384368
infer_state = self._create_inferstate(model_input)
385-
copy_kv_index_to_req(
386-
self.req_manager.req_to_token_indexs,
387-
infer_state.b_req_idx,
388-
infer_state.b_seq_len,
389-
infer_state.mem_index,
390-
)
391369
infer_state.init_some_extra_state(self, model_input.input_ids)
392370
model_output = self._token_forward(model_input.input_ids, infer_state)
393371

@@ -472,25 +450,9 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod
472450
input_ids0, input_ids1 = model_input0.input_ids, model_input1.input_ids
473451

474452
infer_state0 = self._create_inferstate(model_input0, 0)
475-
init_req_to_token_indexes(
476-
self.req_manager.req_to_token_indexs,
477-
model_input0.b_req_idx,
478-
model_input0.b_seq_len,
479-
infer_state0.b_ready_cache_len,
480-
model_input0.max_len_in_batch,
481-
infer_state0.mem_index,
482-
)
483453
infer_state0.init_some_extra_state(self, input_ids0)
484454

485455
infer_state1 = self._create_inferstate(model_input1, 1)
486-
init_req_to_token_indexes(
487-
self.req_manager.req_to_token_indexs,
488-
model_input1.b_req_idx,
489-
model_input1.b_seq_len,
490-
infer_state1.b_ready_cache_len,
491-
model_input1.max_len_in_batch,
492-
infer_state1.mem_index,
493-
)
494456
infer_state1.init_some_extra_state(self, input_ids1)
495457

496458
model_output0, model_output1 = self._overlap_tpsp_context_forward(
@@ -532,20 +494,8 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode
532494
padded_model_input0 = self._create_padded_decode_model_input(model_input0, find_graph_batch_size)
533495
padded_model_input1 = self._create_padded_decode_model_input(model_input1, find_graph_batch_size)
534496
infer_state0 = self._create_inferstate(padded_model_input0, 0)
535-
copy_kv_index_to_req(
536-
self.req_manager.req_to_token_indexs,
537-
infer_state0.b_req_idx,
538-
infer_state0.b_seq_len,
539-
infer_state0.mem_index,
540-
)
541497
infer_state0.init_some_extra_state(self, padded_model_input0.input_ids)
542498
infer_state1 = self._create_inferstate(padded_model_input1, 1)
543-
copy_kv_index_to_req(
544-
self.req_manager.req_to_token_indexs,
545-
infer_state1.b_req_idx,
546-
infer_state1.b_seq_len,
547-
infer_state1.mem_index,
548-
)
549499
infer_state1.init_some_extra_state(self, padded_model_input1.input_ids)
550500

551501
if self.graph.need_capture(find_graph_batch_size):
@@ -570,20 +520,8 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode
570520
model_output1 = self._create_unpad_decode_model_output(model_output1, origin_batch_size=origin_batch_size)
571521
else:
572522
infer_state0 = self._create_inferstate(model_input0, 0)
573-
copy_kv_index_to_req(
574-
self.req_manager.req_to_token_indexs,
575-
infer_state0.b_req_idx,
576-
infer_state0.b_seq_len,
577-
infer_state0.mem_index,
578-
)
579523
infer_state0.init_some_extra_state(self, model_input0.input_ids)
580524
infer_state1 = self._create_inferstate(model_input1, 1)
581-
copy_kv_index_to_req(
582-
self.req_manager.req_to_token_indexs,
583-
infer_state1.b_req_idx,
584-
infer_state1.b_seq_len,
585-
infer_state1.mem_index,
586-
)
587525
infer_state1.init_some_extra_state(self, model_input1.input_ids)
588526

589527
model_output0, model_output1 = self._overlap_tpsp_token_forward(
@@ -684,10 +622,12 @@ def _check_max_len_infer(self):
684622
logger.info("begin check max_len infer")
685623
dummy_input_ids = torch.ones(self.batch_max_tokens, dtype=torch.int32, device="cuda")
686624
b_req_idx = torch.tensor([self.req_manager.alloc()], dtype=torch.int32, device="cuda")
687-
mem_indexes = self.mem_manager.alloc(len(dummy_input_ids)).cuda()
688625
b_seq_len = torch.ones(1, dtype=torch.int32, device="cuda")
689626
b_seq_len[:] = self.batch_max_tokens
690627
b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda")
628+
mem_indexes = self.mem_manager.alloc(
629+
len(dummy_input_ids), b_req_idx, b_seq_len, b_ready_cache_len, True
630+
).cuda()
691631
total_token_num = self.batch_max_tokens
692632
b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cuda")
693633
model_input = ModelInput(

lightllm/common/basemodel/cuda_graph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,13 +196,13 @@ def warmup(self, model):
196196
total_token_num = batch_size * seq_len
197197
max_len_in_batch = self.graph_max_len_in_batch
198198
input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device="cuda")
199-
mem_indexes = model.mem_manager.alloc(len(input_ids)).cuda()
200199
b_req_idx = torch.tensor(
201200
[model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda"
202201
)
203202
b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda")
204203
b_seq_len.fill_(seq_len)
205204
b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
205+
mem_indexes = model.mem_manager.alloc(len(input_ids), b_req_idx, b_seq_len).cuda()
206206

207207
model_input = ModelInput(
208208
batch_size=batch_size,
@@ -252,13 +252,13 @@ def warmup_overlap(self, model):
252252
total_token_num = batch_size * seq_len
253253
max_len_in_batch = self.graph_max_len_in_batch
254254
input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device="cuda")
255-
mem_indexes = model.mem_manager.alloc(len(input_ids)).cuda()
256255
b_req_idx = torch.tensor(
257256
[model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda"
258257
)
259258
b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda")
260259
b_seq_len.fill_(seq_len)
261260
b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
261+
mem_indexes = model.mem_manager.alloc(len(input_ids), b_req_idx, b_seq_len).cuda()
262262

263263
micro_batch = ModelInput(
264264
is_prefill=False,

lightllm/common/infer_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
def init_req_to_token_indexes(
2-
req_to_token_indexs, b_req_idx, b_seq_len, b_ready_cache_len, max_len_in_batch, alloc_mem_index
3-
):
1+
def init_req_to_token_indexes(req_to_token_indexs, b_req_idx, b_seq_len, b_ready_cache_len, alloc_mem_index):
42
start_index = 0
53
b_seq_len_numpy = b_seq_len.cpu().numpy()
64
b_ready_cache_len_numpy = b_ready_cache_len.cpu().numpy()

lightllm/common/mem_manager.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args
1313
from lightllm.distributed.pynccl import PyNcclCommunicator
1414
from lightllm.utils.dist_utils import get_current_device_id
15+
from lightllm.common.infer_utils import init_req_to_token_indexes
16+
from lightllm.common.basemodel.triton_kernel.copy_kv_index_to_req import copy_kv_index_to_req
1517

1618
logger = init_logger(__name__)
1719

@@ -52,6 +54,7 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False
5254
layer_num,
5355
)
5456
self.HOLD_TOKEN_MEMINDEX = self.size
57+
self.req_to_token_indexs = None
5558

5659
def get_cell_size(self):
5760
return 2 * self.head_num * self.head_dim * self.layer_num * torch._utils._element_size(self.dtype)
@@ -243,7 +246,9 @@ def _write_kv_move_data_p2p(self, token_indexes: torch.Tensor, buffer_tensor: to
243246
def _free_buffers(self):
244247
self.kv_buffer = None
245248

246-
def alloc(self, need_size) -> torch.Tensor:
249+
def alloc(
250+
self, need_size, b_req_idx=None, b_seq_len=None, b_ready_cache_len=None, is_prefill=False
251+
) -> torch.Tensor:
247252
if need_size > self.mark_end - self.mark_start:
248253
logger.error(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}")
249254
assert False, "error alloc state"
@@ -255,8 +260,29 @@ def alloc(self, need_size) -> torch.Tensor:
255260

256261
self.can_use_mem_size -= need_size
257262
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
263+
264+
if self.req_to_token_indexs is not None:
265+
assert b_req_idx is not None and b_seq_len is not None, "b_req_idx and b_seq_len must be provided"
266+
if is_prefill:
267+
init_req_to_token_indexes(
268+
self.req_to_token_indexs,
269+
b_req_idx,
270+
b_seq_len,
271+
b_ready_cache_len,
272+
ans,
273+
)
274+
else:
275+
copy_kv_index_to_req(
276+
self.req_to_token_indexs,
277+
b_req_idx.cuda(),
278+
b_seq_len.cuda(),
279+
ans.cuda(),
280+
)
258281
return ans
259282

283+
def set_prefix_cache_to_req(self, req_idx: int, start: int, end: int, values: torch.Tensor):
284+
self.req_to_token_indexs[req_idx, start:end] = values
285+
260286
def free(self, free_index: Union[torch.Tensor, List[int]]):
261287
"""_summary_
262288

lightllm/common/req_manager.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager: MemoryMana
6262
self.req_to_token_indexs = torch.zeros(
6363
(max_request_num + 1, max_sequence_length), dtype=torch.int32, device="cuda"
6464
)
65+
mem_manager.req_to_token_indexs = self.req_to_token_indexs
6566
self.mem_manager = mem_manager
6667
self.req_sampling_params_manager = ReqSamplingParamsManager(max_request_num)
6768
self.max_request_num = max_request_num

lightllm/server/router/model_infer/infer_batch.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,9 @@ def _match_radix_cache(self):
340340
self.shared_kv_node = share_node
341341
ready_cache_len = share_node.node_prefix_total_len
342342
# 从 cpu 到 gpu 是流内阻塞操作
343-
g_infer_context.req_manager.req_to_token_indexs[self.req_idx, 0:ready_cache_len] = value_tensor
343+
g_infer_context.req_manager.mem_manager.set_prefix_cache_to_req(
344+
self.req_idx, 0, ready_cache_len, value_tensor
345+
)
344346
self.cur_kv_len = int(ready_cache_len) # 序列化问题, 该对象可能为numpy.int64,用 int(*)转换
345347
self.shm_req.prompt_cache_len = self.cur_kv_len # 记录 prompt cache 的命中长度
346348

@@ -458,7 +460,7 @@ def diverse_copy(self, req_manager, is_prefill):
458460
req = g_infer_context.requests_mapping[req_id]
459461
req.finish_status.set_status(FinishStatus.NO_FINISH)
460462
input_len = req.get_chuncked_input_token_len()
461-
req_manager.req_to_token_indexs[req.req_idx][prefix_len:input_len] = cache_token_id
463+
req_manager.mem_manager.set_prefix_cache_to_req(req.req_idx, prefix_len, input_len, cache_token_id)
462464
assert input_len == pre_input_len
463465

464466

lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,9 @@ def padded_prepare_prefill_inputs(
7878
g_infer_state_lock.acquire()
7979
if g_infer_context.radix_cache is not None:
8080
g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0] - padded_req_num)
81-
mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0] - padded_req_num)
81+
mem_indexes = g_infer_context.req_manager.mem_manager.alloc(
82+
input_ids.shape[0] - padded_req_num, b_req_idx, b_seq_len, b_ready_cache_len, True
83+
)
8284
g_infer_state_lock.release()
8385

8486
if padded_req_num > 0:
@@ -163,7 +165,9 @@ def padded_prepare_decode_inputs(
163165
g_infer_state_lock.acquire()
164166
if g_infer_context.radix_cache is not None:
165167
g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(b_seq_len.shape[0] - padded_req_num)
166-
mem_indexes = g_infer_context.req_manager.mem_manager.alloc(b_seq_len.shape[0] - padded_req_num)
168+
mem_indexes = g_infer_context.req_manager.mem_manager.alloc(
169+
b_seq_len.shape[0] - padded_req_num, b_req_idx, b_seq_len
170+
)
167171
g_infer_state_lock.release()
168172

169173
if padded_req_num > 0:

lightllm/server/router/model_infer/mode_backend/generic_pre_process.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,9 @@ def prepare_prefill_inputs(
5656
g_infer_state_lock.acquire()
5757
if g_infer_context.radix_cache is not None:
5858
g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0])
59-
mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0])
59+
mem_indexes = g_infer_context.req_manager.mem_manager.alloc(
60+
input_ids.shape[0], b_req_idx, b_seq_len, b_ready_cache_len, True
61+
)
6062
g_infer_state_lock.release()
6163

6264
model_input = ModelInput(
@@ -112,7 +114,7 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In
112114
g_infer_state_lock.acquire()
113115
if g_infer_context.radix_cache is not None:
114116
g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(b_seq_len.shape[0])
115-
mem_indexes = g_infer_context.req_manager.mem_manager.alloc(b_seq_len.shape[0])
117+
mem_indexes = g_infer_context.req_manager.mem_manager.alloc(b_seq_len.shape[0], b_req_idx, b_seq_len)
116118
g_infer_state_lock.release()
117119

118120
model_input = ModelInput(

test/benchmark/static_inference/model_infer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,9 @@ def run_forward_once(
258258
b_seq_len[i] = input_len
259259

260260
total_token_num = batch_size * input_len
261-
mem_indexes = model_part.req_manager.mem_manager.alloc(test_data.shape[0])
261+
mem_indexes = model_part.req_manager.mem_manager.alloc(
262+
test_data.shape[0], b_req_idx, b_seq_len, b_ready_cache_len, True
263+
).cuda()
262264
b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cpu")
263265
rank_id = model_kvargs["rank_id"]
264266

@@ -321,7 +323,7 @@ def run_forward_once(
321323
step_start = time.time()
322324
total_token_num += batch_size
323325
b_seq_len += 1
324-
mem_indexes = model_part.req_manager.mem_manager.alloc(predict_ids.shape[0])
326+
mem_indexes = model_part.req_manager.mem_manager.alloc(predict_ids.shape[0], b_req_idx, b_seq_len).cuda()
325327
max_len_in_batch = input_len + i + 1
326328
logits = decode_fn(
327329
model_part,

test/benchmark/static_inference/model_infer_mtp.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,9 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_
124124
b_seq_len[i] = input_len
125125

126126
total_token_num = input_len * batch_size
127-
mem_indexes = main_model.req_manager.mem_manager.alloc(test_data.shape[0]).cuda()
127+
mem_indexes = main_model.req_manager.mem_manager.alloc(
128+
test_data.shape[0], b_req_idx, b_seq_len, b_ready_cache_len, True
129+
).cuda()
128130
# Main model Prefill
129131
model_input = ModelInput(
130132
batch_size=batch_size,
@@ -191,7 +193,9 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_
191193

192194
nopad_b_seq_idx = torch.tensor(nopad_b_seq_idx, dtype=torch.int32, device="cuda")
193195
nopad_b_seq_len = torch.tensor(nopad_b_seq_len, dtype=torch.int32, device="cuda")
194-
mem_indexes = main_model.req_manager.mem_manager.alloc(batch_size * (len(draft_models) + 1)).cuda()
196+
mem_indexes = main_model.req_manager.mem_manager.alloc(
197+
batch_size * (len(draft_models) + 1), nopad_b_seq_idx, nopad_b_seq_len
198+
).cuda()
195199

196200
model_input = ModelInput(
197201
batch_size=batch_size * (len(draft_models) + 1),

0 commit comments

Comments
 (0)