diff --git a/backends/cadence/aot/TARGETS b/backends/cadence/aot/TARGETS index be74a8d957f..8363f022946 100644 --- a/backends/cadence/aot/TARGETS +++ b/backends/cadence/aot/TARGETS @@ -641,6 +641,9 @@ python_unittest( typing = True, deps = [ "//caffe2:torch", + "//executorch/backends/cadence/aot:graph_builder", "//executorch/backends/cadence/aot/quantizer:quantizer", + "//executorch/exir:pass_base", + "//pytorch/ao:torchao", ], ) diff --git a/backends/cadence/aot/tests/test_quantizer_ops.py b/backends/cadence/aot/tests/test_quantizer_ops.py index f0df592558f..9518ff0e202 100644 --- a/backends/cadence/aot/tests/test_quantizer_ops.py +++ b/backends/cadence/aot/tests/test_quantizer_ops.py @@ -9,14 +9,110 @@ import unittest import torch +from executorch.backends.cadence.aot.graph_builder import GraphBuilder from executorch.backends.cadence.aot.quantizer.patterns import AddmmPattern +from executorch.exir.pass_base import NodeMetadata from executorch.backends.cadence.aot.quantizer.quantizer import ( CadenceAtenQuantizer, CadenceDefaultQuantizer, CadenceW8A32MixedQuantizer, + CadenceWith16BitLinearActivationsQuantizer, + CadenceWith16BitMatmulActivationsQuantizer, + qconfig_A16, qconfig_A8W8, ) +from torchao.quantization.pt2e.quantizer.quantizer import ( + Q_ANNOTATION_KEY, + QuantizationAnnotation, +) + + +class QuantizerAnnotationTest(unittest.TestCase): + """Unit tests for verifying quantizer annotations are correctly applied.""" + + def _build_matmul_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: + """Build a simple graph with a matmul operation.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(4, 8)) + y = builder.placeholder("y", torch.randn(8, 4)) + matmul = builder.call_operator( + op=torch.ops.aten.matmul.default, + args=(x, y), + meta=NodeMetadata( + {"source_fn_stack": [("matmul", torch.ops.aten.matmul.default)]} + ), + ) + builder.output([matmul]) + gm = builder.get_graph_module() + + matmul_nodes = gm.graph.find_nodes( + op="call_function", + target=torch.ops.aten.matmul.default, + ) + self.assertEqual(len(matmul_nodes), 1, "Should find exactly one matmul node") + return gm, matmul_nodes[0] + + def test_matmul_16bit_quantizer_annotation(self) -> None: + """Test that CadenceWith16BitMatmulActivationsQuantizer correctly annotates matmul.""" + gm, matmul_node = self._build_matmul_graph() + + quantizer = CadenceWith16BitMatmulActivationsQuantizer() + quantizer.annotate(gm) + + annotation: QuantizationAnnotation = matmul_node.meta[Q_ANNOTATION_KEY] + self.assertTrue(annotation._annotated) + + self.assertEqual(annotation.output_qspec, qconfig_A16.output_activation) + + self.assertEqual(len(annotation.input_qspec_map), 2) + for _, input_qspec in annotation.input_qspec_map.items(): + self.assertEqual(input_qspec, qconfig_A16.input_activation) + + def _build_linear_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: + """Build a simple graph with a linear operation (no bias).""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(1, 10)) + weight = builder.placeholder("weight", torch.randn(5, 10)) + linear = builder.call_operator( + op=torch.ops.aten.linear.default, + args=(x, weight), + meta=NodeMetadata( + {"source_fn_stack": [("linear", torch.ops.aten.linear.default)]} + ), + ) + builder.output([linear]) + gm = builder.get_graph_module() + + linear_nodes = gm.graph.find_nodes( + op="call_function", + target=torch.ops.aten.linear.default, + ) + self.assertEqual(len(linear_nodes), 1, "Should find exactly one linear node") + return gm, linear_nodes[0] + + def test_linear_16bit_quantizer_annotation(self) -> None: + """Test that CadenceWith16BitLinearActivationsQuantizer correctly annotates linear.""" + gm, linear_node = self._build_linear_graph() + + quantizer = CadenceWith16BitLinearActivationsQuantizer() + quantizer.annotate(gm) + + annotation: QuantizationAnnotation = linear_node.meta[Q_ANNOTATION_KEY] + self.assertTrue(annotation._annotated) + + # Verify output is annotated with qconfig_A16.output_activation (INT16) + self.assertEqual(annotation.output_qspec, qconfig_A16.output_activation) + + # Verify inputs: activation (INT16) and weight (INT8) + self.assertEqual(len(annotation.input_qspec_map), 2) + for input_node, input_qspec in annotation.input_qspec_map.items(): + if input_node == linear_node.args[0]: + # Activation input - should be INT16 + self.assertEqual(input_qspec, qconfig_A16.input_activation) + elif input_node == linear_node.args[1]: + # Weight - should be INT8 + self.assertEqual(input_qspec, qconfig_A16.weight) class QuantizerOpsPreserveTest(unittest.TestCase):