@@ -61,6 +61,8 @@ def __init__(self, kvargs):
6161 self .finetune_config = kvargs .get ("finetune_config" , None )
6262 self .max_req_num = kvargs .get ("max_req_num" , 1000 )
6363 self .max_seq_length = kvargs .get ("max_seq_length" , 1024 * 5 )
64+ # 用于等待外围的一些模块的初始化完成(如 CPU KV Cache 注册完成)
65+ self .wait_events = kvargs .get ("wait_events" , [])
6466 # is_token_healing 和 return_all_prompt_logics 是有排斥关系的两个模式,只能单独有一个生效
6567 # 主要是在prefill阶段返回多少个token的用于后续处理相关。
6668 self .is_token_healing = kvargs .get ("is_token_healing" , False )
@@ -110,12 +112,19 @@ def __init__(self, kvargs):
110112 self ._init_inferstate_cls ()
111113 self ._autotune_warmup ()
112114 self ._init_padded_req ()
115+ # wait必须在init cudagraph 之前,避免错误捕获
116+ self ._wait_other_modules_ready ()
113117 self ._init_cudagraph ()
114118 self ._check_max_len_infer ()
115119 torch .cuda .empty_cache ()
116120 set_model_init_status (True )
117121 return
118122
123+ def _wait_other_modules_ready (self ):
124+ for event in self .wait_events :
125+ event .wait ()
126+ return
127+
119128 def _init_config (self ):
120129 with open (os .path .join (self .weight_dir_ , "config.json" ), "r" ) as json_file :
121130 self .config = json .load (json_file )
@@ -343,17 +352,22 @@ def _prefill(
343352 model_input : ModelInput ,
344353 ):
345354 infer_state = self ._create_inferstate (model_input )
346- infer_state .init_some_extra_state (self , model_input .input_ids )
347355 init_req_to_token_indexes (
348356 req_to_token_indexs = self .req_manager .req_to_token_indexs ,
349357 b_req_idx = infer_state .b_req_idx ,
350358 b_seq_len = infer_state .b_seq_len ,
351359 b_ready_cache_len = infer_state .b_ready_cache_len ,
352- b_start_loc = infer_state . b_start_loc ,
360+ b_start_loc = model_input . b_prefill_start_loc ,
353361 alloc_mem_index = infer_state .mem_index ,
354362 max_q_seq_len = infer_state .max_q_seq_len ,
355363 )
356- return self ._context_forward (model_input .input_ids , infer_state )
364+ prefill_mem_indexes_ready_event = torch .cuda .Event ()
365+ prefill_mem_indexes_ready_event .record ()
366+
367+ infer_state .init_some_extra_state (self , model_input .input_ids )
368+ model_output = self ._context_forward (model_input .input_ids , infer_state )
369+ model_output .prefill_mem_indexes_ready_event = prefill_mem_indexes_ready_event
370+ return model_output
357371
358372 def _decode (
359373 self ,
@@ -482,28 +496,31 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod
482496 input_ids0 , input_ids1 = model_input0 .input_ids , model_input1 .input_ids
483497
484498 infer_state0 = self ._create_inferstate (model_input0 , 0 )
485- infer_state0 .init_some_extra_state (self , input_ids0 )
486499 init_req_to_token_indexes (
487500 req_to_token_indexs = self .req_manager .req_to_token_indexs ,
488501 b_req_idx = infer_state0 .b_req_idx ,
489502 b_seq_len = infer_state0 .b_seq_len ,
490503 b_ready_cache_len = infer_state0 .b_ready_cache_len ,
491- b_start_loc = infer_state0 . b_start_loc ,
504+ b_start_loc = model_input0 . b_prefill_start_loc ,
492505 alloc_mem_index = infer_state0 .mem_index ,
493506 max_q_seq_len = infer_state0 .max_q_seq_len ,
494507 )
508+ infer_state0 .init_some_extra_state (self , input_ids0 )
495509
496510 infer_state1 = self ._create_inferstate (model_input1 , 1 )
497- infer_state1 .init_some_extra_state (self , input_ids1 )
498511 init_req_to_token_indexes (
499512 req_to_token_indexs = self .req_manager .req_to_token_indexs ,
500513 b_req_idx = infer_state1 .b_req_idx ,
501514 b_seq_len = infer_state1 .b_seq_len ,
502515 b_ready_cache_len = infer_state1 .b_ready_cache_len ,
503- b_start_loc = infer_state1 . b_start_loc ,
516+ b_start_loc = model_input1 . b_prefill_start_loc ,
504517 alloc_mem_index = infer_state1 .mem_index ,
505518 max_q_seq_len = infer_state1 .max_q_seq_len ,
506519 )
520+ infer_state1 .init_some_extra_state (self , input_ids1 )
521+
522+ prefill_mem_indexes_ready_event = torch .cuda .Event ()
523+ prefill_mem_indexes_ready_event .record ()
507524
508525 model_output0 , model_output1 = self ._overlap_tpsp_context_forward (
509526 input_ids0 , infer_state0 , input_ids1 = input_ids1 , infer_state1 = infer_state1
@@ -512,6 +529,8 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod
512529 # 在开启使用deepep的时候,需要调用clear_deepep_buffer做资源清理,没有启用的时候
513530 # 该调用没有实际意义
514531 dist_group_manager .clear_deepep_buffer ()
532+ model_output0 .prefill_mem_indexes_ready_event = prefill_mem_indexes_ready_event
533+ model_output1 .prefill_mem_indexes_ready_event = prefill_mem_indexes_ready_event
515534 return model_output0 , model_output1
516535
517536 @torch .no_grad ()
@@ -704,6 +723,7 @@ def _check_max_len_infer(self):
704723 b_seq_len = torch .ones (1 , dtype = torch .int32 , device = "cuda" )
705724 b_seq_len [:] = self .batch_max_tokens
706725 b_ready_cache_len = torch .zeros (1 , dtype = torch .int32 , device = "cuda" )
726+ b_prefill_start_loc = torch .zeros (1 , dtype = torch .int32 , device = "cuda" )
707727 total_token_num = self .batch_max_tokens
708728 b_mtp_index = torch .zeros (1 , dtype = torch .int32 , device = "cuda" )
709729 model_input = ModelInput (
@@ -721,6 +741,7 @@ def _check_max_len_infer(self):
721741 b_mtp_index = b_mtp_index ,
722742 is_prefill = True ,
723743 b_ready_cache_len = b_ready_cache_len ,
744+ b_prefill_start_loc = b_prefill_start_loc ,
724745 )
725746 model_output = self .forward (
726747 model_input ,
@@ -778,6 +799,7 @@ def _autotune_warmup(self):
778799 b_seq_len = torch .ones (1 , dtype = torch .int32 , device = "cuda" )
779800 b_seq_len [:] = input_len
780801 b_ready_cache_len = torch .zeros (1 , dtype = torch .int32 , device = "cuda" )
802+ b_prefill_start_loc = torch .zeros (1 , dtype = torch .int32 , device = "cuda" )
781803 total_token_num = input_len
782804 b_mtp_index = torch .zeros (1 , dtype = torch .int32 , device = "cuda" )
783805 model_input = ModelInput (
@@ -795,6 +817,7 @@ def _autotune_warmup(self):
795817 b_mtp_index = b_mtp_index ,
796818 is_prefill = True ,
797819 b_ready_cache_len = b_ready_cache_len ,
820+ b_prefill_start_loc = b_prefill_start_loc ,
798821 multimodal_params = [],
799822 ** self ._gen_special_model_input (total_token_num ),
800823 )
@@ -838,6 +861,8 @@ def _init_padded_req(self):
838861 )
839862 b_seq_len = torch .ones (batch_size , dtype = torch .int32 , device = "cuda" )
840863 b_ready_cache_len = torch .zeros (batch_size , dtype = torch .int32 , device = "cuda" )
864+ b_q_seq_len = b_seq_len - b_ready_cache_len
865+ b_prefill_start_loc = b_q_seq_len .cumsum (dim = 0 , dtype = torch .int32 ) - b_q_seq_len
841866 total_token_num = prefill_input_len * batch_size
842867 b_mtp_index = torch .zeros (batch_size , dtype = torch .int32 , device = "cuda" )
843868 model_input = ModelInput (
@@ -854,6 +879,7 @@ def _init_padded_req(self):
854879 b_mtp_index = b_mtp_index ,
855880 b_seq_len = b_seq_len ,
856881 b_ready_cache_len = b_ready_cache_len ,
882+ b_prefill_start_loc = b_prefill_start_loc ,
857883 is_prefill = True ,
858884 multimodal_params = [],
859885 ** self ._gen_special_model_input (total_token_num ),
0 commit comments