Skip to content

Commit 4bf12e7

Browse files
committed
fix comments
1 parent 94757d2 commit 4bf12e7

File tree

8 files changed

+136
-99
lines changed

8 files changed

+136
-99
lines changed

examples/dynamo/autocast_example.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,14 @@ def forward(self, x, y):
3131
) # fp32 because of `torch.ops.aten.flatten.using_ints` in `autocast_excluded_ops`
3232
# Respect the precisions in the pytorch autocast context
3333
with torch.autocast(x.device.type, enabled=True, dtype=torch.float32):
34-
x = self.fc1(x)
34+
x = self.fc1(x) # fp32
3535
with torch.autocast(x.device.type, enabled=False):
36-
x = torch.sub(x.half(), y)
37-
out2 = torch.add(x, x)
36+
x = torch.sub(x.half(), y) # fp16
37+
out2 = torch.add(x, x) # fp16
3838
with torch.autocast(x.device.type, enabled=True, dtype=torch.float16):
39-
out2 = torch.log(out2)
39+
out2 = torch.log(
40+
out2
41+
) # fp32 because Pytorch Autocast requires `log` to be in fp32
4042
return x, out, out2
4143

4244

@@ -46,6 +48,9 @@ def forward(self, x, y):
4648
torch.randn((1, 3, 32, 32), dtype=torch.float32, device="cuda"),
4749
torch.randn((1,), dtype=torch.float16, device="cuda"),
4850
)
51+
calibration_dataloader = torch.utils.data.DataLoader(
52+
torch.utils.data.TensorDataset(*inputs), batch_size=1, shuffle=False
53+
)
4954

5055
ep = torch.export.export(model, inputs)
5156

@@ -68,8 +73,9 @@ def forward(self, x, y):
6873
autocast_low_precision_type=torch.float16,
6974
autocast_excluded_nodes={"^conv1$", "relu"},
7075
autocast_excluded_ops={torch.ops.aten.flatten.using_ints},
71-
autocast_data_max=512,
76+
autocast_max_output_threshold=512,
7277
autocast_max_depth_of_reduction=None,
78+
autocast_calibration_dataloader=calibration_dataloader,
7379
)
7480

