Skip to content

Commit b429ddc

Browse files
multimodal chuncked prefill select needed multimodal objs. (#954)
Co-authored-by: baishihao <baishihao@sensetime.com> Co-authored-by: shihaobai <1798930569@qq.com>
1 parent 998e083 commit b429ddc

File tree

5 files changed

+167
-55
lines changed

5 files changed

+167
-55
lines changed

lightllm/common/basemodel/infer_struct.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Tuple, Any, Optional
66
from .triton_kernel.gen_prefill_params import gen_prefill_params
77
from .triton_kernel.gen_decode_params import gen_decode_params
8+
from .triton_kernel.multimodal_emb import mark_multimodal_obj
89

910

1011
class InferStateInfo:
@@ -98,3 +99,24 @@ def copy_for_cuda_graph(self, new_infer_state: "InferStateInfo"):
9899
if attr_ is not None and attr_.data_ptr() != attr_value.data_ptr():
99100
attr_.copy_(attr_value, non_blocking=True)
100101
return
102+
103+
def mark_multimodal_objs_for_prefill(self, input_ids: torch.Tensor):
104+
"""
105+
功能函数,用于标记在chuncked prefill的过程中,到底哪些多模态对象对应的token是需要参与计算的。
106+
因为分chunck的原因,并不是所有的多模态对象对应的token都需要参与计算。
107+
"""
108+
multi_objs = []
109+
for _, p in enumerate(self.multimodal_params):
110+
for obj in p["images"] + p["audios"]:
111+
multi_objs.append(obj)
112+
113+
if multi_objs:
114+
obj_start_ids = torch.tensor([e["token_id"] for e in multi_objs], dtype=torch.int64, device="cuda")
115+
obj_token_lens = torch.tensor([e["token_num"] for e in multi_objs], dtype=torch.int64, device="cuda")
116+
marks = mark_multimodal_obj(
117+
obj_start_token_ids=obj_start_ids, obj_token_lens=obj_token_lens, input_ids=input_ids
118+
)
119+
marks_array = marks.detach().cpu().numpy()
120+
for mark, obj in zip(marks_array, multi_objs):
121+
obj["_prefill_"] = mark > 0
122+
return

lightllm/common/basemodel/triton_kernel/multimodal_emb.py

Lines changed: 90 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -5,48 +5,78 @@
55

66
@triton.jit
77
def _fwd_kernel(
8-
Prompt_ids,
8+
Prompt_ids,
99
Text_weight_embs,
1010
Img_embs,
1111
Out,
1212
Img_token_lens,
1313
Img_start_token_ids,
1414
Img_start_locs,
15-
stride_text_emb_s, stride_text_emb_d, # text_stride
16-
stride_img_emb_s, stride_img_emb_d, # img_stride
17-
stride_out_s, stride_out_d,
15+
stride_text_emb_s,
16+
stride_text_emb_d, # text_stride
17+
stride_img_emb_s,
18+
stride_img_emb_d, # img_stride
19+
stride_out_s,
20+
stride_out_d,
1821
tp_text_start_token_id,
1922
tp_text_end_token_id,
2023
hidden_size,
21-
BLOCK_HIDDEN_DIM: tl.constexpr
22-
):
24+
BLOCK_HIDDEN_DIM: tl.constexpr,
25+
):
2326

2427
seq_index = tl.program_id(0).to(tl.int64)
2528
img_handle_id = tl.program_id(1)
2629

