22#
33# This source code is licensed under the BSD-style license found in the
44# LICENSE file in the root directory of this source tree.
5+ """Provide operator-support checks and registries for TOSA delegation.
6+
7+ Define a base check class, a registry/dispatcher, and several generic checks
8+ used by the TOSA partitioner to decide if FX nodes are eligible for delegation.
9+
10+ """
511
612
713import itertools
4652
4753
4854class SupportedTOSAOperatorCheck (OperatorSupportBase ):
49- """
50- Supported OP for TOSA lowering
55+ """Provide a base operator-support check for TOSA lowering.
56+
57+ Subclasses should implement :py:meth:`is_node_tosa_supported` and declare
58+ the class attributes below to indicate what they support.
59+
60+ Attributes:
61+ targets (list[OpOverload]): Operator overloads supported by this
62+ check.
63+ tosa_specs (list[TosaSpecification]): TOSA specs where the check is
64+ applicable.
65+
5166 """
5267
5368 def __init__ (self , tosa_spec : TosaSpecification , reporter : WhyNoPartitionReporter ):
69+ """Initialize the check with a TOSA spec and reporter.
70+
71+ Args:
72+ tosa_spec (TosaSpecification): Active TOSA specification.
73+ reporter (WhyNoPartitionReporter): Reporter for rejection reasons.
74+
75+ """
5476 self .tosa_spec = tosa_spec
5577 self .reporter = reporter
5678
57- # Should be populated by subclass implementation
79+ # Class attributes populated by subclasses
5880 tosa_specs : list [TosaSpecification ] = []
5981 targets : list [str ] = []
6082
6183 @final
6284 def is_node_supported (
6385 self , submodules : typing .Mapping [str , torch .nn .Module ], node : fx .Node
6486 ) -> bool :
87+ """Return True if the node matches targets and subclass-specific checks.
88+
89+ Args:
90+ submodules (typing.Mapping[str, torch.nn.Module]): Exported program
91+ modules.
92+ node (fx.Node): Node to evaluate.
93+
94+ Returns:
95+ bool: True if both the target and TOSA-specific checks pass.
96+
97+ """
6598 if node .target not in self .targets :
6699 return False
67100 return self .is_node_tosa_supported (node , self .tosa_spec )
68101
69102 def is_node_tosa_supported (
70103 self , node : fx .Node , tosa_spec : TosaSpecification
71104 ) -> bool :
72- """
73- Checks if the fx.Node node is lowerable using the TOSA specification defined by tosa_spec.
105+ """Check if the node is lowerable under the given TOSA spec.
106+
107+ Args:
108+ node (fx.Node): FX node to check.
109+ tosa_spec (TosaSpecification): Active TOSA specification.
110+
111+ Returns:
112+ bool: True if supported; otherwise, False.
113+
74114 """
75115 raise NotImplementedError ("SupportedTOSAOperatorCheck must be extended." )
76116
@@ -83,10 +123,15 @@ def is_node_tosa_supported(
83123
84124
85125def register_tosa_support_check (checker : Type [SupportedTOSAOperatorCheck ]):
86- """
87- Decorator to mark a subclass implmentation of SupportedTosaOperatorCheck
88- to be registered for checking if a torch.fx.Node is lowerable given
89- a TOSA specification.
126+ """Register an operator-support checker for one or more TOSA specs.
127+
128+ Decorate subclasses of :py:class:`SupportedTOSAOperatorCheck` so they are
129+ picked up by the factory and partitioner for the specs declared in their
130+ ``tosa_specs`` class attribute.
131+
132+ Args:
133+ checker (Type[SupportedTOSAOperatorCheck]): Checker class to register.
134+
90135 """
91136 for tosa_spec in checker .tosa_specs :
92137 _tosa_spec_support [tosa_spec ].append (checker )
@@ -96,6 +141,15 @@ def register_tosa_support_check(checker: Type[SupportedTOSAOperatorCheck]):
96141def get_registered_tosa_support_checks (
97142 tosa_spec : TosaSpecification ,
98143) -> list [Type [SupportedTOSAOperatorCheck ]]:
144+ """Get all registered operator-support checkers for a given spec.
145+
146+ Args:
147+ tosa_spec (TosaSpecification): TOSA spec to query.
148+
149+ Returns:
150+ list[Type[SupportedTOSAOperatorCheck]]: Registered checker classes.
151+
152+ """
99153 if tosa_spec not in _tosa_spec_support :
100154 raise RuntimeError (
101155 f"TOSA specification not valid: { tosa_spec } not in { list (_tosa_spec_support .keys ())} "
@@ -110,8 +164,21 @@ def tosa_support_factory(
110164 reporter : WhyNoPartitionReporter ,
111165 additional_checks : Optional [Sequence [OperatorSupportBase ]] = None ,
112166) -> OperatorSupportBase :
113- """Generates an OperatorSupport class depending on the given `tosa_spec`.
114- Additional checks can be supplied to avoid partitioning additional nodes.
167+ """Create an OperatorSupport composite for a TOSA spec.
168+
169+ Combine profile-specific positive checks, registered operator checks, and
170+ negative checks into a single :py:class:`OperatorSupportBase` chain.
171+
172+ Args:
173+ tosa_spec (TosaSpecification): Active TOSA specification.
174+ exported_program (ExportedProgram): Program context for checks.
175+ reporter (WhyNoPartitionReporter): Reporter for rejections.
176+ additional_checks (Optional[Sequence[OperatorSupportBase]]): Extra
177+ negative checks to apply.
178+
179+ Returns:
180+ OperatorSupportBase: Composite checker for the given spec.
181+
115182 """
116183 # Postive checks: Add nodes to partitioning
117184 positive_checks : list [OperatorSupportBase ] = [
@@ -158,37 +225,45 @@ def tosa_support_factory(
158225
159226
160227class TOSAProINTSupportList (OperatorSupportBase ):
161- """
162- TOSA_PRO_INT_SupportList:
163- Ops supported in INT profile via native TOSA ops, decomposition/transformation, pre-compute, or TableOps.
164- Note that ops supported via pre-quantization decompositions are not included here.
228+ """Provide the INT profile support list for TOSA.
229+
230+ TOSA_PRO_INT_SupportList enumerates ops supported in the INT profile via
231+ native TOSA ops, decompositions, pre-compute steps, or TableOps.
232+
233+ Note:
234+ Ops supported via pre-quantization decompositions are not included
235+ here.
236+
165237 """
166238
167239 def is_node_supported (
168240 self , submodules : typing .Mapping [str , torch .nn .Module ], node : fx .Node
169241 ) -> bool :
170-
242+ """Return True if the node is in the INT profile support list."""
171243 return node .op == "call_function" and node .target in TOSA_PRO_INT_SupportList
172244
173245
174246class TOSAProFPSupportList (OperatorSupportBase ):
175- """
176- TOSA_PRO_FP_SupportList:
177- Ops supported in FP profile via native TOSA ops, decomposition/transformation, pre-compute
247+ """Provide the FP profile support list for TOSA.
248+
249+ Includes ops supported natively, via decomposition/transformation, and pre-
250+ compute.
251+
178252 """
179253
180254 def is_node_supported (
181255 self , submodules : typing .Mapping [str , torch .nn .Module ], node : fx .Node
182256 ) -> bool :
183-
257+ """Return True if the node is in the FP profile support list."""
184258 return node .op == "call_function" and node .target in TOSA_PRO_FP_SupportList
185259
186260
187261class CheckProperQuantization (OperatorSupportBase ):
188- """
189- For targeted nodes, check that it has been quantized as expected. In most cases this means that a pair of quantize
190- and dequantize nodes surrounds the node. This is neccessary for table operators and operators that need to rescale
191- activations.
262+ """Ensure targeted nodes are properly quantized.
263+
264+ Verify that a pair of quantize/dequantize nodes surrounds targeted ops so
265+ rescaling and table operators behave correctly.
266+
192267 """
193268
194269 targeted_ops = (
@@ -214,13 +289,28 @@ class CheckProperQuantization(OperatorSupportBase):
214289 )
215290
216291 def __init__ (self , reporter : WhyNoPartitionReporter ):
292+ """Initialize the check with a reporter."""
217293 self .reporter = reporter
218294
219295 def _is_matmul_node_supported (
220296 self , submodules : typing .Mapping [str , torch .nn .Module ], node : fx .Node
221297 ):
222- """
223- Find the matmul source partition containing this node and check that all its inputs and outputs are quantized.
298+ """Check quantization for decomposed matmul partitions.
299+
300+ Handles an edge case where the quantized pipeline
301+ `dq -> torch.matmul/operator.matmul -> q` decomposes into
302+ `dq -> expand -> view -> aten.mm -> view -> q`.
303+
304+ Args:
305+ submodules (Mapping[str, torch.nn.Module]): Map of child modules to
306+ inspect for matmul partitions.
307+ node (fx.Node): Node that should belong to a quantized matmul
308+ partition.
309+
310+ Returns:
311+ bool: True if the matched partition uses quantized inputs and
312+ outputs.
313+
224314 """
225315 for graph_module in submodules .values ():
226316 graph_module = typing .cast (fx .GraphModule , graph_module )
@@ -269,6 +359,12 @@ def _is_matmul_node_supported(
269359 def is_node_supported (
270360 self , submodules : typing .Mapping [str , torch .nn .Module ], node : fx .Node
271361 ) -> bool :
362+ """Return True if the node passes constant-cast and multi-output checks.
363+
364+ Ensures decomposition-specific matmul partitions keep quantized inputs
365+ and outputs.
366+
367+ """
272368 output_quantized = False
273369 input_quantized = False
274370 if node .target not in self .targeted_ops :
@@ -320,21 +416,22 @@ def is_node_supported(
320416
321417
322418class CheckInt64InputsAndOutputs (OperatorSupportBase ):
323- """TOSA does not support int64 tensors so in general, ops with int64 inputs or outputs should not be partitioned .
324- There are however some exceptions:
325- - Nodes with int64 output can be partitioned if they are constant, within int32,
326- and all users cast to something else. In this case, the int64 tensor can safely be cast to int32 AOT.
327- - Nodes with int64 output can be partitioned if all users are getitem with non-int64 output .
328- In this case, there are multiple outputs and the int64 ones are not used .
329- - Nodes with int64 inputs can be partitioned if the inputs are constant placeholders, or constant
330- ops fulfilling the criteria above.
331- Note that we don't check placeholders here, they are partitioned based on whether their users are partitioned
332- or not.
419+ """Reject general int64 tensors while allowing safe exceptions .
420+
421+ Exceptions are:
422+ - Nodes with contant int64 output within int32 range that are cast away
423+ from int64 by all users.
424+ - Int64 output where all users are getitem nodes with non- int64 outputs .
425+ In this case there are multiple outputs and the int64 output is unused.
426+ - Nodes where all inputs are int64 constant placeholders or constant ops
427+ that fulfill the above exceptions.
428+
333429 """
334430
335431 def __init__ (
336432 self , exported_program : ExportedProgram , reporter : WhyNoPartitionReporter
337433 ):
434+ """Initialize the check with program context and reporter."""
338435 self .input_names = [
339436 spec .arg .name
340437 for spec in exported_program .graph_signature .input_specs
@@ -356,6 +453,7 @@ def inside_int32_bounds(self, node: torch.fx.Node) -> bool:
356453 def is_node_supported (
357454 self , submodules : typing .Mapping [str , torch .nn .Module ], node : fx .Node
358455 ) -> bool :
456+ """Return True when int64 use is absent or safe per exceptions."""
359457 if is_submodule_node (node ):
360458 return True
361459 vals = node .meta ["val" ]
@@ -427,16 +525,23 @@ def is_node_supported(
427525
428526
429527class CheckFloat64Inputs (OperatorSupportBase ):
528+ """Reject nodes with float64 inputs.
529+
530+ Useful as a negative check for specs that do not allow float64.
531+
532+ """
430533
431534 def __init__ (
432535 self , exported_program : ExportedProgram , reporter : WhyNoPartitionReporter
433536 ):
537+ """Initialize the check with program context and reporter."""
434538 self .reporter = reporter
435539 super ().__init__ ()
436540
437541 def is_node_supported (
438542 self , submodules : typing .Mapping [str , torch .nn .Module ], node : fx .Node
439543 ) -> bool :
544+ """Return True if no float64 inputs are present."""
440545 if is_submodule_node (node ):
441546 return True
442547 for input_node in (
@@ -455,16 +560,18 @@ def is_node_supported(
455560
456561
457562class RankCheck (OperatorSupportBase ):
458- """Makes sure that nodes with input or output tensors with rank > max_rank are not partitioned """
563+ """Reject nodes with rank greater than ``max_rank``. """
459564
460565 def __init__ (self , reporter : WhyNoPartitionReporter , max_rank : int ):
566+ """Initialize the check with a reporter and maximum rank."""
461567 self .reporter = reporter
462568 self .max_rank = max_rank
463569 super ().__init__ ()
464570
465571 def is_node_supported (
466572 self , submodules : typing .Mapping [str , torch .nn .Module ], node : fx .Node
467573 ) -> bool :
574+ """Return True if input/output tensor ranks are within the limit."""
468575 if is_submodule_node (node ):
469576 return True
470577 input_nodes = (
@@ -509,20 +616,36 @@ def is_node_supported(
509616
510617
511618class CondSupported (OperatorSupportBase ):
512- """Checks whether the cond operator, and it's submodule args, should be partitioned."""
619+ """Check whether cond operator and submodule args should be partitioned.
620+
621+ Applies control-flow extension constraints before allowing delegation.
622+
623+ """
513624
514625 def __init__ (
515626 self ,
516627 exported_program : ExportedProgram ,
517628 tosa_spec : TosaSpecification ,
518629 reporter : WhyNoPartitionReporter ,
519630 ):
631+ """Initialize conditional support checks for TOSA delegation.
632+
633+ Args:
634+ exported_program (ExportedProgram): Program containing the cond
635+ submodules to inspect.
636+ tosa_spec (TosaSpecification): TOSA specification used to validate
637+ supported operators.
638+ reporter (WhyNoPartitionReporter): Reporter that records rejection
639+ reasons for unsupported nodes.
640+
641+ """
520642 self .exported_program = exported_program
521643 self .reporter = reporter
522644 self .tosa_spec = tosa_spec
523645 super ().__init__ ()
524646
525647 def _fully_partitioned (self , submodule : fx .GraphModule ) -> bool :
648+ """Check whether all call_function nodes share one delegation tag."""
526649 partition_tag = None
527650 for submodule_node in submodule .graph .nodes :
528651 if submodule_node .op == "call_function" :
@@ -546,8 +669,10 @@ def _fully_partitioned(self, submodule: fx.GraphModule) -> bool:
546669 return True
547670
548671 def _cond_submodules_fully_partitioned (self , node : fx .Node ) -> bool :
549- """Returns whether the submodule arguments to a cond node were fully partitioned.
550- Updates "val" meta of the submodules if they are.
672+ """Determine whether cond submodules were fully partitioned.
673+
674+ Update the "val" meta of the submodules when they are partitioned.
675+
551676 """
552677 cond_submodules = (
553678 (
@@ -570,6 +695,7 @@ def _cond_submodules_fully_partitioned(self, node: fx.Node) -> bool:
570695 def is_node_supported ( # noqa: C901
571696 self , submodules : typing .Mapping [str , torch .nn .Module ], node : fx .Node
572697 ) -> bool :
698+ """Check whether a cond node and its submodules are partitionable."""
573699 if is_submodule_node (node ):
574700 if not isinstance (self .tosa_spec , Tosa_1_00 ):
575701 self .reporter .report_reject (
0 commit comments