Skip to content

Commit 7b2b561

Browse files
fix bf16 runtime error when one cpu device doesn't meet OneDNN ISA requiresment (#867)
Co-authored-by: Chunyuan WU <chunyuan.wu@intel.com>
1 parent d9a2680 commit 7b2b561

File tree

11 files changed

+117
-64
lines changed

11 files changed

+117
-64
lines changed

intel_extension_for_pytorch/csrc/jit/cpu/passes/graph_rewrite_conv.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "csrc/aten/cpu/WeightPack.h"
2+
#include "csrc/cpu/ideep/ideep.hpp"
23
#include "csrc/jit/cpu/kernels/OpContext.h"
34
#include "csrc/jit/cpu/passes/utils.h"
45
#include "graph_rewrite.h"
@@ -106,19 +107,20 @@ void insertPrePackedConvOp(Block* b) {
106107
IValue input_size_value(input_size_option.value());
107108
if (n->kind() == aten::conv1d || n->kind() == aten::conv2d ||
108109
n->kind() == aten::conv3d) {
109-
auto weight_size_option = n->inputs()
110-
.at(1)
111-
->type()
112-
->cast<TensorType>()
113-
->sizes()
114-
.concrete_sizes();
110+
auto weight_tensor_type = n->inputs().at(1)->type()->cast<TensorType>();
111+
auto weight_size_option = weight_tensor_type->sizes().concrete_sizes();
115112
// weight has not shape info, will not do weight prapacked.
116113
if (!(weight_size_option.has_value() &&
117114
(weight_size_option.value().size() == 3 ||
118115
weight_size_option.value().size() == 4 ||
119116
weight_size_option.value().size() == 5))) {
120117
continue;
121118
}
119+
const auto dtype = weight_tensor_type->scalarType();
120+
if (dtype.has_value() && *dtype == at::ScalarType::BFloat16 &&
121+
!ideep::has_bf16_type_support()) {
122+
continue;
123+
}
122124
bool w_is_channels_last = false;
123125
if (constant_as<at::Tensor>(n->namedInput("weight")).has_value()) {
124126
at::Tensor weight_tensor =

intel_extension_for_pytorch/csrc/jit/cpu/passes/graph_rewrite_conv_transpose.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include "csrc/cpu/ideep/ideep.hpp"
12
#include "graph_rewrite.h"
23
#include "graph_rewrite_utils.h"
34
#include "utils.h"
@@ -35,19 +36,19 @@ void insertPrePackedConvTransposeOpForATen(Block* b) {
3536
}
3637
IValue input_size_value(input_size_option.value());
3738

38-
auto weight_size_option = n->inputs()
39-
.at(1)
40-
->type()
41-
->cast<TensorType>()
42-
->sizes()
43-
.concrete_sizes();
39+
auto weight_tensor_type = n->inputs().at(1)->type()->cast<TensorType>();
40+
auto weight_size_option = weight_tensor_type->sizes().concrete_sizes();
4441
// weight has not shape info, will not do weight prapacked.
4542
if (!(weight_size_option.has_value() &&
4643
(weight_size_option.value().size() == 4 ||
4744
weight_size_option.value().size() == 5))) {
4845
continue;
4946
}
50-
47+
const auto dtype = weight_tensor_type->scalarType();
48+
if (dtype.has_value() && *dtype == at::ScalarType::BFloat16 &&
49+
!ideep::has_bf16_type_support()) {
50+
continue;
51+
}
5152
// # padding - output_padding + stride <= 0 unsupported in mkldnn
5253
auto stride = toIValue(n->input(3))->toIntList();
5354
auto padding = toIValue(n->input(4))->toIntList();

intel_extension_for_pytorch/csrc/jit/cpu/passes/graph_rewrite_linear.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#include <ATen/code_template.h>
2+
#include "csrc/cpu/ideep/ideep.hpp"
23
#include "csrc/jit/cpu/passes/utils.h"
4+
35
#include "graph_rewrite.h"
46
#include "graph_rewrite_utils.h"
57

@@ -98,7 +100,8 @@ void insertPrePackedLinearOp(Block* b, std::unordered_set<Node*>& aten_linear) {
98100
}
99101
auto weight_dtype_option = tt->scalarType();
100102
if (!(weight_dtype_option.has_value() &&
101-
(weight_dtype_option.value() == at::ScalarType::BFloat16) ||
103+
(weight_dtype_option.value() == at::ScalarType::BFloat16) &&
104+
ideep::has_bf16_type_support() ||
102105
aten_linear.find(n) == aten_linear.end())) {
103106
continue;
104107
}

intel_extension_for_pytorch/csrc/python/init_python_bindings.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
#include "intel_extension_for_pytorch/csrc/jit/auto_opt_config.h"
2525
#include "intel_extension_for_pytorch/csrc/utils/env_settings.h"
2626
#include "intel_extension_for_pytorch/csrc/utils/fpmath_mode.h"
27+
#include "intel_extension_for_pytorch/csrc/utils/onednn_utils.h"
2728
#include "intel_extension_for_pytorch/csrc/utils/rw_lock.h"
28-
#include "intel_extension_for_pytorch/csrc/utils/verbose.hpp"
2929

3030
#include <c10/core/DeviceType.h>
3131
#include <torch/csrc/Exceptions.h>
@@ -73,7 +73,11 @@ void InitIpexModuleBindings(py::module m) {
7373
EnvSettings::get_instance().set_settings_profile_op(b_enable);
7474
});
7575

76-
m.def("mkldnn_set_verbose", &torch_ipex::verbose::_mkldnn_set_verbose);
76+
m.def("mkldnn_set_verbose", &torch_ipex::utils::onednn_set_verbose);
77+
m.def("onednn_has_bf16_support", []() {
78+
return torch_ipex::utils::onednn_has_bf16_type_support();
79+
});
80+
7781
// ipex amp autocast
7882
m.def("get_autocast_dtype", []() {
7983
at::ScalarType current_dtype = torch_ipex::autocast::get_autocast_dtype();
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#include "onednn_utils.h"
2+
3+
#include "csrc/cpu/ideep/ideep.hpp"
4+
5+
namespace torch_ipex {
6+
namespace utils {
7+
8+
int onednn_set_verbose(int level) {
9+
return ideep::utils::set_verbose(level);
10+
}
11+
12+
bool onednn_has_bf16_type_support() {
13+
return ideep::has_bf16_type_support();
14+
}
15+
16+
} // namespace utils
17+
} // namespace torch_ipex
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
namespace torch_ipex {
2+
namespace utils {
3+
4+
int onednn_set_verbose(int level);
5+
bool onednn_has_bf16_type_support();
6+
7+
} // namespace utils
8+
} // namespace torch_ipex

intel_extension_for_pytorch/csrc/utils/verbose.cpp

Lines changed: 0 additions & 13 deletions
This file was deleted.

intel_extension_for_pytorch/csrc/utils/verbose.hpp

Lines changed: 0 additions & 7 deletions
This file was deleted.

intel_extension_for_pytorch/frontend.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,10 @@ def optimize(
287287
optimized_model, optimized_optimizer, params_attr = utils._weight_cast.weight_dtype_convert_with_ipex(
288288
optimized_model, optimized_optimizer, params_attr, opt_properties.split_master_weight_for_bf16)
289289
if opt_properties.weights_prepack:
290+
if dtype == torch.bfloat16:
291+
assert core.onednn_has_bf16_support(), \
292+
"BF16 weight prepack needs the cpu support avx512bw, avx512vl and avx512dq, " + \
293+
"please set dtype to torch.float or set weights_prepack to False."
290294
optimized_model, optimized_optimizer, params_attr = utils._weight_prepack.weight_prepack_with_ipex(
291295
optimized_model, optimized_optimizer, params_attr, opt_properties.auto_kernel_selection)
292296
# TODO: model list, optimizer list.

tests/cpu/test_ipex_optimize.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22
import intel_extension_for_pytorch as ipex
3+
import intel_extension_for_pytorch._C as core
34
from intel_extension_for_pytorch.nn.utils._weight_prepack import _IPEXLinear as _IPEXLinear, _IPEXConv2d as _IPEXConv2d
45
from torch.testing._internal.common_utils import TestCase
56
from torch.optim import Adadelta, Adagrad, Adam, AdamW, Adamax, ASGD, RMSprop, Rprop, SGD
@@ -126,6 +127,21 @@ def forward(self, x):
126127
"WARNING: Can't convert model's parameters dtype"):
127128
optimized_model = ipex.optimize(model.eval(), dtype=torch.bfloat16)
128129

130+
def test_optimize_bf16_upsupported(self):
131+
class Conv(torch.nn.Module):
132+
def __init__(self,):
133+
super(Conv, self).__init__()
134+
self.conv = torch.nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
135+
136+
def forward(self, x):
137+
return self.conv(x)
138+
139+
model = Conv()
140+
if not core.onednn_has_bf16_support():
141+
msg = r"BF16 weight prepack needs the cpu support avx512bw, avx512vl and avx512dq, please set dtype to torch.float or set weights_prepack to False."
142+
with self.assertRaisesRegex(AssertionError, msg):
143+
optimized_model = ipex.optimize(model.eval(), dtype=torch.bfloat16)
144+
129145
def test_optimize_unsupport_freeze_optimization(self):
130146
model = ConvBatchNorm().eval()
131147
x = model.input1

0 commit comments

Comments
 (0)