Skip to content

Commit d5bcfed

Browse files
committed
[Executorch] Use temp allocator for allocating scratch memory
Pull Request resolved: #15728 This allows us to leverage temp memory allocator and if that allocator is caching allocator it reduces the allocaiton overhead. ghstack-source-id: 327191611 @exported-using-ghexport Differential Revision: [D85532076](https://our.internmc.facebook.com/intern/diff/D85532076/)
1 parent 4864f7d commit d5bcfed

File tree

2 files changed

+32
-18
lines changed

2 files changed

+32
-18
lines changed

extension/llm/custom_ops/op_sdpa.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,7 @@ Tensor& flash_attention_kernel_out(
273273
// we might consider another appraoch
274274
if (seq_len >= 768) {
275275
sdpa::impl::cpu_flash_attention<CTYPE, 256, 512>(
276+
ctx,
276277
output,
277278
query,
278279
key,
@@ -289,6 +290,7 @@ Tensor& flash_attention_kernel_out(
289290
nullopt);
290291
} else if (seq_len >= 192) {
291292
sdpa::impl::cpu_flash_attention<CTYPE, 64, 512>(
293+
ctx,
292294
output,
293295
query,
294296
key,
@@ -305,6 +307,7 @@ Tensor& flash_attention_kernel_out(
305307
nullopt);
306308
} else {
307309
sdpa::impl::cpu_flash_attention<CTYPE, 32, 512>(
310+
ctx,
308311
output,
309312
query,
310313
key,
@@ -418,6 +421,7 @@ Tensor& custom_sdpa_out_impl(
418421
// we might consider another appraoch
419422
if (seq_len >= 768) {
420423
sdpa::impl::cpu_flash_attention<CTYPE, 256, 512>(
424+
ctx,
421425
output,
422426
q,
423427
k,
@@ -437,6 +441,7 @@ Tensor& custom_sdpa_out_impl(
437441
num_keys_for_causal_attention);
438442
} else if (seq_len >= 192) {
439443
sdpa::impl::cpu_flash_attention<CTYPE, 64, 512>(
444+
ctx,
440445
output,
441446
q,
442447
k,
@@ -456,6 +461,7 @@ Tensor& custom_sdpa_out_impl(
456461
num_keys_for_causal_attention);
457462
} else {
458463
sdpa::impl::cpu_flash_attention<CTYPE, 32, 512>(
464+
ctx,
459465
output,
460466
q,
461467
k,

extension/llm/custom_ops/op_sdpa_impl.h

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,7 @@ TODO: Just handle conversion of bool mask to float
543543
*/
544544
template <typename scalar_t, int64_t q_split_size, int64_t kv_split_size>
545545
void cpu_flash_attention(
546+
RuntimeContext& ctx,
546547
Tensor& output,
547548
const Tensor& query,
548549
const Tensor& key,
@@ -763,29 +764,34 @@ void cpu_flash_attention(
763764

764765
// Since all intermediate compute is accum_t, we need to
765766
// allocate a buffer accordingly.
766-
int64_t size_of_intermediate_precision = sizeof(accum_t);
767-
int64_t size_bytes = size_per_thread * num_thread * query.element_size() *
768-
size_of_intermediate_precision;
769-
std::vector<char> buf_vec(size_bytes);
770-
void* buf = reinterpret_cast<void*>(buf_vec.data());
771-
// Need to double check the following
772-
size_bytes = num_thread * qSplitSize * kvSplitSize * query.element_size();
773-
std::vector<char> buf_reduced_vec(size_bytes);
774-
void* buf_reduced = reinterpret_cast<void*>(buf_reduced_vec.data());
775-
// at::Tensor buf_reduced = at::empty(
776-
// {num_thread, qSplitSize, is_reduced_type ? kvSplitSize : 0},
777-
// query.options());
767+
int64_t size_bytes = size_per_thread * num_thread * sizeof(accum_t);
768+
std::unique_ptr<char[]> allocated_buf;
769+
void* buf;
770+
Result<void*> scratch = ctx.allocate_temp(size_bytes, 64);
771+
if (!scratch.ok()) {
772+
allocated_buf = std::make_unique<char[]>(size_bytes);
773+
buf = allocated_buf.get();
774+
} else {
775+
buf = scratch.get();
776+
}
777+
void* buf_reduced = nullptr;
778778
int64_t size_per_thread_qdq_vec = kvSplitSize * headSize;
779779
// Lets align size_per_thread_qdq_vec to 64 bytes, for coalesced cache reads,
780780
// by padding with right number of per thread elements
781-
constexpr int64_t kAlignment = 64;
782-
size_per_thread_qdq_vec =
783-
(size_per_thread_qdq_vec + kAlignment - 1) & (-(kAlignment - 1));
784781
int64_t size_per_thread_qdq_bytes = size_per_thread_qdq_vec * sizeof(accum_t);
785782
int64_t size_qdq_bytes = size_per_thread_qdq_bytes * num_thread;
786-
std::vector<char> scratch_for_quant_dequant_vec(size_qdq_bytes);
787-
accum_t* scratch_for_quant_dequant =
788-
reinterpret_cast<accum_t*>(scratch_for_quant_dequant_vec.data());
783+
std::unique_ptr<char[]> allocated_buf_for_qdq;
784+
accum_t* scratch_for_quant_dequant;
785+
Result<void*> scratch_for_quant_dequant_res =
786+
ctx.allocate_temp(size_qdq_bytes, 64);
787+
if (!scratch_for_quant_dequant_res.ok()) {
788+
allocated_buf_for_qdq = std::make_unique<char[]>(size_qdq_bytes);
789+
scratch_for_quant_dequant =
790+
reinterpret_cast<accum_t*>(allocated_buf_for_qdq.get());
791+
} else {
792+
scratch_for_quant_dequant =
793+
reinterpret_cast<accum_t*>(scratch_for_quant_dequant_res.get());
794+
}
789795

790796
// Data ptrs
791797
const scalar_t* q_data = query.const_data_ptr<scalar_t>();
@@ -819,6 +825,7 @@ void cpu_flash_attention(
819825
// Initialize max and sum
820826
fill_stub(
821827
qk_max_data, -std::numeric_limits<accum_t>::infinity(), qBlockSize);
828+
fill_stub(qk_sum_data, static_cast<accum_t>(0), qBlockSize);
822829
// Original flash sdpa wasnt really meant to be used
823830
// for decode the way we are using via start_pos here.
824831
// Thus when num_keys is 1 during decode phase, we
@@ -850,6 +857,7 @@ void cpu_flash_attention(
850857
is_causal ? std::min(m + start_pos + qBlockSize, kvSize) : kvSize;
851858
int64_t m_start_pos = m + start_pos;
852859
auto j_kv = j / num_reps;
860+
fill_stub(dst_data, static_cast<accum_t>(0), qSplitSize * headSize);
853861
for (int64_t n = 0; n < num_keys; n += kvSplitSize) {
854862
int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n);
855863
// Calculate scale * q @ k.T

0 commit comments

Comments
 (0)