55#include < ATen/Config.h>
66#include < ATen/NativeFunctions.h>
77#include < ATen/Parallel.h>
8+ #include < ATen/native/cpu/mixed_data_type.h>
89#include < ATen/record_function.h>
910#include < c10/util/accumulate.h>
1011#include " utils/library.h"
@@ -22,6 +23,40 @@ namespace cpu {
2223DEFINE_DISPATCH (GroupNormKernel);
2324DEFINE_DISPATCH (GroupNormBackwardKernel);
2425
26+ void check_group_norm_inputs (
27+ const at::Tensor& input,
28+ const at::Tensor& weight,
29+ const at::Tensor& bias,
30+ int64_t C,
31+ int64_t num_groups) {
32+ TORCH_CHECK (
33+ num_groups > 0 ,
34+ " Expected num groups to be greater than 0, got " ,
35+ num_groups);
36+ TORCH_CHECK (
37+ C % num_groups == 0 ,
38+ " Expected number of channels in input to be divisible by " ,
39+ " num_groups, but got input of shape " ,
40+ input.sizes (),
41+ " and "
42+ " num_groups=" ,
43+ num_groups);
44+ TORCH_CHECK (
45+ !weight.defined () || (weight.dim () == 1 && weight.numel () == C),
46+ " Expected weight to be a vector of size equal to the number of " ,
47+ " channels in input, but got weight of shape " ,
48+ weight.sizes (),
49+ " and input of shape " ,
50+ input.sizes ());
51+ TORCH_CHECK (
52+ !bias.defined () || (bias.dim () == 1 && bias.numel () == C),
53+ " Expected bias to be a vector of size equal to the number of " ,
54+ " channels in input, but got bias of shape " ,
55+ weight.sizes (),
56+ " and input of shape " ,
57+ input.sizes ());
58+ }
59+
2560std::tuple<at::Tensor, at::Tensor, at::Tensor> native_group_norm (
2661 const at::Tensor& X,
2762 const c10::optional<at::Tensor>& gamma_opt /* optional */ ,
@@ -44,9 +79,17 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> native_group_norm(
4479 const at::Tensor& beta =
4580 c10::value_or_else (beta_opt, [] { return at::Tensor (); });
4681
82+ // repeated check so expanded weights can call native_group_norm directly but
83+ // save mean and variance from forward
84+ check_group_norm_inputs (X, gamma, beta, C, group);
4785 auto memory_format = X.device ().is_cpu () ? X.suggest_memory_format ()
4886 : at::MemoryFormat::Contiguous;
4987
88+ bool mixed_type = at::native::is_mixed_type (X, gamma, beta);
89+ if (mixed_type) {
90+ at::native::check_mixed_data_type (X, gamma, beta);
91+ }
92+
5093 at::Tensor Y;
5194 // Add channels last 1d input support
5295 if (is_channels_last_1d (X)) {
@@ -60,8 +103,10 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> native_group_norm(
60103 c10::nullopt /* pin_memory */ ,
61104 memory_format);
62105 }
63- at::Tensor mean = at::empty ({N, group}, X.options ());
64- at::Tensor rstd = at::empty ({N, group}, X.options ());
106+
107+ const auto dtype = at::native::param_scalar_type (X, mixed_type);
108+ at::Tensor mean = at::empty ({N, group}, X.options ().dtype (dtype));
109+ at::Tensor rstd = at::empty ({N, group}, X.options ().dtype (dtype));
65110 GroupNormKernel (
66111 X.device ().type (), X, gamma, beta, N, C, HxW, group, eps, Y, mean, rstd);
67112 return std::make_tuple (Y, mean, rstd);
@@ -157,28 +202,7 @@ at::Tensor group_norm(
157202
158203 const int64_t N = input.size (0 );
159204 const int64_t C = input.size (1 );
160- TORCH_CHECK (
161- C % num_groups == 0 ,
162- " Expected number of channels in input to be divisible by " ,
163- " num_groups, but got input of shape " ,
164- input.sizes (),
165- " and "
166- " num_groups=" ,
167- num_groups);
168- TORCH_CHECK (
169- !weight.defined () || (weight.dim () == 1 && weight.numel () == C),
170- " Expected weight to be a vector of size equal to the number of " ,
171- " channels in input, but got weight of shape " ,
172- weight.sizes (),
173- " and input of shape " ,
174- input.sizes ());
175- TORCH_CHECK (
176- !bias.defined () || (bias.dim () == 1 && bias.numel () == C),
177- " Expected bias to be a vector of size equal to the number of " ,
178- " channels in input, but got bias of shape " ,
179- weight.sizes (),
180- " and input of shape " ,
181- input.sizes ());
205+ check_group_norm_inputs (input, weight, bias, C, num_groups);
182206
183207 const auto input_shape = input.sizes ();
184208 const int64_t HxW =
@@ -215,4 +239,4 @@ IPEX_TORCH_LIBRARY_IMPL(aten, CPU, m) {
215239}
216240
217241} // namespace cpu
218- } // namespace torch_ipex
242+ } // namespace torch_ipex
0 commit comments