Skip to content

Commit 3994f58

Browse files
YizhouZtye1
andauthored
[SDP][FWD] fix nan issue when attn_mask is -inf (#4548) (#4660)
* fix nan issue when attn_mask is -inf * fix clang-format --------- Co-authored-by: Ye Ting <ting.ye@intel.com>
1 parent 1f09fc2 commit 3994f58

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

csrc/gpu/aten/operators/xetla/kernels/SDP/fmha_forward.hpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ template <
2828
class fmha_forward_t {
2929
public:
3030
using accum_t = float;
31+
static constexpr accum_t kNegInfinity =
32+
-std::numeric_limits<accum_t>::infinity();
3133

3234
struct arguments_t {
3335
// Input tensors
@@ -243,7 +245,7 @@ class fmha_forward_t {
243245
// nbarrier
244246
nbarrier.init_nbarrier(sg_idy, nbarrier_role::producer_consumer);
245247
// softmax statistics
246-
softmax_m = -std::numeric_limits<accum_t>::infinity();
248+
softmax_m = kNegInfinity;
247249
softmax_l = 0.f;
248250

249251
// mem desc variables
@@ -640,6 +642,9 @@ class fmha_forward_t {
640642
xetla_vector<accum_t, kSgBr> m_new = wg_row_max(matAccSij);
641643
m_new = xetla_max<accum_t, kSgBr>(m_new, ctx.softmax_m);
642644

645+
xetla_mask<kSgBr> mask_inf = m_new == kNegInfinity;
646+
m_new.xetla_merge(0.f, mask_inf);
647+
643648
if constexpr (wg_size_x > 1)
644649
ctx.nbarrier.arrive();
645650

0 commit comments

Comments
 (0)