Skip to content

Commit ea2059e

Browse files
enable symmetric quantization flow (#133)
* enable symmetric quantization flow * change code style according to clang-format * simply the code and change the code format
1 parent 5d14f80 commit ea2059e

File tree

7 files changed

+196
-113
lines changed

7 files changed

+196
-113
lines changed

intel_pytorch_extension_py/conf.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,24 @@
33
import torch
44
import _torch_ipex as core
55

6+
7+
qscheme_dict ={torch.per_tensor_affine:0,
8+
torch.per_channel_affine:1,
9+
torch.per_tensor_symmetric:2,
10+
torch.per_channel_symmetric:3,
11+
torch.torch.per_channel_affine_float_qparams:4}
12+
613
class AmpConf(object):
7-
def __init__(self, mixed_dtype = torch.bfloat16, configure_file = None):
14+
def __init__(self, mixed_dtype=torch.bfloat16, configure_file=None, qscheme=torch.per_tensor_affine):
815
self.dtype = mixed_dtype
916
self.configure_file = configure_file
1017

1118
if self.dtype == torch.int8:
1219
core.clear_indicators()
20+
assert qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric], \
21+
"qscheme is only support torch.per_tensor_affine and torch.per_tensor_symmetric now"
22+
core.set_int8_qscheme(qscheme_dict[qscheme])
23+
1324
# for int8 path, if user give a exited configure file, load it.
1425
if self.configure_file != None and self.dtype == torch.int8:
1526
if os.path.exists(self.configure_file) and os.stat(self.configure_file).st_size != 0:

tests/cpu/test_jit_llga_quantization_fuser.py

Lines changed: 73 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -73,32 +73,32 @@ def test_conv2d_int8_in_f32_out(self):
7373
groups=g,
7474
bias=bias)
7575
x = torch.rand(1, in_channels * g, spatial, spatial)
76-
77-
graph = self.checkQuantizeTrace(m, [x], atol=2e-1, config_name="conv2d")
78-
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
79-
self.assertFused(graph, ['aten::_convolution', 'aten::quantize_per_tensor', 'aten::quantize_per_channel'])
80-
8176
patterns = [
8277
["aten::quantize_per_tensor"],
8378
["aten::quantize_per_channel", "aten::dequantize", "aten::_convolution"]
8479
]
85-
self.checkPatterns(graph, patterns)
80+
#TODO: enable torch.per_tensor_symmetric case.
81+
for qscheme in [torch.per_tensor_affine]:
82+
graph = self.checkQuantizeTrace(m, [x], atol=2e-1, config_name="conv2d", qscheme=qscheme)
83+
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
84+
self.assertFused(graph, ['aten::_convolution', 'aten::quantize_per_tensor', 'aten::quantize_per_channel'])
85+
self.checkPatterns(graph, patterns)
8686

8787
@llga_test_env
8888
def test_linear_int8_in_f32_out(self):
8989
for bias in [True, False]:
9090
x = torch.rand(32, 28)
9191
m = torch.nn.Linear(in_features=28, out_features=64, bias=bias)
92-
93-
graph = self.checkQuantizeTrace(m, [x], atol=1e-1, config_name="linear")
94-
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
95-
self.assertFused(graph, ['aten::linear', 'aten::quantize_per_tensor', 'aten::quantize_per_channel', 'aten::dequantize'])
9692

9793
patterns = [
9894
["aten::quantize_per_tensor"],
9995
["aten::quantize_per_channel", "aten::dequantize", "aten::linear"],
10096
]
101-
self.checkPatterns(graph, patterns)
97+
for qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]:
98+
graph = self.checkQuantizeTrace(m, [x], atol=1e-1, config_name="linear", qscheme=qscheme)
99+
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
100+
self.assertFused(graph, ['aten::linear', 'aten::quantize_per_tensor', 'aten::quantize_per_channel', 'aten::dequantize'])
101+
self.checkPatterns(graph, patterns)
102102

