@@ -261,7 +261,7 @@ def __init__(self, dense_module):
261261 torch .nn .ConvTranspose3d : _IPEXConvTranspose3d ,
262262}
263263
264- def _should_prepack (module ):
264+ def _should_prepack (module , is_training ):
265265 if type (module ) not in IPEX_WEIGHT_PREPACK_MODULE :
266266 return False
267267 # If hook is on `weight` or `bias`, will not prepack.
@@ -277,6 +277,10 @@ def _should_prepack(module):
277277 for _ , hook in module ._backward_hooks .items ():
278278 if hasattr (hook , 'name' ) and (hook .name == 'weight' or hook .name == 'bias' ):
279279 return False
280+
281+ # for training, if auto_kernel_selection(onednn) is off, IPEX won't prepack FP32 linear.
282+ if isinstance (module , torch .nn .Linear ) and not _using_dnnl () and is_training and module .weight .dtype is torch .float :
283+ return False
280284 if isinstance (module , torch .nn .ConvTranspose2d ):
281285 if module .padding [0 ] - module .output_padding [0 ] + module .stride [0 ] <= 0 :
282286 return False
@@ -296,7 +300,7 @@ def _should_prepack(module):
296300
297301def weight_prepack_with_ipex (module , optimizer , params_attr ):
298302 def convert (m , optimizer , params_attr ):
299- if _should_prepack (m ) and (m .weight .dtype == torch .float32 or m .weight .dtype == torch .bfloat16 or m .weight .dtype == torch .half ):
303+ if _should_prepack (m , optimizer != None ) and (m .weight .dtype == torch .float32 or m .weight .dtype == torch .bfloat16 or m .weight .dtype == torch .half ):
300304 weight = m .master_weight if hasattr (m , "master_weight" ) else m .weight
301305 if weight not in params_attr :
302306 params_attr [weight ] = {}
0 commit comments