Skip to content

Commit 0a7c80d

Browse files
committed
refactor(quantization, python_bindings): Reconstruct Thread-Local Storage and Python Binding Interfaces
1. Thread-Local Storage Optimization - Replaced the thread_local static member `query_` in QuantizerBase with a local static variable inside a function -彻底 resolved symbol duplication issues caused by thread-local variables in template classes - Removed redundant out-of-class definitions in header files 2. Python Binding Interface Refactoring - Introduced BuilderFactory to standardize index builder creation logic - Implemented SearcherFactory to support searchers with multiple quantization types - Optimized header file reference paths and dependencies - Cleaned up legacy implementation code 3. Code Style and Structure Enhancement - Removed redundant header inclusions in builder_factory.cpp - Cleaned up outdated commented code in sq4_quant.h and sq8_quant.h - Standardized header inclusion order in fp32_quant.cpp
1 parent a377071 commit 0a7c80d

File tree

9 files changed

+174
-118
lines changed

9 files changed

+174
-118
lines changed

python_bindings/bindings.cpp

Lines changed: 74 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
#include <omp.h>
88
#endif
99

10-
#include "hnsw/builder.h"
11-
#include "hnsw/hnsw.h"
12-
#include "searcher.h"
10+
#include "core/interfaces.h"
11+
#include "graph/builder_factory.h"
12+
#include "graph/graph.h"
13+
#include "quantization/fp32_quant.h"
14+
#include "searcher/searcher.h"
1315

1416
namespace py = pybind11;
1517

