Skip to content

Commit 304c4bc

Browse files
committed
feat: parallel_check
1 parent 26ae3e3 commit 304c4bc

File tree

5 files changed

+115
-7
lines changed

5 files changed

+115
-7
lines changed

api/v1/chat.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def on_agent_update(agent_id: str, update: Dict[str, Any]):
160160
model_stages=model_config.models,
161161
on_progress=on_progress,
162162
on_agent_update=on_agent_update,
163+
enable_parallel_check=model_config.parallel_check,
163164
llm_params=llm_params,
164165
)
165166

@@ -242,6 +243,7 @@ def on_progress(event):
242243
model_stages=model_config.models,
243244
on_progress=on_progress,
244245
enable_planning=model_config.has_plan_mode,
246+
enable_parallel_check=model_config.parallel_check,
245247
llm_params=llm_params,
246248
)
247249

@@ -404,6 +406,7 @@ async def chat_completions(
404406
num_agents=model_config.num_agent,
405407
parallel_run_agent=model_config.parallel_run_agent,
406408
model_stages=model_config.models,
409+
enable_parallel_check=model_config.parallel_check,
407410
llm_params=llm_params,
408411
)
409412
result = await engine.run()
@@ -423,6 +426,7 @@ async def chat_completions(
423426
required_successful_verifications=model_config.required_verifications,
424427
model_stages=model_config.models,
425428
enable_planning=model_config.has_plan_mode,
429+
enable_parallel_check=model_config.parallel_check,
426430
llm_params=llm_params,
427431
)
428432
result = await engine.run()

config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def __init__(self, model_id: str, config: Dict[str, Any]):
2121
self.max_iterations = config.get("max_iterations", 30)
2222
self.required_verifications = config.get("required_verifications", 3)
2323
self.max_errors = config.get("max_errors_before_give_up", 10)
24+
self.parallel_check = config.get("parallel_check", False) # 并行验证模式
2425

2526
# UltraThink 配置
2627
self.num_agent = config.get("num_agent")

config.yaml.example

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ model:
3737
rpm: 10 # 每分钟限制请求数 (可选)
3838
max_iterations: 30 # 最大迭代次数
3939
required_verifications: 3 # 需要的成功验证次数
40+
parallel_check: true # 并行验证模式 (同时启动3个验证LLM调用)
4041
feature:
4142
vision: true # 视觉能力
4243
summary_think: true # 生成思维链摘要

engine/deep_think.py

