Skip to content

Commit 9632137

Browse files
Arm backend: Add docstrings to operator_support/tosa_supported_operators (#15777)
This file contains the base class `SupportedTOSAOperatorCheck`, which all operator support check inherit from. The documentation for it is applicable to those derived classes as well. The derived classes should document members etc. that are not part of the base class. cc @freddan80 @per @zingo @oscarandersson8218 @digantdesai Signed-off-by: Sebastian Larsson <sebastian.larsson@arm.com> Co-authored-by: Zingo Andersen <zingo.andersen@arm.com>
1 parent b0eba38 commit 9632137

File tree

1 file changed

+166
-40
lines changed

1 file changed

+166
-40
lines changed

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 166 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
5+
"""Provide operator-support checks and registries for TOSA delegation.
6+
7+
Define a base check class, a registry/dispatcher, and several generic checks
8+
used by the TOSA partitioner to decide if FX nodes are eligible for delegation.
9+
10+
"""
511

612

713
import itertools
@@ -46,31 +52,65 @@
4652

4753

4854
class SupportedTOSAOperatorCheck(OperatorSupportBase):
49-
"""
50-
Supported OP for TOSA lowering
55+
"""Provide a base operator-support check for TOSA lowering.
56+
57+
Subclasses should implement :py:meth:`is_node_tosa_supported` and declare
58+
the class attributes below to indicate what they support.
59+
60+
Attributes:
61+
targets (list[OpOverload]): Operator overloads supported by this
62+
check.
63+
tosa_specs (list[TosaSpecification]): TOSA specs where the check is
64+
applicable.
65+
5166
"""
5267

5368
def __init__(self, tosa_spec: TosaSpecification, reporter: WhyNoPartitionReporter):
69+
"""Initialize the check with a TOSA spec and reporter.
70+
71+
Args:
72+
tosa_spec (TosaSpecification): Active TOSA specification.
73+
reporter (WhyNoPartitionReporter): Reporter for rejection reasons.
74+
75+
"""
5476
self.tosa_spec = tosa_spec
5577
self.reporter = reporter
5678

57-
# Should be populated by subclass implementation
79+
# Class attributes populated by subclasses
5880
tosa_specs: list[TosaSpecification] = []
5981
targets: list[str] = []
6082

6183
@final
6284
def is_node_supported(
6385
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
6486
) -> bool:
87+
"""Return True if the node matches targets and subclass-specific checks.
88+
89+
Args:
90+
submodules (typing.Mapping[str, torch.nn.Module]): Exported program
91+
modules.
92+
node (fx.Node): Node to evaluate.
93+
94+
Returns:
95+
bool: True if both the target and TOSA-specific checks pass.
96+
97+
"""
6598
if node.target not in self.targets:
6699
return False
67100
return self.is_node_tosa_supported(node, self.tosa_spec)
68101

69102
def is_node_tosa_supported(
70103
self, node: fx.Node, tosa_spec: TosaSpecification
71104
) -> bool:
72-
"""
73-
Checks if the fx.Node node is lowerable using the TOSA specification defined by tosa_spec.
105+
"""Check if the node is lowerable under the given TOSA spec.
106+
107+
Args:
108+
node (fx.Node): FX node to check.
109+
tosa_spec (TosaSpecification): Active TOSA specification.
110+
111+
Returns:
112+
bool: True if supported; otherwise, False.
113+
74114
"""
75115
raise NotImplementedError("SupportedTOSAOperatorCheck must be extended.")
76116

@@ -83,10 +123,15 @@ def is_node_tosa_supported(
83123

84124

85125
def register_tosa_support_check(checker: Type[SupportedTOSAOperatorCheck]):
86-
"""
87-
Decorator to mark a subclass implmentation of SupportedTosaOperatorCheck
88-
to be registered for checking if a torch.fx.Node is lowerable given
89-
a TOSA specification.
126+
"""Register an operator-support checker for one or more TOSA specs.
127+
128+
Decorate subclasses of :py:class:`SupportedTOSAOperatorCheck` so they are
129+
picked up by the factory and partitioner for the specs declared in their
130+
``tosa_specs`` class attribute.
131+
132+
Args:
133+
checker (Type[SupportedTOSAOperatorCheck]): Checker class to register.
134+
90135
"""
91136
for tosa_spec in checker.tosa_specs:
92137
_tosa_spec_support[tosa_spec].append(checker)
@@ -96,6 +141,15 @@ def register_tosa_support_check(checker: Type[SupportedTOSAOperatorCheck]):
96141
def get_registered_tosa_support_checks(
97142
tosa_spec: TosaSpecification,
98143
) -> list[Type[SupportedTOSAOperatorCheck]]:
144+
"""Get all registered operator-support checkers for a given spec.
145+
146+
Args:
147+
tosa_spec (TosaSpecification): TOSA spec to query.
148+
149+
Returns:
150+
list[Type[SupportedTOSAOperatorCheck]]: Registered checker classes.
151+
152+
"""
99153
if tosa_spec not in _tosa_spec_support:
100154
raise RuntimeError(
101155
f"TOSA specification not valid: {tosa_spec} not in {list(_tosa_spec_support.keys())}"
@@ -110,8 +164,21 @@ def tosa_support_factory(
110164
reporter: WhyNoPartitionReporter,
111165
additional_checks: Optional[Sequence[OperatorSupportBase]] = None,
112166
) -> OperatorSupportBase:
113-
"""Generates an OperatorSupport class depending on the given `tosa_spec`.
114-
Additional checks can be supplied to avoid partitioning additional nodes.
167+
"""Create an OperatorSupport composite for a TOSA spec.
168+
169+
Combine profile-specific positive checks, registered operator checks, and
170+
negative checks into a single :py:class:`OperatorSupportBase` chain.
171+
172+
Args:
173+
tosa_spec (TosaSpecification): Active TOSA specification.
174+
exported_program (ExportedProgram): Program context for checks.
175+
reporter (WhyNoPartitionReporter): Reporter for rejections.
176+
additional_checks (Optional[Sequence[OperatorSupportBase]]): Extra
177+
negative checks to apply.
178+
179+
Returns:
180+
OperatorSupportBase: Composite checker for the given spec.
181+
115182
"""
116183
# Postive checks: Add nodes to partitioning
117184
positive_checks: list[OperatorSupportBase] = [
@@ -158,37 +225,45 @@ def tosa_support_factory(
158225

159226

160227
class TOSAProINTSupportList(OperatorSupportBase):
161-
"""
162-
TOSA_PRO_INT_SupportList:
163-
Ops supported in INT profile via native TOSA ops, decomposition/transformation, pre-compute, or TableOps.
164-
Note that ops supported via pre-quantization decompositions are not included here.
228+
"""Provide the INT profile support list for TOSA.
229+
230+
TOSA_PRO_INT_SupportList enumerates ops supported in the INT profile via
231+
native TOSA ops, decompositions, pre-compute steps, or TableOps.
232+
233+
Note:
234+
Ops supported via pre-quantization decompositions are not included
235+
here.
236+
165237
"""
166238

167239
def is_node_supported(
168240
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
169241
) -> bool:
170-
242+
"""Return True if the node is in the INT profile support list."""
171243
return node.op == "call_function" and node.target in TOSA_PRO_INT_SupportList
172244

173245

174246
class TOSAProFPSupportList(OperatorSupportBase):
175-
"""
176-
TOSA_PRO_FP_SupportList:
177-
Ops supported in FP profile via native TOSA ops, decomposition/transformation, pre-compute
247+
"""Provide the FP profile support list for TOSA.
248+
249+
Includes ops supported natively, via decomposition/transformation, and pre-
250+
compute.
251+
178252
"""
179253

180254
def is_node_supported(
181255
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
182256
) -> bool:
183-
257+
"""Return True if the node is in the FP profile support list."""
184258
return node.op == "call_function" and node.target in TOSA_PRO_FP_SupportList
185259

186260

187261
class CheckProperQuantization(OperatorSupportBase):
188-
"""
189-
For targeted nodes, check that it has been quantized as expected. In most cases this means that a pair of quantize
190-
and dequantize nodes surrounds the node. This is neccessary for table operators and operators that need to rescale
191-
activations.
262+
"""Ensure targeted nodes are properly quantized.
263+
264+
Verify that a pair of quantize/dequantize nodes surrounds targeted ops so
265+
rescaling and table operators behave correctly.
266+
192267
"""
193268

194269
targeted_ops = (
@@ -214,13 +289,28 @@ class CheckProperQuantization(OperatorSupportBase):
214289
)
215290

216291
def __init__(self, reporter: WhyNoPartitionReporter):
292+
"""Initialize the check with a reporter."""
217293
self.reporter = reporter
218294

219295
def _is_matmul_node_supported(
220296
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
221297
):
222-
"""
223-
Find the matmul source partition containing this node and check that all its inputs and outputs are quantized.
298+
"""Check quantization for decomposed matmul partitions.
299+
300+
Handles an edge case where the quantized pipeline
301+
`dq -> torch.matmul/operator.matmul -> q` decomposes into
302+
`dq -> expand -> view -> aten.mm -> view -> q`.
303+
304+
Args:
305+
submodules (Mapping[str, torch.nn.Module]): Map of child modules to
306+
inspect for matmul partitions.
307+
node (fx.Node): Node that should belong to a quantized matmul
308+
partition.
309+
310+
Returns:
311+
bool: True if the matched partition uses quantized inputs and
312+
outputs.
313+
224314
"""
225315
for graph_module in submodules.values():
226316
graph_module = typing.cast(fx.GraphModule, graph_module)
@@ -269,6 +359,12 @@ def _is_matmul_node_supported(
269359
def is_node_supported(
270360
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
271361
) -> bool:
362+
"""Return True if the node passes constant-cast and multi-output checks.
363+
364+
Ensures decomposition-specific matmul partitions keep quantized inputs
365+
and outputs.
366+
367+
"""
272368
output_quantized = False
273369
input_quantized = False
274370
if node.target not in self.targeted_ops:
@@ -320,21 +416,22 @@ def is_node_supported(
320416

321417

322418
class CheckInt64InputsAndOutputs(OperatorSupportBase):
323-
"""TOSA does not support int64 tensors so in general, ops with int64 inputs or outputs should not be partitioned.
324-
There are however some exceptions:
325-
- Nodes with int64 output can be partitioned if they are constant, within int32,
326-
and all users cast to something else. In this case, the int64 tensor can safely be cast to int32 AOT.
327-
- Nodes with int64 output can be partitioned if all users are getitem with non-int64 output.
328-
In this case, there are multiple outputs and the int64 ones are not used.
329-
- Nodes with int64 inputs can be partitioned if the inputs are constant placeholders, or constant
330-
ops fulfilling the criteria above.
331-
Note that we don't check placeholders here, they are partitioned based on whether their users are partitioned
332-
or not.
419+
"""Reject general int64 tensors while allowing safe exceptions.
420+
421+
Exceptions are:
422+
- Nodes with contant int64 output within int32 range that are cast away
423+
from int64 by all users.
424+
- Int64 output where all users are getitem nodes with non-int64 outputs.
425+
In this case there are multiple outputs and the int64 output is unused.
426+
- Nodes where all inputs are int64 constant placeholders or constant ops
427+
that fulfill the above exceptions.
428+
333429
"""
334430

335431
def __init__(
336432
self, exported_program: ExportedProgram, reporter: WhyNoPartitionReporter
337433
):
434+
"""Initialize the check with program context and reporter."""
338435
self.input_names = [
339436
spec.arg.name
340437
for spec in exported_program.graph_signature.input_specs
@@ -356,6 +453,7 @@ def inside_int32_bounds(self, node: torch.fx.Node) -> bool:
356453
def is_node_supported(
357454
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
358455
) -> bool:
456+
"""Return True when int64 use is absent or safe per exceptions."""
359457
if is_submodule_node(node):
360458
return True
361459
vals = node.meta["val"]
@@ -427,16 +525,23 @@ def is_node_supported(
427525

428526

429527
class CheckFloat64Inputs(OperatorSupportBase):
528+
"""Reject nodes with float64 inputs.
529+
530+
Useful as a negative check for specs that do not allow float64.
531+
532+
"""
430533

431534
def __init__(
432535
self, exported_program: ExportedProgram, reporter: WhyNoPartitionReporter
433536
):
537+
"""Initialize the check with program context and reporter."""
434538
self.reporter = reporter
435539
super().__init__()
436540

437541
def is_node_supported(
438542
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
439543
) -> bool:
544+
"""Return True if no float64 inputs are present."""
440545
if is_submodule_node(node):
441546
return True
442547
for input_node in (
@@ -455,16 +560,18 @@ def is_node_supported(
455560

456561

457562
class RankCheck(OperatorSupportBase):
458-
"""Makes sure that nodes with input or output tensors with rank > max_rank are not partitioned"""
563+
"""Reject nodes with rank greater than ``max_rank``."""
459564

460565
def __init__(self, reporter: WhyNoPartitionReporter, max_rank: int):
566+
"""Initialize the check with a reporter and maximum rank."""
461567
self.reporter = reporter
462568
self.max_rank = max_rank
463569
super().__init__()
464570

465571
def is_node_supported(
466572
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
467573
) -> bool:
574+
"""Return True if input/output tensor ranks are within the limit."""
468575
if is_submodule_node(node):
469576
return True
470577
input_nodes = (
@@ -509,20 +616,36 @@ def is_node_supported(
509616

510617

511618
class CondSupported(OperatorSupportBase):
512-
"""Checks whether the cond operator, and it's submodule args, should be partitioned."""
619+
"""Check whether cond operator and submodule args should be partitioned.
620+
621+
Applies control-flow extension constraints before allowing delegation.
622+
623+
"""
513624

514625
def __init__(
515626
self,
516627
exported_program: ExportedProgram,
517628
tosa_spec: TosaSpecification,
518629
reporter: WhyNoPartitionReporter,
519630
):
631+
"""Initialize conditional support checks for TOSA delegation.
632+
633+
Args:
634+
exported_program (ExportedProgram): Program containing the cond
635+
submodules to inspect.
636+
tosa_spec (TosaSpecification): TOSA specification used to validate
637+
supported operators.
638+
reporter (WhyNoPartitionReporter): Reporter that records rejection
639+
reasons for unsupported nodes.
640+
641+
"""
520642
self.exported_program = exported_program
521643
self.reporter = reporter
522644
self.tosa_spec = tosa_spec
523645
super().__init__()
524646

525647
def _fully_partitioned(self, submodule: fx.GraphModule) -> bool:
648+
"""Check whether all call_function nodes share one delegation tag."""
526649
partition_tag = None
527650
for submodule_node in submodule.graph.nodes:
528651
if submodule_node.op == "call_function":
@@ -546,8 +669,10 @@ def _fully_partitioned(self, submodule: fx.GraphModule) -> bool:
546669
return True
547670

548671
def _cond_submodules_fully_partitioned(self, node: fx.Node) -> bool:
549-
"""Returns whether the submodule arguments to a cond node were fully partitioned.
550-
Updates "val" meta of the submodules if they are.
672+
"""Determine whether cond submodules were fully partitioned.
673+
674+
Update the "val" meta of the submodules when they are partitioned.
675+
551676
"""
552677
cond_submodules = (
553678
(
@@ -570,6 +695,7 @@ def _cond_submodules_fully_partitioned(self, node: fx.Node) -> bool:
570695
def is_node_supported( # noqa: C901
571696
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
572697
) -> bool:
698+
"""Check whether a cond node and its submodules are partitionable."""
573699
if is_submodule_node(node):
574700
if not isinstance(self.tosa_spec, Tosa_1_00):
575701
self.reporter.report_reject(

0 commit comments

Comments
 (0)