Skip to content

Commit 5348bd4

Browse files
committed
turbopuffer testing
1 parent 00bc109 commit 5348bd4

File tree

4 files changed

+133
-57
lines changed

4 files changed

+133
-57
lines changed

plugins/turbopuffer/tests/test_rag.py

Lines changed: 0 additions & 38 deletions
This file was deleted.
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
"""Tests for TurboPufferRAG."""
2+
3+
import uuid
4+
5+
import pytest
6+
from dotenv import load_dotenv
7+
8+
from vision_agents.core.rag import Document
9+
from vision_agents.plugins.turbopuffer import TurboPufferRAG
10+
11+
load_dotenv()
12+
13+
# Skip blockbuster for all tests in this module (they make real API calls)
14+
pytestmark = [pytest.mark.integration, pytest.mark.skip_blockbuster]
15+
16+
17+
@pytest.fixture
18+
async def rag():
19+
"""Create a RAG instance for testing, clean up after."""
20+
namespace = f"test-rag-{uuid.uuid4().hex[:8]}"
21+
rag = TurboPufferRAG(namespace=namespace)
22+
yield rag
23+
await rag.clear()
24+
await rag.close()
25+
26+
27+
@pytest.fixture
28+
def unique_doc():
29+
"""Create a document with unique content."""
30+
unique_id = uuid.uuid4()
31+
return Document(
32+
text=f"Test document {unique_id}. Contains quantum computing and AI info.",
33+
source="test_doc.txt",
34+
), str(unique_id)
35+
36+
37+
async def test_basic_upload_and_search(rag: TurboPufferRAG, unique_doc):
38+
"""Upload a document and verify it can be found."""
39+
doc, unique_id = unique_doc
40+
count = await rag.add_documents([doc])
41+
42+
assert count >= 1
43+
assert len(rag.indexed_files) == 1
44+
45+
result = await rag.search(f"document {unique_id}")
46+
assert unique_id in result
47+
48+
49+
async def test_vector_search_mode(rag: TurboPufferRAG):
50+
"""Test vector-only search finds semantically similar content."""
51+
doc = Document(text="Neural networks for pattern recognition.", source="ml.txt")
52+
await rag.add_documents([doc])
53+
54+
result = await rag.search("deep learning patterns", mode="vector")
55+
assert "neural" in result.lower() or "pattern" in result.lower()
56+
57+
58+
async def test_bm25_search_mode(rag: TurboPufferRAG):
59+
"""Test BM25 keyword search finds exact matches."""
60+
unique_sku = f"SKU-{uuid.uuid4().hex[:8].upper()}"
61+
doc = Document(text=f"Product code: {unique_sku}. High-quality widget.", source="product.txt")
62+
await rag.add_documents([doc])
63+
64+
result = await rag.search(unique_sku, mode="bm25")
65+
assert unique_sku in result
66+
67+
68+
async def test_hybrid_search_mode(rag: TurboPufferRAG):
69+
"""Test hybrid search combines vector and BM25."""
70+
doc = Document(text="The API endpoint supports real-time data streaming.", source="api.txt")
71+
await rag.add_documents([doc])
72+
73+
result = await rag.search("real-time streaming API")
74+
assert "streaming" in result.lower() or "api" in result.lower()
75+
76+
77+
async def test_batch_upload_multiple_documents(rag: TurboPufferRAG):
78+
"""Test uploading multiple documents in a batch."""
79+
docs = [
80+
Document(text=f"Document about {topic}: {uuid.uuid4()}", source=f"{topic}.txt")
81+
for topic in ["cats", "dogs", "birds"]
82+
]
83+
84+
count = await rag.add_documents(docs)
85+
assert count >= 3
86+
assert len(rag.indexed_files) == 3
87+
88+
89+
async def test_search_empty_namespace(rag: TurboPufferRAG):
90+
"""Test search returns appropriate message when namespace is empty."""
91+
result = await rag.search("anything")
92+
assert "No relevant information found" in result
93+
94+
95+
async def test_clear_removes_all_documents(rag: TurboPufferRAG, unique_doc):
96+
"""Test that clear() removes all indexed documents."""
97+
doc, _ = unique_doc
98+
await rag.add_documents([doc])
99+
assert len(rag.indexed_files) == 1
100+
101+
await rag.clear()
102+
assert len(rag.indexed_files) == 0
103+
104+
result = await rag.search("anything")
105+
assert "No relevant information found" in result
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from .rag import TurboPufferRAG, create_rag
1+
from .turbopuffer_rag import TurboPufferRAG, create_rag
22

