Skip to content

Commit d9cb8c3

Browse files
author
sangchengmeng
committed
merge main
2 parents 40f8c6a + 230d9d8 commit d9cb8c3

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+3967
-672
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ def __init__(self, kvargs):
6161
self.finetune_config = kvargs.get("finetune_config", None)
6262
self.max_req_num = kvargs.get("max_req_num", 1000)
6363
self.max_seq_length = kvargs.get("max_seq_length", 1024 * 5)
64+
# 用于等待外围的一些模块的初始化完成(如 CPU KV Cache 注册完成)
65+
self.wait_events = kvargs.get("wait_events", [])
6466
# is_token_healing 和 return_all_prompt_logics 是有排斥关系的两个模式,只能单独有一个生效
6567
# 主要是在prefill阶段返回多少个token的用于后续处理相关。
6668
self.is_token_healing = kvargs.get("is_token_healing", False)
@@ -110,12 +112,19 @@ def __init__(self, kvargs):
110112
self._init_inferstate_cls()
111113
self._autotune_warmup()
112114
self._init_padded_req()
115+
# wait必须在init cudagraph 之前,避免错误捕获
116+
self._wait_other_modules_ready()
113117
self._init_cudagraph()
114118
self._check_max_len_infer()
115119
torch.cuda.empty_cache()
116120
set_model_init_status(True)
117121
return
118122

123+
def _wait_other_modules_ready(self):
124+
for event in self.wait_events:
125+
event.wait()
126+
return
127+
119128
def _init_config(self):
120129
with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file:
121130
self.config = json.load(json_file)
@@ -343,17 +352,22 @@ def _prefill(
343352
model_input: ModelInput,
344353
):
345354
infer_state = self._create_inferstate(model_input)
346-
infer_state.init_some_extra_state(self, model_input.input_ids)
347355
init_req_to_token_indexes(
348356
req_to_token_indexs=self.req_manager.req_to_token_indexs,
349357
b_req_idx=infer_state.b_req_idx,
350358
b_seq_len=infer_state.b_seq_len,
351359
b_ready_cache_len=infer_state.b_ready_cache_len,
352-
b_start_loc=infer_state.b_start_loc,
360+
b_start_loc=model_input.b_prefill_start_loc,
353361
alloc_mem_index=infer_state.mem_index,
354362
max_q_seq_len=infer_state.max_q_seq_len,
355363
)
356-
return self._context_forward(model_input.input_ids, infer_state)
364+
prefill_mem_indexes_ready_event = torch.cuda.Event()
365+
prefill_mem_indexes_ready_event.record()
366+
367+
infer_state.init_some_extra_state(self, model_input.input_ids)
368+
model_output = self._context_forward(model_input.input_ids, infer_state)
369+
model_output.prefill_mem_indexes_ready_event = prefill_mem_indexes_ready_event
370+
return model_output
357371

