Skip to content

Commit 9637b08

Browse files
RahulC7facebook-github-bot
authored andcommitted
Creating Paramaterized Test For Quantizers For Easier Testing (pytorch#16098)
Summary: We consolidate the two tests we created into a single testing function using parameterization. This will make testing future Quantizers much easier, and there will be a lot less code duplication. Reviewed By: hsharma35 Differential Revision: D88054917
1 parent d0ee8a3 commit 9637b08

File tree

2 files changed

+66
-34
lines changed

2 files changed

+66
-34
lines changed

backends/cadence/aot/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,7 @@ python_unittest(
640640
],
641641
typing = True,
642642
deps = [
643+
"fbsource//third-party/pypi/parameterized:parameterized",
643644
"//caffe2:torch",
644645
"//executorch/backends/cadence/aot:graph_builder",
645646
"//executorch/backends/cadence/aot/quantizer:quantizer",

backends/cadence/aot/tests/test_quantizer_ops.py

Lines changed: 65 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -7,26 +7,36 @@
77
# pyre-strict
88

99
import unittest
10+
from typing import Callable
1011

1112
import torch
1213
from executorch.backends.cadence.aot.graph_builder import GraphBuilder
1314
from executorch.backends.cadence.aot.quantizer.patterns import AddmmPattern
14-
from executorch.exir.pass_base import NodeMetadata
1515

1616
from executorch.backends.cadence.aot.quantizer.quantizer import (
1717
CadenceAtenQuantizer,
1818
CadenceDefaultQuantizer,
19+
CadenceQuantizer,
1920
CadenceW8A32MixedQuantizer,
2021
CadenceWith16BitLinearActivationsQuantizer,
2122
CadenceWith16BitMatmulActivationsQuantizer,
2223
qconfig_A16,
2324
qconfig_A8W8,
2425
)
26+
from executorch.exir.pass_base import NodeMetadata
27+
from parameterized import parameterized
28+
from torch._ops import OpOverload
2529
from torchao.quantization.pt2e.quantizer.quantizer import (
2630
Q_ANNOTATION_KEY,
2731
QuantizationAnnotation,
32+
QuantizationSpec,
2833
)
2934

35+
# Type alias for graph builder functions
36+
GraphBuilderFn = Callable[
37+
["QuantizerAnnotationTest"], tuple[torch.fx.GraphModule, torch.fx.Node]
38+
]
39+
3040

3141
class QuantizerAnnotationTest(unittest.TestCase):
3242
"""Unit tests for verifying quantizer annotations are correctly applied."""
@@ -53,22 +63,6 @@ def _build_matmul_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
5363
self.assertEqual(len(matmul_nodes), 1, "Should find exactly one matmul node")
5464
return gm, matmul_nodes[0]
5565

56-
def test_matmul_16bit_quantizer_annotation(self) -> None:
57-
"""Test that CadenceWith16BitMatmulActivationsQuantizer correctly annotates matmul."""
58-
gm, matmul_node = self._build_matmul_graph()
59-
60-
quantizer = CadenceWith16BitMatmulActivationsQuantizer()
61-
quantizer.annotate(gm)
62-
63-
annotation: QuantizationAnnotation = matmul_node.meta[Q_ANNOTATION_KEY]
64-
self.assertTrue(annotation._annotated)
65-
66-
self.assertEqual(annotation.output_qspec, qconfig_A16.output_activation)
67-
68-
self.assertEqual(len(annotation.input_qspec_map), 2)
69-
for _, input_qspec in annotation.input_qspec_map.items():
70-
self.assertEqual(input_qspec, qconfig_A16.input_activation)
71-
7266
def _build_linear_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
7367
"""Build a simple graph with a linear operation (no bias)."""
7468
builder = GraphBuilder()
@@ -91,28 +85,65 @@ def _build_linear_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
9185
self.assertEqual(len(linear_nodes), 1, "Should find exactly one linear node")
9286
return gm, linear_nodes[0]
9387

94-
def test_linear_16bit_quantizer_annotation(self) -> None:
95-
"""Test that CadenceWith16BitLinearActivationsQuantizer correctly annotates linear."""
96-
gm, linear_node = self._build_linear_graph()
88+
@parameterized.expand(
89+
[
90+
(
91+
"matmul_A16",
92+
lambda self: self._build_matmul_graph(),
93+
CadenceWith16BitMatmulActivationsQuantizer(),
94+
torch.ops.aten.matmul.default,
95+
qconfig_A16.output_activation,
96+
# For matmul, both inputs are activations
97+
[qconfig_A16.input_activation, qconfig_A16.input_activation],
98+
),
99+
(
100+
"linear_A16",
101+
lambda self: self._build_linear_graph(),
102+
CadenceWith16BitLinearActivationsQuantizer(),
103+
torch.ops.aten.linear.default,
104+
qconfig_A16.output_activation,
105+
# For linear: [input_activation, weight]
106+
[qconfig_A16.input_activation, qconfig_A16.weight],
107+
),
108+
]
109+
)
110+
def test_quantizer_annotation(
111+
self,
112+
name: str,
113+
graph_builder_fn: GraphBuilderFn,
114+
quantizer: CadenceQuantizer,
115+
target: OpOverload,
116+
expected_output_qspec: QuantizationSpec,
117+
expected_input_qspecs: list[QuantizationSpec],
118+
) -> None:
119+
"""Parameterized test for quantizer annotations."""
120+
gm, op_node = graph_builder_fn(self)
97121

98-
quantizer = CadenceWith16BitLinearActivationsQuantizer()
99122
quantizer.annotate(gm)
100123

101-
annotation: QuantizationAnnotation = linear_node.meta[Q_ANNOTATION_KEY]
124+
annotation: QuantizationAnnotation = op_node.meta[Q_ANNOTATION_KEY]
102125
self.assertTrue(annotation._annotated)
103126

104-
# Verify output is annotated with qconfig_A16.output_activation (INT16)
105-
self.assertEqual(annotation.output_qspec, qconfig_A16.output_activation)
106-
107-
# Verify inputs: activation (INT16) and weight (INT8)
108-
self.assertEqual(len(annotation.input_qspec_map), 2)
109-
for input_node, input_qspec in annotation.input_qspec_map.items():
110-
if input_node == linear_node.args[0]:
111-
# Activation input - should be INT16
112-
self.assertEqual(input_qspec, qconfig_A16.input_activation)
113-
elif input_node == linear_node.args[1]:
114-
# Weight - should be INT8
115-
self.assertEqual(input_qspec, qconfig_A16.weight)
127+
# Verify output annotation
128+
self.assertEqual(annotation.output_qspec, expected_output_qspec)
129+
130+
# Verify input annotations
131+
# Build actual_specs in the fixed order defined by op_node.args
132+
self.assertEqual(len(annotation.input_qspec_map), len(expected_input_qspecs))
133+
actual_specs = [
134+
annotation.input_qspec_map[op_node.args[i]]
135+
for i in range(len(expected_input_qspecs))
136+
]
137+
138+
# Compare expected vs actual specs
139+
for i, (expected, actual) in enumerate(
140+
zip(expected_input_qspecs, actual_specs)
141+
):
142+
self.assertEqual(
143+
actual,
144+
expected,
145+
f"Input qspec mismatch at index {i}",
146+
)
116147

117148

118149
class QuantizerOpsPreserveTest(unittest.TestCase):

0 commit comments

Comments
 (0)