Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1339,15 +1339,13 @@ def convert_exported_program_to_serialized_trt_engine(
)

flattened_input_list = get_flat_args_with_check(
exported_program, list(trt_arg_inputs), trt_kwarg_inputs # type: ignore
exported_program, list(trt_arg_inputs), trt_kwarg_inputs
)[0]

try:
interpreter_result = interpret_module_to_result(
gm,
inputs=flattened_input_list,
arg_inputs=list(trt_arg_inputs),
kwarg_inputs=trt_kwarg_inputs,
settings=settings,
engine_cache=engine_cache,
)
Expand Down
11 changes: 6 additions & 5 deletions py/torch_tensorrt/dynamo/_refit.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
logger = logging.getLogger(__name__)


@needs_refit
@needs_refit # type: ignore[misc]
def construct_refit_mapping(
module: torch.fx.GraphModule,
inputs: Sequence[Input],
Expand Down Expand Up @@ -85,7 +85,7 @@ def construct_refit_mapping(
return weight_refit_map


@needs_refit
@needs_refit # type: ignore[misc]
def construct_refit_mapping_from_weight_name_map(
weight_name_map: dict[Any, Any],
state_dict: dict[Any, Any],
Expand Down Expand Up @@ -128,7 +128,7 @@ def construct_refit_mapping_from_weight_name_map(
return engine_weight_map


@needs_refit
@needs_refit # type: ignore[misc]
def _refit_single_trt_engine_with_gm(
new_gm: torch.fx.GraphModule,
old_engine: trt.ICudaEngine,
Expand Down Expand Up @@ -211,7 +211,7 @@ def _refit_single_trt_engine_with_gm(
raise AssertionError("Refitting failed.")


@needs_refit
@needs_refit # type: ignore[misc]
def refit_module_weights(
compiled_module: torch.fx.GraphModule | ExportedProgram,
new_weight_module: ExportedProgram,
Expand Down Expand Up @@ -484,9 +484,10 @@ def refit_module_weights(
weight_name_map=None,
)

# clear EXCLUDE_WEIGHTS flag
# clear EXCLUDE_WEIGHTS flag and set INCLUDE_REFIT flag to make the engine refittable
serialization_config = engine.create_serialization_config()
serialization_config.clear_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS)
serialization_config.set_flag(trt.SerializationFlag.INCLUDE_REFIT)
serialized_engine = engine.serialize_with_config(serialization_config)

if isinstance(compiled_submodule, PythonTorchTensorRTModule):
Expand Down
1 change: 0 additions & 1 deletion py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,6 @@ def __setstate__(self, state: dict[str, Any]) -> None:
"engine_capability",
"hardware_compatible",
"refit_identical_engine_weights",
"strip_engine_weights", # TODO: @Evan to remove this after implementing caching weight-stripped engines as default?
"immutable_weights",
"enable_weight_streaming",
"tiling_optimization_level",
Expand Down
4 changes: 0 additions & 4 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,6 @@ def _pretraced_backend(
logger.warning(
"require_full_compilation arg is not applicable for torch.compile with backend='torch_tensorrt"
)
if settings.strip_engine_weights:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When would a torch.compile use try to use strip weights?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added the warning back. Not sure why strip_engine_weights arg doesn't work for torch.compile()

logger.error(
"strip_engine_weights arg is not supported for torch.compile()"
)
trt_compiled = compile_module(
gm,
torchtrt_inputs,
Expand Down
95 changes: 1 addition & 94 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from torch_tensorrt._utils import is_tensorrt_version_supported
from torch_tensorrt.dynamo import _defaults
from torch_tensorrt.dynamo._engine_cache import BaseEngineCache
from torch_tensorrt.dynamo._settings import CompilationSettings, settings_are_compatible
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
DYNAMO_CONVERTERS as CONVERTERS,
Expand Down Expand Up @@ -594,79 +594,6 @@ def _save_weight_mapping(self) -> None:
gc.collect()
torch.cuda.empty_cache()

@needs_refit # type: ignore[misc]
def _pull_cached_engine(self, hash_val: str) -> Optional[TRTInterpreterResult]:
# query the cached TRT engine
cached_data = self.engine_cache.check(hash_val) # type: ignore[union-attr]
if cached_data is not None: # hit the cache
(
serialized_engine,
self._input_names,
self._output_names,
cached_engine_input_specs,
engine_compilation_settings,
self.weight_name_map,
self.ctx.requires_output_allocator,
) = cached_data

setting_compatiblity, incompattible_settings = settings_are_compatible(
self.compilation_settings, engine_compilation_settings
)
assert (
setting_compatiblity
), f"Attempted to refit a cached engine with incompatible settings: {incompattible_settings}, (old_settings: {engine_compilation_settings}, new_settings: {self.compilation_settings})"

for i, e in enumerate(
[
Input.equivalent_spec(c, i)
for c, i in zip(cached_engine_input_specs, self.input_specs)
]
):
assert (
e
), f"Attempted to refit a cached engine built for a different input size (input: {i}, cached size: {cached_engine_input_specs[i]}, new size: {self.input_specs[i]}"

_LOGGER.info(
"Found the cached engine that corresponds to this graph. It is directly loaded."
)

# refit the cached engine with the new graph module
if not self.compilation_settings.strip_engine_weights:
runtime = trt.Runtime(TRT_LOGGER)
engine = runtime.deserialize_cuda_engine(serialized_engine)

from torch_tensorrt.dynamo._refit import (
_refit_single_trt_engine_with_gm,
)

_refit_single_trt_engine_with_gm(
new_gm=self.module,
old_engine=engine,
input_list=self.input_specs,
settings=self.compilation_settings,
weight_name_map=self.weight_name_map,
)

# TODO: @Evan is waiting for TRT's feature to load the weight-stripped engine
# # EXCLUDE_WEIGHTS flag must be cleared
# serialization_config = engine.create_serialization_config()
# serialization_config.clear_flag(
# trt.SerializationFlag.EXCLUDE_WEIGHTS
# )
# serialized_engine = engine.serialize_with_config(
# serialization_config
# )
# # As of now, the engine becomes non-refittable because when EXCLUDE_WEIGHTS flag is cleared, the REFIT flag is also cleared by TRT to make the plan file smaller

return TRTInterpreterResult(
engine,
self._input_names,
self._output_names,
self.weight_name_map,
self.ctx.requires_output_allocator,
)
return None

def run(
self,
strict_type_constraints: bool = False,
Expand All @@ -682,26 +609,6 @@ def run(
Return:
TRTInterpreterResult
"""
# self.engine_cache could be None if:
# 1) engine_cache is not passed in when calling this function like convert_exported_program_to_serialized_trt_engine etc., or
# 2) both cache_built_engines and reuse_cached_engines are False
if (
self.engine_cache is not None
and not self.compilation_settings.immutable_weights
):
if (
self.compilation_settings.cache_built_engines
or self.compilation_settings.reuse_cached_engines
):
hash_val = self.engine_cache.get_hash(
self.module, self.input_specs, self.compilation_settings
)

if self.compilation_settings.reuse_cached_engines:
interpreter_result = self._pull_cached_engine(hash_val)
if interpreter_result is not None: # hit the cache
return interpreter_result # type: ignore[no-any-return]

self._construct_trt_network_def()
_LOGGER.debug(
f"CPU memory usage after network construction: {get_cpu_memory_usage()} MB"
Expand Down
Loading
Loading