1212from lightllm .common .basemodel .infer_struct import InferStateInfo
1313from lightllm .common .mem_manager import MemoryManager
1414from lightllm .common .req_manager import ReqManager
15- from lightllm .common .infer_utils import init_req_to_token_indexes
1615from lightllm .common .build_utils import repair_config
17- from lightllm .common .basemodel .triton_kernel .copy_kv_index_to_req import copy_kv_index_to_req
1816from lightllm .common .basemodel .layer_infer .cache_tensor_manager import g_cache_manager
1917from lightllm .common .basemodel .cuda_graph import CudaGraph
2018from lightllm .common .quantization import Quantcfg
@@ -333,14 +331,6 @@ def _prefill(
333331 model_input : ModelInput ,
334332 ):
335333 infer_state = self ._create_inferstate (model_input )
336- init_req_to_token_indexes (
337- self .req_manager .req_to_token_indexs ,
338- model_input .b_req_idx ,
339- model_input .b_seq_len ,
340- infer_state .b_ready_cache_len ,
341- model_input .max_len_in_batch ,
342- infer_state .mem_index ,
343- )
344334
345335 infer_state .init_some_extra_state (self , model_input .input_ids )
346336 return self ._context_forward (model_input .input_ids , infer_state )
@@ -361,12 +351,6 @@ def _decode(
361351 find_graph_batch_size = self .graph .find_closest_graph_batch_size (model_input .batch_size )
362352 padded_model_input = self ._create_padded_decode_model_input (model_input , find_graph_batch_size )
363353 infer_state = self ._create_inferstate (padded_model_input )
364- copy_kv_index_to_req (
365- self .req_manager .req_to_token_indexs ,
366- infer_state .b_req_idx ,
367- infer_state .b_seq_len ,
368- infer_state .mem_index ,
369- )
370354 infer_state .init_some_extra_state (self , padded_model_input .input_ids )
371355
372356 if self .graph .need_capture (find_graph_batch_size ):
@@ -382,12 +366,6 @@ def _decode(
382366 )
383367 else :
384368 infer_state = self ._create_inferstate (model_input )
385- copy_kv_index_to_req (
386- self .req_manager .req_to_token_indexs ,
387- infer_state .b_req_idx ,
388- infer_state .b_seq_len ,
389- infer_state .mem_index ,
390- )
391369 infer_state .init_some_extra_state (self , model_input .input_ids )
392370 model_output = self ._token_forward (model_input .input_ids , infer_state )
393371
@@ -472,25 +450,9 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod
472450 input_ids0 , input_ids1 = model_input0 .input_ids , model_input1 .input_ids
473451
474452 infer_state0 = self ._create_inferstate (model_input0 , 0 )
475- init_req_to_token_indexes (
476- self .req_manager .req_to_token_indexs ,
477- model_input0 .b_req_idx ,
478- model_input0 .b_seq_len ,
479- infer_state0 .b_ready_cache_len ,
480- model_input0 .max_len_in_batch ,
481- infer_state0 .mem_index ,
482- )
483453 infer_state0 .init_some_extra_state (self , input_ids0 )
484454
485455 infer_state1 = self ._create_inferstate (model_input1 , 1 )
486- init_req_to_token_indexes (
487- self .req_manager .req_to_token_indexs ,
488- model_input1 .b_req_idx ,
489- model_input1 .b_seq_len ,
490- infer_state1 .b_ready_cache_len ,
491- model_input1 .max_len_in_batch ,
492- infer_state1 .mem_index ,
493- )
494456 infer_state1 .init_some_extra_state (self , input_ids1 )
495457
496458 model_output0 , model_output1 = self ._overlap_tpsp_context_forward (
@@ -532,20 +494,8 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode
532494 padded_model_input0 = self ._create_padded_decode_model_input (model_input0 , find_graph_batch_size )
533495 padded_model_input1 = self ._create_padded_decode_model_input (model_input1 , find_graph_batch_size )
534496 infer_state0 = self ._create_inferstate (padded_model_input0 , 0 )
535- copy_kv_index_to_req (
536- self .req_manager .req_to_token_indexs ,
537- infer_state0 .b_req_idx ,
538- infer_state0 .b_seq_len ,
539- infer_state0 .mem_index ,
540- )
541497 infer_state0 .init_some_extra_state (self , padded_model_input0 .input_ids )
542498 infer_state1 = self ._create_inferstate (padded_model_input1 , 1 )
543- copy_kv_index_to_req (
544- self .req_manager .req_to_token_indexs ,
545- infer_state1 .b_req_idx ,
546- infer_state1 .b_seq_len ,
547- infer_state1 .mem_index ,
548- )
549499 infer_state1 .init_some_extra_state (self , padded_model_input1 .input_ids )
550500
551501 if self .graph .need_capture (find_graph_batch_size ):
@@ -570,20 +520,8 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode
570520 model_output1 = self ._create_unpad_decode_model_output (model_output1 , origin_batch_size = origin_batch_size )
571521 else :
572522 infer_state0 = self ._create_inferstate (model_input0 , 0 )
573- copy_kv_index_to_req (
574- self .req_manager .req_to_token_indexs ,
575- infer_state0 .b_req_idx ,
576- infer_state0 .b_seq_len ,
577- infer_state0 .mem_index ,
578- )
579523 infer_state0 .init_some_extra_state (self , model_input0 .input_ids )
580524 infer_state1 = self ._create_inferstate (model_input1 , 1 )
581- copy_kv_index_to_req (
582- self .req_manager .req_to_token_indexs ,
583- infer_state1 .b_req_idx ,
584- infer_state1 .b_seq_len ,
585- infer_state1 .mem_index ,
586- )
587525 infer_state1 .init_some_extra_state (self , model_input1 .input_ids )
588526
589527 model_output0 , model_output1 = self ._overlap_tpsp_token_forward (
@@ -684,10 +622,12 @@ def _check_max_len_infer(self):
684622 logger .info ("begin check max_len infer" )
685623 dummy_input_ids = torch .ones (self .batch_max_tokens , dtype = torch .int32 , device = "cuda" )
686624 b_req_idx = torch .tensor ([self .req_manager .alloc ()], dtype = torch .int32 , device = "cuda" )
687- mem_indexes = self .mem_manager .alloc (len (dummy_input_ids )).cuda ()
688625 b_seq_len = torch .ones (1 , dtype = torch .int32 , device = "cuda" )
689626 b_seq_len [:] = self .batch_max_tokens
690627 b_ready_cache_len = torch .zeros (1 , dtype = torch .int32 , device = "cuda" )
628+ mem_indexes = self .mem_manager .alloc (
629+ len (dummy_input_ids ), b_req_idx , b_seq_len , b_ready_cache_len , True
630+ ).cuda ()
691631 total_token_num = self .batch_max_tokens
692632 b_mtp_index = torch .zeros (1 , dtype = torch .int32 , device = "cuda" )
693633 model_input = ModelInput (
0 commit comments