Skip to content

Commit f1af6d8

Browse files
authored
Merge pull request stanford-oval#198 from dfusion-dev/main
[New RM] Add AzureAISearch
2 parents aef7f0c + 81760f9 commit f1af6d8

File tree

3 files changed

+142
-5
lines changed

3 files changed

+142
-5
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ You could also install the source code which allows you to modify the behavior o
9393
Currently, our package support:
9494

9595
- `OpenAIModel`, `AzureOpenAIModel`, `ClaudeModel`, `VLLMClient`, `TGIClient`, `TogetherClient`, `OllamaClient`, `GoogleModel`, `DeepSeekModel`, `GroqModel` as language model components
96-
- `YouRM`, `BingSearch`, `VectorRM`, `SerperRM`, `BraveRM`, `SearXNG`, `DuckDuckGoSearchRM`, `TavilySearchRM`, `GoogleSearch` as retrieval module components
96+
- `YouRM`, `BingSearch`, `VectorRM`, `SerperRM`, `BraveRM`, `SearXNG`, `DuckDuckGoSearchRM`, `TavilySearchRM`, `GoogleSearch`, and `AzureAISearch` as retrieval module components
9797

9898
:star2: **PRs for integrating more language models into [knowledge_storm/lm.py](knowledge_storm/lm.py) and search engines/retrievers into [knowledge_storm/rm.py](knowledge_storm/rm.py) are highly appreciated!**
9999

