diff --git a/extension/llm/custom_ops/op_sdpa.cpp b/extension/llm/custom_ops/op_sdpa.cpp index c98fa1729fa..72bddce7b5b 100644 --- a/extension/llm/custom_ops/op_sdpa.cpp +++ b/extension/llm/custom_ops/op_sdpa.cpp @@ -273,6 +273,7 @@ Tensor& flash_attention_kernel_out( // we might consider another appraoch if (seq_len >= 768) { sdpa::impl::cpu_flash_attention( + ctx, output, query, key, @@ -289,6 +290,7 @@ Tensor& flash_attention_kernel_out( nullopt); } else if (seq_len >= 192) { sdpa::impl::cpu_flash_attention( + ctx, output, query, key, @@ -305,6 +307,7 @@ Tensor& flash_attention_kernel_out( nullopt); } else { sdpa::impl::cpu_flash_attention( + ctx, output, query, key, @@ -418,6 +421,7 @@ Tensor& custom_sdpa_out_impl( // we might consider another appraoch if (seq_len >= 768) { sdpa::impl::cpu_flash_attention( + ctx, output, q, k, @@ -437,6 +441,7 @@ Tensor& custom_sdpa_out_impl( num_keys_for_causal_attention); } else if (seq_len >= 192) { sdpa::impl::cpu_flash_attention( + ctx, output, q, k, @@ -456,6 +461,7 @@ Tensor& custom_sdpa_out_impl( num_keys_for_causal_attention); } else { sdpa::impl::cpu_flash_attention( + ctx, output, q, k, diff --git a/extension/llm/custom_ops/op_sdpa_impl.h b/extension/llm/custom_ops/op_sdpa_impl.h index 3fa1d694f02..73c5ccf707f 100644 --- a/extension/llm/custom_ops/op_sdpa_impl.h +++ b/extension/llm/custom_ops/op_sdpa_impl.h @@ -543,6 +543,7 @@ TODO: Just handle conversion of bool mask to float */ template void cpu_flash_attention( + RuntimeContext& ctx, Tensor& output, const Tensor& query, const Tensor& key, @@ -763,29 +764,34 @@ void cpu_flash_attention( // Since all intermediate compute is accum_t, we need to // allocate a buffer accordingly. - int64_t size_of_intermediate_precision = sizeof(accum_t); - int64_t size_bytes = size_per_thread * num_thread * query.element_size() * - size_of_intermediate_precision; - std::vector buf_vec(size_bytes); - void* buf = reinterpret_cast(buf_vec.data()); - // Need to double check the following - size_bytes = num_thread * qSplitSize * kvSplitSize * query.element_size(); - std::vector buf_reduced_vec(size_bytes); - void* buf_reduced = reinterpret_cast(buf_reduced_vec.data()); - // at::Tensor buf_reduced = at::empty( - // {num_thread, qSplitSize, is_reduced_type ? kvSplitSize : 0}, - // query.options()); + int64_t size_bytes = size_per_thread * num_thread * sizeof(accum_t); + std::unique_ptr allocated_buf; + void* buf; + Result scratch = ctx.allocate_temp(size_bytes, 64); + if (!scratch.ok()) { + allocated_buf = std::make_unique(size_bytes); + buf = allocated_buf.get(); + } else { + buf = scratch.get(); + } + void* buf_reduced = nullptr; int64_t size_per_thread_qdq_vec = kvSplitSize * headSize; // Lets align size_per_thread_qdq_vec to 64 bytes, for coalesced cache reads, // by padding with right number of per thread elements - constexpr int64_t kAlignment = 64; - size_per_thread_qdq_vec = - (size_per_thread_qdq_vec + kAlignment - 1) & (-(kAlignment - 1)); int64_t size_per_thread_qdq_bytes = size_per_thread_qdq_vec * sizeof(accum_t); int64_t size_qdq_bytes = size_per_thread_qdq_bytes * num_thread; - std::vector scratch_for_quant_dequant_vec(size_qdq_bytes); - accum_t* scratch_for_quant_dequant = - reinterpret_cast(scratch_for_quant_dequant_vec.data()); + std::unique_ptr allocated_buf_for_qdq; + accum_t* scratch_for_quant_dequant; + Result scratch_for_quant_dequant_res = + ctx.allocate_temp(size_qdq_bytes, 64); + if (!scratch_for_quant_dequant_res.ok()) { + allocated_buf_for_qdq = std::make_unique(size_qdq_bytes); + scratch_for_quant_dequant = + reinterpret_cast(allocated_buf_for_qdq.get()); + } else { + scratch_for_quant_dequant = + reinterpret_cast(scratch_for_quant_dequant_res.get()); + } // Data ptrs const scalar_t* q_data = query.const_data_ptr(); @@ -819,6 +825,7 @@ void cpu_flash_attention( // Initialize max and sum fill_stub( qk_max_data, -std::numeric_limits::infinity(), qBlockSize); + fill_stub(qk_sum_data, static_cast(0), qBlockSize); // Original flash sdpa wasnt really meant to be used // for decode the way we are using via start_pos here. // Thus when num_keys is 1 during decode phase, we @@ -850,6 +857,7 @@ void cpu_flash_attention( is_causal ? std::min(m + start_pos + qBlockSize, kvSize) : kvSize; int64_t m_start_pos = m + start_pos; auto j_kv = j / num_reps; + fill_stub(dst_data, static_cast(0), qSplitSize * headSize); for (int64_t n = 0; n < num_keys; n += kvSplitSize) { int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n); // Calculate scale * q @ k.T