|
9 | 9 | import unittest |
10 | 10 |
|
11 | 11 | import torch |
| 12 | +from executorch.backends.cadence.aot.graph_builder import GraphBuilder |
12 | 13 | from executorch.backends.cadence.aot.quantizer.patterns import AddmmPattern |
| 14 | +from executorch.exir.pass_base import NodeMetadata |
13 | 15 |
|
14 | 16 | from executorch.backends.cadence.aot.quantizer.quantizer import ( |
15 | 17 | CadenceAtenQuantizer, |
16 | 18 | CadenceDefaultQuantizer, |
17 | 19 | CadenceW8A32MixedQuantizer, |
| 20 | + CadenceWith16BitMatmulActivationsQuantizer, |
| 21 | + qconfig_A16, |
18 | 22 | qconfig_A8W8, |
19 | 23 | ) |
| 24 | +from torchao.quantization.pt2e.quantizer.quantizer import ( |
| 25 | + Q_ANNOTATION_KEY, |
| 26 | + QuantizationAnnotation, |
| 27 | +) |
| 28 | + |
| 29 | + |
| 30 | +class QuantizerAnnotationTest(unittest.TestCase): |
| 31 | + """Unit tests for verifying quantizer annotations are correctly applied.""" |
| 32 | + |
| 33 | + def _build_matmul_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: |
| 34 | + """Build a simple graph with a matmul operation.""" |
| 35 | + builder = GraphBuilder() |
| 36 | + x = builder.placeholder("x", torch.randn(4, 8)) |
| 37 | + y = builder.placeholder("y", torch.randn(8, 4)) |
| 38 | + matmul = builder.call_operator( |
| 39 | + op=torch.ops.aten.matmul.default, |
| 40 | + args=(x, y), |
| 41 | + meta=NodeMetadata( |
| 42 | + {"source_fn_stack": [("matmul", torch.ops.aten.matmul.default)]} |
| 43 | + ), |
| 44 | + ) |
| 45 | + builder.output([matmul]) |
| 46 | + gm = builder.get_graph_module() |
| 47 | + |
| 48 | + matmul_nodes = gm.graph.find_nodes( |
| 49 | + op="call_function", |
| 50 | + target=torch.ops.aten.matmul.default, |
| 51 | + ) |
| 52 | + self.assertEqual(len(matmul_nodes), 1, "Should find exactly one matmul node") |
| 53 | + return gm, matmul_nodes[0] |
| 54 | + |
| 55 | + def test_matmul_16bit_quantizer_annotation(self) -> None: |
| 56 | + """Test that CadenceWith16BitMatmulActivationsQuantizer correctly annotates matmul.""" |
| 57 | + gm, matmul_node = self._build_matmul_graph() |
| 58 | + |
| 59 | + quantizer = CadenceWith16BitMatmulActivationsQuantizer() |
| 60 | + quantizer.annotate(gm) |
| 61 | + |
| 62 | + annotation: QuantizationAnnotation = matmul_node.meta[Q_ANNOTATION_KEY] |
| 63 | + self.assertTrue(annotation._annotated) |
| 64 | + |
| 65 | + self.assertEqual(annotation.output_qspec, qconfig_A16.output_activation) |
| 66 | + |
| 67 | + self.assertEqual(len(annotation.input_qspec_map), 2) |
| 68 | + for _, input_qspec in annotation.input_qspec_map.items(): |
| 69 | + self.assertEqual(input_qspec, qconfig_A16.input_activation) |
20 | 70 |
|
21 | 71 |
|
22 | 72 | class QuantizerOpsPreserveTest(unittest.TestCase): |
|
0 commit comments