103103
@llga_test_env
104104
def test_linear_int8_in_int8_out(self):
@@ -117,17 +117,19 @@ def forward(self, x, y):
117117
x = torch.randn(2, 15)
118118
y = torch.randn(2, 20)
119119
m = M(bias)
120-
graph = self.checkQuantizeTrace(m, [x, y], atol=2e-1, config_name="linear_int8")
121-
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 3)
122-
self.assertFused(graph, ['aten::linear',
123-
'aten::quantize_per_tensor', 'aten::quantize_per_channel', 'aten::dequantize'])
124120

125121
patterns = [
126122
["aten::quantize_per_tensor"],
127123
["aten::quantize_per_channel", "aten::dequantize", "aten::linear", "aten::quantize_per_tensor"],
128124
["aten::quantize_per_channel", "aten::dequantize", "aten::linear"]
129125
]
130-
self.checkPatterns(graph, patterns)
126+
127+
for qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]:
128+
graph = self.checkQuantizeTrace(m, [x, y], atol=2e-1, config_name="linear_int8", qscheme=qscheme)
129+
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 3)
130+
self.assertFused(graph, ['aten::linear',
131+
'aten::quantize_per_tensor', 'aten::quantize_per_channel', 'aten::dequantize'])
132+
self.checkPatterns(graph, patterns)
131133

132134
@llga_test_env
133135
def test_max_pool2d(self):
@@ -153,16 +155,16 @@ def test_max_pool2d(self):
153155
ceil_mode=ceil_mode)
154156
x = torch.rand(1, 3, spatial, spatial)
155157

156-
graph = self.checkQuantizeTrace(m, [x], atol=1e-1, config_name="max_pool2d")
157-
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 3)
158-
self.assertFused(graph, ['aten::max_pool2d', 'aten::quantize_per_tensor', 'aten::dequantize'])
159-
160158
patterns = [
161159
["aten::quantize_per_tensor"],
162160
["aten::dequantize", "aten::max_pool2d", "aten::quantize_per_tensor"],
163161
["aten::dequantize"]
164162
]
165-
self.checkPatterns(graph, patterns)
163+
for qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]:
164+
graph = self.checkQuantizeTrace(m, [x], atol=1e-1, config_name="max_pool2d", qscheme=qscheme)
165+
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 3)
166+
self.assertFused(graph, ['aten::max_pool2d', 'aten::quantize_per_tensor', 'aten::dequantize'])
167+
self.checkPatterns(graph, patterns)
166168

167169
@llga_test_env
168170
@unittest.skipIf(True, 'int8 adaptive_avg_pool2d is not supported in the backend')
@@ -172,16 +174,16 @@ def test_adaptive_avg_pool2d(self):
172174
C = torch.randint(3, 10, (1,)).item()
173175
x = torch.randn(N, C, 224, 224, dtype=torch.float32) * 100
174176

175-
graph = self.checkQuantizeTrace(m, [x], atol=1e-1, config_name="adaptive_avg_pool2d")
176-
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 3)
177-
self.assertFused(graph, ['aten::adaptive_avg_pool2d', 'aten::quantize_per_tensor', 'aten::dequantize'])
178-
179177
patterns = [
180178
["aten::quantize_per_tensor"],
181179
["aten::dequantize", "aten::adaptive_avg_pool2d", "aten::quantize_per_tensor"],
182180
["aten::dequantize"]
183181
]
184-
self.checkPatterns(graph, patterns)
182+
for qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]:
183+
graph = self.checkQuantizeTrace(m, [x], atol=1e-1, config_name="adaptive_avg_pool2d", qscheme=qscheme)
184+
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 3)
185+
self.assertFused(graph, ['aten::adaptive_avg_pool2d', 'aten::quantize_per_tensor', 'aten::dequantize'])
186+
self.checkPatterns(graph, patterns)
185187

186188
class TestFusionPattern(JitLlgaTestCase):
187189
@llga_test_env
@@ -206,17 +208,17 @@ def forward(self, x):
206208

