Skip to content

Commit 6ff2908

Browse files
authored
fix gpu out token counter update error (#988)
1 parent 826b08d commit 6ff2908

File tree

5 files changed

+111
-26
lines changed

5 files changed

+111
-26
lines changed

lightllm/common/basemodel/triton_kernel/gen_sampling_params.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -121,37 +121,54 @@ def _token_id_counter_update_kernel(
121121
counter_stride_m,
122122
counter_stride_n,
123123
next_token_ids_ptr,
124+
mask_ptr,
124125
batch_size,
126+
HAS_MASK: tl.constexpr,
125127
BLOCK: tl.constexpr,
126128
):
127129

128130
block_start_index = tl.program_id(0) * BLOCK
129131
offs = block_start_index + tl.arange(0, BLOCK)
130-
mask = offs < batch_size
131-
132-
req_idx = tl.load(b_req_idx_ptr + offs, mask=mask, other=0)
133-
token_ids = tl.load(next_token_ids_ptr + offs, mask=mask, other=0)
134-
135-
tl.atomic_add(
136-
req_to_out_token_id_counter_ptr + req_idx * counter_stride_m + token_ids * counter_stride_n, 1, mask=mask
137-
)
132+
loc_mask = offs < batch_size
133+
134+
req_idx = tl.load(b_req_idx_ptr + offs, mask=loc_mask, other=0)
135+
token_ids = tl.load(next_token_ids_ptr + offs, mask=loc_mask, other=0)
136+
137+
if HAS_MASK:
138+
mask = tl.load(mask_ptr + offs, mask=loc_mask, other=False)
139+
tl.atomic_add(
140+
req_to_out_token_id_counter_ptr + req_idx * counter_stride_m + token_ids * counter_stride_n,
141+
1,
142+
mask=loc_mask & mask,
143+
)
144+
else:
145+
tl.atomic_add(
146+
req_to_out_token_id_counter_ptr + req_idx * counter_stride_m + token_ids * counter_stride_n,
147+
1,
148+
mask=loc_mask,
149+
)
138150
return
139151

140152

141153
@torch.no_grad()
142154
def update_req_to_token_id_counter(
143-
b_req_idx: torch.Tensor, next_token_ids: torch.Tensor, req_to_out_token_id_counter: torch.Tensor
155+
b_req_idx: torch.Tensor,
156+
next_token_ids: torch.Tensor,
157+
req_to_out_token_id_counter: torch.Tensor,
158+
mask: torch.Tensor = None,
144159
):
145160
batch_size = b_req_idx.shape[0]
146161
BLOCK = 256
147-
162+
has_mask = mask is not None
148163
_token_id_counter_update_kernel[(triton.cdiv(batch_size, BLOCK),)](
149164
b_req_idx_ptr=b_req_idx,
150165
req_to_out_token_id_counter_ptr=req_to_out_token_id_counter,
151166
counter_stride_m=req_to_out_token_id_counter.stride(0),
152167
counter_stride_n=req_to_out_token_id_counter.stride(1),
153168
next_token_ids_ptr=next_token_ids,
169+
mask_ptr=mask,
154170
batch_size=batch_size,
171+
HAS_MASK=has_mask,
155172
BLOCK=BLOCK,
156173
num_warps=1,
157174
)

lightllm/common/req_manager.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -162,29 +162,35 @@ def init_req_sampling_params(self, req):
162162

163163
return
164164

165+
def update_reqs_out_token_counter_gpu(
166+
self, b_req_idx: torch.Tensor, next_token_ids: torch.Tensor, mask: torch.Tensor = None
167+
):
168+
if self.penalty_counter_mode not in ["gpu_counter", "pin_mem_counter"]:
169+
return
170+
171+
assert b_req_idx.is_cuda and next_token_ids.is_cuda and b_req_idx.shape[0] == next_token_ids.shape[0]
172+
173+
update_req_to_token_id_counter(
174+
b_req_idx=b_req_idx,
175+
next_token_ids=next_token_ids,
176+
req_to_out_token_id_counter=self.req_to_out_token_id_counter,
177+
mask=mask,
178+
)
179+
return
180+
165181
def update_reqs_token_counter(
166182
self, req_objs: List, next_token_ids: List[int], accept_mark: Optional[List[List[bool]]] = None
167183
):
168184
from lightllm.server.router.model_infer.infer_batch import InferReq
169185

170186
req_objs: List[InferReq] = req_objs
171187

172-
if self.penalty_counter_mode == "cpu_counter":
173-
for req_obj, next_token_id in zip(req_objs, next_token_ids):
174-
if req_obj.need_out_token_id_statistics and req_obj.cur_output_len > 0:
175-
req_obj.out_token_id_count[next_token_id] += 1
176-
else:
177-
b_req_idx = torch.tensor(
178-
[req.req_idx for req in req_objs], dtype=torch.int32, device="cpu", pin_memory=True
179-
).cuda(non_blocking=True)
180-
next_token_ids = (
181-
torch.tensor(next_token_ids, dtype=torch.int32, device="cpu").pin_memory().cuda(non_blocking=True)
182-
)
183-
update_req_to_token_id_counter(
184-
b_req_idx=b_req_idx,
185-
next_token_ids=next_token_ids,
186-
req_to_out_token_id_counter=self.req_to_out_token_id_counter,
187-
)
188+
if self.penalty_counter_mode != "cpu_counter":
189+
return
190+
191+
for req_obj, next_token_id in zip(req_objs, next_token_ids):
192+
if req_obj.need_out_token_id_statistics and req_obj.cur_output_len > 0:
193+
req_obj.out_token_id_count[next_token_id] += 1
188194
return
189195

190196
def gen_cpu_out_token_counter_sampling_params(self, req_objs: List):

lightllm/server/router/model_infer/infer_batch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def register(
3838
self, req_manager: ReqManager, radix_cache: RadixCache, shm_req_manager: ShmReqManager, vocab_size: int
3939
):
4040
self.req_manager = req_manager
41+
self.req_sampling_manager = self.req_manager.req_sampling_params_manager
4142
self.radix_cache = radix_cache
4243
self.shm_req_manager = shm_req_manager
4344

lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,11 @@ def prefill_normal(
115115
b_mtp_index=model_input.b_mtp_index,
116116
b_has_out=b_has_out,
117117
)
118+
g_infer_context.req_sampling_manager.update_reqs_out_token_counter_gpu(
119+
b_req_idx=model_input.b_req_idx,
120+
next_token_ids=next_token_ids,
121+
mask=b_has_out,
122+
)
118123
next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem(
119124
next_token_ids, next_token_logprobs
120125
)
@@ -158,6 +163,10 @@ def decode_normal(
158163
model_input.b_req_idx,
159164
model_input.b_mtp_index,
160165
)
166+
g_infer_context.req_sampling_manager.update_reqs_out_token_counter_gpu(
167+
b_req_idx=model_input.b_req_idx,
168+
next_token_ids=next_token_ids,
169+
)
161170
next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem(
162171
next_token_ids, next_token_logprobs
163172
)
@@ -205,6 +214,11 @@ def prefill_mtp(
205214
b_mtp_index=model_input.b_mtp_index,
206215
b_has_out=b_has_out,
207216
)
217+
g_infer_context.req_sampling_manager.update_reqs_out_token_counter_gpu(
218+
b_req_idx=model_input.b_req_idx,
219+
next_token_ids=next_token_ids,
220+
mask=b_has_out,
221+
)
208222
next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem(
209223
next_token_ids, next_token_logprobs
210224
)
@@ -305,6 +319,13 @@ def decode_mtp(
305319
b_req_idx=model_input.b_req_idx,
306320
mtp_accept_len=mtp_accept_len,
307321
)
322+
323+
g_infer_context.req_sampling_manager.update_reqs_out_token_counter_gpu(
324+
b_req_idx=model_input.b_req_idx,
325+
next_token_ids=next_token_ids,
326+
mask=accepted_index == 1,
327+
)
328+
308329
next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem(
309330
next_token_ids, next_token_logprobs
310331
)

lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,12 @@ def prefill_normal(
134134
b_mtp_index=b_mtp_index,
135135
b_has_out=b_has_out,
136136
)
137+
g_infer_context.req_sampling_manager.update_reqs_out_token_counter_gpu(
138+
b_req_idx=b_req_idx,
139+
next_token_ids=next_token_ids,
140+
mask=b_has_out,
141+
)
142+
137143
next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem(
138144
next_token_ids, next_token_logprobs
139145
)
@@ -182,6 +188,10 @@ def decode_normal(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq
182188
b_req_idx=b_req_idx,
183189
b_mtp_index=b_mtp_index,
184190
)
191+
g_infer_context.req_sampling_manager.update_reqs_out_token_counter_gpu(
192+
b_req_idx=b_req_idx,
193+
next_token_ids=next_token_ids,
194+
)
185195
next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem(
186196
next_token_ids, next_token_logprobs
187197
)
@@ -254,6 +264,11 @@ def prefill_overlap(self, event_pack: OverlapEventPack, prefill_reqs: List[Infer
254264
b_mtp_index=b_mtp_index,
255265
b_has_out=b_has_out,
256266
)
267+
g_infer_context.req_sampling_manager.update_reqs_out_token_counter_gpu(
268+
b_req_idx=b_req_idx,
269+
next_token_ids=next_token_ids,
270+
mask=b_has_out,
271+
)
257272
next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem(
258273
next_token_ids, next_token_logprobs
259274
)
@@ -318,6 +333,10 @@ def decode_overlap(self, event_pack: OverlapEventPack, decode_reqs: List[InferRe
318333
b_req_idx=b_req_idx,
319334
b_mtp_index=b_mtp_index,
320335
)
336+
g_infer_context.req_sampling_manager.update_reqs_out_token_counter_gpu(
337+
b_req_idx=b_req_idx,
338+
next_token_ids=next_token_ids,
339+
)
321340
next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem(
322341
next_token_ids, next_token_logprobs
323342
)
@@ -374,6 +393,11 @@ def prefill_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq]
374393
b_mtp_index=b_mtp_index,
375394
b_has_out=b_has_out,
376395
)
396+
g_infer_context.req_sampling_manager.update_reqs_out_token_counter_gpu(
397+
b_req_idx=b_req_idx,
398+
next_token_ids=next_token_ids,
399+
mask=b_has_out,
400+
)
377401
next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem(
378402
next_token_ids, next_token_logprobs
379403
)
@@ -493,6 +517,11 @@ def decode_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq]):
493517
b_req_idx=b_req_idx,
494518
mtp_accept_len=mtp_accept_len,
495519
)
520+
g_infer_context.req_sampling_manager.update_reqs_out_token_counter_gpu(
521+
b_req_idx=b_req_idx,
522+
next_token_ids=next_token_ids,
523+
mask=accepted_index == 1,
524+
)
496525
next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem(
497526
next_token_ids, next_token_logprobs
498527
)
@@ -571,6 +600,11 @@ def prefill_overlap_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[I
571600
b_mtp_index=b_mtp_index,
572601
b_has_out=b_has_out,
573602
)
603+
g_infer_context.req_sampling_manager.update_reqs_out_token_counter_gpu(
604+
b_req_idx=b_req_idx,
605+
next_token_ids=next_token_ids,
606+
mask=b_has_out,
607+
)
574608

575609
# spec prefill: MTP
576610
draft_micro_input0, draft_micro_input1 = micro_input0, micro_input1
@@ -733,6 +767,12 @@ def decode_overlap_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[Inf
733767
b_req_idx=b_req_idx,
734768
mtp_accept_len=mtp_accept_len,
735769
)
770+
771+
g_infer_context.req_sampling_manager.update_reqs_out_token_counter_gpu(
772+
b_req_idx=b_req_idx,
773+
next_token_ids=next_token_ids,
774+
mask=accepted_index == 1,
775+
)
736776
next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem(
737777
next_token_ids, next_token_logprobs
738778
)

0 commit comments

Comments
 (0)