@@ -76,25 +76,17 @@ def optimize(model, dtype=torch.bfloat16, optimizer=None, level='O1', inplace=Fa
7676 optimized_optimizer = optimizer
7777 else :
7878 optimized_model , optimized_optimizer = _copy_model_and_optimizer (model , optimizer )
79+ if not model .training :
80+ try :
81+ optimized_model = conv_bn_fuse (optimized_model , inplace = inplace )
82+ except :
83+ warnings .warn ("Conv BN folding failed during the optimize process." )
84+ # do weight data type convert for inference model.
85+ if dtype == torch .bfloat16 :
86+ optimized_model = _convert_module_data_type (optimized_model , torch .bfloat16 )
7987 if level == 'O0' :
80- # will be removed after customer op can be traced with autocast,
81- # see https://github.com/pytorch/pytorch/pull/60251.
82- # after removed, will directly return original model and optimizer.
83- if not model .training :
84- try :
85- optimized_model = conv_bn_fuse (optimized_model , inplace = inplace )
86- except :
87- warnings .warn ("Conv BN folding failed during the optimize process." )
88- # do weight data type convert for inference model.
89- if dtype == torch .bfloat16 :
90- optimized_model = _convert_module_data_type (optimized_model , torch .bfloat16 )
88+ pass
9189 elif level == 'O1' :
92- if not model .training :
93- try :
94- optimized_model = conv_bn_fuse (optimized_model , inplace = inplace )
95- except :
96- warnings .warn ("Conv BN folding failed during the optimize process." )
97-
9890 # Do weight prepack, and convert optimizer for training case.
9991 optimized_model , optimized_optimizer , weight_params_attr = _weight_prepack_with_ipex (optimized_model , optimized_optimizer , dtype )
10092 if dtype == torch .bfloat16 and model .training and optimizer is not None :
0 commit comments