@@ -69,24 +69,32 @@ class PyBatchIterator {
6969 void reset () { VecSimBatchIterator_Reset (batchIterator.get ()); }
7070 virtual ~PyBatchIterator () {}
7171};
72+ // @input or @query arguments are a py::object object. (numpy arrays are acceptable)
73+
74+ // To convert input or query to a pointer use input_to_blob(input)
75+ // For example:
76+ // VecSimIndex_AddVector(index, input_to_blob(input), id);
7277
7378class PyVecSimIndex {
7479public:
75- PyVecSimIndex () {}
76-
77- PyVecSimIndex (const VecSimParams ¶ms) { index = VecSimIndex_New (¶ms); }
80+ PyVecSimIndex ()
81+ : create_bytearray(
82+ py::module::import (" src.python_bindings.Mybytearray" ).attr(" create_bytearray" )) {}
83+
84+ PyVecSimIndex (const VecSimParams ¶ms)
85+ : create_bytearray(
86+ py::module::import (" src.python_bindings.Mybytearray" ).attr(" create_bytearray" )) {
87+ index = VecSimIndex_New (¶ms);
88+ }
7889
7990 void addVector (py::object input, size_t id) {
80- py::array_t <float , py::array::c_style | py::array::forcecast> items (input);
81- VecSimIndex_AddVector (index, (void *)items.data (0 ), id);
91+ VecSimIndex_AddVector (index, input_to_blob (input), id);
8292 }
83-
8493 void deleteVector (size_t id) { VecSimIndex_DeleteVector (index, id); }
8594
8695 py::object knn (py::object input, size_t k, VecSimQueryParams *query_params) {
87- py::array_t <float , py::array::c_style | py::array::forcecast> items (input);
8896 VecSimQueryResult_List res =
89- VecSimIndex_TopKQuery (index, ( void *)items. data ( 0 ), k, query_params, BY_SCORE);
97+ VecSimIndex_TopKQuery (index, input_to_blob (input ), k, query_params, BY_SCORE);
9098 if (VecSimQueryResult_Len (res) != k) {
9199 throw std::runtime_error (" Cannot return the results in a contiguous 2D array. Probably "
92100 " ef or M is too small" );
@@ -95,27 +103,32 @@ class PyVecSimIndex {
95103 }
96104
97105 py::object range (py::object input, double radius, VecSimQueryParams *query_params) {
98- py::array_t <float , py::array::c_style | py::array::forcecast> items (input);
99106 VecSimQueryResult_List res =
100- VecSimIndex_RangeQuery (index, ( void *)items. data ( 0 ), radius, query_params, BY_SCORE);
107+ VecSimIndex_RangeQuery (index, input_to_blob (input ), radius, query_params, BY_SCORE);
101108 return wrap_results (res, VecSimQueryResult_Len (res));
102109 }
103110
104111 size_t indexSize () { return VecSimIndex_IndexSize (index); }
105112
106- PyBatchIterator createBatchIterator (py::object &query_blob, VecSimQueryParams *query_params) {
107- py::array_t <float , py::array::c_style | py::array::forcecast> items (query_blob);
108- float *vector_data = (float *)items.data (0 );
109- return PyBatchIterator (VecSimBatchIterator_New (index, vector_data, query_params));
113+ PyBatchIterator createBatchIterator (py::object input, VecSimQueryParams *query_params) {
114+ return PyBatchIterator (VecSimBatchIterator_New (index, input_to_blob (input), query_params));
110115 }
111116
112117 virtual ~PyVecSimIndex () { VecSimIndex_Free (index); }
113118
114119protected:
115120 VecSimIndex *index;
121+
122+ private:
123+ // save the bytearray to keep its pointer valid
124+ py::bytearray tmp_bytearray;
125+ const py::function create_bytearray;
126+ const char *input_to_blob (py::object input) {
127+ tmp_bytearray = create_bytearray (input);
128+ return PyByteArray_AS_STRING (tmp_bytearray.ptr ());
129+ }
116130};
117131
118- // Currently supports only floats. TODO change after serializer refactoring
119132class PyHNSWLibIndex : public PyVecSimIndex {
120133public:
121134 PyHNSWLibIndex (const HNSWParams &hnsw_params) {
0 commit comments