Skip to content

Commit f77fcec

Browse files
authored
[LLGA] do not rewrite single quant and dequant node (#139)
* [LLGA] do not rewrite single quant/dequant * [LLGA] update UTs since we don't rewrite single quant/dequant anymore
1 parent 77068a0 commit f77fcec

File tree

3 files changed

+68
-53
lines changed

3 files changed

+68
-53
lines changed

tests/cpu/test_jit_llga_quantization_fuser.py

Lines changed: 22 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -76,14 +76,13 @@ def test_conv2d_int8_in_f32_out(self):
7676
bias=bias)
7777
x = torch.rand(1, in_channels * g, spatial, spatial)
7878
patterns = [
79-
["aten::quantize_per_tensor"],
8079
["aten::quantize_per_channel", "aten::dequantize", "aten::_convolution"]
8180
]
8281
#TODO: enable torch.per_tensor_symmetric case.
8382
for qscheme in [torch.per_tensor_affine]:
8483
graph = self.checkQuantizeTrace(m, [x], x_var=[torch.rand(5, in_channels * g, spatial, spatial, requires_grad=False)], atol=2e-1, config_name="conv2d", qscheme=qscheme)
85-
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
86-
self.assertFused(graph, ['aten::_convolution', 'aten::quantize_per_tensor', 'aten::quantize_per_channel'])
84+
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
85+
self.assertFused(graph, ['aten::_convolution', 'aten::quantize_per_channel', 'aten::dequantize'])
8786
self.checkPatterns(graph, patterns)
8887

8988
@llga_test_env
@@ -93,13 +92,12 @@ def test_linear_int8_in_f32_out(self):
9392
m = torch.nn.Linear(in_features=28, out_features=64, bias=bias)
9493

9594
patterns = [
96-
["aten::quantize_per_tensor"],
9795
["aten::quantize_per_channel", "aten::dequantize", "aten::linear"],
9896
]
9997
for qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]:
10098
graph = self.checkQuantizeTrace(m, [x], atol=1e-1, config_name="linear", qscheme=qscheme)
101-
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
102-
self.assertFused(graph, ['aten::linear', 'aten::quantize_per_tensor', 'aten::quantize_per_channel', 'aten::dequantize'])
99+
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
100+
self.assertFused(graph, ['aten::linear', 'aten::quantize_per_channel', 'aten::dequantize'])
103101
self.checkPatterns(graph, patterns)
104102

105103
@llga_test_env
@@ -121,16 +119,14 @@ def forward(self, x, y):
121119
m = M(bias)
122120

123121
patterns = [
124-
["aten::quantize_per_tensor"],
125122
["aten::quantize_per_channel", "aten::dequantize", "aten::linear", "aten::quantize_per_tensor"],
126123
["aten::quantize_per_channel", "aten::dequantize", "aten::linear"]
127124
]
128125

129126
for qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]:
130127
graph = self.checkQuantizeTrace(m, [x, y], atol=2e-1, config_name="linear_int8", qscheme=qscheme)
131-
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 3)
132-
self.assertFused(graph, ['aten::linear',
133-
'aten::quantize_per_tensor', 'aten::quantize_per_channel', 'aten::dequantize'])
128+
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
129+
self.assertFused(graph, ['aten::linear', 'aten::quantize_per_channel', 'aten::dequantize'])
134130
self.checkPatterns(graph, patterns)
135131

136132
@llga_test_env
@@ -158,14 +154,12 @@ def test_max_pool2d(self):
158154
x = torch.rand(1, 3, spatial, spatial)
159155

160156
patterns = [
161-
["aten::quantize_per_tensor"],
162157
["aten::dequantize", "aten::max_pool2d", "aten::quantize_per_tensor"],
163-
["aten::dequantize"]
164158
]
165159
for qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]:
166160
graph = self.checkQuantizeTrace(m, [x], atol=1e-1, config_name="max_pool2d", qscheme=qscheme)
167-
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 3)
168-
self.assertFused(graph, ['aten::max_pool2d', 'aten::quantize_per_tensor', 'aten::dequantize'])
161+
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
162+
self.assertFused(graph, ['aten::max_pool2d'])
169163
self.checkPatterns(graph, patterns)
170164

