-
Notifications
You must be signed in to change notification settings - Fork 751
Reduce allocation overhead in quantized sdpa #15610
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
ae61ab4
99902b8
72292a7
93e17be
602a3a7
9a72754
d389555
223495c
4c1faee
4464d3e
e571e70
4890054
0496db7
ec33b54
d2ce926
dacb047
fd8e7e8
8b575ab
713df2d
1d03b6c
05f7021
4e529d1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -213,13 +213,13 @@ void dequant_and_gemm( | |||||
| const int64_t v_stride_n, | ||||||
| float* o_data, | ||||||
| const int64_t o_stride_m, | ||||||
| const float beta) { | ||||||
| std::vector<float> dequantized_v_data(v_data.m * v_data.n); | ||||||
| const float beta, | ||||||
| float* buf_qdq_ptr) { | ||||||
| dequantize_per_channel_optimized( | ||||||
| static_cast<const int8_t*>(v_data.data), | ||||||
| static_cast<const float*>(v_data.scales), | ||||||
| static_cast<const int8_t*>(v_data.zero_points), | ||||||
| dequantized_v_data.data(), | ||||||
| buf_qdq_ptr, | ||||||
| -128, | ||||||
| 127, | ||||||
| 1, | ||||||
|
|
@@ -237,7 +237,7 @@ void dequant_and_gemm( | |||||
| m, | ||||||
| k, | ||||||
| static_cast<float>(1), | ||||||
| dequantized_v_data.data(), | ||||||
| buf_qdq_ptr, | ||||||
| v_data.n, | ||||||
| qk_data, | ||||||
| qk_stride_m, | ||||||
|
|
@@ -257,7 +257,8 @@ void _qk_at_v_gemm( | |||||
| const int64_t v_stride_n, | ||||||
| accum_t* o_data, | ||||||
| const int64_t o_stride_m, | ||||||
| const accum_t beta) { | ||||||
| const accum_t beta, | ||||||
| accum_t* buf_qdq_ptr) { | ||||||
| if (v_data.dtype == ScalarType::Char) { | ||||||
| if constexpr (std::is_same<accum_t, float>::value) { | ||||||
| if (m > 4) { | ||||||
|
|
@@ -273,7 +274,8 @@ void _qk_at_v_gemm( | |||||
| v_stride_n, | ||||||
| o_data, | ||||||
| o_stride_m, | ||||||
| beta); | ||||||
| beta, | ||||||
| buf_qdq_ptr); | ||||||
| } else { | ||||||
| // For smaller batch sizes, use quantized gemm | ||||||
| int a_stride_m_tmp, b_stride_n_tmp; | ||||||
|
|
@@ -773,6 +775,17 @@ void cpu_flash_attention( | |||||
| // at::Tensor buf_reduced = at::empty( | ||||||
| // {num_thread, qSplitSize, is_reduced_type ? kvSplitSize : 0}, | ||||||
| // query.options()); | ||||||
| int64_t size_per_thread_qdq_vec = qSplitSize * kvSplitSize * headSize; | ||||||
| // Lets align size_per_thread_qdq_vec to 64 bytes, for coalesced cache reads, | ||||||
|
||||||
| // Lets align size_per_thread_qdq_vec to 64 bytes, for coalesced cache reads, | |
| // Lets align size_per_thread_qdq_vec to 32 elements (128 bytes for float), for coalesced cache reads, |
Copilot
AI
Nov 17, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The alignment calculation is incorrect. The formula (x + kAlignment - 1) & (-(kAlignment - 1)) uses the wrong mask.
For aligning to a power-of-2 boundary, the correct formula is:
(size_per_thread_qdq_vec + kAlignment - 1) & (-kAlignment)or equivalently:
(size_per_thread_qdq_vec + kAlignment - 1) & ~(kAlignment - 1)The current code uses -(kAlignment - 1) which equals -31 = 0xFFFFFFE1, but the correct mask should be -32 = 0xFFFFFFE0 to properly zero out the bottom 5 bits.
| (size_per_thread_qdq_vec + kAlignment - 1) & (-(kAlignment - 1)); | |
| (size_per_thread_qdq_vec + kAlignment - 1) & -kAlignment; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@copilot are you sure? Please double check again
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The buffer size calculation appears to be larger than necessary. The dequantize operation needs
kvBlockSize * headSizeelements (at mostkvSplitSize * headSize), but this allocatesqSplitSize * kvSplitSize * headSize. The extraqSplitSizefactor seems unnecessary and wastes memory per thread.Consider changing to:
int64_t size_per_thread_qdq_vec = kvSplitSize * headSize;