11#include < tensor-array/core/tensor.hh>
2+ #include < tensor-array/core/data_type_wrapper.hh>
23#include < pybind11/pybind11.h>
34#include < pybind11/numpy.h>
45#include < pybind11/operators.h>
56
67using namespace tensor_array ::value;
8+ using namespace tensor_array ::datatype;
79
810template <typename T>
911TensorBase convert_numpy_to_tensor_base (pybind11::array_t <T> py_buf)
@@ -20,12 +22,22 @@ TensorBase convert_numpy_to_tensor_base(pybind11::array_t<T> py_buf)
2022 return static_cast <unsigned int >(dim);
2123 }
2224 );
25+ warp_type (warp_type (typeid (T)));
2326 return TensorBase (typeid (T), shape_vec, info.ptr );
2427}
2528
29+ pybind11::dtype get_py_type (const std::type_info& info)
30+ {
31+ if (info == typeid (bool ))
32+ return pybind11::dtype::of<bool >();
33+ if (info == typeid (float ))
34+ return pybind11::dtype::of<float >();
35+ throw std::exception ();
36+ }
37+
2638pybind11::array convert_tensor_to_numpy (const Tensor& tensor)
2739{
28- const TensorBase& base_tensor = tensor.get_buffer ();
40+ const TensorBase& base_tensor = tensor.get_buffer (). change_device ({tensor_array::devices::CPU, 0 }) ;
2941 std::vector<pybind11::size_t > shape_vec (base_tensor.shape ().size ());
3042 std::transform
3143 (
@@ -37,8 +49,9 @@ pybind11::array convert_tensor_to_numpy(const Tensor& tensor)
3749 return static_cast <pybind11::size_t >(dim);
3850 }
3951 );
40- pybind11::array arr = pybind11::array ();
41- return arr;
52+ auto ty0 = pybind11::detail::get_type_info (base_tensor.type ());
53+ pybind11::dtype ty1 = get_py_type (base_tensor.type ());
54+ return pybind11::array (ty1, shape_vec, base_tensor.data ());
4255}
4356
4457Tensor python_tuple_slice (const Tensor& t, pybind11::tuple tuple_slice)
@@ -107,15 +120,34 @@ std::size_t python_len(const Tensor& t)
107120 return shape_list.size () != 0 ? shape_list.begin ()[0 ]: 1U ;
108121}
109122
110- std::string tensor_to_string (const Tensor& t)
123+ pybind11::str tensor_to_string (const Tensor& t)
111124{
112- std::ostringstream osstream;
113- osstream << t;
114- return osstream.str ();
125+ return pybind11::repr (convert_tensor_to_numpy (t));
126+ }
127+
128+ Tensor tensor_cast_1 (const Tensor& t, DataType dtype)
129+ {
130+ return t.tensor_cast (warp_type (dtype));
115131}
116132
117133PYBIND11_MODULE (tensor2, m)
118134{
135+ pybind11::enum_<DataType>(m, " DataType" )
136+ .value (" BOOL" , BOOL_DTYPE)
137+ .value (" S_INT_8" , S_INT_8)
138+ .value (" S_INT_16" , S_INT_16)
139+ .value (" S_INT_32" , S_INT_32)
140+ .value (" S_INT_64" , S_INT_64)
141+ .value (" FLOAT" , FLOAT_DTYPE)
142+ .value (" DOUBLE" , DOUBLE_DTYPE)
143+ .value (" HALF" , HALF_DTYPE)
144+ .value (" BFLOAT16" , BF16_DTYPE)
145+ .value (" U_INT_8" , U_INT_8)
146+ .value (" U_INT_16" , U_INT_16)
147+ .value (" U_INT_32" , U_INT_32)
148+ .value (" U_INT_64" , U_INT_64)
149+ .export_values ();
150+
119151 pybind11::class_<Tensor>(m, " Tensor" )
120152 .def (pybind11::init ())
121153 .def (pybind11::init (&convert_numpy_to_tensor_base<float >))
@@ -138,11 +170,22 @@ PYBIND11_MODULE(tensor2, m)
138170 .def (hash (pybind11::self))
139171 .def (" transpose" , &Tensor::transpose)
140172 .def (" calc_grad" , &Tensor::calc_grad)
173+ .def (" sin" , &Tensor::sin)
174+ .def (" sin" , &Tensor::sin)
175+ .def (" cos" , &Tensor::cos)
176+ .def (" tan" , &Tensor::tan)
177+ .def (" sinh" , &Tensor::sinh)
178+ .def (" cosh" , &Tensor::cosh)
179+ .def (" tanh" , &Tensor::tanh)
180+ .def (" log" , &Tensor::log)
181+ .def (" clone" , &Tensor::clone)
182+ .def (" cast" , &tensor_cast_1)
141183 .def (" add" , &add)
142184 .def (" multiply" , &multiply)
143185 .def (" divide" , ÷)
144186 .def (" matmul" , &matmul)
145187 .def (" condition" , &condition)
188+ .def (" numpy" , &convert_tensor_to_numpy)
146189 .def (" __getitem__" , &python_index)
147190 .def (" __getitem__" , &python_slice)
148191 .def (" __getitem__" , &python_tuple_slice)
0 commit comments