Skip to content

Commit 3e008c2

Browse files
committed
amend doc and settings
1 parent a990653 commit 3e008c2

File tree

2 files changed

+45
-24
lines changed

2 files changed

+45
-24
lines changed

docsrc/user_guide/mixed_precision.rst

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,15 @@ Consider the following PyTorch model which explicitly casts intermediate layer t
3232
return x
3333
3434
35-
If we compile the above model using Torch-TensorRT, layer profiling logs indicate that all the layers are
36-
run in FP32. This is because TensorRT picks the kernels for layers which result in the best performance.
35+
If we compile the above model using Torch-TensorRT with the following settings, layer profiling logs indicate that all the layers are
36+
run in FP32. This is because TensorRT picks the kernels for layers which result in the best performance (i.e., weak typing in TensorRT).
3737

3838
.. code-block:: python
3939
4040
inputs = [torch.randn((1, 10), dtype=torch.float32).cuda()]
4141
mod = MyModule().eval().cuda()
4242
ep = torch.export.export(mod, tuple(inputs))
43-
with torch_tensorrt.logging.debug():
44-
trt_gm = torch_tensorrt.dynamo.compile(ep,
45-
inputs=inputs,
46-
debug=True)
43+
trt_gm = torch_tensorrt.dynamo.compile(ep, inputs=inputs)
4744
4845
# Debug log info
4946
# Layers:
@@ -53,31 +50,50 @@ run in FP32. This is because TensorRT picks the kernels for layers which result
5350
5451
5552
In order to respect the types specified by the user in the model (eg: in this case, ``linear2`` layer to run in FP16), users can enable
56-
the compilation setting ``use_explicit_typing=True``. Compiling with this option results in the following TensorRT logs
57-
58-
.. note:: If you enable ``use_explicit_typing=True``, only torch.float32 is supported in the enabled_precisions.
59-
53+
the compilation setting ``use_explicit_typing=True``. Compiling with this option results in the following TensorRT logs:
6054

