Skip to content

Commit 32ebe98

Browse files
authored
Enable parallelization for insert+search in HNSW - MOD-4628 (#316)
1 parent b4d2c47 commit 32ebe98

17 files changed

+857
-430
lines changed

src/VecSim/algorithms/hnsw/hnsw.h

Lines changed: 328 additions & 198 deletions
Large diffs are not rendered by default.

src/VecSim/algorithms/hnsw/hnsw_batch_iterator.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ HNSW_BatchIterator<DataType, DistType>::HNSW_BatchIterator(
7878

7979
this->dist_func = index->getDistFunc();
8080
this->dim = index->getDim();
81-
this->entry_point = index->getEntryPointId();
81+
this->entry_point = HNSW_INVALID_ID; // temporary until we store the entry point to level 0.
8282
// Use "fresh" tag to mark nodes that were visited along the search in some iteration.
8383
this->visited_list = index->getVisitedList();
8484
this->visited_tag = this->visited_list->getFreshTag();
@@ -114,6 +114,7 @@ VecSimQueryResult_Code HNSW_BatchIterator<DataType, DistType>::scanGraphInternal
114114

115115
// Take the current node out of the candidates queue and go over his neighbours.
116116
candidates.pop();
117+
this->index->lockNodeLinks(curr_node_id);
117118
idType *node_links = this->index->get_linklist_at_level(curr_node_id, 0);
118119
linkListSize links_num = this->index->getListCount(node_links);
119120

@@ -137,6 +138,7 @@ VecSimQueryResult_Code HNSW_BatchIterator<DataType, DistType>::scanGraphInternal
137138
candidates.emplace(candidate_dist, candidate_id);
138139
__builtin_prefetch(index->get_linklist_at_level(candidates.top().second, 0));
139140
}
141+
this->index->unlockNodeLinks(curr_node_id);
140142
}
141143
return VecSim_QueryResult_OK;
142144
}

src/VecSim/algorithms/hnsw/hnsw_factory.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ size_t EstimateInitialSize(const HNSWParams *params) {
6363
est += sizeof(size_t) * params->initialCapacity + sizeof(size_t); // element level
6464
est += sizeof(size_t) * params->initialCapacity +
6565
sizeof(size_t); // Labels lookup hash table buckets.
66+
est += sizeof(std::mutex) * params->initialCapacity + sizeof(size_t); // lock per vector
6667
}
6768

6869
// Explicit allocation calls - always allocate a header.
@@ -115,6 +116,7 @@ size_t EstimateElementSize(const HNSWParams *params) {
115116
// lookup hash map.
116117
size_t size_meta_data =
117118
sizeof(tag_t) + sizeof(size_t) + sizeof(size_t) + size_label_lookup_node;
119+
size_t size_lock = sizeof(std::mutex);
118120

119121
/* Disclaimer: we are neglecting two additional factors that consume memory:
120122
* 1. The overall bucket size in labels_lookup hash table is usually higher than the number of
@@ -123,7 +125,7 @@ size_t EstimateElementSize(const HNSWParams *params) {
123125
* 2. The incoming edges that aren't bidirectional are stored in a dynamic array
124126
* (vecsim_stl::vector) Those edges' memory *is omitted completely* from this estimation.
125127
*/
126-
return size_meta_data + size_total_data_per_element;
128+
return size_meta_data + size_total_data_per_element + size_lock;
127129
}
128130

