Skip to content

Commit 84c4e09

Browse files
committed
Revert "fix"
This reverts commit 9cdb534.
1 parent cd4361b commit 84c4e09

File tree

2 files changed

+71
-242
lines changed

2 files changed

+71
-242
lines changed
-3.74 KB
Binary file not shown.

agents/query_analyzer.py

Lines changed: 71 additions & 242 deletions
Original file line numberDiff line numberDiff line change
@@ -1,245 +1,74 @@
1-
import logging
2-
import time
3-
import faiss
4-
import pickle
5-
import numpy as np
6-
import itertools
7-
import re # Import re
1+
# agents/query_analyzer.py
2+
import spacy
3+
import logging # Added import
84
from .base import BaseAgent
9-
from gemini_utils import embed_text
10-
from utils.text_utils import simple_keyword_score, simple_entity_score, section_relevance_score
11-
from config import Config
12-
13-
logger = logging.getLogger(__name__)
14-
15-
DEFAULT_HYBRID_INITIAL_TOP_K = Config.RETRIEVER_INITIAL_K
16-
DEFAULT_HYBRID_FINAL_TOP_K = Config.RETRIEVER_FINAL_K
17-
18-
class RetrieverAgent(BaseAgent):
19-
"""Agent responsible for retrieving and re-ranking relevant text chunks."""
20-
def __init__(self, index_path="faiss_index.index", metadata_path="faiss_metadata.pkl"):
21-
logger.info(f"💾 Loading FAISS index from: {index_path}")
22-
try:
23-
self.index = faiss.read_index(index_path)
24-
logger.info(f"✅ FAISS index loaded successfully. Index dimension: {self.index.d}, Total vectors: {self.index.ntotal}")
25-
except Exception as e:
26-
logger.error(f"❌ Failed to load FAISS index: {e}", exc_info=True)
27-
raise
28-
logger.info(f"💾 Loading metadata from: {metadata_path}")
29-
try:
30-
with open(metadata_path, "rb") as f:
31-
self.metadatas = pickle.load(f)
32-
# Pre-extract texts for faster access if needed elsewhere
33-
self.texts = [m.pop('text', '') for m in self.metadatas] # Extract text and remove from metadata dict
34-
logger.info(f"✅ Metadata loaded successfully. Number of entries: {len(self.metadatas)}")
35-
if len(self.metadatas) != self.index.ntotal:
36-
logger.warning(f"⚠️ Mismatch between index size ({self.index.ntotal}) and metadata count ({len(self.metadatas)}).")
37-
except Exception as e:
38-
logger.error(f"❌ Failed to load metadata: {e}", exc_info=True)
39-
raise
40-
41-
def re_rank_chunks(self, initial_results, query, query_analysis):
42-
"""Re-rank chunks based on multiple factors using utility functions."""
43-
rerank_start_time = time.time()
44-
logger.info("⚖️ Re-ranking retrieved chunks...")
45-
if not initial_results:
46-
logger.warning("No initial results to re-rank.")
47-
return []
48-
49-
keywords = query_analysis.get("keywords", [])
50-
entities = query_analysis.get("entities", [])
51-
query_type = query_analysis.get("query_type", "unknown")
52-
intent_type = query_analysis.get("intent_type", "new_topic") # Get intent
53-
topic_keywords = query_analysis.get("topic_keywords", []) # Get topic keywords
54-
topic_entities = query_analysis.get("topic_entities", []) # Get topic entities
55-
56-
query_keywords_set = set(keywords)
57-
topic_terms_set = set(topic_keywords + topic_entities) # Combine topic terms
58-
59-
logger.debug(f"Re-ranking based on -> Query Keywords: {keywords}, Entities: {entities}, Type: {query_type}, Intent: {intent_type}, Topic Terms: {topic_terms_set}")
60-
61-
# --- Tuned Weights ---
62-
# Adjust weights based on intent? (Example)
63-
if intent_type in ["follow_up", "clarification"] and topic_terms_set:
64-
logger.debug("Adjusting weights for follow-up/clarification intent.")
65-
weights = {
66-
"semantic": 0.15, # Slightly lower semantic weight for current query
67-
"keyword": 0.4, # Keep keyword weight
68-
"entity": 0.25, # Keep entity weight
69-
"topic": 0.2, # Add weight for topic relevance
70-
"section": 0.0
71-
}
72-
else:
73-
weights = {
74-
"semantic": 0.2,
75-
"keyword": 0.5,
76-
"entity": 0.3,
77-
"topic": 0.0, # No topic weight for new topics
78-
"section": 0.0
79-
}
80-
# ---------------------
81-
82-
# Normalize semantic scores (FAISS distances are lower for better matches)
83-
max_faiss_dist = max(r["score"] for r in initial_results) if initial_results else 1.0
84-
if max_faiss_dist <= 0: # Avoid division by zero
85-
max_faiss_dist = 1.0
86-
87-
logger.debug(f"Re-ranking {len(initial_results)} chunks...")
88-
for i, result in enumerate(initial_results):
89-
text_lower = self.texts[result["index"]].lower() # Get text using index
90-
result["text"] = self.texts[result["index"]] # Add full text back for generator
91-
result["metadata"] = self.metadatas[result["index"]] # Add metadata back
92-
93-
result["semantic_score"] = max(0.0, 1.0 - (max(0.0, result["score"]) / max_faiss_dist))
94-
# Use utility functions for scoring
95-
result["keyword_score"] = simple_keyword_score(text_lower, query_keywords_set)
96-
result["entity_score"] = simple_entity_score(text_lower, entities)
97-
result["section_score"] = section_relevance_score(result["metadata"], query_type)
98-
# Add topic score if applicable
99-
result["topic_score"] = simple_keyword_score(text_lower, topic_terms_set) if weights["topic"] > 0 else 0.0
100-
101-
combined_score = (
102-
weights["semantic"] * result["semantic_score"] +
103-
weights["keyword"] * result["keyword_score"] +
104-
weights["entity"] * result["entity_score"] +
105-
weights["topic"] * result["topic_score"] # Include topic score
106-
# + weights["section"] * result["section_score"] # Section score currently unused
107-
)
108-
result["combined_score"] = combined_score
109-
110-
# Confidence calculation (can be refined)
111-
if combined_score > 0.75:
112-
confidence = 0.95
113-
elif combined_score > 0.6:
114-
confidence = 0.8
115-
elif combined_score > 0.45:
116-
confidence = 0.65
117-
elif combined_score > 0.3:
118-
confidence = 0.5
119-
else:
120-
confidence = 0.3
121-
result["confidence"] = confidence
122-
123-
# Sort by combined score
124-
ranked_results = sorted(initial_results, key=lambda x: x["combined_score"], reverse=True)
125-
126-
# Filter based on presence of *query* keywords/entities (important!)
127-
logger.info(f"🔍 Filtering {len(ranked_results)} re-ranked chunks for *query* keyword/entity presence...")
128-
filtered_results = []
129-
query_terms_lower = {k.lower() for k in keywords} | {e.lower() for e in entities}
5+
import re
6+
import time
1307

