Skip to content

Commit 436372a

Browse files
committed
[Executorch] Use temp allocator for allocating scratch memory
Pull Request resolved: #15728 This allows us to leverage temp memory allocator and if that allocator is caching allocator it reduces the allocaiton overhead. ghstack-source-id: 325655867 @exported-using-ghexport Differential Revision: [D85532076](https://our.internmc.facebook.com/intern/diff/D85532076/)
1 parent 5664e1a commit 436372a

File tree

2 files changed

+31
-19
lines changed

2 files changed

+31
-19
lines changed

extension/llm/custom_ops/op_sdpa.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,7 @@ Tensor& flash_attention_kernel_out(
273273
// we might consider another appraoch
274274
if (seq_len >= 768) {
275275
sdpa::impl::cpu_flash_attention<CTYPE, 256, 512>(
276+
ctx,
276277
output,
277278
query,
278279
key,
@@ -289,6 +290,7 @@ Tensor& flash_attention_kernel_out(
289290
nullopt);
290291
} else if (seq_len >= 192) {
291292
sdpa::impl::cpu_flash_attention<CTYPE, 64, 512>(
293+
ctx,
292294
output,
293295
query,
294296
key,
@@ -305,6 +307,7 @@ Tensor& flash_attention_kernel_out(
305307
nullopt);
306308
} else {
307309
sdpa::impl::cpu_flash_attention<CTYPE, 32, 512>(
310+
ctx,
308311
output,
309312
query,
310313
key,
@@ -418,6 +421,7 @@ Tensor& custom_sdpa_out_impl(
418421
// we might consider another appraoch
419422
if (seq_len >= 768) {
420423
sdpa::impl::cpu_flash_attention<CTYPE, 256, 512>(
424+
ctx,
421425
output,
422426
q,
423427
k,
@@ -437,6 +441,7 @@ Tensor& custom_sdpa_out_impl(
437441
num_keys_for_causal_attention);
438442
} else if (seq_len >= 192) {
439443
sdpa::impl::cpu_flash_attention<CTYPE, 64, 512>(
444+
ctx,
440445
output,
441446
q,
442447
k,
@@ -456,6 +461,7 @@ Tensor& custom_sdpa_out_impl(
456461
num_keys_for_causal_attention);
457462
} else {
458463
sdpa::impl::cpu_flash_attention<CTYPE, 32, 512>(
464+
ctx,
459465
output,
460466
q,
461467
k,

extension/llm/custom_ops/op_sdpa_impl.h

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,7 @@ TODO: Just handle conversion of bool mask to float
543543
*/
544544
template <typename scalar_t, int64_t q_split_size, int64_t kv_split_size>
545545
void cpu_flash_attention(
546+
RuntimeContext& ctx,
546547
Tensor& output,
547548
const Tensor& query,
548549
const Tensor& key,
@@ -763,29 +764,34 @@ void cpu_flash_attention(
763764

764765
// Since all intermediate compute is accum_t, we need to
765766
// allocate a buffer accordingly.
766-
int64_t size_of_intermediate_precision = sizeof(accum_t);
767-
int64_t size_bytes = size_per_thread * num_thread * query.element_size() *
768-
size_of_intermediate_precision;
769-
std::vector<char> buf_vec(size_bytes);
770-
void* buf = reinterpret_cast<void*>(buf_vec.data());
771-
// Need to double check the following
772-
size_bytes = num_thread * qSplitSize * kvSplitSize * query.element_size();
773-
std::vector<char> buf_reduced_vec(size_bytes);
774-
void* buf_reduced = reinterpret_cast<void*>(buf_reduced_vec.data());
775-
// at::Tensor buf_reduced = at::empty(
776-
// {num_thread, qSplitSize, is_reduced_type ? kvSplitSize : 0},
777-
// query.options());
778-
int64_t size_per_thread_qdq_vec = kvSplitSize * headSize;
767+
int64_t size_bytes = size_per_thread * num_thread * sizeof(accum_t);
768+
std::unique_ptr<char[]> allocated_buf;
769+
void* buf;
770+
Result<void*> scratch = ctx.allocate_temp(size_bytes, 64);
771+
if (!scratch.ok()) {
772+
allocated_buf = std::make_unique<char[]>(size_bytes);
773+
buf = allocated_buf.get();
774+
} else {
775+
buf = scratch.get();
776+
}
777+
void* buf_reduced = nullptr;
778+
int64_t size_per_thread_qdq_vec = qSplitSize * kvSplitSize * headSize;
779779
// Lets align size_per_thread_qdq_vec to 64 bytes, for coalesced cache reads,
780780
// by padding with right number of per thread elements
781-
constexpr int64_t kAlignment = 64;
782-
size_per_thread_qdq_vec =
783-
(size_per_thread_qdq_vec + kAlignment - 1) & (-(kAlignment - 1));
784781
int64_t size_per_thread_qdq_bytes = size_per_thread_qdq_vec * sizeof(accum_t);
785782
int64_t size_qdq_bytes = size_per_thread_qdq_bytes * num_thread;
786-
std::vector<char> scratch_for_quant_dequant_vec(size_qdq_bytes);
787-
accum_t* scratch_for_quant_dequant =
788-
reinterpret_cast<accum_t*>(scratch_for_quant_dequant_vec.data());
783+
std::unique_ptr<char[]> allocated_buf_for_qdq;
784+
accum_t* scratch_for_quant_dequant;
785+
Result<void*> scratch_for_quant_dequant_res =
786+
ctx.allocate_temp(size_qdq_bytes, 64);
787+
if (!scratch_for_quant_dequant_res.ok()) {
788+
allocated_buf_for_qdq = std::make_unique<char[]>(size_qdq_bytes);
789+
scratch_for_quant_dequant =
790+
reinterpret_cast<accum_t*>(allocated_buf_for_qdq.get());
791+
} else {
792+
scratch_for_quant_dequant =
793+
reinterpret_cast<accum_t*>(scratch_for_quant_dequant_res.get());
794+
}
789795

790796
// Data ptrs
791797
const scalar_t* q_data = query.const_data_ptr<scalar_t>();

0 commit comments

Comments
 (0)