Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -215,3 +215,4 @@ __marimo__/
__pycache__/
*.py[codz]
*$py.class
memory_entries.json
148 changes: 96 additions & 52 deletions examples/run_lightmem_bm25.py
Original file line number Diff line number Diff line change
@@ -1,76 +1,120 @@
"""
Example: Demonstrate basic usage of LightMemory with BM25 retrieval.
This script follows the same structure and annotation style as other examples in LightMem.
"""

import os
from lightmem.memory.lightmem import LightMemory
import sys
import traceback

src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'src'))
if src_path not in sys.path:
sys.path.insert(0, src_path)

try:
from lightmem.memory.lightmem import LightMemory
from lightmem.configs.base import BaseMemoryConfigs, MemoryManagerConfig
from lightmem.configs.retriever.bm25 import BM25Config
from lightmem.configs.logging.base import LoggingConfig
from lightmem.configs.topic_segmenter.base import TopicSegmenterConfig
except ImportError as e:
print(f"--- Import failed! Error: {e} ---")
print(f"--- Please ensure the 'src' directory is at: {src_path} ---")
sys.exit()

# ============ Data Configuration ============
EXAMPLE_COLLECTION = "demo_collection_bm25"


def load_lightmem_bm25(collection_name):
"""
Load a LightMemory instance configured to use BM25 retrieval.
This function mirrors the config structure in other examples.
Load a LightMemory instance configured for BM25 retrieval.
"""
config = {
"pre_compress": False,
"topic_segment": False,
"index_strategy": "bm25", # ✅ Use BM25 instead of embedding
"retrieve_strategy": "bm25",
"bm25_retriever": {
"model_name": "bm25",
"configs": {
"collection_name": collection_name,
"path": f"./bm25_data/{collection_name}",
},
},
"memory_manager": {
"model_name": "openai",
"configs": {
"model": "gpt-3.5-turbo",
"max_tokens": 2048,
}
},
"logging": {
"level": "INFO",
"file_enabled": True,
"log_dir": "logs",
"log_filename_prefix": "example_bm25",
"console_enabled": True,
"file_level": "DEBUG",

# 1. Define BM25 retriever configuration
bm25_retriever_config = BM25Config(
collection_name=collection_name,
path=f"./bm25_data/{collection_name}"
)

# 2. Define Memory Manager configuration (requires OPENAI_API_KEY)
manager_config = MemoryManagerConfig(
model_name="openai",
configs={
"model": "gpt-3.5-turbo",
"max_tokens": 2048,
}
}
lightmem = LightMemory.from_config(config)
)

# 3. Define logging configuration
logging_config = LoggingConfig(
level="INFO",
file_enabled=True,
log_dir="logs",
log_filename_prefix="example_bm25",
console_enabled=True,
file_level="DEBUG",
)

# 4. Define segmenter configuration (requires torch/llmlingua-2)
segmenter_config = TopicSegmenterConfig(
model_name="llmlingua-2"
)

# 5. Assemble the final BaseMemoryConfigs
config_object = BaseMemoryConfigs(
pre_compress=False,
topic_segment=True, # Ensure add_memory does not exit prematurely
topic_segmenter=segmenter_config,
index_strategy="bm25",
retrieve_strategy="bm25",
bm25_retriever=bm25_retriever_config,
memory_manager=manager_config,
logging=logging_config
)

# 6. Initialize using the config object
print("--- Initializing LightMemory (requires API Key and torch)... ---")
lightmem = LightMemory.from_config(config_object.model_dump())
print("--- LightMemory initialization successful ---")

return lightmem


def main():
"""
Run a minimal demonstration of LightMemory add/retrieve workflow (BM25 version).
"""
print("========== LightMemory BM25 Example ==========")
lightmem = load_lightmem_bm25(EXAMPLE_COLLECTION)

print("\n========== LightMemory BM25 Example ==========")

try:
lightmem = load_lightmem_bm25(EXAMPLE_COLLECTION)

# ============ Add Example Memory ============
print("\n--- Attempting: lightmem.add_memory() ---")

session_timestamp = "2025/11/12 (Wed) 19:30"
messages = [
{"role": "user", "content": "The capital of France is Paris.", "time_stamp": session_timestamp},
{"role": "assistant", "content": "Correct, Paris is the capital city of France.", "time_stamp": session_timestamp}
]

result = lightmem.add_memory(messages=messages, force_segment=True, force_extract=True)
print("Memory added:", result)

# ============ Add Example Memory ============
messages = [
{"role": "user", "content": "The capital of France is Paris."},
{"role": "assistant", "content": "Correct, Paris is the capital city of France."}
]
result = lightmem.add_memory(messages=messages, force_segment=True, force_extract=True)
print("Memory added:", result)
# ============ Retrieve Example Memory ============
print("\n--- Attempting: lightmem.retrieve() ---")
query = "What is the capital of France?"
results = lightmem.retrieve(query, limit=5)
print("Query:", query)
print("Retrieved Results:", results)

# ============ Retrieve Example Memory ============
query = "What is the capital of France?"
results = lightmem.retrieve(query, limit=5)
print("Query:", query)
print("Retrieved Results:", results)
print("===============================================")
except Exception as e:
print(f"\n--- Example run failed ---")
print("This may be due to a missing OPENAI_API_KEY environment variable or missing 'torch' dependency.")
print(f"Error Type: {type(e)}")
print(f"Error Details: {e}")
traceback.print_exc()

print("\n===============================================")


# ============ Entry Point ============
if __name__ == "__main__":
main()
main()
14 changes: 9 additions & 5 deletions src/lightmem/configs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from lightmem.configs.multimodal_embedder.base import MMEmbedderConfig
from lightmem.configs.retriever.contextretriever.base import ContextRetrieverConfig
from lightmem.configs.retriever.embeddingretriever.base import EmbeddingRetrieverConfig
from lightmem.configs.retriever.bm25 import BM25Config
from lightmem.configs.logging.base import LoggingConfig

lightmem_dir = ""
Expand Down Expand Up @@ -57,7 +58,7 @@ class BaseMemoryConfigs(BaseModel):
extract_threshold: float = Field(
default=0.5,
)
index_strategy: Optional[Literal["embedding", "context", "hybrid"]] = Field(
index_strategy: Optional[Literal["embedding", "context", "hybrid", "bm25"]] = Field(
default=None,
description="Indexing strategy to use. Choices: "
"embedding|text|hybrid"
Expand All @@ -74,10 +75,9 @@ class BaseMemoryConfigs(BaseModel):
description="Path to the history database",
default=os.path.join(lightmem_dir, "history.db"),
)
retrieve_strategy: Optional[Literal["context", "embedding", "hybrid"]] = Field(
default="embedding",
description="Retrieving strategy to use. Choices: "
"embedding|context|hybrid"
retrieve_strategy: Optional[Literal["context", "embedding", "hybrid", "bm25"]] = Field(
default="embedding",
description="Retrieving strategy to use. Choices: 'embedding' | 'context' | 'hybrid' | 'bm25'"
)
context_retriever: Optional[ContextRetrieverConfig] = Field(
description="Configuration for the context-based retriever (active only if retrieve_strategy is 'context' or 'hybrid')",
Expand All @@ -87,6 +87,10 @@ class BaseMemoryConfigs(BaseModel):
description="Configuration for the embedding-based retriever (active only if retrieve_strategy is 'embedding' or 'hybrid')",
default=None,
)
bm25_retriever: Optional[BM25Config] = Field(
default=None,
description="Configuration for BM25 retriever (active only if retrieve_strategy is 'bm25')."
)
update: Optional[Literal["online","offline"]] = Field(
description="'online'=immediate during execution, 'offline'=scheduled updates",
default="offline",
Expand Down
2 changes: 1 addition & 1 deletion src/lightmem/configs/memory_manager/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class MemoryManagerConfig(BaseModel):
"ollama",
]

configs: Optional[dict] = Field(description="Configuration for the specific MemoryManager model", default={})
configs: Optional[Any] = Field(description="Configuration for the specific MemoryManager model", default={})

@model_validator(mode='before')
def validate_model_name(cls, values):
Expand Down
2 changes: 2 additions & 0 deletions src/lightmem/configs/retriever/bm25.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
class BM25Config(BaseModel):
"""Configuration for BM25 retriever."""

collection_name: str
path: str
k1: float = Field(1.5, description="BM25 k1 parameter controlling term frequency scaling")
b: float = Field(0.75, description="BM25 b parameter controlling document length normalization")
tokenizer: Optional[str] = Field(None, description="Optional tokenizer identifier or function name")
Expand Down
106 changes: 74 additions & 32 deletions src/lightmem/memory/lightmem.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
from lightmem.configs.logging.utils import get_logger


class _DummyEmbedder:
def embed(self, text: str) -> List[float]:
return [0.0] * 128

class MessageNormalizer:

_SESSION_RE = re.compile(
Expand Down Expand Up @@ -138,6 +142,7 @@ def __init__(self, config: BaseMemoryConfigs = BaseMemoryConfigs()):
self.logger.info("Initializing LightMemory with provided configuration")

self.config = config
self.compressor = None
if self.config.pre_compress:
self.logger.info("Initializing pre-compressor")
self.compressor = PreCompressorFactory.from_config(self.config.pre_compressor)
Expand All @@ -147,12 +152,15 @@ def __init__(self, config: BaseMemoryConfigs = BaseMemoryConfigs()):
self.senmem_buffer_manager = SenMemBufferManager(max_tokens=self.segmenter.buffer_len, tokenizer=self.segmenter.tokenizer)
self.logger.info("Initializing memory manager")
self.manager = MemoryManagerFactory.from_config(self.config.memory_manager)
self.shortmem_buffer_manager = ShortMemBufferManager(max_tokens = 1024, tokenizer=getattr(self.manager, "tokenizer", self.manager.config.model))
self.shortmem_buffer_manager = ShortMemBufferManager(max_tokens = 1024, tokenizer=getattr(self.manager, "tokenizer", self.manager.config.model_name))
if self.config.index_strategy == 'embedding' or self.config.index_strategy == 'hybrid':
self.logger.info("Initializing text embedder")
self.text_embedder = TextEmbedderFactory.from_config(self.config.text_embedder)
else:
self.logger.debug("Initializing _DummyEmbedder for non-embedding strategy.")
self.text_embedder = _DummyEmbedder()
# if self.config.multimodal_embedder:
if self.config.retrieve_strategy in ["context", "hybrid"]:
if self.config.retrieve_strategy in ["context", "hybrid", "bm25"]:
self.logger.info("Initializing context retriever")
self.context_retriever = ContextRetrieverFactory.from_config(self.config.context_retriever)
if self.config.retrieve_strategy in ["embedding", "hybrid"]:
Expand Down Expand Up @@ -349,10 +357,20 @@ def offline_update(self, memory_list: List, construct_update_queue_trigger: bool
self.logger.info(f"[{call_id}] Received {len(memory_list)} memory entries")
self.logger.info(f"[{call_id}] construct_update_queue_trigger={construct_update_queue_trigger}, offline_update_trigger={offline_update_trigger}")

if self.config.index_strategy in ["context", "hybrid"]:
if self.config.index_strategy in ["context", "hybrid", "bm25"]:
self.logger.info(f"[{call_id}] Saving memory entries to file (strategy: {self.config.index_strategy})")
save_memory_entries(memory_list, "memory_entries.json")

if self.config.index_strategy == "bm25":
self.logger.info(f"[{call_id}] Indexing {len(memory_list)} entries for BM25")
corpus = [entry.memory for entry in memory_list if entry.memory]
if corpus:
self.context_retriever.index(corpus)
self.logger.info(f"[{call_id}] BM25 indexing complete.")
else:
elf.logger.warning(f"[{call_id}] No memory content found to index for BM25.")


if self.config.index_strategy in ["embedding", "hybrid"]:
inserted_count = 0
self.logger.info(f"[{call_id}] Starting embedding and insertion to vector database")
Expand Down Expand Up @@ -557,42 +575,66 @@ def update_entry(entry):
def retrieve(self, query: str, limit: int = 10, filters: dict = None) -> list[str]:
"""
Retrieve relevant entries and return them as formatted strings.

Args:
query (str): The natural language query string.
limit (int, optional): Number of results to return. Defaults to 10.
filters (dict, optional): Optional filters to narrow down the search. Defaults to None.

Returns:
list[str]: A list of formatted strings containing time_stamp, weekday, and memory.
This method checks the retrieve_strategy to decide which retriever to use.
"""
call_id = f"retrieve_{datetime.now().strftime('%Y%m%d_%H%M%S_%f')}"

self.logger.info(f"========== START {call_id} ==========")
self.logger.info(f"[{call_id}] Query: {query}")
self.logger.info(f"[{call_id}] Strategy: {self.config.retrieve_strategy}")
self.logger.info(f"[{call_id}] Parameters: limit={limit}, filters={filters}")
self.logger.debug(f"[{call_id}] Generating embedding for query")
query_vector = self.text_embedder.embed(query)
self.logger.debug(f"[{call_id}] Query embedding dimension: {len(query_vector)}")
self.logger.info(f"[{call_id}] Searching vector database")
results = self.embedding_retriever.search(
query_vector=query_vector,
limit=limit,
filters=filters,
return_full=True,
)
self.logger.info(f"[{call_id}] Found {len(results)} results")

results = []
formatted_results = []
for r in results:
payload = r.get("payload", {})
time_stamp = payload.get("time_stamp", "")
weekday = payload.get("weekday", "")
memory = payload.get("memory", "")
formatted_results.append(f"{time_stamp} {weekday} {memory}")


if self.config.retrieve_strategy in ["context", "hybrid", "bm25"]:
self.logger.info(f"[{call_id}] Retrieving using Context retriever (strategy: {self.config.retrieve_strategy})")

results = self.context_retriever.retrieve(
query=query,
top_k=limit
)

self.logger.info(f"[{call_id}] Found {len(results)} BM25 results")
for doc in results:
if isinstance(doc, str):
formatted_results.append(doc)
elif isinstance(doc, dict):
formatted_results.append(doc.get("content", str(doc)))
else:
formatted_results.append(str(doc))

elif self.config.retrieve_strategy == "embedding":
self.logger.debug(f"[{call_id}] Generating embedding for query")
query_vector = self.text_embedder.embed(query)
self.logger.debug(f"[{call_id}] Query embedding dimension: {len(query_vector)}")
self.logger.info(f"[{call_id}] Searching vector database")
results = self.embedding_retriever.search(
query_vector=query_vector,
limit=limit,
filters=filters,
return_full=True,
)

self.logger.info(f"[{call_id}] Found {len(results)} embedding results")
for r in results:
if isinstance(r, dict):
payload = r.get("payload", {})
else:
try:
payload = r.payload
except AttributeError:
payload = {}

time_stamp = payload.get("time_stamp", "")
weekday = payload.get("weekday", "")
memory = payload.get("memory", "")
formatted_results.append(f"{time_stamp} {weekday} {memory}")

else:
self.logger.warning(f"[{call_id}] Unknown retrieve_strategy: {self.config.retrieve_strategy}")

result_string = "\n".join(formatted_results)
self.logger.info(f"[{call_id}] Formatted {len(formatted_results)} results into output string")
self.logger.debug(f"[{call_id}] Output string length: {len(result_string)} characters")
self.logger.info(f"========== END {call_id} ==========")
return result_string

return result_string