@@ -1855,7 +1855,11 @@ def forward(self, query, key, value, attn_mask=None):
18551855
18561856 # Validate that the results between Torch and Torch-TRT are similar
18571857 trt_model = torch_tensorrt .dynamo .compile (
1858- exported_program , inputs , enabled_precisions = {torch .half }, min_block_size = 1
1858+ exported_program ,
1859+ inputs ,
1860+ enabled_precisions = {torch .half },
1861+ min_block_size = 1 ,
1862+ use_explicit_typing = False ,
18591863 )
18601864 torch .testing .assert_close (
18611865 trt_model (* inputs ),
@@ -1951,7 +1955,11 @@ def forward(self, query, key, value, attn_mask=None):
19511955
19521956 # Validate that the results between Torch and Torch-TRT are similar
19531957 trt_model = torch_tensorrt .dynamo .compile (
1954- exported_program , inputs , enabled_precisions = {torch .half }, min_block_size = 1
1958+ exported_program ,
1959+ inputs ,
1960+ enabled_precisions = {torch .half },
1961+ min_block_size = 1 ,
1962+ use_explicit_typing = False ,
19551963 )
19561964
19571965 inputs = [
@@ -2017,7 +2025,11 @@ def forward(self, query, key, value):
20172025
20182026 # Validate that the results between Torch and Torch-TRT are similar
20192027 trt_model = torch_tensorrt .dynamo .compile (
2020- exported_program , inputs , enabled_precisions = {torch .half }, min_block_size = 1
2028+ exported_program ,
2029+ inputs ,
2030+ enabled_precisions = {torch .half },
2031+ min_block_size = 1 ,
2032+ use_explicit_typing = False ,
20212033 )
20222034 torch .testing .assert_close (
20232035 trt_model (* inputs ),
@@ -2078,7 +2090,11 @@ def forward(self, query, key, value, attn_bias=None):
20782090
20792091 # Validate that the results between Torch and Torch-TRT are similar
20802092 trt_model = torch_tensorrt .dynamo .compile (
2081- exported_program , inputs , enabled_precisions = {torch .half }, min_block_size = 1
2093+ exported_program ,
2094+ inputs ,
2095+ enabled_precisions = {torch .half },
2096+ min_block_size = 1 ,
2097+ use_explicit_typing = False ,
20822098 )
20832099 torch .testing .assert_close (
20842100 trt_model (* inputs ),
@@ -2139,7 +2155,11 @@ def forward(self, query, key, value, attn_bias=None):
21392155
21402156 # Validate that the results between Torch and Torch-TRT are similar
21412157 trt_model = torch_tensorrt .dynamo .compile (
2142- exported_program , inputs , enabled_precisions = {torch .half }, min_block_size = 1
2158+ exported_program ,
2159+ inputs ,
2160+ enabled_precisions = {torch .half },
2161+ min_block_size = 1 ,
2162+ use_explicit_typing = False ,
21432163 )
21442164 torch .testing .assert_close (
21452165 trt_model (* inputs ),
0 commit comments