Skip to content

Commit a3b57b2

Browse files
committed
fix error
1 parent a7a8039 commit a3b57b2

File tree

1 file changed

+25
-5
lines changed

1 file changed

+25
-5
lines changed

tests/py/dynamo/lowering/test_decompositions.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1855,7 +1855,11 @@ def forward(self, query, key, value, attn_mask=None):
18551855

18561856
# Validate that the results between Torch and Torch-TRT are similar
18571857
trt_model = torch_tensorrt.dynamo.compile(
1858-
exported_program, inputs, enabled_precisions={torch.half}, min_block_size=1
1858+
exported_program,
1859+
inputs,
1860+
enabled_precisions={torch.half},
1861+
min_block_size=1,
1862+
use_explicit_typing=False,
18591863
)
18601864
torch.testing.assert_close(
18611865
trt_model(*inputs),
@@ -1951,7 +1955,11 @@ def forward(self, query, key, value, attn_mask=None):
19511955

19521956
# Validate that the results between Torch and Torch-TRT are similar
19531957
trt_model = torch_tensorrt.dynamo.compile(
1954-
exported_program, inputs, enabled_precisions={torch.half}, min_block_size=1
1958+
exported_program,
1959+
inputs,
1960+
enabled_precisions={torch.half},
1961+
min_block_size=1,
1962+
use_explicit_typing=False,
19551963
)
19561964

19571965
inputs = [
@@ -2017,7 +2025,11 @@ def forward(self, query, key, value):
20172025

20182026
# Validate that the results between Torch and Torch-TRT are similar
20192027
trt_model = torch_tensorrt.dynamo.compile(
2020-
exported_program, inputs, enabled_precisions={torch.half}, min_block_size=1
2028+
exported_program,
2029+
inputs,
2030+
enabled_precisions={torch.half},
2031+
min_block_size=1,
2032+
use_explicit_typing=False,
20212033
)
20222034
torch.testing.assert_close(
20232035
trt_model(*inputs),
@@ -2078,7 +2090,11 @@ def forward(self, query, key, value, attn_bias=None):
20782090

20792091
# Validate that the results between Torch and Torch-TRT are similar
20802092
trt_model = torch_tensorrt.dynamo.compile(
2081-
exported_program, inputs, enabled_precisions={torch.half}, min_block_size=1
2093+
exported_program,
2094+
inputs,
2095+
enabled_precisions={torch.half},
2096+
min_block_size=1,
2097+
use_explicit_typing=False,
20822098
)
20832099
torch.testing.assert_close(
20842100
trt_model(*inputs),
@@ -2139,7 +2155,11 @@ def forward(self, query, key, value, attn_bias=None):
21392155

21402156
# Validate that the results between Torch and Torch-TRT are similar
21412157
trt_model = torch_tensorrt.dynamo.compile(
2142-
exported_program, inputs, enabled_precisions={torch.half}, min_block_size=1
2158+
exported_program,
2159+
inputs,
2160+
enabled_precisions={torch.half},
2161+
min_block_size=1,
2162+
use_explicit_typing=False,
21432163
)
21442164
torch.testing.assert_close(
21452165
trt_model(*inputs),

0 commit comments

Comments
 (0)