2730
token_id = tl.load(Prompt_ids + seq_index)
2831
off_d = tl.arange(0, BLOCK_HIDDEN_DIM)
29-
32+
3033
# load store text emb
31-
for _ in range(0, tl.where((img_handle_id == 0) & (token_id < tp_text_end_token_id) & (token_id >= tp_text_start_token_id), 1, 0), 1):
32-
load_emb = tl.load(Text_weight_embs + stride_text_emb_s * (token_id - tp_text_start_token_id) + off_d * stride_text_emb_d, mask=off_d < hidden_size, other=0)
34+
for _ in range(
35+
0,
36+
tl.where((img_handle_id == 0) & (token_id < tp_text_end_token_id) & (token_id >= tp_text_start_token_id), 1, 0),
37+
1,
38+
):
39+
load_emb = tl.load(
40+
Text_weight_embs + stride_text_emb_s * (token_id - tp_text_start_token_id) + off_d * stride_text_emb_d,
41+
mask=off_d < hidden_size,
42+
other=0,
43+
)
3344
tl.store(Out + stride_out_s * seq_index + stride_out_d * off_d, load_emb, mask=off_d < hidden_size)
34-
45+
3546
img_start_token_id = tl.load(Img_start_token_ids + img_handle_id - 1, mask=img_handle_id >= 1, other=0)
3647
img_start_loc = tl.load(Img_start_locs + img_handle_id - 1, mask=img_handle_id >= 1, other=0)
3748
img_token_len = tl.load(Img_token_lens + img_handle_id - 1, mask=img_handle_id >= 1, other=0)
3849
# load store img emb
39-
for _ in range(0, tl.where((img_handle_id != 0) & (token_id >= img_start_token_id) & (token_id < img_start_token_id + img_token_len), 1, 0), 1):
40-
load_emb = tl.load(Img_embs + stride_img_emb_s * (img_start_loc + token_id - img_start_token_id) + off_d * stride_img_emb_d, mask=off_d < hidden_size, other=0)
50+
for _ in range(
51+
0,
52+
tl.where(
53+
(img_handle_id != 0) & (token_id >= img_start_token_id) & (token_id < img_start_token_id + img_token_len),
54+
1,
55+
0,
56+
),
57+
1,
58+
):
59+
load_emb = tl.load(
60+
Img_embs + stride_img_emb_s * (img_start_loc + token_id - img_start_token_id) + off_d * stride_img_emb_d,
61+
mask=off_d < hidden_size,
62+
other=0,
63+
)
4164
tl.store(Out + stride_out_s * seq_index + stride_out_d * off_d, load_emb, mask=off_d < hidden_size)
4265
return
4366

4467

