33#include < pybind11/pybind11.h>
44#include < pybind11/numpy.h>
55#include < pybind11/operators.h>
6+ #include < pybind11/stl.h>
67
78using namespace tensor_array ::value;
89using namespace tensor_array ::datatype;
@@ -35,9 +36,9 @@ pybind11::dtype get_py_type(const std::type_info& info)
3536 throw std::exception ();
3637}
3738
38- pybind11::array convert_tensor_to_numpy (const Tensor& tensor )
39+ pybind11::array convert_tensor_to_numpy (const Tensor& self )
3940{
40- const TensorBase& base_tensor = tensor .get_buffer ().change_device ({tensor_array::devices::CPU, 0 });
41+ const TensorBase& base_tensor = self .get_buffer ().change_device ({tensor_array::devices::CPU, 0 });
4142 std::vector<pybind11::size_t > shape_vec (base_tensor.shape ().size ());
4243 std::transform
4344 (
@@ -54,15 +55,15 @@ pybind11::array convert_tensor_to_numpy(const Tensor& tensor)
5455 return pybind11::array (ty1, shape_vec, base_tensor.data ());
5556}
5657
57- Tensor python_tuple_slice (const Tensor& t , pybind11::tuple tuple_slice)
58+ Tensor python_tuple_slice (const Tensor& self , pybind11::tuple tuple_slice)
5859{
5960 std::vector<Tensor::Slice> t_slices;
6061 for (size_t i = 0 ; i < tuple_slice.size (); i++)
6162 {
6263 ssize_t start, stop, step;
6364 ssize_t length;
6465 pybind11::slice py_slice = tuple_slice[i].cast <pybind11::slice>();
65- if (!py_slice.compute (t .get_buffer ().shape ().begin ()[i], &start, &stop, &step, &length))
66+ if (!py_slice.compute (self .get_buffer ().shape ().begin ()[i], &start, &stop, &step, &length))
6667 throw std::runtime_error (" Invalid slice" );
6768 t_slices.insert
6869 (
@@ -75,17 +76,17 @@ Tensor python_tuple_slice(const Tensor& t, pybind11::tuple tuple_slice)
7576 }
7677 );
7778 }
78- return t [tensor_array::wrapper::initializer_wrapper (t_slices.begin ().operator ->(), t_slices.end ().operator ->())];
79+ return self [tensor_array::wrapper::initializer_wrapper (t_slices.begin ().operator ->(), t_slices.end ().operator ->())];
7980}
8081
81- Tensor python_slice (const Tensor& t , pybind11::slice py_slice)
82+ Tensor python_slice (const Tensor& self , pybind11::slice py_slice)
8283{
8384 std::vector<Tensor::Slice> t_slices;
8485 ssize_t start, stop, step;
8586 ssize_t length;
86- if (!py_slice.compute (t .get_buffer ().shape ().begin ()[0 ], &start, &stop, &step, &length))
87+ if (!py_slice.compute (self .get_buffer ().shape ().begin ()[0 ], &start, &stop, &step, &length))
8788 throw std::runtime_error (" Invalid slice" );
88- return t
89+ return self
8990 [
9091 {
9192 Tensor::Slice
@@ -98,25 +99,43 @@ Tensor python_slice(const Tensor& t, pybind11::slice py_slice)
9899 ];
99100}
100101
101- Tensor python_index (const Tensor& t , unsigned int i)
102+ Tensor python_index (const Tensor& self , unsigned int i)
102103{
103- return t [i];
104+ return self [i];
104105}
105106
106- std::size_t python_len (const Tensor& t )
107+ std::size_t python_len (const Tensor& self )
107108{
108- std::initializer_list<unsigned int > shape_list = t .get_buffer ().shape ();
109+ std::initializer_list<unsigned int > shape_list = self .get_buffer ().shape ();
109110 return shape_list.size () != 0 ? shape_list.begin ()[0 ]: 1U ;
110111}
111112
112- pybind11::str tensor_to_string (const Tensor& t )
113+ pybind11::str tensor_to_string (const Tensor& self )
113114{
114- return pybind11::repr (convert_tensor_to_numpy (t ));
115+ return pybind11::repr (convert_tensor_to_numpy (self ));
115116}
116117
117- Tensor tensor_cast_1 (const Tensor& t , DataType dtype)
118+ Tensor tensor_cast_1 (const Tensor& self , DataType dtype)
118119{
119- return t.tensor_cast (warp_type (dtype));
120+ return self.tensor_cast (warp_type (dtype));
121+ }
122+
123+ pybind11::tuple tensor_shape (const Tensor& self)
124+ {
125+ return pybind11::cast (std::vector (self.get_buffer ().shape ()));
126+ }
127+
128+ Tensor tensor_copying (const Tensor& self)
129+ {
130+ return self;
131+ }
132+
133+ Tensor py_zeros (pybind11::tuple shape_tuple, DataType dtype)
134+ {
135+ std::vector<unsigned int > shape_vec;
136+ for (auto & it: shape_tuple)
137+ shape_vec.push_back (it.cast <unsigned int >());
138+ return TensorBase (warp_type (dtype), shape_vec);
120139}
121140
122141PYBIND11_MODULE (tensor2, m)
@@ -136,9 +155,18 @@ PYBIND11_MODULE(tensor2, m)
136155 .value (" U_INT_32" , U_INT_32)
137156 .value (" U_INT_64" , U_INT_64)
138157 .export_values ();
158+
159+ m.def
160+ (
161+ " zeros" ,
162+ &py_zeros,
163+ pybind11::arg (" shape" ),
164+ pybind11::arg (" dtype" ) = S_INT_32
165+ );
139166
140167 pybind11::class_<Tensor>(m, " Tensor" )
141168 .def (pybind11::init ())
169+ .def (pybind11::init (&tensor_copying))
142170 .def (pybind11::init (&convert_numpy_to_tensor_base<float >))
143171 .def (pybind11::self + pybind11::self)
144172 .def (pybind11::self - pybind11::self)
@@ -176,11 +204,13 @@ PYBIND11_MODULE(tensor2, m)
176204 .def (" matmul" , &matmul)
177205 .def (" condition" , &condition)
178206 .def (" numpy" , &convert_tensor_to_numpy)
207+ .def (" shape" , &tensor_shape)
179208 .def (" __getitem__" , &python_index)
180209 .def (" __getitem__" , &python_slice)
181210 .def (" __getitem__" , &python_tuple_slice)
182211 .def (" __len__" , &python_len)
183212 .def (" __matmul__" , &matmul)
184213 .def (" __rmatmul__" , &matmul)
185- .def (" __repr__" , &tensor_to_string);
214+ .def (" __repr__" , &tensor_to_string)
215+ .def (" __copy__" , &tensor_copying);
186216}
0 commit comments