1111from lightllm .common .basemodel .infer_struct import InferStateInfo
1212from lightllm .common .mem_manager import MemoryManager
1313from lightllm .common .req_manager import ReqManager
14- from lightllm .common .infer_utils import init_req_to_token_indexes
1514from lightllm .common .build_utils import repair_config
16- from lightllm .common .basemodel .triton_kernel .copy_kv_index_to_req import copy_kv_index_to_req
1715from lightllm .common .basemodel .layer_infer .cache_tensor_manager import g_cache_manager
1816from lightllm .common .basemodel .cuda_graph import CudaGraph
1917from lightllm .common .quantization import Quantcfg
@@ -332,14 +330,6 @@ def _prefill(
332330 model_input : ModelInput ,
333331 ):
334332 infer_state = self ._create_inferstate (model_input )
335- init_req_to_token_indexes (
336- self .req_manager .req_to_token_indexs ,
337- model_input .b_req_idx ,
338- model_input .b_seq_len ,
339- infer_state .b_ready_cache_len ,
340- model_input .max_len_in_batch ,
341- infer_state .mem_index ,
342- )
343333
344334 infer_state .init_some_extra_state (self , model_input .input_ids )
345335 return self ._context_forward (model_input .input_ids , infer_state )
@@ -360,12 +350,6 @@ def _decode(
360350 find_graph_batch_size = self .graph .find_closest_graph_batch_size (model_input .batch_size )
361351 padded_model_input = self ._create_padded_decode_model_input (model_input , find_graph_batch_size )
362352 infer_state = self ._create_inferstate (padded_model_input )
363- copy_kv_index_to_req (
364- self .req_manager .req_to_token_indexs ,
365- infer_state .b_req_idx ,
366- infer_state .b_seq_len ,
367- infer_state .mem_index ,
368- )
369353 infer_state .init_some_extra_state (self , padded_model_input .input_ids )
370354
371355 if self .graph .need_capture (find_graph_batch_size ):
@@ -381,12 +365,6 @@ def _decode(
381365 )
382366 else :
383367 infer_state = self ._create_inferstate (model_input )
384- copy_kv_index_to_req (
385- self .req_manager .req_to_token_indexs ,
386- infer_state .b_req_idx ,
387- infer_state .b_seq_len ,
388- infer_state .mem_index ,
389- )
390368 infer_state .init_some_extra_state (self , model_input .input_ids )
391369 model_output = self ._token_forward (model_input .input_ids , infer_state )
392370
@@ -471,25 +449,9 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod
471449 input_ids0 , input_ids1 = model_input0 .input_ids , model_input1 .input_ids
472450
473451 infer_state0 = self ._create_inferstate (model_input0 , 0 )
474- init_req_to_token_indexes (
475- self .req_manager .req_to_token_indexs ,
476- model_input0 .b_req_idx ,
477- model_input0 .b_seq_len ,
478- infer_state0 .b_ready_cache_len ,
479- model_input0 .max_len_in_batch ,
480- infer_state0 .mem_index ,
481- )
482452 infer_state0 .init_some_extra_state (self , input_ids0 )
483453
484454 infer_state1 = self ._create_inferstate (model_input1 , 1 )
485- init_req_to_token_indexes (
486- self .req_manager .req_to_token_indexs ,
487- model_input1 .b_req_idx ,
488- model_input1 .b_seq_len ,
489- infer_state1 .b_ready_cache_len ,
490- model_input1 .max_len_in_batch ,
491- infer_state1 .mem_index ,
492- )
493455 infer_state1 .init_some_extra_state (self , input_ids1 )
494456
495457 model_output0 , model_output1 = self ._overlap_tpsp_context_forward (
@@ -531,20 +493,8 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode
531493 padded_model_input0 = self ._create_padded_decode_model_input (model_input0 , find_graph_batch_size )
532494 padded_model_input1 = self ._create_padded_decode_model_input (model_input1 , find_graph_batch_size )
533495 infer_state0 = self ._create_inferstate (padded_model_input0 , 0 )
534- copy_kv_index_to_req (
535- self .req_manager .req_to_token_indexs ,
536- infer_state0 .b_req_idx ,
537- infer_state0 .b_seq_len ,
538- infer_state0 .mem_index ,
539- )
540496 infer_state0 .init_some_extra_state (self , padded_model_input0 .input_ids )
541497 infer_state1 = self ._create_inferstate (padded_model_input1 , 1 )
542- copy_kv_index_to_req (
543- self .req_manager .req_to_token_indexs ,
544- infer_state1 .b_req_idx ,
545- infer_state1 .b_seq_len ,
546- infer_state1 .mem_index ,
547- )
548498 infer_state1 .init_some_extra_state (self , padded_model_input1 .input_ids )
549499
550500 if self .graph .need_capture (find_graph_batch_size ):
@@ -569,20 +519,8 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode
569519 model_output1 = self ._create_unpad_decode_model_output (model_output1 , origin_batch_size = origin_batch_size )
570520 else :
571521 infer_state0 = self ._create_inferstate (model_input0 , 0 )
572- copy_kv_index_to_req (
573- self .req_manager .req_to_token_indexs ,
574- infer_state0 .b_req_idx ,
575- infer_state0 .b_seq_len ,
576- infer_state0 .mem_index ,
577- )
578522 infer_state0 .init_some_extra_state (self , model_input0 .input_ids )
579523 infer_state1 = self ._create_inferstate (model_input1 , 1 )
580- copy_kv_index_to_req (
581- self .req_manager .req_to_token_indexs ,
582- infer_state1 .b_req_idx ,
583- infer_state1 .b_seq_len ,
584- infer_state1 .mem_index ,
585- )
586524 infer_state1 .init_some_extra_state (self , model_input1 .input_ids )
587525
588526 model_output0 , model_output1 = self ._overlap_tpsp_token_forward (
@@ -683,10 +621,12 @@ def _check_max_len_infer(self):
683621 logger .info ("begin check max_len infer" )
684622 dummy_input_ids = torch .ones (self .batch_max_tokens , dtype = torch .int32 , device = "cuda" )
685623 b_req_idx = torch .tensor ([self .req_manager .alloc ()], dtype = torch .int32 , device = "cuda" )
686- mem_indexes = self .mem_manager .alloc (len (dummy_input_ids )).cuda ()
687624 b_seq_len = torch .ones (1 , dtype = torch .int32 , device = "cuda" )
688625 b_seq_len [:] = self .batch_max_tokens
689626 b_ready_cache_len = torch .zeros (1 , dtype = torch .int32 , device = "cuda" )
627+ mem_indexes = self .mem_manager .alloc (
628+ len (dummy_input_ids ), b_req_idx , b_seq_len , b_ready_cache_len , True
629+ ).cuda ()
690630 total_token_num = self .batch_max_tokens
691631 b_mtp_index = torch .zeros (1 , dtype = torch .int32 , device = "cuda" )
692632 model_input = ModelInput (
0 commit comments