From e53905d246ff0ed0e00ac312cf222fda1b0ccaba Mon Sep 17 00:00:00 2001 From: cehongwang Date: Thu, 13 Nov 2025 19:11:06 +0000 Subject: [PATCH 01/10] squashed and cleaned the commits --- py/torch_tensorrt/dynamo/_compiler.py | 15 + py/torch_tensorrt/dynamo/_defaults.py | 2 + py/torch_tensorrt/dynamo/_settings.py | 2 + .../partitioning/_resource_partitioner.py | 562 ++++++++++++++++++ .../dynamo/partitioning/fusion_patterns.py | 185 ++++++ 5 files changed, 766 insertions(+) create mode 100644 py/torch_tensorrt/dynamo/partitioning/_resource_partitioner.py create mode 100644 py/torch_tensorrt/dynamo/partitioning/fusion_patterns.py diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index bc345947d3..f5c435884b 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -105,6 +105,7 @@ def cross_compile_for_windows( l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING, offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU, use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE, + cpu_memory_budget: int = _defaults.CPU_MEMORY_BUDGET, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module using TensorRT in Linux for Inference in Windows @@ -179,6 +180,7 @@ def cross_compile_for_windows( tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"]. l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit). use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model + cpu_memory_budget (int): The maximum amount of CPU memory to use for the compilation. If the compilation requires more memory than this budget, the compilation will fail. If set to -1, the compilation will use all available CPU memory. **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -334,6 +336,7 @@ def cross_compile_for_windows( "tiling_optimization_level": tiling_optimization_level, "l2_limit_for_tiling": l2_limit_for_tiling, "use_distributed_mode_trace": use_distributed_mode_trace, + "cpu_memory_budget": cpu_memory_budget, } # disable the following settings is not supported for cross compilation for windows feature @@ -435,6 +438,7 @@ def compile( l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING, offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU, use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE, + cpu_memory_budget: int = _defaults.CPU_MEMORY_BUDGET, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module for NVIDIA GPUs using TensorRT @@ -681,6 +685,7 @@ def compile( "l2_limit_for_tiling": l2_limit_for_tiling, "offload_module_to_cpu": offload_module_to_cpu, "use_distributed_mode_trace": use_distributed_mode_trace, + "cpu_memory_budget": cpu_memory_budget, } logger.debug(f"CPU memory usage before lowering: {get_cpu_memory_usage()} MB") settings = CompilationSettings(**compilation_options) @@ -854,6 +859,16 @@ def preserve_module_specs( require_full_compilation=settings.require_full_compilation, ) + from torch_tensorrt.dynamo.partitioning._resource_partitioner import ( + resource_partition, + ) + + partitioned_module = resource_partition( + gm, + partitioned_module, + cpu_memory_budget=settings.cpu_memory_budget, + ) + dryrun_tracker.unsupported_ops = supported_ops.unsupported_operators # The global partitioner leaves non-TRT nodes as-is diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index de970ecd81..0b4c0a2b54 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -2,6 +2,7 @@ import platform import tempfile +import psutil import torch from torch_tensorrt._Device import Device from torch_tensorrt._enums import EngineCapability, dtype @@ -57,6 +58,7 @@ L2_LIMIT_FOR_TILING = -1 USE_DISTRIBUTED_MODE_TRACE = False OFFLOAD_MODULE_TO_CPU = False +CPU_MEMORY_BUDGET = psutil.virtual_memory().available if platform.system() == "Linux": import pwd diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index d8f6809eae..52ac86012c 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -7,6 +7,7 @@ from torch_tensorrt.dynamo._defaults import ( ASSUME_DYNAMIC_SHAPE_SUPPORT, CACHE_BUILT_ENGINES, + CPU_MEMORY_BUDGET, DISABLE_TF32, DLA_GLOBAL_DRAM_SIZE, DLA_LOCAL_DRAM_SIZE, @@ -140,6 +141,7 @@ class CompilationSettings: l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE offload_module_to_cpu: bool = OFFLOAD_MODULE_TO_CPU + cpu_memory_budget: int = CPU_MEMORY_BUDGET def __getstate__(self) -> dict[str, Any]: from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( diff --git a/py/torch_tensorrt/dynamo/partitioning/_resource_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_resource_partitioner.py new file mode 100644 index 0000000000..967711ba02 --- /dev/null +++ b/py/torch_tensorrt/dynamo/partitioning/_resource_partitioner.py @@ -0,0 +1,562 @@ +"""Resource-aware graph partitioner for TensorRT compilation. + +This module refines an existing capability-based partitioning (accelerated vs +non-accelerated subgraphs) by further splitting accelerated subgraphs to meet +host CPU memory constraints during TensorRT engine building. + +High-level algorithm +-------------------- +Given an original `torch.fx.GraphModule` and a capability-partitioned +`GraphModule` (produced earlier in the pipeline), we: + +1) Reconstruct subgraphs on the original graph + - Iterate over the capability-partitioned module to determine which original + nodes belong to which subgraph (accelerated or not). + - Preserve fusion groups discovered in each subgraph so that all nodes in a fusion + group remain in the same subgraph and not be split across subgraphs. + - Verify subgraphs respect topological order. This is to ensure the validity of the subgraphs. + - Reconstruting subgraphs from partitioned module is easier than building nasted partitioned graph modules and flattening them later. + +2) Estimate memory cost of each possible subgraphs + - Compute a per-subgraph "size" by traversing the graph to find weights + (get_attr) reachable from its nodes and summing tensor bytes. + - Use a set to record the visited nodes and avoid double counting shared parameters across subgraphs. + + +4) Split large accelerated subgraphs + - While a subgraph exceeds the per-engine budget, split it into two or more subgraphs. + - Move nodes incrementally from the front of the original subgraph into a + new left subgraph, repeatedly validating/correcting topological, partitioning, and + dependency constraints. + - Ensure we never split across a fusion group; when a split would break a + fusion, we backtrack dependencies and move the entire fusion and related nodes into the left + side. + - Continue until the left subgraph fits the budget + - Repeat the process for the right subgraph until all subgraphs fit the budget. + +5) Finalize + - After splitting, assert all fusion groups reside in a single subgraph. + - Tag nodes and produce a `GraphModule` where each subgraph becomes either a + TRT engine (accelerated) or runs in Torch (non-accelerated). + +Notes +----- +- The budget is a heuristic bound. If the total model size exceeds 40x the + per-engine budget, we fail early with a clear error suggesting remedies. +""" + +import logging +from typing import Dict, List, Tuple + +import psutil +import torch +from torch.fx.passes.splitter_base import Subgraph, _SplitterBase +from torch.fx.passes.tools_common import CALLABLE_NODE_OPS +from torch_tensorrt.dynamo.partitioning.fusion_patterns import ( + get_node_in_fusion_pattern, +) + +logger = logging.getLogger(__name__) + + +class ResourcePartitioner(_SplitterBase): # type: ignore + """Refine capability-based subgraphs to meet host CPU memory constraints. + + This partitioner takes: + - an original `torch.fx.GraphModule` (`module`) + - a capability-partitioned `GraphModule` (`partitioned_module`) containing + submodules that delineate accelerated vs non-accelerated regions + - a CPU memory budget in bytes (`cpu_memory_budget`) + + It maps nodes from `module` into subgraphs according to `partitioned_module` + and then splits oversized accelerated subgraphs so that each resulting TRT + engine's estimated size fits within a conservative budget derived from + available CPU memory or predefined CPU budget. + """ + + def __init__( + self, + module: torch.fx.GraphModule, + partitioned_module: torch.fx.GraphModule, + cpu_memory_budget: int, + ): + + assert isinstance(module, torch.fx.GraphModule) + assert isinstance(partitioned_module, torch.fx.GraphModule) + + self.module = module + self.partitioned_module = partitioned_module + self.cpu_memory_budget = cpu_memory_budget + + self.deps = self.find_deps() + + self.non_acc_submodule_name = "_run_on_gpu_" + self._node_submodule_map: Dict[str, str] = {} + self._return_tuple = False + self.fusion_patterns: Dict[torch.fx.Node, List[torch.fx.Node]] = {} + + def partition_graph(self) -> torch.fx.GraphModule: + """Build the final partitioned `GraphModule` honoring memory constraints. + + Steps: + - Build subgraph assignments from the capability-partitioned module + - Split oversized accelerated subgraphs based on memory budget + - Tag nodes and construct the final split graph + + Returns: + torch.fx.GraphModule: A graph split into subgraphs based on capability partitioning and memory constraints. + """ + # Delegate nodes based on operator coverage + subgraphs = self.put_nodes_into_subgraphs() + + subgraphs = self.break_subgraphs( + subgraphs, subgraph_size_budget=self.calculate_size_budget() + ) + + # Set the number of TRT engines to be generated + self.num_trt_accelerated_subgraphs = len([s for s in subgraphs if s.is_acc]) + + # Tag the accelerated nodes and split the graph accordingly + self.tag(subgraphs) + + gm = self.split() + + return gm + + def put_nodes_into_subgraphs(self) -> list[Subgraph]: + """Map original graph nodes into capability-based subgraphs. + + - Iterates `partitioned_module` submodules to establish which node names + belong to which subgraph (accelerated or not). + - Builds a fusion pattern map for each subgraph so that known fusion groups remain intact. + Note that since fusion map is built for each subgraph, the capability partitioning can still break the fusion groups. + - Put the nodes into the subgraphs based on the capability partitioning. + - Verifies the resulting list of subgraphs is topologically ordered. + + Returns: + list[Subgraph]: Ordered subgraphs consisting of nodes in `module` based on capability partitioning. + """ + subgraphs_map = {} + subgraphs = [] + name_to_node_map = ( + {} + ) # We use this map to help map the nodes in partitioned module to the nodes in original module. + for name, _ in self.partitioned_module.named_children(): + # We first iterate over the partitioned module to find the subgraphs based on capability partitioning. + submodule = getattr(self.partitioned_module, name) + if not isinstance(submodule, torch.fx.graph_module.GraphModule): + continue + subgraph = Subgraph(is_acc="acc" in name, nodes=[]) + subgraphs.append(subgraph) + self.fusion_patterns.update(get_node_in_fusion_pattern(submodule.graph)) + + for node in submodule.graph.nodes: + # Erase the tag from previous partitioner if it exists + if hasattr(node, "tag"): + delattr(node, "tag") + + if node.op in CALLABLE_NODE_OPS: + # Store which subgraph the node should be put in + subgraphs_map[node.name] = subgraph + + # We then iterate over the original module to put the nodes into the subgraphs. + for node in self.module.graph.nodes: + if hasattr(node, "tag"): + # Erase the tag from previous partitioner + delattr(node, "tag") + if node.op in CALLABLE_NODE_OPS: + name_to_node_map[node.name] = node + subgraphs_map[node.name].nodes.append(node) + + assert self.check_topological_order( + subgraphs + ), "The subgraphs are not topologically ordered" + self.fusion_patterns = { + name_to_node_map[node.name]: [ + name_to_node_map[n.name] for n in fusion_nodes + ] + for node, fusion_nodes in self.fusion_patterns.items() + } + + return subgraphs + + def check_topological_order(self, subgraphs: List[Subgraph]) -> bool: + """Return True if subgraphs are in a valid topological order. + + Each node's dependencies must appear in earlier subgraphs or earlier + positions within the same subgraph. Subgraphs should be topologically ordered to ensure the validity of the subgraphs. + """ + visited_nodes: set[torch.fx.Node] = set() + for subgraph in subgraphs: + for node in subgraph.nodes: + if self.deps[node] > visited_nodes: + return False + visited_nodes.add(node) + return True + + def calculate_size_budget( + self, engine_compilation_memory_usage_multiplier: int = 4 + ) -> int: + """Compute the per-engine size budget in bytes. + + Uses explicit `cpu_memory_budget` minus used RSS + divided by a safety multiplier. + + Args: + engine_compilation_memory_usage_multiplier: Safety divisor applied to + available memory to approximate a per-engine budget. By default we assume TensorRT + compilation requires up to 4x the model's size. + + Returns: + int: Budget in bytes for a single accelerated subgraph. + """ + + used_rss: int = psutil.virtual_memory().used + available_rss = self.cpu_memory_budget - used_rss + return available_rss // engine_compilation_memory_usage_multiplier + + def break_subgraphs( + self, subgraphs: List[Subgraph], subgraph_size_budget: int + ) -> List[Subgraph]: + """Split oversized accelerated subgraphs until they fit within budget. + + - Compute sizes for each subgraph (in bytes of parameters reachable from + that subgraph). + - If the sum of all sizes is catastrophically larger than budget + (threshold 40x), raise a ValueError with guidance. + - For any subgraph whose size exceeds `subgraph_size_budget`, iteratively + split it using `break_subgraph_by_size` and append resulting segments. + - Validate that fusion groups remain intact post splitting. + + Args: + subgraphs: Ordered list of subgraphs from capability partitioning. + subgraph_size_budget: Target maximum size per accelerated subgraph. + + Returns: + List[Subgraph]: New list of subgraphs after resource-aware splitting. + """ + + new_subgraphs = [] + # We throw an error if the remaining memory is almost empty compared to the model size. + # i.e. if the remaining memory is 4G (budget is 1G) the model size is greater than 40G, we stop the compilation. + sizes = self.size_of_subgraphs(subgraphs) + if sum(sizes) > subgraph_size_budget * 40: + raise ValueError( + f"CPU memory budget or available memory is too small to compile the model. CPU memory budget: {self.cpu_memory_budget // (1024 * 1024) if self.cpu_memory_budget != -1 else "All available memory"} MB, Model size: {sum(sizes) // (1024 * 1024)} MB. " + + "Consider setting cpu_memory_budget to a larger value or disable offload_module_to_cpu to save more CPU memory." + ) + for subgraph, size in zip(subgraphs, sizes): + + while size > subgraph_size_budget: + broken_subgraphs, size_0, size_1 = self.break_subgraph_by_size( + subgraph, subgraph_size_budget + ) + size = size_1 + new_subgraphs.append(broken_subgraphs[0]) + subgraph = broken_subgraphs[1] + new_subgraphs.append(subgraph) + + self._varify_all_fusion_nodes_in_same_subgraph(new_subgraphs) + + return new_subgraphs + + def _varify_all_fusion_nodes_in_same_subgraph( + self, subgraphs: List[Subgraph] + ) -> None: + """Assert that every fusion group is contained in exactly one subgraph.""" + node_to_subgraph = {} + for i, s in enumerate(subgraphs): + for n in s.nodes: + node_to_subgraph[n] = i + + fusion_nodes_map_list = [ + len({node_to_subgraph[n] for n in ns}) == 1 + for ns in self.fusion_patterns.values() + ] # fusion nodes must be in the same subgraph + + assert all( + fusion_nodes_map_list + ), "All fusion nodes must be in the same subgraph" + logger.info("All fusion nodes are in the same subgraph.") + + def break_subgraph_by_size( + self, subgraph: Subgraph, size_to_break: int + ) -> Tuple[List[Subgraph], int, int]: + """Split a single oversized subgraph into two valid subgraphs. + + Moves nodes from the head of `subgraph` into a new left segment until + the left segment's estimated size exceeds `size_to_break`. During the + process we: + - Repeatedly validate/correct topological placement + - Detect and avoid splitting fusion groups by moving all fused nodes + (and their producer chain) into the left segment + + Returns: + (segments, size_left, size_right): + segments[0] is the new left subgraph, segments[1] is the residual + right subgraph. Sizes are estimated parameter bytes of each. + """ + all_nodes = subgraph.nodes + device_ordinal = subgraph.device_ordinal + new_subgraphs = [ + Subgraph( + is_acc=True, + nodes=[], + device_ordinal=device_ordinal, + ), + Subgraph( + is_acc=True, + nodes=all_nodes, + device_ordinal=device_ordinal, + ), + ] + + # We break the subgraph until the left subgraph fits the budget. + while True: + # Set a step size proportional to the size of the subgraph to make the algorithm more efficient. + # This reduce the time complexity from O(N**2) to O(N). The max number of steps is 50. + # Note: we want the first step size to be 1. + step_size = ( + 1 if not new_subgraphs[0].nodes else max(1, len(all_nodes) // 50) + ) + new_subgraphs = self.step_and_validate(new_subgraphs, step_size) + size_0, size_1 = self.size_of_subgraphs(new_subgraphs) + if size_0 > size_to_break: + break + + if len(new_subgraphs[1].nodes) == 0: + new_subgraphs.pop(1) + return new_subgraphs, size_0, size_1 + + def step_and_validate( + self, new_subgraphs: List[Subgraph], step_size: int = 1 + ) -> List[Subgraph]: + """Advance the split by `step_size` nodes, then add more nodes to the left subgraph if rules are broken. + There are two rules to check: + 1. The subgraphs should be ordered in a way that is safely to partition. + This is checked by validate_and_correct_subgraphs. Check that function for more details. + 2. The subgraphs should not break any fusion groups. + - Move `step_size` nodes from the right to the left subgraph. + - Run validation/correction to ensure a legal partitioning placement. + - Get all leaf nodes in the left subgraph and check whether any of them are in a fusion group. + - If the move splits a fusion group, migrate the entire fusion into the left subgraph. + + Returns: + List[Subgraph]: Updated pair of subgraphs after stabilization. + """ + + for _ in range(step_size): + new_subgraphs[0].nodes.append(new_subgraphs[1].nodes.pop(0)) + + while True: + new_subgraphs = self.validate_and_correct_subgraphs(new_subgraphs) + nodes_in_first_subgraph = set(new_subgraphs[0].nodes) + nodes_in_second_subgraph = set(new_subgraphs[1].nodes) + leaf_node = self.get_leaf_node(nodes_in_first_subgraph) + broken_fusion = self.step_if_break_fusion( + new_subgraphs, + leaf_node, + nodes_in_first_subgraph, + nodes_in_second_subgraph, + ) + if not broken_fusion or len(new_subgraphs[1].nodes) == 0: + break + + return new_subgraphs + + def step_if_break_fusion( + self, + subgraphs: List[Subgraph], + leaf_nodes: set[torch.fx.Node], + nodes_in_first_subgraph: set[torch.fx.Node], + nodes_in_second_subgraph: set[torch.fx.Node], + ) -> bool: + """Detect a fusion split and migrate fused nodes to the left subgraph. + + Given the current split boundary (captured by `leaf_nodes` of the left + subgraph), check all recorded fusion groups. If any fused node remains + on the right while its peer is on the left, pull the node and all of its + producer chain into the left subgraph to keep fusions intact. + + Returns: + bool: True if any fusion was migrated (i.e., a split would have + broken a fusion), otherwise False. + """ + + def add_nodes(node: torch.fx.Node) -> None: + """ + This function adds a node and all its previous nodes to the first subgraph and removes it from the second subgraph in post order. + """ + if ( + node.op in CALLABLE_NODE_OPS + and node not in nodes_in_first_subgraph + and node in nodes_in_second_subgraph + ): + # Exclude all nodes already in the first subgraph + nodes_in_first_subgraph.add(node) + nodes_in_second_subgraph.remove(node) + for input_node in node._input_nodes: + add_nodes(input_node) + subgraphs[0].nodes.append(node) + subgraphs[1].nodes.remove(node) + + fusion_broken = False + for leaf in leaf_nodes: + for node in self.fusion_patterns.get(leaf, []): + if ( + node not in nodes_in_first_subgraph + and node in nodes_in_second_subgraph + ): + fusion_broken = True + add_nodes(node) + + return fusion_broken + + def get_leaf_node( + self, nodes_in_first_subgraph: set[torch.fx.Node] + ) -> set[torch.fx.Node]: + """Return nodes in the left subgraph that feed any node on the right. + + A node is considered a leaf if at least one of its users is not in the + left subgraph. + """ + leaf_node = set() + + for node in nodes_in_first_subgraph: + for user in node.users: + if user not in nodes_in_first_subgraph: + leaf_node.add(node) + break + return leaf_node + + def size_of_subgraphs(self, subgraphs: List[Subgraph]) -> List[int]: + """Estimate parameter footprint (bytes) for each subgraph. + + Traverses each subgraph's nodes and their producer chains to find + parameters referenced via `get_attr`, summing tensor bytes. Shared + parameters are counted only once globally. + + Returns: + List[int]: Size per subgraph in bytes. + """ + state_dict = self.module.state_dict(keep_vars=True) + sizes = [] + weight_visited_nodes = set() + for subgraph in subgraphs: + nodes_in_subgraph = set(subgraph.nodes) + stack = subgraph.nodes.copy() + size = 0 + while stack: + node = stack.pop() + if node in weight_visited_nodes: + continue + weight_visited_nodes.add(node) + if node.op == "get_attr": + weight = state_dict.get(node.target, None) + if weight is None: + logger.warning(f"Weight {node.target} not found in state_dict") + continue + size += weight.numel() * weight.element_size() + continue + if node not in nodes_in_subgraph: + # Trace to other subgraphs + continue + for input_node in node._input_nodes: + if input_node not in weight_visited_nodes: + stack.append(input_node) + sizes.append(size) + return sizes + + def validate_and_correct_subgraphs( + self, subgraphs: List[Subgraph] + ) -> List[Subgraph]: + """This is very important for the correctness of the partitioning. Torch gives undefined behavior if the subgraphs are not ordered correctly. + + The principle is: nodes that have all dependencies resolved in previous subgraphs should also be moved to the previous subgraph. + For example, given a breakpoint node n resulting in two subgraphs S1 [..., n] and S2 [n+1, ...], all nodes in S2 that is not directly or indirectly depend on n should be moved to S1. + + We use a map to record the index of the subgraph that a node's users should belong to. If the node N is in subgraph S1 and is not the breakpoint node (subgraph.nodes[-1]), + then the users that only depend on N should also be moved to S1. However, N is a breakpoint node, then the users that only depend on N should also be moved to S2. + + With the map, we can determine with subgraph a later node should be moved to according to all its inputs. We take max indices of all inputs nodes to determine the subgraph index. + + Returns: + List[Subgraph]: Corrected subgraphs. + """ + # a map from a node to the index of the subgraph it's user should belong to + visited_nodes = {} + + for i, subgraph in enumerate(subgraphs): + if i == 0: + for node in subgraph.nodes: + visited_nodes[node] = i + # breakpoint node's users should belong to the next subgraph + visited_nodes[subgraph.nodes[-1]] = i + 1 + continue + + elif not subgraph.is_acc: + # non-accelerated subgraphs should be put in the next subgraph + for node in subgraph.nodes: + visited_nodes[subgraph.nodes[-1]] = i + 1 + continue + + else: + to_remove_nodes = [] + for j, node in enumerate(subgraph.nodes): + if j == len(subgraph.nodes) - 1: + # breakpoint node's users should belong to the next subgraph + visited_nodes[node] = i + 1 + continue + subgraph_idx = 0 + for dep in self.deps[node]: + if dep in visited_nodes: + # We take max indices of all inputs nodes to determine the subgraph index. + subgraph_idx = max(subgraph_idx, visited_nodes[dep]) + + if subgraph_idx != i: + # If the node should be moved to a different subgraph, we move it and remove it from the current subgraph. + subgraphs[subgraph_idx].nodes.append(node) + to_remove_nodes.append(node) + # Record the the subgraph that the users of this node should belong to + visited_nodes[node] = subgraph_idx + + # Remove the nodes that are moved to other subgraphs + for node in to_remove_nodes: + subgraph.nodes.remove(node) + + return subgraphs + + +def resource_partition( + gm: torch.fx.GraphModule, + partitioned_module: torch.fx.GraphModule, + cpu_memory_budget: int, +) -> torch.fx.GraphModule: + """Resource-aware partitioning entry point. + + Takes an original FX graph (`gm`) and a capability-partitioned module + (`partitioned_module`) and returns a final graph where accelerated segments + are split further, if necessary, to satisfy CPU memory limits for TRT + engine compilation. + + Args: + gm: Original FX `GraphModule`. + partitioned_module: Capability-partitioned `GraphModule` indicating + accelerated vs non-accelerated regions. + cpu_memory_budget: CPU memory budget in bytes for engine compilation. + Use -1 to base the budget on currently available system memory. + + Returns: + torch.fx.GraphModule: Final graph with resource-constrained subgraphs. + """ + + # Construct + partitioner = ResourcePartitioner( + gm, + partitioned_module, + cpu_memory_budget=cpu_memory_budget, + ) + + partitioned_graph = partitioner.partition_graph() + + return partitioned_graph diff --git a/py/torch_tensorrt/dynamo/partitioning/fusion_patterns.py b/py/torch_tensorrt/dynamo/partitioning/fusion_patterns.py new file mode 100644 index 0000000000..a5b3e74ee5 --- /dev/null +++ b/py/torch_tensorrt/dynamo/partitioning/fusion_patterns.py @@ -0,0 +1,185 @@ +from functools import lru_cache +from typing import Dict, List, Set + +import torch +from torch.fx.passes.utils.matcher_utils import SubgraphMatcher +from torch.ops import aten + + +class ConvBNReLU(torch.nn.Module): # type: ignore[misc] + def __init__(self) -> None: + super().__init__() + + def forward( + self, + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: List[int], + padding: List[int], + dilation: List[int], + transposed: bool, + output_padding: List[int], + groups: int, + bn_weight: torch.Tensor, + bn_bias: torch.Tensor, + running_mean: torch.Tensor, + running_var: torch.Tensor, + momentum: float, + eps: float, + ) -> torch.Tensor: + x = aten.convolution.default( + x, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + ) + x = aten._native_batch_norm_legit_no_training.default( + x, bn_weight, bn_bias, running_mean, running_var, momentum, eps + )[0] + x = aten.relu.default(x) + return x + + +class ConvReLU(torch.nn.Module): # type: ignore[misc] + def __init__(self) -> None: + super().__init__() + + def forward( + self, + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: List[int], + padding: List[int], + dilation: List[int], + transposed: bool, + output_padding: List[int], + groups: int, + ) -> torch.Tensor: + x = aten.convolution.default( + x, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + ) + x = aten.relu.default(x) + return x + + +class ConvGelu(torch.nn.Module): # type: ignore[misc] + def __init__(self) -> None: + super().__init__() + + def forward( + self, + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: List[int], + padding: List[int], + dilation: List[int], + transposed: bool, + output_padding: List[int], + groups: int, + ) -> torch.Tensor: + x = aten.convolution.default( + x, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + ) + x = aten.gelu.default(x) + return x + + +class ConvSilu(torch.nn.Module): # type: ignore[misc] + def __init__(self) -> None: + super().__init__() + + def forward( + self, x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor + ) -> torch.Tensor: + x = aten.convolution.default( + x, weight, bias, [1, 1], [1, 1], [1, 1], False, [0, 0], 1 + ) + x = aten.silu.default(x) + return x + + +class MulAdd(torch.nn.Module): # type: ignore[misc] + def __init__(self) -> None: + super().__init__() + + def forward( + self, x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor + ) -> torch.Tensor: + x = aten.mul.Tensor(x, weight) + x = aten.add.Tensor(x, bias) + return x + + +class MulMul(torch.nn.Module): # type: ignore[misc] + def __init__(self) -> None: + super().__init__() + + def forward( + self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor + ) -> torch.Tensor: + x = aten.mul.Tensor(x, y) + x = aten.mul.Tensor(x, z) + return x + + +All_FUSION_PATTERNS = [ + ConvBNReLU, + ConvReLU, + ConvGelu, + ConvSilu, + MulAdd, + MulMul, +] + + +@lru_cache(maxsize=None) +def get_node_in_fusion_pattern( + graph: torch.fx.Graph, +) -> Dict[torch.fx.Node, Set[torch.fx.Node]]: + """ + This function gets the nodes map of the fusion pattern from the graph. + Key: node that appears in the fusion pattern + Value: the list of nodes that should be fused together + """ + fusion_nodes = {} + for pattern in All_FUSION_PATTERNS: + pattern_graph = torch.fx.symbolic_trace(pattern()) + subgraph_matcher = SubgraphMatcher(pattern_graph.graph) + match_result = subgraph_matcher.match(graph) + for match in match_result: + fusion_group = { + node + for node in match.nodes_map.values() + if node + and type(node) == torch.fx.Node + and node.op == "call_function" + and node not in match.placeholder_nodes + } + for node in fusion_group: + fusion_nodes[node] = fusion_group + + return fusion_nodes From 3f0030fd4ff9cfbd276d09fdc446e0cfe098d5e1 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Fri, 14 Nov 2025 00:46:50 +0000 Subject: [PATCH 02/10] Added decorator and tests --- py/torch_tensorrt/dynamo/_compiler.py | 4 + ...usion_patterns.py => _atomic_subgraphs.py} | 34 ++++--- .../partitioning/_resource_partitioner.py | 4 +- .../test_resource_partitioning.py | 93 +++++++++++++++++++ 4 files changed, 121 insertions(+), 14 deletions(-) rename py/torch_tensorrt/dynamo/partitioning/{fusion_patterns.py => _atomic_subgraphs.py} (86%) create mode 100644 tests/py/dynamo/partitioning/test_resource_partitioning.py diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index f5c435884b..63fec44900 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -619,6 +619,10 @@ def compile( "'arg_inputs' and 'inputs' should not be used at the same time." ) + assert ( + cpu_memory_budget >= 2 * 1024 * 1024 * 1024 + ), "CPU memory budget must be greater than 10GB" + arg_inputs = inputs or arg_inputs if kwarg_inputs is None: diff --git a/py/torch_tensorrt/dynamo/partitioning/fusion_patterns.py b/py/torch_tensorrt/dynamo/partitioning/_atomic_subgraphs.py similarity index 86% rename from py/torch_tensorrt/dynamo/partitioning/fusion_patterns.py rename to py/torch_tensorrt/dynamo/partitioning/_atomic_subgraphs.py index a5b3e74ee5..dbda162dcb 100644 --- a/py/torch_tensorrt/dynamo/partitioning/fusion_patterns.py +++ b/py/torch_tensorrt/dynamo/partitioning/_atomic_subgraphs.py @@ -1,11 +1,25 @@ from functools import lru_cache -from typing import Dict, List, Set +from typing import Callable, Dict, List, Set import torch from torch.fx.passes.utils.matcher_utils import SubgraphMatcher from torch.ops import aten +ATOMIC_SUBGRAPHS = [] + +def register_atomic_subgraph( + is_aten: bool = False, +) -> Callable[[torch.nn.Module], torch.nn.Module]: + + def decorator(subgraph: torch.nn.Module) -> torch.nn.Module: + ATOMIC_SUBGRAPHS.append((subgraph, is_aten)) + return subgraph + + return decorator + + +@register_atomic_subgraph(is_aten=True) class ConvBNReLU(torch.nn.Module): # type: ignore[misc] def __init__(self) -> None: super().__init__() @@ -46,6 +60,7 @@ def forward( return x +@register_atomic_subgraph(is_aten=True) class ConvReLU(torch.nn.Module): # type: ignore[misc] def __init__(self) -> None: super().__init__() @@ -77,6 +92,7 @@ def forward( return x +@register_atomic_subgraph(is_aten=True) class ConvGelu(torch.nn.Module): # type: ignore[misc] def __init__(self) -> None: super().__init__() @@ -108,6 +124,7 @@ def forward( return x +@register_atomic_subgraph(is_aten=True) class ConvSilu(torch.nn.Module): # type: ignore[misc] def __init__(self) -> None: super().__init__() @@ -122,6 +139,7 @@ def forward( return x +@register_atomic_subgraph(is_aten=True) class MulAdd(torch.nn.Module): # type: ignore[misc] def __init__(self) -> None: super().__init__() @@ -134,6 +152,7 @@ def forward( return x +@register_atomic_subgraph(is_aten=True) class MulMul(torch.nn.Module): # type: ignore[misc] def __init__(self) -> None: super().__init__() @@ -146,16 +165,6 @@ def forward( return x -All_FUSION_PATTERNS = [ - ConvBNReLU, - ConvReLU, - ConvGelu, - ConvSilu, - MulAdd, - MulMul, -] - - @lru_cache(maxsize=None) def get_node_in_fusion_pattern( graph: torch.fx.Graph, @@ -166,8 +175,9 @@ def get_node_in_fusion_pattern( Value: the list of nodes that should be fused together """ fusion_nodes = {} - for pattern in All_FUSION_PATTERNS: + for pattern, is_aten in ATOMIC_SUBGRAPHS: pattern_graph = torch.fx.symbolic_trace(pattern()) + # TODO: Add decomposition and lowering if is_aten is False subgraph_matcher = SubgraphMatcher(pattern_graph.graph) match_result = subgraph_matcher.match(graph) for match in match_result: diff --git a/py/torch_tensorrt/dynamo/partitioning/_resource_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_resource_partitioner.py index 967711ba02..0ef8d76e3a 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_resource_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_resource_partitioner.py @@ -52,7 +52,7 @@ import torch from torch.fx.passes.splitter_base import Subgraph, _SplitterBase from torch.fx.passes.tools_common import CALLABLE_NODE_OPS -from torch_tensorrt.dynamo.partitioning.fusion_patterns import ( +from torch_tensorrt.dynamo.partitioning._atomic_subgraphs import ( get_node_in_fusion_pattern, ) @@ -211,7 +211,7 @@ def calculate_size_budget( int: Budget in bytes for a single accelerated subgraph. """ - used_rss: int = psutil.virtual_memory().used + used_rss: int = psutil.Process().memory_info().rss available_rss = self.cpu_memory_budget - used_rss return available_rss // engine_compilation_memory_usage_multiplier diff --git a/tests/py/dynamo/partitioning/test_resource_partitioning.py b/tests/py/dynamo/partitioning/test_resource_partitioning.py new file mode 100644 index 0000000000..f059b0b166 --- /dev/null +++ b/tests/py/dynamo/partitioning/test_resource_partitioning.py @@ -0,0 +1,93 @@ +from typing import Any + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch_tensorrt as torchtrt +from torch.testing._internal.common_utils import TestCase, run_tests +from torch_tensorrt.dynamo import partitioning +from torch_tensorrt.dynamo.conversion import CompilationSettings +from torch_tensorrt.dynamo.lowering import ( + get_decompositions, + post_lowering, + pre_export_lowering, +) +from torch_tensorrt.dynamo.lowering.passes import post_lowering, pre_export_lowering +from torch_tensorrt.dynamo.partitioning._resource_partitioner import resource_partition + + +class TestResourcePartitioning(TestCase): + def test_resource_partitioning(self): + class net(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(1024, 4096, 3, padding=1) + self.bn1 = nn.BatchNorm2d(4096) + self.conv2 = nn.Conv2d(4096, 1024, 3, padding=1) + self.bn2 = nn.BatchNorm2d(1024) + self.fc1 = nn.Linear(1024 * 56 * 56, 10) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = F.relu(x) + x = F.max_pool2d(x, (2, 2)) + x = self.conv2(x) + x = self.bn2(x) + x = F.relu(x) + x = F.max_pool2d(x, (2, 2)) + x = torch.flatten(x, 1) + return self.fc1(x) + + model = net().eval() + model.to("cuda") + inputs = [torch.randn((1, 1024, 224, 224)).to("cuda")] + + enabled_precisions = {torch.float} + use_python_runtime = False + + exp_program = torch.export.export(model, tuple(inputs)) + + compilation_options = { + "use_python_runtime": use_python_runtime, + "enabled_precisions": enabled_precisions, + "min_block_size": 1, + "immutable_weights": True, + "reuse_cached_engines": False, + } + settings = CompilationSettings(**compilation_options) + with torchtrt.dynamo.Debugger( + log_level="debug", + logging_dir="/home/profile/logging/moe", + engine_builder_monitor=False, + ): + + exported_program = pre_export_lowering(exp_program, settings) + exported_program = exported_program.run_decompositions( + get_decompositions(False) + ) + + gm = exported_program.module() + gm = post_lowering(gm, settings) + + partitioned_module, supported_ops = partitioning.fast_partition( + gm, + min_block_size=settings.min_block_size, + torch_executed_ops=settings.torch_executed_ops, + require_full_compilation=settings.require_full_compilation, + skip_fusion=True, + ) + + partitioned_module = resource_partition( + gm, partitioned_module, cpu_memory_budget=2 * 1024 * 1024 * 1024 # 2GB, + ) + + self.assertEqual( + len(list[Any](partitioned_module.named_children())), + 2, + "The graph should have 2 subgraphs", + ) + + +if __name__ == "__main__": + run_tests() From ea6ebc5a70044b287ad1bb6800276a71397c94da Mon Sep 17 00:00:00 2001 From: cehongwang Date: Fri, 14 Nov 2025 20:17:10 +0000 Subject: [PATCH 03/10] Added example and fixed lru problem --- examples/dynamo/low_cpu_memory_compilation.py | 84 +++++++++++++++++++ py/torch_tensorrt/dynamo/_compiler.py | 7 +- .../dynamo/partitioning/_atomic_subgraphs.py | 25 ++++-- 3 files changed, 107 insertions(+), 9 deletions(-) create mode 100644 examples/dynamo/low_cpu_memory_compilation.py diff --git a/examples/dynamo/low_cpu_memory_compilation.py b/examples/dynamo/low_cpu_memory_compilation.py new file mode 100644 index 0000000000..2ff5356490 --- /dev/null +++ b/examples/dynamo/low_cpu_memory_compilation.py @@ -0,0 +1,84 @@ +""" + +.. _low_cpu_memory_compilation: + +Low CPU Memory Compilation Example +================================== + +This example demonstrates compiling a model with a bounded CPU (host) memory +budget using Torch-TensorRT Dynamo. Limiting host RAM use is helpful on +memory-constrained machines or when compiling very large models. + +Key notes: +- The toy model below has roughly 430 MB of parameters. We set the CPU + memory budget to 2 GiB. At compile time, only about 900 MB of host RAM + may remain available. We expect at most 403 * 4 = 1612 MB of memory to be used by the model. + So the model is partitioned into two subgraphs to fit the memory budget. + +- Performance impact varies by model. When the number of TensorRT engines + created is small, the impact is typically minimal. + +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch_tensorrt as torchtrt +from torch_tensorrt.dynamo.conversion import CompilationSettings + + +class net(nn.Module): + def __init__(self): + super().__init__() + # Intentionally large layers to stress host memory during compilation. + self.conv1 = nn.Conv2d(1024, 4096, 3, padding=1) + self.bn1 = nn.BatchNorm2d(4096) + self.conv2 = nn.Conv2d(4096, 1024, 3, padding=1) + self.bn2 = nn.BatchNorm2d(1024) + self.fc1 = nn.Linear(1024 * 56 * 56, 10) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = F.relu(x) + x = F.max_pool2d(x, (2, 2)) + x = self.conv2(x) + x = self.bn2(x) + x = F.relu(x) + x = F.max_pool2d(x, (2, 2)) + x = torch.flatten(x, 1) + return self.fc1(x) + + +model = net().eval() +model.to("cuda") +inputs = [torch.randn((1, 1024, 224, 224)).to("cuda")] + +enabled_precisions = {torch.float} +use_python_runtime = False + +compilation_options = { + "use_python_runtime": use_python_runtime, + "enabled_precisions": enabled_precisions, + "min_block_size": 1, + "immutable_weights": True, + "reuse_cached_engines": False, + "cpu_memory_budget": 2 * 1024 * 1024 * 1024, # 2 GiB in bytes +} + +settings = CompilationSettings(**compilation_options) +with torchtrt.dynamo.Debugger( + log_level="debug", + logging_dir="/home/profile/logging/moe", + engine_builder_monitor=False, +): + + exp_program = torch.export.export(model, tuple(inputs)) + trt_gm = torchtrt.dynamo.compile( + exp_program, + inputs=inputs, + **compilation_options, + ) + + # Expect two back-to-back TensorRT engines due to partitioning under the memory budget. + print(trt_gm) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 63fec44900..0052b6489d 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -40,6 +40,9 @@ post_lowering, pre_export_lowering, ) +from torch_tensorrt.dynamo.partitioning._resource_partitioner import ( + resource_partition, +) from torch_tensorrt.dynamo.utils import ( deallocate_module, get_cpu_memory_usage, @@ -863,10 +866,6 @@ def preserve_module_specs( require_full_compilation=settings.require_full_compilation, ) - from torch_tensorrt.dynamo.partitioning._resource_partitioner import ( - resource_partition, - ) - partitioned_module = resource_partition( gm, partitioned_module, diff --git a/py/torch_tensorrt/dynamo/partitioning/_atomic_subgraphs.py b/py/torch_tensorrt/dynamo/partitioning/_atomic_subgraphs.py index dbda162dcb..e9c0420add 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_atomic_subgraphs.py +++ b/py/torch_tensorrt/dynamo/partitioning/_atomic_subgraphs.py @@ -165,7 +165,6 @@ def forward( return x -@lru_cache(maxsize=None) def get_node_in_fusion_pattern( graph: torch.fx.Graph, ) -> Dict[torch.fx.Node, Set[torch.fx.Node]]: @@ -175,10 +174,8 @@ def get_node_in_fusion_pattern( Value: the list of nodes that should be fused together """ fusion_nodes = {} - for pattern, is_aten in ATOMIC_SUBGRAPHS: - pattern_graph = torch.fx.symbolic_trace(pattern()) - # TODO: Add decomposition and lowering if is_aten is False - subgraph_matcher = SubgraphMatcher(pattern_graph.graph) + for compiled_pattern_graph in get_compiled_atomic_subgraphs(): + subgraph_matcher = SubgraphMatcher(compiled_pattern_graph.graph) match_result = subgraph_matcher.match(graph) for match in match_result: fusion_group = { @@ -193,3 +190,21 @@ def get_node_in_fusion_pattern( fusion_nodes[node] = fusion_group return fusion_nodes + + +@lru_cache(maxsize=None) +def get_compiled_atomic_subgraphs() -> List[torch.fx.GraphModule]: + """ + This function gets the compiled atomic subgraphs from the graph. + LRU cache the result to avoid recompiling the same pattern multiple times. + """ + compiled_atomic_subgraphs = [] + for pattern, is_aten in ATOMIC_SUBGRAPHS: + pattern_graph = torch.fx.symbolic_trace(pattern()) + if not is_aten: + # TODO: Add decomposition and lowering if is_aten is False + raise NotImplementedError( + "Atomic subgraphs are not supported for non-aten subgraphs yet." + ) + compiled_atomic_subgraphs.append(pattern_graph) + return compiled_atomic_subgraphs From c526de967c7c2ed75ce9423bc1b7f2e2d6c0a109 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Wed, 19 Nov 2025 21:18:10 +0000 Subject: [PATCH 04/10] Fixed the comments --- py/torch_tensorrt/dynamo/_settings.py | 1 + py/torch_tensorrt/dynamo/partitioning/_resource_partitioner.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 52ac86012c..05f37fdb43 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -174,6 +174,7 @@ def __setstate__(self, state: dict[str, Any]) -> None: "enable_weight_streaming", "tiling_optimization_level", "l2_limit_for_tiling", + "cpu_memory_budget", ) diff --git a/py/torch_tensorrt/dynamo/partitioning/_resource_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_resource_partitioner.py index 0ef8d76e3a..e85a5c7056 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_resource_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_resource_partitioner.py @@ -242,7 +242,8 @@ def break_subgraphs( sizes = self.size_of_subgraphs(subgraphs) if sum(sizes) > subgraph_size_budget * 40: raise ValueError( - f"CPU memory budget or available memory is too small to compile the model. CPU memory budget: {self.cpu_memory_budget // (1024 * 1024) if self.cpu_memory_budget != -1 else "All available memory"} MB, Model size: {sum(sizes) // (1024 * 1024)} MB. " + "CPU memory budget or available memory is too small to compile the model. " + + f"CPU memory budget: {self.cpu_memory_budget // (1024 * 1024)} MB, Model size: {sum(sizes) // (1024 * 1024)} MB. " + "Consider setting cpu_memory_budget to a larger value or disable offload_module_to_cpu to save more CPU memory." ) for subgraph, size in zip(subgraphs, sizes): From 014e5390a236b79e280a4f649cfe54da737a05c1 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Wed, 19 Nov 2025 22:34:54 +0000 Subject: [PATCH 05/10] Fixed the comments --- examples/dynamo/low_cpu_memory_compilation.py | 26 +++++++++++++ .../dynamo/partitioning/_atomic_subgraphs.py | 37 ++++++++++++------- 2 files changed, 50 insertions(+), 13 deletions(-) diff --git a/examples/dynamo/low_cpu_memory_compilation.py b/examples/dynamo/low_cpu_memory_compilation.py index 2ff5356490..bf5c0bd43a 100644 --- a/examples/dynamo/low_cpu_memory_compilation.py +++ b/examples/dynamo/low_cpu_memory_compilation.py @@ -82,3 +82,29 @@ def forward(self, x): # Expect two back-to-back TensorRT engines due to partitioning under the memory budget. print(trt_gm) + + +""" +You should be able to see two back-to-back TensorRT engines in the graph +Graph Structure: + + Inputs: List[Tensor: (1, 1024, 224, 224)@float32] + ... + TRT Engine #1 - Submodule name: _run_on_acc_0 + Engine Inputs: List[Tensor: (1, 1024, 224, 224)@float32] + Number of Operators in Engine: 9 + Engine Outputs: List[Tensor: (1, 1024, 112, 112)@float32] + ... + TRT Engine #2 - Submodule name: _run_on_acc_1 + Engine Inputs: List[Tensor: (1, 1024, 112, 112)@float32] + Number of Operators in Engine: 3 + Engine Outputs: List[Tensor: (1, 10)@float32] + ... + Outputs: List[Tensor: (1, 10)@float32] + + +GraphModule( + (_run_on_acc_0): TorchTensorRTModule() + (_run_on_acc_1): TorchTensorRTModule() +) +""" diff --git a/py/torch_tensorrt/dynamo/partitioning/_atomic_subgraphs.py b/py/torch_tensorrt/dynamo/partitioning/_atomic_subgraphs.py index e9c0420add..a57778bdc4 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_atomic_subgraphs.py +++ b/py/torch_tensorrt/dynamo/partitioning/_atomic_subgraphs.py @@ -9,17 +9,17 @@ def register_atomic_subgraph( - is_aten: bool = False, + is_core_aten: bool = False, ) -> Callable[[torch.nn.Module], torch.nn.Module]: def decorator(subgraph: torch.nn.Module) -> torch.nn.Module: - ATOMIC_SUBGRAPHS.append((subgraph, is_aten)) + ATOMIC_SUBGRAPHS.append((subgraph, is_core_aten)) return subgraph return decorator -@register_atomic_subgraph(is_aten=True) +@register_atomic_subgraph(is_core_aten=True) class ConvBNReLU(torch.nn.Module): # type: ignore[misc] def __init__(self) -> None: super().__init__() @@ -60,7 +60,7 @@ def forward( return x -@register_atomic_subgraph(is_aten=True) +@register_atomic_subgraph(is_core_aten=True) class ConvReLU(torch.nn.Module): # type: ignore[misc] def __init__(self) -> None: super().__init__() @@ -92,7 +92,7 @@ def forward( return x -@register_atomic_subgraph(is_aten=True) +@register_atomic_subgraph(is_core_aten=True) class ConvGelu(torch.nn.Module): # type: ignore[misc] def __init__(self) -> None: super().__init__() @@ -124,7 +124,7 @@ def forward( return x -@register_atomic_subgraph(is_aten=True) +@register_atomic_subgraph(is_core_aten=True) class ConvSilu(torch.nn.Module): # type: ignore[misc] def __init__(self) -> None: super().__init__() @@ -139,7 +139,7 @@ def forward( return x -@register_atomic_subgraph(is_aten=True) +@register_atomic_subgraph(is_core_aten=True) class MulAdd(torch.nn.Module): # type: ignore[misc] def __init__(self) -> None: super().__init__() @@ -152,7 +152,7 @@ def forward( return x -@register_atomic_subgraph(is_aten=True) +@register_atomic_subgraph(is_core_aten=True) class MulMul(torch.nn.Module): # type: ignore[misc] def __init__(self) -> None: super().__init__() @@ -192,19 +192,30 @@ def get_node_in_fusion_pattern( return fusion_nodes -@lru_cache(maxsize=None) def get_compiled_atomic_subgraphs() -> List[torch.fx.GraphModule]: """ This function gets the compiled atomic subgraphs from the graph. LRU cache the result to avoid recompiling the same pattern multiple times. """ compiled_atomic_subgraphs = [] - for pattern, is_aten in ATOMIC_SUBGRAPHS: - pattern_graph = torch.fx.symbolic_trace(pattern()) - if not is_aten: - # TODO: Add decomposition and lowering if is_aten is False + for pattern, is_core_aten in ATOMIC_SUBGRAPHS: + pattern_graph = trace_atomic_graph(pattern, is_core_aten) + if not is_core_aten: + # TODO: Add decomposition and lowering if is_core_aten is False raise NotImplementedError( "Atomic subgraphs are not supported for non-aten subgraphs yet." ) compiled_atomic_subgraphs.append(pattern_graph) return compiled_atomic_subgraphs + + +@lru_cache(maxsize=None) +def trace_atomic_graph( + graph: torch.nn.Module, is_core_aten: bool = True +) -> torch.fx.GraphModule: + if is_core_aten: + return torch.fx.symbolic_trace(graph()) + else: + raise NotImplementedError( + "Resource partitioner currently does not support unlowered atomic subgraphs" + ) From 812a4577f4034a8dc26b5c2e50c4b0c9eefa30d0 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Fri, 21 Nov 2025 21:50:58 +0000 Subject: [PATCH 06/10] supported global partitioner --- py/torch_tensorrt/dynamo/_compiler.py | 4 +- .../partitioning/_adjacency_partitioner.py | 2 +- .../partitioning/_resource_partitioner.py | 110 ++++----- .../test_resource_partitioning.py | 222 ++++++++++++++++-- 4 files changed, 251 insertions(+), 87 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 0052b6489d..5fd3fdc897 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -867,7 +867,6 @@ def preserve_module_specs( ) partitioned_module = resource_partition( - gm, partitioned_module, cpu_memory_budget=settings.cpu_memory_budget, ) @@ -895,6 +894,7 @@ def preserve_module_specs( for attr in dir(gm): if attr.startswith("_frozen_param"): delattr(gm, attr) + for name, _ in partitioned_module.named_children(): submodule = getattr(partitioned_module, name) # filter on the GraphModule @@ -1357,7 +1357,7 @@ def convert_exported_program_to_serialized_trt_engine( ) flattened_input_list = get_flat_args_with_check( - exported_program, list(trt_arg_inputs), trt_kwarg_inputs # type: ignore + exported_program, list(trt_arg_inputs), trt_kwarg_inputs )[0] try: diff --git a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py index e2f544c2a7..72d0be42c7 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py @@ -230,7 +230,7 @@ def partition_graph(self) -> torch.fx.GraphModule: # Tag the accelerated nodes and split the graph accordingly self.tag(subgraphs) - return self.split() + return self.split(remove_tag=True) def starter_nodes(self) -> Tuple[NodeSet, NodeSet]: """Generates starter nodes for partitioning + segmentation""" diff --git a/py/torch_tensorrt/dynamo/partitioning/_resource_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_resource_partitioner.py index e85a5c7056..994ac1775e 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_resource_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_resource_partitioner.py @@ -50,6 +50,7 @@ import psutil import torch +from torch.fx.experimental.const_fold import _inline_module from torch.fx.passes.splitter_base import Subgraph, _SplitterBase from torch.fx.passes.tools_common import CALLABLE_NODE_OPS from torch_tensorrt.dynamo.partitioning._atomic_subgraphs import ( @@ -77,17 +78,16 @@ class ResourcePartitioner(_SplitterBase): # type: ignore def __init__( self, module: torch.fx.GraphModule, - partitioned_module: torch.fx.GraphModule, cpu_memory_budget: int, + submodule_name: str, ): assert isinstance(module, torch.fx.GraphModule) - assert isinstance(partitioned_module, torch.fx.GraphModule) self.module = module - self.partitioned_module = partitioned_module self.cpu_memory_budget = cpu_memory_budget - + self.resource_split_count = 0 + self.submodule_name = submodule_name self.deps = self.find_deps() self.non_acc_submodule_name = "_run_on_gpu_" @@ -119,64 +119,39 @@ def partition_graph(self) -> torch.fx.GraphModule: # Tag the accelerated nodes and split the graph accordingly self.tag(subgraphs) - gm = self.split() + gm = self.split(remove_tag=True) return gm - def put_nodes_into_subgraphs(self) -> list[Subgraph]: - """Map original graph nodes into capability-based subgraphs. - - - Iterates `partitioned_module` submodules to establish which node names - belong to which subgraph (accelerated or not). - - Builds a fusion pattern map for each subgraph so that known fusion groups remain intact. - Note that since fusion map is built for each subgraph, the capability partitioning can still break the fusion groups. - - Put the nodes into the subgraphs based on the capability partitioning. - - Verifies the resulting list of subgraphs is topologically ordered. + def tag(self, subgraphs: list[Subgraph]) -> None: + self.tags = [] + for subgraph in subgraphs: + tag = f"{self.submodule_name}_resource_split_{self.resource_split_count}" + self.resource_split_count += 1 + self.tags.append(tag) + for node in subgraph.nodes: + node.tag = tag + self._node_submodule_map[node.name] = tag + def put_nodes_into_subgraphs(self) -> list[Subgraph]: + """ + Put the nodes into the subgraphs and erase the tag from previous partitioner if it exists. Returns: - list[Subgraph]: Ordered subgraphs consisting of nodes in `module` based on capability partitioning. + list[Subgraph]: Ordered subgraphs consisting of nodes in `module`. """ - subgraphs_map = {} - subgraphs = [] - name_to_node_map = ( - {} - ) # We use this map to help map the nodes in partitioned module to the nodes in original module. - for name, _ in self.partitioned_module.named_children(): - # We first iterate over the partitioned module to find the subgraphs based on capability partitioning. - submodule = getattr(self.partitioned_module, name) - if not isinstance(submodule, torch.fx.graph_module.GraphModule): - continue - subgraph = Subgraph(is_acc="acc" in name, nodes=[]) - subgraphs.append(subgraph) - self.fusion_patterns.update(get_node_in_fusion_pattern(submodule.graph)) - - for node in submodule.graph.nodes: - # Erase the tag from previous partitioner if it exists - if hasattr(node, "tag"): - delattr(node, "tag") - - if node.op in CALLABLE_NODE_OPS: - # Store which subgraph the node should be put in - subgraphs_map[node.name] = subgraph - # We then iterate over the original module to put the nodes into the subgraphs. + nodes = [] for node in self.module.graph.nodes: if hasattr(node, "tag"): - # Erase the tag from previous partitioner - delattr(node, "tag") + del node.tag if node.op in CALLABLE_NODE_OPS: - name_to_node_map[node.name] = node - subgraphs_map[node.name].nodes.append(node) + nodes.append(node) + subgraphs = [Subgraph(is_acc=True, nodes=nodes)] + self.fusion_patterns = get_node_in_fusion_pattern(self.module.graph) assert self.check_topological_order( subgraphs ), "The subgraphs are not topologically ordered" - self.fusion_patterns = { - name_to_node_map[node.name]: [ - name_to_node_map[n.name] for n in fusion_nodes - ] - for node, fusion_nodes in self.fusion_patterns.items() - } return subgraphs @@ -240,6 +215,7 @@ def break_subgraphs( # We throw an error if the remaining memory is almost empty compared to the model size. # i.e. if the remaining memory is 4G (budget is 1G) the model size is greater than 40G, we stop the compilation. sizes = self.size_of_subgraphs(subgraphs) + # subgraph_size_budget = 500*1024*1024 if sum(sizes) > subgraph_size_budget * 40: raise ValueError( "CPU memory budget or available memory is too small to compile the model. " @@ -255,7 +231,9 @@ def break_subgraphs( size = size_1 new_subgraphs.append(broken_subgraphs[0]) subgraph = broken_subgraphs[1] - new_subgraphs.append(subgraph) + + if len(subgraph.nodes) != 0: + new_subgraphs.append(subgraph) self._varify_all_fusion_nodes_in_same_subgraph(new_subgraphs) @@ -325,8 +303,6 @@ def break_subgraph_by_size( if size_0 > size_to_break: break - if len(new_subgraphs[1].nodes) == 0: - new_subgraphs.pop(1) return new_subgraphs, size_0, size_1 def step_and_validate( @@ -530,7 +506,6 @@ def validate_and_correct_subgraphs( def resource_partition( gm: torch.fx.GraphModule, - partitioned_module: torch.fx.GraphModule, cpu_memory_budget: int, ) -> torch.fx.GraphModule: """Resource-aware partitioning entry point. @@ -552,12 +527,29 @@ def resource_partition( """ # Construct - partitioner = ResourcePartitioner( - gm, - partitioned_module, - cpu_memory_budget=cpu_memory_budget, - ) + for name, _ in gm.named_children(): + submodule = getattr(gm, name) + if ( + not isinstance(submodule, torch.fx.graph_module.GraphModule) + or "_run_on_acc" not in name + ): + continue + partitioner = ResourcePartitioner( + submodule, + submodule_name=name, + cpu_memory_budget=cpu_memory_budget, + ) + + partitioned_graph = partitioner.partition_graph() + setattr(gm, name, partitioned_graph) - partitioned_graph = partitioner.partition_graph() + for name, module in list(gm.named_children()): + if "_run_on_acc" in name: + for subname, submodule in module.named_children(): + if "resource_split" in subname: + setattr(gm, subname, submodule) + _inline_module(gm, name) + delattr(gm, name) - return partitioned_graph + gm.recompile() + return gm diff --git a/tests/py/dynamo/partitioning/test_resource_partitioning.py b/tests/py/dynamo/partitioning/test_resource_partitioning.py index f059b0b166..5777fd938f 100644 --- a/tests/py/dynamo/partitioning/test_resource_partitioning.py +++ b/tests/py/dynamo/partitioning/test_resource_partitioning.py @@ -56,37 +56,209 @@ def forward(self, x): "reuse_cached_engines": False, } settings = CompilationSettings(**compilation_options) - with torchtrt.dynamo.Debugger( - log_level="debug", - logging_dir="/home/profile/logging/moe", - engine_builder_monitor=False, - ): - - exported_program = pre_export_lowering(exp_program, settings) - exported_program = exported_program.run_decompositions( - get_decompositions(False) - ) - gm = exported_program.module() - gm = post_lowering(gm, settings) + exported_program = pre_export_lowering(exp_program, settings) + exported_program = exported_program.run_decompositions( + get_decompositions(False) + ) - partitioned_module, supported_ops = partitioning.fast_partition( - gm, - min_block_size=settings.min_block_size, - torch_executed_ops=settings.torch_executed_ops, - require_full_compilation=settings.require_full_compilation, - skip_fusion=True, - ) + gm = exported_program.module() + gm = post_lowering(gm, settings) + + partitioned_module, supported_ops = partitioning.fast_partition( + gm, + min_block_size=settings.min_block_size, + torch_executed_ops=settings.torch_executed_ops, + require_full_compilation=settings.require_full_compilation, + skip_fusion=True, + ) + + partitioned_module = resource_partition( + partitioned_module, cpu_memory_budget=2 * 1024 * 1024 * 1024 # 2GB, + ) + + self.assertEqual( + len(list[Any](partitioned_module.named_children())), + 2, + "The graph should have 2 subgraphs", + ) + + def test_resource_partitioning_with_capability_partitioning(self): + class net(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(1024, 4096, 3, padding=1) + self.bn1 = nn.BatchNorm2d(4096) + self.conv2 = nn.Conv2d(4096, 4096, 3, padding=1) + self.bn2 = nn.BatchNorm2d(4096) + + self.conv3 = nn.Conv2d(4096, 4096, 3, padding=1) + self.bn3 = nn.BatchNorm2d(4096) + self.conv4 = nn.Conv2d(4096, 1024, 3, padding=1) + self.bn4 = nn.BatchNorm2d(1024) + + self.fc1 = nn.Linear(1024 * 56 * 56, 10) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = F.relu(x) + x = self.conv2(x) + x = self.bn2(x) + x = F.relu(x) + x = F.max_pool2d(x, (2, 2)) + x = self.conv3(x) + x = self.bn3(x) + x = F.relu(x) + x = self.conv4(x) + x = self.bn4(x) + x = F.relu(x) + x = F.max_pool2d(x, (2, 2)) + x = torch.flatten(x, 1) + return self.fc1(x) + + model = net().eval() + model.to("cuda") + inputs = [torch.randn((1, 1024, 224, 224)).to("cuda")] + + enabled_precisions = {torch.float} + use_python_runtime = False + + exp_program = torch.export.export(model, tuple(inputs)) + + compilation_options = { + "use_python_runtime": use_python_runtime, + "enabled_precisions": enabled_precisions, + "min_block_size": 1, + "immutable_weights": True, + "reuse_cached_engines": False, + "torch_executed_ops": {"torch.ops.aten.max_pool2d.default"}, + } + settings = CompilationSettings(**compilation_options) + + exported_program = pre_export_lowering(exp_program, settings) + exported_program = exported_program.run_decompositions( + get_decompositions(False) + ) + + gm = exported_program.module() + gm = post_lowering(gm, settings) - partitioned_module = resource_partition( - gm, partitioned_module, cpu_memory_budget=2 * 1024 * 1024 * 1024 # 2GB, + partitioned_module, supported_ops = partitioning.fast_partition( + gm, + min_block_size=settings.min_block_size, + torch_executed_ops=settings.torch_executed_ops, + require_full_compilation=settings.require_full_compilation, + skip_fusion=True, + ) + + partitioned_module = resource_partition( + partitioned_module, cpu_memory_budget=1.7 * 1024 * 1024 * 1024 # 1.7GB, + ) + + assert ( + len( + [ + name + for name, _ in partitioned_module.named_children() + if "_run_on_acc" in name + ] + ) + == 5 + ), "The graph should have 5 accelerated subgraphs" + assert ( + len( + [ + name + for name, _ in partitioned_module.named_children() + if "_run_on_gpu" in name + ] ) + == 2 + ), "The graph should have 2 non-accelerated subgraphs" + + def test_resource_partitioning_with_global_capability_partitioning(self): + class net(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(1024, 4096, 3, padding=1) + self.bn1 = nn.BatchNorm2d(4096) + self.conv2 = nn.Conv2d(4096, 4096, 3, padding=1) + self.bn2 = nn.BatchNorm2d(4096) + + self.conv3 = nn.Conv2d(4096, 4096, 3, padding=1) + self.bn3 = nn.BatchNorm2d(4096) + self.conv4 = nn.Conv2d(4096, 1024, 3, padding=1) + self.bn4 = nn.BatchNorm2d(1024) + + self.fc1 = nn.Linear(1024 * 56 * 56, 10) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = F.relu(x) + x = self.conv2(x) + x = self.bn2(x) + x = F.relu(x) + x = F.max_pool2d(x, (2, 2)) + x = self.conv3(x) + x = self.bn3(x) + x = F.relu(x) + x = self.conv4(x) + x = self.bn4(x) + x = F.relu(x) + x = F.max_pool2d(x, (2, 2)) + x = torch.flatten(x, 1) + return self.fc1(x) + + model = net().eval() + model.to("cuda") + inputs = [torch.randn((1, 1024, 224, 224)).to("cuda")] + + enabled_precisions = {torch.float} + use_python_runtime = False + + exp_program = torch.export.export(model, tuple(inputs)) + + compilation_options = { + "use_python_runtime": use_python_runtime, + "enabled_precisions": enabled_precisions, + "min_block_size": 1, + "immutable_weights": True, + "reuse_cached_engines": False, + "torch_executed_ops": {"torch.ops.aten.max_pool2d.default"}, + } + settings = CompilationSettings(**compilation_options) + + exported_program = pre_export_lowering(exp_program, settings) + exported_program = exported_program.run_decompositions( + get_decompositions(False) + ) + + gm = exported_program.module() + gm = post_lowering(gm, settings) + + partitioned_module, supported_ops = partitioning.global_partition( + gm, + min_block_size=settings.min_block_size, + torch_executed_ops=settings.torch_executed_ops, + require_full_compilation=settings.require_full_compilation, + ) + + partitioned_module = resource_partition( + partitioned_module, cpu_memory_budget=1.7 * 1024 * 1024 * 1024 # 1.7GB, + ) - self.assertEqual( - len(list[Any](partitioned_module.named_children())), - 2, - "The graph should have 2 subgraphs", + assert ( + len( + [ + name + for name, _ in partitioned_module.named_children() + if "_run_on_acc" in name + ] ) + == 5 + ), "The graph should have 5 accelerated subgraphs" if __name__ == "__main__": From 3544481a76af85752a5f559d7f882f2b3d438f13 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Fri, 21 Nov 2025 22:53:33 +0000 Subject: [PATCH 07/10] Added atomic_subgraph template --- .../dynamo/partitioning/_atomic_subgraphs.py | 88 ++++++------------- .../partitioning/_resource_partitioner.py | 1 - 2 files changed, 25 insertions(+), 64 deletions(-) diff --git a/py/torch_tensorrt/dynamo/partitioning/_atomic_subgraphs.py b/py/torch_tensorrt/dynamo/partitioning/_atomic_subgraphs.py index a57778bdc4..85cab6eba6 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_atomic_subgraphs.py +++ b/py/torch_tensorrt/dynamo/partitioning/_atomic_subgraphs.py @@ -1,5 +1,5 @@ from functools import lru_cache -from typing import Callable, Dict, List, Set +from typing import Any, Callable, Dict, List, Set, Tuple import torch from torch.fx.passes.utils.matcher_utils import SubgraphMatcher @@ -9,20 +9,25 @@ def register_atomic_subgraph( + init_args: Tuple[Any, ...] = tuple(), is_core_aten: bool = False, ) -> Callable[[torch.nn.Module], torch.nn.Module]: def decorator(subgraph: torch.nn.Module) -> torch.nn.Module: - ATOMIC_SUBGRAPHS.append((subgraph, is_core_aten)) + ATOMIC_SUBGRAPHS.append((subgraph, init_args, is_core_aten)) return subgraph return decorator -@register_atomic_subgraph(is_core_aten=True) -class ConvBNReLU(torch.nn.Module): # type: ignore[misc] - def __init__(self) -> None: +@register_atomic_subgraph(init_args=(aten.silu.default,), is_core_aten=True) +@register_atomic_subgraph(init_args=(aten.gelu.default,), is_core_aten=True) +@register_atomic_subgraph(init_args=(aten.relu.default,), is_core_aten=True) +@register_atomic_subgraph(init_args=(aten.sigmoid.default,), is_core_aten=True) +class ConvBNActivation(torch.nn.Module): # type: ignore[misc] + def __init__(self, activation: torch._ops.OpOverload) -> None: super().__init__() + self.activation = activation def forward( self, @@ -56,46 +61,18 @@ def forward( x = aten._native_batch_norm_legit_no_training.default( x, bn_weight, bn_bias, running_mean, running_var, momentum, eps )[0] - x = aten.relu.default(x) - return x - - -@register_atomic_subgraph(is_core_aten=True) -class ConvReLU(torch.nn.Module): # type: ignore[misc] - def __init__(self) -> None: - super().__init__() - - def forward( - self, - x: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor, - stride: List[int], - padding: List[int], - dilation: List[int], - transposed: bool, - output_padding: List[int], - groups: int, - ) -> torch.Tensor: - x = aten.convolution.default( - x, - weight, - bias, - stride, - padding, - dilation, - transposed, - output_padding, - groups, - ) - x = aten.relu.default(x) + x = self.activation(x) return x -@register_atomic_subgraph(is_core_aten=True) -class ConvGelu(torch.nn.Module): # type: ignore[misc] - def __init__(self) -> None: +@register_atomic_subgraph(init_args=(aten.silu.default,), is_core_aten=True) +@register_atomic_subgraph(init_args=(aten.gelu.default,), is_core_aten=True) +@register_atomic_subgraph(init_args=(aten.relu.default,), is_core_aten=True) +@register_atomic_subgraph(init_args=(aten.sigmoid.default,), is_core_aten=True) +class ConvActivation(torch.nn.Module): # type: ignore[misc] + def __init__(self, activation: torch._ops.OpOverload) -> None: super().__init__() + self.activation = activation def forward( self, @@ -120,26 +97,11 @@ def forward( output_padding, groups, ) - x = aten.gelu.default(x) - return x - - -@register_atomic_subgraph(is_core_aten=True) -class ConvSilu(torch.nn.Module): # type: ignore[misc] - def __init__(self) -> None: - super().__init__() - - def forward( - self, x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor - ) -> torch.Tensor: - x = aten.convolution.default( - x, weight, bias, [1, 1], [1, 1], [1, 1], False, [0, 0], 1 - ) - x = aten.silu.default(x) + x = self.activation(x) return x -@register_atomic_subgraph(is_core_aten=True) +@register_atomic_subgraph(init_args=(), is_core_aten=True) class MulAdd(torch.nn.Module): # type: ignore[misc] def __init__(self) -> None: super().__init__() @@ -152,7 +114,7 @@ def forward( return x -@register_atomic_subgraph(is_core_aten=True) +@register_atomic_subgraph(init_args=(), is_core_aten=True) class MulMul(torch.nn.Module): # type: ignore[misc] def __init__(self) -> None: super().__init__() @@ -198,8 +160,8 @@ def get_compiled_atomic_subgraphs() -> List[torch.fx.GraphModule]: LRU cache the result to avoid recompiling the same pattern multiple times. """ compiled_atomic_subgraphs = [] - for pattern, is_core_aten in ATOMIC_SUBGRAPHS: - pattern_graph = trace_atomic_graph(pattern, is_core_aten) + for pattern, init_args, is_core_aten in ATOMIC_SUBGRAPHS: + pattern_graph = trace_atomic_graph(pattern, init_args, is_core_aten) if not is_core_aten: # TODO: Add decomposition and lowering if is_core_aten is False raise NotImplementedError( @@ -211,10 +173,10 @@ def get_compiled_atomic_subgraphs() -> List[torch.fx.GraphModule]: @lru_cache(maxsize=None) def trace_atomic_graph( - graph: torch.nn.Module, is_core_aten: bool = True + graph: torch.nn.Module, init_args: Any, is_core_aten: bool = True ) -> torch.fx.GraphModule: if is_core_aten: - return torch.fx.symbolic_trace(graph()) + return torch.fx.symbolic_trace(graph(*init_args)) else: raise NotImplementedError( "Resource partitioner currently does not support unlowered atomic subgraphs" diff --git a/py/torch_tensorrt/dynamo/partitioning/_resource_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_resource_partitioner.py index 994ac1775e..c6d1282543 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_resource_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_resource_partitioner.py @@ -90,7 +90,6 @@ def __init__( self.submodule_name = submodule_name self.deps = self.find_deps() - self.non_acc_submodule_name = "_run_on_gpu_" self._node_submodule_map: Dict[str, str] = {} self._return_tuple = False self.fusion_patterns: Dict[torch.fx.Node, List[torch.fx.Node]] = {} From 6a879b21905312c4cc975605d6efde3d4205cf39 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Mon, 1 Dec 2025 16:57:45 +0000 Subject: [PATCH 08/10] Fixed the comments. Added more tests --- examples/dynamo/low_cpu_memory_compilation.py | 27 +- py/torch_tensorrt/dynamo/_compiler.py | 4 - .../dynamo/partitioning/_atomic_subgraphs.py | 5 +- .../partitioning/_resource_partitioner.py | 11 +- .../test_resource_partitioning.py | 242 +++++++++++++++++- 5 files changed, 268 insertions(+), 21 deletions(-) diff --git a/examples/dynamo/low_cpu_memory_compilation.py b/examples/dynamo/low_cpu_memory_compilation.py index bf5c0bd43a..30906e91f0 100644 --- a/examples/dynamo/low_cpu_memory_compilation.py +++ b/examples/dynamo/low_cpu_memory_compilation.py @@ -86,25 +86,44 @@ def forward(self, x): """ You should be able to see two back-to-back TensorRT engines in the graph + Graph Structure: Inputs: List[Tensor: (1, 1024, 224, 224)@float32] ... - TRT Engine #1 - Submodule name: _run_on_acc_0 + TRT Engine #1 - Submodule name: _run_on_acc_0_resource_split_0 Engine Inputs: List[Tensor: (1, 1024, 224, 224)@float32] Number of Operators in Engine: 9 Engine Outputs: List[Tensor: (1, 1024, 112, 112)@float32] ... - TRT Engine #2 - Submodule name: _run_on_acc_1 + TRT Engine #2 - Submodule name: _run_on_acc_0_resource_split_1 Engine Inputs: List[Tensor: (1, 1024, 112, 112)@float32] Number of Operators in Engine: 3 Engine Outputs: List[Tensor: (1, 10)@float32] ... Outputs: List[Tensor: (1, 10)@float32] + ------------------------- Aggregate Stats ------------------------- + + Average Number of Operators per TRT Engine: 6.0 + Most Operators in a TRT Engine: 9 + ********** Recommendations ********** + + - For minimal graph segmentation, select min_block_size=9 which would generate 1 TRT engine(s) + - For moderate graph segmentation, select min_block_size=6 which would generate 1 TRT engine(s) + - The current level of graph segmentation is equivalent to selecting min_block_size=3 which generates 2 TRT engine(s) GraphModule( - (_run_on_acc_0): TorchTensorRTModule() - (_run_on_acc_1): TorchTensorRTModule() + (_run_on_acc_0_resource_split_0): TorchTensorRTModule() + (_run_on_acc_0_resource_split_1): TorchTensorRTModule() +) + + + +def forward(self, x): + x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) + _run_on_acc_0_resource_split_0 = self._run_on_acc_0_resource_split_0(x); x = None + _run_on_acc_0_resource_split_1 = self._run_on_acc_0_resource_split_1(_run_on_acc_0_resource_split_0); _run_on_acc_0_resource_split_0 = None + return pytree.tree_unflatten((_run_on_acc_0_resource_split_1,), self._out_spec) ) """ diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 5fd3fdc897..7ca8f714e5 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -622,10 +622,6 @@ def compile( "'arg_inputs' and 'inputs' should not be used at the same time." ) - assert ( - cpu_memory_budget >= 2 * 1024 * 1024 * 1024 - ), "CPU memory budget must be greater than 10GB" - arg_inputs = inputs or arg_inputs if kwarg_inputs is None: diff --git a/py/torch_tensorrt/dynamo/partitioning/_atomic_subgraphs.py b/py/torch_tensorrt/dynamo/partitioning/_atomic_subgraphs.py index 85cab6eba6..b55fc0d873 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_atomic_subgraphs.py +++ b/py/torch_tensorrt/dynamo/partitioning/_atomic_subgraphs.py @@ -1,3 +1,4 @@ +from collections import defaultdict from functools import lru_cache from typing import Any, Callable, Dict, List, Set, Tuple @@ -135,7 +136,7 @@ def get_node_in_fusion_pattern( Key: node that appears in the fusion pattern Value: the list of nodes that should be fused together """ - fusion_nodes = {} + fusion_nodes = defaultdict(set) for compiled_pattern_graph in get_compiled_atomic_subgraphs(): subgraph_matcher = SubgraphMatcher(compiled_pattern_graph.graph) match_result = subgraph_matcher.match(graph) @@ -149,7 +150,7 @@ def get_node_in_fusion_pattern( and node not in match.placeholder_nodes } for node in fusion_group: - fusion_nodes[node] = fusion_group + fusion_nodes[node].update(fusion_group) return fusion_nodes diff --git a/py/torch_tensorrt/dynamo/partitioning/_resource_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_resource_partitioner.py index c6d1282543..ee14deee61 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_resource_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_resource_partitioner.py @@ -46,7 +46,7 @@ """ import logging -from typing import Dict, List, Tuple +from typing import Dict, List, Set, Tuple import psutil import torch @@ -92,7 +92,7 @@ def __init__( self._node_submodule_map: Dict[str, str] = {} self._return_tuple = False - self.fusion_patterns: Dict[torch.fx.Node, List[torch.fx.Node]] = {} + self.fusion_patterns: Dict[torch.fx.Node, Set[torch.fx.Node]] = {} def partition_graph(self) -> torch.fx.GraphModule: """Build the final partitioned `GraphModule` honoring memory constraints. @@ -214,7 +214,6 @@ def break_subgraphs( # We throw an error if the remaining memory is almost empty compared to the model size. # i.e. if the remaining memory is 4G (budget is 1G) the model size is greater than 40G, we stop the compilation. sizes = self.size_of_subgraphs(subgraphs) - # subgraph_size_budget = 500*1024*1024 if sum(sizes) > subgraph_size_budget * 40: raise ValueError( "CPU memory budget or available memory is too small to compile the model. " @@ -470,12 +469,6 @@ def validate_and_correct_subgraphs( visited_nodes[subgraph.nodes[-1]] = i + 1 continue - elif not subgraph.is_acc: - # non-accelerated subgraphs should be put in the next subgraph - for node in subgraph.nodes: - visited_nodes[subgraph.nodes[-1]] = i + 1 - continue - else: to_remove_nodes = [] for j, node in enumerate(subgraph.nodes): diff --git a/tests/py/dynamo/partitioning/test_resource_partitioning.py b/tests/py/dynamo/partitioning/test_resource_partitioning.py index 5777fd938f..bb38a13d44 100644 --- a/tests/py/dynamo/partitioning/test_resource_partitioning.py +++ b/tests/py/dynamo/partitioning/test_resource_partitioning.py @@ -1,9 +1,11 @@ -from typing import Any +from typing import Any, List import torch import torch.nn as nn import torch.nn.functional as F import torch_tensorrt as torchtrt +from torch.fx.passes.splitter_base import Subgraph +from torch.ops import aten from torch.testing._internal.common_utils import TestCase, run_tests from torch_tensorrt.dynamo import partitioning from torch_tensorrt.dynamo.conversion import CompilationSettings @@ -13,7 +15,14 @@ pre_export_lowering, ) from torch_tensorrt.dynamo.lowering.passes import post_lowering, pre_export_lowering -from torch_tensorrt.dynamo.partitioning._resource_partitioner import resource_partition +from torch_tensorrt.dynamo.partitioning._atomic_subgraphs import ( + ATOMIC_SUBGRAPHS, + register_atomic_subgraph, +) +from torch_tensorrt.dynamo.partitioning._resource_partitioner import ( + ResourcePartitioner, + resource_partition, +) class TestResourcePartitioning(TestCase): @@ -83,6 +92,8 @@ def forward(self, x): "The graph should have 2 subgraphs", ) + torch._dynamo.reset() + def test_resource_partitioning_with_capability_partitioning(self): class net(nn.Module): def __init__(self): @@ -177,6 +188,231 @@ def forward(self, x): == 2 ), "The graph should have 2 non-accelerated subgraphs" + torch._dynamo.reset() + + def test_resource_partitioning_with_capability_partitioning_and_atomic_subgraphs( + self, + ): + """ + After defining the atomic subgraphs, the resource partitioner will not be able to find valid partition in the subgraph. + So there should only be 3 accelerated subgraphs and 2 non-accelerated subgraphs. + """ + + @register_atomic_subgraph(init_args=(), is_core_aten=True) + class ReLUConv(nn.Module): + def forward( + self, + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: List[int], + padding: List[int], + dilation: List[int], + transposed: bool, + output_padding: List[int], + groups: int, + ) -> torch.Tensor: + x = aten.relu.default(x) + x = aten.convolution.default( + x, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + ) + return x + + class net(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(1024, 4096, 3, padding=1) + self.bn1 = nn.BatchNorm2d(4096) + self.conv2 = nn.Conv2d(4096, 4096, 3, padding=1) + self.bn2 = nn.BatchNorm2d(4096) + + self.conv3 = nn.Conv2d(4096, 4096, 3, padding=1) + self.bn3 = nn.BatchNorm2d(4096) + self.conv4 = nn.Conv2d(4096, 1024, 3, padding=1) + self.bn4 = nn.BatchNorm2d(1024) + + self.fc1 = nn.Linear(1024 * 56 * 56, 10) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = F.relu(x) + x = self.conv2(x) + x = self.bn2(x) + x = F.relu(x) + x = F.max_pool2d(x, (2, 2)) + x = self.conv3(x) + x = self.bn3(x) + x = F.relu(x) + x = self.conv4(x) + x = self.bn4(x) + x = F.relu(x) + x = F.max_pool2d(x, (2, 2)) + x = torch.flatten(x, 1) + return self.fc1(x) + + model = net().eval() + model.to("cuda") + inputs = [torch.randn((1, 1024, 224, 224)).to("cuda")] + + enabled_precisions = {torch.float} + use_python_runtime = False + + exp_program = torch.export.export(model, tuple(inputs)) + + compilation_options = { + "use_python_runtime": use_python_runtime, + "enabled_precisions": enabled_precisions, + "min_block_size": 1, + "immutable_weights": True, + "reuse_cached_engines": False, + "torch_executed_ops": {"torch.ops.aten.max_pool2d.default"}, + } + settings = CompilationSettings(**compilation_options) + + exported_program = pre_export_lowering(exp_program, settings) + exported_program = exported_program.run_decompositions( + get_decompositions(False) + ) + + gm = exported_program.module() + gm = post_lowering(gm, settings) + + partitioned_module, supported_ops = partitioning.fast_partition( + gm, + min_block_size=settings.min_block_size, + torch_executed_ops=settings.torch_executed_ops, + require_full_compilation=settings.require_full_compilation, + skip_fusion=True, + ) + + partitioned_module = resource_partition( + partitioned_module, cpu_memory_budget=1.7 * 1024 * 1024 * 1024 # 1.7GB, + ) + + assert ( + len( + [ + name + for name, _ in partitioned_module.named_children() + if "_run_on_acc" in name + ] + ) + == 3 + ), "The graph should have 3 accelerated subgraphs" + assert ( + len( + [ + name + for name, _ in partitioned_module.named_children() + if "_run_on_gpu" in name + ] + ) + == 2 + ), "The graph should have 2 non-accelerated subgraphs" + + ATOMIC_SUBGRAPHS.remove((ReLUConv, (), True)) + + torch._dynamo.reset() + + def test_atomic_subgraph_correction(self): + class net(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 3, 3, padding=1) + self.bn1 = nn.BatchNorm2d(3) + self.relu = nn.ReLU() + self.fc = nn.Linear(3 * 224 * 224, 10) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = torch.flatten(x, 1) + x = self.fc(x) + return x + + model = net().eval() + model.to("cuda") + inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] + + enabled_precisions = {torch.float} + use_python_runtime = False + + exp_program = torch.export.export(model, tuple(inputs)) + + compilation_options = { + "use_python_runtime": use_python_runtime, + "enabled_precisions": enabled_precisions, + "min_block_size": 1, + "immutable_weights": True, + "reuse_cached_engines": False, + } + settings = CompilationSettings(**compilation_options) + + exported_program = pre_export_lowering(exp_program, settings) + exported_program = exported_program.run_decompositions( + get_decompositions(False) + ) + + gm = exported_program.module() + gm = post_lowering(gm, settings) + + partitioned_module, supported_ops = partitioning.fast_partition( + gm, + min_block_size=settings.min_block_size, + torch_executed_ops=settings.torch_executed_ops, + require_full_compilation=settings.require_full_compilation, + skip_fusion=True, + ) + + for name, _ in partitioned_module.named_children(): + submodule = getattr(partitioned_module, name) + if ( + not isinstance(submodule, torch.fx.graph_module.GraphModule) + or "_run_on_acc" not in name + ): + continue + partitioner = ResourcePartitioner( + submodule, + submodule_name=name, + cpu_memory_budget=2 * 1024 * 1024 * 1024, + ) + subgraphs = partitioner.put_nodes_into_subgraphs() + new_subgraphs = [] + current_subgraph = [] + # Split the subgraph into two subgraphs by the ReLU node, which breaks the fusion group. + for node in subgraphs[0].nodes: + if node.op == "call_function" and node.target == aten.relu.default: + new_subgraphs.append(Subgraph(is_acc=True, nodes=current_subgraph)) + current_subgraph = [] + current_subgraph.append(node) + if current_subgraph: + new_subgraphs.append(Subgraph(is_acc=True, nodes=current_subgraph)) + + leaf_node = partitioner.get_leaf_node(new_subgraphs[0].nodes) + broken_fusion = partitioner.step_if_break_fusion( + new_subgraphs, + leaf_node, + set(new_subgraphs[0].nodes), + set(new_subgraphs[1].nodes), + ) + # The fusion was broken + assert broken_fusion + + # The fusion should be fixed after the step + partitioner._varify_all_fusion_nodes_in_same_subgraph(new_subgraphs) + + break + def test_resource_partitioning_with_global_capability_partitioning(self): class net(nn.Module): def __init__(self): @@ -260,6 +496,8 @@ def forward(self, x): == 5 ), "The graph should have 5 accelerated subgraphs" + torch._dynamo.reset() + if __name__ == "__main__": run_tests() From 9ac07a93c2c2c354a815a72fbd1ee0918e1c1c14 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Mon, 1 Dec 2025 19:56:56 +0000 Subject: [PATCH 09/10] Fixed the comments --- examples/dynamo/low_cpu_memory_compilation.py | 1 + py/torch_tensorrt/dynamo/_compiler.py | 22 +++++++--- py/torch_tensorrt/dynamo/_defaults.py | 4 +- py/torch_tensorrt/dynamo/_settings.py | 4 +- .../partitioning/_resource_partitioner.py | 44 ++++++++++++------- .../test_resource_partitioning.py | 11 +++-- 6 files changed, 57 insertions(+), 29 deletions(-) diff --git a/examples/dynamo/low_cpu_memory_compilation.py b/examples/dynamo/low_cpu_memory_compilation.py index 30906e91f0..c508d3f0b4 100644 --- a/examples/dynamo/low_cpu_memory_compilation.py +++ b/examples/dynamo/low_cpu_memory_compilation.py @@ -63,6 +63,7 @@ def forward(self, x): "min_block_size": 1, "immutable_weights": True, "reuse_cached_engines": False, + "enable_resource_partitioning": True, "cpu_memory_budget": 2 * 1024 * 1024 * 1024, # 2 GiB in bytes } diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 7ca8f714e5..e0b6bbbff6 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -108,7 +108,8 @@ def cross_compile_for_windows( l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING, offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU, use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE, - cpu_memory_budget: int = _defaults.CPU_MEMORY_BUDGET, + enable_resource_partitioning: bool = _defaults.ENABLE_RESOURCE_PARTITIONING, + cpu_memory_budget: Optional[int] = _defaults.CPU_MEMORY_BUDGET, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module using TensorRT in Linux for Inference in Windows @@ -183,7 +184,8 @@ def cross_compile_for_windows( tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"]. l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit). use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model - cpu_memory_budget (int): The maximum amount of CPU memory to use for the compilation. If the compilation requires more memory than this budget, the compilation will fail. If set to -1, the compilation will use all available CPU memory. + enable_resource_partitioning (bool): Enable resource-aware partitioning. This is useful when the model is large and the CPU memory is limited. + cpu_memory_budget (Optional[int]): The maximum amount of CPU memory to use for the compilation. If the compilation requires more memory than this budget, the compilation will fail. **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -339,6 +341,7 @@ def cross_compile_for_windows( "tiling_optimization_level": tiling_optimization_level, "l2_limit_for_tiling": l2_limit_for_tiling, "use_distributed_mode_trace": use_distributed_mode_trace, + "enable_resource_partitioning": enable_resource_partitioning, "cpu_memory_budget": cpu_memory_budget, } @@ -441,7 +444,8 @@ def compile( l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING, offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU, use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE, - cpu_memory_budget: int = _defaults.CPU_MEMORY_BUDGET, + cpu_memory_budget: Optional[int] = _defaults.CPU_MEMORY_BUDGET, + enable_resource_partitioning: bool = _defaults.ENABLE_RESOURCE_PARTITIONING, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module for NVIDIA GPUs using TensorRT @@ -519,6 +523,8 @@ def compile( l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit). offload_module_to_cpu (bool): Offload the module to CPU. This is useful when we need to minimize GPU memory usage. use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model + enable_resource_partitioning (bool): Enable resource-aware partitioning. This is useful when the model is large and the CPU memory is limited. + cpu_memory_budget (Optional[int]): The maximum amount of CPU memory to use for the compilation. If the compilation requires more memory than this budget, the compilation will fail. **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -688,6 +694,7 @@ def compile( "l2_limit_for_tiling": l2_limit_for_tiling, "offload_module_to_cpu": offload_module_to_cpu, "use_distributed_mode_trace": use_distributed_mode_trace, + "enable_resource_partitioning": enable_resource_partitioning, "cpu_memory_budget": cpu_memory_budget, } logger.debug(f"CPU memory usage before lowering: {get_cpu_memory_usage()} MB") @@ -862,10 +869,11 @@ def preserve_module_specs( require_full_compilation=settings.require_full_compilation, ) - partitioned_module = resource_partition( - partitioned_module, - cpu_memory_budget=settings.cpu_memory_budget, - ) + if settings.enable_resource_partitioning: + partitioned_module = resource_partition( + partitioned_module, + cpu_memory_budget=settings.cpu_memory_budget, + ) dryrun_tracker.unsupported_ops = supported_ops.unsupported_operators diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 0b4c0a2b54..f2e35ce9f6 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -2,7 +2,6 @@ import platform import tempfile -import psutil import torch from torch_tensorrt._Device import Device from torch_tensorrt._enums import EngineCapability, dtype @@ -58,7 +57,8 @@ L2_LIMIT_FOR_TILING = -1 USE_DISTRIBUTED_MODE_TRACE = False OFFLOAD_MODULE_TO_CPU = False -CPU_MEMORY_BUDGET = psutil.virtual_memory().available +ENABLE_RESOURCE_PARTITIONING = False +CPU_MEMORY_BUDGET = None if platform.system() == "Linux": import pwd diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 05f37fdb43..797292ce97 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -15,6 +15,7 @@ DRYRUN, ENABLE_CROSS_COMPILE_FOR_WINDOWS, ENABLE_EXPERIMENTAL_DECOMPOSITIONS, + ENABLE_RESOURCE_PARTITIONING, ENABLE_WEIGHT_STREAMING, ENABLED_PRECISIONS, ENGINE_CAPABILITY, @@ -141,6 +142,7 @@ class CompilationSettings: l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE offload_module_to_cpu: bool = OFFLOAD_MODULE_TO_CPU + enable_resource_partitioning: bool = ENABLE_RESOURCE_PARTITIONING cpu_memory_budget: int = CPU_MEMORY_BUDGET def __getstate__(self) -> dict[str, Any]: @@ -174,7 +176,7 @@ def __setstate__(self, state: dict[str, Any]) -> None: "enable_weight_streaming", "tiling_optimization_level", "l2_limit_for_tiling", - "cpu_memory_budget", + "enable_resource_partitioning", ) diff --git a/py/torch_tensorrt/dynamo/partitioning/_resource_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_resource_partitioner.py index ee14deee61..6cbc9b3db9 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_resource_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_resource_partitioner.py @@ -46,7 +46,7 @@ """ import logging -from typing import Dict, List, Set, Tuple +from typing import Dict, List, Optional, Set, Tuple import psutil import torch @@ -59,6 +59,8 @@ logger = logging.getLogger(__name__) +MAX_NUM_OF_ENGINES = 40 + class ResourcePartitioner(_SplitterBase): # type: ignore """Refine capability-based subgraphs to meet host CPU memory constraints. @@ -78,14 +80,19 @@ class ResourcePartitioner(_SplitterBase): # type: ignore def __init__( self, module: torch.fx.GraphModule, - cpu_memory_budget: int, + cpu_memory_budget: Optional[int], submodule_name: str, ): assert isinstance(module, torch.fx.GraphModule) self.module = module - self.cpu_memory_budget = cpu_memory_budget + self.cpu_memory_budget = ( + cpu_memory_budget + if cpu_memory_budget is not None + else psutil.virtual_memory().available + ) + self.not_set_limit = cpu_memory_budget is None self.resource_split_count = 0 self.submodule_name = submodule_name self.deps = self.find_deps() @@ -148,10 +155,6 @@ def put_nodes_into_subgraphs(self) -> list[Subgraph]: subgraphs = [Subgraph(is_acc=True, nodes=nodes)] self.fusion_patterns = get_node_in_fusion_pattern(self.module.graph) - assert self.check_topological_order( - subgraphs - ), "The subgraphs are not topologically ordered" - return subgraphs def check_topological_order(self, subgraphs: List[Subgraph]) -> bool: @@ -186,7 +189,11 @@ def calculate_size_budget( """ used_rss: int = psutil.Process().memory_info().rss - available_rss = self.cpu_memory_budget - used_rss + available_rss = ( + self.cpu_memory_budget + if self.not_set_limit + else self.cpu_memory_budget - used_rss + ) return available_rss // engine_compilation_memory_usage_multiplier def break_subgraphs( @@ -214,12 +221,17 @@ def break_subgraphs( # We throw an error if the remaining memory is almost empty compared to the model size. # i.e. if the remaining memory is 4G (budget is 1G) the model size is greater than 40G, we stop the compilation. sizes = self.size_of_subgraphs(subgraphs) - if sum(sizes) > subgraph_size_budget * 40: - raise ValueError( - "CPU memory budget or available memory is too small to compile the model. " - + f"CPU memory budget: {self.cpu_memory_budget // (1024 * 1024)} MB, Model size: {sum(sizes) // (1024 * 1024)} MB. " - + "Consider setting cpu_memory_budget to a larger value or disable offload_module_to_cpu to save more CPU memory." - ) + if sum(sizes) > subgraph_size_budget * MAX_NUM_OF_ENGINES: + if self.not_set_limit: + raise ValueError( + "The system memory is too constrained to compile the model without severe perf degradation. Consider setting offload_module_to_cpu=False to save more CPU memory." + ) + else: + raise ValueError( + "CPU memory budget is too small to compile the model. " + + f"CPU memory budget: {self.cpu_memory_budget // (1024 * 1024)} MB, Model size: {sum(sizes) // (1024 * 1024)} MB. " + + "Consider setting cpu_memory_budget to a larger value." + ) for subgraph, size in zip(subgraphs, sizes): while size > subgraph_size_budget: @@ -233,11 +245,11 @@ def break_subgraphs( if len(subgraph.nodes) != 0: new_subgraphs.append(subgraph) - self._varify_all_fusion_nodes_in_same_subgraph(new_subgraphs) + self._verify_all_fusion_nodes_in_same_subgraph(new_subgraphs) return new_subgraphs - def _varify_all_fusion_nodes_in_same_subgraph( + def _verify_all_fusion_nodes_in_same_subgraph( self, subgraphs: List[Subgraph] ) -> None: """Assert that every fusion group is contained in exactly one subgraph.""" diff --git a/tests/py/dynamo/partitioning/test_resource_partitioning.py b/tests/py/dynamo/partitioning/test_resource_partitioning.py index bb38a13d44..c2c1151046 100644 --- a/tests/py/dynamo/partitioning/test_resource_partitioning.py +++ b/tests/py/dynamo/partitioning/test_resource_partitioning.py @@ -63,6 +63,7 @@ def forward(self, x): "min_block_size": 1, "immutable_weights": True, "reuse_cached_engines": False, + "enable_resource_partitioning": True, } settings = CompilationSettings(**compilation_options) @@ -144,6 +145,7 @@ def forward(self, x): "immutable_weights": True, "reuse_cached_engines": False, "torch_executed_ops": {"torch.ops.aten.max_pool2d.default"}, + "enable_resource_partitioning": True, } settings = CompilationSettings(**compilation_options) @@ -175,8 +177,8 @@ def forward(self, x): if "_run_on_acc" in name ] ) - == 5 - ), "The graph should have 5 accelerated subgraphs" + > 3 + ), "The graph should have more than 3 accelerated subgraphs" assert ( len( [ @@ -275,6 +277,7 @@ def forward(self, x): "immutable_weights": True, "reuse_cached_engines": False, "torch_executed_ops": {"torch.ops.aten.max_pool2d.default"}, + "enable_resource_partitioning": True, } settings = CompilationSettings(**compilation_options) @@ -355,6 +358,7 @@ def forward(self, x): "min_block_size": 1, "immutable_weights": True, "reuse_cached_engines": False, + "enable_resource_partitioning": True, } settings = CompilationSettings(**compilation_options) @@ -409,7 +413,7 @@ def forward(self, x): assert broken_fusion # The fusion should be fixed after the step - partitioner._varify_all_fusion_nodes_in_same_subgraph(new_subgraphs) + partitioner._verify_all_fusion_nodes_in_same_subgraph(new_subgraphs) break @@ -463,6 +467,7 @@ def forward(self, x): "immutable_weights": True, "reuse_cached_engines": False, "torch_executed_ops": {"torch.ops.aten.max_pool2d.default"}, + "enable_resource_partitioning": True, } settings = CompilationSettings(**compilation_options) From 08841a1ed58ebb4c1a5b8ef9ec9fa0288b2db2b2 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Mon, 8 Dec 2025 21:10:54 +0000 Subject: [PATCH 10/10] Moved some tests to L1 --- .github/workflows/build-test-linux-x86_64.yml | 4 +- .../partitioning/_resource_partitioner.py | 32 +++-- ...oning.py => test_000_fast_partitioning.py} | 0 ... => test_000_flaky_global_partitioning.py} | 0 ...ing.py => test_000_global_partitioning.py} | 0 ... => test_000_hierarchical_partitioning.py} | 0 .../test_000_resource_partitioning.py | 113 ++++++++++++++++++ ...g.py => test_001_resource_partitioning.py} | 91 -------------- 8 files changed, 136 insertions(+), 104 deletions(-) rename tests/py/dynamo/partitioning/{test_fast_partitioning.py => test_000_fast_partitioning.py} (100%) rename tests/py/dynamo/partitioning/{test_flaky_global_partitioning.py => test_000_flaky_global_partitioning.py} (100%) rename tests/py/dynamo/partitioning/{test_global_partitioning.py => test_000_global_partitioning.py} (100%) rename tests/py/dynamo/partitioning/{test_hierarchical_partitioning.py => test_000_hierarchical_partitioning.py} (100%) create mode 100644 tests/py/dynamo/partitioning/test_000_resource_partitioning.py rename tests/py/dynamo/partitioning/{test_resource_partitioning.py => test_001_resource_partitioning.py} (80%) diff --git a/.github/workflows/build-test-linux-x86_64.yml b/.github/workflows/build-test-linux-x86_64.yml index 4c0f31b256..3918b0f839 100644 --- a/.github/workflows/build-test-linux-x86_64.yml +++ b/.github/workflows/build-test-linux-x86_64.yml @@ -136,7 +136,7 @@ jobs: cd tests/py cd dynamo python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_runtime_tests_results.xml runtime/test_000_* - python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_partitioning_tests_results.xml partitioning/ + python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_partitioning_tests_results.xml partitioning/test_000_* python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_lowering_tests_results.xml lowering/ popd @@ -229,6 +229,8 @@ jobs: pushd . cd tests/py/dynamo python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l1_dynamo_core_tests_results.xml runtime/test_001_* + python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l1_dynamo_core_partitioning_tests_results.xml partitioning/test_001_* + popd L1-dynamo-compile-tests: diff --git a/py/torch_tensorrt/dynamo/partitioning/_resource_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_resource_partitioner.py index 6cbc9b3db9..0a987d7a7f 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_resource_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_resource_partitioner.py @@ -60,6 +60,7 @@ logger = logging.getLogger(__name__) MAX_NUM_OF_ENGINES = 40 +ENGINE_COMPILATION_MEMORY_USAGE_MULTIPLIER = 4 class ResourcePartitioner(_SplitterBase): # type: ignore @@ -87,8 +88,9 @@ def __init__( assert isinstance(module, torch.fx.GraphModule) self.module = module - self.cpu_memory_budget = ( - cpu_memory_budget + used_rss: int = psutil.Process().memory_info().rss + self.remaining_memory_budget = ( + cpu_memory_budget - used_rss if cpu_memory_budget is not None else psutil.virtual_memory().available ) @@ -114,6 +116,12 @@ def partition_graph(self) -> torch.fx.GraphModule: """ # Delegate nodes based on operator coverage subgraphs = self.put_nodes_into_subgraphs() + sizes = self.size_of_subgraphs(subgraphs) + if ( + sum(sizes) * ENGINE_COMPILATION_MEMORY_USAGE_MULTIPLIER + < self.remaining_memory_budget + ): + return self.module subgraphs = self.break_subgraphs( subgraphs, subgraph_size_budget=self.calculate_size_budget() @@ -172,7 +180,8 @@ def check_topological_order(self, subgraphs: List[Subgraph]) -> bool: return True def calculate_size_budget( - self, engine_compilation_memory_usage_multiplier: int = 4 + self, + engine_compilation_memory_usage_multiplier: int = ENGINE_COMPILATION_MEMORY_USAGE_MULTIPLIER, ) -> int: """Compute the per-engine size budget in bytes. @@ -188,13 +197,9 @@ def calculate_size_budget( int: Budget in bytes for a single accelerated subgraph. """ - used_rss: int = psutil.Process().memory_info().rss - available_rss = ( - self.cpu_memory_budget - if self.not_set_limit - else self.cpu_memory_budget - used_rss + return ( + self.remaining_memory_budget // engine_compilation_memory_usage_multiplier ) - return available_rss // engine_compilation_memory_usage_multiplier def break_subgraphs( self, subgraphs: List[Subgraph], subgraph_size_budget: int @@ -229,7 +234,7 @@ def break_subgraphs( else: raise ValueError( "CPU memory budget is too small to compile the model. " - + f"CPU memory budget: {self.cpu_memory_budget // (1024 * 1024)} MB, Model size: {sum(sizes) // (1024 * 1024)} MB. " + + f"CPU memory budget: {self.remaining_memory_budget // (1024 * 1024)} MB, Model size: {sum(sizes) // (1024 * 1024)} MB. " + "Consider setting cpu_memory_budget to a larger value." ) for subgraph, size in zip(subgraphs, sizes): @@ -548,12 +553,15 @@ def resource_partition( setattr(gm, name, partitioned_graph) for name, module in list(gm.named_children()): + split = False if "_run_on_acc" in name: for subname, submodule in module.named_children(): if "resource_split" in subname: + split = True setattr(gm, subname, submodule) - _inline_module(gm, name) - delattr(gm, name) + if split: + _inline_module(gm, name) + delattr(gm, name) gm.recompile() return gm diff --git a/tests/py/dynamo/partitioning/test_fast_partitioning.py b/tests/py/dynamo/partitioning/test_000_fast_partitioning.py similarity index 100% rename from tests/py/dynamo/partitioning/test_fast_partitioning.py rename to tests/py/dynamo/partitioning/test_000_fast_partitioning.py diff --git a/tests/py/dynamo/partitioning/test_flaky_global_partitioning.py b/tests/py/dynamo/partitioning/test_000_flaky_global_partitioning.py similarity index 100% rename from tests/py/dynamo/partitioning/test_flaky_global_partitioning.py rename to tests/py/dynamo/partitioning/test_000_flaky_global_partitioning.py diff --git a/tests/py/dynamo/partitioning/test_global_partitioning.py b/tests/py/dynamo/partitioning/test_000_global_partitioning.py similarity index 100% rename from tests/py/dynamo/partitioning/test_global_partitioning.py rename to tests/py/dynamo/partitioning/test_000_global_partitioning.py diff --git a/tests/py/dynamo/partitioning/test_hierarchical_partitioning.py b/tests/py/dynamo/partitioning/test_000_hierarchical_partitioning.py similarity index 100% rename from tests/py/dynamo/partitioning/test_hierarchical_partitioning.py rename to tests/py/dynamo/partitioning/test_000_hierarchical_partitioning.py diff --git a/tests/py/dynamo/partitioning/test_000_resource_partitioning.py b/tests/py/dynamo/partitioning/test_000_resource_partitioning.py new file mode 100644 index 0000000000..2014eea8fe --- /dev/null +++ b/tests/py/dynamo/partitioning/test_000_resource_partitioning.py @@ -0,0 +1,113 @@ +import torch +import torch.nn as nn +from torch.fx.passes.splitter_base import Subgraph +from torch.ops import aten +from torch.testing._internal.common_utils import TestCase, run_tests +from torch_tensorrt.dynamo import partitioning +from torch_tensorrt.dynamo.conversion import CompilationSettings +from torch_tensorrt.dynamo.lowering import ( + get_decompositions, + post_lowering, + pre_export_lowering, +) +from torch_tensorrt.dynamo.lowering.passes import post_lowering, pre_export_lowering +from torch_tensorrt.dynamo.partitioning._resource_partitioner import ( + ResourcePartitioner, +) + + +class TestResourcePartitioning(TestCase): + def test_atomic_subgraph_correction(self): + class net(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 3, 3, padding=1) + self.bn1 = nn.BatchNorm2d(3) + self.relu = nn.ReLU() + self.fc = nn.Linear(3 * 224 * 224, 10) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = torch.flatten(x, 1) + x = self.fc(x) + return x + + model = net().eval() + model.to("cuda") + inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] + + enabled_precisions = {torch.float} + use_python_runtime = False + + exp_program = torch.export.export(model, tuple(inputs)) + + compilation_options = { + "use_python_runtime": use_python_runtime, + "enabled_precisions": enabled_precisions, + "min_block_size": 1, + "immutable_weights": True, + "reuse_cached_engines": False, + "enable_resource_partitioning": True, + } + settings = CompilationSettings(**compilation_options) + + exported_program = pre_export_lowering(exp_program, settings) + exported_program = exported_program.run_decompositions( + get_decompositions(False) + ) + + gm = exported_program.module() + gm = post_lowering(gm, settings) + + partitioned_module, supported_ops = partitioning.fast_partition( + gm, + min_block_size=settings.min_block_size, + torch_executed_ops=settings.torch_executed_ops, + require_full_compilation=settings.require_full_compilation, + skip_fusion=True, + ) + + for name, _ in partitioned_module.named_children(): + submodule = getattr(partitioned_module, name) + if ( + not isinstance(submodule, torch.fx.graph_module.GraphModule) + or "_run_on_acc" not in name + ): + continue + partitioner = ResourcePartitioner( + submodule, + submodule_name=name, + cpu_memory_budget=2 * 1024 * 1024 * 1024, + ) + subgraphs = partitioner.put_nodes_into_subgraphs() + new_subgraphs = [] + current_subgraph = [] + # Split the subgraph into two subgraphs by the ReLU node, which breaks the fusion group. + for node in subgraphs[0].nodes: + if node.op == "call_function" and node.target == aten.relu.default: + new_subgraphs.append(Subgraph(is_acc=True, nodes=current_subgraph)) + current_subgraph = [] + current_subgraph.append(node) + if current_subgraph: + new_subgraphs.append(Subgraph(is_acc=True, nodes=current_subgraph)) + + leaf_node = partitioner.get_leaf_node(new_subgraphs[0].nodes) + broken_fusion = partitioner.step_if_break_fusion( + new_subgraphs, + leaf_node, + set(new_subgraphs[0].nodes), + set(new_subgraphs[1].nodes), + ) + # The fusion was broken + assert broken_fusion + + # The fusion should be fixed after the step + partitioner._verify_all_fusion_nodes_in_same_subgraph(new_subgraphs) + + break + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/partitioning/test_resource_partitioning.py b/tests/py/dynamo/partitioning/test_001_resource_partitioning.py similarity index 80% rename from tests/py/dynamo/partitioning/test_resource_partitioning.py rename to tests/py/dynamo/partitioning/test_001_resource_partitioning.py index c2c1151046..b5b1b17628 100644 --- a/tests/py/dynamo/partitioning/test_resource_partitioning.py +++ b/tests/py/dynamo/partitioning/test_001_resource_partitioning.py @@ -326,97 +326,6 @@ def forward(self, x): torch._dynamo.reset() - def test_atomic_subgraph_correction(self): - class net(nn.Module): - def __init__(self): - super().__init__() - self.conv1 = nn.Conv2d(3, 3, 3, padding=1) - self.bn1 = nn.BatchNorm2d(3) - self.relu = nn.ReLU() - self.fc = nn.Linear(3 * 224 * 224, 10) - - def forward(self, x): - x = self.conv1(x) - x = self.bn1(x) - x = self.relu(x) - x = torch.flatten(x, 1) - x = self.fc(x) - return x - - model = net().eval() - model.to("cuda") - inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] - - enabled_precisions = {torch.float} - use_python_runtime = False - - exp_program = torch.export.export(model, tuple(inputs)) - - compilation_options = { - "use_python_runtime": use_python_runtime, - "enabled_precisions": enabled_precisions, - "min_block_size": 1, - "immutable_weights": True, - "reuse_cached_engines": False, - "enable_resource_partitioning": True, - } - settings = CompilationSettings(**compilation_options) - - exported_program = pre_export_lowering(exp_program, settings) - exported_program = exported_program.run_decompositions( - get_decompositions(False) - ) - - gm = exported_program.module() - gm = post_lowering(gm, settings) - - partitioned_module, supported_ops = partitioning.fast_partition( - gm, - min_block_size=settings.min_block_size, - torch_executed_ops=settings.torch_executed_ops, - require_full_compilation=settings.require_full_compilation, - skip_fusion=True, - ) - - for name, _ in partitioned_module.named_children(): - submodule = getattr(partitioned_module, name) - if ( - not isinstance(submodule, torch.fx.graph_module.GraphModule) - or "_run_on_acc" not in name - ): - continue - partitioner = ResourcePartitioner( - submodule, - submodule_name=name, - cpu_memory_budget=2 * 1024 * 1024 * 1024, - ) - subgraphs = partitioner.put_nodes_into_subgraphs() - new_subgraphs = [] - current_subgraph = [] - # Split the subgraph into two subgraphs by the ReLU node, which breaks the fusion group. - for node in subgraphs[0].nodes: - if node.op == "call_function" and node.target == aten.relu.default: - new_subgraphs.append(Subgraph(is_acc=True, nodes=current_subgraph)) - current_subgraph = [] - current_subgraph.append(node) - if current_subgraph: - new_subgraphs.append(Subgraph(is_acc=True, nodes=current_subgraph)) - - leaf_node = partitioner.get_leaf_node(new_subgraphs[0].nodes) - broken_fusion = partitioner.step_if_break_fusion( - new_subgraphs, - leaf_node, - set(new_subgraphs[0].nodes), - set(new_subgraphs[1].nodes), - ) - # The fusion was broken - assert broken_fusion - - # The fusion should be fixed after the step - partitioner._verify_all_fusion_nodes_in_same_subgraph(new_subgraphs) - - break - def test_resource_partitioning_with_global_capability_partitioning(self): class net(nn.Module): def __init__(self):