33from torch .testing ._internal .common_utils import TestCase
44
55import time
6- import intel_extension_for_pytorch # noqa
6+ import intel_extension_for_pytorch # noqa
77
88from torch .quantization .quantize_jit import (
99 convert_jit ,
1010 prepare_jit ,
1111)
12+ import pytest
1213
1314checking_atol = 3e-2
1415checking_rtol = 3e-2
1516
17+
1618class ConvSigmoid (torch .nn .Module ):
1719 def __init__ (self ):
1820 super ().__init__ ()
@@ -27,6 +29,7 @@ def forward(self, x):
2729 x = self .block (x )
2830 return x
2931
32+
3033class ConvLeakyRelu (torch .nn .Module ):
3134 def __init__ (self ):
3235 super ().__init__ ()
@@ -39,6 +42,7 @@ def forward(self, x):
3942 x = self .block (x )
4043 return x
4144
45+
4246class Mish (torch .nn .Module ):
4347 def __init__ (self ):
4448 super ().__init__ ()
@@ -47,6 +51,7 @@ def forward(self, x):
4751 x = x * (torch .tanh (torch .nn .functional .softplus (x )))
4852 return x
4953
54+
5055class ConvMish (torch .nn .Module ):
5156 def __init__ (self ):
5257 super ().__init__ ()
@@ -60,6 +65,7 @@ def forward(self, x):
6065 x = self .conv2 (x )
6166 return x
6267
68+
6369class ConvMishAdd (torch .nn .Module ):
6470 def __init__ (self ):
6571 super ().__init__ ()
@@ -75,10 +81,12 @@ def forward(self, x):
7581 x = x + h
7682 return x
7783
84+
7885def impe_fp32_model (model , device , test_input ):
7986 modelImpe = model .to (device )
8087 return modelImpe (test_input .clone ().to (device ))
8188
89+
8290def impe_int8_model (model , device , test_input ):
8391 modelImpe = torch .quantization .QuantWrapper (model )
8492 modelImpe = modelImpe .to (device )
@@ -104,6 +112,7 @@ def impe_int8_model(model, device, test_input):
104112
105113 return modelImpe (test_input .to (device ))
106114
115+
107116def trace_int8_model (model , device , test_input ):
108117 model = model .to (device )
109118 modelJit = torch .jit .trace (model , test_input .to (device ))
@@ -122,7 +131,6 @@ def trace_int8_model(model, device, test_input):
122131 weight = torch .quantization .default_weight_observer
123132 )
124133
125-
126134 modelJit = prepare_jit (modelJit , {'' : qconfig_u8 }, True )
127135
128136 # do calibration
@@ -149,7 +157,9 @@ def trace_int8_model(model, device, test_input):
149157 print ("finish " , device , " testing......." )
150158 return output
151159
160+
152161class TestTorchMethod (TestCase ):
162+ @pytest .mark .skipif (not torch .xpu .utils .has_fp64_dtype (), reason = "fp64 not support by this device" )
153163 def test_qConv2d_sigmoid (self , dtype = torch .float ):
154164 model = ConvSigmoid ()
155165 model1 = copy .deepcopy (model )
0 commit comments