Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions backends/cadence/aot/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,9 @@ python_unittest(
typing = True,
deps = [
"//caffe2:torch",
"//executorch/backends/cadence/aot:graph_builder",
"//executorch/backends/cadence/aot/quantizer:quantizer",
"//executorch/exir:pass_base",
"//pytorch/ao:torchao",
],
)
50 changes: 50 additions & 0 deletions backends/cadence/aot/tests/test_quantizer_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,64 @@
import unittest

import torch
from executorch.backends.cadence.aot.graph_builder import GraphBuilder
from executorch.backends.cadence.aot.quantizer.patterns import AddmmPattern
from executorch.exir.pass_base import NodeMetadata

from executorch.backends.cadence.aot.quantizer.quantizer import (
CadenceAtenQuantizer,
CadenceDefaultQuantizer,
CadenceW8A32MixedQuantizer,
CadenceWith16BitMatmulActivationsQuantizer,
qconfig_A16,
qconfig_A8W8,
)
from torchao.quantization.pt2e.quantizer.quantizer import (
Q_ANNOTATION_KEY,
QuantizationAnnotation,
)


class QuantizerAnnotationTest(unittest.TestCase):
"""Unit tests for verifying quantizer annotations are correctly applied."""

def _build_matmul_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
"""Build a simple graph with a matmul operation."""
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(4, 8))
y = builder.placeholder("y", torch.randn(8, 4))
matmul = builder.call_operator(
op=torch.ops.aten.matmul.default,
args=(x, y),
meta=NodeMetadata(
{"source_fn_stack": [("matmul", torch.ops.aten.matmul.default)]}
),
)
builder.output([matmul])
gm = builder.get_graph_module()

matmul_nodes = gm.graph.find_nodes(
op="call_function",
target=torch.ops.aten.matmul.default,
)
self.assertEqual(len(matmul_nodes), 1, "Should find exactly one matmul node")
return gm, matmul_nodes[0]

def test_matmul_16bit_quantizer_annotation(self) -> None:
"""Test that CadenceWith16BitMatmulActivationsQuantizer correctly annotates matmul."""
gm, matmul_node = self._build_matmul_graph()

quantizer = CadenceWith16BitMatmulActivationsQuantizer()
quantizer.annotate(gm)

annotation: QuantizationAnnotation = matmul_node.meta[Q_ANNOTATION_KEY]
self.assertTrue(annotation._annotated)

self.assertEqual(annotation.output_qspec, qconfig_A16.output_activation)

self.assertEqual(len(annotation.input_qspec_map), 2)
for _, input_qspec in annotation.input_qspec_map.items():
self.assertEqual(input_qspec, qconfig_A16.input_activation)


class QuantizerOpsPreserveTest(unittest.TestCase):
Expand Down
Loading