@@ -35,6 +35,7 @@ enum class SeqDim { ONE = 1, TWO };
3535
3636namespace sdpa ::impl {
3737
38+ static std::vector<char > scratch_for_quant_dequant_vec;
3839struct MaybeQuantizedMatrixData {
3940 const void * data{nullptr };
4041 const int8_t * zero_points{nullptr };
@@ -543,6 +544,7 @@ TODO: Just handle conversion of bool mask to float
543544 */
544545template <typename scalar_t , int64_t q_split_size, int64_t kv_split_size>
545546void cpu_flash_attention (
547+ RuntimeContext& ctx,
546548 Tensor& output,
547549 const Tensor& query,
548550 const Tensor& key,
@@ -763,18 +765,17 @@ void cpu_flash_attention(
763765
764766 // Since all intermediate compute is accum_t, we need to
765767 // allocate a buffer accordingly.
766- int64_t size_of_intermediate_precision = sizeof (accum_t );
767- int64_t size_bytes = size_per_thread * num_thread * query.element_size () *
768- size_of_intermediate_precision;
769- std::vector<char > buf_vec (size_bytes);
770- void * buf = reinterpret_cast <void *>(buf_vec.data ());
771- // Need to double check the following
772- size_bytes = num_thread * qSplitSize * kvSplitSize * query.element_size ();
773- std::vector<char > buf_reduced_vec (size_bytes);
774- void * buf_reduced = reinterpret_cast <void *>(buf_reduced_vec.data ());
775- // at::Tensor buf_reduced = at::empty(
776- // {num_thread, qSplitSize, is_reduced_type ? kvSplitSize : 0},
777- // query.options());
768+ int64_t size_bytes = size_per_thread * num_thread * query.element_size ();
769+ Result<void *> buff_res = ctx.allocate_temp (size_bytes);
770+ std::unique_ptr<char []> allocated_buf;
771+ void * buf;
772+ if (!buff_res.ok ()) {
773+ allocated_buf = std::make_unique<char []>(size_bytes);
774+ buf = reinterpret_cast <void *>(allocated_buf.get ());
775+ } else {
776+ buf = buff_res.get ();
777+ }
778+ void * buf_reduced = nullptr ;
778779 int64_t size_per_thread_qdq_vec = kvSplitSize * headSize;
779780 // Lets align size_per_thread_qdq_vec to 64 bytes, for coalesced cache reads,
780781 // by padding with right number of per thread elements
@@ -783,9 +784,18 @@ void cpu_flash_attention(
783784 (size_per_thread_qdq_vec + kAlignment - 1 ) & (-(kAlignment - 1 ));
784785 int64_t size_per_thread_qdq_bytes = size_per_thread_qdq_vec * sizeof (accum_t );
785786 int64_t size_qdq_bytes = size_per_thread_qdq_bytes * num_thread;
786- std::vector<char > scratch_for_quant_dequant_vec (size_qdq_bytes);
787- accum_t * scratch_for_quant_dequant =
788- reinterpret_cast <accum_t *>(scratch_for_quant_dequant_vec.data ());
787+ std::unique_ptr<char []> allocated_buf_for_qdq;
788+ Result<void *> scratch_for_quant_dequant_res =
789+ ctx.allocate_temp (size_qdq_bytes);
790+ accum_t * scratch_for_quant_dequant;
791+ if (!scratch_for_quant_dequant_res.ok ()) {
792+ allocated_buf_for_qdq = std::make_unique<char []>(size_qdq_bytes);
793+ scratch_for_quant_dequant =
794+ reinterpret_cast <accum_t *>(allocated_buf_for_qdq.get ());
795+ } else {
796+ scratch_for_quant_dequant =
797+ reinterpret_cast <accum_t *>(scratch_for_quant_dequant_res.get ());
798+ }
789799
790800 // Data ptrs
791801 const scalar_t * q_data = query.const_data_ptr <scalar_t >();
0 commit comments