129131
VecSimIndex *NewTieredIndex(const TieredHNSWParams *params,

src/VecSim/algorithms/hnsw/hnsw_multi.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ class HNSWIndex_Multi : public HNSWIndex<DataType, DistType> {
7373
int addVector(const void *vector_data, labelType label, bool overwrite_allowed = true) override;
7474
double getDistanceFrom(labelType label, const void *vector_data) const override;
7575
inline std::vector<idType> markDelete(labelType label) override;
76+
inline bool safeCheckIfLabelExistsInIndex(labelType label,
77+
bool also_done_processing) const override;
7678
};
7779

7880
/**
@@ -188,3 +190,23 @@ std::vector<idType> HNSWIndex_Multi<DataType, DistType>::markDelete(labelType la
188190
label_lookup_.erase(search);
189191
return idsToDelete;
190192
}
193+
194+
template <typename DataType, typename DistType>
195+
inline bool HNSWIndex_Multi<DataType, DistType>::safeCheckIfLabelExistsInIndex(
196+
labelType label, bool also_done_processing) const {
197+
std::unique_lock<std::mutex> index_data_lock(this->index_data_guard_);
198+
auto search_res = label_lookup_.find(label);
199+
bool exists = search_res != label_lookup_.end();
200+
// If we want to make sure that the vector(s) stored under the label were already indexed,
201+
// we go on and check that every associated vector is no longer in process.
202+
if (exists && also_done_processing) {
203+
for (auto id : search_res->second) {
204+
exists = !this->isInProcess(id);
205+
// If we find at least one internal id that is still in process, consider it as not
206+
// ready.
207+
if (!exists)
208+
return false;
209+
}
210+
}
211+
return exists;
212+
}

src/VecSim/algorithms/hnsw/hnsw_serializer.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ HNSWIndex<DataType, DistType>::HNSWIndex(std::ifstream &input, const HNSWParams
88
params->blockSize, params->multi),
99
Serializer(version), max_elements_(params->initialCapacity), epsilon_(params->epsilon),
1010
element_levels_(max_elements_, allocator),
11-
visited_nodes_handler_pool(1, max_elements_, allocator) {
11+
visited_nodes_handler_pool(1, max_elements_, allocator),
12+
element_neighbors_locks_(max_elements_, allocator) {
1213

1314
this->restoreIndexFields(input);
1415
this->fieldsValidation();

src/VecSim/algorithms/hnsw/hnsw_single.h

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ class HNSWIndex_Single : public HNSWIndex<DataType, DistType> {
6464
int addVector(const void *vector_data, labelType label, bool overwrite_allowed = true) override;
6565
double getDistanceFrom(labelType label, const void *vector_data) const override;
6666
inline std::vector<idType> markDelete(labelType label) override;
67+
inline bool safeCheckIfLabelExistsInIndex(labelType label,
68+
bool also_done_processing = false) const override;
6769
};
6870

6971
/**
@@ -121,6 +123,7 @@ int HNSWIndex_Single<DataType, DistType>::addVector(const void *vector_data, con
121123
bool overwrite_allowed) {
122124

123125
// Checking if an element with the given label already exists.
126+
std::unique_lock<std::mutex> index_data_lock(this->index_data_guard_);
124127
bool label_exists = false;
125128
if (label_lookup_.find(label) != label_lookup_.end()) {
126129
label_exists = true;
@@ -132,7 +135,7 @@ int HNSWIndex_Single<DataType, DistType>::addVector(const void *vector_data, con
132135
return -1;
133136
}
134137
}
135-
138+
index_data_lock.unlock();
136139
this->appendVector(vector_data, label);
137140
// Return the delta in the index size due to the insertion.
138141
return label_exists ? 0 : 1;
@@ -159,6 +162,7 @@ HNSWIndex_Single<DataType, DistType>::newBatchIterator(const void *queryBlob,
159162
template <typename DataType, typename DistType>
160163
std::vector<idType> HNSWIndex_Single<DataType, DistType>::markDelete(labelType label) {
161164
std::vector<idType> idsToDelete;
165+
std::unique_lock<std::mutex> index_data_lock(this->index_data_guard_);
162166
auto search = label_lookup_.find(label);
163167
if (search == label_lookup_.end()) {
164168
return idsToDelete;
@@ -168,3 +172,17 @@ std::vector<idType> HNSWIndex_Single<DataType, DistType>::markDelete(labelType l
168172
label_lookup_.erase(search);
169173
return idsToDelete;
170174
}
175+
176+
template <typename DataType, typename DistType>
177+
inline bool HNSWIndex_Single<DataType, DistType>::safeCheckIfLabelExistsInIndex(
178+
labelType label, bool also_done_processing) const {
179+
std::unique_lock<std::mutex> index_data_lock(this->index_data_guard_);
180+
auto it = label_lookup_.find(label);
181+
bool exists = it != label_lookup_.end();
182+
// If we want to make sure that the vector stored under the label was already indexed,
183+
// we go on and check that its associated internal id is no longer in process.
184+
if (exists && also_done_processing) {
185+
return !this->isInProcess(it->second);
186+
}
187+
return exists;
188+
}

src/VecSim/algorithms/hnsw/hnsw_single_tests_friends.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ INDEX_TEST_FRIEND_CLASS(HNSWTest_preferAdHocOptimization_Test)
1010
INDEX_TEST_FRIEND_CLASS(HNSWTest_testSizeEstimation_Test)
1111
INDEX_TEST_FRIEND_CLASS(IndexAllocatorTest_testIncomingEdgesSet_Test)
1212
INDEX_TEST_FRIEND_CLASS(IndexAllocatorTest_test_hnsw_reclaim_memory_Test)
13+
INDEX_TEST_FRIEND_CLASS(HNSWTestParallel_parallelInsertSearch_Test)

src/VecSim/algorithms/hnsw/visited_nodes_handler.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,11 @@ VisitedNodesHandler::~VisitedNodesHandler() { allocator->free_allocation(element
4141
/**
4242
* VisitedNodesHandlerPool methods to enable parallel graph scans.
4343
*/
44-
VisitedNodesHandlerPool::VisitedNodesHandlerPool(int initial_pool_size, int cap,
44+
VisitedNodesHandlerPool::VisitedNodesHandlerPool(size_t initial_pool_size, int cap,
4545
const std::shared_ptr<VecSimAllocator> &allocator)
4646
: VecsimBaseObject(allocator), pool(initial_pool_size, allocator), num_elements(cap),
4747
total_handlers_in_use(1) {
48-
for (int i = 0; i < initial_pool_size; i++)
48+
for (size_t i = 0; i < initial_pool_size; i++)
4949
pool[i] = new (allocator) VisitedNodesHandler(cap, allocator);
5050
}
5151

src/VecSim/algorithms/hnsw/visited_nodes_handler.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class VisitedNodesHandlerPool : public VecsimBaseObject {
5858
unsigned short total_handlers_in_use;
5959

6060
public:
61-
VisitedNodesHandlerPool(int initial_pool_size, int cap,
61+
VisitedNodesHandlerPool(size_t initial_pool_size, int cap,
6262
const std::shared_ptr<VecSimAllocator> &allocator);
6363

6464
VisitedNodesHandler *getAvailableVisitedNodesHandler();

src/VecSim/utils/vec_utils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ void normalizeVector(DataType *input_vector, size_t dim) {
8181
double sum = 0;
8282

8383
for (size_t i = 0; i < dim; i++) {
84-
sum += input_vector[i] * input_vector[i];
84+
sum += (double)input_vector[i] * (double)input_vector[i];
8585
}
8686
DataType norm = sqrt(sum);
8787

0 commit comments

Comments
 (0)