Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@
from .decompose_glu_pass import DecomposeGluPass # noqa
from .decompose_grouped_conv_pass import DecomposeGroupedConvPass # noqa
from .decompose_groupnorm_pass import DecomposeGroupNormPass # noqa
from .decompose_int16_activation_conv2d_pass import ( # noqa
DecomposeConv2dWithInt16ActivationPass,
from .decompose_int16_activation_conv_pass import ( # noqa
DecomposeConvWithInt16ActivationPass,
)
from .decompose_int32_clamp_pass import DecomposeInt32ClampPass # noqa
from .decompose_int_pow_pass import DecomposeIntPowPass # noqa
Expand Down Expand Up @@ -109,7 +109,7 @@
from .replace_scalar_with_tensor_pass import ( # noqa
ReplaceScalarWithTensorByProfilePass,
)
from .rewrite_conv2d_pass import RewriteConv2dPass # noqa
from .rewrite_conv_pass import RewriteConvPass # noqa
from .rewrite_matmul import RewriteMatmulPass # noqa
from .rewrite_upsample import RewriteUpsamplePass # noqa
from .scalars_to_attribute_pass import ScalarsToAttributePass # noqa
Expand Down
8 changes: 4 additions & 4 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
DecomposeAtanPass,
DecomposeAvgPool2dPass,
DecomposeBatchNormNoStatsPass,
DecomposeConv2dWithInt16ActivationPass,
DecomposeConvWithInt16ActivationPass,
DecomposeCoshPass,
DecomposeCosineSimilarityPass,
DecomposeCumsumPass,
Expand Down Expand Up @@ -101,7 +101,7 @@
RemoveNoopPass,
ReplaceInfValuesPass,
ReplaceScalarWithTensorByProfilePass,
RewriteConv2dPass,
RewriteConvPass,
RewriteMatmulPass,
RewriteUpsamplePass,
ScalarsToAttributePass,
Expand Down Expand Up @@ -277,7 +277,7 @@ def _tosa_pipeline(
BroadcastArgsPass(),
ConvertPermuteSingletonToViewPass(),
FuseViewCopyTransformPass(),
DecomposeConv2dWithInt16ActivationPass(),
DecomposeConvWithInt16ActivationPass(),
DecomposeSumPass(),
InsertTableOpsPass(exported_program),
]
Expand All @@ -287,7 +287,7 @@ def _tosa_pipeline(
self.add_passes(
[
RewriteUpsamplePass(),
RewriteConv2dPass(exported_program),
RewriteConvPass(exported_program),
RewriteMatmulPass(),
]
)
Expand Down
14 changes: 14 additions & 0 deletions backends/arm/_passes/arm_pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,20 @@ def get_param_tensor(
raise RuntimeError(f"unsupported param type, {node.op}.")


def expand_around_channel(param: Sequence[int] | int, spatial_rank: int) -> list[int]:
"""
Expand a scalar or 1-D parameter around the channel dimension into a broadcastable
shape while preserving the channel location.
"""
if isinstance(param, int):
return [param] * spatial_rank

param_list = list(param)
if len(param_list) == 1 and spatial_rank > 1:
param_list = param_list * spatial_rank
return param_list


def create_node(
graph: torch.fx.Graph,
op_target: OpOverload | EdgeOpOverload,
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/_passes/conv1d_unsqueeze_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from executorch.backends.arm._passes import ArmPass

from executorch.backends.arm._passes.rewrite_conv2d_pass import RewriteConv2dPass
from executorch.backends.arm._passes.rewrite_conv_pass import RewriteConvPass
from executorch.backends.arm._passes.size_adjust_input_pass import SizeAdjustInputPass

from executorch.exir.dialects._ops import ops as exir_ops
Expand All @@ -29,7 +29,7 @@ class Conv1dUnsqueezePass(ArmPass):
"""

_passes_required_after: Set[Type[ExportPass]] = {
RewriteConv2dPass,
RewriteConvPass,
SizeAdjustInputPass,
}

Expand Down
4 changes: 2 additions & 2 deletions backends/arm/_passes/decompose_cumsum_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.arm_pass_utils import create_node
from executorch.backends.arm._passes.quant_args import QuantArgs
from executorch.backends.arm._passes.rewrite_conv2d_pass import RewriteConv2dPass
from executorch.backends.arm._passes.rewrite_conv_pass import RewriteConvPass

from executorch.backends.transforms.utils import create_constant_placeholder
from executorch.exir import ExportedProgram
Expand Down Expand Up @@ -42,7 +42,7 @@ class DecomposeCumsumPass(ArmPass):
And the convolution is applied over dimension H.
"""

_passes_required_after: Set[Type[ExportPass]] = {RewriteConv2dPass}
_passes_required_after: Set[Type[ExportPass]] = {RewriteConvPass}

def __init__(self, exported_program: ExportedProgram) -> None:
super().__init__()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,37 @@
# LICENSE file in the root directory of this source tree.


from typing import cast, Set, Type
from typing import cast, Sequence, Set, Type

import torch
from executorch.backends.arm._passes.arm_pass import ArmPass
from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.quant_args import QuantArgs

from executorch.backends.arm.tosa.specification import get_context_spec
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass


class DecomposeConv2dWithInt16ActivationPass(ArmPass):
class DecomposeConvWithInt16ActivationPass(ArmPass):
"""
This pass decomposes a convolution with input dtype int16 and bias
into a convolution without bias followed by an addition of the bias
since the TOSA op requires the bias to be int48 which is hard to represent
into a convolution without bias followed by an addition of the bias.
We also reshape the 1D bias to [1, C, 1, …] so it broadcasts along the channel
dimension. Since the TOSA op requires the bias to be int48 which is hard to represent
in torch. Instead rescale the int48 output to int16 and add the bias in int16.
"""

def __init__(self) -> None:
super().__init__()

_passes_required_after: Set[Type[ExportPass]] = set()

def bias_view_shape(
self, bias: torch.Tensor, activation_rank: int
) -> Sequence[int]:
# reshape bias to match convolution output rank so addition broadcasts over channels
return [1, bias.shape[0], *([1] * (activation_rank - 2))]

def call_operator(self, op, args, kwargs, meta):
if op != exir_ops.edge.aten.convolution.default:
return super().call_operator(op, args, kwargs, meta)
Expand All @@ -37,18 +47,22 @@ def call_operator(self, op, args, kwargs, meta):
if args[2] is None:
return super().call_operator(op, args, kwargs, meta)

if args[0].data.dtype == torch.int8:
return super().call_operator(op, args, kwargs, meta)
elif args[0].data.dtype == torch.int16:
if not tosa_spec.support_extension("int16"):
raise ValueError(
"int16 activation for convolution requires TOSA int16 extension"
)
else:
activation_tensor = args[0].data
activation_rank = activation_tensor.dim()

if activation_rank not in (4, 5) or activation_tensor.dtype != torch.int16:
return super().call_operator(op, args, kwargs, meta)

# convolution with bias and activation is int16
bias = args[2]
if not tosa_spec.support_extension("int16"):
raise ValueError(
"int16 activation for convolution requires TOSA int16 extension"
)

# convolution with bias and activation is int16 (expected activation rank enforced above)
# The bias is assumed to be quantized with the same quantization parameters as
# the output of the convolution
bias_arg = args[2]
bias_data = bias_arg.data

no_bias_args = list(args)
no_bias_args[2] = None
Expand All @@ -63,7 +77,7 @@ def call_operator(self, op, args, kwargs, meta):
# reshape the tensor to the same rank as the convolution output to add the bias to the channels
channel_bias = super().call_operator(
exir_ops.edge.aten.view_copy.default,
(bias, [1, len(bias.data), 1, 1]),
(bias_arg, self.bias_view_shape(bias_data, activation_rank)),
{},
new_meta,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from executorch.backends.arm._passes.arm_pass_utils import (
create_node,
expand_around_channel,
get_first_fake_tensor,
get_param_tensor,
is_buffer,
Expand All @@ -29,7 +30,7 @@
from torch.export.graph_signature import InputKind


class RewriteConv2dPass(ArmPass):
class RewriteConvPass(ArmPass):
"""Rewrites aten.convolution to tosa.CONV2D or tosa.DEPTHWISE_CONV2D."""

def __init__(self, exported_program: torch.export.ExportedProgram):
Expand Down Expand Up @@ -88,11 +89,27 @@ def _is_depthwise_conv2d(self, node: torch.fx.Node) -> bool:
or node.target != exir_ops.edge.aten.convolution.default
):
return False
input_tensor = get_first_fake_tensor(node.all_input_nodes[0])
if len(input_tensor.shape) != 4:
return False
groups = node.args[-1]
in_channels = get_first_fake_tensor(node.all_input_nodes[0]).shape[1]
in_channels = input_tensor.shape[1]
out_channels = get_first_fake_tensor(node).shape[1]
return (in_channels == groups) and (out_channels % in_channels) == 0

def _is_conv3d(self, rank, groups) -> bool:
if rank == 5:
# A Conv3D is considered depthwise if Group == InChannels and
# Group * N == OutChannels, where N is a possitive integer.
# Currently we do not support depthwise or grouped conv3d.
# @TODO Add grouped/depthwise conv3d support or reject in partitioner.
if groups != 1:
raise RuntimeError(
"CONV3D with groups != 1 is not supported in the Arm backend."
)
return True
return False

def _reshape_weights(self, weight_node: torch.fx.Node, in_channels: int) -> None:
"""Reshape the weights for depthwise convolution such that when serialized to TOSA,
the weights are in the format [H, W, in_channels, m_length] where
Expand Down Expand Up @@ -201,7 +218,7 @@ def insert_output_rescale(self, graph_module, node):
)
return rescale_node

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901
modified = False
for node in graph_module.graph.nodes:
if (
Expand All @@ -224,30 +241,40 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
group,
) = node.args

pad = [val for val in pad for _ in (0, 1)]
input_fake_tensor = get_first_fake_tensor(x)
weight_fake_tensor = get_first_fake_tensor(weight)
# Adjust the pad value if needed to meet the
# strict convolution output shape calculation.
pad[1] = self._adjust_pad_if_needed(
input_fake_tensor.shape[2],
weight_fake_tensor.shape[2],
stride[0],
pad[1],
dilation[0],
)
pad[3] = self._adjust_pad_if_needed(
input_fake_tensor.shape[3],
weight_fake_tensor.shape[3],
stride[1],
pad[3],
dilation[1],
)
input_shape = input_fake_tensor.shape
weight_shape = weight_fake_tensor.shape
spatial_rank = len(input_shape) - 2
stride_list = expand_around_channel(stride, spatial_rank)
dilation_list = expand_around_channel(dilation, spatial_rank)
pad_list = expand_around_channel(pad, spatial_rank)

pad_attr: list[int] = []
for value in pad_list:
pad_attr.extend([value, value]) # duplicate pad before/after per axis

for axis_index in range(spatial_rank):
pad_index = axis_index * 2 + 1 # adjust trailing pad entry
pad_attr[pad_index] = self._adjust_pad_if_needed(
input_shape[axis_index + 2],
weight_shape[axis_index + 2],
stride_list[axis_index],
pad_attr[pad_index],
dilation_list[axis_index],
)

stride = tuple(stride_list)
dilation = tuple(dilation_list)
pad = pad_attr

has_bias = bias is not None
if not has_bias:
bias = self._add_bias(graph_module, node, weight)

if self._is_depthwise_conv2d(node):
if self._is_conv3d(len(input_shape), group):
target_op = exir_ops.backend.tosa.CONV3D.default
elif self._is_depthwise_conv2d(node):
target_op = exir_ops.backend.tosa.DEPTHWISE_CONV2D.default
# If there are any TOSA.DEPTHWISE_CONV2D nodes using the weights, we've already reshaped them.
if all(user.target != target_op for user in weight.users):
Expand All @@ -256,7 +283,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
else:
target_op = exir_ops.backend.tosa.CONV2D.default

conv2d_args = (
conv_args = (
x,
weight,
bias,
Expand All @@ -272,7 +299,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
tosa_op = create_node(
graph=graph_module.graph,
op_target=target_op,
args=conv2d_args,
args=conv_args,
from_node=node,
inherit_qparams=True,
)
Expand All @@ -281,7 +308,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
input_fake_tensor,
weight_fake_tensor,
bias_fake_tensor,
*conv2d_args[3:],
*conv_args[3:],
)

if (
Expand Down
Loading
Loading