@@ -32,6 +32,20 @@ def forward(self, x):
3232 x2 = nn .Softmax (dim = - 1 )(x1 )
3333 return x2
3434
35+ class inplace_softmax_with_TE_group (torch .nn .Module ):
36+ def __init__ (self ):
37+ super ().__init__ ()
38+ def forward (self , x ):
39+ x1 = x + 1
40+ x2 = x + 2
41+ x3 = x + 3
42+ x4 = x + 4
43+ x5 = x + 5
44+ y1 = (x1 / x2 ).softmax (dim = - 1 )
45+ y2 = ((x4 - x3 ) / x5 ).softmax (dim = - 1 )
46+ return y1 , y2
47+
48+
3549class SoftmaxTester (JitTestCase ):
3650 def test_softmax (self ):
3751 for dtype in ["fp32" , "bf16" ]:
@@ -40,19 +54,22 @@ def test_softmax(self):
4054 test3 = torch .tensor ([[1.0 ,1.0 ],[1.0 ,1.0 ]])
4155 test4 = torch .tensor ([[1.0 ,1.0 ],[1.0 ,1.0 ]]).transpose (1 ,0 )
4256 test5 = torch .tensor ([[2.0 ,2.0 ],[2.0 ,2.0 ]]).transpose (1 ,0 )
57+ test6 = torch .tensor ([[1.0 ,1.0 ],[1.0 ,1.0 ]])
4358
4459 if dtype == "bf16" :
4560 test1 = test1 .bfloat16 ()
4661 test2 = test2 .bfloat16 ()
4762 test3 = test3 .bfloat16 ()
4863 test4 = test4 .bfloat16 ()
4964 test5 = test5 .bfloat16 ()
65+ test6 = test6 .bfloat16 ()
5066
5167 model1 = softmax_with_multiuse_input ().eval ()
5268 model2 = softmax_with_alias_input ().eval ()
5369 model3 = inplace_softmax ().eval ()
5470 model4 = inplace_softmax ().eval ()
5571 model5 = softmax_with_multiuse_input ().eval ()
72+ model6 = inplace_softmax_with_TE_group ().eval ()
5673
5774 with torch .no_grad ():
5875 model1 = torch .jit .trace (model1 , test1 )
@@ -65,6 +82,9 @@ def test_softmax(self):
6582 res4 = model4 (test4 )
6683 model5 = torch .jit .trace (model5 , test5 )
6784 res5 = model5 (test5 )
85+ model6_traced = torch .jit .trace (model6 , test6 )
86+ res6_traced = model6_traced (test6 )
87+ res6 = model6 (test6 )
6888
6989
7090 # should be outplace since multi-use
@@ -82,12 +102,17 @@ def test_softmax(self):
82102 # outplace test, but should be aten::softmax due to non-contiguous input
83103 graph5 = model5 .graph_for (test5 )
84104 self .assertGraphContainsExactly (graph5 , ATEN_SOFTMAX , 1 )
105+ # should be inplace
106+ graph6 = model6_traced .graph_for (test6 )
107+ self .assertGraphContainsExactly (graph6 , IPEX_SOFTMAX_ , 2 )
85108
86109 # the output results of above inplace/outplace softmax should be the same
87110 self .assertEqual (res1 [0 ], res2 [1 ], 0 )
88111 self .assertEqual (res1 [0 ], res3 , 0 )
89112 self .assertEqual (res1 [0 ], res4 , 0 )
90113 self .assertEqual (res1 [0 ], res5 [0 ], 0 )
114+ self .assertEqual (res6 [0 ], res6_traced [0 ], 0 )
115+ self .assertEqual (res6 [1 ], res6_traced [1 ], 0 )
91116
92117
93118if __name__ == '__main__' :
0 commit comments