From 432772765985e99dea97d98cb7960681543f5b53 Mon Sep 17 00:00:00 2001 From: Andrew Grebenisan Date: Fri, 5 Dec 2025 10:13:00 -0800 Subject: [PATCH] Update replace ops to correctly set modified bit Summary: Updated - ReplaceAtenAvgPoolWithCadenceAvgPoolPass - ReplaceIm2RowWithViewPass - ReplaceEmptyTensorsWithFullPass (doesn't use new interface, just correctly only runs full export pass if necessary) - ReplaceWhereWithFullArgsWithWhereScalar - ReplaceMulTensorWithMulAndFullOpsPass - ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass Reviewed By: hsharma35 Differential Revision: D87877176 --- backends/cadence/aot/replace_ops.py | 392 ++++++++++-------- .../aot/tests/test_replace_ops_passes.py | 294 +++++++++++-- 2 files changed, 483 insertions(+), 203 deletions(-) diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index ccadd1e7a88..99f09439d0c 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -15,7 +15,7 @@ import math import operator from operator import neg -from typing import cast, Dict, Iterable, Optional, Sequence, Tuple +from typing import cast, Dict, Iterable, Optional, Sequence import torch import torch.fx @@ -1687,58 +1687,73 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplaceAtenAvgPoolWithCadenceAvgPoolPass(ExportPass): +class ReplaceAtenAvgPoolWithCadenceAvgPoolPass(RemoveOrReplacePassInterface): """ Replace the aten avg_pool op with the cadence custom avg_pool2d op. """ - def call_operator(self, op, args, kwargs, meta): - # Only continue for avg_pool op - if op not in { + @property + def targets(self) -> list[EdgeOpOverload]: + return [ exir_ops.edge.aten.avg_pool1d.default, exir_ops.edge.aten.avg_pool2d.default, - }: - return super().call_operator(op, args, kwargs, meta) + ] + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: # Determine if the op is avg_pool1d or avg_pool2d - avg_pool1d: bool = op == exir_ops.edge.aten.avg_pool1d.default - # Get the input tensor - in_tensor = args[0].to_tensor() + avg_pool1d: bool = node.target == exir_ops.edge.aten.avg_pool1d.default + + # Get the input tensor node + in_tensor_node = node.args[0] + assert isinstance(in_tensor_node, torch.fx.Node) # Replace avg_pool2d with custom avg_pool2d, and if the input tensor is # quantized, pass its zero_point tensor as arg to the custom avg_pool2d. # stride, padding, ceil_mode, count_include_pad, divisor_override, are # the native avg_pool2d args. 'channel_last' denotes NCHW vs NHWC layout, # and is False by default. - kernel_size = args[1] - stride = args[2] if len(args) >= 3 else [1, 1] - padding = args[3] if len(args) >= 4 else [0, 0] - ceil_mode = args[4] if len(args) >= 5 else False - count_include_pad = args[5] if len(args) >= 6 else True - divisor_override = args[6] if len(args) >= 7 else None - zero_point = args[7] if len(args) >= 8 else None + kernel_size = node.args[1] + # When stride is not provided or is empty, PyTorch defaults to kernel_size + stride = node.args[2] if len(node.args) >= 3 and node.args[2] else kernel_size + padding = node.args[3] if len(node.args) >= 4 else [0, 0] + ceil_mode = node.args[4] if len(node.args) >= 5 else False + count_include_pad = node.args[5] if len(node.args) >= 6 else True + divisor_override = node.args[6] if len(node.args) >= 7 else None + zero_point = node.args[7] if len(node.args) >= 8 else None + + graph = node.graph + out_shape = node.meta["val"].shape + + kernel_size = cast(Sequence[int], kernel_size) + stride = cast(Sequence[int], stride) + padding = cast(Sequence[int], padding) # If the op is avg_pool1d, then we need to reshape the 3d input to a 4d # tensor. if avg_pool1d: - in_shape = list(in_tensor.shape) + in_shape = list(in_tensor_node.meta["val"].shape) assert len(in_shape) == 3, "Expected 3d input for avg_pool1d" - in_shape.insert(2, 1) - out_shape = meta["val"].shape - in_view_op = super().call_operator( - exir_ops.edge.aten.view_copy.default, - (in_tensor, in_shape), - kwargs, - meta, - ) + in_shape_4d = in_shape[:2] + [1] + in_shape[2:] + + with graph.inserting_before(node): + in_view_node = graph.call_function( + exir_ops.edge.aten.view_copy.default, + args=(in_tensor_node, in_shape_4d), + ) + in_view_node.meta = node.meta + # Extend the kernel_size, stride and padding to 2d - kernel_size = [1] + kernel_size if len(kernel_size) == 1 else kernel_size - stride = [1] + stride if len(stride) == 1 else stride - padding = [0] + padding if len(padding) == 1 else padding + kernel_size = [1] + list(kernel_size) if len(kernel_size) == 1 else kernel_size + stride = [1] + list(stride) if len(stride) == 1 else stride + padding = [0] + list(padding) if len(padding) == 1 else padding + + input_for_pool = in_view_node + else: + input_for_pool = in_tensor_node # Create a new avg_pool node with the updated args new_args = ( - in_view_op if avg_pool1d else args[0], + input_for_pool, kernel_size, stride, padding, @@ -1748,70 +1763,66 @@ def call_operator(self, op, args, kwargs, meta): zero_point, False, ) - avg_pool2d_op = super().call_operator( - exir_ops.edge.cadence.avg_pool2d.default, - new_args, - kwargs, - meta, - ) - # If the node was avg_pool1d, we again reshape the 4d output to 3d output - return ( - super().call_operator( - exir_ops.edge.aten.view_copy.default, - (avg_pool2d_op, list(out_shape)), - kwargs, - meta, + with graph.inserting_before(node): + avg_pool2d_node = graph.call_function( + exir_ops.edge.cadence.avg_pool2d.default, + args=new_args, ) - if avg_pool1d - else avg_pool2d_op - ) + avg_pool2d_node.meta = node.meta + + # If the node was avg_pool1d, we again reshape the 4d output to 3d output + if avg_pool1d: + with graph.inserting_before(node): + result_node = graph.call_function( + exir_ops.edge.aten.view_copy.default, + args=(avg_pool2d_node, list(out_shape)), + ) + result_node.meta = node.meta + node.replace_all_uses_with(result_node) + else: + node.replace_all_uses_with(avg_pool2d_node) + + return True @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class ReplaceIm2RowWithViewPass(ExportPass): - def can_replace(self, op, args, kwargs, meta) -> bool: - if op != exir_ops.edge.cadence.im2row.default: - return False +class ReplaceIm2RowWithViewPass(RemoveOrReplacePassInterface): + """ + Replace im2row with view when possible (no padding, no dilation, and output spatial dimensions are 1). + """ + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.cadence.im2row.default] + + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: # Check if im2row applies padding. If yes, we cannot replace it with view. - pad = cast(tuple[int, ...], args[3]) + pad = cast(Sequence[int], node.args[3]) if any(p != 0 for p in pad): return False # Check if im2row has dilation. If yes, we cannot replace it with view. - dilation = cast(tuple[int, ...], args[2]) + dilation = cast(Sequence[int], node.args[2]) if any(d != 1 for d in dilation): return False # im2row works on 3D or 4D tensors. # Output shape[1:-1] will be unit if input spatial dimensions are the same as kernel spatial dimensions. - output_shape = meta["val"].shape - if math.prod(output_shape[1:-1]) == 1: - return True - - return False - - def call_operator( - self, - op, - args: tuple[Argument, ...], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op != exir_ops.edge.cadence.im2row.default: - return super().call_operator(op, args, kwargs, meta) + output_shape = node.meta["val"].shape + if math.prod(output_shape[1:-1]) != 1: + return False - if not self.can_replace(op, args, kwargs, meta): - return super().call_operator(op, args, kwargs, meta) + # Replace im2row with view_copy + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + exir_ops.edge.aten.view_copy.default, + args=(node.args[0], list(output_shape)), + ) + new_node.meta = node.meta - output_shape = meta["val"].shape - return super().call_operator( - exir_ops.edge.aten.view_copy.default, - (args[0], tuple(output_shape)), - kwargs, - meta, - ) + node.replace_all_uses_with(new_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=1)) @@ -1830,57 +1841,84 @@ def call_operator(self, op, args, kwargs, meta): return super().call_operator(op, args, kwargs, meta) def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - ret = super().call(graph_module) - modified = ret.graph_module.graph.eliminate_dead_code() or ret.modified - return PassResult(ret.graph_module, modified) + changed = False + for module in filter( + lambda m: isinstance(m, torch.fx.GraphModule), graph_module.modules() + ): + module = cast(torch.fx.GraphModule, module) + for node in module.graph.nodes: + if node.op != "call_function": + continue + val = node.meta.get("val", None) + if isinstance(val, torch.Tensor) and val.numel() == 0: + with module.graph.inserting_before(node): + new_node = module.graph.call_function( + exir_ops.edge.aten.full.default, + args=(val.shape, 0), + kwargs={"dtype": val.dtype}, + ) + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + changed = True + + if changed: + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + return super().call(graph_module) + + return PassResult(graph_module, False) @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class ReplaceWhereWithFullArgsWithWhereScalar(ExportPass): +class ReplaceWhereWithFullArgsWithWhereScalar(RemoveOrReplacePassInterface): """Replaces where ops using two full ops as tensors with a scalar version. """ - def call_operator( - self, - op, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op not in { - exir_ops.edge.aten.where.self, - }: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.where.self] - # If the args are not full ops, bail - # pyre-ignore[16]: `ProxyValue` has no attribute `node`. - if (args[1].node.target != exir_ops.edge.aten.full.default) or ( - args[2].node.target != exir_ops.edge.aten.full.default - ): - return super().call_operator(op, args, kwargs, meta) + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + # Check if args[1] and args[2] are full ops + arg1 = node.args[1] + arg2 = node.args[2] + + if not isinstance(arg1, torch.fx.Node) or not isinstance(arg2, torch.fx.Node): + return False - # If one of the full ops is a different size than than the cond tensor, we need to broadcast. Bail. if ( - # pyre-ignore[16]: `ProxyValue` has no attribute `node`. - list(args[0].to_tensor().shape) != args[1].node.args[0] - or list(args[0].to_tensor().shape) != args[2].node.args[0] + arg1.target != exir_ops.edge.aten.full.default + or arg2.target != exir_ops.edge.aten.full.default ): - return super().call_operator(op, args, kwargs, meta) + return False + + # Get the condition tensor shape + cond_arg = node.args[0] + assert isinstance(cond_arg, torch.fx.Node) + cond_shape = list(cond_arg.meta["val"].shape) + + # Check if the full ops have the same size as the cond tensor + full1_shape = arg1.args[0] + full2_shape = arg2.args[0] + + if cond_shape != full1_shape or cond_shape != full2_shape: + return False # Get the scalar values from the full ops - scalar_value_1 = args[1].node.args[1] - scalar_value_2 = args[2].node.args[1] + scalar_value_1 = arg1.args[1] + scalar_value_2 = arg2.args[1] # Replace the where op with a scalar where op - return super().call_operator( - exir_ops.edge.cadence.where_Scalar.default, - (args[0], scalar_value_1, scalar_value_2), - kwargs, - meta, - ) + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + exir_ops.edge.cadence.where_Scalar.default, + args=(cond_arg, scalar_value_1, scalar_value_2), + ) + new_node.meta = node.meta - return super().call_operator(op, args, kwargs, meta) + node.replace_all_uses_with(new_node) + return True # Adapted from fbcode/pyspeech/opt_passes/replace_ops.py @@ -2116,53 +2154,56 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class ReplaceMulTensorWithMulAndFullOpsPass(ExportPass): +class ReplaceMulTensorWithMulAndFullOpsPass(RemoveOrReplacePassInterface): """ Extracts a single value argument of mul op to a separate full op. """ - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - for mul_node in graph_module.graph.find_nodes( - op="call_function", target=torch.ops.aten.mul.Tensor - ): - x_arg, const_arg = mul_node.args + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.mul.Tensor] - # Swap arguments if the order is wrong - if isinstance(const_arg, torch.fx.Node): - x_arg, const_arg = const_arg, x_arg + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + x_arg, const_arg = node.args - # Skip if the const_arg is not a scalar - if not isinstance(const_arg, (float, int)) or not isinstance( - x_arg, torch.fx.Node - ): - continue + # Swap arguments if the order is wrong + if isinstance(const_arg, torch.fx.Node): + x_arg, const_arg = const_arg, x_arg - # Cast the const_arg to the dtype of the x_arg - full_arg = self.resolve_full_arg(x_arg, const_arg) + # Skip if the const_arg is not a scalar + if not isinstance(const_arg, (float, int)) or not isinstance( + x_arg, torch.fx.Node + ): + return False - full_output_dtype = ( - torch.int32 if isinstance(full_arg, int) else torch.float32 - ) + # Cast the const_arg to the dtype of the x_arg + full_arg = self.resolve_full_arg(x_arg, const_arg) - # Extract an argument to a separate full op. - with graph_module.graph.inserting_before(mul_node): - full_node = graph_module.graph.call_function( - torch.ops.aten.full.default, - args=([1], full_arg), - kwargs={"dtype": full_output_dtype}, - ) - full_node.meta = mul_node.meta - full_node.meta["val"] = [1] - new_mul_node = graph_module.graph.call_function( - torch.ops.aten.mul.Tensor, args=(x_arg, full_node) - ) - new_mul_node.meta = mul_node.meta - # Replace the old mul with a newly created mul. - mul_node.replace_all_uses_with(new_mul_node) - graph_module.graph.erase_node(mul_node) - return super().call(graph_module) + full_output_dtype = ( + torch.int32 if isinstance(full_arg, int) else torch.float32 + ) + + # Extract an argument to a separate full op. + with node.graph.inserting_before(node): + full_node = node.graph.call_function( + exir_ops.edge.aten.full.default, + args=([1], full_arg), + kwargs={"dtype": full_output_dtype}, + ) + full_node.meta = node.meta + full_node.meta["val"] = [1] + new_mul_node = node.graph.call_function( + exir_ops.edge.aten.mul.Tensor, args=(x_arg, full_node) + ) + new_mul_node.meta = node.meta + # Replace the old mul with a newly created mul. + node.replace_all_uses_with(new_mul_node) + node.graph.erase_node(node) + return True - def resolve_full_arg(self, x_arg, const_arg): + def resolve_full_arg( + self, x_arg: torch.fx.Node, const_arg: float | int + ) -> float | int: if x_arg.meta["val"].dtype == torch.float32 and isinstance(const_arg, int): const_arg = float(const_arg) if x_arg.meta["val"].dtype == torch.int32 and isinstance(const_arg, float): @@ -2171,40 +2212,41 @@ def resolve_full_arg(self, x_arg, const_arg): @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass(ExportPass): +class ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass(RemoveOrReplacePassInterface): """ Replace the aten adaptive avg_pool op with the aten avg_pool2d op. """ - def call_operator(self, op, args, kwargs, meta): - # Only continue for avg_pool op - if op not in {exir_ops.edge.aten._adaptive_avg_pool2d.default}: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten._adaptive_avg_pool2d.default] - # Get the input tensor - in_tensor = args[0].to_tensor() - # Permute NCHW to NHWC for computation - in_tensor_permuted = in_tensor.permute(0, 2, 3, 1) - in_tensor_shape = in_tensor_permuted.shape + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + # Get the input tensor node + in_tensor_node = node.args[0] + assert isinstance(in_tensor_node, torch.fx.Node) - output_size = args[1] + # Get input shape (in NCHW format) + in_shape = in_tensor_node.meta["val"].shape + output_size = cast(Sequence[int], node.args[1]) num_dims = len(output_size) + # Spatial dimensions are at indices [2:] for NCHW format # TODO: If in_tensor_shape is not a multiple of output size, # this pass will not work. T224984800 dim_multiples = [ - (in_tensor_shape[i + 1] % output_size[i]) == 0 for i in range(num_dims) + (in_shape[i + 2] % output_size[i]) == 0 for i in range(num_dims) ] if not all(dim_multiples): logging.info( - f"Unable to replace adaptive average pool with average pool. Input tensor shape of {in_tensor_shape} is not a multiple of output size: {output_size}" + f"Unable to replace adaptive average pool with average pool. Input tensor shape of {in_shape} is not a multiple of output size: {output_size}" ) - return super().call_operator(op, args, kwargs, meta) + return False - # Compute stride and kernel_size, then set default values for other arguments - stride = [(in_tensor_shape[i + 1] // output_size[i]) for i in range(num_dims)] + # Compute stride and kernel_size based on spatial dimensions + stride = [(in_shape[i + 2] // output_size[i]) for i in range(num_dims)] kernel_size = [ - in_tensor_shape[i + 1] - (output_size[i] - 1) * stride[i] + in_shape[i + 2] - (output_size[i] - 1) * stride[i] for i in range(num_dims) ] padding = [0] * num_dims @@ -2212,9 +2254,9 @@ def call_operator(self, op, args, kwargs, meta): count_include_pad = True divisor_override = None - # Create a new avg_pool node with the updated args + # Create a new avg_pool2d node with the computed args new_args = ( - args[0], + in_tensor_node, kernel_size, stride, padding, @@ -2222,12 +2264,16 @@ def call_operator(self, op, args, kwargs, meta): count_include_pad, divisor_override, ) - return super().call_operator( - exir_ops.edge.aten.avg_pool2d.default, - new_args, - kwargs, - meta, - ) + + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + exir_ops.edge.aten.avg_pool2d.default, + args=new_args, + ) + new_node.meta = node.meta + + node.replace_all_uses_with(new_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=0)) diff --git a/backends/cadence/aot/tests/test_replace_ops_passes.py b/backends/cadence/aot/tests/test_replace_ops_passes.py index cc891da4f46..d961ce595ea 100644 --- a/backends/cadence/aot/tests/test_replace_ops_passes.py +++ b/backends/cadence/aot/tests/test_replace_ops_passes.py @@ -23,6 +23,7 @@ MakeSliceAndCatDimOutermostPass, ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass, ReplaceAddMMWithLinearPass, + ReplaceAtenAvgPoolWithCadenceAvgPoolPass, ReplaceAtenConvolutionWithCadenceConvolutionPass, ReplaceAtenLinalgSvdWithCadenceLinalgSvdPass, ReplaceConstantPadNdWithSlicePass, @@ -1412,9 +1413,11 @@ def test_replace_permute_with_transpose_nop( count_node(graph_after_passes, exir_ops.edge.aten.transpose_copy.int), 0 ) +class TestReplaceWhereWithFullArgsWithWhereScalar(unittest.TestCase): def test_replace_aten_where_with_cadence(self) -> None: builder = GraphBuilder() - cond = builder.placeholder("cond", torch.randn(4, 8)) + cond_input = torch.randn(4, 8) + cond = builder.placeholder("cond", cond_input) aten_gt_scalar = builder.call_operator( op=exir_ops.edge.aten.gt.Scalar, args=(cond, 0), @@ -1433,8 +1436,24 @@ def test_replace_aten_where_with_cadence(self) -> None: ) builder.output([aten_where_self]) original_gm = builder.get_graph_module() + + # Deepcopy before the pass + gm_before = copy.deepcopy(original_gm) + p = ReplaceWhereWithFullArgsWithWhereScalar() - graph_after_passes = cast(PassResult, p(original_gm)).graph_module + result = cast(PassResult, p(original_gm)) + self.assertTrue(result.modified) + graph_after_passes = result.graph_module + + # Validate numerical accuracy + inputs = [cond_input] + validate( + gm_before, + graph_after_passes, + inputs, + "ReplaceWhereWithFullArgsWithWhereScalar", + ) + self.assertEqual( count_node( graph_after_passes, @@ -1462,9 +1481,9 @@ def test_replace_aten_where_with_cadence_broadcast( val1: float, val2: float, ) -> None: - # cond_shape, a_shape, b_shape, val1, val2 = builder = GraphBuilder() - cond = builder.placeholder("cond", torch.randn(cond_shape)) + cond_input = torch.randn(cond_shape) + cond = builder.placeholder("cond", cond_input) aten_gt_scalar = builder.call_operator( op=exir_ops.edge.aten.gt.Scalar, args=(cond, 0), @@ -1483,8 +1502,25 @@ def test_replace_aten_where_with_cadence_broadcast( ) builder.output([aten_where_self]) original_gm = builder.get_graph_module() + + # Deepcopy before the pass + gm_before = copy.deepcopy(original_gm) + p = ReplaceWhereWithFullArgsWithWhereScalar() - graph_after_passes = cast(PassResult, p(original_gm)).graph_module + result = cast(PassResult, p(original_gm)) + # Broadcast case should not be replaced + self.assertFalse(result.modified) + graph_after_passes = result.graph_module + + # Validate numerical accuracy (should be same since not modified) + inputs = [cond_input] + validate( + gm_before, + graph_after_passes, + inputs, + "ReplaceWhereWithFullArgsWithWhereScalar", + ) + self.assertEqual( count_node( graph_after_passes, @@ -1600,7 +1636,7 @@ class TestReplaceIm2rowWithViewPass(unittest.TestCase): def test_no_replacement_for_conv(self) -> None: # Create a graph with a single im2row node. x = torch.randn(1, 3, 224, 224) - pad_value = torch.randn(1) + pad_value = torch.tensor(0, dtype=torch.int32) channels_last = False gm = single_op_builder( placeholders=(x, pad_value), @@ -1612,9 +1648,19 @@ def test_no_replacement_for_conv(self) -> None: self.assertEqual(count_node(gm, exir_ops.edge.cadence.im2row.default), 1) self.assertEqual(count_node(gm, exir_ops.edge.aten.view_copy.default), 0) + # Deepcopy before the pass + gm_before = copy.deepcopy(gm) + # Apply replacement pass. p = ReplaceIm2RowWithViewPass() - gm_after_replacement = p.call(gm).graph_module + result = p.call(gm) + self.assertFalse(result.modified) + gm_after_replacement = result.graph_module + + # Validate numerical accuracy + inputs = [x, pad_value] + validate(gm_before, gm_after_replacement, inputs, "ReplaceIm2RowWithViewPass") + # Check that no replacement was made. self.assertEqual( count_node(gm_after_replacement, exir_ops.edge.cadence.im2row.default), 1 @@ -1626,7 +1672,7 @@ def test_no_replacement_for_conv(self) -> None: def test_no_replace_for_dilation(self) -> None: # Create a graph with a single im2row node. x = torch.randn(1, 3, 5, 7) - pad_value = torch.randn(1) + pad_value = torch.tensor(0, dtype=torch.int32) channels_last = False gm = single_op_builder( placeholders=(x, pad_value), @@ -1638,9 +1684,19 @@ def test_no_replace_for_dilation(self) -> None: self.assertEqual(count_node(gm, exir_ops.edge.cadence.im2row.default), 1) self.assertEqual(count_node(gm, exir_ops.edge.aten.view_copy.default), 0) + # Deepcopy before the pass + gm_before = copy.deepcopy(gm) + # Apply replacement pass. p = ReplaceIm2RowWithViewPass() - gm_after_replacement = p.call(gm).graph_module + result = p.call(gm) + self.assertFalse(result.modified) + gm_after_replacement = result.graph_module + + # Validate numerical accuracy + inputs = [x, pad_value] + validate(gm_before, gm_after_replacement, inputs, "ReplaceIm2RowWithViewPass") + self.assertEqual( count_node(gm_after_replacement, exir_ops.edge.cadence.im2row.default), 1 ) @@ -1652,7 +1708,7 @@ def test_replace_linear_like_conv(self) -> None: # Create a graph with a single im2row node. in_h, in_w = 13, 15 x = torch.randn(1, 3, in_h, in_w) - pad_value = torch.randn(1) + pad_value = torch.tensor(0, dtype=torch.int32) channels_last = False gm = single_op_builder( placeholders=(x, pad_value), @@ -1664,9 +1720,19 @@ def test_replace_linear_like_conv(self) -> None: self.assertEqual(count_node(gm, exir_ops.edge.cadence.im2row.default), 1) self.assertEqual(count_node(gm, exir_ops.edge.aten.view_copy.default), 0) + # Deepcopy before the pass + gm_before = copy.deepcopy(gm) + # Apply replacement pass. p = ReplaceIm2RowWithViewPass() - gm_after_replacement = p.call(gm).graph_module + result = p.call(gm) + self.assertTrue(result.modified) + gm_after_replacement = result.graph_module + + # Validate numerical accuracy + inputs = [x, pad_value] + validate(gm_before, gm_after_replacement, inputs, "ReplaceIm2RowWithViewPass") + # In this test, the kernel width/height is the same as the input width/height. self.assertEqual( count_node(gm_after_replacement, exir_ops.edge.cadence.im2row.default), 0 @@ -1985,9 +2051,10 @@ def test_cat_insert_transpose(self) -> None: class TestReplaceEmptyTensorsWithFullPass(unittest.TestCase): - def _get_slice_empty_gm(self) -> torch.fx.GraphModule: + def _get_slice_empty_gm(self) -> tuple[torch.fx.GraphModule, torch.Tensor]: builder = GraphBuilder() - x = builder.placeholder("x", torch.randn(4)) + x_input = torch.randn(4) + x = builder.placeholder("x", x_input) # This is empty (numel == 0). slice0 = builder.call_operator( exir_ops.edge.aten.slice_copy.Tensor, (x, 0, 0, 0) @@ -1999,10 +2066,10 @@ def _get_slice_empty_gm(self) -> torch.fx.GraphModule: ((slice0, slice1),), ) builder.output([cat]) - return builder.get_graph_module() + return builder.get_graph_module(), x_input def test_empty_slice(self) -> None: - gm = self._get_slice_empty_gm() + gm, x_input = self._get_slice_empty_gm() self.assertEqual( len( gm.graph.find_nodes( @@ -2019,8 +2086,19 @@ def test_empty_slice(self) -> None: ), 0, ) + + # Deepcopy before the pass + gm_before = copy.deepcopy(gm) + p = ReplaceEmptyTensorsWithFullPass() - updated_gm = cast(PassResult, p(gm)).graph_module + result = cast(PassResult, p(gm)) + self.assertTrue(result.modified) + updated_gm = result.graph_module + + # Validate numerical accuracy + inputs = [x_input] + validate(gm_before, updated_gm, inputs, "ReplaceEmptyTensorsWithFullPass") + self.assertEqual( len( updated_gm.graph.find_nodes( @@ -2048,21 +2126,37 @@ def test_empty_slice(self) -> None: def test_extract_mul_argument_to_full( self, _: str, value: Union[int, float] ) -> None: - x = torch.randn(2, 1, 64) + x_input = torch.randn(2, 1, 64) gm = single_op_builder( - placeholders=(x,), - op=torch.ops.aten.mul.Tensor, - args=(x, value), + placeholders=(x_input,), + op=exir_ops.edge.aten.mul.Tensor, + args=(x_input, value), kwargs={}, ) + + # Deepcopy before the pass + gm_before = copy.deepcopy(gm) + p = ReplaceMulTensorWithMulAndFullOpsPass() - graph_after_passes = p.call(gm).graph_module + result = p.call(gm) + self.assertTrue(result.modified) + graph_after_passes = result.graph_module + + # Validate numerical accuracy + inputs = [x_input] + validate( + gm_before, + graph_after_passes, + inputs, + "ReplaceMulTensorWithMulAndFullOpsPass", + ) + self.assertTrue( op_counts_match( graph_after_passes, expected_op_counts={ - torch.ops.aten.mul.Tensor: 1, - torch.ops.aten.full.default: 1, + exir_ops.edge.aten.mul.Tensor: 1, + exir_ops.edge.aten.full.default: 1, }, ) ) @@ -2071,17 +2165,18 @@ def test_extract_mul_argument_to_full( class TestReplaceAdaptiveAvgPoolWithAtenAvgPoolPass(unittest.TestCase): def _get_adaptive_avg_pool_gm( self, input_shape: Tuple[int, int, int, int], output_shape: Tuple[int, int] - ) -> torch.fx.GraphModule: + ) -> tuple[torch.Tensor, torch.fx.GraphModule]: builder = GraphBuilder() - x = builder.placeholder("x", torch.randn(*input_shape)) + x_input = torch.randn(*input_shape) + x = builder.placeholder("x", x_input) adaptive_avg_pool2d = builder.call_operator( exir_ops.edge.aten._adaptive_avg_pool2d.default, (x, output_shape) ) builder.output([adaptive_avg_pool2d]) - return builder.get_graph_module() + return x_input, builder.get_graph_module() def test_replace_adaptive_avg_pool_with_aten_avg_pool(self) -> None: - gm = self._get_adaptive_avg_pool_gm((1, 64, 128, 128), (8, 8)) + x_input, gm = self._get_adaptive_avg_pool_gm((1, 64, 128, 128), (8, 8)) self.assertEqual( len( gm.graph.find_nodes( @@ -2100,8 +2195,24 @@ def test_replace_adaptive_avg_pool_with_aten_avg_pool(self) -> None: ), 0, ) + + # Deepcopy before the pass + gm_before = copy.deepcopy(gm) + p = ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass() - updated_gm = p.call(gm).graph_module + result = p.call(gm) + self.assertTrue(result.modified) + updated_gm = result.graph_module + + # Validate numerical accuracy + inputs = [x_input] + validate( + gm_before, + updated_gm, + inputs, + "ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass", + ) + self.assertEqual( len( updated_gm.graph.find_nodes( @@ -2128,7 +2239,7 @@ def test_replace_adaptive_avg_pool_with_aten_avg_pool(self) -> None: self.assertEqual(avg_pool2d_node.args[6], None) # divisor_override is None def test_replace_adaptive_avg_pool_with_aten_avg_pool_irregular(self) -> None: - gm = self._get_adaptive_avg_pool_gm((1, 64, 128, 128), (9, 9)) + x_input, gm = self._get_adaptive_avg_pool_gm((1, 64, 128, 128), (9, 9)) self.assertEqual( len( gm.graph.find_nodes( @@ -2146,9 +2257,25 @@ def test_replace_adaptive_avg_pool_with_aten_avg_pool_irregular(self) -> None: ), 0, ) + + # Deepcopy before the pass + gm_before = copy.deepcopy(gm) + # Shapes are not multiples of each other, so pass will not trigger p = ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass() - updated_gm = p.call(gm).graph_module + result = p.call(gm) + self.assertFalse(result.modified) + updated_gm = result.graph_module + + # Validate numerical accuracy (should be same since not modified) + inputs = [x_input] + validate( + gm_before, + updated_gm, + inputs, + "ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass", + ) + self.assertEqual( len( updated_gm.graph.find_nodes( @@ -2167,6 +2294,113 @@ def test_replace_adaptive_avg_pool_with_aten_avg_pool_irregular(self) -> None: ) +class TestReplaceAtenAvgPoolWithCadenceAvgPoolPass(unittest.TestCase): + def _get_aten_avg_pool1d_gm( + self, input_shape: Tuple[int, int, int], kernel_size: int + ) -> tuple[torch.Tensor, torch.fx.GraphModule]: + builder = GraphBuilder() + x_input = torch.randn(*input_shape) + x = builder.placeholder("x", x_input) + avg_pool1d = builder.call_operator( + exir_ops.edge.aten.avg_pool1d.default, (x, [kernel_size]) + ) + builder.output([avg_pool1d]) + return x_input, builder.get_graph_module() + + def _get_aten_avg_pool2d_gm( + self, input_shape: Tuple[int, int, int, int], kernel_size: Tuple[int, int] + ) -> tuple[torch.Tensor, torch.fx.GraphModule]: + builder = GraphBuilder() + x_input = torch.randn(*input_shape) + x = builder.placeholder("x", x_input) + avg_pool2d = builder.call_operator( + exir_ops.edge.aten.avg_pool2d.default, (x, list(kernel_size)) + ) + builder.output([avg_pool2d]) + return x_input, builder.get_graph_module() + + def test_replace_aten_avg_pool1d_with_cadence(self) -> None: + x_input, gm = self._get_aten_avg_pool1d_gm((1, 32, 64), 3) + self.assertEqual( + count_node(gm, exir_ops.edge.aten.avg_pool1d.default), + 1, + ) + self.assertEqual( + count_node(gm, exir_ops.edge.cadence.avg_pool2d.default), + 0, + ) + + # Deepcopy before the pass + gm_before = copy.deepcopy(gm) + + p = ReplaceAtenAvgPoolWithCadenceAvgPoolPass() + result = p.call(gm) + self.assertTrue(result.modified) + updated_gm = result.graph_module + + # Validate numerical accuracy + inputs = [x_input] + validate( + gm_before, + updated_gm, + inputs, + "ReplaceAtenAvgPoolWithCadenceAvgPoolPass", + ) + + # avg_pool1d should be replaced with view operations and avg_pool2d + self.assertEqual( + count_node(updated_gm, exir_ops.edge.aten.avg_pool1d.default), + 0, + ) + self.assertEqual( + count_node(updated_gm, exir_ops.edge.cadence.avg_pool2d.default), + 1, + ) + # Should have view operations for reshaping + self.assertGreater( + count_node(updated_gm, exir_ops.edge.aten.view_copy.default), + 0, + ) + + def test_replace_aten_avg_pool2d_with_cadence(self) -> None: + x_input, gm = self._get_aten_avg_pool2d_gm((1, 32, 64, 64), (3, 3)) + self.assertEqual( + count_node(gm, exir_ops.edge.aten.avg_pool2d.default), + 1, + ) + self.assertEqual( + count_node(gm, exir_ops.edge.cadence.avg_pool2d.default), + 0, + ) + + # Deepcopy before the pass + gm_before = copy.deepcopy(gm) + + p = ReplaceAtenAvgPoolWithCadenceAvgPoolPass() + result = p.call(gm) + self.assertTrue(result.modified) + updated_gm = result.graph_module + + # Validate numerical accuracy + inputs = [x_input] + validate( + gm_before, + updated_gm, + inputs, + "ReplaceAtenAvgPoolWithCadenceAvgPoolPass", + ) + + # avg_pool2d should be replaced with cadence avg_pool2d + self.assertEqual( + count_node(updated_gm, exir_ops.edge.aten.avg_pool2d.default), + 0, + ) + self.assertEqual( + count_node(updated_gm, exir_ops.edge.cadence.avg_pool2d.default), + 1, + ) + + class TestReplaceLinalgSvdPass(unittest.TestCase): @expand( [