diff --git a/README.md b/README.md index c6263c27..1f4715e6 100644 --- a/README.md +++ b/README.md @@ -7,9 +7,10 @@
| Research preview | STORM Paper| Co-STORM Paper | Website |
- **Latest News** 🔥 +- [2025/01] We add [litellm](https://github.com/BerriAI/litellm) integration for language models and embedding models in `knowledge-storm` v1.1.0. + - [2024/09] Co-STORM codebase is now released and integrated into `knowledge-storm` python package v1.0.0. Run `pip install knowledge-storm --upgrade` to check it out. - [2024/09] We introduce collaborative STORM (Co-STORM) to support human-AI collaborative knowledge curation! [Co-STORM Paper](https://www.arxiv.org/abs/2408.15232) has been accepted to EMNLP 2024 main conference. @@ -92,10 +93,11 @@ You could also install the source code which allows you to modify the behavior o Currently, our package support: -- `OpenAIModel`, `AzureOpenAIModel`, `ClaudeModel`, `VLLMClient`, `TGIClient`, `TogetherClient`, `OllamaClient`, `GoogleModel`, `DeepSeekModel`, `GroqModel` as language model components -- `YouRM`, `BingSearch`, `VectorRM`, `SerperRM`, `BraveRM`, `SearXNG`, `DuckDuckGoSearchRM`, `TavilySearchRM`, `GoogleSearch`, and `AzureAISearch` as retrieval module components +- Language model components: All language models supported by litellm as listed [here](https://docs.litellm.ai/docs/providers) +- Embedding model components: All embedding models supported by litellm as listed [here](https://docs.litellm.ai/docs/embedding/supported_embedding) +- retrieval module components: `YouRM`, `BingSearch`, `VectorRM`, `SerperRM`, `BraveRM`, `SearXNG`, `DuckDuckGoSearchRM`, `TavilySearchRM`, `GoogleSearch`, and `AzureAISearch` as -:star2: **PRs for integrating more language models into [knowledge_storm/lm.py](knowledge_storm/lm.py) and search engines/retrievers into [knowledge_storm/rm.py](knowledge_storm/rm.py) are highly appreciated!** +:star2: **PRs for integrating more search engines/retrievers into [knowledge_storm/rm.py](knowledge_storm/rm.py) are highly appreciated!** Both STORM and Co-STORM are working in the information curation layer, you need to set up the information retrieval module and language model module to create their `Runner` classes respectively. @@ -106,7 +108,7 @@ The STORM knowledge curation engine is defined as a simple Python `STORMWikiRunn ```python import os from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs -from knowledge_storm.lm import OpenAIModel +from knowledge_storm.lm import LitellmModel from knowledge_storm.rm import YouRM lm_configs = STORMWikiLMConfigs() @@ -118,8 +120,8 @@ openai_kwargs = { # STORM is a LM system so different components can be powered by different models to reach a good balance between cost and quality. # For a good practice, choose a cheaper/faster model for `conv_simulator_lm` which is used to split queries, synthesize answers in the conversation. # Choose a more powerful model for `article_gen_lm` to generate verifiable text with citations. -gpt_35 = OpenAIModel(model='gpt-3.5-turbo', max_tokens=500, **openai_kwargs) -gpt_4 = OpenAIModel(model='gpt-4o', max_tokens=3000, **openai_kwargs) +gpt_35 = LitellmModel(model='gpt-3.5-turbo', max_tokens=500, **openai_kwargs) +gpt_4 = LitellmModel(model='gpt-4o', max_tokens=3000, **openai_kwargs) lm_configs.set_conv_simulator_lm(gpt_35) lm_configs.set_question_asker_lm(gpt_35) lm_configs.set_outline_gen_lm(gpt_4) @@ -155,7 +157,7 @@ The Co-STORM knowledge curation engine is defined as a simple Python `CoStormRun ```python from knowledge_storm.collaborative_storm.engine import CollaborativeStormLMConfigs, RunnerArgument, CoStormRunner -from knowledge_storm.lm import OpenAIModel +from knowledge_storm.lm import LitellmModel from knowledge_storm.logging_wrapper import LoggingWrapper from knowledge_storm.rm import BingSearch @@ -168,12 +170,12 @@ openai_kwargs = { "top_p": 0.9, "api_base": None, } -question_answering_lm = OpenAIModel(model=gpt_4o_model_name, max_tokens=1000, **openai_kwargs) -discourse_manage_lm = OpenAIModel(model=gpt_4o_model_name, max_tokens=500, **openai_kwargs) -utterance_polishing_lm = OpenAIModel(model=gpt_4o_model_name, max_tokens=2000, **openai_kwargs) -warmstart_outline_gen_lm = OpenAIModel(model=gpt_4o_model_name, max_tokens=500, **openai_kwargs) -question_asking_lm = OpenAIModel(model=gpt_4o_model_name, max_tokens=300, **openai_kwargs) -knowledge_base_lm = OpenAIModel(model=gpt_4o_model_name, max_tokens=1000, **openai_kwargs) +question_answering_lm = LitellmModel(model=gpt_4o_model_name, max_tokens=1000, **openai_kwargs) +discourse_manage_lm = LitellmModel(model=gpt_4o_model_name, max_tokens=500, **openai_kwargs) +utterance_polishing_lm = LitellmModel(model=gpt_4o_model_name, max_tokens=2000, **openai_kwargs) +warmstart_outline_gen_lm = LitellmModel(model=gpt_4o_model_name, max_tokens=500, **openai_kwargs) +question_asking_lm = LitellmModel(model=gpt_4o_model_name, max_tokens=300, **openai_kwargs) +knowledge_base_lm = LitellmModel(model=gpt_4o_model_name, max_tokens=1000, **openai_kwargs) lm_config.set_question_answering_lm(question_answering_lm) lm_config.set_discourse_manage_lm(discourse_manage_lm) @@ -222,6 +224,7 @@ We provide scripts in our [examples folder](examples) as a quick start to run ST We suggest using `secrets.toml` to set up the API keys. Create a file `secrets.toml` under the root directory and add the following content: ```shell +# ============ language model configurations ============ # Set up OpenAI API key. OPENAI_API_KEY="your_openai_api_key" # If you are using the API service provided by OpenAI, include the following line: @@ -230,15 +233,10 @@ OPENAI_API_TYPE="openai" OPENAI_API_TYPE="azure" AZURE_API_BASE="your_azure_api_base_url" AZURE_API_VERSION="your_azure_api_version" -# Set up You.com search API key. -YDC_API_KEY="your_youcom_api_key" -``` - -for **Co-STORM**, please also add following -``` -# if use openai encoder -ENCODER_API_TYPE="openai" -# or ENCODER_API_TYPE="azure" if use azure openai encoder +# ============ retriever configurations ============ +BING_SEARCH_API_KEY="your_bing_search_api_key" # if using bing search +# ============ encoder configurations ============ +ENCODER_API_TYPE="openai" # if using openai encoder ``` ### STORM examples @@ -249,7 +247,7 @@ Run the following command. ```bash python examples/storm_examples/run_storm_wiki_gpt.py \ --output-dir $OUTPUT_DIR \ - --retriever you \ + --retriever bing \ --do-research \ --do-generate-outline \ --do-generate-article \ @@ -328,20 +326,44 @@ We are very grateful to [Michelle Lam](https://michelle123lam.github.io/) for de ## Citation Please cite our paper if you use this code or part of it in your work: ```bibtex -@misc{jiang2024unknownunknowns, - title={Into the Unknown Unknowns: Engaged Human Learning through Participation in Language Model Agent Conversations}, - author={Yucheng Jiang and Yijia Shao and Dekun Ma and Sina J. Semnani and Monica S. Lam}, - year={2024}, - eprint={2408.15232}, - archivePrefix={arXiv}, - primaryClass={cs.CL}, - url={https://arxiv.org/abs/2408.15232}, +@inproceedings{jiang-etal-2024-unknown, + title = "Into the Unknown Unknowns: Engaged Human Learning through Participation in Language Model Agent Conversations", + author = "Jiang, Yucheng and + Shao, Yijia and + Ma, Dekun and + Semnani, Sina and + Lam, Monica", + editor = "Al-Onaizan, Yaser and + Bansal, Mohit and + Chen, Yun-Nung", + booktitle = "Proceedings of the 2024 Conference on Empirical Methods in Natural Language Processing", + month = nov, + year = "2024", + address = "Miami, Florida, USA", + publisher = "Association for Computational Linguistics", + url = "https://aclanthology.org/2024.emnlp-main.554/", + doi = "10.18653/v1/2024.emnlp-main.554", + pages = "9917--9955", } -@inproceedings{shao2024assisting, - title={{Assisting in Writing Wikipedia-like Articles From Scratch with Large Language Models}}, - author={Yijia Shao and Yucheng Jiang and Theodore A. Kanell and Peter Xu and Omar Khattab and Monica S. Lam}, - year={2024}, - booktitle={Proceedings of the 2024 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers)} +@inproceedings{shao-etal-2024-assisting, + title = "Assisting in Writing {W}ikipedia-like Articles From Scratch with Large Language Models", + author = "Shao, Yijia and + Jiang, Yucheng and + Kanell, Theodore and + Xu, Peter and + Khattab, Omar and + Lam, Monica", + editor = "Duh, Kevin and + Gomez, Helena and + Bethard, Steven", + booktitle = "Proceedings of the 2024 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies (Volume 1: Long Papers)", + month = jun, + year = "2024", + address = "Mexico City, Mexico", + publisher = "Association for Computational Linguistics", + url = "https://aclanthology.org/2024.naacl-long.347/", + doi = "10.18653/v1/2024.naacl-long.347", + pages = "6252--6278", } ``` diff --git a/examples/costorm_examples/run_costorm_gpt.py b/examples/costorm_examples/run_costorm_gpt.py index 9046bd5f..40138915 100644 --- a/examples/costorm_examples/run_costorm_gpt.py +++ b/examples/costorm_examples/run_costorm_gpt.py @@ -16,51 +16,83 @@ import os import json from argparse import ArgumentParser -from knowledge_storm.collaborative_storm.engine import CollaborativeStormLMConfigs, RunnerArgument, CoStormRunner -from knowledge_storm.collaborative_storm.modules.callback import LocalConsolePrintCallBackHandler +from knowledge_storm.collaborative_storm.engine import ( + CollaborativeStormLMConfigs, + RunnerArgument, + CoStormRunner, +) +from knowledge_storm.collaborative_storm.modules.callback import ( + LocalConsolePrintCallBackHandler, +) from knowledge_storm.lm import OpenAIModel, AzureOpenAIModel from knowledge_storm.logging_wrapper import LoggingWrapper -from knowledge_storm.rm import YouRM, BingSearch, BraveRM, SerperRM, DuckDuckGoSearchRM, TavilySearchRM, SearXNG +from knowledge_storm.rm import ( + YouRM, + BingSearch, + BraveRM, + SerperRM, + DuckDuckGoSearchRM, + TavilySearchRM, + SearXNG, +) from knowledge_storm.utils import load_api_key def main(args): - load_api_key(toml_file_path='secrets.toml') + load_api_key(toml_file_path="secrets.toml") lm_config: CollaborativeStormLMConfigs = CollaborativeStormLMConfigs() - openai_kwargs = { - "api_key": os.getenv("OPENAI_API_KEY"), - "api_provider": "openai", - "temperature": 1.0, - "top_p": 0.9, - "api_base": None, - } if os.getenv('OPENAI_API_TYPE') == 'openai' else { - "api_key": os.getenv("AZURE_API_KEY"), - "temperature": 1.0, - "top_p": 0.9, - "api_base": os.getenv("AZURE_API_BASE"), - "api_version": os.getenv("AZURE_API_VERSION"), - } - - ModelClass = OpenAIModel if os.getenv('OPENAI_API_TYPE') == 'openai' else AzureOpenAIModel + openai_kwargs = ( + { + "api_key": os.getenv("OPENAI_API_KEY"), + "api_provider": "openai", + "temperature": 1.0, + "top_p": 0.9, + "api_base": None, + } + if os.getenv("OPENAI_API_TYPE") == "openai" + else { + "api_key": os.getenv("AZURE_API_KEY"), + "temperature": 1.0, + "top_p": 0.9, + "api_base": os.getenv("AZURE_API_BASE"), + "api_version": os.getenv("AZURE_API_VERSION"), + } + ) + + ModelClass = ( + OpenAIModel if os.getenv("OPENAI_API_TYPE") == "openai" else AzureOpenAIModel + ) # If you are using Azure service, make sure the model name matches your own deployed model name. # The default name here is only used for demonstration and may not match your case. - gpt_4o_mini_model_name = 'gpt-4o-mini' - gpt_4o_model_name = 'gpt-4o' - if os.getenv('OPENAI_API_TYPE') == 'azure': - openai_kwargs['api_base'] = os.getenv('AZURE_API_BASE') - openai_kwargs['api_version'] = os.getenv('AZURE_API_VERSION') + gpt_4o_mini_model_name = "gpt-4o-mini" + gpt_4o_model_name = "gpt-4o" + if os.getenv("OPENAI_API_TYPE") == "azure": + openai_kwargs["api_base"] = os.getenv("AZURE_API_BASE") + openai_kwargs["api_version"] = os.getenv("AZURE_API_VERSION") # STORM is a LM system so different components can be powered by different models. # For a good balance between cost and quality, you can choose a cheaper/faster model for conv_simulator_lm # which is used to split queries, synthesize answers in the conversation. We recommend using stronger models # for outline_gen_lm which is responsible for organizing the collected information, and article_gen_lm # which is responsible for generating sections with citations. - question_answering_lm = ModelClass(model=gpt_4o_model_name, max_tokens=1000, **openai_kwargs) - discourse_manage_lm = ModelClass(model=gpt_4o_model_name, max_tokens=500, **openai_kwargs) - utterance_polishing_lm = ModelClass(model=gpt_4o_model_name, max_tokens=2000, **openai_kwargs) - warmstart_outline_gen_lm = ModelClass(model=gpt_4o_model_name, max_tokens=500, **openai_kwargs) - question_asking_lm = ModelClass(model=gpt_4o_model_name, max_tokens=300, **openai_kwargs) - knowledge_base_lm = ModelClass(model=gpt_4o_model_name, max_tokens=1000, **openai_kwargs) + question_answering_lm = ModelClass( + model=gpt_4o_model_name, max_tokens=1000, **openai_kwargs + ) + discourse_manage_lm = ModelClass( + model=gpt_4o_model_name, max_tokens=500, **openai_kwargs + ) + utterance_polishing_lm = ModelClass( + model=gpt_4o_model_name, max_tokens=2000, **openai_kwargs + ) + warmstart_outline_gen_lm = ModelClass( + model=gpt_4o_model_name, max_tokens=500, **openai_kwargs + ) + question_asking_lm = ModelClass( + model=gpt_4o_model_name, max_tokens=300, **openai_kwargs + ) + knowledge_base_lm = ModelClass( + model=gpt_4o_model_name, max_tokens=1000, **openai_kwargs + ) lm_config.set_question_answering_lm(question_answering_lm) lm_config.set_discourse_manage_lm(discourse_manage_lm) @@ -69,7 +101,7 @@ def main(args): lm_config.set_question_asking_lm(question_asking_lm) lm_config.set_knowledge_base_lm(knowledge_base_lm) - topic = input('Topic: ') + topic = input("Topic: ") runner_argument = RunnerArgument( topic=topic, retrieve_top_k=args.retrieve_top_k, @@ -83,49 +115,76 @@ def main(args): max_thread_num=args.max_thread_num, max_num_round_table_experts=args.max_num_round_table_experts, moderator_override_N_consecutive_answering_turn=args.moderator_override_N_consecutive_answering_turn, - node_expansion_trigger_count=args.node_expansion_trigger_count) + node_expansion_trigger_count=args.node_expansion_trigger_count, + ) logging_wrapper = LoggingWrapper(lm_config) - callback_handler = LocalConsolePrintCallBackHandler() if args.enable_log_print else None + callback_handler = ( + LocalConsolePrintCallBackHandler() if args.enable_log_print else None + ) # Co-STORM is a knowledge curation system which consumes information from the retrieval module. # Currently, the information source is the Internet and we use search engine API as the retrieval module. match args.retriever: - case 'bing': - rm = BingSearch(bing_search_api=os.getenv('BING_SEARCH_API_KEY'), k=runner_argument.retrieve_top_k) - case 'you': - rm = YouRM(ydc_api_key=os.getenv('YDC_API_KEY'), k=runner_argument.retrieve_top_k) - case 'brave': - rm = BraveRM(brave_search_api_key=os.getenv('BRAVE_API_KEY'), k=runner_argument.retrieve_top_k) - case 'duckduckgo': - rm = DuckDuckGoSearchRM(k=runner_argument.retrieve_top_k, safe_search='On', region='us-en') - case 'serper': - rm = SerperRM(serper_search_api_key=os.getenv('SERPER_API_KEY'), query_params={'autocorrect': True, 'num': 10, 'page': 1}) - case 'tavily': - rm = TavilySearchRM(tavily_search_api_key=os.getenv('TAVILY_API_KEY'), k=runner_argument.retrieve_top_k, include_raw_content=True) - case 'searxng': - rm = SearXNG(searxng_api_key=os.getenv('SEARXNG_API_KEY'), k=runner_argument.retrieve_top_k) + case "bing": + rm = BingSearch( + bing_search_api=os.getenv("BING_SEARCH_API_KEY"), + k=runner_argument.retrieve_top_k, + ) + case "you": + rm = YouRM( + ydc_api_key=os.getenv("YDC_API_KEY"), k=runner_argument.retrieve_top_k + ) + case "brave": + rm = BraveRM( + brave_search_api_key=os.getenv("BRAVE_API_KEY"), + k=runner_argument.retrieve_top_k, + ) + case "duckduckgo": + rm = DuckDuckGoSearchRM( + k=runner_argument.retrieve_top_k, safe_search="On", region="us-en" + ) + case "serper": + rm = SerperRM( + serper_search_api_key=os.getenv("SERPER_API_KEY"), + query_params={"autocorrect": True, "num": 10, "page": 1}, + ) + case "tavily": + rm = TavilySearchRM( + tavily_search_api_key=os.getenv("TAVILY_API_KEY"), + k=runner_argument.retrieve_top_k, + include_raw_content=True, + ) + case "searxng": + rm = SearXNG( + searxng_api_key=os.getenv("SEARXNG_API_KEY"), + k=runner_argument.retrieve_top_k, + ) case _: - raise ValueError(f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", or "searxng"') + raise ValueError( + f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", or "searxng"' + ) - costorm_runner = CoStormRunner(lm_config=lm_config, - runner_argument=runner_argument, - logging_wrapper=logging_wrapper, - rm=rm, - callback_handler=callback_handler) + costorm_runner = CoStormRunner( + lm_config=lm_config, + runner_argument=runner_argument, + logging_wrapper=logging_wrapper, + rm=rm, + callback_handler=callback_handler, + ) # warm start the system costorm_runner.warm_start() # Below is an example of how users may interact with Co-STORM to seek information together # In actual deployment, we suggest allowing the user to decide whether to observe the agent utterance or inject a turn - + # observing Co-STORM LLM agent utterance for 5 turns for _ in range(1): conv_turn = costorm_runner.step() print(f"**{conv_turn.role}**: {conv_turn.utterance}\n") - + # active engaging by injecting your utterance - your_utterance = input('Your utterance: ') + your_utterance = input("Your utterance: ") costorm_runner.step(user_utterance=your_utterance) # continue observing @@ -154,93 +213,105 @@ def main(args): json.dump(log_dump, f, indent=2) -if __name__ == '__main__': +if __name__ == "__main__": parser = ArgumentParser() # global arguments - parser.add_argument('--output-dir', type=str, default='./results/co-storm', - help='Directory to store the outputs.') - parser.add_argument('--retriever', type=str, choices=['bing', 'you', 'brave', 'serper', 'duckduckgo', 'tavily', 'searxng'], - help='The search engine API to use for retrieving information.') + parser.add_argument( + "--output-dir", + type=str, + default="./results/co-storm", + help="Directory to store the outputs.", + ) + parser.add_argument( + "--retriever", + type=str, + choices=["bing", "you", "brave", "serper", "duckduckgo", "tavily", "searxng"], + help="The search engine API to use for retrieving information.", + ) # hyperparameters for co-storm parser.add_argument( - '--retrieve_top_k', + "--retrieve_top_k", type=int, default=10, - help='Retrieve top k results for each query in retriever.' + help="Retrieve top k results for each query in retriever.", ) parser.add_argument( - '--max_search_queries', + "--max_search_queries", type=int, default=2, - help='Maximum number of search queries to consider for each question.' + help="Maximum number of search queries to consider for each question.", ) parser.add_argument( - '--total_conv_turn', + "--total_conv_turn", type=int, default=20, - help='Maximum number of turns in conversation.' + help="Maximum number of turns in conversation.", ) parser.add_argument( - '--max_search_thread', + "--max_search_thread", type=int, default=5, - help='Maximum number of parallel threads for retriever.' + help="Maximum number of parallel threads for retriever.", ) parser.add_argument( - '--max_search_queries_per_turn', + "--max_search_queries_per_turn", type=int, default=3, - help='Maximum number of search queries to consider in each turn.' + help="Maximum number of search queries to consider in each turn.", ) parser.add_argument( - '--warmstart_max_num_experts', + "--warmstart_max_num_experts", type=int, default=3, - help='Max number of experts in perspective-guided QA during warm start.' + help="Max number of experts in perspective-guided QA during warm start.", ) parser.add_argument( - '--warmstart_max_turn_per_experts', + "--warmstart_max_turn_per_experts", type=int, default=2, - help='Max number of turns per perspective during warm start.' + help="Max number of turns per perspective during warm start.", ) parser.add_argument( - '--warmstart_max_thread', + "--warmstart_max_thread", type=int, default=3, - help='Max number of threads for parallel perspective-guided QA during warm start.' + help="Max number of threads for parallel perspective-guided QA during warm start.", ) parser.add_argument( - '--max_thread_num', + "--max_thread_num", type=int, default=10, - help=("Maximum number of threads to use. " - "Consider reducing it if you keep getting 'Exceed rate limit' errors when calling the LM API.") + help=( + "Maximum number of threads to use. " + "Consider reducing it if you keep getting 'Exceed rate limit' errors when calling the LM API." + ), ) parser.add_argument( - '--max_num_round_table_experts', + "--max_num_round_table_experts", type=int, default=2, - help='Max number of active experts in round table discussion.' + help="Max number of active experts in round table discussion.", ) parser.add_argument( - '--moderator_override_N_consecutive_answering_turn', + "--moderator_override_N_consecutive_answering_turn", type=int, default=3, - help=('Number of consecutive expert answering turns before the moderator overrides the conversation.') + help=( + "Number of consecutive expert answering turns before the moderator overrides the conversation." + ), ) parser.add_argument( - '--node_expansion_trigger_count', + "--node_expansion_trigger_count", type=int, default=10, - help='Trigger node expansion for nodes that contain more than N snippets.' + help="Trigger node expansion for nodes that contain more than N snippets.", ) # Boolean flags parser.add_argument( - '--enable_log_print', - action='store_true', - help='If set, enable console log print.' + "--enable_log_print", + action="store_true", + help="If set, enable console log print.", ) main(parser.parse_args()) diff --git a/examples/storm_examples/helper/process_kaggle_arxiv_abstract_dataset.py b/examples/storm_examples/helper/process_kaggle_arxiv_abstract_dataset.py index 926192bf..f3dfe51d 100644 --- a/examples/storm_examples/helper/process_kaggle_arxiv_abstract_dataset.py +++ b/examples/storm_examples/helper/process_kaggle_arxiv_abstract_dataset.py @@ -9,21 +9,28 @@ if __name__ == "__main__": parser = ArgumentParser() - parser.add_argument("--input-path", type=str, help="Path to arxiv_data_210930-054931.csv.") - parser.add_argument("--output-path", type=str, - help="Path to store the csv file that is compatible with VectorRM.") + parser.add_argument( + "--input-path", type=str, help="Path to arxiv_data_210930-054931.csv." + ) + parser.add_argument( + "--output-path", + type=str, + help="Path to store the csv file that is compatible with VectorRM.", + ) args = parser.parse_args() df = pd.read_csv(args.input_path) - print(f'The original dataset has {len(df)} samples.') + print(f"The original dataset has {len(df)} samples.") # Downsample the dataset. - df = df[df['terms'] == "['cs.CV']"] + df = df[df["terms"] == "['cs.CV']"] # Reformat the dataset to match the VectorRM input format. df.rename(columns={"abstracts": "content", "titles": "title"}, inplace=True) - df['url'] = ['uid_' + str(idx) for idx in range(len(df))] # Ensure the url is unique. - df['description'] = '' + df["url"] = [ + "uid_" + str(idx) for idx in range(len(df)) + ] # Ensure the url is unique. + df["description"] = "" - print(f'The downsampled dataset has {len(df)} samples.') - df.to_csv(args.output_path, index=False) \ No newline at end of file + print(f"The downsampled dataset has {len(df)} samples.") + df.to_csv(args.output_path, index=False) diff --git a/examples/storm_examples/run_storm_wiki_claude.py b/examples/storm_examples/run_storm_wiki_claude.py index 1da12c0e..10f6e369 100644 --- a/examples/storm_examples/run_storm_wiki_claude.py +++ b/examples/storm_examples/run_storm_wiki_claude.py @@ -19,19 +19,31 @@ import os from argparse import ArgumentParser -from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs +from knowledge_storm import ( + STORMWikiRunnerArguments, + STORMWikiRunner, + STORMWikiLMConfigs, +) from knowledge_storm.lm import ClaudeModel -from knowledge_storm.rm import YouRM, BingSearch, BraveRM, SerperRM, DuckDuckGoSearchRM, TavilySearchRM, SearXNG +from knowledge_storm.rm import ( + YouRM, + BingSearch, + BraveRM, + SerperRM, + DuckDuckGoSearchRM, + TavilySearchRM, + SearXNG, +) from knowledge_storm.utils import load_api_key def main(args): - load_api_key(toml_file_path='secrets.toml') + load_api_key(toml_file_path="secrets.toml") lm_configs = STORMWikiLMConfigs() claude_kwargs = { - 'api_key': os.getenv("ANTHROPIC_API_KEY"), - 'temperature': 1.0, - 'top_p': 0.9 + "api_key": os.getenv("ANTHROPIC_API_KEY"), + "temperature": 1.0, + "top_p": 0.9, } # STORM is a LM system so different components can be powered by different models. @@ -39,11 +51,21 @@ def main(args): # which is used to split queries, synthesize answers in the conversation. We recommend using stronger models # for outline_gen_lm which is responsible for organizing the collected information, and article_gen_lm # which is responsible for generating sections with citations. - conv_simulator_lm = ClaudeModel(model='claude-3-haiku-20240307', max_tokens=500, **claude_kwargs) - question_asker_lm = ClaudeModel(model='claude-3-sonnet-20240229', max_tokens=500, **claude_kwargs) - outline_gen_lm = ClaudeModel(model='claude-3-opus-20240229', max_tokens=400, **claude_kwargs) - article_gen_lm = ClaudeModel(model='claude-3-opus-20240229', max_tokens=700, **claude_kwargs) - article_polish_lm = ClaudeModel(model='claude-3-opus-20240229', max_tokens=4000, **claude_kwargs) + conv_simulator_lm = ClaudeModel( + model="claude-3-haiku-20240307", max_tokens=500, **claude_kwargs + ) + question_asker_lm = ClaudeModel( + model="claude-3-sonnet-20240229", max_tokens=500, **claude_kwargs + ) + outline_gen_lm = ClaudeModel( + model="claude-3-opus-20240229", max_tokens=400, **claude_kwargs + ) + article_gen_lm = ClaudeModel( + model="claude-3-opus-20240229", max_tokens=700, **claude_kwargs + ) + article_polish_lm = ClaudeModel( + model="claude-3-opus-20240229", max_tokens=4000, **claude_kwargs + ) lm_configs.set_conv_simulator_lm(conv_simulator_lm) lm_configs.set_question_asker_lm(question_asker_lm) @@ -62,26 +84,45 @@ def main(args): # STORM is a knowledge curation system which consumes information from the retrieval module. # Currently, the information source is the Internet and we use search engine API as the retrieval module. match args.retriever: - case 'bing': - rm = BingSearch(bing_search_api=os.getenv('BING_SEARCH_API_KEY'), k=engine_args.search_top_k) - case 'you': - rm = YouRM(ydc_api_key=os.getenv('YDC_API_KEY'), k=engine_args.search_top_k) - case 'brave': - rm = BraveRM(brave_search_api_key=os.getenv('BRAVE_API_KEY'), k=engine_args.search_top_k) - case 'duckduckgo': - rm = DuckDuckGoSearchRM(k=engine_args.search_top_k, safe_search='On', region='us-en') - case 'serper': - rm = SerperRM(serper_search_api_key=os.getenv('SERPER_API_KEY'), query_params={'autocorrect': True, 'num': 10, 'page': 1}) - case 'tavily': - rm = TavilySearchRM(tavily_search_api_key=os.getenv('TAVILY_API_KEY'), k=engine_args.search_top_k, include_raw_content=True) - case 'searxng': - rm = SearXNG(searxng_api_key=os.getenv('SEARXNG_API_KEY'), k=engine_args.search_top_k) + case "bing": + rm = BingSearch( + bing_search_api=os.getenv("BING_SEARCH_API_KEY"), + k=engine_args.search_top_k, + ) + case "you": + rm = YouRM(ydc_api_key=os.getenv("YDC_API_KEY"), k=engine_args.search_top_k) + case "brave": + rm = BraveRM( + brave_search_api_key=os.getenv("BRAVE_API_KEY"), + k=engine_args.search_top_k, + ) + case "duckduckgo": + rm = DuckDuckGoSearchRM( + k=engine_args.search_top_k, safe_search="On", region="us-en" + ) + case "serper": + rm = SerperRM( + serper_search_api_key=os.getenv("SERPER_API_KEY"), + query_params={"autocorrect": True, "num": 10, "page": 1}, + ) + case "tavily": + rm = TavilySearchRM( + tavily_search_api_key=os.getenv("TAVILY_API_KEY"), + k=engine_args.search_top_k, + include_raw_content=True, + ) + case "searxng": + rm = SearXNG( + searxng_api_key=os.getenv("SEARXNG_API_KEY"), k=engine_args.search_top_k + ) case _: - raise ValueError(f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", or "searxng"') - + raise ValueError( + f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", or "searxng"' + ) + runner = STORMWikiRunner(engine_args, lm_configs, rm) - topic = input('Topic: ') + topic = input("Topic: ") runner.run( topic=topic, do_research=args.do_research, @@ -93,38 +134,81 @@ def main(args): runner.summary() -if __name__ == '__main__': +if __name__ == "__main__": parser = ArgumentParser() # global arguments - parser.add_argument('--output-dir', type=str, default='./results/claude', - help='Directory to store the outputs.') - parser.add_argument('--max-thread-num', type=int, default=3, - help='Maximum number of threads to use. The information seeking part and the article generation' - 'part can speed up by using multiple threads. Consider reducing it if keep getting ' - '"Exceed rate limit" error when calling LM API.') - parser.add_argument('--retriever', type=str, choices=['bing', 'you', 'brave', 'serper', 'duckduckgo', 'tavily', 'searxng'], - help='The search engine API to use for retrieving information.') + parser.add_argument( + "--output-dir", + type=str, + default="./results/claude", + help="Directory to store the outputs.", + ) + parser.add_argument( + "--max-thread-num", + type=int, + default=3, + help="Maximum number of threads to use. The information seeking part and the article generation" + "part can speed up by using multiple threads. Consider reducing it if keep getting " + '"Exceed rate limit" error when calling LM API.', + ) + parser.add_argument( + "--retriever", + type=str, + choices=["bing", "you", "brave", "serper", "duckduckgo", "tavily", "searxng"], + help="The search engine API to use for retrieving information.", + ) # stage of the pipeline - parser.add_argument('--do-research', action='store_true', - help='If True, simulate conversation to research the topic; otherwise, load the results.') - parser.add_argument('--do-generate-outline', action='store_true', - help='If True, generate an outline for the topic; otherwise, load the results.') - parser.add_argument('--do-generate-article', action='store_true', - help='If True, generate an article for the topic; otherwise, load the results.') - parser.add_argument('--do-polish-article', action='store_true', - help='If True, polish the article by adding a summarization section and (optionally) removing ' - 'duplicate content.') + parser.add_argument( + "--do-research", + action="store_true", + help="If True, simulate conversation to research the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-outline", + action="store_true", + help="If True, generate an outline for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-article", + action="store_true", + help="If True, generate an article for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-polish-article", + action="store_true", + help="If True, polish the article by adding a summarization section and (optionally) removing " + "duplicate content.", + ) # hyperparameters for the pre-writing stage - parser.add_argument('--max-conv-turn', type=int, default=3, - help='Maximum number of questions in conversational question asking.') - parser.add_argument('--max-perspective', type=int, default=3, - help='Maximum number of perspectives to consider in perspective-guided question asking.') - parser.add_argument('--search-top-k', type=int, default=3, - help='Top k search results to consider for each search query.') + parser.add_argument( + "--max-conv-turn", + type=int, + default=3, + help="Maximum number of questions in conversational question asking.", + ) + parser.add_argument( + "--max-perspective", + type=int, + default=3, + help="Maximum number of perspectives to consider in perspective-guided question asking.", + ) + parser.add_argument( + "--search-top-k", + type=int, + default=3, + help="Top k search results to consider for each search query.", + ) # hyperparameters for the writing stage - parser.add_argument('--retrieve-top-k', type=int, default=3, - help='Top k collected references for each section title.') - parser.add_argument('--remove-duplicate', action='store_true', - help='If True, remove duplicate content from the article.') + parser.add_argument( + "--retrieve-top-k", + type=int, + default=3, + help="Top k collected references for each section title.", + ) + parser.add_argument( + "--remove-duplicate", + action="store_true", + help="If True, remove duplicate content from the article.", + ) - main(parser.parse_args()) \ No newline at end of file + main(parser.parse_args()) diff --git a/examples/storm_examples/run_storm_wiki_deepseek.py b/examples/storm_examples/run_storm_wiki_deepseek.py index e4cf5fca..e831797f 100644 --- a/examples/storm_examples/run_storm_wiki_deepseek.py +++ b/examples/storm_examples/run_storm_wiki_deepseek.py @@ -22,9 +22,21 @@ import logging from argparse import ArgumentParser -from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs +from knowledge_storm import ( + STORMWikiRunnerArguments, + STORMWikiRunner, + STORMWikiLMConfigs, +) from knowledge_storm.lm import DeepSeekModel -from knowledge_storm.rm import YouRM, BingSearch, BraveRM, SerperRM, DuckDuckGoSearchRM, TavilySearchRM, SearXNG +from knowledge_storm.rm import ( + YouRM, + BingSearch, + BraveRM, + SerperRM, + DuckDuckGoSearchRM, + TavilySearchRM, + SearXNG, +) from knowledge_storm.utils import load_api_key @@ -34,10 +46,10 @@ def sanitize_topic(topic): Remove or replace characters that are not allowed in file names. """ # Replace spaces with underscores - topic = topic.replace(' ', '_') + topic = topic.replace(" ", "_") # Remove any character that isn't alphanumeric, underscore, or hyphen - topic = re.sub(r'[^a-zA-Z0-9_-]', '', topic) + topic = re.sub(r"[^a-zA-Z0-9_-]", "", topic) # Ensure the topic isn't empty after sanitization if not topic: @@ -47,29 +59,37 @@ def sanitize_topic(topic): def main(args): - load_api_key(toml_file_path='secrets.toml') + load_api_key(toml_file_path="secrets.toml") lm_configs = STORMWikiLMConfigs() logger = logging.getLogger(__name__) # Ensure DEEPSEEK_API_KEY is set if not os.getenv("DEEPSEEK_API_KEY"): - raise ValueError("DEEPSEEK_API_KEY environment variable is not set. Please set it in your secrets.toml file.") + raise ValueError( + "DEEPSEEK_API_KEY environment variable is not set. Please set it in your secrets.toml file." + ) deepseek_kwargs = { - 'api_key': os.getenv("DEEPSEEK_API_KEY"), - 'api_base': os.getenv("DEEPSEEK_API_BASE", "https://api.deepseek.com"), - 'temperature': args.temperature, - 'top_p': args.top_p, + "api_key": os.getenv("DEEPSEEK_API_KEY"), + "api_base": os.getenv("DEEPSEEK_API_BASE", "https://api.deepseek.com"), + "temperature": args.temperature, + "top_p": args.top_p, } # DeepSeek offers two main models: 'deepseek-chat' for general tasks and 'deepseek-coder' for coding tasks # Users can choose the appropriate model based on their needs - conv_simulator_lm = DeepSeekModel(model=args.model, max_tokens=500, **deepseek_kwargs) - question_asker_lm = DeepSeekModel(model=args.model, max_tokens=500, **deepseek_kwargs) + conv_simulator_lm = DeepSeekModel( + model=args.model, max_tokens=500, **deepseek_kwargs + ) + question_asker_lm = DeepSeekModel( + model=args.model, max_tokens=500, **deepseek_kwargs + ) outline_gen_lm = DeepSeekModel(model=args.model, max_tokens=400, **deepseek_kwargs) article_gen_lm = DeepSeekModel(model=args.model, max_tokens=700, **deepseek_kwargs) - article_polish_lm = DeepSeekModel(model=args.model, max_tokens=4000, **deepseek_kwargs) + article_polish_lm = DeepSeekModel( + model=args.model, max_tokens=4000, **deepseek_kwargs + ) lm_configs.set_conv_simulator_lm(conv_simulator_lm) lm_configs.set_question_asker_lm(question_asker_lm) @@ -88,26 +108,45 @@ def main(args): # STORM is a knowledge curation system which consumes information from the retrieval module. # Currently, the information source is the Internet and we use search engine API as the retrieval module. match args.retriever: - case 'bing': - rm = BingSearch(bing_search_api=os.getenv('BING_SEARCH_API_KEY'), k=engine_args.search_top_k) - case 'you': - rm = YouRM(ydc_api_key=os.getenv('YDC_API_KEY'), k=engine_args.search_top_k) - case 'brave': - rm = BraveRM(brave_search_api_key=os.getenv('BRAVE_API_KEY'), k=engine_args.search_top_k) - case 'duckduckgo': - rm = DuckDuckGoSearchRM(k=engine_args.search_top_k, safe_search='On', region='us-en') - case 'serper': - rm = SerperRM(serper_search_api_key=os.getenv('SERPER_API_KEY'), query_params={'autocorrect': True, 'num': 10, 'page': 1}) - case 'tavily': - rm = TavilySearchRM(tavily_search_api_key=os.getenv('TAVILY_API_KEY'), k=engine_args.search_top_k, include_raw_content=True) - case 'searxng': - rm = SearXNG(searxng_api_key=os.getenv('SEARXNG_API_KEY'), k=engine_args.search_top_k) + case "bing": + rm = BingSearch( + bing_search_api=os.getenv("BING_SEARCH_API_KEY"), + k=engine_args.search_top_k, + ) + case "you": + rm = YouRM(ydc_api_key=os.getenv("YDC_API_KEY"), k=engine_args.search_top_k) + case "brave": + rm = BraveRM( + brave_search_api_key=os.getenv("BRAVE_API_KEY"), + k=engine_args.search_top_k, + ) + case "duckduckgo": + rm = DuckDuckGoSearchRM( + k=engine_args.search_top_k, safe_search="On", region="us-en" + ) + case "serper": + rm = SerperRM( + serper_search_api_key=os.getenv("SERPER_API_KEY"), + query_params={"autocorrect": True, "num": 10, "page": 1}, + ) + case "tavily": + rm = TavilySearchRM( + tavily_search_api_key=os.getenv("TAVILY_API_KEY"), + k=engine_args.search_top_k, + include_raw_content=True, + ) + case "searxng": + rm = SearXNG( + searxng_api_key=os.getenv("SEARXNG_API_KEY"), k=engine_args.search_top_k + ) case _: - raise ValueError(f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", or "searxng"') + raise ValueError( + f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", or "searxng"' + ) runner = STORMWikiRunner(engine_args, lm_configs, rm) - topic = input('Topic: ') + topic = input("Topic: ") sanitized_topic = sanitize_topic(topic) try: @@ -126,44 +165,94 @@ def main(args): raise -if __name__ == '__main__': +if __name__ == "__main__": parser = ArgumentParser() # global arguments - parser.add_argument('--output-dir', type=str, default='./results/deepseek', - help='Directory to store the outputs.') - parser.add_argument('--max-thread-num', type=int, default=3, - help='Maximum number of threads to use. The information seeking part and the article generation' - 'part can speed up by using multiple threads. Consider reducing it if keep getting ' - '"Exceed rate limit" error when calling LM API.') - parser.add_argument('--retriever', type=str, choices=['bing', 'you', 'brave', 'serper', 'duckduckgo', 'tavily', 'searxng'], - help='The search engine API to use for retrieving information.') - parser.add_argument('--model', type=str, choices=['deepseek-chat', 'deepseek-coder'], default='deepseek-chat', - help='DeepSeek model to use. "deepseek-chat" for general tasks, "deepseek-coder" for coding tasks.') - parser.add_argument('--temperature', type=float, default=1.0, - help='Sampling temperature to use.') - parser.add_argument('--top_p', type=float, default=0.9, - help='Top-p sampling parameter.') + parser.add_argument( + "--output-dir", + type=str, + default="./results/deepseek", + help="Directory to store the outputs.", + ) + parser.add_argument( + "--max-thread-num", + type=int, + default=3, + help="Maximum number of threads to use. The information seeking part and the article generation" + "part can speed up by using multiple threads. Consider reducing it if keep getting " + '"Exceed rate limit" error when calling LM API.', + ) + parser.add_argument( + "--retriever", + type=str, + choices=["bing", "you", "brave", "serper", "duckduckgo", "tavily", "searxng"], + help="The search engine API to use for retrieving information.", + ) + parser.add_argument( + "--model", + type=str, + choices=["deepseek-chat", "deepseek-coder"], + default="deepseek-chat", + help='DeepSeek model to use. "deepseek-chat" for general tasks, "deepseek-coder" for coding tasks.', + ) + parser.add_argument( + "--temperature", type=float, default=1.0, help="Sampling temperature to use." + ) + parser.add_argument( + "--top_p", type=float, default=0.9, help="Top-p sampling parameter." + ) # stage of the pipeline - parser.add_argument('--do-research', action='store_true', - help='If True, simulate conversation to research the topic; otherwise, load the results.') - parser.add_argument('--do-generate-outline', action='store_true', - help='If True, generate an outline for the topic; otherwise, load the results.') - parser.add_argument('--do-generate-article', action='store_true', - help='If True, generate an article for the topic; otherwise, load the results.') - parser.add_argument('--do-polish-article', action='store_true', - help='If True, polish the article by adding a summarization section and (optionally) removing ' - 'duplicate content.') + parser.add_argument( + "--do-research", + action="store_true", + help="If True, simulate conversation to research the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-outline", + action="store_true", + help="If True, generate an outline for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-article", + action="store_true", + help="If True, generate an article for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-polish-article", + action="store_true", + help="If True, polish the article by adding a summarization section and (optionally) removing " + "duplicate content.", + ) # hyperparameters for the pre-writing stage - parser.add_argument('--max-conv-turn', type=int, default=3, - help='Maximum number of questions in conversational question asking.') - parser.add_argument('--max-perspective', type=int, default=3, - help='Maximum number of perspectives to consider in perspective-guided question asking.') - parser.add_argument('--search-top-k', type=int, default=3, - help='Top k search results to consider for each search query.') + parser.add_argument( + "--max-conv-turn", + type=int, + default=3, + help="Maximum number of questions in conversational question asking.", + ) + parser.add_argument( + "--max-perspective", + type=int, + default=3, + help="Maximum number of perspectives to consider in perspective-guided question asking.", + ) + parser.add_argument( + "--search-top-k", + type=int, + default=3, + help="Top k search results to consider for each search query.", + ) # hyperparameters for the writing stage - parser.add_argument('--retrieve-top-k', type=int, default=3, - help='Top k collected references for each section title.') - parser.add_argument('--remove-duplicate', action='store_true', - help='If True, remove duplicate content from the article.') + parser.add_argument( + "--retrieve-top-k", + type=int, + default=3, + help="Top k collected references for each section title.", + ) + parser.add_argument( + "--remove-duplicate", + action="store_true", + help="If True, remove duplicate content from the article.", + ) - main(parser.parse_args()) \ No newline at end of file + main(parser.parse_args()) diff --git a/examples/storm_examples/run_storm_wiki_gemini.py b/examples/storm_examples/run_storm_wiki_gemini.py index d5740629..947832fa 100644 --- a/examples/storm_examples/run_storm_wiki_gemini.py +++ b/examples/storm_examples/run_storm_wiki_gemini.py @@ -19,18 +19,31 @@ import os from argparse import ArgumentParser -from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs +from knowledge_storm import ( + STORMWikiRunnerArguments, + STORMWikiRunner, + STORMWikiLMConfigs, +) from knowledge_storm.lm import GoogleModel -from knowledge_storm.rm import YouRM, BingSearch, BraveRM, SerperRM, DuckDuckGoSearchRM, TavilySearchRM, SearXNG +from knowledge_storm.rm import ( + YouRM, + BingSearch, + BraveRM, + SerperRM, + DuckDuckGoSearchRM, + TavilySearchRM, + SearXNG, +) from knowledge_storm.utils import load_api_key + def main(args): - load_api_key(toml_file_path='secrets.toml') + load_api_key(toml_file_path="secrets.toml") lm_configs = STORMWikiLMConfigs() gemini_kwargs = { - 'api_key': os.getenv("GOOGLE_API_KEY"), - 'temperature': 1.0, - 'top_p': 0.9 + "api_key": os.getenv("GOOGLE_API_KEY"), + "temperature": 1.0, + "top_p": 0.9, } # STORM is a LM system so different components can be powered by different models. @@ -40,11 +53,21 @@ def main(args): # which is responsible for generating sections with citations. # To check out available Google models, see: # https://ai.google.dev/gemini-api/docs/get-started/tutorial?lang=python#list_models - conv_simulator_lm = GoogleModel(model='models/gemini-1.5-flash', max_tokens=500, **gemini_kwargs) - question_asker_lm = GoogleModel(model='models/gemini-1.5-flash', max_tokens=500, **gemini_kwargs) - outline_gen_lm = GoogleModel(model='models/gemini-1.5-pro-exp-0801', max_tokens=400, **gemini_kwargs) - article_gen_lm = GoogleModel(model='models/gemini-1.5-pro-exp-0801', max_tokens=700, **gemini_kwargs) - article_polish_lm = GoogleModel(model='models/gemini-1.5-pro-exp-0801', max_tokens=4000, **gemini_kwargs) + conv_simulator_lm = GoogleModel( + model="models/gemini-1.5-flash", max_tokens=500, **gemini_kwargs + ) + question_asker_lm = GoogleModel( + model="models/gemini-1.5-flash", max_tokens=500, **gemini_kwargs + ) + outline_gen_lm = GoogleModel( + model="models/gemini-1.5-pro-exp-0801", max_tokens=400, **gemini_kwargs + ) + article_gen_lm = GoogleModel( + model="models/gemini-1.5-pro-exp-0801", max_tokens=700, **gemini_kwargs + ) + article_polish_lm = GoogleModel( + model="models/gemini-1.5-pro-exp-0801", max_tokens=4000, **gemini_kwargs + ) lm_configs.set_conv_simulator_lm(conv_simulator_lm) lm_configs.set_question_asker_lm(question_asker_lm) @@ -63,26 +86,45 @@ def main(args): # STORM is a knowledge curation system which consumes information from the retrieval module. # Currently, the information source is the Internet and we use search engine API as the retrieval module. match args.retriever: - case 'bing': - rm = BingSearch(bing_search_api=os.getenv('BING_SEARCH_API_KEY'), k=engine_args.search_top_k) - case 'you': - rm = YouRM(ydc_api_key=os.getenv('YDC_API_KEY'), k=engine_args.search_top_k) - case 'brave': - rm = BraveRM(brave_search_api_key=os.getenv('BRAVE_API_KEY'), k=engine_args.search_top_k) - case 'duckduckgo': - rm = DuckDuckGoSearchRM(k=engine_args.search_top_k, safe_search='On', region='us-en') - case 'serper': - rm = SerperRM(serper_search_api_key=os.getenv('SERPER_API_KEY'), query_params={'autocorrect': True, 'num': 10, 'page': 1}) - case 'tavily': - rm = TavilySearchRM(tavily_search_api_key=os.getenv('TAVILY_API_KEY'), k=engine_args.search_top_k, include_raw_content=True) - case 'searxng': - rm = SearXNG(searxng_api_key=os.getenv('SEARXNG_API_KEY'), k=engine_args.search_top_k) + case "bing": + rm = BingSearch( + bing_search_api=os.getenv("BING_SEARCH_API_KEY"), + k=engine_args.search_top_k, + ) + case "you": + rm = YouRM(ydc_api_key=os.getenv("YDC_API_KEY"), k=engine_args.search_top_k) + case "brave": + rm = BraveRM( + brave_search_api_key=os.getenv("BRAVE_API_KEY"), + k=engine_args.search_top_k, + ) + case "duckduckgo": + rm = DuckDuckGoSearchRM( + k=engine_args.search_top_k, safe_search="On", region="us-en" + ) + case "serper": + rm = SerperRM( + serper_search_api_key=os.getenv("SERPER_API_KEY"), + query_params={"autocorrect": True, "num": 10, "page": 1}, + ) + case "tavily": + rm = TavilySearchRM( + tavily_search_api_key=os.getenv("TAVILY_API_KEY"), + k=engine_args.search_top_k, + include_raw_content=True, + ) + case "searxng": + rm = SearXNG( + searxng_api_key=os.getenv("SEARXNG_API_KEY"), k=engine_args.search_top_k + ) case _: - raise ValueError(f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", or "searxng"') + raise ValueError( + f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", or "searxng"' + ) runner = STORMWikiRunner(engine_args, lm_configs, rm) - topic = input('Topic: ') + topic = input("Topic: ") runner.run( topic=topic, do_research=args.do_research, @@ -94,38 +136,81 @@ def main(args): runner.summary() -if __name__ == '__main__': +if __name__ == "__main__": parser = ArgumentParser() # global arguments - parser.add_argument('--output-dir', type=str, default='./results/gemini', - help='Directory to store the outputs.') - parser.add_argument('--max-thread-num', type=int, default=3, - help='Maximum number of threads to use. The information seeking part and the article generation' - 'part can speed up by using multiple threads. Consider reducing it if keep getting ' - '"Exceed rate limit" error when calling LM API.') - parser.add_argument('--retriever', type=str, choices=['bing', 'you', 'brave', 'serper', 'duckduckgo', 'tavily', 'searxng'], - help='The search engine API to use for retrieving information.') + parser.add_argument( + "--output-dir", + type=str, + default="./results/gemini", + help="Directory to store the outputs.", + ) + parser.add_argument( + "--max-thread-num", + type=int, + default=3, + help="Maximum number of threads to use. The information seeking part and the article generation" + "part can speed up by using multiple threads. Consider reducing it if keep getting " + '"Exceed rate limit" error when calling LM API.', + ) + parser.add_argument( + "--retriever", + type=str, + choices=["bing", "you", "brave", "serper", "duckduckgo", "tavily", "searxng"], + help="The search engine API to use for retrieving information.", + ) # stage of the pipeline - parser.add_argument('--do-research', action='store_true', - help='If True, simulate conversation to research the topic; otherwise, load the results.') - parser.add_argument('--do-generate-outline', action='store_true', - help='If True, generate an outline for the topic; otherwise, load the results.') - parser.add_argument('--do-generate-article', action='store_true', - help='If True, generate an article for the topic; otherwise, load the results.') - parser.add_argument('--do-polish-article', action='store_true', - help='If True, polish the article by adding a summarization section and (optionally) removing ' - 'duplicate content.') + parser.add_argument( + "--do-research", + action="store_true", + help="If True, simulate conversation to research the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-outline", + action="store_true", + help="If True, generate an outline for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-article", + action="store_true", + help="If True, generate an article for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-polish-article", + action="store_true", + help="If True, polish the article by adding a summarization section and (optionally) removing " + "duplicate content.", + ) # hyperparameters for the pre-writing stage - parser.add_argument('--max-conv-turn', type=int, default=3, - help='Maximum number of questions in conversational question asking.') - parser.add_argument('--max-perspective', type=int, default=3, - help='Maximum number of perspectives to consider in perspective-guided question asking.') - parser.add_argument('--search-top-k', type=int, default=3, - help='Top k search results to consider for each search query.') + parser.add_argument( + "--max-conv-turn", + type=int, + default=3, + help="Maximum number of questions in conversational question asking.", + ) + parser.add_argument( + "--max-perspective", + type=int, + default=3, + help="Maximum number of perspectives to consider in perspective-guided question asking.", + ) + parser.add_argument( + "--search-top-k", + type=int, + default=3, + help="Top k search results to consider for each search query.", + ) # hyperparameters for the writing stage - parser.add_argument('--retrieve-top-k', type=int, default=3, - help='Top k collected references for each section title.') - parser.add_argument('--remove-duplicate', action='store_true', - help='If True, remove duplicate content from the article.') + parser.add_argument( + "--retrieve-top-k", + type=int, + default=3, + help="Top k collected references for each section title.", + ) + parser.add_argument( + "--remove-duplicate", + action="store_true", + help="If True, remove duplicate content from the article.", + ) main(parser.parse_args()) diff --git a/examples/storm_examples/run_storm_wiki_gpt.py b/examples/storm_examples/run_storm_wiki_gpt.py index b1740a12..e96c6809 100644 --- a/examples/storm_examples/run_storm_wiki_gpt.py +++ b/examples/storm_examples/run_storm_wiki_gpt.py @@ -22,40 +22,63 @@ import os from argparse import ArgumentParser -from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs +from knowledge_storm import ( + STORMWikiRunnerArguments, + STORMWikiRunner, + STORMWikiLMConfigs, +) from knowledge_storm.lm import OpenAIModel, AzureOpenAIModel -from knowledge_storm.rm import YouRM, BingSearch, BraveRM, SerperRM, DuckDuckGoSearchRM, TavilySearchRM, SearXNG, AzureAISearch +from knowledge_storm.rm import ( + YouRM, + BingSearch, + BraveRM, + SerperRM, + DuckDuckGoSearchRM, + TavilySearchRM, + SearXNG, + AzureAISearch, +) from knowledge_storm.utils import load_api_key def main(args): - load_api_key(toml_file_path='secrets.toml') + load_api_key(toml_file_path="secrets.toml") lm_configs = STORMWikiLMConfigs() openai_kwargs = { - 'api_key': os.getenv("OPENAI_API_KEY"), - 'temperature': 1.0, - 'top_p': 0.9, + "api_key": os.getenv("OPENAI_API_KEY"), + "temperature": 1.0, + "top_p": 0.9, } - ModelClass = OpenAIModel if os.getenv('OPENAI_API_TYPE') == 'openai' else AzureOpenAIModel + ModelClass = ( + OpenAIModel if os.getenv("OPENAI_API_TYPE") == "openai" else AzureOpenAIModel + ) # If you are using Azure service, make sure the model name matches your own deployed model name. # The default name here is only used for demonstration and may not match your case. - gpt_35_model_name = 'gpt-3.5-turbo' if os.getenv('OPENAI_API_TYPE') == 'openai' else 'gpt-35-turbo' - gpt_4_model_name = 'gpt-4o' - if os.getenv('OPENAI_API_TYPE') == 'azure': - openai_kwargs['api_base'] = os.getenv('AZURE_API_BASE') - openai_kwargs['api_version'] = os.getenv('AZURE_API_VERSION') + gpt_35_model_name = ( + "gpt-3.5-turbo" if os.getenv("OPENAI_API_TYPE") == "openai" else "gpt-35-turbo" + ) + gpt_4_model_name = "gpt-4o" + if os.getenv("OPENAI_API_TYPE") == "azure": + openai_kwargs["api_base"] = os.getenv("AZURE_API_BASE") + openai_kwargs["api_version"] = os.getenv("AZURE_API_VERSION") # STORM is a LM system so different components can be powered by different models. # For a good balance between cost and quality, you can choose a cheaper/faster model for conv_simulator_lm # which is used to split queries, synthesize answers in the conversation. We recommend using stronger models # for outline_gen_lm which is responsible for organizing the collected information, and article_gen_lm # which is responsible for generating sections with citations. - conv_simulator_lm = ModelClass(model=gpt_35_model_name, max_tokens=500, **openai_kwargs) - question_asker_lm = ModelClass(model=gpt_35_model_name, max_tokens=500, **openai_kwargs) + conv_simulator_lm = ModelClass( + model=gpt_35_model_name, max_tokens=500, **openai_kwargs + ) + question_asker_lm = ModelClass( + model=gpt_35_model_name, max_tokens=500, **openai_kwargs + ) outline_gen_lm = ModelClass(model=gpt_4_model_name, max_tokens=400, **openai_kwargs) article_gen_lm = ModelClass(model=gpt_4_model_name, max_tokens=700, **openai_kwargs) - article_polish_lm = ModelClass(model=gpt_4_model_name, max_tokens=4000, **openai_kwargs) + article_polish_lm = ModelClass( + model=gpt_4_model_name, max_tokens=4000, **openai_kwargs + ) lm_configs.set_conv_simulator_lm(conv_simulator_lm) lm_configs.set_question_asker_lm(question_asker_lm) @@ -75,28 +98,50 @@ def main(args): # Currently, the information source is the Internet and we use search engine API as the retrieval module. match args.retriever: - case 'bing': - rm = BingSearch(bing_search_api=os.getenv('BING_SEARCH_API_KEY'), k=engine_args.search_top_k) - case 'you': - rm = YouRM(ydc_api_key=os.getenv('YDC_API_KEY'), k=engine_args.search_top_k) - case 'brave': - rm = BraveRM(brave_search_api_key=os.getenv('BRAVE_API_KEY'), k=engine_args.search_top_k) - case 'duckduckgo': - rm = DuckDuckGoSearchRM(k=engine_args.search_top_k, safe_search='On', region='us-en') - case 'serper': - rm = SerperRM(serper_search_api_key=os.getenv('SERPER_API_KEY'), query_params={'autocorrect': True, 'num': 10, 'page': 1}) - case 'tavily': - rm = TavilySearchRM(tavily_search_api_key=os.getenv('TAVILY_API_KEY'), k=engine_args.search_top_k, include_raw_content=True) - case 'searxng': - rm = SearXNG(searxng_api_key=os.getenv('SEARXNG_API_KEY'), k=engine_args.search_top_k) - case 'azure_ai_search': - rm = AzureAISearch(azure_ai_search_api_key=os.getenv('AZURE_AI_SEARCH_API_KEY'), k=engine_args.search_top_k) + case "bing": + rm = BingSearch( + bing_search_api=os.getenv("BING_SEARCH_API_KEY"), + k=engine_args.search_top_k, + ) + case "you": + rm = YouRM(ydc_api_key=os.getenv("YDC_API_KEY"), k=engine_args.search_top_k) + case "brave": + rm = BraveRM( + brave_search_api_key=os.getenv("BRAVE_API_KEY"), + k=engine_args.search_top_k, + ) + case "duckduckgo": + rm = DuckDuckGoSearchRM( + k=engine_args.search_top_k, safe_search="On", region="us-en" + ) + case "serper": + rm = SerperRM( + serper_search_api_key=os.getenv("SERPER_API_KEY"), + query_params={"autocorrect": True, "num": 10, "page": 1}, + ) + case "tavily": + rm = TavilySearchRM( + tavily_search_api_key=os.getenv("TAVILY_API_KEY"), + k=engine_args.search_top_k, + include_raw_content=True, + ) + case "searxng": + rm = SearXNG( + searxng_api_key=os.getenv("SEARXNG_API_KEY"), k=engine_args.search_top_k + ) + case "azure_ai_search": + rm = AzureAISearch( + azure_ai_search_api_key=os.getenv("AZURE_AI_SEARCH_API_KEY"), + k=engine_args.search_top_k, + ) case _: - raise ValueError(f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", "searxng", or "azure_ai_search"') + raise ValueError( + f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", "searxng", or "azure_ai_search"' + ) runner = STORMWikiRunner(engine_args, lm_configs, rm) - topic = input('Topic: ') + topic = input("Topic: ") runner.run( topic=topic, do_research=args.do_research, @@ -108,38 +153,90 @@ def main(args): runner.summary() -if __name__ == '__main__': +if __name__ == "__main__": parser = ArgumentParser() # global arguments - parser.add_argument('--output-dir', type=str, default='./results/gpt', - help='Directory to store the outputs.') - parser.add_argument('--max-thread-num', type=int, default=3, - help='Maximum number of threads to use. The information seeking part and the article generation' - 'part can speed up by using multiple threads. Consider reducing it if keep getting ' - '"Exceed rate limit" error when calling LM API.') - parser.add_argument('--retriever', type=str, choices=['bing', 'you', 'brave', 'serper', 'duckduckgo', 'tavily', 'searxng', 'azure_ai_search'], - help='The search engine API to use for retrieving information.') + parser.add_argument( + "--output-dir", + type=str, + default="./results/gpt", + help="Directory to store the outputs.", + ) + parser.add_argument( + "--max-thread-num", + type=int, + default=3, + help="Maximum number of threads to use. The information seeking part and the article generation" + "part can speed up by using multiple threads. Consider reducing it if keep getting " + '"Exceed rate limit" error when calling LM API.', + ) + parser.add_argument( + "--retriever", + type=str, + choices=[ + "bing", + "you", + "brave", + "serper", + "duckduckgo", + "tavily", + "searxng", + "azure_ai_search", + ], + help="The search engine API to use for retrieving information.", + ) # stage of the pipeline - parser.add_argument('--do-research', action='store_true', - help='If True, simulate conversation to research the topic; otherwise, load the results.') - parser.add_argument('--do-generate-outline', action='store_true', - help='If True, generate an outline for the topic; otherwise, load the results.') - parser.add_argument('--do-generate-article', action='store_true', - help='If True, generate an article for the topic; otherwise, load the results.') - parser.add_argument('--do-polish-article', action='store_true', - help='If True, polish the article by adding a summarization section and (optionally) removing ' - 'duplicate content.') + parser.add_argument( + "--do-research", + action="store_true", + help="If True, simulate conversation to research the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-outline", + action="store_true", + help="If True, generate an outline for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-article", + action="store_true", + help="If True, generate an article for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-polish-article", + action="store_true", + help="If True, polish the article by adding a summarization section and (optionally) removing " + "duplicate content.", + ) # hyperparameters for the pre-writing stage - parser.add_argument('--max-conv-turn', type=int, default=3, - help='Maximum number of questions in conversational question asking.') - parser.add_argument('--max-perspective', type=int, default=3, - help='Maximum number of perspectives to consider in perspective-guided question asking.') - parser.add_argument('--search-top-k', type=int, default=3, - help='Top k search results to consider for each search query.') + parser.add_argument( + "--max-conv-turn", + type=int, + default=3, + help="Maximum number of questions in conversational question asking.", + ) + parser.add_argument( + "--max-perspective", + type=int, + default=3, + help="Maximum number of perspectives to consider in perspective-guided question asking.", + ) + parser.add_argument( + "--search-top-k", + type=int, + default=3, + help="Top k search results to consider for each search query.", + ) # hyperparameters for the writing stage - parser.add_argument('--retrieve-top-k', type=int, default=3, - help='Top k collected references for each section title.') - parser.add_argument('--remove-duplicate', action='store_true', - help='If True, remove duplicate content from the article.') + parser.add_argument( + "--retrieve-top-k", + type=int, + default=3, + help="Top k collected references for each section title.", + ) + parser.add_argument( + "--remove-duplicate", + action="store_true", + help="If True, remove duplicate content from the article.", + ) main(parser.parse_args()) diff --git a/examples/storm_examples/run_storm_wiki_gpt_with_VectorRM.py b/examples/storm_examples/run_storm_wiki_gpt_with_VectorRM.py index 8fa4d0d6..77c9dbf7 100644 --- a/examples/storm_examples/run_storm_wiki_gpt_with_VectorRM.py +++ b/examples/storm_examples/run_storm_wiki_gpt_with_VectorRM.py @@ -29,7 +29,11 @@ import os from argparse import ArgumentParser -from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs +from knowledge_storm import ( + STORMWikiRunnerArguments, + STORMWikiRunner, + STORMWikiLMConfigs, +) from knowledge_storm.rm import VectorRM from knowledge_storm.lm import OpenAIModel, AzureOpenAIModel from knowledge_storm.utils import load_api_key, QdrantVectorStoreManager @@ -37,35 +41,45 @@ def main(args): # Load API key from the specified toml file path - load_api_key(toml_file_path='secrets.toml') + load_api_key(toml_file_path="secrets.toml") # Initialize the language model configurations engine_lm_configs = STORMWikiLMConfigs() openai_kwargs = { - 'api_key': os.getenv("OPENAI_API_KEY"), - 'temperature': 1.0, - 'top_p': 0.9, + "api_key": os.getenv("OPENAI_API_KEY"), + "temperature": 1.0, + "top_p": 0.9, } - ModelClass = OpenAIModel if os.getenv('OPENAI_API_TYPE') == 'openai' else AzureOpenAIModel + ModelClass = ( + OpenAIModel if os.getenv("OPENAI_API_TYPE") == "openai" else AzureOpenAIModel + ) # If you are using Azure service, make sure the model name matches your own deployed model name. # The default name here is only used for demonstration and may not match your case. - gpt_35_model_name = 'gpt-3.5-turbo' if os.getenv('OPENAI_API_TYPE') == 'openai' else 'gpt-35-turbo' - gpt_4_model_name = 'gpt-4o' - if os.getenv('OPENAI_API_TYPE') == 'azure': - openai_kwargs['api_base'] = os.getenv('AZURE_API_BASE') - openai_kwargs['api_version'] = os.getenv('AZURE_API_VERSION') + gpt_35_model_name = ( + "gpt-3.5-turbo" if os.getenv("OPENAI_API_TYPE") == "openai" else "gpt-35-turbo" + ) + gpt_4_model_name = "gpt-4o" + if os.getenv("OPENAI_API_TYPE") == "azure": + openai_kwargs["api_base"] = os.getenv("AZURE_API_BASE") + openai_kwargs["api_version"] = os.getenv("AZURE_API_VERSION") # STORM is a LM system so different components can be powered by different models. - # For a good balance between cost and quality, you can choose a cheaper/faster model for conv_simulator_lm + # For a good balance between cost and quality, you can choose a cheaper/faster model for conv_simulator_lm # which is used to split queries, synthesize answers in the conversation. We recommend using stronger models # for outline_gen_lm which is responsible for organizing the collected information, and article_gen_lm # which is responsible for generating sections with citations. - conv_simulator_lm = ModelClass(model=gpt_35_model_name, max_tokens=500, **openai_kwargs) - question_asker_lm = ModelClass(model=gpt_35_model_name, max_tokens=500, **openai_kwargs) + conv_simulator_lm = ModelClass( + model=gpt_35_model_name, max_tokens=500, **openai_kwargs + ) + question_asker_lm = ModelClass( + model=gpt_35_model_name, max_tokens=500, **openai_kwargs + ) outline_gen_lm = ModelClass(model=gpt_4_model_name, max_tokens=400, **openai_kwargs) article_gen_lm = ModelClass(model=gpt_4_model_name, max_tokens=700, **openai_kwargs) - article_polish_lm = ModelClass(model=gpt_4_model_name, max_tokens=4000, **openai_kwargs) + article_polish_lm = ModelClass( + model=gpt_4_model_name, max_tokens=4000, **openai_kwargs + ) engine_lm_configs.set_conv_simulator_lm(conv_simulator_lm) engine_lm_configs.set_question_asker_lm(question_asker_lm) @@ -85,43 +99,49 @@ def main(args): # Create / update the vector store with the documents in the csv file if args.csv_file_path: kwargs = { - 'file_path': args.csv_file_path, - 'content_column': 'content', - 'title_column': 'title', - 'url_column': 'url', - 'desc_column': 'description', - 'batch_size': args.embed_batch_size, - 'vector_db_mode': args.vector_db_mode, - 'collection_name': args.collection_name, - 'embedding_model': args.embedding_model, - 'device': args.device, + "file_path": args.csv_file_path, + "content_column": "content", + "title_column": "title", + "url_column": "url", + "desc_column": "description", + "batch_size": args.embed_batch_size, + "vector_db_mode": args.vector_db_mode, + "collection_name": args.collection_name, + "embedding_model": args.embedding_model, + "device": args.device, } - if args.vector_db_mode == 'offline': + if args.vector_db_mode == "offline": QdrantVectorStoreManager.create_or_update_vector_store( - vector_store_path=args.offline_vector_db_dir, - **kwargs + vector_store_path=args.offline_vector_db_dir, **kwargs ) - elif args.vector_db_mode == 'online': + elif args.vector_db_mode == "online": QdrantVectorStoreManager.create_or_update_vector_store( url=args.online_vector_db_url, - api_key=os.getenv('QDRANT_API_KEY'), + api_key=os.getenv("QDRANT_API_KEY"), **kwargs ) # Setup VectorRM to retrieve information from your own data - rm = VectorRM(collection_name=args.collection_name, embedding_model=args.embedding_model, device=args.device, k=engine_args.search_top_k) + rm = VectorRM( + collection_name=args.collection_name, + embedding_model=args.embedding_model, + device=args.device, + k=engine_args.search_top_k, + ) # initialize the vector store, either online (store the db on Qdrant server) or offline (store the db locally): - if args.vector_db_mode == 'offline': + if args.vector_db_mode == "offline": rm.init_offline_vector_db(vector_store_path=args.offline_vector_db_dir) - elif args.vector_db_mode == 'online': - rm.init_online_vector_db(url=args.online_vector_db_url, api_key=os.getenv('QDRANT_API_KEY')) + elif args.vector_db_mode == "online": + rm.init_online_vector_db( + url=args.online_vector_db_url, api_key=os.getenv("QDRANT_API_KEY") + ) # Initialize the STORM Wiki Runner runner = STORMWikiRunner(engine_args, engine_lm_configs, rm) # run the pipeline - topic = input('Topic: ') + topic = input("Topic: ") runner.run( topic=topic, do_research=args.do_research, @@ -136,50 +156,120 @@ def main(args): if __name__ == "__main__": parser = ArgumentParser() # global arguments - parser.add_argument('--output-dir', type=str, default='./results/gpt_retrieval', - help='Directory to store the outputs.') - parser.add_argument('--max-thread-num', type=int, default=3, - help='Maximum number of threads to use. The information seeking part and the article generation' - 'part can speed up by using multiple threads. Consider reducing it if keep getting ' - '"Exceed rate limit" error when calling LM API.') + parser.add_argument( + "--output-dir", + type=str, + default="./results/gpt_retrieval", + help="Directory to store the outputs.", + ) + parser.add_argument( + "--max-thread-num", + type=int, + default=3, + help="Maximum number of threads to use. The information seeking part and the article generation" + "part can speed up by using multiple threads. Consider reducing it if keep getting " + '"Exceed rate limit" error when calling LM API.', + ) # provide local corpus and set up vector db - parser.add_argument('--collection-name', type=str, default="my_documents", - help='The collection name for vector store.') - parser.add_argument('--embedding_model', type=str, default="BAAI/bge-m3", - help='The collection name for vector store.') - parser.add_argument('--device', type=str, default="mps", - help='The device used to run the retrieval model (mps, cuda, cpu, etc).') - parser.add_argument('--vector-db-mode', type=str, choices=['offline', 'online'], - help='The mode of the Qdrant vector store (offline or online).') - parser.add_argument('--offline-vector-db-dir', type=str, default='./vector_store', - help='If use offline mode, please provide the directory to store the vector store.') - parser.add_argument('--online-vector-db-url', type=str, - help='If use online mode, please provide the url of the Qdrant server.') - parser.add_argument('--csv-file-path', type=str, default=None, - help='The path of the custom document corpus in CSV format. The CSV file should include ' - 'content, title, url, and description columns.') - parser.add_argument('--embed-batch-size', type=int, default=64, - help='Batch size for embedding the documents in the csv file.') + parser.add_argument( + "--collection-name", + type=str, + default="my_documents", + help="The collection name for vector store.", + ) + parser.add_argument( + "--embedding_model", + type=str, + default="BAAI/bge-m3", + help="The collection name for vector store.", + ) + parser.add_argument( + "--device", + type=str, + default="mps", + help="The device used to run the retrieval model (mps, cuda, cpu, etc).", + ) + parser.add_argument( + "--vector-db-mode", + type=str, + choices=["offline", "online"], + help="The mode of the Qdrant vector store (offline or online).", + ) + parser.add_argument( + "--offline-vector-db-dir", + type=str, + default="./vector_store", + help="If use offline mode, please provide the directory to store the vector store.", + ) + parser.add_argument( + "--online-vector-db-url", + type=str, + help="If use online mode, please provide the url of the Qdrant server.", + ) + parser.add_argument( + "--csv-file-path", + type=str, + default=None, + help="The path of the custom document corpus in CSV format. The CSV file should include " + "content, title, url, and description columns.", + ) + parser.add_argument( + "--embed-batch-size", + type=int, + default=64, + help="Batch size for embedding the documents in the csv file.", + ) # stage of the pipeline - parser.add_argument('--do-research', action='store_true', - help='If True, simulate conversation to research the topic; otherwise, load the results.') - parser.add_argument('--do-generate-outline', action='store_true', - help='If True, generate an outline for the topic; otherwise, load the results.') - parser.add_argument('--do-generate-article', action='store_true', - help='If True, generate an article for the topic; otherwise, load the results.') - parser.add_argument('--do-polish-article', action='store_true', - help='If True, polish the article by adding a summarization section and (optionally) removing ' - 'duplicate content.') + parser.add_argument( + "--do-research", + action="store_true", + help="If True, simulate conversation to research the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-outline", + action="store_true", + help="If True, generate an outline for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-article", + action="store_true", + help="If True, generate an article for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-polish-article", + action="store_true", + help="If True, polish the article by adding a summarization section and (optionally) removing " + "duplicate content.", + ) # hyperparameters for the pre-writing stage - parser.add_argument('--max-conv-turn', type=int, default=3, - help='Maximum number of questions in conversational question asking.') - parser.add_argument('--max-perspective', type=int, default=3, - help='Maximum number of perspectives to consider in perspective-guided question asking.') - parser.add_argument('--search-top-k', type=int, default=3, - help='Top k search results to consider for each search query.') + parser.add_argument( + "--max-conv-turn", + type=int, + default=3, + help="Maximum number of questions in conversational question asking.", + ) + parser.add_argument( + "--max-perspective", + type=int, + default=3, + help="Maximum number of perspectives to consider in perspective-guided question asking.", + ) + parser.add_argument( + "--search-top-k", + type=int, + default=3, + help="Top k search results to consider for each search query.", + ) # hyperparameters for the writing stage - parser.add_argument('--retrieve-top-k', type=int, default=3, - help='Top k collected references for each section title.') - parser.add_argument('--remove-duplicate', action='store_true', - help='If True, remove duplicate content from the article.') - main(parser.parse_args()) \ No newline at end of file + parser.add_argument( + "--retrieve-top-k", + type=int, + default=3, + help="Top k collected references for each section title.", + ) + parser.add_argument( + "--remove-duplicate", + action="store_true", + help="If True, remove duplicate content from the article.", + ) + main(parser.parse_args()) diff --git a/examples/storm_examples/run_storm_wiki_groq.py b/examples/storm_examples/run_storm_wiki_groq.py index 0dcaadbb..de5f6ae2 100644 --- a/examples/storm_examples/run_storm_wiki_groq.py +++ b/examples/storm_examples/run_storm_wiki_groq.py @@ -21,12 +21,24 @@ import re from argparse import ArgumentParser -from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs +from knowledge_storm import ( + STORMWikiRunnerArguments, + STORMWikiRunner, + STORMWikiLMConfigs, +) # Now import lm directly import lm from lm import GroqModel -from knowledge_storm.rm import YouRM, BingSearch, BraveRM, SerperRM, DuckDuckGoSearchRM, TavilySearchRM, SearXNG +from knowledge_storm.rm import ( + YouRM, + BingSearch, + BraveRM, + SerperRM, + DuckDuckGoSearchRM, + TavilySearchRM, + SearXNG, +) from knowledge_storm.utils import load_api_key @@ -36,10 +48,10 @@ def sanitize_topic(topic): Remove or replace characters that are not allowed in file names. """ # Replace spaces with underscores - topic = topic.replace(' ', '_') + topic = topic.replace(" ", "_") # Remove any character that isn't alphanumeric, underscore, or hyphen - topic = re.sub(r'[^a-zA-Z0-9_-]', '', topic) + topic = re.sub(r"[^a-zA-Z0-9_-]", "", topic) # Ensure the topic isn't empty after sanitization if not topic: @@ -49,26 +61,34 @@ def sanitize_topic(topic): def main(args): - load_api_key(toml_file_path='secrets.toml') + load_api_key(toml_file_path="secrets.toml") lm_configs = STORMWikiLMConfigs() # Ensure GROQ_API_KEY is set if not os.getenv("GROQ_API_KEY"): - raise ValueError("GROQ_API_KEY environment variable is not set. Please set it in your secrets.toml file.") + raise ValueError( + "GROQ_API_KEY environment variable is not set. Please set it in your secrets.toml file." + ) groq_kwargs = { - 'api_key': os.getenv("GROQ_API_KEY"), - 'api_base': "https://api.groq.com/openai/v1", - 'temperature': args.temperature, - 'top_p': args.top_p, + "api_key": os.getenv("GROQ_API_KEY"), + "api_base": "https://api.groq.com/openai/v1", + "temperature": args.temperature, + "top_p": args.top_p, } # Groq currently offers the "llama3-70b-8192" model with generous free API credits and the llama3.1 family of models as a preview for paying customers - conv_simulator_lm = GroqModel(model="llama3-70b-8192", max_tokens=500, **groq_kwargs) - question_asker_lm = GroqModel(model="llama3-70b-8192", max_tokens=500, **groq_kwargs) + conv_simulator_lm = GroqModel( + model="llama3-70b-8192", max_tokens=500, **groq_kwargs + ) + question_asker_lm = GroqModel( + model="llama3-70b-8192", max_tokens=500, **groq_kwargs + ) outline_gen_lm = GroqModel(model="llama3-70b-8192", max_tokens=400, **groq_kwargs) article_gen_lm = GroqModel(model="llama3-70b-8192", max_tokens=700, **groq_kwargs) - article_polish_lm = GroqModel(model="llama3-70b-8192", max_tokens=4000, **groq_kwargs) + article_polish_lm = GroqModel( + model="llama3-70b-8192", max_tokens=4000, **groq_kwargs + ) lm_configs.set_conv_simulator_lm(conv_simulator_lm) lm_configs.set_question_asker_lm(question_asker_lm) @@ -87,26 +107,45 @@ def main(args): # STORM is a knowledge curation system which consumes information from the retrieval module. # Currently, the information source is the Internet and we use search engine API as the retrieval module. match args.retriever: - case 'bing': - rm = BingSearch(bing_search_api=os.getenv('BING_SEARCH_API_KEY'), k=engine_args.search_top_k) - case 'you': - rm = YouRM(ydc_api_key=os.getenv('YDC_API_KEY'), k=engine_args.search_top_k) - case 'brave': - rm = BraveRM(brave_search_api_key=os.getenv('BRAVE_API_KEY'), k=engine_args.search_top_k) - case 'duckduckgo': - rm = DuckDuckGoSearchRM(k=engine_args.search_top_k, safe_search='On', region='us-en') - case 'serper': - rm = SerperRM(serper_search_api_key=os.getenv('SERPER_API_KEY'), query_params={'autocorrect': True, 'num': 10, 'page': 1}) - case 'tavily': - rm = TavilySearchRM(tavily_search_api_key=os.getenv('TAVILY_API_KEY'), k=engine_args.search_top_k, include_raw_content=True) - case 'searxng': - rm = SearXNG(searxng_api_key=os.getenv('SEARXNG_API_KEY'), k=engine_args.search_top_k) + case "bing": + rm = BingSearch( + bing_search_api=os.getenv("BING_SEARCH_API_KEY"), + k=engine_args.search_top_k, + ) + case "you": + rm = YouRM(ydc_api_key=os.getenv("YDC_API_KEY"), k=engine_args.search_top_k) + case "brave": + rm = BraveRM( + brave_search_api_key=os.getenv("BRAVE_API_KEY"), + k=engine_args.search_top_k, + ) + case "duckduckgo": + rm = DuckDuckGoSearchRM( + k=engine_args.search_top_k, safe_search="On", region="us-en" + ) + case "serper": + rm = SerperRM( + serper_search_api_key=os.getenv("SERPER_API_KEY"), + query_params={"autocorrect": True, "num": 10, "page": 1}, + ) + case "tavily": + rm = TavilySearchRM( + tavily_search_api_key=os.getenv("TAVILY_API_KEY"), + k=engine_args.search_top_k, + include_raw_content=True, + ) + case "searxng": + rm = SearXNG( + searxng_api_key=os.getenv("SEARXNG_API_KEY"), k=engine_args.search_top_k + ) case _: - raise ValueError(f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", or "searxng"') + raise ValueError( + f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", or "searxng"' + ) runner = STORMWikiRunner(engine_args, lm_configs, rm) - topic = input('Topic: ') + topic = input("Topic: ") sanitized_topic = sanitize_topic(topic) try: @@ -125,42 +164,87 @@ def main(args): raise -if __name__ == '__main__': +if __name__ == "__main__": parser = ArgumentParser() # global arguments - parser.add_argument('--output-dir', type=str, default='./results/groq', - help='Directory to store the outputs.') - parser.add_argument('--max-thread-num', type=int, default=3, - help='Maximum number of threads to use. The information seeking part and the article generation' - 'part can speed up by using multiple threads. Consider reducing it if keep getting ' - '"Exceed rate limit" error when calling LM API.') - parser.add_argument('--retriever', type=str, choices=['bing', 'you', 'brave', 'serper', 'duckduckgo', 'tavily', 'searxng'], - help='The search engine API to use for retrieving information.') - parser.add_argument('--temperature', type=float, default=1.0, - help='Sampling temperature to use.') - parser.add_argument('--top_p', type=float, default=0.9, - help='Top-p sampling parameter.') + parser.add_argument( + "--output-dir", + type=str, + default="./results/groq", + help="Directory to store the outputs.", + ) + parser.add_argument( + "--max-thread-num", + type=int, + default=3, + help="Maximum number of threads to use. The information seeking part and the article generation" + "part can speed up by using multiple threads. Consider reducing it if keep getting " + '"Exceed rate limit" error when calling LM API.', + ) + parser.add_argument( + "--retriever", + type=str, + choices=["bing", "you", "brave", "serper", "duckduckgo", "tavily", "searxng"], + help="The search engine API to use for retrieving information.", + ) + parser.add_argument( + "--temperature", type=float, default=1.0, help="Sampling temperature to use." + ) + parser.add_argument( + "--top_p", type=float, default=0.9, help="Top-p sampling parameter." + ) # stage of the pipeline - parser.add_argument('--do-research', action='store_true', - help='If True, simulate conversation to research the topic; otherwise, load the results.') - parser.add_argument('--do-generate-outline', action='store_true', - help='If True, generate an outline for the topic; otherwise, load the results.') - parser.add_argument('--do-generate-article', action='store_true', - help='If True, generate an article for the topic; otherwise, load the results.') - parser.add_argument('--do-polish-article', action='store_true', - help='If True, polish the article by adding a summarization section and (optionally) removing ' - 'duplicate content.') + parser.add_argument( + "--do-research", + action="store_true", + help="If True, simulate conversation to research the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-outline", + action="store_true", + help="If True, generate an outline for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-article", + action="store_true", + help="If True, generate an article for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-polish-article", + action="store_true", + help="If True, polish the article by adding a summarization section and (optionally) removing " + "duplicate content.", + ) # hyperparameters for the pre-writing stage - parser.add_argument('--max-conv-turn', type=int, default=3, - help='Maximum number of questions in conversational question asking.') - parser.add_argument('--max-perspective', type=int, default=3, - help='Maximum number of perspectives to consider in perspective-guided question asking.') - parser.add_argument('--search-top-k', type=int, default=3, - help='Top k search results to consider for each search query.') + parser.add_argument( + "--max-conv-turn", + type=int, + default=3, + help="Maximum number of questions in conversational question asking.", + ) + parser.add_argument( + "--max-perspective", + type=int, + default=3, + help="Maximum number of perspectives to consider in perspective-guided question asking.", + ) + parser.add_argument( + "--search-top-k", + type=int, + default=3, + help="Top k search results to consider for each search query.", + ) # hyperparameters for the writing stage - parser.add_argument('--retrieve-top-k', type=int, default=3, - help='Top k collected references for each section title.') - parser.add_argument('--remove-duplicate', action='store_true', - help='If True, remove duplicate content from the article.') + parser.add_argument( + "--retrieve-top-k", + type=int, + default=3, + help="Top k collected references for each section title.", + ) + parser.add_argument( + "--remove-duplicate", + action="store_true", + help="If True, remove duplicate content from the article.", + ) - main(parser.parse_args()) \ No newline at end of file + main(parser.parse_args()) diff --git a/examples/storm_examples/run_storm_wiki_mistral.py b/examples/storm_examples/run_storm_wiki_mistral.py index ceabd470..e9475413 100644 --- a/examples/storm_examples/run_storm_wiki_mistral.py +++ b/examples/storm_examples/run_storm_wiki_mistral.py @@ -15,26 +15,41 @@ storm_gen_article.txt # Final article generated storm_gen_article_polished.txt # Polished final article (if args.do_polish_article is True) """ + import os from argparse import ArgumentParser from dspy import Example -from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs +from knowledge_storm import ( + STORMWikiRunnerArguments, + STORMWikiRunner, + STORMWikiLMConfigs, +) from knowledge_storm.lm import VLLMClient -from knowledge_storm.rm import YouRM, BingSearch, BraveRM, SerperRM, DuckDuckGoSearchRM, TavilySearchRM, SearXNG +from knowledge_storm.rm import ( + YouRM, + BingSearch, + BraveRM, + SerperRM, + DuckDuckGoSearchRM, + TavilySearchRM, + SearXNG, +) from knowledge_storm.utils import load_api_key def main(args): - load_api_key(toml_file_path='secrets.toml') + load_api_key(toml_file_path="secrets.toml") lm_configs = STORMWikiLMConfigs() mistral_kwargs = { "model": "mistralai/Mistral-7B-Instruct-v0.2", "port": args.port, "url": args.url, - "stop": ('\n\n---',) # dspy uses "\n\n---" to separate examples. Open models sometimes generate this. + "stop": ( + "\n\n---", + ), # dspy uses "\n\n---" to separate examples. Open models sometimes generate this. } conv_simulator_lm = VLLMClient(max_tokens=500, **mistral_kwargs) @@ -60,22 +75,41 @@ def main(args): # STORM is a knowledge curation system which consumes information from the retrieval module. # Currently, the information source is the Internet and we use search engine API as the retrieval module. match args.retriever: - case 'bing': - rm = BingSearch(bing_search_api=os.getenv('BING_SEARCH_API_KEY'), k=engine_args.search_top_k) - case 'you': - rm = YouRM(ydc_api_key=os.getenv('YDC_API_KEY'), k=engine_args.search_top_k) - case 'brave': - rm = BraveRM(brave_search_api_key=os.getenv('BRAVE_API_KEY'), k=engine_args.search_top_k) - case 'duckduckgo': - rm = DuckDuckGoSearchRM(k=engine_args.search_top_k, safe_search='On', region='us-en') - case 'serper': - rm = SerperRM(serper_search_api_key=os.getenv('SERPER_API_KEY'), query_params={'autocorrect': True, 'num': 10, 'page': 1}) - case 'tavily': - rm = TavilySearchRM(tavily_search_api_key=os.getenv('TAVILY_API_KEY'), k=engine_args.search_top_k, include_raw_content=True) - case 'searxng': - rm = SearXNG(searxng_api_key=os.getenv('SEARXNG_API_KEY'), k=engine_args.search_top_k) + case "bing": + rm = BingSearch( + bing_search_api=os.getenv("BING_SEARCH_API_KEY"), + k=engine_args.search_top_k, + ) + case "you": + rm = YouRM(ydc_api_key=os.getenv("YDC_API_KEY"), k=engine_args.search_top_k) + case "brave": + rm = BraveRM( + brave_search_api_key=os.getenv("BRAVE_API_KEY"), + k=engine_args.search_top_k, + ) + case "duckduckgo": + rm = DuckDuckGoSearchRM( + k=engine_args.search_top_k, safe_search="On", region="us-en" + ) + case "serper": + rm = SerperRM( + serper_search_api_key=os.getenv("SERPER_API_KEY"), + query_params={"autocorrect": True, "num": 10, "page": 1}, + ) + case "tavily": + rm = TavilySearchRM( + tavily_search_api_key=os.getenv("TAVILY_API_KEY"), + k=engine_args.search_top_k, + include_raw_content=True, + ) + case "searxng": + rm = SearXNG( + searxng_api_key=os.getenv("SEARXNG_API_KEY"), k=engine_args.search_top_k + ) case _: - raise ValueError(f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", or "searxng"') + raise ValueError( + f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", or "searxng"' + ) runner = STORMWikiRunner(engine_args, lm_configs, rm) @@ -87,26 +121,28 @@ def main(args): find_related_topic_example = Example( topic="Knowledge Curation", related_topics="https://en.wikipedia.org/wiki/Knowledge_management\n" - "https://en.wikipedia.org/wiki/Information_science\n" - "https://en.wikipedia.org/wiki/Library_science\n" + "https://en.wikipedia.org/wiki/Information_science\n" + "https://en.wikipedia.org/wiki/Library_science\n", ) gen_persona_example = Example( topic="Knowledge Curation", examples="Title: Knowledge management\n" - "Table of Contents: History\nResearch\n Dimensions\n Strategies\n Motivations\nKM technologies" - "\nKnowledge barriers\nKnowledge retention\nKnowledge audit\nKnowledge protection\n" - " Knowledge protection methods\n Formal methods\n Informal methods\n" - " Balancing knowledge protection and knowledge sharing\n Knowledge protection risks", + "Table of Contents: History\nResearch\n Dimensions\n Strategies\n Motivations\nKM technologies" + "\nKnowledge barriers\nKnowledge retention\nKnowledge audit\nKnowledge protection\n" + " Knowledge protection methods\n Formal methods\n Informal methods\n" + " Balancing knowledge protection and knowledge sharing\n Knowledge protection risks", personas="1. Historian of Knowledge Systems: This editor will focus on the history and evolution of knowledge curation. They will provide context on how knowledge curation has changed over time and its impact on modern practices.\n" - "2. Information Science Professional: With insights from 'Information science', this editor will explore the foundational theories, definitions, and philosophy that underpin knowledge curation\n" - "3. Digital Librarian: This editor will delve into the specifics of how digital libraries operate, including software, metadata, digital preservation.\n" - "4. Technical expert: This editor will focus on the technical aspects of knowledge curation, such as common features of content management systems.\n" - "5. Museum Curator: The museum curator will contribute expertise on the curation of physical items and the transition of these practices into the digital realm." + "2. Information Science Professional: With insights from 'Information science', this editor will explore the foundational theories, definitions, and philosophy that underpin knowledge curation\n" + "3. Digital Librarian: This editor will delve into the specifics of how digital libraries operate, including software, metadata, digital preservation.\n" + "4. Technical expert: This editor will focus on the technical aspects of knowledge curation, such as common features of content management systems.\n" + "5. Museum Curator: The museum curator will contribute expertise on the curation of physical items and the transition of these practices into the digital realm.", ) runner.storm_knowledge_curation_module.persona_generator.create_writer_with_persona.find_related_topic.demos = [ - find_related_topic_example] + find_related_topic_example + ] runner.storm_knowledge_curation_module.persona_generator.create_writer_with_persona.gen_persona.demos = [ - gen_persona_example] + gen_persona_example + ] # A trade-off of adding one-shot example is that it will increase the input length of the prompt. Also, some # examples may be very long (e.g., an example for writing a section based on the given information), which may @@ -118,24 +154,28 @@ def main(args): topic="Example Topic", conv="Wikipedia Writer: ...\nExpert: ...\nWikipedia Writer: ...\nExpert: ...", old_outline="# Section 1\n## Subsection 1\n## Subsection 2\n" - "# Section 2\n## Subsection 1\n## Subsection 2\n" - "# Section 3", + "# Section 2\n## Subsection 1\n## Subsection 2\n" + "# Section 3", outline="# New Section 1\n## New Subsection 1\n## New Subsection 2\n" - "# New Section 2\n" - "# New Section 3\n## New Subsection 1\n## New Subsection 2\n## New Subsection 3" + "# New Section 2\n" + "# New Section 3\n## New Subsection 1\n## New Subsection 2\n## New Subsection 3", ) - runner.storm_outline_generation_module.write_outline.write_page_outline.demos = [write_page_outline_example] + runner.storm_outline_generation_module.write_outline.write_page_outline.demos = [ + write_page_outline_example + ] write_section_example = Example( info="[1]\nInformation in document 1\n[2]\nInformation in document 2\n[3]\nInformation in document 3", topic="Example Topic", section="Example Section", output="# Example Topic\n## Subsection 1\n" - "This is an example sentence [1]. This is another example sentence [2][3].\n" - "## Subsection 2\nThis is one more example sentence [1]." + "This is an example sentence [1]. This is another example sentence [2][3].\n" + "## Subsection 2\nThis is one more example sentence [1].", ) - runner.storm_article_generation.section_gen.write_section.demos = [write_section_example] + runner.storm_article_generation.section_gen.write_section.demos = [ + write_section_example + ] - topic = input('Topic: ') + topic = input("Topic: ") runner.run( topic=topic, do_research=args.do_research, @@ -147,42 +187,87 @@ def main(args): runner.summary() -if __name__ == '__main__': +if __name__ == "__main__": parser = ArgumentParser() # global arguments - parser.add_argument('--url', type=str, default='http://localhost', - help='URL of the VLLM server.') - parser.add_argument('--port', type=int, default=8000, - help='Port of the VLLM server.') - parser.add_argument('--output-dir', type=str, default='./results/mistral_7b', - help='Directory to store the outputs.') - parser.add_argument('--max-thread-num', type=int, default=3, - help='Maximum number of threads to use. The information seeking part and the article generation' - 'part can speed up by using multiple threads. Consider reducing it if keep getting ' - '"Exceed rate limit" error when calling LM API.') - parser.add_argument('--retriever', type=str, choices=['bing', 'you', 'brave', 'serper', 'duckduckgo', 'tavily', 'searxng'], - help='The search engine API to use for retrieving information.') + parser.add_argument( + "--url", type=str, default="http://localhost", help="URL of the VLLM server." + ) + parser.add_argument( + "--port", type=int, default=8000, help="Port of the VLLM server." + ) + parser.add_argument( + "--output-dir", + type=str, + default="./results/mistral_7b", + help="Directory to store the outputs.", + ) + parser.add_argument( + "--max-thread-num", + type=int, + default=3, + help="Maximum number of threads to use. The information seeking part and the article generation" + "part can speed up by using multiple threads. Consider reducing it if keep getting " + '"Exceed rate limit" error when calling LM API.', + ) + parser.add_argument( + "--retriever", + type=str, + choices=["bing", "you", "brave", "serper", "duckduckgo", "tavily", "searxng"], + help="The search engine API to use for retrieving information.", + ) # stage of the pipeline - parser.add_argument('--do-research', action='store_true', - help='If True, simulate conversation to research the topic; otherwise, load the results.') - parser.add_argument('--do-generate-outline', action='store_true', - help='If True, generate an outline for the topic; otherwise, load the results.') - parser.add_argument('--do-generate-article', action='store_true', - help='If True, generate an article for the topic; otherwise, load the results.') - parser.add_argument('--do-polish-article', action='store_true', - help='If True, polish the article by adding a summarization section and (optionally) removing ' - 'duplicate content.') + parser.add_argument( + "--do-research", + action="store_true", + help="If True, simulate conversation to research the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-outline", + action="store_true", + help="If True, generate an outline for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-article", + action="store_true", + help="If True, generate an article for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-polish-article", + action="store_true", + help="If True, polish the article by adding a summarization section and (optionally) removing " + "duplicate content.", + ) # hyperparameters for the pre-writing stage - parser.add_argument('--max-conv-turn', type=int, default=3, - help='Maximum number of questions in conversational question asking.') - parser.add_argument('--max-perspective', type=int, default=3, - help='Maximum number of perspectives to consider in perspective-guided question asking.') - parser.add_argument('--search-top-k', type=int, default=3, - help='Top k search results to consider for each search query.') + parser.add_argument( + "--max-conv-turn", + type=int, + default=3, + help="Maximum number of questions in conversational question asking.", + ) + parser.add_argument( + "--max-perspective", + type=int, + default=3, + help="Maximum number of perspectives to consider in perspective-guided question asking.", + ) + parser.add_argument( + "--search-top-k", + type=int, + default=3, + help="Top k search results to consider for each search query.", + ) # hyperparameters for the writing stage - parser.add_argument('--retrieve-top-k', type=int, default=3, - help='Top k collected references for each section title.') - parser.add_argument('--remove-duplicate', action='store_true', - help='If True, remove duplicate content from the article.') + parser.add_argument( + "--retrieve-top-k", + type=int, + default=3, + help="Top k collected references for each section title.", + ) + parser.add_argument( + "--remove-duplicate", + action="store_true", + help="If True, remove duplicate content from the article.", + ) - main(parser.parse_args()) \ No newline at end of file + main(parser.parse_args()) diff --git a/examples/storm_examples/run_storm_wiki_ollama.py b/examples/storm_examples/run_storm_wiki_ollama.py index 465e065a..8d2e2b53 100644 --- a/examples/storm_examples/run_storm_wiki_ollama.py +++ b/examples/storm_examples/run_storm_wiki_ollama.py @@ -15,6 +15,7 @@ storm_gen_article.txt # Final article generated storm_gen_article_polished.txt # Polished final article (if args.do_polish_article is True) """ + import os import sys from argparse import ArgumentParser @@ -22,20 +23,34 @@ from dspy import Example from knowledge_storm.lm import OllamaClient -from knowledge_storm.rm import YouRM, BingSearch, BraveRM, SerperRM, DuckDuckGoSearchRM, TavilySearchRM, SearXNG -from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs +from knowledge_storm.rm import ( + YouRM, + BingSearch, + BraveRM, + SerperRM, + DuckDuckGoSearchRM, + TavilySearchRM, + SearXNG, +) +from knowledge_storm import ( + STORMWikiRunnerArguments, + STORMWikiRunner, + STORMWikiLMConfigs, +) from knowledge_storm.utils import load_api_key def main(args): - load_api_key(toml_file_path='secrets.toml') + load_api_key(toml_file_path="secrets.toml") lm_configs = STORMWikiLMConfigs() ollama_kwargs = { "model": args.model, "port": args.port, "url": args.url, - "stop": ('\n\n---',) # dspy uses "\n\n---" to separate examples. Open models sometimes generate this. + "stop": ( + "\n\n---", + ), # dspy uses "\n\n---" to separate examples. Open models sometimes generate this. } conv_simulator_lm = OllamaClient(max_tokens=500, **ollama_kwargs) @@ -61,22 +76,41 @@ def main(args): # STORM is a knowledge curation system which consumes information from the retrieval module. # Currently, the information source is the Internet and we use search engine API as the retrieval module. match args.retriever: - case 'bing': - rm = BingSearch(bing_search_api=os.getenv('BING_SEARCH_API_KEY'), k=engine_args.search_top_k) - case 'you': - rm = YouRM(ydc_api_key=os.getenv('YDC_API_KEY'), k=engine_args.search_top_k) - case 'brave': - rm = BraveRM(brave_search_api_key=os.getenv('BRAVE_API_KEY'), k=engine_args.search_top_k) - case 'duckduckgo': - rm = DuckDuckGoSearchRM(k=engine_args.search_top_k, safe_search='On', region='us-en') - case 'serper': - rm = SerperRM(serper_search_api_key=os.getenv('SERPER_API_KEY'), query_params={'autocorrect': True, 'num': 10, 'page': 1}) - case 'tavily': - rm = TavilySearchRM(tavily_search_api_key=os.getenv('TAVILY_API_KEY'), k=engine_args.search_top_k, include_raw_content=True) - case 'searxng': - rm = SearXNG(searxng_api_key=os.getenv('SEARXNG_API_KEY'), k=engine_args.search_top_k) + case "bing": + rm = BingSearch( + bing_search_api=os.getenv("BING_SEARCH_API_KEY"), + k=engine_args.search_top_k, + ) + case "you": + rm = YouRM(ydc_api_key=os.getenv("YDC_API_KEY"), k=engine_args.search_top_k) + case "brave": + rm = BraveRM( + brave_search_api_key=os.getenv("BRAVE_API_KEY"), + k=engine_args.search_top_k, + ) + case "duckduckgo": + rm = DuckDuckGoSearchRM( + k=engine_args.search_top_k, safe_search="On", region="us-en" + ) + case "serper": + rm = SerperRM( + serper_search_api_key=os.getenv("SERPER_API_KEY"), + query_params={"autocorrect": True, "num": 10, "page": 1}, + ) + case "tavily": + rm = TavilySearchRM( + tavily_search_api_key=os.getenv("TAVILY_API_KEY"), + k=engine_args.search_top_k, + include_raw_content=True, + ) + case "searxng": + rm = SearXNG( + searxng_api_key=os.getenv("SEARXNG_API_KEY"), k=engine_args.search_top_k + ) case _: - raise ValueError(f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", or "searxng"') + raise ValueError( + f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", or "searxng"' + ) runner = STORMWikiRunner(engine_args, lm_configs, rm) @@ -88,26 +122,28 @@ def main(args): find_related_topic_example = Example( topic="Knowledge Curation", related_topics="https://en.wikipedia.org/wiki/Knowledge_management\n" - "https://en.wikipedia.org/wiki/Information_science\n" - "https://en.wikipedia.org/wiki/Library_science\n" + "https://en.wikipedia.org/wiki/Information_science\n" + "https://en.wikipedia.org/wiki/Library_science\n", ) gen_persona_example = Example( topic="Knowledge Curation", examples="Title: Knowledge management\n" - "Table of Contents: History\nResearch\n Dimensions\n Strategies\n Motivations\nKM technologies" - "\nKnowledge barriers\nKnowledge retention\nKnowledge audit\nKnowledge protection\n" - " Knowledge protection methods\n Formal methods\n Informal methods\n" - " Balancing knowledge protection and knowledge sharing\n Knowledge protection risks", + "Table of Contents: History\nResearch\n Dimensions\n Strategies\n Motivations\nKM technologies" + "\nKnowledge barriers\nKnowledge retention\nKnowledge audit\nKnowledge protection\n" + " Knowledge protection methods\n Formal methods\n Informal methods\n" + " Balancing knowledge protection and knowledge sharing\n Knowledge protection risks", personas="1. Historian of Knowledge Systems: This editor will focus on the history and evolution of knowledge curation. They will provide context on how knowledge curation has changed over time and its impact on modern practices.\n" - "2. Information Science Professional: With insights from 'Information science', this editor will explore the foundational theories, definitions, and philosophy that underpin knowledge curation\n" - "3. Digital Librarian: This editor will delve into the specifics of how digital libraries operate, including software, metadata, digital preservation.\n" - "4. Technical expert: This editor will focus on the technical aspects of knowledge curation, such as common features of content management systems.\n" - "5. Museum Curator: The museum curator will contribute expertise on the curation of physical items and the transition of these practices into the digital realm." + "2. Information Science Professional: With insights from 'Information science', this editor will explore the foundational theories, definitions, and philosophy that underpin knowledge curation\n" + "3. Digital Librarian: This editor will delve into the specifics of how digital libraries operate, including software, metadata, digital preservation.\n" + "4. Technical expert: This editor will focus on the technical aspects of knowledge curation, such as common features of content management systems.\n" + "5. Museum Curator: The museum curator will contribute expertise on the curation of physical items and the transition of these practices into the digital realm.", ) runner.storm_knowledge_curation_module.persona_generator.create_writer_with_persona.find_related_topic.demos = [ - find_related_topic_example] + find_related_topic_example + ] runner.storm_knowledge_curation_module.persona_generator.create_writer_with_persona.gen_persona.demos = [ - gen_persona_example] + gen_persona_example + ] # A trade-off of adding one-shot example is that it will increase the input length of the prompt. Also, some # examples may be very long (e.g., an example for writing a section based on the given information), which may @@ -119,24 +155,28 @@ def main(args): topic="Example Topic", conv="Wikipedia Writer: ...\nExpert: ...\nWikipedia Writer: ...\nExpert: ...", old_outline="# Section 1\n## Subsection 1\n## Subsection 2\n" - "# Section 2\n## Subsection 1\n## Subsection 2\n" - "# Section 3", + "# Section 2\n## Subsection 1\n## Subsection 2\n" + "# Section 3", outline="# New Section 1\n## New Subsection 1\n## New Subsection 2\n" - "# New Section 2\n" - "# New Section 3\n## New Subsection 1\n## New Subsection 2\n## New Subsection 3" + "# New Section 2\n" + "# New Section 3\n## New Subsection 1\n## New Subsection 2\n## New Subsection 3", ) - runner.storm_outline_generation_module.write_outline.write_page_outline.demos = [write_page_outline_example] + runner.storm_outline_generation_module.write_outline.write_page_outline.demos = [ + write_page_outline_example + ] write_section_example = Example( info="[1]\nInformation in document 1\n[2]\nInformation in document 2\n[3]\nInformation in document 3", topic="Example Topic", section="Example Section", output="# Example Topic\n## Subsection 1\n" - "This is an example sentence [1]. This is another example sentence [2][3].\n" - "## Subsection 2\nThis is one more example sentence [1]." + "This is an example sentence [1]. This is another example sentence [2][3].\n" + "## Subsection 2\nThis is one more example sentence [1].", ) - runner.storm_article_generation.section_gen.write_section.demos = [write_section_example] + runner.storm_article_generation.section_gen.write_section.demos = [ + write_section_example + ] - topic = input('Topic: ') + topic = input("Topic: ") runner.run( topic=topic, do_research=args.do_research, @@ -148,44 +188,90 @@ def main(args): runner.summary() -if __name__ == '__main__': +if __name__ == "__main__": parser = ArgumentParser() # global arguments - parser.add_argument('--url', type=str, default='http://localhost', - help='URL of the Ollama server.') - parser.add_argument('--port', type=int, default=11434, - help='Port of the Ollama server.') - parser.add_argument('--model', type=str, default='llama3:latest', - help='Model of the Ollama server.') - parser.add_argument('--output-dir', type=str, default='./results/ollama', - help='Directory to store the outputs.') - parser.add_argument('--max-thread-num', type=int, default=3, - help='Maximum number of threads to use. The information seeking part and the article generation' - 'part can speed up by using multiple threads. Consider reducing it if keep getting ' - '"Exceed rate limit" error when calling LM API.') - parser.add_argument('--retriever', type=str, choices=['bing', 'you', 'brave', 'serper', 'duckduckgo', 'tavily', 'searxng'], - help='The search engine API to use for retrieving information.') + parser.add_argument( + "--url", type=str, default="http://localhost", help="URL of the Ollama server." + ) + parser.add_argument( + "--port", type=int, default=11434, help="Port of the Ollama server." + ) + parser.add_argument( + "--model", type=str, default="llama3:latest", help="Model of the Ollama server." + ) + parser.add_argument( + "--output-dir", + type=str, + default="./results/ollama", + help="Directory to store the outputs.", + ) + parser.add_argument( + "--max-thread-num", + type=int, + default=3, + help="Maximum number of threads to use. The information seeking part and the article generation" + "part can speed up by using multiple threads. Consider reducing it if keep getting " + '"Exceed rate limit" error when calling LM API.', + ) + parser.add_argument( + "--retriever", + type=str, + choices=["bing", "you", "brave", "serper", "duckduckgo", "tavily", "searxng"], + help="The search engine API to use for retrieving information.", + ) # stage of the pipeline - parser.add_argument('--do-research', action='store_true', - help='If True, simulate conversation to research the topic; otherwise, load the results.') - parser.add_argument('--do-generate-outline', action='store_true', - help='If True, generate an outline for the topic; otherwise, load the results.') - parser.add_argument('--do-generate-article', action='store_true', - help='If True, generate an article for the topic; otherwise, load the results.') - parser.add_argument('--do-polish-article', action='store_true', - help='If True, polish the article by adding a summarization section and (optionally) removing ' - 'duplicate content.') + parser.add_argument( + "--do-research", + action="store_true", + help="If True, simulate conversation to research the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-outline", + action="store_true", + help="If True, generate an outline for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-article", + action="store_true", + help="If True, generate an article for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-polish-article", + action="store_true", + help="If True, polish the article by adding a summarization section and (optionally) removing " + "duplicate content.", + ) # hyperparameters for the pre-writing stage - parser.add_argument('--max-conv-turn', type=int, default=3, - help='Maximum number of questions in conversational question asking.') - parser.add_argument('--max-perspective', type=int, default=3, - help='Maximum number of perspectives to consider in perspective-guided question asking.') - parser.add_argument('--search-top-k', type=int, default=3, - help='Top k search results to consider for each search query.') + parser.add_argument( + "--max-conv-turn", + type=int, + default=3, + help="Maximum number of questions in conversational question asking.", + ) + parser.add_argument( + "--max-perspective", + type=int, + default=3, + help="Maximum number of perspectives to consider in perspective-guided question asking.", + ) + parser.add_argument( + "--search-top-k", + type=int, + default=3, + help="Top k search results to consider for each search query.", + ) # hyperparameters for the writing stage - parser.add_argument('--retrieve-top-k', type=int, default=3, - help='Top k collected references for each section title.') - parser.add_argument('--remove-duplicate', action='store_true', - help='If True, remove duplicate content from the article.') + parser.add_argument( + "--retrieve-top-k", + type=int, + default=3, + help="Top k collected references for each section title.", + ) + parser.add_argument( + "--remove-duplicate", + action="store_true", + help="If True, remove duplicate content from the article.", + ) - main(parser.parse_args()) \ No newline at end of file + main(parser.parse_args()) diff --git a/examples/storm_examples/run_storm_wiki_ollama_with_searxng.py b/examples/storm_examples/run_storm_wiki_ollama_with_searxng.py index a4be3624..725a68b9 100644 --- a/examples/storm_examples/run_storm_wiki_ollama_with_searxng.py +++ b/examples/storm_examples/run_storm_wiki_ollama_with_searxng.py @@ -3,21 +3,25 @@ from dspy import Example -from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs +from knowledge_storm import ( + STORMWikiRunnerArguments, + STORMWikiRunner, + STORMWikiLMConfigs, +) from knowledge_storm.lm import OllamaClient from knowledge_storm.rm import SearXNG from knowledge_storm.utils import load_api_key def main(args): - load_api_key(toml_file_path='secrets.toml') + load_api_key(toml_file_path="secrets.toml") lm_configs = STORMWikiLMConfigs() ollama_kwargs = { "model": args.model, "port": args.port, "url": args.url, - "stop": ('\n\n---',) + "stop": ("\n\n---",), } conv_simulator_lm = OllamaClient(max_tokens=500, **ollama_kwargs) @@ -40,23 +44,27 @@ def main(args): max_thread_num=args.max_thread_num, ) - rm = SearXNG(searxng_api_url=args.searxng_api_url, searxng_api_key=os.getenv('SEARXNG_API_KEY'), k=engine_args.search_top_k) + rm = SearXNG( + searxng_api_url=args.searxng_api_url, + searxng_api_key=os.getenv("SEARXNG_API_KEY"), + k=engine_args.search_top_k, + ) runner = STORMWikiRunner(engine_args, lm_configs, rm) find_related_topic_example = Example( topic="Knowledge Curation", related_topics="https://en.wikipedia.org/wiki/Knowledge_management\n" - "https://en.wikipedia.org/wiki/Information_science\n" - "https://en.wikipedia.org/wiki/Library_science\n" + "https://en.wikipedia.org/wiki/Information_science\n" + "https://en.wikipedia.org/wiki/Library_science\n", ) gen_persona_example = Example( topic="Knowledge Curation", examples="Title: Knowledge management\n" - "Table of Contents: History\nResearch\n Dimensions\n Strategies\n Motivations\nKM technologies" - "\nKnowledge barriers\nKnowledge retention\nKnowledge audit\nKnowledge protection\n" - " Knowledge protection methods\n Formal methods\n Informal methods\n" - " Balancing knowledge protection and knowledge sharing\n Knowledge protection risks", + "Table of Contents: History\nResearch\n Dimensions\n Strategies\n Motivations\nKM technologies" + "\nKnowledge barriers\nKnowledge retention\nKnowledge audit\nKnowledge protection\n" + " Knowledge protection methods\n Formal methods\n Informal methods\n" + " Balancing knowledge protection and knowledge sharing\n Knowledge protection risks", personas=( "1. Historian of Knowledge Systems: This editor will focus on the history and evolution of knowledge " "curation. They will provide context on how knowledge curation has changed over time and its impact on " @@ -69,35 +77,41 @@ def main(args): "such as common features of content management systems.\n" "5. Museum Curator: The museum curator will contribute expertise on the curation of physical items and " "the transition of these practices into the digital realm." - ) + ), ) runner.storm_knowledge_curation_module.persona_generator.create_writer_with_persona.find_related_topic.demos = [ - find_related_topic_example] + find_related_topic_example + ] runner.storm_knowledge_curation_module.persona_generator.create_writer_with_persona.gen_persona.demos = [ - gen_persona_example] + gen_persona_example + ] write_page_outline_example = Example( topic="Example Topic", conv="Wikipedia Writer: ...\nExpert: ...\nWikipedia Writer: ...\nExpert: ...", old_outline="# Section 1\n## Subsection 1\n## Subsection 2\n" - "# Section 2\n## Subsection 1\n## Subsection 2\n" - "# Section 3", + "# Section 2\n## Subsection 1\n## Subsection 2\n" + "# Section 3", outline="# New Section 1\n## New Subsection 1\n## New Subsection 2\n" - "# New Section 2\n" - "# New Section 3\n## New Subsection 1\n## New Subsection 2\n## New Subsection 3" + "# New Section 2\n" + "# New Section 3\n## New Subsection 1\n## New Subsection 2\n## New Subsection 3", ) - runner.storm_outline_generation_module.write_outline.write_page_outline.demos = [write_page_outline_example] + runner.storm_outline_generation_module.write_outline.write_page_outline.demos = [ + write_page_outline_example + ] write_section_example = Example( info="[1]\nInformation in document 1\n[2]\nInformation in document 2\n[3]\nInformation in document 3", topic="Example Topic", section="Example Section", output="# Example Topic\n## Subsection 1\n" - "This is an example sentence [1]. This is another example sentence [2][3].\n" - "## Subsection 2\nThis is one more example sentence [1]." + "This is an example sentence [1]. This is another example sentence [2][3].\n" + "## Subsection 2\nThis is one more example sentence [1].", ) - runner.storm_article_generation.section_gen.write_section.demos = [write_section_example] + runner.storm_article_generation.section_gen.write_section.demos = [ + write_section_example + ] - topic = input('Topic: ') + topic = input("Topic: ") runner.run( topic=topic, do_research=args.do_research, @@ -109,46 +123,93 @@ def main(args): runner.summary() -if __name__ == '__main__': +if __name__ == "__main__": parser = ArgumentParser() # global arguments - parser.add_argument('--url', type=str, default='http://localhost', - help='URL of the Ollama server.') - parser.add_argument('--port', type=int, default=11434, - help='Port of the Ollama server.') - parser.add_argument('--model', type=str, default='llama3:latest', - help='Model of the Ollama server.') - parser.add_argument('--output-dir', type=str, default='./results/ollama', - help='Directory to store the outputs.') - parser.add_argument('--max-thread-num', type=int, default=3, - help='Maximum number of threads to use. The information seeking part and the article generation' - 'part can speed up by using multiple threads. Consider reducing it if keep getting ' - '"Exceed rate limit" error when calling LM API.') - parser.add_argument('--retriever', type=str, choices=['searxng'], - help='The search engine API to use for retrieving information.') - parser.add_argument('--searxng-api-url', type=str, required=True, - help='URL of the SearXNG API.') + parser.add_argument( + "--url", type=str, default="http://localhost", help="URL of the Ollama server." + ) + parser.add_argument( + "--port", type=int, default=11434, help="Port of the Ollama server." + ) + parser.add_argument( + "--model", type=str, default="llama3:latest", help="Model of the Ollama server." + ) + parser.add_argument( + "--output-dir", + type=str, + default="./results/ollama", + help="Directory to store the outputs.", + ) + parser.add_argument( + "--max-thread-num", + type=int, + default=3, + help="Maximum number of threads to use. The information seeking part and the article generation" + "part can speed up by using multiple threads. Consider reducing it if keep getting " + '"Exceed rate limit" error when calling LM API.', + ) + parser.add_argument( + "--retriever", + type=str, + choices=["searxng"], + help="The search engine API to use for retrieving information.", + ) + parser.add_argument( + "--searxng-api-url", type=str, required=True, help="URL of the SearXNG API." + ) # stage of the pipeline - parser.add_argument('--do-research', action='store_true', - help='If True, simulate conversation to research the topic; otherwise, load the results.') - parser.add_argument('--do-generate-outline', action='store_true', - help='If True, generate an outline for the topic; otherwise, load the results.') - parser.add_argument('--do-generate-article', action='store_true', - help='If True, generate an article for the topic; otherwise, load the results.') - parser.add_argument('--do-polish-article', action='store_true', - help='If True, polish the article by adding a summarization section and (optionally) removing ' - 'duplicate content.') + parser.add_argument( + "--do-research", + action="store_true", + help="If True, simulate conversation to research the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-outline", + action="store_true", + help="If True, generate an outline for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-article", + action="store_true", + help="If True, generate an article for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-polish-article", + action="store_true", + help="If True, polish the article by adding a summarization section and (optionally) removing " + "duplicate content.", + ) # hyperparameters for the pre-writing stage - parser.add_argument('--max-conv-turn', type=int, default=3, - help='Maximum number of questions in conversational question asking.') - parser.add_argument('--max-perspective', type=int, default=3, - help='Maximum number of perspectives to consider in perspective-guided question asking.') - parser.add_argument('--search-top-k', type=int, default=3, - help='Top k search results to consider for each search query.') + parser.add_argument( + "--max-conv-turn", + type=int, + default=3, + help="Maximum number of questions in conversational question asking.", + ) + parser.add_argument( + "--max-perspective", + type=int, + default=3, + help="Maximum number of perspectives to consider in perspective-guided question asking.", + ) + parser.add_argument( + "--search-top-k", + type=int, + default=3, + help="Top k search results to consider for each search query.", + ) # hyperparameters for the writing stage - parser.add_argument('--retrieve-top-k', type=int, default=3, - help='Top k collected references for each section title.') - parser.add_argument('--remove-duplicate', action='store_true', - help='If True, remove duplicate content from the article.') + parser.add_argument( + "--retrieve-top-k", + type=int, + default=3, + help="Top k collected references for each section title.", + ) + parser.add_argument( + "--remove-duplicate", + action="store_true", + help="If True, remove duplicate content from the article.", + ) main(parser.parse_args()) diff --git a/frontend/demo_light/demo_util.py b/frontend/demo_light/demo_util.py index 22258dda..cff1d9f2 100644 --- a/frontend/demo_light/demo_util.py +++ b/frontend/demo_light/demo_util.py @@ -13,7 +13,11 @@ # Uncomment the following lines: # import sys # sys.path.append('../../') -from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs +from knowledge_storm import ( + STORMWikiRunnerArguments, + STORMWikiRunner, + STORMWikiLMConfigs, +) from knowledge_storm.lm import OpenAIModel from knowledge_storm.rm import YouRM from knowledge_storm.storm_wiki.modules.callback import BaseCallbackHandler @@ -21,7 +25,7 @@ from stoc import stoc -class DemoFileIOHelper(): +class DemoFileIOHelper: @staticmethod def read_structure_to_dict(articles_root_path): """ @@ -107,8 +111,10 @@ def set_file_modification_time(file_path, modification_time_string): file_path (str): The path to the file. modification_time_string (str): The desired modification time in 'YYYY-MM-DD HH:MM:SS' format. """ - california_tz = pytz.timezone('America/Los_Angeles') - modification_time = datetime.datetime.strptime(modification_time_string, '%Y-%m-%d %H:%M:%S') + california_tz = pytz.timezone("America/Los_Angeles") + modification_time = datetime.datetime.strptime( + modification_time_string, "%Y-%m-%d %H:%M:%S" + ) modification_time = california_tz.localize(modification_time) modification_time_utc = modification_time.astimezone(datetime.timezone.utc) modification_timestamp = modification_time_utc.timestamp() @@ -125,7 +131,7 @@ def get_latest_modification_time(path): Returns: str: The latest file's modification time in 'YYYY-MM-DD HH:MM:SS' format. """ - california_tz = pytz.timezone('America/Los_Angeles') + california_tz = pytz.timezone("America/Los_Angeles") latest_mod_time = None file_paths = [] @@ -138,17 +144,26 @@ def get_latest_modification_time(path): for file_path in file_paths: modification_timestamp = os.path.getmtime(file_path) - modification_time_utc = datetime.datetime.utcfromtimestamp(modification_timestamp) - modification_time_utc = modification_time_utc.replace(tzinfo=datetime.timezone.utc) - modification_time_california = modification_time_utc.astimezone(california_tz) - - if latest_mod_time is None or modification_time_california > latest_mod_time: + modification_time_utc = datetime.datetime.utcfromtimestamp( + modification_timestamp + ) + modification_time_utc = modification_time_utc.replace( + tzinfo=datetime.timezone.utc + ) + modification_time_california = modification_time_utc.astimezone( + california_tz + ) + + if ( + latest_mod_time is None + or modification_time_california > latest_mod_time + ): latest_mod_time = modification_time_california if latest_mod_time is not None: - return latest_mod_time.strftime('%Y-%m-%d %H:%M:%S') + return latest_mod_time.strftime("%Y-%m-%d %H:%M:%S") else: - return datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') + return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") @staticmethod def assemble_article_data(article_file_path_dict): @@ -171,25 +186,44 @@ def assemble_article_data(article_file_path_dict): if neither the raw nor polished article text exists in the provided file paths. """ - if "storm_gen_article.txt" in article_file_path_dict or "storm_gen_article_polished.txt" in article_file_path_dict: - full_article_name = "storm_gen_article_polished.txt" if "storm_gen_article_polished.txt" in article_file_path_dict else "storm_gen_article.txt" - article_data = {"article": DemoTextProcessingHelper.parse( - DemoFileIOHelper.read_txt_file(article_file_path_dict[full_article_name]))} + if ( + "storm_gen_article.txt" in article_file_path_dict + or "storm_gen_article_polished.txt" in article_file_path_dict + ): + full_article_name = ( + "storm_gen_article_polished.txt" + if "storm_gen_article_polished.txt" in article_file_path_dict + else "storm_gen_article.txt" + ) + article_data = { + "article": DemoTextProcessingHelper.parse( + DemoFileIOHelper.read_txt_file( + article_file_path_dict[full_article_name] + ) + ) + } if "url_to_info.json" in article_file_path_dict: article_data["citations"] = _construct_citation_dict_from_search_result( - DemoFileIOHelper.read_json_file(article_file_path_dict["url_to_info.json"])) + DemoFileIOHelper.read_json_file( + article_file_path_dict["url_to_info.json"] + ) + ) if "conversation_log.json" in article_file_path_dict: article_data["conversation_log"] = DemoFileIOHelper.read_json_file( - article_file_path_dict["conversation_log.json"]) + article_file_path_dict["conversation_log.json"] + ) return article_data return None -class DemoTextProcessingHelper(): - +class DemoTextProcessingHelper: @staticmethod def remove_citations(sent): - return re.sub(r"\[\d+", "", re.sub(r" \[\d+", "", sent)).replace(" |", "").replace("]", "") + return ( + re.sub(r"\[\d+", "", re.sub(r" \[\d+", "", sent)) + .replace(" |", "") + .replace("]", "") + ) @staticmethod def parse_conversation_history(json_data): @@ -199,43 +233,54 @@ def parse_conversation_history(json_data): """ parsed_data = [] for persona_conversation_data in json_data: - if ': ' in persona_conversation_data["perspective"]: - name, description = persona_conversation_data["perspective"].split(": ", 1) - elif '- ' in persona_conversation_data["perspective"]: - name, description = persona_conversation_data["perspective"].split("- ", 1) + if ": " in persona_conversation_data["perspective"]: + name, description = persona_conversation_data["perspective"].split( + ": ", 1 + ) + elif "- " in persona_conversation_data["perspective"]: + name, description = persona_conversation_data["perspective"].split( + "- ", 1 + ) else: name, description = "", persona_conversation_data["perspective"] cur_conversation = [] for dialogue_turn in persona_conversation_data["dlg_turns"]: - cur_conversation.append({"role": "user", "content": dialogue_turn["user_utterance"]}) cur_conversation.append( - {"role": "assistant", - "content": DemoTextProcessingHelper.remove_citations(dialogue_turn["agent_utterance"])}) + {"role": "user", "content": dialogue_turn["user_utterance"]} + ) + cur_conversation.append( + { + "role": "assistant", + "content": DemoTextProcessingHelper.remove_citations( + dialogue_turn["agent_utterance"] + ), + } + ) parsed_data.append((name, description, cur_conversation)) return parsed_data @staticmethod def parse(text): regex = re.compile(r']:\s+"(.*?)"\s+http') - text = regex.sub(']: http', text) + text = regex.sub("]: http", text) return text @staticmethod def add_markdown_indentation(input_string): - lines = input_string.split('\n') + lines = input_string.split("\n") processed_lines = [""] for line in lines: num_hashes = 0 for char in line: - if char == '#': + if char == "#": num_hashes += 1 else: break num_hashes -= 1 num_spaces = 4 * num_hashes - new_line = ' ' * num_spaces + line + new_line = " " * num_spaces + line processed_lines.append(new_line) - return '\n'.join(processed_lines) + return "\n".join(processed_lines) @staticmethod def get_current_time_string(): @@ -245,13 +290,15 @@ def get_current_time_string(): Returns: str: The current California time in 'YYYY-MM-DD HH:MM:SS' format. """ - california_tz = pytz.timezone('America/Los_Angeles') + california_tz = pytz.timezone("America/Los_Angeles") utc_now = datetime.datetime.now(datetime.timezone.utc) california_now = utc_now.astimezone(california_tz) - return california_now.strftime('%Y-%m-%d %H:%M:%S') + return california_now.strftime("%Y-%m-%d %H:%M:%S") @staticmethod - def compare_time_strings(time_string1, time_string2, time_format='%Y-%m-%d %H:%M:%S'): + def compare_time_strings( + time_string1, time_string2, time_format="%Y-%m-%d %H:%M:%S" + ): """ Compares two time strings to determine if they represent the same point in time. @@ -273,13 +320,13 @@ def compare_time_strings(time_string1, time_string2, time_format='%Y-%m-%d %H:%M @staticmethod def add_inline_citation_link(article_text, citation_dict): # Regular expression to find citations like [i] - pattern = r'\[(\d+)\]' + pattern = r"\[(\d+)\]" # Function to replace each citation with its Markdown link def replace_with_link(match): i = match.group(1) - url = citation_dict.get(int(i), {}).get('url', '#') - return f'[[{i}]]({url})' + url = citation_dict.get(int(i), {}).get("url", "#") + return f"[[{i}]]({url})" # Replace all citations in the text with Markdown links return re.sub(pattern, replace_with_link, article_text) @@ -292,26 +339,34 @@ def generate_html_toc(md_text): level = line.count("#") title = line.strip("# ").strip() anchor = title.lower().replace(" ", "-").replace(".", "") - toc.append(f"