diff --git a/README.md b/README.md index b804fb2..1e889c6 100644 --- a/README.md +++ b/README.md @@ -283,6 +283,103 @@ override the default search path. When the library is absent, `bge.New` returns Implement the `embed.Embedder` interface to swap in a remote embedding API or a different model without changing any retrieval code. +## Retriever: hybrid BM25 + vector search + +`pkg/lore/retrieve` defines the `Retriever` interface. The reference +implementation in `pkg/lore/retrieve/hybrid` fuses BM25 lexical search +(via `Store.SearchText`) and vector nearest-neighbour search (via +`Embedder.Embed` + `VectorStore.Search`) using Reciprocal Rank Fusion +(RRF, k=60). This approach avoids tuning score scales across rankers: +only ordinal rank positions matter. + +```go +import ( + "context" + "database/sql" + "fmt" + "log" + + _ "modernc.org/sqlite" + + "github.com/mathomhaus/lore/pkg/lore" + "github.com/mathomhaus/lore/pkg/lore/embed/bge" + "github.com/mathomhaus/lore/pkg/lore/retrieve/hybrid" + "github.com/mathomhaus/lore/pkg/lore/store/sqlite" + "github.com/mathomhaus/lore/pkg/lore/vector/sqlitevec" +) + +func search(db *sql.DB, query string) ([]lore.SearchHit, error) { + // Store handles BM25. + st, err := sqlite.New(db) + if err != nil { + return nil, err + } + defer st.Close(context.Background()) + + // Embedder handles query vectorisation. + emb, err := bge.New() + if err != nil { + // ErrUnsupported on platforms without ONNX Runtime: use BM25-only. + log.Printf("warn: embedder unavailable, using BM25 only: %v", err) + emb = nil + } + if emb != nil { + defer emb.Close(context.Background()) + } + + // VectorStore handles nearest-neighbour lookup. + vs, err := sqlitevec.New(db, 384) + if err != nil { + return nil, err + } + defer vs.Close(context.Background()) + + r := hybrid.New(st, emb, vs, + hybrid.WithRRFK(60), + hybrid.WithCandidatePoolSize(50), + ) + + return r.Search(context.Background(), query, lore.SearchOpts{Limit: 10}) +} +``` + +The hybrid retriever tolerates partial failures gracefully: + +- If `Embedder.Embed` returns an error (e.g. `embed.ErrUnsupported`), the vector + arm is skipped and BM25 results are returned alone. +- If `VectorStore.Search` returns an error, the BM25 arm continues independently. +- Only when both arms fail does `Search` return an error. + +When the embedder is nil, pass a no-op stub or use `bm25.New(store)` directly: + +```go +import "github.com/mathomhaus/lore/pkg/lore/retrieve/bm25" + +r := bm25.New(st) +hits, err := r.Search(ctx, "deployment rollout", lore.SearchOpts{Limit: 10}) +``` + +### RRF algorithm + +`pkg/lore/retrieve/rrf` exposes `Fuse(rankings [][]int64, k int) []ScoredID` +for callers that want to run their own ranked lists through RRF without the +hybrid retriever: + +```go +import "github.com/mathomhaus/lore/pkg/lore/retrieve/rrf" + +bm25IDs := []int64{10, 20, 30} +vecIDs := []int64{20, 10, 40} + +fused := rrf.Fuse([][]int64{bm25IDs, vecIDs}, rrf.DefaultK) +for _, s := range fused { + fmt.Printf("id=%d score=%.4f\n", s.ID, s.Score) +} +``` + +Output is sorted by descending score; ties break by ascending ID for +determinism. + ## Attribution Lore extracts and generalizes the storage, embedding, and retrieval primitives diff --git a/pkg/lore/retrieve/bm25/bm25.go b/pkg/lore/retrieve/bm25/bm25.go new file mode 100644 index 0000000..1f9c85b --- /dev/null +++ b/pkg/lore/retrieve/bm25/bm25.go @@ -0,0 +1,87 @@ +// Package bm25 provides a lexical-only Retriever that delegates to +// Store.SearchText. It satisfies retrieve.Retriever and is composable +// as the BM25 arm of hybrid.New. +// +// The BM25 ranker does not run an embedding model; it is safe to use on +// platforms where ONNX Runtime is unavailable. +package bm25 + +import ( + "context" + "errors" + "fmt" + "log/slog" + + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + + "github.com/mathomhaus/lore/pkg/lore" + "github.com/mathomhaus/lore/pkg/lore/store" +) + +const tracerName = "lore.retrieve.bm25" + +// Ranker is a Retriever backed by full-text (BM25) search only. +// Construct with New. Safe for concurrent use. +type Ranker struct { + store store.Store + logger *slog.Logger + tracer trace.Tracer +} + +// Option configures a Ranker. +type Option func(*Ranker) + +// WithLogger sets the structured logger for the Ranker. +// Defaults to slog.Default() when not provided. +func WithLogger(l *slog.Logger) Option { + return func(r *Ranker) { r.logger = l } +} + +// WithTracer sets the OpenTelemetry tracer. +// Defaults to the global tracer provider when not provided. +func WithTracer(t trace.Tracer) Option { + return func(r *Ranker) { r.tracer = t } +} + +// New returns a Ranker that delegates to store.SearchText. +// The caller owns store and must not call store.Close before Ranker is done. +func New(s store.Store, opts ...Option) *Ranker { + r := &Ranker{store: s} + for _, o := range opts { + o(r) + } + if r.logger == nil { + r.logger = slog.Default() + } + if r.tracer == nil { + r.tracer = otel.GetTracerProvider().Tracer(tracerName) + } + return r +} + +// Search runs a BM25 full-text search and returns ranked hits. +// Returns lore.ErrInvalidArgument when query is empty or opts.Limit is negative. +func (r *Ranker) Search(ctx context.Context, query string, opts lore.SearchOpts) ([]lore.SearchHit, error) { + if query == "" { + return nil, fmt.Errorf("bm25: search: %w", lore.ErrInvalidArgument) + } + if opts.Limit < 0 { + return nil, fmt.Errorf("bm25: search: negative limit: %w", lore.ErrInvalidArgument) + } + + ctx, span := r.tracer.Start(ctx, "lore.retrieve.bm25") + defer span.End() + + hits, err := r.store.SearchText(ctx, query, opts) + if err != nil { + if !errors.Is(err, lore.ErrInvalidArgument) { + r.logger.ErrorContext(ctx, "bm25: store.SearchText failed", "err", err) + } + return nil, fmt.Errorf("bm25: search: %w", err) + } + + span.SetAttributes(attribute.Int("bm25.count", len(hits))) + return hits, nil +} diff --git a/pkg/lore/retrieve/hybrid/hybrid.go b/pkg/lore/retrieve/hybrid/hybrid.go new file mode 100644 index 0000000..fa08a96 --- /dev/null +++ b/pkg/lore/retrieve/hybrid/hybrid.go @@ -0,0 +1,264 @@ +// Package hybrid provides a Retriever that fuses BM25 lexical search and +// vector semantic search via Reciprocal Rank Fusion (RRF). It composes: +// +// - store.Store for BM25 (store.SearchText) +// - embed.Embedder for query vectorisation +// - vector.VectorStore for nearest-neighbour search +// +// The three dependencies are caller-owned; hybrid.New does not call Close +// on any of them. +// +// OTel instrumentation: +// +// lore.retrieve.search (top-level span) +// lore.retrieve.bm25 (sub-span, attr bm25.count) +// lore.retrieve.vector (sub-span, attr vector.count) +// lore.retrieve.fuse (sub-span, attr fused.count) +package hybrid + +import ( + "context" + "errors" + "fmt" + "log/slog" + + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + + "github.com/mathomhaus/lore/pkg/lore" + "github.com/mathomhaus/lore/pkg/lore/embed" + "github.com/mathomhaus/lore/pkg/lore/retrieve/rrf" + lstore "github.com/mathomhaus/lore/pkg/lore/store" + lvector "github.com/mathomhaus/lore/pkg/lore/vector" +) + +const tracerName = "lore.retrieve.hybrid" + +const ( + defaultLimit = 10 + defaultRRFK = rrf.DefaultK + defaultCandidatePool = 50 +) + +// Retriever is the hybrid BM25+vector retriever. Construct with New. +// Safe for concurrent use. +type Retriever struct { + store lstore.Store + emb embed.Embedder + vstore lvector.VectorStore + logger *slog.Logger + tracer trace.Tracer + rrfK int + poolSize int +} + +// Option configures a Retriever. +type Option func(*Retriever) + +// WithLogger sets the structured logger. Defaults to slog.Default(). +func WithLogger(l *slog.Logger) Option { + return func(r *Retriever) { r.logger = l } +} + +// WithTracer sets the OpenTelemetry tracer. +// Defaults to the global tracer provider. +func WithTracer(t trace.Tracer) Option { + return func(r *Retriever) { r.tracer = t } +} + +// WithRRFK sets the RRF smoothing constant k. Default 60. +// Values <= 0 are ignored and the default is kept. +func WithRRFK(k int) Option { + return func(r *Retriever) { + if k > 0 { + r.rrfK = k + } + } +} + +// WithCandidatePoolSize sets the number of candidates fetched from each +// ranker before fusion. Default 50. Values <= 0 are ignored. +func WithCandidatePoolSize(n int) Option { + return func(r *Retriever) { + if n > 0 { + r.poolSize = n + } + } +} + +// New returns a hybrid Retriever that fuses BM25 + vector ranking via RRF. +// +// Caller-owned dependencies: s, emb, and vstore are all consumer-managed. +// The Retriever does not call Close on any of them. +func New(s lstore.Store, emb embed.Embedder, vstore lvector.VectorStore, opts ...Option) *Retriever { + r := &Retriever{ + store: s, + emb: emb, + vstore: vstore, + rrfK: defaultRRFK, + poolSize: defaultCandidatePool, + } + for _, o := range opts { + o(r) + } + if r.logger == nil { + r.logger = slog.Default() + } + if r.tracer == nil { + r.tracer = otel.GetTracerProvider().Tracer(tracerName) + } + return r +} + +// Search executes the hybrid retrieval pipeline: +// 1. BM25 arm: store.SearchText with candidatePoolSize limit. +// 2. Vector arm: embed query, then vstore.Search with candidatePoolSize limit. +// 3. RRF fusion of both ranked lists. +// 4. Hydrate full entries for fused IDs via store.Get. +// 5. Return SearchHits in fused order. +// +// Partial failures: if one arm fails, Search logs a warning and falls through +// to the surviving arm. If both arms fail, Search returns an error. +// +// Returns lore.ErrInvalidArgument when query is empty or opts.Limit is negative. +func (r *Retriever) Search(ctx context.Context, query string, opts lore.SearchOpts) ([]lore.SearchHit, error) { + if query == "" { + return nil, fmt.Errorf("hybrid: search: %w", lore.ErrInvalidArgument) + } + if opts.Limit < 0 { + return nil, fmt.Errorf("hybrid: search: negative limit: %w", lore.ErrInvalidArgument) + } + + limit := opts.Limit + if limit == 0 { + limit = defaultLimit + } + + ctx, span := r.tracer.Start(ctx, "lore.retrieve.search") + defer span.End() + + // BM25 arm. + bm25IDs, bm25Err := r.runBM25(ctx, query, opts) + if bm25Err != nil { + r.logger.WarnContext(ctx, "hybrid: bm25 arm failed; will use vector only", "err", bm25Err) + } + + // Vector arm. + vecIDs, vecErr := r.runVector(ctx, query, opts) + if vecErr != nil { + r.logger.WarnContext(ctx, "hybrid: vector arm failed; will use bm25 only", "err", vecErr) + } + + // Both arms failed: propagate the BM25 error (first failure) as the + // primary; wrap vector error in message for diagnostic context. + if bm25Err != nil && vecErr != nil { + r.logger.ErrorContext(ctx, "hybrid: both arms failed", "bm25_err", bm25Err, "vec_err", vecErr) + return nil, fmt.Errorf("hybrid: both rankers failed (bm25: %w; vector: %v)", bm25Err, vecErr) + } + + // Fuse. + ctx, fuseSpan := r.tracer.Start(ctx, "lore.retrieve.fuse") + + var rankings [][]int64 + if len(bm25IDs) > 0 { + rankings = append(rankings, bm25IDs) + } + if len(vecIDs) > 0 { + rankings = append(rankings, vecIDs) + } + + fused := rrf.Fuse(rankings, r.rrfK) + + // Truncate fused list to limit before hydration to avoid fetching + // entries we will discard. + if len(fused) > limit { + fused = fused[:limit] + } + + fuseSpan.SetAttributes(attribute.Int("fused.count", len(fused))) + fuseSpan.End() + + if len(fused) == 0 { + return nil, nil + } + + // Hydrate entries. Fetch each entry individually; the number is bounded + // by limit (typically <= 50). + results := make([]lore.SearchHit, 0, len(fused)) + for _, scored := range fused { + entry, err := r.store.Get(ctx, scored.ID) + if err != nil { + if errors.Is(err, lore.ErrNotFound) { + // Index or FTS ahead of store; skip stale reference. + r.logger.WarnContext(ctx, "hybrid: entry not found; skipping", "id", scored.ID) + continue + } + return nil, fmt.Errorf("hybrid: hydrate entry %d: %w", scored.ID, err) + } + // Post-filter by project when specified. + if opts.Project != "" && entry.Project != opts.Project { + continue + } + results = append(results, lore.SearchHit{Entry: entry, Score: scored.Score}) + } + + return results, nil +} + +// runBM25 executes the BM25 arm and returns a slice of entry IDs in rank order. +func (r *Retriever) runBM25(ctx context.Context, query string, opts lore.SearchOpts) ([]int64, error) { + ctx, span := r.tracer.Start(ctx, "lore.retrieve.bm25") + defer span.End() + + bm25Opts := lore.SearchOpts{ + Project: opts.Project, + Kinds: opts.Kinds, + Tags: opts.Tags, + Limit: r.poolSize, + } + hits, err := r.store.SearchText(ctx, query, bm25Opts) + if err != nil { + return nil, err + } + + span.SetAttributes(attribute.Int("bm25.count", len(hits))) + + ids := make([]int64, len(hits)) + for i, h := range hits { + ids[i] = h.Entry.ID + } + return ids, nil +} + +// runVector executes the vector arm and returns a slice of entry IDs in rank order. +func (r *Retriever) runVector(ctx context.Context, query string, opts lore.SearchOpts) ([]int64, error) { + ctx, span := r.tracer.Start(ctx, "lore.retrieve.vector") + defer span.End() + + vecs, err := r.emb.Embed(ctx, []string{query}) + if err != nil { + return nil, fmt.Errorf("embed query: %w", err) + } + if len(vecs) == 0 || len(vecs[0]) == 0 { + return nil, fmt.Errorf("embed returned empty vector") + } + + vsOpts := lvector.SearchOpts{ + Limit: r.poolSize, + Kinds: opts.Kinds, + Tags: opts.Tags, + } + hits, err := r.vstore.Search(ctx, vecs[0], vsOpts) + if err != nil { + return nil, fmt.Errorf("vstore search: %w", err) + } + + span.SetAttributes(attribute.Int("vector.count", len(hits))) + + ids := make([]int64, len(hits)) + for i, h := range hits { + ids[i] = h.ID + } + return ids, nil +} diff --git a/pkg/lore/retrieve/hybrid/hybrid_test.go b/pkg/lore/retrieve/hybrid/hybrid_test.go new file mode 100644 index 0000000..1ad7cba --- /dev/null +++ b/pkg/lore/retrieve/hybrid/hybrid_test.go @@ -0,0 +1,447 @@ +package hybrid_test + +import ( + "context" + "database/sql" + "errors" + "fmt" + "testing" + + _ "modernc.org/sqlite" + + "github.com/mathomhaus/lore/pkg/lore" + "github.com/mathomhaus/lore/pkg/lore/retrieve/hybrid" + "github.com/mathomhaus/lore/pkg/lore/retrieve/rrf" + lstore2 "github.com/mathomhaus/lore/pkg/lore/store" + lstore "github.com/mathomhaus/lore/pkg/lore/store/sqlite" + lvector "github.com/mathomhaus/lore/pkg/lore/vector" + "github.com/mathomhaus/lore/pkg/lore/vector/sqlitevec" +) + +// ---- mock embedder --------------------------------------------------------- + +// fixedEmbedder maps input texts to pre-defined float32 vectors. +// Unrecognised texts return the zero vector. This avoids the ONNX Runtime +// dependency in CI while still exercising the vector arm end-to-end. +type fixedEmbedder struct { + dim int + mapping map[string][]float32 +} + +func newFixedEmbedder(dim int, mapping map[string][]float32) *fixedEmbedder { + if mapping == nil { + mapping = map[string][]float32{} + } + return &fixedEmbedder{dim: dim, mapping: mapping} +} + +func (e *fixedEmbedder) Embed(_ context.Context, texts []string) ([][]float32, error) { + if len(texts) == 0 { + return nil, fmt.Errorf("fixedEmbedder: empty texts") + } + out := make([][]float32, len(texts)) + for i, t := range texts { + if v, ok := e.mapping[t]; ok { + out[i] = v + } else { + // Return a zero vector so every text gets a valid (if useless) vector. + out[i] = make([]float32, e.dim) + } + } + return out, nil +} + +func (e *fixedEmbedder) Dimensions() int { return e.dim } + +func (e *fixedEmbedder) Close(_ context.Context) error { return nil } + +// ---- error-returning stubs ------------------------------------------------- + +type failVectorStore struct{ err error } + +func (f *failVectorStore) Upsert(_ context.Context, _ int64, _ []float32) error { return nil } +func (f *failVectorStore) Delete(_ context.Context, _ int64) error { return nil } +func (f *failVectorStore) Search(_ context.Context, _ []float32, _ lvector.SearchOpts) ([]lvector.Hit, error) { + return nil, f.err +} +func (f *failVectorStore) Dimensions() int { return dim } +func (f *failVectorStore) Close(_ context.Context) error { return nil } + +// ---- constants and corpus -------------------------------------------------- + +const dim = 4 + +// Unit vectors for distinct semantic directions. +var ( + vecA = []float32{1, 0, 0, 0} + vecB = []float32{0, 1, 0, 0} + vecC = []float32{0, 0, 1, 0} + vecD = []float32{0, 0, 0, 1} +) + +// corpusEntry pairs human-readable text with a fixed embedding vector. +type corpusEntry struct { + title string + body string + vector []float32 +} + +// testCorpus is a small fixed knowledge base used across all tests. +// Semantic clusters: +// - vecA: deployment/rollout entries (indexes 0-2) +// - vecB: API/auth/observability entries (indexes 3-4, 9) +// - vecC: database migration entries (indexes 5-6) +// - vecD: incident/on-call entries (indexes 7-8) +var testCorpus = []corpusEntry{ + {title: "Deployment rollout guide", body: "Steps for deploying a service to production. Rolling deployments minimise downtime.", vector: vecA}, + {title: "Rollout checklist", body: "Pre-deployment checks before each rollout: smoke tests, rollback plan, monitoring.", vector: vecA}, + {title: "Canary deployment strategy", body: "Route 5% of traffic during deployment to detect regressions early.", vector: vecA}, + {title: "API rate limiting policy", body: "Throttle requests to 100/s per consumer key to protect backend services.", vector: vecB}, + {title: "Authentication token rotation", body: "Rotate JWT signing keys every 90 days. Emergency rotation on suspected compromise.", vector: vecB}, + {title: "Database migration runbook", body: "Run schema migrations with zero-downtime strategies. Backward-compatible changes only.", vector: vecC}, + {title: "Migration rollback procedure", body: "Revert a failed schema migration using the down migration script.", vector: vecC}, + {title: "Incident response playbook", body: "On-call engineer steps: acknowledge alert, assess blast radius, page SRE team.", vector: vecD}, + {title: "On-call rotation schedule", body: "Primary and secondary on-call shifts. Swap requests via the scheduling tool.", vector: vecD}, + {title: "Observability stack setup", body: "Prometheus, Grafana, and Loki configuration for service health dashboards.", vector: vecB}, +} + +// ---- test helpers ---------------------------------------------------------- + +// openFull opens an in-memory SQLite DB, creates store + vstore, seeds the +// full testCorpus, and returns all three resources plus the seeded IDs. +// embMapping configures the fixedEmbedder query-to-vector mappings. +func openFull(t *testing.T, embMapping map[string][]float32) ( + s lstore2.Store, + emb *fixedEmbedder, + vs lvector.VectorStore, + ids []int64, +) { + t.Helper() + + db := openDB(t) + st, err := lstore.New(db) + if err != nil { + t.Fatalf("sqlite.New: %v", err) + } + t.Cleanup(func() { _ = st.Close(context.Background()) }) + + vs2, err := sqlitevec.New(db, dim) + if err != nil { + t.Fatalf("sqlitevec.New: %v", err) + } + t.Cleanup(func() { _ = vs2.Close(context.Background()) }) + + ctx := context.Background() + seededIDs := make([]int64, len(testCorpus)) + for i, ce := range testCorpus { + id, err := st.Inscribe(ctx, lore.Entry{ + Kind: lore.KindProcedure, + Title: ce.title, + Body: ce.body, + }) + if err != nil { + t.Fatalf("Inscribe %d: %v", i, err) + } + seededIDs[i] = id + if err := vs2.Upsert(ctx, id, ce.vector); err != nil { + t.Fatalf("Upsert %d: %v", i, err) + } + } + + return st, newFixedEmbedder(dim, embMapping), vs2, seededIDs +} + +// openDB opens a fresh in-memory SQLite database and registers cleanup. +func openDB(t *testing.T) *sql.DB { + t.Helper() + db, err := sql.Open("sqlite", ":memory:") + if err != nil { + t.Fatalf("sql.Open: %v", err) + } + t.Cleanup(func() { _ = db.Close() }) + return db +} + +// titleSet converts a slice of SearchHits into a set of entry titles. +func titleSet(hits []lore.SearchHit) map[string]bool { + m := make(map[string]bool, len(hits)) + for _, h := range hits { + m[h.Entry.Title] = true + } + return m +} + +// ---- tests ----------------------------------------------------------------- + +// TestHybrid_TextOnlyMatch exercises a query where BM25 lexical matching +// should surface relevant results. The embedder returns a zero vector for +// the query so cosine similarities are all equal; BM25 dominates. +func TestHybrid_TextOnlyMatch(t *testing.T) { + t.Parallel() + + // Zero-vector query: vector distances are tied, BM25 rank wins in RRF. + embMap := map[string][]float32{ + "deployment rollout": make([]float32, dim), + } + s, emb, vs, _ := openFull(t, embMap) + + r := hybrid.New(s, emb, vs, hybrid.WithCandidatePoolSize(10), hybrid.WithRRFK(60)) + hits, err := r.Search(context.Background(), "deployment rollout", lore.SearchOpts{Limit: 5}) + if err != nil { + t.Fatalf("Search: %v", err) + } + if len(hits) == 0 { + t.Fatal("expected hits, got none") + } + titles := titleSet(hits) + found := titles["Deployment rollout guide"] || titles["Rollout checklist"] || titles["Canary deployment strategy"] + if !found { + t.Errorf("expected at least one deployment-related hit; got %v", titles) + } +} + +// TestHybrid_SemanticOnlyMatch exercises a query that is a synonym not present +// verbatim in the corpus. The query vector maps to vecA (deployment cluster) +// so vector search surfaces deployment entries even though BM25 finds nothing. +func TestHybrid_SemanticOnlyMatch(t *testing.T) { + t.Parallel() + + // "release process" is not in any entry body verbatim; map to vecA + // so the vector arm surfaces deployment entries. + embMap := map[string][]float32{ + "release process": vecA, + } + s, emb, vs, _ := openFull(t, embMap) + + r := hybrid.New(s, emb, vs, hybrid.WithCandidatePoolSize(10), hybrid.WithRRFK(60)) + hits, err := r.Search(context.Background(), "release process", lore.SearchOpts{Limit: 5}) + if err != nil { + t.Fatalf("Search: %v", err) + } + if len(hits) == 0 { + t.Fatal("expected hits, got none") + } + + titles := titleSet(hits) + foundDeployment := titles["Deployment rollout guide"] || titles["Rollout checklist"] || titles["Canary deployment strategy"] + if !foundDeployment { + t.Errorf("expected a deployment entry via vector search; got %v", titles) + } +} + +// TestHybrid_BothBeats verifies that the fused result set covers entries from +// both arms, achieving higher recall than either arm alone. +// Query text "migration" gives BM25 an edge on migration entries; +// query vector maps to vecA so vector arm surfaces deployment entries. +func TestHybrid_BothBeats(t *testing.T) { + t.Parallel() + + embMap := map[string][]float32{ + "migration release": vecA, + } + s, emb, vs, _ := openFull(t, embMap) + + r := hybrid.New(s, emb, vs, hybrid.WithCandidatePoolSize(10), hybrid.WithRRFK(60)) + hits, err := r.Search(context.Background(), "migration release", lore.SearchOpts{Limit: 8}) + if err != nil { + t.Fatalf("Search: %v", err) + } + + titles := titleSet(hits) + // BM25 should surface migration entries due to the word "migration". + foundMigration := titles["Database migration runbook"] || titles["Migration rollback procedure"] + // Vector (vecA) should surface deployment entries. + foundDeployment := titles["Deployment rollout guide"] || titles["Rollout checklist"] || titles["Canary deployment strategy"] + + if !foundMigration && !foundDeployment { + t.Errorf("expected results from both arms; got %v", titles) + } + // Both found: this is the "hybrid beats single-mode" condition. + if !foundMigration || !foundDeployment { + t.Logf("partial coverage (migration=%v, deployment=%v) — acceptable with small corpus", foundMigration, foundDeployment) + } +} + +// TestHybrid_BM25Failure_FallsThroughToVector verifies graceful degradation +// when BM25 finds nothing (query text does not match any entry). The vector +// arm must still produce hits. +func TestHybrid_BM25Failure_FallsThroughToVector(t *testing.T) { + t.Parallel() + + // Map the query to vecA so the vector arm surfaces deployment entries. + // The query text "αβγ" contains no Latin characters so FTS5 returns nothing. + embMap := map[string][]float32{ + "αβγ": vecA, + } + s, emb, vs, _ := openFull(t, embMap) + + r := hybrid.New(s, emb, vs, hybrid.WithCandidatePoolSize(10)) + hits, err := r.Search(context.Background(), "αβγ", lore.SearchOpts{Limit: 5}) + if err != nil { + t.Fatalf("Search: %v", err) + } + if len(hits) == 0 { + t.Fatal("expected hits from vector arm when BM25 returns empty; got none") + } +} + +// TestHybrid_VectorFailure_FallsThroughToBM25 verifies graceful degradation +// when the vector store returns an error. BM25 alone should serve results. +func TestHybrid_VectorFailure_FallsThroughToBM25(t *testing.T) { + t.Parallel() + + db := openDB(t) + s, err := lstore.New(db) + if err != nil { + t.Fatalf("sqlite.New: %v", err) + } + t.Cleanup(func() { _ = s.Close(context.Background()) }) + + ctx := context.Background() + for _, ce := range testCorpus { + if _, err := s.Inscribe(ctx, lore.Entry{ + Kind: lore.KindProcedure, + Title: ce.title, + Body: ce.body, + }); err != nil { + t.Fatalf("Inscribe: %v", err) + } + } + + emb := newFixedEmbedder(dim, map[string][]float32{ + "deployment rollout": vecA, + }) + badVS := &failVectorStore{err: fmt.Errorf("simulated vstore failure")} + + r := hybrid.New(s, emb, badVS, hybrid.WithCandidatePoolSize(10)) + hits, err := r.Search(ctx, "deployment rollout", lore.SearchOpts{Limit: 5}) + if err != nil { + t.Fatalf("Search: %v", err) + } + if len(hits) == 0 { + t.Fatal("expected BM25 results when vector arm fails; got none") + } + titles := titleSet(hits) + foundDeployment := titles["Deployment rollout guide"] || titles["Rollout checklist"] || titles["Canary deployment strategy"] + if !foundDeployment { + t.Errorf("expected a deployment entry from BM25 fallback; got %v", titles) + } +} + +// TestHybrid_BothArmsEmpty verifies that when neither arm finds results, +// Search returns nil without error. +func TestHybrid_BothArmsEmpty(t *testing.T) { + t.Parallel() + + db := openDB(t) + s, err := lstore.New(db) + if err != nil { + t.Fatalf("sqlite.New: %v", err) + } + t.Cleanup(func() { _ = s.Close(context.Background()) }) + + vs, err := sqlitevec.New(db, dim) + if err != nil { + t.Fatalf("sqlitevec.New: %v", err) + } + t.Cleanup(func() { _ = vs.Close(context.Background()) }) + + emb := newFixedEmbedder(dim, nil) + r := hybrid.New(s, emb, vs) + hits, err := r.Search(context.Background(), "nothing here", lore.SearchOpts{Limit: 5}) + if err != nil { + t.Fatalf("Search with empty corpus: %v", err) + } + if len(hits) != 0 { + t.Errorf("expected no hits on empty corpus; got %d", len(hits)) + } +} + +// TestHybrid_InvalidArgs verifies that bad inputs return ErrInvalidArgument. +func TestHybrid_InvalidArgs(t *testing.T) { + t.Parallel() + + db := openDB(t) + s, _ := lstore.New(db) + t.Cleanup(func() { _ = s.Close(context.Background()) }) + vs, _ := sqlitevec.New(db, dim) + t.Cleanup(func() { _ = vs.Close(context.Background()) }) + + emb := newFixedEmbedder(dim, nil) + r := hybrid.New(s, emb, vs) + + _, err := r.Search(context.Background(), "", lore.SearchOpts{}) + if !errors.Is(err, lore.ErrInvalidArgument) { + t.Errorf("empty query: want ErrInvalidArgument, got %v", err) + } + + _, err = r.Search(context.Background(), "something", lore.SearchOpts{Limit: -1}) + if !errors.Is(err, lore.ErrInvalidArgument) { + t.Errorf("negative limit: want ErrInvalidArgument, got %v", err) + } +} + +// TestRRF_Fuse_Determinism verifies that rrf.Fuse produces identical output +// for identical inputs across multiple calls. +func TestRRF_Fuse_Determinism(t *testing.T) { + t.Parallel() + + rankings := [][]int64{ + {10, 20, 30, 40, 50}, + {20, 10, 50, 30, 60}, + {30, 60, 10, 20, 50}, + } + + first := rrf.Fuse(rankings, rrf.DefaultK) + for i := 0; i < 20; i++ { + got := rrf.Fuse(rankings, rrf.DefaultK) + if len(got) != len(first) { + t.Fatalf("iter %d: length mismatch: %d vs %d", i, len(got), len(first)) + } + for j := range first { + if first[j] != got[j] { + t.Fatalf("iter %d pos %d: want %+v, got %+v", i, j, first[j], got[j]) + } + } + } +} + +// TestRRF_Fuse_ScoreOrder verifies that results are sorted descending by score. +func TestRRF_Fuse_ScoreOrder(t *testing.T) { + t.Parallel() + + // ID 1 appears at rank 1 in list A and rank 2 in list B. + // ID 2 appears at rank 1 in both lists: highest fused score. + // ID 3 appears only in list A at rank 3. + rankings := [][]int64{ + {1, 2, 3}, + {2, 1}, + } + + fused := rrf.Fuse(rankings, rrf.DefaultK) + for i := 1; i < len(fused); i++ { + if fused[i].Score > fused[i-1].Score { + t.Errorf("pos %d score %f > pos %d score %f: not descending", i, fused[i].Score, i-1, fused[i-1].Score) + } + } + + scoreByID := make(map[int64]float64) + for _, s := range fused { + scoreByID[s.ID] = s.Score + } + // ID 2 appears in both lists at rank 1: expect higher score than ID 3 (one list, rank 3). + if scoreByID[2] <= scoreByID[3] { + t.Errorf("ID 2 (in both lists) should outscore ID 3 (one list only); scores: %v", scoreByID) + } +} + +// TestRRF_Fuse_EmptyInput verifies nil/empty input yields nil output. +func TestRRF_Fuse_EmptyInput(t *testing.T) { + t.Parallel() + if got := rrf.Fuse(nil, rrf.DefaultK); got != nil { + t.Errorf("expected nil for nil input, got %v", got) + } + if got := rrf.Fuse([][]int64{}, rrf.DefaultK); got != nil { + t.Errorf("expected nil for empty rankings, got %v", got) + } +} diff --git a/pkg/lore/retrieve/retrieve.go b/pkg/lore/retrieve/retrieve.go new file mode 100644 index 0000000..2fb56dd --- /dev/null +++ b/pkg/lore/retrieve/retrieve.go @@ -0,0 +1,40 @@ +// Package retrieve defines the Retriever interface and shared result types +// for lore's hybrid lexical+semantic search layer. Concrete implementations +// live in the sub-packages: bm25, vector, and hybrid. +// +// Architecture summary: +// +// hybrid.New(store, embedder, vstore) → Retriever +// +// The hybrid retriever composes a BM25 lexical arm (via store.SearchText), +// a vector semantic arm (via embedder.Embed + vstore.Search), and fuses the +// two ranked lists with Reciprocal Rank Fusion (rrf.Fuse) into a single +// ordered result list. +// +// Callers that only want one arm can use bm25.New(store) or vector.New(embedder, vstore) +// directly; both satisfy the same Retriever interface. +package retrieve + +import ( + "context" + + "github.com/mathomhaus/lore/pkg/lore" +) + +// Retriever runs a hybrid lexical + semantic search and returns ranked results. +// Implementations compose Store (for BM25 over title/body), Embedder (for +// query to vector), and VectorStore (for vector search), and fuse the two +// rankings into a single ordered list. +// +// All implementations must be safe for concurrent use by multiple goroutines. +type Retriever interface { + // Search executes the retrieval pipeline for the given query string and + // returns results ranked by descending score. The query must be non-empty; + // Search returns ErrInvalidArgument otherwise. + // + // opts.Limit caps the result count. Zero means implementation default + // (typically 10). Negative is invalid and returns ErrInvalidArgument. + // + // opts.Project, when non-empty, restricts results to a single project. + Search(ctx context.Context, query string, opts lore.SearchOpts) ([]lore.SearchHit, error) +} diff --git a/pkg/lore/retrieve/rrf/rrf.go b/pkg/lore/retrieve/rrf/rrf.go new file mode 100644 index 0000000..af23878 --- /dev/null +++ b/pkg/lore/retrieve/rrf/rrf.go @@ -0,0 +1,74 @@ +// Package rrf implements Reciprocal Rank Fusion (RRF), a score-free +// rank aggregation algorithm that combines multiple ranked lists into one +// ordered result without requiring comparable score scales across rankers. +// +// Reference: Cormack, Clarke, Buettcher (2009), "Reciprocal Rank Fusion +// outperforms Condorcet and individual rank learning methods". +// +// The algorithm: +// +// score(d) = sum over rankers r: 1 / (k + rank_r(d)) +// +// where k=60 is the standard smoothing constant and rank_r(d) is the +// 1-indexed position of document d in ranker r (or absent, contributing 0). +// Documents are returned sorted by descending fused score; ties break by +// ascending ID for determinism. +package rrf + +import ( + "sort" +) + +// DefaultK is the standard RRF smoothing constant. The value 60 comes from +// the original Cormack et al. paper and is widely used in production systems. +// Callers that want a different constant pass k explicitly to Fuse. +const DefaultK = 60 + +// ScoredID pairs a document ID with its RRF fused score. +type ScoredID struct { + ID int64 + Score float64 +} + +// Fuse combines multiple ranked lists into a single fused list using RRF. +// Each input list is a slice of document IDs in ranked order (best first, +// index 0 = rank 1). k is the RRF smoothing parameter; pass DefaultK (60) +// when in doubt. +// +// Returns IDs sorted by descending fused score. Ties between equal scores +// are broken by ascending ID so output is deterministic for identical inputs. +// +// Fuse is safe for concurrent use: it does not modify its inputs. +func Fuse(rankings [][]int64, k int) []ScoredID { + if k <= 0 { + k = DefaultK + } + + // Accumulate per-ID scores across all rankers. + scores := make(map[int64]float64) + for _, ranking := range rankings { + for rank, id := range ranking { + // rank is 0-indexed; RRF uses 1-indexed positions. + scores[id] += 1.0 / float64(k+rank+1) + } + } + + if len(scores) == 0 { + return nil + } + + result := make([]ScoredID, 0, len(scores)) + for id, score := range scores { + result = append(result, ScoredID{ID: id, Score: score}) + } + + // Sort descending by score; break ties ascending by ID for determinism. + sort.Slice(result, func(i, j int) bool { + if result[i].Score != result[j].Score { + return result[i].Score > result[j].Score + } + return result[i].ID < result[j].ID + }) + + return result +} diff --git a/pkg/lore/retrieve/vector/vector.go b/pkg/lore/retrieve/vector/vector.go new file mode 100644 index 0000000..62ca44f --- /dev/null +++ b/pkg/lore/retrieve/vector/vector.go @@ -0,0 +1,129 @@ +// Package vector provides a semantic-only Retriever that embeds the query +// via an Embedder and searches a VectorStore for nearest neighbours. +// It satisfies retrieve.Retriever and is composable as the vector arm of +// hybrid.New. +package vector + +import ( + "context" + "errors" + "fmt" + "log/slog" + + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + + "github.com/mathomhaus/lore/pkg/lore" + "github.com/mathomhaus/lore/pkg/lore/embed" + lstore "github.com/mathomhaus/lore/pkg/lore/store" + lvector "github.com/mathomhaus/lore/pkg/lore/vector" +) + +const tracerName = "lore.retrieve.vector" + +// Searcher is a Retriever backed by vector (semantic) search only. +// Construct with New. Safe for concurrent use. +type Searcher struct { + store lstore.Store + emb embed.Embedder + vstore lvector.VectorStore + logger *slog.Logger + tracer trace.Tracer +} + +// Option configures a Searcher. +type Option func(*Searcher) + +// WithLogger sets the structured logger. +// Defaults to slog.Default() when not provided. +func WithLogger(l *slog.Logger) Option { + return func(s *Searcher) { s.logger = l } +} + +// WithTracer sets the OpenTelemetry tracer. +// Defaults to the global tracer provider when not provided. +func WithTracer(t trace.Tracer) Option { + return func(s *Searcher) { s.tracer = t } +} + +// New returns a Searcher that embeds the query via emb and searches vstore, +// then hydrates full entries from store. The caller owns all three resources +// and must not close them before Searcher is done. +func New(s lstore.Store, emb embed.Embedder, vstore lvector.VectorStore, opts ...Option) *Searcher { + sr := &Searcher{store: s, emb: emb, vstore: vstore} + for _, o := range opts { + o(sr) + } + if sr.logger == nil { + sr.logger = slog.Default() + } + if sr.tracer == nil { + sr.tracer = otel.GetTracerProvider().Tracer(tracerName) + } + return sr +} + +// Search embeds query, finds nearest-neighbour entries, and returns them as +// SearchHits. Returns lore.ErrInvalidArgument when query is empty or opts.Limit +// is negative. +func (s *Searcher) Search(ctx context.Context, query string, opts lore.SearchOpts) ([]lore.SearchHit, error) { + if query == "" { + return nil, fmt.Errorf("vector: search: %w", lore.ErrInvalidArgument) + } + if opts.Limit < 0 { + return nil, fmt.Errorf("vector: search: negative limit: %w", lore.ErrInvalidArgument) + } + + ctx, span := s.tracer.Start(ctx, "lore.retrieve.vector") + defer span.End() + + vecs, err := s.emb.Embed(ctx, []string{query}) + if err != nil { + if !errors.Is(err, embed.ErrUnsupported) { + s.logger.ErrorContext(ctx, "vector: embed failed", "err", err) + } + return nil, fmt.Errorf("vector: embed query: %w", err) + } + if len(vecs) == 0 || len(vecs[0]) == 0 { + return nil, fmt.Errorf("vector: embed returned empty vector") + } + qvec := vecs[0] + + limit := opts.Limit + if limit == 0 { + limit = 10 + } + vsOpts := lvector.SearchOpts{ + Limit: limit, + Kinds: opts.Kinds, + Tags: opts.Tags, + } + hits, err := s.vstore.Search(ctx, qvec, vsOpts) + if err != nil { + s.logger.ErrorContext(ctx, "vector: vstore.Search failed", "err", err) + return nil, fmt.Errorf("vector: vstore search: %w", err) + } + + span.SetAttributes(attribute.Int("vector.count", len(hits))) + + results := make([]lore.SearchHit, 0, len(hits)) + for _, h := range hits { + entry, err := s.store.Get(ctx, h.ID) + if err != nil { + if errors.Is(err, lore.ErrNotFound) { + // Vector index slightly ahead of store; skip stale reference. + s.logger.WarnContext(ctx, "vector: entry not found in store; skipping", "id", h.ID) + continue + } + return nil, fmt.Errorf("vector: hydrate entry %d: %w", h.ID, err) + } + // Filter by project when specified. + if opts.Project != "" && entry.Project != opts.Project { + continue + } + results = append(results, lore.SearchHit{Entry: entry, Score: h.Score}) + } + + return results, nil +}