Skip to content

Commit c9bf066

Browse files
authored
[etrecord] Implement generic fallback for GraphModuleSerializer.handle_call_function (#16069)
Title says it all! Implement the case where `node.target` is neither `torch._ops.OpOverload` or `torch._ops.HigherOrderOperator`, instead of throwing an exception. Differential Revision: [D88216198](https://our.internmc.facebook.com/intern/diff/D88216198/)
1 parent aea2784 commit c9bf066

File tree

3 files changed

+52
-13
lines changed

3 files changed

+52
-13
lines changed

exir/serde/export_serialize.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@
5454
from torch.utils._sympy.numbers import int_oo
5555
from torch.utils._sympy.value_ranges import ValueRanges
5656

57+
import executorch.exir as exir
58+
5759
# pyre-ignore
5860

5961
from .schema import ( # type: ignore[attr-defined]
@@ -205,6 +207,11 @@ def _reverse_map(d: Dict[Any, Enum]):
205207
}
206208

207209

210+
_KNOWN_FUNCTIONS = {
211+
exir.memory.view,
212+
}
213+
214+
208215
@dataclass
209216
class SerializedArtifact:
210217
exported_program: bytes
@@ -545,6 +552,14 @@ def handle_call_function(self, node: torch.fx.Node):
545552
outputs=self.serialize_hoo_outputs(node),
546553
metadata=self.serialize_metadata(node),
547554
)
555+
elif node.target in _KNOWN_FUNCTIONS:
556+
ex_node = Node(
557+
name=node.name,
558+
target=node._pretty_print_target(node.target),
559+
inputs=self.serialize_hoo_inputs(node.args, node.kwargs),
560+
outputs=self.serialize_hoo_outputs(node),
561+
metadata=self.serialize_metadata(node),
562+
)
548563
else:
549564
raise SerializeError(f"Serializing {node.target} is not supported")
550565

exir/serde/serialize.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,11 @@ def serialize(
374374
)
375375

376376

377+
_KNOWN_FUNCTIONS_MAP = {
378+
"executorch.exir.memory.view": exir.memory.view,
379+
}
380+
381+
377382
class GraphModuleDeserializer(export_serialize.GraphModuleDeserializer):
378383
def deserialize_operator(self, serialized_target: str) -> str:
379384
def find_operator(module: _DialectNamespace, serialized_target: str) -> str:
@@ -450,19 +455,23 @@ def deserialize_node(self, serialized_node: schema.Node, target: Callable) -> No
450455
fx_node.meta.update(self.deserialize_metadata(serialized_node.metadata))
451456
return
452457
elif isinstance(target, str):
453-
# Create a dummy fake op if the target does not exist
454-
# because we cannot create a call_function node w/o a
455-
# callable target
456-
log.warning(
457-
f"Could not find operator {target}. Returning fake operator."
458-
) # noqa: G004
459-
460-
# pyre-ignore
461-
def fake_op(x):
462-
raise NotImplementedError("Fake op is not meant to be run.")
463-
464-
fake_op.__name__ = target
465-
target = fake_op
458+
# Special handling for known functions, which are serialized as a
459+
# string but are still somewhat expected in serialized graphs.
460+
if target in _KNOWN_FUNCTIONS_MAP:
461+
target = _KNOWN_FUNCTIONS_MAP[target]
462+
else:
463+
# Otherwise, create a dummy fake op if the target does not exist
464+
# because we cannot create a call_function node w/o a callable
465+
# target
466+
log.warning(
467+
f"Could not find operator {target}. Returning fake operator."
468+
) # noqa: G004
469+
# pyre-ignore
470+
def fake_op(x):
471+
raise NotImplementedError("Fake op is not meant to be run.")
472+
473+
fake_op.__name__ = target
474+
target = fake_op
466475

467476
args = self.deserialize_inputs_no_schema(serialized_node)
468477
fx_node = self.graph.create_node("call_function", target, args, None, None)

exir/tests/test_serde.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,3 +335,18 @@ def forward(self, x):
335335
node.meta.get("from_node"), node_new.meta.get("from_node")
336336
):
337337
self.assertEqual(node_source.to_dict(), node_source_new.to_dict())
338+
339+
def test_memory_ops(self) -> None:
340+
class MemoryOpsModule(nn.Module):
341+
def __init__(self):
342+
super().__init__()
343+
344+
def forward(self, x, y):
345+
x = exir.memory.view(x, (10, 10))
346+
return x + y
347+
348+
inputs = (
349+
torch.randn(100),
350+
torch.randn(10, 10),
351+
)
352+
self.check_serde(MemoryOpsModule(), inputs)

0 commit comments

Comments
 (0)