Skip to content

Commit f5a6a2b

Browse files
authored
Revert SDPA changes (#5078)
* Revert "[SDP] aten::scaled_dot_product_attention applies different format support in mha kernel (#4885)" This reverts commit cd81509. * Revert "[SDP][FWD] support dynamic memory format (#4776)" This reverts commit 7cb5af5.
1 parent e06f287 commit f5a6a2b

File tree

5 files changed

+26
-87
lines changed

5 files changed

+26
-87
lines changed

csrc/gpu/aten/operators/transformers/attention.cpp

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> ipex_sdp_dropout_backward(
4141
c10::optional<double> scale);
4242

4343
inline Tensor _scaled_dot_product_efficient_attention_impl(
44-
const Tensor& query,
45-
const Tensor& key,
46-
const Tensor& value,
44+
const Tensor& _query,
45+
const Tensor& _key,
46+
const Tensor& _value,
4747
const c10::optional<Tensor>& attn_mask,
4848
const c10::optional<at::Tensor>& dropout_mask,
4949
const c10::optional<at::Tensor>& seed_t,
@@ -63,10 +63,10 @@ inline Tensor _scaled_dot_product_efficient_attention_impl(
6363
attn_mask_padded_block_size = alignTo * ((lastDim + alignTo - 1) / alignTo);
6464
}
6565

66-
// check q, k, v
67-
CHECK_NOSPARSE_LASTCONTIGUOUS_XPU(query);
68-
CHECK_NOSPARSE_LASTCONTIGUOUS_XPU(key);
69-
CHECK_NOSPARSE_LASTCONTIGUOUS_XPU(value);
66+
// make q, k, v strided
67+
auto query = _query.transpose(1, 2).contiguous().transpose(1, 2);
68+
auto key = _key.transpose(1, 2).contiguous().transpose(1, 2);
69+
auto value = _value.transpose(1, 2).contiguous().transpose(1, 2);
7070

7171
// create strided output
7272
// size [bs, num_head, qsize, head_size]
@@ -102,12 +102,7 @@ inline Tensor _scaled_dot_product_efficient_attention_impl(
102102
query.size(3),
103103
query.size(2),
104104
key.size(2),
105-
query.stride(0),
106-
query.stride(1),
107105
query.stride(2),
108-
key.stride(0),
109-
key.stride(1),
110-
key.stride(2),
111106
attn_mask.has_value() ? attn_mask->stride(0) : -1,
112107
attn_mask.has_value() ? attn_mask->stride(1) : -1,
113108
attn_mask.has_value() ? attn_mask->stride(2) : -1,
@@ -1134,12 +1129,7 @@ Tensor varlen_fwd(
11341129
head_dim,
11351130
num_queries,
11361131
num_keys,
1137-
/* q_strideB */ query.stride(0),
1138-
/* q_strideN */ query.stride(1),
11391132
/* q_strideF */ query.stride(2),
1140-
/* kv_strideB */ key.stride(0),
1141-
/* kv_strideN */ key.stride(1),
1142-
/* kv_strideT */ key.stride(2),
11431133
/* bias_strideB */ -1,
11441134
/* bias_strideN */ -1,
11451135
/* bias_strideF */ -1,
@@ -1255,12 +1245,7 @@ Tensor xetla_fsdp_forward_atten_mask_alibi_strided(
12551245
head_dim,
12561246
M,
12571247
N,
1258-
query.stride(0),
1259-
query.stride(1),
12601248
query.stride(2),
1261-
key.stride(0),
1262-
key.stride(1),
1263-
key.stride(2),
12641249
attn_mask.has_value() ? attn_mask_bc.stride(0) : -1,
12651250
attn_mask.has_value() ? attn_mask_bc.stride(1) : -1,
12661251
attn_mask.has_value() ? attn_mask_bc.stride(2) : -1,

csrc/gpu/aten/operators/transformers/sdp_utils.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,6 @@
2020
using namespace gpu::xetla;
2121
#endif
2222

23-
#define CHECK_NOSPARSE_LASTCONTIGUOUS_XPU(TENSOR) \
24-
TORCH_CHECK(TENSOR.is_xpu(), #TENSOR " must be a XPU tensor"); \
25-
TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \
26-
TORCH_CHECK( \
27-
TENSOR.stride(-1) == 1, #TENSOR ": last dimension must be contiguous");
28-
2923
using namespace at;
3024
using namespace torch_ipex::xpu::dpcpp;
3125

csrc/gpu/aten/operators/xetla/kernels/SDP/fmha_forward.hpp

Lines changed: 17 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,7 @@ class fmha_forward_t {
4949
uint32_t uH;
5050
uint32_t uF;
5151
uint32_t uT;
52-
uint32_t q_strideB;
53-
uint32_t q_strideN;
54-
uint32_t q_strideF;
55-
uint32_t kv_strideB;
56-
uint32_t kv_strideN;
57-
uint32_t kv_strideT;
52+
uint64_t q_strideF;
5853
uint32_t bias_strideB;
5954
uint32_t bias_strideN;
6055
uint32_t bias_strideF;
@@ -89,12 +84,7 @@ class fmha_forward_t {
8984
uint32_t head_size,
9085
uint32_t num_queries,
9186
uint32_t num_keys,
92-
uint32_t q_strideB,
93-
uint32_t q_strideN,
94-
uint32_t q_strideF,
95-
uint32_t kv_strideB,
96-
uint32_t kv_strideN,
97-
uint32_t kv_strideT,
87+
uint64_t q_strideF,
9888
uint32_t bias_strideB,
9989
uint32_t bias_strideN,
10090
uint32_t bias_strideF,
@@ -121,12 +111,7 @@ class fmha_forward_t {
121111
uH(head_size),
122112
uF(num_queries),
123113
uT(num_keys),
124-
q_strideB(q_strideB),
125-
q_strideN(q_strideN),
126114
q_strideF(q_strideF),
127-
kv_strideB(kv_strideB),
128-
kv_strideN(kv_strideN),
129-
kv_strideT(kv_strideT),
130115
bias_strideB(bias_strideB),
131116
bias_strideN(bias_strideN),
132117
bias_strideF(bias_strideF),
@@ -318,25 +303,21 @@ class fmha_forward_t {
318303
mem_desc_Oi.init(
319304
args.O_ptr, {end_x, end_y, ld_qo}, {start_acc, start_y});
320305
} else { // 2d mem: [BxF, NxH]
321-
uint32_t ptr_offset =
322-
batch_id * args.q_strideB + head_id * args.q_strideN;
323-
auto Q_ptr = args.Q_ptr + ptr_offset;
324-
auto O_ptr = args.O_ptr + ptr_offset;
325-
326306
// startF
327-
int32_t start_y = item.get_group(1) * kBr;
307+
int32_t start_y = batch_id * args.uF + item.get_group(1) * kBr;
328308
uint32_t end_y = start_y + kBr;
329309
// boundaryF
330-
uint32_t boundary_y = args.uF;
310+
uint32_t boundary_y = (batch_id + 1) * args.uF;
331311
end_y = end_y > boundary_y ? boundary_y : end_y;
332312

333-
int32_t start_acc = 0;
313+
int32_t start_acc = head_id * args.uH;
334314
uint32_t end_acc = start_acc + args.uH;
315+
const uint32_t ld_o = args.uH * args.uN;
335316

336317
mem_desc_Qi.init(
337-
Q_ptr, {end_acc, end_y, args.q_strideF}, {start_acc, start_y});
318+
args.Q_ptr, {end_acc, end_y, args.q_strideF}, {start_acc, start_y});
338319
mem_desc_Oi.init(
339-
O_ptr, {end_acc, end_y, args.q_strideF}, {start_acc, start_y});
320+
args.O_ptr, {end_acc, end_y, ld_o}, {start_acc, start_y});
340321
}
341322

342323
int32_t start_x_ml = item.get_group(1) * kBr + sg_idy * kSgBr;
@@ -394,23 +375,22 @@ class fmha_forward_t {
394375
{start_acc, start_x});
395376

396377
} else {
397-
uint32_t ptr_offset =
398-
batch_id * args.kv_strideB + head_id_kv * args.kv_strideN;
399-
auto K_ptr = args.K_ptr + ptr_offset;
400-
auto V_ptr = args.V_ptr + ptr_offset;
401-
402-
int32_t start_x = startT;
378+
int32_t start_x = batch_id * args.uT + startT;
403379
uint32_t end_x = start_x + kBc;
404-
uint32_t boundary_x = args.uT;
380+
uint32_t boundary_x = (batch_id + 1) * args.uT;
405381
end_x = end_x > boundary_x ? boundary_x : end_x;
406382

407-
int32_t start_acc = 0;
383+
int32_t start_acc = head_id_kv * args.uH;
408384
uint32_t end_acc = start_acc + args.uH;
409385

410386
mem_desc_Kj_T.init(
411-
K_ptr, {end_x, end_acc, args.kv_strideT}, {start_x, start_acc});
387+
args.K_ptr,
388+
{end_x, end_acc, args.uH * args.uNkv},
389+
{start_x, start_acc});
412390
mem_desc_Vj.init(
413-
V_ptr, {end_acc, end_x, args.kv_strideT}, {start_acc, start_x});
391+
args.V_ptr,
392+
{end_acc, end_x, args.uH * args.uNkv},
393+
{start_acc, start_x});
414394
}
415395

416396
// B, N, 1, T

csrc/gpu/aten/operators/xetla/kernels/SDP/fmha_forward_kernel.hpp

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,7 @@ struct dispatch_fmha_forward_args_t {
3333
uint32_t head_size;
3434
uint32_t num_queries;
3535
uint32_t num_keys;
36-
uint32_t q_strideB;
37-
uint32_t q_strideN;
38-
uint32_t q_strideF;
39-
uint32_t kv_strideB;
40-
uint32_t kv_strideN;
41-
uint32_t kv_strideT;
36+
uint64_t q_strideF;
4237
uint32_t bias_strideB;
4338
uint32_t bias_strideN;
4439
uint32_t bias_strideF;
@@ -65,12 +60,7 @@ struct dispatch_fmha_forward_args_t {
6560
head_size(args.head_size),
6661
num_queries(args.num_queries),
6762
num_keys(args.num_keys),
68-
q_strideB(args.q_strideB),
69-
q_strideN(args.q_strideN),
7063
q_strideF(args.q_strideF),
71-
kv_strideB(args.kv_strideB),
72-
kv_strideN(args.kv_strideN),
73-
kv_strideT(args.kv_strideT),
7464
bias_strideB(args.bias_strideB),
7565
bias_strideN(args.bias_strideN),
7666
bias_strideF(args.bias_strideF),
@@ -137,12 +127,7 @@ struct FmhaForwardKernelFunctor {
137127
args.head_size,
138128
args.num_queries,
139129
args.num_keys,
140-
args.q_strideB,
141-
args.q_strideN,
142130
args.q_strideF,
143-
args.kv_strideB,
144-
args.kv_strideN,
145-
args.kv_strideT,
146131
args.bias_strideB,
147132
args.bias_strideN,
148133
args.bias_strideF,

csrc/gpu/aten/operators/xetla/mha.h

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,7 @@ struct fmha_forward_kernel_args_t {
3232
uint32_t head_size;
3333
uint32_t num_queries;
3434
uint32_t num_keys;
35-
uint32_t q_strideB;
36-
uint32_t q_strideN;
37-
uint32_t q_strideF;
38-
uint32_t kv_strideB;
39-
uint32_t kv_strideN;
40-
uint32_t kv_strideT;
35+
uint64_t q_strideF;
4136
uint32_t bias_strideB;
4237
uint32_t bias_strideN;
4338
uint32_t bias_strideF;

0 commit comments

Comments
 (0)