Skip to content

Commit e15ce94

Browse files
committed
change names of API and support for user specified node names
1 parent f7d8068 commit e15ce94

File tree

6 files changed

+111
-147
lines changed

6 files changed

+111
-147
lines changed
Lines changed: 24 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,6 @@
11
import torch
22
import torch.nn as nn
33
import torch_tensorrt
4-
import torchvision
5-
6-
7-
class MyModule(torch.nn.Module):
8-
def forward(self, a_float32, b_float32, c_float32, d_float32):
9-
with torch.autocast(device_type="cuda"):
10-
e_float16 = torch.mm(a_float32, b_float32)
11-
with torch.autocast(device_type="cuda", enabled=False):
12-
# Calls e_float16.float() to ensure float32 execution
13-
# (necessary because e_float16 was created in an autocasted region)
14-
f_float32 = torch.mm(c_float32, e_float16.float())
15-
16-
# No manual casts are required when re-entering the autocast-enabled region.
17-
# torch.mm again runs in float16 and produces float16 output, regardless of input types.
18-
g_float16 = torch.mm(d_float32, f_float32)
19-
return g_float16
204

215

226
class AutocastExample(nn.Module):
@@ -36,44 +20,32 @@ def __init__(self):
3620
self.fc1 = nn.Linear(16 * 8 * 8, 10)
3721

3822
def forward(self, x, y):
39-
out = self.pool1(self.relu1(self.conv1(x))) # fp16
40-
x = self.pool2(self.relu2(self.conv2(out))) # fp16
41-
x = self.flatten(x)
23+
x = self.conv1(x) # fp32 because of "^conv1$" in `autocast_excluded_nodes`
24+
x = self.relu1(x) # fp32 because of "relu" in `autocast_excluded_nodes`
25+
out = self.pool1(x) # fp16
26+
x = self.conv2(out) # fp16
27+
x = self.relu2(x) # fp32 because of "relu" in `autocast_excluded_nodes`
28+
x = self.pool2(x) # fp16
29+
x = self.flatten(
30+
x
31+
) # fp32 because of `torch.ops.aten.flatten.using_ints` in `autocast_excluded_ops`
32+
# Respect the precisions in the pytorch autocast context
4233
with torch.autocast(x.device.type, enabled=True, dtype=torch.float32):
43-
x = self.fc1(x) # fp32
34+
x = self.fc1(x)
4435
with torch.autocast(x.device.type, enabled=False):
45-
x = torch.sub(x.half(), y) # fp16
46-
out2 = torch.add(x, x) # fp16
36+
x = torch.sub(x.half(), y)
37+
out2 = torch.add(x, x)
4738
with torch.autocast(x.device.type, enabled=True, dtype=torch.float16):
48-
out2 = torch.log(out2) # fp32
39+
out2 = torch.log(out2)
4940
return x, out, out2
5041

5142

52-
class MyResNet18Wrapper(torch.nn.Module):
53-
def __init__(self, num_classes=1000, pretrained=True):
54-
super(MyResNet18Wrapper, self).__init__()
55-
self.resnet = torchvision.models.resnet18(
56-
num_classes=num_classes, weights="IMAGENET1K_V1" if pretrained else None
57-
)
58-
59-
def forward(self, x):
60-
x = self.resnet(x)
61-
return x
62-
63-
6443
if __name__ == "__main__":
65-
# model = MyModule().cuda().eval()
66-
# inputs = (torch.randn((8, 8), device="cuda"),
67-
# torch.randn((8, 8), device="cuda"),
68-
# torch.randn((8, 8), device="cuda"),
69-
# torch.randn((8, 8), device="cuda"),)
70-
71-
# model = AutocastExample().cuda().eval()
72-
# inputs = (torch.randn((1, 3, 32, 32), dtype=torch.float32, device="cuda"),
73-
# torch.randn((1,), dtype=torch.float16, device="cuda"),)
74-
75-
model = MyResNet18Wrapper().cuda().eval()
76-
inputs = (torch.randn((1, 3, 224, 224), dtype=torch.float32, device="cuda"),)
44+
model = AutocastExample().cuda().eval()
45+
inputs = (
46+
torch.randn((1, 3, 32, 32), dtype=torch.float32, device="cuda"),
47+
torch.randn((1,), dtype=torch.float16, device="cuda"),
48+
)
7749

7850
ep = torch.export.export(model, inputs)
7951

