Skip to content

Commit 4cf3b92

Browse files
committed
feat: Add RAG (Retrieval-Augmented Generation) support
- Add core RAG module with Document, Chunk, RetrievalResult models - Add RAGProvider abstract base class for pluggable implementations - Add RAG events for observability (retrieval start/complete, document/file added) - Integrate RAG into base LLM class with automatic context injection - Implement GeminiFileSearchRAG using Google's native File Search Tool - Add comprehensive tests for RAG functionality - Add gemini_rag_demo example showing RAG in action
1 parent 7f758e3 commit 4cf3b92

File tree

16 files changed

+4266
-3
lines changed

16 files changed

+4266
-3
lines changed

agents-core/vision_agents/core/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,15 @@
44

55
from vision_agents.core.cli.cli_runner import cli
66
from vision_agents.core.agents.agent_launcher import AgentLauncher
7+
from vision_agents.core.rag import RAGProvider, Document, Chunk, RetrievalResult
78

8-
__all__ = ["Agent", "User", "cli", "AgentLauncher"]
9+
__all__ = [
10+
"Agent",
11+
"User",
12+
"cli",
13+
"AgentLauncher",
14+
"RAGProvider",
15+
"Document",
16+
"Chunk",
17+
"RetrievalResult",
18+
]

agents-core/vision_agents/core/llm/llm.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
if TYPE_CHECKING:
2424
from vision_agents.core.agents import Agent
2525
from vision_agents.core.agents.conversation import Conversation
26+
from vision_agents.core.rag import RAGProvider
2627

2728
from getstream.video.rtc import PcmData
2829
from getstream.video.rtc.pb.stream.video.sfu.models.models_pb2 import Participant
@@ -62,6 +63,10 @@ def __init__(self):
6263
# LLM instructions. Provided by the Agent via `set_instructions` method
6364
self._instructions: str = ""
6465
self._conversation: Optional[Conversation] = None
66+
# RAG provider for retrieval-augmented generation
67+
self._rag_provider: Optional[RAGProvider] = None
68+
self._rag_top_k: int = 5
69+
self._rag_include_citations: bool = True
6570

