@@ -589,7 +589,7 @@ def forward(self, x):
589589 a1 = self .conv_transpose (x )
590590 b1 = torch .sigmoid (a1 )
591591 c1 = self .mul_op (a1 , b1 )
592- return c1
592+ return c1
593593
594594class ChannelShuffle_with_Static_Shape (nn .Module ):
595595 def __init__ (self , batchsize , num_channels , height , width , groups ):
@@ -992,10 +992,10 @@ def _test_onednn_fp32(self, model, input, kind_in_graph=None, kind_not_in_graph=
992992 trace_graph = tr_model .graph_for (input )
993993 res_jit = tr_model (input )
994994 self .assertEqual (res_ref , res_jit )
995-
995+
996996 if kind_in_graph is not None :
997997 self .assertTrue (any (n .kind () == kind_in_graph for n in trace_graph .nodes ()))
998-
998+
999999 if kind_not_in_graph is not None :
10001000 self .assertTrue (all (n .kind () != kind_not_in_graph for n in trace_graph .nodes ()))
10011001
@@ -1054,8 +1054,8 @@ def _test_fusion_unsupported_case(self, m, x, auto_kernel_selection=False, kind_
10541054 traced_model = torch .jit .trace (model , x ).eval ()
10551055 traced_model = torch .jit .freeze (traced_model )
10561056 tresult = traced_model (x )
1057- trace_graph = traced_model .graph_for (x )
1058-
1057+ trace_graph = traced_model .graph_for (x )
1058+
10591059 if kind_in_graph is not None :
10601060 self .assertTrue (any (n .kind () == kind_in_graph for n in trace_graph .nodes ()))
10611061
@@ -1088,7 +1088,7 @@ def test_jit_freeze(self):
10881088 self .assertTrue (all (n .kind () != pack_node for n in freeze_graph .nodes ()))
10891089# for non-freeze model, since op-ctx dose not have value, cannot re-pack for this path
10901090 self .assertTrue (any (n .kind () == imperative_node for n in trace_graph .nodes ()))
1091-
1091+
10921092
10931093 def test_concat_linear (self ):
10941094 def check_op_count (graph_str , op_names = []):
@@ -1459,7 +1459,7 @@ def _test_conv_unary_fusion(self, op_list, seed=None):
14591459 print ("{} rand sed: {}" .format (sys ._getframe ().f_code .co_name , rand_seed ))
14601460 torch .manual_seed (rand_seed )
14611461 else :
1462- print ("{} rand sed: {}" .format (sys ._getframe ().f_code .co_name , seed ))
1462+ print ("{} rand sed: {}" .format (sys ._getframe ().f_code .co_name , seed ))
14631463 torch .manual_seed (seed )
14641464
14651465 for dim in [2 , 3 ]:
@@ -1502,7 +1502,7 @@ def _test_conv_transpose_unary_fusion(self, op_list, seed=None):
15021502 print ("{} rand sed: {}" .format (sys ._getframe ().f_code .co_name , rand_seed ))
15031503 torch .manual_seed (rand_seed )
15041504 else :
1505- print ("{} rand sed: {}" .format (sys ._getframe ().f_code .co_name , seed ))
1505+ print ("{} rand sed: {}" .format (sys ._getframe ().f_code .co_name , seed ))
15061506 torch .manual_seed (seed )
15071507
15081508 for dim in [2 , 3 ]:
@@ -1538,7 +1538,7 @@ def test_conv_unary_fusion(self):
15381538 self ._test_conv_unary_fusion (unary_PyTorch_op_to_IPEX_op_map )
15391539 self ._test_conv_unary_fusion (PyTorch_op_to_IPEX_op_fixed_seed_map , 1654064339261196288 )
15401540
1541- def test_conv_non_unary_fusion (self ):
1541+ def test_conv_non_unary_fusion (self ):
15421542 self ._test_conv_unary_fusion (non_unary_PyTorch_op_to_IPEX_op_map )
15431543
15441544 def test_conv_fusion_unsupported_case (self ):
@@ -1548,7 +1548,7 @@ def test_conv_fusion_unsupported_case(self):
15481548 out_channels = 16
15491549 in_channels = 3
15501550 kernel_size = 3
1551- image_size = 16
1551+ image_size = 16
15521552 for eltwise in unsupported_PyTorch_op_to_IPEX_op_map :
15531553 input_size = [batch_size , in_channels , image_size , image_size ]
15541554
@@ -1560,7 +1560,7 @@ def test_conv_fusion_unsupported_case(self):
15601560
15611561 x = torch .randn (input_size )
15621562 m = ConvEltwise (eltwise , dim , in_channels , out_channels , kernel_size , image_size , ** op_input_list )
1563-
1563+
15641564 self ._test_fusion_unsupported_case (
15651565 m ,
15661566 x ,
@@ -1640,7 +1640,7 @@ def test_output_frozen_conv_bn(self):
16401640 if use_channels_last :
16411641 x = x .to (memory_format = torch .channels_last )
16421642 model = model .to (memory_format = torch .channels_last )
1643-
1643+
16441644 model = ipex .optimize (model , dtype = dtype , conv_bn_folding = False )
16451645
16461646 with torch .cpu .amp .autocast (enabled = True , dtype = dtype ), torch .no_grad ():
@@ -2345,7 +2345,7 @@ def test_conv_transpose_unary_fusion(self):
23452345 self ._test_conv_transpose_unary_fusion (unary_PyTorch_op_to_IPEX_op_map )
23462346 self ._test_conv_transpose_unary_fusion (PyTorch_op_to_IPEX_op_fixed_seed_map , 1654583254233936896 )
23472347
2348- def test_conv_transpose_non_unary_fusion (self ):
2348+ def test_conv_transpose_non_unary_fusion (self ):
23492349 self ._test_conv_transpose_unary_fusion (non_unary_PyTorch_op_to_IPEX_op_map )
23502350
23512351 def test_conv_transpose_fusion_unsupported_case (self ):
@@ -2356,7 +2356,7 @@ def test_conv_transpose_fusion_unsupported_case(self):
23562356 in_channels = 3
23572357 kernel_size = 3
23582358 image_size = 8
2359-
2359+
23602360 for eltwise in unsupported_PyTorch_op_to_IPEX_op_map :
23612361 input_size = [batch_size , in_channels , image_size , image_size ]
23622362
@@ -2405,7 +2405,7 @@ def test_conv_transpose_sigmoid_mul(self):
24052405 # x,
24062406 # kind_in_graph="ipex_prepack::conv_transpose_%s_run" % ipex_eltwise_op,
24072407 # kind_not_in_graph="ipex_prepack::conv_transpose_prepack",
2408- # prec=prec)
2408+ # prec=prec)
24092409
24102410 def test_linear_auto_kernel_selection_fp32 (self ):
24112411 x = torch .rand (32 , 3 )
@@ -2559,22 +2559,22 @@ def _test_linear_unary_fusion(self, op_list, seed=None):
25592559 print ("{} rand sed: {}" .format (sys ._getframe ().f_code .co_name , rand_seed ))
25602560 torch .manual_seed (rand_seed )
25612561 else :
2562- print ("{} rand sed: {}" .format (sys ._getframe ().f_code .co_name , seed ))
2563- torch .manual_seed (seed )
2562+ print ("{} rand sed: {}" .format (sys ._getframe ().f_code .co_name , seed ))
2563+ torch .manual_seed (seed )
25642564
25652565 for bias in [True , False ]:
25662566 for eltwise in op_list :
25672567 input_size = [batch_size , in_channels ]
2568-
2568+
25692569 unary_fusion_op = op_list [eltwise ]
25702570 ipex_eltwise_op = unary_fusion_op .ipex_eltwise_op
25712571 bf16_supported = unary_fusion_op .bf16_supported
25722572 prec = unary_fusion_op .prec
25732573 op_input_list = unary_fusion_op .op_input_list
2574-
2574+
25752575 x = torch .randn (input_size )
25762576 m = LinearEltwise (eltwise , in_channels , out_channels , bias , ** op_input_list )
2577-
2577+
25782578 self ._test_output (
25792579 m ,
25802580 x ,
@@ -2583,7 +2583,7 @@ def _test_linear_unary_fusion(self, op_list, seed=None):
25832583 m ,
25842584 x ,
25852585 kind_in_graph = "ipex_prepack::linear_%s_run" % ipex_eltwise_op ,
2586- kind_not_in_graph = "ipex_prepack::linear_prepack" )
2586+ kind_not_in_graph = "ipex_prepack::linear_prepack" )
25872587 if bf16_supported :
25882588 self ._test_output_bf16 (
25892589 m ,
@@ -2603,17 +2603,17 @@ def test_linear_fusion_unsupported_case(self):
26032603 batch_size = 3
26042604 out_channels = 32
26052605 in_channels = 3
2606- bias = False
2606+ bias = False
26072607
26082608 for eltwise in unsupported_PyTorch_op_to_IPEX_op_map :
26092609 input_size = [batch_size , in_channels ]
2610-
2610+
26112611 unary_fusion_op = unsupported_PyTorch_op_to_IPEX_op_map [eltwise ]
26122612 ipex_eltwise_op = unary_fusion_op .ipex_eltwise_op
26132613 bf16_supported = unary_fusion_op .bf16_supported
26142614 prec = unary_fusion_op .prec
26152615 op_input_list = unary_fusion_op .op_input_list
2616-
2616+
26172617 x = torch .randn (input_size )
26182618 m = LinearEltwise (eltwise , in_channels , out_channels , bias , ** op_input_list )
26192619
@@ -2839,7 +2839,7 @@ def _test_fp32(model_test, input1, input2, bias=None, kind_in_graph='ipex::einsu
28392839 input2 = torch .randn (768 , 2304 )
28402840 model_v1 = EinsumAdd ('bsh,ho->bso' )
28412841 _test_fp32 (model_v1 , input1 , input2 , bias )
2842-
2842+
28432843 bias = torch .randn (1 , 1 , 1 , 4 )
28442844 input1 = torch .randn (12 , 1 , 4 , 16 )
28452845 input2 = torch .randn (12 , 4 , 4 , 16 )
@@ -2851,7 +2851,7 @@ def _test_fp32(model_test, input1, input2, bias=None, kind_in_graph='ipex::einsu
28512851 input2 = torch .randn (768 , 2304 )
28522852 model_v1 = EinsumAddInplace ('bsh,ho->bso' )
28532853 _test_fp32 (model_v1 , input1 , input2 , bias )
2854-
2854+
28552855 input1 = torch .randn (8 , 3 , 768 )
28562856 input2 = torch .randn (768 , 2304 )
28572857 model = EinsumAddScalar ('bsh,ho->bso' ).eval ()
@@ -2876,7 +2876,7 @@ def _test_fp32(model_test, input1, input2, bias=None, kind_in_graph='ipex::einsu
28762876 input4 = torch .randn (2 , 4 , 128 , 768 )
28772877 model_v2 = EinsumAdd ("bnqd,bnkd->bnqk" )
28782878 _test_fp32 (model_v2 , input3 , input4 , bias1 )
2879-
2879+
28802880 bias1 = torch .randn (8 , 1 , 1 , 128 )
28812881 input3 = torch .randn (8 , 4 , 128 , 768 )
28822882 input4 = torch .randn (8 , 4 , 128 , 768 )
@@ -2900,34 +2900,34 @@ def _test_fp32(model_test, input1, input2, bias=None, kind_in_graph='ipex::einsu
29002900 input2 = torch .randn (1024 , 768 )
29012901 model_v2 = EinsumAdd ("mc,cn->mn" )
29022902 _test_fp32 (model_v2 , input1 , input2 , bias1 )
2903-
2903+
29042904 bias1 = torch .randn (1024 )
29052905 input1 = torch .randn (1024 , 1024 )
29062906 input2 = torch .randn (1024 , 1024 )
29072907 model_v2 = EinsumAdd ("mc,cn->nm" )
29082908 _test_fp32 (model_v2 , input1 , input2 , bias1 )
2909-
2909+
29102910 bias1 = torch .randn (768 )
29112911 input1 = torch .randn (2 , 128 , 1024 )
29122912 input2 = torch .randn (1024 , 23 , 768 )
29132913 model_v2 = EinsumAdd ("bqc,chv->bqhv" )
29142914 _test_fp32 (model_v2 , input1 , input2 , bias1 )
2915-
2915+
29162916 bias = torch .randn (768 )
29172917 input1 = torch .randn (2 , 128 , 16 , 64 )
29182918 input2 = torch .randn (16 ,64 , 768 )
29192919 model = EinsumAdd ("bqhc,hco->bqo" )
29202920 _test_fp32 (model , input1 , input2 , bias )
2921-
2921+
29222922 bias = torch .randn (8 )
29232923 input1 = torch .randn (8 )
29242924 input2 = torch .randn (8 )
29252925 model = EinsumAdd ("i,i->" )
29262926 _test_fp32 (model , input1 , input2 , bias )
2927-
2928- #the output of torch.einsum("ij,j") is tensor([])
2927+
2928+ #the output of torch.einsum("ij,j") is tensor([])
29292929 bias = torch .randn (1 )
2930- input1 = torch .randn (0 , 3 )
2930+ input1 = torch .randn (0 , 3 )
29312931 input2 = torch .randn (3 )
29322932 model = EinsumAdd (("ij,j" ))
29332933 _test_fp32 (model , input1 , input2 , bias )
@@ -3042,7 +3042,7 @@ def forward(self, x):
30423042 x1 = self .eltwise (x1 , ** self .params_dict )
30433043 return x1
30443044
3045- for eltwise in ['sigmoid' , 'tanh' , 'celu' , 'elu' , 'hardsigmoid' , ' hardswish' , 'hardtanh' , 'leaky_relu' , 'relu6' , 'relu' , 'rrelu' , 'selu' , 'silu' ]:
3045+ for eltwise in ['sigmoid' , 'tanh' , 'celu' , 'elu' , 'hardswish' , 'hardtanh' , 'leaky_relu' , 'relu6' , 'relu' , 'rrelu' , 'selu' , 'silu' ]:
30463046 eltwise_fn_name = eltwise + '_'
30473047 if eltwise in ['sigmoid' , 'tanh' , 'celu' , 'relu' , 'rrelu' , 'selu' ]:
30483048#use torch.sigmoid_(x)
@@ -3137,6 +3137,22 @@ def forward(self, x):
31373137 kind_not_in_graph = "aten::mul" ,
31383138 prec = 0.1 )
31393139
3140+ def test_hardsigmoid_mul (self ):
3141+ class HardsigmoidMul (nn .Module ):
3142+ def __init__ (self ) -> None :
3143+ super (HardsigmoidMul , self ).__init__ ()
3144+ self .hard_sigmoid = nn .Hardsigmoid ()
3145+
3146+ def forward (self , x ):
3147+ return self .hard_sigmoid (x ) * x
3148+
3149+ model = HardsigmoidMul ().eval ()
3150+ self ._test_output (
3151+ model ,
3152+ torch .randn (2 , 3 , 4 , 5 ),
3153+ kind_in_graph = "ipex::hardsigmoid" ,
3154+ kind_not_in_graph = "aten::hardsigmoid" )
3155+
31403156if __name__ == '__main__' :
31413157 torch .manual_seed (2020 )
31423158 test = unittest .main ()
0 commit comments