77# pyre-strict
88
99import unittest
10+ from typing import Callable
1011
1112import torch
1213from executorch .backends .cadence .aot .graph_builder import GraphBuilder
1314from executorch .backends .cadence .aot .quantizer .patterns import AddmmPattern
14- from executorch .exir .pass_base import NodeMetadata
1515
1616from executorch .backends .cadence .aot .quantizer .quantizer import (
1717 CadenceAtenQuantizer ,
1818 CadenceDefaultQuantizer ,
19+ CadenceQuantizer ,
1920 CadenceW8A32MixedQuantizer ,
2021 CadenceWith16BitLinearActivationsQuantizer ,
2122 CadenceWith16BitMatmulActivationsQuantizer ,
2223 qconfig_A16 ,
2324 qconfig_A8W8 ,
2425)
26+ from executorch .exir .pass_base import NodeMetadata
27+ from parameterized import parameterized
28+ from torch ._ops import OpOverload
2529from torchao .quantization .pt2e .quantizer .quantizer import (
2630 Q_ANNOTATION_KEY ,
2731 QuantizationAnnotation ,
32+ QuantizationSpec ,
2833)
2934
35+ # Type alias for graph builder functions
36+ GraphBuilderFn = Callable [
37+ ["QuantizerAnnotationTest" ], tuple [torch .fx .GraphModule , torch .fx .Node ]
38+ ]
39+
3040
3141class QuantizerAnnotationTest (unittest .TestCase ):
3242 """Unit tests for verifying quantizer annotations are correctly applied."""
@@ -53,22 +63,6 @@ def _build_matmul_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
5363 self .assertEqual (len (matmul_nodes ), 1 , "Should find exactly one matmul node" )
5464 return gm , matmul_nodes [0 ]
5565
56- def test_matmul_16bit_quantizer_annotation (self ) -> None :
57- """Test that CadenceWith16BitMatmulActivationsQuantizer correctly annotates matmul."""
58- gm , matmul_node = self ._build_matmul_graph ()
59-
60- quantizer = CadenceWith16BitMatmulActivationsQuantizer ()
61- quantizer .annotate (gm )
62-
63- annotation : QuantizationAnnotation = matmul_node .meta [Q_ANNOTATION_KEY ]
64- self .assertTrue (annotation ._annotated )
65-
66- self .assertEqual (annotation .output_qspec , qconfig_A16 .output_activation )
67-
68- self .assertEqual (len (annotation .input_qspec_map ), 2 )
69- for _ , input_qspec in annotation .input_qspec_map .items ():
70- self .assertEqual (input_qspec , qconfig_A16 .input_activation )
71-
7266 def _build_linear_graph (self ) -> tuple [torch .fx .GraphModule , torch .fx .Node ]:
7367 """Build a simple graph with a linear operation (no bias)."""
7468 builder = GraphBuilder ()
@@ -91,28 +85,65 @@ def _build_linear_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
9185 self .assertEqual (len (linear_nodes ), 1 , "Should find exactly one linear node" )
9286 return gm , linear_nodes [0 ]
9387
94- def test_linear_16bit_quantizer_annotation (self ) -> None :
95- """Test that CadenceWith16BitLinearActivationsQuantizer correctly annotates linear."""
96- gm , linear_node = self ._build_linear_graph ()
88+ @parameterized .expand (
89+ [
90+ (
91+ "matmul_A16" ,
92+ lambda self : self ._build_matmul_graph (),
93+ CadenceWith16BitMatmulActivationsQuantizer (),
94+ torch .ops .aten .matmul .default ,
95+ qconfig_A16 .output_activation ,
96+ # For matmul, both inputs are activations
97+ [qconfig_A16 .input_activation , qconfig_A16 .input_activation ],
98+ ),
99+ (
100+ "linear_A16" ,
101+ lambda self : self ._build_linear_graph (),
102+ CadenceWith16BitLinearActivationsQuantizer (),
103+ torch .ops .aten .linear .default ,
104+ qconfig_A16 .output_activation ,
105+ # For linear: [input_activation, weight]
106+ [qconfig_A16 .input_activation , qconfig_A16 .weight ],
107+ ),
108+ ]
109+ )
110+ def test_quantizer_annotation (
111+ self ,
112+ name : str ,
113+ graph_builder_fn : GraphBuilderFn ,
114+ quantizer : CadenceQuantizer ,
115+ target : OpOverload ,
116+ expected_output_qspec : QuantizationSpec ,
117+ expected_input_qspecs : list [QuantizationSpec ],
118+ ) -> None :
119+ """Parameterized test for quantizer annotations."""
120+ gm , op_node = graph_builder_fn (self )
97121
98- quantizer = CadenceWith16BitLinearActivationsQuantizer ()
99122 quantizer .annotate (gm )
100123
101- annotation : QuantizationAnnotation = linear_node .meta [Q_ANNOTATION_KEY ]
124+ annotation : QuantizationAnnotation = op_node .meta [Q_ANNOTATION_KEY ]
102125 self .assertTrue (annotation ._annotated )
103126
104- # Verify output is annotated with qconfig_A16.output_activation (INT16)
105- self .assertEqual (annotation .output_qspec , qconfig_A16 .output_activation )
106-
107- # Verify inputs: activation (INT16) and weight (INT8)
108- self .assertEqual (len (annotation .input_qspec_map ), 2 )
109- for input_node , input_qspec in annotation .input_qspec_map .items ():
110- if input_node == linear_node .args [0 ]:
111- # Activation input - should be INT16
112- self .assertEqual (input_qspec , qconfig_A16 .input_activation )
113- elif input_node == linear_node .args [1 ]:
114- # Weight - should be INT8
115- self .assertEqual (input_qspec , qconfig_A16 .weight )
127+ # Verify output annotation
128+ self .assertEqual (annotation .output_qspec , expected_output_qspec )
129+
130+ # Verify input annotations
131+ # Build actual_specs in the fixed order defined by op_node.args
132+ self .assertEqual (len (annotation .input_qspec_map ), len (expected_input_qspecs ))
133+ actual_specs = [
134+ annotation .input_qspec_map [op_node .args [i ]]
135+ for i in range (len (expected_input_qspecs ))
136+ ]
137+
138+ # Compare expected vs actual specs
139+ for i , (expected , actual ) in enumerate (
140+ zip (expected_input_qspecs , actual_specs )
141+ ):
142+ self .assertEqual (
143+ actual ,
144+ expected ,
145+ f"Input qspec mismatch at index { i } " ,
146+ )
116147
117148
118149class QuantizerOpsPreserveTest (unittest .TestCase ):
0 commit comments