You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
@@ -511,6 +518,11 @@ def compile(
511
518
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
512
519
offload_module_to_cpu (bool): Offload the module to CPU. This is useful when we need to minimize GPU memory usage.
513
520
use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model
521
+
low_precision_type (Optional[Union[torch.dtype, dtype]]): The precision to reduce to. We currently support torch.float16 and torch.bfloat16. Default is None, which means no low precision is used.
522
+
nodes_to_exclude (Collection[str]): The set of regex patterns to match node names that should remain in FP32. Default is [].
523
+
targets_to_exclude (Collection[Target]): The set of targets (ATen ops) that should remain in FP32. Default is [].
524
+
data_max (float): Maximum absolute value for node outputs, nodes with outputs greater than this value will remain in FP32. Default is 512.
525
+
max_depth_of_reduction (Optional[int]): Maximum depth of reduction allowed in low precision. Nodes with higher reduction depths will remain in FP32. If not provided, infinity will be used. Default is None.
514
526
**kwargs: Any,
515
527
Returns:
516
528
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -593,6 +605,19 @@ def compile(
593
605
f"use_explicit_typing was set to True, however found that enabled_precisions was also specified (saw: {enabled_precisions}, expected: dtype.f32, dtype.f4). enabled_precisions should not be used when use_explicit_typing=True"
f"low_precision_type must be one of torch.float16, torch.bfloat16, dtype.f16, dtype.bf16, got {low_precision_type}"
619
+
)
620
+
596
621
ifuse_fp32_acc:
597
622
logger.debug(
598
623
"FP32 accumulation for matmul layers is enabled. This option should only be enabled if the model already has FP16 weights and has no effect if it has FP32 weights. \
tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"].
98
104
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
99
105
use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model
106
+
low_precision_type (Optional[Union[torch.dtype, dtype]]): The precision to reduce to. We currently support torch.float16 and torch.bfloat16. Default is None, which means no low precision is used.
107
+
nodes_to_exclude (Collection[str]): The set of regex patterns to match node names that should remain in FP32. Default is [].
108
+
targets_to_exclude (Collection[Target]): The set of targets (ATen ops) that should remain in FP32. Default is [].
109
+
data_max (float): Maximum absolute value for node outputs, nodes with outputs greater than this value will remain in FP32. Default is 512.
110
+
max_depth_of_reduction (Optional[int]): Maximum depth of reduction allowed in low precision. Nodes with higher reduction depths will remain in FP32. If not provided, infinity will be used. Default is None.
111
+
intermediate_node_outputs (dict[str, torch.Tensor]): The intermediate node outputs of the graph. Default is {}.
0 commit comments