Skip to content

Commit a557ebd

Browse files
enable convolution fusion path for n-D weight case (#120)
* enable convolution fusion path for n-D weight case * fix the comment's spelling typo * fix code format
1 parent c2794b1 commit a557ebd

File tree

11 files changed

+1059
-450
lines changed

11 files changed

+1059
-450
lines changed

intel_pytorch_extension_py/utils.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -76,25 +76,17 @@ def optimize(model, dtype=torch.bfloat16, optimizer=None, level='O1', inplace=Fa
7676
optimized_optimizer = optimizer
7777
else:
7878
optimized_model, optimized_optimizer = _copy_model_and_optimizer(model, optimizer)
79+
if not model.training:
80+
try:
81+
optimized_model = conv_bn_fuse(optimized_model, inplace=inplace)
82+
except:
83+
warnings.warn("Conv BN folding failed during the optimize process.")
84+
# do weight data type convert for inference model.
85+
if dtype == torch.bfloat16:
86+
optimized_model = _convert_module_data_type(optimized_model, torch.bfloat16)
7987
if level == 'O0':
80-
# will be removed after customer op can be traced with autocast,
81-
# see https://github.com/pytorch/pytorch/pull/60251.
82-
# after removed, will directly return original model and optimizer.
83-
if not model.training:
84-
try:
85-
optimized_model = conv_bn_fuse(optimized_model, inplace=inplace)
86-
except:
87-
warnings.warn("Conv BN folding failed during the optimize process.")
88-
# do weight data type convert for inference model.
89-
if dtype == torch.bfloat16:
90-
optimized_model = _convert_module_data_type(optimized_model, torch.bfloat16)
88+
pass
9189
elif level == 'O1':
92-
if not model.training:
93-
try:
94-
optimized_model = conv_bn_fuse(optimized_model, inplace=inplace)
95-
except:
96-
warnings.warn("Conv BN folding failed during the optimize process.")
97-
9890
# Do weight prepack, and convert optimizer for training case.
9991
optimized_model, optimized_optimizer, weight_params_attr = _weight_prepack_with_ipex(optimized_model, optimized_optimizer, dtype)
10092
if dtype == torch.bfloat16 and model.training and optimizer is not None:

0 commit comments

Comments
 (0)