@@ -466,25 +466,12 @@ def forward(self, x, y, z, a):
466466 m = M ()
467467
468468 # fp32 in int8 out softmax
469- int8_fp32_patterns = [
470- ["aten::dequantize" , "aten::matmul" , "aten::div" , "aten::add" , "aten::softmax" , "aten::quantize_per_tensor" ],
471- ["aten::dequantize" , "aten::matmul" ],
472- ]
473469 graph = self .checkQuantizeTrace (m , [x , y , z , a ], atol = 2e-1 , int8_bf16 = False )
474- self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 2 )
475- self .checkPatterns (graph , int8_fp32_patterns )
470+ self .assertFused (graph , ['aten::matmul' , 'aten::div' , 'aten::add' , 'aten::softmax' ])
476471
477472 # bf16 in int8 out softmax
478- int8_bf16_patterns = [
479- ["aten::to" , "aten::quantize_per_tensor" ],
480- ["aten::to" , "aten::quantize_per_tensor" ],
481- ["aten::dequantize" , "aten::to" , "aten::matmul" , "aten::div" , "aten::add" , "aten::softmax" , "aten::to" , "aten::quantize_per_tensor" ],
482- ["aten::to" , "aten::quantize_per_tensor" ],
483- ["aten::dequantize" , "aten::to" , "aten::matmul" ],
484- ]
485473 graph = self .checkQuantizeTrace (m , [x , y , z , a ], atol = 2e-1 , int8_bf16 = True )
486- self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 5 )
487- self .checkPatterns (graph , int8_bf16_patterns )
474+ self .assertFused (graph , ['aten::matmul' , 'aten::div' , 'aten::add' , 'aten::softmax' ])
488475
489476class TestFusionPattern (JitLlgaTestCase ):
490477 def test_conv2d_eltwise (self ):
0 commit comments