diff --git a/.gitignore b/.gitignore index b7b0b14f..adfe6f40 100644 --- a/.gitignore +++ b/.gitignore @@ -33,3 +33,6 @@ jepsen/store/ # Jepsen local SSH keys (generated locally; never commit) jepsen/docker/id_rsa jepsen/.ssh/ + +# Go build cache +.cache/ diff --git a/distribution/engine.go b/distribution/engine.go index bff6abe7..986f2fa3 100644 --- a/distribution/engine.go +++ b/distribution/engine.go @@ -121,6 +121,36 @@ func (e *Engine) Stats() []Route { return stats } +// GetIntersectingRoutes returns all routes whose key ranges intersect with [start, end). +// A route [rStart, rEnd) intersects with [start, end) if: +// - rStart < end (or end is nil, meaning unbounded scan) +// - start < rEnd (or rEnd is nil, meaning unbounded route) +func (e *Engine) GetIntersectingRoutes(start, end []byte) []Route { + e.mu.RLock() + defer e.mu.RUnlock() + + var result []Route + for _, r := range e.routes { + // Check if route intersects with [start, end) + // Route ends before scan starts: rEnd != nil && rEnd <= start + if r.End != nil && bytes.Compare(r.End, start) <= 0 { + continue + } + // Route starts at or after scan ends: end != nil && rStart >= end + if end != nil && bytes.Compare(r.Start, end) >= 0 { + continue + } + // Route intersects with scan range + result = append(result, Route{ + Start: cloneBytes(r.Start), + End: cloneBytes(r.End), + GroupID: r.GroupID, + Load: atomic.LoadUint64(&r.Load), + }) + } + return result +} + func (e *Engine) routeIndex(key []byte) int { if len(e.routes) == 0 { return -1 diff --git a/distribution/engine_test.go b/distribution/engine_test.go index 69ee0373..8025de3d 100644 --- a/distribution/engine_test.go +++ b/distribution/engine_test.go @@ -160,3 +160,74 @@ func assertRange(t *testing.T, r Route, start, end []byte) { t.Errorf("expected range [%q, %q), got [%q, %q]", start, end, r.Start, r.End) } } + +func TestEngineGetIntersectingRoutes(t *testing.T) { + e := NewEngine() + e.UpdateRoute([]byte("a"), []byte("m"), 1) + e.UpdateRoute([]byte("m"), []byte("z"), 2) + e.UpdateRoute([]byte("z"), nil, 3) + + cases := []struct { + name string + start []byte + end []byte + groups []uint64 + }{ + { + name: "scan in first range", + start: []byte("b"), + end: []byte("d"), + groups: []uint64{1}, + }, + { + name: "scan across first two ranges", + start: []byte("k"), + end: []byte("p"), + groups: []uint64{1, 2}, + }, + { + name: "scan across all ranges", + start: []byte("a"), + end: nil, + groups: []uint64{1, 2, 3}, + }, + { + name: "scan in last unbounded range", + start: []byte("za"), + end: nil, + groups: []uint64{3}, + }, + { + name: "scan before first range", + start: []byte("0"), + end: []byte("9"), + groups: []uint64{}, + }, + { + name: "scan at boundary", + start: []byte("m"), + end: []byte("n"), + groups: []uint64{2}, + }, + { + name: "scan ending at boundary", + start: []byte("k"), + end: []byte("m"), + groups: []uint64{1}, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + routes := e.GetIntersectingRoutes(c.start, c.end) + if len(routes) != len(c.groups) { + t.Fatalf("expected %d routes, got %d", len(c.groups), len(routes)) + } + for i, expectedGroup := range c.groups { + if routes[i].GroupID != expectedGroup { + t.Errorf("route %d: expected group %d, got %d", i, expectedGroup, routes[i].GroupID) + } + } + }) + } +} diff --git a/kv/shard_store.go b/kv/shard_store.go index 21388fbb..6d0da333 100644 --- a/kv/shard_store.go +++ b/kv/shard_store.go @@ -59,9 +59,14 @@ func (s *ShardStore) ScanAt(ctx context.Context, start []byte, end []byte, limit if limit <= 0 { return []*store.KVPair{}, nil } + + // Get only the routes whose ranges intersect with [start, end) + intersectingRoutes := s.engine.GetIntersectingRoutes(start, end) + var out []*store.KVPair - for _, g := range s.groups { - if g == nil || g.Store == nil { + for _, route := range intersectingRoutes { + g, ok := s.groups[route.GroupID] + if !ok || g == nil || g.Store == nil { continue } kvs, err := g.Store.ScanAt(ctx, start, end, limit, ts)