207209
m = M(eltwise_fn)
208210
x = torch.rand(1, 32, 28, 28)
209-
210-
graph = self.checkQuantizeTrace(m, [x], atol=2e-1, config_name="conv2d_eltwise")
211-
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 3)
212-
self.assertFused(graph, ['aten::_convolution', 'aten::' + eltwise, 'aten::quantize_per_tensor', 'aten::quantize_per_channel', 'aten::dequantize'])
213211

214212
patterns = [
215213
["aten::quantize_per_tensor"],
216214
["aten::quantize_per_channel", "aten::dequantize", "aten::_convolution", 'aten::' + eltwise, "aten::quantize_per_tensor"], # inplace op will become outplace op on the JIT graph
217215
["aten::quantize_per_channel", "aten::dequantize", "aten::_convolution"]
218216
]
219-
self.checkPatterns(graph, patterns)
217+
for qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]:
218+
graph = self.checkQuantizeTrace(m, [x], atol=2e-1, config_name="conv2d_eltwise", qscheme=qscheme)
219+
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 3)
220+
self.assertFused(graph, ['aten::_convolution', 'aten::' + eltwise, 'aten::quantize_per_tensor', 'aten::quantize_per_channel', 'aten::dequantize'])
221+
self.checkPatterns(graph, patterns)
220222

221223
@llga_test_env
222224
def test_conv2d_bn(self):
@@ -235,16 +237,17 @@ def forward(self, x):
235237
x = torch.rand(1, 32, 16, 16)
236238
# TODO: This shape will fail
237239
# x = torch.rand(1, 32, 28, 28)
238-
239-
graph = self.checkQuantizeTrace(m, [x], atol=1e-1, folding=True, config_name="conv2d_bn")
240-
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
241-
self.assertFused(graph, ['aten::_convolution', 'aten::quantize_per_tensor', 'aten::quantize_per_channel'])
242240

243241
patterns = [
244242
["aten::quantize_per_tensor"],
245243
["aten::quantize_per_channel", "aten::dequantize", "aten::_convolution"]
246244
]
247-
self.checkPatterns(graph, patterns)
245+
# TODO: add torch.per_tensor_symmetric case.
246+
for qscheme in [torch.per_tensor_affine]:
247+
graph = self.checkQuantizeTrace(m, [x], atol=1e-1, folding=True, config_name="conv2d_bn", qscheme=qscheme)
248+
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
249+
self.assertFused(graph, ['aten::_convolution', 'aten::quantize_per_tensor', 'aten::quantize_per_channel'])
250+
self.checkPatterns(graph, patterns)
248251

249252
@llga_test_env
250253
def test_conv2d_bn_relu(self):
@@ -262,17 +265,17 @@ def forward(self, x):
262265

263266
m = M().eval()
264267
x = torch.rand(1, 32, 28, 28)
265-
graph = self.checkQuantizeTrace(m, [x], atol=1e-1, folding=True, config_name="conv2d_bn_relu")
266-
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 3)
267-
self.assertFused(graph, ['aten::_convolution', 'aten::relu',
268-
'aten::quantize_per_tensor', 'aten::quantize_per_channel', 'aten::dequantize'])
269-
270268
patterns = [
271269
["aten::quantize_per_tensor"],
272270
["aten::quantize_per_channel", "aten::dequantize", "aten::_convolution", "aten::relu", "aten::quantize_per_tensor"],
273271
["aten::dequantize"]
274272
]
275-
self.checkPatterns(graph, patterns)
273+
for qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]:
274+
graph = self.checkQuantizeTrace(m, [x], atol=1e-1, folding=True, config_name="conv2d_bn_relu", qscheme=qscheme)
275+
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 3)
276+
self.assertFused(graph, ['aten::_convolution', 'aten::relu',
277+
'aten::quantize_per_tensor', 'aten::quantize_per_channel', 'aten::dequantize'])
278+
self.checkPatterns(graph, patterns)
276279

