@@ -46,6 +46,7 @@ def __init__(
4646 model_stages : Dict [str , str ] = None ,
4747 on_progress : Optional [Callable [[ProgressEvent ], None ]] = None ,
4848 enable_planning : bool = False ,
49+ enable_parallel_check : bool = False ,
4950 llm_params : Optional [Dict [str , Any ]] = None ,
5051 ):
5152 self .client = client
@@ -62,6 +63,7 @@ def __init__(
6263 self .on_progress = on_progress
6364 self .sources : List [Source ] = []
6465 self .enable_planning = enable_planning
66+ self .enable_parallel_check = enable_parallel_check
6567 self .llm_params = llm_params or {}
6668 self ._task = None # 用于存储当前任务,以便取消
6769
@@ -158,6 +160,93 @@ async def _verify_solution(
158160
159161 return {"bug_report" : bug_report , "good_verify" : good_verify }
160162
163+ async def _verify_solution_parallel (
164+ self ,
165+ problem_statement : MessageContent ,
166+ solution : str
167+ ) -> Dict [str , str ]:
168+ """并行验证解决方案 - 同时启动required_verifications个验证LLM调用,全部通过才算成功"""
169+ detailed_solution = self ._extract_detailed_solution (solution )
170+ # 提取文本用于构建提示词
171+ problem_text = extract_text_from_content (problem_statement )
172+ verification_prompt = build_verification_prompt (
173+ problem_text ,
174+ detailed_solution
175+ )
176+
177+ num_checks = self .required_verifications
178+ self ._emit ("progress" , {"message" : f"Parallel verifying solution ({ num_checks } concurrent checks)..." })
179+
180+ # 使用验证阶段的模型
181+ verification_model = self ._get_model_for_stage ("verification" )
182+
183+ # 同时启动required_verifications个验证LLM调用
184+ verification_tasks = [
185+ self .client .generate_text (
186+ model = verification_model ,
187+ system = VERIFICATION_SYSTEM_PROMPT ,
188+ prompt = verification_prompt ,
189+ ** self .llm_params
190+ )
191+ for _ in range (num_checks )
192+ ]
193+
194+ # 等待所有验证完成
195+ verification_outputs = await asyncio .gather (* verification_tasks )
196+
197+ # 检查每个验证结果
198+ check_tasks = []
199+ for verification_output in verification_outputs :
200+ check_prompt = (
201+ f'Response in "yes" or "no". Is the following statement saying the '
202+ f'solution is correct, or does not contain critical error or a major '
203+ f'justification gap?\n \n { verification_output } '
204+ )
205+ check_tasks .append (
206+ self .client .generate_text (
207+ model = verification_model ,
208+ prompt = check_prompt ,
209+ ** self .llm_params
210+ )
211+ )
212+
213+ # 等待所有检查完成
214+ good_verifies = await asyncio .gather (* check_tasks )
215+
216+ # 统计通过的验证数量
217+ passed_count = sum (1 for gv in good_verifies if "yes" in gv .lower ())
218+
219+ # 需要全部通过才算验证成功
220+ passed = passed_count == num_checks
221+
222+ # 收集 bug 报告(如果有的话)
223+ bug_reports = []
224+ if not passed :
225+ for i , (verification_output , good_verify ) in enumerate (zip (verification_outputs , good_verifies )):
226+ if "yes" not in good_verify .lower ():
227+ bug_report = self ._extract_detailed_solution (
228+ verification_output ,
229+ "Detailed Review" ,
230+ False
231+ )
232+ if bug_report :
233+ bug_reports .append (f"[Check { i + 1 } ] { bug_report } " )
234+
235+ combined_bug_report = "\n \n " .join (bug_reports ) if bug_reports else ""
236+
237+ # 返回综合结果
238+ good_verify_summary = f"yes (passed { passed_count } /{ num_checks } checks)" if passed else f"no (passed { passed_count } /{ num_checks } checks)"
239+
240+ return {
241+ "bug_report" : combined_bug_report ,
242+ "good_verify" : good_verify_summary ,
243+ "parallel_results" : {
244+ "total_checks" : num_checks ,
245+ "passed_checks" : passed_count ,
246+ "individual_results" : good_verifies
247+ }
248+ }
249+
161250 async def _initial_exploration (
162251 self ,
163252 problem_statement : MessageContent ,
@@ -242,11 +331,17 @@ async def _initial_exploration(
242331
243332 self ._emit ("solution" , {"solution" : improved_solution , "iteration" : 0 })
244333
245- # 验证
246- verification = await self ._verify_solution (
247- problem_statement ,
248- improved_solution
249- )
334+ # 验证 - 根据配置选择串行或并行
335+ if self .enable_parallel_check :
336+ verification = await self ._verify_solution_parallel (
337+ problem_statement ,
338+ improved_solution
339+ )
340+ else :
341+ verification = await self ._verify_solution (
342+ problem_statement ,
343+ improved_solution
344+ )
250345
251346 self ._emit ("verification" , {
252347 "passed" : "yes" in verification ["good_verify" ].lower (),
@@ -400,8 +495,11 @@ async def run(self) -> DeepThinkResult:
400495 knowledge_enhanced = len (self .sources ) > 0 ,
401496 )
402497
403- # 再次验证
404- verification = await self ._verify_solution (self .problem_statement , solution )
498+ # 再次验证 - 根据配置选择串行或并行
499+ if self .enable_parallel_check :
500+ verification = await self ._verify_solution_parallel (self .problem_statement , solution )
501+ else :
502+ verification = await self ._verify_solution (self .problem_statement , solution )
405503 self ._emit ("verification" , {
406504 "passed" : "yes" in verification ["good_verify" ].lower (),
407505 "iteration" : i + 1 ,
0 commit comments