Skip to content

Commit c6c1c1d

Browse files
authored
Fix batch_norm_gather_stats in atsm (#2559)
1 parent 4ad6dd1 commit c6c1c1d

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

csrc/gpu/aten/operators/BatchNorm.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,9 @@ std::tuple<Tensor, Tensor> batch_norm_gather_stats_xpu_template(
339339
const auto ngroups = (features + wgroup_size - 1) / wgroup_size;
340340

341341
int world_size = mean_.size(0);
342+
// Avoid double issues in ATSM
343+
float momentum_ = momentum;
344+
float epsilon_ = epsilon;
342345

343346
auto cgf = DPCPP_Q_CGF(cgh) {
344347
cgh.parallel_for(
@@ -354,24 +357,24 @@ std::tuple<Tensor, Tensor> batch_norm_gather_stats_xpu_template(
354357
for (int j = 0; j < world_size; j++) {
355358
scalar_t count = counts[j];
356359
accscalar_t m = mean[j][tid];
357-
accscalar_t v = accscalar_t(1.0) / (invstd[j][tid]);
358-
v = (v * v - epsilon) * count;
359-
accscalar_t factor = 1.0 / (n + count);
360+
accscalar_t v = accscalar_t(1.0f) / (invstd[j][tid]);
361+
v = (v * v - epsilon_) * count;
362+
accscalar_t factor = 1.0f / (n + count);
360363
var_n += v + (avg - m) * (avg - m) * n * count * factor;
361364
avg = n * factor * avg + count * factor * m;
362365
n += count;
363366
}
364367
save_mean[tid] = avg;
365368
save_invstd[tid] = static_cast<accscalar_t>(1) /
366-
Numerics<accscalar_t>::sqrt(var_n / n + epsilon);
369+
Numerics<accscalar_t>::sqrt(var_n / n + epsilon_);
367370
if (running_mean != nullptr) {
368371
running_mean[tid] = static_cast<scalar_t>(
369-
(1 - momentum) * running_mean[tid] + momentum * avg);
372+
(1 - momentum_) * running_mean[tid] + momentum_ * avg);
370373
}
371374
accscalar_t unbiasedVar = var_n / (n - 1);
372375
if (running_var != nullptr) {
373376
running_var[tid] = static_cast<scalar_t>(
374-
(1 - momentum) * running_var[tid] + momentum * unbiasedVar);
377+
(1 - momentum_) * running_var[tid] + momentum_ * unbiasedVar);
375378
}
376379
}
377380
});

0 commit comments

Comments
 (0)