examples/storm_examples/run_storm_wiki_gpt.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,11 @@
2020
"""
2121

2222
import os
23+
2324
from argparse import ArgumentParser
2425
from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs
2526
from knowledge_storm.lm import OpenAIModel, AzureOpenAIModel
26-
from knowledge_storm.rm import YouRM, BingSearch, BraveRM, SerperRM, DuckDuckGoSearchRM, TavilySearchRM, SearXNG
27+
from knowledge_storm.rm import YouRM, BingSearch, BraveRM, SerperRM, DuckDuckGoSearchRM, TavilySearchRM, SearXNG, AzureAISearch
2728
from knowledge_storm.utils import load_api_key
2829

2930

@@ -72,6 +73,7 @@ def main(args):
7273

7374
# STORM is a knowledge curation system which consumes information from the retrieval module.
7475
# Currently, the information source is the Internet and we use search engine API as the retrieval module.
76+
7577
match args.retriever:
7678
case 'bing':
7779
rm = BingSearch(bing_search_api=os.getenv('BING_SEARCH_API_KEY'), k=engine_args.search_top_k)
@@ -87,8 +89,10 @@ def main(args):
8789
rm = TavilySearchRM(tavily_search_api_key=os.getenv('TAVILY_API_KEY'), k=engine_args.search_top_k, include_raw_content=True)
8890
case 'searxng':
8991
rm = SearXNG(searxng_api_key=os.getenv('SEARXNG_API_KEY'), k=engine_args.search_top_k)
92+
case 'azure_ai_search':
93+
rm = AzureAISearch(azure_ai_search_api_key=os.getenv('AZURE_AI_SEARCH_API_KEY'), k=engine_args.search_top_k)
9094
case _:
91-
raise ValueError(f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", or "searxng"')
95+
raise ValueError(f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", "searxng", or "azure_ai_search"')
9296

9397
runner = STORMWikiRunner(engine_args, lm_configs, rm)
9498

@@ -113,7 +117,7 @@ def main(args):
113117
help='Maximum number of threads to use. The information seeking part and the article generation'
114118
'part can speed up by using multiple threads. Consider reducing it if keep getting '
115119
'"Exceed rate limit" error when calling LM API.')
116-
parser.add_argument('--retriever', type=str, choices=['bing', 'you', 'brave', 'serper', 'duckduckgo', 'tavily', 'searxng'],
120+
parser.add_argument('--retriever', type=str, choices=['bing', 'you', 'brave', 'serper', 'duckduckgo', 'tavily', 'searxng', 'azure_ai_search'],
117121
help='The search engine API to use for retrieving information.')
118122
# stage of the pipeline
119123
parser.add_argument('--do-research', action='store_true',
@@ -138,4 +142,4 @@ def main(args):
138142
parser.add_argument('--remove-duplicate', action='store_true',
139143
help='If True, remove duplicate content from the article.')
140144

141-
main(parser.parse_args())
145+
main(parser.parse_args())

knowledge_storm/rm.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1093,3 +1093,136 @@ def forward(
10931093
collected_results.append(r)
10941094

10951095
return collected_results
1096+
1097+
1098+
class AzureAISearch(dspy.Retrieve):
1099+
"""Retrieve information from custom queries using Azure AI Search.
1100+
1101+
General Documentation: https://learn.microsoft.com/en-us/azure/search/search-create-service-portal.
1102+
Python Documentation: https://learn.microsoft.com/en-us/python/api/overview/azure/search-documents-readme?view=azure-python.
1103+
"""
1104+
1105+
def __init__(
1106+
self,
1107+
azure_ai_search_api_key=None,
1108+
azure_ai_search_url=None,
1109+
azure_ai_search_index_name=None,
1110+
k=3,
1111+
is_valid_source: Callable = None,
1112+
):
1113+
"""
1114+
Params:
1115+
azure_ai_search_api_key: Azure AI Search API key. Check out https://learn.microsoft.com/en-us/azure/search/search-security-api-keys?tabs=rest-use%2Cportal-find%2Cportal-query
1116+
"API key" section
1117+
azure_ai_search_url: Custom Azure AI Search Endpoint URL. Check out https://learn.microsoft.com/en-us/azure/search/search-create-service-portal#name-the-service
1118+
azure_ai_search_index_name: Custom Azure AI Search Index Name. Check out https://learn.microsoft.com/en-us/azure/search/search-how-to-create-search-index?tabs=portal
1119+
k: Number of top results to retrieve.
1120+
is_valid_source: Optional function to filter valid sources.
1121+
min_char_count: Minimum character count for the article to be considered valid.
1122+
snippet_chunk_size: Maximum character count for each snippet.
1123+
webpage_helper_max_threads: Maximum number of threads to use for webpage helper.
1124+
"""
1125+
super().__init__(k=k)
1126+
1127+
try:
1128+
from azure.core.credentials import AzureKeyCredential
1129+
from azure.search.documents import SearchClient
1130+
except ImportError as err:
1131+
raise ImportError(
1132+
"AzureAISearch requires `pip install azure-search-documents`."
1133+
) from err
1134+
1135+
if not azure_ai_search_api_key and not os.environ.get(
1136+
"AZURE_AI_SEARCH_API_KEY"
1137+
):
1138+
raise RuntimeError(
1139+
"You must supply azure_ai_search_api_key or set environment variable AZURE_AI_SEARCH_API_KEY"
1140+
)
1141+
elif azure_ai_search_api_key:
1142+
self.azure_ai_search_api_key = azure_ai_search_api_key
1143+
else:
1144+
self.azure_ai_search_api_key = os.environ["AZURE_AI_SEARCH_API_KEY"]
1145+
1146+
if not azure_ai_search_url and not os.environ.get("AZURE_AI_SEARCH_URL"):
1147+
raise RuntimeError(
1148+
"You must supply azure_ai_search_url or set environment variable AZURE_AI_SEARCH_URL"
1149+
)
1150+
elif azure_ai_search_url:
1151+
self.azure_ai_search_url = azure_ai_search_url
1152+
else:
1153+
self.azure_ai_search_url = os.environ["AZURE_AI_SEARCH_URL"]
1154+
1155+
if not azure_ai_search_index_name and not os.environ.get(
1156+
"AZURE_AI_SEARCH_INDEX_NAME"
1157+
):
1158+
raise RuntimeError(
1159+
"You must supply azure_ai_search_index_name or set environment variable AZURE_AI_SEARCH_INDEX_NAME"
1160+
)
1161+
elif azure_ai_search_index_name:
1162+
self.azure_ai_search_index_name = azure_ai_search_index_name
1163+
else:
1164+
self.azure_ai_search_index_name = os.environ["AZURE_AI_SEARCH_INDEX_NAME"]
1165+
1166+
self.usage = 0
1167+
1168+
# If not None, is_valid_source shall be a function that takes a URL and returns a boolean.
1169+
if is_valid_source:
1170+
self.is_valid_source = is_valid_source
1171+
else:
1172+
self.is_valid_source = lambda x: True
1173+
1174+
def get_usage_and_reset(self):
1175+
usage = self.usage
1176+
self.usage = 0
1177+
1178+
return {"AzureAISearch": usage}
1179+
1180+
def forward(
1181+
self, query_or_queries: Union[str, List[str]], exclude_urls: List[str] = []
1182+
):
1183+
"""Search with Azure Open AI for self.k top passages for query or queries
1184+
1185+
Args:
1186+
query_or_queries (Union[str, List[str]]): The query or queries to search for.
1187+
exclude_urls (List[str]): A list of urls to exclude from the search results.
1188+
1189+
Returns:
1190+
a list of Dicts, each dict has keys of 'description', 'snippets' (list of strings), 'title', 'url'
1191+
"""
1192+
try:
1193+
from azure.core.credentials import AzureKeyCredential
1194+
from azure.search.documents import SearchClient
1195+
except ImportError as err:
1196+
raise ImportError(
1197+
"AzureAISearch requires `pip install azure-search-documents`."
1198+
) from err
1199+
queries = (
1200+
[query_or_queries]
1201+
if isinstance(query_or_queries, str)
1202+
else query_or_queries
1203+
)
1204+
self.usage += len(queries)
1205+
collected_results = []
1206+
1207+
client = SearchClient(
1208+
self.azure_ai_search_url,
1209+
self.azure_ai_search_index_name,
1210+
AzureKeyCredential(self.azure_ai_search_api_key),
1211+
)
1212+
for query in queries:
1213+
try:
1214+
# https://learn.microsoft.com/en-us/python/api/azure-search-documents/azure.search.documents.searchclient?view=azure-python#azure-search-documents-searchclient-search
1215+
results = client.search(search_text=query, top=1)
1216+
1217+
for result in results:
1218+
document = {
1219+
"url": result["metadata_storage_path"],
1220+
"title": result["title"],
1221+
"description": "N/A",
1222+
"snippets": [result["chunk"]],
1223+
}
1224+
collected_results.append(document)
1225+
except Exception as e:
1226+
logging.error(f"Error occurs when searching query {query}: {e}")
1227+
1228+
return collected_results

0 commit comments

Comments
 (0)