66
77# pyre-strict
88
9- from typing import Dict , Optional , Set , Tuple
9+ from typing import Dict , Optional , Tuple
1010
1111from executorch .exir ._serialize ._cord import Cord
12- from executorch .exir ._serialize ._named_data_store import NamedDataStoreOutput
13-
14- from executorch .exir ._serialize ._program import PTEFile , serialize_pte_binary
15- from executorch .exir ._serialize .data_serializer import (
16- DataEntry ,
17- DataPayload ,
18- DataSerializer ,
12+ from executorch .exir ._serialize ._named_data_store import (
13+ NamedDataStore ,
14+ NamedDataStoreOutput ,
1915)
20-
16+ from executorch .exir ._serialize ._program import PTEFile , serialize_pte_binary
17+ from executorch .exir ._serialize .data_serializer import DataPayload , DataSerializer
2118from executorch .exir .capture ._config import ExecutorchBackendConfig
2219from executorch .exir .emit import EmitterOutput
23- from executorch .exir .schema import Tensor , TensorDataLocation
20+ from executorch .exir .schema import Program , Tensor , TensorDataLocation
2421from executorch .exir .tensor_layout import TensorLayout
2522
2623
24+ def _extract_external_tensor_layouts (program : Program ) -> Dict [str , TensorLayout ]:
25+ # Find all external tensors and organize into {fqn: TensorLayout}.
26+ fqn_to_tensor_layout : Dict [str , TensorLayout ] = {}
27+ for plan in program .execution_plan :
28+ for evalue in plan .values :
29+ if isinstance (evalue .val , Tensor ):
30+ tensor = evalue .val
31+ if (
32+ tensor .extra_tensor_info is not None
33+ and tensor .extra_tensor_info .fully_qualified_name is not None
34+ and tensor .extra_tensor_info .location is TensorDataLocation .EXTERNAL
35+ ):
36+ fqn_to_tensor_layout [
37+ # pyre-ignore Undefined attribute [16]: `Optional` has no attribute `fully_qualified_name`
38+ tensor .extra_tensor_info .fully_qualified_name
39+ ] = TensorLayout (tensor .scalar_type , tensor .sizes , tensor .dim_order )
40+
41+ return fqn_to_tensor_layout
42+
43+
2744def serialize_for_executorch (
2845 emitter_output : EmitterOutput ,
2946 config : ExecutorchBackendConfig ,
@@ -41,11 +58,14 @@ def serialize_for_executorch(
4158 ):
4259 # Create a separate NamedDataStoreOutput with only pte_data; exclude
4360 # external_data, which shouldn't be serialized with the PTE file.
44- pte_named_data = NamedDataStoreOutput (
45- buffers = named_data_store .buffers ,
46- pte_data = named_data_store .pte_data ,
47- external_data = {},
48- )
61+ if len (named_data_store .external_data ) == 0 :
62+ pte_named_data = named_data_store
63+ else :
64+ pte_named_data = NamedDataStoreOutput (
65+ buffers = named_data_store .buffers ,
66+ pte_data = named_data_store .pte_data ,
67+ external_data = {},
68+ )
4969 pte : Cord = serialize_pte_binary (
5070 pte_file = PTEFile (
5171 program = emitter_output .program ,
@@ -58,85 +78,45 @@ def serialize_for_executorch(
5878 delegate_alignment = config .delegate_alignment ,
5979 )
6080
61- # Serialize PTD files.
62- ptd_files : Dict [str , Cord ] = {}
63-
64- # Find all external tensors and organize into {fqn: TensorLayout}.
65- fqn_to_tensor_layout : Dict [str , TensorLayout ] = {}
66- for plan in emitter_output .program .execution_plan :
67- for evalue in plan .values :
68- if isinstance (evalue .val , Tensor ):
69- tensor = evalue .val
70- if (
71- tensor .extra_tensor_info is not None
72- and tensor .extra_tensor_info .fully_qualified_name is not None
73- and tensor .extra_tensor_info .location is TensorDataLocation .EXTERNAL
74- ):
75- fqn_to_tensor_layout [
76- # pyre-ignore Undefined attribute [16]: `Optional` has no attribute `fully_qualified_name`
77- tensor .extra_tensor_info .fully_qualified_name
78- ] = TensorLayout (tensor .scalar_type , tensor .sizes , tensor .dim_order )
79-
80- if len (fqn_to_tensor_layout ) == 0 and (
81+ # Early exit if no external weights.
82+ if len (emitter_output .external_constant_map ) == 0 and (
8183 named_data_store is None or len (named_data_store .external_data ) == 0
8284 ):
83- return pte , ptd_files
85+ return pte , {}
8486
85- # Consolidate tensors and opaque data with the same external tag so they
86- # can be saved to the same PTD.
87- all_external_tags : Set [str ] = set ()
88- if named_data_store is not None and len (named_data_store .external_data ) > 0 :
89- assert (
90- len (named_data_store .buffers ) > 0
91- ), "External data exists, but there are no buffers provided."
92- all_external_tags = set (named_data_store .external_data .keys ())
87+ ptd_files : Dict [str , Cord ] = {}
9388
94- if len (fqn_to_tensor_layout ) > 0 :
95- # emitter_output.external_constant_map contains the mapping from
96- # {file: {fqn: index into external_constant_buffer}}
97- # Contains the locations of the tensor buffers, and must be non-empty
98- # if there are external tensors to serialize.
99- assert (
100- emitter_output .external_constant_map is not None
101- ), "External exists, but there are no buffers provided."
102- all_external_tags = all_external_tags | set (
103- emitter_output .external_constant_map .keys ()
104- )
89+ # If there are no emitter constants, use named_data_store directly.
90+ if len (emitter_output .external_constant_map ) == 0 :
91+ for tag in named_data_store .external_data .keys ():
92+ ptd_files [tag ] = data_serializer .serialize (
93+ DataPayload (
94+ buffers = named_data_store .buffers ,
95+ named_data = named_data_store .external_data [tag ],
96+ )
97+ )
98+ return pte , ptd_files
10599
106- for tag in all_external_tags :
107- buffers = []
108- key_to_data_entry : Dict [str , DataEntry ] = {}
109- # pyre-ignore[16]: Undefined attribute: `Optional` has no attribute `get`.
110- fqn_to_index = emitter_output .external_constant_map .get (tag , {})
111- # Create a DataEntry for each external tensor.
100+ # Collect external weights from emitter output and merge them.
101+ fqn_to_tensor_layout = _extract_external_tensor_layouts (emitter_output .program )
102+ updated_named_data_store = NamedDataStore ()
103+ # Add tensor constants from the emitter to the NamedDataStore.
104+ for tag , fqn_to_index in emitter_output .external_constant_map .items ():
112105 for fqn , index in fqn_to_index .items ():
113- assert fqn in fqn_to_tensor_layout
114- assert fqn not in key_to_data_entry # fqn must be unique
115- key_to_data_entry [fqn ] = DataEntry (
116- buffer_index = len (buffers ),
117- alignment = config .constant_tensor_alignment ,
106+ updated_named_data_store .add_named_data (
107+ fqn ,
108+ emitter_output .external_constant_buffer [index ],
118109 tensor_layout = fqn_to_tensor_layout [fqn ],
110+ external_tag = tag ,
119111 )
120- buffers .append (emitter_output .external_constant_buffer [index ])
121-
122- # Extract external data from named_data_store.
123- # pyre-ignore[16]: Undefined attribute: `Optional` has no attribute `get`.
124- blob_to_data_entry = named_data_store .external_data .get (tag , {})
125- for key , data_entry in blob_to_data_entry .items ():
126- assert key not in key_to_data_entry # key must be unique
127- key_to_data_entry [key ] = DataEntry (
128- buffer_index = len (buffers ),
129- alignment = data_entry .alignment ,
130- tensor_layout = data_entry .tensor_layout ,
131- )
132- # pyre-ignore[16]: Undefined attribute: `Optional` has no attribute `buffers`.
133- buffers .append (named_data_store .buffers [data_entry .buffer_index ])
112+ updated_named_data_store .merge_named_data_store (named_data_store )
134113
135- # Serialize into PTD file.
114+ # Serialize each tag into a PTD file.
115+ for tag in updated_named_data_store .external_data .keys ():
136116 ptd_files [tag ] = data_serializer .serialize (
137117 DataPayload (
138- buffers = buffers ,
139- named_data = key_to_data_entry ,
118+ buffers = updated_named_data_store . buffers ,
119+ named_data = updated_named_data_store . external_data [ tag ] ,
140120 )
141121 )
142122
0 commit comments