6671
async def warmup(self) -> None:
6772
"""
@@ -192,6 +197,59 @@ def set_instructions(self, instructions: Instructions | str) -> None:
192197
f"Invalid instructions type {type(instructions)}, expected str or Instructions"
193198
)
194199

200+
def set_rag_provider(
201+
self,
202+
provider: RAGProvider,
203+
top_k: int = 5,
204+
include_citations: bool = True,
205+
) -> None:
206+
"""Attach a RAG provider to this LLM for retrieval-augmented generation.
207+
208+
When a RAG provider is attached, queries will automatically be augmented
209+
with relevant context retrieved from the knowledge base.
210+
211+
Args:
212+
provider: The RAG provider to use for retrieval.
213+
top_k: Number of results to retrieve per query.
214+
include_citations: Whether to include citations in the context.
215+
"""
216+
self._rag_provider = provider
217+
self._rag_top_k = top_k
218+
self._rag_include_citations = include_citations
219+
220+
@property
221+
def rag_provider(self) -> Optional[RAGProvider]:
222+
"""Get the attached RAG provider, if any."""
223+
return self._rag_provider
224+
225+
async def _augment_with_rag(self, text: str) -> str:
226+
"""Augment a query with RAG context if a provider is attached.
227+
228+
Args:
229+
text: The original query text.
230+
231+
Returns:
232+
The query augmented with retrieved context, or the original
233+
text if no RAG provider is attached or no results found.
234+
"""
235+
if self._rag_provider is None:
236+
return text
237+
238+
results = await self._rag_provider.search_with_events(
239+
query=text,
240+
top_k=self._rag_top_k,
241+
)
242+
243+
if not results:
244+
return text
245+
246+
context = self._rag_provider.build_context_prompt(
247+
results,
248+
include_citations=self._rag_include_citations,
249+
)
250+
251+
return f"{context}\n\nUser question: {text}"
252+
195253
def register_function(
196254
self, name: Optional[str] = None, description: Optional[str] = None
197255
) -> Callable:
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
"""RAG (Retrieval-Augmented Generation) module for Vision Agents."""
2+
3+
from .base import RAGProvider
4+
from .events import (
5+
RAGDocumentAddedEvent,
6+
RAGFileAddedEvent,
7+
RAGRetrievalCompleteEvent,
8+
RAGRetrievalStartEvent,
9+
)
10+
from .types import Chunk, Document, RetrievalResult
11+
12+
__all__ = [
13+
"RAGProvider",
14+
"Document",
15+
"Chunk",
16+
"RetrievalResult",
17+
"RAGRetrievalStartEvent",
18+
"RAGRetrievalCompleteEvent",
19+
"RAGDocumentAddedEvent",
20+
"RAGFileAddedEvent",
21+
]
22+
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
"""Base RAG provider interface."""
2+
3+
import abc
4+
import time
5+
from typing import Optional
6+
7+
from vision_agents.core.events.manager import EventManager
8+
9+
from . import events
10+
from .types import Document, RetrievalResult
11+
12+
13+
class RAGProvider(abc.ABC):
14+
"""Abstract base class for RAG (Retrieval-Augmented Generation) providers.
15+
16+
RAG providers handle document ingestion, storage, and retrieval for
17+
augmenting LLM responses with relevant context.
18+
19+
Implementations can be:
20+
- Provider-native (e.g., Gemini File Search, OpenAI Vector Store)
21+
- Local (e.g., FAISS, ChromaDB with custom embeddings)
22+
"""
23+
24+
def __init__(self):
25+
self.events = EventManager()
26+
self.events.register_events_from_module(events)
27+
28+
@abc.abstractmethod
29+
async def add_documents(self, documents: list[Document]) -> None:
30+
"""Ingest documents into the knowledge base.
31+
32+
Args:
33+
documents: List of documents to add.
34+
"""
35+
36+
@abc.abstractmethod
37+
async def add_file(self, file_path: str, metadata: Optional[dict] = None) -> str:
38+
"""Ingest a file into the knowledge base.
39+
40+
Args:
41+
file_path: Path to the file to ingest.
42+
metadata: Optional metadata to associate with the file.
43+
44+
Returns:
45+
ID of the ingested file/document.
46+
"""
47+
48+
async def add_files(self, file_paths: list[str]) -> list[str]:
49+
"""Ingest multiple files into the knowledge base.
50+
51+
Args:
52+
file_paths: List of file paths to ingest.
53+
54+
Returns:
55+
List of IDs for the ingested files.
56+
"""
57+
ids = []
58+
for path in file_paths:
59+
file_id = await self.add_file(path)
60+
ids.append(file_id)
61+
return ids
62+
63+
@abc.abstractmethod
64+
async def search(
65+
self,
66+
query: str,
67+
top_k: int = 5,
68+
) -> list[RetrievalResult]:
69+
"""Retrieve relevant chunks for a query.
70+
71+
Args:
72+
query: The search query.
73+
top_k: Maximum number of results to return.
74+
75+
Returns:
76+
List of retrieval results ordered by relevance.
77+
"""
78+
79+
async def search_with_events(
80+
self,
81+
query: str,
82+
top_k: int = 5,
83+
) -> list[RetrievalResult]:
84+
"""Search with event emission for observability.
85+
86+
Args:
87+
query: The search query.
88+
top_k: Maximum number of results to return.
89+
90+
Returns:
91+
List of retrieval results ordered by relevance.
92+
"""
93+
self.events.send(
94+
events.RAGRetrievalStartEvent(
95+
query=query,
96+
top_k=top_k,
97+
)
98+
)
99+
100+
start_time = time.time()
101+
results = await self.search(query, top_k)
102+
elapsed_ms = (time.time() - start_time) * 1000
103+
104+
self.events.send(
105+
events.RAGRetrievalCompleteEvent(
106+
query=query,
107+
results=results,
108+
retrieval_time_ms=elapsed_ms,
109+
)
110+
)
111+
112+
return results
113+
114+
def build_context_prompt(
115+
self,
116+
results: list[RetrievalResult],
117+
include_citations: bool = True,
118+
) -> str:
119+
"""Format retrieved results for injection into LLM prompt.
120+
121+
Args:
122+
results: List of retrieval results.
123+
include_citations: Whether to include citation markers.
124+
125+
Returns:
126+
Formatted context string to prepend to the user's query.
127+
"""
128+
if not results:
129+
return ""
130+
131+
context_parts = [
132+
"Use the following context to answer the question. "
133+
"If the context doesn't contain relevant information, say so.\n"
134+
]
135+
136+
for i, result in enumerate(results, 1):
137+
citation = f" {result.format_citation()}" if include_citations else ""
138+
context_parts.append(f"[{i}]{citation}: {result.content}\n")
139+
140+
return "\n".join(context_parts)
141+
142+
@abc.abstractmethod
143+
async def delete_document(self, document_id: str) -> bool:
144+
"""Delete a document from the knowledge base.
145+
146+
Args:
147+
document_id: ID of the document to delete.
148+
149+
Returns:
150+
True if deleted, False if not found.
151+
"""
152+
153+
async def clear(self) -> None:
154+
"""Clear all documents from the knowledge base.
155+
156+
Default implementation does nothing. Override if supported.
157+
"""
158+
pass
159+
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
"""RAG-specific events for observability."""
2+
3+
from dataclasses import dataclass, field
4+
from typing import Any, Optional
5+
6+
from vision_agents.core.events import BaseEvent
7+
8+
from .types import RetrievalResult
9+
10+
11+
@dataclass
12+
class RAGRetrievalStartEvent(BaseEvent):
13+
"""Emitted when RAG retrieval begins."""
14+
15+
type: str = field(default="rag.retrieval.start", init=False)
16+
plugin_name: str = "rag"
17+
query: str = ""
18+
top_k: int = 5
19+
20+
21+
@dataclass
22+
class RAGRetrievalCompleteEvent(BaseEvent):
23+
"""Emitted when RAG retrieval completes."""
24+
25+
type: str = field(default="rag.retrieval.complete", init=False)
26+
plugin_name: str = "rag"
27+
query: str = ""
28+
results: list[RetrievalResult] = field(default_factory=list)
29+
retrieval_time_ms: float = 0.0
30+
31+
@property
32+
def result_count(self) -> int:
33+
return len(self.results)
34+
35+
36+
@dataclass
37+
class RAGDocumentAddedEvent(BaseEvent):
38+
"""Emitted when a document is added to the RAG system."""
39+
40+
type: str = field(default="rag.document.added", init=False)
41+
plugin_name: str = "rag"
42+
document_id: str = ""
43+
metadata: dict[str, Any] = field(default_factory=dict)
44+
chunk_count: int = 0
45+
46+
47+
@dataclass
48+
class RAGFileAddedEvent(BaseEvent):
49+
"""Emitted when a file is uploaded to the RAG system."""
50+
51+
type: str = field(default="rag.file.added", init=False)
52+
plugin_name: str = "rag"
53+
file_path: str = ""
54+
file_id: Optional[str] = None
55+
metadata: dict[str, Any] = field(default_factory=dict)
56+
57+
58+

0 commit comments

Comments
 (0)