Skip to content

Commit c4809ed

Browse files
authored
fix fp32/bf16 cast (#135)
* fix fp32/bf16 cast * fix code format
1 parent 7b8ce5b commit c4809ed

File tree

2 files changed

+32
-30
lines changed

2 files changed

+32
-30
lines changed

tests/cpu/test_ipex_optimize.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ def _test_tensor_convert(self, tensor, bf16_tensor):
9292
# recovery float tensor with top half and bottom half
9393
float_tensor = torch.ops.torch_ipex.cat_bfloat16_float(top_half, bot_half)
9494
self.assertEqual(tensor, float_tensor)
95+
self.assertEqual(tensor.stride(), top_half.stride())
96+
self.assertEqual(tensor.stride(), float_tensor.stride())
9597

9698
def test_tensor_convert(self):
9799
# contiguous case
@@ -101,5 +103,8 @@ def test_tensor_convert(self):
101103
self._test_tensor_convert(tensor.t(), tensor.bfloat16().t())
102104
# sliced-out case
103105
self._test_tensor_convert(tensor[2:5, 2:5], tensor.bfloat16()[2:5, 2:5])
106+
# nc11 channel-last case
107+
tensor = torch.rand(128, 256, 1, 1).to(memory_format=torch.channels_last)
108+
self._test_tensor_convert(tensor, tensor.bfloat16())
104109
if __name__ == '__main__':
105110
test = unittest.main()

torch_ipex/csrc/cpu/bf16/Converter.cpp

Lines changed: 27 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -32,19 +32,20 @@ at::Tensor cat_bfloat16_float(const at::Tensor top_half_,
3232
"pack_bfloat16_float: expect both args to be at::BFloat16");
3333
at::Tensor top_half = top_half_.contiguous();
3434
at::Tensor bottom_half = bottom_half_.contiguous();
35-
at::Tensor output =
36-
at::empty(top_half.sizes(), top_half.options().dtype(at::kFloat));
35+
at::Tensor output = at::empty_strided(top_half_.sizes(), top_half_.strides(),
36+
top_half_.options().dtype(at::kFloat));
37+
at::Tensor output_contiguous = output.contiguous();
3738
using bVec = at::vec::Vectorized<at::BFloat16>;
3839
using fVec = at::vec::Vectorized<float>;
39-
at::BFloat16* top_half_data = top_half.data_ptr<at::BFloat16>();
40-
at::BFloat16* bottom_half_data = bottom_half.data_ptr<at::BFloat16>();
41-
float* output_data = output.data_ptr<float>();
40+
at::BFloat16 *top_half_data = top_half.data_ptr<at::BFloat16>();
41+
at::BFloat16 *bottom_half_data = bottom_half.data_ptr<at::BFloat16>();
42+
float *output_data = output_contiguous.data_ptr<float>();
4243
int64_t grain_size = 512;
4344
at::parallel_for(0, top_half.numel(), grain_size, [&](int64_t begin, int64_t end) {
4445
// local pointers
45-
at::BFloat16* top_half_ptr = top_half_data + begin;
46-
at::BFloat16* bottom_half_ptr = bottom_half_data + begin;
47-
float* output_ptr = output_data + begin;
46+
at::BFloat16 *top_half_ptr = top_half_data + begin;
47+
at::BFloat16 *bottom_half_ptr = bottom_half_data + begin;
48+
float *output_ptr = output_data + begin;
4849
const int64_t size = end - begin;
4950
int64_t d = 0;
5051
for (; d < size - (size % bVec::size()); d += bVec::size()) {
@@ -59,10 +60,8 @@ at::Tensor cat_bfloat16_float(const at::Tensor top_half_,
5960
output_ptr[d] = bf16::pack_bfloat16_float(top_half_ptr[d], bottom_half_ptr[d]);
6061
}
6162
});
62-
if (!top_half_.is_contiguous()) {
63-
output = at::empty_strided(top_half_.sizes(), top_half_.strides(),
64-
top_half_.options().dtype(at::kFloat))
65-
.copy_(output);
63+
if (!output.is_contiguous()) {
64+
output.copy_(output_contiguous);
6665
}
6766
return output;
6867
}
@@ -73,22 +72,24 @@ split_float_bfloat16(const at::Tensor tensor_) {
7372
"pack_bfloat16_float: expect both tensor to be at::kFloat");
7473

7574
auto tensor = tensor_.contiguous();
76-
at::Tensor top_half =
77-
at::empty(tensor.sizes(), tensor.options().dtype(at::kBFloat16));
78-
at::Tensor bottom_half =
79-
at::empty(tensor.sizes(), tensor.options().dtype(at::kBFloat16));
80-
75+
auto top_half = at::empty_strided(tensor_.sizes(), tensor_.strides(),
76+
tensor_.options().dtype(at::kBFloat16));
77+
auto top_half_contiguous = top_half.contiguous();
78+
auto bottom_half = at::empty_strided(tensor_.sizes(), tensor_.strides(),
79+
tensor_.options().dtype(at::kBFloat16));
80+
auto bottom_half_contiguous = bottom_half.contiguous();
8181
using bVec = at::vec::Vectorized<at::BFloat16>;
8282
using fVec = at::vec::Vectorized<float>;
83-
at::BFloat16* top_half_data = top_half.data_ptr<at::BFloat16>();
84-
at::BFloat16* bottom_half_data = bottom_half.data_ptr<at::BFloat16>();
85-
float* tensor_data = tensor.data_ptr<float>();
83+
at::BFloat16 *top_half_data = top_half_contiguous.data_ptr<at::BFloat16>();
84+
at::BFloat16 *bottom_half_data =
85+
bottom_half_contiguous.data_ptr<at::BFloat16>();
86+
float *tensor_data = tensor.data_ptr<float>();
8687
int64_t grain_size = 512;
8788
at::parallel_for(0, top_half.numel(), grain_size, [&](int64_t begin, int64_t end) {
8889
// local pointers
89-
at::BFloat16* top_half_ptr = top_half_data + begin;
90-
at::BFloat16* bottom_half_ptr = bottom_half_data + begin;
91-
float* tensor_ptr = tensor_data + begin;
90+
at::BFloat16 *top_half_ptr = top_half_data + begin;
91+
at::BFloat16 *bottom_half_ptr = bottom_half_data + begin;
92+
float *tensor_ptr = tensor_data + begin;
9293
const int64_t size = end - begin;
9394
int64_t d = 0;
9495
for (; d < size - (size % bVec::size()); d += bVec::size()) {
@@ -107,13 +108,9 @@ split_float_bfloat16(const at::Tensor tensor_) {
107108
bottom_half_ptr[d] = bottom_half_val;
108109
}
109110
});
110-
if (!tensor_.is_contiguous()) {
111-
top_half = at::empty_strided(tensor_.sizes(), tensor_.strides(),
112-
tensor_.options().dtype(at::kBFloat16))
113-
.copy_(top_half);
114-
bottom_half = at::empty_strided(tensor_.sizes(), tensor_.strides(),
115-
tensor_.options().dtype(at::kBFloat16))
116-
.copy_(bottom_half);
111+
if (!top_half.is_contiguous()) {
112+
top_half.copy_(top_half_contiguous);
113+
bottom_half.copy_(bottom_half_contiguous);
117114
}
118115
return std::tie(top_half, bottom_half);
119116
}

0 commit comments

Comments
 (0)