358372
def _decode(
359373
self,
@@ -482,28 +496,31 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod
482496
input_ids0, input_ids1 = model_input0.input_ids, model_input1.input_ids
483497

484498
infer_state0 = self._create_inferstate(model_input0, 0)
485-
infer_state0.init_some_extra_state(self, input_ids0)
486499
init_req_to_token_indexes(
487500
req_to_token_indexs=self.req_manager.req_to_token_indexs,
488501
b_req_idx=infer_state0.b_req_idx,
489502
b_seq_len=infer_state0.b_seq_len,
490503
b_ready_cache_len=infer_state0.b_ready_cache_len,
491-
b_start_loc=infer_state0.b_start_loc,
504+
b_start_loc=model_input0.b_prefill_start_loc,
492505
alloc_mem_index=infer_state0.mem_index,
493506
max_q_seq_len=infer_state0.max_q_seq_len,
494507
)
508+
infer_state0.init_some_extra_state(self, input_ids0)
495509

496510
infer_state1 = self._create_inferstate(model_input1, 1)
497-
infer_state1.init_some_extra_state(self, input_ids1)
498511
init_req_to_token_indexes(
499512
req_to_token_indexs=self.req_manager.req_to_token_indexs,
500513
b_req_idx=infer_state1.b_req_idx,
501514
b_seq_len=infer_state1.b_seq_len,
502515
b_ready_cache_len=infer_state1.b_ready_cache_len,
503-
b_start_loc=infer_state1.b_start_loc,
516+
b_start_loc=model_input1.b_prefill_start_loc,
504517
alloc_mem_index=infer_state1.mem_index,
505518
max_q_seq_len=infer_state1.max_q_seq_len,
506519
)
520+
infer_state1.init_some_extra_state(self, input_ids1)
521+
522+
prefill_mem_indexes_ready_event = torch.cuda.Event()
523+
prefill_mem_indexes_ready_event.record()
507524

508525
model_output0, model_output1 = self._overlap_tpsp_context_forward(
509526
input_ids0, infer_state0, input_ids1=input_ids1, infer_state1=infer_state1
@@ -512,6 +529,8 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod
512529
# 在开启使用deepep的时候,需要调用clear_deepep_buffer做资源清理,没有启用的时候
513530
# 该调用没有实际意义
514531
dist_group_manager.clear_deepep_buffer()
532+
model_output0.prefill_mem_indexes_ready_event = prefill_mem_indexes_ready_event
533+
model_output1.prefill_mem_indexes_ready_event = prefill_mem_indexes_ready_event
515534
return model_output0, model_output1
516535

517536
@torch.no_grad()
@@ -704,6 +723,7 @@ def _check_max_len_infer(self):
704723
b_seq_len = torch.ones(1, dtype=torch.int32, device="cuda")
705724
b_seq_len[:] = self.batch_max_tokens
706725
b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda")
726+
b_prefill_start_loc = torch.zeros(1, dtype=torch.int32, device="cuda")
707727
total_token_num = self.batch_max_tokens
708728
b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cuda")
709729
model_input = ModelInput(
@@ -721,6 +741,7 @@ def _check_max_len_infer(self):
721741
b_mtp_index=b_mtp_index,
722742
is_prefill=True,
723743
b_ready_cache_len=b_ready_cache_len,
744+
b_prefill_start_loc=b_prefill_start_loc,
724745
)
725746
model_output = self.forward(
726747
model_input,
@@ -778,6 +799,7 @@ def _autotune_warmup(self):
778799
b_seq_len = torch.ones(1, dtype=torch.int32, device="cuda")
779800
b_seq_len[:] = input_len
780801
b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda")
802+
b_prefill_start_loc = torch.zeros(1, dtype=torch.int32, device="cuda")
781803
total_token_num = input_len
782804
b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cuda")
783805
model_input = ModelInput(
@@ -795,6 +817,7 @@ def _autotune_warmup(self):
795817
b_mtp_index=b_mtp_index,
796818
is_prefill=True,
797819
b_ready_cache_len=b_ready_cache_len,
820+
b_prefill_start_loc=b_prefill_start_loc,
798821
multimodal_params=[],
799822
**self._gen_special_model_input(total_token_num),
800823
)
@@ -838,6 +861,8 @@ def _init_padded_req(self):
838861
)
839862
b_seq_len = torch.ones(batch_size, dtype=torch.int32, device="cuda")
840863
b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
864+
b_q_seq_len = b_seq_len - b_ready_cache_len
865+
b_prefill_start_loc = b_q_seq_len.cumsum(dim=0, dtype=torch.int32) - b_q_seq_len
841866
total_token_num = prefill_input_len * batch_size
842867
b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
843868
model_input = ModelInput(
@@ -854,6 +879,7 @@ def _init_padded_req(self):
854879
b_mtp_index=b_mtp_index,
855880
b_seq_len=b_seq_len,
856881
b_ready_cache_len=b_ready_cache_len,
882+
b_prefill_start_loc=b_prefill_start_loc,
857883
is_prefill=True,
858884
multimodal_params=[],
859885
**self._gen_special_model_input(total_token_num),

lightllm/common/basemodel/batch_objs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class ModelInput:
2424
mem_indexes: torch.Tensor = None
2525
is_prefill: bool = False
2626
b_ready_cache_len: torch.Tensor = None
27+
b_prefill_start_loc: torch.Tensor = None
2728
multimodal_params: list = field(default_factory=list)
2829

2930
# cpu 变量
@@ -49,12 +50,16 @@ def to_cuda(self):
4950
self.b_mtp_index = self.b_mtp_index.cuda(non_blocking=True)
5051
if self.b_ready_cache_len is not None:
5152
self.b_ready_cache_len = self.b_ready_cache_len.cuda(non_blocking=True)
53+
if self.b_prefill_start_loc is not None:
54+
self.b_prefill_start_loc = self.b_prefill_start_loc.cuda(non_blocking=True)
5255

5356

5457
@dataclass
5558
class ModelOutput:
5659
# 通用变量
5760
logits: torch.Tensor
61+
# 用于判断 mem_indexes 是否成功写入 req manager 中的事件对象。
62+
prefill_mem_indexes_ready_event: torch.Event = None
5863

5964
# 专有变量,用于一些特殊的模型,特殊的模式下, 传递一些特殊
6065
# 的输出变量。只在特殊的模型模式下才会具体使用和生效。

lightllm/common/basemodel/layer_weights/meta_weights/gpt_oss_fused_moe_weight_tp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,10 +138,10 @@ def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_t
138138
topk_ids=topk_ids,
139139
inplace=True,
140140
use_fp8_w8a8=use_fp8_w8a8,
141-
w1_scale=w1_scale,
142-
w2_scale=w2_scale,
143141
w1_bias=self.w1_bias,
144142
w2_bias=self.w2_bias / self.tp_world_size_,
143+
w1_scale=w1_scale,
144+
w2_scale=w2_scale,
145145
layout="interleaved",
146146
alpha=self.alpha,
147147
limit=self.limit,

0 commit comments

Comments
 (0)