|
17 | 17 | CadenceAtenQuantizer, |
18 | 18 | CadenceDefaultQuantizer, |
19 | 19 | CadenceW8A32MixedQuantizer, |
| 20 | + CadenceWith16BitLinearActivationsQuantizer, |
20 | 21 | CadenceWith16BitMatmulActivationsQuantizer, |
21 | 22 | qconfig_A16, |
22 | 23 | qconfig_A8W8, |
@@ -68,6 +69,51 @@ def test_matmul_16bit_quantizer_annotation(self) -> None: |
68 | 69 | for _, input_qspec in annotation.input_qspec_map.items(): |
69 | 70 | self.assertEqual(input_qspec, qconfig_A16.input_activation) |
70 | 71 |
|
| 72 | + def _build_linear_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: |
| 73 | + """Build a simple graph with a linear operation (no bias).""" |
| 74 | + builder = GraphBuilder() |
| 75 | + x = builder.placeholder("x", torch.randn(1, 10)) |
| 76 | + weight = builder.placeholder("weight", torch.randn(5, 10)) |
| 77 | + linear = builder.call_operator( |
| 78 | + op=torch.ops.aten.linear.default, |
| 79 | + args=(x, weight), |
| 80 | + meta=NodeMetadata( |
| 81 | + {"source_fn_stack": [("linear", torch.ops.aten.linear.default)]} |
| 82 | + ), |
| 83 | + ) |
| 84 | + builder.output([linear]) |
| 85 | + gm = builder.get_graph_module() |
| 86 | + |
| 87 | + linear_nodes = gm.graph.find_nodes( |
| 88 | + op="call_function", |
| 89 | + target=torch.ops.aten.linear.default, |
| 90 | + ) |
| 91 | + self.assertEqual(len(linear_nodes), 1, "Should find exactly one linear node") |
| 92 | + return gm, linear_nodes[0] |
| 93 | + |
| 94 | + def test_linear_16bit_quantizer_annotation(self) -> None: |
| 95 | + """Test that CadenceWith16BitLinearActivationsQuantizer correctly annotates linear.""" |
| 96 | + gm, linear_node = self._build_linear_graph() |
| 97 | + |
| 98 | + quantizer = CadenceWith16BitLinearActivationsQuantizer() |
| 99 | + quantizer.annotate(gm) |
| 100 | + |
| 101 | + annotation: QuantizationAnnotation = linear_node.meta[Q_ANNOTATION_KEY] |
| 102 | + self.assertTrue(annotation._annotated) |
| 103 | + |
| 104 | + # Verify output is annotated with qconfig_A16.output_activation (INT16) |
| 105 | + self.assertEqual(annotation.output_qspec, qconfig_A16.output_activation) |
| 106 | + |
| 107 | + # Verify inputs: activation (INT16) and weight (INT8) |
| 108 | + self.assertEqual(len(annotation.input_qspec_map), 2) |
| 109 | + for input_node, input_qspec in annotation.input_qspec_map.items(): |
| 110 | + if input_node == linear_node.args[0]: |
| 111 | + # Activation input - should be INT16 |
| 112 | + self.assertEqual(input_qspec, qconfig_A16.input_activation) |
| 113 | + elif input_node == linear_node.args[1]: |
| 114 | + # Weight - should be INT8 |
| 115 | + self.assertEqual(input_qspec, qconfig_A16.weight) |
| 116 | + |
71 | 117 |
|
72 | 118 | class QuantizerOpsPreserveTest(unittest.TestCase): |
73 | 119 | def test_mixed_w8a32_ops_to_preserve(self) -> None: |
|
0 commit comments