171165
@llga_test_env
@@ -212,14 +206,13 @@ def forward(self, x):
212206
x = torch.rand(1, 32, 28, 28)
213207

214208
patterns = [
215-
["aten::quantize_per_tensor"],
216209
["aten::quantize_per_channel", "aten::dequantize", "aten::_convolution", 'aten::' + eltwise, "aten::quantize_per_tensor"], # inplace op will become outplace op on the JIT graph
217210
["aten::quantize_per_channel", "aten::dequantize", "aten::_convolution"]
218211
]
219212
for qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]:
220213
graph = self.checkQuantizeTrace(m, [x], atol=2e-1, config_name="conv2d_eltwise", qscheme=qscheme)
221-
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 3)
222-
self.assertFused(graph, ['aten::_convolution', 'aten::' + eltwise, 'aten::quantize_per_tensor', 'aten::quantize_per_channel', 'aten::dequantize'])
214+
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
215+
self.assertFused(graph, ['aten::_convolution', 'aten::' + eltwise, 'aten::quantize_per_channel', 'aten::dequantize'])
223216
self.checkPatterns(graph, patterns)
224217

225218
@llga_test_env
@@ -241,14 +234,13 @@ def forward(self, x):
241234
# x = torch.rand(1, 32, 28, 28)
242235

243236
patterns = [
244-
["aten::quantize_per_tensor"],
245237
["aten::quantize_per_channel", "aten::dequantize", "aten::_convolution"]
246238
]
247239
# TODO: add torch.per_tensor_symmetric case.
248240
for qscheme in [torch.per_tensor_affine]:
249241
graph = self.checkQuantizeTrace(m, [x], atol=1e-1, folding=True, config_name="conv2d_bn", qscheme=qscheme)
250-
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
251-
self.assertFused(graph, ['aten::_convolution', 'aten::quantize_per_tensor', 'aten::quantize_per_channel'])
242+
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
243+
self.assertFused(graph, ['aten::_convolution', 'aten::quantize_per_channel', 'aten::dequantize'])
252244
self.checkPatterns(graph, patterns)
253245

254246
@llga_test_env
@@ -268,15 +260,12 @@ def forward(self, x):
268260
m = M().eval()
269261
x = torch.rand(1, 32, 28, 28)
270262
patterns = [
271-
["aten::quantize_per_tensor"],
272263
["aten::quantize_per_channel", "aten::dequantize", "aten::_convolution", "aten::relu", "aten::quantize_per_tensor"],
273-
["aten::dequantize"]
274264
]
275265
for qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]:
276266
graph = self.checkQuantizeTrace(m, [x], atol=1e-1, folding=True, config_name="conv2d_bn_relu", qscheme=qscheme)
277-
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 3)
278-
self.assertFused(graph, ['aten::_convolution', 'aten::relu',
279-
'aten::quantize_per_tensor', 'aten::quantize_per_channel', 'aten::dequantize'])
267+
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
268+
self.assertFused(graph, ['aten::_convolution', 'aten::relu', 'aten::quantize_per_channel'])
280269
self.checkPatterns(graph, patterns)
281270

282271
@llga_test_env
@@ -305,13 +294,11 @@ def forward(self, x):
305294
m = M(eltwise_fn, has_bias)
306295
x = torch.rand(32, 28, requires_grad=False)
307296
patterns = [
308-
["aten::quantize_per_tensor"],
309297
["aten::quantize_per_channel", "aten::dequantize", "aten::linear", "aten::" + eltwise, "aten::quantize_per_tensor"],
310-
["aten::dequantize"]
311298
]
312299
for qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]:
313300
graph = self.checkQuantizeTrace(m, [x], x_var=[torch.rand(2, 28, requires_grad=False)], atol=1e-1, config_name="linear_eltwise", qscheme=qscheme)
314-
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 3)
301+
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
315302
self.assertFused(graph, ['aten::' + eltwise])
316303
self.checkPatterns(graph, patterns)
317304