33
__all__ = ["TurboPufferRAG", "create_rag"]

plugins/turbopuffer/vision_agents/plugins/turbopuffer/rag.py renamed to plugins/turbopuffer/vision_agents/plugins/turbopuffer/turbopuffer_rag.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545

4646
from langchain_google_genai import GoogleGenerativeAIEmbeddings
4747
from langchain_text_splitters import RecursiveCharacterTextSplitter
48-
from turbopuffer import AsyncTurbopuffer
48+
from turbopuffer import AsyncTurbopuffer, NotFoundError
4949

5050
from vision_agents.core.rag import RAG, Document
5151

@@ -247,45 +247,51 @@ async def _vector_search(self, query: str, top_k: int) -> list[tuple[str, float]
247247
)
248248

249249
ns = self._client.namespace(self._namespace_name)
250-
results = await ns.query(
251-
rank_by=("vector", "ANN", query_embedding),
252-
top_k=top_k,
253-
include_attributes=["text", "source"],
254-
)
250+
try:
251+
results = await ns.query(
252+
rank_by=("vector", "ANN", query_embedding),
253+
top_k=top_k,
254+
include_attributes=["text", "source"],
255+
)
256+
except NotFoundError:
257+
return []
255258

256259
ranked = []
257260
for row in results.rows:
258261
doc_id = str(row.id)
259262
# Cache the document for later retrieval
260263
self._doc_cache[doc_id] = {
261-
"text": row.text if row.text else "",
262-
"source": row.source if row.source else "unknown",
264+
"text": row["text"] or "",
265+
"source": row["source"] or "unknown",
263266
}
264267
# Lower distance = better, so we use negative for ranking
265-
dist = row.dist if row.dist else 0
268+
dist = row["$dist"] or 0
266269
ranked.append((doc_id, -dist))
267270

268271
return ranked
269272

270273
async def _bm25_search(self, query: str, top_k: int) -> list[tuple[str, float]]:
271274
"""Run BM25 full-text search."""
272275
ns = self._client.namespace(self._namespace_name)
273-
results = await ns.query(
274-
rank_by=("text", "BM25", query),
275-
top_k=top_k,
276-
include_attributes=["text", "source"],
277-
)
276+
try:
277+
results = await ns.query(
278+
rank_by=("text", "BM25", query),
279+
top_k=top_k,
280+
include_attributes=["text", "source"],
281+
)
282+
except NotFoundError:
283+
return []
278284

279285
ranked = []
280286
for row in results.rows:
281287
doc_id = str(row.id)
282288
# Cache the document for later retrieval
283289
self._doc_cache[doc_id] = {
284-
"text": row.text if row.text else "",
285-
"source": row.source if row.source else "unknown",
290+
"text": row["text"] or "",
291+
"source": row["source"] or "unknown",
286292
}
287293
# BM25 score (higher = better)
288-
score = row.dist if row.dist else 0
294+
score = row["$dist"] or 0
289295
ranked.append((doc_id, score))
290296

291297
return ranked
@@ -349,7 +355,10 @@ async def search(
349355
async def clear(self) -> None:
350356
"""Clear all vectors from the namespace."""
351357
ns = self._client.namespace(self._namespace_name)
352-
await ns.delete_all()
358+
try:
359+
await ns.delete_all()
360+
except NotFoundError:
361+
pass # Namespace doesn't exist, nothing to clear
353362
self._indexed_files = []
354363
self._doc_cache.clear()
355364
logger.info(f"Cleared namespace: {self._namespace_name}")

0 commit comments

Comments
 (0)