Skip to content

Commit fd78768

Browse files
authored
Fixed wrong condition in cat op (#2546)
Signed-off-by: majing <Jing1.Ma@intel.com>
1 parent d3f6fe1 commit fd78768

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

csrc/gpu/aten/operators/CatImpl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ static void cat(
336336
bool hasSkippedInput = false;
337337
Tensor notSkippedTensor; // non-owning reference
338338
auto should_skip = [](const Tensor& t) {
339-
return !t.defined() && t.dim() == 1;
339+
return t.numel() == 0 && t.dim() == 1;
340340
};
341341
int nDims = 0;
342342

tests/gpu/examples/test_cat_array.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,3 +131,17 @@ def test_cat_multi_device(self, dtype=torch.float):
131131
x_xpu2 = x_cpu2.clone().to("xpu:1")
132132
res_xpu = torch.cat((x_xpu1, x_xpu2))
133133
self.assertEqual(res_cpu, res_xpu.cpu())
134+
135+
def test_cat_size0_tensor(self):
136+
output1_cpu = torch.cat((torch.tensor([],device='cpu'), torch.tensor([1],device='cpu')), dim=0)
137+
output2_cpu = torch.cat((torch.tensor([],device='cpu'), torch.tensor([1,2],device='cpu')), dim=0)
138+
output3_cpu = torch.cat((torch.tensor([],device='cpu'), torch.tensor([[1,2],[3,4]],device='cpu')), dim=0)
139+
output4_cpu = torch.cat((torch.tensor([],device='cpu'), torch.tensor([[[1]],[[2]]],device='cpu')), dim=0)
140+
output1_xpu = torch.cat((torch.tensor([],device='xpu'), torch.tensor([1],device='xpu')), dim=0)
141+
output2_xpu = torch.cat((torch.tensor([],device='xpu'), torch.tensor([1,2],device='xpu')), dim=0)
142+
output3_xpu = torch.cat((torch.tensor([],device='xpu'), torch.tensor([[1,2],[3,4]],device='xpu')), dim=0)
143+
output4_xpu = torch.cat((torch.tensor([],device='xpu'), torch.tensor([[[1]],[[2]]],device='xpu')), dim=0)
144+
self.assertEqual(output1_cpu, output1_xpu.cpu())
145+
self.assertEqual(output2_cpu, output2_xpu.cpu())
146+
self.assertEqual(output3_cpu, output3_xpu.cpu())
147+
self.assertEqual(output4_cpu, output4_xpu.cpu())

0 commit comments

Comments
 (0)