From 45952652bf0487ab10a281a0620709888aa84f16 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Wed, 1 Oct 2025 12:24:32 -0700 Subject: [PATCH 1/9] chore: updates Signed-off-by: Dheeraj Peri --- tools/llm/README.md | 2 +- tools/llm/test_trt_sdpa.py | 35 +++++ tools/llm/torchtrt_ext/register_sdpa.py | 3 +- tools/llm/torchtrt_ext/trt_sdpa_converter.py | 157 +++++++++++++++++++ 4 files changed, 194 insertions(+), 3 deletions(-) create mode 100644 tools/llm/test_trt_sdpa.py create mode 100644 tools/llm/torchtrt_ext/trt_sdpa_converter.py diff --git a/tools/llm/README.md b/tools/llm/README.md index 05a1e3cc60..5d807384b1 100644 --- a/tools/llm/README.md +++ b/tools/llm/README.md @@ -38,7 +38,7 @@ We have officially verified support for the following models: #### Text-only LLMs: `run_llm.py` ```bash -python run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --prompt "What is parallel programming?" --precision FP16 --num_tokens 128 --cache static_v2 --benchmark +python run_llm.py --model Qwen/Qwen2.5-0.5B-Instruct --prompt "What is parallel programming?" --precision FP16 --num_tokens 128 --cache static_v2 --benchmark ``` #### Vision Language Models: `run_vlm.py` diff --git a/tools/llm/test_trt_sdpa.py b/tools/llm/test_trt_sdpa.py new file mode 100644 index 0000000000..bf9e1165a5 --- /dev/null +++ b/tools/llm/test_trt_sdpa.py @@ -0,0 +1,35 @@ +import torch +import torch_tensorrt +from torchtrt_ext import register_sdpa + + +class ModelNoCache(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, q, k, v): + return torch._C._nn.scaled_dot_product_attention( + q, k, v, dropout_p=0.0, is_causal=True + ) + + +model = ModelNoCache().cuda().eval().to(torch.float16) +q = torch.randn(1, 32, 6, 64).cuda().to(torch.float16) +k = torch.randn(1, 32, 6, 64).cuda().to(torch.float16) +v = torch.randn(1, 32, 6, 64).cuda().to(torch.float16) +pyt_outputs = model(q, k, v) +register_sdpa.enable_sdpa_converter("default", None) +ep = torch.export.export(model, (q, k, v), strict=False) + +with torch_tensorrt.dynamo.Debugger(): + trt_gm = torch_tensorrt.dynamo.compile( + ep, + inputs=(q, k, v), + enabled_precisions={torch.float16}, + min_block_size=1, + disable_tf32=True, + ) + +trt_outputs = trt_gm(q, k, v) + +print("Diff between pyt and trt: ", torch.mean(torch.abs(pyt_outputs - trt_outputs))) diff --git a/tools/llm/torchtrt_ext/register_sdpa.py b/tools/llm/torchtrt_ext/register_sdpa.py index a82384fda9..57bf5d6537 100644 --- a/tools/llm/torchtrt_ext/register_sdpa.py +++ b/tools/llm/torchtrt_ext/register_sdpa.py @@ -15,7 +15,7 @@ ) from transformers import AutoConfig, Gemma3TextConfig -from .sdpa_converter import * +from .trt_sdpa_converter import * logger = logging.getLogger(__name__) @@ -138,7 +138,6 @@ def _process_sdpa_node( dropout_p, is_causal, ) - # Create a new node with torch.nn.functional.scaled_dot_product_attention with gm.graph.inserting_after(node): new_node = gm.graph.call_function( diff --git a/tools/llm/torchtrt_ext/trt_sdpa_converter.py b/tools/llm/torchtrt_ext/trt_sdpa_converter.py new file mode 100644 index 0000000000..b0cf0a7042 --- /dev/null +++ b/tools/llm/torchtrt_ext/trt_sdpa_converter.py @@ -0,0 +1,157 @@ +import logging +import math +from typing import Any, Dict, Optional, Tuple, Union + +import numpy as np +import tensorrt as trt +import torch +import torch_tensorrt +from torch.fx.node import Target +from torch_tensorrt._enums import dtype +from torch_tensorrt.dynamo.conversion import impl +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.dynamo.conversion.converter_utils import ( + SourceIR, + cast_trt_tensor, + get_trt_tensor, +) +from torch_tensorrt.dynamo.types import TRTTensor + +logger = logging.getLogger(__name__) + + +def tril( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + row: TRTTensor, + col: TRTTensor, + sliding_window_size: Optional[int] = None, +) -> TRTTensor: + """ + Create a lower triangular mask tensor for attention mechanisms. + + This function generates a lower triangular mask that can be used in attention + operations to enforce causal attention (each position can only attend to itself + and previous positions). It optionally supports sliding window attention by + limiting the attention span to a specified window size. + + The function creates the mask by: + 1. Generating row and column index tensors + 2. Computing the difference between row and column indices + 3. Creating a mask where row >= col (lower triangular) + 4. Optionally applying sliding window constraints + + Args: + ctx: TensorRT conversion context for managing the conversion process + target: Target operation identifier (usually the operation being converted) + source_ir: Source IR type (e.g., ATEN, TRT) - can be None + name: Base name for generated TensorRT operations (will be extended with suffixes) + row: Tensor representing the number of rows (sequence length dimension) + col: Tensor representing the number of columns (sequence length dimension) + sliding_window_size: Optional sliding window size for attention span limitation. + If None, creates a full lower triangular mask. + If specified, creates a sliding window mask where each position + can only attend to positions within the window. + + Returns: + TRTTensor: A boolean mask tensor with shape [batch, heads, seq_len, seq_len] + where True values indicate allowed attention positions. + + Example: + # Create a full lower triangular mask for causal attention + mask = tril(ctx, target, source_ir, "causal_mask", seq_len, seq_len) + + # Create a sliding window mask with window size 3 + mask = tril(ctx, target, source_ir, "sliding_mask", seq_len, seq_len, 3) + + Mask Examples: + Without sliding window (sliding_window_size=None): + For seq_len=5, returns: + [[ True, False, False, False, False], + [ True, True, False, False, False], + [ True, True, True, False, False], + [ True, True, True, True, False], + [ True, True, True, True, True]] + + With sliding window (sliding_window_size=3): + For seq_len=5, returns: + [[ True, False, False, False, False], + [ True, True, False, False, False], + [ True, True, True, False, False], + [False, True, True, True, False], + [False, False, True, True, True]] + + Note: + This function is specifically designed for attention mechanisms in transformer + models and is used internally by the scaled_dot_product_attention converter. + The sliding window functionality is particularly useful for models like Gemma3 + that use sliding window attention to reduce computational complexity. + """ + row_arange_tensor = impl.arange.arange( + ctx, target, source_ir, name + "_arange_row", start=0, end=row, step=1 + ) + col_arange_tensor = impl.arange.arange( + ctx, target, source_ir, name + "_arange_col", start=0, end=col, step=1 + ) + row_arange_tensor = impl.unsqueeze.unsqueeze( + ctx, target, source_ir, name + "_unsqueeze_row", row_arange_tensor, -1 + ) + col_arange_tensor = impl.unsqueeze.unsqueeze( + ctx, target, source_ir, name + "_unsqueeze_col", col_arange_tensor, 0 + ) + # sub will return the following mask tensor: + # [[0, -1, -2, -3], + # [1, 0, -1, -2], + # [2, 1, 0, -1], + # [3, 2, 1, 0]] + mask = impl.elementwise.sub( + ctx, target, source_ir, name + "_sub", row_arange_tensor, col_arange_tensor + ) + ge_0_mask = impl.elementwise.ge(ctx, target, source_ir, name + "_ge_0", mask, 0.0) + if sliding_window_size is None: + # return the following lower triangular mask includes the main diagonal: + # 0 ■ ⬚ ⬚ ⬚ ⬚ tensor([[[[ True, False, False, False, False], + # 1 ■ ■ ⬚ ⬚ ⬚ [ True, True, False, False, False], + # 2 ■ ■ ■ ⬚ ⬚ [ True, True, True, False, False], + # 3 ■ ■ ■ ■ ⬚ [ True, True, True, True, False], + # 4 ■ ■ ■ ■ ■ [ True, True, True, True, True]]]]) + return ge_0_mask + + lt_window_mask = impl.elementwise.lt( + ctx, target, source_ir, name + "_lt_window_size", mask, sliding_window_size + ) + mask = impl.elementwise.logical_and( + ctx, target, source_ir, name + "_logical_and", ge_0_mask, lt_window_mask + ) + # return the following mask if sliding_window_size is 3: + # 0 ■ ⬚ ⬚ ⬚ ⬚ tensor([[[[ True, False, False, False, False], + # 1 ■ ■ ⬚ ⬚ ⬚ [ True, True, False, False, False], + # 2 ■ ■ ■ ⬚ ⬚ [ True, True, True, False, False], + # 3 ⬚ ■ ■ ■ ⬚ [False, True, True, True, False], + # 4 ⬚ ⬚ ■ ■ ■ [False, False, True, True,True]]]]) + return mask + + +@torch_tensorrt.dynamo.conversion.dynamo_tensorrt_converter( + torch.nn.functional.scaled_dot_product_attention, + enabled=True, + supports_dynamic_shapes=True, +) +def scaled_dot_product_attention( + ctx: torch_tensorrt.dynamo.conversion.ConversionContext, + target: Target, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + name: str, +) -> TRTTensor: + + # always create our own attn_mask + query, key, value, _, dropout_p, is_causal = args + breakpoint() + attention_layer = ctx.net.add_attention( + query, key, value, trt.AttentionNormalizationOp.SOFTMAX, False + ) + breakpoint() + return attention_layer.get_output(0) From 4f92b93d85441c28546e0f43de6a854eff846f1d Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Wed, 1 Oct 2025 13:48:20 -0700 Subject: [PATCH 2/9] chore: updates Signed-off-by: Dheeraj Peri --- tools/llm/test_trt_sdpa.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tools/llm/test_trt_sdpa.py b/tools/llm/test_trt_sdpa.py index bf9e1165a5..702cd5bac3 100644 --- a/tools/llm/test_trt_sdpa.py +++ b/tools/llm/test_trt_sdpa.py @@ -25,9 +25,10 @@ def forward(self, q, k, v): trt_gm = torch_tensorrt.dynamo.compile( ep, inputs=(q, k, v), - enabled_precisions={torch.float16}, + enabled_precisions={torch.float32}, min_block_size=1, disable_tf32=True, + use_explicit_typing=True, ) trt_outputs = trt_gm(q, k, v) From 7a5f592d8ec8096674b10eee9e49278f563ac185 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Wed, 1 Oct 2025 14:18:55 -0700 Subject: [PATCH 3/9] chore: updates Signed-off-by: Dheeraj Peri --- tools/llm/torchtrt_ext/trt_sdpa_converter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/llm/torchtrt_ext/trt_sdpa_converter.py b/tools/llm/torchtrt_ext/trt_sdpa_converter.py index b0cf0a7042..ec3264ce8d 100644 --- a/tools/llm/torchtrt_ext/trt_sdpa_converter.py +++ b/tools/llm/torchtrt_ext/trt_sdpa_converter.py @@ -149,9 +149,9 @@ def scaled_dot_product_attention( # always create our own attn_mask query, key, value, _, dropout_p, is_causal = args - breakpoint() + attention_layer = ctx.net.add_attention( query, key, value, trt.AttentionNormalizationOp.SOFTMAX, False ) - breakpoint() + return attention_layer.get_output(0) From 8c5d6599ebce41e3bf2e255069b10079d74c6056 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Tue, 28 Oct 2025 23:18:46 -0700 Subject: [PATCH 4/9] chore: updates Signed-off-by: Dheeraj Peri --- tools/llm/test_trt_sdpa.py | 16 ++++++++++++--- tools/llm/torchtrt_ext/trt_sdpa_converter.py | 21 +++++++++++++++++++- 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/tools/llm/test_trt_sdpa.py b/tools/llm/test_trt_sdpa.py index 702cd5bac3..f513cb2087 100644 --- a/tools/llm/test_trt_sdpa.py +++ b/tools/llm/test_trt_sdpa.py @@ -9,7 +9,7 @@ def __init__(self): def forward(self, q, k, v): return torch._C._nn.scaled_dot_product_attention( - q, k, v, dropout_p=0.0, is_causal=True + q, k, v, dropout_p=0.0, is_causal=True, scale=1.0 ) @@ -18,8 +18,12 @@ def forward(self, q, k, v): k = torch.randn(1, 32, 6, 64).cuda().to(torch.float16) v = torch.randn(1, 32, 6, 64).cuda().to(torch.float16) pyt_outputs = model(q, k, v) + register_sdpa.enable_sdpa_converter("default", None) -ep = torch.export.export(model, (q, k, v), strict=False) +seq_len_query = torch.export.Dim("seq_len_query", min=2, max=128) +seq_len_key = torch.export.Dim("seq_len_key", min=2, max=128) +dynamic_shapes = {"q": {2: seq_len_key}, "k": {2: seq_len_key}, "v": {2: seq_len_key}} +ep = torch.export.export(model, (q, k, v), dynamic_shapes=dynamic_shapes, strict=False) with torch_tensorrt.dynamo.Debugger(): trt_gm = torch_tensorrt.dynamo.compile( @@ -32,5 +36,11 @@ def forward(self, q, k, v): ) trt_outputs = trt_gm(q, k, v) - print("Diff between pyt and trt: ", torch.mean(torch.abs(pyt_outputs - trt_outputs))) +# breakpoint() +# q = torch.randn(1, 32, 1, 64).cuda().to(torch.float16) +# k = torch.randn(1, 32, 10, 64).cuda().to(torch.float16) +# v = torch.randn(1, 32, 10, 64).cuda().to(torch.float16) +# pyt_outputs = model(q, k, v) +# trt_outputs = trt_gm(q, k, v) +# print("Diff between pyt and trt: ", torch.mean(torch.abs(pyt_outputs - trt_outputs))) diff --git a/tools/llm/torchtrt_ext/trt_sdpa_converter.py b/tools/llm/torchtrt_ext/trt_sdpa_converter.py index ec3264ce8d..711bc0442d 100644 --- a/tools/llm/torchtrt_ext/trt_sdpa_converter.py +++ b/tools/llm/torchtrt_ext/trt_sdpa_converter.py @@ -14,6 +14,7 @@ SourceIR, cast_trt_tensor, get_trt_tensor, + prepend_ones, ) from torch_tensorrt.dynamo.types import TRTTensor @@ -146,12 +147,30 @@ def scaled_dot_product_attention( kwargs: Dict[str, Any], name: str, ) -> TRTTensor: + source_ir = SourceIR.ATEN # always create our own attn_mask - query, key, value, _, dropout_p, is_causal = args + query, key, value, mask, dropout_p, is_causal = args + breakpoint() + # L, S = query.shape[-2], key.shape[-2] + query_len = impl.shape.shape(ctx, target, source_ir, name + "_query_len", query, -2) + key_len = impl.shape.shape(ctx, target, source_ir, name + "_key_len", query, -2) + mask_tensor = tril( + ctx, + target, + source_ir, + name + "_tril", + query_len, + key_len, + ) + + diff = len(query.shape) - len(mask_tensor.shape) + mask_tensor = prepend_ones(ctx, mask_tensor, name + "_prepend_ones", diff) attention_layer = ctx.net.add_attention( query, key, value, trt.AttentionNormalizationOp.SOFTMAX, False ) + if is_causal: + attention_layer.mask = mask_tensor return attention_layer.get_output(0) From 58c3da51793f3bc48c53ceec70c783eb67d06db5 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Wed, 29 Oct 2025 14:17:39 -0700 Subject: [PATCH 5/9] chore: updates Signed-off-by: Dheeraj Peri --- tools/llm/run_llm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tools/llm/run_llm.py b/tools/llm/run_llm.py index 1531c30622..a50babcca5 100644 --- a/tools/llm/run_llm.py +++ b/tools/llm/run_llm.py @@ -54,6 +54,7 @@ def get_model(args): args.model, use_cache=False, attn_implementation="sdpa", + num_hidden_layers=1, ) .eval() .cuda() @@ -108,7 +109,7 @@ def compile_torchtrt(model, input_ids, args): else: enabled_precisions = {torch.float32} - with torch_tensorrt.logging.debug() if args.debug else nullcontext(): + with torch_tensorrt.dynamo.Debugger() if args.debug else nullcontext(): trt_model = torch_tensorrt.dynamo.compile( ep, inputs=[input_ids, position_ids], From 5ebd94f1ec07468ea32c637f2e8fdb10991621a1 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Fri, 14 Nov 2025 10:51:14 -0800 Subject: [PATCH 6/9] chore: updates Signed-off-by: Dheeraj Peri --- tests/py/dynamo/llm/test_llm_models.py | 2 +- tools/llm/torchtrt_ext/trt_sdpa_converter.py | 16 ++++++++++++++-- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/tests/py/dynamo/llm/test_llm_models.py b/tests/py/dynamo/llm/test_llm_models.py index 73811572f9..f47bc9f02a 100644 --- a/tests/py/dynamo/llm/test_llm_models.py +++ b/tests/py/dynamo/llm/test_llm_models.py @@ -14,7 +14,7 @@ @pytest.mark.unit -@pytest.mark.parametrize("precision", ["FP16", "BF16", "FP32"]) +@pytest.mark.parametrize("precision", ["FP16"]) # "BF16", "FP32" def test_llm_decoder_layer(precision): if torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx and precision == "BF16": pytest.skip("TensorRT-RTX does not support bfloat16, skipping test") diff --git a/tools/llm/torchtrt_ext/trt_sdpa_converter.py b/tools/llm/torchtrt_ext/trt_sdpa_converter.py index 711bc0442d..cd7c9e8b4f 100644 --- a/tools/llm/torchtrt_ext/trt_sdpa_converter.py +++ b/tools/llm/torchtrt_ext/trt_sdpa_converter.py @@ -151,7 +151,16 @@ def scaled_dot_product_attention( # always create our own attn_mask query, key, value, mask, dropout_p, is_causal = args - breakpoint() + + num_attention_heads = -1 # key.shape[1] + num_attention_heads_tensor = ( + -1 + ) # get_trt_tensor(ctx, num_attention_heads, name + "_num_attention_heads") + + # Reshape key and value tensors to have -1 in the attn_heads dimension due to TRT MHA API restriction. + # key = impl.shuffle.reshape(ctx, target, source_ir, name + "_key_reshape", input=key, shape=[key.shape[0], num_attention_heads_tensor, key.shape[2], key.shape[3]]) + # value = impl.shuffle.reshape(ctx, target, source_ir, name + "_value_reshape", input=value, shape=[value.shape[0], num_attention_heads_tensor, value.shape[2], value.shape[3]]) + # breakpoint() # L, S = query.shape[-2], key.shape[-2] query_len = impl.shape.shape(ctx, target, source_ir, name + "_query_len", query, -2) key_len = impl.shape.shape(ctx, target, source_ir, name + "_key_len", query, -2) @@ -170,7 +179,10 @@ def scaled_dot_product_attention( attention_layer = ctx.net.add_attention( query, key, value, trt.AttentionNormalizationOp.SOFTMAX, False ) + if is_causal: attention_layer.mask = mask_tensor - return attention_layer.get_output(0) + attention_output = attention_layer.get_output(0) + + return attention_output From d7441318c732326936fb7089ef0311417101d1d7 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Tue, 25 Nov 2025 12:35:25 -0800 Subject: [PATCH 7/9] chore: updates Signed-off-by: Dheeraj Peri --- tools/llm/run_llm.py | 1 - tools/llm/torchtrt_ext/trt_sdpa_converter.py | 24 ++++++++++++-------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/tools/llm/run_llm.py b/tools/llm/run_llm.py index a50babcca5..69a43ff07e 100644 --- a/tools/llm/run_llm.py +++ b/tools/llm/run_llm.py @@ -54,7 +54,6 @@ def get_model(args): args.model, use_cache=False, attn_implementation="sdpa", - num_hidden_layers=1, ) .eval() .cuda() diff --git a/tools/llm/torchtrt_ext/trt_sdpa_converter.py b/tools/llm/torchtrt_ext/trt_sdpa_converter.py index cd7c9e8b4f..c5c7f0d566 100644 --- a/tools/llm/torchtrt_ext/trt_sdpa_converter.py +++ b/tools/llm/torchtrt_ext/trt_sdpa_converter.py @@ -152,18 +152,20 @@ def scaled_dot_product_attention( # always create our own attn_mask query, key, value, mask, dropout_p, is_causal = args - num_attention_heads = -1 # key.shape[1] - num_attention_heads_tensor = ( - -1 - ) # get_trt_tensor(ctx, num_attention_heads, name + "_num_attention_heads") - - # Reshape key and value tensors to have -1 in the attn_heads dimension due to TRT MHA API restriction. - # key = impl.shuffle.reshape(ctx, target, source_ir, name + "_key_reshape", input=key, shape=[key.shape[0], num_attention_heads_tensor, key.shape[2], key.shape[3]]) - # value = impl.shuffle.reshape(ctx, target, source_ir, name + "_value_reshape", input=value, shape=[value.shape[0], num_attention_heads_tensor, value.shape[2], value.shape[3]]) - # breakpoint() + # The exported graph of LLM models have -1 in the attention heads dimension for the query tensor. This value is static for key and value tensors though. + # TODO: We assume that the attention heads dimension is the same for key and value and query tensors. We can implement a lowering pass + # that reads number of attention heads from model config similar to gemma3. For now, we directly use the key.shape[1] as the attention heads dimension. + query = impl.shuffle.reshape( + ctx, + target, + source_ir, + name + "_query_reshape", + input=query, + shape=[query.shape[0], key.shape[1], query.shape[2], query.shape[3]], + ) # L, S = query.shape[-2], key.shape[-2] query_len = impl.shape.shape(ctx, target, source_ir, name + "_query_len", query, -2) - key_len = impl.shape.shape(ctx, target, source_ir, name + "_key_len", query, -2) + key_len = impl.shape.shape(ctx, target, source_ir, name + "_key_len", key, -2) mask_tensor = tril( ctx, target, @@ -180,6 +182,8 @@ def scaled_dot_product_attention( query, key, value, trt.AttentionNormalizationOp.SOFTMAX, False ) + assert attention_layer is not None, "attention layer is None" + if is_causal: attention_layer.mask = mask_tensor From 13faeb664d4f7e6e147c81c1954a0de7b3064c18 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Tue, 25 Nov 2025 12:40:18 -0800 Subject: [PATCH 8/9] chore: updates Signed-off-by: Dheeraj Peri --- tests/py/dynamo/llm/test_llm_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/py/dynamo/llm/test_llm_models.py b/tests/py/dynamo/llm/test_llm_models.py index f47bc9f02a..73811572f9 100644 --- a/tests/py/dynamo/llm/test_llm_models.py +++ b/tests/py/dynamo/llm/test_llm_models.py @@ -14,7 +14,7 @@ @pytest.mark.unit -@pytest.mark.parametrize("precision", ["FP16"]) # "BF16", "FP32" +@pytest.mark.parametrize("precision", ["FP16", "BF16", "FP32"]) def test_llm_decoder_layer(precision): if torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx and precision == "BF16": pytest.skip("TensorRT-RTX does not support bfloat16, skipping test") From 239397c8ab41d80fbe2389539bc07b4a334c5918 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Tue, 25 Nov 2025 12:49:35 -0800 Subject: [PATCH 9/9] chore: updates Signed-off-by: Dheeraj Peri --- tools/llm/run_llm.py | 1 + tools/llm/torchtrt_ext/trt_sdpa_converter.py | 1 + 2 files changed, 2 insertions(+) diff --git a/tools/llm/run_llm.py b/tools/llm/run_llm.py index 69a43ff07e..81b92f2013 100644 --- a/tools/llm/run_llm.py +++ b/tools/llm/run_llm.py @@ -107,6 +107,7 @@ def compile_torchtrt(model, input_ids, args): use_fp32_acc = False else: enabled_precisions = {torch.float32} + use_explicit_typing = True with torch_tensorrt.dynamo.Debugger() if args.debug else nullcontext(): trt_model = torch_tensorrt.dynamo.compile( diff --git a/tools/llm/torchtrt_ext/trt_sdpa_converter.py b/tools/llm/torchtrt_ext/trt_sdpa_converter.py index c5c7f0d566..36d37d8316 100644 --- a/tools/llm/torchtrt_ext/trt_sdpa_converter.py +++ b/tools/llm/torchtrt_ext/trt_sdpa_converter.py @@ -181,6 +181,7 @@ def scaled_dot_product_attention( attention_layer = ctx.net.add_attention( query, key, value, trt.AttentionNormalizationOp.SOFTMAX, False ) + attention_layer.decomposable = True assert attention_layer is not None, "attention layer is None"