Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
351a400
[Executorch] Use temp allocator for allocating scratch memory
kimishpatel Nov 11, 2025
1d96c89
Update on "[Executorch] Use temp allocator for allocating scratch mem…
kimishpatel Nov 14, 2025
0d121ef
Update on "[Executorch] Use temp allocator for allocating scratch mem…
kimishpatel Nov 20, 2025
e1c0756
Update on "[Executorch] Use temp allocator for allocating scratch mem…
kimishpatel Nov 20, 2025
f001497
Update on "[Executorch] Use temp allocator for allocating scratch mem…
kimishpatel Nov 20, 2025
5fa655c
Update on "[Executorch] Use temp allocator for allocating scratch mem…
kimishpatel Nov 21, 2025
c1e599d
Update on "[Executorch] Use temp allocator for allocating scratch mem…
kimishpatel Nov 22, 2025
8d194a5
Update on "[Executorch] Use temp allocator for allocating scratch mem…
kimishpatel Nov 23, 2025
739cf13
Update on "[Executorch] Use temp allocator for allocating scratch mem…
kimishpatel Nov 23, 2025
01cefc3
Update on "[Executorch] Use temp allocator for allocating scratch mem…
kimishpatel Nov 24, 2025
ac7cc0c
Update on "[Executorch] Use temp allocator for allocating scratch mem…
kimishpatel Nov 24, 2025
6f5e330
Update on "[Executorch] Use temp allocator for allocating scratch mem…
kimishpatel Nov 24, 2025
ec321c3
Update on "[Executorch] Use temp allocator for allocating scratch mem…
kimishpatel Nov 24, 2025
0513d96
Update on "[Executorch] Use temp allocator for allocating scratch mem…
kimishpatel Nov 25, 2025
e0c90b0
Update on "[Executorch] Use temp allocator for allocating scratch mem…
kimishpatel Nov 25, 2025
e643419
Update on "[Executorch] Use temp allocator for allocating scratch mem…
kimishpatel Dec 4, 2025
7d8c5fb
Update on "[Executorch] Use temp allocator for allocating scratch mem…
kimishpatel Dec 4, 2025
d26a21c
Update on "[Executorch] Use temp allocator for allocating scratch mem…
kimishpatel Dec 4, 2025
cba93f9
Update on "[Executorch] Use temp allocator for allocating scratch mem…
kimishpatel Dec 4, 2025
18950f6
Update on "[Executorch] Use temp allocator for allocating scratch mem…
kimishpatel Dec 4, 2025
38c21b8
Update on "[Executorch] Use temp allocator for allocating scratch mem…
kimishpatel Dec 4, 2025
47fcfc7
Update on "[Executorch] Use temp allocator for allocating scratch mem…
kimishpatel Dec 4, 2025
8a708e3
Update on "[Executorch] Use temp allocator for allocating scratch mem…
kimishpatel Dec 4, 2025
f27cb26
Update on "[Executorch] Use temp allocator for allocating scratch mem…
kimishpatel Dec 5, 2025
60498e4
Update on "[Executorch] Use temp allocator for allocating scratch mem…
kimishpatel Dec 5, 2025
cc28fd4
Update on "[Executorch] Use temp allocator for allocating scratch mem…
kimishpatel Dec 5, 2025
2c8757e
Update on "[Executorch] Use temp allocator for allocating scratch mem…
kimishpatel Dec 5, 2025
af048bd
Update on "[Executorch] Use temp allocator for allocating scratch mem…
kimishpatel Dec 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions extension/llm/custom_ops/op_sdpa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ Tensor& flash_attention_kernel_out(
// we might consider another appraoch
if (seq_len >= 768) {
sdpa::impl::cpu_flash_attention<CTYPE, 256, 512>(
ctx,
output,
query,
key,
Expand All @@ -289,6 +290,7 @@ Tensor& flash_attention_kernel_out(
nullopt);
} else if (seq_len >= 192) {
sdpa::impl::cpu_flash_attention<CTYPE, 64, 512>(
ctx,
output,
query,
key,
Expand All @@ -305,6 +307,7 @@ Tensor& flash_attention_kernel_out(
nullopt);
} else {
sdpa::impl::cpu_flash_attention<CTYPE, 32, 512>(
ctx,
output,
query,
key,
Expand Down Expand Up @@ -418,6 +421,7 @@ Tensor& custom_sdpa_out_impl(
// we might consider another appraoch
if (seq_len >= 768) {
sdpa::impl::cpu_flash_attention<CTYPE, 256, 512>(
ctx,
output,
q,
k,
Expand All @@ -437,6 +441,7 @@ Tensor& custom_sdpa_out_impl(
num_keys_for_causal_attention);
} else if (seq_len >= 192) {
sdpa::impl::cpu_flash_attention<CTYPE, 64, 512>(
ctx,
output,
q,
k,
Expand All @@ -456,6 +461,7 @@ Tensor& custom_sdpa_out_impl(
num_keys_for_causal_attention);
} else {
sdpa::impl::cpu_flash_attention<CTYPE, 32, 512>(
ctx,
output,
q,
k,
Expand Down
39 changes: 26 additions & 13 deletions extension/llm/custom_ops/op_sdpa_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ enum class SeqDim { ONE = 1, TWO };

namespace sdpa::impl {

static std::vector<char> scratch_for_quant_dequant_vec;
Copy link

Copilot AI Nov 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This static vector scratch_for_quant_dequant_vec is declared but never used in the code. It appears to be a leftover from the refactoring where the local vector was replaced with the temp allocator approach. This should be removed.

Suggested change
static std::vector<char> scratch_for_quant_dequant_vec;

Copilot uses AI. Check for mistakes.
struct MaybeQuantizedMatrixData {
const void* data{nullptr};
const int8_t* zero_points{nullptr};
Expand Down Expand Up @@ -543,6 +544,7 @@ TODO: Just handle conversion of bool mask to float
*/
template <typename scalar_t, int64_t q_split_size, int64_t kv_split_size>
void cpu_flash_attention(
RuntimeContext& ctx,
Tensor& output,
const Tensor& query,
const Tensor& key,
Expand Down Expand Up @@ -766,26 +768,37 @@ void cpu_flash_attention(
int64_t size_of_intermediate_precision = sizeof(accum_t);
int64_t size_bytes = size_per_thread * num_thread * query.element_size() *
size_of_intermediate_precision;
std::vector<char> buf_vec(size_bytes);
void* buf = reinterpret_cast<void*>(buf_vec.data());
// Need to double check the following
size_bytes = num_thread * qSplitSize * kvSplitSize * query.element_size();
std::vector<char> buf_reduced_vec(size_bytes);
void* buf_reduced = reinterpret_cast<void*>(buf_reduced_vec.data());
// at::Tensor buf_reduced = at::empty(
// {num_thread, qSplitSize, is_reduced_type ? kvSplitSize : 0},
// query.options());
Result<void*> buff_res = ctx.allocate_temp(size_bytes);
std::unique_ptr<char[]> allocated_buf;
void* buf;
if (!buff_res.ok()) {
allocated_buf = std::make_unique<char[]>(size_bytes);
buf = reinterpret_cast<void*>(allocated_buf.get());
} else {
buf = buff_res.get();
}
void* buf_reduced = nullptr;
int64_t size_per_thread_qdq_vec = qSplitSize * kvSplitSize * headSize;
// Lets align size_per_thread_qdq_vec to 64 bytes, for coalesced cache reads,
// by padding with right number of per thread elements
constexpr int64_t kAlignment = 32;
size_per_thread_qdq_vec =
(size_per_thread_qdq_vec + kAlignment - 1) & (-(kAlignment - 1));
int64_t size_per_thread_qdq_bytes = size_per_thread_qdq_vec * sizeof(accum_t);
int64_t size_per_thread_qdq_bytes =
size_per_thread_qdq_vec * size_of_intermediate_precision;
int64_t size_qdq_bytes = size_per_thread_qdq_bytes * num_thread;
std::vector<char> scratch_for_quant_dequant_vec(size_qdq_bytes);
accum_t* scratch_for_quant_dequant =
reinterpret_cast<accum_t*>(scratch_for_quant_dequant_vec.data());
std::unique_ptr<char[]> allocated_buf_for_qdq;
Result<void*> scratch_for_quant_dequant_res =
ctx.allocate_temp(size_qdq_bytes);
accum_t* scratch_for_quant_dequant;
if (!scratch_for_quant_dequant_res.ok()) {
allocated_buf_for_qdq = std::make_unique<char[]>(size_qdq_bytes);
scratch_for_quant_dequant =
reinterpret_cast<accum_t*>(allocated_buf_for_qdq.get());
} else {
scratch_for_quant_dequant =
reinterpret_cast<accum_t*>(scratch_for_quant_dequant_res.get());
}

// Data ptrs
const scalar_t* q_data = query.const_data_ptr<scalar_t>();
Expand Down
Loading