You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In order to respect the types specified by the user in the model (eg: in this case, ``linear2`` layer to run in FP16), users can enable
53
-
the compilation setting ``use_explicit_typing=True``. Compiling with this option results in the following TensorRT logs:
53
+
However, since TensorRT 10.12, TensorRT has deprecated weak typing, we must set ``use_explicit_typing=True``
54
+
to enable strong typing, which means users must specify the precision of the nodes in the model. For example,
55
+
in the case above, we set ``linear2`` layer to run in FP16, so if we compile the model with the following settings,
56
+
the ``linear2`` layer will run in FP16 and other layers will run in FP32 as shown in the following TensorRT logs:
54
57
55
58
.. code-block:: python
56
59
@@ -68,32 +71,67 @@ the compilation setting ``use_explicit_typing=True``. Compiling with this option
68
71
Autocast
69
72
---------------
70
73
71
-
Weak typing behavior in TensorRT is deprecated. However it is a good way to maximize performance. Therefore, in Torch-TensorRT,
72
-
we want to provide a way to enable weak typing behavior in Torch-TensorRT, which is called `Autocast`.
74
+
Weak typing behavior in TensorRT is deprecated. However mixed precision is a good way to maximize performance.
75
+
Therefore, in Torch-TensorRT, we want to provide a way to enable mixed precision behavior like weak typing in
76
+
old TensorRT, which is called `Autocast`.
73
77
74
-
Torch-TensorRT Autocast intelligently selects nodes to keep in FP32 precision to maintain model accuracy while benefiting from
75
-
reduced precision on the rest of the nodes. Torch-TensorRT Autocast also supports users to specify which nodes to exclude from Autocast,
76
-
considering some nodes might be more sensitive to affecting accuracy. In addition, Torch-TensorRT Autocast can cooperate with PyTorch
77
-
native Autocast, allowing users to use both PyTorch and Torch-TensorRT Autocast in the same model. Torch-TensorRT respects the precision
78
-
of the nodes within PyTorch Autocast.
78
+
Before we dive into Torch-TensorRT Autocast, let's first take a look at PyTorch Autocast. PyTorch Autocast is a
79
+
context-based autocast, which means it will affect the precision of the nodes inside the context. For example,
80
+
in PyTorch, we can do the following:
79
81
80
-
To enable Torch-TensorRT Autocast, users need to set both ``enable_autocast=True`` and ``use_explicit_typing=True``. For example,
82
+
.. code-block:: python
83
+
84
+
x =self.linear1(x)
85
+
with torch.autocast(device_type="cuda", enabled=True, dtype=torch.float16):
86
+
x =self.linear2(x)
87
+
x =self.linear3(x)
88
+
89
+
This will run ``linear2`` in FP16 and other layers remain in FP32. Please refer to `PyTorch Autocast documentation <https://docs.pytorch.org/docs/stable/amp.html#torch.autocast>`_ for more details.
90
+
91
+
Unlike PyTorch Autocast, Torch-TensorRT Autocast is a rule-based autocast, which intelligently selects nodes to
92
+
keep in FP32 precision to maintain model accuracy while benefiting from reduced precision on the rest of the nodes.
93
+
Torch-TensorRT Autocast also supports users to specify which nodes to exclude from Autocast, considering some nodes
94
+
might be more sensitive to affecting accuracy. In addition, Torch-TensorRT Autocast can cooperate with PyTorch Autocast,
95
+
allowing users to use both PyTorch Autocast and Torch-TensorRT Autocast in the same model. Torch-TensorRT Autocast
96
+
respects the precision of the nodes within PyTorch Autocast context.
97
+
98
+
To enable Torch-TensorRT Autocast, we need to set both ``enable_autocast=True`` and ``use_explicit_typing=True``.
99
+
On top of them, we can also specify the precision of the nodes to reduce to by ``autocast_low_precision_type``,
100
+
and exclude certain nodes/ops from Torch-TensorRT Autocast by ``autocast_excluded_nodes`` or ``autocast_excluded_ops``.
0 commit comments