diff --git a/.gitignore b/.gitignore index 780116a..3d0d8bb 100644 --- a/.gitignore +++ b/.gitignore @@ -39,6 +39,9 @@ __pycache__/ venv .venv /build/ +/.cache/ +/.vscode/ +/.claude/ /cmake-build-debug/ /cmake-build-release/ /cmake-build*/ @@ -76,6 +79,7 @@ pdxearch.egg-info /benchmarks/datasets/queries /benchmarks/datasets/selection_vectors /benchmarks/datasets/ground_truth_filtered +/benchmarks/results/DEFAULT/*.csv /benchmarks/gt_filtered @@ -101,6 +105,8 @@ cmake_install.cmake /benchmarks/BenchmarkPDXIVF /benchmarks/BenchmarkFiltered /benchmarks/BenchmarkSpecialFilters +/benchmarks/BenchmarkInsertion +/benchmarks/BenchmarkWorkload # Test binaries (but keep the committed test data) *.bin diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 7a2922d..d086713 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -5,11 +5,9 @@ We are actively developing PDX and accepting contributions! Any kind of PR is we These are our current priorities: **Features**: -- Inserts and Updates (wip). - Out-of-core execution (disk-based setting). - Implement multi-threading capabilities. - Add PDX to the [VIBE benchmark](https://vector-index-bench.github.io/). -- Create a documentation. **Improvements**: - Regression tests on CI. diff --git a/README.md b/README.md index cbfd6ec..f874b23 100644 --- a/README.md +++ b/README.md @@ -21,8 +21,10 @@ - ⚡ [**Sub-millisecond similarity search**](https://www.lkuffo.com/sub-milisecond-similarity-search-with-pdx/), up to [**10x faster**](./BENCHMARKING.md#two-level-ivf-ivf2-) than FAISS IVF. - ⚡ Up to [**30x faster**](./BENCHMARKING.md#exhaustive-search--ivf) exhaustive search. - 🔍 Efficient [**filtered search**](https://github.com/cwida/PDX/issues/7). +- ⚙️ Fast and reliable [**index maintenance**](https://github.com/cwida/PDX/pull/13). - Query latency competitive with HNSW, with the ease of use of IVF. + ## Our secret sauce [PDX](https://ir.cwi.nl/pub/35044/35044.pdf) is a data layout that **transposes** vectors in a column-major order. This layout unleashes the true potential of dimension pruning. @@ -42,14 +44,20 @@ query = ... # Numpy 1D array d = 1024 knn = 20 +# Build index = IndexPDXIVFTreeSQ8(num_dimensions=d) index.build(data) +# Search ids, dists = index.search(query, knn) +# Maintenance +index.append(row_id_to_insert, new_embedding) +index.delete(row_id_to_delete) + ``` -`IndexPDXIVFTreeSQ8` is our fastest index that will give you the best performance. It is a two-level IVF index with 8-bit quantization. +`IndexPDXIVFTreeSQ8` is our fastest index that will give you the best performance alongside lightweight maintenance. It is a two-level IVF index with 8-bit quantization. Check our [examples](./examples/) for fully working examples in Python and our [benchmarks](./benchmarks) for fully working examples in C++. We support Flat (`float32`) and Quantized (`8-bit`) indexes, as well as the most common distance metrics. diff --git a/benchmarks/CMakeLists.txt b/benchmarks/CMakeLists.txt index fa95195..4f47b15 100644 --- a/benchmarks/CMakeLists.txt +++ b/benchmarks/CMakeLists.txt @@ -12,12 +12,16 @@ add_executable(BenchmarkEndToEnd pdx_end_to_end.cpp) add_executable(BenchmarkSerialization pdx_serialization.cpp) add_executable(BenchmarkFiltered pdx_filtered.cpp) add_executable(BenchmarkSpecialFilters pdx_special_filtered.cpp) +add_executable(BenchmarkInsertion pdx_insertion.cpp) +add_executable(BenchmarkWorkload pdx_workload.cpp) target_link_libraries(BenchmarkPDXIVF ${BENCH_COMMON_LIBS}) target_link_libraries(BenchmarkEndToEnd ${BENCH_COMMON_LIBS}) target_link_libraries(BenchmarkSerialization ${BENCH_COMMON_LIBS}) target_link_libraries(BenchmarkFiltered ${BENCH_COMMON_LIBS}) target_link_libraries(BenchmarkSpecialFilters ${BENCH_COMMON_LIBS}) +target_link_libraries(BenchmarkInsertion ${BENCH_COMMON_LIBS}) +target_link_libraries(BenchmarkWorkload ${BENCH_COMMON_LIBS}) add_custom_target(benchmarks DEPENDS @@ -26,4 +30,6 @@ add_custom_target(benchmarks BenchmarkSerialization BenchmarkFiltered BenchmarkSpecialFilters + BenchmarkInsertion + BenchmarkWorkload ) diff --git a/benchmarks/benchmark_utils.hpp b/benchmarks/benchmark_utils.hpp index 345a8c9..de2c8f9 100644 --- a/benchmarks/benchmark_utils.hpp +++ b/benchmarks/benchmark_utils.hpp @@ -43,7 +43,8 @@ class TicToc { }; // Raw binary data paths (SuperKMeans convention: data_.bin / data__test.bin) -inline std::string RAW_DATA_DIR = std::string{CMAKE_SOURCE_DIR} + "/../SuperKMeans/benchmarks/data"; +inline std::string RAW_DATA_DIR = + std::string{CMAKE_SOURCE_DIR} + "/../../SuperKMeans/benchmarks/data"; inline std::string GROUND_TRUTH_JSON_DIR = std::string{CMAKE_SOURCE_DIR} + "/../SuperKMeans/benchmarks/ground_truth"; @@ -88,6 +89,13 @@ struct PhasesRuntime { size_t end_to_end{0}; }; +enum class StepType { BUILD, INSERT, DELETE }; + +struct WorkloadStep { + StepType type; + float proportion; // fraction of total dataset size N +}; + class BenchmarkUtils { public: inline static std::string PDX_DATA = diff --git a/benchmarks/pdx_end_to_end.cpp b/benchmarks/pdx_end_to_end.cpp index 927213f..ba3b793 100644 --- a/benchmarks/pdx_end_to_end.cpp +++ b/benchmarks/pdx_end_to_end.cpp @@ -7,11 +7,11 @@ #include #include #include -#include #include #include "benchmark_utils.hpp" #include "pdx/index.hpp" +#include "pdx/profiler.hpp" #include "pdx/utils.hpp" template @@ -106,6 +106,7 @@ void RunBenchmark( runtimes[j + l * NUM_MEASURE_RUNS] = {clock.accum_time}; } } + PDX::Profiler::Get().PrintHierarchical(); BenchmarkMetadata results_metadata = { dataset, diff --git a/benchmarks/pdx_insertion.cpp b/benchmarks/pdx_insertion.cpp new file mode 100644 index 0000000..ea7fce5 --- /dev/null +++ b/benchmarks/pdx_insertion.cpp @@ -0,0 +1,225 @@ +#ifndef BENCHMARK_TIME +#define BENCHMARK_TIME = true +#endif + +#include +#include +#include +#include +#include + +#include "benchmark_utils.hpp" +#include "pdx/index.hpp" +#include "pdx/profiler.hpp" +#include "pdx/utils.hpp" + +template +void RunBenchmark( + const RawDatasetInfo& info, + const std::string& dataset, + const std::string& algorithm, + const float* data, + const float* queries, + const std::vector& nprobes_to_use, + const float proportion_to_build +) { + const size_t d = info.num_dimensions; + const size_t n = info.num_embeddings; + const size_t n_queries = info.num_queries; + uint8_t KNN = BenchmarkUtils::KNN; + size_t NUM_MEASURE_RUNS = BenchmarkUtils::NUM_MEASURE_RUNS; + std::string RESULTS_PATH = BENCHMARK_UTILS.RESULTS_DIR_PATH + "INSERTION_PDX.csv"; + + const size_t n_build = static_cast(n * proportion_to_build); + const size_t n_insert = n - n_build; + + PDX::PDXIndexConfig index_config{ + .num_dimensions = static_cast(d), + .distance_metric = info.distance_metric, + .seed = 42, + .normalize = true, + .sampling_fraction = 1.0f + }; + + // Build index with 75% of the data + TicToc clock; + std::cout << "Building index with " << n_build << " / " << n << " embeddings...\n"; + clock.Reset(); + clock.Tic(); + IndexT pdx_index(index_config); + pdx_index.BuildIndex(data, n_build); + clock.Toc(); + std::cout << "Build time: " << clock.GetMilliseconds() << " ms\n"; + std::cout << "Clusters: " << pdx_index.GetNumClusters() << "\n"; + std::cout << "Index in-memory size: " << std::fixed << std::setprecision(2) + << static_cast(pdx_index.GetInMemorySizeInBytes()) / (1024.0 * 1024.0) + << " MB\n"; + + // Insert remaining 25% + std::cout << "Inserting " << n_insert << " embeddings...\n"; + clock.Reset(); + clock.Tic(); + for (size_t i = 0; i < n_insert; ++i) { + size_t row_id = n_build + i; + std::cout << "Inserting embedding " << row_id << " / " << n - 1 << "\r" << std::flush; + pdx_index.Append(row_id, data + row_id * d); + } + clock.Toc(); + std::cout << "Insertion time: " << clock.GetMilliseconds() << " ms\n"; + std::cout << "Avg insertion time: " << clock.GetMilliseconds() / n_insert << " ms/embedding\n"; + std::cout << "Clusters after insertion: " << pdx_index.GetNumClusters() << "\n"; + std::cout << "Index in-memory size after insertion: " << std::fixed << std::setprecision(2) + << static_cast(pdx_index.GetInMemorySizeInBytes()) / (1024.0 * 1024.0) + << " MB\n"; + + PDX::Profiler::Get().PrintHierarchical(); + + // Load ground truth + std::string gt_path = BenchmarkUtils::GROUND_TRUTH_DATA + info.pdx_dataset_name + "_100_norm"; + auto gt_buffer = MmapFile(gt_path); + uint32_t* int_ground_truth = reinterpret_cast(gt_buffer.get()); + std::cout << "Ground truth loaded: " << gt_path << "\n"; + + for (size_t ivf_nprobe : nprobes_to_use) { + if (pdx_index.GetNumClusters() < ivf_nprobe) + continue; + + pdx_index.SetNProbe(ivf_nprobe); + + // Recall pass + float recalls = 0; + for (size_t l = 0; l < n_queries; ++l) { + auto result = pdx_index.Search(queries + l * d, KNN); + BenchmarkUtils::VerifyResult(recalls, result, KNN, int_ground_truth, l); + } + + // Timing pass + std::vector runtimes; + runtimes.resize(NUM_MEASURE_RUNS * n_queries); + TicToc search_clock; + for (size_t j = 0; j < NUM_MEASURE_RUNS; ++j) { + for (size_t l = 0; l < n_queries; ++l) { + search_clock.Reset(); + search_clock.Tic(); + pdx_index.Search(queries + l * d, KNN); + search_clock.Toc(); + runtimes[j + l * NUM_MEASURE_RUNS] = {search_clock.accum_time}; + } + } + + BenchmarkMetadata results_metadata = { + dataset, + algorithm, + NUM_MEASURE_RUNS, + n_queries, + ivf_nprobe, + KNN, + recalls, + }; + BenchmarkUtils::SaveResults(runtimes, RESULTS_PATH, results_metadata); + } +} + +int main(int argc, char* argv[]) { + if (argc < 2) { + std::cerr << "Usage: " << argv[0] << " [index_type] [nprobe] [build_fraction]\n"; + std::cerr << "Index types: pdx_tree_f32 (default), pdx_tree_u8\n"; + std::cerr << "Available datasets:"; + for (const auto& [name, _] : RAW_DATASET_PARAMS) { + std::cerr << " " << name; + } + std::cerr << "\n"; + return 1; + } + std::string dataset = argv[1]; + std::string index_type = (argc > 2) ? argv[2] : "pdx_tree_f32"; + size_t arg_ivf_nprobe = (argc > 3) ? std::atoi(argv[3]) : 0; + float proportion_to_build = (argc > 4) ? std::atof(argv[4]) : 0.75f; + + if (proportion_to_build <= 0.0f || proportion_to_build >= 1.0f) { + std::cerr << "Error: build_fraction must be in (0, 1). Got: " << proportion_to_build + << "\n"; + return 1; + } + + if (index_type != "pdx_tree_f32" && index_type != "pdx_tree_u8") { + std::cerr << "Error: Only pdx_tree_f32 and pdx_tree_u8 support maintenance (insertion).\n"; + std::cerr << "Got: " << index_type << "\n"; + return 1; + } + + auto it = RAW_DATASET_PARAMS.find(dataset); + if (it == RAW_DATASET_PARAMS.end()) { + std::cerr << "Unknown dataset: " << dataset << "\n"; + return 1; + } + const auto& info = it->second; + const size_t n = info.num_embeddings; + const size_t d = info.num_dimensions; + const size_t n_queries = info.num_queries; + + std::cout << "==> PDX Insertion Benchmark (Build " + << static_cast(proportion_to_build * 100) << "% + Insert " + << static_cast((1.0f - proportion_to_build) * 100) << "% + Search)\n"; + std::cout << "Dataset: " << dataset << " (n=" << n << ", d=" << d << ")\n"; + std::cout << "Index type: " << index_type << "\n"; + + // Read data + std::string data_path = RAW_DATA_DIR + "/data_" + dataset + ".bin"; + std::string query_path = RAW_DATA_DIR + "/data_" + dataset + "_test.bin"; + + std::vector data(n * d); + { + std::ifstream file(data_path, std::ios::binary); + if (!file) { + std::cerr << "Failed to open " << data_path << "\n"; + return 1; + } + file.read(reinterpret_cast(data.data()), n * d * sizeof(float)); + } + + std::vector queries(n_queries * d); + { + std::ifstream file(query_path, std::ios::binary); + if (!file) { + std::cerr << "Failed to open " << query_path << "\n"; + return 1; + } + file.read(reinterpret_cast(queries.data()), n_queries * d * sizeof(float)); + } + + std::vector nprobes_to_use; + if (arg_ivf_nprobe > 0) { + nprobes_to_use = {arg_ivf_nprobe}; + } else { + nprobes_to_use.assign( + std::begin(BenchmarkUtils::IVF_PROBES), std::end(BenchmarkUtils::IVF_PROBES) + ); + } + + std::string algorithm = "insertion_" + index_type; + + if (index_type == "pdx_tree_f32") { + RunBenchmark( + info, + dataset, + algorithm, + data.data(), + queries.data(), + nprobes_to_use, + proportion_to_build + ); + } else if (index_type == "pdx_tree_u8") { + RunBenchmark( + info, + dataset, + algorithm, + data.data(), + queries.data(), + nprobes_to_use, + proportion_to_build + ); + } + + return 0; +} diff --git a/benchmarks/pdx_workload.cpp b/benchmarks/pdx_workload.cpp new file mode 100644 index 0000000..eeab60e --- /dev/null +++ b/benchmarks/pdx_workload.cpp @@ -0,0 +1,330 @@ +#ifndef BENCHMARK_TIME +#define BENCHMARK_TIME = true +#endif + +#include +#include +#include +#include +#include + +#include "benchmark_utils.hpp" +#include "pdx/index.hpp" +#include "pdx/profiler.hpp" +#include "pdx/utils.hpp" + +// ---- Edit workload here ---- +static const std::vector WORKLOAD = { + // {StepType::BUILD, 0.50f}, + // {StepType::INSERT, 0.20f}, + // {StepType::DELETE, 0.10f}, + // {StepType::INSERT, 0.30f}, + // {StepType::DELETE, 0.20f}, + // {StepType::INSERT, 0.30f}, + + {StepType::BUILD, 0.50f}, + {StepType::INSERT, 0.50f} +}; + +template +void RunWorkload( + const RawDatasetInfo& info, + const std::string& dataset, + const std::string& algorithm, + const float* data, + const float* queries, + const std::vector& nprobes_to_use, + const std::vector& workload +) { + const size_t d = info.num_dimensions; + const size_t n = info.num_embeddings; + const size_t n_queries = info.num_queries; + uint8_t KNN = BenchmarkUtils::KNN; + size_t NUM_MEASURE_RUNS = BenchmarkUtils::NUM_MEASURE_RUNS; + std::string RESULTS_PATH = BENCHMARK_UTILS.RESULTS_DIR_PATH + "WORKLOAD_PDX.csv"; + + PDX::PDXIndexConfig index_config{ + .num_dimensions = static_cast(d), + .distance_metric = info.distance_metric, + .seed = 42, + .normalize = true, + .sampling_fraction = 1.0f + }; + + IndexT pdx_index(index_config); + TicToc clock; + + // State tracking: + // - next_row_id: monotonically increasing row_id counter (row_ids are never reused) + // - live_entries: (row_id, data_index) pairs currently in the index (stack for deletes) + // - available_data: data indices freed by deletes, available for re-insertion + // - next_data_cursor: next fresh data index (for inserts when available_data is empty) + // - data_to_row_id: maps data array index → current row_id (for ground truth remapping) + size_t next_row_id = 0; + size_t next_data_cursor = 0; + std::vector> live_entries; // (row_id, data_index) + std::vector available_data; + std::vector data_to_row_id(n); + live_entries.reserve(n); + + // Execute workload steps + for (size_t step_idx = 0; step_idx < workload.size(); ++step_idx) { + const auto& step = workload[step_idx]; + size_t count = static_cast(n * step.proportion); + + switch (step.type) { + case StepType::BUILD: { + std::cout << "\n=== Step " << step_idx << ": BUILD " << count << " embeddings ===\n"; + clock.Reset(); + clock.Tic(); + pdx_index.BuildIndex(data, count); + clock.Toc(); + for (size_t i = 0; i < count; ++i) { + live_entries.push_back({i, i}); + data_to_row_id[i] = static_cast(i); + } + next_row_id = count; + next_data_cursor = count; + std::cout << "Build time: " << clock.GetMilliseconds() << " ms\n"; + break; + } + case StepType::INSERT: { + if (available_data.size() + (n - next_data_cursor) < count) { + std::cerr << "Step " << step_idx << ": INSERT " << count << " but only " + << available_data.size() + (n - next_data_cursor) + << " data points available\n"; + return; + } + std::cout << "\n=== Step " << step_idx << ": INSERT " << count << " embeddings ===\n"; + clock.Reset(); + clock.Tic(); + for (size_t i = 0; i < count; ++i) { + // Pick a data point: reuse freed ones first, then fresh + size_t data_idx; + if (!available_data.empty()) { + data_idx = available_data.back(); + available_data.pop_back(); + } else { + data_idx = next_data_cursor++; + } + size_t row_id = next_row_id++; + std::cout << "Inserting row_id=" << row_id << " (data=" << data_idx << ")\r" + << std::flush; + pdx_index.Append(row_id, data + data_idx * d); + live_entries.push_back({row_id, data_idx}); + data_to_row_id[data_idx] = static_cast(row_id); + } + clock.Toc(); + std::cout << "\nInsertion time: " << clock.GetMilliseconds() << " ms\n"; + std::cout << "Avg insertion time: " << clock.GetMilliseconds() / count + << " ms/embedding\n"; + break; + } + case StepType::DELETE: { + if (count > live_entries.size()) { + std::cerr << "Step " << step_idx << ": DELETE " << count << " but only " + << live_entries.size() << " live entries\n"; + return; + } + std::cout << "\n=== Step " << step_idx << ": DELETE " << count << " embeddings ===\n"; + clock.Reset(); + clock.Tic(); + for (size_t i = 0; i < count; ++i) { + auto [row_id, data_idx] = live_entries.back(); + live_entries.pop_back(); + std::cout << "Deleting row_id=" << row_id << " (" << i + 1 << "/" << count << ")\r" + << std::flush; + pdx_index.Delete(row_id); + available_data.push_back(data_idx); + } + clock.Toc(); + std::cout << "\nDeletion time: " << clock.GetMilliseconds() << " ms\n"; + std::cout << "Avg deletion time: " << clock.GetMilliseconds() / count + << " ms/embedding\n"; + break; + } + } + + std::cout << "Clusters: " << pdx_index.GetNumClusters() << "\n"; + std::cout << "Index in-memory size: " << std::fixed << std::setprecision(2) + << static_cast(pdx_index.GetInMemorySizeInBytes()) / (1024.0 * 1024.0) + << " MB\n"; + std::cout << "Live embeddings: " << live_entries.size() << "\n"; + } + + PDX::Profiler::Get().PrintHierarchical(); + + // Load ground truth and remap data indices → current row_ids. + // Ground truth entries are data indices (0..N-1). After deletes + re-inserts, + // some data points have new row_ids. We remap so VerifyResult can compare. + std::string gt_path = BenchmarkUtils::GROUND_TRUTH_DATA + info.pdx_dataset_name + "_100_norm"; + auto gt_buffer = MmapFile(gt_path); + uint32_t* original_gt = reinterpret_cast(gt_buffer.get()); + + const size_t gt_max_k = BenchmarkUtils::GROUND_TRUTH_MAX_K; + std::vector remapped_gt(n_queries * gt_max_k); + for (size_t q = 0; q < n_queries; ++q) { + for (size_t k = 0; k < gt_max_k; ++k) { + uint32_t data_idx = original_gt[k + q * gt_max_k]; + remapped_gt[k + q * gt_max_k] = data_to_row_id[data_idx]; + } + } + std::cout << "\nGround truth loaded and remapped: " << gt_path << "\n"; + + for (size_t ivf_nprobe : nprobes_to_use) { + if (pdx_index.GetNumClusters() < ivf_nprobe) + continue; + + pdx_index.SetNProbe(ivf_nprobe); + + // Recall pass + float recalls = 0; + for (size_t l = 0; l < n_queries; ++l) { + auto result = pdx_index.Search(queries + l * d, KNN); + BenchmarkUtils::VerifyResult(recalls, result, KNN, remapped_gt.data(), l); + } + + // Timing pass + std::vector runtimes; + runtimes.resize(NUM_MEASURE_RUNS * n_queries); + TicToc search_clock; + for (size_t j = 0; j < NUM_MEASURE_RUNS; ++j) { + for (size_t l = 0; l < n_queries; ++l) { + search_clock.Reset(); + search_clock.Tic(); + pdx_index.Search(queries + l * d, KNN); + search_clock.Toc(); + runtimes[j + l * NUM_MEASURE_RUNS] = {search_clock.accum_time}; + } + } + + BenchmarkMetadata results_metadata = { + dataset, + algorithm, + NUM_MEASURE_RUNS, + n_queries, + ivf_nprobe, + KNN, + recalls, + }; + BenchmarkUtils::SaveResults(runtimes, RESULTS_PATH, results_metadata); + } +} + +int main(int argc, char* argv[]) { + const auto& workload = WORKLOAD; + + // Validate workload: build + inserts - deletes must equal 1.0 + float net_proportion = 0.0f; + for (const auto& step : workload) { + if (step.type == StepType::DELETE) { + net_proportion -= step.proportion; + } else { + net_proportion += step.proportion; + } + } + if (std::abs(net_proportion - 1.0f) > 1e-5f) { + std::cerr << "Error: workload net proportion must equal 1.0 " + << "(build + inserts - deletes), got: " << net_proportion << "\n"; + return 1; + } + + if (argc < 2) { + std::cerr << "Usage: " << argv[0] << " [index_type] [nprobe]\n"; + std::cerr << "Index types: pdx_tree_f32 (default), pdx_tree_u8\n"; + std::cerr << "Available datasets:"; + for (const auto& [name, _] : RAW_DATASET_PARAMS) { + std::cerr << " " << name; + } + std::cerr << "\n"; + return 1; + } + std::string dataset = argv[1]; + std::string index_type = (argc > 2) ? argv[2] : "pdx_tree_f32"; + size_t arg_ivf_nprobe = (argc > 3) ? std::atoi(argv[3]) : 0; + + if (index_type != "pdx_tree_f32" && index_type != "pdx_tree_u8") { + std::cerr << "Error: Only pdx_tree_f32 and pdx_tree_u8 support maintenance.\n"; + std::cerr << "Got: " << index_type << "\n"; + return 1; + } + + auto it = RAW_DATASET_PARAMS.find(dataset); + if (it == RAW_DATASET_PARAMS.end()) { + std::cerr << "Unknown dataset: " << dataset << "\n"; + return 1; + } + const auto& info = it->second; + const size_t n = info.num_embeddings; + const size_t d = info.num_dimensions; + const size_t n_queries = info.num_queries; + + // Print workload summary + std::cout << "==> PDX Workload Benchmark\n"; + std::cout << "Dataset: " << dataset << " (n=" << n << ", d=" << d << ")\n"; + std::cout << "Index type: " << index_type << "\n"; + std::cout << "Workload: "; + for (size_t i = 0; i < workload.size(); ++i) { + if (i > 0) + std::cout << " -> "; + switch (workload[i].type) { + case StepType::BUILD: + std::cout << "build(" << workload[i].proportion << ")"; + break; + case StepType::INSERT: + std::cout << "insert(" << workload[i].proportion << ")"; + break; + case StepType::DELETE: + std::cout << "delete(" << workload[i].proportion << ")"; + break; + } + } + std::cout << "\n"; + + // Read data + std::string data_path = RAW_DATA_DIR + "/data_" + dataset + ".bin"; + std::string query_path = RAW_DATA_DIR + "/data_" + dataset + "_test.bin"; + + std::vector data(n * d); + { + std::ifstream file(data_path, std::ios::binary); + if (!file) { + std::cerr << "Failed to open " << data_path << "\n"; + return 1; + } + file.read(reinterpret_cast(data.data()), n * d * sizeof(float)); + } + + std::vector queries(n_queries * d); + { + std::ifstream file(query_path, std::ios::binary); + if (!file) { + std::cerr << "Failed to open " << query_path << "\n"; + return 1; + } + file.read(reinterpret_cast(queries.data()), n_queries * d * sizeof(float)); + } + + std::vector nprobes_to_use; + if (arg_ivf_nprobe > 0) { + nprobes_to_use = {arg_ivf_nprobe}; + } else { + nprobes_to_use.assign( + std::begin(BenchmarkUtils::IVF_PROBES), std::end(BenchmarkUtils::IVF_PROBES) + ); + } + + std::string algorithm = "workload_" + index_type; + + if (index_type == "pdx_tree_f32") { + RunWorkload( + info, dataset, algorithm, data.data(), queries.data(), nprobes_to_use, workload + ); + } else if (index_type == "pdx_tree_u8") { + RunWorkload( + info, dataset, algorithm, data.data(), queries.data(), nprobes_to_use, workload + ); + } + + return 0; +} diff --git a/examples/README.md b/examples/README.md index eee0556..dfd96fc 100644 --- a/examples/README.md +++ b/examples/README.md @@ -30,3 +30,5 @@ Our examples look for `.hdf5` files in `/benchmarks/datasets/downloaded`. These - **`pdx_filtered.py`**: Filtered (predicated) search using `IndexPDXIVFTreeSQ8`. Demonstrates how to pass a set of allowed row IDs to `filtered_search()`, restricting results to only those vectors. Includes a correctness check verifying that returned IDs are a subset of the allowed set. - **`pdx_persist.py`**: Save and load a PDX index to/from disk. Builds an index, saves it with `index.save()`, then reloads it with `load_index()` and queries the restored index. + +- **`pdx_maintenance.py`**: Builds an index with 50% of the data, then inserts the rest of the data and query the index. Recall is maintaned and maintenance is very lightweight. diff --git a/examples/pdx_maintenance.py b/examples/pdx_maintenance.py new file mode 100644 index 0000000..d7490f5 --- /dev/null +++ b/examples/pdx_maintenance.py @@ -0,0 +1,60 @@ +import os +import numpy as np +from examples_utils import TicToc, read_hdf5_data +from pdxearch import IndexPDXIVFTreeSQ8 + +np.random.seed(42) + +""" +PDXearch maintenance example: build with 50% of data, then append the remaining 50%. +Uses a two-level IVF index with 8-bit scalar quantization (U8). +Download the .hdf5 data here: https://drive.google.com/drive/folders/1f76UCrU52N2wToGMFg9ir1MY8ZocrN34?usp=sharing +""" +if __name__ == "__main__": + dataset_name = 'agnews-mxbai-1024-euclidean.hdf5' + num_dimensions = 1024 + nprobe = 25 + knn = 20 + print(f'Running example: PDXearch Maintenance (Build 50% + Append 50%)') + print(f'- D={num_dimensions}, k={knn}, nprobe={nprobe}, dataset={dataset_name}') + train, queries = read_hdf5_data(os.path.join('./benchmarks/datasets/downloaded', dataset_name)) + + n = len(train) + n_build = n // 2 + n_append = n - n_build + + # Build index with first 50% of data + index = IndexPDXIVFTreeSQ8(num_dimensions=num_dimensions, normalize=True) + print(f'\nBuilding index with {n_build}/{n} embeddings...') + clock = TicToc() + clock.tic() + index.build(train[:n_build]) + build_time = clock.toc() + print(f'Build time: {build_time:.1f} ms') + print(f'Clusters: {index.num_clusters}') + print(f'Index size: {index.in_memory_size_bytes / (1024 * 1024):.2f} MB') + + # Append remaining 50% one by one + print(f'\nAppending {n_append} embeddings...') + clock.tic() + for i in range(n_build, n): + index.append(i, train[i]) + append_time = clock.toc() + print(f'Append time: {append_time:.1f} ms ({append_time / n_append:.2f} ms/embedding)') + print(f'Clusters after append: {index.num_clusters}') + print(f'Index size after append: {index.in_memory_size_bytes / (1024 * 1024):.2f} MB') + + # Search + print(f'\nSearching {len(queries)} queries...') + times = [] + clock = TicToc() + for i in range(len(queries)): + clock.tic() + index.search(queries[i], knn, nprobe=nprobe) + times.append(clock.toc()) + print(f'Median search time: {np.median(np.array(times)):.3f} ms') + + # Show results of first query + ids, dists = index.search(queries[0], knn, nprobe=nprobe) + print(f'\nFirst query results (ids): {ids[:10]}') + print(f'First query results (dists): {dists[:10]}') diff --git a/extern/SuperKMeans b/extern/SuperKMeans index 4a8ce02..11b86f8 160000 --- a/extern/SuperKMeans +++ b/extern/SuperKMeans @@ -1 +1 @@ -Subproject commit 4a8ce028f3fef8f276dc187cd46f12e76c75f21e +Subproject commit 11b86f8936c597450487114291641c062110c2b2 diff --git a/include/pdx/cluster.hpp b/include/pdx/cluster.hpp new file mode 100644 index 0000000..548ffd4 --- /dev/null +++ b/include/pdx/cluster.hpp @@ -0,0 +1,429 @@ +#pragma once + +#include "pdx/common.hpp" +#include "pdx/profiler.hpp" +#include "pdx/utils.hpp" +#include +#include +#include +#include +#include +#include +#include +#include + +namespace PDX { + +template +struct Cluster { + using data_t = pdx_data_t; + using tombstones_t = std::unordered_set; + + constexpr static float CAPACITY_THRESHOLD = 1.3f; // 30% more than the current capacity + constexpr static float MIN_CAPACITY_THRESHOLD = 0.5f; + constexpr static uint32_t MIN_MAX_CAPACITY = 256; + + Cluster(uint32_t num_embeddings, uint32_t num_dimensions) + : num_embeddings(num_embeddings), used_capacity(num_embeddings), + max_capacity( + std::max(static_cast(num_embeddings * CAPACITY_THRESHOLD), MIN_MAX_CAPACITY) + ), + min_capacity(static_cast(num_embeddings * MIN_CAPACITY_THRESHOLD)), + num_dimensions(num_dimensions), indices(new uint32_t[max_capacity]), + data(new data_t[static_cast(max_capacity) * num_dimensions]) {} + + Cluster(uint32_t num_embeddings, uint32_t max_capacity, uint32_t num_dimensions) + : num_embeddings(num_embeddings), used_capacity(num_embeddings), max_capacity(max_capacity), + min_capacity(static_cast(num_embeddings * MIN_CAPACITY_THRESHOLD)), + num_dimensions(num_dimensions), indices(new uint32_t[max_capacity]), + data(new data_t[static_cast(max_capacity) * num_dimensions]) {} + + Cluster(Cluster&& other) noexcept + : num_embeddings(other.num_embeddings), used_capacity(other.used_capacity), + max_capacity(other.max_capacity), min_capacity(other.min_capacity), + num_dimensions(other.num_dimensions), n_accessed(other.n_accessed), + n_inserted(other.n_inserted), n_deleted(other.n_deleted), id(other.id), + mesocluster_id(other.mesocluster_id), indices(other.indices), data(other.data), + tombstones(std::move(other.tombstones)) { + other.indices = nullptr; + other.data = nullptr; + } + + ~Cluster() { + delete[] data; + delete[] indices; + } + + Cluster(const Cluster&) = delete; + Cluster& operator=(const Cluster&) = delete; + + // Move-assignment: transfers data ownership, keeps destination's mutex and const + // num_dimensions. Caller must ensure no concurrent access to *this during assignment. + Cluster& operator=(Cluster&& other) noexcept { + if (this != &other) { + assert(num_dimensions == other.num_dimensions); + delete[] data; + delete[] indices; + + num_embeddings = other.num_embeddings; + used_capacity = other.used_capacity; + max_capacity = other.max_capacity; + min_capacity = other.min_capacity; + // num_dimensions: const and guaranteed same — skip + // cluster_mutex: keep our own — skip + n_accessed = other.n_accessed; + n_inserted = other.n_inserted; + n_deleted = other.n_deleted; + id = other.id; + mesocluster_id = other.mesocluster_id; + indices = other.indices; + data = other.data; + tombstones = std::move(other.tombstones); + + other.indices = nullptr; + other.data = nullptr; + } + return *this; + } + + uint32_t num_embeddings{ + }; // Number of valid embeddings in the cluster, i.e., excluding tombstones + uint32_t used_capacity{}; // Total capacity of the cluster, i.e., including tombstones but + // excluding empty slots that are not yet used + uint32_t max_capacity{}; + uint32_t min_capacity{}; + const uint32_t num_dimensions{}; + std::mutex cluster_mutex; + size_t n_accessed = 0; + size_t n_inserted = 0; + size_t n_deleted = 0; + uint32_t id{}; // Position in IVF::clusters vector + uint32_t mesocluster_id{}; // Which L0 meso-cluster contains this L1 cluster (L1 only) + + uint32_t* indices = nullptr; // !These are row_ids + data_t* data = nullptr; + tombstones_t tombstones; // ! Need to have indexes, not row_ids + + void AddTombstone(uint32_t index) { tombstones.insert(index); } + + bool HasTombstone(uint32_t index) const { return tombstones.count(index); } + + uint32_t PopTombstone() { + auto it = tombstones.begin(); + uint32_t val = *it; + tombstones.erase(it); + return val; + } + + void RemoveTombstone(uint32_t index) { tombstones.erase(index); } + + // Returns the index in cluster of the newly appended embedding + uint32_t AppendEmbedding(uint32_t row_id, const data_t* PDX_RESTRICT embedding) { + PDX_PROFILE_SCOPE("LeafAppend"); + std::lock_guard lock(cluster_mutex); + uint32_t next_free_idx = used_capacity; + bool replaced_tombstone = false; + if (!tombstones.empty()) { + next_free_idx = PopTombstone(); + replaced_tombstone = true; + } + if (next_free_idx >= max_capacity) { + throw std::runtime_error( + "AppendEmbedding: cluster buffer overflow (used_capacity=" + + std::to_string(used_capacity) + ", max_capacity=" + std::to_string(max_capacity) + + ")" + ); + } + InsertEmbedding(next_free_idx, row_id, embedding); + num_embeddings++; + if (!replaced_tombstone) { + used_capacity++; + } + assert(num_embeddings <= used_capacity); + + n_inserted++; + return next_free_idx; + } + + void DeleteEmbedding(uint32_t index_in_cluster) { + PDX_PROFILE_SCOPE("LeafDelete"); + std::lock_guard lock(cluster_mutex); + AddTombstone(index_in_cluster); + num_embeddings--; + n_deleted++; + } + + size_t GetInMemorySizeInBytes() const { + return sizeof(*this) + num_embeddings * sizeof(*indices) + + num_embeddings * static_cast(num_dimensions) * sizeof(*data); + } + + // Gather all embeddings from the PDX layout into a contiguous row-major buffer. + // Assumes no tombstones (call CompactCluster first). + // Uses blocked transpose for the vertical block and group-first iteration for + // the horizontal block to maximise cache locality on the source side. + std::unique_ptr GetHorizontalEmbeddingsFromPDXBuffer() const { + std::unique_ptr out( + new data_t[static_cast(num_embeddings) * num_dimensions] + ); + for (uint32_t i = 0; i < num_embeddings; i++) { + ReadEmbeddingFromPDXBuffer(i, out.get() + static_cast(i) * num_dimensions); + } + return out; + } + + // Gather a single embedding from the PDX layout into a row-major buffer. + std::unique_ptr GetHorizontalEmbeddingFromPDXBuffer(uint32_t idx_in_cluster) const { + PDX_PROFILE_SCOPE("DePDXify-One"); + std::unique_ptr out(new data_t[num_dimensions]); + ReadEmbeddingFromPDXBuffer(idx_in_cluster, out.get()); + return out; + } + + // Writes the valid PDX data row-by-row, stripping stride gaps. + // Assumes no tombstones (call CompactCluster first). + void SavePDXData(std::ostream& out) const { + const auto split = GetPDXDimensionSplit(num_dimensions); + const uint32_t vertical_d = split.vertical_dimensions; + const uint32_t horizontal_d = split.horizontal_dimensions; + const size_t stride = max_capacity; + + if constexpr (Q == Quantization::F32) { + for (uint32_t d = 0; d < vertical_d; d++) { + out.write( + reinterpret_cast(data + d * stride), + sizeof(data_t) * num_embeddings + ); + } + } else { + uint32_t d = 0; + for (; d + U8_INTERLEAVE_SIZE <= vertical_d; d += U8_INTERLEAVE_SIZE) { + out.write( + reinterpret_cast(data + d * stride), + num_embeddings * U8_INTERLEAVE_SIZE + ); + } + if (d < vertical_d) { + uint32_t remaining = vertical_d - d; + out.write( + reinterpret_cast(data + d * stride), num_embeddings * remaining + ); + } + } + + const data_t* h_base = data + stride * vertical_d; + for (uint32_t j = 0; j < horizontal_d; j += H_DIM_SIZE) { + out.write( + reinterpret_cast(h_base), sizeof(data_t) * num_embeddings * H_DIM_SIZE + ); + h_base += stride * H_DIM_SIZE; + } + } + + // Reads compact PDX data from ptr and places it into the strided buffer. + // Advances ptr past all read data. + void LoadPDXData(char*& ptr) { + const auto split = GetPDXDimensionSplit(num_dimensions); + const uint32_t vertical_d = split.vertical_dimensions; + const uint32_t horizontal_d = split.horizontal_dimensions; + const size_t stride = max_capacity; + + if constexpr (Q == Quantization::F32) { + for (uint32_t d = 0; d < vertical_d; d++) { + memcpy(data + d * stride, ptr, sizeof(data_t) * num_embeddings); + ptr += sizeof(data_t) * num_embeddings; + } + } else { + uint32_t d = 0; + for (; d + U8_INTERLEAVE_SIZE <= vertical_d; d += U8_INTERLEAVE_SIZE) { + memcpy(data + d * stride, ptr, num_embeddings * U8_INTERLEAVE_SIZE); + ptr += num_embeddings * U8_INTERLEAVE_SIZE; + } + if (d < vertical_d) { + uint32_t remaining = vertical_d - d; + memcpy(data + d * stride, ptr, num_embeddings * remaining); + ptr += num_embeddings * remaining; + } + } + + data_t* h_base = data + stride * vertical_d; + for (uint32_t j = 0; j < horizontal_d; j += H_DIM_SIZE) { + memcpy(h_base, ptr, sizeof(data_t) * num_embeddings * H_DIM_SIZE); + ptr += sizeof(data_t) * num_embeddings * H_DIM_SIZE; + h_base += stride * H_DIM_SIZE; + } + } + + // Caller must hold cluster_mutex + // Returns: vector of (row_id, new_index_in_cluster) for each moved embedding + // TODO(@lkuffo, med): I dont like this while loops too much. Its confusing (but it works) + std::vector> CompactCluster() { + PDX_PROFILE_SCOPE("CompactCluster"); + std::vector> moves; + if (tombstones.empty()) { + return moves; + } + moves.reserve(tombstones.size()); + + // shrink past any tombstoned tail positions (no data movement needed) + while (used_capacity > 0 && HasTombstone(used_capacity - 1)) { + RemoveTombstone(used_capacity - 1); + indices[used_capacity - 1] = 0; + used_capacity--; + } + + // fill remaining interior tombstones by moving from the tail + while (!tombstones.empty()) { + uint32_t tombstone_idx = PopTombstone(); + uint32_t last_idx = used_capacity - 1; + CopyEmbeddingInPDXLayout(last_idx, tombstone_idx); + indices[tombstone_idx] = indices[last_idx]; + moves.emplace_back(indices[tombstone_idx], tombstone_idx); + indices[last_idx] = 0; + used_capacity--; + // The new tail might also be a tombstone, drain it + while (used_capacity > 0 && HasTombstone(used_capacity - 1)) { + RemoveTombstone(used_capacity - 1); + indices[used_capacity - 1] = 0; + used_capacity--; + } + } + + assert(num_embeddings == used_capacity); + return moves; + } + + private: + // Gather-reads one embedding from the transposed PDX buffer into a horizontal (row-major) + // output. Reverse of InsertEmbedding. + void ReadEmbeddingFromPDXBuffer(uint32_t idx_in_cluster, data_t* out) const { + const auto split = GetPDXDimensionSplit(num_dimensions); + const uint32_t vertical_d = split.vertical_dimensions; + const uint32_t horizontal_d = split.horizontal_dimensions; + const size_t stride = max_capacity; + + if constexpr (Q == Quantization::F32) { + for (uint32_t d = 0; d < vertical_d; d++) { + out[d] = data[d * stride + idx_in_cluster]; + } + } else { + uint32_t d = 0; + for (; d + U8_INTERLEAVE_SIZE <= vertical_d; d += U8_INTERLEAVE_SIZE) { + memcpy( + out + d, + data + d * stride + static_cast(idx_in_cluster) * U8_INTERLEAVE_SIZE, + U8_INTERLEAVE_SIZE + ); + } + if (d < vertical_d) { + uint32_t remaining = vertical_d - d; + memcpy( + out + d, + data + d * stride + static_cast(idx_in_cluster) * remaining, + remaining + ); + } + } + + const data_t* h_base = data + stride * vertical_d; + for (uint32_t j = 0; j < horizontal_d; j += H_DIM_SIZE) { + memcpy( + out + vertical_d + j, + h_base + static_cast(idx_in_cluster) * H_DIM_SIZE, + H_DIM_SIZE * sizeof(data_t) + ); + h_base += stride * H_DIM_SIZE; + } + } + + // Scatter-writes a horizontal (row-major) embedding into the transposed PDX buffer layout. + // This function assumes thread safety (caller must hold cluster_mutex). + void InsertEmbedding(uint32_t idx_in_cluster, uint32_t row_id, const data_t* embedding) { + const auto split = GetPDXDimensionSplit(num_dimensions); + const uint32_t vertical_d = split.vertical_dimensions; + const uint32_t horizontal_d = split.horizontal_dimensions; + const size_t stride = max_capacity; + + if constexpr (Q == Quantization::F32) { + // Vertical: column-major, one float per dimension row + for (uint32_t d = 0; d < vertical_d; d++) { + data[d * stride + idx_in_cluster] = embedding[d]; + } + } else { + // U8 Vertical: interleaved in groups of U8_INTERLEAVE_SIZE + uint32_t d = 0; + for (; d + U8_INTERLEAVE_SIZE <= vertical_d; d += U8_INTERLEAVE_SIZE) { + memcpy( + data + d * stride + static_cast(idx_in_cluster) * U8_INTERLEAVE_SIZE, + embedding + d, + U8_INTERLEAVE_SIZE + ); + } + if (d < vertical_d) { + uint32_t remaining = vertical_d - d; + memcpy( + data + d * stride + static_cast(idx_in_cluster) * remaining, + embedding + d, + remaining + ); + } + } + + // Horizontal: groups of H_DIM_SIZE, row-major within each group + data_t* h_base = data + stride * vertical_d; + for (uint32_t j = 0; j < horizontal_d; j += H_DIM_SIZE) { + memcpy( + h_base + static_cast(idx_in_cluster) * H_DIM_SIZE, + embedding + vertical_d + j, + H_DIM_SIZE * sizeof(data_t) + ); + h_base += stride * H_DIM_SIZE; + } + + indices[idx_in_cluster] = row_id; + } + + // Copies an embedding within the PDX buffer from one position to another. + // This function assumes thread safety (caller must hold cluster_mutex). + void CopyEmbeddingInPDXLayout(uint32_t src_idx, uint32_t dst_idx) { + const auto split = GetPDXDimensionSplit(num_dimensions); + const uint32_t vertical_d = split.vertical_dimensions; + const uint32_t horizontal_d = split.horizontal_dimensions; + const size_t stride = max_capacity; + + if constexpr (Q == Quantization::F32) { + for (uint32_t d = 0; d < vertical_d; d++) { + data[d * stride + dst_idx] = data[d * stride + src_idx]; + } + } else { + uint32_t d = 0; + for (; d + U8_INTERLEAVE_SIZE <= vertical_d; d += U8_INTERLEAVE_SIZE) { + memcpy( + data + d * stride + static_cast(dst_idx) * U8_INTERLEAVE_SIZE, + data + d * stride + static_cast(src_idx) * U8_INTERLEAVE_SIZE, + U8_INTERLEAVE_SIZE + ); + } + if (d < vertical_d) { + uint32_t remaining = vertical_d - d; + memcpy( + data + d * stride + static_cast(dst_idx) * remaining, + data + d * stride + static_cast(src_idx) * remaining, + remaining + ); + } + } + + data_t* h_base = data + stride * vertical_d; + for (uint32_t j = 0; j < horizontal_d; j += H_DIM_SIZE) { + memcpy( + h_base + static_cast(dst_idx) * H_DIM_SIZE, + h_base + static_cast(src_idx) * H_DIM_SIZE, + H_DIM_SIZE * sizeof(data_t) + ); + h_base += stride * H_DIM_SIZE; + } + } +}; + +} // namespace PDX diff --git a/include/pdx/clustering.hpp b/include/pdx/clustering.hpp index 48fa3bf..e2981fb 100644 --- a/include/pdx/clustering.hpp +++ b/include/pdx/clustering.hpp @@ -35,7 +35,8 @@ struct KMeansResult { const bool normalize = false, const float sampling_fraction = 0.0f, const uint32_t kmeans_iters = 8, - const bool hierarchical_indexing = true + const bool hierarchical_indexing = true, + const uint32_t n_threads = 0 ) { assert(num_embeddings >= 1); assert(num_dimensions >= 1); @@ -76,11 +77,17 @@ struct KMeansResult { config.iters_refinement = 0; config.seed = seed; // config.verbose = true; - config.n_threads = PDX::g_n_threads; + config.n_threads = n_threads > 0 ? n_threads : PDX::g_n_threads; auto kmeans = skmeans::HierarchicalSuperKMeans(num_clusters, num_dimensions, config); result.centroids = kmeans.Train(embeddings, num_embeddings); - assignments = - kmeans.FastAssign(embeddings, result.centroids.data(), num_embeddings, num_clusters); + if (num_clusters > skmeans::N_CLUSTERS_THRESHOLD_FOR_PRUNING) { + assignments = kmeans.FastAssign( + embeddings, result.centroids.data(), num_embeddings, num_clusters + ); + } else { + assignments = + kmeans.Assign(embeddings, result.centroids.data(), num_embeddings, num_clusters); + } } else { skmeans::SuperKMeansConfig config; config.sampling_fraction = chosen_sampling_fraction; @@ -90,11 +97,17 @@ struct KMeansResult { config.iters = kmeans_iters; config.seed = seed; // config.verbose = true; - config.n_threads = PDX::g_n_threads; + config.n_threads = n_threads > 0 ? n_threads : PDX::g_n_threads; auto kmeans = skmeans::SuperKMeans(num_clusters, num_dimensions, config); result.centroids = kmeans.Train(embeddings, num_embeddings); - assignments = - kmeans.FastAssign(embeddings, result.centroids.data(), num_embeddings, num_clusters); + if (num_clusters > skmeans::N_CLUSTERS_THRESHOLD_FOR_PRUNING) { + assignments = kmeans.FastAssign( + embeddings, result.centroids.data(), num_embeddings, num_clusters + ); + } else { + assignments = + kmeans.Assign(embeddings, result.centroids.data(), num_embeddings, num_clusters); + } } // Convert from vec_id -> centroid_idx into centroid_idx -> vec_id diff --git a/include/pdx/common.hpp b/include/pdx/common.hpp index 71a9f32..cc83270 100644 --- a/include/pdx/common.hpp +++ b/include/pdx/common.hpp @@ -48,6 +48,8 @@ static constexpr uint32_t DIMENSIONS_FETCHING_SIZES[20] = {16, 16, 32, 32, 64, 64, 64, 128, 128, 128, 128, 256, 256, 512, 1024, 2048, 16384}; +static constexpr float CENTROID_PERTURBATION_EPS = 1.0f / 1024.0f; + static constexpr bool AllFetchingSizesMultipleOfU8InterleaveSize() { for (auto s : DIMENSIONS_FETCHING_SIZES) { if (s % U8_INTERLEAVE_SIZE != 0) { @@ -75,7 +77,6 @@ enum Quantization { F32, U8, F16, BF }; enum class PDXIndexType : uint8_t { PDX_F32 = 0, PDX_U8 = 1, PDX_TREE_F32 = 2, PDX_TREE_U8 = 3 }; -// TODO: Do the same for indexes? template struct DistanceType { using type = uint32_t; @@ -87,7 +88,6 @@ struct DistanceType { template using pdx_distance_t = typename DistanceType::type; -// TODO: Do the same for indexes? template struct DataType { using type = uint8_t; // U8 diff --git a/include/pdx/index.hpp b/include/pdx/index.hpp index 6e3b933..bd26103 100644 --- a/include/pdx/index.hpp +++ b/include/pdx/index.hpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -14,6 +15,7 @@ #include "pdx/common.hpp" #include "pdx/ivf_wrapper.hpp" #include "pdx/layout.hpp" +#include "pdx/profiler.hpp" #include "pdx/pruners/adsampling.hpp" #include "pdx/quantizers/scalar.hpp" #include "pdx/searcher.hpp" @@ -83,7 +85,7 @@ inline std::unique_ptr NormalizeAndRotate( if (normalize) { normalized.reset(new float[total_floats]); Quantizer quantizer(num_dimensions); -#pragma omp parallel for num_threads(PDX::g_n_threads) +#pragma omp parallel for if (num_embeddings > 1) num_threads(PDX::g_n_threads) for (size_t i = 0; i < num_embeddings; i++) { quantizer.NormalizeQuery( embeddings + i * num_dimensions, normalized.get() + i * num_dimensions @@ -117,6 +119,7 @@ void PopulateIVFClusters( // Pre-allocate all clusters sequentially for (size_t cluster_idx = 0; cluster_idx < num_clusters; cluster_idx++) { ivf.clusters.emplace_back(kmeans_result.assignments[cluster_idx].size(), num_dimensions); + ivf.clusters[cluster_idx].id = cluster_idx; } // Per-thread tmp buffers for gather + quantize @@ -155,6 +158,8 @@ void PopulateIVFClusters( } StoreClusterEmbeddings(cluster, ivf, tmp, cluster_size); } + + ivf.ComputeClusterOffsets(); } class IPDXIndex { @@ -168,19 +173,26 @@ class IPDXIndex { ) const = 0; virtual void BuildIndex(const float* embeddings, size_t num_embeddings) = 0; virtual void SetNProbe(uint32_t n_probe) const = 0; - virtual void Save(const std::string& path) const = 0; + virtual void Save(const std::string& path) = 0; virtual void Restore(const std::string& path) = 0; virtual uint32_t GetNumDimensions() const = 0; virtual uint32_t GetNumClusters() const = 0; virtual uint32_t GetClusterSize(uint32_t cluster_id) const = 0; virtual std::vector GetClusterRowIds(uint32_t cluster_id) const = 0; virtual size_t GetInMemorySizeInBytes() const = 0; + virtual void Append(size_t /*row_id*/, const float* /*embedding*/) { + throw std::runtime_error("Append is not supported by this index type. Use PDXTreeIndex."); + } + virtual void Delete(size_t /*row_id*/) { + throw std::runtime_error("Delete is not supported by this index type. Use PDXTreeIndex."); + } }; template class PDXIndex : public IPDXIndex { public: using embedding_storage_t = PDX::pdx_data_t; + using cluster_t = PDX::Cluster; private: PDXIndexConfig config{}; @@ -189,38 +201,6 @@ class PDXIndex : public IPDXIndex { std::unique_ptr> searcher; std::vector> row_id_cluster_mapping; - static constexpr PDXIndexType GetIndexType() { - if constexpr (Q == F32) - return PDXIndexType::PDX_F32; - else - return PDXIndexType::PDX_U8; - } - - void BuildRowIdClusterMapping() { - size_t total = 0; - for (size_t c = 0; c < index.num_clusters; c++) { - total += index.clusters[c].num_embeddings; - } - row_id_cluster_mapping.resize(total); - for (uint32_t c = 0; c < index.num_clusters; c++) { - for (uint32_t p = 0; p < index.clusters[c].num_embeddings; p++) { - row_id_cluster_mapping[index.clusters[c].indices[p]] = {c, p}; - } - } - } - - PDX::PredicateEvaluator CreatePredicateEvaluator(const std::vector& passing_row_ids - ) const { - PDX::PredicateEvaluator evaluator(index.num_clusters, row_id_cluster_mapping.size()); - for (const auto row_id : passing_row_ids) { - const auto& [cluster_id, index_in_cluster] = row_id_cluster_mapping[row_id]; - evaluator.n_passing_tuples[cluster_id]++; - evaluator.selection_vector[searcher->cluster_offsets[cluster_id] + index_in_cluster] = - 1; - } - return evaluator; - } - public: PDXIndex() = default; @@ -230,10 +210,17 @@ class PDXIndex : public IPDXIndex { pruner = std::make_unique(config.num_dimensions, config.seed); } - void Save(const std::string& path) const override { + void Save(const std::string& path) override { + // Compact all clusters before saving + for (uint32_t c = 0; c < index.num_clusters; c++) { + auto moves = index.clusters[c].CompactCluster(); + for (const auto& [row_id, new_idx] : moves) { + row_id_cluster_mapping[row_id] = {c, new_idx}; + } + } + std::ofstream out(path, std::ios::binary); - // Index type flag uint8_t type_flag = static_cast(GetIndexType()); out.write(reinterpret_cast(&type_flag), sizeof(uint8_t)); @@ -306,7 +293,14 @@ class PDXIndex : public IPDXIndex { std::vector GetClusterRowIds(uint32_t cluster_id) const override { const auto& cluster = index.clusters[cluster_id]; - return {cluster.indices, cluster.indices + cluster.num_embeddings}; + std::vector row_ids; + row_ids.reserve(cluster.num_embeddings); + for (uint32_t i = 0; i < cluster.used_capacity; i++) { + if (!cluster.HasTombstone(i)) { + row_ids.push_back(cluster.indices[i]); + } + } + return row_ids; } size_t GetInMemorySizeInBytes() const override { @@ -324,10 +318,8 @@ class PDXIndex : public IPDXIndex { size += pruner->num_dimensions * sizeof(uint32_t); // flip_masks } } - // Searcher: cluster_offsets array if (searcher) { size += sizeof(*searcher); - size += index.num_clusters * sizeof(size_t); } // Row ID to cluster mapping size += row_id_cluster_mapping.capacity() * sizeof(std::pair); @@ -409,26 +401,23 @@ class PDXIndex : public IPDXIndex { searcher = std::make_unique>(index, *pruner); BuildRowIdClusterMapping(); } -}; -template -class PDXTreeIndex : public IPDXIndex { - public: - using embedding_storage_t = PDX::pdx_data_t; + void Append(size_t /*row_id*/, const float* /*embedding*/) override { + throw std::runtime_error("Append is not implemented in PDXIndex. Use PDXTreeIndex instead." + ); + } - private: - PDXIndexConfig config{}; - PDX::IVFTree index; - std::unique_ptr pruner; - std::unique_ptr> searcher; - std::unique_ptr> top_level_searcher; - std::vector> row_id_cluster_mapping; + void Delete(size_t /*row_id*/) override { + throw std::runtime_error("Delete is not implemented in PDXIndex. Use PDXTreeIndex instead." + ); + } + private: static constexpr PDXIndexType GetIndexType() { if constexpr (Q == F32) - return PDXIndexType::PDX_TREE_F32; + return PDXIndexType::PDX_F32; else - return PDXIndexType::PDX_TREE_U8; + return PDXIndexType::PDX_U8; } void BuildRowIdClusterMapping() { @@ -446,26 +435,64 @@ class PDXTreeIndex : public IPDXIndex { PDX::PredicateEvaluator CreatePredicateEvaluator(const std::vector& passing_row_ids ) const { - PDX::PredicateEvaluator evaluator(index.num_clusters, row_id_cluster_mapping.size()); + PDX_PROFILE_SCOPE("PredicateEvaluator"); + PDX::PredicateEvaluator evaluator(index.num_clusters, index.total_capacity); for (const auto row_id : passing_row_ids) { const auto& [cluster_id, index_in_cluster] = row_id_cluster_mapping[row_id]; evaluator.n_passing_tuples[cluster_id]++; - evaluator.selection_vector[searcher->cluster_offsets[cluster_id] + index_in_cluster] = - 1; + evaluator.selection_vector[index.cluster_offsets[cluster_id] + index_in_cluster] = 1; } return evaluator; } +}; + +template +class PDXTreeIndex : public IPDXIndex { + public: + using embedding_storage_t = PDX::pdx_data_t; + using cluster_t = PDX::Cluster; + using distance_computer_t = DistanceComputer; + using distance_computer_f32_t = DistanceComputer; + using batch_computer = + skmeans::BatchComputer; + using MatrixR = Eigen::Matrix; + using VectorR = Eigen::VectorXf; + + private: + static constexpr uint32_t DELETED_MARKER = std::numeric_limits::max(); + + PDXIndexConfig config{}; + uint32_t d = 0; + PDX::IVFTree index; + std::unique_ptr pruner; + std::unique_ptr> searcher; + std::unique_ptr> top_level_searcher; + ScalarQuantizer quantizer{0}; + std::vector> row_id_cluster_mapping; public: PDXTreeIndex() = default; - explicit PDXTreeIndex(PDXIndexConfig config) : config(config) { + explicit PDXTreeIndex(PDXIndexConfig config) + : config(config), d(config.num_dimensions), quantizer(config.num_dimensions) { config.Validate(); PDX::g_n_threads = (config.n_threads == 0) ? omp_get_max_threads() : config.n_threads; pruner = std::make_unique(config.num_dimensions, config.seed); } - void Save(const std::string& path) const override { + void Save(const std::string& path) override { + // Compact L1 clusters before saving (update row_id_cluster_mapping from moves) + for (uint32_t c = 0; c < index.num_clusters; c++) { + auto moves = index.clusters[c].CompactCluster(); + for (const auto& [row_id, new_idx] : moves) { + row_id_cluster_mapping[row_id] = {c, new_idx}; + } + } + // Compact L0 clusters (no mapping to update for meso-clusters) + for (uint32_t c = 0; c < index.l0.num_clusters; c++) { + index.l0.clusters[c].CompactCluster(); + } + std::ofstream out(path, std::ios::binary); // Index type flag @@ -506,16 +533,21 @@ class PDXTreeIndex : public IPDXIndex { // Load IVFTree data index.Load(ptr); + d = index.num_dimensions; // Create pruner and searchers - pruner = - std::make_unique(index.num_dimensions, aligned_matrix.get()); + pruner = std::make_unique(d, aligned_matrix.get()); searcher = std::make_unique>(index, *pruner); top_level_searcher = std::make_unique>(index.l0, *pruner); BuildRowIdClusterMapping(); } std::vector Search(const float* query_embedding, size_t knn) const override { + PDX_PROFILE_SCOPE("Search"); + auto n_probe = searcher->GetNProbe(); + if (n_probe == 0) { + searcher->SetNProbe(GetNumClusters()); + } auto n_probe_top_level = GetTopLevelNumClusters(); // We confidently prune half of the search space if (searcher->GetNProbe() < GetNumClusters() / 2) { @@ -539,7 +571,70 @@ class PDXTreeIndex : public IPDXIndex { const std::vector& passing_row_ids ) const override { auto evaluator = CreatePredicateEvaluator(passing_row_ids); - return searcher->FilteredSearch(query_embedding, knn, evaluator); + { + PDX_PROFILE_SCOPE("FilteredSearch"); + return searcher->FilteredSearch(query_embedding, knn, evaluator); + } + } + + // Concurrent writes must always go through a single writer thread + void Append(size_t row_id, const float* PDX_RESTRICT embedding) override { + PDX_PROFILE_SCOPE("Append"); + if (row_id != row_id_cluster_mapping.size()) { + throw std::invalid_argument( + "Append: row_id " + std::to_string(row_id) + " is not sequential (expected " + + std::to_string(row_id_cluster_mapping.size()) + ")" + ); + } + ReserveClusterSlotIfNeeded(); + + const bool normalize = + config.normalize || DistanceMetricRequiresNormalization(config.distance_metric); + + auto preprocessed = NormalizeAndRotate(embedding, 1, d, normalize, *pruner); + + // Find nearest centroid for the new embedding + uint32_t closest_centroid_idx; + { + PDX_PROFILE_SCOPE("Append/FindNearestCentroid"); + auto n_probe_top_level = GetTopLevelNumClusters(); + // We confidently prune 1/8 of the search space + n_probe_top_level = std::max(1u, n_probe_top_level / 8); + top_level_searcher->SetNProbe(n_probe_top_level); + std::vector centroid_candidates = + top_level_searcher->Search(preprocessed.get(), 1, true); + closest_centroid_idx = centroid_candidates[0].index; + } + + auto& cluster = index.clusters[closest_centroid_idx]; + + uint32_t new_index_in_cluster = + QuantizeAndAppend(cluster, static_cast(row_id), preprocessed.get()); + row_id_cluster_mapping.emplace_back(closest_centroid_idx, new_index_in_cluster); + index.total_num_embeddings++; + CheckClusterHealth(cluster); + } + + // Concurrent deletes must always go through a single writer thread + void Delete(size_t row_id) override { + PDX_PROFILE_SCOPE("Delete"); + if (row_id >= row_id_cluster_mapping.size()) { + throw std::invalid_argument( + "Delete: row_id " + std::to_string(row_id) + " is not in the index" + ); + } + const auto& [cluster_id, index_in_cluster] = row_id_cluster_mapping[row_id]; + if (cluster_id == DELETED_MARKER) { + throw std::invalid_argument( + "Delete: row_id " + std::to_string(row_id) + " was already deleted" + ); + } + ReserveClusterSlotIfNeeded(); + auto& cluster = index.clusters[cluster_id]; + cluster.DeleteEmbedding(index_in_cluster); + row_id_cluster_mapping[row_id] = {DELETED_MARKER, DELETED_MARKER}; + index.total_num_embeddings--; + CheckClusterHealth(cluster); } void BuildIndex(const float* const embeddings, const size_t num_embeddings) override { @@ -651,6 +746,13 @@ class PDXTreeIndex : public IPDXIndex { 1.0f ); + // Set mesocluster_id on each L1 cluster from L0 kmeans assignments + for (uint32_t mc = 0; mc < l0_num_clusters; mc++) { + for (uint32_t l1_id : l0_kmeans_result.assignments[mc]) { + index.clusters[l1_id].mesocluster_id = mc; + } + } + top_level_searcher = std::make_unique>(index.l0, *pruner); BuildRowIdClusterMapping(); } @@ -659,7 +761,7 @@ class PDXTreeIndex : public IPDXIndex { const PDX::PDXearch& GetSearcher() const { return *searcher; } - uint32_t GetNumDimensions() const override { return index.num_dimensions; } + uint32_t GetNumDimensions() const override { return d; } uint32_t GetNumClusters() const override { return index.num_clusters; } @@ -669,7 +771,14 @@ class PDXTreeIndex : public IPDXIndex { std::vector GetClusterRowIds(uint32_t cluster_id) const override { const auto& cluster = index.clusters[cluster_id]; - return {cluster.indices, cluster.indices + cluster.num_embeddings}; + std::vector row_ids; + row_ids.reserve(cluster.num_embeddings); + for (uint32_t i = 0; i < cluster.used_capacity; i++) { + if (!cluster.HasTombstone(i)) { + row_ids.push_back(cluster.indices[i]); + } + } + return row_ids; } uint32_t GetTopLevelNumClusters() const { return index.l0.num_clusters; } @@ -689,20 +798,934 @@ class PDXTreeIndex : public IPDXIndex { size += pruner->num_dimensions * sizeof(uint32_t); // flip_masks } } - // L1 searcher: cluster_offsets array if (searcher) { size += sizeof(*searcher); - size += index.num_clusters * sizeof(size_t); } - // L0 top-level searcher: cluster_offsets array if (top_level_searcher) { size += sizeof(*top_level_searcher); - size += index.l0.num_clusters * sizeof(size_t); } // Row ID to cluster mapping size += row_id_cluster_mapping.capacity() * sizeof(std::pair); return size; } + + private: + static constexpr PDXIndexType GetIndexType() { + if constexpr (Q == F32) + return PDXIndexType::PDX_TREE_F32; + else + return PDXIndexType::PDX_TREE_U8; + } + + void BuildRowIdClusterMapping() { + size_t total = 0; + for (size_t c = 0; c < index.num_clusters; c++) { + total += index.clusters[c].num_embeddings; + } + row_id_cluster_mapping.resize(total); + for (uint32_t c = 0; c < index.num_clusters; c++) { + for (uint32_t p = 0; p < index.clusters[c].num_embeddings; p++) { + row_id_cluster_mapping[index.clusters[c].indices[p]] = {c, p}; + } + } + } + + PDX::PredicateEvaluator CreatePredicateEvaluator(const std::vector& passing_row_ids + ) const { + PDX_PROFILE_SCOPE("PredicateEvaluator"); + PDX::PredicateEvaluator evaluator(index.num_clusters, index.total_capacity); + for (const auto row_id : passing_row_ids) { + const auto& [cluster_id, index_in_cluster] = row_id_cluster_mapping[row_id]; + if (cluster_id == DELETED_MARKER) + continue; + evaluator.n_passing_tuples[cluster_id]++; + evaluator.selection_vector[index.cluster_offsets[cluster_id] + index_in_cluster] = 1; + } + return evaluator; + } + + // Ensure the clusters vector won't reallocate while we hold a reference + void ReserveClusterSlotIfNeeded() { + if (index.clusters.size() == index.clusters.capacity()) { + index.clusters.reserve(index.clusters.capacity() * 2); + } + } + + void ReserveL0ClusterSlotIfNeeded() { + if (index.l0.clusters.size() == index.l0.clusters.capacity()) { + index.l0.clusters.reserve(index.l0.clusters.capacity() * 2); + } + } + + // Dequantize raw (Q-type) embeddings to float. For F32 this is a memcpy. + std::unique_ptr DequantizeClusterEmbeddings( + const embedding_storage_t* raw_embeddings, + uint32_t n_emb + ) const { + PDX_PROFILE_SCOPE("Dequantize"); + std::unique_ptr result(new float[static_cast(n_emb) * d]); + if constexpr (Q == U8) { + for (size_t i = 0; i < n_emb; i++) { + searcher->quantizer.DequantizeEmbedding( + raw_embeddings + i * d, + index.quantization_base, + index.quantization_scale, + result.get() + i * d + ); + } + } else { + std::memcpy( + result.get(), raw_embeddings, static_cast(n_emb) * d * sizeof(float) + ); + } + return result; + } + + // Quantize (if U8) and append a float embedding to a cluster. + uint32_t QuantizeAndAppend(cluster_t& cluster, uint32_t row_id, const float* embedding) { + if constexpr (Q == U8) { + std::unique_ptr quantized(new embedding_storage_t[d]); + quantizer.QuantizeEmbedding( + embedding, index.quantization_base, index.quantization_scale, quantized.get() + ); + return cluster.AppendEmbedding(row_id, quantized.get()); + } else { + return cluster.AppendEmbedding(row_id, embedding); + } + } + + // Gather raw embeddings, row IDs, and accumulate centroid sum for a group of indices. + void GatherGroupEmbeddings( + const std::vector& group_idx, + const embedding_storage_t* raw_embeddings, + const float* float_embeddings, + const cluster_t& cluster, + std::vector& embs_out, + std::vector& ids_out, + float* centroid_sum + ) const { + for (uint32_t idx : group_idx) { + embs_out.insert( + embs_out.end(), + raw_embeddings + static_cast(idx) * d, + raw_embeddings + (static_cast(idx) + 1) * d + ); + ids_out.push_back(cluster.indices[idx]); + const float* emb_f = float_embeddings + static_cast(idx) * d; + for (size_t j = 0; j < d; j++) { + centroid_sum[j] += emb_f[j]; + } + } + } + + // Compute mean centroid from accumulated sum. Falls back to fallback if count == 0. + void ComputeCentroidMean( + const float* centroid_sum, + size_t count, + const float* fallback, + float* output + ) const { + if (count == 0) { + std::memcpy(output, fallback, d * sizeof(float)); + } else { + float inv = 1.0f / static_cast(count); +#pragma clang loop vectorize(enable) + for (size_t j = 0; j < d; j++) { + output[j] = centroid_sum[j] * inv; + } + } + const bool normalize = + config.normalize || DistanceMetricRequiresNormalization(config.distance_metric); + if (normalize) { + Quantizer q(d); + q.NormalizeQuery(output, output); + } + } + + // Get neighboring cluster IDs from the same meso-cluster, limited to max_neighbors nearest. + std::vector GetNearestNeighborClusterIds( + uint32_t cluster_id, + uint32_t mesocluster_id, + const float* centroid, + size_t max_neighbors = 32 + ) const { + PDX_PROFILE_SCOPE("GetNeighboringClusters"); + std::vector neighbor_ids; + auto& mesocluster = index.l0.clusters[mesocluster_id]; + for (uint32_t pos = 0; pos < mesocluster.used_capacity; pos++) { + if (mesocluster.HasTombstone(pos)) + continue; + uint32_t nid = mesocluster.indices[pos]; + if (nid == cluster_id) + continue; + neighbor_ids.push_back(nid); + } + if (neighbor_ids.size() > max_neighbors) { + std::vector> neighbor_dists; + neighbor_dists.reserve(neighbor_ids.size()); + for (uint32_t nid : neighbor_ids) { + float dist = distance_computer_f32_t::Horizontal( + centroid, index.centroids.data() + static_cast(nid) * d, d + ); + neighbor_dists.push_back({dist, nid}); + } + std::nth_element( + neighbor_dists.begin(), + neighbor_dists.begin() + max_neighbors, + neighbor_dists.end(), + [](const auto& a, const auto& b) { return a.first < b.first; } + ); + neighbor_ids.clear(); + for (size_t i = 0; i < max_neighbors; ++i) { + neighbor_ids.push_back(neighbor_dists[i].second); + } + } + return neighbor_ids; + } + + uint32_t FindPositionInMesoCluster(uint32_t l1_cluster_id, uint32_t mesocluster_id) const { + PDX_PROFILE_SCOPE("FindPositionInMesoCluster"); + auto& l0_cluster = index.l0.clusters[mesocluster_id]; + for (uint32_t idx = 0; idx < l0_cluster.used_capacity; idx++) { + if (!l0_cluster.HasTombstone(idx) && l0_cluster.indices[idx] == l1_cluster_id) { + return idx; + } + } + throw std::runtime_error( + "FindPositionInMesoCluster: L1 cluster " + std::to_string(l1_cluster_id) + + " not found in L0 meso-cluster " + std::to_string(mesocluster_id) + ); + } + + // ****************************************** + // L0 (Mesoclusters) Maintenance + // ****************************************** + + void CheckL0ClusterHealth(Cluster& l0_cluster, bool allow_merges = true) { + if (l0_cluster.used_capacity == l0_cluster.max_capacity) { + if (l0_cluster.num_embeddings < l0_cluster.used_capacity) { + l0_cluster.CompactCluster(); + } else { + SplitL0Cluster(l0_cluster); + } + } else if (allow_merges && l0_cluster.num_embeddings <= l0_cluster.min_capacity) { + DestroyAndMergeL0Cluster(l0_cluster); + } + } + + void SplitL0Cluster(Cluster& l0_cluster) { + PDX_PROFILE_SCOPE("SplitL0"); + const uint32_t l0_cluster_id = l0_cluster.id; + const uint32_t num_embeddings = l0_cluster.num_embeddings; + + // Gather L1 cluster IDs and their centroids + std::vector l1_ids(num_embeddings); + std::vector l1_centroids(static_cast(num_embeddings) * d); + for (uint32_t i = 0; i < num_embeddings; i++) { + uint32_t l1_id = l0_cluster.indices[i]; + l1_ids[i] = l1_id; + std::memcpy( + l1_centroids.data() + static_cast(i) * d, + index.centroids.data() + static_cast(l1_id) * d, + d * sizeof(float) + ); + } + + KMeansResult split_result = ComputeKMeans( + l1_centroids.data(), + num_embeddings, + d, + 2, + config.distance_metric, + config.seed, + true, + 1.0f, + 4, + false, + 1 + ); + auto& group_a = split_result.assignments[0]; + auto& group_b = split_result.assignments[1]; + + // Gather IDs and compute true centroids + std::vector ids_a(group_a.size()), ids_b(group_b.size()); + auto centroid_sum_a = std::make_unique(d); // zero-init needed + auto centroid_sum_b = std::make_unique(d); // zero-init needed + for (size_t i = 0; i < group_a.size(); i++) { + ids_a[i] = l1_ids[group_a[i]]; + const float* c = l1_centroids.data() + static_cast(group_a[i]) * d; + for (uint32_t j = 0; j < d; j++) + centroid_sum_a[j] += c[j]; + } + for (size_t i = 0; i < group_b.size(); i++) { + ids_b[i] = l1_ids[group_b[i]]; + const float* c = l1_centroids.data() + static_cast(group_b[i]) * d; + for (uint32_t j = 0; j < d; j++) + centroid_sum_b[j] += c[j]; + } + std::unique_ptr true_centroid_a(new float[d]); + std::unique_ptr true_centroid_b(new float[d]); + ComputeCentroidMean( + centroid_sum_a.get(), + group_a.size(), + split_result.centroids.data(), + true_centroid_a.get() + ); + ComputeCentroidMean( + centroid_sum_b.get(), + group_b.size(), + split_result.centroids.data() + d, + true_centroid_b.get() + ); + + // Create new L0 clusters + ReserveL0ClusterSlotIfNeeded(); + uint32_t new_l0_id = index.l0.num_clusters; + Cluster new_a(static_cast(ids_a.size()), d); + new_a.id = l0_cluster_id; + if (!ids_a.empty()) { + std::memcpy(new_a.indices, ids_a.data(), ids_a.size() * sizeof(uint32_t)); + std::vector centroids_a(ids_a.size() * d); + for (size_t i = 0; i < ids_a.size(); i++) { + std::memcpy( + centroids_a.data() + i * d, + index.centroids.data() + static_cast(ids_a[i]) * d, + d * sizeof(float) + ); + } + StoreClusterEmbeddings(new_a, index.l0, centroids_a.data(), ids_a.size()); + } + Cluster new_b(static_cast(ids_b.size()), d); + new_b.id = new_l0_id; + if (!ids_b.empty()) { + std::memcpy(new_b.indices, ids_b.data(), ids_b.size() * sizeof(uint32_t)); + std::vector centroids_b(ids_b.size() * d); + for (size_t i = 0; i < ids_b.size(); i++) { + std::memcpy( + centroids_b.data() + i * d, + index.centroids.data() + static_cast(ids_b[i]) * d, + d * sizeof(float) + ); + } + StoreClusterEmbeddings(new_b, index.l0, centroids_b.data(), ids_b.size()); + } + + // Replace old L0 cluster with A, append B + index.l0.clusters[l0_cluster_id] = std::move(new_a); + index.l0.clusters.push_back(std::move(new_b)); + index.l0.num_clusters++; + + // Update L0 centroids + std::memcpy( + index.l0.centroids.data() + static_cast(l0_cluster_id) * d, + true_centroid_a.get(), + d * sizeof(float) + ); + index.l0.centroids.insert( + index.l0.centroids.end(), true_centroid_b.get(), true_centroid_b.get() + d + ); + + // Update mesocluster_id on affected L1 clusters + for (uint32_t id : ids_a) { + index.clusters[id].mesocluster_id = l0_cluster_id; + } + for (uint32_t id : ids_b) { + index.clusters[id].mesocluster_id = new_l0_id; + } + + index.l0.ComputeClusterOffsets(); + } + + void DestroyAndMergeL0Cluster(Cluster& l0_cluster) { + PDX_PROFILE_SCOPE("MergeL0"); + l0_cluster.CompactCluster(); + const uint32_t l0_id = l0_cluster.id; + const uint32_t num_embeddings = l0_cluster.num_embeddings; + + // Gather L1 cluster IDs and their centroids + std::vector l1_ids(l0_cluster.indices, l0_cluster.indices + num_embeddings); + std::vector l1_centroids(static_cast(num_embeddings) * d); + for (uint32_t i = 0; i < num_embeddings; i++) { + std::memcpy( + l1_centroids.data() + static_cast(i) * d, + index.centroids.data() + static_cast(l1_ids[i]) * d, + d * sizeof(float) + ); + } + + // Swap-and-pop: move last L0 cluster into the dead slot + uint32_t last_l0_id = index.l0.num_clusters - 1; + if (l0_id != last_l0_id) { + index.l0.clusters[l0_id] = std::move(index.l0.clusters[last_l0_id]); + index.l0.clusters[l0_id].id = l0_id; + + std::memcpy( + index.l0.centroids.data() + static_cast(l0_id) * d, + index.l0.centroids.data() + static_cast(last_l0_id) * d, + d * sizeof(float) + ); + + // Update mesocluster_id on all L1 clusters that referenced the moved L0 cluster + auto& moved = index.l0.clusters[l0_id]; + for (uint32_t i = 0; i < moved.used_capacity; i++) { + if (!moved.HasTombstone(i)) { + index.clusters[moved.indices[i]].mesocluster_id = l0_id; + } + } + } + + // Pop dead L0 cluster + index.l0.clusters.pop_back(); + index.l0.centroids.resize(index.l0.centroids.size() - d); + index.l0.num_clusters--; + index.l0.total_num_embeddings -= num_embeddings; + + // Reassign L1 clusters to nearest L0 centroids + ReassignEmbeddingsL0(l1_ids.data(), l1_centroids.data(), num_embeddings); + + index.l0.ComputeClusterOffsets(); + } + + void ReassignEmbeddingsL0( + const uint32_t* l1_ids, + const float* l1_centroids, + uint32_t num_embeddings + ) { + PDX_PROFILE_SCOPE("ReassignL0"); + const uint32_t n_l0 = index.l0.num_clusters; + + std::unique_ptr assignments(new uint32_t[num_embeddings]); + std::unique_ptr result_distances(new float[num_embeddings]); + std::unique_ptr tmp_distances_buf( + new float[skmeans::X_BATCH_SIZE * skmeans::Y_BATCH_SIZE] + ); + + std::vector entry_norms(num_embeddings); + Eigen::Map entries_matrix(l1_centroids, num_embeddings, d); + Eigen::Map e_norms(entry_norms.data(), num_embeddings); + e_norms.noalias() = entries_matrix.rowwise().squaredNorm(); + + std::vector l0_norms(n_l0); + Eigen::Map l0_matrix(index.l0.centroids.data(), n_l0, d); + Eigen::Map c_norms(l0_norms.data(), n_l0); + c_norms.noalias() = l0_matrix.rowwise().squaredNorm(); + + batch_computer::FindNearestNeighbor( + l1_centroids, + index.l0.centroids.data(), + num_embeddings, + n_l0, + d, + entry_norms.data(), + l0_norms.data(), + assignments.get(), + result_distances.get(), + tmp_distances_buf.get() + ); + + for (uint32_t i = 0; i < num_embeddings; i++) { + uint32_t target_l0 = assignments[i]; + index.l0.clusters[target_l0].AppendEmbedding( + l1_ids[i], l1_centroids + static_cast(i) * d + ); + index.l0.total_num_embeddings++; + index.clusters[l1_ids[i]].mesocluster_id = target_l0; + CheckL0ClusterHealth(index.l0.clusters[target_l0], false); + } + } + + // ****************************************** + // L1 (Leaf Clusters) Maintenance + // ****************************************** + + void CheckClusterHealth(cluster_t& cluster, bool allow_merges = true) { + if (cluster.used_capacity == cluster.max_capacity) { + // Its less expensive to compact than to Split + if (cluster.num_embeddings < cluster.used_capacity) { + auto moves = cluster.CompactCluster(); + for (const auto& [row_id, new_idx] : moves) { + row_id_cluster_mapping[row_id] = {cluster.id, new_idx}; + } + } else { + SplitCluster(cluster); + } + } else if (allow_merges && cluster.num_embeddings <= cluster.min_capacity) { + DestroyAndMergeCluster(cluster); + } + } + + void DestroyAndMergeCluster(cluster_t& cluster) { + PDX_PROFILE_SCOPE("Merge"); + cluster.CompactCluster(); + const uint32_t cluster_id = cluster.id; + const uint32_t mesocluster_id = cluster.mesocluster_id; + const uint32_t n_emb = cluster.num_embeddings; + + auto raw_embeddings = cluster.GetHorizontalEmbeddingsFromPDXBuffer(); + std::vector cluster_indices(cluster.indices, cluster.indices + n_emb); + auto cluster_embeddings = DequantizeClusterEmbeddings(raw_embeddings.get(), n_emb); + + // Remove from L0 + uint32_t position_in_mesocluster = FindPositionInMesoCluster(cluster_id, mesocluster_id); + index.l0.clusters[mesocluster_id].DeleteEmbedding(position_in_mesocluster); + index.l0.total_num_embeddings--; + + // Swap-and-pop: move last cluster into the dead slot + uint32_t last_id = index.num_clusters - 1; + if (cluster_id != last_id) { + index.clusters[cluster_id] = std::move(index.clusters[last_id]); + index.clusters[cluster_id].id = cluster_id; + + auto& moved_cluster = index.clusters[cluster_id]; + std::memcpy( + index.centroids.data() + static_cast(cluster_id) * d, + index.centroids.data() + static_cast(last_id) * d, + d * sizeof(float) + ); + for (uint32_t i = 0; i < moved_cluster.used_capacity; i++) { + if (!moved_cluster.HasTombstone(i)) { + row_id_cluster_mapping[moved_cluster.indices[i]] = {cluster_id, i}; + } + } + + uint32_t l0_moved_cluster_position = + FindPositionInMesoCluster(last_id, moved_cluster.mesocluster_id); + index.l0.clusters[moved_cluster.mesocluster_id].indices[l0_moved_cluster_position] = + cluster_id; + } + + // Pop the dead cluster and its centroid (ensured to be at the end) + index.clusters.pop_back(); + index.centroids.resize(index.centroids.size() - d); + index.num_clusters--; + + index.ComputeClusterOffsets(); + index.l0.ComputeClusterOffsets(); + + // Fully removed the dying cluster before reassignment + ReassignEmbeddings( + cluster_indices.data(), cluster_embeddings.get(), n_emb, mesocluster_id, false + ); + + CheckL0ClusterHealth(index.l0.clusters[mesocluster_id]); + } + + // Assumes cluster is compacted and has no tombstones + void SplitCluster(cluster_t& cluster) { + PDX_PROFILE_SCOPE("Split"); + const uint32_t cluster_id = cluster.id; + const uint32_t mesocluster_id = cluster.mesocluster_id; + + auto raw_embeddings = cluster.GetHorizontalEmbeddingsFromPDXBuffer(); + auto cluster_embeddings = + DequantizeClusterEmbeddings(raw_embeddings.get(), cluster.num_embeddings); + + auto centroid_to_split = index.centroids.data() + static_cast(cluster_id) * d; + auto neighboring_clusters_ids = + GetNearestNeighborClusterIds(cluster_id, mesocluster_id, centroid_to_split); + + // 2-means split + std::unique_ptr centroid_a(new float[d]); + std::unique_ptr centroid_b(new float[d]); + std::vector group_a_idx, group_b_idx, group_rest_idx; + { + PDX_PROFILE_SCOPE("Split/KMeans"); + KMeansResult split_result = ComputeKMeans( + cluster_embeddings.get(), + cluster.num_embeddings, + d, + 2, + config.distance_metric, + config.seed, + true, + 1.0f, + 4, + false, + 1 + ); + std::memcpy(centroid_a.get(), split_result.centroids.data(), d * sizeof(float)); + std::memcpy(centroid_b.get(), split_result.centroids.data() + d, d * sizeof(float)); + group_a_idx.reserve(split_result.assignments[0].size()); + group_b_idx.reserve(split_result.assignments[1].size()); + } + + // Assign each embedding to A, B, or rest (closer elsewhere) + { + PDX_PROFILE_SCOPE("Split/Partition"); + for (size_t i = 0; i < cluster.num_embeddings; i++) { + const float* emb = cluster_embeddings.get() + i * d; + float dist_old = distance_computer_f32_t::Horizontal(emb, centroid_to_split, d); + // TODO(@lkuffo, med): We could avoid one of these + // since we have the distance from k-means, we just need to bring it here + float dist_a = distance_computer_f32_t::Horizontal(emb, centroid_a.get(), d); + float dist_b = distance_computer_f32_t::Horizontal(emb, centroid_b.get(), d); + float min_ab = std::min(dist_a, dist_b); + + if (min_ab <= dist_old) { + (dist_a <= dist_b ? group_a_idx : group_b_idx).push_back(i); + } else { + bool closer_elsewhere = false; + for (uint32_t c : neighboring_clusters_ids) { + float dist = distance_computer_f32_t::Horizontal( + emb, index.centroids.data() + static_cast(c) * d, d + ); + if (dist < min_ab) { + closer_elsewhere = true; + break; + } + } + if (closer_elsewhere) { + group_rest_idx.push_back(i); + } else { + (dist_a <= dist_b ? group_a_idx : group_b_idx).push_back(i); + } + } + } + } + + // Gather embeddings and IDs, accumulate centroid sums + std::vector embs_a, embs_b; + std::vector ids_a, ids_b; + embs_a.reserve(group_a_idx.size() * d); + embs_b.reserve(group_b_idx.size() * d); + ids_a.reserve(group_a_idx.size()); + ids_b.reserve(group_b_idx.size()); + auto centroid_sum_a = std::make_unique(d); + auto centroid_sum_b = std::make_unique(d); + { + PDX_PROFILE_SCOPE("Split/GatherEmbeddings"); + GatherGroupEmbeddings( + group_a_idx, + raw_embeddings.get(), + cluster_embeddings.get(), + cluster, + embs_a, + ids_a, + centroid_sum_a.get() + ); + GatherGroupEmbeddings( + group_b_idx, + raw_embeddings.get(), + cluster_embeddings.get(), + cluster, + embs_b, + ids_b, + centroid_sum_b.get() + ); + } + + // Gather group_rest NOW, before the cluster is replaced + std::unique_ptr float_rest(new float[group_rest_idx.size() * d]); + std::unique_ptr ids_rest(new uint32_t[group_rest_idx.size()]); + for (size_t i = 0; i < group_rest_idx.size(); i++) { + std::memcpy( + float_rest.get() + i * d, + cluster_embeddings.get() + static_cast(group_rest_idx[i]) * d, + d * sizeof(float) + ); + ids_rest[i] = cluster.indices[group_rest_idx[i]]; + } + + // Steal neighbors closer to A or B than to their own centroid + { + PDX_PROFILE_SCOPE("Split/NeighborReassign"); + for (uint32_t neighbor_id : neighboring_clusters_ids) { + auto& neighbor = index.clusters[neighbor_id]; + const float* neighbor_centroid = + index.centroids.data() + static_cast(neighbor_id) * d; + + // Quantize centroids for U8, or use directly for F32 + std::unique_ptr q_own, q_a, q_b; + const query_t* query_own; + const query_t* query_a; + const query_t* query_b; + if constexpr (Q == U8) { + q_own.reset(new query_t[d]); + q_a.reset(new query_t[d]); + q_b.reset(new query_t[d]); + searcher->quantizer.QuantizeEmbedding( + neighbor_centroid, + index.quantization_base, + index.quantization_scale, + q_own.get() + ); + searcher->quantizer.QuantizeEmbedding( + centroid_a.get(), + index.quantization_base, + index.quantization_scale, + q_a.get() + ); + searcher->quantizer.QuantizeEmbedding( + centroid_b.get(), + index.quantization_base, + index.quantization_scale, + q_b.get() + ); + query_own = q_own.get(); + query_a = q_a.get(); + query_b = q_b.get(); + } else { + query_own = neighbor_centroid; + query_a = centroid_a.get(); + query_b = centroid_b.get(); + } + + auto distances_to_own = + CalculateDistanceFromEmbeddingToCluster(query_own, neighbor.data, neighbor); + auto distances_to_a = + CalculateDistanceFromEmbeddingToCluster(query_a, neighbor.data, neighbor); + auto distances_to_b = + CalculateDistanceFromEmbeddingToCluster(query_b, neighbor.data, neighbor); + + for (uint32_t p = 0; p < neighbor.used_capacity; p++) { + if (neighbor.HasTombstone(p)) + continue; + + distance_t dist_a = distances_to_a[p]; + distance_t dist_b = distances_to_b[p]; + distance_t dist_to_own = distances_to_own[p]; + + if (dist_to_own < dist_a && dist_to_own < dist_b) { + continue; + } + + // We need the horizontal embedding (this happens in less than 1% of points) + auto raw_emb = neighbor.GetHorizontalEmbeddingFromPDXBuffer(p); + const float* emb_ptr; + std::unique_ptr emb_f32; + if constexpr (Q == U8) { + emb_f32.reset(new float[d]); + searcher->quantizer.DequantizeEmbedding( + raw_emb.get(), + index.quantization_base, + index.quantization_scale, + emb_f32.get() + ); + emb_ptr = emb_f32.get(); + } else { + emb_ptr = raw_emb.get(); + } + + if (dist_a <= dist_b) { + uint32_t row_id = neighbor.indices[p]; + neighbor.DeleteEmbedding(p); + embs_a.insert(embs_a.end(), raw_emb.get(), raw_emb.get() + d); + ids_a.push_back(row_id); + for (size_t j = 0; j < d; j++) + centroid_sum_a[j] += emb_ptr[j]; + } else if (dist_b < dist_a) { + uint32_t row_id = neighbor.indices[p]; + neighbor.DeleteEmbedding(p); + embs_b.insert(embs_b.end(), raw_emb.get(), raw_emb.get() + d); + ids_b.push_back(row_id); + for (size_t j = 0; j < d; j++) + centroid_sum_b[j] += emb_ptr[j]; + } + } + } + } + // Compute true centroids from accumulated sums + size_t count_a = ids_a.size(); + size_t count_b = ids_b.size(); + std::unique_ptr true_centroid_a(new float[d]); + std::unique_ptr true_centroid_b(new float[d]); + { + PDX_PROFILE_SCOPE("Split/ComputeTrueCentroids"); + ComputeCentroidMean( + centroid_sum_a.get(), count_a, centroid_a.get(), true_centroid_a.get() + ); + ComputeCentroidMean( + centroid_sum_b.get(), count_b, centroid_b.get(), true_centroid_b.get() + ); + } + + // Create new clusters and update all data structures + { + PDX_PROFILE_SCOPE("Split/ConsolidateNewClusters"); + cluster_t new_cluster_a(static_cast(count_a), d); + new_cluster_a.id = cluster_id; + new_cluster_a.mesocluster_id = mesocluster_id; + if (count_a > 0) { + std::memcpy(new_cluster_a.indices, ids_a.data(), count_a * sizeof(uint32_t)); + StoreClusterEmbeddings( + new_cluster_a, index, embs_a.data(), count_a + ); + } + uint32_t new_cluster_b_id = index.num_clusters; + cluster_t new_cluster_b(static_cast(count_b), d); + new_cluster_b.id = new_cluster_b_id; + new_cluster_b.mesocluster_id = mesocluster_id; + if (count_b > 0) { + std::memcpy(new_cluster_b.indices, ids_b.data(), count_b * sizeof(uint32_t)); + StoreClusterEmbeddings( + new_cluster_b, index, embs_b.data(), count_b + ); + } + // Replace old cluster with A, append B + index.clusters[cluster_id] = std::move(new_cluster_a); + index.clusters.push_back(std::move(new_cluster_b)); + index.num_clusters++; + // Update centroids + std::memcpy( + index.centroids.data() + static_cast(cluster_id) * d, + true_centroid_a.get(), + d * sizeof(float) + ); + index.centroids.insert( + index.centroids.end(), true_centroid_b.get(), true_centroid_b.get() + d + ); + // Update row_id_cluster_mapping (includes both original and stolen-neighbor points) + for (size_t i = 0; i < count_a; i++) { + row_id_cluster_mapping[ids_a[i]] = {cluster_id, static_cast(i)}; + } + for (size_t i = 0; i < count_b; i++) { + row_id_cluster_mapping[ids_b[i]] = {new_cluster_b_id, static_cast(i)}; + } + // Update L0: remove old centroid, add both new centroids + uint32_t pos = FindPositionInMesoCluster(cluster_id, mesocluster_id); + index.l0.clusters[mesocluster_id].DeleteEmbedding(pos); + index.l0.clusters[mesocluster_id].CompactCluster(); + index.l0.clusters[mesocluster_id].AppendEmbedding(cluster_id, true_centroid_a.get()); + // CheckL0ClusterHealth may reallocate index.l0.clusters + CheckL0ClusterHealth(index.l0.clusters[mesocluster_id]); + index.l0.clusters[mesocluster_id].AppendEmbedding( + new_cluster_b_id, true_centroid_b.get() + ); + index.l0.total_num_embeddings++; + } + + // Reassign rest group (closer to other centroids than A or B) + if (!group_rest_idx.empty()) { + ReassignEmbeddings( + ids_rest.get(), + float_rest.get(), + static_cast(group_rest_idx.size()), + mesocluster_id + ); + } + + index.ComputeClusterOffsets(); + index.l0.ComputeClusterOffsets(); + + CheckL0ClusterHealth(index.l0.clusters[mesocluster_id]); + } + + // Reassign dequantized (float) embeddings to their closest centroid + // within the given mesocluster. + // allow_merges: passed to CheckClusterHealth — false suppresses merge cascades. + // TODO(@lkuffo, med): We can optimize reassignments by doing GEMM+PRUNING for assignments + void ReassignEmbeddings( + uint32_t* row_ids, + const float* embeddings, + uint32_t num_embeddings, + uint32_t mesocluster_id, + bool allow_merges = true + ) { + PDX_PROFILE_SCOPE("Reassign"); + + // Gather cluster IDs and centroids from the mesocluster + auto& meso = index.l0.clusters[mesocluster_id]; + std::vector candidate_ids; + candidate_ids.reserve(meso.used_capacity); + for (uint32_t p = 0; p < meso.used_capacity; p++) { + if (!meso.HasTombstone(p)) { + candidate_ids.push_back(meso.indices[p]); + } + } + const uint32_t n_candidates = static_cast(candidate_ids.size()); + + std::vector candidate_centroids(static_cast(n_candidates) * d); + for (size_t i = 0; i < n_candidates; i++) { + std::memcpy( + candidate_centroids.data() + i * d, + index.centroids.data() + static_cast(candidate_ids[i]) * d, + d * sizeof(float) + ); + } + + std::unique_ptr assignments(new uint32_t[num_embeddings]); + std::unique_ptr result_distances(new float[num_embeddings]); + std::unique_ptr tmp_distances_buf( + new float[skmeans::X_BATCH_SIZE * skmeans::Y_BATCH_SIZE] + ); + + std::vector embeddings_norms(num_embeddings); + Eigen::Map embeddings_matrix(embeddings, num_embeddings, d); + Eigen::Map v_norms(embeddings_norms.data(), num_embeddings); + v_norms.noalias() = embeddings_matrix.rowwise().squaredNorm(); + + std::vector centroid_norms(n_candidates); + Eigen::Map centroids_matrix(candidate_centroids.data(), n_candidates, d); + Eigen::Map c_norms(centroid_norms.data(), n_candidates); + c_norms.noalias() = centroids_matrix.rowwise().squaredNorm(); + + batch_computer::FindNearestNeighbor( + embeddings, + candidate_centroids.data(), + num_embeddings, + n_candidates, + d, + embeddings_norms.data(), + centroid_norms.data(), + assignments.get(), + result_distances.get(), + tmp_distances_buf.get() + ); + + // assignments[i] is an index into candidate_ids, map back to actual cluster ID + for (size_t i = 0; i < num_embeddings; i++) { + uint32_t best_cluster = candidate_ids[assignments[i]]; + uint32_t row_id = row_ids[i]; + uint32_t new_pos = + QuantizeAndAppend(index.clusters[best_cluster], row_id, embeddings + i * d); + row_id_cluster_mapping[row_id] = {best_cluster, new_pos}; + ReserveClusterSlotIfNeeded(); + CheckClusterHealth(index.clusters[best_cluster], allow_merges); + } + } + + using distance_t = pdx_distance_t; + using query_t = pdx_quantized_embedding_t; + + inline std::unique_ptr CalculateDistanceFromEmbeddingToCluster( + const query_t* embedding, + const embedding_storage_t* pdx_embeddings, + cluster_t& cluster + ) { + PDX_PROFILE_SCOPE("Split/CalculatePDXDistance"); + using distance_computer_t = DistanceComputer; + + auto n_vectors = cluster.used_capacity; + auto buffer_stride = cluster.max_capacity; + std::unique_ptr pruning_distances = + std::make_unique(cluster.used_capacity); + std::unique_ptr pruning_positions(new uint32_t[cluster.used_capacity]); + distance_computer_t::Vertical( + embedding, + pdx_embeddings, + n_vectors, + buffer_stride, + 0, + index.num_vertical_dimensions, + pruning_distances.get(), + pruning_positions.get() + ); + for (size_t horizontal_dimension = 0; + horizontal_dimension < index.num_horizontal_dimensions; + horizontal_dimension += H_DIM_SIZE) { + for (size_t vector_idx = 0; vector_idx < n_vectors; vector_idx++) { + size_t data_pos = (index.num_vertical_dimensions * buffer_stride) + + (horizontal_dimension * buffer_stride) + + (vector_idx * H_DIM_SIZE); + pruning_distances[vector_idx] += distance_computer_t::Horizontal( + embedding + index.num_vertical_dimensions + horizontal_dimension, + pdx_embeddings + data_pos, + H_DIM_SIZE + ); + } + } + return pruning_distances; + } }; using PDXIndexF32 = PDXIndex; diff --git a/include/pdx/ivf_wrapper.hpp b/include/pdx/ivf_wrapper.hpp index dd20d05..8c2f8c6 100644 --- a/include/pdx/ivf_wrapper.hpp +++ b/include/pdx/ivf_wrapper.hpp @@ -1,5 +1,6 @@ #pragma once +#include "pdx/cluster.hpp" #include "pdx/common.hpp" #include "pdx/utils.hpp" #include @@ -11,31 +12,6 @@ namespace PDX { -template -struct Cluster { - using data_t = pdx_data_t; - - Cluster(uint32_t num_embeddings, uint32_t num_dimensions) - : num_embeddings(num_embeddings), num_dimensions(num_dimensions), - indices(new uint32_t[num_embeddings]), - data(new data_t[static_cast(num_embeddings) * num_dimensions]) {} - - ~Cluster() { - delete[] data; - delete[] indices; - } - - uint32_t num_embeddings{}; - const uint32_t num_dimensions{}; - uint32_t* indices = nullptr; - data_t* data = nullptr; - - size_t GetInMemorySizeInBytes() const { - return sizeof(*this) + num_embeddings * sizeof(*indices) + - num_embeddings * static_cast(num_dimensions) * sizeof(*data); - } -}; - template class IVF { public: @@ -48,6 +24,9 @@ class IVF { uint32_t num_vertical_dimensions{}; uint32_t num_horizontal_dimensions{}; std::vector clusters; + size_t max_cluster_capacity{0}; + size_t total_capacity{0}; + std::unique_ptr cluster_offsets; bool is_normalized{}; std::vector centroids; @@ -91,6 +70,22 @@ class IVF { clusters.reserve(num_clusters); } + // Compute cluster_offsets, total_capacity, and max_cluster_capacity from current clusters. + // Must be called after all clusters have been created or after structural changes + // (split/merge). + void ComputeClusterOffsets() { + PDX_PROFILE_SCOPE("ComputeClusterOffsets"); + cluster_offsets.reset(new size_t[num_clusters]); + total_capacity = 0; + max_cluster_capacity = 0; + for (size_t i = 0; i < num_clusters; ++i) { + cluster_offsets[i] = total_capacity; + total_capacity += clusters[i].max_capacity; + max_cluster_capacity = + std::max(max_cluster_capacity, static_cast(clusters[i].max_capacity)); + } + } + void Load(char* input) { char* next_value = input; num_dimensions = ((uint32_t*) input)[0]; @@ -100,17 +95,15 @@ class IVF { next_value += sizeof(uint32_t) * 3; num_clusters = ((uint32_t*) next_value)[0]; next_value += sizeof(uint32_t); - auto* nums_embeddings = (uint32_t*) next_value; - next_value += num_clusters * sizeof(uint32_t); + auto* cluster_headers = (uint32_t*) next_value; + next_value += num_clusters * 2 * sizeof(uint32_t); clusters.reserve(num_clusters); for (size_t i = 0; i < num_clusters; ++i) { - clusters.emplace_back(nums_embeddings[i], num_dimensions); - memcpy( - clusters[i].data, - next_value, - sizeof(data_t) * clusters[i].num_embeddings * num_dimensions - ); - next_value += sizeof(data_t) * clusters[i].num_embeddings * num_dimensions; + uint32_t n_emb = cluster_headers[i * 2]; + uint32_t max_cap = cluster_headers[i * 2 + 1]; + clusters.emplace_back(n_emb, max_cap, num_dimensions); + clusters[i].id = i; + clusters[i].LoadPDXData(next_value); } for (size_t i = 0; i < num_clusters; ++i) { memcpy(clusters[i].indices, next_value, sizeof(uint32_t) * clusters[i].num_embeddings); @@ -134,6 +127,7 @@ class IVF { quantization_scale_squared = quantization_scale * quantization_scale; inverse_quantization_scale_squared = 1.0f / quantization_scale_squared; } + ComputeClusterOffsets(); } void Save(std::ostream& out) const { @@ -144,12 +138,10 @@ class IVF { for (size_t i = 0; i < num_clusters; ++i) { out.write(reinterpret_cast(&clusters[i].num_embeddings), sizeof(uint32_t)); + out.write(reinterpret_cast(&clusters[i].max_capacity), sizeof(uint32_t)); } for (size_t i = 0; i < num_clusters; ++i) { - out.write( - reinterpret_cast(clusters[i].data), - sizeof(data_t) * clusters[i].num_embeddings * num_dimensions - ); + clusters[i].SavePDXData(out); } for (size_t i = 0; i < num_clusters; ++i) { out.write( @@ -181,6 +173,7 @@ class IVF { in_memory_size_in_bytes += (clusters.capacity() - clusters.size()) * sizeof(*clusters.data()); in_memory_size_in_bytes += centroids.capacity() * sizeof(*centroids.data()); + in_memory_size_in_bytes += num_clusters * sizeof(size_t); // cluster_offsets return in_memory_size_in_bytes; } }; @@ -242,18 +235,16 @@ class IVFTree : public IVF { l0.num_horizontal_dimensions = h_dims; l0.num_clusters = n_clusters_l0; - auto* nums_embeddings_l0 = (uint32_t*) next_value; - next_value += n_clusters_l0 * sizeof(uint32_t); + auto* l0_headers = (uint32_t*) next_value; + next_value += n_clusters_l0 * 2 * sizeof(uint32_t); l0.clusters.reserve(n_clusters_l0); for (size_t i = 0; i < n_clusters_l0; ++i) { - l0.clusters.emplace_back(nums_embeddings_l0[i], dims); - memcpy( - l0.clusters[i].data, - next_value, - sizeof(float) * l0.clusters[i].num_embeddings * dims - ); - next_value += sizeof(float) * l0.clusters[i].num_embeddings * dims; + uint32_t n_emb = l0_headers[i * 2]; + uint32_t max_cap = l0_headers[i * 2 + 1]; + l0.clusters.emplace_back(n_emb, max_cap, dims); + l0.clusters[i].id = i; + l0.clusters[i].LoadPDXData(next_value); } for (size_t i = 0; i < n_clusters_l0; ++i) { memcpy( @@ -268,18 +259,16 @@ class IVFTree : public IVF { this->num_horizontal_dimensions = h_dims; this->num_clusters = n_clusters_l1; - auto* nums_embeddings_l1 = (uint32_t*) next_value; - next_value += n_clusters_l1 * sizeof(uint32_t); + auto* l1_headers = (uint32_t*) next_value; + next_value += n_clusters_l1 * 2 * sizeof(uint32_t); this->clusters.reserve(n_clusters_l1); for (size_t i = 0; i < n_clusters_l1; ++i) { - this->clusters.emplace_back(nums_embeddings_l1[i], dims); - memcpy( - this->clusters[i].data, - next_value, - sizeof(data_t) * this->clusters[i].num_embeddings * dims - ); - next_value += sizeof(data_t) * this->clusters[i].num_embeddings * dims; + uint32_t n_emb = l1_headers[i * 2]; + uint32_t max_cap = l1_headers[i * 2 + 1]; + this->clusters.emplace_back(n_emb, max_cap, dims); + this->clusters[i].id = i; + this->clusters[i].LoadPDXData(next_value); } for (size_t i = 0; i < n_clusters_l1; ++i) { memcpy( @@ -310,6 +299,16 @@ class IVFTree : public IVF { this->quantization_scale_squared = this->quantization_scale * this->quantization_scale; this->inverse_quantization_scale_squared = 1.0f / this->quantization_scale_squared; } + // Set mesocluster_id on L1 clusters by scanning L0 + for (uint32_t mc = 0; mc < n_clusters_l0; mc++) { + auto& l0c = l0.clusters[mc]; + for (uint32_t p = 0; p < l0c.num_embeddings; p++) { + this->clusters[l0c.indices[p]].mesocluster_id = mc; + } + } + + l0.ComputeClusterOffsets(); + this->ComputeClusterOffsets(); } void Save(std::ostream& out) const { @@ -330,13 +329,13 @@ class IVFTree : public IVF { out.write( reinterpret_cast(&l0.clusters[i].num_embeddings), sizeof(uint32_t) ); - } - for (size_t i = 0; i < n_clusters_l0; ++i) { out.write( - reinterpret_cast(l0.clusters[i].data), - sizeof(float) * l0.clusters[i].num_embeddings * this->num_dimensions + reinterpret_cast(&l0.clusters[i].max_capacity), sizeof(uint32_t) ); } + for (size_t i = 0; i < n_clusters_l0; ++i) { + l0.clusters[i].SavePDXData(out); + } for (size_t i = 0; i < n_clusters_l0; ++i) { out.write( reinterpret_cast(l0.clusters[i].indices), @@ -349,13 +348,13 @@ class IVFTree : public IVF { out.write( reinterpret_cast(&this->clusters[i].num_embeddings), sizeof(uint32_t) ); - } - for (size_t i = 0; i < this->num_clusters; ++i) { out.write( - reinterpret_cast(this->clusters[i].data), - sizeof(data_t) * this->clusters[i].num_embeddings * this->num_dimensions + reinterpret_cast(&this->clusters[i].max_capacity), sizeof(uint32_t) ); } + for (size_t i = 0; i < this->num_clusters; ++i) { + this->clusters[i].SavePDXData(out); + } for (size_t i = 0; i < this->num_clusters; ++i) { out.write( reinterpret_cast(this->clusters[i].indices), diff --git a/include/pdx/layout.hpp b/include/pdx/layout.hpp index 16bf727..f082f9b 100644 --- a/include/pdx/layout.hpp +++ b/include/pdx/layout.hpp @@ -29,17 +29,21 @@ inline void StoreClusterEmbeddings( const auto vertical_d = index.num_vertical_dimensions; const auto horizontal_d = index.num_horizontal_dimensions; + const auto stride = static_cast(cluster.max_capacity); Eigen::Map in(embeddings, num_embeddings, index.num_dimensions); - Eigen::Map out(cluster.data, vertical_d, num_embeddings); + // Vertical block: (vertical_d x num_embeddings) with row stride = max_capacity + Eigen::Map> out( + cluster.data, vertical_d, num_embeddings, Eigen::OuterStride(stride) + ); out.noalias() = in.leftCols(vertical_d).transpose(); - float* horizontal_out = cluster.data + num_embeddings * vertical_d; + float* horizontal_out = cluster.data + stride * vertical_d; for (size_t j = 0; j < horizontal_d; j += PDX::H_DIM_SIZE) { Eigen::Map out_h(horizontal_out, num_embeddings, PDX::H_DIM_SIZE); out_h.noalias() = in.block(0, vertical_d + j, num_embeddings, PDX::H_DIM_SIZE); - horizontal_out += num_embeddings * PDX::H_DIM_SIZE; + horizontal_out += stride * PDX::H_DIM_SIZE; } } @@ -57,29 +61,28 @@ inline void StoreClusterEmbeddings( const auto vertical_d = index.num_vertical_dimensions; const auto horizontal_d = index.num_horizontal_dimensions; + const auto stride = static_cast(cluster.max_capacity); Eigen::Map in(embeddings, num_embeddings, index.num_dimensions); size_t dim = 0; for (; dim + PDX::U8_INTERLEAVE_SIZE <= vertical_d; dim += PDX::U8_INTERLEAVE_SIZE) { Eigen::Map out_v( - cluster.data + dim * num_embeddings, num_embeddings, PDX::U8_INTERLEAVE_SIZE + cluster.data + dim * stride, num_embeddings, PDX::U8_INTERLEAVE_SIZE ); out_v.noalias() = in.block(0, dim, num_embeddings, PDX::U8_INTERLEAVE_SIZE); } if (dim < vertical_d) { auto remaining = static_cast(vertical_d - dim); - Eigen::Map out_v( - cluster.data + dim * num_embeddings, num_embeddings, remaining - ); + Eigen::Map out_v(cluster.data + dim * stride, num_embeddings, remaining); out_v.noalias() = in.block(0, dim, num_embeddings, remaining); } - uint8_t* horizontal_out = cluster.data + num_embeddings * vertical_d; + uint8_t* horizontal_out = cluster.data + stride * vertical_d; for (size_t j = 0; j < horizontal_d; j += PDX::H_DIM_SIZE) { Eigen::Map out_h(horizontal_out, num_embeddings, PDX::H_DIM_SIZE); out_h.noalias() = in.block(0, vertical_d + j, num_embeddings, PDX::H_DIM_SIZE); - horizontal_out += num_embeddings * PDX::H_DIM_SIZE; + horizontal_out += stride * PDX::H_DIM_SIZE; } } diff --git a/include/pdx/lib/lib.hpp b/include/pdx/lib/lib.hpp index b6c2cef..5316065 100644 --- a/include/pdx/lib/lib.hpp +++ b/include/pdx/lib/lib.hpp @@ -162,6 +162,16 @@ class PyPDXIndex { } size_t GetInMemorySizeInBytes() const { return index->GetInMemorySizeInBytes(); } + + void Append(size_t row_id, const py::array_t& embedding) { + auto buf = embedding.request(); + if (buf.ndim != 1) { + throw std::runtime_error("embedding must be a 1D numpy array"); + } + index->Append(row_id, static_cast(buf.ptr)); + } + + void Delete(size_t row_id) { index->Delete(row_id); } }; } // namespace PDX diff --git a/include/pdx/profiler.hpp b/include/pdx/profiler.hpp new file mode 100644 index 0000000..c45f00c --- /dev/null +++ b/include/pdx/profiler.hpp @@ -0,0 +1,323 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +// Disclamer: Code produced by Opus 4.5 + +namespace PDX { + +/** + * @brief A centralized profiler for timing code sections. + * + * Usage: + * // Start/stop manually: + * Profiler::Get().Start("my_section"); + * // ... code ... + * Profiler::Get().Stop("my_section"); + * + * // Or use RAII scoped timer: + * { + * PDX_PROFILE_SCOPE("my_section"); + * // ... code automatically timed ... + * } + * + * // Print results: + * Profiler::Get().Print(); + * + * // Reset all timers: + * Profiler::Get().Reset(); + */ +class Profiler { + public: + struct TimerData { + size_t accum_time_ns = 0; // Accumulated time in nanoseconds + size_t call_count = 0; + std::chrono::high_resolution_clock::time_point start; + bool running = false; + }; + + // Get the global profiler instance + static Profiler& Get() { + static Profiler instance; + return instance; + } + + // Start timing a named section + void Start(const std::string& name) { + std::lock_guard lock(mutex_); + auto& timer = timers_[name]; + if (!timer.running) { + timer.start = std::chrono::high_resolution_clock::now(); + timer.running = true; + } + } + + // Stop timing a named section + void Stop(const std::string& name) { + auto end = std::chrono::high_resolution_clock::now(); + std::lock_guard lock(mutex_); + auto it = timers_.find(name); + if (it != timers_.end() && it->second.running) { + it->second.accum_time_ns += + std::chrono::duration_cast(end - it->second.start) + .count(); + it->second.call_count++; + it->second.running = false; + } + } + + // Get accumulated time in seconds for a timer + double GetTimeSeconds(const std::string& name) const { + std::lock_guard lock(mutex_); + auto it = timers_.find(name); + if (it != timers_.end()) { + return it->second.accum_time_ns / 1e9; + } + return 0.0; + } + + // Get accumulated time in nanoseconds for a timer + size_t GetTimeNanos(const std::string& name) const { + std::lock_guard lock(mutex_); + auto it = timers_.find(name); + if (it != timers_.end()) { + return it->second.accum_time_ns; + } + return 0; + } + + // Get call count for a timer + size_t GetCallCount(const std::string& name) const { + std::lock_guard lock(mutex_); + auto it = timers_.find(name); + if (it != timers_.end()) { + return it->second.call_count; + } + return 0; + } + + // Reset a specific timer + void Reset(const std::string& name) { + std::lock_guard lock(mutex_); + auto it = timers_.find(name); + if (it != timers_.end()) { + it->second.accum_time_ns = 0; + it->second.call_count = 0; + it->second.running = false; + } + } + + // Reset all timers + void Reset() { + std::lock_guard lock(mutex_); + timers_.clear(); + } + + // Print all timers with formatting + void Print(std::ostream& os = std::cout) const { + std::lock_guard lock(mutex_); + + // Calculate total time for percentage calculation + size_t total_ns = 0; + for (const auto& [name, data] : timers_) { + total_ns += data.accum_time_ns; + } + + // Collect and sort timer names for consistent output + std::vector names; + names.reserve(timers_.size()); + for (const auto& [name, _] : timers_) { + names.push_back(name); + } + std::sort(names.begin(), names.end()); + + os << std::fixed << std::setprecision(3); + os << "\n========== PROFILER RESULTS ==========\n"; + + for (const auto& name : names) { + const auto& data = timers_.at(name); + double secs = data.accum_time_ns / 1e9; + double pct = total_ns > 0 ? (data.accum_time_ns * 100.0 / total_ns) : 0.0; + + double ms_per_call = + data.call_count > 0 ? (data.accum_time_ns / 1e6 / data.call_count) : 0.0; + os << std::left << std::setw(35) << name << std::right << std::setw(10) << secs << "s" + << " (" << std::setw(5) << pct << "%)" + << " [" << data.call_count << " calls, " << ms_per_call << " ms/call]" + << "\n"; + } + + os << "---------------------------------------\n"; + os << std::left << std::setw(35) << "TOTAL" << std::right << std::setw(10) + << (static_cast(total_ns) / 1e9) << "s\n"; + os << "=======================================\n"; + } + + // Print a hierarchical view (timers with '/' are grouped) + void PrintHierarchical(std::ostream& os = std::cout) const { + std::lock_guard lock(mutex_); + + // Calculate total time from top-level timers only + size_t total_ns = 0; + for (const auto& [name, data] : timers_) { + if (name.find('/') == std::string::npos) { + total_ns += data.accum_time_ns; + } + } + + os << std::fixed << std::setprecision(3); + os << "\n========== PROFILER RESULTS ==========\n"; + + // Group timers by prefix + std::unordered_map> groups; + std::vector top_level; + + for (const auto& [name, _] : timers_) { + auto pos = name.find('/'); + if (pos != std::string::npos) { + std::string parent = name.substr(0, pos); + groups[parent].push_back(name); + } else { + top_level.push_back(name); + } + } + + // Sort top-level by accumulated time (descending) + std::sort( + top_level.begin(), + top_level.end(), + [this](const std::string& a, const std::string& b) { + return timers_.at(a).accum_time_ns > timers_.at(b).accum_time_ns; + } + ); + + for (const auto& name : top_level) { + const auto& data = timers_.at(name); + double secs = data.accum_time_ns / 1e9; + double pct = total_ns > 0 ? (data.accum_time_ns * 100.0 / total_ns) : 0.0; + + double ms_per_call = + data.call_count > 0 ? (data.accum_time_ns / 1e6 / data.call_count) : 0.0; + os << std::left << std::setw(40) << name << std::right << std::setw(10) << secs << "s" + << " (" << std::setw(5) << pct << "%)"; + if (data.call_count > 1) { + os << " [" << data.call_count << " calls, " << ms_per_call << " ms/call]"; + } + os << "\n"; + + // Print children (sorted by time descending) + auto it = groups.find(name); + if (it != groups.end()) { + auto& children = it->second; + std::sort( + children.begin(), + children.end(), + [this](const std::string& a, const std::string& b) { + return timers_.at(a).accum_time_ns > timers_.at(b).accum_time_ns; + } + ); + for (const auto& child : children) { + const auto& child_data = timers_.at(child); + double child_secs = child_data.accum_time_ns / 1e9; + double child_pct = + total_ns > 0 ? (child_data.accum_time_ns * 100.0 / total_ns) : 0.0; + std::string short_name = " - " + child.substr(name.length() + 1); + + double child_ms_per_call = + child_data.call_count > 0 + ? (child_data.accum_time_ns / 1e6 / child_data.call_count) + : 0.0; + os << std::left << std::setw(40) << short_name << std::right << std::setw(10) + << child_secs << "s" + << " (" << std::setw(5) << child_pct << "%)"; + if (child_data.call_count > 1) { + os << " [" << child_data.call_count << " calls, " << child_ms_per_call + << " ms/call]"; + } + os << "\n"; + } + } + } + + os << "-------------------------------------------\n"; + os << std::left << std::setw(40) << "TOTAL" << std::right << std::setw(10) + << (static_cast(total_ns) / 1e9) << "s\n"; + os << "===========================================\n"; + } + + // Check if profiling is enabled + bool IsEnabled() const { return enabled_; } + + // Enable/disable profiling globally + void SetEnabled(bool enabled) { enabled_ = enabled; } + + private: + Profiler() = default; + ~Profiler() = default; + Profiler(const Profiler&) = delete; + Profiler& operator=(const Profiler&) = delete; + + mutable std::mutex mutex_; + std::unordered_map timers_; + bool enabled_ = true; +}; + +/** + * @brief RAII scoped timer that automatically starts on construction and stops on destruction. + */ +class ScopedTimer { + public: + explicit ScopedTimer(std::string name) : name_(std::move(name)) { + if (Profiler::Get().IsEnabled()) { + Profiler::Get().Start(name_); + } + } + + ~ScopedTimer() { + if (Profiler::Get().IsEnabled()) { + Profiler::Get().Stop(name_); + } + } + + // Non-copyable, non-movable + ScopedTimer(const ScopedTimer&) = delete; + ScopedTimer& operator=(const ScopedTimer&) = delete; + ScopedTimer(ScopedTimer&&) = delete; + ScopedTimer& operator=(ScopedTimer&&) = delete; + + private: + std::string name_; +}; + +// Convenience macros for profiling +// Helper macros for unique variable name generation +#define PDX_CONCAT_IMPL(x, y) x##y +#define PDX_CONCAT(x, y) PDX_CONCAT_IMPL(x, y) + +// Profiling macros - only enabled when BENCHMARK_TIME is defined +#ifdef BENCHMARK_TIME +// PDX_PROFILE_SCOPE creates a scoped timer with the given name +#define PDX_PROFILE_SCOPE(name) ::PDX::ScopedTimer PDX_CONCAT(_pdx_timer_, __LINE__)(name) + +// PDX_PROFILE_FUNCTION creates a scoped timer with the function name +#define PDX_PROFILE_FUNCTION() PDX_PROFILE_SCOPE(__func__) + +// Manual start/stop macros +#define PDX_PROFILE_START(name) ::PDX::Profiler::Get().Start(name) +#define PDX_PROFILE_STOP(name) ::PDX::Profiler::Get().Stop(name) +#else +// No-op macros when profiling is disabled +#define PDX_PROFILE_SCOPE(name) ((void) 0) +#define PDX_PROFILE_FUNCTION() ((void) 0) +#define PDX_PROFILE_START(name) ((void) 0) +#define PDX_PROFILE_STOP(name) ((void) 0) +#endif + +} // namespace PDX diff --git a/include/pdx/quantizers/scalar.hpp b/include/pdx/quantizers/scalar.hpp index e37ad10..27c6f35 100644 --- a/include/pdx/quantizers/scalar.hpp +++ b/include/pdx/quantizers/scalar.hpp @@ -92,6 +92,18 @@ class ScalarQuantizer : public Quantizer { } } } + + void DequantizeEmbedding( + const quantized_embedding_t* quantized_embedding, + const float quantization_base, + const float quantization_scale, + float* output_embedding + ) { + for (size_t i = 0; i < num_dimensions; ++i) { + output_embedding[i] = + static_cast(quantized_embedding[i]) / quantization_scale + quantization_base; + } + } }; } // namespace PDX diff --git a/include/pdx/searcher.hpp b/include/pdx/searcher.hpp index 4cf9b4f..3b0cc69 100644 --- a/include/pdx/searcher.hpp +++ b/include/pdx/searcher.hpp @@ -1,9 +1,11 @@ #pragma once +#include "pdx/cluster.hpp" #include "pdx/common.hpp" #include "pdx/db_mock/predicate_evaluator.hpp" #include "pdx/distance_computers/base_computers.hpp" #include "pdx/ivf_wrapper.hpp" +#include "pdx/profiler.hpp" #include "pdx/pruners/adsampling.hpp" #include "pdx/quantizers/scalar.hpp" #include "pdx/utils.hpp" @@ -29,6 +31,7 @@ class PDXearch { using quantized_embedding_t = pdx_quantized_embedding_t; using index_t = Index; using cluster_t = Cluster; + using tombstones_t = typename cluster_t::tombstones_t; using distance_computer_t = DistanceComputer; Quantizer quantizer; @@ -36,16 +39,7 @@ class PDXearch { index_t& pdx_data; PDXearch(index_t& data_index, Pruner& pruner) - : quantizer(data_index.num_dimensions), pruner(pruner), pdx_data(data_index), - cluster_offsets(new size_t[data_index.num_clusters]) { - for (size_t i = 0; i < data_index.num_clusters; ++i) { - cluster_offsets[i] = total_embeddings; - total_embeddings += data_index.clusters[i].num_embeddings; - max_cluster_size = std::max( - max_cluster_size, static_cast(data_index.clusters[i].num_embeddings) - ); - } - } + : quantizer(data_index.num_dimensions), pruner(pruner), pdx_data(data_index) {} void SetNProbe(size_t nprobe) { ivf_nprobe = nprobe; } @@ -59,15 +53,10 @@ class PDXearch { ); } - std::unique_ptr cluster_offsets; - protected: float selectivity_threshold = 0.80; size_t ivf_nprobe = 0; - size_t total_embeddings{0}; - size_t max_cluster_size{0}; - // Prioritized list of indices of the clusters to probe. E.g., [0, 2, 1]. std::unique_ptr cluster_indices_in_access_order; size_t cluster_access_order_size = 0; @@ -177,6 +166,18 @@ class PDXearch { } }; + void MaskDistancesWithTombstones( + const typename cluster_t::tombstones_t& tombstones, + distance_t* pruning_distances + ) { + if (tombstones.empty()) + return; + const distance_t mask = std::numeric_limits::max() / 2; + for (uint32_t idx : tombstones) { + pruning_distances[idx] = mask; + } + } + static void GetClustersAccessOrderIVF( const float* PDX_RESTRICT query, const index_t& data, @@ -218,18 +219,20 @@ class PDXearch { const quantized_embedding_t* PDX_RESTRICT query, const data_t* data, const size_t n_vectors, + const size_t buffer_stride, uint32_t k, const uint32_t* vector_indices, uint32_t* pruning_positions, distance_t* pruning_distances, - std::priority_queue, VectorComparator>& heap + std::priority_queue, VectorComparator>& heap, + const tombstones_t& tombstones ) { ResetPruningDistances(n_vectors, pruning_distances); distance_computer_t::Vertical( query, data, n_vectors, - n_vectors, + buffer_stride, 0, pdx_data.num_vertical_dimensions, pruning_distances, @@ -239,8 +242,9 @@ class PDXearch { horizontal_dimension < pdx_data.num_horizontal_dimensions; horizontal_dimension += H_DIM_SIZE) { for (size_t vector_idx = 0; vector_idx < n_vectors; vector_idx++) { - size_t data_pos = (pdx_data.num_vertical_dimensions * n_vectors) + - (horizontal_dimension * n_vectors) + (vector_idx * H_DIM_SIZE); + size_t data_pos = (pdx_data.num_vertical_dimensions * buffer_stride) + + (horizontal_dimension * buffer_stride) + + (vector_idx * H_DIM_SIZE); pruning_distances[vector_idx] += distance_computer_t::Horizontal( query + pdx_data.num_vertical_dimensions + horizontal_dimension, data + data_pos, @@ -248,6 +252,7 @@ class PDXearch { ); } } + MaskDistancesWithTombstones(tombstones, pruning_distances); size_t max_possible_k = std::min( static_cast(k) - heap.size(), n_vectors @@ -280,13 +285,15 @@ class PDXearch { const quantized_embedding_t* PDX_RESTRICT query, const data_t* data, const size_t n_vectors, + const size_t buffer_stride, uint32_t k, const uint32_t* vector_indices, uint32_t* pruning_positions, distance_t* pruning_distances, std::priority_queue, VectorComparator>& heap, uint8_t* selection_vector, - uint32_t passing_tuples + uint32_t passing_tuples, + const tombstones_t& tombstones ) { ResetPruningDistances(n_vectors, pruning_distances); size_t n_vectors_not_pruned = 0; @@ -299,8 +306,8 @@ class PDXearch { for (size_t horizontal_dimension = 0; horizontal_dimension < pdx_data.num_horizontal_dimensions; horizontal_dimension += H_DIM_SIZE) { - size_t offset_data = - (pdx_data.num_vertical_dimensions * n_vectors) + (horizontal_dimension * n_vectors); + size_t offset_data = (pdx_data.num_vertical_dimensions * buffer_stride) + + (horizontal_dimension * buffer_stride); for (size_t vector_idx = 0; vector_idx < n_vectors_not_pruned; vector_idx++) { size_t v_idx = pruning_positions[vector_idx]; size_t data_pos = offset_data + (v_idx * H_DIM_SIZE); @@ -317,7 +324,7 @@ class PDXearch { query, data, n_vectors, - n_vectors, + buffer_stride, 0, pdx_data.num_vertical_dimensions, pruning_distances, @@ -329,7 +336,7 @@ class PDXearch { query, data, n_vectors_not_pruned, - n_vectors, + buffer_stride, 0, pdx_data.num_vertical_dimensions, pruning_distances, @@ -340,6 +347,7 @@ class PDXearch { size_t max_possible_k = std::min(static_cast(k) - heap.size(), static_cast(passing_tuples)); MaskDistancesWithSelectionVector(n_vectors, pruning_distances, selection_vector); + MaskDistancesWithTombstones(tombstones, pruning_distances); std::unique_ptr indices_sorted(new size_t[n_vectors]); std::iota(indices_sorted.get(), indices_sorted.get() + n_vectors, static_cast(0)); std::partial_sort( @@ -370,6 +378,7 @@ class PDXearch { const quantized_embedding_t* PDX_RESTRICT query, const data_t* PDX_RESTRICT data, const size_t n_vectors, + const size_t buffer_stride, uint32_t k, float tuples_threshold, uint32_t* pruning_positions, @@ -378,6 +387,7 @@ class PDXearch { std::priority_queue, VectorComparator>& heap, uint32_t& current_dimension_idx, size_t& n_vectors_not_pruned, + const tombstones_t& tombstones, uint32_t passing_tuples = 0, uint8_t* selection_vector = nullptr ) { @@ -386,6 +396,7 @@ class PDXearch { size_t tuples_needed_to_exit = static_cast(std::ceil(tuples_threshold * static_cast(n_vectors))); ResetPruningDistances(n_vectors, pruning_distances); + MaskDistancesWithTombstones(tombstones, pruning_distances); uint32_t n_tuples_to_prune = 0; if constexpr (FILTERED) { float selection_percentage = @@ -407,7 +418,7 @@ class PDXearch { query, data, n_vectors, - n_vectors, + buffer_stride, current_dimension_idx, last_dimension_to_fetch, pruning_distances, @@ -429,6 +440,7 @@ class PDXearch { const quantized_embedding_t* PDX_RESTRICT query, const data_t* PDX_RESTRICT data, const size_t n_vectors, + const size_t buffer_stride, uint32_t k, uint32_t* pruning_positions, distance_t* pruning_distances, @@ -436,9 +448,11 @@ class PDXearch { std::priority_queue, VectorComparator>& heap, uint32_t& current_dimension_idx, size_t& n_vectors_not_pruned, + const tombstones_t& tombstones, const uint8_t* selection_vector = nullptr ) { GetPruningThreshold(k, heap, pruning_threshold, current_dimension_idx); + MaskDistancesWithTombstones(tombstones, pruning_distances); InitPositionsArray( n_vectors, n_vectors_not_pruned, @@ -453,8 +467,8 @@ class PDXearch { while (pdx_data.num_horizontal_dimensions && n_vectors_not_pruned && current_horizontal_dimension < pdx_data.num_horizontal_dimensions) { cur_n_vectors_not_pruned = n_vectors_not_pruned; - size_t offset_data = (pdx_data.num_vertical_dimensions * n_vectors) + - (current_horizontal_dimension * n_vectors); + size_t offset_data = (pdx_data.num_vertical_dimensions * buffer_stride) + + (current_horizontal_dimension * buffer_stride); for (size_t vector_idx = 0; vector_idx < n_vectors_not_pruned; vector_idx++) { size_t v_idx = pruning_positions[vector_idx]; size_t data_pos = offset_data + (v_idx * H_DIM_SIZE); @@ -495,7 +509,7 @@ class PDXearch { query, data, cur_n_vectors_not_pruned, - n_vectors, + buffer_stride, current_vertical_dimension, last_dimension_to_test_idx, pruning_distances, @@ -579,18 +593,23 @@ class PDXearch { } public: - /****************************************************************** - * Search methods - ******************************************************************/ - std::vector Search(const float* PDX_RESTRICT const raw_query, const uint32_t k) { + std::vector Search( + const float* PDX_RESTRICT const raw_query, + const uint32_t k, + const bool is_query_trasnformed = false + ) { Heap local_heap{}; std::unique_ptr query(new float[pdx_data.num_dimensions]); - if (!pdx_data.is_normalized) { - pruner.PreprocessQuery(raw_query, query.get()); + if (is_query_trasnformed) { + std::copy(raw_query, raw_query + pdx_data.num_dimensions, query.get()); } else { - std::unique_ptr normalized_query(new float[pdx_data.num_dimensions]); - quantizer.NormalizeQuery(raw_query, normalized_query.get()); - pruner.PreprocessQuery(normalized_query.get(), query.get()); + if (!pdx_data.is_normalized) { + pruner.PreprocessQuery(raw_query, query.get()); + } else { + std::unique_ptr normalized_query(new float[pdx_data.num_dimensions]); + quantizer.NormalizeQuery(raw_query, normalized_query.get()); + pruner.PreprocessQuery(normalized_query.get(), query.get()); + } } size_t clusters_to_visit = (ivf_nprobe == 0 || ivf_nprobe > pdx_data.num_clusters) ? pdx_data.num_clusters @@ -628,8 +647,10 @@ class PDXearch { local_prepared_query = query.get(); } - std::unique_ptr pruning_distances(new distance_t[max_cluster_size]); - std::unique_ptr pruning_positions(new uint32_t[max_cluster_size]); + std::unique_ptr pruning_distances( + new distance_t[pdx_data.max_cluster_capacity] + ); + std::unique_ptr pruning_positions(new uint32_t[pdx_data.max_cluster_capacity]); for (size_t cluster_idx = 0; cluster_idx < clusters_to_visit; ++cluster_idx) { distance_t pruning_threshold = std::numeric_limits::max(); @@ -641,24 +662,28 @@ class PDXearch { if (cluster.num_embeddings == 0) { continue; } + cluster.n_accessed++; if (local_heap.size() < k) { // We cannot prune until we fill the heap Start( local_prepared_query, cluster.data, - cluster.num_embeddings, + cluster.used_capacity, + cluster.max_capacity, k, cluster.indices, pruning_positions.get(), pruning_distances.get(), - local_heap + local_heap, + cluster.tombstones ); continue; } Warmup( local_prepared_query, cluster.data, - cluster.num_embeddings, + cluster.used_capacity, + cluster.max_capacity, k, selectivity_threshold, pruning_positions.get(), @@ -666,19 +691,22 @@ class PDXearch { pruning_threshold, local_heap, current_dimension_idx, - n_vectors_not_pruned + n_vectors_not_pruned, + cluster.tombstones ); Prune( local_prepared_query, cluster.data, - cluster.num_embeddings, + cluster.used_capacity, + cluster.max_capacity, k, pruning_positions.get(), pruning_distances.get(), pruning_threshold, local_heap, current_dimension_idx, - n_vectors_not_pruned + n_vectors_not_pruned, + cluster.tombstones ); if (n_vectors_not_pruned) { MergeIntoHeap( @@ -698,16 +726,21 @@ class PDXearch { std::vector FilteredSearch( const float* PDX_RESTRICT const raw_query, const uint32_t k, - const PredicateEvaluator& predicate_evaluator + const PredicateEvaluator& predicate_evaluator, + const bool is_query_transformed = false ) { Heap local_heap{}; std::unique_ptr query(new float[pdx_data.num_dimensions]); - if (!pdx_data.is_normalized) { - pruner.PreprocessQuery(raw_query, query.get()); + if (is_query_transformed) { + std::copy(raw_query, raw_query + pdx_data.num_dimensions, query.get()); } else { - std::unique_ptr normalized_query(new float[pdx_data.num_dimensions]); - quantizer.NormalizeQuery(raw_query, normalized_query.get()); - pruner.PreprocessQuery(normalized_query.get(), query.get()); + if (!pdx_data.is_normalized) { + pruner.PreprocessQuery(raw_query, query.get()); + } else { + std::unique_ptr normalized_query(new float[pdx_data.num_dimensions]); + quantizer.NormalizeQuery(raw_query, normalized_query.get()); + pruner.PreprocessQuery(normalized_query.get(), query.get()); + } } size_t clusters_to_visit = (ivf_nprobe == 0 || ivf_nprobe > pdx_data.num_clusters) @@ -735,8 +768,10 @@ class PDXearch { local_prepared_query = query.get(); } - std::unique_ptr pruning_distances(new distance_t[max_cluster_size]); - std::unique_ptr pruning_positions(new uint32_t[max_cluster_size]); + std::unique_ptr pruning_distances( + new distance_t[pdx_data.max_cluster_capacity] + ); + std::unique_ptr pruning_positions(new uint32_t[pdx_data.max_cluster_capacity]); for (size_t cluster_idx = 0; cluster_idx < clusters_to_visit; ++cluster_idx) { distance_t pruning_threshold = std::numeric_limits::max(); @@ -745,7 +780,7 @@ class PDXearch { const size_t current_cluster_idx = local_cluster_order[cluster_idx]; auto [selection_vector, passing_tuples] = predicate_evaluator.GetSelectionVector( - current_cluster_idx, cluster_offsets[current_cluster_idx] + current_cluster_idx, pdx_data.cluster_offsets[current_cluster_idx] ); if (passing_tuples == 0) { continue; @@ -754,26 +789,30 @@ class PDXearch { if (cluster.num_embeddings == 0) { continue; } + cluster.n_accessed++; if (local_heap.size() < k) { // We cannot prune until we fill the heap FilteredStart( local_prepared_query, cluster.data, - cluster.num_embeddings, + cluster.used_capacity, + cluster.max_capacity, k, cluster.indices, pruning_positions.get(), pruning_distances.get(), local_heap, selection_vector, - passing_tuples + passing_tuples, + cluster.tombstones ); continue; } Warmup( local_prepared_query, cluster.data, - cluster.num_embeddings, + cluster.used_capacity, + cluster.max_capacity, k, selectivity_threshold, pruning_positions.get(), @@ -782,13 +821,15 @@ class PDXearch { local_heap, current_dimension_idx, n_vectors_not_pruned, + cluster.tombstones, passing_tuples, selection_vector ); Prune( local_prepared_query, cluster.data, - cluster.num_embeddings, + cluster.used_capacity, + cluster.max_capacity, k, pruning_positions.get(), pruning_distances.get(), @@ -796,6 +837,7 @@ class PDXearch { local_heap, current_dimension_idx, n_vectors_not_pruned, + cluster.tombstones, selection_vector ); if (n_vectors_not_pruned) { diff --git a/python/lib.cpp b/python/lib.cpp index 60b648c..eab0282 100644 --- a/python/lib.cpp +++ b/python/lib.cpp @@ -47,7 +47,9 @@ PYBIND11_MODULE(compiled, m) { .def("get_num_clusters", &PDX::PyPDXIndex::GetNumClusters) .def("get_cluster_size", &PDX::PyPDXIndex::GetClusterSize, py::arg("cluster_id")) .def("get_cluster_row_ids", &PDX::PyPDXIndex::GetClusterRowIds, py::arg("cluster_id")) - .def("get_in_memory_size_in_bytes", &PDX::PyPDXIndex::GetInMemorySizeInBytes); + .def("get_in_memory_size_in_bytes", &PDX::PyPDXIndex::GetInMemorySizeInBytes) + .def("append", &PDX::PyPDXIndex::Append, py::arg("row_id"), py::arg("embedding")) + .def("delete", &PDX::PyPDXIndex::Delete, py::arg("row_id")); m.def( "load_index", diff --git a/python/pdxearch/index_factory.py b/python/pdxearch/index_factory.py index 5d02706..72b0fc5 100644 --- a/python/pdxearch/index_factory.py +++ b/python/pdxearch/index_factory.py @@ -44,6 +44,12 @@ def filtered_search(self, query: np.ndarray, knn: int, row_ids: np.ndarray, npro np.ascontiguousarray(row_ids, dtype=np.uint64), ) + def append(self, row_id: int, embedding: np.ndarray) -> None: + self._index.append(row_id, np.ascontiguousarray(embedding, dtype=np.float32)) + + def delete(self, row_id: int) -> None: + self._index.delete(row_id) + def save(self, path: str) -> None: self._index.save(path) @@ -97,6 +103,12 @@ def filtered_search(self, query: np.ndarray, knn: int, row_ids: np.ndarray, npro np.ascontiguousarray(row_ids, dtype=np.uint64), ) + def append(self, row_id: int, embedding: np.ndarray) -> None: + self._index.append(row_id, np.ascontiguousarray(embedding, dtype=np.float32)) + + def delete(self, row_id: int) -> None: + self._index.delete(row_id) + def save(self, path: str) -> None: self._index.save(path) @@ -151,6 +163,12 @@ def filtered_search(self, query: np.ndarray, knn: int, row_ids: np.ndarray, npro np.ascontiguousarray(row_ids, dtype=np.uint64), ) + def append(self, row_id: int, embedding: np.ndarray) -> None: + self._index.append(row_id, np.ascontiguousarray(embedding, dtype=np.float32)) + + def delete(self, row_id: int) -> None: + self._index.delete(row_id) + def save(self, path: str) -> None: self._index.save(path) @@ -205,6 +223,12 @@ def filtered_search(self, query: np.ndarray, knn: int, row_ids: np.ndarray, npro np.ascontiguousarray(row_ids, dtype=np.uint64), ) + def append(self, row_id: int, embedding: np.ndarray) -> None: + self._index.append(row_id, np.ascontiguousarray(embedding, dtype=np.float32)) + + def delete(self, row_id: int) -> None: + self._index.delete(row_id) + def save(self, path: str) -> None: self._index.save(path) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index dc8b128..508d6cc 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -43,12 +43,16 @@ target_link_libraries(test_filtered_search.out PRIVATE ${TEST_COMMON_LIBS}) add_executable(test_index_properties.out test_index_properties.cpp) target_link_libraries(test_index_properties.out PRIVATE ${TEST_COMMON_LIBS}) +add_executable(test_maintenance.out test_maintenance.cpp) +target_link_libraries(test_maintenance.out PRIVATE ${TEST_COMMON_LIBS}) + include(GoogleTest) gtest_discover_tests(test_distance_computers.out) gtest_discover_tests(test_search.out) gtest_discover_tests(test_serialization.out) gtest_discover_tests(test_filtered_search.out) gtest_discover_tests(test_index_properties.out) +gtest_discover_tests(test_maintenance.out) add_custom_target(tests DEPENDS @@ -57,4 +61,5 @@ add_custom_target(tests test_serialization.out test_filtered_search.out test_index_properties.out + test_maintenance.out ) diff --git a/tests/test_maintenance.cpp b/tests/test_maintenance.cpp new file mode 100644 index 0000000..03e5117 --- /dev/null +++ b/tests/test_maintenance.cpp @@ -0,0 +1,145 @@ +#undef HAS_FFTW + +#include +#include +#include +#include +#include + +#include "pdx/index.hpp" +#include "test_utils.hpp" + +namespace { + +static constexpr size_t D = 384; + +template +IndexT BuildTreeIndex(const float* data, size_t n, size_t d) { + PDX::PDXIndexConfig config{ + .num_dimensions = static_cast(d), + .distance_metric = PDX::DistanceMetric::L2SQ, + .seed = TestUtils::SEED, + .normalize = true, + .sampling_fraction = 1.0f, + .hierarchical_indexing = true, + }; + IndexT index(config); + index.BuildIndex(data, n); + return index; +} + +// Test 1: Build with N-1 points, insert the last one, search for it +template +void RunInsertSingleAndSearch() { + auto data = TestUtils::LoadTestData(D); + const size_t n_build = TestUtils::N_TRAIN - 1; + const size_t inserted_row_id = n_build; + + auto index = BuildTreeIndex(data.train.data(), n_build, D); + index.Append(inserted_row_id, data.train.data() + inserted_row_id * D); + index.SetNProbe(0); + + auto results = index.Search(data.train.data() + inserted_row_id * D, TestUtils::KNN); + + bool found = false; + for (const auto& r : results) { + if (r.index == static_cast(inserted_row_id)) { + found = true; + break; + } + } + EXPECT_TRUE(found) << "Inserted point (row_id=" << inserted_row_id + << ") not found in search results"; +} + +// Test 2: Build with N-10 points, insert 10, filtered search should return all 10 +template +void RunInsertMultipleAndFilteredSearch() { + auto data = TestUtils::LoadTestData(D); + const size_t n_insert = 10; + const size_t n_build = TestUtils::N_TRAIN - n_insert; + + auto index = BuildTreeIndex(data.train.data(), n_build, D); + + std::vector inserted_ids; + for (size_t i = 0; i < n_insert; ++i) { + size_t row_id = n_build + i; + index.Append(row_id, data.train.data() + row_id * D); + inserted_ids.push_back(row_id); + } + + index.SetNProbe(0); + + // Use the first inserted embedding as query + const float* query = data.train.data() + n_build * D; + auto results = index.FilteredSearch(query, n_insert, inserted_ids); + + std::unordered_set result_ids; + for (const auto& r : results) { + result_ids.insert(r.index); + } + + for (size_t id : inserted_ids) { + EXPECT_TRUE(result_ids.count(static_cast(id))) + << "Inserted point (row_id=" << id << ") not found in filtered search results"; + } +} + +// Test 3: Build with N-1 points, insert 1, delete it, search should not find it +template +void RunInsertDeleteAndSearch() { + auto data = TestUtils::LoadTestData(D); + const size_t n_build = TestUtils::N_TRAIN - 1; + const size_t inserted_row_id = n_build; + + auto index = BuildTreeIndex(data.train.data(), n_build, D); + index.Append(inserted_row_id, data.train.data() + inserted_row_id * D); + index.Delete(inserted_row_id); + index.SetNProbe(0); + + auto results = index.Search(data.train.data() + inserted_row_id * D, TestUtils::KNN); + + for (const auto& r : results) { + EXPECT_NE(r.index, static_cast(inserted_row_id)) + << "Deleted point (row_id=" << inserted_row_id + << ") should not appear in search results"; + } +} + +class MaintenanceTest : public ::testing::TestWithParam {}; + +TEST_P(MaintenanceTest, InsertSingleAndSearch) { + std::string index_type = GetParam(); + if (index_type == "pdx_tree_f32") { + RunInsertSingleAndSearch(); + } else { + RunInsertSingleAndSearch(); + } +} + +TEST_P(MaintenanceTest, InsertMultipleAndFilteredSearch) { + std::string index_type = GetParam(); + if (index_type == "pdx_tree_f32") { + RunInsertMultipleAndFilteredSearch(); + } else { + RunInsertMultipleAndFilteredSearch(); + } +} + +TEST_P(MaintenanceTest, InsertDeleteAndSearch) { + std::string index_type = GetParam(); + if (index_type == "pdx_tree_f32") { + RunInsertDeleteAndSearch(); + } else { + RunInsertDeleteAndSearch(); + } +} + +INSTANTIATE_TEST_SUITE_P( + TreeIndexTypes, + MaintenanceTest, + ::testing::Values("pdx_tree_f32", "pdx_tree_u8"), + [](const ::testing::TestParamInfo& info) { return info.param; } +); + +} // namespace