Skip to content

Commit c31d354

Browse files
authored
fix position inside subgraph in inplace check (#870)
1 parent d94ca6e commit c31d354

File tree

2 files changed

+48
-10
lines changed

2 files changed

+48
-10
lines changed

intel_extension_for_pytorch/csrc/jit/cpu/passes/graph_rewrite_inplace_replace.cpp

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ bool hasSideEffectInDefNode(Node* def_node, int position) {
2121
def_node->hasSideEffects() || (def_node->kind() == prim::Param);
2222
}
2323
}
24-
2524
return checkresult;
2625
}
2726

@@ -45,31 +44,45 @@ bool hasSideEffectInBlocks(Block* block, Value* v) {
4544

4645
bool hasSideEffectOrAliasInSubgraphs(Node* node, Value* v) {
4746
bool checkresult = false;
48-
// A LLGAFusionGroup must have its fallbackgraph, we only need to check one of
49-
// them
47+
// A LLGAFusionGroup or TensorExprGroup must have its fallbackgraph, we only
48+
// need to check one of them
5049
if (node->kind().toQualString() ==
5150
Symbol::fromQualString("ipex::LlgaFusionGroup").toQualString()) {
5251
return false;
5352
}
53+
if (node->kind().toQualString() ==
54+
Symbol::fromQualString("prim::TensorExprGroup").toQualString()) {
55+
return false;
56+
}
57+
5458
// get the subgraph of the def node
5559
auto subgraph = node->g(attr::Subgraph);
5660

5761
// find the position of target value in its def node in subgraph
5862
// for example, here find (%input.1), and the posion is 0:
5963
// graph(---),
6064
// %input.1 : Tensor = Ops
61-
// return (%input.1)
62-
int position = v->offset();
63-
auto def_node = subgraph->outputs()[position]->node();
64-
std::unique_ptr<AliasDb> aliasDb_ = std::make_unique<AliasDb>(subgraph);
65+
// %input.2 : Tensor = Ops
66+
// return (%input.1, %input.2)
6567

66-
checkresult = hasSideEffectInDefNode(def_node, position);
68+
// position_in_subgraph is graph returned position, e.g, for %input.1 is 0,
69+
// for %input.2 is 1
70+
int position_in_subgraph = v->offset();
71+
auto def_node = subgraph->outputs()[position_in_subgraph]->node();
72+
// position_in_def_node is def node position, e.g, for %input.1 or %input.2 is
73+
// 0
74+
int position_in_def_node =
75+
subgraph->outputs()[position_in_subgraph]->offset();
76+
77+
checkresult = hasSideEffectInDefNode(def_node, position_in_def_node);
6778

6879
// for def node in subgraph, has to check its alias too
80+
// if the output isn't contained or alias by the inputs to its node, it's
81+
// unique. No need to check for alias if the node is a ListConstruct.
82+
std::unique_ptr<AliasDb> aliasDb_ = std::make_unique<AliasDb>(subgraph);
6983
bool mayAliasInputs = (def_node->kind() != prim::ListConstruct) &&
7084
aliasDb_->mayContainAlias(
71-
def_node->inputs(), def_node->outputs()[position]);
72-
85+
def_node->inputs(), def_node->outputs()[position_in_def_node]);
7386
checkresult = checkresult || mayAliasInputs;
7487
return checkresult;
7588
}

tests/cpu/test_softmax.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,20 @@ def forward(self, x):
3232
x2 = nn.Softmax(dim=-1)(x1)
3333
return x2
3434

35+
class inplace_softmax_with_TE_group(torch.nn.Module):
36+
def __init__(self):
37+
super().__init__()
38+
def forward(self, x):
39+
x1 = x + 1
40+
x2 = x + 2
41+
x3 = x + 3
42+
x4 = x + 4
43+
x5 = x + 5
44+
y1 = (x1 / x2).softmax(dim = -1)
45+
y2 = ((x4 - x3) / x5).softmax(dim = -1)
46+
return y1, y2
47+
48+
3549
class SoftmaxTester(JitTestCase):
3650
def test_softmax(self):
3751
for dtype in ["fp32", "bf16"]:
@@ -40,19 +54,22 @@ def test_softmax(self):
4054
test3 = torch.tensor([[1.0,1.0],[1.0,1.0]])
4155
test4 = torch.tensor([[1.0,1.0],[1.0,1.0]]).transpose(1,0)
4256
test5 = torch.tensor([[2.0,2.0],[2.0,2.0]]).transpose(1,0)
57+
test6 = torch.tensor([[1.0,1.0],[1.0,1.0]])
4358

4459
if dtype == "bf16":
4560
test1 = test1.bfloat16()
4661
test2 = test2.bfloat16()
4762
test3 = test3.bfloat16()
4863
test4 = test4.bfloat16()
4964
test5 = test5.bfloat16()
65+
test6 = test6.bfloat16()
5066

5167
model1 = softmax_with_multiuse_input().eval()
5268
model2 = softmax_with_alias_input().eval()
5369
model3 = inplace_softmax().eval()
5470
model4 = inplace_softmax().eval()
5571
model5 = softmax_with_multiuse_input().eval()
72+
model6 = inplace_softmax_with_TE_group().eval()
5673

5774
with torch.no_grad():
5875
model1 = torch.jit.trace(model1, test1)
@@ -65,6 +82,9 @@ def test_softmax(self):
6582
res4 = model4(test4)
6683
model5 = torch.jit.trace(model5, test5)
6784
res5 = model5(test5)
85+
model6_traced = torch.jit.trace(model6, test6)
86+
res6_traced = model6_traced(test6)
87+
res6 = model6(test6)
6888

6989

7090
# should be outplace since multi-use
@@ -82,12 +102,17 @@ def test_softmax(self):
82102
# outplace test, but should be aten::softmax due to non-contiguous input
83103
graph5 = model5.graph_for(test5)
84104
self.assertGraphContainsExactly(graph5, ATEN_SOFTMAX, 1)
105+
# should be inplace
106+
graph6 = model6_traced.graph_for(test6)
107+
self.assertGraphContainsExactly(graph6, IPEX_SOFTMAX_, 2)
85108

86109
# the output results of above inplace/outplace softmax should be the same
87110
self.assertEqual(res1[0], res2[1], 0)
88111
self.assertEqual(res1[0], res3, 0)
89112
self.assertEqual(res1[0], res4, 0)
90113
self.assertEqual(res1[0], res5[0], 0)
114+
self.assertEqual(res6[0], res6_traced[0], 0)
115+
self.assertEqual(res6[1], res6_traced[1], 0)
91116

92117

93118
if __name__ == '__main__':

0 commit comments

Comments
 (0)