diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c76774c..e69c3e1 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -42,8 +42,8 @@ jobs: - uses: erlef/setup-beam@v1 with: - otp-version: 25 - elixir-version: 1.14 + otp-version: "26.2.1" + elixir-version: "1.16.0" - uses: ilammy/msvc-dev-cmd@v1 with: diff --git a/3rd_party/hnswlib/bruteforce.h b/3rd_party/hnswlib/bruteforce.h index 30b33ae..e571cca 100644 --- a/3rd_party/hnswlib/bruteforce.h +++ b/3rd_party/hnswlib/bruteforce.h @@ -84,9 +84,15 @@ class BruteforceSearch : public AlgorithmInterface { void removePoint(labeltype cur_external) { - size_t cur_c = dict_external_to_internal[cur_external]; + std::unique_lock lock(index_lock); - dict_external_to_internal.erase(cur_external); + auto found = dict_external_to_internal.find(cur_external); + if (found == dict_external_to_internal.end()) { + return; + } + + size_t cur_c = found->second; + dict_external_to_internal.erase(found); labeltype label = *((labeltype*)(data_ + size_per_element_ * (cur_element_count-1) + data_size_)); dict_external_to_internal[label] = cur_c; @@ -106,7 +112,7 @@ class BruteforceSearch : public AlgorithmInterface { dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_)); if ((!isIdAllowed) || (*isIdAllowed)(label)) { - topResults.push(std::pair(dist, label)); + topResults.emplace(dist, label); } } dist_t lastdist = topResults.empty() ? std::numeric_limits::max() : topResults.top().first; @@ -115,7 +121,7 @@ class BruteforceSearch : public AlgorithmInterface { if (dist <= lastdist) { labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_)); if ((!isIdAllowed) || (*isIdAllowed)(label)) { - topResults.push(std::pair(dist, label)); + topResults.emplace(dist, label); } if (topResults.size() > k) topResults.pop(); diff --git a/3rd_party/hnswlib/hnswalg.h b/3rd_party/hnswlib/hnswalg.h index bef0017..e269ae6 100644 --- a/3rd_party/hnswlib/hnswalg.h +++ b/3rd_party/hnswlib/hnswalg.h @@ -8,6 +8,7 @@ #include #include #include +#include namespace hnswlib { typedef unsigned int tableint; @@ -33,7 +34,7 @@ class HierarchicalNSW : public AlgorithmInterface { double mult_{0.0}, revSize_{0.0}; int maxlevel_{0}; - VisitedListPool *visited_list_pool_{nullptr}; + std::unique_ptr visited_list_pool_{nullptr}; // Locks operations with element by label value mutable std::vector label_op_locks_; @@ -92,8 +93,8 @@ class HierarchicalNSW : public AlgorithmInterface { size_t ef_construction = 200, size_t random_seed = 100, bool allow_replace_deleted = false) - : link_list_locks_(max_elements), - label_op_locks_(MAX_LABEL_OPERATION_LOCKS), + : label_op_locks_(MAX_LABEL_OPERATION_LOCKS), + link_list_locks_(max_elements), element_levels_(max_elements), allow_replace_deleted_(allow_replace_deleted) { max_elements_ = max_elements; @@ -101,7 +102,13 @@ class HierarchicalNSW : public AlgorithmInterface { data_size_ = s->get_data_size(); fstdistfunc_ = s->get_dist_func(); dist_func_param_ = s->get_dist_func_param(); - M_ = M; + if ( M <= 10000 ) { + M_ = M; + } else { + HNSWERR << "warning: M parameter exceeds 10000 which may lead to adverse effects." << std::endl; + HNSWERR << " Cap to 10000 will be applied for the rest of the processing." << std::endl; + M_ = 10000; + } maxM_ = M_; maxM0_ = M_ * 2; ef_construction_ = std::max(ef_construction, M_); @@ -122,7 +129,7 @@ class HierarchicalNSW : public AlgorithmInterface { cur_element_count = 0; - visited_list_pool_ = new VisitedListPool(1, max_elements); + visited_list_pool_ = std::unique_ptr(new VisitedListPool(1, max_elements)); // initializations for special treatment of the first node enterpoint_node_ = -1; @@ -138,13 +145,20 @@ class HierarchicalNSW : public AlgorithmInterface { ~HierarchicalNSW() { + clear(); + } + + void clear() { free(data_level0_memory_); + data_level0_memory_ = nullptr; for (tableint i = 0; i < cur_element_count; i++) { if (element_levels_[i] > 0) free(linkLists_[i]); } free(linkLists_); - delete visited_list_pool_; + linkLists_ = nullptr; + cur_element_count = 0; + visited_list_pool_.reset(nullptr); } @@ -291,9 +305,15 @@ class HierarchicalNSW : public AlgorithmInterface { } - template + // bare_bone_search means there is no check for deletions and stop condition is ignored in return of extra performance + template std::priority_queue, std::vector>, CompareByFirst> - searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, BaseFilterFunctor* isIdAllowed = nullptr) const { + searchBaseLayerST( + tableint ep_id, + const void *data_point, + size_t ef, + BaseFilterFunctor* isIdAllowed = nullptr, + BaseSearchStopCondition* stop_condition = nullptr) const { VisitedList *vl = visited_list_pool_->getFreeVisitedList(); vl_type *visited_array = vl->mass; vl_type visited_array_tag = vl->curV; @@ -302,10 +322,15 @@ class HierarchicalNSW : public AlgorithmInterface { std::priority_queue, std::vector>, CompareByFirst> candidate_set; dist_t lowerBound; - if ((!has_deletions || !isMarkedDeleted(ep_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(ep_id)))) { - dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_); + if (bare_bone_search || + (!isMarkedDeleted(ep_id) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(ep_id))))) { + char* ep_data = getDataByInternalId(ep_id); + dist_t dist = fstdistfunc_(data_point, ep_data, dist_func_param_); lowerBound = dist; top_candidates.emplace(dist, ep_id); + if (!bare_bone_search && stop_condition) { + stop_condition->add_point_to_result(getExternalLabel(ep_id), ep_data, dist); + } candidate_set.emplace(-dist, ep_id); } else { lowerBound = std::numeric_limits::max(); @@ -316,9 +341,19 @@ class HierarchicalNSW : public AlgorithmInterface { while (!candidate_set.empty()) { std::pair current_node_pair = candidate_set.top(); + dist_t candidate_dist = -current_node_pair.first; - if ((-current_node_pair.first) > lowerBound && - (top_candidates.size() == ef || (!isIdAllowed && !has_deletions))) { + bool flag_stop_search; + if (bare_bone_search) { + flag_stop_search = candidate_dist > lowerBound; + } else { + if (stop_condition) { + flag_stop_search = stop_condition->should_stop_search(candidate_dist, lowerBound); + } else { + flag_stop_search = candidate_dist > lowerBound && top_candidates.size() == ef; + } + } + if (flag_stop_search) { break; } candidate_set.pop(); @@ -353,7 +388,14 @@ class HierarchicalNSW : public AlgorithmInterface { char *currObj1 = (getDataByInternalId(candidate_id)); dist_t dist = fstdistfunc_(data_point, currObj1, dist_func_param_); - if (top_candidates.size() < ef || lowerBound > dist) { + bool flag_consider_candidate; + if (!bare_bone_search && stop_condition) { + flag_consider_candidate = stop_condition->should_consider_candidate(dist, lowerBound); + } else { + flag_consider_candidate = top_candidates.size() < ef || lowerBound > dist; + } + + if (flag_consider_candidate) { candidate_set.emplace(-dist, candidate_id); #ifdef USE_SSE _mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_ + @@ -361,11 +403,30 @@ class HierarchicalNSW : public AlgorithmInterface { _MM_HINT_T0); //////////////////////// #endif - if ((!has_deletions || !isMarkedDeleted(candidate_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(candidate_id)))) + if (bare_bone_search || + (!isMarkedDeleted(candidate_id) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(candidate_id))))) { top_candidates.emplace(dist, candidate_id); + if (!bare_bone_search && stop_condition) { + stop_condition->add_point_to_result(getExternalLabel(candidate_id), currObj1, dist); + } + } - if (top_candidates.size() > ef) + bool flag_remove_extra = false; + if (!bare_bone_search && stop_condition) { + flag_remove_extra = stop_condition->should_remove_extra(); + } else { + flag_remove_extra = top_candidates.size() > ef; + } + while (flag_remove_extra) { + tableint id = top_candidates.top().second; top_candidates.pop(); + if (!bare_bone_search && stop_condition) { + stop_condition->remove_point_from_result(getExternalLabel(id), getDataByInternalId(id), dist); + flag_remove_extra = stop_condition->should_remove_extra(); + } else { + flag_remove_extra = top_candidates.size() > ef; + } + } if (!top_candidates.empty()) lowerBound = top_candidates.top().first; @@ -380,8 +441,8 @@ class HierarchicalNSW : public AlgorithmInterface { void getNeighborsByHeuristic2( - std::priority_queue, std::vector>, CompareByFirst> &top_candidates, - const size_t M) { + std::priority_queue, std::vector>, CompareByFirst> &top_candidates, + const size_t M) { if (top_candidates.size() < M) { return; } @@ -573,8 +634,7 @@ class HierarchicalNSW : public AlgorithmInterface { if (new_max_elements < cur_element_count) throw std::runtime_error("Cannot resize, max element is less than the current number of elements"); - delete visited_list_pool_; - visited_list_pool_ = new VisitedListPool(1, new_max_elements); + visited_list_pool_.reset(new VisitedListPool(1, new_max_elements)); element_levels_.resize(new_max_elements); @@ -595,6 +655,32 @@ class HierarchicalNSW : public AlgorithmInterface { max_elements_ = new_max_elements; } + size_t indexFileSize() const { + size_t size = 0; + size += sizeof(offsetLevel0_); + size += sizeof(max_elements_); + size += sizeof(cur_element_count); + size += sizeof(size_data_per_element_); + size += sizeof(label_offset_); + size += sizeof(offsetData_); + size += sizeof(maxlevel_); + size += sizeof(enterpoint_node_); + size += sizeof(maxM_); + + size += sizeof(maxM0_); + size += sizeof(M_); + size += sizeof(mult_); + size += sizeof(ef_construction_); + + size += cur_element_count * size_data_per_element_; + + for (size_t i = 0; i < cur_element_count; i++) { + unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0; + size += sizeof(linkListSize); + size += linkListSize; + } + return size; + } void saveIndex(const std::string &location) { std::ofstream output(location, std::ios::binary); @@ -633,6 +719,7 @@ class HierarchicalNSW : public AlgorithmInterface { if (!input.is_open()) throw std::runtime_error("Cannot open file"); + clear(); // get file size: input.seekg(0, input.end); std::streampos total_filesize = input.tellg(); @@ -698,7 +785,7 @@ class HierarchicalNSW : public AlgorithmInterface { std::vector(max_elements).swap(link_list_locks_); std::vector(MAX_LABEL_OPERATION_LOCKS).swap(label_op_locks_); - visited_list_pool_ = new VisitedListPool(1, max_elements); + visited_list_pool_.reset(new VisitedListPool(1, max_elements)); linkLists_ = (char **) malloc(sizeof(void *) * max_elements); if (linkLists_ == nullptr) @@ -752,7 +839,7 @@ class HierarchicalNSW : public AlgorithmInterface { size_t dim = *((size_t *) dist_func_param_); std::vector data; data_t* data_ptr = (data_t*) data_ptrv; - for (int i = 0; i < dim; i++) { + for (size_t i = 0; i < dim; i++) { data.push_back(*data_ptr); data_ptr += 1; } @@ -1216,11 +1303,12 @@ class HierarchicalNSW : public AlgorithmInterface { } std::priority_queue, std::vector>, CompareByFirst> top_candidates; - if (num_deleted_) { - top_candidates = searchBaseLayerST( + bool bare_bone_search = !num_deleted_ && !isIdAllowed; + if (bare_bone_search) { + top_candidates = searchBaseLayerST( currObj, query_data, std::max(ef_, k), isIdAllowed); } else { - top_candidates = searchBaseLayerST( + top_candidates = searchBaseLayerST( currObj, query_data, std::max(ef_, k), isIdAllowed); } @@ -1236,6 +1324,60 @@ class HierarchicalNSW : public AlgorithmInterface { } + std::vector> + searchStopConditionClosest( + const void *query_data, + BaseSearchStopCondition& stop_condition, + BaseFilterFunctor* isIdAllowed = nullptr) const { + std::vector> result; + if (cur_element_count == 0) return result; + + tableint currObj = enterpoint_node_; + dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_); + + for (int level = maxlevel_; level > 0; level--) { + bool changed = true; + while (changed) { + changed = false; + unsigned int *data; + + data = (unsigned int *) get_linklist(currObj, level); + int size = getListCount(data); + metric_hops++; + metric_distance_computations+=size; + + tableint *datal = (tableint *) (data + 1); + for (int i = 0; i < size; i++) { + tableint cand = datal[i]; + if (cand < 0 || cand > max_elements_) + throw std::runtime_error("cand error"); + dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_); + + if (d < curdist) { + curdist = d; + currObj = cand; + changed = true; + } + } + } + } + + std::priority_queue, std::vector>, CompareByFirst> top_candidates; + top_candidates = searchBaseLayerST(currObj, query_data, 0, isIdAllowed, &stop_condition); + + size_t sz = top_candidates.size(); + result.resize(sz); + while (!top_candidates.empty()) { + result[--sz] = top_candidates.top(); + top_candidates.pop(); + } + + stop_condition.filter_results(result); + + return result; + } + + void checkIntegrity() { int connections_checked = 0; std::vector inbound_connections_num(cur_element_count, 0); @@ -1246,7 +1388,6 @@ class HierarchicalNSW : public AlgorithmInterface { tableint *data = (tableint *) (ll_cur + 1); std::unordered_set s; for (int j = 0; j < size; j++) { - assert(data[j] > 0); assert(data[j] < cur_element_count); assert(data[j] != i); inbound_connections_num[data[j]]++; diff --git a/3rd_party/hnswlib/hnswlib.h b/3rd_party/hnswlib/hnswlib.h index fb7118f..7ccfbba 100644 --- a/3rd_party/hnswlib/hnswlib.h +++ b/3rd_party/hnswlib/hnswlib.h @@ -1,4 +1,13 @@ #pragma once + +// https://github.com/nmslib/hnswlib/pull/508 +// This allows others to provide their own error stream (e.g. RcppHNSW) +#ifndef HNSWLIB_ERR_OVERRIDE + #define HNSWERR std::cerr +#else + #define HNSWERR HNSWLIB_ERR_OVERRIDE +#endif + #ifndef NO_MANUAL_VECTORIZATION #if (defined(__SSE__) || _M_IX86_FP > 0 || defined(_M_AMD64) || defined(_M_X64)) #define USE_SSE @@ -15,7 +24,7 @@ #ifdef _MSC_VER #include #include -void cpuid(int32_t out[4], int32_t eax, int32_t ecx) { +static void cpuid(int32_t out[4], int32_t eax, int32_t ecx) { __cpuidex(out, eax, ecx); } static __int64 xgetbv(unsigned int x) { @@ -119,6 +128,25 @@ typedef size_t labeltype; class BaseFilterFunctor { public: virtual bool operator()(hnswlib::labeltype id) { return true; } + virtual ~BaseFilterFunctor() {}; +}; + +template +class BaseSearchStopCondition { + public: + virtual void add_point_to_result(labeltype label, const void *datapoint, dist_t dist) = 0; + + virtual void remove_point_from_result(labeltype label, const void *datapoint, dist_t dist) = 0; + + virtual bool should_stop_search(dist_t candidate_dist, dist_t lowerBound) = 0; + + virtual bool should_consider_candidate(dist_t candidate_dist, dist_t lowerBound) = 0; + + virtual bool should_remove_extra() = 0; + + virtual void filter_results(std::vector> &candidates) = 0; + + virtual ~BaseSearchStopCondition() {} }; template @@ -195,5 +223,6 @@ AlgorithmInterface::searchKnnCloserFirst(const void* query_data, size_t #include "space_l2.h" #include "space_ip.h" +#include "stop_condition.h" #include "bruteforce.h" #include "hnswalg.h" diff --git a/3rd_party/hnswlib/space_ip.h b/3rd_party/hnswlib/space_ip.h index 2b1c359..0e6834c 100644 --- a/3rd_party/hnswlib/space_ip.h +++ b/3rd_party/hnswlib/space_ip.h @@ -157,19 +157,44 @@ InnerProductSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void __m512 sum512 = _mm512_set1_ps(0); - while (pVect1 < pEnd1) { - //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); - + size_t loop = qty16 / 4; + + while (loop--) { __m512 v1 = _mm512_loadu_ps(pVect1); - pVect1 += 16; __m512 v2 = _mm512_loadu_ps(pVect2); + pVect1 += 16; + pVect2 += 16; + + __m512 v3 = _mm512_loadu_ps(pVect1); + __m512 v4 = _mm512_loadu_ps(pVect2); + pVect1 += 16; + pVect2 += 16; + + __m512 v5 = _mm512_loadu_ps(pVect1); + __m512 v6 = _mm512_loadu_ps(pVect2); + pVect1 += 16; pVect2 += 16; - sum512 = _mm512_add_ps(sum512, _mm512_mul_ps(v1, v2)); + + __m512 v7 = _mm512_loadu_ps(pVect1); + __m512 v8 = _mm512_loadu_ps(pVect2); + pVect1 += 16; + pVect2 += 16; + + sum512 = _mm512_fmadd_ps(v1, v2, sum512); + sum512 = _mm512_fmadd_ps(v3, v4, sum512); + sum512 = _mm512_fmadd_ps(v5, v6, sum512); + sum512 = _mm512_fmadd_ps(v7, v8, sum512); } - _mm512_store_ps(TmpRes, sum512); - float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7] + TmpRes[8] + TmpRes[9] + TmpRes[10] + TmpRes[11] + TmpRes[12] + TmpRes[13] + TmpRes[14] + TmpRes[15]; + while (pVect1 < pEnd1) { + __m512 v1 = _mm512_loadu_ps(pVect1); + __m512 v2 = _mm512_loadu_ps(pVect2); + pVect1 += 16; + pVect2 += 16; + sum512 = _mm512_fmadd_ps(v1, v2, sum512); + } + float sum = _mm512_reduce_add_ps(sum512); return sum; } diff --git a/3rd_party/hnswlib/stop_condition.h b/3rd_party/hnswlib/stop_condition.h new file mode 100644 index 0000000..acc80eb --- /dev/null +++ b/3rd_party/hnswlib/stop_condition.h @@ -0,0 +1,276 @@ +#pragma once +#include "space_l2.h" +#include "space_ip.h" +#include +#include + +namespace hnswlib { + +template +class BaseMultiVectorSpace : public SpaceInterface { + public: + virtual DOCIDTYPE get_doc_id(const void *datapoint) = 0; + + virtual void set_doc_id(void *datapoint, DOCIDTYPE doc_id) = 0; +}; + + +template +class MultiVectorL2Space : public BaseMultiVectorSpace { + DISTFUNC fstdistfunc_; + size_t data_size_; + size_t vector_size_; + size_t dim_; + + public: + MultiVectorL2Space(size_t dim) { + fstdistfunc_ = L2Sqr; +#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512) + #if defined(USE_AVX512) + if (AVX512Capable()) + L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX512; + else if (AVXCapable()) + L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX; + #elif defined(USE_AVX) + if (AVXCapable()) + L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX; + #endif + + if (dim % 16 == 0) + fstdistfunc_ = L2SqrSIMD16Ext; + else if (dim % 4 == 0) + fstdistfunc_ = L2SqrSIMD4Ext; + else if (dim > 16) + fstdistfunc_ = L2SqrSIMD16ExtResiduals; + else if (dim > 4) + fstdistfunc_ = L2SqrSIMD4ExtResiduals; +#endif + dim_ = dim; + vector_size_ = dim * sizeof(float); + data_size_ = vector_size_ + sizeof(DOCIDTYPE); + } + + size_t get_data_size() override { + return data_size_; + } + + DISTFUNC get_dist_func() override { + return fstdistfunc_; + } + + void *get_dist_func_param() override { + return &dim_; + } + + DOCIDTYPE get_doc_id(const void *datapoint) override { + return *(DOCIDTYPE *)((char *)datapoint + vector_size_); + } + + void set_doc_id(void *datapoint, DOCIDTYPE doc_id) override { + *(DOCIDTYPE*)((char *)datapoint + vector_size_) = doc_id; + } + + ~MultiVectorL2Space() {} +}; + + +template +class MultiVectorInnerProductSpace : public BaseMultiVectorSpace { + DISTFUNC fstdistfunc_; + size_t data_size_; + size_t vector_size_; + size_t dim_; + + public: + MultiVectorInnerProductSpace(size_t dim) { + fstdistfunc_ = InnerProductDistance; +#if defined(USE_AVX) || defined(USE_SSE) || defined(USE_AVX512) + #if defined(USE_AVX512) + if (AVX512Capable()) { + InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX512; + InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX512; + } else if (AVXCapable()) { + InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX; + InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX; + } + #elif defined(USE_AVX) + if (AVXCapable()) { + InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX; + InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX; + } + #endif + #if defined(USE_AVX) + if (AVXCapable()) { + InnerProductSIMD4Ext = InnerProductSIMD4ExtAVX; + InnerProductDistanceSIMD4Ext = InnerProductDistanceSIMD4ExtAVX; + } + #endif + + if (dim % 16 == 0) + fstdistfunc_ = InnerProductDistanceSIMD16Ext; + else if (dim % 4 == 0) + fstdistfunc_ = InnerProductDistanceSIMD4Ext; + else if (dim > 16) + fstdistfunc_ = InnerProductDistanceSIMD16ExtResiduals; + else if (dim > 4) + fstdistfunc_ = InnerProductDistanceSIMD4ExtResiduals; +#endif + vector_size_ = dim * sizeof(float); + data_size_ = vector_size_ + sizeof(DOCIDTYPE); + } + + size_t get_data_size() override { + return data_size_; + } + + DISTFUNC get_dist_func() override { + return fstdistfunc_; + } + + void *get_dist_func_param() override { + return &dim_; + } + + DOCIDTYPE get_doc_id(const void *datapoint) override { + return *(DOCIDTYPE *)((char *)datapoint + vector_size_); + } + + void set_doc_id(void *datapoint, DOCIDTYPE doc_id) override { + *(DOCIDTYPE*)((char *)datapoint + vector_size_) = doc_id; + } + + ~MultiVectorInnerProductSpace() {} +}; + + +template +class MultiVectorSearchStopCondition : public BaseSearchStopCondition { + size_t curr_num_docs_; + size_t num_docs_to_search_; + size_t ef_collection_; + std::unordered_map doc_counter_; + std::priority_queue> search_results_; + BaseMultiVectorSpace& space_; + + public: + MultiVectorSearchStopCondition( + BaseMultiVectorSpace& space, + size_t num_docs_to_search, + size_t ef_collection = 10) + : space_(space) { + curr_num_docs_ = 0; + num_docs_to_search_ = num_docs_to_search; + ef_collection_ = std::max(ef_collection, num_docs_to_search); + } + + void add_point_to_result(labeltype label, const void *datapoint, dist_t dist) override { + DOCIDTYPE doc_id = space_.get_doc_id(datapoint); + if (doc_counter_[doc_id] == 0) { + curr_num_docs_ += 1; + } + search_results_.emplace(dist, doc_id); + doc_counter_[doc_id] += 1; + } + + void remove_point_from_result(labeltype label, const void *datapoint, dist_t dist) override { + DOCIDTYPE doc_id = space_.get_doc_id(datapoint); + doc_counter_[doc_id] -= 1; + if (doc_counter_[doc_id] == 0) { + curr_num_docs_ -= 1; + } + search_results_.pop(); + } + + bool should_stop_search(dist_t candidate_dist, dist_t lowerBound) override { + bool stop_search = candidate_dist > lowerBound && curr_num_docs_ == ef_collection_; + return stop_search; + } + + bool should_consider_candidate(dist_t candidate_dist, dist_t lowerBound) override { + bool flag_consider_candidate = curr_num_docs_ < ef_collection_ || lowerBound > candidate_dist; + return flag_consider_candidate; + } + + bool should_remove_extra() override { + bool flag_remove_extra = curr_num_docs_ > ef_collection_; + return flag_remove_extra; + } + + void filter_results(std::vector> &candidates) override { + while (curr_num_docs_ > num_docs_to_search_) { + dist_t dist_cand = candidates.back().first; + dist_t dist_res = search_results_.top().first; + assert(dist_cand == dist_res); + DOCIDTYPE doc_id = search_results_.top().second; + doc_counter_[doc_id] -= 1; + if (doc_counter_[doc_id] == 0) { + curr_num_docs_ -= 1; + } + search_results_.pop(); + candidates.pop_back(); + } + } + + ~MultiVectorSearchStopCondition() {} +}; + + +template +class EpsilonSearchStopCondition : public BaseSearchStopCondition { + float epsilon_; + size_t min_num_candidates_; + size_t max_num_candidates_; + size_t curr_num_items_; + + public: + EpsilonSearchStopCondition(float epsilon, size_t min_num_candidates, size_t max_num_candidates) { + assert(min_num_candidates <= max_num_candidates); + epsilon_ = epsilon; + min_num_candidates_ = min_num_candidates; + max_num_candidates_ = max_num_candidates; + curr_num_items_ = 0; + } + + void add_point_to_result(labeltype label, const void *datapoint, dist_t dist) override { + curr_num_items_ += 1; + } + + void remove_point_from_result(labeltype label, const void *datapoint, dist_t dist) override { + curr_num_items_ -= 1; + } + + bool should_stop_search(dist_t candidate_dist, dist_t lowerBound) override { + if (candidate_dist > lowerBound && curr_num_items_ == max_num_candidates_) { + // new candidate can't improve found results + return true; + } + if (candidate_dist > epsilon_ && curr_num_items_ >= min_num_candidates_) { + // new candidate is out of epsilon region and + // minimum number of candidates is checked + return true; + } + return false; + } + + bool should_consider_candidate(dist_t candidate_dist, dist_t lowerBound) override { + bool flag_consider_candidate = curr_num_items_ < max_num_candidates_ || lowerBound > candidate_dist; + return flag_consider_candidate; + } + + bool should_remove_extra() { + bool flag_remove_extra = curr_num_items_ > max_num_candidates_; + return flag_remove_extra; + } + + void filter_results(std::vector> &candidates) override { + while (!candidates.empty() && candidates.back().first > epsilon_) { + candidates.pop_back(); + } + while (candidates.size() > max_num_candidates_) { + candidates.pop_back(); + } + } + + ~EpsilonSearchStopCondition() {} +}; +} // namespace hnswlib diff --git a/c_src/hnswlib_index.hpp b/c_src/hnswlib_index.hpp index 453ad12..4b017f9 100644 --- a/c_src/hnswlib_index.hpp +++ b/c_src/hnswlib_index.hpp @@ -166,6 +166,14 @@ class Index { } + void set_num_threads(int num_threads) { + this->num_threads_default = num_threads; + } + + size_t indexFileSize() const { + return appr_alg->indexFileSize(); + } + void saveIndex(const std::string &path_to_index) { appr_alg->saveIndex(path_to_index); } @@ -424,6 +432,7 @@ class BFIndex { int dim; bool index_inited; bool normalize; + int num_threads_default; hnswlib::labeltype cur_l; hnswlib::BruteforceSearch* alg; @@ -444,6 +453,8 @@ class BFIndex { } alg = NULL; index_inited = false; + + num_threads_default = std::thread::hardware_concurrency(); } @@ -454,6 +465,21 @@ class BFIndex { } + size_t getMaxElements() const { + return alg->maxelements_; + } + + + size_t getCurrentCount() const { + return alg->cur_element_count; + } + + + void set_num_threads(int num_threads) { + this->num_threads_default = num_threads; + } + + void init_new_index(const size_t maxElements) { if (alg) { throw std::runtime_error("The index is already initiated."); diff --git a/c_src/hnswlib_nif.cpp b/c_src/hnswlib_nif.cpp index 623fe86..e6d313d 100644 --- a/c_src/hnswlib_nif.cpp +++ b/c_src/hnswlib_nif.cpp @@ -387,7 +387,7 @@ static ERL_NIF_TERM hnswlib_index_set_num_threads(ErlNifEnv *env, int argc, cons if ((index = NifResHNSWLibIndex::get_resource(env, argv[0], error)) == nullptr) { return enif_make_badarg(env); } - if (!erlang::nif::get(env, argv[1], &new_num_threads)) { + if (!erlang::nif::get(env, argv[1], &new_num_threads) || new_num_threads <= 0) { return enif_make_badarg(env); } index->val->num_threads_default = new_num_threads; @@ -395,6 +395,17 @@ static ERL_NIF_TERM hnswlib_index_set_num_threads(ErlNifEnv *env, int argc, cons return erlang::nif::ok(env); } +static ERL_NIF_TERM hnswlib_index_index_file_size(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { + NifResHNSWLibIndex * index = nullptr; + ERL_NIF_TERM ret, error; + + if ((index = NifResHNSWLibIndex::get_resource(env, argv[0], error)) == nullptr) { + return enif_make_badarg(env); + } + + return erlang::nif::ok(env, erlang::nif::make(env, (unsigned long long)index->val->indexFileSize())); +} + static ERL_NIF_TERM hnswlib_index_get_items(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { NifResHNSWLibIndex * index = nullptr; ErlNifBinary ids_binary; @@ -631,6 +642,23 @@ static ERL_NIF_TERM hnswlib_bfindex_delete_vector(ErlNifEnv *env, int argc, cons return erlang::nif::ok(env); } +static ERL_NIF_TERM hnswlib_bfindex_set_num_threads(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { + NifResHNSWLibBFIndex * index = nullptr; + int new_num_threads; + ERL_NIF_TERM ret, error; + + if ((index = NifResHNSWLibBFIndex::get_resource(env, argv[0], error)) == nullptr) { + return enif_make_badarg(env); + } + if (!erlang::nif::get(env, argv[1], &new_num_threads) || new_num_threads <= 0) { + return enif_make_badarg(env); + } + + index->val->set_num_threads(new_num_threads); + + return erlang::nif::ok(env); +} + static ERL_NIF_TERM hnswlib_bfindex_save_index(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { NifResHNSWLibBFIndex * index = nullptr; std::string path; @@ -701,6 +729,40 @@ static ERL_NIF_TERM hnswlib_bfindex_load_index(ErlNifEnv *env, int argc, const E return ret; } +static ERL_NIF_TERM hnswlib_bfindex_get_max_elements(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { + ERL_NIF_TERM error; + NifResHNSWLibBFIndex * index = nullptr; + if ((index = NifResHNSWLibBFIndex::get_resource(env, argv[0], error)) == nullptr) { + return enif_make_badarg(env); + } + + size_t max_elements = index->val->getMaxElements(); + return erlang::nif::ok(env, erlang::nif::make(env, (unsigned long long)max_elements)); +} + +static ERL_NIF_TERM hnswlib_bfindex_get_current_count(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { + NifResHNSWLibBFIndex * index = nullptr; + ERL_NIF_TERM ret, error; + + if ((index = NifResHNSWLibBFIndex::get_resource(env, argv[0], error)) == nullptr) { + return enif_make_badarg(env); + } + + size_t count = index->val->getCurrentCount(); + return erlang::nif::ok(env, erlang::nif::make(env, (unsigned long long)count)); +} + +static ERL_NIF_TERM hnswlib_bfindex_get_num_threads(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { + NifResHNSWLibBFIndex * index = nullptr; + ERL_NIF_TERM ret, error; + + if ((index = NifResHNSWLibBFIndex::get_resource(env, argv[0], error)) == nullptr) { + return enif_make_badarg(env); + } + + return erlang::nif::ok(env, erlang::nif::make(env, (unsigned long long)index->val->num_threads_default)); +} + static ERL_NIF_TERM hnswlib_float_size(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { return enif_make_uint(env, sizeof(float)); } @@ -737,6 +799,7 @@ static ErlNifFunc nif_functions[] = { {"index_set_ef", 2, hnswlib_index_set_ef, 0}, {"index_get_num_threads", 1, hnswlib_index_get_num_threads, 0}, {"index_set_num_threads", 2, hnswlib_index_set_num_threads, 0}, + {"index_index_file_size", 1, hnswlib_index_index_file_size, 0}, {"index_save_index", 2, hnswlib_index_save_index, ERL_NIF_DIRTY_JOB_IO_BOUND}, {"index_load_index", 5, hnswlib_index_load_index, ERL_NIF_DIRTY_JOB_IO_BOUND}, {"index_mark_deleted", 2, hnswlib_index_mark_deleted, 0}, @@ -751,8 +814,12 @@ static ErlNifFunc nif_functions[] = { {"bfindex_knn_query", 6, hnswlib_bfindex_knn_query, ERL_NIF_DIRTY_JOB_CPU_BOUND}, {"bfindex_add_items", 5, hnswlib_bfindex_add_items, ERL_NIF_DIRTY_JOB_CPU_BOUND}, {"bfindex_delete_vector", 2, hnswlib_bfindex_delete_vector, 0}, + {"bfindex_set_num_threads", 2, hnswlib_bfindex_set_num_threads, 0}, {"bfindex_save_index", 2, hnswlib_bfindex_save_index, ERL_NIF_DIRTY_JOB_IO_BOUND}, {"bfindex_load_index", 4, hnswlib_bfindex_load_index, ERL_NIF_DIRTY_JOB_IO_BOUND}, + {"bfindex_get_max_elements", 1, hnswlib_bfindex_get_max_elements, 0}, + {"bfindex_get_current_count", 1, hnswlib_bfindex_get_current_count, 0}, + {"bfindex_get_num_threads", 1, hnswlib_bfindex_get_num_threads, 0}, {"float_size", 0, hnswlib_float_size, 0} }; diff --git a/lib/hnswlib_bfindex.ex b/lib/hnswlib_bfindex.ex index 8175c0d..2fbc125 100644 --- a/lib/hnswlib_bfindex.ex +++ b/lib/hnswlib_bfindex.ex @@ -145,6 +145,14 @@ defmodule HNSWLib.BFIndex do HNSWLib.Nif.bfindex_delete_vector(self.reference, label) end + @doc """ + Get the current number of threads to use in the index. + """ + @spec set_num_threads(%T{}, pos_integer()) :: {:ok, integer()} | {:error, String.t()} + def set_num_threads(self = %T{}, num_threads) when is_integer(num_threads) and num_threads > 0 do + HNSWLib.Nif.bfindex_set_num_threads(self.reference, num_threads) + end + @doc """ Save current index to disk. @@ -210,4 +218,28 @@ defmodule HNSWLib.BFIndex do {:error, reason} end end + + @doc """ + Get the maximum number of elements the index can hold. + """ + @spec get_max_elements(%T{}) :: {:ok, integer()} | {:error, String.t()} + def get_max_elements(self = %T{}) do + HNSWLib.Nif.bfindex_get_max_elements(self.reference) + end + + @doc """ + Get the current number of elements in the index. + """ + @spec get_current_count(%T{}) :: {:ok, integer()} | {:error, String.t()} + def get_current_count(self = %T{}) do + HNSWLib.Nif.bfindex_get_current_count(self.reference) + end + + @doc """ + Get the current number of threads to use in the index. + """ + @spec get_num_threads(%T{}) :: {:ok, integer()} | {:error, String.t()} + def get_num_threads(self = %T{}) do + HNSWLib.Nif.bfindex_get_num_threads(self.reference) + end end diff --git a/lib/hnswlib_index.ex b/lib/hnswlib_index.ex index 0ca4d58..f9aec9f 100644 --- a/lib/hnswlib_index.ex +++ b/lib/hnswlib_index.ex @@ -193,6 +193,14 @@ defmodule HNSWLib.Index do HNSWLib.Nif.index_set_num_threads(self.reference, new_num_threads) end + @doc """ + Get the size the of index file. + """ + @spec index_file_size(%T{}) :: {:ok, non_neg_integer()} | {:error, String.t()} + def index_file_size(self = %T{}) do + HNSWLib.Nif.index_index_file_size(self.reference) + end + @doc """ Save current index to disk. diff --git a/lib/hnswlib_nif.ex b/lib/hnswlib_nif.ex index f702605..6660087 100644 --- a/lib/hnswlib_nif.ex +++ b/lib/hnswlib_nif.ex @@ -41,6 +41,8 @@ defmodule HNSWLib.Nif do def index_set_num_threads(_self, _new_num_threads), do: :erlang.nif_error(:not_loaded) + def index_index_file_size(_self), do: :erlang.nif_error(:not_loaded) + def index_save_index(_self, _path), do: :erlang.nif_error(:not_loaded) def index_load_index(_space, _dim, _path, _max_elements, _allow_replace_deleted), @@ -70,9 +72,16 @@ defmodule HNSWLib.Nif do def bfindex_delete_vector(_self, _label), do: :erlang.nif_error(:not_loaded) + def bfindex_set_num_threads(_self, _num_threads), do: :erlang.nif_error(:not_loaded) + def bfindex_save_index(_self, _path), do: :erlang.nif_error(:not_loaded) def bfindex_load_index(_space, _dim, _path, _max_elements), do: :erlang.nif_error(:not_loaded) + def bfindex_get_max_elements(_self), do: :erlang.nif_error(:not_loaded) + + def bfindex_get_current_count(_self), do: :erlang.nif_error(:not_loaded) + + def bfindex_get_num_threads(_self), do: :erlang.nif_error(:not_loaded) def float_size, do: :erlang.nif_error(:not_loaded) end diff --git a/test/hnswlib_bfindex_test.exs b/test/hnswlib_bfindex_test.exs index 8950b08..7e601c4 100644 --- a/test/hnswlib_bfindex_test.exs +++ b/test/hnswlib_bfindex_test.exs @@ -362,7 +362,10 @@ defmodule HNSWLib.BFIndex.Test do assert :ok == HNSWLib.BFIndex.save_index(index, save_to) assert File.exists?(save_to) - {:ok, _index_from_save} = HNSWLib.BFIndex.load_index(space, dim, save_to) + {:ok, index_from_save} = HNSWLib.BFIndex.load_index(space, dim, save_to) + + assert HNSWLib.BFIndex.get_max_elements(index) == + HNSWLib.BFIndex.get_max_elements(index_from_save) # cleanup File.rm(save_to) @@ -383,12 +386,69 @@ defmodule HNSWLib.BFIndex.Test do assert :ok == HNSWLib.BFIndex.save_index(index, save_to) assert File.exists?(save_to) - new_max_elements = 1 + new_max_elements = 100 {:ok, _index_from_save} = HNSWLib.BFIndex.load_index(space, dim, save_to, max_elements: new_max_elements) + assert {:ok, 200} == HNSWLib.BFIndex.get_max_elements(index) + # fix: upstream bug? + # assert {:ok, 100} == HNSWLib.BFIndex.get_max_elements(index_from_save) + # cleanup File.rm(save_to) end + + test "HNSWLib.BFIndex.get_max_elements/1" do + space = :l2 + dim = 2 + max_elements = 200 + {:ok, index} = HNSWLib.BFIndex.new(space, dim, max_elements) + + assert {:ok, 200} == HNSWLib.BFIndex.get_max_elements(index) + end + + test "HNSWLib.BFIndex.get_current_count/1 when empty" do + space = :l2 + dim = 2 + max_elements = 200 + {:ok, index} = HNSWLib.BFIndex.new(space, dim, max_elements) + + assert {:ok, 0} == HNSWLib.BFIndex.get_current_count(index) + end + + test "HNSWLib.BFIndex.get_current_count/1 before and after" do + space = :l2 + dim = 2 + max_elements = 200 + items = Nx.tensor([[10, 20], [30, 40]], type: :f32) + {:ok, index} = HNSWLib.BFIndex.new(space, dim, max_elements) + + assert {:ok, 0} == HNSWLib.BFIndex.get_current_count(index) + assert :ok == HNSWLib.BFIndex.add_items(index, items) + assert {:ok, 2} == HNSWLib.BFIndex.get_current_count(index) + end + + test "HNSWLib.BFIndex.get_num_threads/1 default case" do + space = :l2 + dim = 2 + max_elements = 200 + {:ok, index} = HNSWLib.BFIndex.new(space, dim, max_elements) + + {:ok, cur_num_threads} = HNSWLib.BFIndex.get_num_threads(index) + assert System.schedulers() == cur_num_threads + end + + test "HNSWLib.BFIndex.get_num_threads/1 before and after HNSWLib.BFIndex.set_num_threads/2" do + space = :l2 + dim = 2 + max_elements = 200 + {:ok, index} = HNSWLib.BFIndex.new(space, dim, max_elements) + + {:ok, cur_num_threads} = HNSWLib.BFIndex.get_num_threads(index) + assert System.schedulers() == cur_num_threads + assert :ok == HNSWLib.BFIndex.set_num_threads(index, cur_num_threads + 1) + {:ok, updated_num_threads} = HNSWLib.BFIndex.get_num_threads(index) + assert updated_num_threads == cur_num_threads + 1 + end end diff --git a/test/hnswlib_index_test.exs b/test/hnswlib_index_test.exs index 96abb8b..3822387 100644 --- a/test/hnswlib_index_test.exs +++ b/test/hnswlib_index_test.exs @@ -536,6 +536,31 @@ defmodule HNSWLib.Index.Test do assert num_threads == 2 end + test "HNSWLib.Index.index_file_size/1" do + space = :l2 + dim = 2 + max_elements = 2 + {:ok, index} = HNSWLib.Index.new(space, dim, max_elements) + + assert {:ok, 96} == HNSWLib.Index.index_file_size(index) + + max_elements = 400 + assert :ok == HNSWLib.Index.resize_index(index, max_elements) + assert {:ok, 96} == HNSWLib.Index.index_file_size(index) + + items = Nx.tensor([[10, 20], [30, 40]], type: :f32) + ids = Nx.tensor([100, 200]) + :ok = HNSWLib.Index.add_items(index, items, ids: ids) + assert {:ok, 400} == HNSWLib.Index.index_file_size(index) + + max_elements = 800 + assert :ok == HNSWLib.Index.resize_index(index, max_elements) + assert {:ok, 400} == HNSWLib.Index.index_file_size(index) + + :ok = HNSWLib.Index.add_items(index, items, ids: ids) + assert {:ok, 400} == HNSWLib.Index.index_file_size(index) + end + test "HNSWLib.Index.save_index/2" do space = :l2 dim = 2