4568
@torch.no_grad()
46-
def multimodal_emb(out: torch.Tensor, prompt_ids: torch.Tensor, text_weight_embs: torch.Tensor, img_embs: torch.Tensor,
47-
img_token_lens: torch.Tensor, img_start_token_ids: torch.Tensor, img_start_locs: torch.Tensor,
48-
tp_text_start_token_id,
49-
tp_text_end_token_id):
69+
def multimodal_emb(
70+
out: torch.Tensor,
71+
prompt_ids: torch.Tensor,
72+
text_weight_embs: torch.Tensor,
73+
img_embs: torch.Tensor,
74+
img_token_lens: torch.Tensor,
75+
img_start_token_ids: torch.Tensor,
76+
img_start_locs: torch.Tensor,
77+
tp_text_start_token_id,
78+
tp_text_end_token_id,
79+
):
5080
total_len = prompt_ids.shape[0]
5181
BLOCK = triton.next_power_of_2(out.shape[1])
5282
# print(len(img_token_lens))
@@ -60,9 +90,12 @@ def multimodal_emb(out: torch.Tensor, prompt_ids: torch.Tensor, text_weight_embs
6090
img_token_lens,
6191
img_start_token_ids,
6292
img_start_locs,
63-
text_weight_embs.stride(0), text_weight_embs.stride(1),
64-
img_embs.stride(0), img_embs.stride(1),
65-
out.stride(0), out.stride(1),
93+
text_weight_embs.stride(0),
94+
text_weight_embs.stride(1),
95+
img_embs.stride(0),
96+
img_embs.stride(1),
97+
out.stride(0),
98+
out.stride(1),
6699
tp_text_start_token_id,
67100
tp_text_end_token_id,
68101
hidden_size=out.shape[1],
@@ -73,40 +106,44 @@ def multimodal_emb(out: torch.Tensor, prompt_ids: torch.Tensor, text_weight_embs
73106
return
74107

75108

76-
77-
def test():
78-
S, D = 1024 * 1000, 128 * 64
79-
vob_size = 320000
80-
image_size = 10
81-
image_token_size = 512
82-
83-
text_weight = torch.randn((vob_size, D), device='cuda', dtype=torch.float16)
84-
img_weight = torch.randn((image_size * image_token_size, D), device='cuda', dtype=torch.float16)
85-
img_token_lens = torch.full((image_size,), image_token_size, device='cuda', dtype=torch.long)
86-
img_start_token_ids = (torch.arange(0, image_size * image_token_size, image_token_size) + vob_size * 10).cuda().long()
87-
img_start_locs = torch.arange(0, image_size * image_token_size, image_token_size).cuda().long()
88-
89-
prompt_ids = torch.arange(0, S, 1).cuda().long()
90-
prompt_ids[0: image_size * image_token_size] = (vob_size * 10 + torch.arange(0, image_size * image_token_size, 1)).cuda().long()
91-
92-
out = torch.zeros((S, D), dtype=torch.float16, device="cuda")
93-
print(out.shape)
94-
95-
import time
96-
97-
triton_output = multimodal_emb(out, prompt_ids, text_weight, img_weight, img_token_lens, img_start_token_ids, img_start_locs, 0, vob_size)
98-
99-
torch.cuda.synchronize()
100-
iters = 20
101-
t1 = time.time()
102-
for _ in range(iters):
103-
triton_output = multimodal_emb(out, prompt_ids, text_weight, img_weight, img_token_lens, img_start_token_ids, img_start_locs, 0, vob_size)
104-
torch.cuda.synchronize()
105-
t2 = time.time()
106-
print("Triton time cost", (t2 - t1) / iters)
109+
@triton.jit
110+
def _mark_multimodal_obj_need_kernel(
111+
obj_start_token_ids_ptr,
112+
obj_token_lens_ptr,
113+
obj_marks_ptr,
114+
input_ids_ptr,
115+
input_size,
116+
BLOCK_SIZE: tl.constexpr,
117+
):
118+
119+
obj_index = tl.program_id(0)
120+
start_id = tl.load(obj_start_token_ids_ptr + obj_index)
121+
token_len = tl.load(obj_token_lens_ptr + obj_index)
122+
123+
for block_start in range(0, input_size, BLOCK_SIZE):
124+
block_range = block_start + tl.arange(0, BLOCK_SIZE)
125+
cur_input_ids = tl.load(input_ids_ptr + block_range, mask=block_range < input_size, other=0)
126+
mark = tl.where((cur_input_ids >= start_id) & (cur_input_ids < start_id + token_len), 1, 0)
127+
mark = tl.sum(mark)
128+
tl.store(obj_marks_ptr + obj_index, 1, mask=mark > 0)
107129
return
108130

109131

110-
# if __name__ == "__main__":
111-
# test()
112-
132+
@torch.no_grad()
133+
def mark_multimodal_obj(obj_start_token_ids: torch.Tensor, obj_token_lens: torch.Tensor, input_ids: torch.Tensor):
134+
out_mark = torch.empty_like(obj_start_token_ids)
135+
out_mark.fill_(0)
136+
assert obj_start_token_ids.shape == obj_token_lens.shape
137+
BLOCK = 512
138+
grid = (obj_start_token_ids.shape[0],)
139+
_mark_multimodal_obj_need_kernel[grid](
140+
obj_start_token_ids_ptr=obj_start_token_ids,
141+
obj_token_lens_ptr=obj_token_lens,
142+
obj_marks_ptr=out_mark,
143+
input_ids_ptr=input_ids,
144+
input_size=input_ids.shape[0],
145+
BLOCK_SIZE=BLOCK,
146+
num_warps=1,
147+
num_stages=1,
148+
)
149+
return out_mark

lightllm/models/gemma3/layer_infer/pre_layer_infer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,12 @@ def context_forward(self, input_ids, infer_state, layer_weight):
3535
else:
3636
weight_mask[idx] = scale
3737

38+
infer_state.mark_multimodal_objs_for_prefill(input_ids=input_ids)
39+
3840
for batch_id, p in enumerate(infer_state.multimodal_params):
3941
for img in p["images"]:
4042
# skip the same image
41-
if img["token_id"] in img_start_token_ids:
43+
if img["token_id"] in img_start_token_ids or img["_prefill_"] is False:
4244
continue
4345
# pull the img_embeds by uid from shm
4446
data = read_shm(get_shm_name_embed(img["uuid"]))

lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,13 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei
4242
device = layer_weight.wte_weight_.device
4343
dtype = layer_weight.wte_weight_.dtype
4444
hidden_size = layer_weight.wte_weight_.shape[1]
45+
46+
infer_state.mark_multimodal_objs_for_prefill(input_ids=input_ids)
47+
4548
for batch_id, p in enumerate(infer_state.multimodal_params):
4649
for img in p["images"] + p["audios"]:
4750
# skip the same image
48-
if img["token_id"] in img_start_token_ids:
51+
if img["token_id"] in img_start_token_ids or img["_prefill_"] is False:
4952
continue
5053
# pull the img_embeds by uid from shm
5154
data = read_shm(get_shm_name_embed(img["uuid"]))
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import torch
2+
import pytest
3+
from lightllm.common.basemodel.triton_kernel.multimodal_emb import mark_multimodal_obj, multimodal_emb
4+
from lightllm.utils.log_utils import init_logger
5+
6+
logger = init_logger(__name__)
7+
8+
9+
def test_mark_mubltimodal_obj():
10+
obj_start_ids = torch.tensor([1, 4, 100], device="cuda", dtype=torch.int64)
11+
obj_token_lens = torch.tensor([1, 3, 2], device="cuda", dtype=torch.int64)
12+
input_ids = torch.tensor([1, 7, 9, 333], device="cuda", dtype=torch.int64)
13+
14+
mark_obj = mark_multimodal_obj(
15+
obj_start_token_ids=obj_start_ids, obj_token_lens=obj_token_lens, input_ids=input_ids
16+
)
17+
18+
assert torch.equal(mark_obj, torch.tensor([1, 0, 0], device="cuda"))
19+
20+
21+
def test_multimodal_emb():
22+
S, D = 1024 * 1000, 128 * 64
23+
vob_size = 320000
24+
image_size = 10
25+
image_token_size = 512
26+
27+
text_weight = torch.randn((vob_size, D), device="cuda", dtype=torch.float16)
28+
img_weight = torch.randn((image_size * image_token_size, D), device="cuda", dtype=torch.float16)
29+
img_token_lens = torch.full((image_size,), image_token_size, device="cuda", dtype=torch.long)
30+
img_start_token_ids = (
31+
(torch.arange(0, image_size * image_token_size, image_token_size) + vob_size * 10).cuda().long()
32+
)
33+
img_start_locs = torch.arange(0, image_size * image_token_size, image_token_size).cuda().long()
34+
35+
prompt_ids = torch.arange(0, S, 1).cuda().long()
36+
prompt_ids[0 : image_size * image_token_size] = (
37+
(vob_size * 10 + torch.arange(0, image_size * image_token_size, 1)).cuda().long()
38+
)
39+
40+
out = torch.zeros((S, D), dtype=torch.float16, device="cuda")
41+
multimodal_emb(
42+
out, prompt_ids, text_weight, img_weight, img_token_lens, img_start_token_ids, img_start_locs, 0, vob_size
43+
)
44+
return
45+
46+
47+
if __name__ == "__main__":
48+
pytest.main()

0 commit comments

Comments
 (0)