@@ -487,6 +487,63 @@ def forward(self, x):
487487 self .assertFused (graph , ['aten::' + eltwise ])
488488 self .checkPatterns (graph , patterns )
489489
490+ def test_conv_relu_sigmoid_mul (self ):
491+ # dequant
492+ # |
493+ # conv
494+ # |
495+ # relu
496+ # / |
497+ # quant |
498+ # / |
499+ # dequant |
500+ # | |
501+ # conv |
502+ # | |
503+ # relu |
504+ # | |
505+ # quant |
506+ # | |
507+ # dequant |
508+ # | |
509+ # conv |
510+ # | |
511+ # sigmoid |
512+ # \ /
513+ # mul
514+
515+ class M (nn .Module ):
516+ def __init__ (self ):
517+ super (M , self ).__init__ ()
518+ self .conv1 = nn .Conv2d (32 , 32 , 3 , padding = 1 )
519+ self .conv2 = nn .Conv2d (32 , 32 , 3 , padding = 1 )
520+ self .conv3 = nn .Conv2d (32 , 32 , 3 , padding = 1 )
521+
522+ def forward (self , x ):
523+ x = self .conv1 (x )
524+
525+ # The output y of relu is used by mul
526+ y = x .relu ()
527+
528+ z = self .conv2 (y )
529+ z = z .relu ()
530+ z = self .conv3 (z )
531+ z = z .sigmoid ()
532+ z = z .mul (y )
533+ return z
534+
535+ x = torch .rand (1 , 32 ,16 , 16 , requires_grad = False )
536+ m = M ()
537+ graph = self .checkQuantizeTrace (m , [x ], atol = 1e-1 )
538+ patterns = [
539+ ["aten::dequantize" , "aten::_convolution" , "aten::relu" ],
540+ ["aten::dequantize" , "aten::_convolution" , "aten::relu" , "aten::quantize_per_tensor" ],
541+ ["aten::dequantize" , "aten::_convolution" , "aten::sigmoid" , "aten::mul" ],
542+ ]
543+ self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 3 )
544+ self .assertFused (graph , ['aten::_convolution' , 'aten::relu' , 'aten::sigmoid' ,'aten::mul' ])
545+ self .checkPatterns (graph , patterns )
546+
490547 def test_conv_eltwise_tensor_method (self ):
491548 class ConvSigmoid (nn .Module ):
492549 def __init__ (self ):
0 commit comments