55
66@triton .jit
77def _fwd_kernel (
8- Prompt_ids ,
8+ Prompt_ids ,
99 Text_weight_embs ,
1010 Img_embs ,
1111 Out ,
1212 Img_token_lens ,
1313 Img_start_token_ids ,
1414 Img_start_locs ,
15- stride_text_emb_s , stride_text_emb_d , # text_stride
16- stride_img_emb_s , stride_img_emb_d , # img_stride
17- stride_out_s , stride_out_d ,
15+ stride_text_emb_s ,
16+ stride_text_emb_d , # text_stride
17+ stride_img_emb_s ,
18+ stride_img_emb_d , # img_stride
19+ stride_out_s ,
20+ stride_out_d ,
1821 tp_text_start_token_id ,
1922 tp_text_end_token_id ,
2023 hidden_size ,
21- BLOCK_HIDDEN_DIM : tl .constexpr
22- ):
24+ BLOCK_HIDDEN_DIM : tl .constexpr ,
25+ ):
2326
2427 seq_index = tl .program_id (0 ).to (tl .int64 )
2528 img_handle_id = tl .program_id (1 )
2629
2730 token_id = tl .load (Prompt_ids + seq_index )
2831 off_d = tl .arange (0 , BLOCK_HIDDEN_DIM )
29-
32+
3033 # load store text emb
31- for _ in range (0 , tl .where ((img_handle_id == 0 ) & (token_id < tp_text_end_token_id ) & (token_id >= tp_text_start_token_id ), 1 , 0 ), 1 ):
32- load_emb = tl .load (Text_weight_embs + stride_text_emb_s * (token_id - tp_text_start_token_id ) + off_d * stride_text_emb_d , mask = off_d < hidden_size , other = 0 )
34+ for _ in range (
35+ 0 ,
36+ tl .where ((img_handle_id == 0 ) & (token_id < tp_text_end_token_id ) & (token_id >= tp_text_start_token_id ), 1 , 0 ),
37+ 1 ,
38+ ):
39+ load_emb = tl .load (
40+ Text_weight_embs + stride_text_emb_s * (token_id - tp_text_start_token_id ) + off_d * stride_text_emb_d ,
41+ mask = off_d < hidden_size ,
42+ other = 0 ,
43+ )
3344 tl .store (Out + stride_out_s * seq_index + stride_out_d * off_d , load_emb , mask = off_d < hidden_size )
34-
45+
3546 img_start_token_id = tl .load (Img_start_token_ids + img_handle_id - 1 , mask = img_handle_id >= 1 , other = 0 )
3647 img_start_loc = tl .load (Img_start_locs + img_handle_id - 1 , mask = img_handle_id >= 1 , other = 0 )
3748 img_token_len = tl .load (Img_token_lens + img_handle_id - 1 , mask = img_handle_id >= 1 , other = 0 )
3849 # load store img emb
39- for _ in range (0 , tl .where ((img_handle_id != 0 ) & (token_id >= img_start_token_id ) & (token_id < img_start_token_id + img_token_len ), 1 , 0 ), 1 ):
40- load_emb = tl .load (Img_embs + stride_img_emb_s * (img_start_loc + token_id - img_start_token_id ) + off_d * stride_img_emb_d , mask = off_d < hidden_size , other = 0 )
50+ for _ in range (
51+ 0 ,
52+ tl .where (
53+ (img_handle_id != 0 ) & (token_id >= img_start_token_id ) & (token_id < img_start_token_id + img_token_len ),
54+ 1 ,
55+ 0 ,
56+ ),
57+ 1 ,
58+ ):
59+ load_emb = tl .load (
60+ Img_embs + stride_img_emb_s * (img_start_loc + token_id - img_start_token_id ) + off_d * stride_img_emb_d ,
61+ mask = off_d < hidden_size ,
62+ other = 0 ,
63+ )
4164 tl .store (Out + stride_out_s * seq_index + stride_out_d * off_d , load_emb , mask = off_d < hidden_size )
4265 return
4366
4467
4568@torch .no_grad ()
46- def multimodal_emb (out : torch .Tensor , prompt_ids : torch .Tensor , text_weight_embs : torch .Tensor , img_embs : torch .Tensor ,
47- img_token_lens : torch .Tensor , img_start_token_ids : torch .Tensor , img_start_locs : torch .Tensor ,
48- tp_text_start_token_id ,
49- tp_text_end_token_id ):
69+ def multimodal_emb (
70+ out : torch .Tensor ,
71+ prompt_ids : torch .Tensor ,
72+ text_weight_embs : torch .Tensor ,
73+ img_embs : torch .Tensor ,
74+ img_token_lens : torch .Tensor ,
75+ img_start_token_ids : torch .Tensor ,
76+ img_start_locs : torch .Tensor ,
77+ tp_text_start_token_id ,
78+ tp_text_end_token_id ,
79+ ):
5080 total_len = prompt_ids .shape [0 ]
5181 BLOCK = triton .next_power_of_2 (out .shape [1 ])
5282 # print(len(img_token_lens))
@@ -60,9 +90,12 @@ def multimodal_emb(out: torch.Tensor, prompt_ids: torch.Tensor, text_weight_embs
6090 img_token_lens ,
6191 img_start_token_ids ,
6292 img_start_locs ,
63- text_weight_embs .stride (0 ), text_weight_embs .stride (1 ),
64- img_embs .stride (0 ), img_embs .stride (1 ),
65- out .stride (0 ), out .stride (1 ),
93+ text_weight_embs .stride (0 ),
94+ text_weight_embs .stride (1 ),
95+ img_embs .stride (0 ),
96+ img_embs .stride (1 ),
97+ out .stride (0 ),
98+ out .stride (1 ),
6699 tp_text_start_token_id ,
67100 tp_text_end_token_id ,
68101 hidden_size = out .shape [1 ],
@@ -73,40 +106,44 @@ def multimodal_emb(out: torch.Tensor, prompt_ids: torch.Tensor, text_weight_embs
73106 return
74107
75108
76-
77- def test ():
78- S , D = 1024 * 1000 , 128 * 64
79- vob_size = 320000
80- image_size = 10
81- image_token_size = 512
82-
83- text_weight = torch .randn ((vob_size , D ), device = 'cuda' , dtype = torch .float16 )
84- img_weight = torch .randn ((image_size * image_token_size , D ), device = 'cuda' , dtype = torch .float16 )
85- img_token_lens = torch .full ((image_size ,), image_token_size , device = 'cuda' , dtype = torch .long )
86- img_start_token_ids = (torch .arange (0 , image_size * image_token_size , image_token_size ) + vob_size * 10 ).cuda ().long ()
87- img_start_locs = torch .arange (0 , image_size * image_token_size , image_token_size ).cuda ().long ()
88-
89- prompt_ids = torch .arange (0 , S , 1 ).cuda ().long ()
90- prompt_ids [0 : image_size * image_token_size ] = (vob_size * 10 + torch .arange (0 , image_size * image_token_size , 1 )).cuda ().long ()
91-
92- out = torch .zeros ((S , D ), dtype = torch .float16 , device = "cuda" )
93- print (out .shape )
94-
95- import time
96-
97- triton_output = multimodal_emb (out , prompt_ids , text_weight , img_weight , img_token_lens , img_start_token_ids , img_start_locs , 0 , vob_size )
98-
99- torch .cuda .synchronize ()
100- iters = 20
101- t1 = time .time ()
102- for _ in range (iters ):
103- triton_output = multimodal_emb (out , prompt_ids , text_weight , img_weight , img_token_lens , img_start_token_ids , img_start_locs , 0 , vob_size )
104- torch .cuda .synchronize ()
105- t2 = time .time ()
106- print ("Triton time cost" , (t2 - t1 ) / iters )
109+ @triton .jit
110+ def _mark_multimodal_obj_need_kernel (
111+ obj_start_token_ids_ptr ,
112+ obj_token_lens_ptr ,
113+ obj_marks_ptr ,
114+ input_ids_ptr ,
115+ input_size ,
116+ BLOCK_SIZE : tl .constexpr ,
117+ ):
118+
119+ obj_index = tl .program_id (0 )
120+ start_id = tl .load (obj_start_token_ids_ptr + obj_index )
121+ token_len = tl .load (obj_token_lens_ptr + obj_index )
122+
123+ for block_start in range (0 , input_size , BLOCK_SIZE ):
124+ block_range = block_start + tl .arange (0 , BLOCK_SIZE )
125+ cur_input_ids = tl .load (input_ids_ptr + block_range , mask = block_range < input_size , other = 0 )
126+ mark = tl .where ((cur_input_ids >= start_id ) & (cur_input_ids < start_id + token_len ), 1 , 0 )
127+ mark = tl .sum (mark )
128+ tl .store (obj_marks_ptr + obj_index , 1 , mask = mark > 0 )
107129 return
108130
109131
110- # if __name__ == "__main__":
111- # test()
112-
132+ @torch .no_grad ()
133+ def mark_multimodal_obj (obj_start_token_ids : torch .Tensor , obj_token_lens : torch .Tensor , input_ids : torch .Tensor ):
134+ out_mark = torch .empty_like (obj_start_token_ids )
135+ out_mark .fill_ (0 )
136+ assert obj_start_token_ids .shape == obj_token_lens .shape
137+ BLOCK = 512
138+ grid = (obj_start_token_ids .shape [0 ],)
139+ _mark_multimodal_obj_need_kernel [grid ](
140+ obj_start_token_ids_ptr = obj_start_token_ids ,
141+ obj_token_lens_ptr = obj_token_lens ,
142+ obj_marks_ptr = out_mark ,
143+ input_ids_ptr = input_ids ,
144+ input_size = input_ids .shape [0 ],
145+ BLOCK_SIZE = BLOCK ,
146+ num_warps = 1 ,
147+ num_stages = 1 ,
148+ )
149+ return out_mark
0 commit comments