From 4309ded72e7162d700c5e4edb7d8691b920d0010 Mon Sep 17 00:00:00 2001 From: Ryan O'Shea Date: Sun, 19 Oct 2025 04:00:08 +0200 Subject: [PATCH 1/3] Arm backend: Add conv3d support to Tosa/Vgf backends Conv3d is not supported by vela so it can not be supported in u55 or u85. * Adds Conv3D support for FP32, Int8 and int16A8W * Reworks to_tosa_memory_format_pass.py to handle spatial rank 3 tensors (DHW) * Adds support for rank 5 tensors to analyze_output_utils.py * Reworks conv2d passes to handle conv3d and renames them to be more generic Signed-off-by: Ryan O'Shea Change-Id: I0c888b46afe7bdfa0a26c26d5281ab02b945b528 --- backends/arm/_passes/__init__.py | 6 +- backends/arm/_passes/arm_pass_manager.py | 8 +- backends/arm/_passes/arm_pass_utils.py | 14 + backends/arm/_passes/conv1d_unsqueeze_pass.py | 4 +- backends/arm/_passes/decompose_cumsum_pass.py | 4 +- ...> decompose_int16_activation_conv_pass.py} | 46 ++- ...te_conv2d_pass.py => rewrite_conv_pass.py} | 75 +++-- .../arm/_passes/size_adjust_input_pass.py | 49 +-- .../arm/_passes/to_tosa_memory_format_pass.py | 295 +++++++++++++----- .../arm/operator_support/ethos_u55_support.py | 2 + backends/arm/operators/__init__.py | 1 + backends/arm/operators/op_tosa_conv3d.py | 24 ++ .../arm/quantizer/quantization_annotator.py | 52 +-- backends/arm/quantizer/quantization_config.py | 3 +- backends/arm/scripts/parse_test_names.py | 1 + backends/arm/test/ops/test_conv3d.py | 259 +++++++++++++-- .../arm/test/tester/analyze_output_utils.py | 76 +++-- backends/arm/tosa/dialect/__init__.py | 1 + backends/arm/tosa/dialect/ops/conv3d.py | 75 +++++ 19 files changed, 741 insertions(+), 254 deletions(-) rename backends/arm/_passes/{decompose_int16_activation_conv2d_pass.py => decompose_int16_activation_conv_pass.py} (73%) rename backends/arm/_passes/{rewrite_conv2d_pass.py => rewrite_conv_pass.py} (83%) create mode 100644 backends/arm/operators/op_tosa_conv3d.py create mode 100644 backends/arm/tosa/dialect/ops/conv3d.py diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 28775cc8614..f8f9cc36271 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -48,8 +48,8 @@ from .decompose_glu_pass import DecomposeGluPass # noqa from .decompose_grouped_conv_pass import DecomposeGroupedConvPass # noqa from .decompose_groupnorm_pass import DecomposeGroupNormPass # noqa -from .decompose_int16_activation_conv2d_pass import ( # noqa - DecomposeConv2dWithInt16ActivationPass, +from .decompose_int16_activation_conv_pass import ( # noqa + DecomposeConvWithInt16ActivationPass, ) from .decompose_int32_clamp_pass import DecomposeInt32ClampPass # noqa from .decompose_int_pow_pass import DecomposeIntPowPass # noqa @@ -109,7 +109,7 @@ from .replace_scalar_with_tensor_pass import ( # noqa ReplaceScalarWithTensorByProfilePass, ) -from .rewrite_conv2d_pass import RewriteConv2dPass # noqa +from .rewrite_conv_pass import RewriteConvPass # noqa from .rewrite_matmul import RewriteMatmulPass # noqa from .rewrite_upsample import RewriteUpsamplePass # noqa from .scalars_to_attribute_pass import ScalarsToAttributePass # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 098e5f03506..4e20a864f76 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -40,7 +40,7 @@ DecomposeAtanPass, DecomposeAvgPool2dPass, DecomposeBatchNormNoStatsPass, - DecomposeConv2dWithInt16ActivationPass, + DecomposeConvWithInt16ActivationPass, DecomposeCoshPass, DecomposeCosineSimilarityPass, DecomposeCumsumPass, @@ -101,7 +101,7 @@ RemoveNoopPass, ReplaceInfAndLimitValuesPass, ReplaceScalarWithTensorByProfilePass, - RewriteConv2dPass, + RewriteConvPass, RewriteMatmulPass, RewriteUpsamplePass, ScalarsToAttributePass, @@ -277,7 +277,7 @@ def _tosa_pipeline( BroadcastArgsPass(), ConvertPermuteSingletonToViewPass(), FuseViewCopyTransformPass(), - DecomposeConv2dWithInt16ActivationPass(), + DecomposeConvWithInt16ActivationPass(), DecomposeSumPass(), InsertTableOpsPass(exported_program), ] @@ -287,7 +287,7 @@ def _tosa_pipeline( self.add_passes( [ RewriteUpsamplePass(), - RewriteConv2dPass(exported_program), + RewriteConvPass(exported_program), RewriteMatmulPass(), ] ) diff --git a/backends/arm/_passes/arm_pass_utils.py b/backends/arm/_passes/arm_pass_utils.py index b9aa04236eb..006d4fff953 100644 --- a/backends/arm/_passes/arm_pass_utils.py +++ b/backends/arm/_passes/arm_pass_utils.py @@ -106,6 +106,20 @@ def get_param_tensor( raise RuntimeError(f"unsupported param type, {node.op}.") +def expand_around_channel(param: Sequence[int] | int, spatial_rank: int) -> list[int]: + """ + Expand a scalar or 1-D parameter around the channel dimension into a broadcastable + shape while preserving the channel location. + """ + if isinstance(param, int): + return [param] * spatial_rank + + param_list = list(param) + if len(param_list) == 1 and spatial_rank > 1: + param_list = param_list * spatial_rank + return param_list + + def create_node( graph: torch.fx.Graph, op_target: OpOverload | EdgeOpOverload, diff --git a/backends/arm/_passes/conv1d_unsqueeze_pass.py b/backends/arm/_passes/conv1d_unsqueeze_pass.py index b6cf8ffa41b..f0b1026577b 100644 --- a/backends/arm/_passes/conv1d_unsqueeze_pass.py +++ b/backends/arm/_passes/conv1d_unsqueeze_pass.py @@ -10,7 +10,7 @@ from executorch.backends.arm._passes import ArmPass -from executorch.backends.arm._passes.rewrite_conv2d_pass import RewriteConv2dPass +from executorch.backends.arm._passes.rewrite_conv_pass import RewriteConvPass from executorch.backends.arm._passes.size_adjust_input_pass import SizeAdjustInputPass from executorch.exir.dialects._ops import ops as exir_ops @@ -29,7 +29,7 @@ class Conv1dUnsqueezePass(ArmPass): """ _passes_required_after: Set[Type[ExportPass]] = { - RewriteConv2dPass, + RewriteConvPass, SizeAdjustInputPass, } diff --git a/backends/arm/_passes/decompose_cumsum_pass.py b/backends/arm/_passes/decompose_cumsum_pass.py index dedbc2c039f..8b7d31c97ac 100644 --- a/backends/arm/_passes/decompose_cumsum_pass.py +++ b/backends/arm/_passes/decompose_cumsum_pass.py @@ -10,7 +10,7 @@ from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import create_node from executorch.backends.arm._passes.quant_args import QuantArgs -from executorch.backends.arm._passes.rewrite_conv2d_pass import RewriteConv2dPass +from executorch.backends.arm._passes.rewrite_conv_pass import RewriteConvPass from executorch.backends.transforms.utils import create_constant_placeholder from executorch.exir import ExportedProgram @@ -42,7 +42,7 @@ class DecomposeCumsumPass(ArmPass): And the convolution is applied over dimension H. """ - _passes_required_after: Set[Type[ExportPass]] = {RewriteConv2dPass} + _passes_required_after: Set[Type[ExportPass]] = {RewriteConvPass} def __init__(self, exported_program: ExportedProgram) -> None: super().__init__() diff --git a/backends/arm/_passes/decompose_int16_activation_conv2d_pass.py b/backends/arm/_passes/decompose_int16_activation_conv_pass.py similarity index 73% rename from backends/arm/_passes/decompose_int16_activation_conv2d_pass.py rename to backends/arm/_passes/decompose_int16_activation_conv_pass.py index 2f160474c5b..0a8c5eea2b2 100644 --- a/backends/arm/_passes/decompose_int16_activation_conv2d_pass.py +++ b/backends/arm/_passes/decompose_int16_activation_conv_pass.py @@ -4,10 +4,10 @@ # LICENSE file in the root directory of this source tree. -from typing import cast, Set, Type +from typing import cast, Sequence, Set, Type import torch -from executorch.backends.arm._passes.arm_pass import ArmPass +from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.quant_args import QuantArgs from executorch.backends.arm.tosa.specification import get_context_spec @@ -15,16 +15,26 @@ from executorch.exir.pass_base import ExportPass -class DecomposeConv2dWithInt16ActivationPass(ArmPass): +class DecomposeConvWithInt16ActivationPass(ArmPass): """ This pass decomposes a convolution with input dtype int16 and bias - into a convolution without bias followed by an addition of the bias - since the TOSA op requires the bias to be int48 which is hard to represent + into a convolution without bias followed by an addition of the bias. + We also reshape the 1D bias to [1, C, 1, …] so it broadcasts along the channel + dimension. Since the TOSA op requires the bias to be int48 which is hard to represent in torch. Instead rescale the int48 output to int16 and add the bias in int16. """ + def __init__(self) -> None: + super().__init__() + _passes_required_after: Set[Type[ExportPass]] = set() + def bias_view_shape( + self, bias: torch.Tensor, activation_rank: int + ) -> Sequence[int]: + # reshape bias to match convolution output rank so addition broadcasts over channels + return [1, bias.shape[0], *([1] * (activation_rank - 2))] + def call_operator(self, op, args, kwargs, meta): if op != exir_ops.edge.aten.convolution.default: return super().call_operator(op, args, kwargs, meta) @@ -37,18 +47,22 @@ def call_operator(self, op, args, kwargs, meta): if args[2] is None: return super().call_operator(op, args, kwargs, meta) - if args[0].data.dtype == torch.int8: - return super().call_operator(op, args, kwargs, meta) - elif args[0].data.dtype == torch.int16: - if not tosa_spec.support_extension("int16"): - raise ValueError( - "int16 activation for convolution requires TOSA int16 extension" - ) - else: + activation_tensor = args[0].data + activation_rank = activation_tensor.dim() + + if activation_rank not in (4, 5) or activation_tensor.dtype != torch.int16: return super().call_operator(op, args, kwargs, meta) - # convolution with bias and activation is int16 - bias = args[2] + if not tosa_spec.support_extension("int16"): + raise ValueError( + "int16 activation for convolution requires TOSA int16 extension" + ) + + # convolution with bias and activation is int16 (expected activation rank enforced above) + # The bias is assumed to be quantized with the same quantization parameters as + # the output of the convolution + bias_arg = args[2] + bias_data = bias_arg.data no_bias_args = list(args) no_bias_args[2] = None @@ -63,7 +77,7 @@ def call_operator(self, op, args, kwargs, meta): # reshape the tensor to the same rank as the convolution output to add the bias to the channels channel_bias = super().call_operator( exir_ops.edge.aten.view_copy.default, - (bias, [1, len(bias.data), 1, 1]), + (bias_arg, self.bias_view_shape(bias_data, activation_rank)), {}, new_meta, ) diff --git a/backends/arm/_passes/rewrite_conv2d_pass.py b/backends/arm/_passes/rewrite_conv_pass.py similarity index 83% rename from backends/arm/_passes/rewrite_conv2d_pass.py rename to backends/arm/_passes/rewrite_conv_pass.py index 316c0e44136..9d3ad4f933f 100644 --- a/backends/arm/_passes/rewrite_conv2d_pass.py +++ b/backends/arm/_passes/rewrite_conv_pass.py @@ -12,6 +12,7 @@ from executorch.backends.arm._passes.arm_pass_utils import ( create_node, + expand_around_channel, get_first_fake_tensor, get_param_tensor, is_buffer, @@ -29,7 +30,7 @@ from torch.export.graph_signature import InputKind -class RewriteConv2dPass(ArmPass): +class RewriteConvPass(ArmPass): """Rewrites aten.convolution to tosa.CONV2D or tosa.DEPTHWISE_CONV2D.""" def __init__(self, exported_program: torch.export.ExportedProgram): @@ -88,11 +89,27 @@ def _is_depthwise_conv2d(self, node: torch.fx.Node) -> bool: or node.target != exir_ops.edge.aten.convolution.default ): return False + input_tensor = get_first_fake_tensor(node.all_input_nodes[0]) + if len(input_tensor.shape) != 4: + return False groups = node.args[-1] - in_channels = get_first_fake_tensor(node.all_input_nodes[0]).shape[1] + in_channels = input_tensor.shape[1] out_channels = get_first_fake_tensor(node).shape[1] return (in_channels == groups) and (out_channels % in_channels) == 0 + def _is_conv3d(self, rank, groups) -> bool: + if rank == 5: + # A Conv3D is considered depthwise if Group == InChannels and + # Group * N == OutChannels, where N is a possitive integer. + # Currently we do not support depthwise or grouped conv3d. + # @TODO Add grouped/depthwise conv3d support or reject in partitioner. + if groups != 1: + raise RuntimeError( + "CONV3D with groups != 1 is not supported in the Arm backend." + ) + return True + return False + def _reshape_weights(self, weight_node: torch.fx.Node, in_channels: int) -> None: """Reshape the weights for depthwise convolution such that when serialized to TOSA, the weights are in the format [H, W, in_channels, m_length] where @@ -201,7 +218,7 @@ def insert_output_rescale(self, graph_module, node): ) return rescale_node - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901 modified = False for node in graph_module.graph.nodes: if ( @@ -224,30 +241,40 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: group, ) = node.args - pad = [val for val in pad for _ in (0, 1)] input_fake_tensor = get_first_fake_tensor(x) weight_fake_tensor = get_first_fake_tensor(weight) - # Adjust the pad value if needed to meet the - # strict convolution output shape calculation. - pad[1] = self._adjust_pad_if_needed( - input_fake_tensor.shape[2], - weight_fake_tensor.shape[2], - stride[0], - pad[1], - dilation[0], - ) - pad[3] = self._adjust_pad_if_needed( - input_fake_tensor.shape[3], - weight_fake_tensor.shape[3], - stride[1], - pad[3], - dilation[1], - ) + input_shape = input_fake_tensor.shape + weight_shape = weight_fake_tensor.shape + spatial_rank = len(input_shape) - 2 + stride_list = expand_around_channel(stride, spatial_rank) + dilation_list = expand_around_channel(dilation, spatial_rank) + pad_list = expand_around_channel(pad, spatial_rank) + + pad_attr: list[int] = [] + for value in pad_list: + pad_attr.extend([value, value]) # duplicate pad before/after per axis + + for axis_index in range(spatial_rank): + pad_index = axis_index * 2 + 1 # adjust trailing pad entry + pad_attr[pad_index] = self._adjust_pad_if_needed( + input_shape[axis_index + 2], + weight_shape[axis_index + 2], + stride_list[axis_index], + pad_attr[pad_index], + dilation_list[axis_index], + ) + + stride = tuple(stride_list) + dilation = tuple(dilation_list) + pad = pad_attr + has_bias = bias is not None if not has_bias: bias = self._add_bias(graph_module, node, weight) - if self._is_depthwise_conv2d(node): + if self._is_conv3d(len(input_shape), group): + target_op = exir_ops.backend.tosa.CONV3D.default + elif self._is_depthwise_conv2d(node): target_op = exir_ops.backend.tosa.DEPTHWISE_CONV2D.default # If there are any TOSA.DEPTHWISE_CONV2D nodes using the weights, we've already reshaped them. if all(user.target != target_op for user in weight.users): @@ -256,7 +283,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: else: target_op = exir_ops.backend.tosa.CONV2D.default - conv2d_args = ( + conv_args = ( x, weight, bias, @@ -272,7 +299,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: tosa_op = create_node( graph=graph_module.graph, op_target=target_op, - args=conv2d_args, + args=conv_args, from_node=node, inherit_qparams=True, ) @@ -281,7 +308,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: input_fake_tensor, weight_fake_tensor, bias_fake_tensor, - *conv2d_args[3:], + *conv_args[3:], ) if ( diff --git a/backends/arm/_passes/size_adjust_input_pass.py b/backends/arm/_passes/size_adjust_input_pass.py index 9460c8f199a..642a2499deb 100644 --- a/backends/arm/_passes/size_adjust_input_pass.py +++ b/backends/arm/_passes/size_adjust_input_pass.py @@ -4,12 +4,15 @@ # LICENSE file in the root directory of this source tree. -from typing import cast, Set, Type, TypeAlias +from typing import cast, Sequence, Set, Type, TypeAlias import torch.fx from executorch.backends.arm._passes import ArmPass -from executorch.backends.arm._passes.arm_pass_utils import create_node -from executorch.backends.arm._passes.rewrite_conv2d_pass import RewriteConv2dPass +from executorch.backends.arm._passes.arm_pass_utils import ( + create_node, + expand_around_channel, +) +from executorch.backends.arm._passes.rewrite_conv_pass import RewriteConvPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -39,19 +42,22 @@ def pooling_remainder(input_size, pad, kernel_size, stride) -> int: return (input_size + 2 * pad - kernel_size) % stride -def get_slices_conv2d(conv_node: torch.fx.Node) -> Slices: +def get_slices_convolution(conv_node: torch.fx.Node) -> Slices: slices = [] input_node, weight, _, stride_hw, pad_hw, dilation_hw, _, _, _ = conv_node.args weight_shape = cast(torch.fx.Node, weight).meta["val"].shape input_shape = cast(torch.fx.Node, input_node).meta["val"].shape + spatial_rank = len(input_shape) - 2 - for stride, pad, dilation, dim in zip( - cast(list, stride_hw), - cast(list, pad_hw), - cast(list, dilation_hw), - (2, 3), - ): + strides = expand_around_channel(cast(Sequence[int] | int, stride_hw), spatial_rank) + pads = expand_around_channel(cast(Sequence[int] | int, pad_hw), spatial_rank) + dilations = expand_around_channel( + cast(Sequence[int] | int, dilation_hw), spatial_rank + ) + + for axis_index, (stride, pad, dilation) in enumerate(zip(strides, pads, dilations)): + dim = axis_index + 2 remainder = conv_remainder( input_shape[dim], pad, dilation, weight_shape[dim], stride ) @@ -69,19 +75,16 @@ def get_slices_pooling(pooling_node: torch.fx.Node) -> Slices: input_node = pooling_node.args[0] kernel_size = pooling_node.args[1] stride = pooling_node.args[2] - padding = pooling_node.args[3] if len(pooling_node.args) >= 4 else [0, 0] - - # For the loop below, padding must be a list - if isinstance(padding, int): - padding = [padding, padding] + padding = pooling_node.args[3] if len(pooling_node.args) >= 4 else 0 input_shape = cast(torch.fx.Node, input_node).meta["val"].shape - for kernel_length, stride_length, pad_size, dim in zip( - cast(list, kernel_size), - cast(list, stride), - cast(list, padding), - (2, 3), + kernel_sizes = expand_around_channel(cast(Sequence[int] | int, kernel_size), 2) + strides = expand_around_channel(cast(Sequence[int] | int, stride), 2) + pads = expand_around_channel(cast(Sequence[int] | int, padding), 2) + + for dim, (kernel_length, stride_length, pad_size) in enumerate( + zip(kernel_sizes, strides, pads), start=2 ): remainder = pooling_remainder( input_shape[dim], pad_size, kernel_length, stride_length @@ -99,7 +102,7 @@ def get_slices(node: torch.fx.Node) -> Slices: Returns the remainder of input_length; given graph Node. """ if node.target == conv2d_op: - return get_slices_conv2d(node) + return get_slices_convolution(node) elif node.target == max_pooling_op or node.target == avg_pooling_op: return get_slices_pooling(node) else: @@ -186,7 +189,9 @@ class SizeAdjustInputPass(ArmPass): input. """ - _passes_required_after: Set[Type[ExportPass]] = {RewriteConv2dPass} + _passes_required_after: Set[Type[ExportPass]] = { + RewriteConvPass, + } def call(self, graph_module: torch.fx.GraphModule) -> PassResult: graph = graph_module.graph diff --git a/backends/arm/_passes/to_tosa_memory_format_pass.py b/backends/arm/_passes/to_tosa_memory_format_pass.py index 7e998e3a436..07799a840dc 100644 --- a/backends/arm/_passes/to_tosa_memory_format_pass.py +++ b/backends/arm/_passes/to_tosa_memory_format_pass.py @@ -17,17 +17,7 @@ get_first_fake_tensor, is_param_node, ) -from executorch.backends.arm.constants import ( - NCHW_ORDER, - NHWC_INVERSE_ORDER, - NHWC_ORDER, - NNCHW_ORDER, - NNHWC_INVERSE_ORDER, - NNHWC_ORDER, - NNNCHW_ORDER, - NNNHWC_INVERSE_ORDER, - NNNHWC_ORDER, -) +from executorch.backends.arm.constants import NCHW_ORDER, NNCHW_ORDER, NNNCHW_ORDER from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -48,6 +38,7 @@ class ToTosaMemoryFormatPass(ArmPass): that in most cases will be (0, 2, 3, 1) for nodes with 4D-shapes. The pass also inserts backend.tosa.TRANSPOSE when a transition between 3D and 4D/5D tensors happen. The annotated tosa_dim_order is used to permute the node's shape such that it gives a TOSA-compliant shape. + This pass also makes other values aware of spatial dimensions required by future operators by back propogating info as required. """ _passes_required_after: Set[Type[ExportPass]] = set() @@ -57,74 +48,169 @@ def __init__(self, exported_program: ExportedProgram) -> None: self.exported_program = exported_program @staticmethod - def memory_format_differs(shape): - """Returns true if the shape will have a different memory layout in (N)NCHW and (N)NHWC format""" - if len(shape) >= 6: - C = shape[3] - H = shape[4] - W = shape[5] - elif len(shape) == 5: - C = shape[2] - H = shape[3] - W = shape[4] - elif len(shape) == 4: - C = shape[1] - H = shape[2] - W = shape[3] - elif len(shape) == 3: - C = shape[0] - H = shape[1] - W = shape[2] - if len(shape) <= 2: - return False + def _channels_last_order(rank: int, spatial_rank: int) -> tuple[int, ...]: + """ + Compute the permutation of tensor dimensions corresponding to a + "channels_last"-style memory layout for an arbitrary tensor rank. + + In standard PyTorch convention: + - "channels_first" order is (N, C, H, W) + - "channels_last" order is (N, H, W, C) + This helper generalizes that concept beyond 4D tensors, producing an index + ordering that moves the channel dimension to the end while preserving the + relative order of batch and spatial dimensions. + + Args: + rank (int): Total number of tensor dimensions (e.g. 4 for NCHW). + spatial_rank (int): Number of spatial dimensions (e.g. 2 for HW, 3 for DHW). + Values outside [0, rank - 2] are clamped to that range. + + Returns: + tuple[int, ...]: A permutation of dimension indices that reorders the + tensor into "channels_last" format. For example: + - rank=4, spatial_rank=2 → (0, 2, 3, 1) # NCHW → NHWC + - rank=5, spatial_rank=3 → (0, 2, 3, 4, 1) # NCDHW → NDHWC + - rank=3, spatial_rank=1 → (0, 2, 1) + + Notes: + If `rank <= 2`, the function returns the identity order since there + are no distinct channel/spatial dimensions. + In practice only rank 4+ tensors will reach this function as the dim order should be fixed for those. + """ + if rank <= 2: + return tuple(range(rank)) + spatial_rank = max(0, min(spatial_rank, rank - 2)) + channel_axis = rank - (spatial_rank + 1) + batch_axes = list(range(channel_axis)) + spatial_axes = list(range(channel_axis + 1, rank)) + return tuple(batch_axes + spatial_axes + [channel_axis]) + + @staticmethod + def _channels_last_inverse_order(rank: int, spatial_rank: int) -> tuple[int, ...]: + """ + Return the inverse permutation of `_channels_last_order`. + + This provides the axis order needed to map a tensor from + "channels_last" layout back to its original layout. + """ + order = ToTosaMemoryFormatPass._channels_last_order(rank, spatial_rank) + inverse = [0] * rank + for idx, axis in enumerate(order): + inverse[axis] = idx + return tuple(inverse) + + def _initial_spatial_rank(self, node: torch.fx.Node) -> int: + """ + Infer the initial spatial rank based on the current rank, input node spatial + ranks and node target. A spatial dimension includes Height, Width or Depth + fields. In most operators this will only ever be Height and Width, but for 3D + operators such as conv3d this would contain 3 spatial dims. + + Spatial rank is the max of any input node spatial ranks and the number of + trailing spatial dims we need to preserve (rank - 2, capped at 3). This + decides which axes must stay channels-last when inserting transposes. + """ + tensor = get_first_fake_tensor(node).data + # Start by assuming 2D when dealing with rank4+ to account for the base case + # of an increasing amount of batch dimensions. + rank = tensor.dim() + if rank >= 4: + spatial_rank = 2 + elif rank == 3: + spatial_rank = 1 + else: + spatial_rank = 0 + + # Look for supported 3D ops and update spatial rank if relevent. + # Currently only Conv3d is supported. + if node.target == exir_ops.backend.tosa.CONV3D.default: + spatial_rank = 3 - return C > 1 and (H > 1 or W > 1) + # Check input spatial ranks to know what the previous node spatial ranks were. + input_ranks = [ + input_node.meta.get("tosa_spatial_rank", 0) + for input_node in node.all_input_nodes + ] + if input_ranks: + spatial_rank = max([spatial_rank, *input_ranks]) + + # The max that spatial rank can be is 3. If the current rank not capable of holding + # the current spatial rank, we clamp the max to Rank - (Channels and a singular batch dimension). + # This ensures we revert back to lower spatial ranks after we are finished processing higher spatial ops. + return min(spatial_rank, max(rank - 2, 0)) + + @staticmethod + def memory_format_differs(shape, spatial_rank): + """ + Determine whether a tensor shape would be laid out differently in + channels-first ((N)NCHW) versus channels-last ((N)NHWC) memory format. + """ + if len(shape) <= 2 or spatial_rank <= 0: + return False + channel_idx = len(shape) - (spatial_rank + 1) + channel_idx = max(0, min(channel_idx, len(shape) - 1)) + spatial_dims = shape[channel_idx + 1 :] + if not spatial_dims: + return False + channel_dim = shape[channel_idx] + return channel_dim > 1 and any(dim > 1 for dim in spatial_dims) @staticmethod - def is_channel_reshape(input_shape, output_shape): - """Returns true if reshape changes the channel dimension or batch product dimension(s)""" + def is_channel_reshape( + input_shape, output_shape, input_spatial_rank, output_spatial_rank + ): + """ + Check whether a reshape touches the logical channel or consolidated + batch dimensions, which would invalidate dim-order annotations. + """ valid_ranks = {4, 5, 6} if not (len(input_shape) in valid_ranks and len(output_shape) in valid_ranks): return False - C_old = input_shape[-3] - C_new = output_shape[-3] + def channel_index(shape, spatial_rank): + if len(shape) <= 2: + return len(shape) - 1 + idx = len(shape) - (spatial_rank + 1) + return max(0, min(idx, len(shape) - 1)) - def get_batch_prod_dim(shape): + C_old = input_shape[channel_index(input_shape, input_spatial_rank)] + C_new = output_shape[channel_index(output_shape, output_spatial_rank)] + + def get_batch_prod_dim(shape, spatial_rank): product = 1 - for dim in shape[:-3]: + for dim in shape[: channel_index(shape, spatial_rank)]: product = product * dim return product - N_old = get_batch_prod_dim(input_shape) - N_new = get_batch_prod_dim(output_shape) + N_old = get_batch_prod_dim(input_shape, input_spatial_rank) + N_new = get_batch_prod_dim(output_shape, output_spatial_rank) return (N_old != N_new) or (C_old != C_new) @staticmethod def insert_input_transpose(node, input_node, graph_module): + """ + Ensure an input tensor is converted to channels-last ordering by + inserting (or folding) a backend `TRANSPOSE` node. + """ if input_node.target == exir_ops.backend.tosa.TRANSPOSE.default: pre_permute_node = input_node.all_input_nodes[0] node.replace_input_with(input_node, pre_permute_node) return - if len(get_first_fake_tensor(input_node).size()) == 6: - mem_format = NNNHWC_INVERSE_ORDER - elif len(get_first_fake_tensor(input_node).size()) == 5: - mem_format = NNHWC_INVERSE_ORDER - else: - mem_format = NHWC_INVERSE_ORDER + rank = len(get_first_fake_tensor(input_node).size()) + spatial_rank = input_node.meta["tosa_spatial_rank"] + mem_format = ToTosaMemoryFormatPass._channels_last_inverse_order( + rank, spatial_rank + ) # Guard: mem_format must be a true permutation for the current rank - _rank_ = len( - get_first_fake_tensor(input_node).size() - ) # or (node) in output path assert sorted(mem_format) == list( - range(_rank_) - ), f"bad perm {mem_format} for rank {_rank_} in insert_input_transpose" + range(rank) + ), f"bad perm {mem_format} for rank {rank} in insert_input_transpose" with graph_module.graph.inserting_before(node): permute_node = create_node( @@ -141,21 +227,22 @@ def insert_input_transpose(node, input_node, graph_module): permute_node.meta["tosa_dim_order"] = tuple( range(len(input_node.meta["val"].size())) ) + permute_node.meta["tosa_spatial_rank"] = spatial_rank @staticmethod def insert_output_transpose(node, graph_module): + """ + Convert a producer's output to channels-last by appending a backend + `TRANSPOSE` node and rewiring its users. + """ - if len(get_first_fake_tensor(node).size()) == 6: - mem_format = NNNHWC_ORDER - elif len(get_first_fake_tensor(node).size()) == 5: - mem_format = NNHWC_ORDER - else: - mem_format = NHWC_ORDER + rank = len(get_first_fake_tensor(node).size()) + spatial_rank = node.meta["tosa_spatial_rank"] + mem_format = ToTosaMemoryFormatPass._channels_last_order(rank, spatial_rank) # Guard: mem_format must be a true permutation for the current rank - _rank_ = len(get_first_fake_tensor(node).size()) # or (node) in output path assert sorted(mem_format) == list( - range(_rank_) - ), f"bad perm {mem_format} for rank {_rank_} in insert_input_transpose" + range(rank) + ), f"bad perm {mem_format} for rank {rank} in insert_input_transpose" with graph_module.graph.inserting_after(node): permute_node = create_node( @@ -169,16 +256,12 @@ def insert_output_transpose(node, graph_module): ) rank = len(get_first_fake_tensor(node).size()) - if rank == 6: - permute_node.meta["tosa_dim_order"] = NNNHWC_ORDER - elif rank == 5: - permute_node.meta["tosa_dim_order"] = NNHWC_ORDER - else: - permute_node.meta["tosa_dim_order"] = NHWC_ORDER + permute_node.meta["tosa_dim_order"] = mem_format node.meta["tosa_dim_order"] = tuple( range(len(get_first_fake_tensor(node).size())) ) + permute_node.meta["tosa_spatial_rank"] = spatial_rank users = [user for user in node.users if user != permute_node] for user in users: @@ -188,24 +271,33 @@ def insert_output_transpose(node, graph_module): def _insert_view_transpose( input_shape, output_shape, node, input_node, graph_module ): + """ + Insert the necessary input/output transposes around reshapes that cross + the (N)NCHW -> (N)NHWC boundary or that touch channel dimensions. + """ nchw_to_nhwc = len(input_shape) < 4 and len(output_shape) >= 4 nhwc_to_nchw = len(input_shape) >= 4 and len(output_shape) < 4 + + input_sr = input_node.meta["tosa_spatial_rank"] + output_sr = node.meta["tosa_spatial_rank"] + channel_reshape = ToTosaMemoryFormatPass.is_channel_reshape( - output_shape, input_shape + input_shape, + output_shape, + input_sr, + output_sr, ) if ( channel_reshape or nhwc_to_nchw - ) and ToTosaMemoryFormatPass.memory_format_differs(input_shape): - + ) and ToTosaMemoryFormatPass.memory_format_differs(input_shape, input_sr): ToTosaMemoryFormatPass.insert_input_transpose( node, input_node, graph_module ) if ( channel_reshape or nchw_to_nhwc - ) and ToTosaMemoryFormatPass.memory_format_differs(output_shape): - + ) and ToTosaMemoryFormatPass.memory_format_differs(output_shape, output_sr): ToTosaMemoryFormatPass.insert_output_transpose(node, graph_module) def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule): @@ -214,7 +306,7 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule): This is relevant for the following cases: - view: <4D -> >=4D - view: >=4D -> <4D - Additionally, a 4D/5D->4D/5D view operation acting on the channel dimension currently needs to be performed in (N)NCHW format, leadning to one extra input and output transpose for this case. + Additionally, a 4D/5D->4D/5D view operation acting on the channel dimension currently needs to be performed in (N)NCHW format, leading to one extra input and output transpose for this case. Transposes can be avoided for shapes where there is no difference in actual memory, e.g for - H == W == 1 @@ -284,6 +376,10 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule): def remove_dim_order_kwargs( self, graph_module: torch.fx.GraphModule, node: torch.fx.Node ): + """ + Drop any user-specified `dim_order` keyword arguments so the pass remains + the single source of truth for dim-order annotations. + """ if node.op != "call_function": return @@ -298,24 +394,31 @@ def remove_dim_order_kwargs( node.kwargs = kwargs def call(self, graph_module: torch.fx.GraphModule): - for node in graph_module.graph.nodes: + """ + Entry point for the pass: annotate spatial ranks, compute dim orders, + insert bridging transposes, and forward to child passes. + """ + nodes = list(graph_module.graph.nodes) + for node in nodes: if "val" not in node.meta: continue - node_data = get_first_fake_tensor(node).data - + node.meta["tosa_spatial_rank"] = self._initial_spatial_rank(node) self.remove_dim_order_kwargs(graph_module, node) - # Inputs and outputs may vary in dim_order + + self._propagate_spatial_ranks(nodes) + + for node in nodes: + if "val" not in node.meta: + continue + node_data = get_first_fake_tensor(node).data + spatial_rank = node.meta["tosa_spatial_rank"] if _is_input(node, self.exported_program) or node.op == "output": dim_order = node_data.dim_order() - elif node_data.dim() == 4: - dim_order = NHWC_ORDER - elif node_data.dim() == 5: - dim_order = NNHWC_ORDER - elif node_data.dim() == 6: - dim_order = NNNHWC_ORDER else: - dim_order = tuple(range(node_data.dim())) # type: ignore[assignment] - + if node_data.dim() >= 4: + dim_order = self._channels_last_order(node_data.dim(), spatial_rank) + else: + dim_order = tuple(range(node_data.dim())) # type: ignore[assignment] node.meta["tosa_dim_order"] = dim_order # Insert TOSA transposes to convert between (N)NCHW and (N)NHWC format. @@ -325,3 +428,27 @@ def call(self, graph_module: torch.fx.GraphModule): graph_module = super().call(graph_module).graph_module return PassResult(graph_module, True) + + def _propagate_spatial_ranks(self, nodes): + """ + Propagate `tosa_spatial_rank` metadata backwards so earlier nodes learn + about upcoming spatial requirements from future ops. + """ + changed = True + while changed: + changed = False + for node in reversed(nodes): + if "val" not in node.meta: + continue + tensor = get_first_fake_tensor(node) + limit = max(tensor.dim() - 2, 0) + current = node.meta.get("tosa_spatial_rank") + propagated = current + for user in node.users: + user_rank = user.meta.get("tosa_spatial_rank") + if user_rank is None: + continue + propagated = max(propagated, min(user_rank, limit)) + if propagated != current: + node.meta["tosa_spatial_rank"] = propagated + changed = True diff --git a/backends/arm/operator_support/ethos_u55_support.py b/backends/arm/operator_support/ethos_u55_support.py index 49ad820015d..bd43233454f 100644 --- a/backends/arm/operator_support/ethos_u55_support.py +++ b/backends/arm/operator_support/ethos_u55_support.py @@ -189,6 +189,8 @@ class EthosU55NotSupported(OperatorSupportBase): exir_ops.edge.aten.logical_not.default, exir_ops.edge.aten.amax.default, # REDUCE_MAX exir_ops.edge.aten.amin.default, # REDUCE_MIN + exir_ops.edge.aten.conv3d.default, # CONV3D + exir_ops.edge.aten.conv3d.padding, # CONV3D (deprecated alias) exir_ops.edge.aten.eq.Tensor, exir_ops.edge.aten.eq.Scalar, exir_ops.edge.aten.ge.Tensor, diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index b987e99cf4f..15be109d708 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -56,6 +56,7 @@ op_tanh, op_to_dim_order_copy, op_tosa_conv2d, + op_tosa_conv3d, op_tosa_depthwise_conv2d, op_tosa_matmul, op_tosa_rescale, diff --git a/backends/arm/operators/op_tosa_conv3d.py b/backends/arm/operators/op_tosa_conv3d.py new file mode 100644 index 00000000000..e0a8d2ef6ac --- /dev/null +++ b/backends/arm/operators/op_tosa_conv3d.py @@ -0,0 +1,24 @@ +# Copyright 2023-2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Provide a visitor for lowering 3D convolution to TOSA (INT/FP).""" + +from executorch.backends.arm.operators.node_visitor import register_node_visitor +from executorch.backends.arm.operators.op_tosa_conv2d import Conv2dVisitor + + +@register_node_visitor +class Conv3dVisitor(Conv2dVisitor): + """Provide a visitor that serializes TOSA ``CONV3D``.""" + + target = "tosa.CONV3D.default" + + def _get_tosa_op(self): + import serializer.tosa_serializer as ts # type: ignore + + return ts.Op.CONV3D + + def _get_attr_func(self, attr): + return attr.Conv3dAttribute diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index c0f28cc3d87..b32e10bd1bf 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -335,6 +335,14 @@ def _match_pattern( return left_condition and right_condition +_conv_ops = [ + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + torch.ops.aten.conv2d.padding, + torch.ops.aten.conv3d.default, + torch.ops.aten.conv3d.padding, +] + _one_to_one = [ torch.ops.aten.abs.default, torch.ops.aten.ceil.default, @@ -485,14 +493,8 @@ def any_or_hardtanh_min_zero(n: Node): if _match_pattern( node, [ - [ - torch.ops.aten.conv1d.default, - torch.ops.aten.conv2d.default, - torch.ops.aten.conv2d.padding, - ], - [ - torch.ops.aten.batch_norm.default, - ], + _conv_ops, + [torch.ops.aten.batch_norm.default], [ torch.ops.aten.relu.default, torch.ops.aten.relu_.default, @@ -502,11 +504,7 @@ def any_or_hardtanh_min_zero(n: Node): ], filter_fn=any_or_hardtanh_min_zero, ): - if node.target in ( - torch.ops.aten.conv1d.default, - torch.ops.aten.conv2d.default, - torch.ops.aten.conv2d.padding, - ): + if node.target in _conv_ops: quant_properties.quant_inputs = [ _QuantProperty(0, input_act_qspec), _QuantProperty(1, weight_qspec, mark_annotated=True), @@ -523,21 +521,11 @@ def any_or_hardtanh_min_zero(n: Node): elif _match_pattern( node, [ - [ - torch.ops.aten.conv1d.default, - torch.ops.aten.conv2d.default, - torch.ops.aten.conv2d.padding, - ], - [ - torch.ops.aten.batch_norm.default, - ], + _conv_ops, + [torch.ops.aten.batch_norm.default], ], ): - if node.target in ( - torch.ops.aten.conv1d.default, - torch.ops.aten.conv2d.default, - torch.ops.aten.conv2d.padding, - ): + if node.target in _conv_ops: quant_properties.quant_inputs = [ _QuantProperty(0, input_act_qspec), _QuantProperty(1, weight_qspec, mark_annotated=True), @@ -551,10 +539,8 @@ def any_or_hardtanh_min_zero(n: Node): node, [ [ - torch.ops.aten.conv1d.default, - torch.ops.aten.conv2d.default, + *_conv_ops, torch.ops.aten.linear.default, - torch.ops.aten.conv2d.padding, ], [ torch.ops.aten.relu.default, @@ -566,10 +552,8 @@ def any_or_hardtanh_min_zero(n: Node): any_or_hardtanh_min_zero, ): if node.target in ( - torch.ops.aten.conv1d.default, - torch.ops.aten.conv2d.default, + *_conv_ops, torch.ops.aten.linear.default, - torch.ops.aten.conv2d.padding, ): quant_properties.quant_inputs = [ _QuantProperty(0, input_act_qspec), @@ -579,10 +563,8 @@ def any_or_hardtanh_min_zero(n: Node): else: quant_properties.quant_output = _QuantProperty(0, output_act_qspec) elif node.target in ( - torch.ops.aten.conv1d.default, - torch.ops.aten.conv2d.default, + *_conv_ops, torch.ops.aten.linear.default, - torch.ops.aten.conv2d.padding, ): quant_properties.quant_inputs = [ _QuantProperty(0, input_act_qspec), diff --git a/backends/arm/quantizer/quantization_config.py b/backends/arm/quantizer/quantization_config.py index 3e2939cff61..b2bc4a57329 100644 --- a/backends/arm/quantizer/quantization_config.py +++ b/backends/arm/quantizer/quantization_config.py @@ -177,6 +177,8 @@ def _derive_qparams_fn( torch.ops.aten.conv2d.default, torch.ops.aten.linear.default, torch.ops.aten.conv2d.padding, + torch.ops.aten.conv3d.default, + torch.ops.aten.conv3d.padding, ]: if self.input_activation is None or self.weight is None: raise ValueError( @@ -187,7 +189,6 @@ def _derive_qparams_fn( self.input_activation.dtype == torch.int16 and self.weight.dtype == torch.int8 ): - input_act = node.args[0] weight = node.args[1] diff --git a/backends/arm/scripts/parse_test_names.py b/backends/arm/scripts/parse_test_names.py index 1315358b40b..0b531f1d86b 100644 --- a/backends/arm/scripts/parse_test_names.py +++ b/backends/arm/scripts/parse_test_names.py @@ -102,6 +102,7 @@ def parse_test_name( # Special case for convolution op = op.removesuffix("_1d") op = op.removesuffix("_2d") + op = op.removesuffix("_3d") # Remove suffix for 16 bit activation and 8 bit weight test cases op = op.removesuffix("_16a8w") diff --git a/backends/arm/test/ops/test_conv3d.py b/backends/arm/test/ops/test_conv3d.py index 46986103aa0..c5edff8808d 100644 --- a/backends/arm/test/ops/test_conv3d.py +++ b/backends/arm/test/ops/test_conv3d.py @@ -8,7 +8,11 @@ import pytest import torch -from executorch.backends.arm.test import common +from executorch.backends.arm.quantizer.arm_quantizer import ( + get_symmetric_a16w8_quantization_config, + TOSAQuantizer, +) +from executorch.backends.arm.test import common, conftest from executorch.backends.arm.test.tester.test_pipeline import ( EthosU55PipelineINT, EthosU85PipelineINT, @@ -17,6 +21,8 @@ TosaPipelineINT, VgfPipeline, ) +from executorch.backends.arm.tosa import TosaSpecification +from executorch.backends.xnnpack.test.tester import Quantize aten_op = "torch.ops.aten.conv3d.default" exir_op = "executorch_exir_dialects_edge__ops_aten_convolution_default" @@ -109,7 +115,11 @@ def __init__( def get_inputs(self): return ( torch.randn( - self.batches, self.in_channels[0], self.height, self.width, self.depth + self.batches, + self.in_channels[0], + self.depth, + self.height, + self.width, ).to(self.dtype), ) @@ -120,26 +130,108 @@ def forward(self, x): return x -conv3d_2x2_3x2x40x40_nobias = Conv3d( +class Conv3dMultiOp(torch.nn.Module): + """ + Mixed Conv3d/Conv2d pipeline used to verify spatial-rank propagation across ops. + + Topology: + conv3d -> reshape -> conv2d -> reshape/permutation -> conv2d -> reshape -> add(5D) + """ + + def __init__(self, dtype=torch.float): + super().__init__() + self.dtype = dtype + self.conv3d = torch.nn.Conv3d( + in_channels=2, + out_channels=4, + kernel_size=(3, 3, 3), + stride=1, + padding=1, + ).to(dtype) + self.conv2d_main = torch.nn.Conv2d( + in_channels=4, + out_channels=4, + kernel_size=3, + stride=1, + padding=1, + ).to(dtype) + self.conv2d_pointwise = torch.nn.Conv2d( + in_channels=4, + out_channels=4, + kernel_size=1, + stride=1, + padding=0, + ).to(dtype) + self.activation = torch.nn.ReLU() + + def get_inputs(self): + return (torch.randn(1, 2, 3, 8, 8).to(self.dtype),) + + def forward(self, x): + x3d = self.conv3d(x) + batches, channels, depth, height, width = x3d.shape + + reshaped = x3d.reshape(batches * depth, channels, height, width) + conv2d_out = self.activation(self.conv2d_main(reshaped)) + + conv2d_out_5d = ( + conv2d_out.reshape(batches, depth, channels, height, width) + .permute(0, 2, 1, 3, 4) + .contiguous() + ) + + reshaped_again = conv2d_out_5d.permute(0, 2, 1, 3, 4).reshape( + batches * depth, channels, height, width + ) + conv2d_pointwise_out = self.conv2d_pointwise(reshaped_again) + conv2d_pointwise_out_5d = ( + conv2d_pointwise_out.reshape(batches, depth, channels, height, width) + .permute(0, 2, 1, 3, 4) + .contiguous() + ) + + return conv2d_pointwise_out_5d + x3d + + +class DepthwiseConv3d(torch.nn.Module): + def __init__(self, dtype=torch.float): + super().__init__() + self.dtype = dtype + self.conv = torch.nn.Conv3d( + in_channels=2, + out_channels=4, + kernel_size=(3, 3, 3), + padding=1, + groups=2, + ).to(dtype) + + def get_inputs(self): + return (torch.randn(1, 2, 3, 8, 8).to(self.dtype),) + + def forward(self, x): + return self.conv(x) + + +conv3d_2x2_3x2x14x14_nobias = Conv3d( in_channels=2, out_channels=3, kernel_size=(2, 2, 2), stride=1, bias=False, padding=0, - width=40, - height=40, - batches=3, + width=14, + height=14, + batches=2, ) -conv3d_3x3_1x3x256x256_st1 = Conv3d( +conv3d_3x3_1x3x24x24_st1 = Conv3d( in_channels=3, out_channels=10, kernel_size=(3, 3, 3), stride=1, padding=0, - width=256, - height=256, + width=24, + height=24, batches=1, ) @@ -154,14 +246,14 @@ def forward(self, x): batches=1, ) -conv3d_1x1_1x2x128x128_st1 = Conv3d( +conv3d_1x1_1x2x16x16_st1 = Conv3d( in_channels=2, out_channels=1, kernel_size=(1, 1, 1), stride=1, padding=0, - width=128, - height=128, + width=16, + height=16, batches=1, ) @@ -176,25 +268,25 @@ def forward(self, x): batches=1, ) -conv3d_5x5_3x2x128x128_st1 = Conv3d( +conv3d_5x5_3x2x24x24_st1 = Conv3d( in_channels=2, out_channels=3, kernel_size=(5, 5, 5), stride=1, padding=0, - width=128, - height=128, - batches=3, + width=24, + height=24, + batches=2, ) -conv3d_3x3_1x3x224x224_st2_pd1 = Conv3d( +conv3d_3x3_1x3x28x28_st2_pd1 = Conv3d( in_channels=3, out_channels=16, kernel_size=(3, 3, 3), stride=2, padding=1, - width=224, - height=224, + width=28, + height=28, batches=1, ) @@ -214,8 +306,8 @@ def forward(self, x): out_channels=3, kernel_size=(7, 7, 7), stride=2, - padding=1, - dilation=2, + padding=3, + dilation=1, width=16, height=16, batches=1, @@ -306,10 +398,10 @@ def forward(self, x): ) test_data_FP = { - "2x2_3x2x40x40_nobias": lambda: conv3d_2x2_3x2x40x40_nobias, - "3x3_1x3x256x256_st1": lambda: conv3d_3x3_1x3x256x256_st1, + "2x2_3x2x14x14_nobias": lambda: conv3d_2x2_3x2x14x14_nobias, + "3x3_1x3x24x24_st1": lambda: conv3d_3x3_1x3x24x24_st1, "3x3_1x3x12x12_st2_pd1": lambda: conv3d_3x3_1x3x12x12_st2_pd1, - "1x1_1x2x128x128_st1": lambda: conv3d_1x1_1x2x128x128_st1, + "1x1_1x2x16x16_st1": lambda: conv3d_1x1_1x2x16x16_st1, "2x2_1x1x14x13_st2_needs_adjust_pass": lambda: conv3d_2x2_1x1x14x13_st2, "5x5_1x3x14x15_st3_pd1_needs_adjust_pass": lambda: conv3d_5x5_1x3x14x15_st3_pd1, "7x7_1x3x16x16_st2_pd1_dl2_needs_adjust_pass": lambda: conv3d_7x7_1x3x16x16_st2_pd1_dl2, @@ -320,8 +412,8 @@ def forward(self, x): "3x3_1x3x8x9_st3_pd0_dl1_needs_adjust_pass": lambda: conv3d_3x3_1x3x8x9_st3_pd0_dl1, "3x4_1x3x7x7_st3_pd0_dl1_needs_adjust_pass": lambda: conv3d_3x4_1x3x7x7_st3_pd0_dl1, "4x3_1x3x7x7_st3_pd0_dl1_needs_adjust_pass": lambda: conv3d_4x3_1x3x7x7_st3_pd0_dl1, - "5x5_3x2x128x128_st1": lambda: conv3d_5x5_3x2x128x128_st1, - "3x3_1x3x224x224_st2_pd1": lambda: conv3d_3x3_1x3x224x224_st2_pd1, + "5x5_3x2x24x24_st1": lambda: conv3d_5x5_3x2x24x24_st1, + "3x3_1x3x28x28_st2_pd1": lambda: conv3d_3x3_1x3x28x28_st2_pd1, } # Generate a new test set paired with per_channel_quant=True/False. @@ -331,11 +423,36 @@ def forward(self, x): for q in [True, False] } +test_data_INT16 = { + f"{k},16a8w,per_channel_quant={q}": (lambda v=v, q=q: (v(), q)) + for (k, v) in test_data_FP.items() + for q in [True, False] +} + + +def get_symmetric_a16w8_conv3d_quantizer(per_channel_quantization: bool = False): + tosa_version = conftest.get_option("tosa_version") + tosa_profiles = { + "1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"), + } + + quantizer = TOSAQuantizer(tosa_profiles[tosa_version]) + quant_config = get_symmetric_a16w8_quantization_config( + is_per_channel=per_channel_quantization + ) + quantizer.set_global(quant_config) + quantizer.set_module_type(torch.nn.Conv3d, quant_config) + + return Quantize( + quantizer, + quant_config, + ) + + input_t = Tuple[torch.Tensor] @common.parametrize("test_data", test_data_FP) -@pytest.mark.skip # Not implemented, skip until it is. def test_convolution_3d_tosa_FP(test_data): pipeline = TosaPipelineFP[input_t]( test_data(), test_data().get_inputs(), aten_op, exir_op @@ -344,7 +461,6 @@ def test_convolution_3d_tosa_FP(test_data): @common.parametrize("test_data", test_data_INT) -@pytest.mark.skip # Not implemented, skip until it is. def test_convolution_3d_tosa_INT(test_data): model, per_channel_quantization = test_data() pipeline = TosaPipelineINT[input_t]( @@ -358,8 +474,63 @@ def test_convolution_3d_tosa_INT(test_data): pipeline.run() +@common.parametrize("test_data", test_data_INT16) +def test_convolution_3d_tosa_INT16(test_data): + model, per_channel_quantization = test_data() + pipeline = TosaPipelineINT[input_t]( + model, + model.get_inputs(), + aten_op, + exir_op, + per_channel_quantization=per_channel_quantization, + use_to_edge_transform_and_lower=True, + tosa_extensions=["int16"], + qtol=1, + ) + pipeline.change_args( + "quantize", + get_symmetric_a16w8_conv3d_quantizer( + per_channel_quantization=per_channel_quantization + ), + ) + pipeline.run() + + +def test_convolution_3d_tosa_FP_multi_op(): + """Ensure mixed Conv3d/Conv2d graphs keep correct spatial annotations.""" + model = Conv3dMultiOp() + pipeline = TosaPipelineFP[input_t](model, model.get_inputs(), aten_op, exir_op) + pipeline.run() + + +def test_convolution_3d_tosa_INT_multi_op(): + """Ensure mixed Conv3d/Conv2d graphs keep correct spatial annotations.""" + model = Conv3dMultiOp() + pipeline = TosaPipelineINT[input_t]( + model, + model.get_inputs(), + aten_op, + exir_op, + ) + pipeline.run() + + +def test_convolution_3d_tosa_FP_depthwise(): + """Depthwise or Grouped Conv3d should be rejected until grouped support exists.""" + model = DepthwiseConv3d() + pipeline = TosaPipelineFP[input_t]( + model, + model.get_inputs(), + aten_op, + exir_op, + run_on_tosa_ref_model=False, + ) + with pytest.raises(RuntimeError, match="CONV3D with groups != 1"): + pipeline.run() + + @common.parametrize("test_data", test_data_INT) -@pytest.mark.skip # Not implemented, skip until it is. +@pytest.mark.skip(reason="Ethos-U55 does not support CONV3D yet.") def test_convolution_3d_u55_INT(test_data): model, per_channel_quantization = test_data() pipeline = EthosU55PipelineINT[input_t]( @@ -373,7 +544,7 @@ def test_convolution_3d_u55_INT(test_data): @common.parametrize("test_data", test_data_INT) -@pytest.mark.skip # Not implemented, skip until it is. +@pytest.mark.skip(reason="Ethos-U85 does not support CONV3D yet.") def test_convolution_3d_u85_INT(test_data): model, per_channel_quantization = test_data() pipeline = EthosU85PipelineINT[input_t]( @@ -387,7 +558,6 @@ def test_convolution_3d_u85_INT(test_data): @common.parametrize("test_data", test_data_FP) -@pytest.mark.skip # Not implemented, skip until it is. @common.SkipIfNoModelConverter def test_convolution_3d_vgf_FP(test_data): pipeline = VgfPipeline[input_t]( @@ -401,7 +571,6 @@ def test_convolution_3d_vgf_FP(test_data): @common.parametrize("test_data", test_data_INT) -@pytest.mark.skip # Not implemented, skip until it is. @common.SkipIfNoModelConverter def test_convolution_3d_vgf_INT(test_data): model, per_channel_quantization = test_data() @@ -415,6 +584,32 @@ def test_convolution_3d_vgf_INT(test_data): pipeline.run() +def test_convolution_3d_vgf_FP_multi_op(): + """Ensure mixed Conv3d/Conv2d graphs keep correct spatial annotations.""" + model = Conv3dMultiOp() + pipeline = VgfPipeline[input_t]( + model, + model.get_inputs(), + aten_op, + exir_op, + tosa_version="TOSA-1.0+FP", + ) + pipeline.run() + + +def test_convolution_3d_vgf_INT_multi_op(): + """Ensure mixed Conv3d/Conv2d graphs keep correct spatial annotations.""" + model = Conv3dMultiOp() + pipeline = VgfPipeline[input_t]( + model, + model.get_inputs(), + aten_op, + exir_op, + tosa_version="TOSA-1.0+INT", + ) + pipeline.run() + + reject_suite = { "large_stride": lambda: Conv3d( in_channels=1, diff --git a/backends/arm/test/tester/analyze_output_utils.py b/backends/arm/test/tester/analyze_output_utils.py index 3bcac603a9e..527413e9d8f 100644 --- a/backends/arm/test/tester/analyze_output_utils.py +++ b/backends/arm/test/tester/analyze_output_utils.py @@ -39,7 +39,6 @@ def _print_channels( rtol: float, atol: float, ) -> str: - output_str = "" exp = "000" booldata = False @@ -121,7 +120,7 @@ def _print_elements( return output_str -def print_error_diffs( +def print_error_diffs( # noqa: C901 tester_or_result: Any, result_or_reference: TensorLike, reference: TensorLike | None = None, @@ -174,33 +173,53 @@ def print_error_diffs( f"Output needs to be of same shape: {result.shape} != {reference_tensor.shape}" ) shape = result.shape + rank = len(shape) + + if rank == 5: + N, C, D, H, W = shape + elif rank == 4: + N, C, H, W = shape + D = 1 + elif rank == 3: + C, H, W = shape + N, D = 1, 1 + elif rank == 2: + H, W = shape + N, C, D = 1, 1, 1 + elif rank == 1: + W = shape[0] + N, C, D, H = 1, 1, 1, 1 + elif rank == 0: + N = C = D = H = W = 1 + else: + raise ValueError("Invalid tensor rank") - match len(shape): - case 4: - N, C, H, W = (shape[0], shape[1], shape[2], shape[3]) - case 3: - N, C, H, W = (1, shape[0], shape[1], shape[2]) - case 2: - N, C, H, W = (1, 1, shape[0], shape[1]) - case 1: - N, C, H, W = (1, 1, 1, shape[0]) - case 0: - N, C, H, W = (1, 1, 1, 1) - case _: - raise ValueError("Invalid tensor rank") + if rank < 3: + C = 1 + if rank < 2: + H = 1 + if rank < 1: + W = 1 if quantization_scale is not None: atol += quantization_scale * qtol - # Reshape tensors to 4D NCHW format - result = torch.reshape(result, (N, C, H, W)) - reference_tensor = torch.reshape(reference_tensor, (N, C, H, W)) + # Reshape tensors to 4D NCHW format, optionally folding depth into batch. + total_batches = N * D + result = torch.reshape(result, (total_batches, C, H, W)) + reference_tensor = torch.reshape(reference_tensor, (total_batches, C, H, W)) output_str = "" - for n in range(N): - output_str += f"BATCH {n}\n" - result_batch = result[n, :, :, :] - reference_batch = reference_tensor[n, :, :, :] + for idx in range(total_batches): + batch_idx = idx // D if D > 0 else idx + depth_idx = idx % D if D > 0 else 0 + if D > 1: + output_str += f"BATCH {batch_idx} DEPTH {depth_idx}\n" + else: + output_str += f"BATCH {batch_idx}\n" + + result_batch = result[idx, :, :, :] + reference_batch = reference_tensor[idx, :, :, :] is_close = torch.allclose(result_batch, reference_batch, rtol, atol) if is_close: @@ -208,15 +227,15 @@ def print_error_diffs( else: channels_close: list[bool] = [False] * C for c in range(C): - result_hw = result[n, c, :, :] - reference_hw = reference_tensor[n, c, :, :] + result_hw = result[idx, c, :, :] + reference_hw = reference_tensor[idx, c, :, :] channels_close[c] = torch.allclose(result_hw, reference_hw, rtol, atol) if any(channels_close) or len(channels_close) == 1: output_str += _print_channels( - result[n, :, :, :], - reference_tensor[n, :, :, :], + result[idx, :, :, :], + reference_tensor[idx, :, :, :], channels_close, C, H, @@ -226,8 +245,8 @@ def print_error_diffs( ) else: output_str += _print_elements( - result[n, :, :, :], - reference_tensor[n, :, :, :], + result[idx, :, :, :], + reference_tensor[idx, :, :, :], C, H, W, @@ -312,7 +331,6 @@ def dump_error_output( if __name__ == "__main__": - """This is expected to produce the example output of print_diff""" torch.manual_seed(0) a = torch.rand(3, 3, 2, 2) * 0.01 diff --git a/backends/arm/tosa/dialect/__init__.py b/backends/arm/tosa/dialect/__init__.py index adb5064454b..152f99d4431 100644 --- a/backends/arm/tosa/dialect/__init__.py +++ b/backends/arm/tosa/dialect/__init__.py @@ -5,6 +5,7 @@ from executorch.backends.arm.tosa.dialect.ops import ( # noqa F401 conv2d, + conv3d, depthwise_conv2d, matmul, rescale, diff --git a/backends/arm/tosa/dialect/ops/conv3d.py b/backends/arm/tosa/dialect/ops/conv3d.py new file mode 100644 index 00000000000..6428e091367 --- /dev/null +++ b/backends/arm/tosa/dialect/ops/conv3d.py @@ -0,0 +1,75 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch +from executorch.backends.arm.tosa.dialect.lib import TosaValueError +from executorch.backends.arm.tosa.dialect.ops.conv2d import validate_conv2d_args_dtypes +from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op +from executorch.backends.arm.tosa.specification import ( + get_context_spec, + TosaSpecification, +) +from executorch.exir.dialects._ops import ops as exir_ops + + +def validate_conv3d_args_dtypes( + tosa_spec: TosaSpecification, + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, +) -> torch.dtype: + if len(x.shape) != 5 or len(weight.shape) != 5: + raise TosaValueError( + f"Expected 5D input/weight tensors for CONV3D, got {x.shape} and {weight.shape}", + op="CONV3D", + ) + return validate_conv2d_args_dtypes(tosa_spec, x, weight, bias, op="CONV3D") + + +@register_fake_tosa_op( + "CONV3D(Tensor input, " + "Tensor weight, " + "Tensor bias, " + "int[3] stride, " + "int[6] pad, " + "int[3] dialation, " + "bool transposed, " + "int[3] output_padding, " + "int groups) -> Tensor", + ( + TosaSpecification.create_from_string("TOSA-1.0+FP"), + TosaSpecification.create_from_string("TOSA-1.0+INT"), + ), +) +def CONV3D( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: list[int], + pad: list[int], + dialation: list[int], + transposed: bool, + output_padding: list[int], + groups: int, +) -> torch.Tensor: + tosa_spec = get_context_spec() + + output_dtype = validate_conv3d_args_dtypes(tosa_spec, x, weight, bias) + + torch_pad = [pad[0], pad[2], pad[4]] + aten_fake_tensor = exir_ops.edge.aten.convolution.default( + x, + weight, + bias, + stride, + torch_pad, + dialation, + transposed, + output_padding, + groups, + ) + return aten_fake_tensor.to(dtype=output_dtype) From b53895f20581885794997d7c5b5330dd103d7a5f Mon Sep 17 00:00:00 2001 From: Ryan O'Shea Date: Fri, 5 Dec 2025 12:50:28 +0100 Subject: [PATCH 2/3] Fix conv3d linter errors Signed-off-by: Ryan O'Shea Change-Id: I9b096c71da7e1bf943857c04c5606e22ba35da4f --- backends/arm/_passes/rewrite_conv_pass.py | 2 +- backends/arm/test/misc/test_dw_convs_with_shared_weights.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/backends/arm/_passes/rewrite_conv_pass.py b/backends/arm/_passes/rewrite_conv_pass.py index 9d3ad4f933f..7582647eabb 100644 --- a/backends/arm/_passes/rewrite_conv_pass.py +++ b/backends/arm/_passes/rewrite_conv_pass.py @@ -218,7 +218,7 @@ def insert_output_rescale(self, graph_module, node): ) return rescale_node - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901 + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901 modified = False for node in graph_module.graph.nodes: if ( diff --git a/backends/arm/test/misc/test_dw_convs_with_shared_weights.py b/backends/arm/test/misc/test_dw_convs_with_shared_weights.py index 0732924ecfd..8b3b99cf005 100644 --- a/backends/arm/test/misc/test_dw_convs_with_shared_weights.py +++ b/backends/arm/test/misc/test_dw_convs_with_shared_weights.py @@ -6,7 +6,7 @@ from typing import Any, Tuple import torch -from executorch.backends.arm._passes.rewrite_conv2d_pass import RewriteConv2dPass +from executorch.backends.arm._passes.rewrite_conv_pass import RewriteConvPass from executorch.backends.arm.test.tester.test_pipeline import ( PassPipeline, TosaPipelineFP, @@ -51,7 +51,7 @@ def test_convs_tosa_int(): def test_rewrite_conv_pass(): module = DWConvsModule() pipeline = PassPipeline( - module, module.get_inputs(), passes_with_exported_program=[RewriteConv2dPass] + module, module.get_inputs(), passes_with_exported_program=[RewriteConvPass] ) # We can't run TOSA backend dialect operators in eager mode pipeline.pop_stage("run_method_and_compare_outputs") From 40a4202b81f95d1689e773e7a7ba99fbd9410f63 Mon Sep 17 00:00:00 2001 From: Ryan O'Shea Date: Mon, 8 Dec 2025 13:42:44 +0100 Subject: [PATCH 3/3] fix model converter not found on conv3d tests Signed-off-by: Ryan O'Shea Change-Id: Ibce8fdf18a8e3fb2f16a9f5cb0c07d167ba2514c --- backends/arm/test/ops/test_conv3d.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/backends/arm/test/ops/test_conv3d.py b/backends/arm/test/ops/test_conv3d.py index c5edff8808d..9debc19e079 100644 --- a/backends/arm/test/ops/test_conv3d.py +++ b/backends/arm/test/ops/test_conv3d.py @@ -584,6 +584,7 @@ def test_convolution_3d_vgf_INT(test_data): pipeline.run() +@common.SkipIfNoModelConverter def test_convolution_3d_vgf_FP_multi_op(): """Ensure mixed Conv3d/Conv2d graphs keep correct spatial annotations.""" model = Conv3dMultiOp() @@ -597,6 +598,7 @@ def test_convolution_3d_vgf_FP_multi_op(): pipeline.run() +@common.SkipIfNoModelConverter def test_convolution_3d_vgf_INT_multi_op(): """Ensure mixed Conv3d/Conv2d graphs keep correct spatial annotations.""" model = Conv3dMultiOp()