Skip to content

Commit 100093f

Browse files
martinlsmMartin Lindström
andauthored
Arm backend: Sort passes in transform_for_annotation_pipeline (#15790)
The passes listed in ArmPassManager.transform_for_annotation_pipeline can feel a bit arbitrary because there is no clearly intended structure or pattern being applied there. Restructure the list into clearly labelled blocks to make the code easier to read and maintain. cc @freddan80 @per @zingo @oscarandersson8218 @digantdesai Signed-off-by: Martin Lindström <Martin.Lindstroem@arm.com> Co-authored-by: Martin Lindström <Martin.Lindstroem@arm.com>
1 parent 14482e5 commit 100093f

File tree

2 files changed

+43
-28
lines changed

2 files changed

+43
-28
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -291,9 +291,14 @@ def transform_to_backend_pipeline(
291291
)
292292

293293
def transform_for_annotation_pipeline(self, graph_module: GraphModule):
294+
# Preprocessing passes
295+
294296
self.add_pass(
295297
RemoveGraphAssertsPass()
296298
) # ConvertInt64ConstOpsToInt32Pass requires this pass to remove the assertation in Graph
299+
300+
# Transformation passes (pre scalar -> tensor)
301+
297302
self.add_pass(ConvertInt64ConstOpsToInt32Pass())
298303
self.add_pass(ConvertInt64OutputOpsToInt32Pass())
299304
self.add_pass(InsertInt32CastsAfterInt64PlaceholdersPass())
@@ -304,12 +309,18 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
304309
self.add_pass(CastBoolToInt8Pass())
305310
self.add_pass(DecomposeSignPass())
306311
self.add_pass(DecomposeAddmmPass())
307-
self.add_pass(ReplaceScalarWithTensorByProfilePass())
308312
self.add_pass(DecomposeRemainderPass())
309313
self.add_pass(DecomposeFloorDividePass())
310314
self.add_pass(DecomposeDivTensorModePass())
311-
self.add_pass(DecomposeAddSubAlphaPass())
315+
316+
# Scalars -> tensors
317+
318+
self.add_pass(ReplaceScalarWithTensorByProfilePass())
312319
self.add_pass(ScalarsToAttributePass())
320+
321+
# Transformation passes (post scalar removal)
322+
323+
self.add_pass(DecomposeAddSubAlphaPass())
313324
self.add_pass(DecomposeGroupNormPass())
314325
self.add_pass(DecomposeLayerNormPass())
315326
self.add_pass(DecomposeVarPass())
@@ -323,16 +334,16 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
323334
self.add_pass(DecomposeSqrtPass())
324335
self.add_pass(DecomposeSiluPass())
325336
self.add_pass(DecomposeAvgPool2d())
326-
327337
if self.tosa_spec.is_U55_subset:
328338
# Numerically stable softmax uses amax which is not supported on Ethos-U55
329339
self.add_pass(DecomposeSoftmaxUnstablePass())
330340
else:
331341
self.add_pass(DecomposeSoftmaxPass())
332-
333342
self.add_pass(ConvertMinMaxPass())
334-
self.add_pass(ReplaceInfValues())
335343

344+
# Postprocessing passes
345+
346+
self.add_pass(ReplaceInfValues())
336347
if not self.tosa_spec.is_U55_subset:
337348
# Uses where which is not supported on Ethos-U55
338349
self.add_pass(DecomposeMaskedFill())

backends/arm/_passes/decompose_remainder_pass.py

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
from typing import Set, Type
6+
from typing import Dict, Set, Type
77

88
import torch
99
from executorch.backends.arm._passes import ArmPass
@@ -17,46 +17,50 @@
1717

1818
Op = OpOverload | EdgeOpOverload
1919

20-
21-
def _get_remainder_decomposition_ops(op: Op) -> tuple[Op, Op, Op]:
22-
"""
23-
Returns the (div_mode_op, mul_op, sub_op) needed to lower the provided
24-
remainder operator. The concrete ops depend on whether the remainder op is
25-
the aten or edge variant.
26-
"""
27-
if op == exir_ops.edge.aten.remainder.Tensor:
28-
return (
29-
exir_ops.edge.aten.div.Tensor_mode,
30-
exir_ops.edge.aten.mul.Tensor,
31-
exir_ops.edge.aten.sub.Tensor,
32-
)
33-
if op == torch.ops.aten.remainder.Tensor:
34-
return (
35-
torch.ops.aten.div.Tensor_mode,
36-
torch.ops.aten.mul.Tensor,
37-
torch.ops.aten.sub.Tensor,
38-
)
39-
raise RuntimeError(f"Can't get remainder decomposition ops for op {op}")
20+
_decomposition_ops: Dict[Op, tuple[Op, Op, Op]] = {
21+
exir_ops.edge.aten.remainder.Scalar: (
22+
exir_ops.edge.aten.div.Tensor_mode,
23+
exir_ops.edge.aten.mul.Scalar,
24+
exir_ops.edge.aten.sub.Tensor,
25+
),
26+
torch.ops.aten.remainder.Tensor: (
27+
torch.ops.aten.div.Tensor_mode,
28+
torch.ops.aten.mul.Tensor,
29+
torch.ops.aten.sub.Tensor,
30+
),
31+
torch.ops.aten.remainder.Scalar: (
32+
torch.ops.aten.div.Tensor_mode,
33+
torch.ops.aten.mul.Scalar,
34+
torch.ops.aten.sub.Tensor,
35+
),
36+
exir_ops.edge.aten.remainder.Tensor: (
37+
exir_ops.edge.aten.div.Tensor_mode,
38+
exir_ops.edge.aten.mul.Tensor,
39+
exir_ops.edge.aten.sub.Tensor,
40+
),
41+
}
4042

4143

4244
class DecomposeRemainderPass(ArmPass):
4345
"""
4446
Decompose the remainder operation into primitive arithmetic:
4547
remainder(x, y) -> x - floor_div(x, y) * y
46-
where floor_div(x, y) == div(x, y, rounding_mode=\"floor\").
48+
where floor_div(x, y) == div(x, y, rounding_mode="floor").
4749
"""
4850

4951
_passes_required_after: Set[Type[ExportPass]] = {DecomposeDivTensorModePass}
5052

5153
def call_operator(self, op, args, kwargs, meta, updated=False):
5254
supported_ops = (
55+
exir_ops.edge.aten.remainder.Scalar,
5356
exir_ops.edge.aten.remainder.Tensor,
57+
torch.ops.aten.remainder.Scalar,
5458
torch.ops.aten.remainder.Tensor,
5559
)
5660
if op not in supported_ops:
5761
return super().call_operator(op, args, kwargs, meta, updated)
5862

59-
div_op, mul_op, sub_op = _get_remainder_decomposition_ops(op)
63+
div_op, mul_op, sub_op = _decomposition_ops[op]
6064
x, y = args[0], args[1]
6165

6266
floor_div = super().call_operator(

0 commit comments

Comments
 (0)