131-
# If the query itself has no terms, but it's a follow-up, rely on topic terms for filtering?
132-
# Or maybe skip filtering if query terms are absent? Let's skip for now.
133-
if not query_terms_lower and intent_type not in ["follow_up", "clarification"]:
134-
logger.warning("⚠️ No keywords or entities found in query analysis, and not a follow-up. Skipping filtering.")
135-
filtered_results = ranked_results
136-
elif not query_terms_lower and intent_type in ["follow_up", "clarification"]:
137-
logger.warning("⚠️ No keywords or entities in query, but it's a follow-up/clarification. Filtering based on *topic* terms.")
138-
filter_terms = {t.lower() for t in topic_terms_set} # Use topic terms for filtering
139-
if not filter_terms:
140-
logger.warning("⚠️ No topic terms found either. Skipping filtering.")
141-
filtered_results = ranked_results
142-
else:
143-
for result in ranked_results:
144-
text_lower = result["text"].lower()
145-
# Check for topic terms instead of query terms
146-
if any(re.search(r'\b' + re.escape(term) + r'\b', text_lower) for term in filter_terms):
147-
filtered_results.append(result)
8+
# Load the spaCy model once when the class is instantiated
9+
try:
10+
nlp = spacy.load("en_core_web_sm")
11+
print("✅ spaCy model 'en_core_web_sm' loaded successfully.")
12+
except OSError:
13+
print("❌ Error loading spaCy model 'en_core_web_sm'.")
14+
print(" Please run: python -m spacy download en_core_web_sm")
15+
nlp = None # Set nlp to None if loading fails
16+
17+
logger = logging.getLogger(__name__) # Get a logger for this module
18+
19+
class QueryAnalyzerAgent(BaseAgent):
20+
"""Agent responsible for analyzing the user query."""
21+
def run(self, query: str, chat_history: list = None) -> dict: # Add chat_history parameter
22+
start_time = time.time()
23+
logger.debug(f"Analyzing query: '{query}' with history: {chat_history is not None}") # Log if history is present
24+
if not nlp:
25+
logger.warning("spaCy model not loaded, falling back to basic analysis.")
26+
# Fallback basic extraction (similar to previous web.py logic)
27+
keywords = re.findall(r'"(.*?)"|\b[A-Z][a-zA-Z]+\b', query)
28+
entities = re.findall(r'\b[A-Z][a-zA-Z]+(?:\s+[A-Z][a-zA-Z]+)*\b', query)
29+
keywords = list(set([k.strip().lower() for k in keywords if k]))
30+
entities = list(set([e.strip() for e in entities if len(e.split()) > 1 or e in keywords]))
14831
else:
149-
# Standard filtering based on query terms
150-
filter_terms = query_terms_lower
151-
for result in ranked_results:
152-
text_lower = result["text"].lower()
153-
if any(re.search(r'\b' + re.escape(term) + r'\b', text_lower) for term in filter_terms):
154-
filtered_results.append(result)
155-
156-
157-
logger.info(f"✅ Filtered down to {len(filtered_results)} chunks containing relevant terms.")
158-
logger.debug("Top 5 Filtered & Re-ranked Chunks (Combined | Sem | Key | Ent | Top | Conf | Page):")
159-
for i, r in enumerate(filtered_results[:5]):
160-
page = r.get("metadata", {}).get("page", "?")
161-
logger.debug(f"{i+1}. Score={r['combined_score']:.3f} (S:{r['semantic_score']:.2f} K:{r['keyword_score']:.2f} E:{r['entity_score']:.2f} T:{r['topic_score']:.2f}) | Conf={r['confidence']:.2f} | Page={page} | Text: {r['text'][:100]}...")
162-
163-
total_rerank_time = time.time() - rerank_start_time
164-
logger.info(f"Step 2b: Re-ranking & Filtering took: {total_rerank_time:.4f}s")
165-
return filtered_results
166-
167-
168-
def _simple_expand_query(self, query_analysis: dict, max_expansions: int = 2) -> list[str]:
169-
"""Generates simple query variations based on keywords and entities."""
170-
expansions = []
171-
keywords = query_analysis.get("keywords", [])
172-
entities = query_analysis.get("entities", [])
173-
# Consider adding topic terms if it's a follow-up with few query terms?
174-
intent_type = query_analysis.get("intent_type", "new_topic")
175-
topic_keywords = query_analysis.get("topic_keywords", [])
176-
topic_entities = query_analysis.get("topic_entities", [])
177-
178-
terms = list(set(entities + keywords))
179-
180-
# If few terms in query but it's a follow-up, add topic terms to expansion base
181-
if len(terms) < 2 and intent_type in ["follow_up", "clarification"]:
182-
logger.debug("Expanding query using topic terms for follow-up.")
183-
terms.extend(topic_keywords)
184-
terms.extend(topic_entities)
185-
terms = list(set(terms)) # Ensure uniqueness
186-
187-
if not terms:
188-
return []
189-
190-
# Prioritize entities for combinations
191-
priority_terms = entities if entities else keywords
192-
other_terms = keywords if entities else []
193-
194-
# Generate pairs (priority x other, priority x priority)
195-
pairs = []
196-
if priority_terms and other_terms:
197-
pairs.extend(list(itertools.product(priority_terms, other_terms)))
198-
if len(priority_terms) >= 2:
199-
pairs.extend(list(itertools.combinations(priority_terms, 2)))
200-
201-
# Add single terms if not enough pairs
202-
if len(pairs) < max_expansions:
203-
pairs.extend([(t,) for t in terms]) # Add single terms
204-
205-
# Create expansion strings
206-
for pair in pairs:
207-
expansions.append(" ".join(pair))
208-
if len(expansions) >= max_expansions:
209-
break
210-
211-
# Fallback: if still no expansions, use top terms directly
212-
if not expansions and terms:
213-
expansions.extend(terms[:max_expansions])
214-
215-
unique_expansions = list(dict.fromkeys(expansions)) # Maintain order while making unique
216-
logger.debug(f"Generated query expansions: {unique_expansions[:max_expansions]}")
217-
return unique_expansions[:max_expansions]
218-
219-
220-
def run(self, query: str, query_analysis: dict, initial_top_k: int = DEFAULT_HYBRID_INITIAL_TOP_K, final_top_k: int = 5):
221-
"""Retrieves chunks using semantic search (with expansion), filters and re-ranks them."""
222-
run_start_time = time.time()
223-
logger.info(f"🔎 Running hybrid retrieval for: '{query}' (Initial K={initial_top_k}, Final K={final_top_k})")
224-
logger.debug(f"Query Analysis for Retrieval: {query_analysis}") # Log full analysis
225-
226-
expansion_start_time = time.time()
227-
# Use original query if analysis didn't refine, otherwise use refined
228-
query_to_expand = query_analysis.get("original_query", query) # Use original for expansion base
229-
expanded_queries = self._simple_expand_query(query_analysis)
230-
all_queries = [query_to_expand] + expanded_queries # Include original query
231-
232-
query_embeddings = []
233-
for q in all_queries:
234-
emb = embed_text(q)
235-
if emb:
236-
query_embeddings.append(np.array([emb]).astype("float32"))
237-
else:
238-
logger.warning(f"Failed to generate embedding for query variant: '{q}'")
239-
240-
if not query_embeddings:
241-
logger.error("Failed to generate any query embeddings.")
242-
return []
243-
244-
expansion_time = time.time() - expansion_start_time
245-
logger.info(f"Step 2a: Query expansion
32+
# TODO: Incorporate chat_history into spaCy analysis if needed
33+
# For now, just process the current query
34+
doc = nlp(query)
35+
36+
# Extract Named Entities (GPE, PERSON, ORG, LOC, EVENT, DATE etc.)
37+
entities = list(set([ent.text.strip() for ent in doc.ents if ent.label_ in ["GPE", "PERSON", "ORG", "LOC", "EVENT", "DATE", "FAC", "PRODUCT", "WORK_OF_ART"]]))
38+
39+
# Extract Keywords (Noun chunks and Proper Nouns)
40+
keywords = list(set([chunk.text.lower().strip() for chunk in doc.noun_chunks]))
41+
# Add proper nouns that might not be part of chunks or recognized entities
42+
keywords.extend([token.text.lower().strip() for token in doc if token.pos_ == "PROPN" and token.text not in entities])
43+
# Remove duplicates that might exist between entities and keywords after lowercasing
44+
keywords = list(set(keywords))
45+
# Optional: Remove very short keywords if needed
46+
# keywords = [kw for kw in keywords if len(kw) > 2]
47+
48+
# Determine Query Type (Keep existing logic)
49+
query_lower = query.lower()
50+
query_type = "unknown"
51+
if "cause" in query_lower or "why" in query_lower or "effect" in query_lower or "impact" in query_lower:
52+
query_type = "causal/analytical"
53+
elif "compare" in query_lower or "difference" in query_lower or "similar" in query_lower or "contrast" in query_lower:
54+
query_type = "comparative"
55+
elif re.match(r"^(what|who|when|where|which)\s+(is|was|are|were|did|do|does)\b", query_lower) or \
56+
re.match(r"^(define|describe|explain|list)\b", query_lower):
57+
query_type = "factual"
58+
# Add more rules if needed
59+
60+
analysis = {
61+
"original_query": query, # Add the original query here
62+
"keywords": keywords,
63+
"entities": entities,
64+
"query_type": query_type,
65+
# Optionally include history info if used
66+
# "history_considered": chat_history is not None
67+
}
68+
69+
end_time = time.time()
70+
# Log the extracted information
71+
logger.debug(f"Analysis Results: Keywords: {analysis['keywords']}, Entities: {analysis['entities']}, Query Type: {analysis['query_type']}")
72+
logger.debug(f"Analysis Time: {end_time - start_time:.4f}s")
73+
74+
return analysis

0 commit comments

Comments
 (0)