Skip to content

Commit 71d2419

Browse files
committed
[Executorch] Use temp allocator for allocating scratch memory
Pull Request resolved: #15728 This allows us to leverage temp memory allocator and if that allocator is caching allocator it reduces the allocaiton overhead. ghstack-source-id: 325471916 @exported-using-ghexport Differential Revision: [D85532076](https://our.internmc.facebook.com/intern/diff/D85532076/)
1 parent fc3079b commit 71d2419

File tree

2 files changed

+31
-15
lines changed

2 files changed

+31
-15
lines changed

extension/llm/custom_ops/op_sdpa.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,7 @@ Tensor& flash_attention_kernel_out(
273273
// we might consider another appraoch
274274
if (seq_len >= 768) {
275275
sdpa::impl::cpu_flash_attention<CTYPE, 256, 512>(
276+
ctx,
276277
output,
277278
query,
278279
key,
@@ -289,6 +290,7 @@ Tensor& flash_attention_kernel_out(
289290
nullopt);
290291
} else if (seq_len >= 192) {
291292
sdpa::impl::cpu_flash_attention<CTYPE, 64, 512>(
293+
ctx,
292294
output,
293295
query,
294296
key,
@@ -305,6 +307,7 @@ Tensor& flash_attention_kernel_out(
305307
nullopt);
306308
} else {
307309
sdpa::impl::cpu_flash_attention<CTYPE, 32, 512>(
310+
ctx,
308311
output,
309312
query,
310313
key,
@@ -418,6 +421,7 @@ Tensor& custom_sdpa_out_impl(
418421
// we might consider another appraoch
419422
if (seq_len >= 768) {
420423
sdpa::impl::cpu_flash_attention<CTYPE, 256, 512>(
424+
ctx,
421425
output,
422426
q,
423427
k,
@@ -437,6 +441,7 @@ Tensor& custom_sdpa_out_impl(
437441
num_keys_for_causal_attention);
438442
} else if (seq_len >= 192) {
439443
sdpa::impl::cpu_flash_attention<CTYPE, 64, 512>(
444+
ctx,
440445
output,
441446
q,
442447
k,
@@ -456,6 +461,7 @@ Tensor& custom_sdpa_out_impl(
456461
num_keys_for_causal_attention);
457462
} else {
458463
sdpa::impl::cpu_flash_attention<CTYPE, 32, 512>(
464+
ctx,
459465
output,
460466
q,
461467
k,

extension/llm/custom_ops/op_sdpa_impl.h

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ enum class SeqDim { ONE = 1, TWO };
3535

3636
namespace sdpa::impl {
3737

38+
static std::vector<char> scratch_for_quant_dequant_vec;
3839
struct MaybeQuantizedMatrixData {
3940
const void* data{nullptr};
4041
const int8_t* zero_points{nullptr};
@@ -543,6 +544,7 @@ TODO: Just handle conversion of bool mask to float
543544
*/
544545
template <typename scalar_t, int64_t q_split_size, int64_t kv_split_size>
545546
void cpu_flash_attention(
547+
RuntimeContext& ctx,
546548
Tensor& output,
547549
const Tensor& query,
548550
const Tensor& key,
@@ -763,18 +765,17 @@ void cpu_flash_attention(
763765

764766
// Since all intermediate compute is accum_t, we need to
765767
// allocate a buffer accordingly.
766-
int64_t size_of_intermediate_precision = sizeof(accum_t);
767-
int64_t size_bytes = size_per_thread * num_thread * query.element_size() *
768-
size_of_intermediate_precision;
769-
std::vector<char> buf_vec(size_bytes);
770-
void* buf = reinterpret_cast<void*>(buf_vec.data());
771-
// Need to double check the following
772-
size_bytes = num_thread * qSplitSize * kvSplitSize * query.element_size();
773-
std::vector<char> buf_reduced_vec(size_bytes);
774-
void* buf_reduced = reinterpret_cast<void*>(buf_reduced_vec.data());
775-
// at::Tensor buf_reduced = at::empty(
776-
// {num_thread, qSplitSize, is_reduced_type ? kvSplitSize : 0},
777-
// query.options());
768+
int64_t size_bytes = size_per_thread * num_thread * query.element_size();
769+
Result<void*> buff_res = ctx.allocate_temp(size_bytes);
770+
std::unique_ptr<char[]> allocated_buf;
771+
void* buf;
772+
if (!buff_res.ok()) {
773+
allocated_buf = std::make_unique<char[]>(size_bytes);
774+
buf = reinterpret_cast<void*>(allocated_buf.get());
775+
} else {
776+
buf = buff_res.get();
777+
}
778+
void* buf_reduced = nullptr;
778779
int64_t size_per_thread_qdq_vec = kvSplitSize * headSize;
779780
// Lets align size_per_thread_qdq_vec to 64 bytes, for coalesced cache reads,
780781
// by padding with right number of per thread elements
@@ -783,9 +784,18 @@ void cpu_flash_attention(
783784
(size_per_thread_qdq_vec + kAlignment - 1) & (-(kAlignment - 1));
784785
int64_t size_per_thread_qdq_bytes = size_per_thread_qdq_vec * sizeof(accum_t);
785786
int64_t size_qdq_bytes = size_per_thread_qdq_bytes * num_thread;
786-
std::vector<char> scratch_for_quant_dequant_vec(size_qdq_bytes);
787-
accum_t* scratch_for_quant_dequant =
788-
reinterpret_cast<accum_t*>(scratch_for_quant_dequant_vec.data());
787+
std::unique_ptr<char[]> allocated_buf_for_qdq;
788+
Result<void*> scratch_for_quant_dequant_res =
789+
ctx.allocate_temp(size_qdq_bytes);
790+
accum_t* scratch_for_quant_dequant;
791+
if (!scratch_for_quant_dequant_res.ok()) {
792+
allocated_buf_for_qdq = std::make_unique<char[]>(size_qdq_bytes);
793+
scratch_for_quant_dequant =
794+
reinterpret_cast<accum_t*>(allocated_buf_for_qdq.get());
795+
} else {
796+
scratch_for_quant_dequant =
797+
reinterpret_cast<accum_t*>(scratch_for_quant_dequant_res.get());
798+
}
789799

790800
// Data ptrs
791801
const scalar_t* q_data = query.const_data_ptr<scalar_t>();

0 commit comments

Comments
 (0)