Skip to content

Commit 7b8ce5b

Browse files
XiaobingSuperzhuhaozhechunyuan-w
authored
enable onednn batchnorm for training path (#83)
* enable onednn batchnorm for training path * add some note Co-authored-by: zhuhaozhe <haozhe.zhu@intel.com> Co-authored-by: Chunyuan WU <chunyuan.wu@intel.com>
1 parent 951444b commit 7b8ce5b

File tree

4 files changed

+265
-13
lines changed

4 files changed

+265
-13
lines changed

torch_ipex/csrc/autocast_kernel.cpp

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
#include "autocast_mode.h"
21
#include "autocast_kernel.hpp"
2+
#include "autocast_mode.h"
33
#include "autocast_verbose.h"
4+
#include "cpu/BatchNorm.h"
45
#include "quantization/AutoCast.hpp"
56

67
namespace torch_ipex {
@@ -104,12 +105,12 @@ at::Tensor batch_norm(const at::Tensor& input, const c10::optional<at::Tensor>&
104105
#if defined(ENABLE_AUTOCAST_VERBOSE)
105106
verbose::OpNameGuard op_name("batch_norm");
106107
#endif
107-
return at::batch_norm(cpu_cached_cast(at::kFloat, input),
108-
cpu_cached_cast(at::kFloat, weight),
109-
cpu_cached_cast(at::kFloat, bias),
110-
cpu_cached_cast(at::kFloat, running_mean),
111-
cpu_cached_cast(at::kFloat, running_var),
112-
training, momentum, eps, cudnn_enabled);
108+
// This is temporary solution before the bn supports mixed precision in the
109+
// stock pytorch, i.e. input can be bf16 or fp32, but for weight and bias,
110+
// they are always fp32.
111+
return torch_ipex::cpu::batch_norm(input, weight, bias, running_mean,
112+
running_var, training, momentum, eps,
113+
cudnn_enabled);
113114
}
114115

115116
at::Tensor linear(const at::Tensor& input, const at::Tensor& weight, const c10::optional<at::Tensor>& bias) {
@@ -225,9 +226,10 @@ at::Tensor gelu(const at::Tensor& input) {
225226
return at::gelu(input);
226227
}
227228

228-
std::tuple<Tensor, Tensor, Tensor> lstm_aten(
229-
const Tensor& _input, TensorList hx, TensorList _params, bool has_biases,
230-
int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) {
229+
std::tuple<at::Tensor, at::Tensor, at::Tensor>
230+
lstm_aten(const at::Tensor &_input, at::TensorList hx, at::TensorList _params,
231+
bool has_biases, int64_t num_layers, double dropout_p, bool train,
232+
bool bidirectional, bool batch_first) {
231233
c10::impl::ExcludeDispatchKeyGuard no_autocastCPU(DispatchKey::AutocastCPU);
232234
auto target_type = get_autocast_dtype();
233235
// not support projection case, for projection case, make fall through.

torch_ipex/csrc/autocast_kernel.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,10 @@ at::Tensor dropout(const at::Tensor& input, double p, bool train);
4949

5050
at::Tensor gelu(const at::Tensor& input);
5151

52-
std::tuple<Tensor, Tensor, Tensor> lstm_aten(
53-
const Tensor& _input, TensorList hx, TensorList _params, bool has_biases,
54-
int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first);
52+
std::tuple<at::Tensor, at::Tensor, at::Tensor>
53+
lstm_aten(const at::Tensor &_input, at::TensorList hx, at::TensorList _params,
54+
bool has_biases, int64_t num_layers, double dropout_p, bool train,
55+
bool bidirectional, bool batch_first);
5556

5657
} // autocast
5758
} // torch_ipex

torch_ipex/csrc/cpu/BatchNorm.cpp

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
#include "BatchNorm.h"
2+
#include "mkldnn/MKLDNNCommon.h"
3+
#include "torch_ipex/csrc/autocast_mode.h"
4+
#include "torch_ipex/csrc/autocast_verbose.h"
5+
#include <torch/extension.h>
6+
7+
namespace torch_ipex {
8+
namespace cpu {
9+
10+
std::tuple<at::Tensor, at::Tensor, at::Tensor>
11+
batch_norm_impl(const at::Tensor &input, const at::Tensor &weight,
12+
const at::Tensor &bias,
13+
const c10::optional<at::Tensor> &running_mean_opt,
14+
const c10::optional<at::Tensor> &running_var_opt, bool train,
15+
double momentum, double eps) {
16+
const at::Tensor &running_mean =
17+
c10::value_or_else(running_mean_opt, [] { return at::Tensor(); });
18+
const at::Tensor &running_var =
19+
c10::value_or_else(running_var_opt, [] { return at::Tensor(); });
20+
21+
ideep::tensor x = itensor_view_from_dense(input);
22+
ideep::tensor w = itensor_view_from_dense(weight);
23+
ideep::tensor b = itensor_view_from_dense(bias);
24+
bool use_running_stat = (running_mean.defined() && running_var.defined());
25+
26+
bool is_channels_last =
27+
input.suggest_memory_format() == at::MemoryFormat::ChannelsLast;
28+
auto output =
29+
at::empty(input.sizes(),
30+
input.options().memory_format(input.suggest_memory_format()));
31+
ideep::tensor y;
32+
if (is_channels_last) {
33+
y = itensor_view_from_dense(output);
34+
}
35+
if (train) {
36+
// TODO: enable 3d batchnorm.
37+
TORCH_CHECK(
38+
input.dim() == 4,
39+
"batch_norm: currently mkldnn training only support 2d batchnorm");
40+
auto saved_mean = at::empty(input.size(1), weight.options());
41+
auto saved_var = at::empty(input.size(1), weight.options());
42+
ideep::tensor mkldnn_saved_mean = itensor_view_from_dense(saved_mean);
43+
ideep::tensor mkldnn_saved_var = itensor_view_from_dense(saved_var);
44+
ideep::batch_normalization_forward_training::compute(
45+
x, w, b, y, mkldnn_saved_mean, mkldnn_saved_var, momentum, eps);
46+
if (use_running_stat) {
47+
auto len = x.get_nelems() / w.get_nelems(); // n*h*w
48+
ideep::tensor m = itensor_view_from_dense(running_mean);
49+
ideep::tensor v = itensor_view_from_dense(running_var);
50+
const std::vector<float> scales_mean{static_cast<float>(1 - momentum),
51+
static_cast<float>(momentum)};
52+
const std::vector<float> scales_var{
53+
static_cast<float>(1 - momentum),
54+
static_cast<float>(momentum * len / (len - 1))};
55+
ideep::sum::compute(scales_mean, {m, mkldnn_saved_mean}, m);
56+
ideep::sum::compute(scales_var, {v, mkldnn_saved_var}, v);
57+
}
58+
if (is_channels_last) {
59+
return std::make_tuple(output, saved_mean, saved_var);
60+
} else {
61+
return std::make_tuple(
62+
at::native::mkldnn_to_dense(new_with_itensor_mkldnn(
63+
std::move(y),
64+
optTypeMetaToScalarType(input.options().dtype_opt()),
65+
input.options().device_opt())),
66+
saved_mean, saved_var);
67+
}
68+
} else {
69+
TORCH_CHECK(input.dim() == 4 || input.dim() == 5,
70+
"batch_norm: currently mkldnn inference only support 2d and 3d "
71+
"batchnorm");
72+
if (use_running_stat) {
73+
ideep::tensor m = itensor_view_from_dense(running_mean);
74+
ideep::tensor v = itensor_view_from_dense(running_var);
75+
ideep::batch_normalization_forward_inference::compute(x, m, v, w, b, y,
76+
eps);
77+
} else {
78+
// TODO: keep running estimates.
79+
TORCH_CHECK(
80+
false,
81+
"mkldnn_batch_norm: mkldnn inference is not keep running estimates.");
82+
}
83+
84+
if (is_channels_last) {
85+
return std::make_tuple(output, at::Tensor(), at::Tensor());
86+
} else {
87+
return std::make_tuple(
88+
at::native::mkldnn_to_dense(new_with_itensor_mkldnn(
89+
std::move(y),
90+
optTypeMetaToScalarType(input.options().dtype_opt()),
91+
input.options().device_opt())),
92+
at::Tensor(), at::Tensor());
93+
}
94+
}
95+
}
96+
97+
std::tuple<at::Tensor, at::Tensor, at::Tensor>
98+
batch_norm_backward_impl(const at::Tensor &grad_output, const at::Tensor &input,
99+
const at::Tensor &weight, const at::Tensor &save_mean,
100+
const at::Tensor &save_invstd, bool train, double eps,
101+
std::array<bool, 3> grad_input_mask) {
102+
TORCH_CHECK(
103+
train,
104+
"mkldnn_batch_norm_backward: currently mkldnn only support train model");
105+
ideep::tensor grady = itensor_view_from_dense(grad_output);
106+
ideep::tensor x = itensor_view_from_dense(input);
107+
ideep::tensor w = itensor_view_from_dense(weight);
108+
ideep::tensor m = itensor_view_from_dense(save_mean);
109+
ideep::tensor v = itensor_view_from_dense(save_invstd);
110+
111+
bool is_channels_last =
112+
grad_output.suggest_memory_format() == at::MemoryFormat::ChannelsLast;
113+
auto grad_input = at::empty(
114+
grad_output.sizes(),
115+
grad_output.options().memory_format(grad_output.suggest_memory_format()));
116+
auto grad_weight = at::empty(grad_output.size(1), weight.options());
117+
auto grad_bias = at::empty(grad_output.size(1), weight.options());
118+
ideep::tensor gradx, gradw, gradb;
119+
if (is_channels_last) {
120+
gradx = itensor_view_from_dense(grad_input);
121+
}
122+
gradw = itensor_view_from_dense(grad_weight);
123+
gradb = itensor_view_from_dense(grad_bias);
124+
ideep::batch_normalization_backward::compute(x, m, v, grady, w, gradx, gradw,
125+
gradb, eps);
126+
127+
if (is_channels_last) {
128+
return std::make_tuple(grad_input, grad_weight, grad_bias);
129+
} else {
130+
return std::make_tuple(
131+
at::native::mkldnn_to_dense(new_with_itensor_mkldnn(
132+
std::move(gradx),
133+
optTypeMetaToScalarType(grad_output.options().dtype_opt()),
134+
grad_output.options().device_opt())),
135+
grad_weight, grad_bias);
136+
}
137+
}
138+
139+
at::Tensor
140+
IPEXBatchNormOp::forward(torch::autograd::AutogradContext *ctx,
141+
const at::Tensor &input, const at::Tensor &weight,
142+
const at::Tensor &bias,
143+
const c10::optional<at::Tensor> &running_mean_opt,
144+
const c10::optional<at::Tensor> &running_var_opt,
145+
bool train, double momentum, double eps) {
146+
#if defined(IPEX_PROFILE_OP)
147+
RECORD_FUNCTION("IPEXBatchNormOp::forward", std::vector<c10::IValue>({}));
148+
#endif
149+
ctx->saved_data["train"] = train;
150+
ctx->saved_data["eps"] = eps;
151+
ctx->saved_data["input_requires_grad"] = input.requires_grad();
152+
ctx->saved_data["weight_requires_grad"] = weight.requires_grad();
153+
ctx->saved_data["bias_requires_grad"] = bias.requires_grad();
154+
at::Tensor output, save_mean, save_invstd;
155+
std::tie(output, save_mean, save_invstd) =
156+
batch_norm_impl(input, weight, bias, running_mean_opt, running_var_opt,
157+
train, momentum, eps);
158+
ctx->save_for_backward({input, weight, save_mean, save_invstd});
159+
return output;
160+
}
161+
162+
torch::autograd::variable_list
163+
IPEXBatchNormOp::backward(torch::autograd::AutogradContext *ctx,
164+
torch::autograd::variable_list grad_outputs) {
165+
#if defined(IPEX_PROFILE_OP)
166+
RECORD_FUNCTION("IPEXConvolutionOp::backward", std::vector<c10::IValue>({}));
167+
#endif
168+
auto train = ctx->saved_data["train"].toBool();
169+
auto eps = ctx->saved_data["eps"].toDouble();
170+
171+
std::array<bool, 3> output_mask;
172+
output_mask[0] = ctx->saved_data["input_requires_grad"].toBool();
173+
output_mask[1] = ctx->saved_data["weight_requires_grad"].toBool();
174+
output_mask[2] = ctx->saved_data["bias_requires_grad"].toBool();
175+
auto saved = ctx->get_saved_variables();
176+
at::Tensor input = saved[0];
177+
at::Tensor weight = saved[1];
178+
at::Tensor save_mean = saved[2];
179+
at::Tensor save_invstd = saved[3];
180+
at::Tensor grad_input, grad_weight, grad_bias;
181+
std::tie(grad_input, grad_weight, grad_bias) =
182+
batch_norm_backward_impl(grad_outputs[0], input, weight, save_mean,
183+
save_invstd, train, eps, output_mask);
184+
return {grad_input, grad_weight, grad_bias, at::Tensor(),
185+
at::Tensor(), at::Tensor(), at::Tensor(), at::Tensor()};
186+
}
187+
188+
at::Tensor batch_norm(const at::Tensor &input,
189+
const c10::optional<at::Tensor> &weight_opt,
190+
const c10::optional<at::Tensor> &bias_opt,
191+
const c10::optional<at::Tensor> &running_mean_opt,
192+
const c10::optional<at::Tensor> &running_var_opt,
193+
bool train, double momentum, double eps,
194+
bool cudnn_enabled) {
195+
#if defined(IPEX_PROFILE_OP)
196+
RECORD_FUNCTION("torch_ipex::batch_norm", std::vector<c10::IValue>({}));
197+
#endif
198+
if (weight_opt.has_value() && bias_opt.has_value() && train &&
199+
!torch::jit::tracer::isTracing()) {
200+
return IPEXBatchNormOp::apply(input, weight_opt.value(), bias_opt.value(),
201+
running_mean_opt, running_var_opt, train,
202+
momentum, eps);
203+
} else {
204+
at::Tensor input_ = input;
205+
if (input.scalar_type() == at::kBFloat16) {
206+
input_ = input.to(at::kFloat);
207+
}
208+
return at::batch_norm(input_, weight_opt, bias_opt, running_mean_opt,
209+
running_var_opt, train, momentum, eps, cudnn_enabled);
210+
}
211+
}
212+
213+
} // namespace cpu
214+
} // namespace torch_ipex

torch_ipex/csrc/cpu/BatchNorm.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#pragma once
2+
3+
#include <ATen/ATen.h>
4+
#include <ATen/Tensor.h>
5+
#include <torch/csrc/autograd/custom_function.h>
6+
7+
#include "ideep/ideep.hpp"
8+
9+
namespace torch_ipex {
10+
namespace cpu {
11+
12+
class IPEXBatchNormOp : public torch::autograd::Function<IPEXBatchNormOp> {
13+
public:
14+
static at::Tensor forward(torch::autograd::AutogradContext *ctx,
15+
const at::Tensor &input, const at::Tensor &weight,
16+
const at::Tensor &bias,
17+
const c10::optional<at::Tensor> &running_mean_opt,
18+
const c10::optional<at::Tensor> &running_var_opt,
19+
bool train, double momentum, double eps);
20+
21+
static torch::autograd::variable_list
22+
backward(torch::autograd::AutogradContext *ctx,
23+
torch::autograd::variable_list grad_outputs);
24+
};
25+
26+
at::Tensor batch_norm(const at::Tensor &input,
27+
const c10::optional<at::Tensor> &weight_opt,
28+
const c10::optional<at::Tensor> &bias_opt,
29+
const c10::optional<at::Tensor> &running_mean_opt,
30+
const c10::optional<at::Tensor> &running_var_opt,
31+
bool train, double momentum, double eps,
32+
bool cudnn_enabled);
33+
34+
} // namespace cpu
35+
} // namespace torch_ipex

0 commit comments

Comments
 (0)