@@ -343,15 +330,14 @@ def forward(self, x, y):
343330
x = torch.rand(1, 32, 16, 16, requires_grad=False)
344331
y = torch.rand(1, 32, 16, 16, requires_grad=False)
345332
patterns = [
346-
["aten::quantize_per_tensor"],
347-
["aten::quantize_per_tensor"],
348333
["aten::quantize_per_channel", "aten::dequantize", "aten::_convolution", "aten::quantize_per_tensor"],
349334
["aten::quantize_per_channel", "aten::dequantize", "aten::_convolution", "aten::relu", "aten::add", "aten::quantize_per_tensor"],
350335
["aten::quantize_per_channel", "aten::dequantize", "aten::_convolution"]
351336
]
352337
for qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]:
353338
graph = self.checkQuantizeTrace(m, [x, y], folding=True, atol=1e-1, config_name="conv2d_sum", qscheme=qscheme)
354-
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 5)
339+
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 3)
340+
self.assertFused(graph, ['aten::_convolution', 'aten::relu', 'aten::add', 'aten::quantize_per_channel', 'aten::dequantize'])
355341
self.checkPatterns(graph, patterns)
356342

357343
@llga_test_env
@@ -373,29 +359,15 @@ def forward(self, x, y):
373359
y = torch.randn(2, 20)
374360
m = M()
375361
patterns = [
376-
["aten::quantize_per_tensor"],
377-
["aten::quantize_per_tensor"],
378362
["aten::quantize_per_channel", "aten::dequantize", "aten::linear", "aten::add", "aten::quantize_per_tensor"],
379363
["aten::quantize_per_channel", "aten::dequantize", "aten::linear"]
380364
]
381365
for qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]:
382366
graph = self.checkQuantizeTrace(m, [x, y], atol=2e-1, remove_dropout=True, config_name="linear_dropout_sum", qscheme=qscheme)
383-
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 4)
384-
self.assertFused(graph, ['aten::linear', 'aten::add',
385-
'aten::quantize_per_tensor', 'aten::quantize_per_channel', 'aten::dequantize'])
367+
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
368+
self.assertFused(graph, ['aten::linear', 'aten::add', 'aten::quantize_per_channel', 'aten::dequantize'])
386369
self.checkPatterns(graph, patterns)
387370

388-
# TODO: check patterns when oneDNN support sum post_ops with zps
389-
# patterns = [
390-
# ["aten::quantize_per_tensor"],
391-
# ["aten::quantize_per_channel"],
392-
# ["aten::dequantize", "aten::linear", "aten::add", "aten::quantize_per_tensor"],
393-
# ["aten::quantize_per_channel"],
394-
# ["aten::dequantize", "aten::linear", "aten::quantize_per_tensor"],
395-
# ["aten::dequantize"]
396-
# ]
397-
# self.checkPatterns(graph, patterns)
398-
399371
@llga_test_env
400372
def test_defer_size(self):
401373
class M(nn.Module):
@@ -415,14 +387,13 @@ def forward(self, x):
415387
m = M()
416388
x = torch.rand(1, 32, 28, 28)
417389
patterns = [
418-
["aten::quantize_per_tensor"],
419390
["aten::quantize_per_channel", "aten::dequantize", "aten::_convolution", 'aten::relu', "aten::quantize_per_tensor"],
420391
["aten::quantize_per_channel", "aten::dequantize", "aten::_convolution"]
421392
]
422393
for qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]:
423394
graph = self.checkQuantizeTrace(m, [x], atol=2e-1, config_name="defer_size", qscheme=qscheme)
424-
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 3)
425-
self.assertFused(graph, ['aten::_convolution', 'aten::relu', 'aten::quantize_per_tensor', 'aten::quantize_per_channel', 'aten::dequantize'])
395+
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
396+
self.assertFused(graph, ['aten::_convolution', 'aten::relu', 'aten::quantize_per_channel', 'aten::dequantize'])
426397
self.checkPatterns(graph, patterns)
427398

