Skip to content

Commit da6306f

Browse files
authored
Arm backend: Support mixed TOSA profiles (#15773)
### Summary Add initial support for handling mixed TOSA INT and FP profiles. ### Test plan Tested through existing and new unit tests. cc @freddan80 @zingo @oscarandersson8218 @digantdesai --------- Signed-off-by: Per Åstrand <per.astrand@arm.com>
1 parent 100093f commit da6306f

File tree

5 files changed

+225
-23
lines changed

5 files changed

+225
-23
lines changed

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,11 @@
3939
TOSA_PRO_FP_SupportList,
4040
TOSA_PRO_INT_SupportList,
4141
)
42-
from executorch.backends.arm.tosa import TosaSpecification
43-
from executorch.backends.arm.tosa.specification import Tosa_1_00
42+
from executorch.backends.arm.tosa.specification import (
43+
Tosa_1_00,
44+
TosaSpecification,
45+
TosaSpecMapping,
46+
)
4447
from executorch.exir import ExportedProgram
4548
from executorch.exir.backend.utils import WhyNoPartitionReporter
4649
from executorch.exir.dialects._ops import ops as exir_ops
@@ -116,10 +119,9 @@ def is_node_tosa_supported(
116119

117120

118121
# container for all SupportedTosaOperatorCheck classes
119-
_tosa_spec_support: dict[TosaSpecification, list[Type[SupportedTOSAOperatorCheck]]] = {
120-
TosaSpecification.create_from_string("TOSA-1.0+INT"): [],
121-
TosaSpecification.create_from_string("TOSA-1.0+FP"): [],
122-
}
122+
_tosa_spec_support: TosaSpecMapping[Type[SupportedTOSAOperatorCheck]] = (
123+
TosaSpecMapping()
124+
)
123125

124126

125127
def register_tosa_support_check(checker: Type[SupportedTOSAOperatorCheck]):
@@ -134,7 +136,7 @@ def register_tosa_support_check(checker: Type[SupportedTOSAOperatorCheck]):
134136
135137
"""
136138
for tosa_spec in checker.tosa_specs:
137-
_tosa_spec_support[tosa_spec].append(checker)
139+
_tosa_spec_support.add(tosa_spec, checker)
138140
return checker
139141

140142

@@ -150,12 +152,12 @@ def get_registered_tosa_support_checks(
150152
list[Type[SupportedTOSAOperatorCheck]]: Registered checker classes.
151153
152154
"""
153-
if tosa_spec not in _tosa_spec_support:
155+
checks = _tosa_spec_support.get(tosa_spec)
156+
if not checks:
154157
raise RuntimeError(
155-
f"TOSA specification not valid: {tosa_spec} not in {list(_tosa_spec_support.keys())}"
158+
f"TOSA specification not valid: {tosa_spec} not in {list(_tosa_spec_support._mapping.keys())}"
156159
)
157-
158-
return _tosa_spec_support[tosa_spec]
160+
return checks
159161

160162

161163
def tosa_support_factory(

backends/arm/operators/node_visitor.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
"""
1313

1414
import json
15+
16+
import logging
1517
from typing import Any, Dict, List, Optional
1618

1719
import torch
@@ -20,9 +22,14 @@
2022
from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec
2123
from executorch.backends.arm.debug.schema import DebugHook
2224
from executorch.backends.arm.tosa.mapping import TosaArg
23-
from executorch.backends.arm.tosa.specification import TosaSpecification
25+
from executorch.backends.arm.tosa.specification import (
26+
TosaSpecification,
27+
TosaSpecMapping,
28+
)
2429
from torch.export import ExportedProgram
2530

31+
logger = logging.getLogger(__name__)
32+
2633

2734
class NodeVisitor:
2835
"""Provide a visitor pattern to lower edge IR to TOSA.
@@ -125,23 +132,31 @@ def define_node(
125132

126133

127134
# container for all node visitors
128-
_node_visitor_dicts: Dict[TosaSpecification, Dict] = {
129-
TosaSpecification.create_from_string("TOSA-1.0+INT"): {},
130-
TosaSpecification.create_from_string("TOSA-1.0+FP"): {},
131-
}
135+
_node_visitor_tuples: TosaSpecMapping[tuple] = TosaSpecMapping()
132136

133137

134138
def register_node_visitor(visitor):
135139
"""Register a concrete ``NodeVisitor`` class for its TOSA specs."""
136140
for tosa_spec in visitor.tosa_specs:
137-
_node_visitor_dicts[tosa_spec][visitor.target] = visitor
141+
# Try to get the tuple to make sure it doesn't exist
142+
visitor_tuple = (visitor.target, visitor)
143+
try:
144+
tuples = _node_visitor_tuples.get(tosa_spec)
145+
except KeyError:
146+
tuples = []
147+
148+
if visitor_tuple in tuples:
149+
raise RuntimeError(
150+
f"Visitor for target {visitor.target} already registered for TOSA spec {tosa_spec}"
151+
)
152+
_node_visitor_tuples.add(tosa_spec, visitor_tuple)
138153
return visitor
139154

140155

141156
def get_node_visitors(*args) -> Dict[str, NodeVisitor]:
142157
"""Return a mapping from target names to visitor instances for a spec."""
143-
node_visitors = {}
144-
tosa_spec = None
158+
node_visitors: Dict[str, NodeVisitor] = {}
159+
tosa_spec: TosaSpecification | None = None
145160
for arg in args:
146161
if isinstance(arg, TosaSpecification):
147162
tosa_spec = arg
@@ -150,7 +165,13 @@ def get_node_visitors(*args) -> Dict[str, NodeVisitor]:
150165
if tosa_spec is None:
151166
raise RuntimeError("No TOSA specification supplied.")
152167

153-
for target, visitor in _node_visitor_dicts[tosa_spec].items():
168+
# Use the mapping to get the dict for this spec (handles combined specs)
169+
for node_visitor_tuple in _node_visitor_tuples.get(tosa_spec):
170+
target, visitor = node_visitor_tuple
171+
if target in node_visitors and node_visitors[target].__class__ != visitor:
172+
logger.warning(
173+
f"Target {target} already has visitor class {node_visitors[target].__class__.__name__} registered, overwriting with class: {visitor.__name__}"
174+
)
154175
node_visitors[target] = visitor(*args)
155176

156177
return node_visitors

backends/arm/operators/op_index_tensor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from torch.fx import Node
2525

2626

27-
@register_node_visitor
2827
class CommonIndexTensorVisitor(NodeVisitor):
2928
target = "aten.index.Tensor"
3029

backends/arm/test/misc/test_tosa_spec.py

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55

66
import unittest
77

8-
from executorch.backends.arm.tosa.specification import Tosa_1_00, TosaSpecification
8+
from executorch.backends.arm.tosa.specification import (
9+
Tosa_1_00,
10+
TosaSpecification,
11+
TosaSpecMapping,
12+
)
913

1014
from parameterized import parameterized # type: ignore[import-untyped]
1115

@@ -66,3 +70,100 @@ def test_correct_string_representation(self, version_string: str):
6670
tosa_spec = TosaSpecification.create_from_string(version_string)
6771
assert isinstance(tosa_spec, Tosa_1_00)
6872
assert f"{tosa_spec}" == version_string
73+
74+
75+
class TestTosaSpecMapping(unittest.TestCase):
76+
"""Tests the TosaSpecMapping class"""
77+
78+
def test_mapping(self):
79+
mapping = TosaSpecMapping()
80+
mapping.add(TosaSpecification.create_from_string("TOSA-1.0+INT"), "A")
81+
# check that the mapping is correct
82+
vals = mapping.get(TosaSpecification.create_from_string("TOSA-1.0+INT"))
83+
84+
assert vals == ["A"]
85+
assert len(vals) == 1
86+
87+
def test_mapping_multiple(self):
88+
mapping = TosaSpecMapping()
89+
mapping.add(TosaSpecification.create_from_string("TOSA-1.0+INT"), "A")
90+
mapping.add(TosaSpecification.create_from_string("TOSA-1.0+INT"), "B")
91+
# check that the mapping is correct
92+
vals = mapping.get(TosaSpecification.create_from_string("TOSA-1.0+INT"))
93+
94+
assert vals == ["A", "B"]
95+
assert len(vals) == 2
96+
97+
def test_mapping_different_profiles(self):
98+
mapping = TosaSpecMapping()
99+
mapping.add(TosaSpecification.create_from_string("TOSA-1.0+INT"), "A")
100+
mapping.add(TosaSpecification.create_from_string("TOSA-1.0+FP"), "B")
101+
# check that the mapping is correct
102+
vals_int = mapping.get(TosaSpecification.create_from_string("TOSA-1.0+INT"))
103+
vals_fp = mapping.get(TosaSpecification.create_from_string("TOSA-1.0+FP"))
104+
105+
assert vals_int == ["A"]
106+
assert vals_fp == ["B"]
107+
assert len(vals_int) == 1
108+
assert len(vals_fp) == 1
109+
110+
def test_mapping_different_profiles_combined_consumer(self):
111+
mapping = TosaSpecMapping()
112+
mapping.add(TosaSpecification.create_from_string("TOSA-1.0+INT"), "A")
113+
mapping.add(TosaSpecification.create_from_string("TOSA-1.0+FP"), "B")
114+
# check that the mapping is correct
115+
combined_vals = mapping.get(
116+
TosaSpecification.create_from_string("TOSA-1.0+INT+FP")
117+
)
118+
119+
assert "A" in combined_vals
120+
assert "B" in combined_vals
121+
assert len(combined_vals) == 2
122+
123+
def test_mapping_no_spec(self):
124+
mapping = TosaSpecMapping()
125+
with self.assertRaises(KeyError):
126+
mapping.get(TosaSpecification.create_from_string("TOSA-1.0+INT"))
127+
128+
def test_mapping_no_values_for_spec(self):
129+
mapping = TosaSpecMapping()
130+
mapping.add(TosaSpecification.create_from_string("TOSA-1.0+FP"), "A")
131+
with self.assertRaises(KeyError):
132+
mapping.get(TosaSpecification.create_from_string("TOSA-1.0+INT"))
133+
134+
def test_spec_with_different_profiles(self):
135+
mapping = TosaSpecMapping()
136+
mapping.add(TosaSpecification.create_from_string("TOSA-1.0+FP"), "A")
137+
mapping.add(TosaSpecification.create_from_string("TOSA-1.0+INT"), "B")
138+
# check that the mapping is correct
139+
vals_int = mapping.get(TosaSpecification.create_from_string("TOSA-1.0+INT"))
140+
vals_fp = mapping.get(TosaSpecification.create_from_string("TOSA-1.0+FP"))
141+
vals_int_fp = mapping.get(
142+
TosaSpecification.create_from_string("TOSA-1.0+INT+FP")
143+
)
144+
145+
assert vals_fp == ["A"]
146+
assert vals_int == ["B"]
147+
assert len(vals_int) == 1
148+
assert len(vals_fp) == 1
149+
assert len(vals_int_fp) == 2
150+
151+
def test_combined_profiles(self):
152+
mapping = TosaSpecMapping()
153+
with self.assertRaises(ValueError):
154+
# Don't allow multiple profiles in a single spec
155+
mapping.add(TosaSpecification.create_from_string("TOSA-1.0+INT+FP"), "A")
156+
157+
def test_spec_add_with_extension(self):
158+
mapping = TosaSpecMapping()
159+
with self.assertRaises(ValueError):
160+
mapping.add(
161+
TosaSpecification.create_from_string("TOSA-1.0.0+INT+int16"), "A"
162+
)
163+
164+
def test_spec_non_canonical_key(self):
165+
mapping = TosaSpecMapping()
166+
mapping.add(TosaSpecification.create_from_string("TOSA-1.0+INT"), "A")
167+
168+
val = mapping.get(TosaSpecification.create_from_string("TOSA-1.0+INT+u55"))
169+
assert val == ["A"]

backends/arm/tosa/specification.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,71 @@
1212

1313
import contextvars
1414
import re
15-
from typing import List
15+
from typing import Dict, Generic, List, Set, TypeVar
1616

1717
from packaging.version import Version
1818

19+
T = TypeVar("T")
20+
21+
22+
class TosaSpecMapping(Generic[T]):
23+
def __init__(self):
24+
self._mapping: Dict[TosaSpecification, List[T]] = {}
25+
26+
def add(self, spec: "TosaSpecification", value: T) -> None:
27+
"""
28+
Adds a value to the mapping for the given TOSA specification.
29+
The specification is normalized to its canonical form, which means that
30+
only the version and profiles are considered, without extensions.
31+
This allows for grouping of values under the same TOSA specification
32+
regardless of the extensions they may have.
33+
"""
34+
35+
if spec.is_U55_subset or spec.extensions:
36+
raise ValueError(
37+
f"TosaSpecMapping does not support extensions, got: {spec}"
38+
)
39+
40+
if isinstance(spec, Tosa_1_00) and len(spec.profiles) > 1:
41+
raise ValueError(
42+
f"TosaSpecMapping does not support multiple profiles, got: {spec}"
43+
)
44+
45+
norm_spec = spec._canonical_key()
46+
if norm_spec not in self._mapping:
47+
self._mapping[norm_spec] = []
48+
self._mapping[norm_spec].append(value)
49+
50+
@staticmethod
51+
def _get_base_specs(spec: "TosaSpecification") -> List["TosaSpecification"]:
52+
# Handles combined TOSA-1.0+FP+INT, etc.
53+
if isinstance(spec, Tosa_1_00):
54+
profiles: Set[str] = set(spec.profiles)
55+
if profiles == {"FP", "INT"}:
56+
version = spec.version
57+
return [
58+
TosaSpecification.create_from_string(f"TOSA-{version}+FP"),
59+
TosaSpecification.create_from_string(f"TOSA-{version}+INT"),
60+
]
61+
return [spec]
62+
63+
def get(self, spec: "TosaSpecification") -> List[T]:
64+
"""
65+
Returns a list of values associated with the given TOSA specification.
66+
The specification is normalized to its canonical form, which means that
67+
only the version and profiles are considered, without extensions.
68+
"""
69+
70+
base_specs = self._get_base_specs(spec)
71+
result: List[T] = []
72+
for base in base_specs:
73+
norm_base = base._canonical_key()
74+
result.extend(self._mapping.get(norm_base, []))
75+
if len(result) == 0:
76+
raise KeyError(f"No values found for TOSA specification: {spec}")
77+
78+
return result # Do not deduplicate with set(), as values may be unhashable
79+
1980

2081
class TosaSpecification:
2182
"""Represent a TOSA specification.
@@ -34,6 +95,7 @@ class TosaSpecification:
3495

3596
version: Version
3697
is_U55_subset: bool
98+
extensions: List[str]
3799

38100
def support_integer(self) -> bool:
39101
"""Return True if integer operations are supported."""
@@ -52,6 +114,7 @@ def __init__(self, version: Version, extras: List[str]):
52114
53115
"""
54116
self.version = version
117+
self.extensions = []
55118

56119
self.is_U55_subset = "u55" in extras
57120
if self.is_U55_subset:
@@ -89,6 +152,12 @@ def create_from_string(repr: str) -> "TosaSpecification":
89152

90153
raise ValueError(f"Failed to parse TOSA specification representation: {repr}")
91154

155+
def _canonical_key(self) -> "TosaSpecification":
156+
"""
157+
Returns a new TosaSpecification instance with only version and profiles (no extensions).
158+
"""
159+
raise NotImplementedError
160+
92161

93162
class Tosa_1_00(TosaSpecification):
94163
"""Provide TOSA 1.00 profile and extension semantics.
@@ -232,6 +301,16 @@ def support_extension(self, extension: str) -> bool:
232301

233302
return False
234303

304+
def _canonical_key(self) -> "Tosa_1_00":
305+
"""
306+
Returns a new Tosa_1_00 instance with only major.minor version and profiles (no extensions).
307+
Patch version is set to zero for normalization.
308+
"""
309+
from packaging.version import Version
310+
311+
norm_version = Version(f"{self.version.major}.{self.version.minor}.0")
312+
return Tosa_1_00(norm_version, self.profiles.copy())
313+
235314

236315
class TosaLoweringContext:
237316
"""Manage the TOSA specification context for lowering.

0 commit comments

Comments
 (0)