Skip to content

Commit b0eba38

Browse files
martinlsmMartin Lindströmzingo
authored
Arm backend: Add PrintGraphModuleCodePass (#15774)
Inspecting a module between passes that operate on it is often useful for debugging purposes. This patch adds a pass called `PrintGraphModuleCodePass` that prints the graph module's code in its current state. Compared to the already existing `VisualizePass`, `PrintGraphModuleCodePass` enables quicker feedback when the module is small enough to be visualized in a text-based print. Example output from the pass: ``` [arm_pass_manager.py:305] def forward(self, x, y): x, y, = fx_pytree.tree_flatten_spec(([x, y], {}), self._in_spec) remainder = torch.ops.aten.remainder.Scalar(x, 0.25); x = None return pytree.tree_unflatten((remainder,), self._out_spec) ``` cc @freddan80 @per @zingo @oscarandersson8218 @digantdesai Signed-off-by: Martin Lindström <Martin.Lindstroem@arm.com> Co-authored-by: Martin Lindström <Martin.Lindstroem@arm.com> Co-authored-by: Zingo Andersen <zingo.andersen@arm.com>
1 parent 0bb9d18 commit b0eba38

File tree

1 file changed

+30
-0
lines changed

1 file changed

+30
-0
lines changed

backends/arm/_passes/_debug_passes.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,16 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import inspect
7+
import os
68
from typing import Set, Type
79

810
import torch
911
from executorch.backends.arm._passes import ArmPass
1012
from executorch.devtools.visualization.visualization_utils import visualize_graph
1113
from executorch.exir import ExportedProgram
1214
from executorch.exir.pass_base import ExportPass, PassResult
15+
from torch.fx import GraphModule
1316

1417

1518
class VisualizePass(ArmPass):
@@ -26,3 +29,30 @@ def __init__(self, exported_program: ExportedProgram) -> None:
2629
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
2730
visualize_graph(graph_module, self.exported_program)
2831
return PassResult(graph_module, False)
32+
33+
34+
class PrintGraphModuleCodePass(ArmPass):
35+
"""
36+
This pass prints the graph module's code to stdout for debugging purposes.
37+
38+
Example output:
39+
40+
[arm_pass_manager.py:305]
41+
def forward(self, x, y):
42+
x, y, = fx_pytree.tree_flatten_spec(([x, y], {}), self._in_spec)
43+
remainder = torch.ops.aten.remainder.Scalar(x, 0.25); x = None
44+
return pytree.tree_unflatten((remainder,), self._out_spec)
45+
"""
46+
47+
_passes_required_after: Set[Type[ExportPass]] = set()
48+
49+
def __init__(self, label: str | None = None):
50+
super().__init__()
51+
caller_frame = inspect.stack()[1]
52+
origin = f"{os.path.basename(caller_frame.filename)}:{caller_frame.lineno}"
53+
self.label = f"[{label}]" if label is not None else f"[{origin}]"
54+
55+
def call(self, graph_module: GraphModule) -> PassResult:
56+
gm_code = graph_module.code.strip()
57+
print(f"\n{self.label}\n{gm_code}")
58+
return PassResult(graph_module, False)

0 commit comments

Comments
 (0)