Skip to content

Commit eac8809

Browse files
committed
implement autocast
1 parent c9859b6 commit eac8809

File tree

10 files changed

+626
-29
lines changed

10 files changed

+626
-29
lines changed

core/runtime/execute_engine.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,6 @@ void setup_input_tensors(
107107
TORCHTRT_CHECK(
108108
inputs[i].is_cuda(), "Expected input tensors to have device cuda, found device " << inputs[i].device());
109109

110-
auto expected_type =
111-
util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str()));
112-
TORCHTRT_CHECK(
113-
inputs[i].dtype() == expected_type,
114-
"Expected input tensors to have type " << expected_type << ", found type " << inputs[i].dtype());
115-
116110
auto dims = core::util::toDims(inputs[i].sizes());
117111
auto shape = core::util::toVec(dims);
118112
LOG_DEBUG("Input Name: " << name << " Shape: " << dims);
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch_tensorrt
4+
import torchvision
5+
6+
7+
class MyModule(torch.nn.Module):
8+
def forward(self, a_float32, b_float32, c_float32, d_float32):
9+
with torch.autocast(device_type="cuda"):
10+
e_float16 = torch.mm(a_float32, b_float32)
11+
with torch.autocast(device_type="cuda", enabled=False):
12+
# Calls e_float16.float() to ensure float32 execution
13+
# (necessary because e_float16 was created in an autocasted region)
14+
f_float32 = torch.mm(c_float32, e_float16.float())
15+
16+
# No manual casts are required when re-entering the autocast-enabled region.
17+
# torch.mm again runs in float16 and produces float16 output, regardless of input types.
18+
g_float16 = torch.mm(d_float32, f_float32)
19+
return g_float16
20+
21+
22+
class AutocastExample(nn.Module):
23+
def __init__(self):
24+
super(AutocastExample, self).__init__()
25+
self.conv1 = nn.Conv2d(
26+
in_channels=3, out_channels=8, kernel_size=3, stride=1, padding=1
27+
)
28+
self.relu1 = nn.ReLU()
29+
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
30+
self.conv2 = nn.Conv2d(
31+
in_channels=8, out_channels=16, kernel_size=3, stride=1, padding=1
32+
)
33+
self.relu2 = nn.ReLU()
34+
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
35+
self.flatten = nn.Flatten()
36+
self.fc1 = nn.Linear(16 * 8 * 8, 10)
37+
38+
def forward(self, x, y):
39+
out = self.pool1(self.relu1(self.conv1(x))) # fp16
40+
x = self.pool2(self.relu2(self.conv2(out))) # fp16
41+
x = self.flatten(x)
42+
with torch.autocast(x.device.type, enabled=True, dtype=torch.float32):
43+
x = self.fc1(x) # fp32
44+
with torch.autocast(x.device.type, enabled=False):
45+
x = torch.sub(x.half(), y) # fp16
46+
out2 = torch.add(x, x) # fp16
47+
with torch.autocast(x.device.type, enabled=True, dtype=torch.float16):
48+
out2 = torch.log(out2) # fp32
49+
return x, out, out2
50+
51+
52+
class MyResNet18Wrapper(torch.nn.Module):
53+
def __init__(self, num_classes=1000, pretrained=True):
54+
super(MyResNet18Wrapper, self).__init__()
55+
self.resnet = torchvision.models.resnet18(
56+
num_classes=num_classes, weights="IMAGENET1K_V1" if pretrained else None
57+
)
58+
59+
def forward(self, x):
60+
x = self.resnet(x)
61+
return x
62+
63+
64+
if __name__ == "__main__":
65+
# model = MyModule().cuda().eval()
66+
# inputs = (torch.randn((8, 8), device="cuda"),
67+
# torch.randn((8, 8), device="cuda"),
68+
# torch.randn((8, 8), device="cuda"),
69+
# torch.randn((8, 8), device="cuda"),)
70+
71+
# model = AutocastExample().cuda().eval()
72+
# inputs = (torch.randn((1, 3, 32, 32), dtype=torch.float32, device="cuda"),
73+
# torch.randn((1,), dtype=torch.float16, device="cuda"),)
74+
75+
model = MyResNet18Wrapper().cuda().eval()
76+
inputs = (torch.randn((1, 3, 224, 224), dtype=torch.float32, device="cuda"),)
77+
78+
ep = torch.export.export(model, inputs)
79+
80+
with torch_tensorrt.dynamo.Debugger(
81+
"graphs",
82+
logging_dir=".",
83+
engine_builder_monitor=False,
84+
):
85+
trt_mod = torch_tensorrt.compile(
86+
ep.module(),
87+
arg_inputs=inputs,
88+
use_explicit_typing=False,
89+
min_block_size=1,
90+
use_python_runtime=True,
91+
low_precision_type=torch.float16,
92+
# nodes_to_exclude={"^conv2d$"},
93+
targets_to_exclude={},
94+
data_max=512,
95+
max_depth_of_reduction=None,
96+
)
97+
98+
trt_out = trt_mod(*inputs)

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,13 @@ def compile(
434434
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
435435
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
436436
use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE,
437+
low_precision_type: Optional[
438+
Union[torch.dtype, dtype]
439+
] = _defaults.LOW_PRECISION_TYPE,
440+
nodes_to_exclude: Collection[str] = _defaults.NODES_TO_EXCLUDE,
441+
targets_to_exclude: Collection[Target] = _defaults.TARGETS_TO_EXCLUDE,
442+
data_max: float = _defaults.DATA_MAX,
443+
max_depth_of_reduction: Optional[int] = _defaults.MAX_DEPTH_OF_REDUCTION,
437444
**kwargs: Any,
438445
) -> torch.fx.GraphModule:
439446
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
@@ -511,6 +518,11 @@ def compile(
511518
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
512519
offload_module_to_cpu (bool): Offload the module to CPU. This is useful when we need to minimize GPU memory usage.
513520
use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model
521+
low_precision_type (Optional[Union[torch.dtype, dtype]]): The precision to reduce to. We currently support torch.float16 and torch.bfloat16. Default is None, which means no low precision is used.
522+
nodes_to_exclude (Collection[str]): The set of regex patterns to match node names that should remain in FP32. Default is [].
523+
targets_to_exclude (Collection[Target]): The set of targets (ATen ops) that should remain in FP32. Default is [].
524+
data_max (float): Maximum absolute value for node outputs, nodes with outputs greater than this value will remain in FP32. Default is 512.
525+
max_depth_of_reduction (Optional[int]): Maximum depth of reduction allowed in low precision. Nodes with higher reduction depths will remain in FP32. If not provided, infinity will be used. Default is None.
514526
**kwargs: Any,
515527
Returns:
516528
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -593,6 +605,19 @@ def compile(
593605
f"use_explicit_typing was set to True, however found that enabled_precisions was also specified (saw: {enabled_precisions}, expected: dtype.f32, dtype.f4). enabled_precisions should not be used when use_explicit_typing=True"
594606
)
595607

608+
if low_precision_type is not None:
609+
if not isinstance(low_precision_type, (torch.dtype, dtype)):
610+
raise ValueError(
611+
f"low_precision_type must be a torch.dtype or dtype, got {type(low_precision_type)}"
612+
)
613+
if low_precision_type not in {
614+
torch.float16,
615+
torch.bfloat16,
616+
} and low_precision_type not in {dtype.f16, dtype.bf16}:
617+
raise ValueError(
618+
f"low_precision_type must be one of torch.float16, torch.bfloat16, dtype.f16, dtype.bf16, got {low_precision_type}"
619+
)
620+
596621
if use_fp32_acc:
597622
logger.debug(
598623
"FP32 accumulation for matmul layers is enabled. This option should only be enabled if the model already has FP16 weights and has no effect if it has FP32 weights. \
@@ -622,6 +647,38 @@ def compile(
622647
if not isinstance(arg_inputs, collections.abc.Sequence):
623648
arg_inputs = [arg_inputs] # type: ignore
624649

650+
# save intermediate outputs of each node for Autocast
651+
intermediate_node_outputs = {}
652+
if not use_explicit_typing:
653+
654+
class DumpInterpreter(torch.fx.Interpreter): # type: ignore[misc]
655+
"""Dump intermediate outputs of each node"""
656+
657+
def run_node(self, n: torch.fx.Node) -> Any:
658+
if (
659+
n.op == "call_function"
660+
and n.target != torch.ops.higher_order.wrap_with_autocast
661+
):
662+
out = super().run_node(n)
663+
if not isinstance(out, torch.Tensor):
664+
raise ValueError(
665+
f"Please file a bug with Torch-TensorRT because it expects a torch.Tensor but got {type(out)} for node {n.name}."
666+
)
667+
intermediate_node_outputs[n.name] = out
668+
return out
669+
return super().run_node(n)
670+
671+
def _materialize(x: Input | torch.Tensor) -> torch.Tensor:
672+
"""Materialize an Input object to a tensor"""
673+
if isinstance(x, Input):
674+
return x.torch_tensor
675+
return x
676+
677+
with torch.no_grad():
678+
mat_args = tuple(_materialize(a) for a in arg_inputs)
679+
mat_kwargs = {k: _materialize(v) for k, v in kwarg_inputs.items()}
680+
DumpInterpreter(exported_program.module()).run(*mat_args, **mat_kwargs)
681+
625682
# Prepare torch_trt inputs
626683
trt_arg_inputs: Sequence[Input] = prepare_inputs(arg_inputs)
627684
trt_kwarg_inputs: Optional[dict[Any, Any]] = prepare_inputs(kwarg_inputs)
@@ -680,6 +737,12 @@ def compile(
680737
"l2_limit_for_tiling": l2_limit_for_tiling,
681738
"offload_module_to_cpu": offload_module_to_cpu,
682739
"use_distributed_mode_trace": use_distributed_mode_trace,
740+
"low_precision_type": low_precision_type,
741+
"nodes_to_exclude": nodes_to_exclude,
742+
"targets_to_exclude": targets_to_exclude,
743+
"data_max": data_max,
744+
"max_depth_of_reduction": max_depth_of_reduction,
745+
"intermediate_node_outputs": intermediate_node_outputs,
683746
}
684747

685748
settings = CompilationSettings(**compilation_options)

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@
5757
L2_LIMIT_FOR_TILING = -1
5858
USE_DISTRIBUTED_MODE_TRACE = False
5959
OFFLOAD_MODULE_TO_CPU = False
60+
LOW_PRECISION_TYPE = None
61+
NODES_TO_EXCLUDE = set[str]()
62+
TARGETS_TO_EXCLUDE = set[torch.fx.node.Target]()
63+
DATA_MAX = 512
64+
MAX_DEPTH_OF_REDUCTION = None
6065

6166
if platform.system() == "Linux":
6267
import pwd

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
from dataclasses import dataclass, field
22
from typing import Any, Collection, Optional, Set, Tuple, Union
33

4+
import torch
45
from torch.fx.node import Target
56
from torch_tensorrt._Device import Device
67
from torch_tensorrt._enums import EngineCapability, dtype
78
from torch_tensorrt.dynamo._defaults import (
89
ASSUME_DYNAMIC_SHAPE_SUPPORT,
910
CACHE_BUILT_ENGINES,
11+
DATA_MAX,
1012
DISABLE_TF32,
1113
DLA_GLOBAL_DRAM_SIZE,
1214
DLA_LOCAL_DRAM_SIZE,
@@ -21,8 +23,11 @@
2123
IMMUTABLE_WEIGHTS,
2224
L2_LIMIT_FOR_TILING,
2325
LAZY_ENGINE_INIT,
26+
LOW_PRECISION_TYPE,
2427
MAX_AUX_STREAMS,
28+
MAX_DEPTH_OF_REDUCTION,
2529
MIN_BLOCK_SIZE,
30+
NODES_TO_EXCLUDE,
2631
NUM_AVG_TIMING_ITERS,
2732
OFFLOAD_MODULE_TO_CPU,
2833
OPTIMIZATION_LEVEL,
@@ -32,6 +37,7 @@
3237
REUSE_CACHED_ENGINES,
3338
SPARSE_WEIGHTS,
3439
STRIP_ENGINE_WEIGHTS,
40+
TARGETS_TO_EXCLUDE,
3541
TILING_OPTIMIZATION_LEVEL,
3642
TIMING_CACHE_PATH,
3743
TRUNCATE_DOUBLE,
@@ -97,6 +103,12 @@ class CompilationSettings:
97103
tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"].
98104
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
99105
use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model
106+
low_precision_type (Optional[Union[torch.dtype, dtype]]): The precision to reduce to. We currently support torch.float16 and torch.bfloat16. Default is None, which means no low precision is used.
107+
nodes_to_exclude (Collection[str]): The set of regex patterns to match node names that should remain in FP32. Default is [].
108+
targets_to_exclude (Collection[Target]): The set of targets (ATen ops) that should remain in FP32. Default is [].
109+
data_max (float): Maximum absolute value for node outputs, nodes with outputs greater than this value will remain in FP32. Default is 512.
110+
max_depth_of_reduction (Optional[int]): Maximum depth of reduction allowed in low precision. Nodes with higher reduction depths will remain in FP32. If not provided, infinity will be used. Default is None.
111+
intermediate_node_outputs (dict[str, torch.Tensor]): The intermediate node outputs of the graph. Default is {}.
100112
"""
101113

102114
enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
@@ -140,6 +152,16 @@ class CompilationSettings:
140152
l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING
141153
use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE
142154
offload_module_to_cpu: bool = OFFLOAD_MODULE_TO_CPU
155+
low_precision_type: Optional[dtype] = LOW_PRECISION_TYPE
156+
nodes_to_exclude: Collection[str] = field(default_factory=lambda: NODES_TO_EXCLUDE)
157+
targets_to_exclude: Collection[Target] = field(
158+
default_factory=lambda: TARGETS_TO_EXCLUDE
159+
)
160+
data_max: float = DATA_MAX
161+
max_depth_of_reduction: Optional[int] = MAX_DEPTH_OF_REDUCTION
162+
intermediate_node_outputs: dict[str, torch.Tensor] = field(
163+
default_factory=lambda: {}
164+
)
143165

144166
def __getstate__(self) -> dict[str, Any]:
145167
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -292,17 +292,9 @@ def _populate_trt_builder_config(
292292
)
293293

294294
if not self.compilation_settings.use_explicit_typing:
295-
if dtype.float16 in self.compilation_settings.enabled_precisions:
296-
builder_config.set_flag(trt.BuilderFlag.FP16)
297-
298-
if dtype.int8 in self.compilation_settings.enabled_precisions:
299-
builder_config.set_flag(trt.BuilderFlag.INT8)
300-
301-
if dtype.fp8 in self.compilation_settings.enabled_precisions:
302-
builder_config.set_flag(trt.BuilderFlag.FP8)
303-
304-
if dtype.bfloat16 in self.compilation_settings.enabled_precisions:
305-
builder_config.set_flag(trt.BuilderFlag.BF16)
295+
_LOGGER.info(
296+
"Torch-TensorRT uses Autocast to determine the precision of the graph, because weak typing has been deprecated in TensorRT 10.12."
297+
)
306298

307299
if self.compilation_settings.sparse_weights:
308300
builder_config.set_flag(trt.BuilderFlag.SPARSE_WEIGHTS)

py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,13 @@
1515
from .remove_num_users_is_0_nodes import remove_num_users_is_0_nodes
1616
from .repair_input_as_output import repair_input_as_output
1717
from .replace_max_pool_with_indices import replace_max_pool_with_indices
18+
from .rule_based_autocast import rule_based_autocast
19+
20+
pre_lowering_pass_list = [
21+
remove_detach,
22+
rule_based_autocast,
23+
remove_assert_nodes, # rule_based_autocast might insert assert nodes
24+
]
1825

1926
post_lowering_pass_list = [
2027
remove_input_alias_fixing_clones,
@@ -27,10 +34,6 @@
2734
complex_graph_detection,
2835
]
2936

30-
pre_lowering_pass_list = [
31-
remove_detach,
32-
]
33-
3437
if not is_tegra_platform():
3538
from .fuse_distributed_ops import fuse_distributed_ops
3639

0 commit comments

Comments
 (0)