Skip to content

Commit d94ca6e

Browse files
authored
[1.12] Update the backend (d2567d7) to fix the pattern match for regnet_y and mul perf issue for efficientnet (#859)
* add ut to reproduce * update llga to d2567d to include fix for pattern match and mul perf
1 parent 7b2b561 commit d94ca6e

File tree

2 files changed

+58
-1
lines changed

2 files changed

+58
-1
lines changed

tests/cpu/test_ao_jit_llga_quantization_fuser.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,63 @@ def forward(self, x):
487487
self.assertFused(graph, ['aten::' + eltwise])
488488
self.checkPatterns(graph, patterns)
489489

490+
def test_conv_relu_sigmoid_mul(self):
491+
# dequant
492+
# |
493+
# conv
494+
# |
495+
# relu
496+
# / |
497+
# quant |
498+
# / |
499+
# dequant |
500+
# | |
501+
# conv |
502+
# | |
503+
# relu |
504+
# | |
505+
# quant |
506+
# | |
507+
# dequant |
508+
# | |
509+
# conv |
510+
# | |
511+
# sigmoid |
512+
# \ /
513+
# mul
514+
515+
class M(nn.Module):
516+
def __init__(self):
517+
super(M, self).__init__()
518+
self.conv1 = nn.Conv2d(32, 32, 3, padding=1)
519+
self.conv2 = nn.Conv2d(32, 32, 3, padding=1)
520+
self.conv3 = nn.Conv2d(32, 32, 3, padding=1)
521+
522+
def forward(self, x):
523+
x = self.conv1(x)
524+
525+
# The output y of relu is used by mul
526+
y = x.relu()
527+
528+
z = self.conv2(y)
529+
z = z.relu()
530+
z = self.conv3(z)
531+
z = z.sigmoid()
532+
z = z.mul(y)
533+
return z
534+
535+
x = torch.rand(1, 32,16, 16, requires_grad=False)
536+
m = M()
537+
graph = self.checkQuantizeTrace(m, [x], atol=1e-1)
538+
patterns = [
539+
["aten::dequantize", "aten::_convolution", "aten::relu"],
540+
["aten::dequantize", "aten::_convolution", "aten::relu", "aten::quantize_per_tensor"],
541+
["aten::dequantize", "aten::_convolution", "aten::sigmoid", "aten::mul"],
542+
]
543+
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 3)
544+
self.assertFused(graph, ['aten::_convolution', 'aten::relu', 'aten::sigmoid','aten::mul'])
545+
self.checkPatterns(graph, patterns)
546+
490547
def test_conv_eltwise_tensor_method(self):
491548
class ConvSigmoid(nn.Module):
492549
def __init__(self):

third_party/llga

Submodule llga updated from e15979e to d2567d7

0 commit comments

Comments
 (0)