Skip to content

Commit 971cf59

Browse files
RahulC7facebook-github-bot
authored andcommitted
Adding Test For CadenceWith16BitLinearActivationsQuantizer (pytorch#16097)
Summary: We test the CadenceWith16BitLinearActivationQuantizer. We use the graph builder to build the graph with metadata(that's needed for quantizer.annotate to recognize the nodes), and we ensure that the quantization params are as expected. Reviewed By: hsharma35 Differential Revision: D88054651
1 parent 67878e4 commit 971cf59

File tree

1 file changed

+46
-0
lines changed

1 file changed

+46
-0
lines changed

backends/cadence/aot/tests/test_quantizer_ops.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
CadenceAtenQuantizer,
1818
CadenceDefaultQuantizer,
1919
CadenceW8A32MixedQuantizer,
20+
CadenceWith16BitLinearActivationsQuantizer,
2021
CadenceWith16BitMatmulActivationsQuantizer,
2122
qconfig_A16,
2223
qconfig_A8W8,
@@ -68,6 +69,51 @@ def test_matmul_16bit_quantizer_annotation(self) -> None:
6869
for _, input_qspec in annotation.input_qspec_map.items():
6970
self.assertEqual(input_qspec, qconfig_A16.input_activation)
7071

72+
def _build_linear_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
73+
"""Build a simple graph with a linear operation (no bias)."""
74+
builder = GraphBuilder()
75+
x = builder.placeholder("x", torch.randn(1, 10))
76+
weight = builder.placeholder("weight", torch.randn(5, 10))
77+
linear = builder.call_operator(
78+
op=torch.ops.aten.linear.default,
79+
args=(x, weight),
80+
meta=NodeMetadata(
81+
{"source_fn_stack": [("linear", torch.ops.aten.linear.default)]}
82+
),
83+
)
84+
builder.output([linear])
85+
gm = builder.get_graph_module()
86+
87+
linear_nodes = gm.graph.find_nodes(
88+
op="call_function",
89+
target=torch.ops.aten.linear.default,
90+
)
91+
self.assertEqual(len(linear_nodes), 1, "Should find exactly one linear node")
92+
return gm, linear_nodes[0]
93+
94+
def test_linear_16bit_quantizer_annotation(self) -> None:
95+
"""Test that CadenceWith16BitLinearActivationsQuantizer correctly annotates linear."""
96+
gm, linear_node = self._build_linear_graph()
97+
98+
quantizer = CadenceWith16BitLinearActivationsQuantizer()
99+
quantizer.annotate(gm)
100+
101+
annotation: QuantizationAnnotation = linear_node.meta[Q_ANNOTATION_KEY]
102+
self.assertTrue(annotation._annotated)
103+
104+
# Verify output is annotated with qconfig_A16.output_activation (INT16)
105+
self.assertEqual(annotation.output_qspec, qconfig_A16.output_activation)
106+
107+
# Verify inputs: activation (INT16) and weight (INT8)
108+
self.assertEqual(len(annotation.input_qspec_map), 2)
109+
for input_node, input_qspec in annotation.input_qspec_map.items():
110+
if input_node == linear_node.args[0]:
111+
# Activation input - should be INT16
112+
self.assertEqual(input_qspec, qconfig_A16.input_activation)
113+
elif input_node == linear_node.args[1]:
114+
# Weight - should be INT8
115+
self.assertEqual(input_qspec, qconfig_A16.weight)
116+
71117

72118
class QuantizerOpsPreserveTest(unittest.TestCase):
73119
def test_mixed_w8a32_ops_to_preserve(self) -> None:

0 commit comments

Comments
 (0)