@@ -131,3 +131,17 @@ def test_cat_multi_device(self, dtype=torch.float):
131131 x_xpu2 = x_cpu2 .clone ().to ("xpu:1" )
132132 res_xpu = torch .cat ((x_xpu1 , x_xpu2 ))
133133 self .assertEqual (res_cpu , res_xpu .cpu ())
134+
135+ def test_cat_size0_tensor (self ):
136+ output1_cpu = torch .cat ((torch .tensor ([],device = 'cpu' ), torch .tensor ([1 ],device = 'cpu' )), dim = 0 )
137+ output2_cpu = torch .cat ((torch .tensor ([],device = 'cpu' ), torch .tensor ([1 ,2 ],device = 'cpu' )), dim = 0 )
138+ output3_cpu = torch .cat ((torch .tensor ([],device = 'cpu' ), torch .tensor ([[1 ,2 ],[3 ,4 ]],device = 'cpu' )), dim = 0 )
139+ output4_cpu = torch .cat ((torch .tensor ([],device = 'cpu' ), torch .tensor ([[[1 ]],[[2 ]]],device = 'cpu' )), dim = 0 )
140+ output1_xpu = torch .cat ((torch .tensor ([],device = 'xpu' ), torch .tensor ([1 ],device = 'xpu' )), dim = 0 )
141+ output2_xpu = torch .cat ((torch .tensor ([],device = 'xpu' ), torch .tensor ([1 ,2 ],device = 'xpu' )), dim = 0 )
142+ output3_xpu = torch .cat ((torch .tensor ([],device = 'xpu' ), torch .tensor ([[1 ,2 ],[3 ,4 ]],device = 'xpu' )), dim = 0 )
143+ output4_xpu = torch .cat ((torch .tensor ([],device = 'xpu' ), torch .tensor ([[[1 ]],[[2 ]]],device = 'xpu' )), dim = 0 )
144+ self .assertEqual (output1_cpu , output1_xpu .cpu ())
145+ self .assertEqual (output2_cpu , output2_xpu .cpu ())
146+ self .assertEqual (output3_cpu , output3_xpu .cpu ())
147+ self .assertEqual (output4_cpu , output4_xpu .cpu ())
0 commit comments