Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions backends/cadence/aot/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,11 @@ python_unittest(
],
typing = True,
deps = [
"fbsource//third-party/pypi/parameterized:parameterized",
"//caffe2:torch",
"//executorch/backends/cadence/aot:graph_builder",
"//executorch/backends/cadence/aot/quantizer:quantizer",
"//executorch/exir:pass_base",
"//pytorch/ao:torchao",
],
)
127 changes: 127 additions & 0 deletions backends/cadence/aot/tests/test_quantizer_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,143 @@
# pyre-strict

import unittest
from typing import Callable

import torch
from executorch.backends.cadence.aot.graph_builder import GraphBuilder
from executorch.backends.cadence.aot.quantizer.patterns import AddmmPattern

from executorch.backends.cadence.aot.quantizer.quantizer import (
CadenceAtenQuantizer,
CadenceDefaultQuantizer,
CadenceQuantizer,
CadenceW8A32MixedQuantizer,
CadenceWith16BitLinearActivationsQuantizer,
CadenceWith16BitMatmulActivationsQuantizer,
qconfig_A16,
qconfig_A8W8,
)
from executorch.exir.pass_base import NodeMetadata
from parameterized import parameterized
from torch._ops import OpOverload
from torchao.quantization.pt2e.quantizer.quantizer import (
Q_ANNOTATION_KEY,
QuantizationAnnotation,
QuantizationSpec,
)

# Type alias for graph builder functions
GraphBuilderFn = Callable[
["QuantizerAnnotationTest"], tuple[torch.fx.GraphModule, torch.fx.Node]
]


class QuantizerAnnotationTest(unittest.TestCase):
"""Unit tests for verifying quantizer annotations are correctly applied."""

def _build_matmul_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
"""Build a simple graph with a matmul operation."""
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(4, 8))
y = builder.placeholder("y", torch.randn(8, 4))
matmul = builder.call_operator(
op=torch.ops.aten.matmul.default,
args=(x, y),
meta=NodeMetadata(
{"source_fn_stack": [("matmul", torch.ops.aten.matmul.default)]}
),
)
builder.output([matmul])
gm = builder.get_graph_module()

matmul_nodes = gm.graph.find_nodes(
op="call_function",
target=torch.ops.aten.matmul.default,
)
self.assertEqual(len(matmul_nodes), 1, "Should find exactly one matmul node")
return gm, matmul_nodes[0]

def _build_linear_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
"""Build a simple graph with a linear operation (no bias)."""
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(1, 10))
weight = builder.placeholder("weight", torch.randn(5, 10))
linear = builder.call_operator(
op=torch.ops.aten.linear.default,
args=(x, weight),
meta=NodeMetadata(
{"source_fn_stack": [("linear", torch.ops.aten.linear.default)]}
),
)
builder.output([linear])
gm = builder.get_graph_module()

linear_nodes = gm.graph.find_nodes(
op="call_function",
target=torch.ops.aten.linear.default,
)
self.assertEqual(len(linear_nodes), 1, "Should find exactly one linear node")
return gm, linear_nodes[0]

@parameterized.expand(
[
(
"matmul_A16",
lambda self: self._build_matmul_graph(),
CadenceWith16BitMatmulActivationsQuantizer(),
torch.ops.aten.matmul.default,
qconfig_A16.output_activation,
# For matmul, both inputs are activations
[qconfig_A16.input_activation, qconfig_A16.input_activation],
),
(
"linear_A16",
lambda self: self._build_linear_graph(),
CadenceWith16BitLinearActivationsQuantizer(),
torch.ops.aten.linear.default,
qconfig_A16.output_activation,
# For linear: [input_activation, weight]
[qconfig_A16.input_activation, qconfig_A16.weight],
),
]
)
def test_quantizer_annotation(
self,
name: str,
graph_builder_fn: GraphBuilderFn,
quantizer: CadenceQuantizer,
target: OpOverload,
Copy link

Copilot AI Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The target parameter is unused in the test function. Consider removing it from the parameter list (lines 94 and 103) if it's not needed, or use it to validate that op_node.target matches the expected target operation for additional test robustness.

Copilot uses AI. Check for mistakes.
expected_output_qspec: QuantizationSpec,
expected_input_qspecs: list[QuantizationSpec],
) -> None:
"""Parameterized test for quantizer annotations."""
gm, op_node = graph_builder_fn(self)

quantizer.annotate(gm)

annotation: QuantizationAnnotation = op_node.meta[Q_ANNOTATION_KEY]
self.assertTrue(annotation._annotated)

# Verify output annotation
self.assertEqual(annotation.output_qspec, expected_output_qspec)

# Verify input annotations
# Build actual_specs in the fixed order defined by op_node.args
self.assertEqual(len(annotation.input_qspec_map), len(expected_input_qspecs))
actual_specs = [
annotation.input_qspec_map[op_node.args[i]]
for i in range(len(expected_input_qspecs))
]
Comment on lines +133 to +136
Copy link

Copilot AI Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential KeyError: accessing annotation.input_qspec_map[op_node.args[i]] could fail if op_node.args[i] is not in the map. Consider adding a check or using .get() with a clear error message to make debugging easier if the key is missing.

Suggested change
actual_specs = [
annotation.input_qspec_map[op_node.args[i]]
for i in range(len(expected_input_qspecs))
]
actual_specs = []
for i in range(len(expected_input_qspecs)):
key = op_node.args[i]
if key not in annotation.input_qspec_map:
raise KeyError(
f"Key {key!r} not found in input_qspec_map. "
f"Available keys: {list(annotation.input_qspec_map.keys())}"
)
actual_specs.append(annotation.input_qspec_map[key])

Copilot uses AI. Check for mistakes.

# Compare expected vs actual specs
for i, (expected, actual) in enumerate(
zip(expected_input_qspecs, actual_specs)
):
self.assertEqual(
actual,
expected,
f"Input qspec mismatch at index {i}",
)


class QuantizerOpsPreserveTest(unittest.TestCase):
Expand Down
Loading