277280
@llga_test_env
278281
def test_linear_eltwise(self):
@@ -299,17 +302,16 @@ def forward(self, x):
299302
eltwise_fn = get_eltwise_fn(eltwise)
300303
m = M(eltwise_fn, has_bias)
301304
x = torch.rand(32, 28, requires_grad=False)
302-
303-
graph = self.checkQuantizeTrace(m, [x], atol=1e-1, config_name="linear_eltwise")
304-
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 3)
305-
self.assertFused(graph, ['aten::' + eltwise])
306-
307305
patterns = [
308306
["aten::quantize_per_tensor"],
309307
["aten::quantize_per_channel", "aten::dequantize", "aten::linear", "aten::" + eltwise, "aten::quantize_per_tensor"],
310308
["aten::dequantize"]
311309
]
312-
self.checkPatterns(graph, patterns)
310+
for qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]:
311+
graph = self.checkQuantizeTrace(m, [x], atol=1e-1, config_name="linear_eltwise", qscheme=qscheme)
312+
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 3)
313+
self.assertFused(graph, ['aten::' + eltwise])
314+
self.checkPatterns(graph, patterns)
313315

314316
@llga_test_env
315317
def test_conv2d_sum(self):
@@ -338,17 +340,17 @@ def forward(self, x, y):
338340
m = M(bias).eval()
339341
x = torch.rand(1, 32, 16, 16, requires_grad=False)
340342
y = torch.rand(1, 32, 16, 16, requires_grad=False)
341-
graph = self.checkQuantizeTrace(m, [x, y], folding=True, atol=1e-1, config_name="conv2d_sum")
342-
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 5)
343-
344343
patterns = [
345344
["aten::quantize_per_tensor"],
346345
["aten::quantize_per_tensor"],
347346
["aten::quantize_per_channel", "aten::dequantize", "aten::_convolution", "aten::quantize_per_tensor"],
348347
["aten::quantize_per_channel", "aten::dequantize", "aten::_convolution", "aten::relu", "aten::add", "aten::quantize_per_tensor"],
349348
["aten::quantize_per_channel", "aten::dequantize", "aten::_convolution"]
350349
]
351-
self.checkPatterns(graph, patterns)
350+
for qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]:
351+
graph = self.checkQuantizeTrace(m, [x, y], folding=True, atol=1e-1, config_name="conv2d_sum", qscheme=qscheme)
352+
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 5)
353+
self.checkPatterns(graph, patterns)
352354

353355
@llga_test_env
354356
def test_linear_dropout_sum(self):
@@ -368,17 +370,17 @@ def forward(self, x, y):
368370
x = torch.randn(2, 15)
369371
y = torch.randn(2, 20)
370372
m = M()
371-
graph = self.checkQuantizeTrace(m, [x, y], atol=2e-1, remove_dropout=True, config_name="linear_dropout_sum")
372-
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 4)
373-
self.assertFused(graph, ['aten::linear', 'aten::add',
374-
'aten::quantize_per_tensor', 'aten::quantize_per_channel', 'aten::dequantize'])
375-
376373
patterns = [
377374
["aten::quantize_per_tensor"],
378375
["aten::quantize_per_tensor"],
379376
["aten::quantize_per_channel", "aten::dequantize", "aten::linear", "aten::add", "aten::quantize_per_tensor"],
380377
["aten::quantize_per_channel", "aten::dequantize", "aten::linear"]
381378
]
379+
for qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]:
380+
graph = self.checkQuantizeTrace(m, [x, y], atol=2e-1, remove_dropout=True, config_name="linear_dropout_sum", qscheme=qscheme)
381+
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 4)
382+
self.assertFused(graph, ['aten::linear', 'aten::add',
383+
'aten::quantize_per_tensor', 'aten::quantize_per_channel', 'aten::dequantize'])
382384
self.checkPatterns(graph, patterns)
383385

