diff --git a/backends/cadence/aot/memory_constraints.py b/backends/cadence/aot/memory_constraints.py index 0eaaa8987c6..4b637da8d48 100644 --- a/backends/cadence/aot/memory_constraints.py +++ b/backends/cadence/aot/memory_constraints.py @@ -417,6 +417,10 @@ def is_slice_view(self, node: torch.fx.Node) -> bool: return not self.constraint.is_alias_of(source_info.source, node) return False + def has_relative_placement_constraint(self, node: torch.fx.Node) -> bool: + """Return if `node` already has any relative placement constraint.""" + return self.constraint.get_relative_placement_source(node) is not None + # Return true if the cat node performs concatenation along outermost dimension def is_cat_along_outermost_dim( self, graph_module: torch.fx.GraphModule, cat_node: torch.fx.Node @@ -481,6 +485,17 @@ def is_removable_cat_op( if any(self.is_slice_view(arg) for arg in cat_tensors): return False + # If any of the tensors already has a relative placement constraint, + # we cannot add a new constraint for this cat without conflicting. + # This can happen when a tensor is used in multiple cat operations. + if any(self.has_relative_placement_constraint(arg) for arg in cat_tensors): + return False + + # If the same tensor appears multiple times in the cat inputs, + # we cannot place it at multiple different offsets relative to the output. + if len(cat_tensors) != len(set(cat_tensors)): + return False + # Many ops in HiFi require the input to be aligned to 8-byte boundary. # If the cat is not the graph's output, then ensure that the relative # offset of any concatenated non-placeholder tensor is a multiple of diff --git a/backends/cadence/aot/tests/test_memory_passes.py b/backends/cadence/aot/tests/test_memory_passes.py index 41f903ccf06..6c8da2202d4 100644 --- a/backends/cadence/aot/tests/test_memory_passes.py +++ b/backends/cadence/aot/tests/test_memory_passes.py @@ -947,6 +947,110 @@ def test_cat_then_cat(self) -> None: self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0) self.verify_nop_memory_alloc(graph_module) + def test_cat_with_duplicate_input_tensor(self) -> None: + """ + Test that cat is NOT optimized when the same tensor appears multiple + times in the cat input list. This is because we cannot place the same + tensor at multiple different offsets relative to the output. + """ + builder = GraphBuilder() + x = builder.placeholder("x", torch.ones(3, 6, dtype=torch.float32)) + to_add_to_x = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=([3, 6], 123.0), + kwargs={"dtype": torch.float32}, + ) + add_x = builder.call_operator( + op=exir_ops.edge.aten.add.Tensor, + args=(x, to_add_to_x), + ) + pre_created_output = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=([6, 6], 0.0), + kwargs={"dtype": torch.float32}, + ) + # Same tensor (add_x) appears twice in the cat inputs + cat = builder.call_operator( + op=torch.ops.aten.cat.out, + args=([add_x, add_x],), + kwargs={"dim": 0, "out": pre_created_output}, + ) + builder.output([cat]) + original = builder.get_graph_module() + graph_module = self.run_memory_planning(original) + graph_module.graph.eliminate_dead_code() + + # Assert that cat op is NOT optimized away since the same tensor + # appears multiple times in the input list + self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1) + self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 0) + self.verify_nop_memory_alloc(graph_module) + + def test_cat_with_tensor_having_existing_constraint(self) -> None: + """ + Test that the second cat is NOT optimized when a tensor already has a + relative placement constraint from a previous cat operation. + """ + builder = GraphBuilder() + x = builder.placeholder("x", torch.ones(8, 8, dtype=torch.float32)) + to_add = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=([8, 8], 1.0), + kwargs={"dtype": torch.float32}, + ) + x1 = builder.call_operator( + op=exir_ops.edge.aten.add.Tensor, + args=(x, to_add), + ) + x2 = builder.call_operator( + op=exir_ops.edge.aten.add.Tensor, + args=(x1, to_add), + ) + x3 = builder.call_operator( + op=exir_ops.edge.aten.add.Tensor, + args=(x2, to_add), + ) + # First cat: cat(x1, x2) - this will give x1 and x2 relative placement constraints + pre_created_output1 = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=([16, 8], 0.0), + kwargs={"dtype": torch.float32}, + ) + cat1 = builder.call_operator( + op=torch.ops.aten.cat.out, + args=([x1, x2],), + kwargs={"dim": 0, "out": pre_created_output1}, + ) + # Second cat: cat(x2, x3) - x2 already has a constraint from cat1, + # so this cat cannot be optimized + pre_created_output2 = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=([16, 8], 0.0), + kwargs={"dtype": torch.float32}, + ) + cat2 = builder.call_operator( + op=torch.ops.aten.cat.out, + args=([x2, x3],), + kwargs={"dim": 0, "out": pre_created_output2}, + ) + # Use both cat results to keep them alive + graph_output = builder.call_operator( + op=exir_ops.edge.aten.add.Tensor, + args=(cat1, cat2), + ) + builder.output([graph_output]) + original = builder.get_graph_module() + graph_module = self.run_memory_planning( + original, opt_level=3, alloc_graph_input=False + ) + graph_module.graph.eliminate_dead_code() + + # The first cat should be optimized to _cat_nop, but the second cat + # cannot be optimized because x2 already has a relative placement constraint + self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1) + self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1) + self.verify_nop_memory_alloc(graph_module) + def test_view_for_unallocated_output(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.ones(3, 5, dtype=torch.float32))