Skip to content

Commit d886373

Browse files
authored
Fix duplication bug in serialize_for_executorch
Differential Revision: D88415390 Pull Request resolved: #16087
1 parent 1a7b0dc commit d886373

File tree

3 files changed

+366
-83
lines changed

3 files changed

+366
-83
lines changed

exir/_serialize/_serialize.py

Lines changed: 63 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -6,24 +6,41 @@
66

77
# pyre-strict
88

9-
from typing import Dict, Optional, Set, Tuple
9+
from typing import Dict, Optional, Tuple
1010

1111
from executorch.exir._serialize._cord import Cord
12-
from executorch.exir._serialize._named_data_store import NamedDataStoreOutput
13-
14-
from executorch.exir._serialize._program import PTEFile, serialize_pte_binary
15-
from executorch.exir._serialize.data_serializer import (
16-
DataEntry,
17-
DataPayload,
18-
DataSerializer,
12+
from executorch.exir._serialize._named_data_store import (
13+
NamedDataStore,
14+
NamedDataStoreOutput,
1915
)
20-
16+
from executorch.exir._serialize._program import PTEFile, serialize_pte_binary
17+
from executorch.exir._serialize.data_serializer import DataPayload, DataSerializer
2118
from executorch.exir.capture._config import ExecutorchBackendConfig
2219
from executorch.exir.emit import EmitterOutput
23-
from executorch.exir.schema import Tensor, TensorDataLocation
20+
from executorch.exir.schema import Program, Tensor, TensorDataLocation
2421
from executorch.exir.tensor_layout import TensorLayout
2522

2623

24+
def _extract_external_tensor_layouts(program: Program) -> Dict[str, TensorLayout]:
25+
# Find all external tensors and organize into {fqn: TensorLayout}.
26+
fqn_to_tensor_layout: Dict[str, TensorLayout] = {}
27+
for plan in program.execution_plan:
28+
for evalue in plan.values:
29+
if isinstance(evalue.val, Tensor):
30+
tensor = evalue.val
31+
if (
32+
tensor.extra_tensor_info is not None
33+
and tensor.extra_tensor_info.fully_qualified_name is not None
34+
and tensor.extra_tensor_info.location is TensorDataLocation.EXTERNAL
35+
):
36+
fqn_to_tensor_layout[
37+
# pyre-ignore Undefined attribute [16]: `Optional` has no attribute `fully_qualified_name`
38+
tensor.extra_tensor_info.fully_qualified_name
39+
] = TensorLayout(tensor.scalar_type, tensor.sizes, tensor.dim_order)
40+
41+
return fqn_to_tensor_layout
42+
43+
2744
def serialize_for_executorch(
2845
emitter_output: EmitterOutput,
2946
config: ExecutorchBackendConfig,
@@ -41,11 +58,14 @@ def serialize_for_executorch(
4158
):
4259
# Create a separate NamedDataStoreOutput with only pte_data; exclude
4360
# external_data, which shouldn't be serialized with the PTE file.
44-
pte_named_data = NamedDataStoreOutput(
45-
buffers=named_data_store.buffers,
46-
pte_data=named_data_store.pte_data,
47-
external_data={},
48-
)
61+
if len(named_data_store.external_data) == 0:
62+
pte_named_data = named_data_store
63+
else:
64+
pte_named_data = NamedDataStoreOutput(
65+
buffers=named_data_store.buffers,
66+
pte_data=named_data_store.pte_data,
67+
external_data={},
68+
)
4969
pte: Cord = serialize_pte_binary(
5070
pte_file=PTEFile(
5171
program=emitter_output.program,
@@ -58,85 +78,45 @@ def serialize_for_executorch(
5878
delegate_alignment=config.delegate_alignment,
5979
)
6080

61-
# Serialize PTD files.
62-
ptd_files: Dict[str, Cord] = {}
63-
64-
# Find all external tensors and organize into {fqn: TensorLayout}.
65-
fqn_to_tensor_layout: Dict[str, TensorLayout] = {}
66-
for plan in emitter_output.program.execution_plan:
67-
for evalue in plan.values:
68-
if isinstance(evalue.val, Tensor):
69-
tensor = evalue.val
70-
if (
71-
tensor.extra_tensor_info is not None
72-
and tensor.extra_tensor_info.fully_qualified_name is not None
73-
and tensor.extra_tensor_info.location is TensorDataLocation.EXTERNAL
74-
):
75-
fqn_to_tensor_layout[
76-
# pyre-ignore Undefined attribute [16]: `Optional` has no attribute `fully_qualified_name`
77-
tensor.extra_tensor_info.fully_qualified_name
78-
] = TensorLayout(tensor.scalar_type, tensor.sizes, tensor.dim_order)
79-
80-
if len(fqn_to_tensor_layout) == 0 and (
81+
# Early exit if no external weights.
82+
if len(emitter_output.external_constant_map) == 0 and (
8183
named_data_store is None or len(named_data_store.external_data) == 0
8284
):
83-
return pte, ptd_files
85+
return pte, {}
8486

85-
# Consolidate tensors and opaque data with the same external tag so they
86-
# can be saved to the same PTD.
87-
all_external_tags: Set[str] = set()
88-
if named_data_store is not None and len(named_data_store.external_data) > 0:
89-
assert (
90-
len(named_data_store.buffers) > 0
91-
), "External data exists, but there are no buffers provided."
92-
all_external_tags = set(named_data_store.external_data.keys())
87+
ptd_files: Dict[str, Cord] = {}
9388

94-
if len(fqn_to_tensor_layout) > 0:
95-
# emitter_output.external_constant_map contains the mapping from
96-
# {file: {fqn: index into external_constant_buffer}}
97-
# Contains the locations of the tensor buffers, and must be non-empty
98-
# if there are external tensors to serialize.
99-
assert (
100-
emitter_output.external_constant_map is not None
101-
), "External exists, but there are no buffers provided."
102-
all_external_tags = all_external_tags | set(
103-
emitter_output.external_constant_map.keys()
104-
)
89+
# If there are no emitter constants, use named_data_store directly.
90+
if len(emitter_output.external_constant_map) == 0:
91+
for tag in named_data_store.external_data.keys():
92+
ptd_files[tag] = data_serializer.serialize(
93+
DataPayload(
94+
buffers=named_data_store.buffers,
95+
named_data=named_data_store.external_data[tag],
96+
)
97+
)
98+
return pte, ptd_files
10599

106-
for tag in all_external_tags:
107-
buffers = []
108-
key_to_data_entry: Dict[str, DataEntry] = {}
109-
# pyre-ignore[16]: Undefined attribute: `Optional` has no attribute `get`.
110-
fqn_to_index = emitter_output.external_constant_map.get(tag, {})
111-
# Create a DataEntry for each external tensor.
100+
# Collect external weights from emitter output and merge them.
101+
fqn_to_tensor_layout = _extract_external_tensor_layouts(emitter_output.program)
102+
updated_named_data_store = NamedDataStore()
103+
# Add tensor constants from the emitter to the NamedDataStore.
104+
for tag, fqn_to_index in emitter_output.external_constant_map.items():
112105
for fqn, index in fqn_to_index.items():
113-
assert fqn in fqn_to_tensor_layout
114-
assert fqn not in key_to_data_entry # fqn must be unique
115-
key_to_data_entry[fqn] = DataEntry(
116-
buffer_index=len(buffers),
117-
alignment=config.constant_tensor_alignment,
106+
updated_named_data_store.add_named_data(
107+
fqn,
108+
emitter_output.external_constant_buffer[index],
118109
tensor_layout=fqn_to_tensor_layout[fqn],
110+
external_tag=tag,
119111
)
120-
buffers.append(emitter_output.external_constant_buffer[index])
121-
122-
# Extract external data from named_data_store.
123-
# pyre-ignore[16]: Undefined attribute: `Optional` has no attribute `get`.
124-
blob_to_data_entry = named_data_store.external_data.get(tag, {})
125-
for key, data_entry in blob_to_data_entry.items():
126-
assert key not in key_to_data_entry # key must be unique
127-
key_to_data_entry[key] = DataEntry(
128-
buffer_index=len(buffers),
129-
alignment=data_entry.alignment,
130-
tensor_layout=data_entry.tensor_layout,
131-
)
132-
# pyre-ignore[16]: Undefined attribute: `Optional` has no attribute `buffers`.
133-
buffers.append(named_data_store.buffers[data_entry.buffer_index])
112+
updated_named_data_store.merge_named_data_store(named_data_store)
134113

135-
# Serialize into PTD file.
114+
# Serialize each tag into a PTD file.
115+
for tag in updated_named_data_store.external_data.keys():
136116
ptd_files[tag] = data_serializer.serialize(
137117
DataPayload(
138-
buffers=buffers,
139-
named_data=key_to_data_entry,
118+
buffers=updated_named_data_store.buffers,
119+
named_data=updated_named_data_store.external_data[tag],
140120
)
141121
)
142122

exir/_serialize/test/TARGETS

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,17 @@ python_unittest(
4343
"//executorch/exir/_serialize:lib",
4444
],
4545
)
46+
47+
python_unittest(
48+
name = "test_serialize",
49+
srcs = [
50+
"test_serialize.py",
51+
],
52+
deps = [
53+
"//executorch/exir:schema",
54+
"//executorch/exir/_serialize:lib",
55+
"//executorch/exir/emit:lib",
56+
"//executorch/exir/capture:config",
57+
"//executorch/exir/tests:lib",
58+
],
59+
)

0 commit comments

Comments
 (0)