@@ -93,11 +65,11 @@ def forward(self, x):
9365
##### strong typing + autocast #####
9466
use_explicit_typing=True,
9567
enable_autocast=True,
96-
low_precision_type=torch.float16,
97-
# nodes_to_exclude={"^conv2d$"},
98-
targets_to_exclude={},
99-
data_max=512,
100-
max_depth_of_reduction=None,
68+
autocast_low_precision_type=torch.float16,
69+
autocast_excluded_nodes={"^conv1$", "relu"},
70+
autocast_excluded_ops={torch.ops.aten.flatten.using_ints},
71+
autocast_data_max=512,
72+
autocast_max_depth_of_reduction=None,
10173
)
10274

10375
trt_out = trt_mod(*inputs)

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -435,13 +435,15 @@ def compile(
435435
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
436436
use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE,
437437
enable_autocast: bool = _defaults.ENABLE_AUTOCAST,
438-
low_precision_type: Optional[
438+
autocast_low_precision_type: Optional[
439439
Union[torch.dtype, dtype]
440-
] = _defaults.LOW_PRECISION_TYPE,
441-
nodes_to_exclude: Collection[str] = _defaults.NODES_TO_EXCLUDE,
442-
targets_to_exclude: Collection[Target] = _defaults.TARGETS_TO_EXCLUDE,
443-
data_max: float = _defaults.DATA_MAX,
444-
max_depth_of_reduction: Optional[int] = _defaults.MAX_DEPTH_OF_REDUCTION,
440+
] = _defaults.AUTOCAST_LOW_PRECISION_TYPE,
441+
autocast_excluded_nodes: Collection[str] = _defaults.AUTOCAST_EXCLUDED_NODES,
442+
autocast_excluded_ops: Collection[Target] = _defaults.AUTOCAST_EXCLUDED_OPS,
443+
autocast_data_max: float = _defaults.AUTOCAST_DATA_MAX,
444+
autocast_max_depth_of_reduction: Optional[
445+
int
446+
] = _defaults.AUTOCAST_MAX_DEPTH_OF_REDUCTION,
445447
**kwargs: Any,
446448
) -> torch.fx.GraphModule:
447449
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
@@ -520,11 +522,11 @@ def compile(
520522
offload_module_to_cpu (bool): Offload the module to CPU. This is useful when we need to minimize GPU memory usage.
521523
use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model
522524
enable_autocast (bool): Whether to enable autocast. If enabled, use_explicit_typing will be set to True.
523-
low_precision_type (Optional[Union[torch.dtype, dtype]]): The precision to reduce to. We currently support torch.float16 and torch.bfloat16. Default is None, which means no low precision is used.
524-
nodes_to_exclude (Collection[str]): The set of regex patterns to match node names that should remain in FP32. Default is [].
525-
targets_to_exclude (Collection[Target]): The set of targets (ATen ops) that should remain in FP32. Default is [].
526-
data_max (float): Maximum absolute value for node outputs, nodes with outputs greater than this value will remain in FP32. Default is 512.
527-
max_depth_of_reduction (Optional[int]): Maximum depth of reduction allowed in low precision. Nodes with higher reduction depths will remain in FP32. If not provided, infinity will be used. Default is None.
525+
autocast_low_precision_type (Optional[Union[torch.dtype, dtype]]): The precision to reduce to. We currently support torch.float16 and torch.bfloat16. Default is None, which means no low precision is used.
526+
autocast_excluded_nodes (Collection[str]): The set of regex patterns to match node names that should remain in FP32. Default is [].
527+
autocast_excluded_ops (Collection[Target]): The set of targets (ATen ops) that should remain in FP32. Default is [].
528+
autocast_data_max (float): Maximum absolute value for node outputs, nodes with outputs greater than this value will remain in FP32. Default is 512.
529+
autocast_max_depth_of_reduction (Optional[int]): Maximum depth of reduction allowed in low precision. Nodes with higher reduction depths will remain in FP32. If not provided, infinity will be used. Default is None.
528530
**kwargs: Any,
529531
Returns:
530532
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -611,17 +613,17 @@ def compile(
611613
f"use_explicit_typing was set to True, however found that enabled_precisions was also specified (saw: {enabled_precisions}, expected: dtype.f32, dtype.f4). enabled_precisions should not be used when use_explicit_typing=True"
612614
)
613615

614-
if low_precision_type is not None:
615-
if not isinstance(low_precision_type, (torch.dtype, dtype)):
616+
if autocast_low_precision_type is not None:
617+
if not isinstance(autocast_low_precision_type, (torch.dtype, dtype)):
616618
raise ValueError(
617-
f"low_precision_type must be a torch.dtype or torch_tensorrt._enums.dtype, got {type(low_precision_type)}"
619+
f"autocast_low_precision_type must be a torch.dtype or torch_tensorrt._enums.dtype, got {type(autocast_low_precision_type)}"
618620
)
619-
if low_precision_type not in {
621+
if autocast_low_precision_type not in {
620622
torch.float16,
621623
torch.bfloat16,
622-
} and low_precision_type not in {dtype.f16, dtype.bf16}:
624+
} and autocast_low_precision_type not in {dtype.f16, dtype.bf16}:
623625
raise ValueError(
624-
f"low_precision_type must be one of torch.float16, torch.bfloat16, dtype.f16, dtype.bf16, got {low_precision_type}"
626+
f"autocast_low_precision_type must be one of torch.float16, torch.bfloat16, dtype.f16, dtype.bf16, got {autocast_low_precision_type}"
625627
)
626628

627629
if use_fp32_acc:
@@ -654,7 +656,7 @@ def compile(
654656
arg_inputs = [arg_inputs] # type: ignore
655657

656658
# save intermediate outputs of each node for Autocast
657-
intermediate_node_outputs = {}
659+
autocast_intermediate_node_outputs = {}
658660
if not use_explicit_typing:
659661

660662
class DumpInterpreter(torch.fx.Interpreter): # type: ignore[misc]
@@ -670,7 +672,7 @@ def run_node(self, n: torch.fx.Node) -> Any:
670672
raise ValueError(
671673
f"Please file a bug with Torch-TensorRT because it expects a torch.Tensor but got {type(out)} for node {n.name}."
672674
)
673-
intermediate_node_outputs[n.name] = out
675+
autocast_intermediate_node_outputs[n.name] = out
674676
return out
675677
return super().run_node(n)
676678

@@ -744,12 +746,12 @@ def _materialize(x: Input | torch.Tensor) -> torch.Tensor:
744746
"offload_module_to_cpu": offload_module_to_cpu,
745747
"use_distributed_mode_trace": use_distributed_mode_trace,
746748
"enable_autocast": enable_autocast,
747-
"low_precision_type": low_precision_type,
748-
"nodes_to_exclude": nodes_to_exclude,
749-
"targets_to_exclude": targets_to_exclude,
750-
"data_max": data_max,
751-
"max_depth_of_reduction": max_depth_of_reduction,
752-
"intermediate_node_outputs": intermediate_node_outputs,
749+
"autocast_low_precision_type": autocast_low_precision_type,
750+
"autocast_excluded_nodes": autocast_excluded_nodes,
751+
"autocast_excluded_ops": autocast_excluded_ops,
752+
"autocast_data_max": autocast_data_max,
753+
"autocast_max_depth_of_reduction": autocast_max_depth_of_reduction,
754+
"autocast_intermediate_node_outputs": autocast_intermediate_node_outputs,
753755
}
754756

755757
settings = CompilationSettings(**compilation_options)

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,11 @@
5858
USE_DISTRIBUTED_MODE_TRACE = False
5959
OFFLOAD_MODULE_TO_CPU = False
6060
ENABLE_AUTOCAST = False
61-
LOW_PRECISION_TYPE = None
62-
NODES_TO_EXCLUDE = set[str]()
63-
TARGETS_TO_EXCLUDE = set[torch.fx.node.Target]()
64-
DATA_MAX = 512
65-
MAX_DEPTH_OF_REDUCTION = None
61+
AUTOCAST_LOW_PRECISION_TYPE = None
62+
AUTOCAST_EXCLUDED_NODES = set[str]()
63+
AUTOCAST_EXCLUDED_OPS = set[torch.fx.node.Target]()
64+
AUTOCAST_DATA_MAX = 512
65+
AUTOCAST_MAX_DEPTH_OF_REDUCTION = None
6666

6767
if platform.system() == "Linux":
6868
import pwd

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,12 @@
77
from torch_tensorrt._enums import EngineCapability, dtype
88
from torch_tensorrt.dynamo._defaults import (
99
ASSUME_DYNAMIC_SHAPE_SUPPORT,
10+
AUTOCAST_DATA_MAX,
11+
AUTOCAST_EXCLUDED_NODES,
12+
AUTOCAST_EXCLUDED_OPS,
13+
AUTOCAST_LOW_PRECISION_TYPE,
14+
AUTOCAST_MAX_DEPTH_OF_REDUCTION,
1015
CACHE_BUILT_ENGINES,
11-
DATA_MAX,
1216
DISABLE_TF32,
1317
DLA_GLOBAL_DRAM_SIZE,
1418
DLA_LOCAL_DRAM_SIZE,
@@ -24,11 +28,8 @@
2428
IMMUTABLE_WEIGHTS,
2529
L2_LIMIT_FOR_TILING,
2630
LAZY_ENGINE_INIT,
27-
LOW_PRECISION_TYPE,
2831
MAX_AUX_STREAMS,
29-
MAX_DEPTH_OF_REDUCTION,
3032
MIN_BLOCK_SIZE,
31-
NODES_TO_EXCLUDE,
3233
NUM_AVG_TIMING_ITERS,
3334
OFFLOAD_MODULE_TO_CPU,
3435
OPTIMIZATION_LEVEL,
@@ -38,7 +39,6 @@
3839
REUSE_CACHED_ENGINES,
3940
SPARSE_WEIGHTS,
4041
STRIP_ENGINE_WEIGHTS,
41-
TARGETS_TO_EXCLUDE,
4242
TILING_OPTIMIZATION_LEVEL,
4343
TIMING_CACHE_PATH,
4444
TRUNCATE_DOUBLE,
@@ -105,12 +105,12 @@ class CompilationSettings:
105105
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
106106
use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model
107107
enable_autocast (bool): Whether to enable autocast. If enabled, use_explicit_typing will be set to True.
108-
low_precision_type (Optional[Union[torch.dtype, dtype]]): The precision to reduce to. We currently support torch.float16 and torch.bfloat16. Default is None, which means no low precision is used.
109-
nodes_to_exclude (Collection[str]): The set of regex patterns to match node names that should remain in FP32. Default is [].
110-
targets_to_exclude (Collection[Target]): The set of targets (ATen ops) that should remain in FP32. Default is [].
111-
data_max (float): Maximum absolute value for node outputs, nodes with outputs greater than this value will remain in FP32. Default is 512.
112-
max_depth_of_reduction (Optional[int]): Maximum depth of reduction allowed in low precision. Nodes with higher reduction depths will remain in FP32. If not provided, infinity will be used. Default is None.
113-
intermediate_node_outputs (dict[str, torch.Tensor]): The intermediate node outputs of the graph. Default is {}.
108+
autocast_low_precision_type (Optional[Union[torch.dtype, dtype]]): The precision to reduce to. We currently support torch.float16 and torch.bfloat16. Default is None, which means no low precision is used.
109+
autocast_excluded_nodes (Collection[str]): The set of regex patterns to match node names that should remain in FP32. Default is [].
110+
autocast_excluded_ops (Collection[Target]): The set of targets (ATen ops) that should remain in FP32. Default is [].
111+
autocast_data_max (float): Maximum absolute value for node outputs, nodes with outputs greater than this value will remain in FP32. Default is 512.
112+
autocast_max_depth_of_reduction (Optional[int]): Maximum depth of reduction allowed in low precision. Nodes with higher reduction depths will remain in FP32. If not provided, infinity will be used. Default is None.
113+
autocast_intermediate_node_outputs (dict[str, torch.Tensor]): The intermediate node outputs of the graph. Default is {}.
114114
"""
115115

116116
enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
@@ -155,14 +155,16 @@ class CompilationSettings:
155155
use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE
156156
offload_module_to_cpu: bool = OFFLOAD_MODULE_TO_CPU
157157
enable_autocast: bool = ENABLE_AUTOCAST
158-
low_precision_type: Optional[dtype] = LOW_PRECISION_TYPE
159-
nodes_to_exclude: Collection[str] = field(default_factory=lambda: NODES_TO_EXCLUDE)
160-
targets_to_exclude: Collection[Target] = field(
161-
default_factory=lambda: TARGETS_TO_EXCLUDE
158+
autocast_low_precision_type: Optional[dtype] = AUTOCAST_LOW_PRECISION_TYPE
159+
autocast_excluded_nodes: Collection[str] = field(
160+
default_factory=lambda: AUTOCAST_EXCLUDED_NODES
162161
)
163-
data_max: float = DATA_MAX
164-
max_depth_of_reduction: Optional[int] = MAX_DEPTH_OF_REDUCTION
165-
intermediate_node_outputs: dict[str, torch.Tensor] = field(
162+
autocast_excluded_ops: Collection[Target] = field(
163+
default_factory=lambda: AUTOCAST_EXCLUDED_OPS
164+
)
165+
autocast_data_max: float = AUTOCAST_DATA_MAX
166+
autocast_max_depth_of_reduction: Optional[int] = AUTOCAST_MAX_DEPTH_OF_REDUCTION
167+
autocast_intermediate_node_outputs: dict[str, torch.Tensor] = field(
166168
default_factory=lambda: {}
167169
)
168170

py/torch_tensorrt/dynamo/lowering/passes/nodeclassifier.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -53,24 +53,28 @@ def __init__(self, disabled_node_name_regex):
5353
self.disabled_node_name_regex = disabled_node_name_regex
5454

5555
def _check_inner(self, node):
56+
stack = node.meta.get("nn_module_stack")
57+
node_name = next(reversed(stack), "").split("__")[
58+
-1
59+
] # get the user specified name of the node
5660
return any(
57-
re.match(regex, node.name) for regex in self.disabled_node_name_regex
61+
re.match(regex, node_name) for regex in self.disabled_node_name_regex
5862
)
5963

6064

61-
class DisabledTargets(NodeRuleBase):
65+
class DisabledOpTypes(NodeRuleBase):
6266
"""Rule for keeping nodes with specific operation types in high precision."""
6367

64-
def __init__(self, targets_to_exclude):
68+
def __init__(self, excluded_ops):
6569
"""Initialize the rule.
6670
6771
Args:
68-
targets_to_exclude: List of operation types to keep in high precision.
72+
excluded_ops: List of operation types to keep in high precision.
6973
"""
70-
self.targets_to_exclude = targets_to_exclude
74+
self.excluded_ops = excluded_ops
7175

7276
def _check_inner(self, node):
73-
return node.target in self.targets_to_exclude
77+
return node.target in self.excluded_ops
7478

7579

7680
class IORangeRule(NodeRuleBase):
@@ -219,8 +223,8 @@ class NodeClassifier:
219223
def __init__(
220224
self,
221225
nodes,
222-
nodes_to_exclude: Collection[str] | None = None,
223-
targets_to_exclude: Collection[torch.fx.node.Target] | None = None,
226+
excluded_nodes: Collection[str] | None = None,
227+
excluded_ops: Collection[torch.fx.node.Target] | None = None,
224228
custom_rule: NodeRuleBase | None = None,
225229
data_max: float | None = 1000.0,
226230
max_depth_of_reduction: int | None = None,
@@ -236,8 +240,8 @@ def __init__(
236240
max_depth_of_reduction: Maximum depth of reduction allowed in low precision.
237241
"""
238242
self.nodes = nodes
239-
self.nodes_to_exclude = nodes_to_exclude
240-
self.targets_to_exclude = targets_to_exclude
243+
self.excluded_nodes = excluded_nodes
244+
self.excluded_ops = excluded_ops
241245
self.custom_rule = custom_rule
242246
self.data_max = data_max
243247
self.max_depth_of_reduction = max_depth_of_reduction
@@ -252,10 +256,10 @@ def _gen_block_node_rules(self, reference_data):
252256
list[NodeRuleBase]: List of rules to apply.
253257
"""
254258
block_node_rules: list[NodeRuleBase] = []
255-
if self.nodes_to_exclude:
256-
block_node_rules.append(DisabledNodeNameRegexRule(self.nodes_to_exclude))
257-
if self.targets_to_exclude:
258-
block_node_rules.append(DisabledTargets(self.targets_to_exclude))
259+
if self.excluded_nodes:
260+
block_node_rules.append(DisabledNodeNameRegexRule(self.excluded_nodes))
261+
if self.excluded_ops:
262+
block_node_rules.append(DisabledOpTypes(self.excluded_ops))
259263
if reference_data:
260264
block_node_rules.append(IORangeRule(self.data_max, reference_data))
261265
if self.max_depth_of_reduction is not None:

0 commit comments

Comments
 (0)