-
Notifications
You must be signed in to change notification settings - Fork 749
Reduce allocation overhead in quantized sdpa #15610
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ae61ab4
99902b8
72292a7
93e17be
602a3a7
9a72754
d389555
223495c
4c1faee
4464d3e
e571e70
4890054
0496db7
ec33b54
d2ce926
dacb047
fd8e7e8
8b575ab
713df2d
1d03b6c
05f7021
4e529d1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -213,13 +213,13 @@ void dequant_and_gemm( | |||||
| const int64_t v_stride_n, | ||||||
| float* o_data, | ||||||
| const int64_t o_stride_m, | ||||||
| const float beta) { | ||||||
| std::vector<float> dequantized_v_data(v_data.m * v_data.n); | ||||||
| const float beta, | ||||||
| float* buf_qdq_ptr) { | ||||||
| dequantize_per_channel_optimized( | ||||||
| static_cast<const int8_t*>(v_data.data), | ||||||
| static_cast<const float*>(v_data.scales), | ||||||
| static_cast<const int8_t*>(v_data.zero_points), | ||||||
| dequantized_v_data.data(), | ||||||
| buf_qdq_ptr, | ||||||
| -128, | ||||||
| 127, | ||||||
| 1, | ||||||
|
|
@@ -237,7 +237,7 @@ void dequant_and_gemm( | |||||
| m, | ||||||
| k, | ||||||
| static_cast<float>(1), | ||||||
| dequantized_v_data.data(), | ||||||
| buf_qdq_ptr, | ||||||
| v_data.n, | ||||||
| qk_data, | ||||||
| qk_stride_m, | ||||||
|
|
@@ -257,7 +257,8 @@ void _qk_at_v_gemm( | |||||
| const int64_t v_stride_n, | ||||||
| accum_t* o_data, | ||||||
| const int64_t o_stride_m, | ||||||
| const accum_t beta) { | ||||||
| const accum_t beta, | ||||||
| accum_t* buf_qdq_ptr) { | ||||||
| if (v_data.dtype == ScalarType::Char) { | ||||||
| if constexpr (std::is_same<accum_t, float>::value) { | ||||||
| if (m > 4) { | ||||||
|
|
@@ -273,7 +274,8 @@ void _qk_at_v_gemm( | |||||
| v_stride_n, | ||||||
| o_data, | ||||||
| o_stride_m, | ||||||
| beta); | ||||||
| beta, | ||||||
| buf_qdq_ptr); | ||||||
| } else { | ||||||
| // For smaller batch sizes, use quantized gemm | ||||||
| int a_stride_m_tmp, b_stride_n_tmp; | ||||||
|
|
@@ -773,6 +775,17 @@ void cpu_flash_attention( | |||||
| // at::Tensor buf_reduced = at::empty( | ||||||
| // {num_thread, qSplitSize, is_reduced_type ? kvSplitSize : 0}, | ||||||
| // query.options()); | ||||||
| int64_t size_per_thread_qdq_vec = kvSplitSize * headSize; | ||||||
| // Lets align size_per_thread_qdq_vec to 64 bytes, for coalesced cache reads, | ||||||
| // by padding with right number of per thread elements | ||||||
| constexpr int64_t kAlignment = 64; | ||||||
| size_per_thread_qdq_vec = | ||||||
| (size_per_thread_qdq_vec + kAlignment - 1) & (-(kAlignment - 1)); | ||||||
|
||||||
| (size_per_thread_qdq_vec + kAlignment - 1) & (-(kAlignment - 1)); | |
| (size_per_thread_qdq_vec + kAlignment - 1) & -kAlignment; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@copilot are you sure? Please double check again
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment says "align to 64 bytes" but
kAlignment = 32aligns to 32 elements. Sincesize_per_thread_qdq_vecis an element count (not byte count), and assumingaccum_tisfloat(4 bytes), this aligns to 128 bytes (32 * 4), not 64 bytes.Either:
kAlignmentto 16 if 64-byte alignment is desired, or