@@ -543,6 +543,7 @@ TODO: Just handle conversion of bool mask to float
543543 */
544544template <typename scalar_t , int64_t q_split_size, int64_t kv_split_size>
545545void cpu_flash_attention (
546+ RuntimeContext& ctx,
546547 Tensor& output,
547548 const Tensor& query,
548549 const Tensor& key,
@@ -763,29 +764,34 @@ void cpu_flash_attention(
763764
764765 // Since all intermediate compute is accum_t, we need to
765766 // allocate a buffer accordingly.
766- int64_t size_of_intermediate_precision = sizeof (accum_t );
767- int64_t size_bytes = size_per_thread * num_thread * query.element_size () *
768- size_of_intermediate_precision;
769- std::vector<char > buf_vec (size_bytes);
770- void * buf = reinterpret_cast <void *>(buf_vec.data ());
771- // Need to double check the following
772- size_bytes = num_thread * qSplitSize * kvSplitSize * query.element_size ();
773- std::vector<char > buf_reduced_vec (size_bytes);
774- void * buf_reduced = reinterpret_cast <void *>(buf_reduced_vec.data ());
775- // at::Tensor buf_reduced = at::empty(
776- // {num_thread, qSplitSize, is_reduced_type ? kvSplitSize : 0},
777- // query.options());
778- int64_t size_per_thread_qdq_vec = kvSplitSize * headSize;
767+ int64_t size_bytes = size_per_thread * num_thread * sizeof (accum_t );
768+ std::unique_ptr<char []> allocated_buf;
769+ void * buf;
770+ Result<void *> scratch = ctx.allocate_temp (size_bytes, 64 );
771+ if (!scratch.ok ()) {
772+ allocated_buf = std::make_unique<char []>(size_bytes);
773+ buf = allocated_buf.get ();
774+ } else {
775+ buf = scratch.get ();
776+ }
777+ void * buf_reduced = nullptr ;
778+ int64_t size_per_thread_qdq_vec = qSplitSize * kvSplitSize * headSize;
779779 // Lets align size_per_thread_qdq_vec to 64 bytes, for coalesced cache reads,
780780 // by padding with right number of per thread elements
781- constexpr int64_t kAlignment = 64 ;
782- size_per_thread_qdq_vec =
783- (size_per_thread_qdq_vec + kAlignment - 1 ) & (-(kAlignment - 1 ));
784781 int64_t size_per_thread_qdq_bytes = size_per_thread_qdq_vec * sizeof (accum_t );
785782 int64_t size_qdq_bytes = size_per_thread_qdq_bytes * num_thread;
786- std::vector<char > scratch_for_quant_dequant_vec (size_qdq_bytes);
787- accum_t * scratch_for_quant_dequant =
788- reinterpret_cast <accum_t *>(scratch_for_quant_dequant_vec.data ());
783+ std::unique_ptr<char []> allocated_buf_for_qdq;
784+ accum_t * scratch_for_quant_dequant;
785+ Result<void *> scratch_for_quant_dequant_res =
786+ ctx.allocate_temp (size_qdq_bytes, 64 );
787+ if (!scratch_for_quant_dequant_res.ok ()) {
788+ allocated_buf_for_qdq = std::make_unique<char []>(size_qdq_bytes);
789+ scratch_for_quant_dequant =
790+ reinterpret_cast <accum_t *>(allocated_buf_for_qdq.get ());
791+ } else {
792+ scratch_for_quant_dequant =
793+ reinterpret_cast <accum_t *>(scratch_for_quant_dequant_res.get ());
794+ }
789795
790796 // Data ptrs
791797 const scalar_t * q_data = query.const_data_ptr <scalar_t >();
0 commit comments