428399
class TestShapeFallback(JitLlgaTestCase):
@@ -486,9 +457,7 @@ def _test_vision(self, model_name):
486457

487458
# TODO: aten::adaptive_avg_pool2d also need to be fused once backend supported it
488459
self.assertFused(graph, ['aten::_convolution', 'aten::relu',
489-
'aten::max_pool2d', 'aten::linear'
490-
'aten::quantize_per_tensor', 'aten::quantize_per_channel',
491-
'aten::dequantize'])
460+
'aten::max_pool2d', 'aten::linear', 'aten::quantize_per_channel'])
492461

493462

494463
for model_name, enabled in [

torch_ipex/csrc/jit/codegen/onednn/graph_helper.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,41 @@ bool isViewOp(Node* n) {
488488
return false;
489489
}
490490

491+
void checkAndRemoveAttr(Node *n, std::string attr) {
492+
TORCH_CHECK(n->hasAttributeS(attr),
493+
"dequant node with numAttributes != 0 must have attr: ", attr);
494+
n->removeAttributeS(attr);
495+
}
496+
497+
void removeAttrOfDequant(Node *n) {
498+
if (n->kind() == Symbol::aten("dequantize")) {
499+
if (n->numAttributes() == 0)
500+
return;
501+
std::vector<std::string> common_attrs{"zps", "scales", "in_type"};
502+
for (const auto &attr : common_attrs) {
503+
checkAndRemoveAttr(n, attr);
504+
}
505+
506+
if (n->s(Symbol::attr("qtype")) == std::string("per_channel")) {
507+
checkAndRemoveAttr(n, std::string("axis"));
508+
}
509+
checkAndRemoveAttr(n, std::string("qtype"));
510+
}
511+
}
512+
513+
bool LlgaGraphHelper::isSingleQuantDequant(Node *n) {
514+
if (n->kind() != Symbol::aten("quantize_per_tensor") &&
515+
n->kind() != Symbol::aten("quantize_per_channel") &&
516+
n->kind() != Symbol::aten("dequantize"))
517+
return false;
518+
if (!opToOwningPartition_.has(n))
519+
return false;
520+
521+
auto partitionId = opToOwningPartition_.get(n);
522+
auto OpNum = partitions_[partitionId].get_ops_num();
523+
return OpNum == 1;
524+
}
525+
491526
bool LlgaGraphHelper::shouldConsiderForMerge(Node* node) {
492527
// if we're already in the process of merging
493528
if (isLlgaSubgraph(node)) {
@@ -496,6 +531,15 @@ bool LlgaGraphHelper::shouldConsiderForMerge(Node* node) {
496531
if (isViewOp(node)) {
497532
return false;
498533
}
534+
// For a partition composed of 1 single quant or 1 single dequant,
535+
// do not rewrite it in the bridge, so that the FWK may have chances
536+
// to optimize single int8 op that LLGA does not support
537+
if (isSingleQuantDequant(node)) {
538+
// We have added attr on dequant node to create LLGA dequant op.
539+
// If we won't rewrite it with LLGA op, remove the attr here.
540+
removeAttrOfDequant(node);
541+
return false;
542+
}
499543
return opToOwningPartition_.has(node);
500544
}
501545

torch_ipex/csrc/jit/codegen/onednn/graph_helper.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ class LlgaGraphHelper {
6363
private:
6464
size_t countSupportedOps(const std::shared_ptr<Graph>& graph) const;
6565

66+
bool isSingleQuantDequant(Node *node);
67+
6668
OpPartitionMap opToOwningPartition_;
6769
std::vector<dnnl::graph::partition> partitions_;
6870
};

0 commit comments

Comments
 (0)