@@ -32,19 +32,20 @@ at::Tensor cat_bfloat16_float(const at::Tensor top_half_,
3232 " pack_bfloat16_float: expect both args to be at::BFloat16" );
3333 at::Tensor top_half = top_half_.contiguous ();
3434 at::Tensor bottom_half = bottom_half_.contiguous ();
35- at::Tensor output =
36- at::empty (top_half.sizes (), top_half.options ().dtype (at::kFloat ));
35+ at::Tensor output = at::empty_strided (top_half_.sizes (), top_half_.strides (),
36+ top_half_.options ().dtype (at::kFloat ));
37+ at::Tensor output_contiguous = output.contiguous ();
3738 using bVec = at::vec::Vectorized<at::BFloat16>;
3839 using fVec = at::vec::Vectorized<float >;
39- at::BFloat16* top_half_data = top_half.data_ptr <at::BFloat16>();
40- at::BFloat16* bottom_half_data = bottom_half.data_ptr <at::BFloat16>();
41- float * output_data = output .data_ptr <float >();
40+ at::BFloat16 * top_half_data = top_half.data_ptr <at::BFloat16>();
41+ at::BFloat16 * bottom_half_data = bottom_half.data_ptr <at::BFloat16>();
42+ float * output_data = output_contiguous .data_ptr <float >();
4243 int64_t grain_size = 512 ;
4344 at::parallel_for (0 , top_half.numel (), grain_size, [&](int64_t begin, int64_t end) {
4445 // local pointers
45- at::BFloat16* top_half_ptr = top_half_data + begin;
46- at::BFloat16* bottom_half_ptr = bottom_half_data + begin;
47- float * output_ptr = output_data + begin;
46+ at::BFloat16 * top_half_ptr = top_half_data + begin;
47+ at::BFloat16 * bottom_half_ptr = bottom_half_data + begin;
48+ float * output_ptr = output_data + begin;
4849 const int64_t size = end - begin;
4950 int64_t d = 0 ;
5051 for (; d < size - (size % bVec::size ()); d += bVec::size ()) {
@@ -59,10 +60,8 @@ at::Tensor cat_bfloat16_float(const at::Tensor top_half_,
5960 output_ptr[d] = bf16::pack_bfloat16_float (top_half_ptr[d], bottom_half_ptr[d]);
6061 }
6162 });
62- if (!top_half_.is_contiguous ()) {
63- output = at::empty_strided (top_half_.sizes (), top_half_.strides (),
64- top_half_.options ().dtype (at::kFloat ))
65- .copy_ (output);
63+ if (!output.is_contiguous ()) {
64+ output.copy_ (output_contiguous);
6665 }
6766 return output;
6867}
@@ -73,22 +72,24 @@ split_float_bfloat16(const at::Tensor tensor_) {
7372 " pack_bfloat16_float: expect both tensor to be at::kFloat" );
7473
7574 auto tensor = tensor_.contiguous ();
76- at::Tensor top_half =
77- at::empty (tensor.sizes (), tensor.options ().dtype (at::kBFloat16 ));
78- at::Tensor bottom_half =
79- at::empty (tensor.sizes (), tensor.options ().dtype (at::kBFloat16 ));
80-
75+ auto top_half = at::empty_strided (tensor_.sizes (), tensor_.strides (),
76+ tensor_.options ().dtype (at::kBFloat16 ));
77+ auto top_half_contiguous = top_half.contiguous ();
78+ auto bottom_half = at::empty_strided (tensor_.sizes (), tensor_.strides (),
79+ tensor_.options ().dtype (at::kBFloat16 ));
80+ auto bottom_half_contiguous = bottom_half.contiguous ();
8181 using bVec = at::vec::Vectorized<at::BFloat16>;
8282 using fVec = at::vec::Vectorized<float >;
83- at::BFloat16* top_half_data = top_half.data_ptr <at::BFloat16>();
84- at::BFloat16* bottom_half_data = bottom_half.data_ptr <at::BFloat16>();
85- float * tensor_data = tensor.data_ptr <float >();
83+ at::BFloat16 *top_half_data = top_half_contiguous.data_ptr <at::BFloat16>();
84+ at::BFloat16 *bottom_half_data =
85+ bottom_half_contiguous.data_ptr <at::BFloat16>();
86+ float *tensor_data = tensor.data_ptr <float >();
8687 int64_t grain_size = 512 ;
8788 at::parallel_for (0 , top_half.numel (), grain_size, [&](int64_t begin, int64_t end) {
8889 // local pointers
89- at::BFloat16* top_half_ptr = top_half_data + begin;
90- at::BFloat16* bottom_half_ptr = bottom_half_data + begin;
91- float * tensor_ptr = tensor_data + begin;
90+ at::BFloat16 * top_half_ptr = top_half_data + begin;
91+ at::BFloat16 * bottom_half_ptr = bottom_half_data + begin;
92+ float * tensor_ptr = tensor_data + begin;
9293 const int64_t size = end - begin;
9394 int64_t d = 0 ;
9495 for (; d < size - (size % bVec::size ()); d += bVec::size ()) {
@@ -107,13 +108,9 @@ split_float_bfloat16(const at::Tensor tensor_) {
107108 bottom_half_ptr[d] = bottom_half_val;
108109 }
109110 });
110- if (!tensor_.is_contiguous ()) {
111- top_half = at::empty_strided (tensor_.sizes (), tensor_.strides (),
112- tensor_.options ().dtype (at::kBFloat16 ))
113- .copy_ (top_half);
114- bottom_half = at::empty_strided (tensor_.sizes (), tensor_.strides (),
115- tensor_.options ().dtype (at::kBFloat16 ))
116- .copy_ (bottom_half);
111+ if (!top_half.is_contiguous ()) {
112+ top_half.copy_ (top_half_contiguous);
113+ bottom_half.copy_ (bottom_half_contiguous);
117114 }
118115 return std::tie (top_half, bottom_half);
119116}
0 commit comments