diff --git a/backends/cadence/aot/quantizer/patterns.py b/backends/cadence/aot/quantizer/patterns.py index f8131eb202a..7a11541b601 100644 --- a/backends/cadence/aot/quantizer/patterns.py +++ b/backends/cadence/aot/quantizer/patterns.py @@ -721,3 +721,18 @@ def __init__(self, args, meta): def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_w8a32_gru.default + + +class RmsNormPattern(QuantizationPattern): + """Pattern that preserves rms_norm from decomposition without matching anything.""" + + def partition_types(self) -> list[torch._ops.OpOverload]: + return [torch.ops.aten.rms_norm.default] + + def get_anchors( + self, gm: torch.fx.GraphModule, fused_partition: List[fx.GraphModule] + ) -> Tuple[PartitionAnchors, fx.Node]: + return PartitionAnchors(empty=True), None # pyre-ignore[7] + + def replacement_op(self) -> torch._ops.OpOverload: + return torch.ops.aten.rms_norm.default diff --git a/backends/cadence/aot/quantizer/quantizer.py b/backends/cadence/aot/quantizer/quantizer.py index e0256437022..bdd4cc810a0 100644 --- a/backends/cadence/aot/quantizer/quantizer.py +++ b/backends/cadence/aot/quantizer/quantizer.py @@ -30,6 +30,7 @@ QuantizationPattern, ReluPattern0, ReluPattern1, + RmsNormPattern, SoftmaxPattern, ) from executorch.backends.cadence.aot.quantizer.utils import ( @@ -37,9 +38,7 @@ is_annotated, no_outside_users, ) - from torch import fx - from torchao.quantization.pt2e import HistogramObserver, MinMaxObserver from torchao.quantization.pt2e.quantizer import ( ComposableQuantizer, @@ -285,6 +284,15 @@ def __init__( super().__init__([]) +class CadenceRmsNormNopQuantizer(CadenceQuantizer): + """ + Nop quantizer that preserves rms_norm from decomposition. + """ + + def __init__(self) -> None: + super().__init__([CadenceAtenQuantizer(RmsNormPattern(), qconfig_A8W8)]) + + class CadenceWithLayerNormQuantizer(CadenceQuantizer): """ Quantizer including layer norm