Skip to content

Commit 2e3618d

Browse files
authored
add check for onednn linear (#1239) (#1244)
* add check for onednn linear * disable prepack FP32 linear when dnnl is disabled in training * fix ut * minor change
1 parent c0ca047 commit 2e3618d

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

intel_extension_for_pytorch/nn/utils/_weight_prepack.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def __init__(self, dense_module):
261261
torch.nn.ConvTranspose3d: _IPEXConvTranspose3d,
262262
}
263263

264-
def _should_prepack(module):
264+
def _should_prepack(module, is_training):
265265
if type(module) not in IPEX_WEIGHT_PREPACK_MODULE:
266266
return False
267267
# If hook is on `weight` or `bias`, will not prepack.
@@ -277,6 +277,10 @@ def _should_prepack(module):
277277
for _, hook in module._backward_hooks.items():
278278
if hasattr(hook, 'name') and (hook.name == 'weight' or hook.name == 'bias'):
279279
return False
280+
281+
# for training, if auto_kernel_selection(onednn) is off, IPEX won't prepack FP32 linear.
282+
if isinstance(module, torch.nn.Linear) and not _using_dnnl() and is_training and module.weight.dtype is torch.float:
283+
return False
280284
if isinstance(module, torch.nn.ConvTranspose2d):
281285
if module.padding[0] - module.output_padding[0] + module.stride[0] <= 0:
282286
return False
@@ -296,7 +300,7 @@ def _should_prepack(module):
296300

297301
def weight_prepack_with_ipex(module, optimizer, params_attr):
298302
def convert(m, optimizer, params_attr):
299-
if _should_prepack(m) and (m.weight.dtype == torch.float32 or m.weight.dtype == torch.bfloat16 or m.weight.dtype == torch.half):
303+
if _should_prepack(m, optimizer!=None) and (m.weight.dtype == torch.float32 or m.weight.dtype == torch.bfloat16 or m.weight.dtype == torch.half):
300304
weight = m.master_weight if hasattr(m, "master_weight") else m.weight
301305
if weight not in params_attr:
302306
params_attr[weight] = {}

tests/cpu/test_ipex_optimize.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,10 @@ def test_module_conversion(self):
298298
self.assertTrue(isinstance(opt_M.linear, torch.nn.Linear))
299299
self.assertTrue(isinstance(opt_M.conv, torch.nn.Conv2d))
300300
else:
301-
self.assertTrue(isinstance(opt_M.linear, _IPEXLinear))
301+
if not auto_kernel_selection and dtype == torch.float32:
302+
self.assertTrue(isinstance(opt_M.linear, torch.nn.Linear))
303+
else:
304+
self.assertTrue(isinstance(opt_M.linear, _IPEXLinear))
302305
self.assertTrue(isinstance(opt_M.conv, _IPEXConv2d))
303306

304307
def test_record_shape(self):

0 commit comments

Comments
 (0)