384386
# TODO: check patterns when oneDNN support sum post_ops with zps
@@ -407,20 +409,19 @@ def forward(self, x):
407409
y = self.conv2(x)
408410
y = y.reshape(x.size(0), -1)
409411
return y
410-
412+
411413
m = M()
412414
x = torch.rand(1, 32, 28, 28)
413-
414-
graph = self.checkQuantizeTrace(m, [x], atol=2e-1, config_name="defer_size")
415-
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 3)
416-
self.assertFused(graph, ['aten::_convolution', 'aten::relu', 'aten::quantize_per_tensor', 'aten::quantize_per_channel', 'aten::dequantize'])
417-
418415
patterns = [
419416
["aten::quantize_per_tensor"],
420417
["aten::quantize_per_channel", "aten::dequantize", "aten::_convolution", 'aten::relu', "aten::quantize_per_tensor"],
421418
["aten::quantize_per_channel", "aten::dequantize", "aten::_convolution"]
422419
]
423-
self.checkPatterns(graph, patterns)
420+
for qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]:
421+
graph = self.checkQuantizeTrace(m, [x], atol=2e-1, config_name="defer_size", qscheme=qscheme)
422+
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 3)
423+
self.assertFused(graph, ['aten::_convolution', 'aten::relu', 'aten::quantize_per_tensor', 'aten::quantize_per_channel', 'aten::dequantize'])
424+
self.checkPatterns(graph, patterns)
424425

425426
class TestModel(JitLlgaTestCase):
426427
@skipIfNoTorchVision
@@ -429,13 +430,14 @@ def _test_vision(self, model_name):
429430
m = getattr(torchvision.models, model_name)().eval()
430431
x = torch.rand(1, 3, 224, 224) / 10
431432

432-
graph = self.checkQuantizeTrace(m, [x], atol=2e-1, folding=True, config_name=model_name)
433-
434-
# TODO: aten::adaptive_avg_pool2d also need to be fused once backend supported it
435-
self.assertFused(graph, ['aten::_convolution', 'aten::relu',
436-
'aten::max_pool2d', 'aten::linear'
437-
'aten::quantize_per_tensor', 'aten::quantize_per_channel',
438-
'aten::dequantize'])
433+
for qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]:
434+
graph = self.checkQuantizeTrace(m, [x], atol=2e-1, folding=True, config_name=model_name, qscheme=qscheme)
435+
436+
# TODO: aten::adaptive_avg_pool2d also need to be fused once backend supported it
437+
self.assertFused(graph, ['aten::_convolution', 'aten::relu',
438+
'aten::max_pool2d', 'aten::linear'
439+
'aten::quantize_per_tensor', 'aten::quantize_per_channel',
440+
'aten::dequantize'])
439441

440442

441443
for model_name, enabled in [

tests/cpu/test_jit_llga_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,24 +73,24 @@ def assertFused(self, graph, fused_patterns):
7373
for pat in fused_patterns:
7474
self.assertGraphContainsExactly(graph, pat, 0)
7575

76-
def checkQuantizeTrace(self, model, x, atol=1e-3, rtol=1e-2, folding=False, remove_dropout=False, config_name=""):
76+
def checkQuantizeTrace(self, model, x, atol=1e-3, rtol=1e-2, folding=False, remove_dropout=False, config_name="", qscheme=torch.per_tensor_affine):
7777
model.eval()
7878
with torch.no_grad(), torch._jit_internal._disable_emit_hooks():
7979
# fold conv bn
80-
if folding:
80+
if folding:
8181
model = ipex.fx.conv_bn_fuse(model)
8282

8383
if remove_dropout:
8484
ipex.utils._replace_dropout_with_identity(model)
8585

8686
# do calibration
87-
conf = ipex.AmpConf(torch.int8)
87+
conf = ipex.AmpConf(torch.int8, qscheme=qscheme)
8888
with ipex.amp.calibrate():
8989
y = model(*x)
9090

9191
with tempfile.TemporaryDirectory() as tmp:
9292
path = os.path.join(tmp, 'configure_%s.json' % config_name)
93-
93+
9494
# TODO: remove the serialization and test it in another separate UT once IPEX supported
9595
# directly using the conf for int8 path
9696
conf.save(path)

0 commit comments

Comments
 (0)