File tree Expand file tree Collapse file tree 1 file changed +6
-1
lines changed
csrc/gpu/aten/operators/xetla/kernels/SDP Expand file tree Collapse file tree 1 file changed +6
-1
lines changed Original file line number Diff line number Diff line change @@ -28,6 +28,8 @@ template <
2828class fmha_forward_t {
2929 public:
3030 using accum_t = float ;
31+ static constexpr accum_t kNegInfinity =
32+ -std::numeric_limits<accum_t >::infinity();
3133
3234 struct arguments_t {
3335 // Input tensors
@@ -243,7 +245,7 @@ class fmha_forward_t {
243245 // nbarrier
244246 nbarrier.init_nbarrier (sg_idy, nbarrier_role::producer_consumer);
245247 // softmax statistics
246- softmax_m = -std::numeric_limits< accum_t >:: infinity () ;
248+ softmax_m = kNegInfinity ;
247249 softmax_l = 0 .f ;
248250
249251 // mem desc variables
@@ -640,6 +642,9 @@ class fmha_forward_t {
640642 xetla_vector<accum_t , kSgBr > m_new = wg_row_max (matAccSij);
641643 m_new = xetla_max<accum_t , kSgBr >(m_new, ctx.softmax_m );
642644
645+ xetla_mask<kSgBr > mask_inf = m_new == kNegInfinity ;
646+ m_new.xetla_merge (0 .f , mask_inf);
647+
643648 if constexpr (wg_size_x > 1 )
644649 ctx.nbarrier .arrive ();
645650
You can’t perform that action at this time.
0 commit comments