@@ -73,32 +73,32 @@ def test_conv2d_int8_in_f32_out(self):
7373 groups = g ,
7474 bias = bias )
7575 x = torch .rand (1 , in_channels * g , spatial , spatial )
76-
77- graph = self .checkQuantizeTrace (m , [x ], atol = 2e-1 , config_name = "conv2d" )
78- self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 2 )
79- self .assertFused (graph , ['aten::_convolution' , 'aten::quantize_per_tensor' , 'aten::quantize_per_channel' ])
80-
8176 patterns = [
8277 ["aten::quantize_per_tensor" ],
8378 ["aten::quantize_per_channel" , "aten::dequantize" , "aten::_convolution" ]
8479 ]
85- self .checkPatterns (graph , patterns )
80+ #TODO: enable torch.per_tensor_symmetric case.
81+ for qscheme in [torch .per_tensor_affine ]:
82+ graph = self .checkQuantizeTrace (m , [x ], atol = 2e-1 , config_name = "conv2d" , qscheme = qscheme )
83+ self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 2 )
84+ self .assertFused (graph , ['aten::_convolution' , 'aten::quantize_per_tensor' , 'aten::quantize_per_channel' ])
85+ self .checkPatterns (graph , patterns )
8686
8787 @llga_test_env
8888 def test_linear_int8_in_f32_out (self ):
8989 for bias in [True , False ]:
9090 x = torch .rand (32 , 28 )
9191 m = torch .nn .Linear (in_features = 28 , out_features = 64 , bias = bias )
92-
93- graph = self .checkQuantizeTrace (m , [x ], atol = 1e-1 , config_name = "linear" )
94- self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 2 )
95- self .assertFused (graph , ['aten::linear' , 'aten::quantize_per_tensor' , 'aten::quantize_per_channel' , 'aten::dequantize' ])
9692
9793 patterns = [
9894 ["aten::quantize_per_tensor" ],
9995 ["aten::quantize_per_channel" , "aten::dequantize" , "aten::linear" ],
10096 ]
101- self .checkPatterns (graph , patterns )
97+ for qscheme in [torch .per_tensor_affine , torch .per_tensor_symmetric ]:
98+ graph = self .checkQuantizeTrace (m , [x ], atol = 1e-1 , config_name = "linear" , qscheme = qscheme )
99+ self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 2 )
100+ self .assertFused (graph , ['aten::linear' , 'aten::quantize_per_tensor' , 'aten::quantize_per_channel' , 'aten::dequantize' ])
101+ self .checkPatterns (graph , patterns )
102102
103103 @llga_test_env
104104 def test_linear_int8_in_int8_out (self ):
@@ -117,17 +117,19 @@ def forward(self, x, y):
117117 x = torch .randn (2 , 15 )
118118 y = torch .randn (2 , 20 )
119119 m = M (bias )
120- graph = self .checkQuantizeTrace (m , [x , y ], atol = 2e-1 , config_name = "linear_int8" )
121- self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 3 )
122- self .assertFused (graph , ['aten::linear' ,
123- 'aten::quantize_per_tensor' , 'aten::quantize_per_channel' , 'aten::dequantize' ])
124120
125121 patterns = [
126122 ["aten::quantize_per_tensor" ],
127123 ["aten::quantize_per_channel" , "aten::dequantize" , "aten::linear" , "aten::quantize_per_tensor" ],
128124 ["aten::quantize_per_channel" , "aten::dequantize" , "aten::linear" ]
129125 ]
130- self .checkPatterns (graph , patterns )
126+
127+ for qscheme in [torch .per_tensor_affine , torch .per_tensor_symmetric ]:
128+ graph = self .checkQuantizeTrace (m , [x , y ], atol = 2e-1 , config_name = "linear_int8" , qscheme = qscheme )
129+ self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 3 )
130+ self .assertFused (graph , ['aten::linear' ,
131+ 'aten::quantize_per_tensor' , 'aten::quantize_per_channel' , 'aten::dequantize' ])
132+ self .checkPatterns (graph , patterns )
131133
132134 @llga_test_env
133135 def test_max_pool2d (self ):
@@ -153,16 +155,16 @@ def test_max_pool2d(self):
153155 ceil_mode = ceil_mode )
154156 x = torch .rand (1 , 3 , spatial , spatial )
155157
156- graph = self .checkQuantizeTrace (m , [x ], atol = 1e-1 , config_name = "max_pool2d" )
157- self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 3 )
158- self .assertFused (graph , ['aten::max_pool2d' , 'aten::quantize_per_tensor' , 'aten::dequantize' ])
159-
160158 patterns = [
161159 ["aten::quantize_per_tensor" ],
162160 ["aten::dequantize" , "aten::max_pool2d" , "aten::quantize_per_tensor" ],
163161 ["aten::dequantize" ]
164162 ]
165- self .checkPatterns (graph , patterns )
163+ for qscheme in [torch .per_tensor_affine , torch .per_tensor_symmetric ]:
164+ graph = self .checkQuantizeTrace (m , [x ], atol = 1e-1 , config_name = "max_pool2d" , qscheme = qscheme )
165+ self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 3 )
166+ self .assertFused (graph , ['aten::max_pool2d' , 'aten::quantize_per_tensor' , 'aten::dequantize' ])
167+ self .checkPatterns (graph , patterns )
166168
167169 @llga_test_env
168170 @unittest .skipIf (True , 'int8 adaptive_avg_pool2d is not supported in the backend' )
@@ -172,16 +174,16 @@ def test_adaptive_avg_pool2d(self):
172174 C = torch .randint (3 , 10 , (1 ,)).item ()
173175 x = torch .randn (N , C , 224 , 224 , dtype = torch .float32 ) * 100
174176
175- graph = self .checkQuantizeTrace (m , [x ], atol = 1e-1 , config_name = "adaptive_avg_pool2d" )
176- self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 3 )
177- self .assertFused (graph , ['aten::adaptive_avg_pool2d' , 'aten::quantize_per_tensor' , 'aten::dequantize' ])
178-
179177 patterns = [
180178 ["aten::quantize_per_tensor" ],
181179 ["aten::dequantize" , "aten::adaptive_avg_pool2d" , "aten::quantize_per_tensor" ],
182180 ["aten::dequantize" ]
183181 ]
184- self .checkPatterns (graph , patterns )
182+ for qscheme in [torch .per_tensor_affine , torch .per_tensor_symmetric ]:
183+ graph = self .checkQuantizeTrace (m , [x ], atol = 1e-1 , config_name = "adaptive_avg_pool2d" , qscheme = qscheme )
184+ self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 3 )
185+ self .assertFused (graph , ['aten::adaptive_avg_pool2d' , 'aten::quantize_per_tensor' , 'aten::dequantize' ])
186+ self .checkPatterns (graph , patterns )
185187
186188class TestFusionPattern (JitLlgaTestCase ):
187189 @llga_test_env
@@ -206,17 +208,17 @@ def forward(self, x):
206208
207209 m = M (eltwise_fn )
208210 x = torch .rand (1 , 32 , 28 , 28 )
209-
210- graph = self .checkQuantizeTrace (m , [x ], atol = 2e-1 , config_name = "conv2d_eltwise" )
211- self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 3 )
212- self .assertFused (graph , ['aten::_convolution' , 'aten::' + eltwise , 'aten::quantize_per_tensor' , 'aten::quantize_per_channel' , 'aten::dequantize' ])
213211
214212 patterns = [
215213 ["aten::quantize_per_tensor" ],
216214 ["aten::quantize_per_channel" , "aten::dequantize" , "aten::_convolution" , 'aten::' + eltwise , "aten::quantize_per_tensor" ], # inplace op will become outplace op on the JIT graph
217215 ["aten::quantize_per_channel" , "aten::dequantize" , "aten::_convolution" ]
218216 ]
219- self .checkPatterns (graph , patterns )
217+ for qscheme in [torch .per_tensor_affine , torch .per_tensor_symmetric ]:
218+ graph = self .checkQuantizeTrace (m , [x ], atol = 2e-1 , config_name = "conv2d_eltwise" , qscheme = qscheme )
219+ self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 3 )
220+ self .assertFused (graph , ['aten::_convolution' , 'aten::' + eltwise , 'aten::quantize_per_tensor' , 'aten::quantize_per_channel' , 'aten::dequantize' ])
221+ self .checkPatterns (graph , patterns )
220222
221223 @llga_test_env
222224 def test_conv2d_bn (self ):
@@ -235,16 +237,17 @@ def forward(self, x):
235237 x = torch .rand (1 , 32 , 16 , 16 )
236238 # TODO: This shape will fail
237239 # x = torch.rand(1, 32, 28, 28)
238-
239- graph = self .checkQuantizeTrace (m , [x ], atol = 1e-1 , folding = True , config_name = "conv2d_bn" )
240- self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 2 )
241- self .assertFused (graph , ['aten::_convolution' , 'aten::quantize_per_tensor' , 'aten::quantize_per_channel' ])
242240
243241 patterns = [
244242 ["aten::quantize_per_tensor" ],
245243 ["aten::quantize_per_channel" , "aten::dequantize" , "aten::_convolution" ]
246244 ]
247- self .checkPatterns (graph , patterns )
245+ # TODO: add torch.per_tensor_symmetric case.
246+ for qscheme in [torch .per_tensor_affine ]:
247+ graph = self .checkQuantizeTrace (m , [x ], atol = 1e-1 , folding = True , config_name = "conv2d_bn" , qscheme = qscheme )
248+ self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 2 )
249+ self .assertFused (graph , ['aten::_convolution' , 'aten::quantize_per_tensor' , 'aten::quantize_per_channel' ])
250+ self .checkPatterns (graph , patterns )
248251
249252 @llga_test_env
250253 def test_conv2d_bn_relu (self ):
@@ -262,17 +265,17 @@ def forward(self, x):
262265
263266 m = M ().eval ()
264267 x = torch .rand (1 , 32 , 28 , 28 )
265- graph = self .checkQuantizeTrace (m , [x ], atol = 1e-1 , folding = True , config_name = "conv2d_bn_relu" )
266- self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 3 )
267- self .assertFused (graph , ['aten::_convolution' , 'aten::relu' ,
268- 'aten::quantize_per_tensor' , 'aten::quantize_per_channel' , 'aten::dequantize' ])
269-
270268 patterns = [
271269 ["aten::quantize_per_tensor" ],
272270 ["aten::quantize_per_channel" , "aten::dequantize" , "aten::_convolution" , "aten::relu" , "aten::quantize_per_tensor" ],
273271 ["aten::dequantize" ]
274272 ]
275- self .checkPatterns (graph , patterns )
273+ for qscheme in [torch .per_tensor_affine , torch .per_tensor_symmetric ]:
274+ graph = self .checkQuantizeTrace (m , [x ], atol = 1e-1 , folding = True , config_name = "conv2d_bn_relu" , qscheme = qscheme )
275+ self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 3 )
276+ self .assertFused (graph , ['aten::_convolution' , 'aten::relu' ,
277+ 'aten::quantize_per_tensor' , 'aten::quantize_per_channel' , 'aten::dequantize' ])
278+ self .checkPatterns (graph , patterns )
276279
277280 @llga_test_env
278281 def test_linear_eltwise (self ):
@@ -299,17 +302,16 @@ def forward(self, x):
299302 eltwise_fn = get_eltwise_fn (eltwise )
300303 m = M (eltwise_fn , has_bias )
301304 x = torch .rand (32 , 28 , requires_grad = False )
302-
303- graph = self .checkQuantizeTrace (m , [x ], atol = 1e-1 , config_name = "linear_eltwise" )
304- self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 3 )
305- self .assertFused (graph , ['aten::' + eltwise ])
306-
307305 patterns = [
308306 ["aten::quantize_per_tensor" ],
309307 ["aten::quantize_per_channel" , "aten::dequantize" , "aten::linear" , "aten::" + eltwise , "aten::quantize_per_tensor" ],
310308 ["aten::dequantize" ]
311309 ]
312- self .checkPatterns (graph , patterns )
310+ for qscheme in [torch .per_tensor_affine , torch .per_tensor_symmetric ]:
311+ graph = self .checkQuantizeTrace (m , [x ], atol = 1e-1 , config_name = "linear_eltwise" , qscheme = qscheme )
312+ self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 3 )
313+ self .assertFused (graph , ['aten::' + eltwise ])
314+ self .checkPatterns (graph , patterns )
313315
314316 @llga_test_env
315317 def test_conv2d_sum (self ):
@@ -338,17 +340,17 @@ def forward(self, x, y):
338340 m = M (bias ).eval ()
339341 x = torch .rand (1 , 32 , 16 , 16 , requires_grad = False )
340342 y = torch .rand (1 , 32 , 16 , 16 , requires_grad = False )
341- graph = self .checkQuantizeTrace (m , [x , y ], folding = True , atol = 1e-1 , config_name = "conv2d_sum" )
342- self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 5 )
343-
344343 patterns = [
345344 ["aten::quantize_per_tensor" ],
346345 ["aten::quantize_per_tensor" ],
347346 ["aten::quantize_per_channel" , "aten::dequantize" , "aten::_convolution" , "aten::quantize_per_tensor" ],
348347 ["aten::quantize_per_channel" , "aten::dequantize" , "aten::_convolution" , "aten::relu" , "aten::add" , "aten::quantize_per_tensor" ],
349348 ["aten::quantize_per_channel" , "aten::dequantize" , "aten::_convolution" ]
350349 ]
351- self .checkPatterns (graph , patterns )
350+ for qscheme in [torch .per_tensor_affine , torch .per_tensor_symmetric ]:
351+ graph = self .checkQuantizeTrace (m , [x , y ], folding = True , atol = 1e-1 , config_name = "conv2d_sum" , qscheme = qscheme )
352+ self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 5 )
353+ self .checkPatterns (graph , patterns )
352354
353355 @llga_test_env
354356 def test_linear_dropout_sum (self ):
@@ -368,17 +370,17 @@ def forward(self, x, y):
368370 x = torch .randn (2 , 15 )
369371 y = torch .randn (2 , 20 )
370372 m = M ()
371- graph = self .checkQuantizeTrace (m , [x , y ], atol = 2e-1 , remove_dropout = True , config_name = "linear_dropout_sum" )
372- self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 4 )
373- self .assertFused (graph , ['aten::linear' , 'aten::add' ,
374- 'aten::quantize_per_tensor' , 'aten::quantize_per_channel' , 'aten::dequantize' ])
375-
376373 patterns = [
377374 ["aten::quantize_per_tensor" ],
378375 ["aten::quantize_per_tensor" ],
379376 ["aten::quantize_per_channel" , "aten::dequantize" , "aten::linear" , "aten::add" , "aten::quantize_per_tensor" ],
380377 ["aten::quantize_per_channel" , "aten::dequantize" , "aten::linear" ]
381378 ]
379+ for qscheme in [torch .per_tensor_affine , torch .per_tensor_symmetric ]:
380+ graph = self .checkQuantizeTrace (m , [x , y ], atol = 2e-1 , remove_dropout = True , config_name = "linear_dropout_sum" , qscheme = qscheme )
381+ self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 4 )
382+ self .assertFused (graph , ['aten::linear' , 'aten::add' ,
383+ 'aten::quantize_per_tensor' , 'aten::quantize_per_channel' , 'aten::dequantize' ])
382384 self .checkPatterns (graph , patterns )
383385
384386 # TODO: check patterns when oneDNN support sum post_ops with zps
@@ -407,20 +409,19 @@ def forward(self, x):
407409 y = self .conv2 (x )
408410 y = y .reshape (x .size (0 ), - 1 )
409411 return y
410-
412+
411413 m = M ()
412414 x = torch .rand (1 , 32 , 28 , 28 )
413-
414- graph = self .checkQuantizeTrace (m , [x ], atol = 2e-1 , config_name = "defer_size" )
415- self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 3 )
416- self .assertFused (graph , ['aten::_convolution' , 'aten::relu' , 'aten::quantize_per_tensor' , 'aten::quantize_per_channel' , 'aten::dequantize' ])
417-
418415 patterns = [
419416 ["aten::quantize_per_tensor" ],
420417 ["aten::quantize_per_channel" , "aten::dequantize" , "aten::_convolution" , 'aten::relu' , "aten::quantize_per_tensor" ],
421418 ["aten::quantize_per_channel" , "aten::dequantize" , "aten::_convolution" ]
422419 ]
423- self .checkPatterns (graph , patterns )
420+ for qscheme in [torch .per_tensor_affine , torch .per_tensor_symmetric ]:
421+ graph = self .checkQuantizeTrace (m , [x ], atol = 2e-1 , config_name = "defer_size" , qscheme = qscheme )
422+ self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 3 )
423+ self .assertFused (graph , ['aten::_convolution' , 'aten::relu' , 'aten::quantize_per_tensor' , 'aten::quantize_per_channel' , 'aten::dequantize' ])
424+ self .checkPatterns (graph , patterns )
424425
425426class TestModel (JitLlgaTestCase ):
426427 @skipIfNoTorchVision
@@ -429,13 +430,14 @@ def _test_vision(self, model_name):
429430 m = getattr (torchvision .models , model_name )().eval ()
430431 x = torch .rand (1 , 3 , 224 , 224 ) / 10
431432
432- graph = self .checkQuantizeTrace (m , [x ], atol = 2e-1 , folding = True , config_name = model_name )
433-
434- # TODO: aten::adaptive_avg_pool2d also need to be fused once backend supported it
435- self .assertFused (graph , ['aten::_convolution' , 'aten::relu' ,
436- 'aten::max_pool2d' , 'aten::linear'
437- 'aten::quantize_per_tensor' , 'aten::quantize_per_channel' ,
438- 'aten::dequantize' ])
433+ for qscheme in [torch .per_tensor_affine , torch .per_tensor_symmetric ]:
434+ graph = self .checkQuantizeTrace (m , [x ], atol = 2e-1 , folding = True , config_name = model_name , qscheme = qscheme )
435+
436+ # TODO: aten::adaptive_avg_pool2d also need to be fused once backend supported it
437+ self .assertFused (graph , ['aten::_convolution' , 'aten::relu' ,
438+ 'aten::max_pool2d' , 'aten::linear'
439+ 'aten::quantize_per_tensor' , 'aten::quantize_per_channel' ,
440+ 'aten::dequantize' ])
439441
440442
441443for model_name , enabled in [
0 commit comments