1717from intel_extension_for_pytorch .cpu ._auto_kernel_selection import _enable_dnnl , _disable_dnnl
1818import intel_extension_for_pytorch ._C as torch_ipex_cpp
1919try :
20- from . import tpp
20+ from . import tpp
2121except :
2222 warnings .warn ("pls install transformers repo when you want to use fast_bert API" )
2323
@@ -88,10 +88,10 @@ def _deep_copy_params_attr(old_module, new_module):
8888
8989def enable_auto_channels_last ():
9090 global auto_channels_last
91- auto_channels_last = True
91+ auto_channels_last = True
9292
9393def disable_auto_channels_last ():
94- global auto_channels_last
94+ global auto_channels_last
9595 auto_channels_last = False
9696
9797class _Properties (object ):
@@ -185,7 +185,7 @@ def forward(*input, **kwargs):
185185 else :
186186 return self .model (* input , ** kwargs )
187187 else :
188- # Lock the graph generation process to avoid multiple threads generating graph simultaneously.
188+ # Lock the graph generation process to avoid multiple threads generating graph simultaneously.
189189 with self .lock :
190190 if self .method :
191191 if self .train :
@@ -324,7 +324,7 @@ def optimize(
324324 input data will impact the block format of packed weight. If not feed a sample
325325 input, Intel® Extension for PyTorch* will pack the weight per some predefined heuristics.
326326 If feed a sample input with real input shape, Intel® Extension for PyTorch* can get
327- best block format.
327+ best block format.
328328 auto_kernel_selection (bool) [experimental]: Different backends may have
329329 different performances with different dtypes/shapes. Default value
330330 is False. Intel® Extension for PyTorch* will try to optimize the
@@ -388,19 +388,18 @@ def optimize(
388388 opt_properties = _Properties ()
389389 if level not in opt_levels :
390390 raise RuntimeError (
391- "Unexpected optimization level {}. " .format (level ) +
392- "Options are 'O0', 'O1'." )
391+ f"Unexpected optimization level { level } . Options are 'O0', 'O1'." )
393392 else :
394393 opt_properties = opt_levels [level ](opt_properties )
395394
396395 device_type = 'cpu'
397396 if len (list (model .parameters ())) and list (model .parameters ())[0 ].device .type == 'xpu' :
398397 if not all ([param .device .type == 'xpu' for param in list (model .parameters ())]):
399- raise RuntimeError ("The model is mixed with different device type" )
398+ raise RuntimeError ("The model is mixed with different device type. " )
400399 else :
401400 device_type = 'xpu'
402401
403- # auto model channels_last memory format conversion
402+ # auto model channels_last memory format conversion
404403 # TODO: for xpu, the auto channels last is temp disabled
405404 if auto_channels_last and device_type == 'cpu' :
406405 _convert_convNd_weight_memory_format (model )
@@ -433,22 +432,25 @@ def optimize(
433432 # when on xpu, some features are not supported
434433 if device_type == 'xpu' :
435434 if opt_properties .auto_kernel_selection :
436- warnings .warn ("For XPU device , the auto kernel selection is unsupported, so disable it" )
435+ warnings .warn ("For XPU device, the auto kernel selection is unsupported, so disable it. " )
437436 opt_properties .auto_kernel_selection = False
438437 if opt_properties .split_master_weight_for_bf16 :
439- warnings .warn ("For XPU device, the split master weight is unsupported for now, so temp to disable it" )
438+ warnings .warn ("For XPU device, the split master weight is unsupported for now, so temp to disable it. " )
440439 # TODO: for xpu, the split master weight will be supported soon
441440 opt_properties .split_master_weight_for_bf16 = False
442441 if opt_properties .graph_mode :
443- warnings .warn ("For XPU, the oob solution for inference is to trace model outside of the ipex.optimize, so temp to disable the graph mode" )
442+ warnings .warn ("For XPU, the Out-of-Box (OOB) solution for inference is to trace model outside of the " +
443+ "ipex.optimize, so temp to disable the graph mode." )
444444 # TODO: for xpu now, the oob solution for inference is to trace model outside of the ipex.optimize.
445445 opt_properties .graph_mode = False
446446 if not inplace :
447- warnings .warn ("For XPU device to save valuable device memory, temp to do optimization on inplaced model, so make inplace to be true" )
447+ warnings .warn ("For XPU device to save valuable device memory, temp to do optimization on inplaced model, " +
448+ "so make inplace to be true" )
448449 # TODO: for xpu, inplace is true will add device memory pressure, so set inplace to be true
449450 inplace = True
450451 if opt_properties .weights_prepack :
451- warnings .warn ("For XPU, the weight prepack and sample input are disabled. For onednn layout, IPEX_XPU_ONEDNN_LAYOUT is recommended to use" )
452+ warnings .warn ("For XPU, the weight prepack and sample input are disabled. For onednn layout, " +
453+ "IPEX_XPU_ONEDNN_LAYOUT is recommended to use" )
452454 opt_properties .weights_prepack = False
453455 sample_input = None
454456
@@ -462,7 +464,7 @@ def optimize(
462464 if isinstance (sample_input , torch .Tensor ):
463465 sample_input = (sample_input ,)
464466 utils ._weight_prepack .record_input_shape_for_prepack (optimized_model , sample_input )
465-
467+
466468 if not model .training :
467469 if opt_properties .conv_bn_folding :
468470 try :
@@ -487,22 +489,22 @@ def optimize(
487489 if not opt_properties .fuse_update_step :
488490 opt_properties .split_master_weight_for_bf16 = False
489491 warnings .warn (
490- "IPEX does not non-fused split master weight for bf16 training," +
491- "have reset split_master_weight_for_bf16 flag to False." +
492- "If you want to use split_master_weight_for_bf16." +
493- "Please set both split_master_weight_for_bf16 and fuse_update_step to True" )
492+ "IPEX does not non-fused split master weight for bf16 training, " +
493+ "have reset split_master_weight_for_bf16 flag to False. " +
494+ "If you want to use split_master_weight_for_bf16. " +
495+ "Please set both split_master_weight_for_bf16 and fuse_update_step to True. " )
494496 elif type (optimizer ) not in IPEX_FUSED_OPTIMIZER_LIST_CPU and device_type == 'cpu' :
495497 opt_properties .split_master_weight_for_bf16 = False
496498 opt_properties .fuse_update_step = False
497499 warnings .warn (
498- "IPEX CPU does not support fused/fused split update for" + str (type (optimizer )) +
499- "will use non-fused master weight update for bf16 training on CPU" )
500+ "IPEX CPU does not support fused/fused split update for " + str (type (optimizer )) +
501+ " will use non-fused master weight update for bf16 training on CPU. " )
500502 elif type (optimizer ) not in IPEX_FUSED_OPTIMIZER_LIST_XPU and device_type == 'xpu' :
501503 opt_properties .split_master_weight_for_bf16 = False
502504 opt_properties .fuse_update_step = False
503505 warnings .warn (
504- "IPEX XPU does not support fused/fused split update for" + str (type (optimizer )) +
505- "will use non-fused master weight update for bf16 training on XPU" )
506+ "IPEX XPU does not support fused/fused split update for " + str (type (optimizer )) +
507+ " will use non-fused master weight update for bf16 training on XPU. " )
506508
507509 # convert optimizer for training case.
508510 params_attr = {}
@@ -512,7 +514,7 @@ def optimize(
512514 optimized_model , optimized_optimizer , params_attr = utils ._weight_cast .weight_dtype_convert_with_ipex (
513515 optimized_model , optimized_optimizer , params_attr , opt_properties .split_master_weight_for_bf16 , convert_dtype = torch .bfloat16 )
514516 if dtype == torch .half and model .training :
515- assert device_type != 'xpu' , "For now, XPU device does not support model training with half precision"
517+ assert device_type != 'xpu' , "For now, XPU device does not support model training with half precision. "
516518 optimized_model , optimized_optimizer , params_attr = utils ._weight_cast .weight_dtype_convert_with_ipex (
517519 optimized_model , optimized_optimizer , params_attr , False , convert_dtype = torch .half )
518520 # Since TorchDynamo cannot handle custom operations yet, for the case of inference graph mode,
@@ -669,32 +671,32 @@ def get_fp32_math_mode(device="cpu"):
669671
670672def fast_bert (model , dtype = torch .float , optimizer = None , unpad = False ):
671673 r"""
672- Use TPP to speedup training/inference. fast_bert API is still a experimental
673- feature and now only optimized for bert model.
674+ Use TPP to speedup training/inference. fast_bert API is still a experimental
675+ feature and now only optimized for bert model.
674676
675677 Args:
676678 model (torch.nn.Module): User model to apply optimizations on.
677679 dtype (torch.dtype): Only works for ``torch.bfloat16`` and ``torch.float`` .
678680 The default value is torch.float.
679681 optimizer (torch.optim.Optimizer): User optimizer to apply optimizations
680682 on, such as SGD. The default value is ``None``, meaning inference case.
681- unpad(bool): Unpad the squence to reduce the sparsity.
682- seed(string): The seed used for the libxsmm kernel. In general it should be same
683- to the torch.seed
683+ unpad(bool): Unpad the squence to reduce the sparsity.
684+ seed(string): The seed used for the libxsmm kernel. In general it should be same
685+ to the torch.seed
684686
685687 .. warning::
686688
687689 Please invoke ``fast_bert`` function AFTER loading weights to model via
688690 ``model.load_state_dict(torch.load(PATH))``.
689691
690692 .. warning::
691-
693+
692694 This API can't be used when you have applied the ipex.optimize.
693695
694696 .. warning::
695697
696698 Please invoke ``optimize`` function BEFORE invoking DDP in distributed
697- training scenario.
699+ training scenario.
698700
699701 Examples:
700702
@@ -717,36 +719,36 @@ def fast_bert(model, dtype=torch.float, optimizer=None, unpad=False):
717719 max_version = '4.20.0'
718720 if 'transformers' not in installed_pkg :
719721 raise RuntimeError ("Please installed the transformers with version: between {} and {}" .format (min_version , max_version ))
720-
722+
721723 import transformers
722- from packaging import version
724+ from packaging import version
723725 trans_version = transformers .__version__
724726 if version .parse (trans_version ) < version .parse (min_version ) or version .parse (trans_version ) > version .parse (max_version ):
725727 raise RuntimeError ("Please installed the transformers with version: between {} and {} while now transformers== {}" .format (min_version , max_version , trans_version ))
726728 PT_OPTIMIZER_TO_TPP_OPTIMIZER = {torch .optim .AdamW : tpp .optim .AdamW ,
727729 transformers .optimization .AdamW : tpp .optim .AdamW ,
728730 torch .optim .SGD : tpp .optim .SGD }
729- assert (dtype == torch .float or dtype == torch .bfloat16 , "TPP only support torch.float and torch.bfloat16" )
730-
731+ assert (dtype == torch .float or dtype == torch .bfloat16 , "TPP only support torch.float and torch.bfloat16" )
732+
731733 #setup the seed for libxsmm (can be only positive int value)which will imapct some ops using seed. e.g., dropout
732734 torch_ipex_cpp .xsmm_manual_seed (torch .tensor (torch .initial_seed ()).to (torch .int32 ).abs ().item ())
733- #replace the original transfomers module object with tpp module which has the same functionality but with more
734- #operator fusion optimization
735+ #replace the original transfomers module object with tpp module which has the same functionality but with more
736+ #operator fusion optimization
735737 new_model = copy .deepcopy (model )
736738 tpp .fused_bert .layer_use_bf16 = True if dtype == torch .bfloat16 else False
737- if unpad :
738- tpp .fused_bert .unpad = True
739+ if unpad :
740+ tpp .fused_bert .unpad = True
739741 else :
740742 tpp .fused_bert .unpad = False
741743 assert (isinstance (new_model .bert .embeddings , transformers .models .bert .modeling_bert .BertEmbeddings ))
742744 new_model .bert .embeddings = tpp .fused_bert .BertEmbeddings (model .bert .config )
743745 assert (isinstance (new_model .bert .encoder , transformers .models .bert .modeling_bert .BertEncoder ))
744746 new_model .bert .encoder = tpp .fused_bert .BertEncoder (model .bert .config )
745- new_model .load_state_dict (model .state_dict ())#copy the original params into the tpp module
747+ new_model .load_state_dict (model .state_dict ())#copy the original params into the tpp module
746748 tpp .block (new_model )#get block format weights/bias
747749 if optimizer is None :
748750 return new_model
749- #replace the original pytorch/transformer optimizer with tpp optimizer for SGD/AdamW
751+ #replace the original pytorch/transformer optimizer with tpp optimizer for SGD/AdamW
750752 #keep the original optimizer state and replace the params with the blocked tpp params
751753 param_pair = {}
752754 for param_ori , param_tpp in zip (model .parameters (), new_model .parameters ()):
0 commit comments