Lines changed: 105 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def __init__(
4646
model_stages: Dict[str, str] = None,
4747
on_progress: Optional[Callable[[ProgressEvent], None]] = None,
4848
enable_planning: bool = False,
49+
enable_parallel_check: bool = False,
4950
llm_params: Optional[Dict[str, Any]] = None,
5051
):
5152
self.client = client
@@ -62,6 +63,7 @@ def __init__(
6263
self.on_progress = on_progress
6364
self.sources: List[Source] = []
6465
self.enable_planning = enable_planning
66+
self.enable_parallel_check = enable_parallel_check
6567
self.llm_params = llm_params or {}
6668
self._task = None # 用于存储当前任务,以便取消
6769

@@ -158,6 +160,93 @@ async def _verify_solution(
158160

159161
return {"bug_report": bug_report, "good_verify": good_verify}
160162

163+
async def _verify_solution_parallel(
164+
self,
165+
problem_statement: MessageContent,
166+
solution: str
167+
) -> Dict[str, str]:
168+
"""并行验证解决方案 - 同时启动required_verifications个验证LLM调用,全部通过才算成功"""
169+
detailed_solution = self._extract_detailed_solution(solution)
170+
# 提取文本用于构建提示词
171+
problem_text = extract_text_from_content(problem_statement)
172+
verification_prompt = build_verification_prompt(
173+
problem_text,
174+
detailed_solution
175+
)
176+
177+
num_checks = self.required_verifications
178+
self._emit("progress", {"message": f"Parallel verifying solution ({num_checks} concurrent checks)..."})
179+
180+
# 使用验证阶段的模型
181+
verification_model = self._get_model_for_stage("verification")
182+
183+
# 同时启动required_verifications个验证LLM调用
184+
verification_tasks = [
185+
self.client.generate_text(
186+
model=verification_model,
187+
system=VERIFICATION_SYSTEM_PROMPT,
188+
prompt=verification_prompt,
189+
**self.llm_params
190+
)
191+
for _ in range(num_checks)
192+
]
193+
194+
# 等待所有验证完成
195+
verification_outputs = await asyncio.gather(*verification_tasks)
196+
197+
# 检查每个验证结果
198+
check_tasks = []
199+
for verification_output in verification_outputs:
200+
check_prompt = (
201+
f'Response in "yes" or "no". Is the following statement saying the '
202+
f'solution is correct, or does not contain critical error or a major '
203+
f'justification gap?\n\n{verification_output}'
204+
)
205+
check_tasks.append(
206+
self.client.generate_text(
207+
model=verification_model,
208+
prompt=check_prompt,
209+
**self.llm_params
210+
)
211+
)
212+
213+
# 等待所有检查完成
214+
good_verifies = await asyncio.gather(*check_tasks)
215+
216+
# 统计通过的验证数量
217+
passed_count = sum(1 for gv in good_verifies if "yes" in gv.lower())
218+
219+
# 需要全部通过才算验证成功
220+
passed = passed_count == num_checks
221+
222+
# 收集 bug 报告(如果有的话)
223+
bug_reports = []
224+
if not passed:
225+
for i, (verification_output, good_verify) in enumerate(zip(verification_outputs, good_verifies)):
226+
if "yes" not in good_verify.lower():
227+
bug_report = self._extract_detailed_solution(
228+
verification_output,
229+
"Detailed Review",
230+
False
231+
)
232+
if bug_report:
233+
bug_reports.append(f"[Check {i+1}] {bug_report}")
234+
235+
combined_bug_report = "\n\n".join(bug_reports) if bug_reports else ""
236+
237+
# 返回综合结果
238+
good_verify_summary = f"yes (passed {passed_count}/{num_checks} checks)" if passed else f"no (passed {passed_count}/{num_checks} checks)"
239+
240+
return {
241+
"bug_report": combined_bug_report,
242+
"good_verify": good_verify_summary,
243+
"parallel_results": {
244+
"total_checks": num_checks,
245+
"passed_checks": passed_count,
246+
"individual_results": good_verifies
247+
}
248+
}
249+
161250
async def _initial_exploration(
162251
self,
163252
problem_statement: MessageContent,
@@ -242,11 +331,17 @@ async def _initial_exploration(
242331

243332
self._emit("solution", {"solution": improved_solution, "iteration": 0})
244333

245-
# 验证
246-
verification = await self._verify_solution(
247-
problem_statement,
248-
improved_solution
249-
)
334+
# 验证 - 根据配置选择串行或并行
335+
if self.enable_parallel_check:
336+
verification = await self._verify_solution_parallel(
337+
problem_statement,
338+
improved_solution
339+
)
340+
else:
341+
verification = await self._verify_solution(
342+
problem_statement,
343+
improved_solution
344+
)
250345

251346
self._emit("verification", {
252347
"passed": "yes" in verification["good_verify"].lower(),
@@ -400,8 +495,11 @@ async def run(self) -> DeepThinkResult:
400495
knowledge_enhanced=len(self.sources) > 0,
401496
)
402497

403-
# 再次验证
404-
verification = await self._verify_solution(self.problem_statement, solution)
498+
# 再次验证 - 根据配置选择串行或并行
499+
if self.enable_parallel_check:
500+
verification = await self._verify_solution_parallel(self.problem_statement, solution)
501+
else:
502+
verification = await self._verify_solution(self.problem_statement, solution)
405503
self._emit("verification", {
406504
"passed": "yes" in verification["good_verify"].lower(),
407505
"iteration": i + 1,

engine/ultra_think.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def __init__(
4343
model_stages: Dict[str, str] = None,
4444
on_progress: Optional[Callable[[ProgressEvent], None]] = None,
4545
on_agent_update: Optional[Callable[[str, Dict[str, Any]], None]] = None,
46+
enable_parallel_check: bool = False,
4647
llm_params: Optional[Dict[str, Any]] = None,
4748
):
4849
self.client = client
@@ -60,6 +61,7 @@ def __init__(
6061
self.model_stages = model_stages or {}
6162
self.on_progress = on_progress
6263
self.on_agent_update = on_agent_update
64+
self.enable_parallel_check = enable_parallel_check
6365
self.sources: List[Source] = []
6466
self.llm_params = llm_params or {}
6567

@@ -206,6 +208,7 @@ def agent_progress_handler(event: ProgressEvent):
206208
max_errors_before_give_up=self.max_errors,
207209
model_stages=self.model_stages,
208210
on_progress=agent_progress_handler,
211+
enable_parallel_check=self.enable_parallel_check,
209212
llm_params=self.llm_params,
210213
)
211214

@@ -340,6 +343,7 @@ def synthesis_progress_handler(event: ProgressEvent):
340343
max_errors_before_give_up=self.max_errors,
341344
model_stages=self.model_stages,
342345
on_progress=synthesis_progress_handler,
346+
enable_parallel_check=self.enable_parallel_check,
343347
llm_params=self.llm_params,
344348
)
345349

0 commit comments

Comments
 (0)