6155
.. code-block:: python
6256
6357
inputs = [torch.randn((1, 10), dtype=torch.float32).cuda()]
6458
mod = MyModule().eval().cuda()
6559
ep = torch.export.export(mod, tuple(inputs))
66-
with torch_tensorrt.logging.debug():
67-
trt_gm = torch_tensorrt.dynamo.compile(ep,
68-
inputs=inputs,
69-
use_explicit_typing=True,
70-
debug=True)
60+
trt_gm = torch_tensorrt.dynamo.compile(ep, inputs=inputs, use_explicit_typing=True)
7161
7262
# Debug log info
7363
# Layers:
7464
# Name: __myl_MulSumAddCas_myl0_0, LayerType: kgen, Inputs: [ { Name: linear1/addmm_constant_0 _ linear1/addmm_add_broadcast_to_same_shape_lhs_broadcast_constantFloat, Dimensions: [1,10], Format/Datatype: Float }, { Name: __mye112_dconst, Dimensions: [10,10], Format/Datatype: Float }, { Name: x, Dimensions: [10,1], Format/Datatype: Float }], Outputs: [ { Name: __myln_k_arg__bb1_2, Dimensions: [1,10], Format/Datatype: Half }], TacticName: __myl_MulSumAddCas_0xacf8f5dd9be2f3e7bb09cdddeac6c936, StreamId: 0, Metadata:
7565
# Name: __myl_ResMulSumAddCas_myl0_1, LayerType: kgen, Inputs: [ { Name: __mye127_dconst, Dimensions: [10,30], Format/Datatype: Half }, { Name: linear2/addmm_1_constant_0 _ linear2/addmm_1_add_broadcast_to_same_shape_lhs_broadcast_constantHalf, Dimensions: [1,30], Format/Datatype: Half }, { Name: __myln_k_arg__bb1_2, Dimensions: [1,10], Format/Datatype: Half }], Outputs: [ { Name: __myln_k_arg__bb1_3, Dimensions: [1,30], Format/Datatype: Float }], TacticName: __myl_ResMulSumAddCas_0x5a3b318b5a1c97b7d5110c0291481337, StreamId: 0, Metadata:
7666
# Name: __myl_ResMulSumAdd_myl0_2, LayerType: kgen, Inputs: [ { Name: __mye142_dconst, Dimensions: [30,40], Format/Datatype: Float }, { Name: linear3/addmm_2_constant_0 _ linear3/addmm_2_add_broadcast_to_same_shape_lhs_broadcast_constantFloat, Dimensions: [1,40], Format/Datatype: Float }, { Name: __myln_k_arg__bb1_3, Dimensions: [1,30], Format/Datatype: Float }], Outputs: [ { Name: output0, Dimensions: [1,40], Format/Datatype: Float }], TacticName: __myl_ResMulSumAdd_0x3fad91127c640fd6db771aa9cde67db0, StreamId: 0, Metadata:
7767
78-
Now the ``linear2`` layer runs in FP16 as shown in the above logs.
68+
Autocast
69+
---------------
70+
71+
Weak typing behavior in TensorRT is deprecated. However it is a good way to maximize performance. Therefore, in Torch-TensorRT,
72+
we want to provide a way to enable weak typing behavior in Torch-TensorRT, which is called `Autocast`.
73+
74+
Torch-TensorRT Autocast intelligently selects nodes to keep in FP32 precision to maintain model accuracy while benefiting from
75+
reduced precision on the rest of the nodes. Torch-TensorRT Autocast also supports users to specify which nodes to exclude from Autocast,
76+
considering some nodes might be more sensitive to affecting accuracy. In addition, Torch-TensorRT Autocast can cooperate with PyTorch
77+
native Autocast, allowing users to use both PyTorch and Torch-TensorRT Autocast in the same model. Torch-TensorRT respects the precision
78+
of the nodes within PyTorch Autocast.
79+
80+
To enable Torch-TensorRT Autocast, users need to set both ``enable_autocast=True`` and ``use_explicit_typing=True``. For example,
81+
82+
.. code-block:: python
83+
84+
inputs = [torch.randn((1, 10), dtype=torch.float32).cuda()]
85+
mod = MyModule().eval().cuda()
86+
ep = torch.export.export(mod, tuple(inputs))
87+
trt_gm = torch_tensorrt.dynamo.compile(ep, inputs=inputs, enable_autocast=True, use_explicit_typing=True)
88+
7989
90+
Users can also specify the precision of the nodes by ``autocast_low_precision_type``, or ``autocast_excluded_nodes`` / ``autocast_excluded_ops``
91+
to exclude certain nodes/ops from Autocast.
8092

93+
In summary, there are three ways in Torch-TensorRT to enable mixed precision:
94+
1. TRT chooses precision (weak typing): ``use_explicit_typing=False + enable_autocast=False``
95+
2. User specifies precision (strong typing): ``use_explicit_typing=True + enable_autocast=False``
96+
3. Autocast chooses precision (autocast + strong typing): ``use_explicit_typing=True + enable_autocast=True``
8197

8298
FP32 Accumulation
8399
-----------------
@@ -93,14 +109,12 @@ When ``use_fp32_acc=True`` is set, Torch-TensorRT will attempt to use FP32 accum
93109
inputs = [torch.randn((1, 10), dtype=torch.float16).cuda()]
94110
mod = MyModule().eval().cuda()
95111
ep = torch.export.export(mod, tuple(inputs))
96-
with torch_tensorrt.logging.debug():
97-
trt_gm = torch_tensorrt.dynamo.compile(
98-
ep,
99-
inputs=inputs,
100-
use_fp32_acc=True,
101-
use_explicit_typing=True, # Explicit typing must be enabled
102-
debug=True
103-
)
112+
trt_gm = torch_tensorrt.dynamo.compile(
113+
ep,
114+
inputs=inputs,
115+
use_fp32_acc=True,
116+
use_explicit_typing=True, # Explicit typing must be enabled
117+
)
104118
105119
# Debug log info
106120
# Layers:

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,13 @@ def __setstate__(self, state: dict[str, Any]) -> None:
201201
"enable_weight_streaming",
202202
"tiling_optimization_level",
203203
"l2_limit_for_tiling",
204+
"enable_autocast",
205+
"autocast_low_precision_type",
206+
"autocast_excluded_nodes",
207+
"autocast_excluded_ops",
208+
"autocast_max_output_threshold",
209+
"autocast_max_depth_of_reduction",
210+
"autocast_calibration_dataloader",
204211
)
205212

206213

0 commit comments

Comments
 (0)