Skip to content

Commit b2649ed

Browse files
authored
fix pos_emb (#1126)
1 parent 5faa29c commit b2649ed

File tree

1 file changed

+39
-65
lines changed

1 file changed

+39
-65
lines changed
Lines changed: 39 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
import math
2+
import torch
13
import triton
24
import triton.language as tl
3-
import torch
45

56

67
@triton.jit
@@ -16,94 +17,68 @@ def rotary_kernel(
1617
stride_cos_d,
1718
stride_sin_l,
1819
stride_sin_d,
19-
L,
20-
H,
21-
D,
22-
BLOCK_SEQ: tl.constexpr,
20+
total_len,
21+
head_num,
22+
D: tl.constexpr,
2323
BLOCK_HEAD: tl.constexpr,
24+
HALF_D: tl.constexpr,
2425
BLOCK_D: tl.constexpr,
2526
):
26-
pid_head_blk = tl.program_id(0)
27-
pid_seq_blk = tl.program_id(1)
27+
pid_h_block_index = tl.program_id(0).to(tl.int64)
28+
pid_l_start = tl.program_id(1).to(tl.int64)
29+
pid_blk = tl.program_id(2).to(tl.int64)
2830

29-
offs_h = pid_head_blk * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD)
30-
offs_l = pid_seq_blk * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ)
3131
offs_d = tl.arange(0, BLOCK_D)
32+
d = pid_blk * BLOCK_D + offs_d
33+
mask = d < D
34+
partner_d = tl.where(d < HALF_D, d + HALF_D, d - HALF_D)
3235

33-
offs_h = offs_h.to(tl.int64)
34-
offs_l = offs_l.to(tl.int64)
35-
offs_d = offs_d.to(tl.int64)
36-
37-
mask_h = offs_h < H
38-
mask_l = offs_l < L
39-
mask_d = offs_d < D
40-
41-
HALF_D = D // 2
42-
43-
l_b = offs_l[:, None, None]
44-
h_b = offs_h[None, :, None]
45-
d_b = offs_d[None, None, :]
46-
47-
mask = mask_l[:, None, None] & mask_h[None, :, None] & mask_d[None, None, :]
48-
49-
base = l_b * stride_l + h_b * stride_h + d_b * stride_d
50-
x = tl.load(inp_ptr + base, mask=mask, other=0.0)
51-
52-
cos_base_2d = offs_l[:, None] * stride_cos_l + offs_d[None, :] * stride_cos_d
53-
sin_base_2d = offs_l[:, None] * stride_sin_l + offs_d[None, :] * stride_sin_d
54-
mask_ld = mask_l[:, None] & mask_d[None, :]
55-
56-
cos_2d = tl.load(cos_ptr + cos_base_2d, mask=mask_ld, other=0.0)
57-
sin_2d = tl.load(sin_ptr + sin_base_2d, mask=mask_ld, other=0.0)
36+
for pid_l in tl.range(pid_l_start, total_len, step=tl.num_programs(axis=1)):
37+
cos_ptr_ = cos_ptr + pid_l * stride_cos_l + d
38+
sin_ptr_ = sin_ptr + pid_l * stride_sin_l + d
39+
cos = tl.load(cos_ptr_, mask=mask)
40+
sin = tl.load(sin_ptr_, mask=mask)
5841

59-
cos = cos_2d[:, None, :]
60-
sin = sin_2d[:, None, :]
42+
for iter_index in tl.static_range(0, BLOCK_HEAD):
43+
pid_h = pid_h_block_index * BLOCK_HEAD + iter_index
44+
pid_h = tl.where(pid_h < head_num, pid_h, pid_h_block_index * BLOCK_HEAD)
45+
base = pid_l * stride_l + pid_h * stride_h
46+
in_ptr = inp_ptr + base + d * stride_d
47+
x = tl.load(in_ptr, mask=mask, other=0.0)
6148

62-
partner_d = tl.where(offs_d < HALF_D, offs_d + HALF_D, offs_d - HALF_D)
63-
partner_d_b = partner_d[None, None, :]
49+
partner_ptr = inp_ptr + base + partner_d * stride_d
50+
partner_val = tl.load(partner_ptr, mask=mask, other=0.0)
51+
rotated = tl.where(d < HALF_D, -partner_val, partner_val)
6452

65-
partner_base = l_b * stride_l + h_b * stride_h + partner_d_b * stride_d
66-
partner_val = tl.load(inp_ptr + partner_base, mask=mask, other=0.0)
53+
y = x * cos + rotated * sin
6754

68-
rotated = tl.where(d_b < HALF_D, -partner_val, partner_val)
69-
70-
y = x * cos + rotated * sin
71-
72-
tl.store(out_ptr + base, y, mask=mask)
55+
out_ptr_ = out_ptr + base + d
56+
tl.store(out_ptr_, y, mask=mask)
7357

7458

7559
def apply_rotary_pos_emb_triton(
76-
tensor: torch.Tensor,
77-
cos: torch.Tensor,
78-
sin: torch.Tensor,
60+
tensor: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, BLOCK_D: int = 128
7961
) -> torch.Tensor:
8062
assert tensor.is_cuda and cos.is_cuda and sin.is_cuda
8163
assert cos.is_contiguous() and sin.is_contiguous()
8264
if tensor.ndim != 3:
8365
raise RuntimeError("tensor shape should be [L, H, D]")
84-
8566
orig_dtype = tensor.dtype
8667
x = tensor.float()
8768

8869
cos = cos.repeat(1, 2).view(cos.size(0), -1).contiguous().float()
8970
sin = sin.repeat(1, 2).view(sin.size(0), -1).contiguous().float()
9071

9172
L, H, D = x.shape
73+
HALF_D = D // 2
9274
y = torch.empty_like(x)
93-
94-
BLOCK_SEQ = 16
95-
BLOCK_HEAD = 4
96-
BLOCK_D = triton.next_power_of_2(D)
97-
98-
if D >= 128:
99-
num_warps = 8
75+
if L < 1024:
76+
grid_L = L
10077
else:
101-
num_warps = 4
78+
grid_L = 1024
10279

103-
grid = (
104-
triton.cdiv(H, BLOCK_HEAD),
105-
triton.cdiv(L, BLOCK_SEQ),
106-
)
80+
BLOCK_HEAD = 4
81+
grid = (triton.cdiv(H, BLOCK_HEAD), grid_L, triton.cdiv(D, BLOCK_D))
10782

10883
rotary_kernel[grid](
10984
inp_ptr=x,
@@ -117,13 +92,12 @@ def apply_rotary_pos_emb_triton(
11792
stride_cos_d=cos.stride(1),
11893
stride_sin_l=sin.stride(0),
11994
stride_sin_d=sin.stride(1),
120-
L=L,
121-
H=H,
95+
total_len=L,
96+
head_num=H,
12297
D=D,
123-
BLOCK_SEQ=BLOCK_SEQ,
12498
BLOCK_HEAD=BLOCK_HEAD,
99+
HALF_D=HALF_D,
125100
BLOCK_D=BLOCK_D,
126-
num_warps=num_warps,
127101
)
128102

129103
return y.to(orig_dtype)

0 commit comments

Comments
 (0)