Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions backends/cadence/aot/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,9 @@ python_unittest(
typing = True,
deps = [
"//caffe2:torch",
"//executorch/backends/cadence/aot:graph_builder",
"//executorch/backends/cadence/aot/quantizer:quantizer",
"//executorch/exir:pass_base",
"//pytorch/ao:torchao",
],
)
96 changes: 96 additions & 0 deletions backends/cadence/aot/tests/test_quantizer_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,110 @@
import unittest

import torch
from executorch.backends.cadence.aot.graph_builder import GraphBuilder
from executorch.backends.cadence.aot.quantizer.patterns import AddmmPattern
from executorch.exir.pass_base import NodeMetadata

from executorch.backends.cadence.aot.quantizer.quantizer import (
CadenceAtenQuantizer,
CadenceDefaultQuantizer,
CadenceW8A32MixedQuantizer,
CadenceWith16BitLinearActivationsQuantizer,
CadenceWith16BitMatmulActivationsQuantizer,
qconfig_A16,
qconfig_A8W8,
)
from torchao.quantization.pt2e.quantizer.quantizer import (
Q_ANNOTATION_KEY,
QuantizationAnnotation,
)


class QuantizerAnnotationTest(unittest.TestCase):
"""Unit tests for verifying quantizer annotations are correctly applied."""

def _build_matmul_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
"""Build a simple graph with a matmul operation."""
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(4, 8))
y = builder.placeholder("y", torch.randn(8, 4))
matmul = builder.call_operator(
op=torch.ops.aten.matmul.default,
args=(x, y),
meta=NodeMetadata(
{"source_fn_stack": [("matmul", torch.ops.aten.matmul.default)]}
),
)
builder.output([matmul])
gm = builder.get_graph_module()

matmul_nodes = gm.graph.find_nodes(
op="call_function",
target=torch.ops.aten.matmul.default,
)
self.assertEqual(len(matmul_nodes), 1, "Should find exactly one matmul node")
return gm, matmul_nodes[0]

def test_matmul_16bit_quantizer_annotation(self) -> None:
"""Test that CadenceWith16BitMatmulActivationsQuantizer correctly annotates matmul."""
gm, matmul_node = self._build_matmul_graph()

quantizer = CadenceWith16BitMatmulActivationsQuantizer()
quantizer.annotate(gm)

annotation: QuantizationAnnotation = matmul_node.meta[Q_ANNOTATION_KEY]
self.assertTrue(annotation._annotated)

self.assertEqual(annotation.output_qspec, qconfig_A16.output_activation)

self.assertEqual(len(annotation.input_qspec_map), 2)
for _, input_qspec in annotation.input_qspec_map.items():
self.assertEqual(input_qspec, qconfig_A16.input_activation)

def _build_linear_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
"""Build a simple graph with a linear operation (no bias)."""
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(1, 10))
weight = builder.placeholder("weight", torch.randn(5, 10))
linear = builder.call_operator(
op=torch.ops.aten.linear.default,
args=(x, weight),
meta=NodeMetadata(
{"source_fn_stack": [("linear", torch.ops.aten.linear.default)]}
),
)
builder.output([linear])
gm = builder.get_graph_module()

linear_nodes = gm.graph.find_nodes(
op="call_function",
target=torch.ops.aten.linear.default,
)
self.assertEqual(len(linear_nodes), 1, "Should find exactly one linear node")
return gm, linear_nodes[0]

def test_linear_16bit_quantizer_annotation(self) -> None:
"""Test that CadenceWith16BitLinearActivationsQuantizer correctly annotates linear."""
gm, linear_node = self._build_linear_graph()

quantizer = CadenceWith16BitLinearActivationsQuantizer()
quantizer.annotate(gm)

annotation: QuantizationAnnotation = linear_node.meta[Q_ANNOTATION_KEY]
self.assertTrue(annotation._annotated)

# Verify output is annotated with qconfig_A16.output_activation (INT16)
self.assertEqual(annotation.output_qspec, qconfig_A16.output_activation)

# Verify inputs: activation (INT16) and weight (INT8)
self.assertEqual(len(annotation.input_qspec_map), 2)
for input_node, input_qspec in annotation.input_qspec_map.items():
if input_node == linear_node.args[0]:
# Activation input - should be INT16
self.assertEqual(input_qspec, qconfig_A16.input_activation)
elif input_node == linear_node.args[1]:
# Weight - should be INT8
self.assertEqual(input_qspec, qconfig_A16.weight)


Comment on lines 116 to 117
Copy link

Copilot AI Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test should verify that all inputs in input_qspec_map are expected. Currently, if an unexpected input node appears that doesn't match linear_node.args[0] or linear_node.args[1], the test will silently pass without checking its qspec. Consider adding an else clause with a self.fail() to catch unexpected inputs:

for input_node, input_qspec in annotation.input_qspec_map.items():
    if input_node == linear_node.args[0]:
        # Activation input - should be INT16
        self.assertEqual(input_qspec, qconfig_A16.input_activation)
    elif input_node == linear_node.args[1]:
        # Weight - should be INT8
        self.assertEqual(input_qspec, qconfig_A16.weight)
    else:
        self.fail(f"Unexpected input node in input_qspec_map: {input_node}")
Suggested change
else:
self.fail(f"Unexpected input node in input_qspec_map: {input_node}")

Copilot uses AI. Check for mistakes.
class QuantizerOpsPreserveTest(unittest.TestCase):
Expand Down
Loading