1- import math
2- import torch
31import triton
42import triton .language as tl
3+ import torch
54
65
76@triton .jit
@@ -17,57 +16,94 @@ def rotary_kernel(
1716 stride_cos_d ,
1817 stride_sin_l ,
1918 stride_sin_d ,
20- D : tl .constexpr ,
21- HALF_D : tl .constexpr ,
19+ L ,
20+ H ,
21+ D ,
22+ BLOCK_SEQ : tl .constexpr ,
23+ BLOCK_HEAD : tl .constexpr ,
2224 BLOCK_D : tl .constexpr ,
2325):
24- pid_h = tl .program_id (0 ).to (tl .int64 )
25- pid_l = tl .program_id (1 ).to (tl .int64 )
26- pid_blk = tl .program_id (2 ).to (tl .int64 )
26+ pid_head_blk = tl .program_id (0 )
27+ pid_seq_blk = tl .program_id (1 )
2728
29+ offs_h = pid_head_blk * BLOCK_HEAD + tl .arange (0 , BLOCK_HEAD )
30+ offs_l = pid_seq_blk * BLOCK_SEQ + tl .arange (0 , BLOCK_SEQ )
2831 offs_d = tl .arange (0 , BLOCK_D )
29- d = pid_blk * BLOCK_D + offs_d
30- mask = d < D
3132
32- base = pid_l * stride_l + pid_h * stride_h
33+ offs_h = offs_h .to (tl .int64 )
34+ offs_l = offs_l .to (tl .int64 )
35+ offs_d = offs_d .to (tl .int64 )
36+
37+ mask_h = offs_h < H
38+ mask_l = offs_l < L
39+ mask_d = offs_d < D
40+
41+ HALF_D = D // 2
42+
43+ l_b = offs_l [:, None , None ]
44+ h_b = offs_h [None , :, None ]
45+ d_b = offs_d [None , None , :]
46+
47+ mask = mask_l [:, None , None ] & mask_h [None , :, None ] & mask_d [None , None , :]
48+
49+ base = l_b * stride_l + h_b * stride_h + d_b * stride_d
50+ x = tl .load (inp_ptr + base , mask = mask , other = 0.0 )
3351
34- in_ptr = inp_ptr + base + d * stride_d
35- cos_ptr_ = cos_ptr + pid_l * stride_cos_l + d
36- sin_ptr_ = sin_ptr + pid_l * stride_sin_l + d
52+ cos_base_2d = offs_l [:, None ] * stride_cos_l + offs_d [ None , :] * stride_cos_d
53+ sin_base_2d = offs_l [:, None ] * stride_sin_l + offs_d [ None , :] * stride_sin_d
54+ mask_ld = mask_l [:, None ] & mask_d [ None , :]
3755
38- x = tl .load (in_ptr , mask = mask )
39- cos = tl .load (cos_ptr_ , mask = mask )
40- sin = tl .load (sin_ptr_ , mask = mask )
56+ cos_2d = tl .load (cos_ptr + cos_base_2d , mask = mask_ld , other = 0.0 )
57+ sin_2d = tl .load (sin_ptr + sin_base_2d , mask = mask_ld , other = 0.0 )
4158
42- partner_d = tl .where (d < HALF_D , d + HALF_D , d - HALF_D )
43- partner_ptr = inp_ptr + base + partner_d * stride_d
44- partner_val = tl .load (partner_ptr , mask = mask )
45- rotated = tl .where (d < HALF_D , - partner_val , partner_val )
59+ cos = cos_2d [:, None , :]
60+ sin = sin_2d [:, None , :]
61+
62+ partner_d = tl .where (offs_d < HALF_D , offs_d + HALF_D , offs_d - HALF_D )
63+ partner_d_b = partner_d [None , None , :]
64+
65+ partner_base = l_b * stride_l + h_b * stride_h + partner_d_b * stride_d
66+ partner_val = tl .load (inp_ptr + partner_base , mask = mask , other = 0.0 )
67+
68+ rotated = tl .where (d_b < HALF_D , - partner_val , partner_val )
4669
4770 y = x * cos + rotated * sin
4871
49- out_ptr_ = out_ptr + base + d
50- tl .store (out_ptr_ , y , mask = mask )
72+ tl .store (out_ptr + base , y , mask = mask )
5173
5274
5375def apply_rotary_pos_emb_triton (
54- tensor : torch .Tensor , cos : torch .Tensor , sin : torch .Tensor , BLOCK_D : int = 128
76+ tensor : torch .Tensor ,
77+ cos : torch .Tensor ,
78+ sin : torch .Tensor ,
5579) -> torch .Tensor :
5680 assert tensor .is_cuda and cos .is_cuda and sin .is_cuda
5781 assert cos .is_contiguous () and sin .is_contiguous ()
5882 if tensor .ndim != 3 :
5983 raise RuntimeError ("tensor shape should be [L, H, D]" )
84+
6085 orig_dtype = tensor .dtype
6186 x = tensor .float ()
6287
6388 cos = cos .repeat (1 , 2 ).view (cos .size (0 ), - 1 ).contiguous ().float ()
6489 sin = sin .repeat (1 , 2 ).view (sin .size (0 ), - 1 ).contiguous ().float ()
6590
6691 L , H , D = x .shape
67- HALF_D = D // 2
6892 y = torch .empty_like (x )
6993
70- grid = (H , L , triton .cdiv (D , BLOCK_D ))
94+ BLOCK_SEQ = 16
95+ BLOCK_HEAD = 4
96+ BLOCK_D = triton .next_power_of_2 (D )
97+
98+ if D >= 128 :
99+ num_warps = 8
100+ else :
101+ num_warps = 4
102+
103+ grid = (
104+ triton .cdiv (H , BLOCK_HEAD ),
105+ triton .cdiv (L , BLOCK_SEQ ),
106+ )
71107
72108 rotary_kernel [grid ](
73109 inp_ptr = x ,
@@ -81,9 +117,13 @@ def apply_rotary_pos_emb_triton(
81117 stride_cos_d = cos .stride (1 ),
82118 stride_sin_l = sin .stride (0 ),
83119 stride_sin_d = sin .stride (1 ),
120+ L = L ,
121+ H = H ,
84122 D = D ,
85- HALF_D = HALF_D ,
123+ BLOCK_SEQ = BLOCK_SEQ ,
124+ BLOCK_HEAD = BLOCK_HEAD ,
86125 BLOCK_D = BLOCK_D ,
126+ num_warps = num_warps ,
87127 )
88128
89129 return y .to (orig_dtype )
0 commit comments