1+ import math
2+ import torch
13import triton
24import triton .language as tl
3- import torch
45
56
67@triton .jit
@@ -16,94 +17,68 @@ def rotary_kernel(
1617 stride_cos_d ,
1718 stride_sin_l ,
1819 stride_sin_d ,
19- L ,
20- H ,
21- D ,
22- BLOCK_SEQ : tl .constexpr ,
20+ total_len ,
21+ head_num ,
22+ D : tl .constexpr ,
2323 BLOCK_HEAD : tl .constexpr ,
24+ HALF_D : tl .constexpr ,
2425 BLOCK_D : tl .constexpr ,
2526):
26- pid_head_blk = tl .program_id (0 )
27- pid_seq_blk = tl .program_id (1 )
27+ pid_h_block_index = tl .program_id (0 ).to (tl .int64 )
28+ pid_l_start = tl .program_id (1 ).to (tl .int64 )
29+ pid_blk = tl .program_id (2 ).to (tl .int64 )
2830
29- offs_h = pid_head_blk * BLOCK_HEAD + tl .arange (0 , BLOCK_HEAD )
30- offs_l = pid_seq_blk * BLOCK_SEQ + tl .arange (0 , BLOCK_SEQ )
3131 offs_d = tl .arange (0 , BLOCK_D )
32+ d = pid_blk * BLOCK_D + offs_d
33+ mask = d < D
34+ partner_d = tl .where (d < HALF_D , d + HALF_D , d - HALF_D )
3235
33- offs_h = offs_h .to (tl .int64 )
34- offs_l = offs_l .to (tl .int64 )
35- offs_d = offs_d .to (tl .int64 )
36-
37- mask_h = offs_h < H
38- mask_l = offs_l < L
39- mask_d = offs_d < D
40-
41- HALF_D = D // 2
42-
43- l_b = offs_l [:, None , None ]
44- h_b = offs_h [None , :, None ]
45- d_b = offs_d [None , None , :]
46-
47- mask = mask_l [:, None , None ] & mask_h [None , :, None ] & mask_d [None , None , :]
48-
49- base = l_b * stride_l + h_b * stride_h + d_b * stride_d
50- x = tl .load (inp_ptr + base , mask = mask , other = 0.0 )
51-
52- cos_base_2d = offs_l [:, None ] * stride_cos_l + offs_d [None , :] * stride_cos_d
53- sin_base_2d = offs_l [:, None ] * stride_sin_l + offs_d [None , :] * stride_sin_d
54- mask_ld = mask_l [:, None ] & mask_d [None , :]
55-
56- cos_2d = tl .load (cos_ptr + cos_base_2d , mask = mask_ld , other = 0.0 )
57- sin_2d = tl .load (sin_ptr + sin_base_2d , mask = mask_ld , other = 0.0 )
36+ for pid_l in tl .range (pid_l_start , total_len , step = tl .num_programs (axis = 1 )):
37+ cos_ptr_ = cos_ptr + pid_l * stride_cos_l + d
38+ sin_ptr_ = sin_ptr + pid_l * stride_sin_l + d
39+ cos = tl .load (cos_ptr_ , mask = mask )
40+ sin = tl .load (sin_ptr_ , mask = mask )
5841
59- cos = cos_2d [:, None , :]
60- sin = sin_2d [:, None , :]
42+ for iter_index in tl .static_range (0 , BLOCK_HEAD ):
43+ pid_h = pid_h_block_index * BLOCK_HEAD + iter_index
44+ pid_h = tl .where (pid_h < head_num , pid_h , pid_h_block_index * BLOCK_HEAD )
45+ base = pid_l * stride_l + pid_h * stride_h
46+ in_ptr = inp_ptr + base + d * stride_d
47+ x = tl .load (in_ptr , mask = mask , other = 0.0 )
6148
62- partner_d = tl .where (offs_d < HALF_D , offs_d + HALF_D , offs_d - HALF_D )
63- partner_d_b = partner_d [None , None , :]
49+ partner_ptr = inp_ptr + base + partner_d * stride_d
50+ partner_val = tl .load (partner_ptr , mask = mask , other = 0.0 )
51+ rotated = tl .where (d < HALF_D , - partner_val , partner_val )
6452
65- partner_base = l_b * stride_l + h_b * stride_h + partner_d_b * stride_d
66- partner_val = tl .load (inp_ptr + partner_base , mask = mask , other = 0.0 )
53+ y = x * cos + rotated * sin
6754
68- rotated = tl .where (d_b < HALF_D , - partner_val , partner_val )
69-
70- y = x * cos + rotated * sin
71-
72- tl .store (out_ptr + base , y , mask = mask )
55+ out_ptr_ = out_ptr + base + d
56+ tl .store (out_ptr_ , y , mask = mask )
7357
7458
7559def apply_rotary_pos_emb_triton (
76- tensor : torch .Tensor ,
77- cos : torch .Tensor ,
78- sin : torch .Tensor ,
60+ tensor : torch .Tensor , cos : torch .Tensor , sin : torch .Tensor , BLOCK_D : int = 128
7961) -> torch .Tensor :
8062 assert tensor .is_cuda and cos .is_cuda and sin .is_cuda
8163 assert cos .is_contiguous () and sin .is_contiguous ()
8264 if tensor .ndim != 3 :
8365 raise RuntimeError ("tensor shape should be [L, H, D]" )
84-
8566 orig_dtype = tensor .dtype
8667 x = tensor .float ()
8768
8869 cos = cos .repeat (1 , 2 ).view (cos .size (0 ), - 1 ).contiguous ().float ()
8970 sin = sin .repeat (1 , 2 ).view (sin .size (0 ), - 1 ).contiguous ().float ()
9071
9172 L , H , D = x .shape
73+ HALF_D = D // 2
9274 y = torch .empty_like (x )
93-
94- BLOCK_SEQ = 16
95- BLOCK_HEAD = 4
96- BLOCK_D = triton .next_power_of_2 (D )
97-
98- if D >= 128 :
99- num_warps = 8
75+ if L < 1024 :
76+ grid_L = L
10077 else :
101- num_warps = 4
78+ grid_L = 1024
10279
103- grid = (
104- triton .cdiv (H , BLOCK_HEAD ),
105- triton .cdiv (L , BLOCK_SEQ ),
106- )
80+ BLOCK_HEAD = 4
81+ grid = (triton .cdiv (H , BLOCK_HEAD ), grid_L , triton .cdiv (D , BLOCK_D ))
10782
10883 rotary_kernel [grid ](
10984 inp_ptr = x ,
@@ -117,13 +92,12 @@ def apply_rotary_pos_emb_triton(
11792 stride_cos_d = cos .stride (1 ),
11893 stride_sin_l = sin .stride (0 ),
11994 stride_sin_d = sin .stride (1 ),
120- L = L ,
121- H = H ,
95+ total_len = L ,
96+ head_num = H ,
12297 D = D ,
123- BLOCK_SEQ = BLOCK_SEQ ,
12498 BLOCK_HEAD = BLOCK_HEAD ,
99+ HALF_D = HALF_D ,
125100 BLOCK_D = BLOCK_D ,
126- num_warps = num_warps ,
127101 )
128102
129103 return y .to (orig_dtype )
0 commit comments