forked from feather-store/feather
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathhnswlib.h
More file actions
228 lines (183 loc) · 5.98 KB
/
hnswlib.h
File metadata and controls
228 lines (183 loc) · 5.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
#pragma once
// https://github.com/nmslib/hnswlib/pull/508
// This allows others to provide their own error stream (e.g. RcppHNSW)
#ifndef HNSWLIB_ERR_OVERRIDE
#define HNSWERR std::cerr
#else
#define HNSWERR HNSWLIB_ERR_OVERRIDE
#endif
#ifndef NO_MANUAL_VECTORIZATION
#if (defined(__SSE__) || _M_IX86_FP > 0 || defined(_M_AMD64) || defined(_M_X64))
#define USE_SSE
#ifdef __AVX__
#define USE_AVX
#ifdef __AVX512F__
#define USE_AVX512
#endif
#endif
#endif
#endif
#if defined(USE_AVX) || defined(USE_SSE)
#ifdef _MSC_VER
#include <intrin.h>
#include <stdexcept>
static void cpuid(int32_t out[4], int32_t eax, int32_t ecx) {
__cpuidex(out, eax, ecx);
}
static __int64 xgetbv(unsigned int x) {
return _xgetbv(x);
}
#else
#include <x86intrin.h>
#include <cpuid.h>
#include <stdint.h>
static void cpuid(int32_t cpuInfo[4], int32_t eax, int32_t ecx) {
__cpuid_count(eax, ecx, cpuInfo[0], cpuInfo[1], cpuInfo[2], cpuInfo[3]);
}
static uint64_t xgetbv(unsigned int index) {
uint32_t eax, edx;
__asm__ __volatile__("xgetbv" : "=a"(eax), "=d"(edx) : "c"(index));
return ((uint64_t)edx << 32) | eax;
}
#endif
#if defined(USE_AVX512)
#include <immintrin.h>
#endif
#if defined(__GNUC__)
#define PORTABLE_ALIGN32 __attribute__((aligned(32)))
#define PORTABLE_ALIGN64 __attribute__((aligned(64)))
#else
#define PORTABLE_ALIGN32 __declspec(align(32))
#define PORTABLE_ALIGN64 __declspec(align(64))
#endif
// Adapted from https://github.com/Mysticial/FeatureDetector
#define _XCR_XFEATURE_ENABLED_MASK 0
static bool AVXCapable() {
int cpuInfo[4];
// CPU support
cpuid(cpuInfo, 0, 0);
int nIds = cpuInfo[0];
bool HW_AVX = false;
if (nIds >= 0x00000001) {
cpuid(cpuInfo, 0x00000001, 0);
HW_AVX = (cpuInfo[2] & ((int)1 << 28)) != 0;
}
// OS support
cpuid(cpuInfo, 1, 0);
bool osUsesXSAVE_XRSTORE = (cpuInfo[2] & (1 << 27)) != 0;
bool cpuAVXSuport = (cpuInfo[2] & (1 << 28)) != 0;
bool avxSupported = false;
if (osUsesXSAVE_XRSTORE && cpuAVXSuport) {
uint64_t xcrFeatureMask = xgetbv(_XCR_XFEATURE_ENABLED_MASK);
avxSupported = (xcrFeatureMask & 0x6) == 0x6;
}
return HW_AVX && avxSupported;
}
static bool AVX512Capable() {
if (!AVXCapable()) return false;
int cpuInfo[4];
// CPU support
cpuid(cpuInfo, 0, 0);
int nIds = cpuInfo[0];
bool HW_AVX512F = false;
if (nIds >= 0x00000007) { // AVX512 Foundation
cpuid(cpuInfo, 0x00000007, 0);
HW_AVX512F = (cpuInfo[1] & ((int)1 << 16)) != 0;
}
// OS support
cpuid(cpuInfo, 1, 0);
bool osUsesXSAVE_XRSTORE = (cpuInfo[2] & (1 << 27)) != 0;
bool cpuAVXSuport = (cpuInfo[2] & (1 << 28)) != 0;
bool avx512Supported = false;
if (osUsesXSAVE_XRSTORE && cpuAVXSuport) {
uint64_t xcrFeatureMask = xgetbv(_XCR_XFEATURE_ENABLED_MASK);
avx512Supported = (xcrFeatureMask & 0xe6) == 0xe6;
}
return HW_AVX512F && avx512Supported;
}
#endif
#include <queue>
#include <vector>
#include <iostream>
#include <string.h>
namespace hnswlib {
typedef size_t labeltype;
// This can be extended to store state for filtering (e.g. from a std::set)
class BaseFilterFunctor {
public:
virtual bool operator()(hnswlib::labeltype id) { return true; }
virtual ~BaseFilterFunctor() {};
};
template<typename dist_t>
class BaseSearchStopCondition {
public:
virtual void add_point_to_result(labeltype label, const void *datapoint, dist_t dist) = 0;
virtual void remove_point_from_result(labeltype label, const void *datapoint, dist_t dist) = 0;
virtual bool should_stop_search(dist_t candidate_dist, dist_t lowerBound) = 0;
virtual bool should_consider_candidate(dist_t candidate_dist, dist_t lowerBound) = 0;
virtual bool should_remove_extra() = 0;
virtual void filter_results(std::vector<std::pair<dist_t, labeltype >> &candidates) = 0;
virtual ~BaseSearchStopCondition() {}
};
template <typename T>
class pairGreater {
public:
bool operator()(const T& p1, const T& p2) {
return p1.first > p2.first;
}
};
template<typename T>
static void writeBinaryPOD(std::ostream &out, const T &podRef) {
out.write((char *) &podRef, sizeof(T));
}
template<typename T>
static void readBinaryPOD(std::istream &in, T &podRef) {
in.read((char *) &podRef, sizeof(T));
}
template<typename MTYPE>
using DISTFUNC = MTYPE(*)(const void *, const void *, const void *);
template<typename MTYPE>
class SpaceInterface {
public:
// virtual void search(void *);
virtual size_t get_data_size() = 0;
virtual DISTFUNC<MTYPE> get_dist_func() = 0;
virtual void *get_dist_func_param() = 0;
virtual ~SpaceInterface() {}
};
template<typename dist_t>
class AlgorithmInterface {
public:
virtual void addPoint(const void *datapoint, labeltype label, bool replace_deleted = false) = 0;
virtual std::priority_queue<std::pair<dist_t, labeltype>>
searchKnn(const void*, size_t, BaseFilterFunctor* isIdAllowed = nullptr) const = 0;
// Return k nearest neighbor in the order of closer fist
virtual std::vector<std::pair<dist_t, labeltype>>
searchKnnCloserFirst(const void* query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const;
virtual void saveIndex(const std::string &location) = 0;
virtual ~AlgorithmInterface(){
}
};
template<typename dist_t>
std::vector<std::pair<dist_t, labeltype>>
AlgorithmInterface<dist_t>::searchKnnCloserFirst(const void* query_data, size_t k,
BaseFilterFunctor* isIdAllowed) const {
std::vector<std::pair<dist_t, labeltype>> result;
// here searchKnn returns the result in the order of further first
auto ret = searchKnn(query_data, k, isIdAllowed);
{
size_t sz = ret.size();
result.resize(sz);
while (!ret.empty()) {
result[--sz] = ret.top();
ret.pop();
}
}
return result;
}
} // namespace hnswlib
#include "space_l2.h"
#include "space_ip.h"
#include "stop_condition.h"
#include "bruteforce.h"
#include "hnswalg.h"