@@ -62,10 +64,10 @@ void parallel_for(size_t n, int num_threads, Func f) {
6264
// 3. Graph wrapper
6365
// -----------------------------------------------------------------------------
6466
struct Graph {
65-
deepsearch::Graph<int> graph;
67+
deepsearch::graph::Graph graph;
6668

6769
Graph() = default;
68-
explicit Graph(const deepsearch::Graph<int>& graph) : graph(graph) {}
70+
explicit Graph(const deepsearch::graph::Graph& graph) : graph(graph) {}
6971
explicit Graph(const std::string& filename) { graph.load(filename); }
7072

7173
void save(const std::string& filename) { graph.save(filename); }
@@ -76,45 +78,94 @@ struct Graph {
7678
// 4. Index wrapper
7779
// -----------------------------------------------------------------------------
7880
struct Index {
79-
std::unique_ptr<deepsearch::Builder> idx;
81+
std::unique_ptr<deepsearch::graph::GraphBuilder<float>> builder;
82+
deepsearch::core::DistanceType distance_type_;
83+
size_t dim_;
8084

8185
Index(const std::string& type, int dim, const std::string& metric, int R = 32,
8286
int L = 200) {
8387
if (dim <= 0) throw py::value_error("`dim` must be positive");
8488
if (R <= 0) throw py::value_error("`R` must be positive");
8589
if (L < 0) throw py::value_error("`L` must be non-negative");
8690

91+
dim_ = dim;
92+
93+
// 解析距离类型
94+
if (metric == "L2") {
95+
distance_type_ = deepsearch::core::DistanceType::L2;
96+
} else if (metric == "IP") {
97+
distance_type_ = deepsearch::core::DistanceType::IP;
98+
} else if (metric == "COSINE") {
99+
distance_type_ = deepsearch::core::DistanceType::COSINE;
100+
} else {
101+
throw py::value_error("Unknown metric: " + metric);
102+
}
103+
87104
if (type == "HNSW") {
88-
idx = std::make_unique<deepsearch::Hnsw>(dim, metric, R, L);
105+
// 使用BuilderFactory创建HNSW构建器
106+
deepsearch::graph::BuilderConfig config;
107+
config.M = R;
108+
config.ef_construction = L;
109+
110+
builder = deepsearch::graph::BuilderFactory<float>::create(
111+
deepsearch::graph::BuilderType::HNSW, distance_type_, dim_, config);
89112
} else {
90113
throw py::value_error("Unknown index type: " + type);
91114
}
92115
}
93116

94117
Graph build(py::object data) {
95118
auto buf = to_buffer<float>(data);
96-
if (buf.cols != idx->Dim())
119+
if (buf.cols != dim_)
97120
throw py::value_error("Dimension mismatch: expected " +
98-
std::to_string(idx->Dim()) + ", got " +
121+
std::to_string(dim_) + ", got " +
99122
std::to_string(buf.cols));
100-
idx->Build(buf.ptr, buf.rows);
101-
return Graph(idx->GetGraph());
123+
124+
auto graph = builder->build(buf.ptr, buf.rows, buf.cols);
125+
return Graph(graph);
102126
}
103127
};
104128

105129
// -----------------------------------------------------------------------------
106130
// 5. Searcher wrapper
107131
// -----------------------------------------------------------------------------
108132
struct Searcher {
109-
std::unique_ptr<deepsearch::SearcherBase> sr;
133+
std::unique_ptr<deepsearch::searcher::SearcherBase> searcher;
110134
ssize_t dim_;
111135

112136
Searcher(const Graph& graph, py::object data, const std::string& metric,
113-
int level) {
137+
const std::string& quant) {
114138
auto buf = to_buffer<float>(data);
115139
dim_ = buf.cols;
116-
sr = deepsearch::create_searcher(graph.graph, metric, level);
117-
sr->SetData(buf.ptr, buf.rows, buf.cols);
140+
141+
// 解析距离类型
142+
deepsearch::core::DistanceType distance_type;
143+
if (metric == "L2") {
144+
distance_type = deepsearch::core::DistanceType::L2;
145+
} else if (metric == "IP") {
146+
distance_type = deepsearch::core::DistanceType::IP;
147+
} else if (metric == "COSINE") {
148+
distance_type = deepsearch::core::DistanceType::COSINE;
149+
} else {
150+
throw py::value_error("Unknown metric: " + metric);
151+
}
152+
153+
// 使用SearcherFactory创建FP32搜索器
154+
if (quant == "fp32") {
155+
searcher = deepsearch::searcher::SearcherFactory::createFP32(
156+
graph.graph, distance_type, dim_);
157+
} else if (quant == "sq8") {
158+
searcher = deepsearch::searcher::SearcherFactory::createSQ8(
159+
graph.graph, distance_type, dim_);
160+
} else if (quant == "sq4") {
161+
searcher = deepsearch::searcher::SearcherFactory::createSQ4(
162+
graph.graph, distance_type, dim_);
163+
} else {
164+
throw py::value_error("Unknown quant: " + quant);
165+
}
166+
167+
// 设置数据
168+
searcher->SetData(buf.ptr, buf.rows, buf.cols);
118169
}
119170

120171
py::array_t<int> search(py::object query, int k) {
@@ -124,7 +175,7 @@ struct Searcher {
124175
")");
125176

126177
int* ids = new int[k];
127-
sr->Search(buf.ptr, k, ids);
178+
searcher->Search(buf.ptr, k, ids);
128179

129180
py::capsule free_when_done(ids,
130181
[](void* f) { delete[] static_cast<int*>(f); });
@@ -145,13 +196,12 @@ struct Searcher {
145196
int* ids = new int[nq * k];
146197

147198
parallel_for(nq, num_threads, [&](size_t i) {
148-
sr->Search(buf.ptr + i * dim_, k, ids + i * k);
199+
searcher->Search(buf.ptr + i * dim_, k, ids + i * k);
149200
});
150201

151202
py::capsule free_when_done(ids,
152203
[](void* f) { delete[] static_cast<int*>(f); });
153204

154-
// 返回二维数组,Python 侧析构时自动调用 capsule
155205
return py::array_t<int>({(ssize_t)nq, (ssize_t)k}, // shape
156206
{(ssize_t)(k * sizeof(int)), // row stride
157207
(ssize_t)(sizeof(int))}, // col stride
@@ -162,12 +212,13 @@ struct Searcher {
162212

163213
void set_ef(int ef) {
164214
if (ef <= 0) throw py::value_error("`ef` must be positive");
165-
sr->SetEf(ef);
215+
searcher->SetEf(ef);
166216
}
167217

168218
void optimize(int num_threads = 0) {
169219
// Use parallel_for with a single iteration to adjust threads
170-
parallel_for(1, num_threads, [&](size_t) { sr->Optimize(num_threads); });
220+
parallel_for(1, num_threads,
221+
[&](size_t) { searcher->Optimize(num_threads); });
171222
}
172223
};
173224

@@ -202,9 +253,10 @@ PYBIND11_MODULE(deepsearch, m) {
202253
"Build the index from a float array");
203254

204255
py::class_<Searcher>(m, "Searcher")
205-
.def(py::init<const Graph&, py::object, const std::string&, int>(),
256+
.def(py::init<const Graph&, py::object, const std::string&,
257+
const std::string&>(),
206258
py::arg("graph"), py::arg("data"), py::arg("metric"),
207-
py::arg("level"))
259+
py::arg("quant"))
208260
.def("search", &Searcher::search, py::arg("query"), py::arg("k"),
209261
"Search a single vector")
210262
.def("batch_search", &Searcher::batch_search, py::arg("query"),

src/graph/builder_factory.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
#include <stdexcept>
44

55
#include "core/exceptions.h"
6-
#include "hnsw_builder.cpp" // 包含模板实现
76

87
namespace deepsearch {
98
namespace graph {

src/main.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
#include <iostream>
44
#include <string>
55

6-
#include "builder_factory.h"
7-
#include "searcher.h"
6+
#include "graph/builder_factory.h"
7+
#include "searcher/searcher.h"
88

99
using namespace deepsearch::core;
1010
using namespace deepsearch::graph;

src/quantization/fp32_quant.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
#include "fp32_quant.h"
22

3-
#include <core/factory.h>
4-
#include <distance/computers.h>
5-
63
#include <cstring>
74
#include <stdexcept>
85

96
#include "allocator.h"
7+
#include "distance/computers.h"
108

119
namespace deepsearch {
1210
namespace quantization {
@@ -67,21 +65,22 @@ void FP32Quantizer::prefetch_data(size_t index, int lines) const {
6765

6866
void FP32Quantizer::encode_query(const float* query) {
6967
ensure_thread_query_initialized();
70-
encode(query, query_.get());
68+
encode(query, get_query().get());
7169
}
7270

7371
float FP32Quantizer::compute_query_distance(size_t index) const {
7472
const float* data_code = reinterpret_cast<const float*>(get_data(index));
75-
return distance_computer_->compute(query_.get(), data_code);
73+
return distance_computer_->compute(get_query().get(), data_code);
7674
}
7775

7876
float FP32Quantizer::compute_query_distance(const float* code) const {
79-
return distance_computer_->compute(query_.get(), code);
77+
return distance_computer_->compute(get_query().get(), code);
8078
}
8179

8280
void FP32Quantizer::ensure_thread_query_initialized() {
83-
if (query_ == nullptr) {
84-
query_.reset(static_cast<float*>(alloc64B(d_align * sizeof(float))));
81+
auto& query = get_query();
82+
if (query == nullptr) {
83+
query.reset(static_cast<float*>(alloc64B(d_align * sizeof(float))));
8584
}
8685
}
8786

src/quantization/quantizer.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,19 +46,19 @@ class QuantizerBase {
4646
}
4747
};
4848

49-
// 将 query_ 改为线程局部存储
50-
static thread_local std::unique_ptr<CodeType[], FreeDeleter> query_;
49+
// 使用函数返回局部静态变量,完全避免重复符号问题
50+
static std::unique_ptr<CodeType[], FreeDeleter>& get_query() {
51+
thread_local std::unique_ptr<CodeType[], FreeDeleter> query_ = nullptr;
52+
return query_;
53+
}
54+
55+
// query_ 使用线程局部存储,使用 inline 关键字解决重复符号问题
56+
// inline static thread_local std::unique_ptr<CodeType[], FreeDeleter> query_;
5157

5258
// 添加线程局部初始化方法
5359
virtual void ensure_thread_query_initialized() = 0;
5460
};
5561

56-
// 线程局部变量类外定义
57-
template <typename InputType, typename CodeType>
58-
thread_local std::unique_ptr<
59-
CodeType[], typename QuantizerBase<InputType, CodeType>::FreeDeleter>
60-
QuantizerBase<InputType, CodeType>::query_ = nullptr;
61-
6262
// 距离计算器基类
6363
template <typename T>
6464
class DistanceComputerBase {

src/quantization/sq4_quant.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -103,21 +103,22 @@ float SQ4Quantizer::compute_distance(const uint8_t* a, const uint8_t* b) const {
103103

104104
void SQ4Quantizer::encode_query(const float* query) {
105105
ensure_thread_query_initialized();
106-
encode(query, query_.get());
106+
encode(query, get_query().get());
107107
}
108108

109109
float SQ4Quantizer::compute_query_distance(size_t index) const {
110110
const uint8_t* data_code = reinterpret_cast<const uint8_t*>(get_data(index));
111-
return distance_computer_->compute(query_.get(), data_code);
111+
return distance_computer_->compute(get_query().get(), data_code);
112112
}
113113

114114
float SQ4Quantizer::compute_query_distance(const uint8_t* code) const {
115-
return distance_computer_->compute(query_.get(), code);
115+
return distance_computer_->compute(get_query().get(), code);
116116
}
117117

118118
void SQ4Quantizer::ensure_thread_query_initialized() {
119-
if (query_ == nullptr) {
120-
query_.reset(static_cast<uint8_t*>(alloc64B(d_align * sizeof(uint8_t))));
119+
auto& query = get_query();
120+
if (query == nullptr) {
121+
query.reset(static_cast<uint8_t*>(alloc64B(d_align * sizeof(uint8_t))));
121122
}
122123
}
123124

src/quantization/sq4_quant.h

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33
#include <algorithm>
44
#include <vector>
55

6-
#include "fp32_quant.h"
76
#include "quantizer.h"
87
#include "simd/distance_functions.h"
98

109
namespace deepsearch {
1110
namespace quantization {
1211

12+
class FP32Quantizer;
13+
1314
class SQ4Quantizer : public QuantizerBase<float, uint8_t> {
1415
public:
1516
using data_type = uint8_t;
@@ -49,36 +50,37 @@ class SQ4Quantizer : public QuantizerBase<float, uint8_t> {
4950
// 重排序接口 - 使用FP32量化器进行精排
5051
template <typename Pool>
5152
void reorder(const Pool& pool, const float* query, int* dst, int k) const {
52-
if (reorder_quantizer_) {
53-
// 使用FP32量化器进行精确重排序
54-
std::vector<std::pair<int, float>> candidates;
55-
candidates.reserve(std::min(k, pool.size()));
56-
57-
for (int i = 0; i < std::min(k, pool.size()); ++i) {
58-
int id = pool.id(i);
59-
if (id >= 0) {
60-
// 从FP32量化器获取原始浮点数据
61-
const float* fp32_data =
62-
reinterpret_cast<const float*>(reorder_quantizer_->get_data(id));
63-
// 使用FP32量化器的距离计算方法
64-
float exact_dist =
65-
reorder_quantizer_->compute_distance(query, fp32_data);
66-
candidates.emplace_back(id, exact_dist);
67-
}
68-
}
69-
70-
std::sort(
71-
candidates.begin(), candidates.end(),
72-
[](const auto& a, const auto& b) { return a.second < b.second; });
73-
74-
int result_size = std::min(k, static_cast<int>(candidates.size()));
75-
for (int i = 0; i < result_size; ++i) {
76-
dst[i] = candidates[i].first;
77-
}
78-
for (int i = result_size; i < k; ++i) {
79-
dst[i] = -1;
80-
}
81-
}
53+
// if (reorder_quantizer_) {
54+
// // 使用FP32量化器进行精确重排序
55+
// std::vector<std::pair<int, float>> candidates;
56+
// candidates.reserve(std::min(k, pool.size()));
57+
//
58+
// for (int i = 0; i < std::min(k, pool.size()); ++i) {
59+
// int id = pool.id(i);
60+
// if (id >= 0) {
61+
// // 从FP32量化器获取原始浮点数据
62+
// const float* fp32_data =
63+
// reinterpret_cast<const
64+
// float*>(reorder_quantizer_->get_data(id));
65+
// // 使用FP32量化器的距离计算方法
66+
// float exact_dist =
67+
// reorder_quantizer_->compute_distance(query, fp32_data);
68+
// candidates.emplace_back(id, exact_dist);
69+
// }
70+
// }
71+
//
72+
// std::sort(
73+
// candidates.begin(), candidates.end(),
74+
// [](const auto& a, const auto& b) { return a.second < b.second; });
75+
//
76+
// int result_size = std::min(k, static_cast<int>(candidates.size()));
77+
// for (int i = 0; i < result_size; ++i) {
78+
// dst[i] = candidates[i].first;
79+
// }
80+
// for (int i = result_size; i < k; ++i) {
81+
// dst[i] = -1;
82+
// }
83+
// }
8284
}
8385

8486
private:

0 commit comments

Comments
 (0)