Skip to content

Commit 4014597

Browse files
Arm backend: Replace ±inf and FP limit values with ±255.0 (#15976)
- Rename ReplaceInfValuesPass -> ReplaceInfAndLimitValuesPass - Extend to rewrite torch.finfo(torch.float32).{min,max} to ±255.0 to avoid generating TOSA RESCALE shifts < 2 (invalid per spec) Signed-off-by: Yufeng Shi <yufeng.shi@arm.com>
1 parent d886373 commit 4014597

File tree

4 files changed

+13
-10
lines changed

4 files changed

+13
-10
lines changed

backends/arm/_passes/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,5 +117,7 @@
117117
from .to_tosa_memory_format_pass import ToTosaMemoryFormatPass # noqa
118118
from .unsqueeze_before_repeat_pass import UnsqueezeBeforeRepeatPass # noqa
119119
from .unsqueeze_scalar_placeholders_pass import UnsqueezeScalarPlaceholdersPass # noqa
120-
from .replace_inf_values_pass import ReplaceInfValuesPass # noqa # usort: skip
120+
from .replace_inf_and_limit_values_pass import ( # noqa # usort: skip
121+
ReplaceInfAndLimitValuesPass,
122+
)
121123
from .arm_pass_manager import ArmPassManager # noqa # usort: skip

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@
9999
RemoveGetItemPass,
100100
RemoveGraphAssertsPass,
101101
RemoveNoopPass,
102-
ReplaceInfValuesPass,
102+
ReplaceInfAndLimitValuesPass,
103103
ReplaceScalarWithTensorByProfilePass,
104104
RewriteConv2dPass,
105105
RewriteMatmulPass,
@@ -385,7 +385,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
385385
# Postprocessing passes
386386
self.add_passes(
387387
[
388-
ReplaceInfValuesPass(),
388+
ReplaceInfAndLimitValuesPass(),
389389
DecomposeMaskedFillPass() if not self.tosa_spec.is_U55_subset else None,
390390
]
391391
)

backends/arm/_passes/replace_inf_values_pass.py renamed to backends/arm/_passes/replace_inf_and_limit_values_pass.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@
1414
from executorch.exir.pass_base import ExportPass, PassResult
1515

1616

17-
class ReplaceInfValuesPass(ArmPass):
17+
class ReplaceInfAndLimitValuesPass(ArmPass):
1818
"""
19-
Due to limitation in Quantizer, we need to change inf/-inf to more quantizable values.
19+
Rewrites +inf/-inf and floating-point limit values (e.g., torch.finfo(...).min/max)
20+
to quantization-friendly values (±255 by default), improving quantizer stability
21+
(notably for attention mask paths).
2022
"""
2123

2224
_passes_required_after: Set[Type[ExportPass]] = set()
@@ -34,12 +36,12 @@ def call(self, graph_module: torch.fx.GraphModule):
3436
for node in graph_module.graph.nodes:
3537
arg_list = list(node.args)
3638
for index, arg in enumerate(arg_list):
37-
if arg == float("-inf"):
39+
if arg == float("-inf") or arg == torch.finfo(torch.float32).min:
3840
modified = True
39-
arg_list[index] = -255
40-
elif arg == float("inf"):
41+
arg_list[index] = -255.0
42+
elif arg == float("inf") or arg == torch.finfo(torch.float32).max:
4143
modified = True
42-
arg_list[index] = +255
44+
arg_list[index] = +255.0
4345
node.args = tuple(arg_list)
4446

4547
if modified:

backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ class TestCLIPTextModelWithProjection:
4242

4343
ops_after_partitioner_INT = {
4444
"executorch_exir_dialects_edge__ops_aten_argmax_default": 1,
45-
"executorch_exir_dialects_edge__ops_aten_full_default": 1,
4645
"executorch_exir_dialects_edge__ops_aten_index_select_default": 1,
4746
"executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor": 1,
4847
"executorch_exir_dialects_edge__ops_aten_view_copy_default": 1,

0 commit comments

Comments
 (0)