7581
trt_out = trt_mod(*inputs)

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@ def compile(
440440
] = _defaults.AUTOCAST_LOW_PRECISION_TYPE,
441441
autocast_excluded_nodes: Collection[str] = _defaults.AUTOCAST_EXCLUDED_NODES,
442442
autocast_excluded_ops: Collection[Target] = _defaults.AUTOCAST_EXCLUDED_OPS,
443-
autocast_data_max: float = _defaults.AUTOCAST_DATA_MAX,
443+
autocast_max_output_threshold: float = _defaults.AUTOCAST_MAX_OUTPUT_THRESHOLD,
444444
autocast_max_depth_of_reduction: Optional[
445445
int
446446
] = _defaults.AUTOCAST_MAX_DEPTH_OF_REDUCTION,
@@ -526,10 +526,10 @@ def compile(
526526
use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model
527527
enable_autocast (bool): Whether to enable autocast. If enabled, use_explicit_typing will be set to True.
528528
autocast_low_precision_type (Optional[Union[torch.dtype, dtype]]): The precision to reduce to. We currently support torch.float16 and torch.bfloat16. Default is None, which means no low precision is used.
529-
autocast_excluded_nodes (Collection[str]): The set of regex patterns to match node names that should remain in FP32. Default is [].
529+
autocast_excluded_nodes (Collection[str]): The set of regex patterns to match user-specified node names that should remain in FP32. Default is [].
530530
autocast_excluded_ops (Collection[Target]): The set of targets (ATen ops) that should remain in FP32. Default is [].
531-
autocast_data_max (float): Maximum absolute value for node outputs, nodes with outputs greater than this value will remain in FP32. Default is 512.
532-
autocast_max_depth_of_reduction (Optional[int]): Maximum depth of reduction allowed in low precision. Nodes with higher reduction depths will remain in FP32. If not provided, infinity will be used. Default is None.
531+
autocast_max_output_threshold (float): Maximum absolute value for node outputs, nodes with outputs greater than this value will remain in FP32. Default is 512.
532+
autocast_max_depth_of_reduction (Optional[int]): Maximum depth of reduction allowed in low precision. Nodes with higher reduction depths will remain in FP32. This helps prevent excessive accuracy loss in operations particularly sensitive to reduced precision, as higher-depth reductions may amplify computation errors in low precision formats. If not provided, infinity will be used. Default is None.
533533
autocast_calibration_dataloader (Optional[torch.utils.data.DataLoader]): The dataloader to use for autocast calibration. Default is None.
534534
**kwargs: Any,
535535
Returns:
@@ -721,7 +721,7 @@ def compile(
721721
"autocast_low_precision_type": autocast_low_precision_type,
722722
"autocast_excluded_nodes": autocast_excluded_nodes,
723723
"autocast_excluded_ops": autocast_excluded_ops,
724-
"autocast_data_max": autocast_data_max,
724+
"autocast_max_output_threshold": autocast_max_output_threshold,
725725
"autocast_max_depth_of_reduction": autocast_max_depth_of_reduction,
726726
"autocast_calibration_dataloader": autocast_calibration_dataloader,
727727
}

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
AUTOCAST_LOW_PRECISION_TYPE = None
6262
AUTOCAST_EXCLUDED_NODES = set[str]()
6363
AUTOCAST_EXCLUDED_OPS = set[torch.fx.node.Target]()
64-
AUTOCAST_DATA_MAX = 512
64+
AUTOCAST_MAX_OUTPUT_THRESHOLD = 512
6565
AUTOCAST_MAX_DEPTH_OF_REDUCTION = None
6666
AUTOCAST_CALIBRATION_DATALOADER = None
6767

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88
from torch_tensorrt.dynamo._defaults import (
99
ASSUME_DYNAMIC_SHAPE_SUPPORT,
1010
AUTOCAST_CALIBRATION_DATALOADER,
11-
AUTOCAST_DATA_MAX,
1211
AUTOCAST_EXCLUDED_NODES,
1312
AUTOCAST_EXCLUDED_OPS,
1413
AUTOCAST_LOW_PRECISION_TYPE,
1514
AUTOCAST_MAX_DEPTH_OF_REDUCTION,
15+
AUTOCAST_MAX_OUTPUT_THRESHOLD,
1616
CACHE_BUILT_ENGINES,
1717
DISABLE_TF32,
1818
DLA_GLOBAL_DRAM_SIZE,
@@ -107,10 +107,10 @@ class CompilationSettings:
107107
use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model
108108
enable_autocast (bool): Whether to enable autocast. If enabled, use_explicit_typing will be set to True.
109109
autocast_low_precision_type (Optional[Union[torch.dtype, dtype]]): The precision to reduce to. We currently support torch.float16 and torch.bfloat16. Default is None, which means no low precision is used.
110-
autocast_excluded_nodes (Collection[str]): The set of regex patterns to match node names that should remain in FP32. Default is [].
110+
autocast_excluded_nodes (Collection[str]): The set of regex patterns to match user-specified node names that should remain in FP32. Default is [].
111111
autocast_excluded_ops (Collection[Target]): The set of targets (ATen ops) that should remain in FP32. Default is [].
112-
autocast_data_max (float): Maximum absolute value for node outputs, nodes with outputs greater than this value will remain in FP32. Default is 512.
113-
autocast_max_depth_of_reduction (Optional[int]): Maximum depth of reduction allowed in low precision. Nodes with higher reduction depths will remain in FP32. If not provided, infinity will be used. Default is None.
112+
autocast_max_output_threshold (float): Maximum absolute value for node outputs, nodes with outputs greater than this value will remain in FP32. Default is 512.
113+
autocast_max_depth_of_reduction (Optional[int]): Maximum depth of reduction allowed in low precision. Nodes with higher reduction depths will remain in FP32. This helps prevent excessive accuracy loss in operations particularly sensitive to reduced precision, as higher-depth reductions may amplify computation errors in low precision formats. If not provided, infinity will be used. Default is None.
114114
autocast_calibration_dataloader (Optional[torch.utils.data.DataLoader]): The dataloader to use for autocast calibration. Default is None.
115115
"""
116116

@@ -163,7 +163,7 @@ class CompilationSettings:
163163
autocast_excluded_ops: Collection[Target] = field(
164164
default_factory=lambda: AUTOCAST_EXCLUDED_OPS
165165
)
166-
autocast_data_max: float = AUTOCAST_DATA_MAX
166+
autocast_max_output_threshold: float = AUTOCAST_MAX_OUTPUT_THRESHOLD
167167
autocast_max_depth_of_reduction: Optional[int] = AUTOCAST_MAX_DEPTH_OF_REDUCTION
168168
autocast_calibration_dataloader: Optional[torch.utils.data.DataLoader] = (
169169
AUTOCAST_CALIBRATION_DATALOADER

py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py

Lines changed: 9 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
import logging
2+
import operator
23
from typing import Any, Callable, Optional, Sequence, Union
34

45
import torch
56
from torch_tensorrt._utils import is_tegra_platform
67
from torch_tensorrt.dynamo._settings import CompilationSettings
8+
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
9+
trace_intermediate_node_outputs,
10+
)
711

812
from .complex_graph_rewrite import complex_graph_detection
913
from .constant_folding import constant_fold
@@ -141,33 +145,11 @@ def pre_export_lowering(
141145

142146
# Only for rule-based autocast to collect the intermediate node outputs
143147
if settings.enable_autocast:
144-
autocast_intermediate_node_outputs: dict[str, torch.Tensor] = {}
145-
146-
class IntermediateNodeTracer(torch.fx.Interpreter): # type: ignore[misc]
147-
def run_node(self, n: torch.fx.Node) -> Any:
148-
out = super().run_node(n)
149-
if (
150-
n.op == "call_function"
151-
and n.target != torch.ops.higher_order.wrap_with_autocast
152-
):
153-
if not isinstance(out, torch.Tensor):
154-
raise ValueError(
155-
f"Please file a bug with Torch-TensorRT because it expects a torch.Tensor but got {type(out)} for node {n.name}."
156-
)
157-
if n.name in autocast_intermediate_node_outputs:
158-
autocast_intermediate_node_outputs[n.name] = torch.cat(
159-
[autocast_intermediate_node_outputs[n.name], out], dim=0
160-
)
161-
else:
162-
autocast_intermediate_node_outputs[n.name] = out
163-
return out
164-
165-
if settings.autocast_calibration_dataloader is not None:
166-
tracer = IntermediateNodeTracer(ep.module())
167-
for batch in settings.autocast_calibration_dataloader:
168-
tracer.run(tuple(batch))
169-
settings.autocast_intermediate_node_outputs = autocast_intermediate_node_outputs
170-
148+
settings.autocast_intermediate_node_outputs = trace_intermediate_node_outputs(
149+
ep.module(),
150+
settings.autocast_calibration_dataloader,
151+
[torch.ops.higher_order.wrap_with_autocast, operator.getitem],
152+
)
171153
gm = ep.graph_module
172154
gm = ATEN_PRE_LOWERING_PASSES(gm, settings)
173155
return ep

py/torch_tensorrt/dynamo/lowering/passes/nodeclassifier.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def check(self, node):
2929
"""Check if a node should be skipped based on the rule.
3030
3131
Args:
32-
node: The ONNX node to check.
32+
node: The torch.fx.Node to check.
3333
3434
Returns:
3535
bool: True if the node should be kept in high precision, False otherwise.
@@ -42,13 +42,13 @@ def check(self, node):
4242

4343

4444
class DisabledNodeNameRegexRule(NodeRuleBase):
45-
"""Rule for keeping nodes with matching names in high precision."""
45+
"""Rule for keeping nodes with matching user-specified names in high precision."""
4646

4747
def __init__(self, disabled_node_name_regex):
4848
"""Initialize the rule.
4949
5050
Args:
51-
disabled_node_name_regex: List of regex patterns for node names to keep in high precision.
51+
disabled_node_name_regex: List of regex patterns for user-specified node names to keep in high precision.
5252
"""
5353
self.disabled_node_name_regex = disabled_node_name_regex
5454

@@ -63,13 +63,13 @@ def _check_inner(self, node):
6363

6464

6565
class DisabledOpTypes(NodeRuleBase):
66-
"""Rule for keeping nodes with specific operation types in high precision."""
66+
"""Rule for keeping nodes with specific ATen ops in high precision."""
6767

6868
def __init__(self, excluded_ops):
6969
"""Initialize the rule.
7070
7171
Args:
72-
excluded_ops: List of operation types to keep in high precision.
72+
excluded_ops: List of ATen ops that should remain in FP32.
7373
"""
7474
self.excluded_ops = excluded_ops
7575

@@ -80,14 +80,14 @@ def _check_inner(self, node):
8080
class IORangeRule(NodeRuleBase):
8181
"""Rule for keeping nodes with out-of-range inputs/outputs in high precision."""
8282

83-
def __init__(self, data_max, reference_data):
83+
def __init__(self, max_output_threshold, reference_data):
8484
"""Initialize the rule.
8585
8686
Args:
87-
data_max: Maximum absolute value allowed for node I/O.
87+
max_output_threshold: Maximum absolute value allowed for node I/O.
8888
reference_data: Reference data for checking I/O ranges.
8989
"""
90-
self.data_max = data_max
90+
self.max_output_threshold = max_output_threshold
9191
self.reference_data = reference_data
9292
self.output_data = None
9393

@@ -108,7 +108,7 @@ def is_io_out_of_range(node):
108108
logger.debug(
109109
f"Node {node.name}: reference data: min={ref_data.min()}, max={ref_data.max()}"
110110
)
111-
if torch.any(torch.abs(ref_data) > self.data_max):
111+
if torch.any(torch.abs(ref_data) > self.max_output_threshold):
112112
self.output_data = ref_data
113113
return True
114114

@@ -126,14 +126,17 @@ def _log_skipped(self, node, **kwargs):
126126
if self.output_data is not None:
127127
logger.info(
128128
f"Skipping node {node.name}: reference IO out of range: min={torch.min(self.output_data)}, "
129-
f"max={torch.max(self.output_data)}, range=[{-self.data_max}, {self.data_max}]"
129+
f"max={torch.max(self.output_data)}, range=[{-self.max_output_threshold}, {self.max_output_threshold}]"
130130
)
131131
else:
132132
super()._log_skipped(node, **kwargs)
133133

134134

135135
class DepthOfReductionRule(NodeRuleBase):
136-
"""Rule for keeping nodes with high depth of reduction in high precision."""
136+
"""
137+
Rule for keeping nodes with high depth of reduction in high precision. This helps prevent excessive accuracy loss in operations particularly sensitive to reduced precision, as higher-depth reductions may amplify computation errors in low precision formats.
138+
Reduction ops are those that aggregate data across one or more axes, decreasing the dimensionality of the input tensor, such as convolution, gemm, etc.
139+
"""
137140

138141
def __init__(self, max_depth_of_reduction, reference_data):
139142
"""Initialize the rule.
@@ -226,7 +229,7 @@ def __init__(
226229
excluded_nodes: Collection[str] | None = None,
227230
excluded_ops: Collection[torch.fx.node.Target] | None = None,
228231
custom_rule: NodeRuleBase | None = None,
229-
data_max: float | None = 1000.0,
232+
max_output_threshold: float | None = 512,
230233
max_depth_of_reduction: int | None = None,
231234
):
232235
"""Initialize the node classifier.
@@ -236,14 +239,14 @@ def __init__(
236239
nodes_to_exclude: Collection of regex patterns for node names to keep in high precision.
237240
targets_to_exclude: Collection of targets to keep in high precision.
238241
custom_rule: Optional custom classification rule.
239-
data_max: Maximum absolute value allowed for node I/O.
242+
max_output_threshold: Maximum absolute value allowed for node I/O.
240243
max_depth_of_reduction: Maximum depth of reduction allowed in low precision.
241244
"""
242245
self.nodes = nodes
243246
self.excluded_nodes = excluded_nodes
244247
self.excluded_ops = excluded_ops
245248
self.custom_rule = custom_rule
246-
self.data_max = data_max
249+
self.max_output_threshold = max_output_threshold
247250
self.max_depth_of_reduction = max_depth_of_reduction
248251

249252
def _gen_block_node_rules(self, reference_data):
@@ -261,7 +264,9 @@ def _gen_block_node_rules(self, reference_data):
261264
if self.excluded_ops:
262265
block_node_rules.append(DisabledOpTypes(self.excluded_ops))
263266
if reference_data:
264-
block_node_rules.append(IORangeRule(self.data_max, reference_data))
267+
block_node_rules.append(
268+
IORangeRule(self.max_output_threshold, reference_data)
269+
)
265270
if self.max_depth_of_reduction is not None:
266271
block_node_rules.append(
267272
DepthOfReductionRule(

py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List
1+
from typing import Any, Dict, List, Sequence
22

33
import torch
44

@@ -68,3 +68,44 @@ def is_node_complex(node: torch.fx.Node, complexNodes):
6868
complexNodes[node.name] = True
6969
return True
7070
return False
71+
72+
73+
def trace_intermediate_node_outputs(
74+
gm: torch.fx.GraphModule,
75+
calibration_dataloader: torch.utils.data.DataLoader,
76+
excluded_ops: Sequence[torch.fx.node.Target] = [],
77+
) -> Dict[str, torch.Tensor]:
78+
"""Trace the intermediate node outputs of a graph module.
79+
80+
Args:
81+
gm (torch.fx.GraphModule): The graph module to trace the intermediate node outputs of.
82+
calibration_dataloader (torch.utils.data.DataLoader): The dataloader to use for tracing.
83+
excluded_ops (Set[torch.fx.node.Target]): The set of ATen ops that should be excluded from the trace. For example, `{torch.ops.higher_order.wrap_with_autocast, operator.getitem}`. Default is an empty set.
84+
85+
Returns:
86+
Dict[str, torch.Tensor]: A dictionary of intermediate node outputs. The key is the node name and the value is the tensor.
87+
"""
88+
89+
intermediate_node_outputs: Dict[str, torch.Tensor] = {}
90+
91+
class IntermediateNodeTracer(torch.fx.Interpreter): # type: ignore[misc]
92+
def run_node(self, n: torch.fx.Node) -> Any:
93+
out = super().run_node(n)
94+
if n.op == "call_function" and n.target not in excluded_ops:
95+
if not isinstance(out, torch.Tensor):
96+
raise ValueError(
97+
f"Please file a bug with Torch-TensorRT because it expects a torch.Tensor but got {type(out)} for node {n.name}."
98+
)
99+
if n.name in intermediate_node_outputs:
100+
intermediate_node_outputs[n.name] = torch.cat(
101+
[intermediate_node_outputs[n.name], out], dim=0
102+
)
103+
else:
104+
intermediate_node_outputs[n.name] = out
105+
return out
106+
107+
if calibration_dataloader is not None:
108+
tracer = IntermediateNodeTracer(gm)
109+
for batch in calibration_dataloader:
110+
tracer.run(tuple(batch))
111+
return intermediate_node_outputs

0 commit comments

Comments
 (0)