Skip to content

Commit 0a62149

Browse files
committed
optimize Cast insertion logic, fix io dtype issue and comments, and add tests
1 parent 4bf12e7 commit 0a62149

File tree

9 files changed

+432
-47
lines changed

9 files changed

+432
-47
lines changed

core/runtime/execute_engine.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,12 @@ void setup_input_tensors(
107107
TORCHTRT_CHECK(
108108
inputs[i].is_cuda(), "Expected input tensors to have device cuda, found device " << inputs[i].device());
109109

110+
auto expected_type =
111+
util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str()));
112+
TORCHTRT_CHECK(
113+
inputs[i].dtype() == expected_type,
114+
"Expected input tensors to have type " << expected_type << ", found type " << inputs[i].dtype());
115+
110116
auto dims = core::util::toDims(inputs[i].sizes());
111117
auto shape = core::util::toVec(dims);
112118
LOG_DEBUG("Input Name: " << name << " Shape: " << dims);

examples/dynamo/autocast_example.py

Lines changed: 21 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
import torch_tensorrt
44

55

6-
class AutocastExample(nn.Module):
6+
class MixedPytorchAutocastModel(nn.Module):
77
def __init__(self):
8-
super(AutocastExample, self).__init__()
8+
super(MixedPytorchAutocastModel, self).__init__()
99
self.conv1 = nn.Conv2d(
1010
in_channels=3, out_channels=8, kernel_size=3, stride=1, padding=1
1111
)
@@ -19,47 +19,36 @@ def __init__(self):
1919
self.flatten = nn.Flatten()
2020
self.fc1 = nn.Linear(16 * 8 * 8, 10)
2121

22-
def forward(self, x, y):
23-
x = self.conv1(x) # fp32 because of "^conv1$" in `autocast_excluded_nodes`
24-
x = self.relu1(x) # fp32 because of "relu" in `autocast_excluded_nodes`
25-
out = self.pool1(x) # fp16
26-
x = self.conv2(out) # fp16
27-
x = self.relu2(x) # fp32 because of "relu" in `autocast_excluded_nodes`
28-
x = self.pool2(x) # fp16
29-
x = self.flatten(
30-
x
31-
) # fp32 because of `torch.ops.aten.flatten.using_ints` in `autocast_excluded_ops`
32-
# Respect the precisions in the pytorch autocast context
33-
with torch.autocast(x.device.type, enabled=True, dtype=torch.float32):
34-
x = self.fc1(x) # fp32
35-
with torch.autocast(x.device.type, enabled=False):
36-
x = torch.sub(x.half(), y) # fp16
37-
out2 = torch.add(x, x) # fp16
22+
def forward(self, x):
23+
x = self.conv1(x)
24+
x = self.relu1(x)
25+
x = self.pool1(x)
26+
x = self.conv2(x)
27+
x = self.relu2(x)
28+
x = self.pool2(x)
29+
x = self.flatten(x)
3830
with torch.autocast(x.device.type, enabled=True, dtype=torch.float16):
39-
out2 = torch.log(
40-
out2
41-
) # fp32 because Pytorch Autocast requires `log` to be in fp32
42-
return x, out, out2
31+
x = self.fc1(x)
32+
out = torch.log(
33+
torch.abs(x) + 1
34+
) # log is fp32 due to Pytorch Autocast requirements
35+
return out
4336

4437

4538
if __name__ == "__main__":
46-
model = AutocastExample().cuda().eval()
47-
inputs = (
48-
torch.randn((1, 3, 32, 32), dtype=torch.float32, device="cuda"),
49-
torch.randn((1,), dtype=torch.float16, device="cuda"),
50-
)
39+
model = MixedPytorchAutocastModel().cuda().eval()
40+
inputs = (torch.randn((8, 3, 32, 32), dtype=torch.float32, device="cuda"),)
41+
ep = torch.export.export(model, inputs)
5142
calibration_dataloader = torch.utils.data.DataLoader(
52-
torch.utils.data.TensorDataset(*inputs), batch_size=1, shuffle=False
43+
torch.utils.data.TensorDataset(*inputs), batch_size=2, shuffle=False
5344
)
5445

