Skip to content

Commit f7d8068

Browse files
committed
add arg enable_autocast
1 parent f6c7c7c commit f7d8068

File tree

7 files changed

+33
-12
lines changed

7 files changed

+33
-12
lines changed

examples/dynamo/autocast_example.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,14 @@ def forward(self, x):
8585
trt_mod = torch_tensorrt.compile(
8686
ep.module(),
8787
arg_inputs=inputs,
88-
use_explicit_typing=False,
8988
min_block_size=1,
9089
use_python_runtime=True,
90+
##### weak typing #####
91+
# use_explicit_typing=False,
92+
# enabled_precisions={torch.float16},
93+
##### strong typing + autocast #####
94+
use_explicit_typing=True,
95+
enable_autocast=True,
9196
low_precision_type=torch.float16,
9297
# nodes_to_exclude={"^conv2d$"},
9398
targets_to_exclude={},

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def cross_compile_for_windows(
141141
disable_tf32 (bool): Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas
142142
assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False
143143
sparse_weights (bool): Enable sparsity for convolution and fully connected layers.
144-
enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels
144+
enabled_precisions (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels
145145
capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels
146146
num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels
147147
workspace_size (int): Maximum size of workspace given to TensorRT
@@ -434,6 +434,7 @@ def compile(
434434
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
435435
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
436436
use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE,
437+
enable_autocast: bool = _defaults.ENABLE_AUTOCAST,
437438
low_precision_type: Optional[
438439
Union[torch.dtype, dtype]
439440
] = _defaults.LOW_PRECISION_TYPE,
@@ -518,6 +519,7 @@ def compile(
518519
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
519520
offload_module_to_cpu (bool): Offload the module to CPU. This is useful when we need to minimize GPU memory usage.
520521
use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model
522+
enable_autocast (bool): Whether to enable autocast. If enabled, use_explicit_typing will be set to True.
521523
low_precision_type (Optional[Union[torch.dtype, dtype]]): The precision to reduce to. We currently support torch.float16 and torch.bfloat16. Default is None, which means no low precision is used.
522524
nodes_to_exclude (Collection[str]): The set of regex patterns to match node names that should remain in FP32. Default is [].
523525
targets_to_exclude (Collection[Target]): The set of targets (ATen ops) that should remain in FP32. Default is [].
@@ -596,6 +598,10 @@ def compile(
596598
"\nThis feature is unimplemented in Torch-TRT Dynamo currently."
597599
)
598600

601+
if enable_autocast:
602+
use_explicit_typing = True
603+
logger.debug("Autocast is enabled, setting use_explicit_typing to True.")
604+
599605
if use_explicit_typing:
600606
if len(enabled_precisions) != 1 or not any(
601607
x in enabled_precisions
@@ -608,7 +614,7 @@ def compile(
608614
if low_precision_type is not None:
609615
if not isinstance(low_precision_type, (torch.dtype, dtype)):
610616
raise ValueError(
611-
f"low_precision_type must be a torch.dtype or dtype, got {type(low_precision_type)}"
617+
f"low_precision_type must be a torch.dtype or torch_tensorrt._enums.dtype, got {type(low_precision_type)}"
612618
)
613619
if low_precision_type not in {
614620
torch.float16,
@@ -737,6 +743,7 @@ def _materialize(x: Input | torch.Tensor) -> torch.Tensor:
737743
"l2_limit_for_tiling": l2_limit_for_tiling,
738744
"offload_module_to_cpu": offload_module_to_cpu,
739745
"use_distributed_mode_trace": use_distributed_mode_trace,
746+
"enable_autocast": enable_autocast,
740747
"low_precision_type": low_precision_type,
741748
"nodes_to_exclude": nodes_to_exclude,
742749
"targets_to_exclude": targets_to_exclude,

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
L2_LIMIT_FOR_TILING = -1
5858
USE_DISTRIBUTED_MODE_TRACE = False
5959
OFFLOAD_MODULE_TO_CPU = False
60+
ENABLE_AUTOCAST = False
6061
LOW_PRECISION_TYPE = None
6162
NODES_TO_EXCLUDE = set[str]()
6263
TARGETS_TO_EXCLUDE = set[torch.fx.node.Target]()

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
DLA_LOCAL_DRAM_SIZE,
1515
DLA_SRAM_SIZE,
1616
DRYRUN,
17+
ENABLE_AUTOCAST,
1718
ENABLE_CROSS_COMPILE_FOR_WINDOWS,
1819
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
1920
ENABLE_WEIGHT_STREAMING,
@@ -103,6 +104,7 @@ class CompilationSettings:
103104
tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"].
104105
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
105106
use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model
107+
enable_autocast (bool): Whether to enable autocast. If enabled, use_explicit_typing will be set to True.
106108
low_precision_type (Optional[Union[torch.dtype, dtype]]): The precision to reduce to. We currently support torch.float16 and torch.bfloat16. Default is None, which means no low precision is used.
107109
nodes_to_exclude (Collection[str]): The set of regex patterns to match node names that should remain in FP32. Default is [].
108110
targets_to_exclude (Collection[Target]): The set of targets (ATen ops) that should remain in FP32. Default is [].
@@ -152,6 +154,7 @@ class CompilationSettings:
152154
l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING
153155
use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE
154156
offload_module_to_cpu: bool = OFFLOAD_MODULE_TO_CPU
157+
enable_autocast: bool = ENABLE_AUTOCAST
155158
low_precision_type: Optional[dtype] = LOW_PRECISION_TYPE
156159
nodes_to_exclude: Collection[str] = field(default_factory=lambda: NODES_TO_EXCLUDE)
157160
targets_to_exclude: Collection[Target] = field(
@@ -179,6 +182,7 @@ def __setstate__(self, state: dict[str, Any]) -> None:
179182
self.__dict__.update(state)
180183

181184

185+
# If any of the following setting is changed, the engine should be rebuilt.
182186
_SETTINGS_TO_BE_ENGINE_INVARIANT = (
183187
"enabled_precisions",
184188
"max_aux_streams",

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -292,9 +292,17 @@ def _populate_trt_builder_config(
292292
)
293293

294294
if not self.compilation_settings.use_explicit_typing:
295-
_LOGGER.info(
296-
"Torch-TensorRT uses Autocast to determine the precision of the graph, because weak typing has been deprecated in TensorRT 10.12."
297-
)
295+
if dtype.float16 in self.compilation_settings.enabled_precisions:
296+
builder_config.set_flag(trt.BuilderFlag.FP16)
297+
298+
if dtype.int8 in self.compilation_settings.enabled_precisions:
299+
builder_config.set_flag(trt.BuilderFlag.INT8)
300+
301+
if dtype.fp8 in self.compilation_settings.enabled_precisions:
302+
builder_config.set_flag(trt.BuilderFlag.FP8)
303+
304+
if dtype.bfloat16 in self.compilation_settings.enabled_precisions:
305+
builder_config.set_flag(trt.BuilderFlag.BF16)
298306

299307
if self.compilation_settings.sparse_weights:
300308
builder_config.set_flag(trt.BuilderFlag.SPARSE_WEIGHTS)

py/torch_tensorrt/dynamo/lowering/passes/rule_based_autocast.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ def rule_based_autocast(
2626
gm: torch.fx.GraphModule, settings: CompilationSettings
2727
) -> torch.fx.GraphModule:
2828
"""Rule-based autocast"""
29-
if settings.use_explicit_typing:
30-
logger.debug("Strong typing is enabled, skipping rule-based autocast.")
29+
if not settings.enable_autocast:
30+
logger.debug("Autocast is not enabled, skipping rule-based autocast.")
3131
return gm
3232

3333
# nodes = list(gm.graph.nodes)

py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -154,10 +154,6 @@ def forward(
154154
+ contiguous_inputs[i + 1 :]
155155
)
156156

157-
assert (
158-
contiguous_inputs[i].dtype == inputs[i].dtype
159-
), f"Dtype mismatch for {i}th input. Expect {inputs[i].dtype}, got {contiguous_inputs[i].dtype}."
160-
161157
if need_cudagraphs_record:
162158
# If cudagraphs is enabled, this memory is reserved for future cudagraph runs
163159
# Clone is required to avoid re-using user-provided GPU memory

0 commit comments

Comments
 (0)