77#include < omp.h>
88#endif
99
10- #include " hnsw/builder.h"
11- #include " hnsw/hnsw.h"
12- #include " searcher.h"
10+ #include " core/interfaces.h"
11+ #include " graph/builder_factory.h"
12+ #include " graph/graph.h"
13+ #include " quantization/fp32_quant.h"
14+ #include " searcher/searcher.h"
1315
1416namespace py = pybind11;
1517
@@ -62,10 +64,10 @@ void parallel_for(size_t n, int num_threads, Func f) {
6264// 3. Graph wrapper
6365// -----------------------------------------------------------------------------
6466struct Graph {
65- deepsearch::Graph< int > graph;
67+ deepsearch::graph:: Graph graph;
6668
6769 Graph () = default ;
68- explicit Graph (const deepsearch::Graph< int > & graph) : graph(graph) {}
70+ explicit Graph (const deepsearch::graph:: Graph& graph) : graph(graph) {}
6971 explicit Graph (const std::string& filename) { graph.load (filename); }
7072
7173 void save (const std::string& filename) { graph.save (filename); }
@@ -76,45 +78,94 @@ struct Graph {
7678// 4. Index wrapper
7779// -----------------------------------------------------------------------------
7880struct Index {
79- std::unique_ptr<deepsearch::Builder> idx;
81+ std::unique_ptr<deepsearch::graph::GraphBuilder<float >> builder;
82+ deepsearch::core::DistanceType distance_type_;
83+ size_t dim_;
8084
8185 Index (const std::string& type, int dim, const std::string& metric, int R = 32 ,
8286 int L = 200 ) {
8387 if (dim <= 0 ) throw py::value_error (" `dim` must be positive" );
8488 if (R <= 0 ) throw py::value_error (" `R` must be positive" );
8589 if (L < 0 ) throw py::value_error (" `L` must be non-negative" );
8690
91+ dim_ = dim;
92+
93+ // 解析距离类型
94+ if (metric == " L2" ) {
95+ distance_type_ = deepsearch::core::DistanceType::L2;
96+ } else if (metric == " IP" ) {
97+ distance_type_ = deepsearch::core::DistanceType::IP;
98+ } else if (metric == " COSINE" ) {
99+ distance_type_ = deepsearch::core::DistanceType::COSINE;
100+ } else {
101+ throw py::value_error (" Unknown metric: " + metric);
102+ }
103+
87104 if (type == " HNSW" ) {
88- idx = std::make_unique<deepsearch::Hnsw>(dim, metric, R, L);
105+ // 使用BuilderFactory创建HNSW构建器
106+ deepsearch::graph::BuilderConfig config;
107+ config.M = R;
108+ config.ef_construction = L;
109+
110+ builder = deepsearch::graph::BuilderFactory<float >::create (
111+ deepsearch::graph::BuilderType::HNSW, distance_type_, dim_, config);
89112 } else {
90113 throw py::value_error (" Unknown index type: " + type);
91114 }
92115 }
93116
94117 Graph build (py::object data) {
95118 auto buf = to_buffer<float >(data);
96- if (buf.cols != idx-> Dim () )
119+ if (buf.cols != dim_ )
97120 throw py::value_error (" Dimension mismatch: expected " +
98- std::to_string (idx-> Dim () ) + " , got " +
121+ std::to_string (dim_ ) + " , got " +
99122 std::to_string (buf.cols ));
100- idx->Build (buf.ptr , buf.rows );
101- return Graph (idx->GetGraph ());
123+
124+ auto graph = builder->build (buf.ptr , buf.rows , buf.cols );
125+ return Graph (graph);
102126 }
103127};
104128
105129// -----------------------------------------------------------------------------
106130// 5. Searcher wrapper
107131// -----------------------------------------------------------------------------
108132struct Searcher {
109- std::unique_ptr<deepsearch::SearcherBase> sr ;
133+ std::unique_ptr<deepsearch::searcher:: SearcherBase> searcher ;
110134 ssize_t dim_;
111135
112136 Searcher (const Graph& graph, py::object data, const std::string& metric,
113- int level ) {
137+ const std::string& quant ) {
114138 auto buf = to_buffer<float >(data);
115139 dim_ = buf.cols ;
116- sr = deepsearch::create_searcher (graph.graph , metric, level);
117- sr->SetData (buf.ptr , buf.rows , buf.cols );
140+
141+ // 解析距离类型
142+ deepsearch::core::DistanceType distance_type;
143+ if (metric == " L2" ) {
144+ distance_type = deepsearch::core::DistanceType::L2;
145+ } else if (metric == " IP" ) {
146+ distance_type = deepsearch::core::DistanceType::IP;
147+ } else if (metric == " COSINE" ) {
148+ distance_type = deepsearch::core::DistanceType::COSINE;
149+ } else {
150+ throw py::value_error (" Unknown metric: " + metric);
151+ }
152+
153+ // 使用SearcherFactory创建FP32搜索器
154+ if (quant == " fp32" ) {
155+ searcher = deepsearch::searcher::SearcherFactory::createFP32 (
156+ graph.graph , distance_type, dim_);
157+ } else if (quant == " sq8" ) {
158+ searcher = deepsearch::searcher::SearcherFactory::createSQ8 (
159+ graph.graph , distance_type, dim_);
160+ } else if (quant == " sq4" ) {
161+ searcher = deepsearch::searcher::SearcherFactory::createSQ4 (
162+ graph.graph , distance_type, dim_);
163+ } else {
164+ throw py::value_error (" Unknown quant: " + quant);
165+ }
166+
167+ // 设置数据
168+ searcher->SetData (buf.ptr , buf.rows , buf.cols );
118169 }
119170
120171 py::array_t <int > search (py::object query, int k) {
@@ -124,7 +175,7 @@ struct Searcher {
124175 " )" );
125176
126177 int * ids = new int [k];
127- sr ->Search (buf.ptr , k, ids);
178+ searcher ->Search (buf.ptr , k, ids);
128179
129180 py::capsule free_when_done (ids,
130181 [](void * f) { delete[] static_cast <int *>(f); });
@@ -145,13 +196,12 @@ struct Searcher {
145196 int * ids = new int [nq * k];
146197
147198 parallel_for (nq, num_threads, [&](size_t i) {
148- sr ->Search (buf.ptr + i * dim_, k, ids + i * k);
199+ searcher ->Search (buf.ptr + i * dim_, k, ids + i * k);
149200 });
150201
151202 py::capsule free_when_done (ids,
152203 [](void * f) { delete[] static_cast <int *>(f); });
153204
154- // 返回二维数组,Python 侧析构时自动调用 capsule
155205 return py::array_t <int >({(ssize_t )nq, (ssize_t )k}, // shape
156206 {(ssize_t )(k * sizeof (int )), // row stride
157207 (ssize_t )(sizeof (int ))}, // col stride
@@ -162,12 +212,13 @@ struct Searcher {
162212
163213 void set_ef (int ef) {
164214 if (ef <= 0 ) throw py::value_error (" `ef` must be positive" );
165- sr ->SetEf (ef);
215+ searcher ->SetEf (ef);
166216 }
167217
168218 void optimize (int num_threads = 0 ) {
169219 // Use parallel_for with a single iteration to adjust threads
170- parallel_for (1 , num_threads, [&](size_t ) { sr->Optimize (num_threads); });
220+ parallel_for (1 , num_threads,
221+ [&](size_t ) { searcher->Optimize (num_threads); });
171222 }
172223};
173224
@@ -202,9 +253,10 @@ PYBIND11_MODULE(deepsearch, m) {
202253 " Build the index from a float array" );
203254
204255 py::class_<Searcher>(m, " Searcher" )
205- .def (py::init<const Graph&, py::object, const std::string&, int >(),
256+ .def (py::init<const Graph&, py::object, const std::string&,
257+ const std::string&>(),
206258 py::arg (" graph" ), py::arg (" data" ), py::arg (" metric" ),
207- py::arg (" level " ))
259+ py::arg (" quant " ))
208260 .def (" search" , &Searcher::search, py::arg (" query" ), py::arg (" k" ),
209261 " Search a single vector" )
210262 .def (" batch_search" , &Searcher::batch_search, py::arg (" query" ),
0 commit comments