Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update hnswlib to v0.8.0 #17

Merged
merged 8 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions 3rd_party/hnswlib/bruteforce.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,16 @@ class BruteforceSearch : public AlgorithmInterface<dist_t> {


void removePoint(labeltype cur_external) {
size_t cur_c = dict_external_to_internal[cur_external];
std::unique_lock<std::mutex> lock(index_lock);

dict_external_to_internal.erase(cur_external);
auto found = dict_external_to_internal.find(cur_external);
if (found == dict_external_to_internal.end()) {
return;
}

dict_external_to_internal.erase(found);

size_t cur_c = found->second;
labeltype label = *((labeltype*)(data_ + size_per_element_ * (cur_element_count-1) + data_size_));
dict_external_to_internal[label] = cur_c;
memcpy(data_ + size_per_element_ * cur_c,
Expand All @@ -106,7 +112,7 @@ class BruteforceSearch : public AlgorithmInterface<dist_t> {
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_));
if ((!isIdAllowed) || (*isIdAllowed)(label)) {
topResults.push(std::pair<dist_t, labeltype>(dist, label));
topResults.emplace(dist, label);
}
}
dist_t lastdist = topResults.empty() ? std::numeric_limits<dist_t>::max() : topResults.top().first;
Expand All @@ -115,7 +121,7 @@ class BruteforceSearch : public AlgorithmInterface<dist_t> {
if (dist <= lastdist) {
labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_));
if ((!isIdAllowed) || (*isIdAllowed)(label)) {
topResults.push(std::pair<dist_t, labeltype>(dist, label));
topResults.emplace(dist, label);
}
if (topResults.size() > k)
topResults.pop();
Expand Down
191 changes: 166 additions & 25 deletions 3rd_party/hnswlib/hnswalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <assert.h>
#include <unordered_set>
#include <list>
#include <memory>

namespace hnswlib {
typedef unsigned int tableint;
Expand All @@ -33,7 +34,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
double mult_{0.0}, revSize_{0.0};
int maxlevel_{0};

VisitedListPool *visited_list_pool_{nullptr};
std::unique_ptr<VisitedListPool> visited_list_pool_{nullptr};

// Locks operations with element by label value
mutable std::vector<std::mutex> label_op_locks_;
Expand Down Expand Up @@ -92,16 +93,22 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
size_t ef_construction = 200,
size_t random_seed = 100,
bool allow_replace_deleted = false)
: link_list_locks_(max_elements),
label_op_locks_(MAX_LABEL_OPERATION_LOCKS),
: label_op_locks_(MAX_LABEL_OPERATION_LOCKS),
link_list_locks_(max_elements),
element_levels_(max_elements),
allow_replace_deleted_(allow_replace_deleted) {
max_elements_ = max_elements;
num_deleted_ = 0;
data_size_ = s->get_data_size();
fstdistfunc_ = s->get_dist_func();
dist_func_param_ = s->get_dist_func_param();
M_ = M;
if ( M <= 10000 ) {
M_ = M;
} else {
HNSWERR << "warning: M parameter exceeds 10000 which may lead to adverse effects." << std::endl;
HNSWERR << " Cap to 10000 will be applied for the rest of the processing." << std::endl;
M_ = 10000;
}
maxM_ = M_;
maxM0_ = M_ * 2;
ef_construction_ = std::max(ef_construction, M_);
Expand All @@ -122,7 +129,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {

cur_element_count = 0;

visited_list_pool_ = new VisitedListPool(1, max_elements);
visited_list_pool_ = std::unique_ptr<VisitedListPool>(new VisitedListPool(1, max_elements));

// initializations for special treatment of the first node
enterpoint_node_ = -1;
Expand All @@ -138,13 +145,20 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {


~HierarchicalNSW() {
clear();
}

void clear() {
free(data_level0_memory_);
data_level0_memory_ = nullptr;
for (tableint i = 0; i < cur_element_count; i++) {
if (element_levels_[i] > 0)
free(linkLists_[i]);
}
free(linkLists_);
delete visited_list_pool_;
linkLists_ = nullptr;
cur_element_count = 0;
visited_list_pool_.reset(nullptr);
}


Expand Down Expand Up @@ -291,9 +305,15 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
}


template <bool has_deletions, bool collect_metrics = false>
// bare_bone_search means there is no check for deletions and stop condition is ignored in return of extra performance
template <bool bare_bone_search = true, bool collect_metrics = false>
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, BaseFilterFunctor* isIdAllowed = nullptr) const {
searchBaseLayerST(
tableint ep_id,
const void *data_point,
size_t ef,
BaseFilterFunctor* isIdAllowed = nullptr,
BaseSearchStopCondition<dist_t>* stop_condition = nullptr) const {
VisitedList *vl = visited_list_pool_->getFreeVisitedList();
vl_type *visited_array = vl->mass;
vl_type visited_array_tag = vl->curV;
Expand All @@ -302,10 +322,15 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidate_set;

dist_t lowerBound;
if ((!has_deletions || !isMarkedDeleted(ep_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(ep_id)))) {
dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_);
if (bare_bone_search ||
(!isMarkedDeleted(ep_id) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(ep_id))))) {
char* ep_data = getDataByInternalId(ep_id);
dist_t dist = fstdistfunc_(data_point, ep_data, dist_func_param_);
lowerBound = dist;
top_candidates.emplace(dist, ep_id);
if (!bare_bone_search && stop_condition) {
stop_condition->add_point_to_result(getExternalLabel(ep_id), ep_data, dist);
}
candidate_set.emplace(-dist, ep_id);
} else {
lowerBound = std::numeric_limits<dist_t>::max();
Expand All @@ -316,9 +341,19 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {

while (!candidate_set.empty()) {
std::pair<dist_t, tableint> current_node_pair = candidate_set.top();
dist_t candidate_dist = -current_node_pair.first;

if ((-current_node_pair.first) > lowerBound &&
(top_candidates.size() == ef || (!isIdAllowed && !has_deletions))) {
bool flag_stop_search;
if (bare_bone_search) {
flag_stop_search = candidate_dist > lowerBound;
} else {
if (stop_condition) {
flag_stop_search = stop_condition->should_stop_search(candidate_dist, lowerBound);
} else {
flag_stop_search = candidate_dist > lowerBound && top_candidates.size() == ef;
}
}
if (flag_stop_search) {
break;
}
candidate_set.pop();
Expand Down Expand Up @@ -353,19 +388,45 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
char *currObj1 = (getDataByInternalId(candidate_id));
dist_t dist = fstdistfunc_(data_point, currObj1, dist_func_param_);

if (top_candidates.size() < ef || lowerBound > dist) {
bool flag_consider_candidate;
if (!bare_bone_search && stop_condition) {
flag_consider_candidate = stop_condition->should_consider_candidate(dist, lowerBound);
} else {
flag_consider_candidate = top_candidates.size() < ef || lowerBound > dist;
}

if (flag_consider_candidate) {
candidate_set.emplace(-dist, candidate_id);
#ifdef USE_SSE
_mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_ +
offsetLevel0_, ///////////
_MM_HINT_T0); ////////////////////////
#endif

if ((!has_deletions || !isMarkedDeleted(candidate_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(candidate_id))))
if (bare_bone_search ||
(!isMarkedDeleted(candidate_id) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(candidate_id))))) {
top_candidates.emplace(dist, candidate_id);
if (!bare_bone_search && stop_condition) {
stop_condition->add_point_to_result(getExternalLabel(candidate_id), currObj1, dist);
}
}

if (top_candidates.size() > ef)
bool flag_remove_extra = false;
if (!bare_bone_search && stop_condition) {
flag_remove_extra = stop_condition->should_remove_extra();
} else {
flag_remove_extra = top_candidates.size() > ef;
}
while (flag_remove_extra) {
tableint id = top_candidates.top().second;
top_candidates.pop();
if (!bare_bone_search && stop_condition) {
stop_condition->remove_point_from_result(getExternalLabel(id), getDataByInternalId(id), dist);
flag_remove_extra = stop_condition->should_remove_extra();
} else {
flag_remove_extra = top_candidates.size() > ef;
}
}

if (!top_candidates.empty())
lowerBound = top_candidates.top().first;
Expand All @@ -380,8 +441,8 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {


void getNeighborsByHeuristic2(
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates,
const size_t M) {
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates,
const size_t M) {
if (top_candidates.size() < M) {
return;
}
Expand Down Expand Up @@ -573,8 +634,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
if (new_max_elements < cur_element_count)
throw std::runtime_error("Cannot resize, max element is less than the current number of elements");

delete visited_list_pool_;
visited_list_pool_ = new VisitedListPool(1, new_max_elements);
visited_list_pool_.reset(new VisitedListPool(1, new_max_elements));

element_levels_.resize(new_max_elements);

Expand All @@ -595,6 +655,32 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
max_elements_ = new_max_elements;
}

size_t indexFileSize() const {
size_t size = 0;
size += sizeof(offsetLevel0_);
size += sizeof(max_elements_);
size += sizeof(cur_element_count);
size += sizeof(size_data_per_element_);
size += sizeof(label_offset_);
size += sizeof(offsetData_);
size += sizeof(maxlevel_);
size += sizeof(enterpoint_node_);
size += sizeof(maxM_);

size += sizeof(maxM0_);
size += sizeof(M_);
size += sizeof(mult_);
size += sizeof(ef_construction_);

size += cur_element_count * size_data_per_element_;

for (size_t i = 0; i < cur_element_count; i++) {
unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0;
size += sizeof(linkListSize);
size += linkListSize;
}
return size;
}

void saveIndex(const std::string &location) {
std::ofstream output(location, std::ios::binary);
Expand Down Expand Up @@ -633,6 +719,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
if (!input.is_open())
throw std::runtime_error("Cannot open file");

clear();
// get file size:
input.seekg(0, input.end);
std::streampos total_filesize = input.tellg();
Expand Down Expand Up @@ -698,7 +785,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
std::vector<std::mutex>(max_elements).swap(link_list_locks_);
std::vector<std::mutex>(MAX_LABEL_OPERATION_LOCKS).swap(label_op_locks_);

visited_list_pool_ = new VisitedListPool(1, max_elements);
visited_list_pool_.reset(new VisitedListPool(1, max_elements));

linkLists_ = (char **) malloc(sizeof(void *) * max_elements);
if (linkLists_ == nullptr)
Expand Down Expand Up @@ -752,7 +839,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
size_t dim = *((size_t *) dist_func_param_);
std::vector<data_t> data;
data_t* data_ptr = (data_t*) data_ptrv;
for (int i = 0; i < dim; i++) {
for (size_t i = 0; i < dim; i++) {
data.push_back(*data_ptr);
data_ptr += 1;
}
Expand Down Expand Up @@ -1216,11 +1303,12 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
}

std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
if (num_deleted_) {
top_candidates = searchBaseLayerST<true, true>(
bool bare_bone_search = !num_deleted_ && !isIdAllowed;
if (bare_bone_search) {
top_candidates = searchBaseLayerST<true>(
currObj, query_data, std::max(ef_, k), isIdAllowed);
} else {
top_candidates = searchBaseLayerST<false, true>(
top_candidates = searchBaseLayerST<false>(
currObj, query_data, std::max(ef_, k), isIdAllowed);
}

Expand All @@ -1236,6 +1324,60 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
}


std::vector<std::pair<dist_t, labeltype >>
searchStopConditionClosest(
const void *query_data,
BaseSearchStopCondition<dist_t>& stop_condition,
BaseFilterFunctor* isIdAllowed = nullptr) const {
std::vector<std::pair<dist_t, labeltype >> result;
if (cur_element_count == 0) return result;

tableint currObj = enterpoint_node_;
dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_);

for (int level = maxlevel_; level > 0; level--) {
bool changed = true;
while (changed) {
changed = false;
unsigned int *data;

data = (unsigned int *) get_linklist(currObj, level);
int size = getListCount(data);
metric_hops++;
metric_distance_computations+=size;

tableint *datal = (tableint *) (data + 1);
for (int i = 0; i < size; i++) {
tableint cand = datal[i];
if (cand < 0 || cand > max_elements_)
throw std::runtime_error("cand error");
dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_);

if (d < curdist) {
curdist = d;
currObj = cand;
changed = true;
}
}
}
}

std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
top_candidates = searchBaseLayerST<false>(currObj, query_data, 0, isIdAllowed, &stop_condition);

size_t sz = top_candidates.size();
result.resize(sz);
while (!top_candidates.empty()) {
result[--sz] = top_candidates.top();
top_candidates.pop();
}

stop_condition.filter_results(result);

return result;
}


void checkIntegrity() {
int connections_checked = 0;
std::vector <int > inbound_connections_num(cur_element_count, 0);
Expand All @@ -1246,7 +1388,6 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
tableint *data = (tableint *) (ll_cur + 1);
std::unordered_set<tableint> s;
for (int j = 0; j < size; j++) {
assert(data[j] > 0);
assert(data[j] < cur_element_count);
assert(data[j] != i);
inbound_connections_num[data[j]]++;
Expand Down
Loading
Loading