@@ -339,6 +339,9 @@ std::tuple<Tensor, Tensor> batch_norm_gather_stats_xpu_template(
339339 const auto ngroups = (features + wgroup_size - 1 ) / wgroup_size;
340340
341341 int world_size = mean_.size (0 );
342+ // Avoid double issues in ATSM
343+ float momentum_ = momentum;
344+ float epsilon_ = epsilon;
342345
343346 auto cgf = DPCPP_Q_CGF (cgh) {
344347 cgh.parallel_for (
@@ -354,24 +357,24 @@ std::tuple<Tensor, Tensor> batch_norm_gather_stats_xpu_template(
354357 for (int j = 0 ; j < world_size; j++) {
355358 scalar_t count = counts[j];
356359 accscalar_t m = mean[j][tid];
357- accscalar_t v = accscalar_t (1.0 ) / (invstd[j][tid]);
358- v = (v * v - epsilon ) * count;
359- accscalar_t factor = 1.0 / (n + count);
360+ accscalar_t v = accscalar_t (1 .0f ) / (invstd[j][tid]);
361+ v = (v * v - epsilon_ ) * count;
362+ accscalar_t factor = 1 .0f / (n + count);
360363 var_n += v + (avg - m) * (avg - m) * n * count * factor;
361364 avg = n * factor * avg + count * factor * m;
362365 n += count;
363366 }
364367 save_mean[tid] = avg;
365368 save_invstd[tid] = static_cast <accscalar_t >(1 ) /
366- Numerics<accscalar_t >::sqrt (var_n / n + epsilon );
369+ Numerics<accscalar_t >::sqrt (var_n / n + epsilon_ );
367370 if (running_mean != nullptr ) {
368371 running_mean[tid] = static_cast <scalar_t >(
369- (1 - momentum ) * running_mean[tid] + momentum * avg);
372+ (1 - momentum_ ) * running_mean[tid] + momentum_ * avg);
370373 }
371374 accscalar_t unbiasedVar = var_n / (n - 1 );
372375 if (running_var != nullptr ) {
373376 running_var[tid] = static_cast <scalar_t >(
374- (1 - momentum ) * running_var[tid] + momentum * unbiasedVar);
377+ (1 - momentum_ ) * running_var[tid] + momentum_ * unbiasedVar);
375378 }
376379 }
377380 });
0 commit comments