Skip to content

Commit ad8b4ae

Browse files
SangChengCsangchengmeng
andauthored
fix triton_rotary_rope_emb (#1125)
Co-authored-by: sangchengmeng <sangchengmeng@sensetime.com>
1 parent 58f8c1d commit ad8b4ae

File tree

1 file changed

+66
-26
lines changed

1 file changed

+66
-26
lines changed
Lines changed: 66 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
import math
2-
import torch
31
import triton
42
import triton.language as tl
3+
import torch
54

65

76
@triton.jit
@@ -17,57 +16,94 @@ def rotary_kernel(
1716
stride_cos_d,
1817
stride_sin_l,
1918
stride_sin_d,
20-
D: tl.constexpr,
21-
HALF_D: tl.constexpr,
19+
L,
20+
H,
21+
D,
22+
BLOCK_SEQ: tl.constexpr,
23+
BLOCK_HEAD: tl.constexpr,
2224
BLOCK_D: tl.constexpr,
2325
):
24-
pid_h = tl.program_id(0).to(tl.int64)
25-
pid_l = tl.program_id(1).to(tl.int64)
26-
pid_blk = tl.program_id(2).to(tl.int64)
26+
pid_head_blk = tl.program_id(0)
27+
pid_seq_blk = tl.program_id(1)
2728

29+
offs_h = pid_head_blk * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD)
30+
offs_l = pid_seq_blk * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ)
2831
offs_d = tl.arange(0, BLOCK_D)
29-
d = pid_blk * BLOCK_D + offs_d
30-
mask = d < D
3132

32-
base = pid_l * stride_l + pid_h * stride_h
33+
offs_h = offs_h.to(tl.int64)
34+
offs_l = offs_l.to(tl.int64)
35+
offs_d = offs_d.to(tl.int64)
36+
37+
mask_h = offs_h < H
38+
mask_l = offs_l < L
39+
mask_d = offs_d < D
40+
41+
HALF_D = D // 2
42+
43+
l_b = offs_l[:, None, None]
44+
h_b = offs_h[None, :, None]
45+
d_b = offs_d[None, None, :]
46+
47+
mask = mask_l[:, None, None] & mask_h[None, :, None] & mask_d[None, None, :]
48+
49+
base = l_b * stride_l + h_b * stride_h + d_b * stride_d
50+
x = tl.load(inp_ptr + base, mask=mask, other=0.0)
3351

34-
in_ptr = inp_ptr + base + d * stride_d
35-
cos_ptr_ = cos_ptr + pid_l * stride_cos_l + d
36-
sin_ptr_ = sin_ptr + pid_l * stride_sin_l + d
52+
cos_base_2d = offs_l[:, None] * stride_cos_l + offs_d[None, :] * stride_cos_d
53+
sin_base_2d = offs_l[:, None] * stride_sin_l + offs_d[None, :] * stride_sin_d
54+
mask_ld = mask_l[:, None] & mask_d[None, :]
3755

38-
x = tl.load(in_ptr, mask=mask)
39-
cos = tl.load(cos_ptr_, mask=mask)
40-
sin = tl.load(sin_ptr_, mask=mask)
56+
cos_2d = tl.load(cos_ptr + cos_base_2d, mask=mask_ld, other=0.0)
57+
sin_2d = tl.load(sin_ptr + sin_base_2d, mask=mask_ld, other=0.0)
4158

42-
partner_d = tl.where(d < HALF_D, d + HALF_D, d - HALF_D)
43-
partner_ptr = inp_ptr + base + partner_d * stride_d
44-
partner_val = tl.load(partner_ptr, mask=mask)
45-
rotated = tl.where(d < HALF_D, -partner_val, partner_val)
59+
cos = cos_2d[:, None, :]
60+
sin = sin_2d[:, None, :]
61+
62+
partner_d = tl.where(offs_d < HALF_D, offs_d + HALF_D, offs_d - HALF_D)
63+
partner_d_b = partner_d[None, None, :]
64+
65+
partner_base = l_b * stride_l + h_b * stride_h + partner_d_b * stride_d
66+
partner_val = tl.load(inp_ptr + partner_base, mask=mask, other=0.0)
67+
68+
rotated = tl.where(d_b < HALF_D, -partner_val, partner_val)
4669

4770
y = x * cos + rotated * sin
4871

49-
out_ptr_ = out_ptr + base + d
50-
tl.store(out_ptr_, y, mask=mask)
72+
tl.store(out_ptr + base, y, mask=mask)
5173

5274

5375
def apply_rotary_pos_emb_triton(
54-
tensor: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, BLOCK_D: int = 128
76+
tensor: torch.Tensor,
77+
cos: torch.Tensor,
78+
sin: torch.Tensor,
5579
) -> torch.Tensor:
5680
assert tensor.is_cuda and cos.is_cuda and sin.is_cuda
5781
assert cos.is_contiguous() and sin.is_contiguous()
5882
if tensor.ndim != 3:
5983
raise RuntimeError("tensor shape should be [L, H, D]")
84+
6085
orig_dtype = tensor.dtype
6186
x = tensor.float()
6287

6388
cos = cos.repeat(1, 2).view(cos.size(0), -1).contiguous().float()
6489
sin = sin.repeat(1, 2).view(sin.size(0), -1).contiguous().float()
6590

6691
L, H, D = x.shape
67-
HALF_D = D // 2
6892
y = torch.empty_like(x)
6993

70-
grid = (H, L, triton.cdiv(D, BLOCK_D))
94+
BLOCK_SEQ = 16
95+
BLOCK_HEAD = 4
96+
BLOCK_D = triton.next_power_of_2(D)
97+
98+
if D >= 128:
99+
num_warps = 8
100+
else:
101+
num_warps = 4
102+
103+
grid = (
104+
triton.cdiv(H, BLOCK_HEAD),
105+
triton.cdiv(L, BLOCK_SEQ),
106+
)
71107

72108
rotary_kernel[grid](
73109
inp_ptr=x,
@@ -81,9 +117,13 @@ def apply_rotary_pos_emb_triton(
81117
stride_cos_d=cos.stride(1),
82118
stride_sin_l=sin.stride(0),
83119
stride_sin_d=sin.stride(1),
120+
L=L,
121+
H=H,
84122
D=D,
85-
HALF_D=HALF_D,
123+
BLOCK_SEQ=BLOCK_SEQ,
124+
BLOCK_HEAD=BLOCK_HEAD,
86125
BLOCK_D=BLOCK_D,
126+
num_warps=num_warps,
87127
)
88128

89129
return y.to(orig_dtype)

0 commit comments

Comments
 (0)