forked from DataIntelligenceCrew/go-faiss
-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathindex_binary.go
More file actions
421 lines (358 loc) · 12.6 KB
/
index_binary.go
File metadata and controls
421 lines (358 loc) · 12.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
package faiss
/*
#include <stdlib.h>
#include <stdint.h>
#include <faiss/c_api/Index_c_ex.h>
#include <faiss/c_api/IndexBinary_c_ex.h>
#include <faiss/c_api/IndexBinaryIVF_c_ex.h>
#include <faiss/c_api/index_factory_c.h>
*/
import "C"
import (
"encoding/json"
"fmt"
"unsafe"
)
type BinaryIndex interface {
// D returns the dimension of the indexed vectors.
D() int
// MetricType returns the metric type of the index.
MetricType() int
// Ntotal returns the total number of vectors currently stored in the index.
Ntotal() int64
// set the direct map type for IVF indexes.
// 0 for No Map
// 1 for Array
// 2 for Hash
SetDirectMap(maptype int) error
// set the number of probes for IVF indexes
SetNProbe(nprobe int32)
// returns true if the underlying index is an IVF index
IsIVFIndex() bool
// IVFParams returns the nlist and nprobe parameters for IVF indexes
IVFParams() (nprobe int, nlist int)
// trains the index on a representative set of vectors
Train(xb []uint8) error
// adds vectors to the index
Add(xb []uint8) error
// queries the index with the vectors in xb
// returns the IDs of the k nearest neighbors for each query vector and
// their corresponding distances
Search(xb []uint8, k int64) (distances []int32, labels []int64, err error)
// SearchWithOptions performs a search with additional optional constraints.
// - Selector can be used to restrict the search to a subset of the indexed vectors based on their IDs.
// - params is a JSON object that can contain additional search parameters specific to the index type, such as IVF search parameters.
SearchWithOptions(xb []uint8, k int64, sel Selector, params json.RawMessage) (distances []int32, labels []int64, err error)
// returns a slice where each index corresponds to a cluster in an IVF
// index, and the value at each index is the count of vectors in that
// cluster, considering only the vectors specified in the include selector.
ObtainClusterVectorCountsFromIVFIndex(include Selector, nlist int) (
[]int64, error)
// returns the IDs and distances of the closest numCentroids centroids to
// the query vector xb, considering only the centroids specified in the
// includedCentroids selector.
ObtainClustersWithDistancesFromIVFIndex(xb []uint8, includedCentroids Selector,
numCentroids int64) ([]int64, []int32, error)
// Applicable only to IVF indexes: Returns the top k centroid cardinalities and
// their vectors in chosen order (descending or ascending)
ObtainKCentroidCardinalitiesFromIVFIndex(limit int, descending bool) ([]uint64, [][]uint8, error)
// searches the specified clusters in an IVF index for the k nearest neighbors
// of the query vector xb, considering only the vectors specified in the include selector
// and additional search parameters passed as a JSON object.
SearchClustersFromIVFIndex(eligibleCentroidIDs []int64, centroidDis []int32,
centroidsToProbe int, xb []uint8, k int64, include Selector,
params json.RawMessage) ([]int32, []int64, error)
// returns the total size of the index in bytes
Size() uint64
// frees the memory associated with the index
Close()
bPtr() *C.FaissIndexBinary
}
type faissBinaryIndex struct {
bIdx *C.FaissIndexBinary
}
func (b *faissBinaryIndex) bPtr() *C.FaissIndexBinary {
return b.bIdx
}
func (b *faissBinaryIndex) D() int {
return int(C.faiss_IndexBinary_d(b.bIdx))
}
func (b *faissBinaryIndex) MetricType() int {
return int(C.faiss_IndexBinary_metric_type(b.bIdx))
}
func (b *faissBinaryIndex) Ntotal() int64 {
return int64(C.faiss_IndexBinary_ntotal(b.bIdx))
}
func (b *faissBinaryIndex) SetDirectMap(mapType int) (err error) {
// Applicable only to IVF indexes
ivfPtrBinary := C.faiss_IndexBinaryIVF_cast(b.bIdx)
if ivfPtrBinary == nil {
return fmt.Errorf("index is not of ivf type")
}
if c := C.faiss_IndexBinaryIVF_set_direct_map(
ivfPtrBinary,
C.int(mapType),
); c != 0 {
err = getLastError()
}
return err
}
func (b *faissBinaryIndex) SetNProbe(nprobe int32) {
// Applicable only to IVF indexes
ivfPtrBinary := C.faiss_IndexBinaryIVF_cast(b.bIdx)
if ivfPtrBinary == nil {
return
}
C.faiss_IndexBinaryIVF_set_nprobe(ivfPtrBinary, C.size_t(nprobe))
}
func (b *faissBinaryIndex) IsIVFIndex() bool {
ivfPtrBinary := C.faiss_IndexBinaryIVF_cast(b.bIdx)
return ivfPtrBinary != nil
}
func (b *faissBinaryIndex) IVFParams() (nprobe int, nlist int) {
// Applicable only to IVF indexes
ivfPtrBinary := C.faiss_IndexBinaryIVF_cast(b.bIdx)
if ivfPtrBinary == nil {
return 0, 0
}
nlist = int(C.faiss_IndexBinaryIVF_nlist(ivfPtrBinary))
nprobe = int(C.faiss_IndexBinaryIVF_nprobe(ivfPtrBinary))
return nprobe, nlist
}
func (b *faissBinaryIndex) Train(x []uint8) error {
n := (len(x) * 8) / b.D()
if c := C.faiss_IndexBinary_train(b.bIdx, C.idx_t(n),
(*C.uint8_t)(&x[0])); c != 0 {
return getLastError()
}
return nil
}
func (b *faissBinaryIndex) Add(x []uint8) error {
n := (len(x) * 8) / b.D()
if c := C.faiss_IndexBinary_add(b.bIdx, C.idx_t(n),
(*C.uint8_t)(&x[0])); c != 0 {
return getLastError()
}
return nil
}
func (b *faissBinaryIndex) Search(xb []uint8, k int64) (
[]int32, []int64, error) {
nq := (len(xb) * 8) / b.D()
distances := make([]int32, int64(nq)*k)
labels := make([]int64, int64(nq)*k)
if c := C.faiss_IndexBinary_search(
b.bIdx,
C.idx_t(nq),
(*C.uint8_t)(&xb[0]),
C.idx_t(k),
(*C.int32_t)(&distances[0]),
(*C.idx_t)(&labels[0]),
); c != 0 {
return nil, nil, getLastError()
}
return distances, labels, nil
}
func (b *faissBinaryIndex) SearchWithOptions(xb []uint8, k int64, sel Selector, params json.RawMessage) ([]int32, []int64, error) {
if sel == nil && params == nil {
return b.Search(xb, k)
}
return b.searchWithOptions(xb, k, sel, params)
}
func (b *faissBinaryIndex) searchWithOptions(xb []uint8, k int64, selector Selector,
params json.RawMessage) ([]int32, []int64, error) {
// Build a binary search params object to contain either the selector, the additional params, or both.
searchParams, err := NewBinarySearchParams(b, params, selector, nil)
if err != nil {
return nil, nil, err
}
defer searchParams.Delete()
nq := (len(xb) * 8) / b.D()
distances := make([]int32, int64(nq)*k)
labels := make([]int64, int64(nq)*k)
if c := C.faiss_IndexBinary_search_with_params(
b.bIdx,
C.idx_t(nq),
(*C.uint8_t)(&xb[0]),
C.idx_t(k),
searchParams.sp,
(*C.int32_t)(&distances[0]),
(*C.idx_t)(&labels[0]),
); c != 0 {
return nil, nil, getLastError()
}
return distances, labels, nil
}
func (b *faissBinaryIndex) ObtainClusterVectorCountsFromIVFIndex(includedVectors Selector, nlist int) ([]int64, error) {
// Applicable only to IVF indexes
ivfPtrBinary := C.faiss_IndexBinaryIVF_cast(b.bIdx)
if ivfPtrBinary == nil {
return nil, fmt.Errorf("index is not of ivf type")
}
// Creating a slice to hold the count of vectors per cluster
// Since we have nlist clusters, we create a slice of size nlist
// listCount[i] will hold the count of vectors in cluster i
listCount := make([]int64, nlist)
// Creating a FAISS selector based on the include bitmap.
params, err := NewStandardSearchParams(includedVectors)
if err != nil {
return nil, err
}
defer params.Delete()
// Calling the C function to populate listCount
// with the count of vectors per cluster, considering only
// the vectors specified in the include selector.
if c := C.faiss_IndexBinaryIVF_list_vector_count(
ivfPtrBinary,
(*C.idx_t)(unsafe.Pointer(&listCount[0])),
C.size_t(nlist),
params.sp,
); c != 0 {
return nil, getLastError()
}
return listCount, nil
}
func (b *faissBinaryIndex) ObtainClustersWithDistancesFromIVFIndex(xb []uint8, includedCentroids Selector, numCentroids int64) ([]int64, []int32, error) {
// Applicable only to IVF indexes
ivfPtrBinary := C.faiss_IndexBinaryIVF_cast(b.bIdx)
if ivfPtrBinary == nil {
return nil, nil, fmt.Errorf("index is not of ivf type")
}
params, err := NewStandardSearchParams(includedCentroids)
if err != nil {
return nil, nil, err
}
defer params.Delete()
// Populate these with the centroids and their distances.
centroids := make([]int64, numCentroids)
centroidDistances := make([]int32, numCentroids)
n := (len(xb) * 8) / b.D()
if c := C.faiss_IndexBinaryIVF_search_closest_eligible_centroids(
ivfPtrBinary,
(C.idx_t)(n),
(*C.uint8_t)(&xb[0]),
(C.idx_t)(numCentroids),
(*C.int32_t)(¢roidDistances[0]),
(*C.idx_t)(¢roids[0]),
params.sp,
); c != 0 {
return nil, nil, getLastError()
}
return centroids, centroidDistances, nil
}
func (b *faissBinaryIndex) ObtainKCentroidCardinalitiesFromIVFIndex(limit int, descending bool) (
[]uint64, [][]uint8, error) {
if limit <= 0 {
return nil, nil, nil
}
// Applicable only to IVF indexes
ivfPtrBinary := C.faiss_IndexBinaryIVF_cast(b.bIdx)
if ivfPtrBinary == nil {
return nil, nil, fmt.Errorf("index is not of ivf type")
}
nlist := int(C.faiss_IndexBinaryIVF_nlist(ivfPtrBinary))
if nlist == 0 {
return nil, nil, nil
}
centroidCardinalities := make([]C.size_t, nlist)
// Allocate a flat buffer for all centroids, then slice it per centroid
d := b.D()
flatCentroids := make([]uint8, nlist*d/8)
// Call the C function to fill centroid vectors and cardinalities
c := C.faiss_IndexBinaryIVF_get_centroids_and_cardinality(
ivfPtrBinary,
(*C.uint8_t)(&flatCentroids[0]),
(*C.size_t)(¢roidCardinalities[0]),
nil,
)
if c != 0 {
return nil, nil, getLastError()
}
topIndices := getIndicesOfKCentroidCardinalities(
centroidCardinalities,
min(limit, nlist),
descending)
rvCardinalities := make([]uint64, len(topIndices))
rvCentroids := make([][]uint8, len(topIndices))
for i, idx := range topIndices {
rvCardinalities[i] = uint64(centroidCardinalities[idx])
rvCentroids[i] = flatCentroids[idx*d : (idx+1)*d]
}
return rvCardinalities, rvCentroids, nil
}
func (b *faissBinaryIndex) SearchClustersFromIVFIndex(eligibleCentroidIDs []int64, centroidDis []int32, centroidsToProbe int,
xb []uint8, k int64, include Selector, params json.RawMessage) ([]int32, []int64, error) {
// Applicable only to IVF indexes
ivfPtrBinary := C.faiss_IndexBinaryIVF_cast(b.bIdx)
if ivfPtrBinary == nil {
return nil, nil, fmt.Errorf("index is not of ivf type")
}
// If no include selector is provided, we have no results to return.
// return an error indicating that the SearchClustersFromIVFIndex requires a valid selector.
if include == nil {
return nil, nil, fmt.Errorf("SearchClustersFromIVFIndex requires a valid include selector")
}
// create a temporary search params object to set nprobe, this will override
// the nprobe and the nlist set at index time, this will allow the search to
// probe only the clusters specified in eligibleCentroidIDs
tempParams := &defaultSearchParamsIVF{
// Nlist is set to the number of eligible centroids, which will override
// the nlist set at index time.
Nlist: len(eligibleCentroidIDs),
// Have to override nprobe so that more clusters will be searched for this
// query, if required.
Nprobe: centroidsToProbe,
}
searchParams, err := NewBinarySearchParams(b, params, include, tempParams)
if err != nil {
return nil, nil, err
}
defer searchParams.Delete()
n := (len(xb) * 8) / b.D()
distances := make([]int32, int64(n)*k)
labels := make([]int64, int64(n)*k)
// Adjust the slices to match the effective nprobe set in searchParams, as the input
// parameters may have different nprobe value, which will be a hard override, over the
// centroidsToProbe value passed to this function.
// If the effective nprobe is greater than the length of eligibleCentroidIDs,
// we limit it to the length of eligibleCentroidIDs.
effectiveNprobe := min(getNProbeFromSearchParams(searchParams), int32(len(eligibleCentroidIDs)))
eligibleCentroidIDs = eligibleCentroidIDs[:effectiveNprobe]
centroidDis = centroidDis[:effectiveNprobe]
if c := C.faiss_IndexBinaryIVF_search_preassigned_with_params(
ivfPtrBinary,
(C.idx_t)(n),
(*C.uint8_t)(&xb[0]),
(C.idx_t)(k),
(*C.idx_t)(&eligibleCentroidIDs[0]),
(*C.int32_t)(¢roidDis[0]),
(*C.int32_t)(&distances[0]),
(*C.idx_t)(&labels[0]),
(C.int)(0),
searchParams.sp,
); c != 0 {
return nil, nil, getLastError()
}
return distances, labels, nil
}
func (b *faissBinaryIndex) Size() uint64 {
size := C.faiss_IndexBinary_size(b.bIdx)
return uint64(size)
}
func (idx *faissBinaryIndex) Close() {
C.faiss_IndexBinary_free(idx.bIdx)
}
type BinaryIndexImpl struct {
BinaryIndex
}
func BinaryIndexFactory(dims int, description string) (*BinaryIndexImpl, error) {
var cDescription *C.char
if description != "" {
cDescription = C.CString(description)
defer C.free(unsafe.Pointer(cDescription))
}
var idx faissBinaryIndex
if c := C.faiss_index_binary_factory(&idx.bIdx, C.int(dims), cDescription); c != 0 {
return nil, getLastError()
}
return &BinaryIndexImpl{&idx}, nil
}