You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
disable_tf32 (bool): Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas
142
142
assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False
143
143
sparse_weights (bool): Enable sparsity for convolution and fully connected layers.
144
-
enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels
144
+
enabled_precisions (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels
145
145
capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels
146
146
num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels
147
147
workspace_size (int): Maximum size of workspace given to TensorRT
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
519
520
offload_module_to_cpu (bool): Offload the module to CPU. This is useful when we need to minimize GPU memory usage.
520
521
use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model
522
+
enable_autocast (bool): Whether to enable autocast. If enabled, use_explicit_typing will be set to True.
521
523
low_precision_type (Optional[Union[torch.dtype, dtype]]): The precision to reduce to. We currently support torch.float16 and torch.bfloat16. Default is None, which means no low precision is used.
522
524
nodes_to_exclude (Collection[str]): The set of regex patterns to match node names that should remain in FP32. Default is [].
523
525
targets_to_exclude (Collection[Target]): The set of targets (ATen ops) that should remain in FP32. Default is [].
@@ -596,6 +598,10 @@ def compile(
596
598
"\nThis feature is unimplemented in Torch-TRT Dynamo currently."
597
599
)
598
600
601
+
ifenable_autocast:
602
+
use_explicit_typing=True
603
+
logger.debug("Autocast is enabled, setting use_explicit_typing to True.")
Copy file name to clipboardExpand all lines: py/torch_tensorrt/dynamo/_settings.py
+4Lines changed: 4 additions & 0 deletions
Original file line number
Diff line number
Diff line change
@@ -14,6 +14,7 @@
14
14
DLA_LOCAL_DRAM_SIZE,
15
15
DLA_SRAM_SIZE,
16
16
DRYRUN,
17
+
ENABLE_AUTOCAST,
17
18
ENABLE_CROSS_COMPILE_FOR_WINDOWS,
18
19
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
19
20
ENABLE_WEIGHT_STREAMING,
@@ -103,6 +104,7 @@ class CompilationSettings:
103
104
tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"].
104
105
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
105
106
use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model
107
+
enable_autocast (bool): Whether to enable autocast. If enabled, use_explicit_typing will be set to True.
106
108
low_precision_type (Optional[Union[torch.dtype, dtype]]): The precision to reduce to. We currently support torch.float16 and torch.bfloat16. Default is None, which means no low precision is used.
107
109
nodes_to_exclude (Collection[str]): The set of regex patterns to match node names that should remain in FP32. Default is [].
108
110
targets_to_exclude (Collection[Target]): The set of targets (ATen ops) that should remain in FP32. Default is [].
0 commit comments