Skip to content

Commit 32b45ae

Browse files
committed
addressing review comments
1 parent 5fd75af commit 32b45ae

File tree

6 files changed

+134
-70
lines changed

6 files changed

+134
-70
lines changed

examples/distributed_inference/tensor_parallel_initialize_dist.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ def initialize_logger(
4848

4949
# console handler
5050
ch = logging.StreamHandler()
51-
ch.setLevel(console_level) # Console handler controls what's printed in console output
51+
ch.setLevel(
52+
console_level
53+
) # Console handler controls what's printed in console output
5254
ch.setFormatter(logging.Formatter(f"[Rank {rank}] %(levelname)s: %(message)s"))
5355
logger.addHandler(ch)
5456

@@ -123,21 +125,12 @@ def initialize_distributed_env(
123125
torch.cuda.set_device(device_id)
124126

125127
# Set C++ TensorRT runtime log level based on most verbose handler
126-
# this is similar to set_log_level()
128+
# Use the most verbose level to ensure all important logs are captured
127129
cpp_level = min(file_level_int, console_level_int)
128130
try:
129-
import tensorrt as trt
130-
from torch_tensorrt._features import ENABLED_FEATURES
131-
132-
if ENABLED_FEATURES.torch_tensorrt_runtime:
133-
if cpp_level == logging.DEBUG:
134-
torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.VERBOSE))
135-
elif cpp_level == logging.INFO:
136-
torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.INFO))
137-
elif cpp_level == logging.WARNING:
138-
torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.WARNING))
139-
else:
140-
torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.ERROR))
131+
import torch_tensorrt.logging as torchtrt_logging
132+
133+
torchtrt_logging.set_level(cpp_level)
141134
except Exception as e:
142135
logger.warning(f"Could not set C++ TensorRT log level: {e}")
143136

py/torch_tensorrt/_features.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,25 @@ def _enabled_features_str() -> str:
7878
return out_str
7979

8080

81+
# Inline helper functions for checking feature availability
82+
def has_torch_tensorrt_runtime() -> bool:
83+
"""Check if Torch-TensorRT C++ runtime is available.
84+
85+
Returns:
86+
bool: True if libtorchtrt_runtime.so or libtorchtrt.so is available
87+
"""
88+
return bool(ENABLED_FEATURES.torch_tensorrt_runtime)
89+
90+
91+
def has_torchscript_frontend() -> bool:
92+
"""Check if TorchScript frontend is available.
93+
94+
Returns:
95+
bool: True if libtorchtrt.so is available
96+
"""
97+
return bool(ENABLED_FEATURES.torchscript_frontend)
98+
99+
81100
def needs_tensorrt_rtx(f: Callable[..., Any]) -> Callable[..., Any]:
82101
def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
83102
if ENABLED_FEATURES.tensorrt_rtx:
@@ -180,6 +199,7 @@ def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
180199
if ENABLED_FEATURES.trtllm_for_nccl:
181200
return f(*args, **kwargs)
182201
else:
202+
183203
def not_implemented(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
184204
raise NotImplementedError(
185205
"TensorRT-LLM plugin for NCCL is not available"

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import torch
99
import torch._dynamo as td
10+
import torch_tensorrt.logging as torchtrt_logging
1011
from torch._dynamo.backends.common import aot_autograd
1112
from torch._dynamo.utils import detect_fake_mode
1213
from torch._functorch.aot_autograd import aot_export_joint_simple
@@ -23,7 +24,6 @@
2324
from torch_tensorrt.dynamo.utils import (
2425
parse_dynamo_kwargs,
2526
prepare_inputs,
26-
set_log_level,
2727
)
2828

2929
logger = logging.getLogger(__name__)
@@ -40,7 +40,7 @@ def torch_tensorrt_backend(
4040
and "debug" in kwargs["options"]
4141
and kwargs["options"]["debug"]
4242
) or ("debug" in kwargs and kwargs["debug"]):
43-
set_log_level(logger.parent, logging.DEBUG)
43+
torchtrt_logging.set_level(logging.DEBUG)
4444

4545
DEFAULT_BACKEND = aot_torch_tensorrt_aten_backend
4646

py/torch_tensorrt/dynamo/utils.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
2929
from torch_tensorrt._Device import Device
3030
from torch_tensorrt._enums import dtype
31-
from torch_tensorrt._features import ENABLED_FEATURES
3231
from torch_tensorrt._Input import Input
3332
from torch_tensorrt._utils import is_tensorrt_version_supported
3433
from torch_tensorrt.dynamo import _defaults
@@ -270,33 +269,6 @@ def get_model_device(module: torch.fx.GraphModule) -> torch.device:
270269
return device
271270

272271

273-
def set_log_level(parent_logger: Any, level: Any) -> None:
274-
"""
275-
Sets the log level to the user provided level.
276-
This is used to set debug logging at a global level
277-
at entry points of tracing, dynamo and torch_compile compilation.
278-
And set log level for c++ torch trt logger if runtime is available.
279-
"""
280-
if parent_logger:
281-
parent_logger.setLevel(level)
282-
283-
if ENABLED_FEATURES.torch_tensorrt_runtime:
284-
if level == logging.DEBUG:
285-
log_level = trt.ILogger.Severity.VERBOSE
286-
elif level == logging.INFO:
287-
log_level = trt.ILogger.Severity.INFO
288-
elif level == logging.WARNING:
289-
log_level = trt.ILogger.Severity.WARNING
290-
elif level == logging.ERROR:
291-
log_level = trt.ILogger.Severity.ERROR
292-
elif level == logging.CRITICAL:
293-
log_level = trt.ILogger.Severity.INTERNAL_ERROR
294-
else:
295-
raise AssertionError(f"{level} is not valid log level")
296-
297-
torch.ops.tensorrt.set_logging_level(int(log_level))
298-
299-
300272
def prepare_inputs(
301273
inputs: Input | torch.Tensor | Sequence[Any] | Dict[Any, Any],
302274
disable_memory_format_check: bool = False,

0 commit comments

Comments
 (0)