11import unittest
22import itertools
3- from functools import wraps
4-
53import torch
64import torch .nn as nn
75import torch .nn .functional as F
8- from test_jit_llga_utils import JitLlgaTestCase , run_tests , LLGA_FUSION_GROUP
6+ from test_jit_llga_utils import JitLlgaTestCase , run_tests , LLGA_FUSION_GROUP , llga_test_env
97from torch .testing ._internal .common_utils import TEST_SCIPY
108
119import intel_pytorch_extension as ipex
@@ -27,21 +25,6 @@ def get_eltwise_fn(name):
2725 else :
2826 raise NameError ('Eltwise function %s not found' % name )
2927
30- # For LLGA UT, disable the PyTorch profiling executor and the IPEX JIT opt
31- def llga_test_env (func ):
32- @wraps (func )
33- def wrapTheFunction (* args ):
34- # make sure that the profiling mode is turned on
35- torch ._C ._jit_set_profiling_mode (True )
36- torch ._C ._jit_set_profiling_executor (True )
37-
38- ipex .core ._jit_set_llga_enabled (True )
39- ipex .core .disable_jit_opt ()
40- func (* args )
41- ipex .core .enable_jit_opt ()
42- ipex .core ._jit_set_llga_enabled (False )
43- return wrapTheFunction
44-
4528class TestOp (JitLlgaTestCase ):
4629 @llga_test_env
4730 def test_conv2d_int8_in_f32_out (self ):
@@ -162,25 +145,6 @@ def test_max_pool2d(self):
162145 self .assertFused (graph , ['aten::max_pool2d' ])
163146 self .checkPatterns (graph , patterns )
164147
165- @llga_test_env
166- @unittest .skipIf (True , 'int8 adaptive_avg_pool2d is not supported in the backend' )
167- def test_adaptive_avg_pool2d (self ):
168- m = nn .AdaptiveAvgPool2d ((1 , 1 ))
169- N = torch .randint (3 , 10 , (1 ,)).item ()
170- C = torch .randint (3 , 10 , (1 ,)).item ()
171- x = torch .randn (N , C , 224 , 224 , dtype = torch .float32 ) * 100
172-
173- patterns = [
174- ["aten::quantize_per_tensor" ],
175- ["aten::dequantize" , "aten::adaptive_avg_pool2d" , "aten::quantize_per_tensor" ],
176- ["aten::dequantize" ]
177- ]
178- for qscheme in [torch .per_tensor_affine , torch .per_tensor_symmetric ]:
179- graph = self .checkQuantizeTrace (m , [x ], atol = 1e-1 , config_name = "adaptive_avg_pool2d" , qscheme = qscheme )
180- self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 3 )
181- self .assertFused (graph , ['aten::adaptive_avg_pool2d' , 'aten::quantize_per_tensor' , 'aten::dequantize' ])
182- self .checkPatterns (graph , patterns )
183-
184148class TestFusionPattern (JitLlgaTestCase ):
185149 @llga_test_env
186150 def test_conv2d_eltwise (self ):
@@ -408,7 +372,7 @@ def forward(self, x):
408372 new_x_shape = x .size ()[:- 1 ] + (3 , 5 )
409373 x = x .view (* new_x_shape )
410374 return x .permute (0 , 2 , 1 , 3 )
411-
375+
412376 x = torch .randn (5 , 10 , 15 )
413377 m = M ()
414378
@@ -434,7 +398,7 @@ def forward(self, x):
434398 x = self .conv1 (x )
435399 x = self .conv2 (x ).reshape (x .size (0 ), 4 , - 1 )
436400 return x
437-
401+
438402 x = torch .randn (15 , 4 , 28 , 28 )
439403 # change the size of the input, check the fallback
440404 x_var = torch .randn (7 , 4 , 16 , 16 )
0 commit comments