Skip to content

Commit 14482e5

Browse files
Arm backend: Fix incorrect qparams propagation (#15698)
For some passes after q-dq-folding, metadata was naively copied to all new nodes in decomposition. This can lead to problems if these nodes were to be computed ahead of time. This is solved by introducing a new optional argument to create_node, called inherit_qparams that defaults to false. This means that you explicitly need to set it to True in order to inherit qparams from the original node. For call_operator passes, the metadata dict is copied and modified in each pass. The patch also removes a few duplicate passes in ArmPassManager. cc @freddan80 @per @zingo @digantdesai Signed-off-by: Oscar Andersson <oscar.andersson@arm.com> Co-authored-by: Zingo Andersen <zingo.andersen@arm.com>
1 parent 9632137 commit 14482e5

16 files changed

+112
-40
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,6 @@ def _tosa_pipeline(
190190

191191
# Node transformation passes (post q/dq folding)
192192

193-
self.add_pass(DecomposeExpm1Pass())
194193
self.add_pass(DecomposeLogitPass())
195194
self.add_pass(DecomposeMaskedFill())
196195
self.add_pass(DecomposeRoundPass())
@@ -209,7 +208,6 @@ def _tosa_pipeline(
209208
self.add_pass(DecomposeSinhPass())
210209
self.add_pass(DecomposeSignPass())
211210
self.add_pass(DecomposeFloorDividePass())
212-
self.add_pass(DecomposeDivTensorModePass())
213211
self.add_pass(DecomposeGeluPass())
214212
self.add_pass(DecomposeAddSubAlphaPass())
215213
self.add_pass(DecomposeGroupedConv())

backends/arm/_passes/arm_pass_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def create_node(
114114
quantize: bool = False,
115115
q_params: Optional[tuple] = None,
116116
from_node: Optional[torch.fx.Node] = None,
117+
inherit_qparams: bool = False,
117118
):
118119
"""
119120
Adds a node to 'graph'. graph.inserting_before/after() should be used before the call to decide where to insert the node.
@@ -132,6 +133,14 @@ def create_node(
132133
keys = from_node.meta.keys()
133134
for key in keys:
134135
new_meta[key] = from_node.meta[key]
136+
if not inherit_qparams:
137+
if "input_qparams" in new_meta:
138+
new_meta["input_qparams"] = {}
139+
if "output_qparams" in new_meta:
140+
new_meta["output_qparams"] = {}
141+
elif inherit_qparams:
142+
raise ValueError("inherit_qparams is only valid when from_node is given")
143+
135144
old_stack_trace = new_meta.get("stack_trace", "")
136145
new_meta["stack_trace"] = f"{old_stack_trace}\n{traceback.format_stack()[-2]}"
137146
node.meta = new_meta

backends/arm/_passes/broadcast_args_pass.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
6363
args=(arg, multiples),
6464
kwargs={},
6565
from_node=node,
66+
inherit_qparams=False,
6667
)
6768
node.replace_input_with(arg, repeat)
6869

backends/arm/_passes/conv1d_unsqueeze_pass.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,17 @@ def call_operator(self, op, args, kwargs, meta):
4040
if len(stride) != 1:
4141
return super().call_operator(op, args, kwargs, meta)
4242

43+
x_meta = meta.copy()
44+
x_meta.data["input_qparams"] = {}
45+
x_meta.data["output_qparams"] = {}
46+
4347
x = args[0]
4448
x_unsqueezed_shape = list(x.data.shape) + [1]
4549
x = super().call_operator(
4650
exir_ops.edge.aten.view_copy.default,
4751
(x, x_unsqueezed_shape),
4852
{},
49-
meta,
53+
x_meta,
5054
updated=True,
5155
)
5256

@@ -79,12 +83,15 @@ def call_operator(self, op, args, kwargs, meta):
7983
exir_ops.edge.aten.convolution.default, new_args, kwargs, meta, updated=True
8084
)
8185

86+
x_squeezed_meta = meta.copy()
87+
x_squeezed_meta.data["input_qparams"] = {}
88+
x_squeezed_meta.data["output_qparams"] = {}
8289
x_squeezed_shape = list(x.data.shape)[:-1]
8390
x = super().call_operator(
8491
exir_ops.edge.aten.view_copy.default,
8592
(x, x_squeezed_shape),
8693
{},
87-
meta,
94+
x_squeezed_meta,
8895
updated=True,
8996
)
9097

backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from executorch.backends.arm._passes.decompose_avg_pool2d import DecomposeAvgPool2d
1313

1414
from executorch.exir.dialects._ops import ops as exir_ops
15-
from executorch.exir.pass_base import ExportPass
15+
from executorch.exir.pass_base import ExportPass, NodeMetadata
1616

1717
edge_ops = (exir_ops.edge.aten._adaptive_avg_pool2d.default,)
1818
aten_ops = (torch.ops.aten.adaptive_avg_pool2d.default,)
@@ -60,6 +60,11 @@ def call_operator(self, op, args, kwargs, meta, updated=False):
6060
# Vela currently only allows a stride in the interval of [1,3] for AvgPool2d.
6161
# To accommodate this, the AvgPool2d op is applied to pooling regions and the results are concatenated.
6262

63+
# Slices and concats does not require quantization parameters
64+
metadata_dict = dict(meta.data)
65+
metadata_dict["input_qparams"] = {}
66+
metadata_dict["output_qparams"] = {}
67+
meta_with_no_qparams = NodeMetadata(metadata_dict)
6368
res = []
6469
for out_i in range(output_size_h):
6570
row = []
@@ -72,11 +77,15 @@ def call_operator(self, op, args, kwargs, meta, updated=False):
7277

7378
# Slice along H
7479
x_h = super().call_operator(
75-
slice_op, (x, 2, start_h, end_h), kwargs, meta, True
80+
slice_op, (x, 2, start_h, end_h), kwargs, meta_with_no_qparams, True
7681
)
7782
# Slice along W
7883
x_hw = super().call_operator(
79-
slice_op, (x_h, 3, start_w, end_w), kwargs, meta, True
84+
slice_op,
85+
(x_h, 3, start_w, end_w),
86+
kwargs,
87+
meta_with_no_qparams,
88+
True,
8089
)
8190

8291
# Apply avg pooling with kernel size equal to the pooling region
@@ -89,9 +98,13 @@ def call_operator(self, op, args, kwargs, meta, updated=False):
8998
row.append(pooled)
9099

91100
# Concatenate row results along width (dim=3)
92-
row_tensor = super().call_operator(cat_op, (row, 3), kwargs, meta, True)
101+
row_tensor = super().call_operator(
102+
cat_op, (row, 3), kwargs, meta_with_no_qparams, True
103+
)
93104
res.append(row_tensor)
94105

95106
# Concatenate all rows along height (dim=2)
96-
out = super().call_operator(cat_op, (res, 2), kwargs, meta, True)
107+
out = super().call_operator(
108+
cat_op, (res, 2), kwargs, meta_with_no_qparams, True
109+
)
97110
return out

backends/arm/_passes/decompose_cumsum_pass.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,13 @@ def call(self, graph_module):
101101
with graph.inserting_before(node):
102102
# Reshape to 4D with
103103
view_args = (input_node, conv_shape)
104-
view_node = create_node(graph, view_op, args=view_args, from_node=node)
104+
view_node = create_node(
105+
graph,
106+
view_op,
107+
args=view_args,
108+
from_node=node,
109+
inherit_qparams=False,
110+
)
105111

106112
conv_args = (
107113
view_node,
@@ -114,7 +120,9 @@ def call(self, graph_module):
114120
[0],
115121
1,
116122
)
117-
conv_node = create_node(graph, conv_op, args=conv_args, from_node=node)
123+
conv_node = create_node(
124+
graph, conv_op, args=conv_args, from_node=node, inherit_qparams=True
125+
)
118126

119127
# The convolution is inserted after quantization, so we need to set our
120128
# own quantization parameters for the weights here. However since the
@@ -129,12 +137,20 @@ def call(self, graph_module):
129137

130138
slice_args = (conv_node, 2, 0, original_shape[dim])
131139
slice_node = create_node(
132-
graph, slice_op, args=slice_args, from_node=node
140+
graph,
141+
slice_op,
142+
args=slice_args,
143+
from_node=node,
144+
inherit_qparams=False,
133145
)
134146

135147
view_original_args = (slice_node, original_shape)
136148
view_original_node = create_node(
137-
graph, view_op, args=view_original_args, from_node=node
149+
graph,
150+
view_op,
151+
args=view_original_args,
152+
from_node=node,
153+
inherit_qparams=False,
138154
)
139155

140156
# Replace and remove original

backends/arm/_passes/decompose_linear_pass.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ def call(self, graph_module):
5555
op_target=exir_ops.edge.aten.view_copy.default,
5656
args=(input, input_reshaped_shape),
5757
kwargs={},
58+
from_node=node,
59+
inherit_qparams=False,
5860
)
5961

6062
# Reshape weights to 4D with shape (Co, Ci, 1, 1)
@@ -63,6 +65,8 @@ def call(self, graph_module):
6365
op_target=exir_ops.edge.aten.view_copy.default,
6466
args=(weights, weights_reshaped_shape),
6567
kwargs={},
68+
from_node=node,
69+
inherit_qparams=False,
6670
)
6771

6872
conv = create_node(
@@ -81,6 +85,7 @@ def call(self, graph_module):
8185
),
8286
kwargs={},
8387
from_node=node,
88+
inherit_qparams=True,
8489
)
8590

8691
with graph_module.graph.inserting_after(conv):
@@ -93,14 +98,8 @@ def call(self, graph_module):
9398
args=(conv, list(output_shape)),
9499
kwargs={},
95100
from_node=node,
101+
inherit_qparams=False,
96102
)
97-
# Quantization parameters are inherited from original linear node, but
98-
# output reshape should use the linear node's output qparams for both input
99-
# and output.
100-
if "input_qparams" in output.meta:
101-
output.meta["input_qparams"] = output.meta.get(
102-
"output_qparams", None
103-
)
104103

105104
node.replace_all_uses_with(output)
106105
graph_module.graph.erase_node(node)

backends/arm/_passes/decompose_maxpool2d_with_dilation.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -70,34 +70,40 @@ def call_operator(self, op, args, kwargs, meta):
7070
ph2 += extra_h * d_h
7171
pw2 += extra_w * d_w
7272

73+
meta_with_no_qparams = meta.copy()
74+
meta_with_no_qparams.data["output_qparams"] = {}
75+
meta_with_no_qparams.data["input_qparams"] = {}
76+
meta_with_no_output_qparams = meta.copy()
77+
meta_with_no_output_qparams.data["output_qparams"] = {}
78+
7379
# 1) Pad via EXIR edge pad (preserves dtype)
7480
pad_edge = exir_ops.edge.aten.constant_pad_nd.default
7581
pads = [pw, pw2, ph, ph2, 0, 0, 0, 0]
7682
x_pad = super().call_operator(
7783
pad_edge,
7884
(x, pads, 0),
7985
{},
80-
meta,
86+
meta_with_no_output_qparams,
8187
)
8288

8389
# 2) Space-to-batch: reshape and permute
8490
x2 = super().call_operator(
8591
exir_ops.edge.aten.view_copy.default,
8692
(x_pad, [N, C, H_pack, d_h, W_pack, d_w]),
8793
{},
88-
meta,
94+
meta_with_no_qparams,
8995
)
9096
x2 = super().call_operator(
9197
exir_ops.edge.aten.permute_copy.default,
9298
(x2, [3, 5, 0, 1, 2, 4]),
9399
{},
94-
meta,
100+
meta_with_no_qparams,
95101
)
96102
x2 = super().call_operator(
97103
exir_ops.edge.aten.view_copy.default,
98104
(x2, [N * d_h * d_w, C, H_pack, W_pack]),
99105
{},
100-
meta,
106+
meta_with_no_qparams,
101107
)
102108

103109
# 3) Core pooling on packed tensor
@@ -120,13 +126,13 @@ def call_operator(self, op, args, kwargs, meta):
120126
operator.getitem,
121127
(pool_out, 0),
122128
{},
123-
meta,
129+
meta_with_no_qparams,
124130
)
125131
indices_proxy = super().call_operator(
126132
operator.getitem,
127133
(pool_out, 1),
128134
{},
129-
meta,
135+
meta_with_no_qparams,
130136
)
131137
pooled_fake, _ = pool_out.data
132138
else:
@@ -141,20 +147,20 @@ def call_operator(self, op, args, kwargs, meta):
141147
exir_ops.edge.aten.view_copy.default,
142148
(pooled_proxy, [d_h, d_w, N, C_out, H_out, W_out]),
143149
{},
144-
meta,
150+
meta_with_no_qparams,
145151
)
146152
out = super().call_operator(
147153
exir_ops.edge.aten.permute_copy.default,
148154
(out, [2, 3, 4, 0, 5, 1]),
149155
{},
150-
meta,
156+
meta_with_no_qparams,
151157
)
152158
# now flatten back into (N, C, H_out*d_h, W_out*d_w)
153159
out = super().call_operator(
154160
exir_ops.edge.aten.view_copy.default,
155161
(out, [N, C_out, H_out * d_h, W_out * d_w]),
156162
{},
157-
meta,
163+
meta_with_no_qparams,
158164
)
159165

160166
# 5) Final crop
@@ -166,13 +172,13 @@ def call_operator(self, op, args, kwargs, meta):
166172
exir_ops.edge.aten.slice_copy.Tensor,
167173
(out, 2, S_top, S_top + H),
168174
{},
169-
meta,
175+
meta_with_no_qparams,
170176
)
171177
out = super().call_operator(
172178
exir_ops.edge.aten.slice_copy.Tensor,
173179
(out, 3, S_left, S_left + W),
174180
{},
175-
meta,
181+
meta_with_no_qparams,
176182
)
177183

178184
if is_with_indices:
@@ -181,7 +187,7 @@ def call_operator(self, op, args, kwargs, meta):
181187
exir_ops.edge.aten.view_copy.default,
182188
(indices_proxy, [d_h, d_w, N, C_out, H_out, W_out]),
183189
{},
184-
meta,
190+
meta_with_no_qparams,
185191
)
186192
idx = super().call_operator(
187193
exir_ops.edge.aten.permute_copy.default,
@@ -193,19 +199,19 @@ def call_operator(self, op, args, kwargs, meta):
193199
exir_ops.edge.aten.view_copy.default,
194200
(idx, [N, C_out, H_out * d_h, W_out * d_w]),
195201
{},
196-
meta,
202+
meta_with_no_qparams,
197203
)
198204
idx = super().call_operator(
199205
exir_ops.edge.aten.slice_copy.Tensor,
200206
(idx, 2, S_top, S_top + H),
201207
{},
202-
meta,
208+
meta_with_no_qparams,
203209
)
204210
idx = super().call_operator(
205211
exir_ops.edge.aten.slice_copy.Tensor,
206212
(idx, 3, S_left, S_left + W),
207213
{},
208-
meta,
214+
meta_with_no_qparams,
209215
)
210216
return out, idx
211217

backends/arm/_passes/decompose_select.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,18 @@ def call(self, graph_module: torch.fx.GraphModule):
5252

5353
with graph_module.graph.inserting_before(node):
5454
slice_node = create_node(
55-
graph_module.graph, slice_op, (input_node, dim, index, index + 1)
55+
graph_module.graph,
56+
slice_op,
57+
(input_node, dim, index, index + 1),
58+
from_node=node,
59+
inherit_qparams=False,
5660
)
5761
squeeze_node = create_node(
58-
graph_module.graph, squeeze_op, (slice_node, [dim]), from_node=node
62+
graph_module.graph,
63+
squeeze_op,
64+
(slice_node, [dim]),
65+
from_node=node,
66+
inherit_qparams=True,
5967
)
6068

6169
node.replace_all_uses_with(squeeze_node)

backends/arm/_passes/decompose_sum_pass.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,11 @@ def call_operator(self, op, args, kwargs, meta):
7777

7878
for dim in dims:
7979
input_node = super().call_operator(
80-
sum_op, (input_node, dim, True), kwargs, meta, updated=True
80+
sum_op,
81+
(input_node, dim, True),
82+
kwargs,
83+
meta,
84+
updated=True,
8185
)
8286

8387
if not keepdims:

0 commit comments

Comments
 (0)