Skip to content

Commit 462604a

Browse files
[cherry-pick] ut: relax the checker to make it work for DNNL and GC backend (#1443)
Co-authored-by: Weizhuo Zhang <weizhuo.zhang@intel.com>
1 parent d8ef113 commit 462604a

File tree

2 files changed

+3
-16
lines changed

2 files changed

+3
-16
lines changed

tests/cpu/test_ao_jit_llga_quantization_fuser.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -466,25 +466,12 @@ def forward(self, x, y, z, a):
466466
m = M()
467467

468468
# fp32 in int8 out softmax
469-
int8_fp32_patterns = [
470-
["aten::dequantize", "aten::matmul", "aten::div", "aten::add", "aten::softmax", "aten::quantize_per_tensor"],
471-
["aten::dequantize", "aten::matmul"],
472-
]
473469
graph = self.checkQuantizeTrace(m, [x, y, z, a], atol=2e-1, int8_bf16=False)
474-
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
475-
self.checkPatterns(graph, int8_fp32_patterns)
470+
self.assertFused(graph, ['aten::matmul', 'aten::div', 'aten::add', 'aten::softmax'])
476471

477472
# bf16 in int8 out softmax
478-
int8_bf16_patterns = [
479-
["aten::to", "aten::quantize_per_tensor"],
480-
["aten::to", "aten::quantize_per_tensor"],
481-
["aten::dequantize", "aten::to", "aten::matmul", "aten::div", "aten::add", "aten::softmax", "aten::to", "aten::quantize_per_tensor"],
482-
["aten::to", "aten::quantize_per_tensor"],
483-
["aten::dequantize", "aten::to", "aten::matmul"],
484-
]
485473
graph = self.checkQuantizeTrace(m, [x, y, z, a], atol=2e-1, int8_bf16=True)
486-
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 5)
487-
self.checkPatterns(graph, int8_bf16_patterns)
474+
self.assertFused(graph, ['aten::matmul', 'aten::div', 'aten::add', 'aten::softmax'])
488475

489476
class TestFusionPattern(JitLlgaTestCase):
490477
def test_conv2d_eltwise(self):

tests/cpu/test_jit_llga_fuser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -670,7 +670,7 @@ def forward_test(x, y, z, a):
670670
a = torch.rand(128, 1, 1, 384)
671671

672672
graph, _ = self.checkTrace(forward_test, [x, y, z, a])
673-
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
673+
self.assertFused(graph, ['aten::matmul', 'aten::div', 'aten::add', 'aten::softmax', 'aten::contiguous'])
674674

675675
@llga_fp32_bf16_test_env
676676
def test_no_contiguous_no_op(self):

0 commit comments

Comments
 (0)