55-
ep = torch.export.export(model, inputs)
56-
5746
with torch_tensorrt.dynamo.Debugger(
5847
"graphs",
5948
logging_dir=".",
6049
engine_builder_monitor=False,
6150
):
62-
trt_mod = torch_tensorrt.compile(
51+
trt_autocast_mod = torch_tensorrt.compile(
6352
ep.module(),
6453
arg_inputs=inputs,
6554
min_block_size=1,
@@ -78,4 +67,4 @@ def forward(self, x, y):
7867
autocast_calibration_dataloader=calibration_dataloader,
7968
)
8069

81-
trt_out = trt_mod(*inputs)
70+
autocast_outs = trt_autocast_mod(*inputs)

py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323

2424
pre_lowering_pass_list = [
2525
remove_detach,
26+
remove_assert_nodes,
2627
rule_based_autocast,
27-
remove_assert_nodes, # rule_based_autocast might insert assert nodes
2828
]
2929

3030
post_lowering_pass_list = [

py/torch_tensorrt/dynamo/lowering/passes/nodeclassifier.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,13 @@ def __init__(self, disabled_node_name_regex):
5454

5555
def _check_inner(self, node):
5656
stack = node.meta.get("nn_module_stack")
57-
node_name = next(reversed(stack), "").split("__")[
58-
-1
59-
] # get the user specified name of the node
57+
try:
58+
# get the user specified name of the node
59+
node_name = stack.get(next(reversed(stack)), [""])[0]
60+
except Exception as e:
61+
raise ValueError(
62+
f"Failed to get the user specified name of the node {node.name} because {e}. Please file a bug with Torch-TensorRT."
63+
)
6064
return any(
6165
re.match(regex, node_name) for regex in self.disabled_node_name_regex
6266
)

py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,7 @@ def run_node(self, n: torch.fx.Node) -> Any:
9393
out = super().run_node(n)
9494
if n.op == "call_function" and n.target not in excluded_ops:
9595
if not isinstance(out, torch.Tensor):
96-
raise ValueError(
97-
f"Please file a bug with Torch-TensorRT because it expects a torch.Tensor but got {type(out)} for node {n.name}."
98-
)
96+
return out
9997
if n.name in intermediate_node_outputs:
10098
intermediate_node_outputs[n.name] = torch.cat(
10199
[intermediate_node_outputs[n.name], out], dim=0

py/torch_tensorrt/dynamo/lowering/passes/rule_based_autocast.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -66,15 +66,19 @@ def _cast_all_tensor_args_to_dtype(
6666
"""
6767
if isinstance(arg, torch.fx.Node) and is_tensor_node(arg):
6868
val = arg.meta.get("val", None)
69-
with gm.graph.inserting_before(node):
70-
cast = gm.graph.call_function(
71-
torch.ops.aten.to.dtype, args=(arg, dtype)
72-
)
73-
7469
if isinstance(val, torch.Tensor):
75-
arg.meta["val"] = val.to(dtype)
76-
cast.meta.update(arg.meta)
77-
return cast
70+
if val.dtype == dtype:
71+
return arg
72+
else:
73+
with gm.graph.inserting_before(node):
74+
cast = gm.graph.call_function(
75+
torch.ops.aten.to.dtype, args=(arg, dtype)
76+
)
77+
# copy the meta of the original tensor to the casted tensor
78+
cast.meta.update(arg.meta)
79+
# update the dtype of the casted tensor
80+
cast.meta["val"] = cast.meta["val"].to(dtype)
81+
return cast
7882
elif isinstance(arg, (tuple, list)):
7983
return type(arg)(
8084
_cast_all_tensor_args_to_dtype(node, a, dtype) for a in arg
@@ -102,13 +106,15 @@ def _cast_all_tensor_args_to_dtype(
102106
node.kwargs = _cast_all_tensor_args_to_dtype(
103107
node, node.kwargs, autocast_low_precision_type
104108
)
109+
node.meta["val"] = node.meta["val"].to(autocast_low_precision_type)
105110
elif node.name in high_precision_nodes:
106111
node.args = _cast_all_tensor_args_to_dtype(
107112
node, node.args, autocast_high_precision_type
108113
)
109114
node.kwargs = _cast_all_tensor_args_to_dtype(
110115
node, node.kwargs, autocast_high_precision_type
111116
)
117+
node.meta["val"] = node.meta["val"].to(autocast_high_precision_type)
112118

113119
gm = clean_up_graph_after_modifications(gm)
114120
logger.debug("Graph after Autocast based on the rules:\n%s", gm.graph)

py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,10 @@ def forward(
154154
+ contiguous_inputs[i + 1 :]
155155
)
156156

157+
assert (
158+
contiguous_inputs[i].dtype == inputs[i].dtype
159+
), f"Dtype mismatch for {i}th input. Expect {inputs[i].dtype}, got {contiguous_inputs[i].dtype}."
160+
157161
if need_cudagraphs_record:
158162
# If cudagraphs is enabled, this memory is reserved for future cudagraph runs
159163
# Clone is required to avoid re-using user-provided GPU memory

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,11 @@ def setup_engine(self) -> None:
275275
len(self.input_names) + len(self.output_names)
276276
)
277277

278+
self.input_dtypes = [
279+
dtype._from(self.engine.get_tensor_dtype(input_name))
280+
for input_name in self.input_names
281+
]
282+
278283
self.input_shapes = [
279284
self.engine.get_tensor_shape(input_name) for input_name in self.input_names
280285
]
@@ -367,6 +372,10 @@ def setup_input_tensors(
367372
+ contiguous_inputs[i + 1 :]
368373
)
369374

375+
assert (
376+
contiguous_inputs[i].dtype == self.input_dtypes[i]
377+
), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {contiguous_inputs[i].dtype}."
378+
370379
if need_cudagraphs_record:
371380
# If cudagraphs is enabled, this memory is reserved for future cudagraph runs
372381
# Clone is required to avoid re-using user-provided GPU memory

0 commit comments

Comments
 (0)