diff --git a/README.md b/README.md index c6263c27..1f4715e6 100644 --- a/README.md +++ b/README.md @@ -7,9 +7,10 @@

| Research preview | STORM Paper| Co-STORM Paper | Website |

- **Latest News** 🔥 +- [2025/01] We add [litellm](https://github.com/BerriAI/litellm) integration for language models and embedding models in `knowledge-storm` v1.1.0. + - [2024/09] Co-STORM codebase is now released and integrated into `knowledge-storm` python package v1.0.0. Run `pip install knowledge-storm --upgrade` to check it out. - [2024/09] We introduce collaborative STORM (Co-STORM) to support human-AI collaborative knowledge curation! [Co-STORM Paper](https://www.arxiv.org/abs/2408.15232) has been accepted to EMNLP 2024 main conference. @@ -92,10 +93,11 @@ You could also install the source code which allows you to modify the behavior o Currently, our package support: -- `OpenAIModel`, `AzureOpenAIModel`, `ClaudeModel`, `VLLMClient`, `TGIClient`, `TogetherClient`, `OllamaClient`, `GoogleModel`, `DeepSeekModel`, `GroqModel` as language model components -- `YouRM`, `BingSearch`, `VectorRM`, `SerperRM`, `BraveRM`, `SearXNG`, `DuckDuckGoSearchRM`, `TavilySearchRM`, `GoogleSearch`, and `AzureAISearch` as retrieval module components +- Language model components: All language models supported by litellm as listed [here](https://docs.litellm.ai/docs/providers) +- Embedding model components: All embedding models supported by litellm as listed [here](https://docs.litellm.ai/docs/embedding/supported_embedding) +- retrieval module components: `YouRM`, `BingSearch`, `VectorRM`, `SerperRM`, `BraveRM`, `SearXNG`, `DuckDuckGoSearchRM`, `TavilySearchRM`, `GoogleSearch`, and `AzureAISearch` as -:star2: **PRs for integrating more language models into [knowledge_storm/lm.py](knowledge_storm/lm.py) and search engines/retrievers into [knowledge_storm/rm.py](knowledge_storm/rm.py) are highly appreciated!** +:star2: **PRs for integrating more search engines/retrievers into [knowledge_storm/rm.py](knowledge_storm/rm.py) are highly appreciated!** Both STORM and Co-STORM are working in the information curation layer, you need to set up the information retrieval module and language model module to create their `Runner` classes respectively. @@ -106,7 +108,7 @@ The STORM knowledge curation engine is defined as a simple Python `STORMWikiRunn ```python import os from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs -from knowledge_storm.lm import OpenAIModel +from knowledge_storm.lm import LitellmModel from knowledge_storm.rm import YouRM lm_configs = STORMWikiLMConfigs() @@ -118,8 +120,8 @@ openai_kwargs = { # STORM is a LM system so different components can be powered by different models to reach a good balance between cost and quality. # For a good practice, choose a cheaper/faster model for `conv_simulator_lm` which is used to split queries, synthesize answers in the conversation. # Choose a more powerful model for `article_gen_lm` to generate verifiable text with citations. -gpt_35 = OpenAIModel(model='gpt-3.5-turbo', max_tokens=500, **openai_kwargs) -gpt_4 = OpenAIModel(model='gpt-4o', max_tokens=3000, **openai_kwargs) +gpt_35 = LitellmModel(model='gpt-3.5-turbo', max_tokens=500, **openai_kwargs) +gpt_4 = LitellmModel(model='gpt-4o', max_tokens=3000, **openai_kwargs) lm_configs.set_conv_simulator_lm(gpt_35) lm_configs.set_question_asker_lm(gpt_35) lm_configs.set_outline_gen_lm(gpt_4) @@ -155,7 +157,7 @@ The Co-STORM knowledge curation engine is defined as a simple Python `CoStormRun ```python from knowledge_storm.collaborative_storm.engine import CollaborativeStormLMConfigs, RunnerArgument, CoStormRunner -from knowledge_storm.lm import OpenAIModel +from knowledge_storm.lm import LitellmModel from knowledge_storm.logging_wrapper import LoggingWrapper from knowledge_storm.rm import BingSearch @@ -168,12 +170,12 @@ openai_kwargs = { "top_p": 0.9, "api_base": None, } -question_answering_lm = OpenAIModel(model=gpt_4o_model_name, max_tokens=1000, **openai_kwargs) -discourse_manage_lm = OpenAIModel(model=gpt_4o_model_name, max_tokens=500, **openai_kwargs) -utterance_polishing_lm = OpenAIModel(model=gpt_4o_model_name, max_tokens=2000, **openai_kwargs) -warmstart_outline_gen_lm = OpenAIModel(model=gpt_4o_model_name, max_tokens=500, **openai_kwargs) -question_asking_lm = OpenAIModel(model=gpt_4o_model_name, max_tokens=300, **openai_kwargs) -knowledge_base_lm = OpenAIModel(model=gpt_4o_model_name, max_tokens=1000, **openai_kwargs) +question_answering_lm = LitellmModel(model=gpt_4o_model_name, max_tokens=1000, **openai_kwargs) +discourse_manage_lm = LitellmModel(model=gpt_4o_model_name, max_tokens=500, **openai_kwargs) +utterance_polishing_lm = LitellmModel(model=gpt_4o_model_name, max_tokens=2000, **openai_kwargs) +warmstart_outline_gen_lm = LitellmModel(model=gpt_4o_model_name, max_tokens=500, **openai_kwargs) +question_asking_lm = LitellmModel(model=gpt_4o_model_name, max_tokens=300, **openai_kwargs) +knowledge_base_lm = LitellmModel(model=gpt_4o_model_name, max_tokens=1000, **openai_kwargs) lm_config.set_question_answering_lm(question_answering_lm) lm_config.set_discourse_manage_lm(discourse_manage_lm) @@ -222,6 +224,7 @@ We provide scripts in our [examples folder](examples) as a quick start to run ST We suggest using `secrets.toml` to set up the API keys. Create a file `secrets.toml` under the root directory and add the following content: ```shell +# ============ language model configurations ============ # Set up OpenAI API key. OPENAI_API_KEY="your_openai_api_key" # If you are using the API service provided by OpenAI, include the following line: @@ -230,15 +233,10 @@ OPENAI_API_TYPE="openai" OPENAI_API_TYPE="azure" AZURE_API_BASE="your_azure_api_base_url" AZURE_API_VERSION="your_azure_api_version" -# Set up You.com search API key. -YDC_API_KEY="your_youcom_api_key" -``` - -for **Co-STORM**, please also add following -``` -# if use openai encoder -ENCODER_API_TYPE="openai" -# or ENCODER_API_TYPE="azure" if use azure openai encoder +# ============ retriever configurations ============ +BING_SEARCH_API_KEY="your_bing_search_api_key" # if using bing search +# ============ encoder configurations ============ +ENCODER_API_TYPE="openai" # if using openai encoder ``` ### STORM examples @@ -249,7 +247,7 @@ Run the following command. ```bash python examples/storm_examples/run_storm_wiki_gpt.py \ --output-dir $OUTPUT_DIR \ - --retriever you \ + --retriever bing \ --do-research \ --do-generate-outline \ --do-generate-article \ @@ -328,20 +326,44 @@ We are very grateful to [Michelle Lam](https://michelle123lam.github.io/) for de ## Citation Please cite our paper if you use this code or part of it in your work: ```bibtex -@misc{jiang2024unknownunknowns, - title={Into the Unknown Unknowns: Engaged Human Learning through Participation in Language Model Agent Conversations}, - author={Yucheng Jiang and Yijia Shao and Dekun Ma and Sina J. Semnani and Monica S. Lam}, - year={2024}, - eprint={2408.15232}, - archivePrefix={arXiv}, - primaryClass={cs.CL}, - url={https://arxiv.org/abs/2408.15232}, +@inproceedings{jiang-etal-2024-unknown, + title = "Into the Unknown Unknowns: Engaged Human Learning through Participation in Language Model Agent Conversations", + author = "Jiang, Yucheng and + Shao, Yijia and + Ma, Dekun and + Semnani, Sina and + Lam, Monica", + editor = "Al-Onaizan, Yaser and + Bansal, Mohit and + Chen, Yun-Nung", + booktitle = "Proceedings of the 2024 Conference on Empirical Methods in Natural Language Processing", + month = nov, + year = "2024", + address = "Miami, Florida, USA", + publisher = "Association for Computational Linguistics", + url = "https://aclanthology.org/2024.emnlp-main.554/", + doi = "10.18653/v1/2024.emnlp-main.554", + pages = "9917--9955", } -@inproceedings{shao2024assisting, - title={{Assisting in Writing Wikipedia-like Articles From Scratch with Large Language Models}}, - author={Yijia Shao and Yucheng Jiang and Theodore A. Kanell and Peter Xu and Omar Khattab and Monica S. Lam}, - year={2024}, - booktitle={Proceedings of the 2024 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers)} +@inproceedings{shao-etal-2024-assisting, + title = "Assisting in Writing {W}ikipedia-like Articles From Scratch with Large Language Models", + author = "Shao, Yijia and + Jiang, Yucheng and + Kanell, Theodore and + Xu, Peter and + Khattab, Omar and + Lam, Monica", + editor = "Duh, Kevin and + Gomez, Helena and + Bethard, Steven", + booktitle = "Proceedings of the 2024 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies (Volume 1: Long Papers)", + month = jun, + year = "2024", + address = "Mexico City, Mexico", + publisher = "Association for Computational Linguistics", + url = "https://aclanthology.org/2024.naacl-long.347/", + doi = "10.18653/v1/2024.naacl-long.347", + pages = "6252--6278", } ``` diff --git a/examples/costorm_examples/run_costorm_gpt.py b/examples/costorm_examples/run_costorm_gpt.py index 9046bd5f..40138915 100644 --- a/examples/costorm_examples/run_costorm_gpt.py +++ b/examples/costorm_examples/run_costorm_gpt.py @@ -16,51 +16,83 @@ import os import json from argparse import ArgumentParser -from knowledge_storm.collaborative_storm.engine import CollaborativeStormLMConfigs, RunnerArgument, CoStormRunner -from knowledge_storm.collaborative_storm.modules.callback import LocalConsolePrintCallBackHandler +from knowledge_storm.collaborative_storm.engine import ( + CollaborativeStormLMConfigs, + RunnerArgument, + CoStormRunner, +) +from knowledge_storm.collaborative_storm.modules.callback import ( + LocalConsolePrintCallBackHandler, +) from knowledge_storm.lm import OpenAIModel, AzureOpenAIModel from knowledge_storm.logging_wrapper import LoggingWrapper -from knowledge_storm.rm import YouRM, BingSearch, BraveRM, SerperRM, DuckDuckGoSearchRM, TavilySearchRM, SearXNG +from knowledge_storm.rm import ( + YouRM, + BingSearch, + BraveRM, + SerperRM, + DuckDuckGoSearchRM, + TavilySearchRM, + SearXNG, +) from knowledge_storm.utils import load_api_key def main(args): - load_api_key(toml_file_path='secrets.toml') + load_api_key(toml_file_path="secrets.toml") lm_config: CollaborativeStormLMConfigs = CollaborativeStormLMConfigs() - openai_kwargs = { - "api_key": os.getenv("OPENAI_API_KEY"), - "api_provider": "openai", - "temperature": 1.0, - "top_p": 0.9, - "api_base": None, - } if os.getenv('OPENAI_API_TYPE') == 'openai' else { - "api_key": os.getenv("AZURE_API_KEY"), - "temperature": 1.0, - "top_p": 0.9, - "api_base": os.getenv("AZURE_API_BASE"), - "api_version": os.getenv("AZURE_API_VERSION"), - } - - ModelClass = OpenAIModel if os.getenv('OPENAI_API_TYPE') == 'openai' else AzureOpenAIModel + openai_kwargs = ( + { + "api_key": os.getenv("OPENAI_API_KEY"), + "api_provider": "openai", + "temperature": 1.0, + "top_p": 0.9, + "api_base": None, + } + if os.getenv("OPENAI_API_TYPE") == "openai" + else { + "api_key": os.getenv("AZURE_API_KEY"), + "temperature": 1.0, + "top_p": 0.9, + "api_base": os.getenv("AZURE_API_BASE"), + "api_version": os.getenv("AZURE_API_VERSION"), + } + ) + + ModelClass = ( + OpenAIModel if os.getenv("OPENAI_API_TYPE") == "openai" else AzureOpenAIModel + ) # If you are using Azure service, make sure the model name matches your own deployed model name. # The default name here is only used for demonstration and may not match your case. - gpt_4o_mini_model_name = 'gpt-4o-mini' - gpt_4o_model_name = 'gpt-4o' - if os.getenv('OPENAI_API_TYPE') == 'azure': - openai_kwargs['api_base'] = os.getenv('AZURE_API_BASE') - openai_kwargs['api_version'] = os.getenv('AZURE_API_VERSION') + gpt_4o_mini_model_name = "gpt-4o-mini" + gpt_4o_model_name = "gpt-4o" + if os.getenv("OPENAI_API_TYPE") == "azure": + openai_kwargs["api_base"] = os.getenv("AZURE_API_BASE") + openai_kwargs["api_version"] = os.getenv("AZURE_API_VERSION") # STORM is a LM system so different components can be powered by different models. # For a good balance between cost and quality, you can choose a cheaper/faster model for conv_simulator_lm # which is used to split queries, synthesize answers in the conversation. We recommend using stronger models # for outline_gen_lm which is responsible for organizing the collected information, and article_gen_lm # which is responsible for generating sections with citations. - question_answering_lm = ModelClass(model=gpt_4o_model_name, max_tokens=1000, **openai_kwargs) - discourse_manage_lm = ModelClass(model=gpt_4o_model_name, max_tokens=500, **openai_kwargs) - utterance_polishing_lm = ModelClass(model=gpt_4o_model_name, max_tokens=2000, **openai_kwargs) - warmstart_outline_gen_lm = ModelClass(model=gpt_4o_model_name, max_tokens=500, **openai_kwargs) - question_asking_lm = ModelClass(model=gpt_4o_model_name, max_tokens=300, **openai_kwargs) - knowledge_base_lm = ModelClass(model=gpt_4o_model_name, max_tokens=1000, **openai_kwargs) + question_answering_lm = ModelClass( + model=gpt_4o_model_name, max_tokens=1000, **openai_kwargs + ) + discourse_manage_lm = ModelClass( + model=gpt_4o_model_name, max_tokens=500, **openai_kwargs + ) + utterance_polishing_lm = ModelClass( + model=gpt_4o_model_name, max_tokens=2000, **openai_kwargs + ) + warmstart_outline_gen_lm = ModelClass( + model=gpt_4o_model_name, max_tokens=500, **openai_kwargs + ) + question_asking_lm = ModelClass( + model=gpt_4o_model_name, max_tokens=300, **openai_kwargs + ) + knowledge_base_lm = ModelClass( + model=gpt_4o_model_name, max_tokens=1000, **openai_kwargs + ) lm_config.set_question_answering_lm(question_answering_lm) lm_config.set_discourse_manage_lm(discourse_manage_lm) @@ -69,7 +101,7 @@ def main(args): lm_config.set_question_asking_lm(question_asking_lm) lm_config.set_knowledge_base_lm(knowledge_base_lm) - topic = input('Topic: ') + topic = input("Topic: ") runner_argument = RunnerArgument( topic=topic, retrieve_top_k=args.retrieve_top_k, @@ -83,49 +115,76 @@ def main(args): max_thread_num=args.max_thread_num, max_num_round_table_experts=args.max_num_round_table_experts, moderator_override_N_consecutive_answering_turn=args.moderator_override_N_consecutive_answering_turn, - node_expansion_trigger_count=args.node_expansion_trigger_count) + node_expansion_trigger_count=args.node_expansion_trigger_count, + ) logging_wrapper = LoggingWrapper(lm_config) - callback_handler = LocalConsolePrintCallBackHandler() if args.enable_log_print else None + callback_handler = ( + LocalConsolePrintCallBackHandler() if args.enable_log_print else None + ) # Co-STORM is a knowledge curation system which consumes information from the retrieval module. # Currently, the information source is the Internet and we use search engine API as the retrieval module. match args.retriever: - case 'bing': - rm = BingSearch(bing_search_api=os.getenv('BING_SEARCH_API_KEY'), k=runner_argument.retrieve_top_k) - case 'you': - rm = YouRM(ydc_api_key=os.getenv('YDC_API_KEY'), k=runner_argument.retrieve_top_k) - case 'brave': - rm = BraveRM(brave_search_api_key=os.getenv('BRAVE_API_KEY'), k=runner_argument.retrieve_top_k) - case 'duckduckgo': - rm = DuckDuckGoSearchRM(k=runner_argument.retrieve_top_k, safe_search='On', region='us-en') - case 'serper': - rm = SerperRM(serper_search_api_key=os.getenv('SERPER_API_KEY'), query_params={'autocorrect': True, 'num': 10, 'page': 1}) - case 'tavily': - rm = TavilySearchRM(tavily_search_api_key=os.getenv('TAVILY_API_KEY'), k=runner_argument.retrieve_top_k, include_raw_content=True) - case 'searxng': - rm = SearXNG(searxng_api_key=os.getenv('SEARXNG_API_KEY'), k=runner_argument.retrieve_top_k) + case "bing": + rm = BingSearch( + bing_search_api=os.getenv("BING_SEARCH_API_KEY"), + k=runner_argument.retrieve_top_k, + ) + case "you": + rm = YouRM( + ydc_api_key=os.getenv("YDC_API_KEY"), k=runner_argument.retrieve_top_k + ) + case "brave": + rm = BraveRM( + brave_search_api_key=os.getenv("BRAVE_API_KEY"), + k=runner_argument.retrieve_top_k, + ) + case "duckduckgo": + rm = DuckDuckGoSearchRM( + k=runner_argument.retrieve_top_k, safe_search="On", region="us-en" + ) + case "serper": + rm = SerperRM( + serper_search_api_key=os.getenv("SERPER_API_KEY"), + query_params={"autocorrect": True, "num": 10, "page": 1}, + ) + case "tavily": + rm = TavilySearchRM( + tavily_search_api_key=os.getenv("TAVILY_API_KEY"), + k=runner_argument.retrieve_top_k, + include_raw_content=True, + ) + case "searxng": + rm = SearXNG( + searxng_api_key=os.getenv("SEARXNG_API_KEY"), + k=runner_argument.retrieve_top_k, + ) case _: - raise ValueError(f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", or "searxng"') + raise ValueError( + f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", or "searxng"' + ) - costorm_runner = CoStormRunner(lm_config=lm_config, - runner_argument=runner_argument, - logging_wrapper=logging_wrapper, - rm=rm, - callback_handler=callback_handler) + costorm_runner = CoStormRunner( + lm_config=lm_config, + runner_argument=runner_argument, + logging_wrapper=logging_wrapper, + rm=rm, + callback_handler=callback_handler, + ) # warm start the system costorm_runner.warm_start() # Below is an example of how users may interact with Co-STORM to seek information together # In actual deployment, we suggest allowing the user to decide whether to observe the agent utterance or inject a turn - + # observing Co-STORM LLM agent utterance for 5 turns for _ in range(1): conv_turn = costorm_runner.step() print(f"**{conv_turn.role}**: {conv_turn.utterance}\n") - + # active engaging by injecting your utterance - your_utterance = input('Your utterance: ') + your_utterance = input("Your utterance: ") costorm_runner.step(user_utterance=your_utterance) # continue observing @@ -154,93 +213,105 @@ def main(args): json.dump(log_dump, f, indent=2) -if __name__ == '__main__': +if __name__ == "__main__": parser = ArgumentParser() # global arguments - parser.add_argument('--output-dir', type=str, default='./results/co-storm', - help='Directory to store the outputs.') - parser.add_argument('--retriever', type=str, choices=['bing', 'you', 'brave', 'serper', 'duckduckgo', 'tavily', 'searxng'], - help='The search engine API to use for retrieving information.') + parser.add_argument( + "--output-dir", + type=str, + default="./results/co-storm", + help="Directory to store the outputs.", + ) + parser.add_argument( + "--retriever", + type=str, + choices=["bing", "you", "brave", "serper", "duckduckgo", "tavily", "searxng"], + help="The search engine API to use for retrieving information.", + ) # hyperparameters for co-storm parser.add_argument( - '--retrieve_top_k', + "--retrieve_top_k", type=int, default=10, - help='Retrieve top k results for each query in retriever.' + help="Retrieve top k results for each query in retriever.", ) parser.add_argument( - '--max_search_queries', + "--max_search_queries", type=int, default=2, - help='Maximum number of search queries to consider for each question.' + help="Maximum number of search queries to consider for each question.", ) parser.add_argument( - '--total_conv_turn', + "--total_conv_turn", type=int, default=20, - help='Maximum number of turns in conversation.' + help="Maximum number of turns in conversation.", ) parser.add_argument( - '--max_search_thread', + "--max_search_thread", type=int, default=5, - help='Maximum number of parallel threads for retriever.' + help="Maximum number of parallel threads for retriever.", ) parser.add_argument( - '--max_search_queries_per_turn', + "--max_search_queries_per_turn", type=int, default=3, - help='Maximum number of search queries to consider in each turn.' + help="Maximum number of search queries to consider in each turn.", ) parser.add_argument( - '--warmstart_max_num_experts', + "--warmstart_max_num_experts", type=int, default=3, - help='Max number of experts in perspective-guided QA during warm start.' + help="Max number of experts in perspective-guided QA during warm start.", ) parser.add_argument( - '--warmstart_max_turn_per_experts', + "--warmstart_max_turn_per_experts", type=int, default=2, - help='Max number of turns per perspective during warm start.' + help="Max number of turns per perspective during warm start.", ) parser.add_argument( - '--warmstart_max_thread', + "--warmstart_max_thread", type=int, default=3, - help='Max number of threads for parallel perspective-guided QA during warm start.' + help="Max number of threads for parallel perspective-guided QA during warm start.", ) parser.add_argument( - '--max_thread_num', + "--max_thread_num", type=int, default=10, - help=("Maximum number of threads to use. " - "Consider reducing it if you keep getting 'Exceed rate limit' errors when calling the LM API.") + help=( + "Maximum number of threads to use. " + "Consider reducing it if you keep getting 'Exceed rate limit' errors when calling the LM API." + ), ) parser.add_argument( - '--max_num_round_table_experts', + "--max_num_round_table_experts", type=int, default=2, - help='Max number of active experts in round table discussion.' + help="Max number of active experts in round table discussion.", ) parser.add_argument( - '--moderator_override_N_consecutive_answering_turn', + "--moderator_override_N_consecutive_answering_turn", type=int, default=3, - help=('Number of consecutive expert answering turns before the moderator overrides the conversation.') + help=( + "Number of consecutive expert answering turns before the moderator overrides the conversation." + ), ) parser.add_argument( - '--node_expansion_trigger_count', + "--node_expansion_trigger_count", type=int, default=10, - help='Trigger node expansion for nodes that contain more than N snippets.' + help="Trigger node expansion for nodes that contain more than N snippets.", ) # Boolean flags parser.add_argument( - '--enable_log_print', - action='store_true', - help='If set, enable console log print.' + "--enable_log_print", + action="store_true", + help="If set, enable console log print.", ) main(parser.parse_args()) diff --git a/examples/storm_examples/helper/process_kaggle_arxiv_abstract_dataset.py b/examples/storm_examples/helper/process_kaggle_arxiv_abstract_dataset.py index 926192bf..f3dfe51d 100644 --- a/examples/storm_examples/helper/process_kaggle_arxiv_abstract_dataset.py +++ b/examples/storm_examples/helper/process_kaggle_arxiv_abstract_dataset.py @@ -9,21 +9,28 @@ if __name__ == "__main__": parser = ArgumentParser() - parser.add_argument("--input-path", type=str, help="Path to arxiv_data_210930-054931.csv.") - parser.add_argument("--output-path", type=str, - help="Path to store the csv file that is compatible with VectorRM.") + parser.add_argument( + "--input-path", type=str, help="Path to arxiv_data_210930-054931.csv." + ) + parser.add_argument( + "--output-path", + type=str, + help="Path to store the csv file that is compatible with VectorRM.", + ) args = parser.parse_args() df = pd.read_csv(args.input_path) - print(f'The original dataset has {len(df)} samples.') + print(f"The original dataset has {len(df)} samples.") # Downsample the dataset. - df = df[df['terms'] == "['cs.CV']"] + df = df[df["terms"] == "['cs.CV']"] # Reformat the dataset to match the VectorRM input format. df.rename(columns={"abstracts": "content", "titles": "title"}, inplace=True) - df['url'] = ['uid_' + str(idx) for idx in range(len(df))] # Ensure the url is unique. - df['description'] = '' + df["url"] = [ + "uid_" + str(idx) for idx in range(len(df)) + ] # Ensure the url is unique. + df["description"] = "" - print(f'The downsampled dataset has {len(df)} samples.') - df.to_csv(args.output_path, index=False) \ No newline at end of file + print(f"The downsampled dataset has {len(df)} samples.") + df.to_csv(args.output_path, index=False) diff --git a/examples/storm_examples/run_storm_wiki_claude.py b/examples/storm_examples/run_storm_wiki_claude.py index 1da12c0e..10f6e369 100644 --- a/examples/storm_examples/run_storm_wiki_claude.py +++ b/examples/storm_examples/run_storm_wiki_claude.py @@ -19,19 +19,31 @@ import os from argparse import ArgumentParser -from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs +from knowledge_storm import ( + STORMWikiRunnerArguments, + STORMWikiRunner, + STORMWikiLMConfigs, +) from knowledge_storm.lm import ClaudeModel -from knowledge_storm.rm import YouRM, BingSearch, BraveRM, SerperRM, DuckDuckGoSearchRM, TavilySearchRM, SearXNG +from knowledge_storm.rm import ( + YouRM, + BingSearch, + BraveRM, + SerperRM, + DuckDuckGoSearchRM, + TavilySearchRM, + SearXNG, +) from knowledge_storm.utils import load_api_key def main(args): - load_api_key(toml_file_path='secrets.toml') + load_api_key(toml_file_path="secrets.toml") lm_configs = STORMWikiLMConfigs() claude_kwargs = { - 'api_key': os.getenv("ANTHROPIC_API_KEY"), - 'temperature': 1.0, - 'top_p': 0.9 + "api_key": os.getenv("ANTHROPIC_API_KEY"), + "temperature": 1.0, + "top_p": 0.9, } # STORM is a LM system so different components can be powered by different models. @@ -39,11 +51,21 @@ def main(args): # which is used to split queries, synthesize answers in the conversation. We recommend using stronger models # for outline_gen_lm which is responsible for organizing the collected information, and article_gen_lm # which is responsible for generating sections with citations. - conv_simulator_lm = ClaudeModel(model='claude-3-haiku-20240307', max_tokens=500, **claude_kwargs) - question_asker_lm = ClaudeModel(model='claude-3-sonnet-20240229', max_tokens=500, **claude_kwargs) - outline_gen_lm = ClaudeModel(model='claude-3-opus-20240229', max_tokens=400, **claude_kwargs) - article_gen_lm = ClaudeModel(model='claude-3-opus-20240229', max_tokens=700, **claude_kwargs) - article_polish_lm = ClaudeModel(model='claude-3-opus-20240229', max_tokens=4000, **claude_kwargs) + conv_simulator_lm = ClaudeModel( + model="claude-3-haiku-20240307", max_tokens=500, **claude_kwargs + ) + question_asker_lm = ClaudeModel( + model="claude-3-sonnet-20240229", max_tokens=500, **claude_kwargs + ) + outline_gen_lm = ClaudeModel( + model="claude-3-opus-20240229", max_tokens=400, **claude_kwargs + ) + article_gen_lm = ClaudeModel( + model="claude-3-opus-20240229", max_tokens=700, **claude_kwargs + ) + article_polish_lm = ClaudeModel( + model="claude-3-opus-20240229", max_tokens=4000, **claude_kwargs + ) lm_configs.set_conv_simulator_lm(conv_simulator_lm) lm_configs.set_question_asker_lm(question_asker_lm) @@ -62,26 +84,45 @@ def main(args): # STORM is a knowledge curation system which consumes information from the retrieval module. # Currently, the information source is the Internet and we use search engine API as the retrieval module. match args.retriever: - case 'bing': - rm = BingSearch(bing_search_api=os.getenv('BING_SEARCH_API_KEY'), k=engine_args.search_top_k) - case 'you': - rm = YouRM(ydc_api_key=os.getenv('YDC_API_KEY'), k=engine_args.search_top_k) - case 'brave': - rm = BraveRM(brave_search_api_key=os.getenv('BRAVE_API_KEY'), k=engine_args.search_top_k) - case 'duckduckgo': - rm = DuckDuckGoSearchRM(k=engine_args.search_top_k, safe_search='On', region='us-en') - case 'serper': - rm = SerperRM(serper_search_api_key=os.getenv('SERPER_API_KEY'), query_params={'autocorrect': True, 'num': 10, 'page': 1}) - case 'tavily': - rm = TavilySearchRM(tavily_search_api_key=os.getenv('TAVILY_API_KEY'), k=engine_args.search_top_k, include_raw_content=True) - case 'searxng': - rm = SearXNG(searxng_api_key=os.getenv('SEARXNG_API_KEY'), k=engine_args.search_top_k) + case "bing": + rm = BingSearch( + bing_search_api=os.getenv("BING_SEARCH_API_KEY"), + k=engine_args.search_top_k, + ) + case "you": + rm = YouRM(ydc_api_key=os.getenv("YDC_API_KEY"), k=engine_args.search_top_k) + case "brave": + rm = BraveRM( + brave_search_api_key=os.getenv("BRAVE_API_KEY"), + k=engine_args.search_top_k, + ) + case "duckduckgo": + rm = DuckDuckGoSearchRM( + k=engine_args.search_top_k, safe_search="On", region="us-en" + ) + case "serper": + rm = SerperRM( + serper_search_api_key=os.getenv("SERPER_API_KEY"), + query_params={"autocorrect": True, "num": 10, "page": 1}, + ) + case "tavily": + rm = TavilySearchRM( + tavily_search_api_key=os.getenv("TAVILY_API_KEY"), + k=engine_args.search_top_k, + include_raw_content=True, + ) + case "searxng": + rm = SearXNG( + searxng_api_key=os.getenv("SEARXNG_API_KEY"), k=engine_args.search_top_k + ) case _: - raise ValueError(f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", or "searxng"') - + raise ValueError( + f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", or "searxng"' + ) + runner = STORMWikiRunner(engine_args, lm_configs, rm) - topic = input('Topic: ') + topic = input("Topic: ") runner.run( topic=topic, do_research=args.do_research, @@ -93,38 +134,81 @@ def main(args): runner.summary() -if __name__ == '__main__': +if __name__ == "__main__": parser = ArgumentParser() # global arguments - parser.add_argument('--output-dir', type=str, default='./results/claude', - help='Directory to store the outputs.') - parser.add_argument('--max-thread-num', type=int, default=3, - help='Maximum number of threads to use. The information seeking part and the article generation' - 'part can speed up by using multiple threads. Consider reducing it if keep getting ' - '"Exceed rate limit" error when calling LM API.') - parser.add_argument('--retriever', type=str, choices=['bing', 'you', 'brave', 'serper', 'duckduckgo', 'tavily', 'searxng'], - help='The search engine API to use for retrieving information.') + parser.add_argument( + "--output-dir", + type=str, + default="./results/claude", + help="Directory to store the outputs.", + ) + parser.add_argument( + "--max-thread-num", + type=int, + default=3, + help="Maximum number of threads to use. The information seeking part and the article generation" + "part can speed up by using multiple threads. Consider reducing it if keep getting " + '"Exceed rate limit" error when calling LM API.', + ) + parser.add_argument( + "--retriever", + type=str, + choices=["bing", "you", "brave", "serper", "duckduckgo", "tavily", "searxng"], + help="The search engine API to use for retrieving information.", + ) # stage of the pipeline - parser.add_argument('--do-research', action='store_true', - help='If True, simulate conversation to research the topic; otherwise, load the results.') - parser.add_argument('--do-generate-outline', action='store_true', - help='If True, generate an outline for the topic; otherwise, load the results.') - parser.add_argument('--do-generate-article', action='store_true', - help='If True, generate an article for the topic; otherwise, load the results.') - parser.add_argument('--do-polish-article', action='store_true', - help='If True, polish the article by adding a summarization section and (optionally) removing ' - 'duplicate content.') + parser.add_argument( + "--do-research", + action="store_true", + help="If True, simulate conversation to research the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-outline", + action="store_true", + help="If True, generate an outline for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-article", + action="store_true", + help="If True, generate an article for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-polish-article", + action="store_true", + help="If True, polish the article by adding a summarization section and (optionally) removing " + "duplicate content.", + ) # hyperparameters for the pre-writing stage - parser.add_argument('--max-conv-turn', type=int, default=3, - help='Maximum number of questions in conversational question asking.') - parser.add_argument('--max-perspective', type=int, default=3, - help='Maximum number of perspectives to consider in perspective-guided question asking.') - parser.add_argument('--search-top-k', type=int, default=3, - help='Top k search results to consider for each search query.') + parser.add_argument( + "--max-conv-turn", + type=int, + default=3, + help="Maximum number of questions in conversational question asking.", + ) + parser.add_argument( + "--max-perspective", + type=int, + default=3, + help="Maximum number of perspectives to consider in perspective-guided question asking.", + ) + parser.add_argument( + "--search-top-k", + type=int, + default=3, + help="Top k search results to consider for each search query.", + ) # hyperparameters for the writing stage - parser.add_argument('--retrieve-top-k', type=int, default=3, - help='Top k collected references for each section title.') - parser.add_argument('--remove-duplicate', action='store_true', - help='If True, remove duplicate content from the article.') + parser.add_argument( + "--retrieve-top-k", + type=int, + default=3, + help="Top k collected references for each section title.", + ) + parser.add_argument( + "--remove-duplicate", + action="store_true", + help="If True, remove duplicate content from the article.", + ) - main(parser.parse_args()) \ No newline at end of file + main(parser.parse_args()) diff --git a/examples/storm_examples/run_storm_wiki_deepseek.py b/examples/storm_examples/run_storm_wiki_deepseek.py index e4cf5fca..e831797f 100644 --- a/examples/storm_examples/run_storm_wiki_deepseek.py +++ b/examples/storm_examples/run_storm_wiki_deepseek.py @@ -22,9 +22,21 @@ import logging from argparse import ArgumentParser -from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs +from knowledge_storm import ( + STORMWikiRunnerArguments, + STORMWikiRunner, + STORMWikiLMConfigs, +) from knowledge_storm.lm import DeepSeekModel -from knowledge_storm.rm import YouRM, BingSearch, BraveRM, SerperRM, DuckDuckGoSearchRM, TavilySearchRM, SearXNG +from knowledge_storm.rm import ( + YouRM, + BingSearch, + BraveRM, + SerperRM, + DuckDuckGoSearchRM, + TavilySearchRM, + SearXNG, +) from knowledge_storm.utils import load_api_key @@ -34,10 +46,10 @@ def sanitize_topic(topic): Remove or replace characters that are not allowed in file names. """ # Replace spaces with underscores - topic = topic.replace(' ', '_') + topic = topic.replace(" ", "_") # Remove any character that isn't alphanumeric, underscore, or hyphen - topic = re.sub(r'[^a-zA-Z0-9_-]', '', topic) + topic = re.sub(r"[^a-zA-Z0-9_-]", "", topic) # Ensure the topic isn't empty after sanitization if not topic: @@ -47,29 +59,37 @@ def sanitize_topic(topic): def main(args): - load_api_key(toml_file_path='secrets.toml') + load_api_key(toml_file_path="secrets.toml") lm_configs = STORMWikiLMConfigs() logger = logging.getLogger(__name__) # Ensure DEEPSEEK_API_KEY is set if not os.getenv("DEEPSEEK_API_KEY"): - raise ValueError("DEEPSEEK_API_KEY environment variable is not set. Please set it in your secrets.toml file.") + raise ValueError( + "DEEPSEEK_API_KEY environment variable is not set. Please set it in your secrets.toml file." + ) deepseek_kwargs = { - 'api_key': os.getenv("DEEPSEEK_API_KEY"), - 'api_base': os.getenv("DEEPSEEK_API_BASE", "https://api.deepseek.com"), - 'temperature': args.temperature, - 'top_p': args.top_p, + "api_key": os.getenv("DEEPSEEK_API_KEY"), + "api_base": os.getenv("DEEPSEEK_API_BASE", "https://api.deepseek.com"), + "temperature": args.temperature, + "top_p": args.top_p, } # DeepSeek offers two main models: 'deepseek-chat' for general tasks and 'deepseek-coder' for coding tasks # Users can choose the appropriate model based on their needs - conv_simulator_lm = DeepSeekModel(model=args.model, max_tokens=500, **deepseek_kwargs) - question_asker_lm = DeepSeekModel(model=args.model, max_tokens=500, **deepseek_kwargs) + conv_simulator_lm = DeepSeekModel( + model=args.model, max_tokens=500, **deepseek_kwargs + ) + question_asker_lm = DeepSeekModel( + model=args.model, max_tokens=500, **deepseek_kwargs + ) outline_gen_lm = DeepSeekModel(model=args.model, max_tokens=400, **deepseek_kwargs) article_gen_lm = DeepSeekModel(model=args.model, max_tokens=700, **deepseek_kwargs) - article_polish_lm = DeepSeekModel(model=args.model, max_tokens=4000, **deepseek_kwargs) + article_polish_lm = DeepSeekModel( + model=args.model, max_tokens=4000, **deepseek_kwargs + ) lm_configs.set_conv_simulator_lm(conv_simulator_lm) lm_configs.set_question_asker_lm(question_asker_lm) @@ -88,26 +108,45 @@ def main(args): # STORM is a knowledge curation system which consumes information from the retrieval module. # Currently, the information source is the Internet and we use search engine API as the retrieval module. match args.retriever: - case 'bing': - rm = BingSearch(bing_search_api=os.getenv('BING_SEARCH_API_KEY'), k=engine_args.search_top_k) - case 'you': - rm = YouRM(ydc_api_key=os.getenv('YDC_API_KEY'), k=engine_args.search_top_k) - case 'brave': - rm = BraveRM(brave_search_api_key=os.getenv('BRAVE_API_KEY'), k=engine_args.search_top_k) - case 'duckduckgo': - rm = DuckDuckGoSearchRM(k=engine_args.search_top_k, safe_search='On', region='us-en') - case 'serper': - rm = SerperRM(serper_search_api_key=os.getenv('SERPER_API_KEY'), query_params={'autocorrect': True, 'num': 10, 'page': 1}) - case 'tavily': - rm = TavilySearchRM(tavily_search_api_key=os.getenv('TAVILY_API_KEY'), k=engine_args.search_top_k, include_raw_content=True) - case 'searxng': - rm = SearXNG(searxng_api_key=os.getenv('SEARXNG_API_KEY'), k=engine_args.search_top_k) + case "bing": + rm = BingSearch( + bing_search_api=os.getenv("BING_SEARCH_API_KEY"), + k=engine_args.search_top_k, + ) + case "you": + rm = YouRM(ydc_api_key=os.getenv("YDC_API_KEY"), k=engine_args.search_top_k) + case "brave": + rm = BraveRM( + brave_search_api_key=os.getenv("BRAVE_API_KEY"), + k=engine_args.search_top_k, + ) + case "duckduckgo": + rm = DuckDuckGoSearchRM( + k=engine_args.search_top_k, safe_search="On", region="us-en" + ) + case "serper": + rm = SerperRM( + serper_search_api_key=os.getenv("SERPER_API_KEY"), + query_params={"autocorrect": True, "num": 10, "page": 1}, + ) + case "tavily": + rm = TavilySearchRM( + tavily_search_api_key=os.getenv("TAVILY_API_KEY"), + k=engine_args.search_top_k, + include_raw_content=True, + ) + case "searxng": + rm = SearXNG( + searxng_api_key=os.getenv("SEARXNG_API_KEY"), k=engine_args.search_top_k + ) case _: - raise ValueError(f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", or "searxng"') + raise ValueError( + f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", or "searxng"' + ) runner = STORMWikiRunner(engine_args, lm_configs, rm) - topic = input('Topic: ') + topic = input("Topic: ") sanitized_topic = sanitize_topic(topic) try: @@ -126,44 +165,94 @@ def main(args): raise -if __name__ == '__main__': +if __name__ == "__main__": parser = ArgumentParser() # global arguments - parser.add_argument('--output-dir', type=str, default='./results/deepseek', - help='Directory to store the outputs.') - parser.add_argument('--max-thread-num', type=int, default=3, - help='Maximum number of threads to use. The information seeking part and the article generation' - 'part can speed up by using multiple threads. Consider reducing it if keep getting ' - '"Exceed rate limit" error when calling LM API.') - parser.add_argument('--retriever', type=str, choices=['bing', 'you', 'brave', 'serper', 'duckduckgo', 'tavily', 'searxng'], - help='The search engine API to use for retrieving information.') - parser.add_argument('--model', type=str, choices=['deepseek-chat', 'deepseek-coder'], default='deepseek-chat', - help='DeepSeek model to use. "deepseek-chat" for general tasks, "deepseek-coder" for coding tasks.') - parser.add_argument('--temperature', type=float, default=1.0, - help='Sampling temperature to use.') - parser.add_argument('--top_p', type=float, default=0.9, - help='Top-p sampling parameter.') + parser.add_argument( + "--output-dir", + type=str, + default="./results/deepseek", + help="Directory to store the outputs.", + ) + parser.add_argument( + "--max-thread-num", + type=int, + default=3, + help="Maximum number of threads to use. The information seeking part and the article generation" + "part can speed up by using multiple threads. Consider reducing it if keep getting " + '"Exceed rate limit" error when calling LM API.', + ) + parser.add_argument( + "--retriever", + type=str, + choices=["bing", "you", "brave", "serper", "duckduckgo", "tavily", "searxng"], + help="The search engine API to use for retrieving information.", + ) + parser.add_argument( + "--model", + type=str, + choices=["deepseek-chat", "deepseek-coder"], + default="deepseek-chat", + help='DeepSeek model to use. "deepseek-chat" for general tasks, "deepseek-coder" for coding tasks.', + ) + parser.add_argument( + "--temperature", type=float, default=1.0, help="Sampling temperature to use." + ) + parser.add_argument( + "--top_p", type=float, default=0.9, help="Top-p sampling parameter." + ) # stage of the pipeline - parser.add_argument('--do-research', action='store_true', - help='If True, simulate conversation to research the topic; otherwise, load the results.') - parser.add_argument('--do-generate-outline', action='store_true', - help='If True, generate an outline for the topic; otherwise, load the results.') - parser.add_argument('--do-generate-article', action='store_true', - help='If True, generate an article for the topic; otherwise, load the results.') - parser.add_argument('--do-polish-article', action='store_true', - help='If True, polish the article by adding a summarization section and (optionally) removing ' - 'duplicate content.') + parser.add_argument( + "--do-research", + action="store_true", + help="If True, simulate conversation to research the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-outline", + action="store_true", + help="If True, generate an outline for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-article", + action="store_true", + help="If True, generate an article for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-polish-article", + action="store_true", + help="If True, polish the article by adding a summarization section and (optionally) removing " + "duplicate content.", + ) # hyperparameters for the pre-writing stage - parser.add_argument('--max-conv-turn', type=int, default=3, - help='Maximum number of questions in conversational question asking.') - parser.add_argument('--max-perspective', type=int, default=3, - help='Maximum number of perspectives to consider in perspective-guided question asking.') - parser.add_argument('--search-top-k', type=int, default=3, - help='Top k search results to consider for each search query.') + parser.add_argument( + "--max-conv-turn", + type=int, + default=3, + help="Maximum number of questions in conversational question asking.", + ) + parser.add_argument( + "--max-perspective", + type=int, + default=3, + help="Maximum number of perspectives to consider in perspective-guided question asking.", + ) + parser.add_argument( + "--search-top-k", + type=int, + default=3, + help="Top k search results to consider for each search query.", + ) # hyperparameters for the writing stage - parser.add_argument('--retrieve-top-k', type=int, default=3, - help='Top k collected references for each section title.') - parser.add_argument('--remove-duplicate', action='store_true', - help='If True, remove duplicate content from the article.') + parser.add_argument( + "--retrieve-top-k", + type=int, + default=3, + help="Top k collected references for each section title.", + ) + parser.add_argument( + "--remove-duplicate", + action="store_true", + help="If True, remove duplicate content from the article.", + ) - main(parser.parse_args()) \ No newline at end of file + main(parser.parse_args()) diff --git a/examples/storm_examples/run_storm_wiki_gemini.py b/examples/storm_examples/run_storm_wiki_gemini.py index d5740629..947832fa 100644 --- a/examples/storm_examples/run_storm_wiki_gemini.py +++ b/examples/storm_examples/run_storm_wiki_gemini.py @@ -19,18 +19,31 @@ import os from argparse import ArgumentParser -from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs +from knowledge_storm import ( + STORMWikiRunnerArguments, + STORMWikiRunner, + STORMWikiLMConfigs, +) from knowledge_storm.lm import GoogleModel -from knowledge_storm.rm import YouRM, BingSearch, BraveRM, SerperRM, DuckDuckGoSearchRM, TavilySearchRM, SearXNG +from knowledge_storm.rm import ( + YouRM, + BingSearch, + BraveRM, + SerperRM, + DuckDuckGoSearchRM, + TavilySearchRM, + SearXNG, +) from knowledge_storm.utils import load_api_key + def main(args): - load_api_key(toml_file_path='secrets.toml') + load_api_key(toml_file_path="secrets.toml") lm_configs = STORMWikiLMConfigs() gemini_kwargs = { - 'api_key': os.getenv("GOOGLE_API_KEY"), - 'temperature': 1.0, - 'top_p': 0.9 + "api_key": os.getenv("GOOGLE_API_KEY"), + "temperature": 1.0, + "top_p": 0.9, } # STORM is a LM system so different components can be powered by different models. @@ -40,11 +53,21 @@ def main(args): # which is responsible for generating sections with citations. # To check out available Google models, see: # https://ai.google.dev/gemini-api/docs/get-started/tutorial?lang=python#list_models - conv_simulator_lm = GoogleModel(model='models/gemini-1.5-flash', max_tokens=500, **gemini_kwargs) - question_asker_lm = GoogleModel(model='models/gemini-1.5-flash', max_tokens=500, **gemini_kwargs) - outline_gen_lm = GoogleModel(model='models/gemini-1.5-pro-exp-0801', max_tokens=400, **gemini_kwargs) - article_gen_lm = GoogleModel(model='models/gemini-1.5-pro-exp-0801', max_tokens=700, **gemini_kwargs) - article_polish_lm = GoogleModel(model='models/gemini-1.5-pro-exp-0801', max_tokens=4000, **gemini_kwargs) + conv_simulator_lm = GoogleModel( + model="models/gemini-1.5-flash", max_tokens=500, **gemini_kwargs + ) + question_asker_lm = GoogleModel( + model="models/gemini-1.5-flash", max_tokens=500, **gemini_kwargs + ) + outline_gen_lm = GoogleModel( + model="models/gemini-1.5-pro-exp-0801", max_tokens=400, **gemini_kwargs + ) + article_gen_lm = GoogleModel( + model="models/gemini-1.5-pro-exp-0801", max_tokens=700, **gemini_kwargs + ) + article_polish_lm = GoogleModel( + model="models/gemini-1.5-pro-exp-0801", max_tokens=4000, **gemini_kwargs + ) lm_configs.set_conv_simulator_lm(conv_simulator_lm) lm_configs.set_question_asker_lm(question_asker_lm) @@ -63,26 +86,45 @@ def main(args): # STORM is a knowledge curation system which consumes information from the retrieval module. # Currently, the information source is the Internet and we use search engine API as the retrieval module. match args.retriever: - case 'bing': - rm = BingSearch(bing_search_api=os.getenv('BING_SEARCH_API_KEY'), k=engine_args.search_top_k) - case 'you': - rm = YouRM(ydc_api_key=os.getenv('YDC_API_KEY'), k=engine_args.search_top_k) - case 'brave': - rm = BraveRM(brave_search_api_key=os.getenv('BRAVE_API_KEY'), k=engine_args.search_top_k) - case 'duckduckgo': - rm = DuckDuckGoSearchRM(k=engine_args.search_top_k, safe_search='On', region='us-en') - case 'serper': - rm = SerperRM(serper_search_api_key=os.getenv('SERPER_API_KEY'), query_params={'autocorrect': True, 'num': 10, 'page': 1}) - case 'tavily': - rm = TavilySearchRM(tavily_search_api_key=os.getenv('TAVILY_API_KEY'), k=engine_args.search_top_k, include_raw_content=True) - case 'searxng': - rm = SearXNG(searxng_api_key=os.getenv('SEARXNG_API_KEY'), k=engine_args.search_top_k) + case "bing": + rm = BingSearch( + bing_search_api=os.getenv("BING_SEARCH_API_KEY"), + k=engine_args.search_top_k, + ) + case "you": + rm = YouRM(ydc_api_key=os.getenv("YDC_API_KEY"), k=engine_args.search_top_k) + case "brave": + rm = BraveRM( + brave_search_api_key=os.getenv("BRAVE_API_KEY"), + k=engine_args.search_top_k, + ) + case "duckduckgo": + rm = DuckDuckGoSearchRM( + k=engine_args.search_top_k, safe_search="On", region="us-en" + ) + case "serper": + rm = SerperRM( + serper_search_api_key=os.getenv("SERPER_API_KEY"), + query_params={"autocorrect": True, "num": 10, "page": 1}, + ) + case "tavily": + rm = TavilySearchRM( + tavily_search_api_key=os.getenv("TAVILY_API_KEY"), + k=engine_args.search_top_k, + include_raw_content=True, + ) + case "searxng": + rm = SearXNG( + searxng_api_key=os.getenv("SEARXNG_API_KEY"), k=engine_args.search_top_k + ) case _: - raise ValueError(f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", or "searxng"') + raise ValueError( + f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", or "searxng"' + ) runner = STORMWikiRunner(engine_args, lm_configs, rm) - topic = input('Topic: ') + topic = input("Topic: ") runner.run( topic=topic, do_research=args.do_research, @@ -94,38 +136,81 @@ def main(args): runner.summary() -if __name__ == '__main__': +if __name__ == "__main__": parser = ArgumentParser() # global arguments - parser.add_argument('--output-dir', type=str, default='./results/gemini', - help='Directory to store the outputs.') - parser.add_argument('--max-thread-num', type=int, default=3, - help='Maximum number of threads to use. The information seeking part and the article generation' - 'part can speed up by using multiple threads. Consider reducing it if keep getting ' - '"Exceed rate limit" error when calling LM API.') - parser.add_argument('--retriever', type=str, choices=['bing', 'you', 'brave', 'serper', 'duckduckgo', 'tavily', 'searxng'], - help='The search engine API to use for retrieving information.') + parser.add_argument( + "--output-dir", + type=str, + default="./results/gemini", + help="Directory to store the outputs.", + ) + parser.add_argument( + "--max-thread-num", + type=int, + default=3, + help="Maximum number of threads to use. The information seeking part and the article generation" + "part can speed up by using multiple threads. Consider reducing it if keep getting " + '"Exceed rate limit" error when calling LM API.', + ) + parser.add_argument( + "--retriever", + type=str, + choices=["bing", "you", "brave", "serper", "duckduckgo", "tavily", "searxng"], + help="The search engine API to use for retrieving information.", + ) # stage of the pipeline - parser.add_argument('--do-research', action='store_true', - help='If True, simulate conversation to research the topic; otherwise, load the results.') - parser.add_argument('--do-generate-outline', action='store_true', - help='If True, generate an outline for the topic; otherwise, load the results.') - parser.add_argument('--do-generate-article', action='store_true', - help='If True, generate an article for the topic; otherwise, load the results.') - parser.add_argument('--do-polish-article', action='store_true', - help='If True, polish the article by adding a summarization section and (optionally) removing ' - 'duplicate content.') + parser.add_argument( + "--do-research", + action="store_true", + help="If True, simulate conversation to research the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-outline", + action="store_true", + help="If True, generate an outline for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-article", + action="store_true", + help="If True, generate an article for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-polish-article", + action="store_true", + help="If True, polish the article by adding a summarization section and (optionally) removing " + "duplicate content.", + ) # hyperparameters for the pre-writing stage - parser.add_argument('--max-conv-turn', type=int, default=3, - help='Maximum number of questions in conversational question asking.') - parser.add_argument('--max-perspective', type=int, default=3, - help='Maximum number of perspectives to consider in perspective-guided question asking.') - parser.add_argument('--search-top-k', type=int, default=3, - help='Top k search results to consider for each search query.') + parser.add_argument( + "--max-conv-turn", + type=int, + default=3, + help="Maximum number of questions in conversational question asking.", + ) + parser.add_argument( + "--max-perspective", + type=int, + default=3, + help="Maximum number of perspectives to consider in perspective-guided question asking.", + ) + parser.add_argument( + "--search-top-k", + type=int, + default=3, + help="Top k search results to consider for each search query.", + ) # hyperparameters for the writing stage - parser.add_argument('--retrieve-top-k', type=int, default=3, - help='Top k collected references for each section title.') - parser.add_argument('--remove-duplicate', action='store_true', - help='If True, remove duplicate content from the article.') + parser.add_argument( + "--retrieve-top-k", + type=int, + default=3, + help="Top k collected references for each section title.", + ) + parser.add_argument( + "--remove-duplicate", + action="store_true", + help="If True, remove duplicate content from the article.", + ) main(parser.parse_args()) diff --git a/examples/storm_examples/run_storm_wiki_gpt.py b/examples/storm_examples/run_storm_wiki_gpt.py index b1740a12..e96c6809 100644 --- a/examples/storm_examples/run_storm_wiki_gpt.py +++ b/examples/storm_examples/run_storm_wiki_gpt.py @@ -22,40 +22,63 @@ import os from argparse import ArgumentParser -from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs +from knowledge_storm import ( + STORMWikiRunnerArguments, + STORMWikiRunner, + STORMWikiLMConfigs, +) from knowledge_storm.lm import OpenAIModel, AzureOpenAIModel -from knowledge_storm.rm import YouRM, BingSearch, BraveRM, SerperRM, DuckDuckGoSearchRM, TavilySearchRM, SearXNG, AzureAISearch +from knowledge_storm.rm import ( + YouRM, + BingSearch, + BraveRM, + SerperRM, + DuckDuckGoSearchRM, + TavilySearchRM, + SearXNG, + AzureAISearch, +) from knowledge_storm.utils import load_api_key def main(args): - load_api_key(toml_file_path='secrets.toml') + load_api_key(toml_file_path="secrets.toml") lm_configs = STORMWikiLMConfigs() openai_kwargs = { - 'api_key': os.getenv("OPENAI_API_KEY"), - 'temperature': 1.0, - 'top_p': 0.9, + "api_key": os.getenv("OPENAI_API_KEY"), + "temperature": 1.0, + "top_p": 0.9, } - ModelClass = OpenAIModel if os.getenv('OPENAI_API_TYPE') == 'openai' else AzureOpenAIModel + ModelClass = ( + OpenAIModel if os.getenv("OPENAI_API_TYPE") == "openai" else AzureOpenAIModel + ) # If you are using Azure service, make sure the model name matches your own deployed model name. # The default name here is only used for demonstration and may not match your case. - gpt_35_model_name = 'gpt-3.5-turbo' if os.getenv('OPENAI_API_TYPE') == 'openai' else 'gpt-35-turbo' - gpt_4_model_name = 'gpt-4o' - if os.getenv('OPENAI_API_TYPE') == 'azure': - openai_kwargs['api_base'] = os.getenv('AZURE_API_BASE') - openai_kwargs['api_version'] = os.getenv('AZURE_API_VERSION') + gpt_35_model_name = ( + "gpt-3.5-turbo" if os.getenv("OPENAI_API_TYPE") == "openai" else "gpt-35-turbo" + ) + gpt_4_model_name = "gpt-4o" + if os.getenv("OPENAI_API_TYPE") == "azure": + openai_kwargs["api_base"] = os.getenv("AZURE_API_BASE") + openai_kwargs["api_version"] = os.getenv("AZURE_API_VERSION") # STORM is a LM system so different components can be powered by different models. # For a good balance between cost and quality, you can choose a cheaper/faster model for conv_simulator_lm # which is used to split queries, synthesize answers in the conversation. We recommend using stronger models # for outline_gen_lm which is responsible for organizing the collected information, and article_gen_lm # which is responsible for generating sections with citations. - conv_simulator_lm = ModelClass(model=gpt_35_model_name, max_tokens=500, **openai_kwargs) - question_asker_lm = ModelClass(model=gpt_35_model_name, max_tokens=500, **openai_kwargs) + conv_simulator_lm = ModelClass( + model=gpt_35_model_name, max_tokens=500, **openai_kwargs + ) + question_asker_lm = ModelClass( + model=gpt_35_model_name, max_tokens=500, **openai_kwargs + ) outline_gen_lm = ModelClass(model=gpt_4_model_name, max_tokens=400, **openai_kwargs) article_gen_lm = ModelClass(model=gpt_4_model_name, max_tokens=700, **openai_kwargs) - article_polish_lm = ModelClass(model=gpt_4_model_name, max_tokens=4000, **openai_kwargs) + article_polish_lm = ModelClass( + model=gpt_4_model_name, max_tokens=4000, **openai_kwargs + ) lm_configs.set_conv_simulator_lm(conv_simulator_lm) lm_configs.set_question_asker_lm(question_asker_lm) @@ -75,28 +98,50 @@ def main(args): # Currently, the information source is the Internet and we use search engine API as the retrieval module. match args.retriever: - case 'bing': - rm = BingSearch(bing_search_api=os.getenv('BING_SEARCH_API_KEY'), k=engine_args.search_top_k) - case 'you': - rm = YouRM(ydc_api_key=os.getenv('YDC_API_KEY'), k=engine_args.search_top_k) - case 'brave': - rm = BraveRM(brave_search_api_key=os.getenv('BRAVE_API_KEY'), k=engine_args.search_top_k) - case 'duckduckgo': - rm = DuckDuckGoSearchRM(k=engine_args.search_top_k, safe_search='On', region='us-en') - case 'serper': - rm = SerperRM(serper_search_api_key=os.getenv('SERPER_API_KEY'), query_params={'autocorrect': True, 'num': 10, 'page': 1}) - case 'tavily': - rm = TavilySearchRM(tavily_search_api_key=os.getenv('TAVILY_API_KEY'), k=engine_args.search_top_k, include_raw_content=True) - case 'searxng': - rm = SearXNG(searxng_api_key=os.getenv('SEARXNG_API_KEY'), k=engine_args.search_top_k) - case 'azure_ai_search': - rm = AzureAISearch(azure_ai_search_api_key=os.getenv('AZURE_AI_SEARCH_API_KEY'), k=engine_args.search_top_k) + case "bing": + rm = BingSearch( + bing_search_api=os.getenv("BING_SEARCH_API_KEY"), + k=engine_args.search_top_k, + ) + case "you": + rm = YouRM(ydc_api_key=os.getenv("YDC_API_KEY"), k=engine_args.search_top_k) + case "brave": + rm = BraveRM( + brave_search_api_key=os.getenv("BRAVE_API_KEY"), + k=engine_args.search_top_k, + ) + case "duckduckgo": + rm = DuckDuckGoSearchRM( + k=engine_args.search_top_k, safe_search="On", region="us-en" + ) + case "serper": + rm = SerperRM( + serper_search_api_key=os.getenv("SERPER_API_KEY"), + query_params={"autocorrect": True, "num": 10, "page": 1}, + ) + case "tavily": + rm = TavilySearchRM( + tavily_search_api_key=os.getenv("TAVILY_API_KEY"), + k=engine_args.search_top_k, + include_raw_content=True, + ) + case "searxng": + rm = SearXNG( + searxng_api_key=os.getenv("SEARXNG_API_KEY"), k=engine_args.search_top_k + ) + case "azure_ai_search": + rm = AzureAISearch( + azure_ai_search_api_key=os.getenv("AZURE_AI_SEARCH_API_KEY"), + k=engine_args.search_top_k, + ) case _: - raise ValueError(f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", "searxng", or "azure_ai_search"') + raise ValueError( + f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", "searxng", or "azure_ai_search"' + ) runner = STORMWikiRunner(engine_args, lm_configs, rm) - topic = input('Topic: ') + topic = input("Topic: ") runner.run( topic=topic, do_research=args.do_research, @@ -108,38 +153,90 @@ def main(args): runner.summary() -if __name__ == '__main__': +if __name__ == "__main__": parser = ArgumentParser() # global arguments - parser.add_argument('--output-dir', type=str, default='./results/gpt', - help='Directory to store the outputs.') - parser.add_argument('--max-thread-num', type=int, default=3, - help='Maximum number of threads to use. The information seeking part and the article generation' - 'part can speed up by using multiple threads. Consider reducing it if keep getting ' - '"Exceed rate limit" error when calling LM API.') - parser.add_argument('--retriever', type=str, choices=['bing', 'you', 'brave', 'serper', 'duckduckgo', 'tavily', 'searxng', 'azure_ai_search'], - help='The search engine API to use for retrieving information.') + parser.add_argument( + "--output-dir", + type=str, + default="./results/gpt", + help="Directory to store the outputs.", + ) + parser.add_argument( + "--max-thread-num", + type=int, + default=3, + help="Maximum number of threads to use. The information seeking part and the article generation" + "part can speed up by using multiple threads. Consider reducing it if keep getting " + '"Exceed rate limit" error when calling LM API.', + ) + parser.add_argument( + "--retriever", + type=str, + choices=[ + "bing", + "you", + "brave", + "serper", + "duckduckgo", + "tavily", + "searxng", + "azure_ai_search", + ], + help="The search engine API to use for retrieving information.", + ) # stage of the pipeline - parser.add_argument('--do-research', action='store_true', - help='If True, simulate conversation to research the topic; otherwise, load the results.') - parser.add_argument('--do-generate-outline', action='store_true', - help='If True, generate an outline for the topic; otherwise, load the results.') - parser.add_argument('--do-generate-article', action='store_true', - help='If True, generate an article for the topic; otherwise, load the results.') - parser.add_argument('--do-polish-article', action='store_true', - help='If True, polish the article by adding a summarization section and (optionally) removing ' - 'duplicate content.') + parser.add_argument( + "--do-research", + action="store_true", + help="If True, simulate conversation to research the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-outline", + action="store_true", + help="If True, generate an outline for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-article", + action="store_true", + help="If True, generate an article for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-polish-article", + action="store_true", + help="If True, polish the article by adding a summarization section and (optionally) removing " + "duplicate content.", + ) # hyperparameters for the pre-writing stage - parser.add_argument('--max-conv-turn', type=int, default=3, - help='Maximum number of questions in conversational question asking.') - parser.add_argument('--max-perspective', type=int, default=3, - help='Maximum number of perspectives to consider in perspective-guided question asking.') - parser.add_argument('--search-top-k', type=int, default=3, - help='Top k search results to consider for each search query.') + parser.add_argument( + "--max-conv-turn", + type=int, + default=3, + help="Maximum number of questions in conversational question asking.", + ) + parser.add_argument( + "--max-perspective", + type=int, + default=3, + help="Maximum number of perspectives to consider in perspective-guided question asking.", + ) + parser.add_argument( + "--search-top-k", + type=int, + default=3, + help="Top k search results to consider for each search query.", + ) # hyperparameters for the writing stage - parser.add_argument('--retrieve-top-k', type=int, default=3, - help='Top k collected references for each section title.') - parser.add_argument('--remove-duplicate', action='store_true', - help='If True, remove duplicate content from the article.') + parser.add_argument( + "--retrieve-top-k", + type=int, + default=3, + help="Top k collected references for each section title.", + ) + parser.add_argument( + "--remove-duplicate", + action="store_true", + help="If True, remove duplicate content from the article.", + ) main(parser.parse_args()) diff --git a/examples/storm_examples/run_storm_wiki_gpt_with_VectorRM.py b/examples/storm_examples/run_storm_wiki_gpt_with_VectorRM.py index 8fa4d0d6..77c9dbf7 100644 --- a/examples/storm_examples/run_storm_wiki_gpt_with_VectorRM.py +++ b/examples/storm_examples/run_storm_wiki_gpt_with_VectorRM.py @@ -29,7 +29,11 @@ import os from argparse import ArgumentParser -from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs +from knowledge_storm import ( + STORMWikiRunnerArguments, + STORMWikiRunner, + STORMWikiLMConfigs, +) from knowledge_storm.rm import VectorRM from knowledge_storm.lm import OpenAIModel, AzureOpenAIModel from knowledge_storm.utils import load_api_key, QdrantVectorStoreManager @@ -37,35 +41,45 @@ def main(args): # Load API key from the specified toml file path - load_api_key(toml_file_path='secrets.toml') + load_api_key(toml_file_path="secrets.toml") # Initialize the language model configurations engine_lm_configs = STORMWikiLMConfigs() openai_kwargs = { - 'api_key': os.getenv("OPENAI_API_KEY"), - 'temperature': 1.0, - 'top_p': 0.9, + "api_key": os.getenv("OPENAI_API_KEY"), + "temperature": 1.0, + "top_p": 0.9, } - ModelClass = OpenAIModel if os.getenv('OPENAI_API_TYPE') == 'openai' else AzureOpenAIModel + ModelClass = ( + OpenAIModel if os.getenv("OPENAI_API_TYPE") == "openai" else AzureOpenAIModel + ) # If you are using Azure service, make sure the model name matches your own deployed model name. # The default name here is only used for demonstration and may not match your case. - gpt_35_model_name = 'gpt-3.5-turbo' if os.getenv('OPENAI_API_TYPE') == 'openai' else 'gpt-35-turbo' - gpt_4_model_name = 'gpt-4o' - if os.getenv('OPENAI_API_TYPE') == 'azure': - openai_kwargs['api_base'] = os.getenv('AZURE_API_BASE') - openai_kwargs['api_version'] = os.getenv('AZURE_API_VERSION') + gpt_35_model_name = ( + "gpt-3.5-turbo" if os.getenv("OPENAI_API_TYPE") == "openai" else "gpt-35-turbo" + ) + gpt_4_model_name = "gpt-4o" + if os.getenv("OPENAI_API_TYPE") == "azure": + openai_kwargs["api_base"] = os.getenv("AZURE_API_BASE") + openai_kwargs["api_version"] = os.getenv("AZURE_API_VERSION") # STORM is a LM system so different components can be powered by different models. - # For a good balance between cost and quality, you can choose a cheaper/faster model for conv_simulator_lm + # For a good balance between cost and quality, you can choose a cheaper/faster model for conv_simulator_lm # which is used to split queries, synthesize answers in the conversation. We recommend using stronger models # for outline_gen_lm which is responsible for organizing the collected information, and article_gen_lm # which is responsible for generating sections with citations. - conv_simulator_lm = ModelClass(model=gpt_35_model_name, max_tokens=500, **openai_kwargs) - question_asker_lm = ModelClass(model=gpt_35_model_name, max_tokens=500, **openai_kwargs) + conv_simulator_lm = ModelClass( + model=gpt_35_model_name, max_tokens=500, **openai_kwargs + ) + question_asker_lm = ModelClass( + model=gpt_35_model_name, max_tokens=500, **openai_kwargs + ) outline_gen_lm = ModelClass(model=gpt_4_model_name, max_tokens=400, **openai_kwargs) article_gen_lm = ModelClass(model=gpt_4_model_name, max_tokens=700, **openai_kwargs) - article_polish_lm = ModelClass(model=gpt_4_model_name, max_tokens=4000, **openai_kwargs) + article_polish_lm = ModelClass( + model=gpt_4_model_name, max_tokens=4000, **openai_kwargs + ) engine_lm_configs.set_conv_simulator_lm(conv_simulator_lm) engine_lm_configs.set_question_asker_lm(question_asker_lm) @@ -85,43 +99,49 @@ def main(args): # Create / update the vector store with the documents in the csv file if args.csv_file_path: kwargs = { - 'file_path': args.csv_file_path, - 'content_column': 'content', - 'title_column': 'title', - 'url_column': 'url', - 'desc_column': 'description', - 'batch_size': args.embed_batch_size, - 'vector_db_mode': args.vector_db_mode, - 'collection_name': args.collection_name, - 'embedding_model': args.embedding_model, - 'device': args.device, + "file_path": args.csv_file_path, + "content_column": "content", + "title_column": "title", + "url_column": "url", + "desc_column": "description", + "batch_size": args.embed_batch_size, + "vector_db_mode": args.vector_db_mode, + "collection_name": args.collection_name, + "embedding_model": args.embedding_model, + "device": args.device, } - if args.vector_db_mode == 'offline': + if args.vector_db_mode == "offline": QdrantVectorStoreManager.create_or_update_vector_store( - vector_store_path=args.offline_vector_db_dir, - **kwargs + vector_store_path=args.offline_vector_db_dir, **kwargs ) - elif args.vector_db_mode == 'online': + elif args.vector_db_mode == "online": QdrantVectorStoreManager.create_or_update_vector_store( url=args.online_vector_db_url, - api_key=os.getenv('QDRANT_API_KEY'), + api_key=os.getenv("QDRANT_API_KEY"), **kwargs ) # Setup VectorRM to retrieve information from your own data - rm = VectorRM(collection_name=args.collection_name, embedding_model=args.embedding_model, device=args.device, k=engine_args.search_top_k) + rm = VectorRM( + collection_name=args.collection_name, + embedding_model=args.embedding_model, + device=args.device, + k=engine_args.search_top_k, + ) # initialize the vector store, either online (store the db on Qdrant server) or offline (store the db locally): - if args.vector_db_mode == 'offline': + if args.vector_db_mode == "offline": rm.init_offline_vector_db(vector_store_path=args.offline_vector_db_dir) - elif args.vector_db_mode == 'online': - rm.init_online_vector_db(url=args.online_vector_db_url, api_key=os.getenv('QDRANT_API_KEY')) + elif args.vector_db_mode == "online": + rm.init_online_vector_db( + url=args.online_vector_db_url, api_key=os.getenv("QDRANT_API_KEY") + ) # Initialize the STORM Wiki Runner runner = STORMWikiRunner(engine_args, engine_lm_configs, rm) # run the pipeline - topic = input('Topic: ') + topic = input("Topic: ") runner.run( topic=topic, do_research=args.do_research, @@ -136,50 +156,120 @@ def main(args): if __name__ == "__main__": parser = ArgumentParser() # global arguments - parser.add_argument('--output-dir', type=str, default='./results/gpt_retrieval', - help='Directory to store the outputs.') - parser.add_argument('--max-thread-num', type=int, default=3, - help='Maximum number of threads to use. The information seeking part and the article generation' - 'part can speed up by using multiple threads. Consider reducing it if keep getting ' - '"Exceed rate limit" error when calling LM API.') + parser.add_argument( + "--output-dir", + type=str, + default="./results/gpt_retrieval", + help="Directory to store the outputs.", + ) + parser.add_argument( + "--max-thread-num", + type=int, + default=3, + help="Maximum number of threads to use. The information seeking part and the article generation" + "part can speed up by using multiple threads. Consider reducing it if keep getting " + '"Exceed rate limit" error when calling LM API.', + ) # provide local corpus and set up vector db - parser.add_argument('--collection-name', type=str, default="my_documents", - help='The collection name for vector store.') - parser.add_argument('--embedding_model', type=str, default="BAAI/bge-m3", - help='The collection name for vector store.') - parser.add_argument('--device', type=str, default="mps", - help='The device used to run the retrieval model (mps, cuda, cpu, etc).') - parser.add_argument('--vector-db-mode', type=str, choices=['offline', 'online'], - help='The mode of the Qdrant vector store (offline or online).') - parser.add_argument('--offline-vector-db-dir', type=str, default='./vector_store', - help='If use offline mode, please provide the directory to store the vector store.') - parser.add_argument('--online-vector-db-url', type=str, - help='If use online mode, please provide the url of the Qdrant server.') - parser.add_argument('--csv-file-path', type=str, default=None, - help='The path of the custom document corpus in CSV format. The CSV file should include ' - 'content, title, url, and description columns.') - parser.add_argument('--embed-batch-size', type=int, default=64, - help='Batch size for embedding the documents in the csv file.') + parser.add_argument( + "--collection-name", + type=str, + default="my_documents", + help="The collection name for vector store.", + ) + parser.add_argument( + "--embedding_model", + type=str, + default="BAAI/bge-m3", + help="The collection name for vector store.", + ) + parser.add_argument( + "--device", + type=str, + default="mps", + help="The device used to run the retrieval model (mps, cuda, cpu, etc).", + ) + parser.add_argument( + "--vector-db-mode", + type=str, + choices=["offline", "online"], + help="The mode of the Qdrant vector store (offline or online).", + ) + parser.add_argument( + "--offline-vector-db-dir", + type=str, + default="./vector_store", + help="If use offline mode, please provide the directory to store the vector store.", + ) + parser.add_argument( + "--online-vector-db-url", + type=str, + help="If use online mode, please provide the url of the Qdrant server.", + ) + parser.add_argument( + "--csv-file-path", + type=str, + default=None, + help="The path of the custom document corpus in CSV format. The CSV file should include " + "content, title, url, and description columns.", + ) + parser.add_argument( + "--embed-batch-size", + type=int, + default=64, + help="Batch size for embedding the documents in the csv file.", + ) # stage of the pipeline - parser.add_argument('--do-research', action='store_true', - help='If True, simulate conversation to research the topic; otherwise, load the results.') - parser.add_argument('--do-generate-outline', action='store_true', - help='If True, generate an outline for the topic; otherwise, load the results.') - parser.add_argument('--do-generate-article', action='store_true', - help='If True, generate an article for the topic; otherwise, load the results.') - parser.add_argument('--do-polish-article', action='store_true', - help='If True, polish the article by adding a summarization section and (optionally) removing ' - 'duplicate content.') + parser.add_argument( + "--do-research", + action="store_true", + help="If True, simulate conversation to research the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-outline", + action="store_true", + help="If True, generate an outline for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-article", + action="store_true", + help="If True, generate an article for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-polish-article", + action="store_true", + help="If True, polish the article by adding a summarization section and (optionally) removing " + "duplicate content.", + ) # hyperparameters for the pre-writing stage - parser.add_argument('--max-conv-turn', type=int, default=3, - help='Maximum number of questions in conversational question asking.') - parser.add_argument('--max-perspective', type=int, default=3, - help='Maximum number of perspectives to consider in perspective-guided question asking.') - parser.add_argument('--search-top-k', type=int, default=3, - help='Top k search results to consider for each search query.') + parser.add_argument( + "--max-conv-turn", + type=int, + default=3, + help="Maximum number of questions in conversational question asking.", + ) + parser.add_argument( + "--max-perspective", + type=int, + default=3, + help="Maximum number of perspectives to consider in perspective-guided question asking.", + ) + parser.add_argument( + "--search-top-k", + type=int, + default=3, + help="Top k search results to consider for each search query.", + ) # hyperparameters for the writing stage - parser.add_argument('--retrieve-top-k', type=int, default=3, - help='Top k collected references for each section title.') - parser.add_argument('--remove-duplicate', action='store_true', - help='If True, remove duplicate content from the article.') - main(parser.parse_args()) \ No newline at end of file + parser.add_argument( + "--retrieve-top-k", + type=int, + default=3, + help="Top k collected references for each section title.", + ) + parser.add_argument( + "--remove-duplicate", + action="store_true", + help="If True, remove duplicate content from the article.", + ) + main(parser.parse_args()) diff --git a/examples/storm_examples/run_storm_wiki_groq.py b/examples/storm_examples/run_storm_wiki_groq.py index 0dcaadbb..de5f6ae2 100644 --- a/examples/storm_examples/run_storm_wiki_groq.py +++ b/examples/storm_examples/run_storm_wiki_groq.py @@ -21,12 +21,24 @@ import re from argparse import ArgumentParser -from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs +from knowledge_storm import ( + STORMWikiRunnerArguments, + STORMWikiRunner, + STORMWikiLMConfigs, +) # Now import lm directly import lm from lm import GroqModel -from knowledge_storm.rm import YouRM, BingSearch, BraveRM, SerperRM, DuckDuckGoSearchRM, TavilySearchRM, SearXNG +from knowledge_storm.rm import ( + YouRM, + BingSearch, + BraveRM, + SerperRM, + DuckDuckGoSearchRM, + TavilySearchRM, + SearXNG, +) from knowledge_storm.utils import load_api_key @@ -36,10 +48,10 @@ def sanitize_topic(topic): Remove or replace characters that are not allowed in file names. """ # Replace spaces with underscores - topic = topic.replace(' ', '_') + topic = topic.replace(" ", "_") # Remove any character that isn't alphanumeric, underscore, or hyphen - topic = re.sub(r'[^a-zA-Z0-9_-]', '', topic) + topic = re.sub(r"[^a-zA-Z0-9_-]", "", topic) # Ensure the topic isn't empty after sanitization if not topic: @@ -49,26 +61,34 @@ def sanitize_topic(topic): def main(args): - load_api_key(toml_file_path='secrets.toml') + load_api_key(toml_file_path="secrets.toml") lm_configs = STORMWikiLMConfigs() # Ensure GROQ_API_KEY is set if not os.getenv("GROQ_API_KEY"): - raise ValueError("GROQ_API_KEY environment variable is not set. Please set it in your secrets.toml file.") + raise ValueError( + "GROQ_API_KEY environment variable is not set. Please set it in your secrets.toml file." + ) groq_kwargs = { - 'api_key': os.getenv("GROQ_API_KEY"), - 'api_base': "https://api.groq.com/openai/v1", - 'temperature': args.temperature, - 'top_p': args.top_p, + "api_key": os.getenv("GROQ_API_KEY"), + "api_base": "https://api.groq.com/openai/v1", + "temperature": args.temperature, + "top_p": args.top_p, } # Groq currently offers the "llama3-70b-8192" model with generous free API credits and the llama3.1 family of models as a preview for paying customers - conv_simulator_lm = GroqModel(model="llama3-70b-8192", max_tokens=500, **groq_kwargs) - question_asker_lm = GroqModel(model="llama3-70b-8192", max_tokens=500, **groq_kwargs) + conv_simulator_lm = GroqModel( + model="llama3-70b-8192", max_tokens=500, **groq_kwargs + ) + question_asker_lm = GroqModel( + model="llama3-70b-8192", max_tokens=500, **groq_kwargs + ) outline_gen_lm = GroqModel(model="llama3-70b-8192", max_tokens=400, **groq_kwargs) article_gen_lm = GroqModel(model="llama3-70b-8192", max_tokens=700, **groq_kwargs) - article_polish_lm = GroqModel(model="llama3-70b-8192", max_tokens=4000, **groq_kwargs) + article_polish_lm = GroqModel( + model="llama3-70b-8192", max_tokens=4000, **groq_kwargs + ) lm_configs.set_conv_simulator_lm(conv_simulator_lm) lm_configs.set_question_asker_lm(question_asker_lm) @@ -87,26 +107,45 @@ def main(args): # STORM is a knowledge curation system which consumes information from the retrieval module. # Currently, the information source is the Internet and we use search engine API as the retrieval module. match args.retriever: - case 'bing': - rm = BingSearch(bing_search_api=os.getenv('BING_SEARCH_API_KEY'), k=engine_args.search_top_k) - case 'you': - rm = YouRM(ydc_api_key=os.getenv('YDC_API_KEY'), k=engine_args.search_top_k) - case 'brave': - rm = BraveRM(brave_search_api_key=os.getenv('BRAVE_API_KEY'), k=engine_args.search_top_k) - case 'duckduckgo': - rm = DuckDuckGoSearchRM(k=engine_args.search_top_k, safe_search='On', region='us-en') - case 'serper': - rm = SerperRM(serper_search_api_key=os.getenv('SERPER_API_KEY'), query_params={'autocorrect': True, 'num': 10, 'page': 1}) - case 'tavily': - rm = TavilySearchRM(tavily_search_api_key=os.getenv('TAVILY_API_KEY'), k=engine_args.search_top_k, include_raw_content=True) - case 'searxng': - rm = SearXNG(searxng_api_key=os.getenv('SEARXNG_API_KEY'), k=engine_args.search_top_k) + case "bing": + rm = BingSearch( + bing_search_api=os.getenv("BING_SEARCH_API_KEY"), + k=engine_args.search_top_k, + ) + case "you": + rm = YouRM(ydc_api_key=os.getenv("YDC_API_KEY"), k=engine_args.search_top_k) + case "brave": + rm = BraveRM( + brave_search_api_key=os.getenv("BRAVE_API_KEY"), + k=engine_args.search_top_k, + ) + case "duckduckgo": + rm = DuckDuckGoSearchRM( + k=engine_args.search_top_k, safe_search="On", region="us-en" + ) + case "serper": + rm = SerperRM( + serper_search_api_key=os.getenv("SERPER_API_KEY"), + query_params={"autocorrect": True, "num": 10, "page": 1}, + ) + case "tavily": + rm = TavilySearchRM( + tavily_search_api_key=os.getenv("TAVILY_API_KEY"), + k=engine_args.search_top_k, + include_raw_content=True, + ) + case "searxng": + rm = SearXNG( + searxng_api_key=os.getenv("SEARXNG_API_KEY"), k=engine_args.search_top_k + ) case _: - raise ValueError(f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", or "searxng"') + raise ValueError( + f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", or "searxng"' + ) runner = STORMWikiRunner(engine_args, lm_configs, rm) - topic = input('Topic: ') + topic = input("Topic: ") sanitized_topic = sanitize_topic(topic) try: @@ -125,42 +164,87 @@ def main(args): raise -if __name__ == '__main__': +if __name__ == "__main__": parser = ArgumentParser() # global arguments - parser.add_argument('--output-dir', type=str, default='./results/groq', - help='Directory to store the outputs.') - parser.add_argument('--max-thread-num', type=int, default=3, - help='Maximum number of threads to use. The information seeking part and the article generation' - 'part can speed up by using multiple threads. Consider reducing it if keep getting ' - '"Exceed rate limit" error when calling LM API.') - parser.add_argument('--retriever', type=str, choices=['bing', 'you', 'brave', 'serper', 'duckduckgo', 'tavily', 'searxng'], - help='The search engine API to use for retrieving information.') - parser.add_argument('--temperature', type=float, default=1.0, - help='Sampling temperature to use.') - parser.add_argument('--top_p', type=float, default=0.9, - help='Top-p sampling parameter.') + parser.add_argument( + "--output-dir", + type=str, + default="./results/groq", + help="Directory to store the outputs.", + ) + parser.add_argument( + "--max-thread-num", + type=int, + default=3, + help="Maximum number of threads to use. The information seeking part and the article generation" + "part can speed up by using multiple threads. Consider reducing it if keep getting " + '"Exceed rate limit" error when calling LM API.', + ) + parser.add_argument( + "--retriever", + type=str, + choices=["bing", "you", "brave", "serper", "duckduckgo", "tavily", "searxng"], + help="The search engine API to use for retrieving information.", + ) + parser.add_argument( + "--temperature", type=float, default=1.0, help="Sampling temperature to use." + ) + parser.add_argument( + "--top_p", type=float, default=0.9, help="Top-p sampling parameter." + ) # stage of the pipeline - parser.add_argument('--do-research', action='store_true', - help='If True, simulate conversation to research the topic; otherwise, load the results.') - parser.add_argument('--do-generate-outline', action='store_true', - help='If True, generate an outline for the topic; otherwise, load the results.') - parser.add_argument('--do-generate-article', action='store_true', - help='If True, generate an article for the topic; otherwise, load the results.') - parser.add_argument('--do-polish-article', action='store_true', - help='If True, polish the article by adding a summarization section and (optionally) removing ' - 'duplicate content.') + parser.add_argument( + "--do-research", + action="store_true", + help="If True, simulate conversation to research the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-outline", + action="store_true", + help="If True, generate an outline for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-article", + action="store_true", + help="If True, generate an article for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-polish-article", + action="store_true", + help="If True, polish the article by adding a summarization section and (optionally) removing " + "duplicate content.", + ) # hyperparameters for the pre-writing stage - parser.add_argument('--max-conv-turn', type=int, default=3, - help='Maximum number of questions in conversational question asking.') - parser.add_argument('--max-perspective', type=int, default=3, - help='Maximum number of perspectives to consider in perspective-guided question asking.') - parser.add_argument('--search-top-k', type=int, default=3, - help='Top k search results to consider for each search query.') + parser.add_argument( + "--max-conv-turn", + type=int, + default=3, + help="Maximum number of questions in conversational question asking.", + ) + parser.add_argument( + "--max-perspective", + type=int, + default=3, + help="Maximum number of perspectives to consider in perspective-guided question asking.", + ) + parser.add_argument( + "--search-top-k", + type=int, + default=3, + help="Top k search results to consider for each search query.", + ) # hyperparameters for the writing stage - parser.add_argument('--retrieve-top-k', type=int, default=3, - help='Top k collected references for each section title.') - parser.add_argument('--remove-duplicate', action='store_true', - help='If True, remove duplicate content from the article.') + parser.add_argument( + "--retrieve-top-k", + type=int, + default=3, + help="Top k collected references for each section title.", + ) + parser.add_argument( + "--remove-duplicate", + action="store_true", + help="If True, remove duplicate content from the article.", + ) - main(parser.parse_args()) \ No newline at end of file + main(parser.parse_args()) diff --git a/examples/storm_examples/run_storm_wiki_mistral.py b/examples/storm_examples/run_storm_wiki_mistral.py index ceabd470..e9475413 100644 --- a/examples/storm_examples/run_storm_wiki_mistral.py +++ b/examples/storm_examples/run_storm_wiki_mistral.py @@ -15,26 +15,41 @@ storm_gen_article.txt # Final article generated storm_gen_article_polished.txt # Polished final article (if args.do_polish_article is True) """ + import os from argparse import ArgumentParser from dspy import Example -from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs +from knowledge_storm import ( + STORMWikiRunnerArguments, + STORMWikiRunner, + STORMWikiLMConfigs, +) from knowledge_storm.lm import VLLMClient -from knowledge_storm.rm import YouRM, BingSearch, BraveRM, SerperRM, DuckDuckGoSearchRM, TavilySearchRM, SearXNG +from knowledge_storm.rm import ( + YouRM, + BingSearch, + BraveRM, + SerperRM, + DuckDuckGoSearchRM, + TavilySearchRM, + SearXNG, +) from knowledge_storm.utils import load_api_key def main(args): - load_api_key(toml_file_path='secrets.toml') + load_api_key(toml_file_path="secrets.toml") lm_configs = STORMWikiLMConfigs() mistral_kwargs = { "model": "mistralai/Mistral-7B-Instruct-v0.2", "port": args.port, "url": args.url, - "stop": ('\n\n---',) # dspy uses "\n\n---" to separate examples. Open models sometimes generate this. + "stop": ( + "\n\n---", + ), # dspy uses "\n\n---" to separate examples. Open models sometimes generate this. } conv_simulator_lm = VLLMClient(max_tokens=500, **mistral_kwargs) @@ -60,22 +75,41 @@ def main(args): # STORM is a knowledge curation system which consumes information from the retrieval module. # Currently, the information source is the Internet and we use search engine API as the retrieval module. match args.retriever: - case 'bing': - rm = BingSearch(bing_search_api=os.getenv('BING_SEARCH_API_KEY'), k=engine_args.search_top_k) - case 'you': - rm = YouRM(ydc_api_key=os.getenv('YDC_API_KEY'), k=engine_args.search_top_k) - case 'brave': - rm = BraveRM(brave_search_api_key=os.getenv('BRAVE_API_KEY'), k=engine_args.search_top_k) - case 'duckduckgo': - rm = DuckDuckGoSearchRM(k=engine_args.search_top_k, safe_search='On', region='us-en') - case 'serper': - rm = SerperRM(serper_search_api_key=os.getenv('SERPER_API_KEY'), query_params={'autocorrect': True, 'num': 10, 'page': 1}) - case 'tavily': - rm = TavilySearchRM(tavily_search_api_key=os.getenv('TAVILY_API_KEY'), k=engine_args.search_top_k, include_raw_content=True) - case 'searxng': - rm = SearXNG(searxng_api_key=os.getenv('SEARXNG_API_KEY'), k=engine_args.search_top_k) + case "bing": + rm = BingSearch( + bing_search_api=os.getenv("BING_SEARCH_API_KEY"), + k=engine_args.search_top_k, + ) + case "you": + rm = YouRM(ydc_api_key=os.getenv("YDC_API_KEY"), k=engine_args.search_top_k) + case "brave": + rm = BraveRM( + brave_search_api_key=os.getenv("BRAVE_API_KEY"), + k=engine_args.search_top_k, + ) + case "duckduckgo": + rm = DuckDuckGoSearchRM( + k=engine_args.search_top_k, safe_search="On", region="us-en" + ) + case "serper": + rm = SerperRM( + serper_search_api_key=os.getenv("SERPER_API_KEY"), + query_params={"autocorrect": True, "num": 10, "page": 1}, + ) + case "tavily": + rm = TavilySearchRM( + tavily_search_api_key=os.getenv("TAVILY_API_KEY"), + k=engine_args.search_top_k, + include_raw_content=True, + ) + case "searxng": + rm = SearXNG( + searxng_api_key=os.getenv("SEARXNG_API_KEY"), k=engine_args.search_top_k + ) case _: - raise ValueError(f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", or "searxng"') + raise ValueError( + f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", or "searxng"' + ) runner = STORMWikiRunner(engine_args, lm_configs, rm) @@ -87,26 +121,28 @@ def main(args): find_related_topic_example = Example( topic="Knowledge Curation", related_topics="https://en.wikipedia.org/wiki/Knowledge_management\n" - "https://en.wikipedia.org/wiki/Information_science\n" - "https://en.wikipedia.org/wiki/Library_science\n" + "https://en.wikipedia.org/wiki/Information_science\n" + "https://en.wikipedia.org/wiki/Library_science\n", ) gen_persona_example = Example( topic="Knowledge Curation", examples="Title: Knowledge management\n" - "Table of Contents: History\nResearch\n Dimensions\n Strategies\n Motivations\nKM technologies" - "\nKnowledge barriers\nKnowledge retention\nKnowledge audit\nKnowledge protection\n" - " Knowledge protection methods\n Formal methods\n Informal methods\n" - " Balancing knowledge protection and knowledge sharing\n Knowledge protection risks", + "Table of Contents: History\nResearch\n Dimensions\n Strategies\n Motivations\nKM technologies" + "\nKnowledge barriers\nKnowledge retention\nKnowledge audit\nKnowledge protection\n" + " Knowledge protection methods\n Formal methods\n Informal methods\n" + " Balancing knowledge protection and knowledge sharing\n Knowledge protection risks", personas="1. Historian of Knowledge Systems: This editor will focus on the history and evolution of knowledge curation. They will provide context on how knowledge curation has changed over time and its impact on modern practices.\n" - "2. Information Science Professional: With insights from 'Information science', this editor will explore the foundational theories, definitions, and philosophy that underpin knowledge curation\n" - "3. Digital Librarian: This editor will delve into the specifics of how digital libraries operate, including software, metadata, digital preservation.\n" - "4. Technical expert: This editor will focus on the technical aspects of knowledge curation, such as common features of content management systems.\n" - "5. Museum Curator: The museum curator will contribute expertise on the curation of physical items and the transition of these practices into the digital realm." + "2. Information Science Professional: With insights from 'Information science', this editor will explore the foundational theories, definitions, and philosophy that underpin knowledge curation\n" + "3. Digital Librarian: This editor will delve into the specifics of how digital libraries operate, including software, metadata, digital preservation.\n" + "4. Technical expert: This editor will focus on the technical aspects of knowledge curation, such as common features of content management systems.\n" + "5. Museum Curator: The museum curator will contribute expertise on the curation of physical items and the transition of these practices into the digital realm.", ) runner.storm_knowledge_curation_module.persona_generator.create_writer_with_persona.find_related_topic.demos = [ - find_related_topic_example] + find_related_topic_example + ] runner.storm_knowledge_curation_module.persona_generator.create_writer_with_persona.gen_persona.demos = [ - gen_persona_example] + gen_persona_example + ] # A trade-off of adding one-shot example is that it will increase the input length of the prompt. Also, some # examples may be very long (e.g., an example for writing a section based on the given information), which may @@ -118,24 +154,28 @@ def main(args): topic="Example Topic", conv="Wikipedia Writer: ...\nExpert: ...\nWikipedia Writer: ...\nExpert: ...", old_outline="# Section 1\n## Subsection 1\n## Subsection 2\n" - "# Section 2\n## Subsection 1\n## Subsection 2\n" - "# Section 3", + "# Section 2\n## Subsection 1\n## Subsection 2\n" + "# Section 3", outline="# New Section 1\n## New Subsection 1\n## New Subsection 2\n" - "# New Section 2\n" - "# New Section 3\n## New Subsection 1\n## New Subsection 2\n## New Subsection 3" + "# New Section 2\n" + "# New Section 3\n## New Subsection 1\n## New Subsection 2\n## New Subsection 3", ) - runner.storm_outline_generation_module.write_outline.write_page_outline.demos = [write_page_outline_example] + runner.storm_outline_generation_module.write_outline.write_page_outline.demos = [ + write_page_outline_example + ] write_section_example = Example( info="[1]\nInformation in document 1\n[2]\nInformation in document 2\n[3]\nInformation in document 3", topic="Example Topic", section="Example Section", output="# Example Topic\n## Subsection 1\n" - "This is an example sentence [1]. This is another example sentence [2][3].\n" - "## Subsection 2\nThis is one more example sentence [1]." + "This is an example sentence [1]. This is another example sentence [2][3].\n" + "## Subsection 2\nThis is one more example sentence [1].", ) - runner.storm_article_generation.section_gen.write_section.demos = [write_section_example] + runner.storm_article_generation.section_gen.write_section.demos = [ + write_section_example + ] - topic = input('Topic: ') + topic = input("Topic: ") runner.run( topic=topic, do_research=args.do_research, @@ -147,42 +187,87 @@ def main(args): runner.summary() -if __name__ == '__main__': +if __name__ == "__main__": parser = ArgumentParser() # global arguments - parser.add_argument('--url', type=str, default='http://localhost', - help='URL of the VLLM server.') - parser.add_argument('--port', type=int, default=8000, - help='Port of the VLLM server.') - parser.add_argument('--output-dir', type=str, default='./results/mistral_7b', - help='Directory to store the outputs.') - parser.add_argument('--max-thread-num', type=int, default=3, - help='Maximum number of threads to use. The information seeking part and the article generation' - 'part can speed up by using multiple threads. Consider reducing it if keep getting ' - '"Exceed rate limit" error when calling LM API.') - parser.add_argument('--retriever', type=str, choices=['bing', 'you', 'brave', 'serper', 'duckduckgo', 'tavily', 'searxng'], - help='The search engine API to use for retrieving information.') + parser.add_argument( + "--url", type=str, default="http://localhost", help="URL of the VLLM server." + ) + parser.add_argument( + "--port", type=int, default=8000, help="Port of the VLLM server." + ) + parser.add_argument( + "--output-dir", + type=str, + default="./results/mistral_7b", + help="Directory to store the outputs.", + ) + parser.add_argument( + "--max-thread-num", + type=int, + default=3, + help="Maximum number of threads to use. The information seeking part and the article generation" + "part can speed up by using multiple threads. Consider reducing it if keep getting " + '"Exceed rate limit" error when calling LM API.', + ) + parser.add_argument( + "--retriever", + type=str, + choices=["bing", "you", "brave", "serper", "duckduckgo", "tavily", "searxng"], + help="The search engine API to use for retrieving information.", + ) # stage of the pipeline - parser.add_argument('--do-research', action='store_true', - help='If True, simulate conversation to research the topic; otherwise, load the results.') - parser.add_argument('--do-generate-outline', action='store_true', - help='If True, generate an outline for the topic; otherwise, load the results.') - parser.add_argument('--do-generate-article', action='store_true', - help='If True, generate an article for the topic; otherwise, load the results.') - parser.add_argument('--do-polish-article', action='store_true', - help='If True, polish the article by adding a summarization section and (optionally) removing ' - 'duplicate content.') + parser.add_argument( + "--do-research", + action="store_true", + help="If True, simulate conversation to research the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-outline", + action="store_true", + help="If True, generate an outline for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-article", + action="store_true", + help="If True, generate an article for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-polish-article", + action="store_true", + help="If True, polish the article by adding a summarization section and (optionally) removing " + "duplicate content.", + ) # hyperparameters for the pre-writing stage - parser.add_argument('--max-conv-turn', type=int, default=3, - help='Maximum number of questions in conversational question asking.') - parser.add_argument('--max-perspective', type=int, default=3, - help='Maximum number of perspectives to consider in perspective-guided question asking.') - parser.add_argument('--search-top-k', type=int, default=3, - help='Top k search results to consider for each search query.') + parser.add_argument( + "--max-conv-turn", + type=int, + default=3, + help="Maximum number of questions in conversational question asking.", + ) + parser.add_argument( + "--max-perspective", + type=int, + default=3, + help="Maximum number of perspectives to consider in perspective-guided question asking.", + ) + parser.add_argument( + "--search-top-k", + type=int, + default=3, + help="Top k search results to consider for each search query.", + ) # hyperparameters for the writing stage - parser.add_argument('--retrieve-top-k', type=int, default=3, - help='Top k collected references for each section title.') - parser.add_argument('--remove-duplicate', action='store_true', - help='If True, remove duplicate content from the article.') + parser.add_argument( + "--retrieve-top-k", + type=int, + default=3, + help="Top k collected references for each section title.", + ) + parser.add_argument( + "--remove-duplicate", + action="store_true", + help="If True, remove duplicate content from the article.", + ) - main(parser.parse_args()) \ No newline at end of file + main(parser.parse_args()) diff --git a/examples/storm_examples/run_storm_wiki_ollama.py b/examples/storm_examples/run_storm_wiki_ollama.py index 465e065a..8d2e2b53 100644 --- a/examples/storm_examples/run_storm_wiki_ollama.py +++ b/examples/storm_examples/run_storm_wiki_ollama.py @@ -15,6 +15,7 @@ storm_gen_article.txt # Final article generated storm_gen_article_polished.txt # Polished final article (if args.do_polish_article is True) """ + import os import sys from argparse import ArgumentParser @@ -22,20 +23,34 @@ from dspy import Example from knowledge_storm.lm import OllamaClient -from knowledge_storm.rm import YouRM, BingSearch, BraveRM, SerperRM, DuckDuckGoSearchRM, TavilySearchRM, SearXNG -from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs +from knowledge_storm.rm import ( + YouRM, + BingSearch, + BraveRM, + SerperRM, + DuckDuckGoSearchRM, + TavilySearchRM, + SearXNG, +) +from knowledge_storm import ( + STORMWikiRunnerArguments, + STORMWikiRunner, + STORMWikiLMConfigs, +) from knowledge_storm.utils import load_api_key def main(args): - load_api_key(toml_file_path='secrets.toml') + load_api_key(toml_file_path="secrets.toml") lm_configs = STORMWikiLMConfigs() ollama_kwargs = { "model": args.model, "port": args.port, "url": args.url, - "stop": ('\n\n---',) # dspy uses "\n\n---" to separate examples. Open models sometimes generate this. + "stop": ( + "\n\n---", + ), # dspy uses "\n\n---" to separate examples. Open models sometimes generate this. } conv_simulator_lm = OllamaClient(max_tokens=500, **ollama_kwargs) @@ -61,22 +76,41 @@ def main(args): # STORM is a knowledge curation system which consumes information from the retrieval module. # Currently, the information source is the Internet and we use search engine API as the retrieval module. match args.retriever: - case 'bing': - rm = BingSearch(bing_search_api=os.getenv('BING_SEARCH_API_KEY'), k=engine_args.search_top_k) - case 'you': - rm = YouRM(ydc_api_key=os.getenv('YDC_API_KEY'), k=engine_args.search_top_k) - case 'brave': - rm = BraveRM(brave_search_api_key=os.getenv('BRAVE_API_KEY'), k=engine_args.search_top_k) - case 'duckduckgo': - rm = DuckDuckGoSearchRM(k=engine_args.search_top_k, safe_search='On', region='us-en') - case 'serper': - rm = SerperRM(serper_search_api_key=os.getenv('SERPER_API_KEY'), query_params={'autocorrect': True, 'num': 10, 'page': 1}) - case 'tavily': - rm = TavilySearchRM(tavily_search_api_key=os.getenv('TAVILY_API_KEY'), k=engine_args.search_top_k, include_raw_content=True) - case 'searxng': - rm = SearXNG(searxng_api_key=os.getenv('SEARXNG_API_KEY'), k=engine_args.search_top_k) + case "bing": + rm = BingSearch( + bing_search_api=os.getenv("BING_SEARCH_API_KEY"), + k=engine_args.search_top_k, + ) + case "you": + rm = YouRM(ydc_api_key=os.getenv("YDC_API_KEY"), k=engine_args.search_top_k) + case "brave": + rm = BraveRM( + brave_search_api_key=os.getenv("BRAVE_API_KEY"), + k=engine_args.search_top_k, + ) + case "duckduckgo": + rm = DuckDuckGoSearchRM( + k=engine_args.search_top_k, safe_search="On", region="us-en" + ) + case "serper": + rm = SerperRM( + serper_search_api_key=os.getenv("SERPER_API_KEY"), + query_params={"autocorrect": True, "num": 10, "page": 1}, + ) + case "tavily": + rm = TavilySearchRM( + tavily_search_api_key=os.getenv("TAVILY_API_KEY"), + k=engine_args.search_top_k, + include_raw_content=True, + ) + case "searxng": + rm = SearXNG( + searxng_api_key=os.getenv("SEARXNG_API_KEY"), k=engine_args.search_top_k + ) case _: - raise ValueError(f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", or "searxng"') + raise ValueError( + f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", or "searxng"' + ) runner = STORMWikiRunner(engine_args, lm_configs, rm) @@ -88,26 +122,28 @@ def main(args): find_related_topic_example = Example( topic="Knowledge Curation", related_topics="https://en.wikipedia.org/wiki/Knowledge_management\n" - "https://en.wikipedia.org/wiki/Information_science\n" - "https://en.wikipedia.org/wiki/Library_science\n" + "https://en.wikipedia.org/wiki/Information_science\n" + "https://en.wikipedia.org/wiki/Library_science\n", ) gen_persona_example = Example( topic="Knowledge Curation", examples="Title: Knowledge management\n" - "Table of Contents: History\nResearch\n Dimensions\n Strategies\n Motivations\nKM technologies" - "\nKnowledge barriers\nKnowledge retention\nKnowledge audit\nKnowledge protection\n" - " Knowledge protection methods\n Formal methods\n Informal methods\n" - " Balancing knowledge protection and knowledge sharing\n Knowledge protection risks", + "Table of Contents: History\nResearch\n Dimensions\n Strategies\n Motivations\nKM technologies" + "\nKnowledge barriers\nKnowledge retention\nKnowledge audit\nKnowledge protection\n" + " Knowledge protection methods\n Formal methods\n Informal methods\n" + " Balancing knowledge protection and knowledge sharing\n Knowledge protection risks", personas="1. Historian of Knowledge Systems: This editor will focus on the history and evolution of knowledge curation. They will provide context on how knowledge curation has changed over time and its impact on modern practices.\n" - "2. Information Science Professional: With insights from 'Information science', this editor will explore the foundational theories, definitions, and philosophy that underpin knowledge curation\n" - "3. Digital Librarian: This editor will delve into the specifics of how digital libraries operate, including software, metadata, digital preservation.\n" - "4. Technical expert: This editor will focus on the technical aspects of knowledge curation, such as common features of content management systems.\n" - "5. Museum Curator: The museum curator will contribute expertise on the curation of physical items and the transition of these practices into the digital realm." + "2. Information Science Professional: With insights from 'Information science', this editor will explore the foundational theories, definitions, and philosophy that underpin knowledge curation\n" + "3. Digital Librarian: This editor will delve into the specifics of how digital libraries operate, including software, metadata, digital preservation.\n" + "4. Technical expert: This editor will focus on the technical aspects of knowledge curation, such as common features of content management systems.\n" + "5. Museum Curator: The museum curator will contribute expertise on the curation of physical items and the transition of these practices into the digital realm.", ) runner.storm_knowledge_curation_module.persona_generator.create_writer_with_persona.find_related_topic.demos = [ - find_related_topic_example] + find_related_topic_example + ] runner.storm_knowledge_curation_module.persona_generator.create_writer_with_persona.gen_persona.demos = [ - gen_persona_example] + gen_persona_example + ] # A trade-off of adding one-shot example is that it will increase the input length of the prompt. Also, some # examples may be very long (e.g., an example for writing a section based on the given information), which may @@ -119,24 +155,28 @@ def main(args): topic="Example Topic", conv="Wikipedia Writer: ...\nExpert: ...\nWikipedia Writer: ...\nExpert: ...", old_outline="# Section 1\n## Subsection 1\n## Subsection 2\n" - "# Section 2\n## Subsection 1\n## Subsection 2\n" - "# Section 3", + "# Section 2\n## Subsection 1\n## Subsection 2\n" + "# Section 3", outline="# New Section 1\n## New Subsection 1\n## New Subsection 2\n" - "# New Section 2\n" - "# New Section 3\n## New Subsection 1\n## New Subsection 2\n## New Subsection 3" + "# New Section 2\n" + "# New Section 3\n## New Subsection 1\n## New Subsection 2\n## New Subsection 3", ) - runner.storm_outline_generation_module.write_outline.write_page_outline.demos = [write_page_outline_example] + runner.storm_outline_generation_module.write_outline.write_page_outline.demos = [ + write_page_outline_example + ] write_section_example = Example( info="[1]\nInformation in document 1\n[2]\nInformation in document 2\n[3]\nInformation in document 3", topic="Example Topic", section="Example Section", output="# Example Topic\n## Subsection 1\n" - "This is an example sentence [1]. This is another example sentence [2][3].\n" - "## Subsection 2\nThis is one more example sentence [1]." + "This is an example sentence [1]. This is another example sentence [2][3].\n" + "## Subsection 2\nThis is one more example sentence [1].", ) - runner.storm_article_generation.section_gen.write_section.demos = [write_section_example] + runner.storm_article_generation.section_gen.write_section.demos = [ + write_section_example + ] - topic = input('Topic: ') + topic = input("Topic: ") runner.run( topic=topic, do_research=args.do_research, @@ -148,44 +188,90 @@ def main(args): runner.summary() -if __name__ == '__main__': +if __name__ == "__main__": parser = ArgumentParser() # global arguments - parser.add_argument('--url', type=str, default='http://localhost', - help='URL of the Ollama server.') - parser.add_argument('--port', type=int, default=11434, - help='Port of the Ollama server.') - parser.add_argument('--model', type=str, default='llama3:latest', - help='Model of the Ollama server.') - parser.add_argument('--output-dir', type=str, default='./results/ollama', - help='Directory to store the outputs.') - parser.add_argument('--max-thread-num', type=int, default=3, - help='Maximum number of threads to use. The information seeking part and the article generation' - 'part can speed up by using multiple threads. Consider reducing it if keep getting ' - '"Exceed rate limit" error when calling LM API.') - parser.add_argument('--retriever', type=str, choices=['bing', 'you', 'brave', 'serper', 'duckduckgo', 'tavily', 'searxng'], - help='The search engine API to use for retrieving information.') + parser.add_argument( + "--url", type=str, default="http://localhost", help="URL of the Ollama server." + ) + parser.add_argument( + "--port", type=int, default=11434, help="Port of the Ollama server." + ) + parser.add_argument( + "--model", type=str, default="llama3:latest", help="Model of the Ollama server." + ) + parser.add_argument( + "--output-dir", + type=str, + default="./results/ollama", + help="Directory to store the outputs.", + ) + parser.add_argument( + "--max-thread-num", + type=int, + default=3, + help="Maximum number of threads to use. The information seeking part and the article generation" + "part can speed up by using multiple threads. Consider reducing it if keep getting " + '"Exceed rate limit" error when calling LM API.', + ) + parser.add_argument( + "--retriever", + type=str, + choices=["bing", "you", "brave", "serper", "duckduckgo", "tavily", "searxng"], + help="The search engine API to use for retrieving information.", + ) # stage of the pipeline - parser.add_argument('--do-research', action='store_true', - help='If True, simulate conversation to research the topic; otherwise, load the results.') - parser.add_argument('--do-generate-outline', action='store_true', - help='If True, generate an outline for the topic; otherwise, load the results.') - parser.add_argument('--do-generate-article', action='store_true', - help='If True, generate an article for the topic; otherwise, load the results.') - parser.add_argument('--do-polish-article', action='store_true', - help='If True, polish the article by adding a summarization section and (optionally) removing ' - 'duplicate content.') + parser.add_argument( + "--do-research", + action="store_true", + help="If True, simulate conversation to research the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-outline", + action="store_true", + help="If True, generate an outline for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-article", + action="store_true", + help="If True, generate an article for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-polish-article", + action="store_true", + help="If True, polish the article by adding a summarization section and (optionally) removing " + "duplicate content.", + ) # hyperparameters for the pre-writing stage - parser.add_argument('--max-conv-turn', type=int, default=3, - help='Maximum number of questions in conversational question asking.') - parser.add_argument('--max-perspective', type=int, default=3, - help='Maximum number of perspectives to consider in perspective-guided question asking.') - parser.add_argument('--search-top-k', type=int, default=3, - help='Top k search results to consider for each search query.') + parser.add_argument( + "--max-conv-turn", + type=int, + default=3, + help="Maximum number of questions in conversational question asking.", + ) + parser.add_argument( + "--max-perspective", + type=int, + default=3, + help="Maximum number of perspectives to consider in perspective-guided question asking.", + ) + parser.add_argument( + "--search-top-k", + type=int, + default=3, + help="Top k search results to consider for each search query.", + ) # hyperparameters for the writing stage - parser.add_argument('--retrieve-top-k', type=int, default=3, - help='Top k collected references for each section title.') - parser.add_argument('--remove-duplicate', action='store_true', - help='If True, remove duplicate content from the article.') + parser.add_argument( + "--retrieve-top-k", + type=int, + default=3, + help="Top k collected references for each section title.", + ) + parser.add_argument( + "--remove-duplicate", + action="store_true", + help="If True, remove duplicate content from the article.", + ) - main(parser.parse_args()) \ No newline at end of file + main(parser.parse_args()) diff --git a/examples/storm_examples/run_storm_wiki_ollama_with_searxng.py b/examples/storm_examples/run_storm_wiki_ollama_with_searxng.py index a4be3624..725a68b9 100644 --- a/examples/storm_examples/run_storm_wiki_ollama_with_searxng.py +++ b/examples/storm_examples/run_storm_wiki_ollama_with_searxng.py @@ -3,21 +3,25 @@ from dspy import Example -from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs +from knowledge_storm import ( + STORMWikiRunnerArguments, + STORMWikiRunner, + STORMWikiLMConfigs, +) from knowledge_storm.lm import OllamaClient from knowledge_storm.rm import SearXNG from knowledge_storm.utils import load_api_key def main(args): - load_api_key(toml_file_path='secrets.toml') + load_api_key(toml_file_path="secrets.toml") lm_configs = STORMWikiLMConfigs() ollama_kwargs = { "model": args.model, "port": args.port, "url": args.url, - "stop": ('\n\n---',) + "stop": ("\n\n---",), } conv_simulator_lm = OllamaClient(max_tokens=500, **ollama_kwargs) @@ -40,23 +44,27 @@ def main(args): max_thread_num=args.max_thread_num, ) - rm = SearXNG(searxng_api_url=args.searxng_api_url, searxng_api_key=os.getenv('SEARXNG_API_KEY'), k=engine_args.search_top_k) + rm = SearXNG( + searxng_api_url=args.searxng_api_url, + searxng_api_key=os.getenv("SEARXNG_API_KEY"), + k=engine_args.search_top_k, + ) runner = STORMWikiRunner(engine_args, lm_configs, rm) find_related_topic_example = Example( topic="Knowledge Curation", related_topics="https://en.wikipedia.org/wiki/Knowledge_management\n" - "https://en.wikipedia.org/wiki/Information_science\n" - "https://en.wikipedia.org/wiki/Library_science\n" + "https://en.wikipedia.org/wiki/Information_science\n" + "https://en.wikipedia.org/wiki/Library_science\n", ) gen_persona_example = Example( topic="Knowledge Curation", examples="Title: Knowledge management\n" - "Table of Contents: History\nResearch\n Dimensions\n Strategies\n Motivations\nKM technologies" - "\nKnowledge barriers\nKnowledge retention\nKnowledge audit\nKnowledge protection\n" - " Knowledge protection methods\n Formal methods\n Informal methods\n" - " Balancing knowledge protection and knowledge sharing\n Knowledge protection risks", + "Table of Contents: History\nResearch\n Dimensions\n Strategies\n Motivations\nKM technologies" + "\nKnowledge barriers\nKnowledge retention\nKnowledge audit\nKnowledge protection\n" + " Knowledge protection methods\n Formal methods\n Informal methods\n" + " Balancing knowledge protection and knowledge sharing\n Knowledge protection risks", personas=( "1. Historian of Knowledge Systems: This editor will focus on the history and evolution of knowledge " "curation. They will provide context on how knowledge curation has changed over time and its impact on " @@ -69,35 +77,41 @@ def main(args): "such as common features of content management systems.\n" "5. Museum Curator: The museum curator will contribute expertise on the curation of physical items and " "the transition of these practices into the digital realm." - ) + ), ) runner.storm_knowledge_curation_module.persona_generator.create_writer_with_persona.find_related_topic.demos = [ - find_related_topic_example] + find_related_topic_example + ] runner.storm_knowledge_curation_module.persona_generator.create_writer_with_persona.gen_persona.demos = [ - gen_persona_example] + gen_persona_example + ] write_page_outline_example = Example( topic="Example Topic", conv="Wikipedia Writer: ...\nExpert: ...\nWikipedia Writer: ...\nExpert: ...", old_outline="# Section 1\n## Subsection 1\n## Subsection 2\n" - "# Section 2\n## Subsection 1\n## Subsection 2\n" - "# Section 3", + "# Section 2\n## Subsection 1\n## Subsection 2\n" + "# Section 3", outline="# New Section 1\n## New Subsection 1\n## New Subsection 2\n" - "# New Section 2\n" - "# New Section 3\n## New Subsection 1\n## New Subsection 2\n## New Subsection 3" + "# New Section 2\n" + "# New Section 3\n## New Subsection 1\n## New Subsection 2\n## New Subsection 3", ) - runner.storm_outline_generation_module.write_outline.write_page_outline.demos = [write_page_outline_example] + runner.storm_outline_generation_module.write_outline.write_page_outline.demos = [ + write_page_outline_example + ] write_section_example = Example( info="[1]\nInformation in document 1\n[2]\nInformation in document 2\n[3]\nInformation in document 3", topic="Example Topic", section="Example Section", output="# Example Topic\n## Subsection 1\n" - "This is an example sentence [1]. This is another example sentence [2][3].\n" - "## Subsection 2\nThis is one more example sentence [1]." + "This is an example sentence [1]. This is another example sentence [2][3].\n" + "## Subsection 2\nThis is one more example sentence [1].", ) - runner.storm_article_generation.section_gen.write_section.demos = [write_section_example] + runner.storm_article_generation.section_gen.write_section.demos = [ + write_section_example + ] - topic = input('Topic: ') + topic = input("Topic: ") runner.run( topic=topic, do_research=args.do_research, @@ -109,46 +123,93 @@ def main(args): runner.summary() -if __name__ == '__main__': +if __name__ == "__main__": parser = ArgumentParser() # global arguments - parser.add_argument('--url', type=str, default='http://localhost', - help='URL of the Ollama server.') - parser.add_argument('--port', type=int, default=11434, - help='Port of the Ollama server.') - parser.add_argument('--model', type=str, default='llama3:latest', - help='Model of the Ollama server.') - parser.add_argument('--output-dir', type=str, default='./results/ollama', - help='Directory to store the outputs.') - parser.add_argument('--max-thread-num', type=int, default=3, - help='Maximum number of threads to use. The information seeking part and the article generation' - 'part can speed up by using multiple threads. Consider reducing it if keep getting ' - '"Exceed rate limit" error when calling LM API.') - parser.add_argument('--retriever', type=str, choices=['searxng'], - help='The search engine API to use for retrieving information.') - parser.add_argument('--searxng-api-url', type=str, required=True, - help='URL of the SearXNG API.') + parser.add_argument( + "--url", type=str, default="http://localhost", help="URL of the Ollama server." + ) + parser.add_argument( + "--port", type=int, default=11434, help="Port of the Ollama server." + ) + parser.add_argument( + "--model", type=str, default="llama3:latest", help="Model of the Ollama server." + ) + parser.add_argument( + "--output-dir", + type=str, + default="./results/ollama", + help="Directory to store the outputs.", + ) + parser.add_argument( + "--max-thread-num", + type=int, + default=3, + help="Maximum number of threads to use. The information seeking part and the article generation" + "part can speed up by using multiple threads. Consider reducing it if keep getting " + '"Exceed rate limit" error when calling LM API.', + ) + parser.add_argument( + "--retriever", + type=str, + choices=["searxng"], + help="The search engine API to use for retrieving information.", + ) + parser.add_argument( + "--searxng-api-url", type=str, required=True, help="URL of the SearXNG API." + ) # stage of the pipeline - parser.add_argument('--do-research', action='store_true', - help='If True, simulate conversation to research the topic; otherwise, load the results.') - parser.add_argument('--do-generate-outline', action='store_true', - help='If True, generate an outline for the topic; otherwise, load the results.') - parser.add_argument('--do-generate-article', action='store_true', - help='If True, generate an article for the topic; otherwise, load the results.') - parser.add_argument('--do-polish-article', action='store_true', - help='If True, polish the article by adding a summarization section and (optionally) removing ' - 'duplicate content.') + parser.add_argument( + "--do-research", + action="store_true", + help="If True, simulate conversation to research the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-outline", + action="store_true", + help="If True, generate an outline for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-article", + action="store_true", + help="If True, generate an article for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-polish-article", + action="store_true", + help="If True, polish the article by adding a summarization section and (optionally) removing " + "duplicate content.", + ) # hyperparameters for the pre-writing stage - parser.add_argument('--max-conv-turn', type=int, default=3, - help='Maximum number of questions in conversational question asking.') - parser.add_argument('--max-perspective', type=int, default=3, - help='Maximum number of perspectives to consider in perspective-guided question asking.') - parser.add_argument('--search-top-k', type=int, default=3, - help='Top k search results to consider for each search query.') + parser.add_argument( + "--max-conv-turn", + type=int, + default=3, + help="Maximum number of questions in conversational question asking.", + ) + parser.add_argument( + "--max-perspective", + type=int, + default=3, + help="Maximum number of perspectives to consider in perspective-guided question asking.", + ) + parser.add_argument( + "--search-top-k", + type=int, + default=3, + help="Top k search results to consider for each search query.", + ) # hyperparameters for the writing stage - parser.add_argument('--retrieve-top-k', type=int, default=3, - help='Top k collected references for each section title.') - parser.add_argument('--remove-duplicate', action='store_true', - help='If True, remove duplicate content from the article.') + parser.add_argument( + "--retrieve-top-k", + type=int, + default=3, + help="Top k collected references for each section title.", + ) + parser.add_argument( + "--remove-duplicate", + action="store_true", + help="If True, remove duplicate content from the article.", + ) main(parser.parse_args()) diff --git a/frontend/demo_light/demo_util.py b/frontend/demo_light/demo_util.py index 22258dda..cff1d9f2 100644 --- a/frontend/demo_light/demo_util.py +++ b/frontend/demo_light/demo_util.py @@ -13,7 +13,11 @@ # Uncomment the following lines: # import sys # sys.path.append('../../') -from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs +from knowledge_storm import ( + STORMWikiRunnerArguments, + STORMWikiRunner, + STORMWikiLMConfigs, +) from knowledge_storm.lm import OpenAIModel from knowledge_storm.rm import YouRM from knowledge_storm.storm_wiki.modules.callback import BaseCallbackHandler @@ -21,7 +25,7 @@ from stoc import stoc -class DemoFileIOHelper(): +class DemoFileIOHelper: @staticmethod def read_structure_to_dict(articles_root_path): """ @@ -107,8 +111,10 @@ def set_file_modification_time(file_path, modification_time_string): file_path (str): The path to the file. modification_time_string (str): The desired modification time in 'YYYY-MM-DD HH:MM:SS' format. """ - california_tz = pytz.timezone('America/Los_Angeles') - modification_time = datetime.datetime.strptime(modification_time_string, '%Y-%m-%d %H:%M:%S') + california_tz = pytz.timezone("America/Los_Angeles") + modification_time = datetime.datetime.strptime( + modification_time_string, "%Y-%m-%d %H:%M:%S" + ) modification_time = california_tz.localize(modification_time) modification_time_utc = modification_time.astimezone(datetime.timezone.utc) modification_timestamp = modification_time_utc.timestamp() @@ -125,7 +131,7 @@ def get_latest_modification_time(path): Returns: str: The latest file's modification time in 'YYYY-MM-DD HH:MM:SS' format. """ - california_tz = pytz.timezone('America/Los_Angeles') + california_tz = pytz.timezone("America/Los_Angeles") latest_mod_time = None file_paths = [] @@ -138,17 +144,26 @@ def get_latest_modification_time(path): for file_path in file_paths: modification_timestamp = os.path.getmtime(file_path) - modification_time_utc = datetime.datetime.utcfromtimestamp(modification_timestamp) - modification_time_utc = modification_time_utc.replace(tzinfo=datetime.timezone.utc) - modification_time_california = modification_time_utc.astimezone(california_tz) - - if latest_mod_time is None or modification_time_california > latest_mod_time: + modification_time_utc = datetime.datetime.utcfromtimestamp( + modification_timestamp + ) + modification_time_utc = modification_time_utc.replace( + tzinfo=datetime.timezone.utc + ) + modification_time_california = modification_time_utc.astimezone( + california_tz + ) + + if ( + latest_mod_time is None + or modification_time_california > latest_mod_time + ): latest_mod_time = modification_time_california if latest_mod_time is not None: - return latest_mod_time.strftime('%Y-%m-%d %H:%M:%S') + return latest_mod_time.strftime("%Y-%m-%d %H:%M:%S") else: - return datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') + return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") @staticmethod def assemble_article_data(article_file_path_dict): @@ -171,25 +186,44 @@ def assemble_article_data(article_file_path_dict): if neither the raw nor polished article text exists in the provided file paths. """ - if "storm_gen_article.txt" in article_file_path_dict or "storm_gen_article_polished.txt" in article_file_path_dict: - full_article_name = "storm_gen_article_polished.txt" if "storm_gen_article_polished.txt" in article_file_path_dict else "storm_gen_article.txt" - article_data = {"article": DemoTextProcessingHelper.parse( - DemoFileIOHelper.read_txt_file(article_file_path_dict[full_article_name]))} + if ( + "storm_gen_article.txt" in article_file_path_dict + or "storm_gen_article_polished.txt" in article_file_path_dict + ): + full_article_name = ( + "storm_gen_article_polished.txt" + if "storm_gen_article_polished.txt" in article_file_path_dict + else "storm_gen_article.txt" + ) + article_data = { + "article": DemoTextProcessingHelper.parse( + DemoFileIOHelper.read_txt_file( + article_file_path_dict[full_article_name] + ) + ) + } if "url_to_info.json" in article_file_path_dict: article_data["citations"] = _construct_citation_dict_from_search_result( - DemoFileIOHelper.read_json_file(article_file_path_dict["url_to_info.json"])) + DemoFileIOHelper.read_json_file( + article_file_path_dict["url_to_info.json"] + ) + ) if "conversation_log.json" in article_file_path_dict: article_data["conversation_log"] = DemoFileIOHelper.read_json_file( - article_file_path_dict["conversation_log.json"]) + article_file_path_dict["conversation_log.json"] + ) return article_data return None -class DemoTextProcessingHelper(): - +class DemoTextProcessingHelper: @staticmethod def remove_citations(sent): - return re.sub(r"\[\d+", "", re.sub(r" \[\d+", "", sent)).replace(" |", "").replace("]", "") + return ( + re.sub(r"\[\d+", "", re.sub(r" \[\d+", "", sent)) + .replace(" |", "") + .replace("]", "") + ) @staticmethod def parse_conversation_history(json_data): @@ -199,43 +233,54 @@ def parse_conversation_history(json_data): """ parsed_data = [] for persona_conversation_data in json_data: - if ': ' in persona_conversation_data["perspective"]: - name, description = persona_conversation_data["perspective"].split(": ", 1) - elif '- ' in persona_conversation_data["perspective"]: - name, description = persona_conversation_data["perspective"].split("- ", 1) + if ": " in persona_conversation_data["perspective"]: + name, description = persona_conversation_data["perspective"].split( + ": ", 1 + ) + elif "- " in persona_conversation_data["perspective"]: + name, description = persona_conversation_data["perspective"].split( + "- ", 1 + ) else: name, description = "", persona_conversation_data["perspective"] cur_conversation = [] for dialogue_turn in persona_conversation_data["dlg_turns"]: - cur_conversation.append({"role": "user", "content": dialogue_turn["user_utterance"]}) cur_conversation.append( - {"role": "assistant", - "content": DemoTextProcessingHelper.remove_citations(dialogue_turn["agent_utterance"])}) + {"role": "user", "content": dialogue_turn["user_utterance"]} + ) + cur_conversation.append( + { + "role": "assistant", + "content": DemoTextProcessingHelper.remove_citations( + dialogue_turn["agent_utterance"] + ), + } + ) parsed_data.append((name, description, cur_conversation)) return parsed_data @staticmethod def parse(text): regex = re.compile(r']:\s+"(.*?)"\s+http') - text = regex.sub(']: http', text) + text = regex.sub("]: http", text) return text @staticmethod def add_markdown_indentation(input_string): - lines = input_string.split('\n') + lines = input_string.split("\n") processed_lines = [""] for line in lines: num_hashes = 0 for char in line: - if char == '#': + if char == "#": num_hashes += 1 else: break num_hashes -= 1 num_spaces = 4 * num_hashes - new_line = ' ' * num_spaces + line + new_line = " " * num_spaces + line processed_lines.append(new_line) - return '\n'.join(processed_lines) + return "\n".join(processed_lines) @staticmethod def get_current_time_string(): @@ -245,13 +290,15 @@ def get_current_time_string(): Returns: str: The current California time in 'YYYY-MM-DD HH:MM:SS' format. """ - california_tz = pytz.timezone('America/Los_Angeles') + california_tz = pytz.timezone("America/Los_Angeles") utc_now = datetime.datetime.now(datetime.timezone.utc) california_now = utc_now.astimezone(california_tz) - return california_now.strftime('%Y-%m-%d %H:%M:%S') + return california_now.strftime("%Y-%m-%d %H:%M:%S") @staticmethod - def compare_time_strings(time_string1, time_string2, time_format='%Y-%m-%d %H:%M:%S'): + def compare_time_strings( + time_string1, time_string2, time_format="%Y-%m-%d %H:%M:%S" + ): """ Compares two time strings to determine if they represent the same point in time. @@ -273,13 +320,13 @@ def compare_time_strings(time_string1, time_string2, time_format='%Y-%m-%d %H:%M @staticmethod def add_inline_citation_link(article_text, citation_dict): # Regular expression to find citations like [i] - pattern = r'\[(\d+)\]' + pattern = r"\[(\d+)\]" # Function to replace each citation with its Markdown link def replace_with_link(match): i = match.group(1) - url = citation_dict.get(int(i), {}).get('url', '#') - return f'[[{i}]]({url})' + url = citation_dict.get(int(i), {}).get("url", "#") + return f"[[{i}]]({url})" # Replace all citations in the text with Markdown links return re.sub(pattern, replace_with_link, article_text) @@ -292,26 +339,34 @@ def generate_html_toc(md_text): level = line.count("#") title = line.strip("# ").strip() anchor = title.lower().replace(" ", "-").replace(".", "") - toc.append(f"
  • {title}
  • ") + toc.append( + f"
  • {title}
  • " + ) return "" @staticmethod def construct_bibliography_from_url_to_info(url_to_info): bibliography_list = [] - sorted_url_to_unified_index = dict(sorted(url_to_info['url_to_unified_index'].items(), - key=lambda item: item[1])) + sorted_url_to_unified_index = dict( + sorted( + url_to_info["url_to_unified_index"].items(), key=lambda item: item[1] + ) + ) for url, index in sorted_url_to_unified_index.items(): - title = url_to_info['url_to_info'][url]['title'] + title = url_to_info["url_to_info"][url]["title"] bibliography_list.append(f"[{index}]: [{title}]({url})") bibliography_string = "\n\n".join(bibliography_list) return f"# References\n\n{bibliography_string}" -class DemoUIHelper(): +class DemoUIHelper: def st_markdown_adjust_size(content, font_size=20): - st.markdown(f""" + st.markdown( + f""" {content} - """, unsafe_allow_html=True) + """, + unsafe_allow_html=True, + ) @staticmethod def get_article_card_UI_style(boarder_color="#9AD8E1"): @@ -326,7 +381,7 @@ def get_article_card_UI_style(boarder_color="#9AD8E1"): "border-radius": "5px", "border-left": f"0.5rem solid {boarder_color}", "box-shadow": "0 0.15rem 1.75rem 0 rgba(58, 59, 69, 0.15)", - "margin": "0px" + "margin": "0px", }, "title": { "white-space": "nowrap", @@ -336,7 +391,7 @@ def get_article_card_UI_style(boarder_color="#9AD8E1"): "color": "rgb(49, 51, 63)", "text-align": "left", "width": "95%", - "font-weight": "normal" + "font-weight": "normal", }, "text": { "white-space": "nowrap", @@ -345,11 +400,9 @@ def get_article_card_UI_style(boarder_color="#9AD8E1"): "font-size": "25px", "color": "rgb(49, 51, 63)", "text-align": "left", - "width": "95%" + "width": "95%", }, - "filter": { - "background-color": "rgba(0, 0, 0, 0)" - } + "filter": {"background-color": "rgba(0, 0, 0, 0)"}, } @staticmethod @@ -373,7 +426,8 @@ def customize_toast_css_style(): line-height: 1.5; /* Adjust this value as needed */ } - """, unsafe_allow_html=True + """, + unsafe_allow_html=True, ) @staticmethod @@ -405,10 +459,12 @@ def _construct_citation_dict_from_search_result(search_results): if search_results is None: return None citation_dict = {} - for url, index in search_results['url_to_unified_index'].items(): - citation_dict[index] = {'url': url, - 'title': search_results['url_to_info'][url]['title'], - 'snippets': search_results['url_to_info'][url]['snippets']} + for url, index in search_results["url_to_unified_index"].items(): + citation_dict[index] = { + "url": url, + "title": search_results["url_to_info"][url]["title"], + "snippets": search_results["url_to_info"][url]["snippets"], + } return citation_dict @@ -416,10 +472,14 @@ def _display_main_article_text(article_text, citation_dict, table_content_sideba # Post-process the generated article for better display. if "Write the lead section:" in article_text: article_text = article_text[ - article_text.find("Write the lead section:") + len("Write the lead section:"):] - if article_text[0] == '#': - article_text = '\n'.join(article_text.split('\n')[1:]) - article_text = DemoTextProcessingHelper.add_inline_citation_link(article_text, citation_dict) + article_text.find("Write the lead section:") + + len("Write the lead section:") : + ] + if article_text[0] == "#": + article_text = "\n".join(article_text.split("\n")[1:]) + article_text = DemoTextProcessingHelper.add_inline_citation_link( + article_text, citation_dict + ) # '$' needs to be changed to '\$' to avoid being interpreted as LaTeX in st.markdown() article_text = article_text.replace("$", "\\$") stoc.from_markdown(article_text, table_content_sidebar) @@ -430,10 +490,10 @@ def _display_references(citation_dict): reference_list = [f"reference [{i}]" for i in range(1, len(citation_dict) + 1)] selected_key = st.selectbox("Select a reference", reference_list) citation_val = citation_dict[reference_list.index(selected_key) + 1] - citation_val['title'] = citation_val['title'].replace("$", "\\$") + citation_val["title"] = citation_val["title"].replace("$", "\\$") st.markdown(f"**Title:** {citation_val['title']}") st.markdown(f"**Url:** {citation_val['url']}") - snippets = '\n\n'.join(citation_val['snippets']).replace("$", "\\$") + snippets = "\n\n".join(citation_val["snippets"]).replace("$", "\\$") st.markdown(f"**Highlights:**\n\n {snippets}") else: st.markdown("**No references available**") @@ -444,7 +504,9 @@ def _display_persona_conversations(conversation_log): Display persona conversation in dialogue UI """ # get personas list as (persona_name, persona_description, dialogue turns list) tuple - parsed_conversation_history = DemoTextProcessingHelper.parse_conversation_history(conversation_log) + parsed_conversation_history = DemoTextProcessingHelper.parse_conversation_history( + conversation_log + ) # construct tabs for each persona conversation persona_tabs = st.tabs([name for (name, _, _) in parsed_conversation_history]) for idx, persona_tab in enumerate(persona_tabs): @@ -453,7 +515,7 @@ def _display_persona_conversations(conversation_log): st.info(parsed_conversation_history[idx][1]) # show user / agent utterance in dialogue UI for message in parsed_conversation_history[idx][2]: - message['content'] = message['content'].replace("$", "\\$") + message["content"] = message["content"].replace("$", "\\$") with st.chat_message(message["role"]): if message["role"] == "user": st.markdown(f"**{message['content']}**") @@ -461,14 +523,22 @@ def _display_persona_conversations(conversation_log): st.markdown(message["content"]) -def _display_main_article(selected_article_file_path_dict, show_reference=True, show_conversation=True): - article_data = DemoFileIOHelper.assemble_article_data(selected_article_file_path_dict) +def _display_main_article( + selected_article_file_path_dict, show_reference=True, show_conversation=True +): + article_data = DemoFileIOHelper.assemble_article_data( + selected_article_file_path_dict + ) with st.container(height=1000, border=True): - table_content_sidebar = st.sidebar.expander("**Table of contents**", expanded=True) - _display_main_article_text(article_text=article_data.get("article", ""), - citation_dict=article_data.get("citations", {}), - table_content_sidebar=table_content_sidebar) + table_content_sidebar = st.sidebar.expander( + "**Table of contents**", expanded=True + ) + _display_main_article_text( + article_text=article_data.get("article", ""), + citation_dict=article_data.get("citations", {}), + table_content_sidebar=table_content_sidebar, + ) # display reference panel if show_reference and "citations" in article_data: @@ -479,9 +549,12 @@ def _display_main_article(selected_article_file_path_dict, show_reference=True, # display conversation history if show_conversation and "conversation_log" in article_data: with st.expander( - "**STORM** is powered by a knowledge agent that proactively research a given topic by asking good questions coming from different perspectives.\n\n" - ":sunglasses: Click here to view the agent's brain**STORM**ing process!"): - _display_persona_conversations(conversation_log=article_data.get("conversation_log", {})) + "**STORM** is powered by a knowledge agent that proactively research a given topic by asking good questions coming from different perspectives.\n\n" + ":sunglasses: Click here to view the agent's brain**STORM**ing process!" + ): + _display_persona_conversations( + conversation_log=article_data.get("conversation_log", {}) + ) def get_demo_dir(): @@ -492,7 +565,11 @@ def clear_other_page_session_state(page_index: Optional[int]): if page_index is None: keys_to_delete = [key for key in st.session_state if key.startswith("page")] else: - keys_to_delete = [key for key in st.session_state if key.startswith("page") and f"page{page_index}" not in key] + keys_to_delete = [ + key + for key in st.session_state + if key.startswith("page") and f"page{page_index}" not in key + ] for key in set(keys_to_delete): del st.session_state[key] @@ -504,29 +581,44 @@ def set_storm_runner(): # configure STORM runner llm_configs = STORMWikiLMConfigs() - llm_configs.init_openai_model(openai_api_key=st.secrets['OPENAI_API_KEY'], openai_type='openai') - llm_configs.set_question_asker_lm(OpenAIModel(model='gpt-4-1106-preview', api_key=st.secrets['OPENAI_API_KEY'], - api_provider='openai', - max_tokens=500, temperature=1.0, top_p=0.9)) + llm_configs.init_openai_model( + openai_api_key=st.secrets["OPENAI_API_KEY"], openai_type="openai" + ) + llm_configs.set_question_asker_lm( + OpenAIModel( + model="gpt-4-1106-preview", + api_key=st.secrets["OPENAI_API_KEY"], + api_provider="openai", + max_tokens=500, + temperature=1.0, + top_p=0.9, + ) + ) engine_args = STORMWikiRunnerArguments( output_dir=current_working_dir, max_conv_turn=3, max_perspective=3, search_top_k=3, - retrieve_top_k=5 + retrieve_top_k=5, ) - rm = YouRM(ydc_api_key=st.secrets['YDC_API_KEY'], k=engine_args.search_top_k) + rm = YouRM(ydc_api_key=st.secrets["YDC_API_KEY"], k=engine_args.search_top_k) runner = STORMWikiRunner(engine_args, llm_configs, rm) st.session_state["runner"] = runner -def display_article_page(selected_article_name, selected_article_file_path_dict, - show_title=True, show_main_article=True): +def display_article_page( + selected_article_name, + selected_article_file_path_dict, + show_title=True, + show_main_article=True, +): if show_title: - st.markdown(f"

    {selected_article_name.replace('_', ' ')}

    ", - unsafe_allow_html=True) + st.markdown( + f"

    {selected_article_name.replace('_', ' ')}

    ", + unsafe_allow_html=True, + ) if show_main_article: _display_main_article(selected_article_file_path_dict) @@ -537,20 +629,25 @@ def __init__(self, status_container): self.status_container = status_container def on_identify_perspective_start(self, **kwargs): - self.status_container.info('Start identifying different perspectives for researching the topic.') + self.status_container.info( + "Start identifying different perspectives for researching the topic." + ) def on_identify_perspective_end(self, perspectives: list[str], **kwargs): perspective_list = "\n- ".join(perspectives) - self.status_container.success(f'Finish identifying perspectives. Will now start gathering information' - f' from the following perspectives:\n- {perspective_list}') + self.status_container.success( + f"Finish identifying perspectives. Will now start gathering information" + f" from the following perspectives:\n- {perspective_list}" + ) def on_information_gathering_start(self, **kwargs): - self.status_container.info('Start browsing the Internet.') + self.status_container.info("Start browsing the Internet.") def on_dialogue_turn_end(self, dlg_turn, **kwargs): urls = list(set([r.url for r in dlg_turn.search_results])) for url in urls: - self.status_container.markdown(f""" + self.status_container.markdown( + f"""
    Finish browsing {url}.
    - """, unsafe_allow_html=True) + """, + unsafe_allow_html=True, + ) def on_information_gathering_end(self, **kwargs): - self.status_container.success('Finish collecting information.') + self.status_container.success("Finish collecting information.") def on_information_organization_start(self, **kwargs): - self.status_container.info('Start organizing information into a hierarchical outline.') + self.status_container.info( + "Start organizing information into a hierarchical outline." + ) def on_direct_outline_generation_end(self, outline: str, **kwargs): - self.status_container.success(f'Finish leveraging the internal knowledge of the large language model.') + self.status_container.success( + f"Finish leveraging the internal knowledge of the large language model." + ) def on_outline_refinement_end(self, outline: str, **kwargs): - self.status_container.success(f'Finish leveraging the collected information.') + self.status_container.success(f"Finish leveraging the collected information.") diff --git a/frontend/demo_light/pages_util/CreateNewArticle.py b/frontend/demo_light/pages_util/CreateNewArticle.py index 2257044a..2cbef475 100644 --- a/frontend/demo_light/pages_util/CreateNewArticle.py +++ b/frontend/demo_light/pages_util/CreateNewArticle.py @@ -3,39 +3,60 @@ import demo_util import streamlit as st -from demo_util import DemoFileIOHelper, DemoTextProcessingHelper, DemoUIHelper, truncate_filename +from demo_util import ( + DemoFileIOHelper, + DemoTextProcessingHelper, + DemoUIHelper, + truncate_filename, +) + def handle_not_started(): if st.session_state["page3_write_article_state"] == "not started": - _, search_form_column, _ = st.columns([2, 5, 2]) with search_form_column: - with st.form(key='search_form'): + with st.form(key="search_form"): # Text input for the search topic - DemoUIHelper.st_markdown_adjust_size(content="Enter the topic you want to learn in depth:", - font_size=18) - st.session_state["page3_topic"] = st.text_input(label='page3_topic', label_visibility="collapsed") + DemoUIHelper.st_markdown_adjust_size( + content="Enter the topic you want to learn in depth:", font_size=18 + ) + st.session_state["page3_topic"] = st.text_input( + label="page3_topic", label_visibility="collapsed" + ) pass_appropriateness_check = True # Submit button for the form - submit_button = st.form_submit_button(label='Research') + submit_button = st.form_submit_button(label="Research") # only start new search when button is clicked, not started, or already finished previous one - if submit_button and st.session_state["page3_write_article_state"] in ["not started", "show results"]: + if submit_button and st.session_state["page3_write_article_state"] in [ + "not started", + "show results", + ]: if not st.session_state["page3_topic"].strip(): pass_appropriateness_check = False - st.session_state["page3_warning_message"] = "topic could not be empty" - - st.session_state["page3_topic_name_cleaned"] = st.session_state["page3_topic"].replace( - ' ', '_').replace('/', '_') - st.session_state["page3_topic_name_truncated"] = truncate_filename(st.session_state["page3_topic_name_cleaned"]) + st.session_state["page3_warning_message"] = ( + "topic could not be empty" + ) + + st.session_state["page3_topic_name_cleaned"] = ( + st.session_state["page3_topic"] + .replace(" ", "_") + .replace("/", "_") + ) + st.session_state["page3_topic_name_truncated"] = truncate_filename( + st.session_state["page3_topic_name_cleaned"] + ) if not pass_appropriateness_check: st.session_state["page3_write_article_state"] = "not started" - alert = st.warning(st.session_state["page3_warning_message"], icon="⚠️") + alert = st.warning( + st.session_state["page3_warning_message"], icon="⚠️" + ) time.sleep(5) alert.empty() else: st.session_state["page3_write_article_state"] = "initiated" + def handle_initiated(): if st.session_state["page3_write_article_state"] == "initiated": current_working_dir = os.path.join(demo_util.get_demo_dir(), "DEMO_WORKING_DIR") @@ -47,9 +68,12 @@ def handle_initiated(): st.session_state["page3_current_working_dir"] = current_working_dir st.session_state["page3_write_article_state"] = "pre_writing" + def handle_pre_writing(): - if st.session_state["page3_write_article_state"] == "pre_writing": - status = st.status("I am brain**STORM**ing now to research the topic. (This may take 2-3 minutes.)") + if st.session_state["page3_write_article_state"] == "pre_writing": + status = st.status( + "I am brain**STORM**ing now to research the topic. (This may take 2-3 minutes.)" + ) st_callback_handler = demo_util.StreamlitCallbackHandler(status) with status: # STORM main gen outline @@ -59,23 +83,37 @@ def handle_pre_writing(): do_generate_outline=True, do_generate_article=False, do_polish_article=False, - callback_handler=st_callback_handler + callback_handler=st_callback_handler, + ) + conversation_log_path = os.path.join( + st.session_state["page3_current_working_dir"], + st.session_state["page3_topic_name_truncated"], + "conversation_log.json", + ) + demo_util._display_persona_conversations( + DemoFileIOHelper.read_json_file(conversation_log_path) ) - conversation_log_path = os.path.join(st.session_state["page3_current_working_dir"], - st.session_state["page3_topic_name_truncated"], "conversation_log.json") - demo_util._display_persona_conversations(DemoFileIOHelper.read_json_file(conversation_log_path)) st.session_state["page3_write_article_state"] = "final_writing" status.update(label="brain**STORM**ing complete!", state="complete") + def handle_final_writing(): if st.session_state["page3_write_article_state"] == "final_writing": # polish final article with st.status( - "Now I will connect the information I found for your reference. (This may take 4-5 minutes.)") as status: - st.info('Now I will connect the information I found for your reference. (This may take 4-5 minutes.)') - st.session_state["runner"].run(topic=st.session_state["page3_topic"], do_research=False, - do_generate_outline=False, - do_generate_article=True, do_polish_article=True, remove_duplicate=False) + "Now I will connect the information I found for your reference. (This may take 4-5 minutes.)" + ) as status: + st.info( + "Now I will connect the information I found for your reference. (This may take 4-5 minutes.)" + ) + st.session_state["runner"].run( + topic=st.session_state["page3_topic"], + do_research=False, + do_generate_outline=False, + do_generate_article=True, + do_polish_article=True, + remove_duplicate=False, + ) # finish the session st.session_state["runner"].post_run() @@ -83,6 +121,7 @@ def handle_final_writing(): st.session_state["page3_write_article_state"] = "prepare_to_show_result" status.update(label="information snythesis complete!", state="complete") + def handle_prepare_to_show_result(): if st.session_state["page3_write_article_state"] == "prepare_to_show_result": _, show_result_col, _ = st.columns([4, 3, 4]) @@ -91,16 +130,23 @@ def handle_prepare_to_show_result(): st.session_state["page3_write_article_state"] = "completed" st.rerun() + def handle_completed(): - if st.session_state["page3_write_article_state"] == "completed": # display polished article current_working_dir_paths = DemoFileIOHelper.read_structure_to_dict( - st.session_state["page3_current_working_dir"]) - current_article_file_path_dict = current_working_dir_paths[st.session_state["page3_topic_name_truncated"]] - demo_util.display_article_page(selected_article_name=st.session_state["page3_topic_name_cleaned"], - selected_article_file_path_dict=current_article_file_path_dict, - show_title=True, show_main_article=True) + st.session_state["page3_current_working_dir"] + ) + current_article_file_path_dict = current_working_dir_paths[ + st.session_state["page3_topic_name_truncated"] + ] + demo_util.display_article_page( + selected_article_name=st.session_state["page3_topic_name_cleaned"], + selected_article_file_path_dict=current_article_file_path_dict, + show_title=True, + show_main_article=True, + ) + def create_new_article_page(): demo_util.clear_other_page_session_state(page_index=3) diff --git a/frontend/demo_light/pages_util/MyArticles.py b/frontend/demo_light/pages_util/MyArticles.py index e4e3bd11..3d9a62d2 100644 --- a/frontend/demo_light/pages_util/MyArticles.py +++ b/frontend/demo_light/pages_util/MyArticles.py @@ -11,7 +11,10 @@ def my_articles_page(): with st.sidebar: _, return_button_col = st.columns([2, 5]) with return_button_col: - if st.button("Select another article", disabled="page2_selected_my_article" not in st.session_state): + if st.button( + "Select another article", + disabled="page2_selected_my_article" not in st.session_state, + ): if "page2_selected_my_article" in st.session_state: del st.session_state["page2_selected_my_article"] st.rerun() @@ -20,17 +23,22 @@ def my_articles_page(): if "page2_user_articles_file_path_dict" not in st.session_state: local_dir = os.path.join(demo_util.get_demo_dir(), "DEMO_WORKING_DIR") os.makedirs(local_dir, exist_ok=True) - st.session_state["page2_user_articles_file_path_dict"] = DemoFileIOHelper.read_structure_to_dict(local_dir) + st.session_state["page2_user_articles_file_path_dict"] = ( + DemoFileIOHelper.read_structure_to_dict(local_dir) + ) # if no feature demo selected, display all featured articles as info cards def article_card_setup(column_to_add, card_title, article_name): with column_to_add: cleaned_article_title = article_name.replace("_", " ") - hasClicked = card(title=" / ".join(card_title), - text=article_name.replace("_", " "), - image=DemoFileIOHelper.read_image_as_base64( - os.path.join(demo_util.get_demo_dir(), "assets", "void.jpg")), - styles=DemoUIHelper.get_article_card_UI_style(boarder_color="#9AD8E1")) + hasClicked = card( + title=" / ".join(card_title), + text=article_name.replace("_", " "), + image=DemoFileIOHelper.read_image_as_base64( + os.path.join(demo_util.get_demo_dir(), "assets", "void.jpg") + ), + styles=DemoUIHelper.get_article_card_UI_style(boarder_color="#9AD8E1"), + ) if hasClicked: st.session_state["page2_selected_my_article"] = article_name st.rerun() @@ -40,7 +48,9 @@ def article_card_setup(column_to_add, card_title, article_name): my_article_columns = st.columns(3) if len(st.session_state["page2_user_articles_file_path_dict"]) > 0: # get article names - article_names = sorted(list(st.session_state["page2_user_articles_file_path_dict"].keys())) + article_names = sorted( + list(st.session_state["page2_user_articles_file_path_dict"].keys()) + ) # configure pagination pagination = st.container() bottom_menu = st.columns((1, 4, 1, 1, 1))[1:-1] @@ -48,7 +58,9 @@ def article_card_setup(column_to_add, card_title, article_name): batch_size = st.selectbox("Page Size", options=[24, 48, 72]) with bottom_menu[1]: total_pages = ( - int(len(article_names) / batch_size) if int(len(article_names) / batch_size) > 0 else 1 + int(len(article_names) / batch_size) + if int(len(article_names) / batch_size) > 0 + else 1 ) current_page = st.number_input( "Page", min_value=1, max_value=total_pages, step=1 @@ -60,19 +72,24 @@ def article_card_setup(column_to_add, card_title, article_name): my_article_count = 0 start_index = (current_page - 1) * batch_size end_index = min(current_page * batch_size, len(article_names)) - for article_name in article_names[start_index: end_index]: + for article_name in article_names[start_index:end_index]: column_to_add = my_article_columns[my_article_count % 3] my_article_count += 1 - article_card_setup(column_to_add=column_to_add, - card_title=["My Article"], - article_name=article_name) + article_card_setup( + column_to_add=column_to_add, + card_title=["My Article"], + article_name=article_name, + ) else: with my_article_columns[0]: - hasClicked = card(title="Get started", - text="Start your first research!", - image=DemoFileIOHelper.read_image_as_base64( - os.path.join(demo_util.get_demo_dir(), "assets", "void.jpg")), - styles=DemoUIHelper.get_article_card_UI_style()) + hasClicked = card( + title="Get started", + text="Start your first research!", + image=DemoFileIOHelper.read_image_as_base64( + os.path.join(demo_util.get_demo_dir(), "assets", "void.jpg") + ), + styles=DemoUIHelper.get_article_card_UI_style(), + ) if hasClicked: st.session_state.selected_page = 1 st.session_state["manual_selection_override"] = True @@ -80,8 +97,13 @@ def article_card_setup(column_to_add, card_title, article_name): st.rerun() else: selected_article_name = st.session_state["page2_selected_my_article"] - selected_article_file_path_dict = st.session_state["page2_user_articles_file_path_dict"][selected_article_name] + selected_article_file_path_dict = st.session_state[ + "page2_user_articles_file_path_dict" + ][selected_article_name] - demo_util.display_article_page(selected_article_name=selected_article_name, - selected_article_file_path_dict=selected_article_file_path_dict, - show_title=True, show_main_article=True) + demo_util.display_article_page( + selected_article_name=selected_article_name, + selected_article_file_path_dict=selected_article_file_path_dict, + show_title=True, + show_main_article=True, + ) diff --git a/frontend/demo_light/stoc.py b/frontend/demo_light/stoc.py index 7bd4402b..2abd3c45 100644 --- a/frontend/demo_light/stoc.py +++ b/frontend/demo_light/stoc.py @@ -44,9 +44,9 @@ def toc(self, expander): for title_size, title in self.toc_items: h = int(title_size.replace("h", "")) markdown_toc += ( - " " * 2 * h - + "- " - + f' {title} \n' + " " * 2 * h + + "- " + + f' {title} \n' ) # st.sidebar.write(markdown_toc, unsafe_allow_html=True) st.write(markdown_toc, unsafe_allow_html=True) @@ -56,27 +56,35 @@ def get_toc(cls, markdown_text: str, topic=""): def increase_heading_depth_and_add_top_heading(markdown_text, new_top_heading): lines = markdown_text.splitlines() # Increase the depth of each heading by adding an extra '#' - increased_depth_lines = ['#' + line if line.startswith('#') else line for line in lines] + increased_depth_lines = [ + "#" + line if line.startswith("#") else line for line in lines + ] # Add the new top-level heading at the beginning increased_depth_lines.insert(0, f"# {new_top_heading}") # Re-join the modified lines back into a single string - modified_text = '\n'.join(increased_depth_lines) + modified_text = "\n".join(increased_depth_lines) return modified_text if topic: - markdown_text = increase_heading_depth_and_add_top_heading(markdown_text, topic) + markdown_text = increase_heading_depth_and_add_top_heading( + markdown_text, topic + ) toc = [] for line in markdown_text.splitlines(): - if line.startswith('#'): + if line.startswith("#"): # Remove the '#' characters and strip leading/trailing spaces - heading_text = line.lstrip('#').strip() + heading_text = line.lstrip("#").strip() # Create slug (lowercase, spaces to hyphens, remove non-alphanumeric characters) - slug = re.sub(r'[^a-zA-Z0-9\s-]', '', heading_text).lower().replace(' ', '-') + slug = ( + re.sub(r"[^a-zA-Z0-9\s-]", "", heading_text) + .lower() + .replace(" ", "-") + ) # Determine heading level for indentation - level = line.count('#') - 1 + level = line.count("#") - 1 # Add to the table of contents - toc.append(' ' * level + f'- [{heading_text}](#{slug})') - return '\n'.join(toc) + toc.append(" " * level + f"- [{heading_text}](#{slug})") + return "\n".join(toc) @classmethod def from_markdown(cls, text: str, expander=None): diff --git a/frontend/demo_light/storm.py b/frontend/demo_light/storm.py index 1d095c35..5aca259e 100644 --- a/frontend/demo_light/storm.py +++ b/frontend/demo_light/storm.py @@ -11,13 +11,13 @@ def main(): global database - st.set_page_config(layout='wide') + st.set_page_config(layout="wide") if "first_run" not in st.session_state: - st.session_state['first_run'] = True + st.session_state["first_run"] = True # set api keys from secrets - if st.session_state['first_run']: + if st.session_state["first_run"]: for key, value in st.secrets.items(): if type(value) == str: os.environ[key] = value @@ -31,20 +31,26 @@ def main(): st.session_state["rerun_requested"] = False st.rerun() - st.write('', unsafe_allow_html=True) + st.write( + "", unsafe_allow_html=True + ) menu_container = st.container() with menu_container: pages = ["My Articles", "Create New Article"] - styles={ - "container": {"padding": "0.2rem 0", - "background-color": "#22222200"}, - } - menu_selection = option_menu(None, pages, - icons=['house', 'search'], - menu_icon="cast", default_index=0, orientation="horizontal", - manual_select=st.session_state.selected_page, - styles=styles, - key='menu_selection') + styles = { + "container": {"padding": "0.2rem 0", "background-color": "#22222200"}, + } + menu_selection = option_menu( + None, + pages, + icons=["house", "search"], + menu_icon="cast", + default_index=0, + orientation="horizontal", + manual_select=st.session_state.selected_page, + styles=styles, + key="menu_selection", + ) if st.session_state.get("manual_selection_override", False): menu_selection = pages[st.session_state["selected_page"]] st.session_state["manual_selection_override"] = False diff --git a/knowledge_storm/__init__.py b/knowledge_storm/__init__.py index 10338956..b036b4e4 100644 --- a/knowledge_storm/__init__.py +++ b/knowledge_storm/__init__.py @@ -7,4 +7,4 @@ from .utils import * from .dataclass import * -__version__ = "1.0.1" +__version__ = "1.1.0" diff --git a/knowledge_storm/collaborative_storm/engine.py b/knowledge_storm/collaborative_storm/engine.py index 2e281f6d..193d684d 100644 --- a/knowledge_storm/collaborative_storm/engine.py +++ b/knowledge_storm/collaborative_storm/engine.py @@ -14,9 +14,10 @@ from .modules.expert_generation import GenerateExpertModule from .modules.warmstart_hierarchical_chat import WarmStartModule from ..dataclass import ConversationTurn, KnowledgeBase +from ..encoder import Encoder from ..interface import LMConfigs, Agent from ..logging_wrapper import LoggingWrapper -from ..lm import OpenAIModel, AzureOpenAIModel, TogetherClient +from ..lm import LitellmModel from ..rm import BingSearch @@ -45,27 +46,26 @@ def init( if lm_type and lm_type == "openai": openai_kwargs = { "api_key": os.getenv("OPENAI_API_KEY"), - "api_provider": "openai", "temperature": temperature, "top_p": top_p, "api_base": None, } - self.question_answering_lm = OpenAIModel( + self.question_answering_lm = LitellmModel( model="gpt-4o-2024-05-13", max_tokens=1000, **openai_kwargs ) - self.discourse_manage_lm = OpenAIModel( + self.discourse_manage_lm = LitellmModel( model="gpt-4o-2024-05-13", max_tokens=500, **openai_kwargs ) - self.utterance_polishing_lm = OpenAIModel( + self.utterance_polishing_lm = LitellmModel( model="gpt-4o-2024-05-13", max_tokens=2000, **openai_kwargs ) - self.warmstart_outline_gen_lm = OpenAIModel( + self.warmstart_outline_gen_lm = LitellmModel( model="gpt-4-1106-preview", max_tokens=500, **openai_kwargs ) - self.question_asking_lm = OpenAIModel( + self.question_asking_lm = LitellmModel( model="gpt-4o-2024-05-13", max_tokens=300, **openai_kwargs ) - self.knowledge_base_lm = OpenAIModel( + self.knowledge_base_lm = LitellmModel( model="gpt-4o-2024-05-13", max_tokens=1000, **openai_kwargs ) elif lm_type and lm_type == "azure": @@ -76,23 +76,23 @@ def init( "api_base": os.getenv("AZURE_API_BASE"), "api_version": os.getenv("AZURE_API_VERSION"), } - self.question_answering_lm = AzureOpenAIModel( - model="gpt-4o", max_tokens=1000, **azure_kwargs, model_type="chat" + self.question_answering_lm = LitellmModel( + model="azure/gpt-4o", max_tokens=1000, **azure_kwargs, model_type="chat" ) - self.discourse_manage_lm = AzureOpenAIModel( - model="gpt-4o", max_tokens=500, **azure_kwargs, model_type="chat" + self.discourse_manage_lm = LitellmModel( + model="azure/gpt-4o", max_tokens=500, **azure_kwargs, model_type="chat" ) - self.utterance_polishing_lm = AzureOpenAIModel( - model="gpt-4o", max_tokens=2000, **azure_kwargs, model_type="chat" + self.utterance_polishing_lm = LitellmModel( + model="azure/gpt-4o", max_tokens=2000, **azure_kwargs, model_type="chat" ) - self.warmstart_outline_gen_lm = AzureOpenAIModel( - model="gpt-4o", max_tokens=300, **azure_kwargs, model_type="chat" + self.warmstart_outline_gen_lm = LitellmModel( + model="azure/gpt-4o", max_tokens=300, **azure_kwargs, model_type="chat" ) - self.question_asking_lm = AzureOpenAIModel( - model="gpt-4o", max_tokens=300, **azure_kwargs, model_type="chat" + self.question_asking_lm = LitellmModel( + model="azure/gpt-4o", max_tokens=300, **azure_kwargs, model_type="chat" ) - self.knowledge_base_lm = AzureOpenAIModel( - model="gpt-4o", max_tokens=1000, **azure_kwargs, model_type="chat" + self.knowledge_base_lm = LitellmModel( + model="azure/gpt-4o", max_tokens=1000, **azure_kwargs, model_type="chat" ) elif lm_type and lm_type == "together": together_kwargs = { @@ -100,38 +100,38 @@ def init( "temperature": temperature, "top_p": top_p, } - self.question_answering_lm = TogetherClient( - model="meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", + self.question_answering_lm = LitellmModel( + model="together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", max_tokens=1000, model_type="chat", **together_kwargs, ) - self.discourse_manage_lm = TogetherClient( - model="meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", + self.discourse_manage_lm = LitellmModel( + model="together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", max_tokens=500, model_type="chat", **together_kwargs, ) - self.utterance_polishing_lm = TogetherClient( - model="meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", + self.utterance_polishing_lm = LitellmModel( + model="together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", max_tokens=2000, model_type="chat", **together_kwargs, ) - self.warmstart_outline_gen_lm = TogetherClient( - model="meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", + self.warmstart_outline_gen_lm = LitellmModel( + model="together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", max_tokens=500, model_type="chat", **together_kwargs, ) - self.question_asking_lm = TogetherClient( - model="meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", + self.question_asking_lm = LitellmModel( + model="together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", max_tokens=300, model_type="chat", **together_kwargs, ) - self.knowledge_base_lm = TogetherClient( - model="meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", + self.knowledge_base_lm = LitellmModel( + model="together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", max_tokens=1000, model_type="chat", **together_kwargs, @@ -323,6 +323,7 @@ def __init__( lm_config: CollaborativeStormLMConfigs, runner_argument: RunnerArgument, rm: dspy.Retrieve, + encoder: Encoder, callback_handler: BaseCallbackHandler, ): # parameter management @@ -331,6 +332,7 @@ def __init__( self.logging_wrapper = logging_wrapper self.callback_handler = callback_handler self.rm = rm + self.encoder = encoder # role management self.experts: List[CoStormExpert] = [] self.simulated_user: SimulatedUser = SimulatedUser( @@ -360,6 +362,7 @@ def __init__( lm_config=self.lm_config, runner_argument=self.runner_argument, logging_wrapper=self.logging_wrapper, + encoder=self.encoder, callback_handler=self.callback_handler, ) self.general_knowledge_provider = CoStormExpert( @@ -469,16 +472,16 @@ def get_next_turn_policy( elif self.runner_argument.rag_only_baseline_mode: assert self.conversation_history[-1].role == "Guest" next_turn_policy.agent = self.pure_rag_agent + elif self.next_turn_moderator_override: + next_turn_policy.agent = self.moderator + if not dry_run: + self.next_turn_moderator_override = False elif ( not self.runner_argument.disable_moderator and self._should_generate_question(conversation_history) ): next_turn_policy.agent = self.moderator next_turn_policy.should_reorganize_knowledge_base = True - elif self.next_turn_moderator_override: - next_turn_policy.agent = self.moderator - if not dry_run: - self.next_turn_moderator_override = False # experts RAG gen else: next_turn_policy.agent = self.general_knowledge_provider @@ -516,18 +519,21 @@ def __init__( self.rm = BingSearch(k=runner_argument.retrieve_top_k) else: self.rm = rm + self.encoder = Encoder() self.conversation_history = [] self.warmstart_conv_archive = [] self.knowledge_base = KnowledgeBase( topic=self.runner_argument.topic, knowledge_base_lm=self.lm_config.knowledge_base_lm, node_expansion_trigger_count=self.runner_argument.node_expansion_trigger_count, + encoder=self.encoder, ) self.discourse_manager = DiscourseManager( lm_config=self.lm_config, runner_argument=self.runner_argument, logging_wrapper=self.logging_wrapper, rm=self.rm, + encoder=self.encoder, callback_handler=callback_handler, ) @@ -546,7 +552,7 @@ def to_dict(self): } @classmethod - def from_dict(cls, data): + def from_dict(cls, data, callback_handler: BaseCallbackHandler = None): # FIXME: does not use the lm_config data but naively use default setting lm_config = CollaborativeStormLMConfigs() lm_config.init(lm_type=os.getenv("OPENAI_API_TYPE")) @@ -554,7 +560,9 @@ def from_dict(cls, data): lm_config=lm_config, runner_argument=RunnerArgument.from_dict(data["runner_argument"]), logging_wrapper=LoggingWrapper(lm_config), + callback_handler=callback_handler, ) + costorm_runner.encoder = Encoder() costorm_runner.conversation_history = [ ConversationTurn.from_dict(turn) for turn in data["conversation_history"] ] @@ -567,6 +575,7 @@ def from_dict(cls, data): data=data["knowledge_base"], knowledge_base_lm=costorm_runner.lm_config.knowledge_base_lm, node_expansion_trigger_count=costorm_runner.runner_argument.node_expansion_trigger_count, + encoder=costorm_runner.encoder, ) return costorm_runner @@ -591,11 +600,13 @@ def warm_start(self): callback_handler=self.callback_handler, ) - warmstart_conv, warmstart_revised_conv, warmstart_experts = ( - warm_start_module.initiate_warm_start( - topic=self.runner_argument.topic, - knowledge_base=self.knowledge_base, - ) + ( + warmstart_conv, + warmstart_revised_conv, + warmstart_experts, + ) = warm_start_module.initiate_warm_start( + topic=self.runner_argument.topic, + knowledge_base=self.knowledge_base, ) self.discourse_manager.experts = ( self.discourse_manager._parse_expert_names_to_agent( @@ -607,11 +618,14 @@ def warm_start(self): warmstart_revised_conv if warmstart_revised_conv else warmstart_conv ) self.warmstart_conv_archive = warmstart_conv - self.knowledge_base.reorganize() + self.knowledge_base.reogranize() else: if self.knowledge_base is None: self.knowledge_base = KnowledgeBase( - topic=self.runner_argument.topic + topic=self.runner_argument.topic, + knowledge_base_lm=self.lm_config.knowledge_base_lm, + node_expansion_trigger_count=self.runner_argument.node_expansion_trigger_count, + encoder=self.encoder, ) if self.conversation_history is None: self.conversation_history = [] @@ -633,7 +647,9 @@ def generate_report(self) -> str: Returns: str: A string representing the report, with "#" "##" indicating hierarchical sections and [1][2] indicating references. """ - with self.logging_wrapper.log_pipeline_stage("report generation stage"): + with self.logging_wrapper.log_pipeline_stage( + f"report generation after conv turn: {len(self.conversation_history)}" + ): with self.logging_wrapper.log_event( "report generation stage: generate report" ): @@ -741,5 +757,5 @@ def step( ): if self.callback_handler is not None: self.callback_handler.on_mindmap_reorg_start() - self.knowledge_base.reorganize() + self.knowledge_base.reogranize() return conv_turn diff --git a/knowledge_storm/collaborative_storm/modules/article_generation.py b/knowledge_storm/collaborative_storm/modules/article_generation.py index be614007..532b7467 100644 --- a/knowledge_storm/collaborative_storm/modules/article_generation.py +++ b/knowledge_storm/collaborative_storm/modules/article_generation.py @@ -21,7 +21,7 @@ def _get_cited_information_string( self, all_citation_index: Set[int], knowledge_base: KnowledgeBase, - max_words: int = 1500, + max_words: int = 4000, ): information = [] cur_word_count = 0 diff --git a/knowledge_storm/collaborative_storm/modules/co_storm_agents.py b/knowledge_storm/collaborative_storm/modules/co_storm_agents.py index e7f60299..dd42cb8c 100644 --- a/knowledge_storm/collaborative_storm/modules/co_storm_agents.py +++ b/knowledge_storm/collaborative_storm/modules/co_storm_agents.py @@ -13,7 +13,7 @@ from .grounded_question_generation import GroundedQuestionGenerationModule from .simulate_user import GenSimulatedUserUtterance from ...dataclass import ConversationTurn, KnowledgeBase -from ...encoder import get_text_embeddings +from ...encoder import Encoder from ...interface import Agent, Information, LMConfigs from ...logging_wrapper import LoggingWrapper @@ -174,6 +174,7 @@ def __init__( lm_config: LMConfigs, runner_argument: "RunnerArgument", logging_wrapper: LoggingWrapper, + encoder: Encoder, callback_handler: BaseCallbackHandler = None, ): super().__init__(topic, role_name, role_description) @@ -184,6 +185,7 @@ def __init__( engine=self.lm_config.question_asking_lm ) self.callback_handler = callback_handler + self.encoder = encoder def _get_conv_turn_unused_information( self, conv_turn: ConversationTurn, knowledge_base: KnowledgeBase @@ -211,19 +213,12 @@ def _get_conv_turn_unused_information( # extract snippets to get embeddings unused_information_snippets = [info.snippets[0] for info in unused_information] # get embeddings - cache = knowledge_base.embedding_cache - unused_snippets_embeddings, _ = get_text_embeddings( - unused_information_snippets, embedding_cache=cache, max_workers=100 - ) - claim_embedding, _ = get_text_embeddings( - conv_turn.claim_to_make, embedding_cache=cache - ) - query_embedding, _ = get_text_embeddings( - conv_turn.queries, embedding_cache=cache - ) - cited_snippets_embedding, _ = get_text_embeddings( - cited_snippets, embedding_cache=cache + unused_snippets_embeddings = self.encoder.encode( + unused_information_snippets, max_workers=20 ) + claim_embedding = self.encoder.encode(conv_turn.claim_to_make) + query_embedding = self.encoder.encode(conv_turn.queries) + cited_snippets_embedding = self.encoder.encode(cited_snippets) # calculate similarity query_similarities = cosine_similarity( unused_snippets_embeddings, query_embedding @@ -270,8 +265,7 @@ def _get_sorted_unused_snippets( ) batch_snippets.append(conv_turn.claim_to_make) batch_snippets.extend(conv_turn.queries) - cache = knowledge_base.embedding_cache - get_text_embeddings(batch_snippets, embedding_cache=cache, max_workers=300) + self.encoder.encode(batch_snippets, max_workers=20) # get sorted unused snippets for each turn sorted_snippets = [] diff --git a/knowledge_storm/collaborative_storm/modules/information_insertion_module.py b/knowledge_storm/collaborative_storm/modules/information_insertion_module.py index c858671b..05cbde31 100644 --- a/knowledge_storm/collaborative_storm/modules/information_insertion_module.py +++ b/knowledge_storm/collaborative_storm/modules/information_insertion_module.py @@ -9,7 +9,7 @@ from .collaborative_storm_utils import trim_output_after_hint from ...dataclass import KnowledgeNode, KnowledgeBase -from ...encoder import get_text_embeddings +from ...encoder import Encoder from ...interface import Information @@ -51,8 +51,9 @@ class InsertInformationCandidateChoice(dspy.Signature): class InsertInformationModule(dspy.Module): - def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]): + def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel], encoder: Encoder): self.engine = engine + self.encoder = encoder self.insert_info = dspy.ChainOfThought(InsertInformation) self.candidate_choosing = dspy.Predict(InsertInformationCandidateChoice) @@ -153,7 +154,7 @@ def _get_sorted_embed_sim_section( query: str, ): if encoded_outline is not None and encoded_outline.size > 0: - encoded_query, token_usage = get_text_embeddings(f"{question}, {query}") + encoded_query = self.encoder.encode(f"{question}, {query}") sim = cosine_similarity([encoded_query], encoded_outline)[0] sorted_indices = np.argsort(sim) sorted_outlines = np.array(outlines)[sorted_indices[::-1]] @@ -226,7 +227,6 @@ def forward( insert_root: Optional[KnowledgeNode] = None, skip_candidate_from_embedding: bool = False, ): - if not isinstance(information, List): information = [information] intent_to_placement_dict: Dict = self._info_list_to_intent_mapping( @@ -270,9 +270,10 @@ def insert_info_to_kb(info, placement_prediction): root=insert_root, ) - encoded_outlines, outlines = ( - knowledge_base.get_knowledge_base_structure_embedding(root=insert_root) - ) + ( + encoded_outlines, + outlines, + ) = knowledge_base.get_knowledge_base_structure_embedding(root=insert_root) to_return = [] if not allow_create_new_node: # use multi thread as knowledge base structure does not change @@ -295,10 +296,11 @@ def insert_info_to_kb(info, placement_prediction): else: # use sequential insert as knowledge base structure might change for question, query in intent_to_placement_dict: - encoded_outlines, outlines = ( - knowledge_base.get_knowledge_base_structure_embedding( - root=insert_root - ) + ( + encoded_outlines, + outlines, + ) = knowledge_base.get_knowledge_base_structure_embedding( + root=insert_root ) _, placement_prediction = process_intent(question=question, query=query) intent_to_placement_dict[(question, query)] = placement_prediction diff --git a/knowledge_storm/dataclass.py b/knowledge_storm/dataclass.py index 1eff9a64..53bd7e4e 100644 --- a/knowledge_storm/dataclass.py +++ b/knowledge_storm/dataclass.py @@ -4,7 +4,7 @@ import threading from typing import Set, Dict, List, Optional, Union, Tuple -from .encoder import get_text_embeddings +from .encoder import Encoder from .interface import Information @@ -310,6 +310,7 @@ def __init__( topic: str, knowledge_base_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel], node_expansion_trigger_count: int, + encoder: Encoder, ): """ Initializes a KnowledgeBase instance. @@ -333,9 +334,10 @@ def __init__( ) self.topic: str = topic + self.encoder: Encoder = encoder self.information_insert_module = InsertInformationModule( - engine=knowledge_base_lm + engine=knowledge_base_lm, encoder=self.encoder ) self.expand_node_module = ExpandNodeModule( engine=knowledge_base_lm, @@ -353,7 +355,6 @@ def __init__( "encoded_structure": np.array([[]]), "structure_string": "", } - self.embedding_cache: Dict[str, np.ndarray] = {} self.info_uuid_to_info_dict: Dict[int, Information] = {} self.info_hash_to_uuid_dict: Dict[int, int] = {} self._lock = threading.Lock() @@ -375,11 +376,13 @@ def from_dict( data: Dict, knowledge_base_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel], node_expansion_trigger_count: int, + encoder: Encoder, ): knowledge_base = cls( topic=data["topic"], knowledge_base_lm=knowledge_base_lm, node_expansion_trigger_count=node_expansion_trigger_count, + encoder=encoder, ) knowledge_base.root = KnowledgeNode.from_dict(data["tree"]) knowledge_base.info_hash_to_uuid_dict = { @@ -408,9 +411,7 @@ def get_knowledge_base_structure_embedding( cleaned_outline_strings = [ outline.replace(" -> ", ", ") for outline in outline_strings ] - encoded_outline, _ = get_text_embeddings( - cleaned_outline_strings, embedding_cache=self.embedding_cache - ) + encoded_outline = self.encoder.encode(cleaned_outline_strings) self.kb_embedding = { "hash": outline_string_hash, "encoded_structure": encoded_outline, @@ -545,7 +546,6 @@ def get_node_hierarchy_string( cited_indices: Optional[List[int]] = None, root: Optional[KnowledgeNode] = None, ) -> str: - def find_node_contain_index(node, index): """ Traverses the tree downward from the given node. @@ -825,7 +825,7 @@ def update_from_conv_turn( def get_knowledge_base_summary(self): return self.gen_summary_module(self) - def reorganize(self): + def reogranize(self): """ Reorganizes the knowledge base through two main processes: top-down expansion and bottom-up cleaning. diff --git a/knowledge_storm/encoder.py b/knowledge_storm/encoder.py index 01fc9725..bb50eedc 100644 --- a/knowledge_storm/encoder.py +++ b/knowledge_storm/encoder.py @@ -1,144 +1,178 @@ -import requests import os -from typing import List, Tuple, Union, Optional, Dict, Literal import numpy as np from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import List, Tuple, Union, Optional, Dict, Literal +from pathlib import Path +try: + import warnings -class EmbeddingModel: - def __init__(self): - pass - - def get_embedding(self, text: str) -> Tuple[np.ndarray, int]: - raise Exception("Not implemented") - - -class OpenAIEmbeddingModel(EmbeddingModel): - def __init__(self, model: str = "text-embedding-3-small", api_key: str = None): - if not api_key: - api_key = os.getenv("OPENAI_API_KEY") - - self.url = "https://api.openai.com/v1/embeddings" - self.headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {api_key}", - } - self.model = model - - def get_embedding(self, text: str) -> Tuple[np.ndarray, int]: - data = {"input": text, "model": self.model} - - response = requests.post(self.url, headers=self.headers, json=data) - if response.status_code == 200: - data = response.json() - embedding = np.array(data["data"][0]["embedding"]) - token = data["usage"]["prompt_tokens"] - return embedding, token - else: - response.raise_for_status() - - -class TogetherEmbeddingModel: - def __init__(self, model: str = "BAAI/bge-large-en-v1.5", api_key: str = None): - import together - - self.model = model - if not api_key: - api_key = os.getenv("TOGETHER_API_KEY") - self.together_client = together.Together(api_key=api_key) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + if "LITELLM_LOCAL_MODEL_COST_MAP" not in os.environ: + os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" + import litellm - def get_embedding(self, text: str) -> Tuple[np.ndarray, int]: - response = self.together_client.embeddings.create(input=text, model=self.model) - return response.data[0].embedding, -1 + litellm.drop_params = True + litellm.telemetry = False + from litellm.caching.caching import Cache -class AzureOpenAIEmbeddingModel: - def __init__(self, model: str = "text-embedding-3-small", api_key: str = None): - from openai import AzureOpenAI + disk_cache_dir = os.path.join(Path.home(), ".storm_local_cache") + litellm.cache = Cache(disk_cache_dir=disk_cache_dir, type="disk") - self.model = model - if not api_key: - api_key = os.getenv("AZURE_API_KEY") +except ImportError: - self.client = AzureOpenAI( - api_key=api_key, - api_version=os.getenv("AZURE_API_VERSION"), - azure_endpoint=os.getenv("AZURE_API_BASE"), - ) + class LitellmPlaceholder: + def __getattr__(self, _): + raise ImportError( + "The LiteLLM package is not installed. Run `pip install litellm`." + ) - def get_embedding(self, text: str) -> Tuple[np.ndarray, int]: - response = self.client.embeddings.create(input=text, model=self.model) + litellm = LitellmPlaceholder() - embedding = np.array(response.data[0].embedding) - token = response.usage.prompt_tokens - return embedding, token - -def get_text_embeddings( - texts: Union[str, List[str]], - max_workers: int = 5, - embedding_cache: Optional[Dict[str, np.ndarray]] = None, -) -> Tuple[np.ndarray, int]: +class Encoder: """ - Get text embeddings using OpenAI's text-embedding-3-small model. - - Args: - texts (Union[str, List[str]]): A single text string or a list of text strings to embed. - max_workers (int): The maximum number of workers for parallel processing. - api_key (str): The API key for accessing OpenAI's services. - embedding_cache (Optional[Dict[str, np.ndarray]]): A cache to store previously computed embeddings. - - Returns: - Tuple[np.ndarray, int]: The 2D array of embeddings and the total token usage. + A wrapper class for the LiteLLM embedding model, designed to handle embedding + generation tasks efficiently. It supports parallel processing and local caching of + embedding results for improved performance. + + The Encoder utilizes the LiteLLM library to interact with various embedding models, + such as OpenAI and Azure embeddings. Users can specify the desired encoder type and + provide relevant API credentials during initialization. + + Features: + - Support for multiple embedding models (e.g., OpenAI, Azure). + - Parallel processing for faster embedding generation. + - Local disk caching to store and reuse embedding results. + - Total token usage tracking for cost monitoring. + + Note: + Refer to the LiteLLM documentation for details on supported embedding models: + https://docs.litellm.ai/docs/embedding/supported_embedding """ - embedding_model = None - encoder_type = os.getenv("ENCODER_API_TYPE") - if encoder_type and encoder_type == "openai": - embedding_model = OpenAIEmbeddingModel() - elif encoder_type and encoder_type == "azure": - embedding_model = AzureOpenAIEmbeddingModel() - elif encoder_type == encoder_type == "together": - embedding_model = TogetherEmbeddingModel() - else: - raise Exception( - "No valid encoder type is provided. Check /secrets.toml for the field ENCODER_API_TYPE" - ) - def fetch_embedding(text: str) -> Tuple[str, np.ndarray, int]: - if embedding_cache is not None and text in embedding_cache: - return ( - text, - embedding_cache[text], - 0, - ) # Returning 0 tokens since no API call is made - embedding, token_usage = embedding_model.get_embedding(text) + def __init__( + self, + encoder_type: Optional[str] = None, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + api_version: Optional[str] = None, + ): + """ + Initializes the Encoder with the appropriate embedding model. + + Args: + encoder_type (Optional[str]): Type of encoder ('openai', 'azure', etc.). + api_key (Optional[str]): API key for the encoder service. + api_base (Optional[str]): API base URL for the encoder service. + api_version (Optional[str]): API version for the encoder service. + """ + self.embedding_model_name = None + self.kargs = {} + self.total_token_usage = 0 + + # Initialize the appropriate embedding model + encoder_type = encoder_type or os.getenv("ENCODER_API_TYPE") + if not encoder_type: + raise ValueError("ENCODER_API_TYPE environment variable is not set.") + + if encoder_type.lower() == "openai": + self.embedding_model_name = "text-embedding-3-small" + self.kargs = {"api_key": api_key or os.getenv("OPENAI_API_KEY")} + elif encoder_type.lower() == "azure": + self.embedding_model_name = "azure/text-embedding-3-small" + self.kargs = { + "api_key": api_key or os.getenv("AZURE_API_KEY"), + "api_base": api_base or os.getenv("AZURE_API_BASE"), + "api_version": api_version or os.getenv("AZURE_API_VERSION"), + } + else: + raise ValueError( + f"Unsupported ENCODER_API_TYPE '{encoder_type}'. Supported types are 'openai', 'azure', 'together'." + ) + + def get_total_token_usage(self, reset: bool = False) -> int: + """ + Retrieves the total token usage. + + Args: + reset (bool): If True, resets the total token usage counter after retrieval. + + Returns: + int: The total number of tokens used. + """ + token_usage = self.total_token_usage + if reset: + self.total_token_usage = 0 + return token_usage + + def encode(self, texts: Union[str, List[str]], max_workers: int = 5) -> np.ndarray: + """ + Public method to get embeddings for the given texts. + + Args: + texts (Union[str, List[str]]): A single text string or a list of text strings to embed. + + Returns: + np.ndarray: The array of embeddings. + """ + return self._get_text_embeddings(texts, max_workers=max_workers) + + def _get_single_text_embedding(self, text): + response = litellm.embedding( + model=self.embedding_model_name, input=text, caching=True, **self.kargs + ) + embedding = response.data[0]["embedding"] + token_usage = response.get("usage", {}).get("total_tokens", 0) return text, embedding, token_usage - if isinstance(texts, str): - _, embedding, tokens = fetch_embedding(texts) - return np.array(embedding), tokens - - embeddings = [] - total_tokens = 0 - - with ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = {executor.submit(fetch_embedding, text): text for text in texts} - - for future in as_completed(futures): - try: - text, embedding, tokens = future.result() - embeddings.append((text, embedding, tokens)) - total_tokens += tokens - except Exception as e: - print(f"An error occurred for text: {futures[future]}") - print(e) - - # Sort results to match the order of the input texts - embeddings.sort(key=lambda x: texts.index(x[0])) - if embedding_cache is not None: - for text, embedding, _ in embeddings: - embedding_cache[text] = embedding - embeddings = [result[1] for result in embeddings] - - return np.array(embeddings), total_tokens + def _get_text_embeddings( + self, + texts: Union[str, List[str]], + max_workers: int = 5, + ) -> Tuple[np.ndarray, int]: + """ + Get text embeddings using OpenAI's text-embedding-3-small model. + + Args: + texts (Union[str, List[str]]): A single text string or a list of text strings to embed. + max_workers (int): The maximum number of workers for parallel processing. + api_key (str): The API key for accessing OpenAI's services. + embedding_cache (Optional[Dict[str, np.ndarray]]): A cache to store previously computed embeddings. + + Returns: + Tuple[np.ndarray, int]: The 2D array of embeddings and the total token usage. + """ + + if isinstance(texts, str): + _, embedding, tokens = self._get_single_text_embedding(texts) + self.total_token_usage += tokens + return np.array(embedding) + + embeddings = [] + total_tokens = 0 + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = { + executor.submit(self._get_single_text_embedding, text): text + for text in texts + } + + for future in as_completed(futures): + try: + text, embedding, tokens = future.result() + embeddings.append((text, embedding, tokens)) + total_tokens += tokens + except Exception as e: + print(f"An error occurred for text: {futures[future]}") + print(e) + + # Sort results to match the order of the input texts + embeddings.sort(key=lambda x: texts.index(x[0])) + embeddings = [result[1] for result in embeddings] + self.total_token_usage += total_tokens + + return np.array(embeddings) diff --git a/knowledge_storm/interface.py b/knowledge_storm/interface.py index 5922602f..d246742e 100644 --- a/knowledge_storm/interface.py +++ b/knowledge_storm/interface.py @@ -473,7 +473,6 @@ def collect_and_reset_lm_usage(self): return model_name_to_usage def log(self): - return OrderedDict( { attr_name: getattr(self, attr_name).kwargs diff --git a/knowledge_storm/lm.py b/knowledge_storm/lm.py index 0cae49be..d5836f72 100644 --- a/knowledge_storm/lm.py +++ b/knowledge_storm/lm.py @@ -1,16 +1,20 @@ +import backoff +import dspy +import functools import logging import os import random +import requests import threading from typing import Optional, Literal, Any +import ujson +from pathlib import Path + -import backoff -import dspy -import requests from dsp import ERRORS, backoff_hdlr, giveup_hdlr from dsp.modules.hf import openai_to_hf from dsp.modules.hf_client import send_hftgi_request_v01_wrapped -from openai import OpenAI +from openai import OpenAI, AzureOpenAI from transformers import AutoTokenizer try: @@ -18,6 +22,256 @@ except ImportError: RateLimitError = None +############################ +# Code copied from https://github.com/stanfordnlp/dspy/blob/main/dspy/clients/lm.py on Sep 29, 2024 + +# try: +import warnings + +with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + if "LITELLM_LOCAL_MODEL_COST_MAP" not in os.environ: + os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" + import litellm + + litellm.drop_params = True + litellm.telemetry = False + +from litellm.caching.caching import Cache + +disk_cache_dir = os.path.join(Path.home(), ".storm_local_cache") +litellm.cache = Cache(disk_cache_dir=disk_cache_dir, type="disk") + +# except ImportError: + +# class LitellmPlaceholder: +# def __getattr__(self, _): +# raise ImportError( +# "The LiteLLM package is not installed. Run `pip install litellm`." +# ) + +# litellm = LitellmPlaceholder() +LM_LRU_CACHE_MAX_SIZE = 3000 + + +class LM: + def __init__( + self, + model, + model_type="chat", + temperature=0.0, + max_tokens=1000, + cache=True, + **kwargs, + ): + self.model = model + self.model_type = model_type + self.cache = cache + self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs) + self.history = [] + + if "o1-" in model: + assert ( + max_tokens >= 5000 and temperature == 1.0 + ), "OpenAI's o1-* models require passing temperature=1.0 and max_tokens >= 5000 to `dspy.LM(...)`" + + def __call__(self, prompt=None, messages=None, **kwargs): + # Build the request. + cache = kwargs.pop("cache", self.cache) + messages = messages or [{"role": "user", "content": prompt}] + kwargs = {**self.kwargs, **kwargs} + + # Make the request and handle LRU & disk caching. + if self.model_type == "chat": + completion = cached_litellm_completion if cache else litellm_completion + else: + completion = ( + cached_litellm_text_completion if cache else litellm_text_completion + ) + + response = completion( + ujson.dumps(dict(model=self.model, messages=messages, **kwargs)) + ) + outputs = [ + c.message.content if hasattr(c, "message") else c["text"] + for c in response["choices"] + ] + + # Logging, with removed api key & where `cost` is None on cache hit. + kwargs = {k: v for k, v in kwargs.items() if not k.startswith("api_")} + entry = dict(prompt=prompt, messages=messages, kwargs=kwargs, response=response) + entry = dict(**entry, outputs=outputs, usage=dict(response["usage"])) + entry = dict( + **entry, cost=response.get("_hidden_params", {}).get("response_cost") + ) + self.history.append(entry) + + return outputs + + def inspect_history(self, n: int = 1): + _inspect_history(self, n) + + +@functools.lru_cache(maxsize=LM_LRU_CACHE_MAX_SIZE) +def cached_litellm_completion(request): + return litellm_completion(request, cache={"no-cache": False, "no-store": False}) + + +def litellm_completion(request, cache={"no-cache": True, "no-store": True}): + kwargs = ujson.loads(request) + return litellm.completion(cache=cache, **kwargs) + + +@functools.lru_cache(maxsize=LM_LRU_CACHE_MAX_SIZE) +def cached_litellm_text_completion(request): + return litellm_text_completion( + request, cache={"no-cache": False, "no-store": False} + ) + + +def litellm_text_completion(request, cache={"no-cache": True, "no-store": True}): + kwargs = ujson.loads(request) + + # Extract the provider and model from the model string. + model = kwargs.pop("model").split("/", 1) + provider, model = model[0] if len(model) > 1 else "openai", model[-1] + + # Use the API key and base from the kwargs, or from the environment. + api_key = kwargs.pop("api_key", None) or os.getenv(f"{provider}_API_KEY") + api_base = kwargs.pop("api_base", None) or os.getenv(f"{provider}_API_BASE") + + # Build the prompt from the messages. + prompt = "\n\n".join( + [x["content"] for x in kwargs.pop("messages")] + ["BEGIN RESPONSE:"] + ) + + return litellm.text_completion( + cache=cache, + model=f"text-completion-openai/{model}", + api_key=api_key, + api_base=api_base, + prompt=prompt, + **kwargs, + ) + + +def _green(text: str, end: str = "\n"): + return "\x1b[32m" + str(text).lstrip() + "\x1b[0m" + end + + +def _red(text: str, end: str = "\n"): + return "\x1b[31m" + str(text) + "\x1b[0m" + end + + +def _inspect_history(lm, n: int = 1): + """Prints the last n prompts and their completions.""" + + for item in lm.history[-n:]: + messages = item["messages"] or [{"role": "user", "content": item["prompt"]}] + outputs = item["outputs"] + + print("\n\n\n") + for msg in messages: + print(_red(f"{msg['role'].capitalize()} message:")) + print(msg["content"].strip()) + print("\n") + + print(_red("Response:")) + print(_green(outputs[0].strip())) + + if len(outputs) > 1: + choices_text = f" \t (and {len(outputs)-1} other completions)" + print(_red(choices_text, end="")) + + print("\n\n\n") + + +############################ + + +class LitellmModel(LM): + """A wrapper class for LiteLLM. + + Check out https://docs.litellm.ai/docs/providers for usage details. + """ + + def __init__( + self, + model: str = "openai/gpt-4o-mini", + api_key: Optional[str] = None, + model_type: Literal["chat", "text"] = "chat", + **kwargs, + ): + super().__init__(model=model, api_key=api_key, model_type=model_type, **kwargs) + self._token_usage_lock = threading.Lock() + self.prompt_tokens = 0 + self.completion_tokens = 0 + + def log_usage(self, response): + """Log the total tokens from the OpenAI API response.""" + usage_data = response.get("usage") + if usage_data: + with self._token_usage_lock: + self.prompt_tokens += usage_data.get("prompt_tokens", 0) + self.completion_tokens += usage_data.get("completion_tokens", 0) + + def get_usage_and_reset(self): + """Get the total tokens used and reset the token usage.""" + usage = { + self.model + or self.kwargs.get("model") + or self.kwargs.get("engine"): { + "prompt_tokens": self.prompt_tokens, + "completion_tokens": self.completion_tokens, + } + } + self.prompt_tokens = 0 + self.completion_tokens = 0 + + return usage + + def __call__(self, prompt=None, messages=None, **kwargs): + # Build the request. + cache = kwargs.pop("cache", self.cache) + messages = messages or [{"role": "user", "content": prompt}] + kwargs = {**self.kwargs, **kwargs} + + # Make the request and handle LRU & disk caching. + if self.model_type == "chat": + completion = cached_litellm_completion if cache else litellm_completion + else: + completion = ( + cached_litellm_text_completion if cache else litellm_text_completion + ) + + response = completion( + ujson.dumps(dict(model=self.model, messages=messages, **kwargs)) + ) + response_dict = response.json() + self.log_usage(response_dict) + outputs = [ + c.message.content if hasattr(c, "message") else c["text"] + for c in response["choices"] + ] + + # Logging, with removed api key & where `cost` is None on cache hit. + kwargs = {k: v for k, v in kwargs.items() if not k.startswith("api_")} + entry = dict( + prompt=prompt, messages=messages, kwargs=kwargs, response=response_dict + ) + entry = dict(**entry, outputs=outputs, usage=dict(response_dict["usage"])) + entry = dict( + **entry, cost=response.get("_hidden_params", {}).get("response_cost") + ) + self.history.append(entry) + + return outputs + + +# ======================================================================== +# The following language model classes were deprecated after v1.1.0. +# They remain in this file for backward compatibility but will no longer be maintained. + class OpenAIModel(dspy.OpenAI): """A wrapper class for dspy.OpenAI.""" @@ -204,54 +458,139 @@ def __call__( return completions -class AzureOpenAIModel(dspy.AzureOpenAI): - """A wrapper class for dspy.AzureOpenAI.""" +class AzureOpenAIModel(dspy.LM): + """A wrapper class of Azure OpenAI endpoint. + + Note: param::model should match the deployment_id on your Azure platform. + """ def __init__( self, - api_base: Optional[str] = None, - api_version: Optional[str] = None, - model: str = "gpt-4o-mini", - api_key: Optional[str] = None, + azure_endpoint: str, + api_version: str, + model: str, + api_key: str, model_type: Literal["chat", "text"] = "chat", **kwargs, ): - super().__init__( - api_base=api_base, - api_version=api_version, - model=model, + super().__init__(model=model) + self._token_usage_lock = threading.Lock() + self.prompt_tokens = 0 + self.completion_tokens = 0 + self.model = model + self.provider = "azure" + self.model_type = model_type + + self.client = AzureOpenAI( + azure_endpoint=azure_endpoint, api_key=api_key, - model_type=model_type, - **kwargs, + api_version=api_version, ) - self._token_usage_lock = threading.Lock() self.prompt_tokens = 0 self.completion_tokens = 0 + self.kwargs = { + "model": model, + "temperature": 0.0, + "max_tokens": 150, + "top_p": 1, + "frequency_penalty": 0, + "presence_penalty": 0, + "n": 1, + **kwargs, + } + + @backoff.on_exception( + backoff.expo, + ERRORS, + max_time=1000, + on_backoff=backoff_hdlr, + giveup=giveup_hdlr, + ) + def basic_request(self, prompt: str, **kwargs) -> Any: + kwargs = {**self.kwargs, **kwargs} + + try: + if self.model_type == "chat": + messages = [{"role": "user", "content": prompt}] + + response = self.client.chat.completions.create( + messages=messages, **kwargs + ) + else: + response = self.client.completions.create(prompt=prompt, **kwargs) + + self.log_usage(response) + + history_entry = { + "prompt": prompt, + "response": dict(response), + "kwargs": kwargs, + } + self.history.append(history_entry) + + return response + + except Exception as e: + logging.error(f"Error making request to Azure OpenAI: {str(e)}") + raise + + def _get_choice_text(self, choice: Any) -> str: + """Extract text from a choice object based on model type.""" + if self.model_type == "chat": + return choice.message.content + return choice.text + def log_usage(self, response): - """Log the total tokens from the OpenAI API response. - Override log_usage() in dspy.AzureOpenAI for tracking accumulated token usage. - """ - usage_data = response.get("usage") + """Log the total tokens from response.""" + usage_data = response.usage if usage_data: with self._token_usage_lock: - self.prompt_tokens += usage_data.get("prompt_tokens", 0) - self.completion_tokens += usage_data.get("completion_tokens", 0) + self.prompt_tokens += usage_data.prompt_tokens + self.completion_tokens += usage_data.completion_tokens def get_usage_and_reset(self): """Get the total tokens used and reset the token usage.""" usage = { - self.kwargs.get("model") - or self.kwargs.get("engine"): { + self.model: { "prompt_tokens": self.prompt_tokens, "completion_tokens": self.completion_tokens, } } self.prompt_tokens = 0 self.completion_tokens = 0 - return usage + def __call__( + self, + prompt: str, + only_completed: bool = True, + return_sorted: bool = False, + **kwargs, + ) -> list[str]: + """Get completions from Azure OpenAI. + + Args: + prompt: The prompt to send to the model + only_completed: Only return completed responses + return_sorted: Sort completions by probability (not implemented) + **kwargs: Additional arguments to pass to the API + + Returns: + List of completion strings + """ + response = self.basic_request(prompt, **kwargs) + + choices = response.choices + completed_choices = [c for c in choices if c.finish_reason != "length"] + + if only_completed and completed_choices: + choices = completed_choices + + completions = [self._get_choice_text(c) for c in choices] + + return completions + class GroqModel(dspy.OpenAI): """A wrapper class for Groq API (https://console.groq.com/), compatible with dspy.OpenAI.""" @@ -942,3 +1281,6 @@ def __call__( completions.append(response.parts[0].text) return completions + + +# ======================================================================== diff --git a/knowledge_storm/rm.py b/knowledge_storm/rm.py index a9836275..563116fe 100644 --- a/knowledge_storm/rm.py +++ b/knowledge_storm/rm.py @@ -7,10 +7,6 @@ import requests from dsp import backoff_hdlr, giveup_hdlr -from langchain_huggingface import HuggingFaceEmbeddings -from langchain_qdrant import Qdrant -from qdrant_client import QdrantClient - from .utils import WebPageHelper @@ -199,6 +195,8 @@ def __init__( device: str = "mps", k: int = 3, ): + from langchain_huggingface import HuggingFaceEmbeddings + """ Params: collection_name: Name of the Qdrant collection. @@ -228,6 +226,8 @@ def __init__( self.qdrant = None def _check_collection(self): + from langchain_qdrant import Qdrant + """ Check if the Qdrant collection exists and create it if it does not. """ @@ -248,6 +248,8 @@ def _check_collection(self): ) def init_online_vector_db(self, url: str, api_key: str): + from qdrant_client import QdrantClient + """ Initialize the Qdrant client that is connected to an online vector store with the given URL and API key. @@ -269,6 +271,8 @@ def init_online_vector_db(self, url: str, api_key: str): raise ValueError(f"Error occurs when connecting to the server: {e}") def init_offline_vector_db(self, vector_store_path: str): + from qdrant_client import QdrantClient + """ Initialize the Qdrant client that is connected to an offline vector store with the given vector store folder path. @@ -336,19 +340,20 @@ def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[st class StanfordOvalArxivRM(dspy.Retrieve): """[Alpha] This retrieval class is for internal use only, not intended for the public.""" - def __init__(self, endpoint, k=3): + def __init__(self, endpoint, k=3, rerank=True): super().__init__(k=k) self.endpoint = endpoint self.usage = 0 + self.rerank = rerank def get_usage_and_reset(self): usage = self.usage self.usage = 0 - return {"CS224vArxivRM": usage} + return {"StanfordOvalArxivRM": usage} def _retrieve(self, query: str): - payload = {"query": query, "num_blocks": self.k} + payload = {"query": query, "num_blocks": self.k, "rerank": self.rerank} response = requests.post( self.endpoint, json=payload, headers={"Content-Type": "application/json"} @@ -356,16 +361,21 @@ def _retrieve(self, query: str): # Check if the request was successful if response.status_code == 200: - data = response.json()[0] + response_data_list = response.json()[0]["results"] results = [] - for i in range(len(data["title"])): + for response_data in response_data_list: result = { - "title": data["title"][i], - "url": data["title"][i], - "snippets": [data["text"][i]], - "description": "N/A", - "meta": {"section_title": data["full_section_title"][i]}, + "title": response_data["document_title"], + "url": response_data["url"], + "snippets": [response_data["content"]], + "description": response_data.get("description", "N/A"), + "meta": { + key: value + for key, value in response_data.items() + if key not in ["document_title", "url", "content"] + }, } + results.append(result) return results @@ -537,9 +547,7 @@ def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[st snippets = [organic.get("snippet")] if self.ENABLE_EXTRA_SNIPPET_EXTRACTION: snippets.extend( - valid_url_to_snippets.get(url.strip("'"), {}).get( - "snippets", [] - ) + valid_url_to_snippets.get(url, {}).get("snippets", []) ) collected_results.append( { diff --git a/knowledge_storm/storm_wiki/engine.py b/knowledge_storm/storm_wiki/engine.py index 9a94acff..c698a9a2 100644 --- a/knowledge_storm/storm_wiki/engine.py +++ b/knowledge_storm/storm_wiki/engine.py @@ -14,7 +14,7 @@ from .modules.persona_generator import StormPersonaGenerator from .modules.storm_dataclass import StormInformationTable, StormArticle from ..interface import Engine, LMConfigs, Retriever -from ..lm import OpenAIModel, AzureOpenAIModel +from ..lm import LitellmModel from ..utils import FileIOHelper, makeStringRed, truncate_filename @@ -56,50 +56,49 @@ def init_openai_model( openai_kwargs = { "api_key": openai_api_key, - "api_provider": "openai", "temperature": temperature, "top_p": top_p, "api_base": None, } if openai_type and openai_type == "openai": - self.conv_simulator_lm = OpenAIModel( + self.conv_simulator_lm = LitellmModel( model="gpt-4o-mini-2024-07-18", max_tokens=500, **openai_kwargs ) - self.question_asker_lm = OpenAIModel( + self.question_asker_lm = LitellmModel( model="gpt-4o-mini-2024-07-18", max_tokens=500, **openai_kwargs ) # 1/12/2024: Update gpt-4 to gpt-4-1106-preview. (Currently keep the original setup when using azure.) - self.outline_gen_lm = OpenAIModel( + self.outline_gen_lm = LitellmModel( model="gpt-4-0125-preview", max_tokens=400, **openai_kwargs ) - self.article_gen_lm = OpenAIModel( + self.article_gen_lm = LitellmModel( model="gpt-4o-2024-05-13", max_tokens=700, **openai_kwargs ) - self.article_polish_lm = OpenAIModel( + self.article_polish_lm = LitellmModel( model="gpt-4o-2024-05-13", max_tokens=4000, **openai_kwargs ) elif openai_type and openai_type == "azure": - self.conv_simulator_lm = OpenAIModel( - model="gpt-4o-mini-2024-07-18", max_tokens=500, **openai_kwargs + self.conv_simulator_lm = LitellmModel( + model="azure/gpt-4o-mini-2024-07-18", max_tokens=500, **openai_kwargs ) - self.question_asker_lm = AzureOpenAIModel( - model="gpt-4o-mini-2024-07-18", + self.question_asker_lm = LitellmModel( + model="azure/gpt-4o-mini-2024-07-18", max_tokens=500, **azure_kwargs, model_type="chat", ) # use combination of openai and azure-openai as azure-openai does not support gpt-4 in standard deployment - self.outline_gen_lm = AzureOpenAIModel( - model="gpt-4o", max_tokens=400, **azure_kwargs, model_type="chat" + self.outline_gen_lm = LitellmModel( + model="azure/gpt-4o", max_tokens=400, **azure_kwargs, model_type="chat" ) - self.article_gen_lm = AzureOpenAIModel( - model="gpt-4o-mini-2024-07-18", + self.article_gen_lm = LitellmModel( + model="azure/gpt-4o-mini-2024-07-18", max_tokens=700, **azure_kwargs, model_type="chat", ) - self.article_polish_lm = AzureOpenAIModel( - model="gpt-4o-mini-2024-07-18", + self.article_polish_lm = LitellmModel( + model="azure/gpt-4o-mini-2024-07-18", max_tokens=4000, **azure_kwargs, model_type="chat", @@ -214,16 +213,16 @@ def run_knowledge_curation_module( ground_truth_url: str = "None", callback_handler: BaseCallbackHandler = None, ) -> StormInformationTable: - - information_table, conversation_log = ( - self.storm_knowledge_curation_module.research( - topic=self.topic, - ground_truth_url=ground_truth_url, - callback_handler=callback_handler, - max_perspective=self.args.max_perspective, - disable_perspective=False, - return_conversation_log=True, - ) + ( + information_table, + conversation_log, + ) = self.storm_knowledge_curation_module.research( + topic=self.topic, + ground_truth_url=ground_truth_url, + callback_handler=callback_handler, + max_perspective=self.args.max_perspective, + disable_perspective=False, + return_conversation_log=True, ) FileIOHelper.dump_json( @@ -240,7 +239,6 @@ def run_outline_generation_module( information_table: StormInformationTable, callback_handler: BaseCallbackHandler = None, ) -> StormArticle: - outline, draft_outline = self.storm_outline_generation_module.generate_outline( topic=self.topic, information_table=information_table, @@ -258,10 +256,9 @@ def run_outline_generation_module( def run_article_generation_module( self, outline: StormArticle, - information_table: StormInformationTable, + information_table=StormInformationTable, callback_handler: BaseCallbackHandler = None, ) -> StormArticle: - draft_article = self.storm_article_generation.generate_article( topic=self.topic, information_table=information_table, @@ -279,7 +276,6 @@ def run_article_generation_module( def run_article_polishing_module( self, draft_article: StormArticle, remove_duplicate: bool = False ) -> StormArticle: - polished_article = self.storm_article_polishing_module.polish_article( topic=self.topic, draft_article=draft_article, diff --git a/knowledge_storm/storm_wiki/modules/article_generation.py b/knowledge_storm/storm_wiki/modules/article_generation.py index a23b7886..6289c7aa 100644 --- a/knowledge_storm/storm_wiki/modules/article_generation.py +++ b/knowledge_storm/storm_wiki/modules/article_generation.py @@ -88,7 +88,6 @@ def generate_article( ) section_output_dict_collection = [section_output_dict] else: - with concurrent.futures.ThreadPoolExecutor( max_workers=self.max_thread_num ) as executor: diff --git a/knowledge_storm/storm_wiki/modules/storm_dataclass.py b/knowledge_storm/storm_wiki/modules/storm_dataclass.py index 75812d9c..119869cd 100644 --- a/knowledge_storm/storm_wiki/modules/storm_dataclass.py +++ b/knowledge_storm/storm_wiki/modules/storm_dataclass.py @@ -114,9 +114,7 @@ def prepare_table_for_retrieval(self): for snippet in information.snippets: self.collected_urls.append(url) self.collected_snippets.append(snippet) - self.encoded_snippets = self.encoder.encode( - self.collected_snippets, show_progress_bar=False - ) + self.encoded_snippets = self.encoder.encode(self.collected_snippets) def retrieve_information( self, queries: Union[List[str], str], search_top_k @@ -126,7 +124,7 @@ def retrieve_information( if type(queries) is str: queries = [queries] for query in queries: - encoded_query = self.encoder.encode(query, show_progress_bar=False) + encoded_query = self.encoder.encode(query) sim = cosine_similarity([encoded_query], self.encoded_snippets)[0] sorted_indices = np.argsort(sim) for i in sorted_indices[-search_top_k:][::-1]: diff --git a/knowledge_storm/utils.py b/knowledge_storm/utils.py index 2e3cbb65..4411b05e 100644 --- a/knowledge_storm/utils.py +++ b/knowledge_storm/utils.py @@ -1,4 +1,6 @@ import concurrent.futures +import dspy +import httpx import json import logging import os @@ -6,21 +8,14 @@ import re import regex import sys -import time +import toml from typing import List, Dict +from tqdm import tqdm -import httpx -import pandas as pd -import toml -from langchain_core.documents import Document -from langchain_huggingface import HuggingFaceEmbeddings -from langchain_qdrant import Qdrant from langchain_text_splitters import RecursiveCharacterTextSplitter -from qdrant_client import QdrantClient, models -from tqdm import tqdm from trafilatura import extract -from .lm import OpenAIModel +from .lm import LitellmModel logging.getLogger("httpx").setLevel(logging.WARNING) # Disable INFO logging for httpx. @@ -72,8 +67,11 @@ class QdrantVectorStoreManager: @staticmethod def _check_create_collection( - client: QdrantClient, collection_name: str, model: HuggingFaceEmbeddings + client: "QdrantClient", collection_name: str, model: "HuggingFaceEmbeddings" ): + from langchain_qdrant import Qdrant + from qdrant_client import models + """Check if the Qdrant collection exists and create it if it does not.""" if client is None: raise ValueError("Qdrant client is not initialized.") @@ -103,8 +101,10 @@ def _check_create_collection( @staticmethod def _init_online_vector_db( - url: str, api_key: str, collection_name: str, model: HuggingFaceEmbeddings + url: str, api_key: str, collection_name: str, model: "HuggingFaceEmbeddings" ): + from qdrant_client import QdrantClient + """Initialize the Qdrant client that is connected to an online vector store with the given URL and API key. Args: @@ -128,8 +128,10 @@ def _init_online_vector_db( @staticmethod def _init_offline_vector_db( - vector_store_path: str, collection_name: str, model: HuggingFaceEmbeddings + vector_store_path: str, collection_name: str, model: "HuggingFaceEmbeddings" ): + from qdrant_client import QdrantClient + """Initialize the Qdrant client that is connected to an offline vector store with the given vector store folder path. Args: @@ -164,6 +166,8 @@ def create_or_update_vector_store( embedding_model: str = "BAAI/bge-m3", device: str = "mps", ): + from qdrant_client import Document + """ Takes a CSV file and adds each row in the CSV file to the Qdrant collection. @@ -192,6 +196,8 @@ def create_or_update_vector_store( model_kwargs = {"device": device} encode_kwargs = {"normalize_embeddings": True} + from langchain_huggingface import HuggingFaceEmbeddings + model = HuggingFaceEmbeddings( model_name=embedding_model, model_kwargs=model_kwargs, @@ -231,6 +237,8 @@ def create_or_update_vector_store( raise ValueError("Qdrant client is not initialized.") # read the csv file + import pandas as pd + df = pd.read_csv(file_path) # check that content column exists and url column exists if content_column not in df.columns: @@ -704,10 +712,8 @@ def urls_to_snippets(self, urls: List[str]) -> Dict: def user_input_appropriateness_check(user_input): - my_openai_model = OpenAIModel( - api_key=os.getenv("OPENAI_API_KEY"), - api_provider="openai", - model="gpt-4o-mini-2024-07-18", + my_openai_model = LitellmModel( + model="azure/gpt-4o-mini", max_tokens=10, temperature=0.0, top_p=0.9, @@ -761,10 +767,8 @@ def user_input_appropriateness_check(user_input): def purpose_appropriateness_check(user_input): - my_openai_model = OpenAIModel( - api_key=os.getenv("OPENAI_API_KEY"), - api_provider="openai", - model="gpt-4o-mini-2024-07-18", + my_openai_model = LitellmModel( + model="azure/gpt-4o-mini", max_tokens=10, temperature=0.0, top_p=0.9, diff --git a/requirements.txt b/requirements.txt index c6c10f67..14e68cf2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,5 @@ langchain-huggingface qdrant-client langchain-qdrant numpy==1.26.4 +litellm==1.59.3 +diskcache \ No newline at end of file diff --git a/setup.py b/setup.py index 0890bb74..970159ad 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ setup( name="knowledge-storm", - version="1.0.1", + version="1.1.0", author="Yijia Shao, Yucheng Jiang", author_email="shaoyj@stanford.edu, yuchengj@stanford.edu", description="STORM: A language model-powered knowledge curation engine.",