From 8f616b41f7d64e283e2676b6b7449c78e56a5bc2 Mon Sep 17 00:00:00 2001 From: Ethan Ng Date: Fri, 5 Dec 2025 15:05:53 -0800 Subject: [PATCH] Add RmsNormNopQuantizer and Pattern Differential Revision: D88520820 --- backends/cadence/aot/quantizer/patterns.py | 15 +++++++++++++++ backends/cadence/aot/quantizer/quantizer.py | 12 ++++++++++-- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/backends/cadence/aot/quantizer/patterns.py b/backends/cadence/aot/quantizer/patterns.py index f8131eb202a..1ec5ee28be1 100644 --- a/backends/cadence/aot/quantizer/patterns.py +++ b/backends/cadence/aot/quantizer/patterns.py @@ -721,3 +721,18 @@ def __init__(self, args, meta): def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_w8a32_gru.default + + +class _RmsNormPattern(QuantizationPattern): + """Pattern that preserves rms_norm from decomposition without matching anything.""" + + def partition_types(self) -> list[torch._ops.OpOverload]: + return [torch.ops.aten.rms_norm.default] + + def get_anchors( + self, gm: torch.fx.GraphModule, fused_partition: List[fx.GraphModule] + ) -> Tuple[PartitionAnchors, fx.Node]: + return PartitionAnchors(empty=True), None # pyre-ignore[7] + + def replacement_op(self) -> torch._ops.OpOverload: + return torch.ops.aten.rms_norm.default diff --git a/backends/cadence/aot/quantizer/quantizer.py b/backends/cadence/aot/quantizer/quantizer.py index e0256437022..704fe58babe 100644 --- a/backends/cadence/aot/quantizer/quantizer.py +++ b/backends/cadence/aot/quantizer/quantizer.py @@ -11,6 +11,7 @@ import torch from executorch.backends.cadence.aot.quantizer.patterns import ( + _RmsNormPattern, AddmmPattern, AddPattern, BmmPattern, @@ -37,9 +38,7 @@ is_annotated, no_outside_users, ) - from torch import fx - from torchao.quantization.pt2e import HistogramObserver, MinMaxObserver from torchao.quantization.pt2e.quantizer import ( ComposableQuantizer, @@ -285,6 +284,15 @@ def __init__( super().__init__([]) +class CadenceRmsNormNopQuantizer(CadenceQuantizer): + """ + Nop quantizer that preserves rms_norm from decomposition. + """ + + def __init__(self) -> None: + super().__init__([CadenceAtenQuantizer(_RmsNormPattern(), qconfig_A8W8)]) + + class CadenceWithLayerNormQuantizer(CadenceQuantizer): """ Quantizer including layer norm