Skip to content

Commit 94757d2

Browse files
committed
support dataloader for calibration
1 parent e15ce94 commit 94757d2

File tree

4 files changed

+40
-36
lines changed

4 files changed

+40
-36
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 5 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,9 @@ def compile(
444444
autocast_max_depth_of_reduction: Optional[
445445
int
446446
] = _defaults.AUTOCAST_MAX_DEPTH_OF_REDUCTION,
447+
autocast_calibration_dataloader: Optional[
448+
torch.utils.data.DataLoader
449+
] = _defaults.AUTOCAST_CALIBRATION_DATALOADER,
447450
**kwargs: Any,
448451
) -> torch.fx.GraphModule:
449452
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
@@ -527,6 +530,7 @@ def compile(
527530
autocast_excluded_ops (Collection[Target]): The set of targets (ATen ops) that should remain in FP32. Default is [].
528531
autocast_data_max (float): Maximum absolute value for node outputs, nodes with outputs greater than this value will remain in FP32. Default is 512.
529532
autocast_max_depth_of_reduction (Optional[int]): Maximum depth of reduction allowed in low precision. Nodes with higher reduction depths will remain in FP32. If not provided, infinity will be used. Default is None.
533+
autocast_calibration_dataloader (Optional[torch.utils.data.DataLoader]): The dataloader to use for autocast calibration. Default is None.
530534
**kwargs: Any,
531535
Returns:
532536
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -655,38 +659,6 @@ def compile(
655659
if not isinstance(arg_inputs, collections.abc.Sequence):
656660
arg_inputs = [arg_inputs] # type: ignore
657661

658-
# save intermediate outputs of each node for Autocast
659-
autocast_intermediate_node_outputs = {}
660-
if not use_explicit_typing:
661-
662-
class DumpInterpreter(torch.fx.Interpreter): # type: ignore[misc]
663-
"""Dump intermediate outputs of each node"""
664-
665-
def run_node(self, n: torch.fx.Node) -> Any:
666-
if (
667-
n.op == "call_function"
668-
and n.target != torch.ops.higher_order.wrap_with_autocast
669-
):
670-
out = super().run_node(n)
671-
if not isinstance(out, torch.Tensor):
672-
raise ValueError(
673-
f"Please file a bug with Torch-TensorRT because it expects a torch.Tensor but got {type(out)} for node {n.name}."
674-
)
675-
autocast_intermediate_node_outputs[n.name] = out
676-
return out
677-
return super().run_node(n)
678-
679-
def _materialize(x: Input | torch.Tensor) -> torch.Tensor:
680-
"""Materialize an Input object to a tensor"""
681-
if isinstance(x, Input):
682-
return x.torch_tensor
683-
return x
684-
685-
with torch.no_grad():
686-
mat_args = tuple(_materialize(a) for a in arg_inputs)
687-
mat_kwargs = {k: _materialize(v) for k, v in kwarg_inputs.items()}
688-
DumpInterpreter(exported_program.module()).run(*mat_args, **mat_kwargs)
689-
690662
# Prepare torch_trt inputs
691663
trt_arg_inputs: Sequence[Input] = prepare_inputs(arg_inputs)
692664
trt_kwarg_inputs: Optional[dict[Any, Any]] = prepare_inputs(kwarg_inputs)
@@ -751,7 +723,7 @@ def _materialize(x: Input | torch.Tensor) -> torch.Tensor:
751723
"autocast_excluded_ops": autocast_excluded_ops,
752724
"autocast_data_max": autocast_data_max,
753725
"autocast_max_depth_of_reduction": autocast_max_depth_of_reduction,
754-
"autocast_intermediate_node_outputs": autocast_intermediate_node_outputs,
726+
"autocast_calibration_dataloader": autocast_calibration_dataloader,
755727
}
756728

757729
settings = CompilationSettings(**compilation_options)

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
AUTOCAST_EXCLUDED_OPS = set[torch.fx.node.Target]()
6464
AUTOCAST_DATA_MAX = 512
6565
AUTOCAST_MAX_DEPTH_OF_REDUCTION = None
66+
AUTOCAST_CALIBRATION_DATALOADER = None
6667

6768
if platform.system() == "Linux":
6869
import pwd

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from torch_tensorrt._enums import EngineCapability, dtype
88
from torch_tensorrt.dynamo._defaults import (
99
ASSUME_DYNAMIC_SHAPE_SUPPORT,
10+
AUTOCAST_CALIBRATION_DATALOADER,
1011
AUTOCAST_DATA_MAX,
1112
AUTOCAST_EXCLUDED_NODES,
1213
AUTOCAST_EXCLUDED_OPS,
@@ -110,7 +111,7 @@ class CompilationSettings:
110111
autocast_excluded_ops (Collection[Target]): The set of targets (ATen ops) that should remain in FP32. Default is [].
111112
autocast_data_max (float): Maximum absolute value for node outputs, nodes with outputs greater than this value will remain in FP32. Default is 512.
112113
autocast_max_depth_of_reduction (Optional[int]): Maximum depth of reduction allowed in low precision. Nodes with higher reduction depths will remain in FP32. If not provided, infinity will be used. Default is None.
113-
autocast_intermediate_node_outputs (dict[str, torch.Tensor]): The intermediate node outputs of the graph. Default is {}.
114+
autocast_calibration_dataloader (Optional[torch.utils.data.DataLoader]): The dataloader to use for autocast calibration. Default is None.
114115
"""
115116

116117
enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
@@ -164,8 +165,8 @@ class CompilationSettings:
164165
)
165166
autocast_data_max: float = AUTOCAST_DATA_MAX
166167
autocast_max_depth_of_reduction: Optional[int] = AUTOCAST_MAX_DEPTH_OF_REDUCTION
167-
autocast_intermediate_node_outputs: dict[str, torch.Tensor] = field(
168-
default_factory=lambda: {}
168+
autocast_calibration_dataloader: Optional[torch.utils.data.DataLoader] = (
169+
AUTOCAST_CALIBRATION_DATALOADER
169170
)
170171

171172
def __getstate__(self) -> dict[str, Any]:

py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,36 @@ def pre_export_lowering(
138138
logging.debug(
139139
f"Invoking DynamoPassManager and applying lowering passes: {ATEN_PRE_LOWERING_PASSES}"
140140
)
141+
142+
# Only for rule-based autocast to collect the intermediate node outputs
143+
if settings.enable_autocast:
144+
autocast_intermediate_node_outputs: dict[str, torch.Tensor] = {}
145+
146+
class IntermediateNodeTracer(torch.fx.Interpreter): # type: ignore[misc]
147+
def run_node(self, n: torch.fx.Node) -> Any:
148+
out = super().run_node(n)
149+
if (
150+
n.op == "call_function"
151+
and n.target != torch.ops.higher_order.wrap_with_autocast
152+
):
153+
if not isinstance(out, torch.Tensor):
154+
raise ValueError(
155+
f"Please file a bug with Torch-TensorRT because it expects a torch.Tensor but got {type(out)} for node {n.name}."
156+
)
157+
if n.name in autocast_intermediate_node_outputs:
158+
autocast_intermediate_node_outputs[n.name] = torch.cat(
159+
[autocast_intermediate_node_outputs[n.name], out], dim=0
160+
)
161+
else:
162+
autocast_intermediate_node_outputs[n.name] = out
163+
return out
164+
165+
if settings.autocast_calibration_dataloader is not None:
166+
tracer = IntermediateNodeTracer(ep.module())
167+
for batch in settings.autocast_calibration_dataloader:
168+
tracer.run(tuple(batch))
169+
settings.autocast_intermediate_node_outputs = autocast_intermediate_node_outputs
170+
141171
gm = ep.graph_module
142172
gm = ATEN_PRE_LOWERING_PASSES(gm, settings)
143173
return ep

0 commit comments

Comments
 (0)